diff --git a/src/aleph/sdk/chains/ethereum.py b/src/aleph/sdk/chains/ethereum.py index ab93df56..c185d174 100644 --- a/src/aleph/sdk/chains/ethereum.py +++ b/src/aleph/sdk/chains/ethereum.py @@ -15,6 +15,7 @@ from web3.types import TxParams, TxReceipt from aleph.sdk.exceptions import InsufficientFundsError +from aleph.sdk.types import TokenType from ..conf import settings from ..connectors.superfluid import Superfluid @@ -22,12 +23,13 @@ BALANCEOF_ABI, MIN_ETH_BALANCE, MIN_ETH_BALANCE_WEI, + FlowUpdate, + from_wei_token, get_chain_id, get_chains_with_super_token, get_rpc, get_super_token_address, get_token_address, - to_human_readable_token, ) from ..exceptions import BadSignatureError from ..utils import bytes_from_hex @@ -106,8 +108,9 @@ def can_transact(self, block=True) -> bool: valid = balance > MIN_ETH_BALANCE_WEI if self.chain else False if not valid and block: raise InsufficientFundsError( + token_type=TokenType.GAS, required_funds=MIN_ETH_BALANCE, - available_funds=to_human_readable_token(balance), + available_funds=float(from_wei_token(balance)), ) return valid @@ -162,6 +165,12 @@ def get_super_token_balance(self) -> Decimal: return Decimal(contract.functions.balanceOf(self.get_address()).call()) return Decimal(0) + def can_start_flow(self, flow: Decimal) -> bool: + """Check if the account has enough funds to start a Superfluid flow of the given size.""" + if not self.superfluid_connector: + raise ValueError("Superfluid connector is required to check a flow") + return self.superfluid_connector.can_start_flow(flow) + def create_flow(self, receiver: str, flow: Decimal) -> Awaitable[str]: """Creat a Superfluid flow between this account and the receiver address.""" if not self.superfluid_connector: @@ -188,6 +197,19 @@ def delete_flow(self, receiver: str) -> Awaitable[str]: raise ValueError("Superfluid connector is required to delete a flow") return self.superfluid_connector.delete_flow(receiver=receiver) + def manage_flow( + self, + receiver: str, + flow: Decimal, + update_type: FlowUpdate, + ) -> Awaitable[Optional[str]]: + """Manage the Superfluid flow between this account and the receiver address.""" + if not self.superfluid_connector: + raise ValueError("Superfluid connector is required to manage a flow") + return self.superfluid_connector.manage_flow( + receiver=receiver, flow=flow, update_type=update_type + ) + def get_fallback_account( path: Optional[Path] = None, chain: Optional[Chain] = None diff --git a/src/aleph/sdk/chains/evm.py b/src/aleph/sdk/chains/evm.py index 5bf66ef1..a5eeed84 100644 --- a/src/aleph/sdk/chains/evm.py +++ b/src/aleph/sdk/chains/evm.py @@ -5,6 +5,7 @@ from aleph_message.models import Chain from eth_account import Account # type: ignore +from ..evm_utils import FlowUpdate from .common import get_fallback_private_key from .ethereum import ETHAccount @@ -29,6 +30,9 @@ def get_token_balance(self) -> Decimal: def get_super_token_balance(self) -> Decimal: raise ValueError(f"Super token not implemented for this chain {self.CHAIN}") + def can_start_flow(self, flow: Decimal) -> bool: + raise ValueError(f"Flow checking not implemented for this chain {self.CHAIN}") + def create_flow(self, receiver: str, flow: Decimal) -> Awaitable[str]: raise ValueError(f"Flow creation not implemented for this chain {self.CHAIN}") @@ -41,6 +45,11 @@ def update_flow(self, receiver: str, flow: Decimal) -> Awaitable[str]: def delete_flow(self, receiver: str) -> Awaitable[str]: raise ValueError(f"Flow deletion not implemented for this chain {self.CHAIN}") + def manage_flow( + self, receiver: str, flow: Decimal, update_type: FlowUpdate + ) -> Awaitable[Optional[str]]: + raise ValueError(f"Flow management not implemented for this chain {self.CHAIN}") + def get_fallback_account( path: Optional[Path] = None, chain: Optional[Chain] = None diff --git a/src/aleph/sdk/client/abstract.py b/src/aleph/sdk/client/abstract.py index 025aae6a..7f9fed8e 100644 --- a/src/aleph/sdk/client/abstract.py +++ b/src/aleph/sdk/client/abstract.py @@ -20,9 +20,9 @@ from aleph_message.models import ( AlephMessage, + ExecutableContent, ItemHash, ItemType, - MessagesResponse, MessageType, Payment, PostMessage, @@ -41,7 +41,7 @@ from aleph.sdk.utils import extended_json_encoder from ..query.filters import MessageFilter, PostFilter -from ..query.responses import PostsResponse, PriceResponse +from ..query.responses import MessagesResponse, PostsResponse, PriceResponse from ..types import GenericMessage, StorageEnum from ..utils import Writable, compute_sha256 @@ -110,7 +110,7 @@ async def get_posts_iterator( ) page += 1 for post in resp.posts: - yield post + yield post # type: ignore @abstractmethod async def download_file(self, file_hash: str) -> bytes: @@ -242,6 +242,18 @@ def watch_messages( """ raise NotImplementedError("Did you mean to import `AlephHttpClient`?") + @abstractmethod + def get_estimated_price( + self, + content: ExecutableContent, + ) -> Coroutine[Any, Any, PriceResponse]: + """ + Get Instance/Program content estimated price + + :param content: Instance or Program content + """ + raise NotImplementedError("Did you mean to import `AlephHttpClient`?") + @abstractmethod def get_program_price( self, @@ -265,7 +277,7 @@ async def create_post( post_type: str, ref: Optional[str] = None, address: Optional[str] = None, - channel: Optional[str] = None, + channel: Optional[str] = settings.DEFAULT_CHANNEL, inline: bool = True, storage_engine: StorageEnum = StorageEnum.storage, sync: bool = False, @@ -290,9 +302,9 @@ async def create_post( async def create_aggregate( self, key: str, - content: Mapping[str, Any], + content: dict[str, Any], address: Optional[str] = None, - channel: Optional[str] = None, + channel: Optional[str] = settings.DEFAULT_CHANNEL, inline: bool = True, sync: bool = False, ) -> Tuple[AlephMessage, MessageStatus]: @@ -302,7 +314,7 @@ async def create_aggregate( :param key: Key to use to store the content :param content: Content to store :param address: Address to use to sign the message - :param channel: Channel to use (Default: "TEST") + :param channel: Channel to use (Default: "ALEPH-CLOUDSOLUTIONS") :param inline: Whether to write content inside the message (Default: True) :param sync: If true, waits for the message to be processed by the API server (Default: False) """ @@ -321,7 +333,7 @@ async def create_store( ref: Optional[str] = None, storage_engine: StorageEnum = StorageEnum.storage, extra_fields: Optional[dict] = None, - channel: Optional[str] = None, + channel: Optional[str] = settings.DEFAULT_CHANNEL, sync: bool = False, ) -> Tuple[AlephMessage, MessageStatus]: """ @@ -350,22 +362,22 @@ async def create_program( program_ref: str, entrypoint: str, runtime: str, - environment_variables: Optional[Mapping[str, str]] = None, - storage_engine: StorageEnum = StorageEnum.storage, - channel: Optional[str] = None, + metadata: Optional[dict[str, Any]] = None, address: Optional[str] = None, - sync: bool = False, - memory: Optional[int] = None, vcpus: Optional[int] = None, + memory: Optional[int] = None, timeout_seconds: Optional[float] = None, - persistent: bool = False, - allow_amend: bool = False, internet: bool = True, + allow_amend: bool = False, aleph_api: bool = True, encoding: Encoding = Encoding.zip, + persistent: bool = False, volumes: Optional[List[Mapping]] = None, - subscriptions: Optional[List[Mapping]] = None, - metadata: Optional[Mapping[str, Any]] = None, + environment_variables: Optional[dict[str, str]] = None, + subscriptions: Optional[List[dict]] = None, + sync: bool = False, + channel: Optional[str] = settings.DEFAULT_CHANNEL, + storage_engine: StorageEnum = StorageEnum.storage, ) -> Tuple[AlephMessage, MessageStatus]: """ Post a (create) PROGRAM message. @@ -373,22 +385,22 @@ async def create_program( :param program_ref: Reference to the program to run :param entrypoint: Entrypoint to run :param runtime: Runtime to use - :param environment_variables: Environment variables to pass to the program - :param storage_engine: Storage engine to use (Default: "storage") - :param channel: Channel to use (Default: "TEST") + :param metadata: Metadata to attach to the message :param address: Address to use (Default: account.get_address()) - :param sync: If true, waits for the message to be processed by the API server - :param memory: Memory in MB for the VM to be allocated (Default: 128) :param vcpus: Number of vCPUs to allocate (Default: 1) + :param memory: Memory in MB for the VM to be allocated (Default: 128) :param timeout_seconds: Timeout in seconds (Default: 30.0) - :param persistent: Whether the program should be persistent or not (Default: False) - :param allow_amend: Whether the deployed VM image may be changed (Default: False) :param internet: Whether the VM should have internet connectivity. (Default: True) + :param allow_amend: Whether the deployed VM image may be changed (Default: False) :param aleph_api: Whether the VM needs access to Aleph messages API (Default: True) :param encoding: Encoding to use (Default: Encoding.zip) + :param persistent: Whether the program should be persistent or not (Default: False) :param volumes: Volumes to mount + :param environment_variables: Environment variables to pass to the program :param subscriptions: Patterns of aleph.im messages to forward to the program's event receiver - :param metadata: Metadata to attach to the message + :param sync: If true, waits for the message to be processed by the API server + :param channel: Channel to use (Default: "ALEPH-CLOUDSOLUTIONS") + :param storage_engine: Storage engine to use (Default: "storage") """ raise NotImplementedError( "Did you mean to import `AuthenticatedAlephHttpClient`?" @@ -400,9 +412,9 @@ async def create_instance( rootfs: str, rootfs_size: int, payment: Optional[Payment] = None, - environment_variables: Optional[Mapping[str, str]] = None, + environment_variables: Optional[dict[str, str]] = None, storage_engine: StorageEnum = StorageEnum.storage, - channel: Optional[str] = None, + channel: Optional[str] = settings.DEFAULT_CHANNEL, address: Optional[str] = None, sync: bool = False, memory: Optional[int] = None, @@ -416,7 +428,7 @@ async def create_instance( volumes: Optional[List[Mapping]] = None, volume_persistence: str = "host", ssh_keys: Optional[List[str]] = None, - metadata: Optional[Mapping[str, Any]] = None, + metadata: Optional[dict[str, Any]] = None, requirements: Optional[HostRequirements] = None, ) -> Tuple[AlephMessage, MessageStatus]: """ @@ -427,7 +439,7 @@ async def create_instance( :param payment: Payment method used to pay for the instance :param environment_variables: Environment variables to pass to the program :param storage_engine: Storage engine to use (Default: "storage") - :param channel: Channel to use (Default: "TEST") + :param channel: Channel to use (Default: "ALEPH-CLOUDSOLUTIONS") :param address: Address to use (Default: account.get_address()) :param sync: If true, waits for the message to be processed by the API server :param memory: Memory in MB for the VM to be allocated (Default: 2048) @@ -455,7 +467,7 @@ async def forget( hashes: List[ItemHash], reason: Optional[str], storage_engine: StorageEnum = StorageEnum.storage, - channel: Optional[str] = None, + channel: Optional[str] = settings.DEFAULT_CHANNEL, address: Optional[str] = None, sync: bool = False, ) -> Tuple[AlephMessage, MessageStatus]: @@ -468,7 +480,7 @@ async def forget( :param hashes: Hashes of the messages to forget :param reason: Reason for forgetting the messages :param storage_engine: Storage engine to use (Default: "storage") - :param channel: Channel to use (Default: "TEST") + :param channel: Channel to use (Default: "ALEPH-CLOUDSOLUTIONS") :param address: Address to use (Default: account.get_address()) :param sync: If true, waits for the message to be processed by the API server (Default: False) """ @@ -490,7 +502,7 @@ async def generate_signed_message( :param message_type: Type of the message (PostMessage, ...) :param content: User-defined content of the message - :param channel: Channel to use (Default: "TEST") + :param channel: Channel to use (Default: "ALEPH-CLOUDSOLUTIONS") :param allow_inlining: Whether to allow inlining the content of the message (Default: True) :param storage_engine: Storage engine to use (Default: "storage") """ @@ -537,7 +549,7 @@ async def submit( self, content: Dict[str, Any], message_type: MessageType, - channel: Optional[str] = None, + channel: Optional[str] = settings.DEFAULT_CHANNEL, storage_engine: StorageEnum = StorageEnum.storage, allow_inlining: bool = True, sync: bool = False, @@ -549,7 +561,7 @@ async def submit( :param content: Content of the message :param message_type: Type of the message - :param channel: Channel to use (Default: "TEST") + :param channel: Channel to use (Default: "ALEPH-CLOUDSOLUTIONS") :param storage_engine: Storage engine to use (Default: "storage") :param allow_inlining: Whether to allow inlining the content of the message (Default: True) :param sync: If true, waits for the message to be processed by the API server (Default: False) diff --git a/src/aleph/sdk/client/authenticated_http.py b/src/aleph/sdk/client/authenticated_http.py index f84b97ca..9bb9a1e7 100644 --- a/src/aleph/sdk/client/authenticated_http.py +++ b/src/aleph/sdk/client/authenticated_http.py @@ -5,45 +5,37 @@ import time from io import BytesIO from pathlib import Path -from typing import Any, Dict, List, Mapping, NoReturn, Optional, Tuple, Union +from typing import Any, Dict, Mapping, NoReturn, Optional, Tuple, Union import aiohttp from aleph_message.models import ( AggregateContent, AggregateMessage, AlephMessage, - Chain, ForgetContent, ForgetMessage, - InstanceContent, InstanceMessage, ItemHash, + ItemType, MessageType, PostContent, PostMessage, - ProgramContent, ProgramMessage, StoreContent, StoreMessage, ) -from aleph_message.models.execution.base import Encoding, Payment, PaymentType +from aleph_message.models.execution.base import Encoding, Payment from aleph_message.models.execution.environment import ( - FunctionEnvironment, HostRequirements, HypervisorType, - InstanceEnvironment, - MachineResources, TrustedExecutionEnvironment, ) -from aleph_message.models.execution.instance import RootfsVolume -from aleph_message.models.execution.program import CodeContent, FunctionRuntime -from aleph_message.models.execution.volume import MachineVolume, ParentVolume from aleph_message.status import MessageStatus from ..conf import settings from ..exceptions import BroadcastError, InsufficientFundsError, InvalidMessageError -from ..types import Account, StorageEnum -from ..utils import extended_json_encoder, parse_volume +from ..types import Account, StorageEnum, TokenType +from ..utils import extended_json_encoder, make_instance_content, make_program_content from .abstract import AuthenticatedAlephClient from .http import AlephHttpClient @@ -285,7 +277,7 @@ async def create_post( post_type: str, ref: Optional[str] = None, address: Optional[str] = None, - channel: Optional[str] = None, + channel: Optional[str] = settings.DEFAULT_CHANNEL, inline: bool = True, storage_engine: StorageEnum = StorageEnum.storage, sync: bool = False, @@ -308,14 +300,14 @@ async def create_post( storage_engine=storage_engine, sync=sync, ) - return message, status + return message, status # type: ignore async def create_aggregate( self, key: str, - content: Mapping[str, Any], + content: dict[str, Any], address: Optional[str] = None, - channel: Optional[str] = None, + channel: Optional[str] = settings.DEFAULT_CHANNEL, inline: bool = True, sync: bool = False, ) -> Tuple[AggregateMessage, MessageStatus]: @@ -335,7 +327,7 @@ async def create_aggregate( allow_inlining=inline, sync=sync, ) - return message, status + return message, status # type: ignore async def create_store( self, @@ -347,7 +339,7 @@ async def create_store( ref: Optional[str] = None, storage_engine: StorageEnum = StorageEnum.storage, extra_fields: Optional[dict] = None, - channel: Optional[str] = None, + channel: Optional[str] = settings.DEFAULT_CHANNEL, sync: bool = False, ) -> Tuple[StoreMessage, MessageStatus]: address = address or settings.ADDRESS_TO_USE or self.account.get_address() @@ -400,7 +392,7 @@ async def create_store( if extra_fields is not None: values.update(extra_fields) - content = StoreContent(**values) + content = StoreContent.parse_obj(values) message, status, _ = await self.submit( content=content.dict(exclude_none=True), @@ -409,109 +401,89 @@ async def create_store( allow_inlining=True, sync=sync, ) - return message, status + return message, status # type: ignore async def create_program( self, program_ref: str, entrypoint: str, runtime: str, - environment_variables: Optional[Mapping[str, str]] = None, - storage_engine: StorageEnum = StorageEnum.storage, - channel: Optional[str] = None, + metadata: Optional[dict[str, Any]] = None, address: Optional[str] = None, - sync: bool = False, - memory: Optional[int] = None, vcpus: Optional[int] = None, + memory: Optional[int] = None, timeout_seconds: Optional[float] = None, - persistent: bool = False, - allow_amend: bool = False, internet: bool = True, + allow_amend: bool = False, aleph_api: bool = True, encoding: Encoding = Encoding.zip, - volumes: Optional[List[Mapping]] = None, - subscriptions: Optional[List[Mapping]] = None, - metadata: Optional[Mapping[str, Any]] = None, + persistent: bool = False, + volumes: Optional[list[Mapping]] = None, + environment_variables: Optional[dict[str, str]] = None, + subscriptions: Optional[list[dict]] = None, + sync: bool = False, + channel: Optional[str] = settings.DEFAULT_CHANNEL, + storage_engine: StorageEnum = StorageEnum.storage, ) -> Tuple[ProgramMessage, MessageStatus]: address = address or settings.ADDRESS_TO_USE or self.account.get_address() - volumes = volumes if volumes is not None else [] - memory = memory or settings.DEFAULT_VM_MEMORY - vcpus = vcpus or settings.DEFAULT_VM_VCPUS - timeout_seconds = timeout_seconds or settings.DEFAULT_VM_TIMEOUT - - # TODO: Check that program_ref, runtime and data_ref exist - - # Register the different ways to trigger a VM - if subscriptions: - # Trigger on HTTP calls and on aleph.im message subscriptions. - triggers = { - "http": True, - "persistent": persistent, - "message": subscriptions, - } - else: - # Trigger on HTTP calls. - triggers = {"http": True, "persistent": persistent} - - volumes: List[MachineVolume] = [parse_volume(volume) for volume in volumes] - - content = ProgramContent( - type="vm-function", + content = make_program_content( + program_ref=program_ref, + entrypoint=entrypoint, + runtime=runtime, + metadata=metadata, address=address, + vcpus=vcpus, + memory=memory, + timeout_seconds=timeout_seconds, + internet=internet, + aleph_api=aleph_api, allow_amend=allow_amend, - code=CodeContent( - encoding=encoding, - entrypoint=entrypoint, - ref=program_ref, - use_latest=True, - ), - on=triggers, - environment=FunctionEnvironment( - reproducible=False, - internet=internet, - aleph_api=aleph_api, - ), - variables=environment_variables, - resources=MachineResources( - vcpus=vcpus, - memory=memory, - seconds=timeout_seconds, - ), - runtime=FunctionRuntime( - ref=runtime, - use_latest=True, - comment=( - "Official aleph.im runtime" - if runtime == settings.DEFAULT_RUNTIME_ID - else "" - ), - ), - volumes=[parse_volume(volume) for volume in volumes], - time=time.time(), - metadata=metadata, + encoding=encoding, + persistent=persistent, + volumes=volumes, + environment_variables=environment_variables, + subscriptions=subscriptions, ) - # Ensure that the version of aleph-message used supports the field. - assert content.on.persistent == persistent - message, status, _ = await self.submit( content=content.dict(exclude_none=True), message_type=MessageType.program, channel=channel, storage_engine=storage_engine, sync=sync, + raise_on_rejected=False, ) - return message, status + if status in (MessageStatus.PROCESSED, MessageStatus.PENDING): + return message, status # type: ignore + + # get the reason for rejection + rejected_message = await self.get_message_error(message.item_hash) + assert rejected_message, "No rejected message found" + error_code = rejected_message["error_code"] + if error_code == 5: + # not enough balance + details = rejected_message["details"] + errors = details["errors"] + error = errors[0] + account_balance = float(error["account_balance"]) + required_balance = float(error["required_balance"]) + raise InsufficientFundsError( + token_type=TokenType.ALEPH, + required_funds=required_balance, + available_funds=account_balance, + ) + else: + raise ValueError(f"Unknown error code {error_code}: {rejected_message}") async def create_instance( self, rootfs: str, rootfs_size: int, payment: Optional[Payment] = None, - environment_variables: Optional[Mapping[str, str]] = None, + environment_variables: Optional[dict[str, str]] = None, storage_engine: StorageEnum = StorageEnum.storage, - channel: Optional[str] = None, + channel: Optional[str] = settings.DEFAULT_CHANNEL, address: Optional[str] = None, sync: bool = False, memory: Optional[int] = None, @@ -522,55 +494,34 @@ async def create_instance( aleph_api: bool = True, hypervisor: Optional[HypervisorType] = None, trusted_execution: Optional[TrustedExecutionEnvironment] = None, - volumes: Optional[List[Mapping]] = None, + volumes: Optional[list[Mapping]] = None, volume_persistence: str = "host", - ssh_keys: Optional[List[str]] = None, - metadata: Optional[Mapping[str, Any]] = None, + ssh_keys: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, requirements: Optional[HostRequirements] = None, ) -> Tuple[InstanceMessage, MessageStatus]: address = address or settings.ADDRESS_TO_USE or self.account.get_address() - volumes = volumes if volumes is not None else [] - memory = memory or settings.DEFAULT_VM_MEMORY - vcpus = vcpus or settings.DEFAULT_VM_VCPUS - timeout_seconds = timeout_seconds or settings.DEFAULT_VM_TIMEOUT - - payment = payment or Payment(chain=Chain.ETH, type=PaymentType.hold) - - # Default to the QEMU hypervisor for instances. - selected_hypervisor: HypervisorType = hypervisor or HypervisorType.qemu - - content = InstanceContent( + content = make_instance_content( + rootfs=rootfs, + rootfs_size=rootfs_size, + payment=payment, + environment_variables=environment_variables, address=address, + memory=memory, + vcpus=vcpus, + timeout_seconds=timeout_seconds, allow_amend=allow_amend, - environment=InstanceEnvironment( - internet=internet, - aleph_api=aleph_api, - hypervisor=selected_hypervisor, - trusted_execution=trusted_execution, - ), - variables=environment_variables, - resources=MachineResources( - vcpus=vcpus, - memory=memory, - seconds=timeout_seconds, - ), - rootfs=RootfsVolume( - parent=ParentVolume( - ref=rootfs, - use_latest=True, - ), - size_mib=rootfs_size, - persistence="host", - use_latest=True, - ), - volumes=[parse_volume(volume) for volume in volumes], - requirements=requirements, - time=time.time(), - authorized_keys=ssh_keys, + internet=internet, + aleph_api=aleph_api, + hypervisor=hypervisor, + trusted_execution=trusted_execution, + volumes=volumes, + ssh_keys=ssh_keys, metadata=metadata, - payment=payment, + requirements=requirements, ) + message, status, response = await self.submit( content=content.dict(exclude_none=True), message_type=MessageType.instance, @@ -580,7 +531,7 @@ async def create_instance( raise_on_rejected=False, ) if status in (MessageStatus.PROCESSED, MessageStatus.PENDING): - return message, status + return message, status # type: ignore # get the reason for rejection rejected_message = await self.get_message_error(message.item_hash) @@ -594,17 +545,19 @@ async def create_instance( account_balance = float(error["account_balance"]) required_balance = float(error["required_balance"]) raise InsufficientFundsError( - required_funds=required_balance, available_funds=account_balance + token_type=TokenType.ALEPH, + required_funds=required_balance, + available_funds=account_balance, ) else: raise ValueError(f"Unknown error code {error_code}: {rejected_message}") async def forget( self, - hashes: List[ItemHash], + hashes: list[ItemHash], reason: Optional[str], storage_engine: StorageEnum = StorageEnum.storage, - channel: Optional[str] = None, + channel: Optional[str] = settings.DEFAULT_CHANNEL, address: Optional[str] = None, sync: bool = False, ) -> Tuple[ForgetMessage, MessageStatus]: @@ -625,13 +578,13 @@ async def forget( allow_inlining=True, sync=sync, ) - return message, status + return message, status # type: ignore async def submit( self, content: Dict[str, Any], message_type: MessageType, - channel: Optional[str] = None, + channel: Optional[str] = settings.DEFAULT_CHANNEL, storage_engine: StorageEnum = StorageEnum.storage, allow_inlining: bool = True, sync: bool = False, @@ -653,7 +606,7 @@ async def _storage_push_file_with_message( self, file_content: bytes, store_content: StoreContent, - channel: Optional[str] = None, + channel: Optional[str] = settings.DEFAULT_CHANNEL, sync: bool = False, ) -> Tuple[StoreMessage, MessageStatus]: """Push a file to the storage service.""" @@ -685,7 +638,7 @@ async def _storage_push_file_with_message( message_status = ( MessageStatus.PENDING if resp.status == 202 else MessageStatus.PROCESSED ) - return message, message_status + return message, message_status # type: ignore async def _upload_file_native( self, @@ -694,7 +647,7 @@ async def _upload_file_native( guess_mime_type: bool = False, ref: Optional[str] = None, extra_fields: Optional[dict] = None, - channel: Optional[str] = None, + channel: Optional[str] = settings.DEFAULT_CHANNEL, sync: bool = False, ) -> Tuple[StoreMessage, MessageStatus]: file_hash = hashlib.sha256(file_content).hexdigest() @@ -706,9 +659,9 @@ async def _upload_file_native( store_content = StoreContent( address=address, ref=ref, - item_type=StorageEnum.storage, - item_hash=file_hash, - mime_type=mime_type, + item_type=ItemType.storage, + item_hash=ItemHash(file_hash), + mime_type=mime_type, # type: ignore time=time.time(), **extra_fields, ) diff --git a/src/aleph/sdk/client/http.py b/src/aleph/sdk/client/http.py index 4b42f08a..f4e8b898 100644 --- a/src/aleph/sdk/client/http.py +++ b/src/aleph/sdk/client/http.py @@ -2,6 +2,7 @@ import logging import os.path import ssl +import time from io import BytesIO from pathlib import Path from typing import ( @@ -20,7 +21,15 @@ import aiohttp from aiohttp.web import HTTPNotFound from aleph_message import parse_message -from aleph_message.models import AlephMessage, ItemHash, ItemType, MessageType +from aleph_message.models import ( + AlephMessage, + Chain, + ExecutableContent, + ItemHash, + ItemType, + MessageType, + ProgramContent, +) from aleph_message.status import MessageStatus from pydantic import ValidationError @@ -37,6 +46,7 @@ from ..utils import ( Writable, check_unix_socket_valid, + compute_sha256, copy_async_readable_to_buffer, extended_json_encoder, get_message_type_value, @@ -358,7 +368,7 @@ async def get_messages( ) @overload - async def get_message( + async def get_message( # type: ignore self, item_hash: str, message_type: Optional[Type[GenericMessage]] = None, @@ -383,7 +393,7 @@ async def get_message( resp.raise_for_status() except aiohttp.ClientResponseError as e: if e.status == 404: - raise MessageNotFoundError(f"No such hash {item_hash}") + raise MessageNotFoundError(f"No such hash {item_hash}") from e raise e message_raw = await resp.json() if message_raw["status"] == "forgotten": @@ -399,9 +409,9 @@ async def get_message( f"does not match the expected type '{expected_type}'" ) if with_status: - return message, message_raw["status"] + return message, message_raw["status"] # type: ignore else: - return message + return message # type: ignore async def get_message_error( self, @@ -448,6 +458,47 @@ async def watch_messages( elif msg.type == aiohttp.WSMsgType.ERROR: break + async def get_estimated_price( + self, + content: ExecutableContent, + ) -> PriceResponse: + cleaned_content = content.dict(exclude_none=True) + item_content: str = json.dumps( + cleaned_content, + separators=(",", ":"), + default=extended_json_encoder, + ) + message = parse_message( + dict( + sender=content.address, + chain=Chain.ETH, + type=( + MessageType.program + if isinstance(content, ProgramContent) + else MessageType.instance + ), + content=cleaned_content, + item_content=item_content, + time=time.time(), + channel=settings.DEFAULT_CHANNEL, + item_type=ItemType.inline, + item_hash=compute_sha256(item_content), + ) + ) + + async with self.http_session.post( + "/api/v0/price/estimate", json=dict(message=message) + ) as resp: + try: + resp.raise_for_status() + response_json = await resp.json() + return PriceResponse( + required_tokens=response_json["required_tokens"], + payment_type=response_json["payment_type"], + ) + except aiohttp.ClientResponseError as e: + raise e + async def get_program_price(self, item_hash: str) -> PriceResponse: async with self.http_session.get(f"/api/v0/price/{item_hash}") as resp: try: @@ -491,15 +542,21 @@ async def get_stored_content( resp = f"Invalid CID: {message.content.item_hash}" else: filename = safe_getattr(message.content, "metadata.name") - hash = message.content.item_hash + item_hash = message.content.item_hash url = ( f"{self.api_server}/api/v0/storage/raw/" - if len(hash) == 64 + if len(item_hash) == 64 else settings.IPFS_GATEWAY - ) + hash - result = StoredContent(filename=filename, hash=hash, url=url) + ) + item_hash + result = StoredContent( + filename=filename, hash=item_hash, url=url, error=None + ) except MessageNotFoundError: resp = f"Message not found: {item_hash}" except ForgottenMessageError: resp = f"Message forgotten: {item_hash}" - return result if result else StoredContent(error=resp) + return ( + result + if result + else StoredContent(error=resp, filename=None, hash=None, url=None) + ) diff --git a/src/aleph/sdk/conf.py b/src/aleph/sdk/conf.py index 4dc7c9e7..c925a05e 100644 --- a/src/aleph/sdk/conf.py +++ b/src/aleph/sdk/conf.py @@ -44,27 +44,22 @@ class Settings(BaseSettings): HTTP_REQUEST_TIMEOUT = 15.0 DEFAULT_CHANNEL: str = "ALEPH-CLOUDSOLUTIONS" + + # Firecracker runtime for programs DEFAULT_RUNTIME_ID: str = ( "63f07193e6ee9d207b7d1fcf8286f9aee34e6f12f101d2ec77c1229f92964696" ) - DEBIAN_11_ROOTFS_ID: str = ( - "887957042bb0e360da3485ed33175882ce72a70d79f1ba599400ff4802b7cee7" - ) - DEBIAN_12_ROOTFS_ID: str = ( - "6e30de68c6cedfa6b45240c2b51e52495ac6fb1bd4b36457b3d5ca307594d595" - ) - UBUNTU_22_ROOTFS_ID: str = ( - "77fef271aa6ff9825efa3186ca2e715d19e7108279b817201c69c34cedc74c27" - ) - DEBIAN_11_QEMU_ROOTFS_ID: str = ( - "f7e68c568906b4ebcd3cd3c4bfdff96c489cd2a9ef73ba2d7503f244dfd578de" - ) + + # Qemu rootfs for instances DEBIAN_12_QEMU_ROOTFS_ID: str = ( "b6ff5c3a8205d1ca4c7c3369300eeafff498b558f71b851aa2114afd0a532717" ) UBUNTU_22_QEMU_ROOTFS_ID: str = ( "4a0f62da42f4478544616519e6f5d58adb1096e069b392b151d47c3609492d0c" ) + UBUNTU_24_QEMU_ROOTFS_ID: str = ( + "5330dcefe1857bcd97b7b7f24d1420a7d46232d53f27be280c8a7071d88bd84e" + ) DEFAULT_CONFIDENTIAL_FIRMWARE: str = ( "ba5bb13f3abca960b101a759be162b229e2b7e93ecad9d1307e54de887f177ff" @@ -86,6 +81,7 @@ class Settings(BaseSettings): VM_URL_PATH = "https://aleph.sh/vm/{hash}" VM_URL_HOST = "https://{hash_base32}.aleph.sh" IPFS_GATEWAY = "https://ipfs.aleph.cloud/ipfs/" + CRN_URL_FOR_PROGRAMS = "https://dchq.staging.aleph.sh/" # Web3Provider settings TOKEN_DECIMALS = 18 diff --git a/src/aleph/sdk/connectors/superfluid.py b/src/aleph/sdk/connectors/superfluid.py index 4b7274f8..76bbf907 100644 --- a/src/aleph/sdk/connectors/superfluid.py +++ b/src/aleph/sdk/connectors/superfluid.py @@ -1,14 +1,19 @@ from __future__ import annotations from decimal import Decimal -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from eth_utils import to_normalized_address from superfluid import CFA_V1, Operation, Web3FlowInfo +from aleph.sdk.evm_utils import ( + FlowUpdate, + from_wei_token, + get_super_token_address, + to_wei_token, +) from aleph.sdk.exceptions import InsufficientFundsError - -from ..evm_utils import get_super_token_address, to_human_readable_token, to_wei_token +from aleph.sdk.types import TokenType if TYPE_CHECKING: from aleph.sdk.chains.ethereum import ETHAccount @@ -44,6 +49,7 @@ async def _execute_operation_with_account(self, operation: Operation) -> str: return await self.account._sign_and_send_transaction(populated_transaction) def can_start_flow(self, flow: Decimal, block=True) -> bool: + """Check if the account has enough funds to start a Superfluid flow of the given size.""" valid = False if self.account.can_transact(block=block): balance = self.account.get_super_token_balance() @@ -51,8 +57,9 @@ def can_start_flow(self, flow: Decimal, block=True) -> bool: valid = balance > MIN_FLOW_4H if not valid and block: raise InsufficientFundsError( - required_funds=float(MIN_FLOW_4H), - available_funds=to_human_readable_token(balance), + token_type=TokenType.ALEPH, + required_funds=float(from_wei_token(MIN_FLOW_4H)), + available_funds=float(from_wei_token(balance)), ) return valid @@ -96,3 +103,51 @@ async def update_flow(self, receiver: str, flow: Decimal) -> str: flow_rate=int(to_wei_token(flow)), ), ) + + async def manage_flow( + self, + receiver: str, + flow: Decimal, + update_type: FlowUpdate, + ) -> Optional[str]: + """ + Update the flow of a Superfluid stream between a sender and receiver. + This function either increases or decreases the flow rate between the sender and receiver, + based on the update_type. If no flow exists and the update type is augmentation, it creates a new flow + with the specified rate. If the update type is reduction and the reduction amount brings the flow to zero + or below, the flow is deleted. + + :param receiver: Address of the receiver in hexadecimal format. + :param flow: The flow rate to be added or removed (in ether). + :param update_type: The type of update to perform (augmentation or reduction). + :return: The transaction hash of the executed operation (create, update, or delete flow). + """ + + # Retrieve current flow info + flow_info: Web3FlowInfo = await self.account.get_flow(receiver) + + current_flow_rate_wei: Decimal = Decimal(flow_info["flowRate"] or 0) + flow_rate_wei: int = int(to_wei_token(flow)) + + if update_type == FlowUpdate.INCREASE: + if current_flow_rate_wei > 0: + # Update existing flow by increasing the rate + new_flow_rate_wei = current_flow_rate_wei + flow_rate_wei + new_flow_rate_ether = from_wei_token(new_flow_rate_wei) + return await self.account.update_flow(receiver, new_flow_rate_ether) + else: + # Create a new flow if none exists + return await self.account.create_flow(receiver, flow) + else: + if current_flow_rate_wei > 0: + # Reduce the existing flow + new_flow_rate_wei = current_flow_rate_wei - flow_rate_wei + # Ensure to not leave infinitesimal flows + # Often, there were 1-10 wei remaining in the flow rate, which prevented the flow from being deleted + if new_flow_rate_wei > 99: + new_flow_rate_ether = from_wei_token(new_flow_rate_wei) + return await self.account.update_flow(receiver, new_flow_rate_ether) + else: + # Delete the flow if the new flow rate is zero or negative + return await self.account.delete_flow(receiver) + return None diff --git a/src/aleph/sdk/evm_utils.py b/src/aleph/sdk/evm_utils.py index 4d2026ef..a425d580 100644 --- a/src/aleph/sdk/evm_utils.py +++ b/src/aleph/sdk/evm_utils.py @@ -1,4 +1,5 @@ -from decimal import Decimal +from decimal import ROUND_CEILING, Context, Decimal +from enum import Enum from typing import List, Optional, Union from aleph_message.models import Chain @@ -21,12 +22,26 @@ }]""" -def to_human_readable_token(amount: Decimal) -> float: - return float(amount / (Decimal(10) ** Decimal(settings.TOKEN_DECIMALS))) +class FlowUpdate(str, Enum): + REDUCE = "reduce" + INCREASE = "increase" + + +def ether_rounding(amount: Decimal) -> Decimal: + """Rounds the given value to 18 decimals.""" + return amount.quantize( + Decimal(1) / Decimal(10**18), rounding=ROUND_CEILING, context=Context(prec=36) + ) + + +def from_wei_token(amount: Decimal) -> Decimal: + """Converts the given wei value to ether.""" + return ether_rounding(amount / Decimal(10) ** Decimal(settings.TOKEN_DECIMALS)) def to_wei_token(amount: Decimal) -> Decimal: - return amount * Decimal(10) ** Decimal(settings.TOKEN_DECIMALS) + """Converts the given ether value to wei.""" + return Decimal(int(amount * Decimal(10) ** Decimal(settings.TOKEN_DECIMALS))) def get_chain_id(chain: Union[Chain, str, None]) -> Optional[int]: diff --git a/src/aleph/sdk/exceptions.py b/src/aleph/sdk/exceptions.py index a538a31c..05ed755f 100644 --- a/src/aleph/sdk/exceptions.py +++ b/src/aleph/sdk/exceptions.py @@ -1,5 +1,8 @@ from abc import ABC +from .types import TokenType +from .utils import displayable_amount + class QueryError(ABC, ValueError): """The result of an API query is inconsistent.""" @@ -69,14 +72,18 @@ class ForgottenMessageError(QueryError): class InsufficientFundsError(Exception): """Raised when the account does not have enough funds to perform an action""" + token_type: TokenType required_funds: float available_funds: float - def __init__(self, required_funds: float, available_funds: float): + def __init__( + self, token_type: TokenType, required_funds: float, available_funds: float + ): + self.token_type = token_type self.required_funds = required_funds self.available_funds = available_funds super().__init__( - f"Insufficient funds: required {required_funds}, available {available_funds}" + f"Insufficient funds ({self.token_type.value}): required {displayable_amount(self.required_funds, decimals=8)}, available {displayable_amount(self.available_funds, decimals=8)}" ) diff --git a/src/aleph/sdk/types.py b/src/aleph/sdk/types.py index c698da5d..05fa9815 100644 --- a/src/aleph/sdk/types.py +++ b/src/aleph/sdk/types.py @@ -83,7 +83,20 @@ class ChainInfo(BaseModel): class StoredContent(BaseModel): + """ + A stored content. + """ + filename: Optional[str] hash: Optional[str] url: Optional[str] error: Optional[str] + + +class TokenType(str, Enum): + """ + A token type. + """ + + GAS = "GAS" + ALEPH = "ALEPH" diff --git a/src/aleph/sdk/utils.py b/src/aleph/sdk/utils.py index c3fc154a..5cbc1e8c 100644 --- a/src/aleph/sdk/utils.py +++ b/src/aleph/sdk/utils.py @@ -8,6 +8,7 @@ import os import subprocess from datetime import date, datetime, time +from decimal import Context, Decimal, InvalidOperation from enum import Enum from pathlib import Path from shutil import make_archive @@ -15,7 +16,6 @@ Any, Dict, Iterable, - List, Mapping, Optional, Protocol, @@ -28,9 +28,38 @@ from uuid import UUID from zipfile import BadZipFile, ZipFile -from aleph_message.models import ItemHash, MessageType -from aleph_message.models.execution.program import Encoding -from aleph_message.models.execution.volume import MachineVolume +from aleph_message.models import ( + Chain, + InstanceContent, + ItemHash, + MachineType, + MessageType, + ProgramContent, +) +from aleph_message.models.execution.base import Payment, PaymentType +from aleph_message.models.execution.environment import ( + FunctionEnvironment, + FunctionTriggers, + HostRequirements, + HypervisorType, + InstanceEnvironment, + MachineResources, + Subscription, + TrustedExecutionEnvironment, +) +from aleph_message.models.execution.instance import RootfsVolume +from aleph_message.models.execution.program import ( + CodeContent, + Encoding, + FunctionRuntime, +) +from aleph_message.models.execution.volume import ( + MachineVolume, + ParentVolume, + PersistentVolumeSizeMib, + VolumePersistence, +) +from aleph_message.utils import Mebibytes from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from jwcrypto.jwa import JWA @@ -177,19 +206,17 @@ def extended_json_encoder(obj: Any) -> Any: def parse_volume(volume_dict: Union[Mapping, MachineVolume]) -> MachineVolume: - # Python 3.9 does not support `isinstance(volume_dict, MachineVolume)`, - # so we need to iterate over all types. if any( isinstance(volume_dict, volume_type) for volume_type in get_args(MachineVolume) ): - return volume_dict + return volume_dict # type: ignore + for volume_type in get_args(MachineVolume): try: return volume_type.parse_obj(volume_dict) except ValueError: - continue - else: - raise ValueError(f"Could not parse volume: {volume_dict}") + pass + raise ValueError(f"Could not parse volume: {volume_dict}") def compute_sha256(s: str) -> str: @@ -234,7 +261,7 @@ def sign_vm_control_payload(payload: Dict[str, str], ephemeral_key) -> str: async def run_in_subprocess( - command: List[str], check: bool = True, stdin_input: Optional[bytes] = None + command: list[str], check: bool = True, stdin_input: Optional[bytes] = None ) -> bytes: """Run the specified command in a subprocess, returns the stdout of the process.""" logger.debug(f"command: {' '.join(command)}") @@ -401,3 +428,166 @@ def safe_getattr(obj, attr, default=None): if obj is default: break return obj + + +def displayable_amount( + amount: Union[str, int, float, Decimal], decimals: int = 18 +) -> str: + """Returns the amount as a string without unnecessary decimals.""" + + str_amount = "" + try: + dec_amount = Decimal(amount) + if decimals: + dec_amount = dec_amount.quantize( + Decimal(1) / Decimal(10**decimals), context=Context(prec=36) + ) + str_amount = str(format(dec_amount.normalize(), "f")) + except ValueError: + logger.error(f"Invalid amount to display: {amount}") + exit(1) + except InvalidOperation: + logger.error(f"Invalid operation on amount to display: {amount}") + exit(1) + return str_amount + + +def make_instance_content( + rootfs: str, + rootfs_size: int, + payment: Optional[Payment] = None, + environment_variables: Optional[dict[str, str]] = None, + address: Optional[str] = None, + memory: Optional[int] = None, + vcpus: Optional[int] = None, + timeout_seconds: Optional[float] = None, + allow_amend: bool = False, + internet: bool = True, + aleph_api: bool = True, + hypervisor: Optional[HypervisorType] = None, + trusted_execution: Optional[TrustedExecutionEnvironment] = None, + volumes: Optional[list[Mapping]] = None, + ssh_keys: Optional[list[str]] = None, + metadata: Optional[dict[str, Any]] = None, + requirements: Optional[HostRequirements] = None, +) -> InstanceContent: + """ + Create InstanceContent object given the provided fields. + """ + + address = address or "0x0000000000000000000000000000000000000000" + payment = payment or Payment(chain=Chain.ETH, type=PaymentType.hold, receiver=None) + selected_hypervisor: HypervisorType = hypervisor or HypervisorType.qemu + vcpus = vcpus or settings.DEFAULT_VM_VCPUS + memory = memory or settings.DEFAULT_VM_MEMORY + timeout_seconds = timeout_seconds or settings.DEFAULT_VM_TIMEOUT + volumes = volumes if volumes is not None else [] + + return InstanceContent( + address=address, + allow_amend=allow_amend, + environment=InstanceEnvironment( + internet=internet, + aleph_api=aleph_api, + hypervisor=selected_hypervisor, + trusted_execution=trusted_execution, + ), + variables=environment_variables, + resources=MachineResources( + vcpus=vcpus, + memory=Mebibytes(memory), + seconds=int(timeout_seconds), + ), + rootfs=RootfsVolume( + parent=ParentVolume( + ref=ItemHash(rootfs), + use_latest=True, + ), + size_mib=PersistentVolumeSizeMib(rootfs_size), + persistence=VolumePersistence.host, + ), + volumes=[parse_volume(volume) for volume in volumes], + requirements=requirements, + time=datetime.now().timestamp(), + authorized_keys=ssh_keys, + metadata=metadata, + payment=payment, + ) + + +def make_program_content( + program_ref: str, + entrypoint: str, + runtime: str, + metadata: Optional[dict[str, Any]] = None, + address: Optional[str] = None, + vcpus: Optional[int] = None, + memory: Optional[int] = None, + timeout_seconds: Optional[float] = None, + internet: bool = False, + aleph_api: bool = True, + allow_amend: bool = False, + encoding: Encoding = Encoding.zip, + persistent: bool = False, + volumes: Optional[list[Mapping]] = None, + environment_variables: Optional[dict[str, str]] = None, + subscriptions: Optional[list[dict]] = None, + payment: Optional[Payment] = None, +) -> ProgramContent: + """ + Create ProgramContent object given the provided fields. + """ + + address = address or "0x0000000000000000000000000000000000000000" + payment = payment or Payment(chain=Chain.ETH, type=PaymentType.hold, receiver=None) + vcpus = vcpus or settings.DEFAULT_VM_VCPUS + memory = memory or settings.DEFAULT_VM_MEMORY + timeout_seconds = timeout_seconds or settings.DEFAULT_VM_TIMEOUT + volumes = volumes if volumes is not None else [] + subscriptions = ( + [Subscription(**sub) for sub in subscriptions] + if subscriptions is not None + else None + ) + + return ProgramContent( + type=MachineType.vm_function, + address=address, + allow_amend=allow_amend, + code=CodeContent( + encoding=encoding, + entrypoint=entrypoint, + ref=ItemHash(program_ref), + use_latest=True, + ), + on=FunctionTriggers( + http=True, + persistent=persistent, + message=subscriptions, + ), + environment=FunctionEnvironment( + reproducible=False, + internet=internet, + aleph_api=aleph_api, + ), + variables=environment_variables, + resources=MachineResources( + vcpus=vcpus, + memory=Mebibytes(memory), + seconds=int(timeout_seconds), + ), + runtime=FunctionRuntime( + ref=ItemHash(runtime), + use_latest=True, + comment=( + "Official aleph.im runtime" + if runtime == settings.DEFAULT_RUNTIME_ID + else "" + ), + ), + volumes=[parse_volume(volume) for volume in volumes], + time=datetime.now().timestamp(), + metadata=metadata, + authorized_keys=[], + payment=payment, + ) diff --git a/tests/unit/aleph_vm_authentication.py b/tests/unit/aleph_vm_authentication.py index 491da51a..6083a119 100644 --- a/tests/unit/aleph_vm_authentication.py +++ b/tests/unit/aleph_vm_authentication.py @@ -263,7 +263,7 @@ async def authenticate_websocket_message( signed_operation = SignedOperation.parse_obj(message["X-SignedOperation"]) if signed_operation.content.domain != domain_name: logger.debug( - f"Invalid domain '{signed_pubkey.content.domain}' != '{domain_name}'" + f"Invalid domain '{signed_operation.content.domain}' != '{domain_name}'" ) raise web.HTTPUnauthorized(reason="Invalid domain") return verify_signed_operation(signed_operation, signed_pubkey) diff --git a/tests/unit/test_asynchronous.py b/tests/unit/test_asynchronous.py index b044e170..e2647590 100644 --- a/tests/unit/test_asynchronous.py +++ b/tests/unit/test_asynchronous.py @@ -7,6 +7,7 @@ Chain, ForgetMessage, InstanceMessage, + ItemHash, MessageType, Payment, PaymentType, @@ -184,12 +185,16 @@ async def test_create_confidential_instance(mock_session_with_post_success): ), hypervisor=HypervisorType.qemu, trusted_execution=TrustedExecutionEnvironment( - firmware="cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe", + firmware=ItemHash( + "cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe" + ), policy=0b1, ), requirements=HostRequirements( node=NodeRequirements( - node_hash="cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe", + node_hash=ItemHash( + "cafecafecafecafecafecafecafecafecafecafecafecafecafecafecafecafe" + ), ) ), ) @@ -285,5 +290,6 @@ async def test_create_instance_insufficient_funds_error( payment=Payment( chain=Chain.ETH, type=PaymentType.hold, + receiver=None, ), ) diff --git a/tests/unit/test_price.py b/tests/unit/test_price.py index bed9304a..fe9e3468 100644 --- a/tests/unit/test_price.py +++ b/tests/unit/test_price.py @@ -11,14 +11,14 @@ async def test_get_program_price_valid(): Test that the get_program_price method returns the correct PriceResponse when given a valid item hash. """ - expected_response = { - "required_tokens": 3.0555555555555556e-06, - "payment_type": "superfluid", - } - mock_session = make_mock_get_session(expected_response) + expected = PriceResponse( + required_tokens=3.0555555555555556e-06, + payment_type="superfluid", + ) + mock_session = make_mock_get_session(expected.dict()) async with mock_session: response = await mock_session.get_program_price("cacacacacacaca") - assert response == PriceResponse(**expected_response) + assert response == expected @pytest.mark.asyncio diff --git a/tests/unit/test_superfluid.py b/tests/unit/test_superfluid.py index c2f853bd..74bcc38e 100644 --- a/tests/unit/test_superfluid.py +++ b/tests/unit/test_superfluid.py @@ -7,6 +7,7 @@ from eth_utils import to_checksum_address from aleph.sdk.chains.ethereum import ETHAccount +from aleph.sdk.evm_utils import FlowUpdate def generate_fake_eth_address(): @@ -24,6 +25,7 @@ def mock_superfluid(): mock_superfluid.create_flow = AsyncMock(return_value="0xTransactionHash") mock_superfluid.delete_flow = AsyncMock(return_value="0xTransactionHash") mock_superfluid.update_flow = AsyncMock(return_value="0xTransactionHash") + mock_superfluid.manage_flow = AsyncMock(return_value="0xTransactionHash") # Mock get_flow to return a mock Web3FlowInfo mock_flow_info = {"timestamp": 0, "flowRate": 0, "deposit": 0, "owedDeposit": 0} @@ -98,3 +100,14 @@ async def test_get_flow(eth_account, mock_superfluid): assert flow_info["flowRate"] == 0 assert flow_info["deposit"] == 0 assert flow_info["owedDeposit"] == 0 + + +@pytest.mark.asyncio +async def test_manage_flow(eth_account, mock_superfluid): + receiver = generate_fake_eth_address() + flow = Decimal("0.005") + + tx_hash = await eth_account.manage_flow(receiver, flow, FlowUpdate.INCREASE) + + assert tx_hash == "0xTransactionHash" + mock_superfluid.manage_flow.assert_awaited_once() diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index bfca23a5..c560455d 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,5 +1,6 @@ import base64 import datetime +from unittest.mock import MagicMock import pytest as pytest from aleph_message.models import ( @@ -158,6 +159,7 @@ def test_parse_immutable_volume(): def test_parse_ephemeral_volume(): volume_dict = { "comment": "Dummy hash", + "mount": "/opt/data", "ephemeral": True, "size_mib": 1, } @@ -169,6 +171,8 @@ def test_parse_ephemeral_volume(): def test_parse_persistent_volume(): volume_dict = { + "comment": "Dummy hash", + "mount": "/opt/data", "parent": { "ref": "QmX8K1c22WmQBAww5ShWQqwMiFif7XFrJD6iFBj7skQZXW", "use_latest": True, @@ -184,9 +188,9 @@ def test_parse_persistent_volume(): assert isinstance(volume, PersistentVolume) -def test_calculate_firmware_hash(mocker): - mock_path = mocker.Mock( - read_bytes=mocker.Mock(return_value=b"abc"), +def test_calculate_firmware_hash(): + mock_path = MagicMock( + read_bytes=MagicMock(return_value=b"abc"), ) assert (