128 lines
4.4 KiB
Python
128 lines
4.4 KiB
Python
"""Library to interface with Solana public keys."""
|
|
from __future__ import annotations
|
|
|
|
from hashlib import sha256
|
|
from typing import Any, List, Optional, Tuple, Union
|
|
|
|
from based58 import b58decode, b58encode
|
|
from nacl.signing import VerifyKey
|
|
|
|
from solana.utils import ed25519_base, helpers
|
|
|
|
|
|
class OnCurveException(Exception):
|
|
"""Raise when generated address is on the curve."""
|
|
|
|
|
|
class PublicKey:
|
|
"""The public key of a keypair.
|
|
|
|
Example:
|
|
>>> # An arbitary public key:
|
|
>>> pubkey = PublicKey(1)
|
|
>>> str(pubkey) # String representation in base58 form.
|
|
'11111111111111111111111111111112'
|
|
>>> bytes(pubkey).hex()
|
|
'0000000000000000000000000000000000000000000000000000000000000001'
|
|
"""
|
|
|
|
LENGTH = 32
|
|
"""Constant for standard length of a public key."""
|
|
|
|
def __init__(self, value: Union[bytearray, bytes, int, str, List[int], VerifyKey]):
|
|
"""Init PublicKey object."""
|
|
self._key: Optional[bytes] = None
|
|
if isinstance(value, str):
|
|
try:
|
|
self._key = b58decode(value.encode("ascii"))
|
|
except ValueError as err:
|
|
raise ValueError("invalid public key input:", value) from err
|
|
if len(self._key) != self.LENGTH:
|
|
raise ValueError("invalid public key input:", value)
|
|
elif isinstance(value, int):
|
|
self._key = bytes([value])
|
|
else:
|
|
self._key = bytes(value)
|
|
|
|
if len(self._key) > self.LENGTH:
|
|
raise ValueError("invalid public key input:", value)
|
|
|
|
def __bytes__(self) -> bytes:
|
|
"""Public key in bytes."""
|
|
if not self._key:
|
|
return bytes(self.LENGTH)
|
|
return self._key if len(self._key) == self.LENGTH else self._key.rjust(self.LENGTH, b"\0")
|
|
|
|
def __eq__(self, other: Any) -> bool:
|
|
"""Equality definition for PublicKeys."""
|
|
return False if not isinstance(other, PublicKey) else bytes(self) == bytes(other)
|
|
|
|
def __hash__(self) -> int:
|
|
"""Returns a unique hash for set operations."""
|
|
return hash(self.__bytes__())
|
|
|
|
def __repr__(self) -> str:
|
|
"""Representation of a PublicKey."""
|
|
return str(self)
|
|
|
|
def __str__(self) -> str:
|
|
"""String definition for PublicKey."""
|
|
return self.to_base58().decode("utf-8")
|
|
|
|
def to_base58(self) -> bytes:
|
|
"""Public key in base58.
|
|
|
|
Returns:
|
|
The base58-encoded public key.
|
|
"""
|
|
return b58encode(bytes(self))
|
|
|
|
@staticmethod
|
|
def create_with_seed(from_public_key: PublicKey, seed: str, program_id: PublicKey) -> PublicKey:
|
|
"""Derive a public key from another key, a seed, and a program ID.
|
|
|
|
Returns:
|
|
The derived public key.
|
|
"""
|
|
buf = bytes(from_public_key) + seed.encode("utf-8") + bytes(program_id)
|
|
return PublicKey(sha256(buf).digest())
|
|
|
|
@staticmethod
|
|
def create_program_address(seeds: List[bytes], program_id: PublicKey) -> PublicKey:
|
|
"""Derive a program address from seeds and a program ID.
|
|
|
|
Returns:
|
|
The derived program address.
|
|
"""
|
|
buffer = b"".join(seeds + [bytes(program_id), b"ProgramDerivedAddress"])
|
|
hashbytes: bytes = sha256(buffer).digest()
|
|
if not PublicKey._is_on_curve(hashbytes):
|
|
return PublicKey(hashbytes)
|
|
raise OnCurveException("Invalid seeds, address must fall off the curve")
|
|
|
|
@staticmethod
|
|
def find_program_address(seeds: List[bytes], program_id: PublicKey) -> Tuple[PublicKey, int]:
|
|
"""Find a valid program address.
|
|
|
|
Valid program addresses must fall off the ed25519 curve. This function
|
|
iterates a nonce until it finds one that when combined with the seeds
|
|
results in a valid program address.
|
|
|
|
Returns:
|
|
The program address and nonce used.
|
|
"""
|
|
nonce = 255
|
|
while nonce != 0:
|
|
try:
|
|
buffer = seeds + [helpers.to_uint8_bytes(nonce)]
|
|
address = PublicKey.create_program_address(buffer, program_id)
|
|
except OnCurveException:
|
|
nonce -= 1
|
|
continue
|
|
return address, nonce
|
|
raise KeyError("Unable to find a viable program address nonce")
|
|
|
|
@staticmethod
|
|
def _is_on_curve(pubkey_bytes: bytes) -> bool:
|
|
"""Verify the point is on curve or not."""
|
|
return ed25519_base.is_on_curve(pubkey_bytes)
|