diff --git a/abis/ModelRegistry.abi.json b/abis/ModelRegistry.abi.json index 361ac42..968aaa4 100644 --- a/abis/ModelRegistry.abi.json +++ b/abis/ModelRegistry.abi.json @@ -96,6 +96,56 @@ ], "stateMutability": "view" }, + { + "type": "function", + "name": "getActiveModelsWithDetails", + "inputs": [], + "outputs": [ + { + "name": "models", + "type": "tuple[]", + "internalType": "struct IModelRegistry.ModelDetails[]", + "components": [ + { + "name": "modelId", + "type": "uint256", + "internalType": "uint256" + }, + { + "name": "modelName", + "type": "string", + "internalType": "string" + }, + { + "name": "modelVerifier", + "type": "address", + "internalType": "address" + }, + { + "name": "verificationStrategy", + "type": "uint8", + "internalType": "enum IModelRegistry.VerificationStrategy" + }, + { + "name": "computeCost", + "type": "uint256", + "internalType": "uint256" + }, + { + "name": "requiredFUCUs", + "type": "uint256", + "internalType": "uint256" + }, + { + "name": "isActive", + "type": "bool", + "internalType": "bool" + } + ] + } + ], + "stateMutability": "view" + }, { "type": "function", "name": "getRandomActiveModel", diff --git a/abis/SertnNodesManager.abi.json b/abis/SertnNodesManager.abi.json index 95363ec..376cabd 100644 --- a/abis/SertnNodesManager.abi.json +++ b/abis/SertnNodesManager.abi.json @@ -77,6 +77,29 @@ ], "stateMutability": "view" }, + { + "type": "function", + "name": "getAllNodesWithDetails", + "inputs": [], + "outputs": [ + { + "name": "nodeDetails", + "type": "uint256[8][]", + "internalType": "uint256[8][]" + }, + { + "name": "supportedModels", + "type": "uint256[][]", + "internalType": "uint256[][]" + }, + { + "name": "modelAllocations", + "type": "uint256[][]", + "internalType": "uint256[][]" + } + ], + "stateMutability": "view" + }, { "type": "function", "name": "getAvailableFucusForOperatorModel", diff --git a/abis/SertnServiceManager.abi.json b/abis/SertnServiceManager.abi.json index 10add9f..c24d2c0 100644 --- a/abis/SertnServiceManager.abi.json +++ b/abis/SertnServiceManager.abi.json @@ -424,7 +424,7 @@ }, { "type": "function", - "name": "taskCompleted", + "name": "taskResolved", "inputs": [ { "name": "_operator", @@ -556,6 +556,12 @@ "indexed": false, "internalType": "uint256" }, + { + "name": "token", + "type": "address", + "indexed": false, + "internalType": "contract IERC20" + }, { "name": "currentInterval", "type": "uint32", diff --git a/abis/SertnTaskManager.abi.json b/abis/SertnTaskManager.abi.json index 3bfb695..495c6ec 100644 --- a/abis/SertnTaskManager.abi.json +++ b/abis/SertnTaskManager.abi.json @@ -127,6 +127,150 @@ ], "stateMutability": "view" }, + { + "type": "function", + "name": "getTaskHistoryStats", + "inputs": [], + "outputs": [ + { + "name": "totalTasks", + "type": "uint256", + "internalType": "uint256" + }, + { + "name": "resolvedTasks", + "type": "uint256", + "internalType": "uint256" + }, + { + "name": "rejectedTasks", + "type": "uint256", + "internalType": "uint256" + }, + { + "name": "pendingTasksCount", + "type": "uint256", + "internalType": "uint256" + } + ], + "stateMutability": "view" + }, + { + "type": "function", + "name": "getTasksByModel", + "inputs": [ + { + "name": "modelId", + "type": "uint256", + "internalType": "uint256" + }, + { + "name": "offset", + "type": "uint256", + "internalType": "uint256" + }, + { + "name": "limit", + "type": "uint256", + "internalType": "uint256" + } + ], + "outputs": [ + { + "name": "", + "type": "uint256[]", + "internalType": "uint256[]" + } + ], + "stateMutability": "view" + }, + { + "type": "function", + "name": "getTasksByOperator", + "inputs": [ + { + "name": "operator", + "type": "address", + "internalType": "address" + }, + { + "name": "offset", + "type": "uint256", + "internalType": "uint256" + }, + { + "name": "limit", + "type": "uint256", + "internalType": "uint256" + } + ], + "outputs": [ + { + "name": "", + "type": "uint256[]", + "internalType": "uint256[]" + } + ], + "stateMutability": "view" + }, + { + "type": "function", + "name": "getTasksByState", + "inputs": [ + { + "name": "state", + "type": "uint8", + "internalType": "enum ISertnTaskManager.TaskState" + }, + { + "name": "offset", + "type": "uint256", + "internalType": "uint256" + }, + { + "name": "limit", + "type": "uint256", + "internalType": "uint256" + } + ], + "outputs": [ + { + "name": "", + "type": "uint256[]", + "internalType": "uint256[]" + } + ], + "stateMutability": "view" + }, + { + "type": "function", + "name": "getTasksByUser", + "inputs": [ + { + "name": "user", + "type": "address", + "internalType": "address" + }, + { + "name": "offset", + "type": "uint256", + "internalType": "uint256" + }, + { + "name": "limit", + "type": "uint256", + "internalType": "uint256" + } + ], + "outputs": [ + { + "name": "", + "type": "uint256[]", + "internalType": "uint256[]" + } + ], + "stateMutability": "view" + }, { "type": "function", "name": "initialize", @@ -494,6 +638,102 @@ ], "stateMutability": "view" }, + { + "type": "function", + "name": "tasksByModel", + "inputs": [ + { + "name": "", + "type": "uint256", + "internalType": "uint256" + }, + { + "name": "", + "type": "uint256", + "internalType": "uint256" + } + ], + "outputs": [ + { + "name": "", + "type": "uint256", + "internalType": "uint256" + } + ], + "stateMutability": "view" + }, + { + "type": "function", + "name": "tasksByOperator", + "inputs": [ + { + "name": "", + "type": "address", + "internalType": "address" + }, + { + "name": "", + "type": "uint256", + "internalType": "uint256" + } + ], + "outputs": [ + { + "name": "", + "type": "uint256", + "internalType": "uint256" + } + ], + "stateMutability": "view" + }, + { + "type": "function", + "name": "tasksByState", + "inputs": [ + { + "name": "", + "type": "uint8", + "internalType": "enum ISertnTaskManager.TaskState" + }, + { + "name": "", + "type": "uint256", + "internalType": "uint256" + } + ], + "outputs": [ + { + "name": "", + "type": "uint256", + "internalType": "uint256" + } + ], + "stateMutability": "view" + }, + { + "type": "function", + "name": "tasksByUser", + "inputs": [ + { + "name": "", + "type": "address", + "internalType": "address" + }, + { + "name": "", + "type": "uint256", + "internalType": "uint256" + } + ], + "outputs": [ + { + "name": "", + "type": "uint256", + "internalType": "uint256" + } + ], + "stateMutability": "view" + }, { "type": "function", "name": "transferOwnership", diff --git a/client/pyproject.toml b/client/pyproject.toml index 076a3f2..f9bc7d2 100644 --- a/client/pyproject.toml +++ b/client/pyproject.toml @@ -20,6 +20,9 @@ dependencies = [ "eth-abi>=5.2.0", "python-dotenv>=1.1.1", "gitpython>=3.1.45", + "cachetools>=6.2.1", + "pymemcache>=4.0.0", + "starlette>=0.49.1", ] [project.optional-dependencies] diff --git a/client/src/aggregator/main.py b/client/src/aggregator/main.py index cf8703b..6c7780b 100644 --- a/client/src/aggregator/main.py +++ b/client/src/aggregator/main.py @@ -12,6 +12,7 @@ from aggregator.errors import InvalidProofError from aggregator.server import AggregatorServer from common.abis import ERC20_ABI, STRATEGY_ABI +from common.addresses import addresses from common.auto_update import AutoUpdate from common.config import AggregatorConfig from common.constants import ( @@ -31,6 +32,7 @@ def run_aggregator(config: AggregatorConfig) -> None: + addresses.init_addresses(chain_id=config.chain_id) logger.info("Starting Sertn Aggregator...") aggregator = Aggregator(config=config) threading.Thread(target=aggregator.start_sending_new_tasks, args=[]).start() diff --git a/client/src/aggregator/server.py b/client/src/aggregator/server.py index 76e435e..c9f488c 100644 --- a/client/src/aggregator/server.py +++ b/client/src/aggregator/server.py @@ -1,15 +1,29 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import uvicorn -from fastapi import APIRouter, FastAPI, HTTPException +from fastapi import APIRouter, FastAPI, HTTPException, Query from pydantic import BaseModel from starlette.concurrency import run_in_threadpool -from .errors import InvalidProofError +from aggregator.errors import InvalidProofError +from common.abis import ( + ERC20_ABI, + STRATEGY_ABI, +) +from common.addresses import addresses +from common.cache import CacheBackend, cached_method, get_cache +from common.constants import ( + ADDRESS_REGEXP, + OPERATOR_SET_ID, +) +from common.contract_constants import TaskStateMap, TaskStructMap +from common.logging import get_logger if TYPE_CHECKING: from aggregator.main import Aggregator +logger = get_logger("aggregator_server") + class ProofRequest(BaseModel): """Pydantic model for operator-submitted proof.""" @@ -20,18 +34,59 @@ class ProofRequest(BaseModel): class AggregatorServer: + def __init__(self, aggregator: "Aggregator"): + # List of all strategy addresses + self.strategies = ( + addresses.STRATEGIES_ADDRESSES + addresses.ETH_STRATEGY_ADDRESSES + ) + # OperatorSet is a struct with (avs, id) + self.operator_set = (addresses.SERVICE_MANAGER_ADDRESS, OPERATOR_SET_ID) + self.aggregator = aggregator self.eth_client = aggregator.eth_client self.app = FastAPI() self.router = APIRouter() + + # Initialize cache + config = aggregator.config + self.cache: CacheBackend = get_cache(config=config.caching) + self._register_routes() + def start(self): + host, port = self.aggregator.config.aggregator_server_ip_port_address.split(":") + uvicorn.run(self.app, host=host, port=int(port)) + def _register_routes(self) -> None: - self.router.add_api_route("/proof", self.submit_proof, methods=["POST"]) self.router.add_api_route("/health", self.health, methods=["GET"]) + self.router.add_api_route("/proof", self.submit_proof, methods=["POST"]) + self.router.add_api_route("/models", self.models_list, methods=["GET"]) + self.router.add_api_route( + "/model-inference-history", self.model_inference_history, methods=["GET"] + ) + self.router.add_api_route( + "/operator-inference-history", + self.operator_inference_history, + methods=["GET"], + ) + self.router.add_api_route( + "/user-inference-history", self.user_inference_history, methods=["GET"] + ) + self.router.add_api_route( + "/state-inference-history", self.state_inference_history, methods=["GET"] + ) + self.router.add_api_route( + "/inference-stats", self.inference_stats, methods=["GET"] + ) + self.router.add_api_route("/nodes", self.nodes_list, methods=["GET"]) + self.router.add_api_route("/tvl", self.tvl, methods=["GET"]) + self.router.add_api_route("/fees", self.fees_accumulated, methods=["GET"]) self.app.include_router(self.router) + async def health(self): + return {"status": "running"} + async def submit_proof(self, data: ProofRequest): try: await run_in_threadpool( @@ -44,9 +99,628 @@ async def submit_proof(self, data: ProofRequest): except InvalidProofError as exc: raise HTTPException(status_code=400, detail=str(exc)) - async def health(self): - return {"status": "running"} + async def models_list(self): + try: + # Get all active models with their details in a single contract call + models_with_details = await run_in_threadpool( + self.eth_client.model_registry.functions.getActiveModelsWithDetails().call + ) - def start(self): - host, port = self.aggregator.config.aggregator_server_ip_port_address.split(":") - uvicorn.run(self.app, host=host, port=int(port)) + # Convert the contract response to a more readable format + models = [] + for model_data in models_with_details: + models.append( + { + "id": model_data[0], # modelId + "name": model_data[1], # modelName + "verifier": model_data[2], # modelVerifier + "verification_strategy": model_data[3], # verificationStrategy + "compute_cost": model_data[4], # computeCost + "required_fucus": model_data[5], # requiredFUCUs + "is_active": model_data[6], # isActive + } + ) + + return models + except Exception as exc: + raise HTTPException( + status_code=500, detail=f"Failed to retrieve models: {str(exc)}" + ) + + async def model_inference_history( + self, + model_id: int = Query(..., description="Model ID to get inference history for"), + offset: int = Query(0, ge=0, description="Pagination offset"), + limit: int = Query(20, ge=1, le=100, description="Pagination limit (max 100)"), + include_details: bool = Query(False, description="Include full task details"), + ): + """ + Get inference task history for a specific model with pagination. + + Returns paginated list of task IDs and optionally full task details. + """ + return await self.inference_history( + model_id=model_id, + offset=offset, + limit=limit, + include_details=include_details, + ) + + async def operator_inference_history( + self, + operator: str = Query( + ..., + pattern=ADDRESS_REGEXP, + description="Operator address to get inference history for", + ), + offset: int = Query(0, ge=0, description="Pagination offset"), + limit: int = Query(20, ge=1, le=100, description="Pagination limit (max 100)"), + include_details: bool = Query(False, description="Include full task details"), + ): + """ + Get inference task history for a specific operator with pagination. + + Returns paginated list of task IDs and optionally full task details. + """ + return await self.inference_history( + operator=operator, + offset=offset, + limit=limit, + include_details=include_details, + ) + + async def user_inference_history( + self, + user: str = Query( + ..., + pattern=ADDRESS_REGEXP, + description="User address to get inference history for", + ), + offset: int = Query(0, ge=0, description="Pagination offset"), + limit: int = Query(20, ge=1, le=100, description="Pagination limit (max 100)"), + include_details: bool = Query(False, description="Include full task details"), + ): + """ + Get inference task history for a specific user with pagination. + + Returns paginated list of task IDs and optionally full task details. + """ + return await self.inference_history( + user=user, offset=offset, limit=limit, include_details=include_details + ) + + async def state_inference_history( + self, + state: int = Query( + ..., + description=( + "Task state to get inference history for " + "(0=CREATED, 1=ASSIGNED, 2=COMPLETED, 3=CHALLENGED, 4=REJECTED, 5=RESOLVED)" + ), + ), + offset: int = Query(0, ge=0, description="Pagination offset"), + limit: int = Query(20, ge=1, le=100, description="Pagination limit (max 100)"), + include_details: bool = Query(False, description="Include full task details"), + ): + """ + Get inference task history for a specific task state with pagination. + + Returns paginated list of task IDs and optionally full task details. + """ + return await self.inference_history( + state=state, offset=offset, limit=limit, include_details=include_details + ) + + async def inference_history( + self, + model_id: Optional[int] = None, + operator: Optional[str] = None, + user: Optional[str] = None, + state: Optional[int] = None, + offset: int = 0, + limit: int = 20, + include_details: bool = False, + ): + """ + Get inference task history with filtering and pagination. + + Returns paginated list of task IDs and optionally full task details. + Can filter by model, operator, user, or task state. + """ + try: + + # Determine which filtering method to use based on provided parameters + task_ids = await run_in_threadpool( + self._get_filtered_task_ids, + model_id, + operator, + user, + state, + offset, + limit, + ) + + # If no task details requested, return just the IDs + if not include_details: + return { + "tasks": task_ids, + "pagination": { + "offset": offset, + "limit": limit, + "returned_count": len(task_ids), + }, + } + + # Get full task details for each task ID + tasks = [] + for task_id in task_ids: + task_data = await run_in_threadpool( + self.eth_client.task_manager.functions.getTask(task_id).call + ) + tasks.append(self._format_task_data(task_id, task_data)) + + return { + "tasks": tasks, + "pagination": { + "offset": offset, + "limit": limit, + "returned_count": len(tasks), + }, + } + + except Exception as exc: + raise HTTPException( + status_code=500, + detail=f"Failed to retrieve inference history: {str(exc)}", + ) + + async def inference_stats(self): + """ + Get inference task statistics with optional filtering. + + Returns aggregated statistics about tasks including totals, success rates, etc. + Can filter by model, operator, or user. + """ + try: + stats = await run_in_threadpool(self._get_task_history_stats) + return {"stats": stats} + except Exception as exc: + raise HTTPException( + status_code=500, + detail=f"Failed to retrieve inference statistics: {str(exc)}", + ) + + async def nodes_list(self): + """ + Get list of all active nodes with full details. + + Returns list of active nodes with complete details including + operator, name, metadata, FUCUS allocation, and supported models. + """ + try: + # Use the efficient contract function to get all active nodes with details + ( + node_details_arrays, + supported_models_arrays, + model_allocations_arrays, + ) = await run_in_threadpool( + self.eth_client.nodes_manager.functions.getAllNodesWithDetails().call + ) + + nodes = [] + for i, node_details in enumerate(node_details_arrays): + # Unpack the uint256[8] array + node_id = node_details[0] + operator_uint = node_details[1] + total_fucus = node_details[2] + allocated_fucus = node_details[3] + available_fucus = node_details[4] + is_active_int = node_details[5] + created_at = node_details[6] + supported_models_count = node_details[7] + + # Convert operator back to address + operator = f"0x{operator_uint:040x}" + is_active = is_active_int == 1 + + # Get name and metadata from individual contract call + node_data = await run_in_threadpool( + self.eth_client.nodes_manager.functions.nodes(node_id).call + ) + name = node_data[2] + metadata = node_data[3] + + # Build model configurations + model_configs = [] + for j, model_id in enumerate(supported_models_arrays[i]): + model_configs.append( + { + "model_id": model_id, + "allocated_fucus": model_allocations_arrays[i][j], + } + ) + + nodes.append( + { + "node_id": node_id, + "operator": operator, + "name": name, + "metadata": metadata, + "total_fucus": total_fucus, + "allocated_fucus": allocated_fucus, + "available_fucus": available_fucus, + "is_active": is_active, + "created_at": created_at, + "supported_models_count": supported_models_count, + "supported_models": model_configs, + } + ) + + return nodes + + except Exception as exc: + raise HTTPException( + status_code=500, detail=f"Failed to retrieve nodes: {str(exc)}" + ) + + async def tvl(self): + """ + Get Total Value Locked (TVL) for the AVS + Returns TVL by strategy, showing the total shares delegated. + """ + + try: + return await run_in_threadpool(self._get_avs_shares) + except Exception as exc: + raise HTTPException( + status_code=500, detail=f"Failed to calculate TVL: {str(exc)}" + ) + + async def fees_accumulated( + self, + hours: int = Query(24, description="Time window in hours to query events"), + ) -> dict: + """ + Get all fees accumulated by operators for a specified time period. + """ + try: + result = await run_in_threadpool( + self._get_rewards_accumulated_events, hours + ) + return result + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Error retrieving fees accumulated: {str(e)}", + ) + + @cached_method(ttl=3600) + def _get_rewards_accumulated_events(self, hours: int = 24) -> dict: + """ + Query TaskRewardAccumulated events from the ServiceManager contract. + Cached for 1 hour. + + Event signature: + TaskRewardAccumulated( + address indexed operator, + uint256 fee, + IERC20 token, + uint32 currentInterval + ) + """ + # Calculate block range for the last N hours + current_block = self.eth_client.w3.eth.block_number + + # Estimate blocks to go back (assuming ~12 seconds per block) + blocks_per_hour = 300 # 3600 seconds / 12 seconds per block + blocks_to_query = blocks_per_hour * hours + from_block = max(0, current_block - blocks_to_query) + + # Get TaskRewardAccumulated events + event_filter = ( + self.eth_client.service_manager.events.TaskRewardAccumulated.create_filter( + from_block=from_block, to_block=current_block + ) + ) + events = event_filter.get_all_entries() + + # Operator splits to calculate actual operator earnings + operator_splits = {} + # Aggregate rewards by operator and token + rewards_by_operator = {} + total_rewards_by_token = {} # token_address -> total_amount + # token_address -> sum of operators' shares + total_operators_rewards_by_token = {} + + for event in events: + operator = event.args.operator + fee = event.args.fee + token = event.args.token + interval = event.args.currentInterval + + # Initialize operator entry if not exists + if operator not in operator_splits: + operator_splits[operator] = self._get_operator_split(operator) + split_bips = operator_splits[operator] # Operator's share in bips (0-10000) + + # Initialize operator entry if not exists + if operator not in rewards_by_operator: + rewards_by_operator[operator] = {} + + # Initialize token entry for operator if not exists + if token not in rewards_by_operator[operator]: + rewards_by_operator[operator][token] = { + "operator_split_bips": split_bips, + "total_accumulated": 0, + "operator_share": 0, + "stakers_share": 0, + "count": 0, + "intervals": set(), + } + + # Accumulate rewards + rewards_by_operator[operator][token]["total_accumulated"] += fee + rewards_by_operator[operator][token]["operator_share"] += ( + fee * split_bips + ) // 10000 + rewards_by_operator[operator][token]["stakers_share"] += ( + fee - (fee * split_bips) // 10000 + ) + rewards_by_operator[operator][token]["count"] += 1 + rewards_by_operator[operator][token]["intervals"].add(interval) + + # Aggregate total by token + if token not in total_rewards_by_token: + total_rewards_by_token[token] = 0 + total_operators_rewards_by_token[token] = 0 + total_rewards_by_token[token] += fee + total_operators_rewards_by_token[token] += rewards_by_operator[operator][ + token + ]["operator_share"] + + return { + "total_rewards_by_token": total_rewards_by_token, + "rewards_by_operator": rewards_by_operator, + "operator_splits": operator_splits, + } + + @cached_method() + def _get_operator_split(self, operator: str) -> int: + """ + Get the operator's split (commission) in basis points for this AVS's operator set. + + The split represents the percentage of rewards the operator keeps vs what goes to stakers. + Returns value in basis points (bips) where 10000 = 100%. + For example: 1000 bips = 10% goes to operator, 90% to stakers + """ + try: + # Get the operator's split for this AVS's operator set + # operator_set is a tuple of (avs_address, operator_set_id) + split_bips = ( + self.eth_client.rewards_coordinator.functions.getOperatorSetSplit( + operator, + self.operator_set, # (SERVICE_MANAGER_ADDRESS, OPERATOR_SET_ID) + ).call() + ) + return split_bips + except Exception as exc: + logger.exception(f"Failed to get operator split for {operator}") + raise exc + + @cached_method(ttl=300) + def _get_filtered_task_ids( + self, + model_id: Optional[int], + operator: Optional[str], + user: Optional[str], + state: Optional[int], + offset: int = 0, + limit: int = 20, + ) -> list[int]: + """Get filtered task IDs based on provided filters.""" + + # Priority order for filtering (most specific first) + if state is not None: + # Filter by task state + return self.eth_client.task_manager.functions.getTasksByState( + state, offset, limit + ).call() + elif model_id is not None: + # Filter by model ID + return self.eth_client.task_manager.functions.getTasksByModel( + model_id, offset, limit + ).call() + elif operator is not None: + # Filter by operator address + return self.eth_client.task_manager.functions.getTasksByOperator( + operator, offset, limit + ).call() + elif user is not None: + # Filter by user address + return self.eth_client.task_manager.functions.getTasksByUser( + user, offset, limit + ).call() + else: + # No filters provided, return just latest tasks with pagination + tasksCount = self.eth_client.task_manager.functions.taskNonce().call() - 1 + if offset >= tasksCount: + return [] + returnListLength = ( + limit if (offset + limit) <= tasksCount else (tasksCount - offset) + ) + returnList = [] + for i in range(returnListLength): + returnList.append(tasksCount - offset - i) + return returnList + + @cached_method(ttl=300) + def _get_task_history_stats(self) -> dict: + """Get task history statistics. Cached for 5 minutes.""" + + # Get stats from contract + total_tasks, completed_tasks, rejected_tasks, pending_tasks = ( + self.eth_client.task_manager.functions.getTaskHistoryStats().call() + ) + + return { + "total_tasks": total_tasks, + "completed_tasks": completed_tasks, # RESOLVED tasks + "rejected_tasks": rejected_tasks, + "pending_tasks": pending_tasks, # ASSIGNED + CHALLENGED + COMPLETED + "success_rate": ( + round(completed_tasks / total_tasks * 100, 2) if total_tasks > 0 else 0 + ), + } + + def _format_task_data(self, task_id: int, task_data: tuple) -> dict: + """Format raw contract task data into a readable dictionary.""" + + return { + "task_id": task_id, + "start_block": task_data[TaskStructMap.START_BLOCK], + "start_timestamp": task_data[TaskStructMap.START_TIME], + "model_id": task_data[TaskStructMap.MODEL_ID], + "inputs": ( + task_data[TaskStructMap.INPUTS].decode("utf-8", errors="ignore") + if task_data[TaskStructMap.INPUTS] + else "" + ), + "proof_hash": ( + task_data[TaskStructMap.PROOF_HASH].hex() + if task_data[TaskStructMap.PROOF_HASH] + else "" + ), + "user": task_data[TaskStructMap.USER], + "nonce": task_data[TaskStructMap.NONCE], + "operator": task_data[TaskStructMap.OPERATOR], + "state": { + "value": task_data[TaskStructMap.STATE], + "name": TaskStateMap.from_int(task_data[TaskStructMap.STATE]).name, + }, + "output": ( + task_data[TaskStructMap.OUTPUT].decode("utf-8", errors="ignore") + if task_data[TaskStructMap.OUTPUT] + else "" + ), + "fee": task_data[TaskStructMap.FEE], + } + + @cached_method() + def _get_avs_shares(self) -> dict: + """ + Get AVS shares across all strategies and aggregate by strategy and operator. + Cached for 24 hours + """ + # Get all operators registered to the AVS operator set + operators = self.eth_client.allocation_manager.functions.getMembers( + self.operator_set + ).call() + + # Batch call to get all operators' shares across all strategies + # Returns uint256[][] - array of arrays where operators_shares[i][j] is + # operator[i]'s shares in strategy[j] + operators_shares = ( + ( + self.eth_client.delegation_manager.functions.getOperatorsShares( + operators, self.strategies + ).call() + ) + if operators + else [] + ) + + # Get strategy metadata (token details) for all strategies + strategy_metadata: dict[str, dict[str, str]] = self._get_strategies_details( + self.strategies + ) + + # Aggregate shares by strategy + # strategy_addr -> {total_shares, total_amount, token, symbol, decimals} + tvl_by_strategy = {} + # Also prepare breakdown by operator + # operator_addr -> [(token_addr, symbol, amount), ...] + tvl_by_operator = {} + + for op_idx, operator in enumerate(operators): + operator_tokens = [] + + for strategy_idx, strategy_address in enumerate(self.strategies): + shares = operators_shares[op_idx][strategy_idx] + + if shares > 0: + # Initialize strategy entry if not exists + if strategy_address not in tvl_by_strategy: + metadata = strategy_metadata.get(strategy_address, {}) + tvl_by_strategy[strategy_address] = { + "total_shares": 0, + "total_amount": 0.0, + "operators_count": 0, + **metadata, + } + + # Add to strategy total + tvl_by_strategy[strategy_address]["total_shares"] += shares + tvl_by_strategy[strategy_address]["operators_count"] += 1 + + # Add to operator's token breakdown + operator_tokens.append( + { + "strategy": strategy_address, + "token": tvl_by_strategy[strategy_address]["token"], + "symbol": tvl_by_strategy[strategy_address]["symbol"], + "shares": shares, + "amount": shares, + } + ) + + # Update human-readable amount if we have decimals + if "decimals" in tvl_by_strategy[strategy_address]: + decimals = tvl_by_strategy[strategy_address]["decimals"] + tvl_by_strategy[strategy_address]["total_amount"] = ( + tvl_by_strategy[strategy_address]["total_shares"] + / (10**decimals) + ) + operator_tokens[-1]["amount"] = shares / (10**decimals) + else: + tvl_by_strategy[strategy_address]["total_amount"] += shares + + if operator_tokens: + tvl_by_operator[operator] = operator_tokens + + return { + "total_operators": len(operators), + "active_operators": len(tvl_by_operator), + "strategies_count": len(self.strategies), + "tvl_by_strategy": tvl_by_strategy, + "tvl_by_operator": tvl_by_operator, + } + + @cached_method() + def _get_strategies_details( + self, strategies: list[str] + ) -> dict[str, dict[str, str]]: + """ + Get strategy details including underlying token symbol and decimals. + Returns a mapping of strategy address to its details (token, symbol, decimals). + Cached for 24 hours as this rarely changes. + """ + strategies_metadata = {} + for strategy_address in strategies: + strategy_contract = self.eth_client.w3.eth.contract( + address=strategy_address, + abi=STRATEGY_ABI, + ) + token_address = strategy_contract.functions.underlyingToken().call() + token_contract = self.eth_client.w3.eth.contract( + address=token_address, + abi=ERC20_ABI, + ) + symbol = token_contract.functions.symbol().call() + decimals = token_contract.functions.decimals().call() + + strategies_metadata[strategy_address] = { + "token": token_address, + "symbol": symbol, + "decimals": decimals, + } + return strategies_metadata diff --git a/client/src/avs_operator/main.py b/client/src/avs_operator/main.py index cd262b4..8ba5a23 100644 --- a/client/src/avs_operator/main.py +++ b/client/src/avs_operator/main.py @@ -12,11 +12,12 @@ from web3 import Web3 from avs_operator.nodes import OperatorNodesManager +from common.addresses import addresses from common.auto_update import AutoUpdate from common.config import OperatorConfig -from common.logging import get_logger from common.contract_constants import TaskStructMap from common.eth import EthereumClient, load_ecdsa_private_key +from common.logging import get_logger from models.onnx_run import run_onnx from models.proof.ezkl_handler import EZKLHandler @@ -39,6 +40,7 @@ def run_operator(config: OperatorConfig) -> None: class TaskOperator: def __init__(self, config: OperatorConfig): self.config = config + addresses.init_addresses(chain_id=self.config.chain_id) self.eth_client = EthereumClient( eth_rpc_url=self.config.eth_rpc_url, gas_strategy=self.config.gas_strategy ) diff --git a/client/src/common/addresses.py b/client/src/common/addresses.py new file mode 100644 index 0000000..482d4d5 --- /dev/null +++ b/client/src/common/addresses.py @@ -0,0 +1,96 @@ +import json +from typing import Callable, Optional + +from common.constants import CONTRACTS_DIR + + +def address_property(attr_name: str): + def decorator(func: Callable) -> property: + def getter(self) -> any: + self._check_initialized() + return getattr(self, attr_name) + + return property(getter) + + return decorator + + +class AddressManager: + """Singleton manager for contract addresses.""" + + _instance: Optional["AddressManager"] = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + self.task_manager: Optional[str] = None + self.service_manager: Optional[str] = None + self.allocation_manager: Optional[str] = None + self.strategies: Optional[list[str]] = None + self.eth_strategies: Optional[list[str]] = None + + def init_addresses(self, chain_id: int) -> None: + """Initialize contract addresses based on the given chain ID.""" + deployment_file = ( + CONTRACTS_DIR / "deployments" / f"sertnDeployment_{chain_id}.json" + ) + + try: + with open(deployment_file) as f: + deployment_info = json.load(f) + self.task_manager = deployment_info["sertnTaskManager"] + self.service_manager = deployment_info["sertnServiceManager"] + self.allocation_manager = deployment_info["allocationManager"] + self.strategies = [ + deployment_info["strategy_0"], + deployment_info["strategy_1"], + deployment_info["strategy_2"], + ] + self.eth_strategies = [ + deployment_info["eth_strategy_0"], + deployment_info["eth_strategy_1"], + ] + self._initialized = True + except FileNotFoundError: + raise FileNotFoundError( + f"Deployment file for chain ID {chain_id} not found: {deployment_file}. " + "Did you forget to deploy?" + ) + except KeyError as e: + raise KeyError( + f"Missing expected key in deployment file for chain ID {chain_id}: {e}" + ) + + def _check_initialized(self): + if not self._initialized: + raise RuntimeError( + "AddressManager not initialized. Call addresses.init(chain_id) first." + ) + + @address_property("task_manager") + def TASK_MANAGER_ADDRESS(self) -> str: + pass + + @address_property("service_manager") + def SERVICE_MANAGER_ADDRESS(self) -> str: + pass + + @address_property("allocation_manager") + def ALLOCATION_MANAGER_ADDRESS(self) -> str: + pass + + @address_property("strategies") + def STRATEGIES_ADDRESSES(self) -> list[str]: + pass + + @address_property("eth_strategies") + def ETH_STRATEGY_ADDRESSES(self) -> list[str]: + pass + + +# Global singleton instance +addresses = AddressManager() diff --git a/client/src/common/cache.py b/client/src/common/cache.py new file mode 100644 index 0000000..e753490 --- /dev/null +++ b/client/src/common/cache.py @@ -0,0 +1,481 @@ +""" +Hybrid caching solution supporting both Memcached (Cloud Memorystore) and in-memory caching. +Automatically falls back to in-memory cache if Memcached is unavailable. +""" + +import functools +import json +from abc import ABC, abstractmethod +from typing import Any, Callable, Optional + +import requests +from cachetools import TTLCache +from pymemcache.client.base import Client as MemcacheClient +from pymemcache.exceptions import MemcacheError + +from common.config import CacheConfig +from common.logging import get_logger + +logger = get_logger("cache") + + +class CacheBackend(ABC): + """Abstract base class for cache backends""" + + @abstractmethod + def get(self, key: str) -> Optional[Any]: + """Get value from cache""" + pass + + @abstractmethod + def set(self, key: str, value: Any, ttl: int = 300) -> bool: + """Set value in cache with TTL in seconds""" + pass + + @abstractmethod + def delete(self, key: str) -> bool: + """Delete value from cache""" + pass + + @abstractmethod + def exists(self, key: str) -> bool: + """Check if key exists in cache""" + pass + + @abstractmethod + def clear(self) -> bool: + """Clear all cache entries""" + pass + + +class MemcachedCache(CacheBackend): + """Memcached-based cache backend for Cloud Memorystore""" + + def __init__( + self, memcached_host, memcached_port, connect_timeout=2.0, timeout=2.0 + ): + self.client = MemcacheClient( + (memcached_host, memcached_port), + connect_timeout=connect_timeout, + timeout=timeout, + ) + + # Test connection with a simple operation + self.client.version() + logger.debug( + f"Successfully connected to Memcached at {memcached_host}:{memcached_port}" + ) + + def get(self, key: str) -> Optional[Any]: + try: + return self.client.get(key) + except MemcacheError as e: + logger.error(f"Memcached get error for key {key}: {e}") + return None + except Exception as e: + logger.error(f"Unexpected error getting key {key}: {e}") + return None + + def set(self, key: str, value: Any, ttl: int = 300) -> bool: + try: + # Memcached expects TTL as expire time (0 = never expire, >0 = seconds) + self.client.set(key=key, value=value, expire=ttl) + return True + except MemcacheError as e: + logger.error(f"Memcached set error for key {key}: {e}") + return False + except Exception as e: + logger.error(f"Unexpected error setting key {key}: {e}") + return False + + def delete(self, key: str) -> bool: + try: + self.client.delete(key) + return True + except MemcacheError as e: + logger.error(f"Memcached delete error for key {key}: {e}") + return False + except Exception as e: + logger.error(f"Unexpected error deleting key {key}: {e}") + return False + + def exists(self, key: str) -> bool: + try: + # Memcached doesn't have a native exists operation + # We need to get the value to check existence + return self.client.get(key) is not None + except MemcacheError as e: + logger.error(f"Memcached exists error for key {key}: {e}") + return False + except Exception as e: + logger.error(f"Unexpected error checking key {key}: {e}") + return False + + def clear(self) -> bool: + try: + self.client.flush_all() + return True + except MemcacheError as e: + logger.error(f"Memcached clear error: {e}") + return False + except Exception as e: + logger.error(f"Unexpected error clearing cache: {e}") + return False + + +class InMemoryCache(CacheBackend): + """In-memory cache backend using cachetools""" + + def __init__(self, maxsize: int = 1000, default_ttl: int = 300): + self.cache = TTLCache(maxsize=maxsize, ttl=default_ttl) + logger.info( + f"In-memory cache backend initialized (maxsize={maxsize}, ttl={default_ttl}s)" + ) + + def get(self, key: str) -> Optional[Any]: + try: + return self.cache.get(key) + except Exception as e: + logger.error(f"In-memory get error for key {key}: {e}") + return None + + def set(self, key: str, value: Any, ttl: int = 300) -> bool: + try: + # Note: TTLCache uses a single TTL for all entries + # For per-key TTL, you'd need a more complex implementation + self.cache[key] = value + return True + except Exception as e: + logger.error(f"In-memory set error for key {key}: {e}") + return False + + def delete(self, key: str) -> bool: + try: + if key in self.cache: + del self.cache[key] + return True + except Exception as e: + logger.error(f"In-memory delete error for key {key}: {e}") + return False + + def exists(self, key: str) -> bool: + return key in self.cache + + def clear(self) -> bool: + try: + self.cache.clear() + return True + except Exception as e: + logger.error(f"In-memory clear error: {e}") + return False + + +class CloudflareKVCache(CacheBackend): + """ + Cloudflare Workers KV cache backend + """ + + def __init__( + self, + account_id: str, + namespace_id: str, + api_token: str, + timeout: float = 5.0, + ): + self.account_id = account_id + self.namespace_id = namespace_id + self.api_token = api_token + self.timeout = timeout + self.base_url = f"https://api.cloudflare.com/client/v4/accounts/{account_id}/storage/kv/namespaces/{namespace_id}" + self.headers = { + "Authorization": f"Bearer {api_token}", + "Content-Type": "application/json", + } + logger.info( + f"Cloudflare KV cache backend initialized (namespace: {namespace_id})" + ) + + def get(self, key: str) -> Optional[Any]: + try: + url = f"{self.base_url}/values/{key}" + response = requests.get(url, headers=self.headers, timeout=self.timeout) + + if response.status_code == 404: + return None + + if response.status_code != 200: + logger.error( + f"Cloudflare KV get error for key {key}: HTTP {response.status_code}" + ) + return None + + # Cloudflare KV stores raw bytes, deserialize with pickle + return json.loads(response.content) + + except requests.RequestException as e: + logger.error(f"Cloudflare KV get error for key {key}: {e}") + return None + except Exception as e: + logger.error(f"Unexpected error getting key {key}: {e}") + return None + + def set(self, key: str, value: Any, ttl: int = 300) -> bool: + try: + # Cloudflare KV uses metadata and expiration_ttl + params = {"expiration_ttl": ttl} if ttl > 0 else {} + + response = requests.put( + f"{self.base_url}/values/{key}", + data=json.dumps(value), + headers=self.headers, + params=params, + timeout=self.timeout, + ) + + if response.status_code not in (200, 201): + logger.error( + f"Cloudflare KV set error for key {key}: HTTP {response.status_code}" + ) + return False + + return True + + except requests.RequestException as e: + logger.error(f"Cloudflare KV set error for key {key}: {e}") + return False + except Exception as e: + logger.error(f"Unexpected error setting key {key}: {e}") + return False + + def delete(self, key: str) -> bool: + try: + url = f"{self.base_url}/values/{key}" + response = requests.delete(url, headers=self.headers, timeout=self.timeout) + + if response.status_code not in (200, 404): + logger.error( + f"Cloudflare KV delete error for key {key}: HTTP {response.status_code}" + ) + return False + + return True + + except requests.RequestException as e: + logger.error(f"Cloudflare KV delete error for key {key}: {e}") + return False + except Exception as e: + logger.error(f"Unexpected error deleting key {key}: {e}") + return False + + def exists(self, key: str) -> bool: + try: + url = f"{self.base_url}/values/{key}" + response = requests.head(url, headers=self.headers, timeout=self.timeout) + return response.status_code == 200 + + except requests.RequestException as e: + logger.error(f"Cloudflare KV exists error for key {key}: {e}") + return False + except Exception as e: + logger.error(f"Unexpected error checking key {key}: {e}") + return False + + def clear(self) -> bool: + """ + Clear all entries in the namespace. + Note: This lists all keys and deletes them one by one. + For production use, consider using Cloudflare's bulk delete API. + """ + try: + # List all keys in the namespace + url = f"{self.base_url}/keys" + response = requests.get(url, headers=self.headers, timeout=self.timeout) + + if response.status_code != 200: + logger.error(f"Cloudflare KV clear error: HTTP {response.status_code}") + return False + + data = response.json() + keys = [item["name"] for item in data.get("result", [])] + + # Delete each key + for key in keys: + self.delete(key) + + logger.info(f"Cleared {len(keys)} keys from Cloudflare KV") + return True + + except requests.RequestException as e: + logger.error(f"Cloudflare KV clear error: {e}") + return False + except Exception as e: + logger.error(f"Unexpected error clearing cache: {e}") + return False + + +class NoOpCache(CacheBackend): + """No-op cache backend that disables caching entirely""" + + def __init__(self): + logger.info("Caching is disabled (NoOpCache)") + + def get(self, key: str) -> Optional[Any]: + return None + + def set(self, key: str, value: Any, ttl: int = 300) -> bool: + return True + + def delete(self, key: str) -> bool: + return True + + def exists(self, key: str) -> bool: + return False + + def clear(self) -> bool: + return True + + +def get_cache(config: CacheConfig) -> CacheBackend: + """ + Factory function that returns a cache backend instance. + Priority order: Cloudflare KV > Memcached > In-memory cache. + + Args: + config: CacheConfig instance with cache configuration + + Returns: + CacheBackend: CloudflareKVCache, MemcachedCache, InMemoryCache, or NoOpCache instance + """ + # If caching is disabled, return no-op cache + if config.disable: + return NoOpCache() + + # Try Cloudflare KV first if configuration is provided + if ( + config.cloudflare_account_id + and config.cloudflare_namespace_id + and config.cloudflare_api_token + ): + try: + kv_cache = CloudflareKVCache( + account_id=config.cloudflare_account_id, + namespace_id=config.cloudflare_namespace_id, + api_token=config.cloudflare_api_token, + timeout=config.timeout, + ) + # Test connection with a simple operation + kv_cache.exists("__test_connection__") + logger.info("Successfully connected to Cloudflare KV") + return kv_cache + + except requests.RequestException as e: + logger.warning( + f"Failed to connect to Cloudflare KV: {e}. Trying Memcached..." + ) + except Exception as e: + logger.warning( + f"Unexpected error connecting to Cloudflare KV: {e}. Trying Memcached..." + ) + + # Try Memcached if configuration is provided + if config.memcached_host and config.memcached_port: + try: + return MemcachedCache( + config.memcached_host, + config.memcached_port, + config.connect_timeout, + config.timeout, + ) + + except MemcacheError as e: + logger.warning( + f"Failed to connect to Memcached: {e}. Falling back to in-memory cache." + ) + except Exception as e: + logger.warning( + f"Unexpected error connecting to Memcached: {e}. Falling back to in-memory cache." + ) + + # Fallback to in-memory cache + logger.info( + "Using in-memory cache (no external cache configuration or connection failed)" + ) + return InMemoryCache( + maxsize=config.fallback_maxsize, default_ttl=config.fallback_ttl + ) + + +def cached_method(ttl: int = 86400, key_prefix: Optional[str] = None): + """ + Decorator for caching method results. + + The cache key is automatically generated from: + - key_prefix (or method name if not provided) + - method arguments (args and kwargs) + + Usage: + class MyClass: + def __init__(self, cache: CacheBackend): + self.cache = cache + + @cached_method(ttl=600, key_prefix="operator_split") + def get_operator_split(self, operator: str) -> int: + # expensive operation + return result + + Args: + ttl: Time-to-live for cached value in seconds + key_prefix: Optional prefix for cache key (defaults to method name) + + Returns: + Decorated method that uses caching + """ + + def decorator(method: Callable) -> Callable: + method_name = method.__name__ + prefix = key_prefix or method_name + + @functools.wraps(method) + def wrapper(self, *args, **kwargs): + # Check if instance has a cache attribute + if not hasattr(self, "cache"): + # No cache available, just call the method + return method(self, *args, **kwargs) + + cache: CacheBackend = self.cache + + # Generate cache key from method name and arguments + # Convert args and kwargs to a hashable representation + try: + # Try to create a simple string key from arguments + args_str = "_".join(str(arg) for arg in args) + kwargs_str = "_".join(f"{k}={v}" for k, v in sorted(kwargs.items())) + key_parts = [prefix] + if args_str: + key_parts.append(args_str) + if kwargs_str: + key_parts.append(kwargs_str) + cache_key = ":".join(key_parts) + except Exception as e: + logger.warning(f"Failed to generate cache key for {method_name}: {e}") + # Fall back to calling method without caching + return method(self, *args, **kwargs) + + # Try to get from cache + cached_value = cache.get(cache_key) + if cached_value is not None: + logger.debug(f"Cache hit for {cache_key}") + return cached_value + + # Cache miss - call the method + logger.debug(f"Cache miss for {cache_key}") + result = method(self, *args, **kwargs) + + # Store in cache + cache.set(cache_key, result, ttl=ttl) + + return result + + return wrapper + + return decorator diff --git a/client/src/common/config.py b/client/src/common/config.py index 4e3b2b8..03a377b 100644 --- a/client/src/common/config.py +++ b/client/src/common/config.py @@ -73,6 +73,55 @@ def validate_fucus_allocation(self) -> "NodeConfig": return self +class CacheConfig(BaseModel): + """ + Configuration for caching backend. + User can use different cache options: Memcached, Cloudflare KV, or in-memory fallback + All caching options are optional. + """ + + disable: bool = Field( + default=False, + description="Enable caching (set to False to disable all caching)", + ) + memcached_host: Optional[str] = Field( + default=None, + description="Memcached host (e.g., IP address from Memorystore)", + ) + memcached_port: int = Field( + default=11211, + description="Memcached port", + ) + cloudflare_account_id: Optional[str] = Field( + default=None, + description="Cloudflare account ID for KV storage", + ) + cloudflare_namespace_id: Optional[str] = Field( + default=None, + description="Cloudflare KV namespace ID", + ) + cloudflare_api_token: Optional[str] = Field( + default=None, + description="Cloudflare API token with KV permissions", + ) + connect_timeout: float = Field( + default=2.0, + description="Cache connection timeout in seconds", + ) + timeout: float = Field( + default=2.0, + description="Cache operation timeout in seconds", + ) + fallback_maxsize: int = Field( + default=1000, + description="Max size for in-memory fallback cache", + ) + fallback_ttl: int = Field( + default=300, + description="TTL (in seconds) for in-memory fallback cache entries", + ) + + class BaseConfig(BaseModel): """Base configuration shared by both operator and aggregator.""" @@ -82,6 +131,11 @@ class BaseConfig(BaseModel): default=Environment.PRODUCTION, description="Environment setting for logging and behavior", ) + chain_id: int = Field( + description="Ethereum chain ID to connect to", + default=31337, # Hardhat local network + ge=1, + ) eth_rpc_url: str = Field(description="Ethereum RPC URL") ecdsa_private_key_store_path: Path = Field( ..., description="Path to ECDSA private key file" @@ -94,6 +148,10 @@ class BaseConfig(BaseModel): default=True, description="Enable automatic updates for the application", ) + caching: CacheConfig = Field( + default_factory=CacheConfig, + description="Caching configuration", + ) @field_validator("ecdsa_private_key_store_path") @classmethod diff --git a/client/src/common/constants.py b/client/src/common/constants.py index 23b8719..255a37d 100644 --- a/client/src/common/constants.py +++ b/client/src/common/constants.py @@ -2,6 +2,9 @@ import json from pathlib import Path +ZERO_ADDRESS = "0x0000000000000000000000000000000000000000" +ADDRESS_REGEXP = r"^0x[a-fA-F0-9]{40}$" + ROOT_DIR = Path(__file__).parent.parent.parent.parent CLIENT_PATH = ROOT_DIR / "client" CLIENT_SRC_PATH = CLIENT_PATH / "src" @@ -19,22 +22,6 @@ TEMP_FOLDER.mkdir(parents=True, exist_ok=True) PROOFS_FOLDER.mkdir(parents=True, exist_ok=True) -# contracts addresses: -with open(CONTRACTS_DIR / "deployments" / "sertnDeployment.json") as f: - deployment_info = json.load(f) - TASK_MANAGER_ADDRESS = deployment_info["sertnTaskManager"] - SERVICE_MANAGER_ADDRESS = deployment_info["sertnServiceManager"] - ALLOCATION_MANAGER_ADDRESS = deployment_info["allocationManager"] - STRATEGIES_ADDRESSES = [ - deployment_info["strategy_0"], - deployment_info["strategy_1"], - deployment_info["strategy_2"], - ] - ETH_STRATEGY_ADDRESSES = [ - deployment_info["eth_strategy_0"], - deployment_info["eth_strategy_1"], - ] - IGNORED_MODEL_HASHES = [] # Queue size limits @@ -52,4 +39,8 @@ CIRCUIT_TIMEOUT_SECONDS = 60 # Operator set ID +# `OperatorSet` is a struct in the AllocationManager contract that groups operators within AVS. +# It contains only an ID (uint256) and an AVS address (address). +# At the time of writing, we only have one AVS, so we use the default ID of 0. +# If multiple AVSs are introduced in the future, this will need to be updated accordingly. OPERATOR_SET_ID = 0 diff --git a/client/src/common/eth.py b/client/src/common/eth.py index f2ae0ff..e58ac43 100644 --- a/client/src/common/eth.py +++ b/client/src/common/eth.py @@ -19,11 +19,7 @@ ) from common.config import GasStrategy from common.logging import get_logger -from common.constants import ( - ALLOCATION_MANAGER_ADDRESS, - SERVICE_MANAGER_ADDRESS, - TASK_MANAGER_ADDRESS, -) +from common.addresses import addresses from common.gas_strategy import get_gas_config logger = get_logger("common") @@ -72,14 +68,14 @@ def init_contracts(self) -> None: Initialize contracts """ # Service manager - self.check_contract_deployed(SERVICE_MANAGER_ADDRESS) + self.check_contract_deployed(addresses.SERVICE_MANAGER_ADDRESS) self.service_manager = self.w3.eth.contract( - address=SERVICE_MANAGER_ADDRESS, abi=SERVICE_MANAGER_ABI + address=addresses.SERVICE_MANAGER_ADDRESS, abi=SERVICE_MANAGER_ABI ) # Task manager - self.check_contract_deployed(TASK_MANAGER_ADDRESS) + self.check_contract_deployed(addresses.TASK_MANAGER_ADDRESS) self.task_manager = self.w3.eth.contract( - address=TASK_MANAGER_ADDRESS, abi=TASK_MANAGER_ABI + address=addresses.TASK_MANAGER_ADDRESS, abi=TASK_MANAGER_ABI ) # Delegation manager delegation_manager_address = ( @@ -100,9 +96,9 @@ def init_contracts(self) -> None: abi=STRATEGY_MANAGER_ABI, ) # allocation manager - self.check_contract_deployed(ALLOCATION_MANAGER_ADDRESS) + self.check_contract_deployed(addresses.ALLOCATION_MANAGER_ADDRESS) self.allocation_manager = self.w3.eth.contract( - address=ALLOCATION_MANAGER_ADDRESS, + address=addresses.ALLOCATION_MANAGER_ADDRESS, abi=ALLOCATION_MANAGER_ABI, ) # Model registry diff --git a/client/src/main.py b/client/src/main.py index 96b78ce..ee4662a 100644 --- a/client/src/main.py +++ b/client/src/main.py @@ -57,13 +57,12 @@ def start( raise typer.Exit(1) ensure_external_files() + config_obj = load_config(config, mode) try: if mode == "operator": - config_obj = load_config(config, "operator") run_operator(config_obj) elif mode == "aggregator": - config_obj = load_config(config, "aggregator") run_aggregator(config_obj) else: logger.error(f"Invalid mode: {mode}. Use 'operator' or 'aggregator'") diff --git a/client/src/management/owner.py b/client/src/management/owner.py index b1687dc..51904e5 100644 --- a/client/src/management/owner.py +++ b/client/src/management/owner.py @@ -1,5 +1,6 @@ from eth_account import Account +from common.addresses import addresses from common.eth import EthereumClient from common.config import GasStrategy from common.logging import get_logger @@ -18,8 +19,10 @@ def __init__( self, private_key: str, eth_rpc_url: str, + chain_id: int = 31337, gas_strategy: GasStrategy = GasStrategy.STANDARD, ): + addresses.init_addresses(chain_id=chain_id) self.private_key = private_key self.owner_address = Account.from_key(self.private_key).address self.gas_strategy = gas_strategy diff --git a/client/tests/__init__.py b/client/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/client/tests/conftest.py b/client/tests/conftest.py index 1f76429..61b55e1 100644 --- a/client/tests/conftest.py +++ b/client/tests/conftest.py @@ -1,18 +1,24 @@ +import asyncio import json import os import shutil import sys import tempfile +import threading +import time from pathlib import Path import pytest import requests +import uvicorn from dotenv import load_dotenv from eth_account import Account from web3 import Web3 from aggregator.main import Aggregator from avs_operator.main import TaskOperator +from common.abis import STRATEGY_ABI +from common.addresses import addresses from common.config import AggregatorConfig, OperatorConfig from common.constants import CLIENT_SRC_PATH, ROOT_DIR from management.owner import AvsOwner @@ -20,12 +26,36 @@ sys.path.insert(0, str(CLIENT_SRC_PATH)) load_dotenv(ROOT_DIR / ".env") # Load environment variables +OPERATOR_NODES = [ + { + "node_name": "node1", + "metadata": "optional metadata", + "total_fucus": 500, + "is_active": True, + "models": [ + {"model_name": "model_0", "allocated_fucus": 500}, + # {"model_name": "model_1", "allocated_fucus": 50}, + ], + }, + { + "node_name": "node2", + "metadata": "optional metadata", + "total_fucus": 900, + "is_active": True, + "models": [ + {"model_name": "model_0", "allocated_fucus": 900}, + # {"model_name": "model_1", "allocated_fucus": 10}, + ], + }, +] + @pytest.fixture(scope="session") def owner(): return AvsOwner( private_key=os.getenv("PRIVATE_KEY"), eth_rpc_url="http://localhost:8545", + chain_id=31337, ) @@ -37,10 +67,53 @@ def aggregator(): ecdsa_private_key_store_path="tests/keys/aggregator.ecdsa.key.json", proof_request_probability=1.0, # challenge every task auto_update=False, + caching={"disable": True}, ) return Aggregator(config) +@pytest.fixture(scope="session") +def aggregator_server(aggregator: Aggregator): + """Start aggregator server in a separate thread and provide cleanup.""" + server = None + server_thread = None + + def start_server(): + nonlocal server + config = uvicorn.Config( + app=aggregator.server.app, + host="0.0.0.0", + port=8090, + log_level="info", + ) + server = uvicorn.Server(config) + + # Run server until stop event is set + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + loop.run_until_complete(server.serve()) + except asyncio.CancelledError: + pass + finally: + loop.close() + + # Start server in a separate thread + server_thread = threading.Thread(target=start_server) + server_thread.start() + time.sleep(5) # Wait for server to start + + yield aggregator + + # Cleanup - stop server + if server: + server.should_exit = True + + if server_thread: + server_thread.join(timeout=5) # Wait up to 5 seconds + + @pytest.fixture(scope="function") def operator(): config = OperatorConfig( @@ -48,28 +121,8 @@ def operator(): aggregator_server_ip_port_address="localhost:8090", ecdsa_private_key_store_path="tests/keys/operator.ecdsa.key.json", auto_update=False, - nodes=[ - { - "node_name": "node1", - "metadata": "optional metadata", - "total_fucus": 500, - "is_active": True, - "models": [ - {"model_name": "model_0", "allocated_fucus": 500}, - # {"model_name": "model_1", "allocated_fucus": 50}, - ], - }, - { - "node_name": "node2", - "metadata": "optional metadata", - "total_fucus": 900, - "is_active": True, - "models": [ - {"model_name": "model_0", "allocated_fucus": 900}, - # {"model_name": "model_1", "allocated_fucus": 10}, - ], - }, - ], + nodes=OPERATOR_NODES, + caching={"disable": True}, ) operator = TaskOperator(config) operator.nodes_manager.sync_nodes() @@ -196,3 +249,12 @@ def dummy_address() -> str: """Return a valid EIP-55 Ethereum address for tests.""" raw = "0x" + os.urandom(20).hex() return Web3.to_checksum_address(raw) + + +@pytest.fixture(scope="function") +def strategies(aggregator: Aggregator): + """Return list of strategy contract objects.""" + return [ + aggregator.eth_client.w3.eth.contract(address=addr, abi=STRATEGY_ABI) + for addr in addresses.STRATEGIES_ADDRESSES + ] diff --git a/client/tests/test_caches.py b/client/tests/test_caches.py new file mode 100644 index 0000000..d532331 --- /dev/null +++ b/client/tests/test_caches.py @@ -0,0 +1,186 @@ +from unittest.mock import Mock, patch + +import pytest + +from common.cache import ( + CloudflareKVCache, + InMemoryCache, + MemcachedCache, + NoOpCache, + cached_method, + get_cache, +) +from common.config import CacheConfig + + +class TestBackends: + + def test_in_memory_cache(self): + """Test cache initialization with custom parameters.""" + cache = get_cache( + CacheConfig( + disable=False, + ) + ) + assert isinstance(cache, InMemoryCache) + + assert cache.exists("test_key") is False + assert cache.set("test_key", "test_value", ttl=60) is True + assert cache.exists("test_key") is True + assert cache.get("test_key") == "test_value" + assert cache.delete("test_key") is True + assert cache.get("test_key") is None + + @patch("common.cache.MemcacheClient") + def test_memcached(self, mock_client_class): + """Test Memcached initialization.""" + mock_client = Mock() + mock_client.version.return_value = b"1.6.0" + mock_client_class.return_value = mock_client + mock_client.get.return_value = "test_value" + + cache = get_cache( + CacheConfig( + disable=False, + memcached_host="localhost", + memcached_port=11211, + ) + ) + assert isinstance(cache, MemcachedCache) + + # Set a value + assert cache.set("test_key", "test_value", ttl=60) is True + mock_client.set.assert_called_once_with( + key="test_key", value="test_value", expire=60 + ) + + # Get the value + result = cache.get("test_key") + assert result == "test_value" + mock_client.get.assert_called_once_with("test_key") + + # Delete the key + assert cache.delete("test_key") is True + mock_client.delete.assert_called_once_with("test_key") + + # Test MemcacheError handling on get + from pymemcache.exceptions import MemcacheError + + mock_client.get.side_effect = MemcacheError("Connection error") + # Should return None on error, not raise exception + assert cache.get("test_key") is None + + @patch("common.cache.requests.get") + @patch("common.cache.requests.put") + @patch("common.cache.requests.delete") + @patch("common.cache.requests.head") + def test_cloudflare_kv_cache(self, mock_head, mock_delete, mock_put, mock_get): + """Test successful get operation.""" + + cache = get_cache( + CacheConfig( + disable=False, + cloudflare_account_id="acc", + cloudflare_namespace_id="ns", + cloudflare_api_token="token", + ) + ) + assert isinstance(cache, CloudflareKVCache) + + mock_response = Mock() + + # Key does not exist + mock_response.status_code = 404 + mock_get.return_value = mock_response + cache = CloudflareKVCache("acc", "ns", "token") + result = cache.get("nonexistent") + assert result is None + assert mock_get.call_count == 1 + + # Get the value + mock_response.status_code = 200 + mock_response.content = b'"test_value"' + mock_get.return_value = mock_response + result = cache.get("test_key") + assert result == "test_value" + assert mock_get.call_count == 2 + + # Set the value + mock_response.status_code = 200 + mock_put.return_value = mock_response + cache = CloudflareKVCache("acc", "ns", "token") + result = cache.set("test_key", "test_value", ttl=60) + assert result is True + mock_put.assert_called_once() + + # Delete the key + mock_response.status_code = 200 + mock_delete.return_value = mock_response + cache = CloudflareKVCache("acc", "ns", "token") + result = cache.delete("test_key") + assert result is True + mock_delete.assert_called_once() + + def test_no_op_cache(self): + """Test NoOpCache initialization.""" + cache = get_cache(CacheConfig(disable=True)) + assert isinstance(cache, NoOpCache) + assert cache.get("any_key") is None + assert cache.exists("any_key") is False + assert cache.delete("any_key") is True + assert cache.clear() is True + assert cache.set("key", "value") is True + assert cache.get("key") is None + + +class TestCachedMethodDecorator: + """Tests for the cached_method decorator.""" + + class SomeClass: + def __init__(self, cache=None): + self.call_count = 0 + if cache: + self.cache = cache + + @cached_method(ttl=60, key_prefix="test") + def method(self, arg1: str, arg2: int, kwarg1: str = "default") -> str: + self.call_count += 1 + return f"result_{arg1}_{arg2}_{kwarg1}" + + def test_works_without_cache_attribute(self): + """Test that decorator works when object has no cache attribute.""" + + obj = self.SomeClass() + + # Should work without caching + result1 = obj.method("value1", 42, kwarg1="kwarg") + assert result1 == "result_value1_42_kwarg" + assert obj.call_count == 1 + + # Should call method again (no caching) + result2 = obj.method("value1", 42, kwarg1="kwarg") + assert result2 == "result_value1_42_kwarg" + assert obj.call_count == 2 + + def test_handles_multiple_arguments(self): + """Test caching with multiple arguments.""" + cache = get_cache(CacheConfig(disable=False)) + assert isinstance(cache, InMemoryCache) + cache.clear() + + obj = self.SomeClass(cache=cache) + + # First call + result1 = obj.method("a", 1, kwarg1="x") + assert result1 == "result_a_1_x" + assert obj.call_count == 1 + + # Same args - should use cache + result2 = obj.method("a", 1, kwarg1="x") + assert result2 == "result_a_1_x" + assert obj.call_count == 1 + + # Different args - should execute + result3 = obj.method("a", 2, kwarg1="x") + assert result3 == "result_a_2_x" + assert obj.call_count == 2 diff --git a/client/tests/test_model_registry.py b/client/tests/test_model_registry.py index ddf8d11..c934617 100644 --- a/client/tests/test_model_registry.py +++ b/client/tests/test_model_registry.py @@ -1,6 +1,7 @@ import json import shutil import uuid +import requests from common.constants import MODELS_FOLDER from models.execution_layer.base_input import BaseInput @@ -20,6 +21,12 @@ def update_model_metadata(models_root, model_name, **kwargs): json.dump(metadata, f) +def request_models_from_api(): + resp = requests.get("http://localhost:8090/models") + assert resp.status_code == 200 + return {m["name"]: m for m in resp.json()} + + def test_load_circuit_input_class(models_root, write_model): write_model(models_root, "model_ok", is_active=True) @@ -50,7 +57,7 @@ def test_ensure_external_files(models_root, write_model): def test_sync_models_create_update_disable( - models_root, write_model, owner, dummy_address + models_root, write_model, owner, dummy_address, aggregator_server ): try: model_registry = ModelRegistry( @@ -73,9 +80,15 @@ def test_sync_models_create_update_disable( write_model( models_root, model_name, is_active=True, compute_cost=111, required_fucus=55 ) + + # smoke test `/models` endpoint - new model is not there yet + api_models = request_models_from_api() + assert model_name not in api_models + # and sync one more time model_registry.sync_models() + # check active models in contract active_chain_models = [ (model_id, model) for model_id, model in model_registry._get_blockchain_models().items() @@ -87,6 +100,12 @@ def test_sync_models_create_update_disable( assert model["compute_cost"] == 111 assert model["required_fucus"] == 55 + # and now the model should be in the API response + api_models = request_models_from_api() + assert model_name in api_models + assert api_models[model_name]["id"] == model_id + assert api_models[model_name]["compute_cost"] == 111 + # change the model cost and fucus update_model_metadata( models_root, @@ -101,6 +120,11 @@ def test_sync_models_create_update_disable( assert model["required_fucus"] == 77 assert model["verifier"] == dummy_address + # check that the model is updated in the API response + api_models = request_models_from_api() + assert api_models[model_name]["id"] == model_id + assert api_models[model_name]["compute_cost"] == 222 + # disable the model update_model_metadata(models_root, model_name, is_active=False) model_registry.sync_models() @@ -111,6 +135,9 @@ def test_sync_models_create_update_disable( ] assert len(active_chain_models) == 0 + # check the model is not anymore in the API response + assert model_name not in request_models_from_api() + # activate the model again update_model_metadata(models_root, model_name, is_active=True) model_registry.sync_models() @@ -120,6 +147,8 @@ def test_sync_models_create_update_disable( if model["active"] ] assert len(active_chain_models) == 1 + # check the model is again in the API response + assert model_name in request_models_from_api() # and again disable it removing the model directory shutil.rmtree(models_root / model_name, ignore_errors=True) @@ -130,6 +159,8 @@ def test_sync_models_create_update_disable( if model["active"] ] assert len(active_chain_models) == 0 + # and again nothing in the API response + assert len(request_models_from_api()) == 0 finally: # set "real" models from "real" models folder back to the chain model_registry = ModelRegistry( diff --git a/client/tests/test_nodes.py b/client/tests/test_nodes.py new file mode 100644 index 0000000..2590e7f --- /dev/null +++ b/client/tests/test_nodes.py @@ -0,0 +1,132 @@ +import requests + +from aggregator.main import Aggregator +from avs_operator.main import TaskOperator +from common.config import OperatorConfig +from tests.conftest import OPERATOR_NODES + + +def get_nodes_list(): + response = requests.get("http://localhost:8090/nodes") + return response.json() + + +def get_model_ids(): + response = requests.get("http://localhost:8090/models") + res = response.json() + return {m["name"]: m["id"] for m in res} + + +def get_expected_nodes(nodes_config=OPERATOR_NODES): + model_ids = get_model_ids() + return { + n["node_name"]: { + "metadata": n["metadata"], + "total_fucus": n["total_fucus"], + "is_active": n["is_active"], + "supported_models": [ + { + "model_id": model_ids[model["model_name"]], + "allocated_fucus": model["allocated_fucus"], + } + for model in n["models"] + ], + } + for n in nodes_config + if n["is_active"] + } + + +def test_operator_initialization(operator: TaskOperator, aggregator_server: Aggregator): + expected_result = get_expected_nodes(OPERATOR_NODES) + resp = get_nodes_list() + + assert len(resp) == len(expected_result), "Node count mismatch" + for node in resp: + node_name = node["name"] + assert node_name in expected_result, f"Unexpected node name: {node_name}" + expected_node = expected_result[node_name] + assert ( + node["metadata"] == expected_node["metadata"] + ), f"Metadata mismatch for {node_name}" + assert ( + node["total_fucus"] == expected_node["total_fucus"] + ), f"Total fucus mismatch for {node_name}" + assert ( + node["is_active"] == expected_node["is_active"] + ), f"Is active mismatch for {node_name}" + assert len(node["supported_models"]) == len( + expected_node["supported_models"] + ), f"Supported models count mismatch for {node_name}" + assert ( + node["operator"].lower() == operator.operator_address.lower() + ), f"Operator address mismatch for {node_name}" + for model in node["supported_models"]: + match = next( + ( + m + for m in expected_node["supported_models"] + if m["model_id"] == model["model_id"] + ), + None, + ) + assert ( + match is not None + ), f"Unexpected model ID {model['model_id']} for {node_name}" + assert ( + model["allocated_fucus"] == match["allocated_fucus"] + ), f"Allocated fucus mismatch for model ID {model['model_id']} in {node_name}" + + +def test_operator_nodes_update(aggregator_server: Aggregator): + nodes_config = OPERATOR_NODES.copy() + # Deactivate the first node + nodes_config[0]["is_active"] = False + # Update the second node's fucus + nodes_config[1]["total_fucus"] = 1000 + nodes_config[1]["models"][0]["allocated_fucus"] = 1000 + + # initialize a new operator instance with the updated config + operator = TaskOperator( + OperatorConfig( + eth_rpc_url="http://localhost:8545", + aggregator_server_ip_port_address="localhost:8090", + ecdsa_private_key_store_path="tests/keys/operator.ecdsa.key.json", + auto_update=False, + nodes=nodes_config, + ) + ) + # sync nodes with the updated config + operator.nodes_manager.sync_nodes() + + # get the nodes list from the server and verify the updates + expected_result = get_expected_nodes(nodes_config) + resp = get_nodes_list() + + assert len(resp) == 1, "Node count mismatch" + node = resp[0] + assert node["total_fucus"] == 1000, f"Total fucus mismatch" + + # Sync back to original config - reduce back allocated fucus + nodes_config[1]["models"][0]["allocated_fucus"] = 900 + TaskOperator( + OperatorConfig( + eth_rpc_url="http://localhost:8545", + aggregator_server_ip_port_address="localhost:8090", + ecdsa_private_key_store_path="tests/keys/operator.ecdsa.key.json", + auto_update=False, + nodes=nodes_config, + ) + ).nodes_manager.sync_nodes() + # reduce back total fucus + nodes_config[0]["is_active"] = True + nodes_config[1]["total_fucus"] = 900 + TaskOperator( + OperatorConfig( + eth_rpc_url="http://localhost:8545", + aggregator_server_ip_port_address="localhost:8090", + ecdsa_private_key_store_path="tests/keys/operator.ecdsa.key.json", + auto_update=False, + nodes=nodes_config, + ) + ).nodes_manager.sync_nodes() diff --git a/client/tests/test_workflow.py b/client/tests/test_workflow.py index 8af4385..28ba71b 100644 --- a/client/tests/test_workflow.py +++ b/client/tests/test_workflow.py @@ -1,73 +1,100 @@ -import asyncio -import threading import time -import uvicorn +import requests from aggregator.main import Aggregator from avs_operator.main import TaskOperator -from common.constants import STRATEGIES_ADDRESSES +from common.addresses import addresses from common.contract_constants import TaskStateMap, TaskStructMap from management.owner import AvsOwner class TestWorkflow: - def start_aggregator_server(self, aggregator: Aggregator): - self._stop_event = threading.Event() - config = uvicorn.Config( - app=aggregator.server.app, host="0.0.0.0", port=8090, log_level="info" - ) - self.server = uvicorn.Server(config) - - # Run server until stop event is set - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - try: - loop.run_until_complete(self.server.serve()) - except asyncio.CancelledError: - pass - finally: - loop.close() - - def stop_aggregator_server(self): - """Stop server thread""" - if self.server: - self.server.should_exit = True - self._stop_event.set() + def make_request(self, path, **params): + params = {"limit": 100, "offset": 0, "include_details": False, **params} + res = requests.get(f"http://localhost:8090/{path}", params=params) + assert res.status_code == 200 + return res.json() def test_process_task( self, aggregator: Aggregator, + aggregator_server: Aggregator, operator: TaskOperator, owner: AvsOwner, init_environment: dict, + strategies: list, ): - """Just a smoke test to ensure send_new_task runs without errors""" + """ + Simulate end-to-end task processing workflow: + - Create a new task + - Operator processes the task + - Aggregator challenges the task + - Operator generates proof + - Aggregator resolves the task + - Verify rewards distribution + """ + # get underlying token address for the future task + token_address = strategies[0].functions.underlyingToken().call() + + # Get the operator's fees before processing the task + res = self.make_request("fees", hours=24) + base_operator_fee = ( + res["rewards_by_operator"] + .get(operator.operator_address, {}) + .get(token_address, {}) + .get("operator_share", 0) + ) # create a new task task_id = aggregator.send_new_task(1) assert task_id is not None, "Task ID should not be None" - # here task should be assigned to the operator + + # here task should be assigned to the operator + res = self.make_request( + "operator-inference-history", operator=operator.operator_address + ) + assert task_id in res["tasks"] + + # check that the task is visible in the stats endpoint + res = self.make_request( + "state-inference-history", state=TaskStateMap.ASSIGNED.value + ) + assert task_id in res["tasks"] # process the task by the operator processed_count = operator.listen_for_events(loop_running=False) assert processed_count == 1, "Operator should process one task" # the task should be marked as completed + # check that the task is not visible in the assigned state + res = self.make_request( + "state-inference-history", state=TaskStateMap.ASSIGNED.value + ) + assert task_id not in res["tasks"] + # and is visible in the completed state + res = self.make_request( + "state-inference-history", state=TaskStateMap.COMPLETED.value + ) + assert task_id in res["tasks"] + # checkout the completed task processed_count = aggregator.listen_for_events(loop_running=False) assert processed_count == 1, "Aggregator should process one task" # At this point the task should be challenged - # start aggregator server in a separate thread - # This is necessary to allow the aggregator to listen for events and process the challenge - aggregator_server = threading.Thread( - target=self.start_aggregator_server, args=[aggregator] - ) - aggregator_server.start() time.sleep(5) + # Check that the task is not visible in the completed state + res = self.make_request( + "state-inference-history", state=TaskStateMap.COMPLETED.value + ) + assert task_id not in res["tasks"] + # and is visible in the challenged state + res = self.make_request( + "state-inference-history", state=TaskStateMap.CHALLENGED.value + ) + assert task_id in res["tasks"] # check events by the operator, it should process the challenge and generate proof processed_count = operator.listen_for_events(loop_running=False) @@ -78,6 +105,27 @@ def test_process_task( task = aggregator.eth_client.task_manager.functions.getTask(task_id).call() assert task[TaskStructMap.STATE] == TaskStateMap.RESOLVED model_id = task[TaskStructMap.MODEL_ID] + user = task[TaskStructMap.USER] + + # Check that the model is not visible in the challenged state + res = self.make_request( + "state-inference-history", state=TaskStateMap.CHALLENGED.value + ) + assert task_id not in res["tasks"] + # and is visible in the resolved state + res = self.make_request( + "state-inference-history", + state=TaskStateMap.RESOLVED.value, + limit=1, + offset=0, + ) + assert [task_id] == res["tasks"] + # and the task is visible in the model-specific history + res = self.make_request("model-inference-history", model_id=model_id) + assert task_id in res["tasks"] + # and the task is visible in the user-specific history + res = self.make_request("user-inference-history", user=user, limit=1, offset=0) + assert [task_id] == res["tasks"] # check rewards collected for the operator operators_in_interval: list = ( @@ -94,7 +142,7 @@ def test_process_task( rewards = aggregator.eth_client.service_manager.functions.getIntervalRewards( init_environment["current_interval"], operator.operator_address, - STRATEGIES_ADDRESSES[0], + addresses.STRATEGIES_ADDRESSES[0], ).call() model_cost: int = aggregator.eth_client.model_registry.functions.computeCost( model_id @@ -103,7 +151,7 @@ def test_process_task( operator.operator_address, ], "Operator should be in the current interval" assert strategies_in_interval == [ - STRATEGIES_ADDRESSES[0], + addresses.STRATEGIES_ADDRESSES[0], ], "Aggregator should be in the current interval" assert ( rewards == model_cost @@ -118,13 +166,22 @@ def test_process_task( # Submit rewards for the interval owner.submit_rewards_for_interval(init_environment["current_interval"]) - self.stop_aggregator_server() - if aggregator_server: - aggregator_server.join(timeout=5) # Wait up to 5 seconds + # Check the fees accumulated during the interval (smoke test) + res = self.make_request("fees", hours=24) + updated_operator_fee = ( + res["rewards_by_operator"] + .get(operator.operator_address, {}) + .get(token_address, {}) + .get("operator_share", 0) + ) + assert ( + updated_operator_fee > base_operator_fee + ), "Operator's fees should increase after processing the task" def test_task_incorrect_proof( self, aggregator: Aggregator, + aggregator_server: Aggregator, operator: TaskOperator, owner: AvsOwner, init_environment: dict, @@ -138,14 +195,14 @@ def mock_generate_proof_for_task(*args, **kwargs) -> str: operator.generate_proof_for_task = mock_generate_proof_for_task initial_shares = ( aggregator.eth_client.delegation_manager.functions.operatorShares( - operator.operator_address, STRATEGIES_ADDRESSES[0] + operator.operator_address, addresses.STRATEGIES_ADDRESSES[0] ).call() ) initial_rewards = ( aggregator.eth_client.service_manager.functions.getIntervalRewards( init_environment["current_interval"], operator.operator_address, - STRATEGIES_ADDRESSES[0], + addresses.STRATEGIES_ADDRESSES[0], ).call() ) @@ -164,12 +221,6 @@ def mock_generate_proof_for_task(*args, **kwargs) -> str: assert processed_count == 1, "Aggregator should process one task" # At this point the task should be challenged - # start aggregator server in a separate thread - # This is necessary to allow the aggregator to listen for events and process the challenge - aggregator_server = threading.Thread( - target=self.start_aggregator_server, args=[aggregator] - ) - aggregator_server.start() time.sleep(5) # check events by the operator, it should process the challenge and generate proof @@ -180,7 +231,17 @@ def mock_generate_proof_for_task(*args, **kwargs) -> str: # here the task should be resolved by the aggregator task = aggregator.eth_client.task_manager.functions.getTask(task_id).call() assert task[TaskStructMap.STATE] == TaskStateMap.REJECTED.value - # model_id = task[TaskStructMap.MODEL_ID] + + # Check that the task is not visible in the challenged state + res = self.make_request( + "state-inference-history", state=TaskStateMap.CHALLENGED.value + ) + assert task_id not in res["tasks"] + # and is visible in the rejected state + res = self.make_request( + "state-inference-history", state=TaskStateMap.REJECTED.value + ) + assert task_id in res["tasks"] # check rewards collected for the operator operators_in_interval: list = ( @@ -194,7 +255,7 @@ def mock_generate_proof_for_task(*args, **kwargs) -> str: final_shares = ( aggregator.eth_client.delegation_manager.functions.operatorShares( - operator.operator_address, STRATEGIES_ADDRESSES[0] + operator.operator_address, addresses.STRATEGIES_ADDRESSES[0] ).call() ) assert ( @@ -205,11 +266,20 @@ def mock_generate_proof_for_task(*args, **kwargs) -> str: aggregator.eth_client.service_manager.functions.getIntervalRewards( init_environment["current_interval"], operator.operator_address, - STRATEGIES_ADDRESSES[0], + addresses.STRATEGIES_ADDRESSES[0], ).call() ) assert final_rewards == initial_rewards, "No new rewards for the operator" - self.stop_aggregator_server() - if aggregator_server: - aggregator_server.join(timeout=5) # Wait up to 5 seconds + def test_tvl_endpoint(self, aggregator_server): + """ + Just a smoke test to ensure /tvl endpoint works and returns expected fields + """ + response = requests.get("http://localhost:8090/tvl") + assert response.status_code == 200 + data = response.json() + assert "tvl_by_operator" in data + assert "tvl_by_strategy" in data + assert data["total_operators"] == 1 + assert data["active_operators"] == 1 + assert data["strategies_count"] == 5 diff --git a/client/uv.lock b/client/uv.lock index e989244..a21655b 100644 --- a/client/uv.lock +++ b/client/uv.lock @@ -93,6 +93,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/6a/bc7e17a3e87a2985d3e8f4da4cd0f481060eb78fb08596c42be62c90a4d9/aiosignal-1.3.2-py2.py3-none-any.whl", hash = "sha256:45cde58e409a301715980c2b01d0c28bdde3770d8290b5eb2173759d9acb31a5", size = 7597, upload-time = "2024-12-13T17:10:38.469Z" }, ] +[[package]] +name = "annotated-doc" +version = "0.0.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d7/a6/dc46877b911e40c00d395771ea710d5e77b6de7bacd5fdcd78d70cc5a48f/annotated_doc-0.0.3.tar.gz", hash = "sha256:e18370014c70187422c33e945053ff4c286f453a984eba84d0dbfa0c935adeda", size = 5535, upload-time = "2025-10-24T14:57:10.718Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/02/b7/cf592cb5de5cb3bade3357f8d2cf42bf103bbe39f459824b4939fd212911/annotated_doc-0.0.3-py3-none-any.whl", hash = "sha256:348ec6664a76f1fd3be81f43dffbee4c7e8ce931ba71ec67cc7f4ade7fbbb580", size = 5488, upload-time = "2025-10-24T14:57:09.462Z" }, +] + [[package]] name = "annotated-types" version = "0.7.0" @@ -191,6 +200,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7d/27/73ca754c226badce7f64a25964d990954f50ab672a401b81b7ee55e6c1bd/bitarray-3.1.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:8560480a743341f5720e1ed91234e2199ca4422a6afe849575562aa920468487", size = 132431, upload-time = "2025-03-06T21:56:57.082Z" }, ] +[[package]] +name = "cachetools" +version = "6.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cc/7e/b975b5814bd36faf009faebe22c1072a1fa1168db34d285ef0ba071ad78c/cachetools-6.2.1.tar.gz", hash = "sha256:3f391e4bd8f8bf0931169baf7456cc822705f4e2a31f840d218f445b9a854201", size = 31325, upload-time = "2025-10-12T14:55:30.139Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/96/c5/1e741d26306c42e2bf6ab740b2202872727e0f606033c9dd713f8b93f5a8/cachetools-6.2.1-py3-none-any.whl", hash = "sha256:09868944b6dde876dfd44e1d47e18484541eaf12f26f29b7af91b26cc892d701", size = 11280, upload-time = "2025-10-12T14:55:28.382Z" }, +] + [[package]] name = "certifi" version = "2025.1.31" @@ -574,16 +592,17 @@ wheels = [ [[package]] name = "fastapi" -version = "0.115.12" +version = "0.120.1" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "annotated-doc" }, { name = "pydantic" }, { name = "starlette" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f4/55/ae499352d82338331ca1e28c7f4a63bfd09479b16395dce38cf50a39e2c2/fastapi-0.115.12.tar.gz", hash = "sha256:1e2c2a2646905f9e83d32f04a3f86aff4a286669c6c950ca95b5fd68c2602681", size = 295236, upload-time = "2025-03-23T22:55:43.822Z" } +sdist = { url = "https://files.pythonhosted.org/packages/40/cc/28aff6e246ee85bd571b26e4a793b84d42700e3bdc3008c3d747eda7b06d/fastapi-0.120.1.tar.gz", hash = "sha256:b5c6217e9ddca6dfcf54c97986180d4a1955e10c693d74943fc5327700178bff", size = 337616, upload-time = "2025-10-27T17:53:42.954Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/50/b3/b51f09c2ba432a576fe63758bddc81f78f0c6309d9e5c10d194313bf021e/fastapi-0.115.12-py3-none-any.whl", hash = "sha256:e94613d6c05e27be7ffebdd6ea5f388112e5e430c8f7d6494a9d1d88d43e814d", size = 95164, upload-time = "2025-03-23T22:55:42.101Z" }, + { url = "https://files.pythonhosted.org/packages/7e/bb/1a74dbe87e9a595bf63052c886dfef965dc5b91d149456a8301eb3d41ce2/fastapi-0.120.1-py3-none-any.whl", hash = "sha256:0e8a2c328e96c117272d8c794d3a97d205f753cc2e69dd7ee387b7488a75601f", size = 108254, upload-time = "2025-10-27T17:53:40.076Z" }, ] [[package]] @@ -1190,6 +1209,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8a/0b/9fcc47d19c48b59121088dd6da2488a49d5f72dacf8262e2790a1d2c7d15/pygments-2.19.1-py3-none-any.whl", hash = "sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c", size = 1225293, upload-time = "2025-01-06T17:26:25.553Z" }, ] +[[package]] +name = "pymemcache" +version = "4.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d9/b6/4541b664aeaad025dfb8e851dcddf8e25ab22607e674dd2b562ea3e3586f/pymemcache-4.0.0.tar.gz", hash = "sha256:27bf9bd1bbc1e20f83633208620d56de50f14185055e49504f4f5e94e94aff94", size = 70176, upload-time = "2022-10-17T16:53:07.726Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/41/ba/2f7b22d8135b51c4fefb041461f8431e1908778e6539ff5af6eeaaee367a/pymemcache-4.0.0-py2.py3-none-any.whl", hash = "sha256:f507bc20e0dc8d562f8df9d872107a278df049fa496805c1431b926f3ddd0eab", size = 60772, upload-time = "2022-10-17T16:53:04.388Z" }, +] + [[package]] name = "pyreadline3" version = "3.5.4" @@ -1422,6 +1450,7 @@ name = "sertn-avs" version = "0.1.0" source = { editable = "." } dependencies = [ + { name = "cachetools" }, { name = "eth-abi" }, { name = "eth-account" }, { name = "ezkl" }, @@ -1430,8 +1459,10 @@ dependencies = [ { name = "onnxruntime" }, { name = "packaging" }, { name = "pydantic" }, + { name = "pymemcache" }, { name = "python-dotenv" }, { name = "pyyaml" }, + { name = "starlette" }, { name = "torch", version = "2.4.1", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, { name = "torch", version = "2.4.1+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, { name = "tqdm" }, @@ -1456,6 +1487,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "cachetools", specifier = ">=6.2.1" }, { name = "eth-abi", specifier = ">=5.2.0" }, { name = "eth-account", specifier = ">=0.13.5" }, { name = "ezkl", specifier = "==19.0.7" }, @@ -1464,11 +1496,13 @@ requires-dist = [ { name = "onnxruntime", specifier = ">=1.21.1" }, { name = "packaging", specifier = ">=24.2" }, { name = "pydantic", specifier = ">=2.6.3" }, + { name = "pymemcache", specifier = ">=4.0.0" }, { name = "pytest", marker = "extra == 'test'", specifier = ">=8.0.0" }, { name = "pytest-asyncio", marker = "extra == 'test'", specifier = ">=0.23.5" }, { name = "pytest-cov", marker = "extra == 'test'", specifier = ">=4.1.0" }, { name = "python-dotenv", specifier = ">=1.1.1" }, { name = "pyyaml", specifier = ">=6.0.2" }, + { name = "starlette", specifier = ">=0.49.1" }, { name = "torch", specifier = "==2.4.1", index = "https://download.pytorch.org/whl/cpu" }, { name = "tqdm", specifier = ">=4.66.2" }, { name = "typer", specifier = ">=0.9.0" }, @@ -1513,14 +1547,15 @@ wheels = [ [[package]] name = "starlette" -version = "0.46.1" +version = "0.49.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, + { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/04/1b/52b27f2e13ceedc79a908e29eac426a63465a1a01248e5f24aa36a62aeb3/starlette-0.46.1.tar.gz", hash = "sha256:3c88d58ee4bd1bb807c0d1acb381838afc7752f9ddaec81bbe4383611d833230", size = 2580102, upload-time = "2025-03-08T10:55:34.504Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1b/3f/507c21db33b66fb027a332f2cb3abbbe924cc3a79ced12f01ed8645955c9/starlette-0.49.1.tar.gz", hash = "sha256:481a43b71e24ed8c43b11ea02f5353d77840e01480881b8cb5a26b8cae64a8cb", size = 2654703, upload-time = "2025-10-28T17:34:10.928Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a0/4b/528ccf7a982216885a1ff4908e886b8fb5f19862d1962f56a3fce2435a70/starlette-0.46.1-py3-none-any.whl", hash = "sha256:77c74ed9d2720138b25875133f3a2dae6d854af2ec37dceb56aef370c1d8a227", size = 71995, upload-time = "2025-03-08T10:55:32.662Z" }, + { url = "https://files.pythonhosted.org/packages/51/da/545b75d420bb23b5d494b0517757b351963e974e79933f01e05c929f20a6/starlette-0.49.1-py3-none-any.whl", hash = "sha256:d92ce9f07e4a3caa3ac13a79523bd18e3bc0042bb8ff2d759a8e7dd0e1859875", size = 74175, upload-time = "2025-10-28T17:34:09.13Z" }, ] [[package]] diff --git a/contracts/anvil/deploy-sertn.sh b/contracts/anvil/deploy-sertn.sh index 7db1060..8186160 100755 --- a/contracts/anvil/deploy-sertn.sh +++ b/contracts/anvil/deploy-sertn.sh @@ -14,16 +14,21 @@ cd ../ forge script script/SertnDeployer.s.sol --rpc-url $RPC_HOST:$RPC_PORT --broadcast -# Format the JSON file using Python -if [ -f deployments/sertnDeployment.json ]; then - python3 -c " +# Format sertnDeployment_*.json files +shopt -s nullglob +json_files=(deployments/sertnDeployment_*.json) + +if [ ${#json_files[@]} -eq 0 ]; then + echo "No sertnDeployment_*.json files found in contracts/deployments/" +else + for json_file in "${json_files[@]}"; do + python3 -c " import json -with open('deployments/sertnDeployment.json', 'r') as f: +with open('$json_file', 'r') as f: data = json.load(f) -with open('deployments/sertnDeployment.json', 'w') as f: +with open('$json_file', 'w') as f: json.dump(data, f, indent=4) " - echo "Formatted contracts/deployments/sertnDeployment.json" -else - echo "contracts/deployments/sertnDeployment.json not found!" + echo "Formatted contracts/$json_file" + done fi diff --git a/contracts/deployments/sertnDeployment.json b/contracts/deployments/sertnDeployment.json deleted file mode 100644 index 8f52b56..0000000 --- a/contracts/deployments/sertnDeployment.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "allocationManager": "0x4826533B4897376654Bb4d4AD88B7faFD0C98528", - "eth_strategy_0": "0x33791c463B145298c575b4409d52c2BcF743BF67", - "eth_strategy_1": "0x9472EF1614f103Ae8f714cCeeF4B438D353Ce1Fa", - "rewardsCoordinator": "0x4c5859f0F772848b2D91F1D83E2Fe57935348029", - "sertnRegistrar": "0xBEc49fA140aCaA83533fB00A2BB19bDdd0290f25", - "sertnServiceManager": "0x172076E0166D1F9Cc711C77Adf8488051744980C", - "sertnTaskManager": "0xD84379CEae14AA33C123Af12424A37803F885889", - "strategy_0": "0xBe1DBc0F9DEef46efE57c08Be6C65aD72C1Fb527", - "strategy_1": "0x33791c463B145298c575b4409d52c2BcF743BF67", - "strategy_2": "0x9472EF1614f103Ae8f714cCeeF4B438D353Ce1Fa" -} \ No newline at end of file diff --git a/contracts/deployments/sertnDeployment_31337.json b/contracts/deployments/sertnDeployment_31337.json new file mode 100644 index 0000000..9691e9b --- /dev/null +++ b/contracts/deployments/sertnDeployment_31337.json @@ -0,0 +1,12 @@ +{ + "allocationManager": "0xDc64a140Aa3E981100a9becA4E685f962f0cF6C9", + "eth_strategy_0": "0x5E3d0fdE6f793B3115A9E7f5EBC195bbeeD35d6C", + "eth_strategy_1": "0xa12fFA0B9f159BB4C54bce579611927Addc51610", + "rewardsCoordinator": "0xA51c1fc2f0D1a1b8494Ed1FE312d7C3a78Ed91C0", + "sertnRegistrar": "0x9d4454B023096f34B160D6B654540c56A1F81688", + "sertnServiceManager": "0x0E801D84Fa97b50751Dbf25036d067dCf18858bF", + "sertnTaskManager": "0x5eb3Bc0a489C5A8288765d2336659EbCA68FCd00", + "strategy_0": "0x53839913417ebc7171723489F29B9B54F49b4EEA", + "strategy_1": "0x5E3d0fdE6f793B3115A9E7f5EBC195bbeeD35d6C", + "strategy_2": "0xa12fFA0B9f159BB4C54bce579611927Addc51610" +} \ No newline at end of file diff --git a/contracts/interfaces/IModelRegistry.sol b/contracts/interfaces/IModelRegistry.sol index ff5422f..d315101 100644 --- a/contracts/interfaces/IModelRegistry.sol +++ b/contracts/interfaces/IModelRegistry.sol @@ -11,6 +11,19 @@ interface IModelRegistry { Onchain, Offchain } + + /** + * @notice Struct containing all model details + */ + struct ModelDetails { + uint256 modelId; + string modelName; + address modelVerifier; + VerificationStrategy verificationStrategy; + uint256 computeCost; + uint256 requiredFUCUs; + bool isActive; + } /** * @notice The event emitted when a new model is created * @param modelId The id of the model @@ -169,4 +182,10 @@ interface IModelRegistry { * @notice Check whether a model is active */ function isActive(uint256 modelId) external view returns (bool); + + /** + * @notice Get all active models with their complete details in a single call + * @return models Array of ModelDetails structs containing all active models + */ + function getActiveModelsWithDetails() external view returns (ModelDetails[] memory models); } diff --git a/contracts/interfaces/ISertnNodesManager.sol b/contracts/interfaces/ISertnNodesManager.sol index 3e469b2..69af942 100644 --- a/contracts/interfaces/ISertnNodesManager.sol +++ b/contracts/interfaces/ISertnNodesManager.sol @@ -313,4 +313,19 @@ interface ISertnNodesManager { address operator, uint256 modelId ) external view returns (uint256); + + /** + * @notice Get all nodes with their details and supported models in a single call + * @return nodeDetails Array of node details (nodeId, operator, name, metadata, totalFucus, allocatedFucus, availableFucus, isActive, createdAt, supportedModelsCount) + * @return supportedModels Array of arrays containing supported model IDs for each node + * @return modelAllocations Array of arrays containing FUCUS allocations for each model on each node + */ + function getAllNodesWithDetails() + external + view + returns ( + uint256[8][] memory nodeDetails, + uint256[][] memory supportedModels, + uint256[][] memory modelAllocations + ); } diff --git a/contracts/interfaces/ISertnServiceManager.sol b/contracts/interfaces/ISertnServiceManager.sol index 1262e8a..c9f068f 100644 --- a/contracts/interfaces/ISertnServiceManager.sol +++ b/contracts/interfaces/ISertnServiceManager.sol @@ -18,7 +18,12 @@ interface ISertnServiceManager { error AggregatorAlreadyExists(); /// @notice Emitted when the task is completed and operator reward is accumulated - event TaskRewardAccumulated(address indexed operator, uint256 fee, uint32 currentInterval); + event TaskRewardAccumulated( + address indexed operator, + uint256 fee, + IERC20 token, + uint32 currentInterval + ); event RewardsSubmittedForInterval(uint32 interval, uint256 operators_quantity); @@ -63,7 +68,7 @@ interface ISertnServiceManager { /** * @notice Task completed */ - function taskCompleted( + function taskResolved( address _operator, uint256 _fee, IStrategy _strategy, diff --git a/contracts/interfaces/ISertnTaskManager.sol b/contracts/interfaces/ISertnTaskManager.sol index ed3e02d..cdc00c3 100644 --- a/contracts/interfaces/ISertnTaskManager.sol +++ b/contracts/interfaces/ISertnTaskManager.sol @@ -120,4 +120,75 @@ interface ISertnTaskManager { * @param proof The proof of completion */ function submitProofForTask(uint256 taskId, bytes calldata proof) external; + + // === TASK HISTORY QUERY FUNCTIONS === + + /** + * @notice Get paginated task IDs for a specific model + * @param modelId The model ID to query + * @param offset Starting index + * @param limit Maximum number of results + * @return Array of task IDs (paginated) + */ + function getTasksByModel( + uint256 modelId, + uint256 offset, + uint256 limit + ) external view returns (uint256[] memory); + + /** + * @notice Get paginated task IDs for a specific operator + * @param operator The operator address to query + * @param offset Starting index + * @param limit Maximum number of results + * @return Array of task IDs (paginated) + */ + function getTasksByOperator( + address operator, + uint256 offset, + uint256 limit + ) external view returns (uint256[] memory); + + /** + * @notice Get paginated task IDs for a specific user + * @param user The user address to query + * @param offset Starting index + * @param limit Maximum number of results + * @return Array of task IDs (paginated) + */ + function getTasksByUser( + address user, + uint256 offset, + uint256 limit + ) external view returns (uint256[] memory); + + /** + * @notice Get paginated task IDs for a specific state + * @param state The task state to query + * @param offset Starting index + * @param limit Maximum number of results + * @return Array of task IDs (paginated) + */ + function getTasksByState( + TaskState state, + uint256 offset, + uint256 limit + ) external view returns (uint256[] memory); + + /** + * @notice Get task history counts for overview statistics + * @return totalTasks Total number of tasks + * @return completedTasks Number of completed/resolved tasks + * @return rejectedTasks Number of rejected tasks + * @return pendingTasks Number of pending/assigned/challenged tasks + */ + function getTaskHistoryStats() + external + view + returns ( + uint256 totalTasks, + uint256 completedTasks, + uint256 rejectedTasks, + uint256 pendingTasks + ); } diff --git a/contracts/script/InitLocalWorkers.t.sol b/contracts/script/InitLocalWorkers.t.sol index 4ce0969..12d3d31 100644 --- a/contracts/script/InitLocalWorkers.t.sol +++ b/contracts/script/InitLocalWorkers.t.sol @@ -4,6 +4,8 @@ pragma solidity ^0.8.19; import "forge-std/Script.sol"; import "forge-std/console.sol"; +import {IERC20} from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; +import {Strings} from "@openzeppelin/contracts/utils/Strings.sol"; import {IStrategy} from "@eigenlayer/contracts/interfaces/IStrategy.sol"; import {IStrategyManager} from "@eigenlayer/contracts/interfaces/IStrategyManager.sol"; import {IDelegationManager} from "@eigenlayer/contracts/interfaces/IDelegationManager.sol"; @@ -12,18 +14,17 @@ import {ISignatureUtilsMixinTypes} from "@eigenlayer/contracts/interfaces/ISigna import {IDelegationManagerTypes} from "@eigenlayer/contracts/interfaces/IDelegationManager.sol"; import {IAllocationManager} from "@eigenlayer/contracts/interfaces/IAllocationManager.sol"; import {IAllocationManagerTypes} from "@eigenlayer/contracts/interfaces/IAllocationManager.sol"; -import {IModelRegistry} from "../interfaces/IModelRegistry.sol"; import {DelegationManager} from "@eigenlayer/contracts/core/DelegationManager.sol"; import {RewardsCoordinator} from "@eigenlayer/contracts/core/RewardsCoordinator.sol"; import {AllocationManager} from "@eigenlayer/contracts/core/AllocationManager.sol"; import {OperatorSet} from "@eigenlayer/contracts/libraries/OperatorSetLib.sol"; -import {IERC20} from "@openzeppelin/contracts/token/ERC20/IERC20.sol"; import {console2 as console} from "forge-std/Test.sol"; import {CoreDeploymentLib} from "./utils/CoreDeploymentLib.sol"; +import {ISertnTaskManager} from "../interfaces/ISertnTaskManager.sol"; +import {IModelRegistry} from "../interfaces/IModelRegistry.sol"; import {SertnRegistrar} from "../src/SertnRegistrar.sol"; import {SertnServiceManager} from "../src/SertnServiceManager.sol"; -import {ISertnTaskManager} from "../interfaces/ISertnTaskManager.sol"; import {ERC20Mock} from "../test/mockContracts/ERC20Mock.sol"; import {MockVerifier} from "../test/mockContracts/VerifierMock.sol"; @@ -100,7 +101,13 @@ contract InitLocalEnvScript is Script { coreDeployment = CoreDeploymentLib.readDeploymentJson("deployments/core/", block.chainid); // Read deployment addresses from JSON file - string memory deploymentFile = vm.readFile("./deployments/sertnDeployment.json"); + string memory deploymentFile = vm.readFile( + string.concat( + "./deployments/sertnDeployment_", + Strings.toString(block.chainid), + ".json" + ) + ); address strategyAddress1 = vm.parseJsonAddress(deploymentFile, ".strategy_0"); address strategyAddress2 = vm.parseJsonAddress(deploymentFile, ".strategy_1"); diff --git a/contracts/script/SertnDeployer.s.sol b/contracts/script/SertnDeployer.s.sol index 194d443..cedca33 100644 --- a/contracts/script/SertnDeployer.s.sol +++ b/contracts/script/SertnDeployer.s.sol @@ -149,7 +149,10 @@ contract SertnDeployer is Script, Test { address(_ethStrategies[i]) ); } - vm.writeFile("deployments/sertnDeployment.json", json); + vm.writeFile( + string.concat("deployments/sertnDeployment_", Strings.toString(block.chainid), ".json"), + json + ); vm.stopBroadcast(); } diff --git a/contracts/src/ModelRegistry.sol b/contracts/src/ModelRegistry.sol index ec6e06c..91395e3 100644 --- a/contracts/src/ModelRegistry.sol +++ b/contracts/src/ModelRegistry.sol @@ -171,4 +171,23 @@ contract ModelRegistry is OwnableUpgradeable, IModelRegistry { function isActive(uint256 modelId) external view returns (bool) { return activeModelIds.contains(modelId); } + + /// @inheritdoc IModelRegistry + function getActiveModelsWithDetails() external view returns (ModelDetails[] memory models) { + uint256[] memory activeIds = activeModelIds.values(); + models = new ModelDetails[](activeIds.length); + + for (uint256 i = 0; i < activeIds.length; i++) { + uint256 modelId = activeIds[i]; + models[i] = ModelDetails({ + modelId: modelId, + modelName: modelName[modelId], + modelVerifier: modelVerifier[modelId], + verificationStrategy: verificationStrategy[modelId], + computeCost: computeCost[modelId], + requiredFUCUs: requiredFUCUs[modelId], + isActive: true // We know it's active since we got it from activeModelIds + }); + } + } } diff --git a/contracts/src/SertnNodesManager.sol b/contracts/src/SertnNodesManager.sol index 9732cff..74b44f8 100644 --- a/contracts/src/SertnNodesManager.sol +++ b/contracts/src/SertnNodesManager.sol @@ -524,4 +524,71 @@ contract SertnNodesManager is OwnableUpgradeable, ISertnNodesManager { availableFucus[i] = tempFucus[i]; } } + + /** + * @notice Get all nodes with their details and supported models in a single call + * @return nodeDetails Array of node details packed as uint256[8] arrays: + * [0] nodeId, [1] operator (as uint256), [2] totalFucus, [3] allocatedFucus, + * [4] availableFucus, [5] isActive (1/0), [6] createdAt, [7] supportedModelsCount + * @return supportedModels Array of arrays containing supported model IDs for each node + * @return modelAllocations Array of arrays containing FUCUS allocations for each model on each node + * @dev This function is gas-optimized for batch operations. String data (name, metadata) + * lengths are returned but not the strings themselves to save gas. Use getNodeDetails + * for individual nodes if string data is needed. + */ + function getAllNodesWithDetails() + external + view + returns ( + uint256[8][] memory nodeDetails, + uint256[][] memory supportedModels, + uint256[][] memory modelAllocations + ) + { + // First pass: count valid nodes + uint256 validCount = 0; + for (uint256 i = 1; i < nextNodeId; i++) { + if (nodes[i].operator != address(0) && nodes[i].isActive) { + validCount++; + } + } + + // Initialize return arrays + nodeDetails = new uint256[8][](validCount); + supportedModels = new uint256[][](validCount); + modelAllocations = new uint256[][](validCount); + + // Second pass: populate data + uint256 index = 0; + for (uint256 i = 1; i < nextNodeId; i++) { + if (nodes[i].operator != address(0) && nodes[i].isActive) { + Node memory node = nodes[i]; + uint256 allocated = getTotalAllocatedFucusForNode(i); + uint256 available = node.totalFucus > allocated ? node.totalFucus - allocated : 0; + + // Pack node details into uint256 array + nodeDetails[index][0] = i; // nodeId + nodeDetails[index][1] = uint256(uint160(node.operator)); // operator as uint256 + nodeDetails[index][2] = node.totalFucus; + nodeDetails[index][3] = allocated; + nodeDetails[index][4] = available; + nodeDetails[index][5] = node.isActive ? 1 : 0; + nodeDetails[index][6] = node.createdAt; + nodeDetails[index][7] = nodeSupportedModels[i].length(); // supportedModelsCount + + // Get supported models and their allocations + uint256 modelCount = nodeSupportedModels[i].length(); + supportedModels[index] = new uint256[](modelCount); + modelAllocations[index] = new uint256[](modelCount); + + for (uint256 j = 0; j < modelCount; j++) { + uint256 modelId = nodeSupportedModels[i].at(j); + supportedModels[index][j] = modelId; + modelAllocations[index][j] = nodeModelConfigs[i][modelId].allocatedFucus; + } + + index++; + } + } + } } diff --git a/contracts/src/SertnServiceManager.sol b/contracts/src/SertnServiceManager.sol index 8abcf30..a1a5e2e 100644 --- a/contracts/src/SertnServiceManager.sol +++ b/contracts/src/SertnServiceManager.sol @@ -43,15 +43,6 @@ contract SertnServiceManager is IModelRegistry public modelRegistry; ISertnRegistrar public sertnRegistrar; - // Operator info - // mapping(address => bytes) public opInfo; - // The number of nodes that a given operator has - // mapping(address => uint256) public operatorNodeCount; - // Compute units for a given operator-node - // mapping(address => mapping(uint256 => uint256)) public operatorNodeComputeUnits; - // Which models a given operator node supports - // mapping(address => mapping(uint256 => mapping(uint256 => bool))) public operatorNodeModelIds; - // Set of aggregators EnumerableSet.AddressSet internal aggregators; @@ -222,11 +213,11 @@ contract SertnServiceManager is } /// @inheritdoc ISertnServiceManager - function taskCompleted( + function taskResolved( address _operator, uint256 _fee, IStrategy _strategy, - uint32 _startTimestamp + uint32 _startTimestamp // TODO: do we need that? ) external onlyTaskManager nonReentrant { uint32 currentInterval = this.getCurrentInterval(); @@ -247,7 +238,7 @@ contract SertnServiceManager is intervalRewards[currentInterval][address(_strategy)][_operator] += _fee; - emit TaskRewardAccumulated(_operator, _fee, currentInterval); + emit TaskRewardAccumulated(_operator, _fee, _strategy.underlyingToken(), currentInterval); } /// @inheritdoc ISertnServiceManager diff --git a/contracts/src/SertnTaskManager.sol b/contracts/src/SertnTaskManager.sol index 42a14c6..c694f40 100644 --- a/contracts/src/SertnTaskManager.sol +++ b/contracts/src/SertnTaskManager.sol @@ -19,7 +19,6 @@ import {IModelRegistry} from "../interfaces/IModelRegistry.sol"; import {ModelRegistry} from "./ModelRegistry.sol"; import {SertnNodesManager} from "./SertnNodesManager.sol"; - contract SertnTaskManager is OwnableUpgradeable, ISertnTaskManager { using EnumerableSet for EnumerableSet.UintSet; // queue of tasks that are waiting to be assigned to an operator @@ -33,6 +32,19 @@ contract SertnTaskManager is OwnableUpgradeable, ISertnTaskManager { // all assigned tasks IDs, which are not resolved and not rejected EnumerableSet.UintSet private pendingTasks; + // History tracking mappings for efficient task history queries + // modelId => array of task IDs + mapping(uint256 => uint256[]) public tasksByModel; + // operator address => array of task IDs + mapping(address => uint256[]) public tasksByOperator; + // user/aggregator address => array of task IDs + mapping(address => uint256[]) public tasksByUser; + // TaskState => array of task IDs + mapping(TaskState => uint256[]) public tasksByState; + + // Mapping to track task indices for efficient removal (if needed later) + mapping(uint256 => mapping(uint8 => uint256)) private taskIndexInState; // taskId => state => index + IERC20 public ser; IAllocationManager public allocationManager; @@ -85,6 +97,9 @@ contract SertnTaskManager is OwnableUpgradeable, ISertnTaskManager { tasks[task.nonce].state = TaskState.ASSIGNED; pendingTasks.add(task.nonce); + // Add to history tracking + _addTaskToHistory(task.nonce, tasks[task.nonce]); + // Allocate FUCUs for this task _allocateFucusForTask(task.nonce); @@ -111,13 +126,9 @@ contract SertnTaskManager is OwnableUpgradeable, ISertnTaskManager { } tasks[taskId].output = output; + _updateTaskStateInHistory(taskId, tasks[taskId].state, TaskState.COMPLETED); tasks[taskId].state = TaskState.COMPLETED; emit TaskCompleted(taskId, task.operator); - - OperatorSet memory operatorSet = allocationManager.getAllocatedSets(task.operator)[0]; - IStrategy strategy = allocationManager.getAllocatedStrategies(task.operator, operatorSet)[ - 0 - ]; } function challengeTask(uint256 taskId) external onlyAggregators { @@ -141,8 +152,8 @@ contract SertnTaskManager is OwnableUpgradeable, ISertnTaskManager { if (task.state != TaskState.COMPLETED) { revert TaskStateIncorrect(TaskState.COMPLETED); } + _updateTaskStateInHistory(taskId, tasks[taskId].state, TaskState.CHALLENGED); tasks[taskId].state = TaskState.CHALLENGED; - // TODO: maybe operator address instead of msg.sender? emit TaskChallenged(taskId, msg.sender); } @@ -160,16 +171,18 @@ contract SertnTaskManager is OwnableUpgradeable, ISertnTaskManager { ]; if (success) { - sertnServiceManager.taskCompleted( + sertnServiceManager.taskResolved( task.operator, task.fee, strategy, task.startTimestamp ); + _updateTaskStateInHistory(task.nonce, task.state, TaskState.RESOLVED); tasks[taskId].state = TaskState.RESOLVED; emit TaskResolved(taskId, task.operator); } else { sertnServiceManager.slashOperator(task.operator, task.fee, operatorSet.id, strategy); + _updateTaskStateInHistory(task.nonce, task.state, TaskState.REJECTED); tasks[taskId].state = TaskState.REJECTED; emit TaskRejected(taskId, task.operator); } @@ -276,4 +289,178 @@ contract SertnTaskManager is OwnableUpgradeable, ISertnTaskManager { // Release the allocated FUCUs sertnNodesManager.releaseFucusForTask(task.operator, task.modelId, requiredFucus); } + + // === TASK HISTORY HELPER FUNCTIONS === + + /** + * @notice Internal helper function to paginate an array of task IDs + * @param taskArray Storage reference to the array to paginate + * @param offset Starting index (from the end of array - most recent first) + * @param limit Maximum number of results + * @return Paginated array of task IDs (most recent first) + */ + function _paginateTaskIds( + uint256[] storage taskArray, + uint256 offset, + uint256 limit + ) internal view returns (uint256[] memory) { + if (offset >= taskArray.length) { + return new uint256[](0); + } + + // Calculate how many items we can actually return + uint256 availableItems = taskArray.length - offset; + uint256 quantityToReturn = limit > availableItems ? availableItems : limit; + + // Calculate start index from the end (most recent first) + uint256 startIndex = availableItems - 1; + + uint256[] memory result = new uint256[](quantityToReturn); + + // Fill result array from most recent to oldest + for (uint256 i = 0; i < quantityToReturn; i++) { + result[i] = taskArray[startIndex - i]; + } + + return result; + } + + /** + * @notice Add a task to the history tracking mappings + * @param taskId The ID of the task + * @param task The task data + */ + function _addTaskToHistory(uint256 taskId, Task memory task) internal { + // Add to model history + tasksByModel[task.modelId].push(taskId); + + // Add to operator history + tasksByOperator[task.operator].push(taskId); + + // Add to user history + tasksByUser[task.user].push(taskId); + + // Add to state history + tasksByState[task.state].push(taskId); + taskIndexInState[taskId][uint8(task.state)] = tasksByState[task.state].length - 1; + } + + /** + * @notice Update task state in history tracking + * @param taskId The ID of the task + * @param oldState The previous state + * @param newState The new state + */ + function _updateTaskStateInHistory( + uint256 taskId, + TaskState oldState, + TaskState newState + ) internal { + // Remove from old state array + uint256 oldIndex = taskIndexInState[taskId][uint8(oldState)]; + uint256[] storage oldStateArray = tasksByState[oldState]; + uint256 lastTaskId = oldStateArray[oldStateArray.length - 1]; + + // Move last element to the position of the element to remove + oldStateArray[oldIndex] = lastTaskId; + taskIndexInState[lastTaskId][uint8(oldState)] = oldIndex; + + // Remove last element + oldStateArray.pop(); + + // Add to new state array + tasksByState[newState].push(taskId); + taskIndexInState[taskId][uint8(newState)] = tasksByState[newState].length - 1; + } + + // === TASK HISTORY QUERY FUNCTIONS === + + /** + * @notice Get paginated task IDs for a specific model + * @param modelId The model ID to query + * @param offset Starting index + * @param limit Maximum number of results + * @return Array of task IDs (paginated) + */ + function getTasksByModel( + uint256 modelId, + uint256 offset, + uint256 limit + ) external view returns (uint256[] memory) { + return _paginateTaskIds(tasksByModel[modelId], offset, limit); + } + + /** + * @notice Get paginated task IDs for a specific operator + * @param operator The operator address to query + * @param offset Starting index + * @param limit Maximum number of results + * @return Array of task IDs (paginated) + */ + function getTasksByOperator( + address operator, + uint256 offset, + uint256 limit + ) external view returns (uint256[] memory) { + return _paginateTaskIds(tasksByOperator[operator], offset, limit); + } + + /** + * @notice Get paginated task IDs for a specific user + * @param user The user address to query + * @param offset Starting index + * @param limit Maximum number of results + * @return Array of task IDs (paginated) + */ + function getTasksByUser( + address user, + uint256 offset, + uint256 limit + ) external view returns (uint256[] memory) { + return _paginateTaskIds(tasksByUser[user], offset, limit); + } + + /** + * @notice Get all task IDs with a specific state + * @param state The task state to query + * @param offset Starting index + * @param limit Maximum number of results + * @return Array of task IDs (paginated) + */ + function getTasksByState( + TaskState state, + uint256 offset, + uint256 limit + ) external view returns (uint256[] memory) { + return _paginateTaskIds(tasksByState[state], offset, limit); + } + + /** + * @notice Get task history counts for overview statistics + * @return totalTasks Total number of tasks + * @return resolvedTasks Number of completed/resolved tasks + * @return rejectedTasks Number of rejected tasks + * @return pendingTasksCount Number of pending/assigned/challenged tasks + */ + function getTaskHistoryStats() + external + view + returns ( + uint256 totalTasks, + uint256 resolvedTasks, + uint256 rejectedTasks, + uint256 pendingTasksCount + ) + { + totalTasks = taskNonce - 1; // -1 because nonce starts at 1 + + // For detailed stats, you would need to iterate through the tasks + // This is a basic implementation + resolvedTasks = tasksByState[TaskState.RESOLVED].length; + rejectedTasks = tasksByState[TaskState.REJECTED].length; + pendingTasksCount = + tasksByState[TaskState.ASSIGNED].length + + tasksByState[TaskState.CHALLENGED].length + + tasksByState[TaskState.COMPLETED].length; + } } diff --git a/contracts/test/ModelRegistry.t.sol b/contracts/test/ModelRegistry.t.sol index 61f40c9..7bd92e9 100644 --- a/contracts/test/ModelRegistry.t.sol +++ b/contracts/test/ModelRegistry.t.sol @@ -298,4 +298,86 @@ contract ModelRegistryTest is Test { vm.stopPrank(); } + + function test_getActiveModelsWithDetails_empty_when_no_active_models() public { + vm.startPrank(owner.key.addr); + + // Disable the only active model + modelRegistry.disableModel(modelId); + + // Get active models with details + IModelRegistry.ModelDetails[] memory models = modelRegistry.getActiveModelsWithDetails(); + + // Should have no models + assertEq(models.length, 0, "Should have no active models when all are disabled"); + + vm.stopPrank(); + } + + function test_getActiveModelsWithDetails_updates_after_model_changes() public { + vm.startPrank(owner.key.addr); + + // Create second model + uint256 modelId2 = modelRegistry.createNewModel( + address(mockVerifier2), + IModelRegistry.VerificationStrategy.Offchain, + "model2", + 200, + 15 + ); + + // Update first model's properties + modelRegistry.updateModelName(modelId, "updatedModel1"); + modelRegistry.updateComputeCost(modelId, 500); + modelRegistry.updateRequiredFUCUs(modelId, 25); + modelRegistry.updateVerificationStrategy( + modelId, + IModelRegistry.VerificationStrategy.Offchain + ); + + // Get active models with details + IModelRegistry.ModelDetails[] memory models = modelRegistry.getActiveModelsWithDetails(); + assertEq(models.length, 2, "Should have 2 active models"); + + // Verify both models are present + bool foundModel1 = false; + bool foundModel2 = false; + + // Find and verify the updated model + for (uint256 i = 0; i < models.length; i++) { + if (models[i].modelId == modelId) { + assertEq( + keccak256(abi.encodePacked(models[i].modelName)), + keccak256(abi.encodePacked("updatedModel1")), + "Updated model name should be reflected" + ); + assertEq(models[i].computeCost, 500, "Updated compute cost should be reflected"); + assertEq(models[i].requiredFUCUs, 25, "Updated FUCUs should be reflected"); + assertEq( + uint256(models[i].verificationStrategy), + uint256(IModelRegistry.VerificationStrategy.Offchain), + "Updated strategy should be reflected" + ); + foundModel1 = true; + } else if (models[i].modelId == modelId2) { + // Verify second model remains unchanged + assertEq( + keccak256(abi.encodePacked(models[i].modelName)), + keccak256(abi.encodePacked("model2")), + "Second model name should remain unchanged" + ); + assertEq( + models[i].computeCost, + 200, + "Second model compute cost should remain unchanged" + ); + assertEq(models[i].requiredFUCUs, 15, "Second model FUCUs should remain unchanged"); + foundModel2 = true; + } + } + assertTrue(foundModel1, "Updated model should be found with correct properties"); + assertTrue(foundModel2, "Second model should be found with original properties"); + + vm.stopPrank(); + } } diff --git a/contracts/test/SertnNodesManager.t.sol b/contracts/test/SertnNodesManager.t.sol index 107adff..6f7cbe1 100644 --- a/contracts/test/SertnNodesManager.t.sol +++ b/contracts/test/SertnNodesManager.t.sol @@ -802,4 +802,137 @@ contract SertnNodesManagerTest is Test { assertEq(nodesManager.getTotalFucusForOperatorModel(operator1.addr, modelId1), 0); } + + // ============ GET ALL NODES WITH DETAILS TESTS ============ + + function testGetAllNodesWithDetailsEmpty() public view { + ( + uint256[8][] memory nodeDetails, + uint256[][] memory supportedModels, + uint256[][] memory modelAllocations + ) = nodesManager.getAllNodesWithDetails(); + + assertEq(nodeDetails.length, 0); + assertEq(supportedModels.length, 0); + assertEq(modelAllocations.length, 0); + } + + function testGetAllNodesWithDetailsActiveOnly() public { + // Register multiple nodes with different operators + vm.startPrank(operator1.addr); + uint256 nodeId1 = nodesManager.registerNode("Node 1", "metadata1", 1000); + nodesManager.addModelSupport(nodeId1, modelId1, 300); + nodesManager.addModelSupport(nodeId1, modelId2, 200); + vm.stopPrank(); + + vm.startPrank(operator2.addr); + uint256 nodeId2 = nodesManager.registerNode("Node 2", "metadata2", 2000); + nodesManager.addModelSupport(nodeId2, modelId1, 500); + + // Create an inactive node (should not be returned) + uint256 nodeId3 = nodesManager.registerNode("Node 3", "metadata3", 1500); + nodesManager.deactivateNode(nodeId3); + vm.stopPrank(); + + // Test getAllNodesWithDetails (should return only active nodes) + ( + uint256[8][] memory nodeDetails, + uint256[][] memory supportedModels, + uint256[][] memory modelAllocations + ) = nodesManager.getAllNodesWithDetails(); + + // Should return 2 active nodes (nodeId1 and nodeId2), excluding inactive nodeId3 + assertEq(nodeDetails.length, 2); + assertEq(supportedModels.length, 2); + assertEq(modelAllocations.length, 2); + + // Verify first node (nodeId1) + assertEq(nodeDetails[0][0], nodeId1); // nodeId + assertEq(nodeDetails[0][1], uint256(uint160(operator1.addr))); // operator as uint256 + assertEq(nodeDetails[0][2], 1000); // totalFucus + assertEq(nodeDetails[0][3], 500); // allocatedFucus (300 + 200) + assertEq(nodeDetails[0][4], 500); // availableFucus (1000 - 500) + assertEq(nodeDetails[0][5], 1); // isActive (1 = true) + assertEq(nodeDetails[0][6], block.timestamp); // createdAt + assertEq(nodeDetails[0][7], 2); // supportedModelsCount + + // Verify supported models for first node + assertEq(supportedModels[0].length, 2); + assertEq(supportedModels[0][0], modelId1); + assertEq(supportedModels[0][1], modelId2); + assertEq(modelAllocations[0].length, 2); + assertEq(modelAllocations[0][0], 300); + assertEq(modelAllocations[0][1], 200); + + // Verify second node (nodeId2) + assertEq(nodeDetails[1][0], nodeId2); // nodeId + assertEq(nodeDetails[1][1], uint256(uint160(operator2.addr))); // operator as uint256 + assertEq(nodeDetails[1][2], 2000); // totalFucus + assertEq(nodeDetails[1][3], 500); // allocatedFucus + assertEq(nodeDetails[1][4], 1500); // availableFucus (2000 - 500) + assertEq(nodeDetails[1][5], 1); // isActive (1 = true) + assertEq(nodeDetails[1][6], block.timestamp); // createdAt + assertEq(nodeDetails[1][7], 1); // supportedModelsCount + + // Verify supported models for second node + assertEq(supportedModels[1].length, 1); + assertEq(supportedModels[1][0], modelId1); + assertEq(modelAllocations[1].length, 1); + assertEq(modelAllocations[1][0], 500); + } + + function testGetAllNodesWithDetailsNoModelSupport() public { + // Register a node without model support + vm.startPrank(operator1.addr); + uint256 nodeId = nodesManager.registerNode("Empty Node", "no_models", 1000); + vm.stopPrank(); + + ( + uint256[8][] memory nodeDetails, + uint256[][] memory supportedModels, + uint256[][] memory modelAllocations + ) = nodesManager.getAllNodesWithDetails(); + + assertEq(nodeDetails.length, 1); + assertEq(nodeDetails[0][0], nodeId); + assertEq(nodeDetails[0][3], 0); // allocatedFucus = 0 + assertEq(nodeDetails[0][4], 1000); // availableFucus = totalFucus + assertEq(nodeDetails[0][7], 0); // supportedModelsCount = 0 + + // Verify empty model support arrays + assertEq(supportedModels[0].length, 0); + assertEq(modelAllocations[0].length, 0); + } + + function testGetAllNodesWithDetailsMultipleModels() public { + // Register a node with multiple model supports + vm.startPrank(operator1.addr); + uint256 nodeId = nodesManager.registerNode("Multi Model Node", "supports_many", 2000); + nodesManager.addModelSupport(nodeId, modelId1, 800); + nodesManager.addModelSupport(nodeId, modelId2, 600); + vm.stopPrank(); + + ( + uint256[8][] memory nodeDetails, + uint256[][] memory supportedModels, + uint256[][] memory modelAllocations + ) = nodesManager.getAllNodesWithDetails(); + + assertEq(nodeDetails.length, 1); + assertEq(nodeDetails[0][0], nodeId); + assertEq(nodeDetails[0][2], 2000); // totalFucus + assertEq(nodeDetails[0][3], 1400); // allocatedFucus (800 + 600) + assertEq(nodeDetails[0][4], 600); // availableFucus (2000 - 1400) + assertEq(nodeDetails[0][7], 2); // supportedModelsCount + + // Verify model support arrays + assertEq(supportedModels[0].length, 2); + assertEq(modelAllocations[0].length, 2); + + // Models should be in the order they were added + assertEq(supportedModels[0][0], modelId1); + assertEq(supportedModels[0][1], modelId2); + assertEq(modelAllocations[0][0], 800); + assertEq(modelAllocations[0][1], 600); + } } diff --git a/contracts/test/SertnServiceManager.t.sol b/contracts/test/SertnServiceManager.t.sol index e69be81..387eef6 100644 --- a/contracts/test/SertnServiceManager.t.sol +++ b/contracts/test/SertnServiceManager.t.sol @@ -272,7 +272,7 @@ contract SertnServiceManagerTest is Test { serviceManager.pullFeeFromUser(user, IERC20(address(mockToken1)), 1000); } - function test_taskCompleted_success() public { + function test_taskResolved_success() public { // Setup: Add aggregator and mint tokens vm.prank(owner); serviceManager.addAggregator(aggregator1); @@ -287,8 +287,8 @@ contract SertnServiceManagerTest is Test { // Test: Complete task (called by task manager) vm.prank(address(mockTaskManager)); vm.expectEmit(true, true, false, true); - emit ISertnServiceManager.TaskRewardAccumulated(operator, feeAmount, 0); // currentInterval is mocked as 0 - serviceManager.taskCompleted( + emit ISertnServiceManager.TaskRewardAccumulated(operator, feeAmount, mockToken1, 0); // currentInterval is mocked as 0 + serviceManager.taskResolved( operator, feeAmount, IStrategy(address(mockStrategy1)), @@ -313,10 +313,10 @@ contract SertnServiceManagerTest is Test { // No revert means success } - function test_taskCompleted_revertNotTaskManager() public { + function test_taskResolved_revertNotTaskManager() public { vm.expectRevert(ISertnServiceManager.NotTaskManager.selector); vm.prank(aggregator1); - serviceManager.taskCompleted( + serviceManager.taskResolved( operator, 1000, IStrategy(address(mockStrategy1)), @@ -396,7 +396,7 @@ contract SertnServiceManagerTest is Test { ); // Verify model exists - assertEq(modelRegistry.modelName(modelId), "test_model"); + assertEq(modelRegistry.modelName(modelId), "test_model"); assertEq(address(serviceManager.modelRegistry()), address(modelRegistry)); vm.stopPrank(); @@ -429,7 +429,7 @@ contract SertnServiceManagerTest is Test { // Workflow: Task completed, rewards distributed vm.prank(address(mockTaskManager)); - serviceManager.taskCompleted( + serviceManager.taskResolved( operator, feeAmount, IStrategy(address(mockStrategy1)), @@ -439,7 +439,7 @@ contract SertnServiceManagerTest is Test { // Verify final state assertEq(mockToken1.balanceOf(address(serviceManager)), feeAmount); assertTrue(serviceManager.isAggregator(aggregator1)); - assertEq(modelRegistry.modelName(modelId), "workflow_model"); + assertEq(modelRegistry.modelName(modelId), "workflow_model"); uint32 currentInterval = serviceManager.getCurrentInterval(); diff --git a/contracts/test/SertnTaskManager.t.sol b/contracts/test/SertnTaskManager.t.sol index a2b4fc9..0abb084 100644 --- a/contracts/test/SertnTaskManager.t.sol +++ b/contracts/test/SertnTaskManager.t.sol @@ -428,6 +428,309 @@ contract SertnTaskManagerTest is Test { assertEq(taskManager.taskNonce(), task.nonce + 2); } + // === HISTORY TESTS === + + function test_getTasksByModel() public { + // Create a second model for testing + vm.startPrank(owner); + uint256 modelId2 = modelRegistry.createNewModel( + address(new MockVerifier()), + IModelRegistry.VerificationStrategy.Onchain, + "test_model_2", + 200, + 20 + ); + vm.stopPrank(); + + // Send tasks for both models + ISertnTaskManager.Task memory task1 = _createValidTask(); + vm.prank(aggregator); + taskManager.sendTask(task1); + + ISertnTaskManager.Task memory task2 = _createValidTask(); + task2.modelId = modelId2; + task2.nonce = 2; + vm.prank(aggregator); + taskManager.sendTask(task2); + + ISertnTaskManager.Task memory task3 = _createValidTask(); + task3.nonce = 3; + vm.prank(aggregator); + taskManager.sendTask(task3); + + // Get tasks for first model (should have 2 tasks: task1 and task3) + uint256[] memory tasksModel1 = taskManager.getTasksByModel(modelId, 0, 10); + assertEq(tasksModel1.length, 2, "Model 1 should have 2 tasks"); + assertTrue( + (tasksModel1[0] == 1 && tasksModel1[1] == 3) || + (tasksModel1[0] == 3 && tasksModel1[1] == 1), + "Should contain correct task IDs for model 1" + ); + + // Get tasks for second model (should have 1 task: task2) + uint256[] memory tasksModel2 = taskManager.getTasksByModel(modelId2, 0, 10); + assertEq(tasksModel2.length, 1, "Model 2 should have 1 task"); + assertEq(tasksModel2[0], 2, "Should contain correct task ID for model 2"); + } + + function test_getTasksByModel_pagination() public { + // Send 5 tasks with the same model + for (uint256 i = 0; i < 5; i++) { + ISertnTaskManager.Task memory task = _createValidTask(); + task.nonce = i + 1; + vm.prank(aggregator); + taskManager.sendTask(task); + } + + // Test pagination - first page (limit 3) + uint256[] memory page1 = taskManager.getTasksByModel(modelId, 0, 3); + assertEq(page1.length, 3, "First page should have 3 items"); + assertEq(page1[0], 5, "First item should be task 5"); + assertEq(page1[1], 4, "Second item should be task 4"); + assertEq(page1[2], 3, "Third item should be task 3"); + + // Test pagination - second page (offset 3, limit 3) + uint256[] memory page2 = taskManager.getTasksByModel(modelId, 3, 3); + assertEq(page2.length, 2, "Second page should have 2 remaining items"); + assertEq(page2[0], 2, "First item on page 2 should be task 2"); + assertEq(page2[1], 1, "Second item on page 2 should be task 1"); + + // Test pagination - beyond available data + uint256[] memory page3 = taskManager.getTasksByModel(modelId, 10, 3); + assertEq(page3.length, 0, "Should return empty array when offset is beyond data"); + } + + function test_getTasksByOperator() public { + address operator2 = vm.addr(100); + + // Setup allocation for second operator + vm.startPrank(owner); + OperatorSet[] memory sets = new OperatorSet[](1); + sets[0] = OperatorSet({id: 2, avs: address(0)}); + mockAllocationManager.setAllocatedSets(operator2, sets); + + IStrategy[] memory strategies = new IStrategy[](1); + strategies[0] = mockStrategy; + mockAllocationManager.setAllocatedStrategies(operator2, sets[0], strategies); + vm.stopPrank(); + + // Send tasks for different operators + ISertnTaskManager.Task memory task1 = _createValidTask(); + vm.prank(aggregator); + taskManager.sendTask(task1); + + ISertnTaskManager.Task memory task2 = _createValidTask(); + task2.operator = operator2; + task2.nonce = 2; + vm.prank(aggregator); + taskManager.sendTask(task2); + + ISertnTaskManager.Task memory task3 = _createValidTask(); + task3.nonce = 3; + vm.prank(aggregator); + taskManager.sendTask(task3); + + // Get tasks for first operator (should have 2 tasks) + uint256[] memory tasksOp1 = taskManager.getTasksByOperator(operator, 0, 10); + assertEq(tasksOp1.length, 2, "Operator 1 should have 2 tasks"); + + // Get tasks for second operator (should have 1 task) + uint256[] memory tasksOp2 = taskManager.getTasksByOperator(operator2, 0, 10); + assertEq(tasksOp2.length, 1, "Operator 2 should have 1 task"); + assertEq(tasksOp2[0], 2, "Should contain correct task ID for operator 2"); + + // Get tasks for an operator with no tasks + address unknownOperator = vm.addr(999); + uint256[] memory tasks = taskManager.getTasksByOperator(unknownOperator, 0, 10); + assertEq(tasks.length, 0, "Should return empty array for operator with no tasks"); + } + + function test_getTasksByUser() public { + address user2 = vm.addr(200); + + // Send tasks for different users + ISertnTaskManager.Task memory task1 = _createValidTask(); + vm.prank(aggregator); + taskManager.sendTask(task1); + + ISertnTaskManager.Task memory task2 = _createValidTask(); + task2.user = user2; + task2.nonce = 2; + vm.prank(aggregator); + taskManager.sendTask(task2); + + ISertnTaskManager.Task memory task3 = _createValidTask(); + task3.nonce = 3; + vm.prank(aggregator); + taskManager.sendTask(task3); + + // Get tasks for first user (should have 2 tasks) + uint256[] memory tasksUser1 = taskManager.getTasksByUser(user, 0, 10); + assertEq(tasksUser1.length, 2, "User 1 should have 2 tasks"); + + // Get tasks for second user (should have 1 task) + uint256[] memory tasksUser2 = taskManager.getTasksByUser(user2, 0, 10); + assertEq(tasksUser2.length, 1, "User 2 should have 1 task"); + assertEq(tasksUser2[0], 2, "Should contain correct task ID for user 2"); + + // Get tasks for an user with no tasks + address unknownUser = vm.addr(999); + uint256[] memory tasks = taskManager.getTasksByUser(unknownUser, 0, 10); + assertEq(tasks.length, 0, "Should return empty array for user with no tasks"); + } + + function test_getTasksByState_assigned() public { + // Send tasks with different states + ISertnTaskManager.Task memory task1 = _createValidTask(); + vm.prank(aggregator); + taskManager.sendTask(task1); // This will be in ASSIGNED state + + ISertnTaskManager.Task memory task2 = _createValidTask(); + task2.nonce = 2; + vm.prank(aggregator); + taskManager.sendTask(task2); // This will be in ASSIGNED state + + // Get tasks in ASSIGNED state (should be 2) + uint256[] memory assignedTasks = taskManager.getTasksByState( + ISertnTaskManager.TaskState.ASSIGNED, + 0, + 10 + ); + assertEq(assignedTasks.length, 2, "Should have 2 task in ASSIGNED state"); + + // Complete one task + vm.prank(operator); + taskManager.submitTaskOutput(1, "output"); + + // Get tasks in ASSIGNED state (should be 1) + assignedTasks = taskManager.getTasksByState(ISertnTaskManager.TaskState.ASSIGNED, 0, 10); + assertEq(assignedTasks.length, 1, "Should have 1 task in ASSIGNED state"); + assertEq(assignedTasks[0], 2, "Task 2 should be in ASSIGNED state"); + + // Get tasks in COMPLETED state (should be 1) + uint256[] memory completedTasks = taskManager.getTasksByState( + ISertnTaskManager.TaskState.COMPLETED, + 0, + 10 + ); + assertEq(completedTasks.length, 1, "Should have 1 task in COMPLETED state"); + assertEq(completedTasks[0], 1, "Task 1 should be in COMPLETED state"); + + // Get tasks in REJECTED state (should be 0) + uint256[] memory rejectedTasks = taskManager.getTasksByState( + ISertnTaskManager.TaskState.REJECTED, + 0, + 10 + ); + assertEq(rejectedTasks.length, 0, "Should have 0 tasks in REJECTED state"); + } + + function test_getTaskHistoryStats_global() public { + // Send and process several tasks with different outcomes + + // Task 1: Send and resolve + ISertnTaskManager.Task memory task1 = _createValidTask(); + task1.nonce = 1; + vm.prank(aggregator); + taskManager.sendTask(task1); + + vm.prank(operator); + taskManager.submitTaskOutput(task1.nonce, "output1"); + + vm.prank(aggregator); + taskManager.challengeTask(task1.nonce); + + vm.prank(aggregator); + taskManager.resolveTask(task1.nonce, true); // RESOLVED + + // Task 2: Send and reject + ISertnTaskManager.Task memory task2 = _createValidTask(); + task2.nonce = 2; + vm.prank(aggregator); + taskManager.sendTask(task2); + + vm.prank(operator); + taskManager.submitTaskOutput(2, "output2"); + + vm.prank(aggregator); + taskManager.challengeTask(2); + + vm.prank(aggregator); + taskManager.resolveTask(2, false); // REJECTED + + // Task 3: Send and leave pending + ISertnTaskManager.Task memory task3 = _createValidTask(); + task3.nonce = 3; + vm.prank(aggregator); + taskManager.sendTask(task3); // ASSIGNED (pending) + + // Get global stats (all parameters zero/null) + ( + uint256 totalTasks, + uint256 resolvedTasks, + uint256 rejectedTasks, + uint256 pendingTasksCount + ) = taskManager.getTaskHistoryStats(); + + assertEq(totalTasks, 3, "Should have 3 total tasks"); + assertEq(resolvedTasks, 1, "Should have 1 resolved task"); + assertEq(rejectedTasks, 1, "Should have 1 rejected task"); + assertEq(pendingTasksCount, 1, "Should have 1 pending task"); + } + + function test_task_history_tracking_on_send() public { + // Send a task and verify it's tracked in all relevant mappings + ISertnTaskManager.Task memory task = _createValidTask(); + vm.prank(aggregator); + taskManager.sendTask(task); + + // Verify task is in model history + uint256[] memory modelTasks = taskManager.getTasksByModel(modelId, 0, 10); + assertEq(modelTasks.length, 1, "Task should be tracked in model history"); + assertEq(modelTasks[0], 1, "Task ID should match"); + + // Verify task is in operator history + uint256[] memory operatorTasks = taskManager.getTasksByOperator(operator, 0, 10); + assertEq(operatorTasks.length, 1, "Task should be tracked in operator history"); + assertEq(operatorTasks[0], 1, "Task ID should match"); + + // Verify task is in user history + uint256[] memory userTasks = taskManager.getTasksByUser(user, 0, 10); + assertEq(userTasks.length, 1, "Task should be tracked in user history"); + assertEq(userTasks[0], 1, "Task ID should match"); + + // Verify task is in state history + uint256[] memory stateTasks = taskManager.getTasksByState( + ISertnTaskManager.TaskState.ASSIGNED, + 0, + 10 + ); + assertEq(stateTasks.length, 1, "Task should be tracked in state history"); + assertEq(stateTasks[0], 1, "Task ID should match"); + } + + function test_pagination_edge_cases() public { + // Test empty results with various offset/limit combinations + uint256[] memory empty1 = taskManager.getTasksByModel(modelId, 0, 0); + assertEq(empty1.length, 0, "Should return empty array with limit 0"); + + uint256[] memory empty2 = taskManager.getTasksByModel(modelId, 100, 10); + assertEq(empty2.length, 0, "Should return empty array with high offset"); + + // Send one task + ISertnTaskManager.Task memory task = _createValidTask(); + vm.prank(aggregator); + taskManager.sendTask(task); + + // Test with offset equal to array length + uint256[] memory edge1 = taskManager.getTasksByModel(modelId, 1, 10); + assertEq(edge1.length, 0, "Should return empty array when offset equals array length"); + + // Test with limit larger than remaining items + uint256[] memory edge2 = taskManager.getTasksByModel(modelId, 0, 100); + assertEq(edge2.length, 1, "Should return all available items when limit is larger"); + } + // Helper function to create a valid task function _createValidTask() internal view returns (ISertnTaskManager.Task memory) { return diff --git a/contracts/test/mockContracts/SertnNodesManagerMock.sol b/contracts/test/mockContracts/SertnNodesManagerMock.sol index 3023665..e716875 100644 --- a/contracts/test/mockContracts/SertnNodesManagerMock.sol +++ b/contracts/test/mockContracts/SertnNodesManagerMock.sol @@ -201,4 +201,36 @@ contract SertnNodesManagerMock is ISertnNodesManager { availableFucus = new uint256[](1); availableFucus[0] = 1000000; } + + function getAllNodesWithDetails() + external + pure + returns ( + uint256[8][] memory nodeDetails, + uint256[][] memory supportedModels, + uint256[][] memory modelAllocations + ) + { + nodeDetails = new uint256[8][](1); + nodeDetails[0] = [ + 1, // nodeId + uint256(uint160(address(0x1))), // operator + 1000000, // totalFucus + 0, // allocatedFucus + 1000000, // availableFucus + 1, // supportedModelsCount + 1, // active (bool as uint) + 0 // createdAt + ]; + + supportedModels = new uint256[][](1); + supportedModels[0] = new uint256[](1); + supportedModels[0][0] = 1; // modelId + + modelAllocations = new uint256[][](1); + modelAllocations[0] = new uint256[](1); + modelAllocations[0][0] = 1000000; // allocatedFucus + + return (nodeDetails, supportedModels, modelAllocations); + } } diff --git a/contracts/test/mockContracts/SertnServiceManagerMock.sol b/contracts/test/mockContracts/SertnServiceManagerMock.sol index c445b21..799f768 100644 --- a/contracts/test/mockContracts/SertnServiceManagerMock.sol +++ b/contracts/test/mockContracts/SertnServiceManagerMock.sol @@ -20,7 +20,7 @@ contract MockSertnServiceManager { emit FeesPulled(user, token, fee); } - function taskCompleted(address operator, uint256 fee, address strategy, uint32) external { + function taskResolved(address operator, uint256 fee, address strategy, uint32) external { emit TaskCompleted(operator, fee, strategy); } diff --git a/cspell.json b/cspell.json index 02b3dfd..d554211 100644 --- a/cspell.json +++ b/cspell.json @@ -7,6 +7,7 @@ "addfinalizer", "autouse", "avsregistry", + "bips", "blsagg", "CHAINID", "chainio", @@ -24,6 +25,7 @@ "keccak", "Localnet", "LSTM", + "Memorystore", "Offchain", "operatorsinfo", "pbar", @@ -38,6 +40,7 @@ "Slashable", "timespan", "tupple", + "urandom", "venv", "ZKML" ],