diff --git a/src/aleph/sdk/account.py b/src/aleph/sdk/account.py index 872ee3c4..aedae1d7 100644 --- a/src/aleph/sdk/account.py +++ b/src/aleph/sdk/account.py @@ -6,6 +6,7 @@ from aleph_message.models import Chain from aleph.sdk.chains.common import get_fallback_private_key +from aleph.sdk.chains.cosmos import CSDKAccount from aleph.sdk.chains.ethereum import ETHAccount from aleph.sdk.chains.evm import EVMAccount from aleph.sdk.chains.remote import RemoteAccount @@ -38,6 +39,7 @@ Chain.SOL: SOLAccount, Chain.WORLDCHAIN: EVMAccount, Chain.ZORA: EVMAccount, + Chain.CSDK: CSDKAccount, } diff --git a/src/aleph/sdk/chains/cosmos.py b/src/aleph/sdk/chains/cosmos.py index d6bba627..fcaa045e 100644 --- a/src/aleph/sdk/chains/cosmos.py +++ b/src/aleph/sdk/chains/cosmos.py @@ -1,10 +1,12 @@ import base64 import hashlib import json -from typing import Union +from pathlib import Path +from typing import Optional, Union import ecdsa from cosmospy._wallet import privkey_to_address, privkey_to_pubkey +from ecdsa import BadSignatureError from .common import BaseAccount, get_fallback_private_key, get_verification_buffer @@ -52,7 +54,7 @@ def __init__(self, private_key=None, hrp=DEFAULT_HRP): async def sign_message(self, message): message = self._setup_sender(message) verif = get_verification_string(message) - base64_pubkey = base64.b64encode(self.get_public_key().encode()).decode("utf-8") + base64_pubkey = base64.b64encode(self.get_public_key()).decode("utf-8") signature = await self.sign_raw(verif.encode("utf-8")) sig = { @@ -78,17 +80,43 @@ def get_address(self) -> str: return privkey_to_address(self.private_key) def get_public_key(self) -> str: - return privkey_to_pubkey(self.private_key).decode() + return privkey_to_pubkey(self.private_key) -def get_fallback_account(hrp=DEFAULT_HRP): - return CSDKAccount(private_key=get_fallback_private_key(), hrp=hrp) +def get_fallback_account(path: Optional[Path] = None, hrp=DEFAULT_HRP): + return CSDKAccount(private_key=get_fallback_private_key(path=path), hrp=hrp) def verify_signature( signature: Union[bytes, str], public_key: Union[bytes, str], message: Union[bytes, str], -) -> bool: - """TODO: Implement this""" - raise NotImplementedError("Not implemented yet") +): + """ + Verifies a signature. + Args: + signature: The signature to verify. Can be a base64 encoded string or bytes. + public_key: The public key to use for verification. Can be a base64 encoded string or bytes. + message: The message to verify. Can be an utf-8 string or bytes. + Raises: + BadSignatureError: If the signature is invalid.! + """ + + if isinstance(signature, str): + signature = base64.b64decode(signature.encode("utf-8")) + if isinstance(public_key, str): + public_key = base64.b64decode(public_key) + if isinstance(message, str): + message = message.encode("utf-8") + + vk = ecdsa.VerifyingKey.from_string(public_key, curve=ecdsa.SECP256k1) + + try: + vk.verify( + signature, + message, + hashfunc=hashlib.sha256, + ) + return True + except Exception as e: + raise BadSignatureError from e diff --git a/src/aleph/sdk/client/http.py b/src/aleph/sdk/client/http.py index 4b42f08a..1b837fbf 100644 --- a/src/aleph/sdk/client/http.py +++ b/src/aleph/sdk/client/http.py @@ -362,7 +362,8 @@ async def get_message( self, item_hash: str, message_type: Optional[Type[GenericMessage]] = None, - ) -> GenericMessage: ... + ) -> GenericMessage: + ... @overload async def get_message( @@ -370,7 +371,8 @@ async def get_message( item_hash: str, message_type: Optional[Type[GenericMessage]] = None, with_status: bool = False, - ) -> Tuple[GenericMessage, MessageStatus]: ... + ) -> Tuple[GenericMessage, MessageStatus]: + ... async def get_message( self, diff --git a/src/aleph/sdk/types.py b/src/aleph/sdk/types.py index c698da5d..ad5671c2 100644 --- a/src/aleph/sdk/types.py +++ b/src/aleph/sdk/types.py @@ -20,28 +20,36 @@ class Account(Protocol): CURVE: str @abstractmethod - async def sign_message(self, message: Dict) -> Dict: ... + async def sign_message(self, message: Dict) -> Dict: + ... @abstractmethod - async def sign_raw(self, buffer: bytes) -> bytes: ... + async def sign_raw(self, buffer: bytes) -> bytes: + ... @abstractmethod - def get_address(self) -> str: ... + def get_address(self) -> str: + ... @abstractmethod - def get_public_key(self) -> str: ... + def get_public_key(self) -> str: + ... class AccountFromPrivateKey(Account, Protocol): """Only accounts that are initialized from a private key string are supported.""" - def __init__(self, private_key: bytes, chain: Chain): ... + def __init__(self, private_key: bytes, chain: Chain): + ... - async def sign_raw(self, buffer: bytes) -> bytes: ... + async def sign_raw(self, buffer: bytes) -> bytes: + ... - def export_private_key(self) -> str: ... + def export_private_key(self) -> str: + ... - def switch_chain(self, chain: Optional[str] = None) -> None: ... + def switch_chain(self, chain: Optional[str] = None) -> None: + ... GenericMessage = TypeVar("GenericMessage", bound=AlephMessage) diff --git a/src/aleph/sdk/utils.py b/src/aleph/sdk/utils.py index c3fc154a..a7c01b2b 100644 --- a/src/aleph/sdk/utils.py +++ b/src/aleph/sdk/utils.py @@ -110,11 +110,13 @@ def check_unix_socket_valid(unix_socket_path: str) -> bool: class AsyncReadable(Protocol[T]): - async def read(self, n: int = -1) -> T: ... + async def read(self, n: int = -1) -> T: + ... class Writable(Protocol[U]): - def write(self, buffer: U) -> int: ... + def write(self, buffer: U) -> int: + ... async def copy_async_readable_to_buffer( diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index c1c56fcd..c8a0d0c5 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -11,6 +11,7 @@ from aiohttp import ClientResponseError from aleph_message.models import AggregateMessage, AlephMessage, PostMessage +import aleph.sdk.chains.cosmos as cosmos import aleph.sdk.chains.ethereum as ethereum import aleph.sdk.chains.solana as solana import aleph.sdk.chains.substrate as substrate @@ -54,6 +55,13 @@ def substrate_account() -> substrate.DOTAccount: yield substrate.get_fallback_account(path=Path(private_key_file.name)) +@pytest.fixture +def cosmos_account() -> cosmos.CSDKAccount: + with NamedTemporaryFile(delete=False) as private_key_file: + private_key_file.close() + yield cosmos.get_fallback_account(path=Path(private_key_file.name)) + + @pytest.fixture def json_messages(): messages_path = Path(__file__).parent / "messages.json" @@ -162,9 +170,11 @@ def __init__(self, sync: bool): async def __aenter__(self): return self - async def __aexit__(self, exc_type, exc_val, exc_tb): ... + async def __aexit__(self, exc_type, exc_val, exc_tb): + ... - async def raise_for_status(self): ... + async def raise_for_status(self): + ... @property def status(self): diff --git a/tests/unit/test_chain_cosmos.py b/tests/unit/test_chain_cosmos.py new file mode 100644 index 00000000..9d871d2c --- /dev/null +++ b/tests/unit/test_chain_cosmos.py @@ -0,0 +1,91 @@ +import base64 +import json +from dataclasses import asdict, dataclass + +import pytest +from ecdsa import BadSignatureError + +from aleph.sdk.chains.common import get_verification_buffer +from aleph.sdk.chains.cosmos import get_verification_string, verify_signature + + +@dataclass +class Message: + chain: str + sender: str + type: str + item_hash: str + + +@pytest.mark.asyncio +async def test_verify_signature(cosmos_account): + message = asdict( + Message( + "CSDK", + cosmos_account.get_address(), + "POST", + "SomeHash", + ) + ) + await cosmos_account.sign_message(message) + assert message["signature"] + signature = json.loads(message["signature"]) + raw_signature = signature["signature"] + assert isinstance(raw_signature, str) + + pub_key = base64.b64decode(signature["pub_key"]["value"]) + + verify_signature( + raw_signature, + pub_key, + get_verification_string(message), + ) + + +@pytest.mark.asyncio +async def test_verify_signature_raw(cosmos_account): + message = asdict( + Message( + "CSDK", + cosmos_account.get_address(), + "POST", + "SomeHash", + ) + ) + await cosmos_account.sign_message(message) + raw_message = get_verification_buffer(message) + raw_signature = await cosmos_account.sign_raw(raw_message) + assert isinstance(raw_signature, bytes) + + pub_key = cosmos_account.get_public_key() + verify_signature( + raw_signature.decode(), + pub_key, + raw_message, + ) + + +@pytest.mark.asyncio +async def test_bad_signature(cosmos_account): + message = asdict( + Message( + "CSDK", + cosmos_account.get_address(), + "POST", + "SomeHash", + ) + ) + await cosmos_account.sign_message(message) + assert message["signature"] + signature = json.loads(message["signature"]) + raw_signature = "1" + signature["signature"] + assert isinstance(raw_signature, str) + + pub_key = base64.b64decode(signature["pub_key"]["value"]) + + with pytest.raises(BadSignatureError): + verify_signature( + raw_signature, + pub_key, + get_verification_string(message), + )