Skip to content
3 changes: 2 additions & 1 deletion src/aleph/sdk/client/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from aleph.sdk.client.services.port_forwarder import PortForwarder
from aleph.sdk.client.services.pricing import Pricing
from aleph.sdk.client.services.scheduler import Scheduler
from aleph.sdk.client.services.settings import Settings as NetworkSettingsService
from aleph.sdk.client.services.voucher import Vouchers

from ..conf import settings
Expand Down Expand Up @@ -146,7 +147,7 @@ async def __aenter__(self):
self.instance = Instance(self)
self.pricing = Pricing(self)
self.voucher = Vouchers(self)

self.network_settings = NetworkSettingsService(self)
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
Expand Down
10 changes: 3 additions & 7 deletions src/aleph/sdk/client/services/crn.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,31 +97,27 @@ def find_gpu_on_network(self):

def filter_crn(
self,
latest_crn_version: bool = False,
crn_version: Optional[str] = None,
ipv6: bool = False,
stream_address: bool = False,
confidential: bool = False,
gpu: bool = False,
) -> list[CRN]:
"""Filter compute resource node list, unfiltered by default.
Args:
latest_crn_version (bool): Filter by latest crn version.
crn_version (str): Filter by specific crn version.
ipv6 (bool): Filter invalid IPv6 configuration.
stream_address (bool): Filter invalid payment receiver address.
confidential (bool): Filter by confidential computing support.
gpu (bool): Filter by GPU support.
Returns:
list[CRN]: List of compute resource nodes. (if no filter applied, return all)
"""
# current_crn_version = await fetch_latest_crn_version()
# Relax current filter to allow use aleph-vm versions since 1.5.1.
# TODO: Allow to specify that option on settings aggregate on maybe on GitHub
current_crn_version = "1.5.1"

filtered_crn: list[CRN] = []
for crn_ in self.crns:
# Check crn version
if latest_crn_version and (crn_.version or "0.0.0") < current_crn_version:
if crn_version and (crn_.version or "0.0.0") < crn_version:
continue

# Filter with ipv6 check
Expand Down
5 changes: 2 additions & 3 deletions src/aleph/sdk/client/services/pricing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import TYPE_CHECKING, Dict, List, Optional, Union

from aleph.sdk.client.services.base import BaseService
from aleph.sdk.conf import settings

if TYPE_CHECKING:
pass
Expand Down Expand Up @@ -205,9 +206,7 @@ def __init__(self, client):
async def get_pricing_aggregate(
self,
) -> PricingModel:
result = await self.get_config(
address="0xFba561a84A537fCaa567bb7A2257e7142701ae2A"
)
result = await self.get_config(address=settings.ALEPH_AGGREGATE_ADDRESS)
return result.data[0]

async def get_pricing_for_services(
Expand Down
40 changes: 40 additions & 0 deletions src/aleph/sdk/client/services/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import List

from pydantic import BaseModel

from aleph.sdk.conf import settings

from .base import BaseService


class NetworkAvailableGpu(BaseModel):
name: str
model: str
vendor: str
device_id: str


class NetworkSettingsModel(BaseModel):
compatible_gpus: List[NetworkAvailableGpu]
last_crn_version: str
community_wallet_address: str
community_wallet_timestamp: int


class Settings(BaseService[NetworkSettingsModel]):
"""
This Service handle logic around Pricing
"""

aggregate_key = "settings"
model_cls = NetworkSettingsModel

def __init__(self, client):
super().__init__(client=client)

# Config from aggregate
async def get_settings_aggregate(
self,
) -> NetworkSettingsModel:
result = await self.get_config(address=settings.ALEPH_AGGREGATE_ADDRESS)
return result.data[0]
2 changes: 2 additions & 0 deletions src/aleph/sdk/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ class Settings(BaseSettings):
VOUCHER_SOL_REGISTRY: str = "https://api.claim.twentysix.cloud/v1/registry/sol"
VOUCHER_ORIGIN_ADDRESS: str = "0xB34f25f2c935bCA437C061547eA12851d719dEFb"

ALEPH_AGGREGATE_ADDRESS: str = "0xFba561a84A537fCaa567bb7A2257e7142701ae2A"

# Web3Provider settings
TOKEN_DECIMALS: ClassVar[int] = 18
TX_TIMEOUT: ClassVar[int] = 60 * 3
Expand Down
200 changes: 200 additions & 0 deletions tests/unit/services/test_settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
from unittest.mock import AsyncMock, MagicMock

import pytest

from aleph.sdk import AlephHttpClient
from aleph.sdk.client.services.settings import NetworkSettingsModel, Settings


@pytest.fixture
def mock_settings_aggregate_response():
return {
"compatible_gpus": [
{
"name": "AD102GL [L40S]",
"model": "L40S",
"vendor": "NVIDIA",
"device_id": "10de:26b9",
},
{
"name": "GB202 [GeForce RTX 5090]",
"model": "RTX 5090",
"vendor": "NVIDIA",
"device_id": "10de:2685",
},
{
"name": "GB202 [GeForce RTX 5090 D]",
"model": "RTX 5090",
"vendor": "NVIDIA",
"device_id": "10de:2687",
},
{
"name": "AD102 [GeForce RTX 4090]",
"model": "RTX 4090",
"vendor": "NVIDIA",
"device_id": "10de:2684",
},
{
"name": "AD102 [GeForce RTX 4090 D]",
"model": "RTX 4090",
"vendor": "NVIDIA",
"device_id": "10de:2685",
},
{
"name": "GA102 [GeForce RTX 3090]",
"model": "RTX 3090",
"vendor": "NVIDIA",
"device_id": "10de:2204",
},
{
"name": "GA102 [GeForce RTX 3090 Ti]",
"model": "RTX 3090",
"vendor": "NVIDIA",
"device_id": "10de:2203",
},
{
"name": "AD104GL [RTX 4000 SFF Ada Generation]",
"model": "RTX 4000 ADA",
"vendor": "NVIDIA",
"device_id": "10de:27b0",
},
{
"name": "AD104GL [RTX 4000 Ada Generation]",
"model": "RTX 4000 ADA",
"vendor": "NVIDIA",
"device_id": "10de:27b2",
},
{
"name": "GA102GL [RTX A5000]",
"model": "RTX A5000",
"vendor": "NVIDIA",
"device_id": "10de:2231",
},
{
"name": "GA102GL [RTX A6000]",
"model": "RTX A6000",
"vendor": "NVIDIA",
"device_id": "10de:2230",
},
{
"name": "GH100 [H100]",
"model": "H100",
"vendor": "NVIDIA",
"device_id": "10de:2336",
},
{
"name": "GH100 [H100 NVSwitch]",
"model": "H100",
"vendor": "NVIDIA",
"device_id": "10de:22a3",
},
{
"name": "GH100 [H100 CNX]",
"model": "H100",
"vendor": "NVIDIA",
"device_id": "10de:2313",
},
{
"name": "GH100 [H100 SXM5 80GB]",
"model": "H100",
"vendor": "NVIDIA",
"device_id": "10de:2330",
},
{
"name": "GH100 [H100 PCIe]",
"model": "H100",
"vendor": "NVIDIA",
"device_id": "10de:2331",
},
{
"name": "GA100",
"model": "A100",
"vendor": "NVIDIA",
"device_id": "10de:2080",
},
{
"name": "GA100",
"model": "A100",
"vendor": "NVIDIA",
"device_id": "10de:2081",
},
{
"name": "GA100 [A100 SXM4 80GB]",
"model": "A100",
"vendor": "NVIDIA",
"device_id": "10de:20b2",
},
{
"name": "GA100 [A100 PCIe 80GB]",
"model": "A100",
"vendor": "NVIDIA",
"device_id": "10de:20b5",
},
{
"name": "GA100 [A100X]",
"model": "A100",
"vendor": "NVIDIA",
"device_id": "10de:20b8",
},
{
"name": "GH100 [H200 SXM 141GB]",
"model": "H200",
"vendor": "NVIDIA",
"device_id": "10de:2335",
},
{
"name": "GH100 [H200 NVL]",
"model": "H200",
"vendor": "NVIDIA",
"device_id": "10de:233b",
},
{
"name": "AD102GL [RTX 6000 ADA]",
"model": "RTX 6000 ADA",
"vendor": "NVIDIA",
"device_id": "10de:26b1",
},
],
"last_crn_version": "1.7.2",
"community_wallet_address": "0x5aBd3258C5492fD378EBC2e0017416E199e5Da56",
"community_wallet_timestamp": 1739996239,
}


@pytest.mark.asyncio
async def test_get_settings_aggregate(
make_mock_aiohttp_session, mock_settings_aggregate_response
):
client = AlephHttpClient(api_server="http://localhost")

# Properly mock the fetch_aggregate method using monkeypatch
client._http_session = MagicMock()
monkeypatch = AsyncMock(return_value=mock_settings_aggregate_response)
setattr(client, "fetch_aggregate", monkeypatch)

settings_service = Settings(client)
result = await settings_service.get_settings_aggregate()

assert isinstance(result, NetworkSettingsModel)
assert len(result.compatible_gpus) == 24 # We have 24 GPUs in the mock data

rtx4000_gpu = next(
gpu for gpu in result.compatible_gpus if gpu.device_id == "10de:27b0"
)
assert rtx4000_gpu.name == "AD104GL [RTX 4000 SFF Ada Generation]"
assert rtx4000_gpu.model == "RTX 4000 ADA"
assert rtx4000_gpu.vendor == "NVIDIA"

assert result.last_crn_version == "1.7.2"
assert (
result.community_wallet_address == "0x5aBd3258C5492fD378EBC2e0017416E199e5Da56"
)
assert result.community_wallet_timestamp == 1739996239

# Verify that fetch_aggregate was called with the correct parameters
assert monkeypatch.call_count == 1
assert (
monkeypatch.call_args.kwargs["address"]
== "0xFba561a84A537fCaa567bb7A2257e7142701ae2A"
)
assert monkeypatch.call_args.kwargs["key"] == "settings"
Loading