diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index acdeb810..a7195b90 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -41,7 +41,7 @@ jobs: - name: Set up Python for macOS if: startsWith(matrix.os, 'macos') - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: 3.11 diff --git a/.github/workflows/test-build.yml b/.github/workflows/test-build.yml index a440d1c9..81308d86 100644 --- a/.github/workflows/test-build.yml +++ b/.github/workflows/test-build.yml @@ -20,7 +20,7 @@ jobs: runs-on: ${{matrix.os}} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Workaround github issue https://github.com/actions/runner-images/issues/7192 if: startsWith(matrix.os, 'ubuntu-') @@ -35,7 +35,7 @@ jobs: - name: Set up Python for macOS if: startsWith(matrix.os, 'macos') - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: 3.11 diff --git a/.github/workflows/test-docker.yml b/.github/workflows/test-docker.yml index a97d722a..38b58517 100644 --- a/.github/workflows/test-docker.yml +++ b/.github/workflows/test-docker.yml @@ -16,7 +16,7 @@ jobs: runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 # Use GitHub's Docker registry to cache intermediate layers - run: echo ${{ secrets.GITHUB_TOKEN }} | docker login docker.pkg.github.com diff --git a/pyproject.toml b/pyproject.toml index 260754d6..993d5a47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,8 +30,8 @@ dynamic = [ "version" ] dependencies = [ "aiodns==3.2", "aiohttp==3.11.12", - "aleph-message>=0.6", - "aleph-sdk-python>=1.3,<2", + "aleph-message>=0.6.1", + "aleph-sdk-python>=1.4,<2", "base58==2.1.1", # Needed now as default with _load_account changement "py-sr25519-bindings==0.2", # Needed for DOT signatures "pygments==2.19.1", diff --git a/src/aleph_client/__main__.py b/src/aleph_client/__main__.py index 60775ed6..c86b1f15 100644 --- a/src/aleph_client/__main__.py +++ b/src/aleph_client/__main__.py @@ -11,6 +11,7 @@ instance, message, node, + pricing, program, ) from aleph_client.utils import AsyncTyper @@ -32,6 +33,7 @@ app.add_typer(domain.app, name="domain", help="Manage custom domain (DNS) on aleph.im & twentysix.cloud") app.add_typer(node.app, name="node", help="Get node info on aleph.im & twentysix.cloud") app.add_typer(about.app, name="about", help="Display the informations of Aleph CLI") +app.command("pricing")(pricing.prices_for_service) if __name__ == "__main__": app() diff --git a/src/aleph_client/commands/account.py b/src/aleph_client/commands/account.py index 68025626..3ac4d9b0 100644 --- a/src/aleph_client/commands/account.py +++ b/src/aleph_client/commands/account.py @@ -21,7 +21,7 @@ get_chains_with_super_token, get_compatible_chains, ) -from aleph.sdk.utils import bytes_from_hex +from aleph.sdk.utils import bytes_from_hex, displayable_amount from aleph_message.models import Chain from rich.console import Console from rich.panel import Panel @@ -241,6 +241,20 @@ def sign_bytes( typer.echo("\nSignature: " + signature.hex()) +async def get_balance(address: str) -> dict: + balance_data: dict = {} + uri = f"{settings.API_HOST}/api/v0/addresses/{address}/balance" + async with aiohttp.ClientSession() as session: + response = await session.get(uri) + if response.status == 200: + balance_data = await response.json() + balance_data["available_amount"] = balance_data["balance"] - balance_data["locked_amount"] + else: + error = f"Failed to retrieve balance for address {address}. Status code: {response.status}" + raise Exception(error) + return balance_data + + @app.command() async def balance( address: Optional[str] = typer.Option(None, help="Address"), @@ -255,54 +269,46 @@ async def balance( address = account.get_address() if address: - uri = f"{settings.API_HOST}/api/v0/addresses/{address}/balance" - - async with aiohttp.ClientSession() as session: - response = await session.get(uri) - if response.status == 200: - balance_data = await response.json() - balance_data["available_amount"] = balance_data["balance"] - balance_data["locked_amount"] - - infos = [ - Text.from_markup(f"Address: [bright_cyan]{balance_data['address']}[/bright_cyan]"), - Text.from_markup( - f"\nBalance: [bright_cyan]{balance_data['balance']:.2f}".rstrip("0").rstrip(".") - + "[/bright_cyan]" - ), - ] - details = balance_data.get("details") - if details: - infos += [Text("\n ā†³ Details")] - for chain_, chain_balance in details.items(): - infos += [ - Text.from_markup( - f"\n {chain_}: [orange3]{chain_balance:.2f}".rstrip("0").rstrip(".") + "[/orange3]" - ) - ] - available_color = "bright_cyan" if balance_data["available_amount"] >= 0 else "red" - infos += [ - Text.from_markup( - f"\n - Locked: [bright_cyan]{balance_data['locked_amount']:.2f}".rstrip("0").rstrip(".") - + "[/bright_cyan]" - ), - Text.from_markup( - f"\n - Available: [{available_color}]{balance_data['available_amount']:.2f}".rstrip("0").rstrip( - "." + try: + balance_data = await get_balance(address) + infos = [ + Text.from_markup(f"Address: [bright_cyan]{balance_data['address']}[/bright_cyan]"), + Text.from_markup( + f"\nBalance: [bright_cyan]{displayable_amount(balance_data['balance'], decimals=2)}[/bright_cyan]" + ), + ] + details = balance_data.get("details") + if details: + infos += [Text("\n ā†³ Details")] + for chain_, chain_balance in details.items(): + infos += [ + Text.from_markup( + f"\n {chain_}: [orange3]{displayable_amount(chain_balance, decimals=2)}[/orange3]" ) - + f"[/{available_color}]" - ), - ] - console.print( - Panel( - Text.assemble(*infos), - title="Account Infos", - border_style="bright_cyan", - expand=False, - title_align="left", - ) + ] + available_color = "bright_cyan" if balance_data["available_amount"] >= 0 else "red" + infos += [ + Text.from_markup( + f"\n - Locked: [bright_cyan]{displayable_amount(balance_data['locked_amount'], decimals=2)}" + "[/bright_cyan]" + ), + Text.from_markup( + f"\n - Available: [{available_color}]" + f"{displayable_amount(balance_data['available_amount'], decimals=2)}" + f"[/{available_color}]" + ), + ] + console.print( + Panel( + Text.assemble(*infos), + title="Account Infos", + border_style="bright_cyan", + expand=False, + title_align="left", ) - else: - typer.echo(f"Failed to retrieve balance for address {address}. Status code: {response.status}") + ) + except Exception as e: + typer.echo(e) else: typer.echo("Error: Please provide either a private key, private key file, or an address.") diff --git a/src/aleph_client/commands/files.py b/src/aleph_client/commands/files.py index 8eb00eec..d0f725cb 100644 --- a/src/aleph_client/commands/files.py +++ b/src/aleph_client/commands/files.py @@ -142,7 +142,9 @@ async def download( @app.command() async def forget( - item_hash: str = typer.Argument(..., help="Hash to forget"), + item_hash: str = typer.Argument( + ..., help="Hash(es) to forget. Must be a comma separated list. Example: `123...abc` or `123...abc,456...xyz`" + ), reason: str = typer.Argument("User deletion", help="reason to forget"), channel: Optional[str] = typer.Option(default=settings.DEFAULT_CHANNEL, help=help_strings.CHANNEL), private_key: Optional[str] = typer.Option(settings.PRIVATE_KEY_STRING, help=help_strings.PRIVATE_KEY), @@ -155,8 +157,10 @@ async def forget( account: AccountFromPrivateKey = _load_account(private_key, private_key_file) + hashes = [ItemHash(item_hash) for item_hash in item_hash.split(",")] + async with AuthenticatedAlephHttpClient(account=account, api_server=settings.API_HOST) as client: - value = await client.forget(hashes=[ItemHash(item_hash)], reason=reason, channel=channel) + value = await client.forget(hashes=hashes, reason=reason, channel=channel) typer.echo(f"{value[0].json(indent=4)}") diff --git a/src/aleph_client/commands/help_strings.py b/src/aleph_client/commands/help_strings.py index 79d7e2b2..3cdbbf2f 100644 --- a/src/aleph_client/commands/help_strings.py +++ b/src/aleph_client/commands/help_strings.py @@ -17,16 +17,15 @@ ASK_FOR_CONFIRMATION = "Prompt user for confirmation" IPFS_CATCH_ALL_PATH = "Choose a relative path to catch all unmatched route or a 404 error" PAYMENT_TYPE = "Payment method, either holding tokens, NFTs, or Pay-As-You-Go via token streaming" -HYPERVISOR = "Hypervisor to use to launch your instance. Defaults to QEMU" +HYPERVISOR = "Hypervisor to use to launch your instance. Always defaults to QEMU, since Firecracker is now deprecated for instances" INSTANCE_NAME = "Name of your new instance" ROOTFS = ( "Hash of the rootfs to use for your instance. Defaults to Ubuntu 22. You can also create your own rootfs and pin it" ) -ROOTFS_SIZE = ( - "Size of the rootfs to use for your instance. If not set, content.size of the --rootfs store message will be used" -) +COMPUTE_UNITS = "Number of compute units to allocate. Compute units correspond to a tier that includes vcpus, memory, disk and gpu presets. For reference, run: `aleph pricing --help`" +ROOTFS_SIZE = "Rootfs size in MiB to allocate" VCPUS = "Number of virtual CPUs to allocate" -MEMORY = "Maximum memory (RAM) allocation on VM in MiB" +MEMORY = "Maximum memory (RAM) in MiB to allocate" TIMEOUT_SECONDS = "If vm is not called after [timeout_seconds] it will shutdown" SSH_PUBKEY_FILE = "Path to a public ssh key to be added to the instance" CRN_HASH = "Hash of the CRN to deploy to (only applicable for confidential and/or Pay-As-You-Go instances)" @@ -37,6 +36,7 @@ CONFIDENTIAL_FIRMWARE_HASH = "Hash of the UEFI Firmware content, to validate measure (ignored if path is provided)" CONFIDENTIAL_FIRMWARE_PATH = "Path to the UEFI Firmware content, to validate measure (instead of the hash)" GPU_OPTION = "Launch an instance attaching a GPU to it" +GPU_PREMIUM_OPTION = "Use Premium GPUs (VRAM > 48GiB)" KEEP_SESSION = "Keeping the already initiated session" VM_SECRET = "Secret password to start the VM" CRN_URL_VM_DELETION = "Domain of the CRN where an associated VM is running. It ensures your VM will be stopped and erased on the CRN before the instance message is actually deleted" @@ -51,6 +51,7 @@ PAYMENT_CHAIN_USED = "Chain you are using to pay for your instance" ORIGIN_CHAIN = "Chain of origin of your private key (ensuring correct parsing)" ADDRESS_CHAIN = "Chain for the address" +ADDRESS_PAYER = "Address of the payer. In order to delegate the payment, your account must be authorized beforehand to publish on the behalf of this address. See the docs for more info: https://docs.aleph.im/protocol/permissions/" CREATE_REPLACE = "Overwrites private key file if it already exists" CREATE_ACTIVE = "Loads the new private key after creation" PROMPT_CRN_URL = "URL of the CRN (Compute node) on which the instance is running" @@ -58,8 +59,10 @@ PROGRAM_PATH = "Path to your source code. Can be a directory, a .squashfs file or a .zip archive" PROGRAM_ENTRYPOINT = "Your program entrypoint. Example: `main:app` for Python programs, else `run.sh` for a script containing your launch command" PROGRAM_RUNTIME = "Hash of the runtime to use for your program. You can also create your own runtime and pin it. Currently defaults to `{runtime_id}` (Use `aleph program runtime-checker` to inspect it)" -PROGRAM_BETA = "If true, you will be prompted to add message subscriptions to your program" +PROGRAM_INTERNET = "Enable internet access for your program. By default, internet access is disabled" +PROGRAM_PERSISTENT = "Create your program as persistent. By default, programs are ephemeral (serverless): they only start when called and then shutdown after the defined timeout delay." PROGRAM_UPDATABLE = "Allow program updates. By default, only the source code can be modified without requiring redeployement (same item hash). When enabled (set to True), this option allows to update any other field. However, such modifications will require a program redeployment (new item hash)" +PROGRAM_BETA = "If true, you will be prompted to add message subscriptions to your program" PROGRAM_KEEP_CODE = "Keep the source code intact instead of deleting it" PROGRAM_KEEP_PREV = "Keep the previous program intact instead of deleting it" TARGET_ADDRESS = "Target address. Defaults to current account address" diff --git a/src/aleph_client/commands/instance/__init__.py b/src/aleph_client/commands/instance/__init__.py index 2be90815..3e5385e8 100644 --- a/src/aleph_client/commands/instance/__init__.py +++ b/src/aleph_client/commands/instance/__init__.py @@ -6,9 +6,8 @@ import logging import shutil from decimal import Decimal -from math import ceil from pathlib import Path -from typing import Optional, cast +from typing import Any, Optional, Union, cast import aiohttp import typer @@ -18,7 +17,11 @@ from aleph.sdk.client.vm_client import VmClient from aleph.sdk.client.vm_confidential_client import VmConfidentialClient from aleph.sdk.conf import load_main_configuration, settings -from aleph.sdk.evm_utils import get_chains_with_holding, get_chains_with_super_token +from aleph.sdk.evm_utils import ( + FlowUpdate, + get_chains_with_holding, + get_chains_with_super_token, +) from aleph.sdk.exceptions import ( ForgottenMessageError, InsufficientFundsError, @@ -26,8 +29,13 @@ ) from aleph.sdk.query.filters import MessageFilter from aleph.sdk.query.responses import PriceResponse -from aleph.sdk.types import StorageEnum -from aleph.sdk.utils import calculate_firmware_hash, safe_getattr +from aleph.sdk.types import StorageEnum, TokenType +from aleph.sdk.utils import ( + calculate_firmware_hash, + displayable_amount, + make_instance_content, + safe_getattr, +) from aleph_message.models import Chain, InstanceMessage, MessageType, StoreMessage from aleph_message.models.execution.base import Payment, PaymentType from aleph_message.models.execution.environment import ( @@ -47,17 +55,20 @@ from rich.text import Text from aleph_client.commands import help_strings +from aleph_client.commands.account import get_balance from aleph_client.commands.instance.display import CRNTable from aleph_client.commands.instance.network import ( fetch_crn_info, + fetch_crn_list, + fetch_settings, fetch_vm_info, find_crn_of_vm, ) -from aleph_client.commands.instance.superfluid import FlowUpdate, update_flow -from aleph_client.commands.node import NodeInfo, _fetch_nodes +from aleph_client.commands.pricing import PricingEntity, SelectedTier, fetch_pricing from aleph_client.commands.utils import ( filter_only_valid_messages, find_sevctl_or_exit, + found_gpus_by_model, get_or_prompt_volumes, setup_logging, str_to_datetime, @@ -68,19 +79,16 @@ wait_for_processed_instance, yes_no_input, ) -from aleph_client.models import CRNInfo from aleph_client.utils import AsyncTyper, sanitize_url logger = logging.getLogger(__name__) app = AsyncTyper(no_args_is_help=True) -# TODO: This should be put on the API to get always from there -FLOW_INSTANCE_PRICE_PER_SECOND = Decimal("0.0000155") # 0.055/h - +metavar_valid_payment_types = f"[{'|'.join(PaymentType)}|nft]" hold_chains = [*get_chains_with_holding(), Chain.SOL] -super_token_chains = get_chains_with_super_token() metavar_valid_chains = f"[{'|'.join(hold_chains)}]" -metavar_valid_payment_types = f"[{'|'.join(PaymentType)}|nft]" +super_token_chains = get_chains_with_super_token() +metavar_valid_payg_chains = f"[{'|'.join(super_token_chains)}]" @app.command() @@ -98,12 +106,13 @@ async def create( metavar=metavar_valid_chains, case_sensitive=False, ), - hypervisor: Optional[HypervisorType] = typer.Option(HypervisorType.qemu, help=help_strings.HYPERVISOR), + hypervisor: HypervisorType = typer.Option(HypervisorType.qemu, help=help_strings.HYPERVISOR), name: Optional[str] = typer.Option(None, help=help_strings.INSTANCE_NAME), rootfs: Optional[str] = typer.Option(None, help=help_strings.ROOTFS), - rootfs_size: Optional[int] = typer.Option(None, help=help_strings.ROOTFS_SIZE), + compute_units: Optional[int] = typer.Option(None, help=help_strings.COMPUTE_UNITS), vcpus: Optional[int] = typer.Option(None, help=help_strings.VCPUS), memory: Optional[int] = typer.Option(None, help=help_strings.MEMORY), + rootfs_size: Optional[int] = typer.Option(None, help=help_strings.ROOTFS_SIZE), timeout_seconds: float = typer.Option( settings.DEFAULT_VM_TIMEOUT, help=help_strings.TIMEOUT_SECONDS, @@ -112,6 +121,7 @@ async def create( Path("~/.ssh/id_rsa.pub").expanduser(), help=help_strings.SSH_PUBKEY_FILE, ), + address: Optional[str] = typer.Option(None, help=help_strings.ADDRESS_PAYER), crn_hash: Optional[str] = typer.Option(None, help=help_strings.CRN_HASH), crn_url: Optional[str] = typer.Option(None, help=help_strings.CRN_URL), confidential: bool = typer.Option(False, help=help_strings.CONFIDENTIAL_OPTION), @@ -119,6 +129,7 @@ async def create( default=settings.DEFAULT_CONFIDENTIAL_FIRMWARE, help=help_strings.CONFIDENTIAL_FIRMWARE ), gpu: bool = typer.Option(False, help=help_strings.GPU_OPTION), + premium: Optional[bool] = typer.Option(None, help=help_strings.GPU_PREMIUM_OPTION), skip_volume: bool = typer.Option(False, help=help_strings.SKIP_VOLUME), persistent_volume: Optional[list[str]] = typer.Option(None, help=help_strings.PERSISTENT_VOLUME), ephemeral_volume: Optional[list[str]] = typer.Option(None, help=help_strings.EPHEMERAL_VOLUME), @@ -150,6 +161,10 @@ async def create( ) ssh_pubkey: str = ssh_pubkey_file.read_text(encoding="utf-8").strip() + # Populates account / address + account = _load_account(private_key, private_key_file, chain=payment_chain) + address = address or settings.ADDRESS_TO_USE or account.get_address() + # Loads default configuration if no chain is set if payment_chain is None: config = load_main_configuration(settings.CONFIG_FILE) @@ -168,15 +183,22 @@ async def create( ) # Force-switches if NFT payment-type + nft_chains = [Chain.AVAX, Chain.BASE, Chain.SOL] if payment_type == "nft": payment_type = PaymentType.hold - payment_chain = Chain( - Prompt.ask( - "On which chain did you claim your NFT voucher?", - choices=[Chain.AVAX.value, Chain.BASE.value, Chain.SOL.value], - default=Chain.AVAX.value, + if payment_chain is None or payment_chain not in nft_chains: + if payment_chain: + console.print( + f"[red]{safe_getattr(payment_chain, 'value') or payment_chain}[/red]" + " incompatible with NFT vouchers." + ) + payment_chain = Chain( + Prompt.ask( + "On which chain did you claim your NFT voucher?", + choices=[nft_chain.value for nft_chain in nft_chains], + default=Chain.AVAX.value, + ) ) - ) elif payment_type in [ptype.value for ptype in PaymentType]: payment_type = PaymentType(payment_type) else: @@ -186,6 +208,9 @@ async def create( # Checks if payment-chain is compatible with PAYG is_stream = payment_type != PaymentType.hold if is_stream: + if address != account.get_address(): + console.print("Payment delegation is incompatible with Pay-As-You-Go.") + raise typer.Exit(code=1) if payment_chain is None or payment_chain not in super_token_chains: if payment_chain: console.print( @@ -199,6 +224,7 @@ async def create( default=Chain.AVAX.value, ) ) + # Fallback for Hold-tier if no config / no chain is set / chain not in hold_chains elif payment_chain is None or payment_chain not in hold_chains: if payment_chain: @@ -213,58 +239,18 @@ async def create( ) ) - # Populates account - account = _load_account(private_key, private_key_file, chain=payment_chain) - - # Checks required balances (Gas + Aleph ERC20) for superfluid payment - if is_stream and isinstance(account, ETHAccount): - if account.CHAIN != payment_chain: - account.switch_chain(payment_chain) - if account.superfluid_connector and hasattr(account.superfluid_connector, "can_start_flow"): - try: # Quick check with theoretical min price - account.superfluid_connector.can_start_flow(FLOW_INSTANCE_PRICE_PER_SECOND) # 0.055/h - except Exception as e: - echo(e) - raise typer.Exit(code=1) from e - else: - echo("Superfluid connector not available on this chain.") - raise typer.Exit(code=1) + # Ensure hypervisor is compatible + if hypervisor != HypervisorType.qemu: + console.print("QEMU is now the only supported hypervisor. Firecracker has been deprecated for instances.") + raise typer.Exit(code=1) - # Checks if Hypervisor is compatible with confidential or with GPU support - if confidential or gpu: - if hypervisor and hypervisor != HypervisorType.qemu: - echo("Only QEMU is supported as an hypervisor for confidential") - raise typer.Exit(code=1) - elif not hypervisor: - echo("Using QEMU as hypervisor for confidential or GPU support") - hypervisor = HypervisorType.qemu - - available_hypervisors = { - HypervisorType.firecracker: { - "ubuntu22": settings.UBUNTU_22_ROOTFS_ID, - "debian12": settings.DEBIAN_12_ROOTFS_ID, - "debian11": settings.DEBIAN_11_ROOTFS_ID, - }, - HypervisorType.qemu: { - "ubuntu22": settings.UBUNTU_22_QEMU_ROOTFS_ID, - "debian12": settings.DEBIAN_12_QEMU_ROOTFS_ID, - "debian11": settings.DEBIAN_11_QEMU_ROOTFS_ID, - }, + os_choices = { + "ubuntu22": settings.UBUNTU_22_QEMU_ROOTFS_ID, + "ubuntu24": settings.UBUNTU_24_QEMU_ROOTFS_ID, + "debian12": settings.DEBIAN_12_QEMU_ROOTFS_ID, } - if hypervisor is None: - hypervisor_choice = HypervisorType[ - Prompt.ask( - "Which hypervisor you want to use?", - default=settings.DEFAULT_HYPERVISOR.name, - choices=[x.name for x in available_hypervisors], - ) - ] - hypervisor = HypervisorType(hypervisor_choice) - is_qemu = hypervisor == HypervisorType.qemu - - os_choices = available_hypervisors[hypervisor] - + # Rootfs selection if not rootfs or len(rootfs) != 64: if confidential: # Confidential only support custom rootfs @@ -295,8 +281,6 @@ async def create( echo(f"Given rootfs volume {rootfs} has been deleted on aleph.im") if not rootfs_message: raise typer.Exit(code=1) - elif rootfs_size is None: - rootfs_size = safe_getattr(rootfs_message, "content.size") # Validate confidential firmware message exist confidential_firmware_as_hash = None @@ -313,20 +297,45 @@ async def create( if not firmware_message: raise typer.Exit(code=1) - name = name or validated_prompt("Instance name", lambda x: len(x) < 65) - rootfs_size = rootfs_size or validated_int_prompt( - "Disk size in MiB", default=settings.DEFAULT_ROOTFS_SIZE, min_value=10_240, max_value=542_288 - ) - vcpus = vcpus or validated_int_prompt( - "Number of virtual cpus to allocate", default=settings.DEFAULT_VM_VCPUS, min_value=1, max_value=12 + # Filter and prepare the list of available GPUs + crn_list = None + found_gpu_models: Optional[dict[str, dict[str, dict[str, int]]]] = None + if gpu: + echo("Fetching available GPU list...") + crn_list = await fetch_crn_list(latest_crn_version=True, ipv6=True, stream_address=True, gpu=True) + found_gpu_models = found_gpus_by_model(crn_list) + if not found_gpu_models: + echo("No available GPU found. Try again later.") + raise typer.Exit(code=1) + premium = yes_no_input(f"{help_strings.GPU_PREMIUM_OPTION}?", default=False) if premium is None else premium + + pricing = await fetch_pricing() + pricing_entity = ( + PricingEntity.INSTANCE_CONFIDENTIAL + if confidential + else ( + PricingEntity.INSTANCE_GPU_PREMIUM + if gpu and premium + else PricingEntity.INSTANCE_GPU_STANDARD if gpu else PricingEntity.INSTANCE + ) ) - memory = memory or validated_int_prompt( - "Maximum memory allocation on vm in MiB", - default=settings.DEFAULT_INSTANCE_MEMORY, - min_value=2_048, - max_value=24_576, + tier = cast( # Safe cast + SelectedTier, + pricing.display_table_for( + pricing_entity, + compute_units=compute_units or 0, + vcpus=vcpus or 0, + memory=memory or 0, + disk=rootfs_size or 0, + gpu_models=found_gpu_models, + selector=True, + ), ) - + name = name or validated_prompt("Instance name", lambda x: x and len(x) < 65) + vcpus = tier.vcpus + memory = tier.memory + rootfs_size = tier.disk + gpu_model = tier.gpu_model volumes = [] if not skip_volume: volumes = get_or_prompt_volumes( @@ -335,67 +344,61 @@ async def create( immutable_volume=immutable_volume, ) + # Early check with minimal cost (Gas + Aleph ERC20) + available_funds = Decimal(0 if is_stream else (await get_balance(address))["available_amount"]) + try: + if is_stream and isinstance(account, ETHAccount): + if account.CHAIN != payment_chain: + account.switch_chain(payment_chain) + if safe_getattr(account, "superfluid_connector"): + account.can_start_flow(tier.price.payg) + else: + echo("Superfluid connector not available on this chain.") + raise typer.Exit(code=1) + elif available_funds < tier.price.hold: + raise InsufficientFundsError(TokenType.ALEPH, float(tier.price.hold), float(available_funds)) + except InsufficientFundsError as e: + echo(e) + raise typer.Exit(code=1) from e + stream_reward_address = None crn = None if is_stream or confidential or gpu: - if crn_url and crn_hash: - crn_url = sanitize_url(crn_url) + if crn_url: try: - crn_name, score, reward_addr, terms_and_conditions = "?", 0, "", None - nodes: NodeInfo = await _fetch_nodes() - for node in nodes.nodes: - found_node, hash_match = None, False - try: - if sanitize_url(node["address"]) == crn_url: - found_node = node - if found_node["hash"] == crn_hash: - hash_match = True - except aiohttp.InvalidURL: - logger.debug(f"Invalid URL for node `{node['hash']}`: {node['address']}") - if found_node: - if hash_match: - crn_name = found_node["name"] - score = found_node["score"] - reward_addr = found_node["stream_reward"] - terms_and_conditions = node["terms_and_conditions"] - break - else: - echo( - f"* Provided CRN *\nUrl: {crn_url}\nHash: {crn_hash}\n\n* Found CRN *\nUrl: " - f"{found_node['address']}\nHash: {found_node['hash']}\n\nMismatch between provided CRN " - "and found CRN" - ) - raise typer.Exit(1) - if crn_name == "?": - echo(f"* Provided CRN *\nUrl: {crn_url}\nHash: {crn_hash}\n\nCRN not found in aggregate") - raise typer.Exit(1) - crn_info = await fetch_crn_info(crn_url) - if crn_info: - crn = CRNInfo( - hash=ItemHash(crn_hash), - name=crn_name or "?", - url=crn_url, - version=crn_info.get("version", ""), - score=score, - stream_reward_address=str(crn_info.get("payment", {}).get("PAYMENT_RECEIVER_ADDRESS")) - or reward_addr - or "", - machine_usage=crn_info.get("machine_usage"), - qemu_support=bool(crn_info.get("computing", {}).get("ENABLE_QEMU_SUPPORT", False)), - confidential_computing=bool( - crn_info.get("computing", {}).get("ENABLE_CONFIDENTIAL_COMPUTING", False) - ), - gpu_support=bool(crn_info.get("computing", {}).get("ENABLE_GPU_SUPPORT", False)), - terms_and_conditions=terms_and_conditions, - ) + crn_url = sanitize_url(crn_url) + except aiohttp.InvalidURL as e: + echo(f"Invalid URL provided: {crn_url}") + raise typer.Exit(1) from e + + echo("Fetching compute resource node's list...") + crn_list = await fetch_crn_list() # Precache CRN list + + if (crn_url or crn_hash) and not gpu: + try: + crn = await fetch_crn_info(crn_url, crn_hash) + if crn: + if (crn_hash and crn_hash != crn.hash) or (crn_url and crn_url != crn.url): + echo( + f"* Provided CRN *\nUrl: {crn_url}\nHash: {crn_hash}\n\n* Found CRN *\nUrl: " + f"{crn.url}\nHash: {crn.hash}\n\nMismatch between provided CRN and found CRN" + ) + raise typer.Exit(1) crn.display_crn_specs() + else: + echo(f"* Provided CRN *\nUrl: {crn_url}\nHash: {crn_hash}\n\nProvided CRN not found") + raise typer.Exit(1) except Exception as e: - echo(f"Unable to fetch CRN config: {e}") raise typer.Exit(1) from e while not crn: crn_table = CRNTable( - only_reward_address=is_stream, only_qemu=is_qemu, only_confidentials=confidential, only_gpu=gpu + only_latest_crn_version=True, + only_reward_address=is_stream, + only_qemu=True, + only_confidentials=confidential, + only_gpu=gpu, + only_gpu_model=gpu_model, ) crn = await crn_table.run_async() if not crn: @@ -411,13 +414,13 @@ async def create( "instances are scheduled automatically on available CRNs by the Aleph.im network." ) - requirements, trusted_execution, gpu_requirement = None, None, None + requirements, trusted_execution, gpu_requirement, tac_accepted = None, None, None, None if crn: stream_reward_address = safe_getattr(crn, "stream_reward_address") or "" if is_stream and not stream_reward_address: - echo("Selected CRN does not have a defined receiver address.") + echo("Selected CRN does not have a defined or valid receiver address.") raise typer.Exit(1) - if is_qemu and not safe_getattr(crn, "qemu_support"): + if not safe_getattr(crn, "qemu_support"): echo("Selected CRN does not support QEMU hypervisor.") raise typer.Exit(1) if confidential: @@ -470,51 +473,77 @@ async def create( ) ] if crn.terms_and_conditions: - accepted = await crn.display_terms_and_conditions(auto_accept=crn_auto_tac) - if accepted is None: + tac_accepted = await crn.display_terms_and_conditions(auto_accept=crn_auto_tac) + if tac_accepted is None: echo("Failed to fetch terms and conditions.\nContact support or use a different CRN.") raise typer.Exit(1) - elif not accepted: + elif not tac_accepted: echo("Terms & Conditions rejected: instance creation aborted.") raise typer.Exit(1) echo("Terms & Conditions accepted.") + requirements = HostRequirements( node=NodeRequirements( node_hash=crn.hash, - terms_and_conditions=(ItemHash(crn.terms_and_conditions) if crn.terms_and_conditions else None), + terms_and_conditions=(ItemHash(crn.terms_and_conditions) if tac_accepted else None), ), gpu=gpu_requirement, ) + payment = Payment( + chain=payment_chain, + receiver=stream_reward_address if stream_reward_address else None, + type=payment_type, + ) + + content_dict: dict[str, Any] = { + "address": address, + "rootfs": rootfs, + "rootfs_size": rootfs_size, + "metadata": {"name": name}, + "memory": memory, + "vcpus": vcpus, + "timeout_seconds": timeout_seconds, + "volumes": volumes, + "ssh_keys": [ssh_pubkey], + "hypervisor": hypervisor, + "payment": payment, + "requirements": requirements, + "trusted_execution": trusted_execution, + } + + # Estimate cost and check required balances (Gas + Aleph ERC20) + required_tokens: Decimal + async with AlephHttpClient(api_server=settings.API_HOST) as client: + try: + content = make_instance_content(**content_dict) + price: PriceResponse = await client.get_estimated_price(content) + required_tokens = Decimal(price.required_tokens) + except Exception as e: + echo(f"Failed to estimate instance cost, error: {e}") + raise typer.Exit(code=1) from e + + try: + if is_stream and isinstance(account, ETHAccount): + account.can_start_flow(required_tokens) + elif available_funds < required_tokens: + raise InsufficientFundsError(TokenType.ALEPH, float(required_tokens), float(available_funds)) + except InsufficientFundsError as e: + echo(e) + raise typer.Exit(code=1) from e + async with AuthenticatedAlephHttpClient(account=account, api_server=settings.API_HOST) as client: - payment = Payment( - chain=payment_chain, - receiver=stream_reward_address if stream_reward_address else None, - type=payment_type, - ) try: message, status = await client.create_instance( - sync=True, - rootfs=rootfs, - rootfs_size=rootfs_size, - storage_engine=StorageEnum.storage, + **content_dict, channel=channel, - metadata={"name": name}, - memory=memory, - vcpus=vcpus, - timeout_seconds=timeout_seconds, - volumes=volumes, - ssh_keys=[ssh_pubkey], - hypervisor=hypervisor, - payment=payment, - requirements=requirements, - trusted_execution=trusted_execution, + storage_engine=StorageEnum.storage, + sync=True, ) except InsufficientFundsError as e: echo( f"Instance creation failed due to insufficient funds.\n" - f"{account.get_address()} on {account.CHAIN} has {e.available_funds} ALEPH but " - f"needs {e.required_funds} ALEPH." + f"{address} on {account.CHAIN} has {e.available_funds} ALEPH but needs {e.required_funds} ALEPH." ) raise typer.Exit(code=1) from e except Exception as e: @@ -539,45 +568,53 @@ async def create( await wait_for_processed_instance(session, item_hash) # Pay-As-You-Go - if payment_type == PaymentType.superfluid: - price: PriceResponse = await client.get_program_price(item_hash) - ceil_factor = 10**18 - required_tokens = ceil(Decimal(price.required_tokens) * ceil_factor) / ceil_factor - if isinstance(account, ETHAccount) and account.superfluid_connector: - try: # Double check with effective price - account.superfluid_connector.can_start_flow(FLOW_INSTANCE_PRICE_PER_SECOND) # Min for 0.11/h - except Exception as e: - echo(e) - raise typer.Exit(code=1) from e - flow_hash = await update_flow( - account=account, - receiver=crn.stream_reward_address, - flow=Decimal(required_tokens), + if is_stream and isinstance(account, ETHAccount): + # Start the flows + echo("Starting the flows...") + fetched_settings = await fetch_settings() + community_wallet_address = fetched_settings.get("community_wallet_address") + flow_crn_amount = required_tokens * Decimal("0.8") + flow_hash_crn = await account.manage_flow( + receiver=crn.stream_reward_address, + flow=flow_crn_amount, + update_type=FlowUpdate.INCREASE, + ) + if flow_hash_crn: + await asyncio.sleep(5) # 2nd flow tx fails if no delay + flow_hash_community = await account.manage_flow( + receiver=community_wallet_address, + flow=required_tokens - flow_crn_amount, update_type=FlowUpdate.INCREASE, ) - # Wait for the flow transaction to be confirmed - await wait_for_confirmed_flow(account, message.content.payment.receiver) - if flow_hash: - flow_info = "\n".join( - f"[orange3]{key}[/orange3]: {value}" - for key, value in { - "Hash": flow_hash, - "Aleph cost": ( - f"{price.required_tokens:.7f}/sec | {3600*price.required_tokens:.2f}/hour | " - f"{86400*price.required_tokens:.2f}/day | {2592000*price.required_tokens:.2f}/month" - ), - "CRN receiver address": crn.stream_reward_address, - }.items() - ) - console.print( - Panel( - flow_info, - title="Flow Created", - border_style="violet", - expand=False, - title_align="left", - ) + else: + echo("Flow creation failed. Check your wallet balance and try recreate the VM.") + raise typer.Exit(code=1) + # Wait for the flow transactions to be confirmed + await wait_for_confirmed_flow(account, crn.stream_reward_address) + await wait_for_confirmed_flow(account, community_wallet_address) + if flow_hash_crn and flow_hash_community: + flow_info = "\n".join( + f"[orange3]{key}[/orange3]: {value}" + for key, value in { + "$ALEPH": f"[violet]{displayable_amount(required_tokens, decimals=8)}/sec" + f" | {displayable_amount(3600*required_tokens, decimals=3)}/hour" + f" | {displayable_amount(86400*required_tokens, decimals=3)}/day" + f" | {displayable_amount(2628000*required_tokens, decimals=3)}/month[/violet]", + "Flow Distribution": "\n[bright_cyan]80% -> CRN wallet[/bright_cyan]" + f"\n Address: {crn.stream_reward_address}\n Tx: {flow_hash_crn}" + f"\n[bright_cyan]20% -> Community wallet[/bright_cyan]" + f"\n Address: {community_wallet_address}\n Tx: {flow_hash_community}", + }.items() + ) + console.print( + Panel( + Text.from_markup(flow_info), + title="Flows Created", + border_style="violet", + expand=False, + title_align="left", ) + ) # Notify CRN async with VmClient(account, crn.url) as crn_client: @@ -678,13 +715,14 @@ async def delete( echo("Instance does not exist") raise typer.Exit(code=1) from None except ForgottenMessageError: - echo("Instance already forgotten") + echo("Instance already deleted") raise typer.Exit(code=1) from None if existing_message.sender != account.get_address(): echo("You are not the owner of this instance") raise typer.Exit(code=1) - # If PAYG, retrieve flow price + # If PAYG, retrieve creation time & flow price + creation_time: float = existing_message.content.time payment: Optional[Payment] = existing_message.content.payment price: Optional[PriceResponse] = None if safe_getattr(payment, "type") == PaymentType.superfluid: @@ -694,8 +732,7 @@ async def delete( chain = existing_message.content.payment.chain # type: ignore # Check status of the instance and eventually erase associated VM - node_list: NodeInfo = await _fetch_nodes() - _, info = await fetch_vm_info(existing_message, node_list) + _, info = await fetch_vm_info(existing_message) auto_scheduled = info["allocation_type"] == help_strings.ALLOCATION_AUTO crn_url = (info["crn_url"] not in [help_strings.CRN_PENDING, help_strings.CRN_UNKNOWN] and info["crn_url"]) or ( domain and sanitize_url(domain) @@ -721,12 +758,36 @@ async def delete( if payment and payment.type == PaymentType.superfluid and payment.receiver and isinstance(account, ETHAccount): if account.CHAIN != payment.chain: account.switch_chain(payment.chain) - if account.superfluid_connector and price: - flow_hash = await update_flow( - account, payment.receiver, Decimal(price.required_tokens), FlowUpdate.REDUCE + if safe_getattr(account, "superfluid_connector") and price: + fetched_settings = await fetch_settings() + community_wallet_timestamp = fetched_settings.get("community_wallet_timestamp") + community_wallet_address = fetched_settings.get("community_wallet_address") + try: # Safety check to ensure account can transact + account.can_transact() + except Exception as e: + echo(e) + raise typer.Exit(code=1) from e + echo("Deleting the flows...") + flow_crn_percent = Decimal("0.8") if community_wallet_timestamp < creation_time else Decimal("1") + flow_com_percent = Decimal("1") - flow_crn_percent + flow_hash_crn = await account.manage_flow( + payment.receiver, Decimal(price.required_tokens) * flow_crn_percent, FlowUpdate.REDUCE ) - if flow_hash: - echo(f"Flow {flow_hash} has been deleted.") + if flow_hash_crn: + echo(f"CRN flow has been deleted successfully (Tx: {flow_hash_crn})") + if flow_com_percent > Decimal("0"): + await asyncio.sleep(5) + flow_hash_community = await account.manage_flow( + community_wallet_address, + Decimal(price.required_tokens) * flow_com_percent, + FlowUpdate.REDUCE, + ) + if flow_hash_community: + echo(f"Community flow has been deleted successfully (Tx: {flow_hash_community})") + else: + echo("No community flow to delete (legacy instance). Skipping...") + else: + echo("No flow to delete. Skipping...") message, status = await client.forget(hashes=[ItemHash(item_hash)], reason=reason) if print_message: @@ -734,13 +795,14 @@ async def delete( echo(f"Instance {item_hash} has been deleted.") -async def _show_instances(messages: builtins.list[InstanceMessage], node_list: NodeInfo): +async def _show_instances(messages: builtins.list[InstanceMessage]): table = Table(box=box.ROUNDED, style="blue_violet") table.add_column(f"Instances [{len(messages)}]", style="blue", overflow="fold") table.add_column("Specifications", style="blue") table.add_column("Logs", style="blue", overflow="fold") - scheduler_responses = dict(await asyncio.gather(*[fetch_vm_info(message, node_list) for message in messages])) + await fetch_crn_list() # Precache CRN list + scheduler_responses = dict(await asyncio.gather(*[fetch_vm_info(message) for message in messages])) uninitialized_confidential_found = False for message in messages: info = scheduler_responses[message.item_hash] @@ -759,12 +821,11 @@ async def _show_instances(messages: builtins.list[InstanceMessage], node_list: N link = f"https://explorer.aleph.im/address/ETH/{message.sender}/message/INSTANCE/{message.item_hash}" # link = f"{settings.API_HOST}/api/v0/messages/{message.item_hash}" item_hash_link = Text.from_markup(f"[link={link}]{message.item_hash}[/link]", style="bright_cyan") - is_hold = info["payment"] == "hold" payment = Text.assemble( "Payment: ", Text( info["payment"].capitalize().ljust(12), - style="red" if is_hold else "orange3", + style="red" if info["payment"] == PaymentType.hold.value else "orange3", ), ) confidential = Text.assemble( @@ -774,15 +835,21 @@ async def _show_instances(messages: builtins.list[InstanceMessage], node_list: N created_at = Text.assemble( "Created at: ", Text(str(str_to_datetime(info["created_at"])).split(".", maxsplit=1)[0], style="orchid") ) - cost: Text | str = "" - if not is_hold: - async with AlephHttpClient(api_server=settings.API_HOST) as client: - price: PriceResponse = await client.get_program_price(message.item_hash) - psec = Text(f"{price.required_tokens:.7f}/sec", style="magenta3") - phour = Text(f"{3600*price.required_tokens:.2f}/hour", style="magenta3") - pday = Text(f"{86400*price.required_tokens:.2f}/day", style="magenta3") - pmonth = Text(f"{2592000*price.required_tokens:.2f}/month", style="magenta3") - cost = Text.assemble("\nAleph cost: ", psec, " | ", phour, " | ", pday, " | ", pmonth) + async with AlephHttpClient(api_server=settings.API_HOST) as client: + price: PriceResponse = await client.get_program_price(message.item_hash) + required_tokens = Decimal(price.required_tokens) + if price.payment_type == PaymentType.hold.value: + aleph_price = Text(f"{displayable_amount(required_tokens, decimals=3)} (fixed)", style="violet") + else: + psec = f"{displayable_amount(required_tokens, decimals=8)}/sec" + phour = f"{displayable_amount(3600*required_tokens, decimals=3)}/hour" + pday = f"{displayable_amount(86400*required_tokens, decimals=3)}/day" + pmonth = f"{displayable_amount(2628000*required_tokens, decimals=3)}/month" + aleph_price = Text.assemble(psec, " | ", phour, " | ", pday, " | ", pmonth, style="violet") + cost = Text.assemble("\n$ALEPH: ", aleph_price) + payer: Union[str, Text] = "" + if message.sender != message.content.address: + payer = Text.assemble("\nPayer: ", Text(str(message.sender), style="orange1")) instance = Text.assemble( "Item Hash ā†“\t Name: ", name, @@ -795,6 +862,7 @@ async def _show_instances(messages: builtins.list[InstanceMessage], node_list: N chain, created_at, cost, + payer, ) hypervisor = safe_getattr(message, "content.environment.hypervisor") specs = [ @@ -856,7 +924,7 @@ async def _show_instances(messages: builtins.list[InstanceMessage], node_list: N console = Console() console.print(table) - infos = [Text.from_markup(f"[bold]Address:[/bold] [bright_cyan]{messages[0].content.address}[/bright_cyan]")] + infos = [Text.from_markup(f"[bold]Address:[/bold] [bright_cyan]{messages[0].sender}[/bright_cyan]")] if uninitialized_confidential_found: infos += [ Text.assemble( @@ -893,9 +961,8 @@ async def list_instances( setup_logging(debug) - if address is None: - account = _load_account(private_key, private_key_file, chain=chain) - address = account.get_address() + account = _load_account(private_key, private_key_file, chain=chain) + address = address or settings.ADDRESS_TO_USE or account.get_address() async with AlephHttpClient(api_server=settings.API_HOST) as client: resp = await client.get_messages( @@ -915,8 +982,7 @@ async def list_instances( else: # Since we filtered on message type, we can safely cast as InstanceMessage. messages = cast(builtins.list[InstanceMessage], messages) - resource_nodes: NodeInfo = await _fetch_nodes() - await _show_instances(messages, resource_nodes) + await _show_instances(messages) @app.command() @@ -1232,9 +1298,10 @@ async def confidential_create( ), name: Optional[str] = typer.Option(None, help=help_strings.INSTANCE_NAME), rootfs: Optional[str] = typer.Option(None, help=help_strings.ROOTFS), - rootfs_size: Optional[int] = typer.Option(None, help=help_strings.ROOTFS_SIZE), + compute_units: Optional[int] = typer.Option(None, help=help_strings.COMPUTE_UNITS), vcpus: Optional[int] = typer.Option(None, help=help_strings.VCPUS), memory: Optional[int] = typer.Option(None, help=help_strings.MEMORY), + rootfs_size: Optional[int] = typer.Option(None, help=help_strings.ROOTFS_SIZE), timeout_seconds: float = typer.Option( settings.DEFAULT_VM_TIMEOUT, help=help_strings.TIMEOUT_SECONDS, @@ -1243,7 +1310,9 @@ async def confidential_create( Path("~/.ssh/id_rsa.pub").expanduser(), help=help_strings.SSH_PUBKEY_FILE, ), + address: Optional[str] = typer.Option(None, help=help_strings.ADDRESS_PAYER), gpu: bool = typer.Option(False, help=help_strings.GPU_OPTION), + premium: Optional[bool] = typer.Option(None, help=help_strings.GPU_PREMIUM_OPTION), skip_volume: bool = typer.Option(False, help=help_strings.SKIP_VOLUME), persistent_volume: Optional[list[str]] = typer.Option(None, help=help_strings.PERSISTENT_VOLUME), ephemeral_volume: Optional[list[str]] = typer.Option(None, help=help_strings.EPHEMERAL_VOLUME), @@ -1277,17 +1346,20 @@ async def confidential_create( hypervisor=HypervisorType.qemu, name=name, rootfs=rootfs, - rootfs_size=rootfs_size, + compute_units=compute_units, vcpus=vcpus, memory=memory, + rootfs_size=rootfs_size, timeout_seconds=timeout_seconds, ssh_pubkey_file=ssh_pubkey_file, + address=address, crn_hash=crn_hash, crn_url=crn_url, crn_auto_tac=crn_auto_tac, confidential=True, confidential_firmware=confidential_firmware, gpu=gpu, + premium=premium, skip_volume=skip_volume, persistent_volume=persistent_volume, ephemeral_volume=ephemeral_volume, @@ -1363,3 +1435,79 @@ async def confidential_create( verbose=True, debug=debug, ) + + +@app.command(name="gpu") +async def gpu_create( + payment_chain: Optional[Chain] = typer.Option( + None, + help=help_strings.PAYMENT_CHAIN, + metavar=metavar_valid_payg_chains, + case_sensitive=False, + ), + name: Optional[str] = typer.Option(None, help=help_strings.INSTANCE_NAME), + rootfs: Optional[str] = typer.Option(None, help=help_strings.ROOTFS), + compute_units: Optional[int] = typer.Option(None, help=help_strings.COMPUTE_UNITS), + vcpus: Optional[int] = typer.Option(None, help=help_strings.VCPUS), + memory: Optional[int] = typer.Option(None, help=help_strings.MEMORY), + rootfs_size: Optional[int] = typer.Option(None, help=help_strings.ROOTFS_SIZE), + premium: Optional[bool] = typer.Option(None, help=help_strings.GPU_PREMIUM_OPTION), + timeout_seconds: float = typer.Option( + settings.DEFAULT_VM_TIMEOUT, + help=help_strings.TIMEOUT_SECONDS, + ), + ssh_pubkey_file: Path = typer.Option( + Path("~/.ssh/id_rsa.pub").expanduser(), + help=help_strings.SSH_PUBKEY_FILE, + ), + address: Optional[str] = typer.Option(None, help=help_strings.ADDRESS_PAYER), + crn_hash: Optional[str] = typer.Option(None, help=help_strings.CRN_HASH), + crn_url: Optional[str] = typer.Option(None, help=help_strings.CRN_URL), + skip_volume: bool = typer.Option(False, help=help_strings.SKIP_VOLUME), + persistent_volume: Optional[list[str]] = typer.Option(None, help=help_strings.PERSISTENT_VOLUME), + ephemeral_volume: Optional[list[str]] = typer.Option(None, help=help_strings.EPHEMERAL_VOLUME), + immutable_volume: Optional[list[str]] = typer.Option( + None, + help=help_strings.IMMUTABLE_VOLUME, + ), + crn_auto_tac: bool = typer.Option(False, help=help_strings.CRN_AUTO_TAC), + channel: Optional[str] = typer.Option(default=settings.DEFAULT_CHANNEL, help=help_strings.CHANNEL), + private_key: Optional[str] = typer.Option(settings.PRIVATE_KEY_STRING, help=help_strings.PRIVATE_KEY), + private_key_file: Optional[Path] = typer.Option(settings.PRIVATE_KEY_FILE, help=help_strings.PRIVATE_KEY_FILE), + print_message: bool = typer.Option(False), + verbose: bool = typer.Option(True), + debug: bool = False, +): + """Create and register a new GPU instance on aleph.im""" + + await create( + payment_type=PaymentType.superfluid, + payment_chain=payment_chain, + hypervisor=HypervisorType.qemu, + name=name, + rootfs=rootfs, + compute_units=compute_units, + vcpus=vcpus, + memory=memory, + rootfs_size=rootfs_size, + timeout_seconds=timeout_seconds, + ssh_pubkey_file=ssh_pubkey_file, + address=address, + crn_hash=crn_hash, + crn_url=crn_url, + crn_auto_tac=crn_auto_tac, + confidential=False, + confidential_firmware=None, + gpu=True, + premium=premium, + skip_volume=skip_volume, + persistent_volume=persistent_volume, + ephemeral_volume=ephemeral_volume, + immutable_volume=immutable_volume, + channel=channel, + private_key=private_key, + private_key_file=private_key_file, + print_message=print_message, + verbose=verbose, + debug=debug, + ) diff --git a/src/aleph_client/commands/instance/display.py b/src/aleph_client/commands/instance/display.py index 63d13566..2d6e9dcb 100644 --- a/src/aleph_client/commands/instance/display.py +++ b/src/aleph_client/commands/instance/display.py @@ -11,10 +11,12 @@ from textual.widgets import DataTable, Footer, Label, ProgressBar from textual.widgets._data_table import RowKey -from aleph_client.commands.instance.network import fetch_crn_info -from aleph_client.commands.node import NodeInfo, _fetch_nodes, _format_score +from aleph_client.commands.instance.network import ( + fetch_crn_list, + fetch_latest_crn_version, +) +from aleph_client.commands.node import _format_score from aleph_client.models import CRNInfo -from aleph_client.utils import extract_valid_eth_address logger = logging.getLogger(__name__) @@ -23,6 +25,7 @@ class CRNTable(App[CRNInfo]): table: DataTable tasks: set[asyncio.Task] = set() crns: dict[RowKey, CRNInfo] = {} + current_crn_version: str total_crns: int active_crns: int = 0 filtered_crns: int = 0 @@ -32,7 +35,12 @@ class CRNTable(App[CRNInfo]): only_qemu: bool = False only_confidentials: bool = False only_gpu: bool = False + only_gpu_model: Optional[str] = None current_sorts: set = set() + loader_label_start: Label + loader_label_end: Label + progress_bar: ProgressBar + BINDINGS = [ ("s", "sort_by_score", "Sort By Score"), ("n", "sort_by_name", "Sort By Name"), @@ -47,16 +55,20 @@ class CRNTable(App[CRNInfo]): def __init__( self, + only_latest_crn_version: bool = False, only_reward_address: bool = False, only_qemu: bool = False, only_confidentials: bool = False, only_gpu: bool = False, + only_gpu_model: Optional[str] = None, ): super().__init__() + self.only_latest_crn_version = only_latest_crn_version self.only_reward_address = only_reward_address self.only_qemu = only_qemu self.only_confidentials = only_confidentials self.only_gpu = only_gpu + self.only_gpu_model = only_gpu_model def compose(self): """Create child widgets for the app.""" @@ -67,7 +79,7 @@ def compose(self): self.table.add_column("Reward Address", key="stream_reward_address") self.table.add_column("šŸ”’", key="confidential_computing") self.table.add_column("GPU", key="gpu_support") - ## self.table.add_column("Qemu", key="qemu_support") ## Qemu computing enabled by default on nodes + ## self.table.add_column("Qemu", key="qemu_support") ## Qemu computing enabled by default on CRNs self.table.add_column("Cores", key="cpu") self.table.add_column("Free RAM šŸŒ”", key="ram") self.table.add_column("Free Disk šŸ’æ", key="hdd") @@ -91,97 +103,80 @@ async def on_mount(self): task.add_done_callback(self.tasks.discard) async def fetch_node_list(self): - nodes: NodeInfo = await _fetch_nodes() - for node in nodes.nodes: - self.crns[RowKey(node["hash"])] = CRNInfo( - hash=node["hash"], - name=node["name"], - url=node["address"].rstrip("/"), - version=None, - score=node["score"], - stream_reward_address=node["stream_reward"], - machine_usage=None, - qemu_support=None, - confidential_computing=None, - gpu_support=None, - terms_and_conditions=node["terms_and_conditions"], - ) + self.crns: dict[RowKey, CRNInfo] = {RowKey(crn.hash): crn for crn in await fetch_crn_list()} + self.current_crn_version = await fetch_latest_crn_version() # Initialize the progress bar self.total_crns = len(self.crns) self.progress_bar.total = self.total_crns - self.loader_label_start.update(f"Fetching data of {self.total_crns} nodes ") + self.loader_label_start.update(f"Fetching data of {self.total_crns} CRNs ") self.tasks = set() # Fetch all CRNs - for node in list(self.crns.values()): - task = asyncio.create_task(self.fetch_node_info(node)) + for crn in list(self.crns.values()): + task = asyncio.create_task(self.add_crn_info(crn)) self.tasks.add(task) task.add_done_callback(self.make_progress) task.add_done_callback(self.tasks.discard) - async def fetch_node_info(self, node: CRNInfo): - try: - crn_info = await fetch_crn_info(node.url) - except Exception as e: - logger.debug(e) + async def add_crn_info(self, crn: CRNInfo): + self.active_crns += 1 + # Skip CRNs with legacy version + if self.only_latest_crn_version and crn.version < self.current_crn_version: + logger.debug(f"Skipping CRN {crn.hash}, legacy version") + return + # Skip CRNs without machine usage + if not crn.machine_usage: + logger.debug(f"Skipping CRN {crn.hash}, no machine usage") + return + # Skip CRNs without ipv6 connectivity + if not crn.ipv6: + logger.debug(f"Skipping CRN {crn.hash}, no ipv6 connectivity") + return + # Skip CRNs without reward address if only_reward_address is set + if self.only_reward_address and not crn.stream_reward_address: + logger.debug(f"Skipping CRN {crn.hash}, no reward address") return - if crn_info: - node.version = crn_info.get("version", "") - node.stream_reward_address = extract_valid_eth_address( - crn_info.get("payment", {}).get("PAYMENT_RECEIVER_ADDRESS") or node.stream_reward_address or "" - ) - node.qemu_support = crn_info.get("computing", {}).get("ENABLE_QEMU_SUPPORT", False) - node.confidential_computing = crn_info.get("computing", {}).get("ENABLE_CONFIDENTIAL_COMPUTING", False) - node.gpu_support = crn_info.get("computing", {}).get("ENABLE_GPU_SUPPORT", False) - node.machine_usage = crn_info.get("machine_usage") - - # Skip nodes without machine usage - if not node.machine_usage: - logger.debug(f"Skipping node {node.hash}, no machine usage") - return - - self.active_crns += 1 - # Skip nodes without reward address if only_reward_address is set - if self.only_reward_address and not node.stream_reward_address: - logger.debug(f"Skipping node {node.hash}, no reward address") - return - # Skip non-qemu nodes if only_qemu is set - if self.only_qemu and not node.qemu_support: - logger.debug(f"Skipping node {node.hash}, no qemu support") - return - # Skip non-confidential nodes if only_confidentials is set - if self.only_confidentials and not node.confidential_computing: - logger.debug(f"Skipping node {node.hash}, no confidential support") - return - # Skip non-gpu nodes if only-gpu is set - if ( - self.only_gpu - and not node.gpu_support - and not (node.machine_usage.gpu and len(node.machine_usage.gpu.available_devices) < 1) - ): - logger.debug(f"Skipping node {node.hash}, no GPU support or without GPU available") - return - self.filtered_crns += 1 - - # Fetch terms and conditions - tac = await node.terms_and_conditions_content - - self.table.add_row( - _format_score(node.score), - node.name, - node.version, - node.stream_reward_address, - "āœ…" if node.confidential_computing else "āœ–", - # "āœ…" if node.qemu_support else "āœ–", ## Qemu computing enabled by default on nodes - "āœ…" if node.gpu_support else "āœ–", - node.display_cpu, - node.display_ram, - node.display_hdd, - node.url, - tac.url if tac else "āœ–", - key=node.hash, - ) + # Skip non-qemu CRNs if only_qemu is set + if self.only_qemu and not crn.qemu_support: + logger.debug(f"Skipping CRN {crn.hash}, no qemu support") + return + # Skip non-confidential CRNs if only_confidentials is set + if self.only_confidentials and not crn.confidential_computing: + logger.debug(f"Skipping CRN {crn.hash}, no confidential support") + return + # Skip non-gpu CRNs if only-gpu is set + if self.only_gpu and not (crn.gpu_support and crn.compatible_available_gpus): + logger.debug(f"Skipping CRN {crn.hash}, no GPU support or without GPU available") + return + # Skip CRNs without compatible GPU if only-gpu-model is set + elif ( + self.only_gpu + and self.only_gpu_model + and self.only_gpu_model not in [gpu["model"] for gpu in crn.compatible_available_gpus] + ): + logger.debug(f"Skipping CRN {crn.hash}, no {self.only_gpu_model} GPU support") + return + self.filtered_crns += 1 + + # Fetch terms and conditions + tac = await crn.terms_and_conditions_content + + self.table.add_row( + _format_score(crn.score), + crn.name, + crn.version, + crn.stream_reward_address, + "āœ…" if crn.confidential_computing else "āœ–", + # "āœ…" if crn.qemu_support else "āœ–", ## Qemu computing enabled by default on crns + "āœ…" if crn.gpu_support else "āœ–", + crn.display_cpu, + crn.display_ram, + crn.display_hdd, + crn.url, + tac.url if tac else "āœ–", + key=crn.hash, + ) def make_progress(self, task): """Called automatically to advance the progress bar.""" @@ -191,7 +186,7 @@ def make_progress(self, task): except NoMatches: pass if len(self.tasks) == 0: - self.loader_label_start.update(f"Fetched {self.total_crns} nodes ") + self.loader_label_start.update(f"Fetched {self.total_crns} CRNs ") def on_data_table_row_selected(self, message: DataTable.RowSelected): """Return the selected row""" diff --git a/src/aleph_client/commands/instance/network.py b/src/aleph_client/commands/instance/network.py index a4f75f28..bc57580b 100644 --- a/src/aleph_client/commands/instance/network.py +++ b/src/aleph_client/commands/instance/network.py @@ -5,7 +5,13 @@ from json import JSONDecodeError from typing import Optional -import aiohttp +from aiohttp import ( + ClientConnectorError, + ClientResponseError, + ClientSession, + ClientTimeout, + InvalidURL, +) from aleph.sdk import AlephHttpClient from aleph.sdk.conf import settings from aleph.sdk.exceptions import ForgottenMessageError, MessageNotFoundError @@ -19,70 +25,155 @@ from aleph_client.commands import help_strings from aleph_client.commands.files import download -from aleph_client.commands.node import NodeInfo, _fetch_nodes -from aleph_client.models import MachineUsage -from aleph_client.utils import fetch_json, sanitize_url +from aleph_client.models import CRNInfo +from aleph_client.utils import ( + async_lru_cache, + extract_valid_eth_address, + fetch_json, + sanitize_url, +) logger = logging.getLogger(__name__) +latest_crn_version_link = "https://api.github.com/repos/aleph-im/aleph-vm/releases/latest" -PATH_STATUS_CONFIG = "/status/config" -PATH_ABOUT_USAGE_SYSTEM = "/about/usage/system" +settings_link = ( + f"{sanitize_url(settings.API_HOST)}" + "/api/v0/aggregates/0xFba561a84A537fCaa567bb7A2257e7142701ae2A.json?keys=settings" +) +crn_list_link = ( + f"{sanitize_url(settings.CRN_URL_FOR_PROGRAMS)}" + "/vm/bec08b08bb9f9685880f3aeb9c1533951ad56abef2a39c97f5a93683bdaa5e30/crns.json" +) -async def fetch_crn_info(node_url: str) -> dict | None: - """ - Fetches compute node usage information and version. +PATH_ABOUT_EXECUTIONS_LIST = "/about/executions/list" + + +@async_lru_cache +async def call_program_crn_list() -> Optional[dict]: + """Call program to fetch the compute resource node list. - Args: - node_url: URL of the compute node. Returns: - CRN information. + dict: Dictionary containing the compute resource node list. """ - url = "" + try: - base_url: str = sanitize_url(node_url) - timeout = aiohttp.ClientTimeout(total=settings.HTTP_REQUEST_TIMEOUT) - async with aiohttp.ClientSession(timeout=timeout) as session: - info: dict - url = base_url + PATH_STATUS_CONFIG - async with session.get(url) as resp: - resp.raise_for_status() - info = await resp.json() - url = base_url + PATH_ABOUT_USAGE_SYSTEM - async with session.get(url) as resp: - resp.raise_for_status() - system: dict = await resp.json() - info["machine_usage"] = MachineUsage.parse_obj(system) - return info - except aiohttp.InvalidURL as e: - logger.debug(f"Invalid CRN URL: {url}: {e}") + async with ClientSession(timeout=ClientTimeout(total=60)) as session: + logger.debug("Fetching crn list...") + async with session.get(crn_list_link) as resp: + if resp.status != 200: + error = "Unable to fetch crn list from program" + raise Exception(error) + return await resp.json() + except InvalidURL as e: + error = f"Invalid URL: {crn_list_link}: {e}" except TimeoutError as e: - logger.debug(f"Timeout while fetching CRN: {url}: {e}") - except aiohttp.ClientConnectionError as e: - logger.debug(f"Error on CRN connection: {url}: {e}") - except aiohttp.ClientResponseError as e: - logger.debug(f"Error on CRN response: {url}: {e}") + error = f"Timeout while fetching: {crn_list_link}: {e}" + except ClientConnectorError as e: + error = f"Error on connection: {crn_list_link}: {e}" + except ClientResponseError as e: + error = f"Error on response: {crn_list_link}: {e}" except JSONDecodeError as e: - logger.debug(f"Error decoding CRN JSON: {url}: {e}") - except ValidationError as e: - logger.debug(f"Validation error when fetching CRN: {url}: {e}") + error = f"Error when decoding JSON: {crn_list_link}: {e}" except Exception as e: - logger.debug(f"Unexpected error when fetching CRN: {url}: {e}") - return None + error = f"Unexpected error while fetching: {crn_list_link}: {e}" + raise Exception(error) -async def fetch_vm_info(message: InstanceMessage, node_list: NodeInfo) -> tuple[str, dict[str, str]]: +@async_lru_cache +async def fetch_latest_crn_version() -> str: + """Fetch the latest crn version. + + Returns: + str: Latest crn version as x.x.x. """ - Fetches VM information given an instance message and the node list. + + async with ClientSession() as session: + try: + data = await fetch_json(session, latest_crn_version_link) + version = data.get("tag_name") + if not version: + msg = "No tag_name found in GitHub release data" + raise ValueError(msg) + return version + except Exception as e: + logger.error(f"Error while fetching latest crn version: {e}") + raise Exit(code=1) from e + + +@async_lru_cache +async def fetch_crn_list( + latest_crn_version: bool = False, + ipv6: bool = False, + stream_address: bool = False, + confidential: bool = False, + gpu: bool = False, +) -> list[CRNInfo]: + """Fetch compute resource node list, unfiltered by default. + + Args: + latest_crn_version (bool): Filter by latest crn version. + ipv6 (bool): Filter invalid IPv6 configuration. + stream_address (bool): Filter invalid payment receiver address. + confidential (bool): Filter by confidential computing support. + gpu (bool): Filter by GPU support. + Returns: + list[CRNInfo]: List of compute resource nodes. + """ + + data = await call_program_crn_list() + current_crn_version = await fetch_latest_crn_version() + crns = [] + for crn in data.get("crns"): + if latest_crn_version and (crn.get("version") or "0.0.0") < current_crn_version: + continue + if ipv6: + ipv6_check = crn.get("ipv6_check") + if not ipv6_check or not all(ipv6_check.values()): + continue + if stream_address and not extract_valid_eth_address(crn.get("payment_receiver_address") or ""): + continue + if confidential and not crn.get("confidential_support"): + continue + if gpu and not (crn.get("gpu_support") and crn.get("compatible_available_gpus")): + continue + try: + crns.append(CRNInfo.from_unsanitized_input(crn)) + except ValidationError: + logger.debug(f"Invalid CRN: {crn}") + continue + return crns + + +async def fetch_crn_info(crn_url: Optional[str] = None, crn_hash: Optional[str] = None) -> Optional[CRNInfo]: + """Retrieve a compute resource node by URL. + + Args: + crn_url (Optional[str]): URL of the compute resource node. + crn_hash (Optional[str]): Hash of the compute resource node. + Returns: + Union[CRNInfo, None]: The compute resource node or None if not found. + """ + + crn_url = sanitize_url(crn_url) + crn_list = await fetch_crn_list() + for crn in crn_list: + if crn.url == crn_url or crn.hash == crn_hash: + return crn + return None + + +async def fetch_vm_info(message: InstanceMessage) -> tuple[str, dict[str, str]]: + """Fetches VM information given an instance message. Args: message: Instance message. - node_list: Node list. Returns: VM information. """ - async with aiohttp.ClientSession() as session: + + async with ClientSession() as session: chain = safe_getattr(message, "content.payment.chain.value") hold = safe_getattr(message, "content.payment.type.value") crn_hash = safe_getattr(message, "content.requirements.node.node_hash") @@ -119,21 +210,22 @@ async def fetch_vm_info(message: InstanceMessage, node_list: NodeInfo) -> tuple[ info["ipv6_logs"] = allocation["vm_ipv6"] for node in nodes["nodes"]: if node["ipv6"].split("::")[0] == ":".join(str(info["ipv6_logs"]).split(":")[:4]): - info["crn_url"] = node["url"].rstrip("/") + info["crn_url"] = sanitize_url(node["url"]) break - except (aiohttp.ClientResponseError, aiohttp.ClientConnectorError) as e: + except (ClientResponseError, ClientConnectorError) as e: info["crn_url"] = help_strings.CRN_PENDING info["ipv6_logs"] = help_strings.VM_SCHEDULED logger.debug(f"Error while calling Scheduler API ({url}): {e}") else: - # Fetch from the CRN API if PAYG-tier or confidential or GPU + # Fetch from the CRN program endpoint if PAYG-tier or confidential or GPU info["allocation_type"] = help_strings.ALLOCATION_MANUAL - for node in node_list.nodes: - if node["hash"] == crn_hash: - info["crn_url"] = node["address"].rstrip("/") + node_list = await fetch_crn_list() + for node in node_list: + if node.hash == crn_hash: + info["crn_url"] = node.url break if info["crn_url"]: - path = f"{info['crn_url']}/about/executions/list" + path = f"{info['crn_url']}{PATH_ABOUT_EXECUTIONS_LIST}" executions = await fetch_json(session, path) if message.item_hash in executions: interface = IPv6Interface(executions[message.item_hash]["networking"]["ipv6"]) @@ -147,12 +239,20 @@ async def fetch_vm_info(message: InstanceMessage, node_list: NodeInfo) -> tuple[ tac = await download(tac_hash, only_info=True, verbose=False) tac_url = safe_getattr(tac, "url") or f"missing ā†’ {tac_hash}" info.update({"tac_url": tac_url, "tac_accepted": "Yes"}) - except (aiohttp.ClientResponseError, aiohttp.ClientConnectorError) as e: + except (ClientResponseError, ClientConnectorError) as e: info["ipv6_logs"] = f"Not available. Server error: {e}" return message.item_hash, info async def find_crn_of_vm(vm_id: str) -> Optional[str]: + """Finds the CRN where the VM is running given its item hash. + + Args: + vm_id (str): Item hash of the VM. + Returns: + str: CRN url or None if not found. + """ + async with AlephHttpClient(api_server=settings.API_HOST) as client: message: Optional[InstanceMessage] = None try: @@ -163,7 +263,23 @@ async def find_crn_of_vm(vm_id: str) -> Optional[str]: echo("Instance has been deleted on aleph.im") if not message: raise Exit(code=1) - node_list: NodeInfo = await _fetch_nodes() - _, info = await fetch_vm_info(message, node_list) + _, info = await fetch_vm_info(message) is_valid = info["crn_url"] not in [help_strings.CRN_PENDING, help_strings.CRN_UNKNOWN] return str(info["crn_url"]) if is_valid else None + + +@async_lru_cache +async def fetch_settings() -> dict: + """Fetch the settings from aggregate for flows and gpu instances. + + Returns: + dict: Dictionary containing the settings. + """ + + async with ClientSession() as session: + try: + data = await fetch_json(session, settings_link) + return data.get("data", {}).get("settings") + except Exception as e: + logger.error(f"Error while fetching settings: {e}") + raise Exit(code=1) from e diff --git a/src/aleph_client/commands/instance/superfluid.py b/src/aleph_client/commands/instance/superfluid.py deleted file mode 100644 index 09a67d81..00000000 --- a/src/aleph_client/commands/instance/superfluid.py +++ /dev/null @@ -1,68 +0,0 @@ -import logging -from decimal import Decimal -from enum import Enum - -from aleph.sdk.chains.ethereum import ETHAccount -from aleph.sdk.conf import settings -from click import echo -from eth_utils.currency import to_wei -from superfluid import Web3FlowInfo - -logger = logging.getLogger(__name__) - - -def from_wei(wei_value: Decimal) -> Decimal: - """Converts the given wei value to ether.""" - return wei_value / Decimal(10**settings.TOKEN_DECIMALS) - - -class FlowUpdate(str, Enum): - REDUCE = "reduce" - INCREASE = "increase" - - -async def update_flow(account: ETHAccount, receiver: str, flow: Decimal, update_type: FlowUpdate): - """ - Update the flow of a Superfluid stream between a sender and receiver. - This function either increases or decreases the flow rate between the sender and receiver, - based on the update_type. If no flow exists and the update type is augmentation, it creates a new flow - with the specified rate. If the update type is reduction and the reduction amount brings the flow to zero - or below, the flow is deleted. - - :param account: The SuperFluid account instance used to interact with the blockchain. - :param chain: The blockchain chain to interact with. - :param receiver: Address of the receiver in hexadecimal format. - :param flow: The flow rate to be added or removed (in ether). - :param update_type: The type of update to perform (augmentation or reduction). - :return: The transaction hash of the executed operation (create, update, or delete flow). - """ - - # Retrieve current flow info - flow_info: Web3FlowInfo = await account.get_flow(receiver) - - current_flow_rate_wei: Decimal = Decimal(flow_info["flowRate"] or "0") - flow_rate_wei: int = to_wei(flow, "ether") - - if update_type == FlowUpdate.INCREASE: - if current_flow_rate_wei > 0: - # Update existing flow by augmenting the rate - new_flow_rate_wei = current_flow_rate_wei + flow_rate_wei - new_flow_rate_ether = from_wei(new_flow_rate_wei) - return await account.update_flow(receiver, new_flow_rate_ether) - else: - # Create a new flow if none exists - return await account.create_flow(receiver, flow) - elif update_type == FlowUpdate.REDUCE: - if current_flow_rate_wei > 0: - # Reduce the existing flow - new_flow_rate_wei = current_flow_rate_wei - flow_rate_wei - # Ensure to not leave infinitesimal flows - # Often, there were 1-10 wei remaining in the flow rate, which prevented the flow from being deleted - if new_flow_rate_wei > 99: - new_flow_rate_ether = from_wei(new_flow_rate_wei) - return await account.update_flow(receiver, new_flow_rate_ether) - else: - # Delete the flow if the new flow rate is zero or negative - return await account.delete_flow(receiver) - else: - echo("No existing flow to stop. Skipping...") diff --git a/src/aleph_client/commands/pricing.py b/src/aleph_client/commands/pricing.py new file mode 100644 index 00000000..c72f3a85 --- /dev/null +++ b/src/aleph_client/commands/pricing.py @@ -0,0 +1,335 @@ +from __future__ import annotations + +import logging +from decimal import Decimal +from enum import Enum +from typing import Annotated, Optional + +import aiohttp +import typer +from aleph.sdk.conf import settings +from aleph.sdk.utils import displayable_amount, safe_getattr +from pydantic import BaseModel +from rich import box +from rich.console import Console, Group +from rich.panel import Panel +from rich.table import Table +from rich.text import Text + +from aleph_client.commands.utils import setup_logging, validated_prompt +from aleph_client.utils import async_lru_cache, sanitize_url + +logger = logging.getLogger(__name__) + +pricing_link = ( + f"{sanitize_url(settings.API_HOST)}/api/v0/aggregates/0xFba561a84A537fCaa567bb7A2257e7142701ae2A.json?keys=pricing" +) + + +class PricingEntity(str, Enum): + STORAGE = "storage" + WEB3_HOSTING = "web3_hosting" + PROGRAM = "program" + PROGRAM_PERSISTENT = "program_persistent" + INSTANCE = "instance" + INSTANCE_CONFIDENTIAL = "instance_confidential" + INSTANCE_GPU_STANDARD = "instance_gpu_standard" + INSTANCE_GPU_PREMIUM = "instance_gpu_premium" + + +class GroupEntity(str, Enum): + STORAGE = "storage" + WEBSITE = "website" + PROGRAM = "program" + INSTANCE = "instance" + CONFIDENTIAL = "confidential" + GPU = "gpu" + ALL = "all" + + +PRICING_GROUPS: dict[str, list[PricingEntity]] = { + GroupEntity.STORAGE: [PricingEntity.STORAGE], + GroupEntity.WEBSITE: [PricingEntity.WEB3_HOSTING], + GroupEntity.PROGRAM: [PricingEntity.PROGRAM, PricingEntity.PROGRAM_PERSISTENT], + GroupEntity.INSTANCE: [PricingEntity.INSTANCE], + GroupEntity.CONFIDENTIAL: [PricingEntity.INSTANCE_CONFIDENTIAL], + GroupEntity.GPU: [PricingEntity.INSTANCE_GPU_STANDARD, PricingEntity.INSTANCE_GPU_PREMIUM], + GroupEntity.ALL: list(PricingEntity), +} + +PAYG_GROUP: list[PricingEntity] = [ + PricingEntity.INSTANCE, + PricingEntity.INSTANCE_CONFIDENTIAL, + PricingEntity.INSTANCE_GPU_STANDARD, + PricingEntity.INSTANCE_GPU_PREMIUM, +] + +MAX_VALUE = Decimal(999_999_999) + + +class SelectedTierPrice(BaseModel): + hold: Decimal + payg: Decimal # Token by second + storage: Optional[SelectedTierPrice] + + +class SelectedTier(BaseModel): + tier: int + compute_units: int + vcpus: int + memory: int + disk: int + gpu_model: Optional[str] + price: SelectedTierPrice + + +class Pricing: + def __init__(self, **kwargs): + self.data = kwargs.get("data", {}).get("pricing", {}) + + def display_table_for( + self, + pricing_entity: Optional[PricingEntity] = None, + compute_units: int = 0, + vcpus: int = 0, + memory: int = 0, + disk: int = 0, + gpu_models: Optional[dict[str, dict[str, dict[str, int]]]] = None, + persistent: Optional[bool] = None, + selector: bool = False, + exit_on_error: bool = True, + verbose: bool = True, + ) -> Optional[SelectedTier]: + """Display pricing table for an entity""" + + if not pricing_entity: + if persistent is not None: + # Program entity selection: Persistent or Non-Persistent + pricing_entity = PricingEntity.PROGRAM_PERSISTENT if persistent else PricingEntity.PROGRAM + + entity_name = safe_getattr(pricing_entity, "value") + if pricing_entity: + entity = self.data.get(entity_name) + label = entity_name.replace("_", " ").title() + else: + logger.error(f"Entity {entity_name} not found") + if exit_on_error: + raise typer.Exit(1) + else: + return None + + unit = entity.get("compute_unit", {}) + unit_vcpus = unit.get("vcpus") + unit_memory = unit.get("memory_mib") + unit_disk = unit.get("disk_mib") + price = entity.get("price", {}) + price_unit = price.get("compute_unit") + price_storage = price.get("storage") + price_fixed = price.get("fixed") + tiers = entity.get("tiers", []) + + displayable_group = None + tier_data: dict[int, SelectedTier] = {} + auto_selected = (compute_units or vcpus or memory or disk) and not gpu_models + if tiers: + if auto_selected: + tiers = [ + tier + for tier in tiers + if compute_units <= tier["compute_units"] + and vcpus <= unit_vcpus * tier["compute_units"] + and memory <= unit_memory * tier["compute_units"] + and disk <= unit_disk * tier["compute_units"] + ] + if tiers: + tiers = tiers[:1] + else: + requirements = [] + if compute_units: + requirements.append(f"compute_units>={compute_units}") + if vcpus: + requirements.append(f"vcpus>={vcpus}") + if memory: + requirements.append(f"memory>={memory}") + if disk: + requirements.append(f"disk>={disk}") + typer.echo( + f"Minimum tier with required {' & '.join(requirements)}" + f" not found for {pricing_entity.value}" + ) + if exit_on_error: + raise typer.Exit(1) + else: + return None + + table = Table( + border_style="magenta", + box=box.MINIMAL, + ) + table.add_column("Tier", style="cyan") + table.add_column("Compute Units", style="orchid") + table.add_column("vCPUs", style="bright_cyan") + table.add_column("RAM (GiB)", style="bright_cyan") + table.add_column("Disk (GiB)", style="bright_cyan") + if "model" in tiers[0]: + table.add_column("GPU Model", style="orange1") + if "vram" in tiers[0]: + table.add_column("VRAM (GiB)", style="orange1") + if "holding" in price_unit: + table.add_column("$ALEPH (Holding)", style="red", justify="center") + if "payg" in price_unit and pricing_entity in PAYG_GROUP: + table.add_column("$ALEPH (Pay-As-You-Go)", style="green", justify="center") + if pricing_entity in PRICING_GROUPS[GroupEntity.PROGRAM]: + table.add_column("+ Internet Access", style="orange1", justify="center") + + for tier in tiers: + tier_id = tier["id"].split("-", 1)[1] + current_units = tier["compute_units"] + table.add_section() + row = [ + tier_id, + str(current_units), + str(unit_vcpus * current_units), + f"{unit_memory * current_units / 1024:.0f}", + f"{unit_disk * current_units / 1024:.0f}", + ] + if "model" in tier: + if gpu_models is None: + row.append(tier["model"]) + elif tier["model"] in gpu_models: + gpu_line = tier["model"] + for device, details in gpu_models[tier["model"]].items(): + gpu_line += f"\n[bright_yellow]ā€¢ {device}[/bright_yellow]\n" + gpu_line += f" [grey50]ā†³ [white]{details['count']}[/white]" + gpu_line += f" available on [white]{details['on_crns']}[/white] CRN(s)[/grey50]" + row.append(Text.from_markup(gpu_line)) + else: + continue + if "vram" in tier: + row.append(f"{tier['vram'] / 1024:.0f}") + if "holding" in price_unit: + row.append( + f"{displayable_amount(Decimal(price_unit['holding']) * current_units, decimals=3)} tokens" + ) + if "payg" in price_unit and pricing_entity in PAYG_GROUP: + payg_hourly = Decimal(price_unit["payg"]) * current_units + row.append( + f"{displayable_amount(payg_hourly, decimals=3)} token/hour" + f"\n{displayable_amount(payg_hourly*24, decimals=3)} token/day" + ) + if pricing_entity in PRICING_GROUPS[GroupEntity.PROGRAM]: + internet_cell = ( + "āœ… Included" + if pricing_entity == PricingEntity.PROGRAM_PERSISTENT + else f"{displayable_amount(Decimal(price_unit['holding']) * current_units * 2)} tokens" + ) + row.append(internet_cell) + table.add_row(*row) + + tier_data[tier_id] = SelectedTier( + tier=tier_id, + compute_units=current_units, + vcpus=unit_vcpus * current_units, + memory=unit_memory * current_units, + disk=unit_disk * current_units, + gpu_model=tier.get("model"), + price=SelectedTierPrice( + hold=Decimal(price_unit["holding"]) * current_units if "holding" in price_unit else MAX_VALUE, + payg=Decimal(price_unit["payg"]) / 3600 * current_units if "payg" in price_unit else MAX_VALUE, + storage=SelectedTierPrice( + hold=Decimal(price_storage["holding"]) if "holding" in price_storage else MAX_VALUE, + payg=Decimal(price_storage["payg"]) / 3600 if "payg" in price_storage else MAX_VALUE, + storage=None, + ), + ), + ) + + extra_price_holding = ( + f"[red]{displayable_amount(Decimal(price_storage['holding'])*1024, decimals=5)}" + " token/GiB[/red] (Holding) -or- " + if "holding" in price_storage + else "" + ) + infos = [ + Text.from_markup( + f"Extra Volume Cost: {extra_price_holding}" + f"[green]{displayable_amount(Decimal(price_storage['payg'])*1024*24, decimals=5)}" + " token/GiB/day[/green] (Pay-As-You-Go)" + ) + ] + displayable_group = Group( + table, + Text.assemble(*infos), + ) + else: + infos = [Text("\n")] + if price_fixed: + infos.append( + Text.from_markup( + f"Service & Availability (Holding): [orange1]{displayable_amount(price_fixed, decimals=3)}" + " tokens[/orange1]\n\n+ " + ) + ) + infos.append( + Text.from_markup( + "$ALEPH (Holding): [bright_cyan]" + f"{displayable_amount(Decimal(price_storage['holding']), decimals=5)}" + " token/Mib[/bright_cyan] -or- [bright_cyan]" + f"{displayable_amount(Decimal(price_storage['holding'])*1024, decimals=5)}" + " token/GiB[/bright_cyan]" + ) + ) + displayable_group = Group( + Text.assemble(*infos), + ) + + if gpu_models and not tier_data: + typer.echo(f"No GPU available for {label} at the moment.") + raise typer.Exit(1) + elif verbose: + console = Console() + console.print( + Panel( + displayable_group, + title=f"Pricing: {'Selected ' if compute_units else ''}{label}", + border_style="orchid", + expand=False, + title_align="left", + ) + ) + + if selector and pricing_entity not in [PricingEntity.STORAGE, PricingEntity.WEB3_HOSTING]: + if not auto_selected: + tier_id = validated_prompt("Select a tier by index", lambda tier_id: tier_id in tier_data) + return next(iter(tier_data.values())) if auto_selected else tier_data[tier_id] + + return None + + +@async_lru_cache +async def fetch_pricing() -> Pricing: + """Fetch pricing aggregate and format it as Pricing""" + + async with aiohttp.ClientSession() as session: + async with session.get(pricing_link) as resp: + if resp.status != 200: + logger.error("Unable to fetch pricing aggregate") + raise typer.Exit(1) + + data = await resp.json() + return Pricing(**data) + + +async def prices_for_service( + service: Annotated[GroupEntity, typer.Argument(help="Service to display pricing for")], + compute_units: Annotated[int, typer.Option(help="Compute units to display pricing for")] = 0, + debug: bool = False, +): + """Display pricing for services available on aleph.im & twentysix.cloud""" + + setup_logging(debug) + + group = PRICING_GROUPS[service] + pricing = await fetch_pricing() + for entity in group: + pricing.display_table_for(entity, compute_units=compute_units, exit_on_error=False) diff --git a/src/aleph_client/commands/program.py b/src/aleph_client/commands/program.py index cd8af38d..3942931e 100644 --- a/src/aleph_client/commands/program.py +++ b/src/aleph_client/commands/program.py @@ -5,8 +5,9 @@ import re from base64 import b16decode, b32encode from collections.abc import Mapping +from decimal import Decimal from pathlib import Path -from typing import Optional, cast +from typing import Any, Optional, cast from zipfile import BadZipFile import aiohttp @@ -15,10 +16,15 @@ from aleph.sdk.account import _load_account from aleph.sdk.client.vm_client import VmClient from aleph.sdk.conf import settings -from aleph.sdk.exceptions import ForgottenMessageError, MessageNotFoundError +from aleph.sdk.exceptions import ( + ForgottenMessageError, + InsufficientFundsError, + MessageNotFoundError, +) from aleph.sdk.query.filters import MessageFilter -from aleph.sdk.types import AccountFromPrivateKey, StorageEnum -from aleph.sdk.utils import safe_getattr +from aleph.sdk.query.responses import PriceResponse +from aleph.sdk.types import AccountFromPrivateKey, StorageEnum, TokenType +from aleph.sdk.utils import make_program_content, safe_getattr from aleph_message.models import Chain, MessageType, ProgramMessage, StoreMessage from aleph_message.models.execution.program import ProgramContent from aleph_message.models.item_hash import ItemHash @@ -32,6 +38,8 @@ from rich.text import Text from aleph_client.commands import help_strings +from aleph_client.commands.account import get_balance +from aleph_client.commands.pricing import PricingEntity, SelectedTier, fetch_pricing from aleph_client.commands.utils import ( filter_only_valid_messages, get_or_prompt_environment_variables, @@ -56,24 +64,28 @@ async def upload( ..., help=help_strings.PROGRAM_ENTRYPOINT, ), - channel: Optional[str] = typer.Option(default=settings.DEFAULT_CHANNEL, help=help_strings.CHANNEL), - memory: int = typer.Option(settings.DEFAULT_VM_MEMORY, help=help_strings.MEMORY), - vcpus: int = typer.Option(settings.DEFAULT_VM_VCPUS, help=help_strings.VCPUS), - timeout_seconds: float = typer.Option( - settings.DEFAULT_VM_TIMEOUT, - help=help_strings.TIMEOUT_SECONDS, - ), name: Optional[str] = typer.Option(None, help="Name for your program"), runtime: str = typer.Option( None, help=help_strings.PROGRAM_RUNTIME.format(runtime_id=settings.DEFAULT_RUNTIME_ID), ), + compute_units: Optional[int] = typer.Option(None, help=help_strings.COMPUTE_UNITS), + vcpus: Optional[int] = typer.Option(None, help=help_strings.VCPUS), + memory: Optional[int] = typer.Option(None, help=help_strings.MEMORY), + timeout_seconds: float = typer.Option( + settings.DEFAULT_VM_TIMEOUT, + help=help_strings.TIMEOUT_SECONDS, + ), + internet: bool = typer.Option( + False, + help=help_strings.PROGRAM_INTERNET, + ), + updatable: bool = typer.Option(False, help=help_strings.PROGRAM_UPDATABLE), beta: bool = typer.Option( False, help=help_strings.PROGRAM_BETA, ), - persistent: bool = False, - updatable: bool = typer.Option(False, help=help_strings.PROGRAM_UPDATABLE), + persistent: bool = typer.Option(False, help=help_strings.PROGRAM_PERSISTENT), skip_volume: bool = typer.Option(False, help=help_strings.SKIP_VOLUME), persistent_volume: Optional[list[str]] = typer.Option(None, help=help_strings.PERSISTENT_VOLUME), ephemeral_volume: Optional[list[str]] = typer.Option(None, help=help_strings.EPHEMERAL_VOLUME), @@ -83,6 +95,8 @@ async def upload( ), skip_env_var: bool = typer.Option(False, help=help_strings.SKIP_ENV_VAR), env_vars: Optional[str] = typer.Option(None, help=help_strings.ENVIRONMENT_VARIABLES), + address: Optional[str] = typer.Option(None, help=help_strings.ADDRESS_PAYER), + channel: Optional[str] = typer.Option(default=settings.DEFAULT_CHANNEL, help=help_strings.CHANNEL), private_key: Optional[str] = typer.Option(settings.PRIVATE_KEY_STRING, help=help_strings.PRIVATE_KEY), private_key_file: Optional[Path] = typer.Option(settings.PRIVATE_KEY_FILE, help=help_strings.PRIVATE_KEY_FILE), print_messages: bool = typer.Option(False), @@ -109,32 +123,7 @@ async def upload( raise typer.Exit(code=4) from error account: AccountFromPrivateKey = _load_account(private_key, private_key_file) - - name = name or validated_prompt("Program name", lambda x: len(x) < 65) - runtime = runtime or input(f"Ref of runtime? [{settings.DEFAULT_RUNTIME_ID}] ") or settings.DEFAULT_RUNTIME_ID - - volumes = [] - if not skip_volume: - volumes = get_or_prompt_volumes( - persistent_volume=persistent_volume, - ephemeral_volume=ephemeral_volume, - immutable_volume=immutable_volume, - ) - - environment_variables = None - if not skip_env_var: - environment_variables = get_or_prompt_environment_variables(env_vars) - - subscriptions: Optional[list[Mapping]] = None - if beta and yes_no_input("Subscribe to messages?", default=False): - content_raw = input_multiline() - try: - subscriptions = json.loads(content_raw) - except json.decoder.JSONDecodeError as error: - typer.echo("Not valid JSON") - raise typer.Exit(code=2) from error - else: - subscriptions = None + address = address or settings.ADDRESS_TO_USE or account.get_address() async with AuthenticatedAlephHttpClient(account=account, api_server=settings.API_HOST) as client: # Upload the source code @@ -153,30 +142,105 @@ async def upload( guess_mime_type=True, ref=None, ) - logger.debug("Upload finished") + logger.debug("Code upload finished") if print_messages or print_code_message: typer.echo(f"{user_code.json(indent=4)}") program_ref = user_code.item_hash - # Register the program - message, status = await client.create_program( - program_ref=program_ref, - entrypoint=entrypoint, - metadata={"name": name}, - allow_amend=updatable, - runtime=runtime, - storage_engine=StorageEnum.storage, - channel=channel, - memory=memory, - vcpus=vcpus, - timeout_seconds=timeout_seconds, - persistent=persistent, - encoding=encoding, - volumes=volumes, - environment_variables=environment_variables, - subscriptions=subscriptions, + pricing = await fetch_pricing() + pricing_entity = PricingEntity.PROGRAM_PERSISTENT if persistent else PricingEntity.PROGRAM + tier = cast( # Safe cast + SelectedTier, + pricing.display_table_for( + pricing_entity, + compute_units=compute_units or 0, + vcpus=vcpus or 0, + memory=memory or 0, + disk=0, + selector=True, + verbose=verbose, + ), ) - logger.debug("Upload finished") + name = name or validated_prompt("Program name", lambda x: x and len(x) < 65) + vcpus = tier.vcpus + memory = tier.memory + runtime = runtime or input(f"Ref of runtime? [{settings.DEFAULT_RUNTIME_ID}] ") or settings.DEFAULT_RUNTIME_ID + + volumes = [] + if not skip_volume: + volumes = get_or_prompt_volumes( + persistent_volume=persistent_volume, + ephemeral_volume=ephemeral_volume, + immutable_volume=immutable_volume, + ) + + environment_variables = None + if not skip_env_var: + environment_variables = get_or_prompt_environment_variables(env_vars) + + subscriptions: Optional[list[Mapping]] = None + if beta and yes_no_input("Subscribe to messages?", default=False): + content_raw = input_multiline() + try: + subscriptions = json.loads(content_raw) + except json.decoder.JSONDecodeError as error: + typer.echo("Not valid JSON") + raise typer.Exit(code=2) from error + else: + subscriptions = None + + content_dict: dict[str, Any] = { + "program_ref": program_ref, + "entrypoint": entrypoint, + "runtime": runtime, + "metadata": {"name": name}, + "address": address, + "vcpus": vcpus, + "memory": memory, + "timeout_seconds": timeout_seconds, + "internet": internet, + "allow_amend": updatable, + "encoding": encoding, + "persistent": persistent, + "volumes": volumes, + "environment_variables": environment_variables, + "subscriptions": subscriptions, + } + + # Estimate cost and check required balances (Aleph ERC20) + required_tokens: Decimal + try: + content = make_program_content(**content_dict) + price: PriceResponse = await client.get_estimated_price(content) + required_tokens = Decimal(price.required_tokens) + except Exception as e: + typer.echo(f"Failed to estimate program cost, error: {e}") + raise typer.Exit(code=1) from e + + available_funds = Decimal((await get_balance(address))["available_amount"]) + try: + if available_funds < required_tokens: + raise InsufficientFundsError(TokenType.ALEPH, float(required_tokens), float(available_funds)) + except InsufficientFundsError as e: + typer.echo(e) + raise typer.Exit(code=1) from e + + # Register the program + try: + message, status = await client.create_program( + **content_dict, + channel=channel, + storage_engine=StorageEnum.storage, + sync=True, + ) + except InsufficientFundsError as e: + typer.echo( + f"Program creation failed due to insufficient funds.\n" + f"{address} has {e.available_funds} ALEPH but needs {e.required_funds} ALEPH." + ) + raise typer.Exit(code=1) from e + + logger.debug("Program upload finished") if print_messages or print_program_message: typer.echo(f"{message.json(indent=4)}") @@ -287,7 +351,7 @@ async def update( guess_mime_type=True, ref=code_message.item_hash, ) - logger.debug("Upload finished") + logger.debug("Code upload finished") if print_message: typer.echo(f"{message.json(indent=4)}") @@ -395,9 +459,8 @@ async def list_programs( setup_logging(debug) - if address is None: - account = _load_account(private_key, private_key_file) - address = account.get_address() + account = _load_account(private_key, private_key_file) + address = address or settings.ADDRESS_TO_USE or account.get_address() async with AlephHttpClient(api_server=settings.API_HOST) as client: resp = await client.get_messages( @@ -459,6 +522,7 @@ async def list_programs( f"RAM: [magenta3]{message.content.resources.memory / 1_024:.2f} GiB[/magenta3]\n", "HyperV: [magenta3]Firecracker[/magenta3]\n", f"Timeout: [orange3]{message.content.resources.seconds}s[/orange3]\n", + f"Internet: {'[green]Yes[/green]' if message.content.environment.internet else '[red]No[/red]'}\n", f"Persistent: {'[green]Yes[/green]' if message.content.on.persistent else '[red]No[/red]'}\n", f"Updatable: {'[green]Yes[/green]' if message.content.allow_amend else '[orange3]Code only[/orange3]'}", ] @@ -494,7 +558,7 @@ async def list_programs( console.print(table) infos = [ Text.from_markup( - f"[bold]Address:[/bold] [bright_cyan]{messages[0].content.address}[/bright_cyan]\n\nTo access any " + f"[bold]Address:[/bold] [bright_cyan]{messages[0].sender}[/bright_cyan]\n\nTo access any " "program's logs, use:\n" ), Text.from_markup( @@ -755,17 +819,20 @@ async def runtime_checker( program_hash = await upload( path=Path(__file__).resolve().parent / "program_utils/runtime_checker.squashfs", entrypoint="main:app", - channel=settings.DEFAULT_CHANNEL, - memory=settings.DEFAULT_VM_MEMORY, - vcpus=settings.DEFAULT_VM_VCPUS, - timeout_seconds=settings.DEFAULT_VM_TIMEOUT, name="runtime_checker", runtime=item_hash, - beta=False, + compute_units=1, + vcpus=None, + memory=None, + timeout_seconds=None, + internet=False, persistent=False, updatable=False, + beta=False, skip_volume=True, skip_env_var=True, + address=None, + channel=settings.DEFAULT_CHANNEL, private_key=private_key, private_key_file=private_key_file, print_messages=False, @@ -778,7 +845,7 @@ async def runtime_checker( msg = "No program hash" raise Exception(msg) except Exception as e: - echo(f"Failed to deploy the runtime checker program: {e}") + echo("Failed to deploy the runtime checker program") raise typer.Exit(code=1) from e program_url = settings.VM_URL_PATH.format(hash=program_hash) diff --git a/src/aleph_client/commands/utils.py b/src/aleph_client/commands/utils.py index 1f942f39..dc428bfe 100644 --- a/src/aleph_client/commands/utils.py +++ b/src/aleph_client/commands/utils.py @@ -82,28 +82,37 @@ def yes_no_input(text: str, default: str | bool) -> bool: def prompt_for_volumes(): while yes_no_input("Add volume?", default=False): mount = validated_prompt("Mount path (ex: /opt/data): ", lambda text: len(text) > 0) - name = validated_prompt("Name: ", lambda text: len(text) > 0) - comment = Prompt.ask("Comment: ") - persistent = yes_no_input("Persist on VM host?", default=False) - if persistent: - size_mib = validated_int_prompt("Size (MiB): ", min_value=1) - yield { - "comment": comment, - "mount": mount, - "name": name, - "persistence": "host", - "size_mib": size_mib, - } - else: + comment = Prompt.ask("Comment (description): ") + base_volume = {"mount": mount, "comment": comment} + + if yes_no_input("Use an immutable volume?", default=False): ref = validated_prompt("Item hash: ", lambda text: len(text) == 64) use_latest = yes_no_input("Use latest version?", default=True) yield { - "comment": comment, - "mount": mount, - "name": name, + **base_volume, "ref": ref, "use_latest": use_latest, } + elif yes_no_input("Persist on VM host?", default=False): + parent = None + if yes_no_input("Copy from a parent volume?", default=False): + parent = {"ref": validated_prompt("Item hash: ", lambda text: len(text) == 64), "use_latest": True} + name = validated_prompt("Name: ", lambda text: len(text) > 0) + size_mib = validated_int_prompt("Size (MiB): ", min_value=1, max_value=2048000) + yield { + **base_volume, + "parent": parent, + "persistence": "host", + "name": name, + "size_mib": size_mib, + } + else: # Ephemeral + size_mib = validated_int_prompt("Size (MiB): ", min_value=1, max_value=1024) + yield { + **base_volume, + "ephemeral": True, + "size_mib": size_mib, + } def volume_to_dict(volume: list[str]) -> Optional[dict[str, Union[str, int]]]: @@ -174,7 +183,7 @@ def get_or_prompt_environment_variables(env_vars: Optional[str]) -> Optional[dic def str_to_datetime(date: Optional[str]) -> Optional[datetime]: """ - Converts a string representation of a date/time to a datetime object. + Converts a string representation of a date/time to a datetime object in local time. The function can accept either a timestamp or an ISO format datetime string as the input. """ @@ -182,10 +191,13 @@ def str_to_datetime(date: Optional[str]) -> Optional[datetime]: return None try: date_f = float(date) - return datetime.fromtimestamp(date_f, tz=timezone.utc) + utc_dt = datetime.fromtimestamp(date_f, tz=timezone.utc) + return utc_dt.astimezone() except ValueError: - pass - return datetime.fromisoformat(date) + dt = datetime.fromisoformat(date) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return dt.astimezone() T = TypeVar("T") @@ -227,7 +239,7 @@ def validated_int_prompt( while True: try: value = IntPrompt.ask( - prompt + f" [min: {min_value or '-'}, max: {max_value or '-'}]", + prompt + f" [orange1][/orange1]", default=default, ) except PromptError: @@ -324,3 +336,29 @@ def find_sevctl_or_exit() -> Path: echo("Instructions for setup https://docs.aleph.im/computing/confidential/requirements/") raise Exit(code=1) return Path(sevctl_path) + + +def found_gpus_by_model(crn_list: list) -> dict[str, dict[str, dict[str, int]]]: + found_gpu_models: dict[str, dict[str, dict[str, int]]] = {} + for crn_ in crn_list: + found_gpus: dict[str, dict[str, dict[str, int]]] = {} + for gpu_ in crn_.compatible_available_gpus: + model = gpu_["model"] + device = gpu_["device_name"] + if model not in found_gpus: + found_gpus[model] = {device: {"count": 1, "on_crns": 1}} + elif device not in found_gpus[model]: + found_gpus[model][device] = {"count": 1, "on_crns": 1} + else: + found_gpus[model][device]["count"] += 1 + for model, devices in found_gpus.items(): + if model not in found_gpu_models: + found_gpu_models[model] = devices + else: + for device, details in devices.items(): + if device not in found_gpu_models[model]: + found_gpu_models[model][device] = details + else: + found_gpu_models[model][device]["count"] += details["count"] + found_gpu_models[model][device]["on_crns"] += details["on_crns"] + return found_gpu_models diff --git a/src/aleph_client/models.py b/src/aleph_client/models.py index 6703d474..93214127 100644 --- a/src/aleph_client/models.py +++ b/src/aleph_client/models.py @@ -1,6 +1,7 @@ from datetime import datetime -from typing import Optional +from typing import Any, Optional +from aiohttp import InvalidURL from aleph.sdk.types import StoredContent from aleph_message.models import ItemHash from aleph_message.models.execution.environment import CpuProperties, GpuDeviceClass @@ -13,6 +14,7 @@ from aleph_client.commands.files import download from aleph_client.commands.node import _escape_and_normalize, _remove_ansi_escape +from aleph_client.utils import extract_valid_eth_address, sanitize_url class LoadAverage(BaseModel): @@ -53,10 +55,12 @@ class MachineProperties(BaseModel): class GpuDevice(BaseModel): vendor: str + model: str device_name: str device_class: GpuDeviceClass pci_host: str device_id: str + compatible: bool class GPUProperties(BaseModel): @@ -127,15 +131,55 @@ def from_unsanitized_input( class CRNInfo(BaseModel): hash: ItemHash name: str + owner: str url: str + ccn_hash: Optional[str] + status: Optional[str] version: Optional[str] score: float + reward_address: str stream_reward_address: str machine_usage: Optional[MachineUsage] - qemu_support: Optional[bool] - confidential_computing: Optional[bool] - gpu_support: Optional[bool] + ipv6: bool + qemu_support: bool + confidential_computing: bool + gpu_support: bool terms_and_conditions: Optional[str] + compatible_available_gpus: Optional[list] + + @staticmethod + def from_unsanitized_input( + crn: dict[str, Any], + ) -> "CRNInfo": + payment_receiver_address = crn.get("payment_receiver_address") + stream_reward_address = extract_valid_eth_address(payment_receiver_address) if payment_receiver_address else "" + system_usage = crn.get("system_usage") + machine_usage = MachineUsage.parse_obj(system_usage) if system_usage else None + ipv6_check = crn.get("ipv6_check") + ipv6 = bool(ipv6_check and all(ipv6_check.values())) + try: + url = sanitize_url(crn["address"]) + except InvalidURL: + url = "" + return CRNInfo( + hash=crn["hash"], + name=crn["name"], + owner=crn["owner"], + url=url, + version=crn["version"], + ccn_hash=crn["parent"], + status=crn["status"], + score=crn["score"], + reward_address=crn["reward"], + stream_reward_address=stream_reward_address, + machine_usage=machine_usage, + ipv6=ipv6, + qemu_support=bool(crn["qemu_support"]), + confidential_computing=bool(crn["confidential_support"]), + gpu_support=bool(crn["gpu_support"]), + terms_and_conditions=crn["terms_and_conditions"], + compatible_available_gpus=crn["compatible_available_gpus"], + ) @property def display_cpu(self) -> str: diff --git a/src/aleph_client/utils.py b/src/aleph_client/utils.py index dd4451a4..cc3c5aaa 100644 --- a/src/aleph_client/utils.py +++ b/src/aleph_client/utils.py @@ -6,7 +6,9 @@ import os import re import subprocess -from functools import partial, wraps +import sys +from asyncio import ensure_future +from functools import lru_cache, partial, wraps from pathlib import Path from shutil import make_archive from typing import Optional, Union @@ -179,3 +181,12 @@ def sanitize_url(url: str) -> str: msg = "Invalid URL host" raise aiohttp.InvalidURL(msg) return url.strip("/") + + +def async_lru_cache(async_function): + + @lru_cache(maxsize=0 if "pytest" in sys.modules else 1) + def cached_async_function(*args, **kwargs): + return ensure_future(async_function(*args, **kwargs)) + + return cached_async_function diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 785d8802..b61486c9 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -3,20 +3,24 @@ from aleph.sdk.chains.evm import EVMAccount from aleph.sdk.conf import settings -from eth_utils.currency import to_wei from pydantic import BaseModel from aleph_client.commands.node import NodeInfo # Change to Aleph testnet -settings.API_HOST = "https://api.twentysix.testnet.network" +# settings.API_HOST = "https://api.twentysix.testnet.network" +settings.API_HOST = "http://51.159.223.120:4024" # TODO: change it # Utils FAKE_PRIVATE_KEY = b"cafe" * 8 FAKE_PUBKEY_FILE = "/path/fake/pubkey" FAKE_ADDRESS_EVM = "0x00001A0e6B9a46Be48a294D74D897d9C48678862" -FAKE_STORE_HASH = "102682ea8bcc0cec9c42f32fbd2660286b4eb31003108440988343726304607a" # Has to exist on Aleph Testnet -FAKE_STORE_HASH_CONTENT_FILE_CID = "QmX8K1c22WmQBAww5ShWQqwMiFif7XFrJD6iFBj7skQZXW" # From FAKE_STORE_HASH message +# FAKE_STORE_HASH = "102682ea8bcc0cec9c42f32fbd2660286b4eb31003108440988343726304607a" # Has to exist on Aleph Testnet +# FAKE_STORE_HASH_CONTENT_FILE_CID = "QmX8K1c22WmQBAww5ShWQqwMiFif7XFrJD6iFBj7skQZXW" # From FAKE_STORE_HASH message +# FAKE_STORE_HASH_PUBLISHER = "0x74F82AC22C1EB20dDb9799284FD8D60eaf48A8fb" # From FAKE_STORE_HASH message +FAKE_STORE_HASH = "5b868dc8c2df0dd9bb810b7a31cc50c8ad1e6569905e45ab4fd2eee36fecc4d2" # TODO: change it +FAKE_STORE_HASH_CONTENT_FILE_CID = "QmXSEnpQCnUfeGFoSjY1XAK1Cuad5CtAaqyachGTtsFSuA" # TODO: change it +FAKE_STORE_HASH_PUBLISHER = "0xe0aaF578B287de16852dbc54Ae34a263FF2F4b9E" # TODO: change it FAKE_VM_HASH = "ab12" * 16 FAKE_PROGRAM_HASH = "cd34" * 16 FAKE_PROGRAM_HASH_2 = "ef56" * 16 @@ -39,11 +43,8 @@ def create_mock_load_account(): mock_loader = MagicMock(return_value=mock_account) mock_loader.return_value.get_super_token_balance = MagicMock(return_value=Decimal(10000 * (10**18))) mock_loader.return_value.can_transact = MagicMock(return_value=True) - mock_loader.return_value.superfluid_connector = MagicMock(can_start_flow=MagicMock(return_value=True)) - mock_loader.return_value.get_flow = AsyncMock(return_value={"flowRate": to_wei(0.0001, unit="ether")}) - mock_loader.return_value.create_flow = AsyncMock(return_value=FAKE_FLOW_HASH) - mock_loader.return_value.update_flow = AsyncMock(return_value=FAKE_FLOW_HASH) - mock_loader.return_value.delete_flow = AsyncMock(return_value=FAKE_FLOW_HASH) + mock_loader.return_value.can_start_flow = MagicMock(return_value=True) + mock_loader.return_value.manage_flow = AsyncMock(return_value=FAKE_FLOW_HASH) return mock_loader diff --git a/tests/unit/test_commands.py b/tests/unit/test_commands.py index 369d338a..2a3e1581 100644 --- a/tests/unit/test_commands.py +++ b/tests/unit/test_commands.py @@ -9,7 +9,11 @@ from aleph_client.__main__ import app -from .mocks import FAKE_STORE_HASH, FAKE_STORE_HASH_CONTENT_FILE_CID +from .mocks import ( + FAKE_STORE_HASH, + FAKE_STORE_HASH_CONTENT_FILE_CID, + FAKE_STORE_HASH_PUBLISHER, +) runner = CliRunner() @@ -158,11 +162,11 @@ def test_message_get(): [ "message", "get", - "102682ea8bcc0cec9c42f32fbd2660286b4eb31003108440988343726304607a", + FAKE_STORE_HASH, ], ) assert result.exit_code == 0 - assert "0x74F82AC22C1EB20dDb9799284FD8D60eaf48A8fb" in result.stdout + assert FAKE_STORE_HASH_PUBLISHER in result.stdout def test_message_find(): @@ -175,12 +179,12 @@ def test_message_find(): "--page=1", "--start-date=1234", "--chains=ETH", - "--hashes=102682ea8bcc0cec9c42f32fbd2660286b4eb31003108440988343726304607a", + f"--hashes={FAKE_STORE_HASH}", ], ) assert result.exit_code == 0 - assert "0x74F82AC22C1EB20dDb9799284FD8D60eaf48A8fb" in result.stdout - assert "102682ea8bcc0cec9c42f32fbd2660286b4eb31003108440988343726304607a" in result.stdout + assert FAKE_STORE_HASH_PUBLISHER in result.stdout + assert FAKE_STORE_HASH in result.stdout def test_post_message(env_files): diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index 05bdde22..94d6aba1 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -63,10 +63,12 @@ def dummy_gpu_device() -> GpuDevice: return GpuDevice( vendor="NVIDIA", + model="RTX 4090", device_name="RTX 4090", device_class=GpuDeviceClass.VGA_COMPATIBLE_CONTROLLER, pci_host="01:00.0", device_id="abcd:1234", + compatible=True, ) @@ -78,7 +80,7 @@ def dummy_machine_info() -> MachineInfo: hash=FAKE_CRN_HASH, name="Mock CRN", url="https://example.com", - version="v420.69", + version="123.420.69", score=0.5, reward_address=FAKE_ADDRESS_EVM, machine_usage=MachineUsage( @@ -113,44 +115,33 @@ def dummy_machine_info() -> MachineInfo: ) -def create_mock_crn_info(): - mock_machine_info = dummy_machine_info() - return MagicMock( - return_value=CRNInfo( - hash=ItemHash(FAKE_CRN_HASH), - name="Mock CRN", - url=FAKE_CRN_URL, - version="v420.69", - score=0.5, - stream_reward_address=mock_machine_info.reward_address, - machine_usage=mock_machine_info.machine_usage, - qemu_support=True, - confidential_computing=True, - gpu_support=True, - terms_and_conditions=FAKE_STORE_HASH, - ) - ) - - def dict_to_ci_multi_dict_proxy(d: dict) -> CIMultiDictProxy: """Return a read-only proxy to a case-insensitive multi-dict created from a dict.""" return CIMultiDictProxy(CIMultiDict(d)) +def create_mock_fetch_latest_crn_version(): + return AsyncMock(return_value="123.420.69") + + @pytest.mark.asyncio -async def test_fetch_crn_info() -> None: +async def test_fetch_crn_info(): + mock_fetch_latest_crn_version = create_mock_fetch_latest_crn_version() + + @patch("aleph_client.commands.instance.network.fetch_latest_crn_version", mock_fetch_latest_crn_version) + async def fetch_crn_info_with_mock(url): + print() # For better display when pytest -v -s + return await fetch_crn_info(url) + # Test with valid node - # TODO: Mock the response from the node, don't rely on a real node - node_url = "https://ovh.staging.aleph.sh" - info = await fetch_crn_info(node_url) + node_url = "https://crn-lon04.omega-aleph.com/" # Always prefer a top score CRN here + info = await fetch_crn_info_with_mock(node_url) assert info - assert info["machine_usage"] - + assert info.machine_usage # Test with invalid node invalid_node_url = "https://coconut.example.org/" - assert not (await fetch_crn_info(invalid_node_url)) - - # TODO: Test different error handling + assert not (await fetch_crn_info_with_mock(invalid_node_url)) + mock_fetch_latest_crn_version.assert_called() def test_sanitize_url_with_empty_url(): @@ -193,7 +184,7 @@ def create_mock_instance_message(mock_account, payg=False, coco=False, gpu=False item_hash=vm_item_hash, content=Dict( address=mock_account.get_address(), - time=1734037086.2333803, + time=2999999999.1234567, metadata={"name": "mock_instance"}, authorized_keys=["ssh-rsa ..."], environment=Dict(hypervisor=HypervisorType.qemu, trusted_execution=None), @@ -251,6 +242,31 @@ def create_mock_validate_ssh_pubkey_file(): ) +def create_mock_fetch_crn_info(): + mock_machine_info = dummy_machine_info() + return AsyncMock( + return_value=CRNInfo( + hash=ItemHash(FAKE_CRN_HASH), + name="Mock CRN", + owner=FAKE_ADDRESS_EVM, + url=FAKE_CRN_URL, + ccn_hash=FAKE_CRN_HASH, + status="linked", + version="123.420.69", + score=0.9, + reward_address=FAKE_ADDRESS_EVM, + stream_reward_address=mock_machine_info.reward_address, + machine_usage=mock_machine_info.machine_usage, + ipv6=True, + qemu_support=True, + confidential_computing=True, + gpu_support=True, + terms_and_conditions=FAKE_STORE_HASH, + compatible_available_gpus=[dummy_gpu_device()], + ) + ) + + def create_mock_fetch_vm_info(): return AsyncMock( return_value=[FAKE_VM_HASH, {"crn_url": FAKE_CRN_URL, "allocation_type": help_strings.ALLOCATION_MANUAL}] @@ -261,28 +277,48 @@ def create_mock_shutil(): return MagicMock(which=MagicMock(return_value="/root/.cargo/bin/sevctl", move=MagicMock(return_value="/fake/path"))) -def create_mock_client(): +def create_mock_client(payment_type="superfluid"): mock_client = AsyncMock( get_message=AsyncMock(return_value=True), get_stored_content=AsyncMock( return_value=Dict(filename="fake_tac", hash="0xfake_tac", url="https://fake.tac.com") ), + get_estimated_price=AsyncMock( + return_value=MagicMock( + required_tokens=0.00001527777777777777 if payment_type == "superfluid" else 1000, + payment_type=payment_type, + ) + ), ) mock_client_class = MagicMock() mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client) return mock_client_class, mock_client -def create_mock_auth_client(mock_account): +def create_mock_auth_client(mock_account, payment_type="superfluid", payment_types=None): + + def response_get_program_price(ptype): + return MagicMock( + required_tokens=0.00001527777777777777 if ptype == "superfluid" else 1000, + payment_type=ptype, + ) + mock_response_get_message = create_mock_instance_message(mock_account, payg=True) mock_response_create_instance = MagicMock(item_hash=FAKE_VM_HASH) mock_auth_client = AsyncMock( get_messages=AsyncMock(), get_message=AsyncMock(return_value=mock_response_get_message), create_instance=AsyncMock(return_value=[mock_response_create_instance, 200]), - get_program_price=AsyncMock(return_value=MagicMock(required_tokens=0.0001)), + get_program_price=None, forget=AsyncMock(return_value=(MagicMock(), 200)), ) + if payment_types: + mock_auth_client.get_program_price = AsyncMock( + side_effect=[response_get_program_price(pt) for pt in payment_types] + ) + else: + mock_auth_client.get_program_price = AsyncMock(return_value=response_get_program_price(payment_type)) + mock_auth_client_class = MagicMock() mock_auth_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_auth_client) return mock_auth_client_class, mock_auth_client @@ -328,6 +364,20 @@ def create_mock_vm_coco_client(): return mock_vm_coco_client_class, mock_vm_coco_client +# TODO: GPU test requires a rework +""" ( # gpu_superfluid_evm + { + "payment_type": "superfluid", + "payment_chain": "BASE", + "rootfs": "debian12", + "crn_hash": FAKE_CRN_HASH, + "crn_url": FAKE_CRN_URL, + "gpu": True, + }, + (FAKE_VM_HASH, FAKE_CRN_URL, "BASE"), + ), """ + + @pytest.mark.parametrize( ids=[ "regular_hold_evm", @@ -336,7 +386,7 @@ def create_mock_vm_coco_client(): "coco_hold_sol", "coco_hold_evm", "coco_superfluid_evm", - "gpu_superfluid_evm", + # "gpu_superfluid_evm", ], argnames="args, expected", argvalues=[ @@ -402,17 +452,6 @@ def create_mock_vm_coco_client(): }, (FAKE_VM_HASH, FAKE_CRN_URL, "BASE"), ), - ( # gpu_superfluid_evm - { - "payment_type": "superfluid", - "payment_chain": "BASE", - "rootfs": "debian12", - "crn_hash": FAKE_CRN_HASH, - "crn_url": FAKE_CRN_URL, - "gpu": True, - }, - (FAKE_VM_HASH, FAKE_CRN_URL, "BASE"), - ), ], ) @pytest.mark.asyncio @@ -420,23 +459,26 @@ async def test_create_instance(args, expected): mock_validate_ssh_pubkey_file = create_mock_validate_ssh_pubkey_file() mock_load_account = create_mock_load_account() mock_account = mock_load_account.return_value - mock_client_class, _ = create_mock_client() - mock_auth_client_class, mock_auth_client = create_mock_auth_client(mock_account) + mock_get_balance = AsyncMock(return_value={"available_amount": 100000}) + mock_client_class, mock_client = create_mock_client(payment_type=args["payment_type"]) + mock_auth_client_class, mock_auth_client = create_mock_auth_client(mock_account, payment_type=args["payment_type"]) mock_vm_client_class, mock_vm_client = create_mock_vm_client() - mock_crn_info = create_mock_crn_info() + mock_fetch_latest_crn_version = create_mock_fetch_latest_crn_version() + mock_fetch_crn_info = create_mock_fetch_crn_info() mock_validated_int_prompt = MagicMock(return_value=1) mock_wait_for_processed_instance = AsyncMock() - mock_update_flow = AsyncMock(return_value="fake_flow_hash") mock_wait_for_confirmed_flow = AsyncMock() @patch("aleph_client.commands.instance.validate_ssh_pubkey_file", mock_validate_ssh_pubkey_file) @patch("aleph_client.commands.instance._load_account", mock_load_account) + @patch("aleph_client.commands.instance.get_balance", mock_get_balance) @patch("aleph_client.commands.instance.AlephHttpClient", mock_client_class) @patch("aleph_client.commands.instance.AuthenticatedAlephHttpClient", mock_auth_client_class) - @patch("aleph_client.commands.instance.CRNInfo", mock_crn_info) + @patch("aleph_client.commands.instance.network.fetch_latest_crn_version", mock_fetch_latest_crn_version) + @patch("aleph_client.commands.instance.fetch_crn_info", mock_fetch_crn_info) @patch("aleph_client.commands.instance.validated_int_prompt", mock_validated_int_prompt) @patch("aleph_client.commands.instance.wait_for_processed_instance", mock_wait_for_processed_instance) - @patch("aleph_client.commands.instance.update_flow", mock_update_flow) + @patch.object(asyncio, "sleep", AsyncMock()) @patch("aleph_client.commands.instance.wait_for_confirmed_flow", mock_wait_for_confirmed_flow) @patch("aleph_client.commands.instance.VmClient", mock_vm_client_class) async def create_instance(instance_spec): @@ -445,9 +487,10 @@ async def create_instance(instance_spec): "ssh_pubkey_file": FAKE_PUBKEY_FILE, "name": "mock_instance", "hypervisor": HypervisorType.qemu, - "rootfs_size": 20480, - "vcpus": 1, - "memory": 2048, + "compute_units": 1, + "vcpus": None, + "memory": None, + "rootfs_size": None, "timeout_seconds": settings.DEFAULT_VM_TIMEOUT, "skip_volume": True, "persistent_volume": None, @@ -455,10 +498,12 @@ async def create_instance(instance_spec): "immutable_volume": None, "crn_auto_tac": True, "channel": settings.DEFAULT_CHANNEL, + "address": None, "crn_hash": None, "crn_url": None, "confidential": False, "gpu": False, + "premium": None, "private_key": None, "private_key_file": None, "print_message": False, @@ -468,13 +513,22 @@ async def create_instance(instance_spec): return await create(**all_args) returned = await create_instance(args) + # Basic assertions for all cases mock_load_account.assert_called_once() mock_validate_ssh_pubkey_file.return_value.read_text.assert_called_once() + mock_client.get_estimated_price.assert_called_once() mock_auth_client.create_instance.assert_called_once() - if args["payment_type"] == "superfluid": + # Payment type specific assertions + if args["payment_type"] == "hold": + mock_get_balance.assert_called_once() + elif args["payment_type"] == "superfluid": + assert mock_account.manage_flow.call_count == 2 + assert mock_wait_for_confirmed_flow.call_count == 2 + # CRN related assertions + if args["payment_type"] == "superfluid" or args.get("confidential") or args.get("gpu"): + mock_fetch_latest_crn_version.assert_called() + mock_fetch_crn_info.assert_called_once() mock_wait_for_processed_instance.assert_called_once() - mock_update_flow.assert_called_once() - mock_wait_for_confirmed_flow.assert_called_once() mock_vm_client.start_instance.assert_called_once() assert returned == expected @@ -483,11 +537,15 @@ async def create_instance(instance_spec): async def test_list_instances(): mock_load_account = create_mock_load_account() mock_account = mock_load_account.return_value + mock_fetch_latest_crn_version = create_mock_fetch_latest_crn_version() mock_client_class, mock_client = create_mock_client() - mock_auth_client_class, mock_auth_client = create_mock_auth_client(mock_account) mock_instance_messages = create_mock_instance_messages(mock_account) + mock_auth_client_class, mock_auth_client = create_mock_auth_client( + mock_account, payment_types=[vm.content.payment.type for vm in mock_instance_messages.return_value] + ) @patch("aleph_client.commands.instance._load_account", mock_load_account) + @patch("aleph_client.commands.instance.network.fetch_latest_crn_version", mock_fetch_latest_crn_version) @patch("aleph_client.commands.files.AlephHttpClient", mock_client_class) @patch("aleph_client.commands.instance.AlephHttpClient", mock_auth_client_class) @patch("aleph_client.commands.instance.filter_only_valid_messages", mock_instance_messages) @@ -500,9 +558,10 @@ async def list_instance(): debug=False, ) mock_instance_messages.assert_called_once() + mock_fetch_latest_crn_version.assert_called() mock_auth_client.get_messages.assert_called_once() mock_auth_client.get_program_price.assert_called() - assert mock_auth_client.get_program_price.call_count == 4 + assert mock_auth_client.get_program_price.call_count == 5 assert mock_client.get_stored_content.call_count == 1 await list_instance() @@ -520,6 +579,7 @@ async def test_delete_instance(): @patch("aleph_client.commands.instance.AuthenticatedAlephHttpClient", mock_auth_client_class) @patch("aleph_client.commands.instance.fetch_vm_info", mock_fetch_vm_info) @patch("aleph_client.commands.instance.VmClient", mock_vm_client_class) + @patch.object(asyncio, "sleep", AsyncMock()) async def delete_instance(): print() # For better display when pytest -v -s await delete( @@ -530,7 +590,7 @@ async def delete_instance(): ) mock_auth_client.get_message.assert_called_once() mock_vm_client.erase_instance.assert_called_once() - mock_account.delete_flow.assert_awaited_once() + assert mock_account.manage_flow.call_count == 2 mock_auth_client.forget.assert_called_once() await delete_instance() @@ -731,10 +791,8 @@ async def coco_start(): "payment_chain": "AVAX", "crn_hash": FAKE_CRN_HASH, "crn_url": FAKE_CRN_URL, - "vcpus": 1, - "memory": 2048, "rootfs": FAKE_STORE_HASH, - "rootfs_size": 20480, + "compute_units": 1, }, {"vm_id": FAKE_VM_HASH}, # coco_from_hash ], @@ -770,14 +828,17 @@ async def coco_create(instance_spec): "crn_hash": None, "crn_url": None, "ssh_pubkey_file": FAKE_PUBKEY_FILE, + "address": None, "name": "mock_instance", "vm_secret": "fake_secret", + "compute_units": None, "vcpus": None, "memory": None, + "rootfs_size": None, "timeout_seconds": settings.DEFAULT_VM_TIMEOUT, "gpu": False, + "premium": None, "rootfs": None, - "rootfs_size": None, "skip_volume": True, "persistent_volume": None, "ephemeral_volume": None, diff --git a/tests/unit/test_pricing.py b/tests/unit/test_pricing.py new file mode 100644 index 00000000..2f13606e --- /dev/null +++ b/tests/unit/test_pricing.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +import pytest + +from aleph_client.commands.pricing import GroupEntity, prices_for_service + + +@pytest.mark.parametrize( + ids=list(GroupEntity), + argnames="args", + argvalues=list(GroupEntity), +) +@pytest.mark.asyncio +async def test_prices_for_service(capsys, args): + print() # For better display when pytest -v -s + await prices_for_service(service=args) + captured = capsys.readouterr() + assert captured.out.startswith("\nā•­ā”€ Pricing:") diff --git a/tests/unit/test_program.py b/tests/unit/test_program.py index 94db51f2..7ef1a07e 100644 --- a/tests/unit/test_program.py +++ b/tests/unit/test_program.py @@ -22,6 +22,7 @@ ) from .mocks import ( + FAKE_ADDRESS_EVM, FAKE_PROGRAM_HASH, FAKE_PROGRAM_HASH_2, FAKE_STORE_HASH, @@ -31,7 +32,9 @@ ) -def create_mock_program_message(mock_account, program_item_hash=None, persistent=False, allow_amend=True): +def create_mock_program_message( + mock_account, program_item_hash=None, internet=False, persistent=False, allow_amend=True +): if not program_item_hash: tmp = list(FAKE_PROGRAM_HASH) random.shuffle(tmp) @@ -39,7 +42,7 @@ def create_mock_program_message(mock_account, program_item_hash=None, persistent program = Dict( chain=Chain.ETH, sender=mock_account.get_address(), - type="program", + type="vm-function", channel="ALEPH-CLOUDSOLUTIONS", confirmed=True, item_type="inline", @@ -49,7 +52,12 @@ def create_mock_program_message(mock_account, program_item_hash=None, persistent type="vm-function", address=mock_account.get_address(), time=1734037086.2333803, - metadata={"name": "mock_program"}, + metadata={ + "name": f"mock_program{'_internet' if internet else ''}" + f"{'_persistent' if persistent else ''}" + f"{'_updatable' if allow_amend else ''}", + }, + environment=Dict(internet=internet), resources=Dict(vcpus=1, memory=1024, seconds=30), volumes=[ Dict(name="immutable", mount="/opt/packages", ref=FAKE_STORE_HASH), @@ -68,8 +76,10 @@ def create_mock_program_message(mock_account, program_item_hash=None, persistent def create_mock_program_messages(mock_account): return AsyncMock( return_value=[ + create_mock_program_message(mock_account, allow_amend=False), + create_mock_program_message(mock_account, internet=True, allow_amend=False), + create_mock_program_message(mock_account, persistent=True, allow_amend=False), create_mock_program_message(mock_account), - create_mock_program_message(mock_account, persistent=True), ] ) @@ -86,6 +96,12 @@ def create_mock_auth_client(mock_account, swap_persistent=False): create_program=AsyncMock(return_value=[MagicMock(item_hash=FAKE_PROGRAM_HASH), 200]), forget=AsyncMock(return_value=(MagicMock(), 200)), submit=AsyncMock(return_value=[mock_response_get_message_2, 200, MagicMock()]), + get_estimated_price=AsyncMock( + return_value=MagicMock( + required_tokens=1000, + payment_type="hold", + ) + ), ) mock_auth_client_class = MagicMock() mock_auth_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_auth_client) @@ -137,27 +153,32 @@ async def test_upload_program(): mock_load_account = create_mock_load_account() mock_account = mock_load_account.return_value mock_auth_client_class, mock_auth_client = create_mock_auth_client(mock_account) + mock_get_balance = AsyncMock(return_value={"available_amount": 100000}) @patch("aleph_client.commands.program._load_account", mock_load_account) @patch("aleph_client.utils.os.path.isfile", MagicMock(return_value=True)) @patch("aleph_client.commands.program.AuthenticatedAlephHttpClient", mock_auth_client_class) + @patch("aleph_client.commands.program.get_balance", mock_get_balance) @patch("aleph_client.commands.program.open", MagicMock()) async def upload_program(): print() # For better display when pytest -v -s returned = await upload( + address=FAKE_ADDRESS_EVM, path=Path("/fake/file.squashfs"), entrypoint="main:app", - channel=settings.DEFAULT_CHANNEL, - memory=settings.DEFAULT_VM_MEMORY, - vcpus=settings.DEFAULT_VM_VCPUS, - timeout_seconds=settings.DEFAULT_VM_TIMEOUT, name="mock_program", runtime=settings.DEFAULT_RUNTIME_ID, + compute_units=1, + vcpus=None, + memory=None, + timeout_seconds=None, + internet=False, + updatable=True, beta=False, persistent=False, - updatable=True, skip_volume=True, skip_env_var=True, + channel=settings.DEFAULT_CHANNEL, private_key=None, private_key_file=None, print_messages=False, @@ -168,6 +189,8 @@ async def upload_program(): ) mock_load_account.assert_called_once() mock_auth_client.create_store.assert_called_once() + mock_get_balance.assert_called_once() + mock_auth_client.get_estimated_price.assert_called_once() mock_auth_client.create_program.assert_called_once() assert returned == FAKE_PROGRAM_HASH