128 lines
		
	
	
	
		
			4.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			128 lines
		
	
	
	
		
			4.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""Library to interface with Solana public keys."""
 | 
						|
from __future__ import annotations
 | 
						|
 | 
						|
from hashlib import sha256
 | 
						|
from typing import Any, List, Optional, Tuple, Union
 | 
						|
 | 
						|
from based58 import b58decode, b58encode
 | 
						|
from nacl.signing import VerifyKey
 | 
						|
 | 
						|
from solana.utils import ed25519_base, helpers
 | 
						|
 | 
						|
 | 
						|
class OnCurveException(Exception):
 | 
						|
    """Raise when generated address is on the curve."""
 | 
						|
 | 
						|
 | 
						|
class PublicKey:
 | 
						|
    """The public key of a keypair.
 | 
						|
 | 
						|
    Example:
 | 
						|
        >>> # An arbitary public key:
 | 
						|
        >>> pubkey = PublicKey(1)
 | 
						|
        >>> str(pubkey) # String representation in base58 form.
 | 
						|
        '11111111111111111111111111111112'
 | 
						|
        >>> bytes(pubkey).hex()
 | 
						|
        '0000000000000000000000000000000000000000000000000000000000000001'
 | 
						|
    """
 | 
						|
 | 
						|
    LENGTH = 32
 | 
						|
    """Constant for standard length of a public key."""
 | 
						|
 | 
						|
    def __init__(self, value: Union[bytearray, bytes, int, str, List[int], VerifyKey]):
 | 
						|
        """Init PublicKey object."""
 | 
						|
        self._key: Optional[bytes] = None
 | 
						|
        if isinstance(value, str):
 | 
						|
            try:
 | 
						|
                self._key = b58decode(value.encode("ascii"))
 | 
						|
            except ValueError as err:
 | 
						|
                raise ValueError("invalid public key input:", value) from err
 | 
						|
            if len(self._key) != self.LENGTH:
 | 
						|
                raise ValueError("invalid public key input:", value)
 | 
						|
        elif isinstance(value, int):
 | 
						|
            self._key = bytes([value])
 | 
						|
        else:
 | 
						|
            self._key = bytes(value)
 | 
						|
 | 
						|
        if len(self._key) > self.LENGTH:
 | 
						|
            raise ValueError("invalid public key input:", value)
 | 
						|
 | 
						|
    def __bytes__(self) -> bytes:
 | 
						|
        """Public key in bytes."""
 | 
						|
        if not self._key:
 | 
						|
            return bytes(self.LENGTH)
 | 
						|
        return self._key if len(self._key) == self.LENGTH else self._key.rjust(self.LENGTH, b"\0")
 | 
						|
 | 
						|
    def __eq__(self, other: Any) -> bool:
 | 
						|
        """Equality definition for PublicKeys."""
 | 
						|
        return False if not isinstance(other, PublicKey) else bytes(self) == bytes(other)
 | 
						|
 | 
						|
    def __hash__(self) -> int:
 | 
						|
        """Returns a unique hash for set operations."""
 | 
						|
        return hash(self.__bytes__())
 | 
						|
 | 
						|
    def __repr__(self) -> str:
 | 
						|
        """Representation of a PublicKey."""
 | 
						|
        return str(self)
 | 
						|
 | 
						|
    def __str__(self) -> str:
 | 
						|
        """String definition for PublicKey."""
 | 
						|
        return self.to_base58().decode("utf-8")
 | 
						|
 | 
						|
    def to_base58(self) -> bytes:
 | 
						|
        """Public key in base58.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            The base58-encoded public key.
 | 
						|
        """
 | 
						|
        return b58encode(bytes(self))
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def create_with_seed(from_public_key: PublicKey, seed: str, program_id: PublicKey) -> PublicKey:
 | 
						|
        """Derive a public key from another key, a seed, and a program ID.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            The derived public key.
 | 
						|
        """
 | 
						|
        buf = bytes(from_public_key) + seed.encode("utf-8") + bytes(program_id)
 | 
						|
        return PublicKey(sha256(buf).digest())
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def create_program_address(seeds: List[bytes], program_id: PublicKey) -> PublicKey:
 | 
						|
        """Derive a program address from seeds and a program ID.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            The derived program address.
 | 
						|
        """
 | 
						|
        buffer = b"".join(seeds + [bytes(program_id), b"ProgramDerivedAddress"])
 | 
						|
        hashbytes: bytes = sha256(buffer).digest()
 | 
						|
        if not PublicKey._is_on_curve(hashbytes):
 | 
						|
            return PublicKey(hashbytes)
 | 
						|
        raise OnCurveException("Invalid seeds, address must fall off the curve")
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def find_program_address(seeds: List[bytes], program_id: PublicKey) -> Tuple[PublicKey, int]:
 | 
						|
        """Find a valid program address.
 | 
						|
 | 
						|
        Valid program addresses must fall off the ed25519 curve.  This function
 | 
						|
        iterates a nonce until it finds one that when combined with the seeds
 | 
						|
        results in a valid program address.
 | 
						|
 | 
						|
        Returns:
 | 
						|
            The program address and nonce used.
 | 
						|
        """
 | 
						|
        nonce = 255
 | 
						|
        while nonce != 0:
 | 
						|
            try:
 | 
						|
                buffer = seeds + [helpers.to_uint8_bytes(nonce)]
 | 
						|
                address = PublicKey.create_program_address(buffer, program_id)
 | 
						|
            except OnCurveException:
 | 
						|
                nonce -= 1
 | 
						|
                continue
 | 
						|
            return address, nonce
 | 
						|
        raise KeyError("Unable to find a viable program address nonce")
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def _is_on_curve(pubkey_bytes: bytes) -> bool:
 | 
						|
        """Verify the point is on curve or not."""
 | 
						|
        return ed25519_base.is_on_curve(pubkey_bytes)
 |