From a546c65a386b6ccef3c518aee292cf839a071e06 Mon Sep 17 00:00:00 2001 From: Derek Johnston Date: Sat, 9 Aug 2025 13:21:18 +0100 Subject: [PATCH 1/3] prototype archon --- CLAUDE.md | 97 ++++ experimental/__init__.py | 0 experimental/archon/__init__.py | 0 experimental/archon/src/__init__.py | 0 experimental/archon/src/graph_executor.py | 111 +++++ experimental/archon/src/graph_loader.py | 431 ++++++++++++++++++ experimental/archon/src/graph_models.py | 229 ++++++++++ experimental/archon/tests/__init__.py | 0 experimental/archon/tests/conftest.py | 81 ++++ .../archon/tests/test_graph_executor.py | 133 ++++++ .../archon/tests/test_graph_loader.py | 188 ++++++++ .../archon/tests/test_type_preservation.py | 112 +++++ experimental/archon/tests/test_validation.py | 75 +++ pyproject.toml | 6 + requirements.txt | 2 +- 15 files changed, 1464 insertions(+), 1 deletion(-) create mode 100644 CLAUDE.md create mode 100644 experimental/__init__.py create mode 100644 experimental/archon/__init__.py create mode 100644 experimental/archon/src/__init__.py create mode 100644 experimental/archon/src/graph_executor.py create mode 100644 experimental/archon/src/graph_loader.py create mode 100644 experimental/archon/src/graph_models.py create mode 100644 experimental/archon/tests/__init__.py create mode 100644 experimental/archon/tests/conftest.py create mode 100644 experimental/archon/tests/test_graph_executor.py create mode 100644 experimental/archon/tests/test_graph_loader.py create mode 100644 experimental/archon/tests/test_type_preservation.py create mode 100644 experimental/archon/tests/test_validation.py diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..3130744d --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,97 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Development Commands + +### Setup and Dependencies +- `uv venv` - Create virtual environment +- `source .venv/bin/activate` - Activate virtual environment +- `./scripts/install_deps.sh` - Install dependencies +- `uv pip install -e .[dev]` - Install project in development mode with dev dependencies + +### Code Quality and Testing +- `./scripts/run_checks.sh` - Run complete test suite (ruff format, ruff check, mypy, pytest) +- `uv run ruff format .` - Format code +- `uv run ruff check .` - Lint code +- `uv run mypy src` - Type check source code +- `uv run pytest` - Run all tests +- `uv run pytest tests/test_specific.py` - Run specific test file +- `uv run isort .` - Sort imports + +### Running the Application +- `uv run talos` - Start interactive CLI (requires OPENAI_API_KEY, PINATA_API_KEY, PINATA_SECRET_API_KEY environment variables) +- `uv run talos proposals eval --file ` - Evaluate a proposal from file +- `uv run talos twitter get-user-prompt ` - Get Twitter user persona prompt + +## Architecture Overview + +Talos is an AI protocol owner system built with a sophisticated multi-agent architecture: + +### Core Components + +**MainAgent** (`src/talos/core/main_agent.py`): The top-level orchestrating agent that: +- Delegates tasks using a Router to different services and skills +- Manages scheduled jobs for autonomous execution +- Integrates with Hypervisor for action approval +- Handles dataset management for contextual information retrieval + +**Agent Base Class** (`src/talos/core/agent.py`): Foundation for all agents providing: +- LangChain integration with chat models +- Message history management and memory persistence +- Tool management and supervised tool execution +- Context building with dataset search integration + +**Hypervisor** (`src/talos/hypervisor/hypervisor.py`): Security and governance layer that: +- Monitors and approves/denies agent actions +- Uses dedicated prompts to evaluate action safety +- Maintains agent history for context-aware decisions + +### Key Architectural Patterns + +**Router Pattern** (`src/talos/core/router.py`): Delegates tasks to appropriate services and skills based on request classification. + +**Skill-based Architecture** (`src/talos/skills/`): Modular capabilities including: +- ProposalsSkill - Governance proposal evaluation +- TwitterSentimentSkill - Social media sentiment analysis +- CryptographySkill - Cryptographic operations +- TwitterInfluenceSkill - Influence analysis + +**Service Layer** (`src/talos/services/`): Abstract and concrete implementations for external integrations: +- GitHub service for repository management +- Twitter service for social media interaction +- Proposal service for governance workflows +- Yield management for DeFi operations + +**Tool Management** (`src/talos/tools/`): Extensible tool system with: +- SupervisedTool base class requiring hypervisor approval +- Platform-specific tools (Twitter, GitHub, DexScreener) +- Document loading and search capabilities +- Memory management tools + +### Data Flow + +1. User input → MainAgent → Router → appropriate Service/Skill +2. Actions requiring approval → Hypervisor → approval/denial +3. Context building → DatasetManager → relevant document retrieval +4. Tool execution → ToolManager → supervised execution + +## Code Style Guidelines + +- Use `from __future__ import annotations` for forward references +- Prefer `list` and `dict` over `List` and `Dict` in type hints +- Use `model_post_init` for Pydantic model initialization logic +- Set `arbitrary_types_allowed=True` in ConfigDict for LangChain integration +- Default LLM model is `gpt-4o` +- Line length limit is 120 characters (configured in ruff) +- All function signatures require type hints + +## Environment Variables + +Required for full functionality: +- `OPENAI_API_KEY` - OpenAI API access +- `GITHUB_TOKEN` - GitHub API access +- `PINATA_API_KEY` / `PINATA_SECRET_API_KEY` - IPFS storage +- Additional service-specific keys as needed +- Always run scripts/run_checks.sh after performing a large code change. +- ⏺ Code commenting principle: Only include comments that explain why something is done or provide context that isn't obvious from the code itself - avoid comments that simply restate what the code is doing, as these add noise without value. \ No newline at end of file diff --git a/experimental/__init__.py b/experimental/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/experimental/archon/__init__.py b/experimental/archon/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/experimental/archon/src/__init__.py b/experimental/archon/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/experimental/archon/src/graph_executor.py b/experimental/archon/src/graph_executor.py new file mode 100644 index 00000000..67a4603c --- /dev/null +++ b/experimental/archon/src/graph_executor.py @@ -0,0 +1,111 @@ +""" +LangGraph Executor Module + +Provides GraphExecutor class for loading and executing stored graphs from IPFS. +This encapsulates the process of retrieving a graph by its IPFS hash and running +it with provided input, handling both sync and async execution automatically. +""" + +from __future__ import annotations + +from typing import Any, Union + +from langgraph.graph.graph import CompiledGraph +from pydantic import BaseModel + +from .graph_loader import GraphLoader + + +class LoadedGraph: + """Container for a loaded graph with its state class.""" + + def __init__(self, compiled_graph: CompiledGraph, state_class: type[BaseModel]): + self.compiled_graph = compiled_graph + self.state_class = state_class + + async def execute(self, input_state: Union[BaseModel, dict[str, Any]]) -> dict[str, Any]: + """Execute the graph with type-aware input handling.""" + # Convert Pydantic model to dict if needed (LangGraph expects dict) + if isinstance(input_state, BaseModel): + state_dict = input_state.model_dump() + else: + state_dict = input_state + + return await self.compiled_graph.ainvoke(state_dict) + + def create_state(self, **kwargs) -> BaseModel: + """Create a properly typed state instance.""" + return self.state_class(**kwargs) + + def validate_result(self, result: dict[str, Any]) -> BaseModel: + """Validate and convert result dict back to typed state.""" + return self.state_class.model_validate(result) + + +class GraphExecutor: + """ + Executes stored LangGraph workflows loaded from IPFS with enhanced type safety. + + Provides both simple execution and advanced type-aware methods. + + Example: + >>> executor = GraphExecutor() + >>> # Simple execution + >>> result = await executor.execute_graph(ipfs_hash, {"input": "data"}) + >>> + >>> # Type-aware execution + >>> loaded = executor.load_graph(ipfs_hash) + >>> state = loaded.create_state(input_text="Hello world!") + >>> result_dict = await loaded.execute(state) + >>> typed_result = loaded.validate_result(result_dict) + """ + + def __init__(self) -> None: + """Initialize the GraphExecutor with a GraphLoader instance.""" + self.loader = GraphLoader() + + def load_graph(self, ipfs_hash: str) -> LoadedGraph: + """ + Load a graph and expose its state class for type-safe usage. + + Args: + ipfs_hash: IPFS hash of the stored graph definition + + Returns: + LoadedGraph with both compiled graph and state class + + Raises: + ImportError: If graph functions or state class cannot be loaded + ValueError: If graph structure is invalid + """ + # Load the stored definition first to get state class + stored_definition = self.loader.retrieve_from_ipfs(ipfs_hash) + state_class = self.loader._load_class_from_reference(stored_definition.state_schema.class_reference) + + # Load the compiled graph + compiled_graph = self.loader.load_graph(ipfs_hash) + + return LoadedGraph(compiled_graph, state_class) + + async def execute_graph(self, ipfs_hash: str, input_state: dict[str, Any]) -> dict[str, Any]: + """ + Load and execute a graph from IPFS with the given input state. + + Args: + ipfs_hash: IPFS hash of the stored graph definition + input_state: Input state dictionary matching the graph's state schema + + Returns: + Final state dictionary after graph execution + + Raises: + ImportError: If graph functions cannot be loaded + ValueError: If graph structure is invalid + Exception: If graph execution fails + + Note: + This method always uses async execution (ainvoke) which works + universally for sync, async, and mixed graphs. + """ + loaded_graph = self.load_graph(ipfs_hash) + return await loaded_graph.execute(input_state) diff --git a/experimental/archon/src/graph_loader.py b/experimental/archon/src/graph_loader.py new file mode 100644 index 00000000..c97040bd --- /dev/null +++ b/experimental/archon/src/graph_loader.py @@ -0,0 +1,431 @@ +""" +LangGraph IPFS Storage Module + +Provides GraphLoader class for serializing LangGraph workflows and storing them +on IPFS via Pinata for decentralized, governance-ready AI agent management. + +This module combines: +- LangGraph's native graph serialization +- Pydantic models for validation and structure +- IPFS storage via Pinata for immutable, content-addressed storage +""" + +from __future__ import annotations + +import importlib +import os +import tempfile +from typing import Any, Awaitable, Callable, Hashable, Union + +import requests +from langgraph.graph import END, START, StateGraph +from langgraph.graph.graph import CompiledGraph +from pinata_python.pinning import Pinning + +from .graph_models import ( + ConditionalEdgeDefinition, + ExecutionConfig, + GraphEdgeDefinition, + GraphMetadata, + GraphNodeDefinition, + SerializableGraphDefinition, + StateChannelDefinition, + StateSchema, + StoredGraphDefinition, +) + + +class GraphLoader: + """ + Manages serialization, IPFS storage, and deserialization of LangGraph workflows. + + Combines LangGraph's native serialization with Pydantic validation and + IPFS storage via Pinata for decentralized, governance-ready AI workflows. + + Example: + >>> loader = GraphLoader() + >>> ipfs_hash = loader.save_graph(compiled_graph, "my_workflow", "Description") + >>> recreated_graph = loader.load_graph(ipfs_hash) + """ + + def __init__(self) -> None: + """Initialize the GraphLoader.""" + pass + + def serialize_graph_from_builder( + self, + state_graph: StateGraph, + name: str, + description: str, + created_by: str = "experimental_poc", + ) -> str: + """ + Serialize graph from StateGraph builder (before compilation) to capture serializable references. + + Args: + state_graph: The StateGraph builder (before calling .compile()) + name: Name for the workflow + description: Description of what the workflow does + created_by: Who created this workflow + + Returns: + JSON string representation of the serializable graph definition + + This approach extracts function references and graph structure from the StateGraph + builder before compilation, allowing for proper deserialization. + """ + # Extract serializable node definitions + nodes = [] + for node_name, node_spec in state_graph.nodes.items(): + # Extract function information from StateNodeSpec + if not hasattr(node_spec, "runnable"): + raise ValueError(f"Node '{node_name}' missing expected 'runnable' attribute") + + runnable = node_spec.runnable + + # For sync functions: use runnable.func + # For async functions: runnable.func is None, use runnable.afunc + if hasattr(runnable, "func") and runnable.func is not None: + func = runnable.func + elif hasattr(runnable, "afunc") and runnable.afunc is not None: + func = runnable.afunc + else: + raise ValueError( + f"Node '{node_name}' has no valid function reference. " + f"Expected runnable.func or runnable.afunc to be set" + ) + + if not hasattr(func, "__module__") or not hasattr(func, "__name__"): + raise ValueError(f"Function for node '{node_name}' missing __module__ or __name__ attributes") + + function_reference = f"{func.__module__}:{func.__name__}" + + nodes.append(GraphNodeDefinition(name=node_name, function_reference=function_reference)) + + # Extract simple edges + edges = [] + for edge_tuple in getattr(state_graph, "edges", set()): + source, target = edge_tuple + edges.append(GraphEdgeDefinition(source=source, target=target)) + + # Extract conditional edges from branches + conditional_edges = [] + branches = getattr(state_graph, "branches", {}) + for source_node, branch_dict in branches.items(): + # The branch dict has function names as keys + for func_name, branch_obj in branch_dict.items(): + if not hasattr(branch_obj, "path"): + raise ValueError(f"Conditional edge from '{source_node}' missing expected 'path' attribute") + + # Extract condition function similar to node functions + if hasattr(branch_obj.path, "func") and branch_obj.path.func is not None: + condition_func = branch_obj.path.func + elif hasattr(branch_obj.path, "afunc") and branch_obj.path.afunc is not None: + condition_func = branch_obj.path.afunc + else: + raise ValueError( + f"Conditional edge from '{source_node}' has no valid condition function. " + f"Expected branch_obj.path.func or branch_obj.path.afunc to be set" + ) + + if not hasattr(condition_func, "__module__") or not hasattr(condition_func, "__name__"): + raise ValueError( + f"Condition function for edge from '{source_node}' missing __module__ or __name__ attributes" + ) + + if not hasattr(branch_obj, "ends"): + raise ValueError(f"Conditional edge from '{source_node}' missing expected 'ends' attribute") + + condition_function_reference = f"{condition_func.__module__}:{condition_func.__name__}" + conditional_edges.append( + ConditionalEdgeDefinition( + source_node=source_node, + condition_function_reference=condition_function_reference, + target_mapping=branch_obj.ends, + ) + ) + + # Extract state channel information + state_channels: list[StateChannelDefinition] = [] + if hasattr(state_graph, "state_schema"): + # This would need to be implemented based on StateGraph internals + pass + + # Create serializable graph definition + serializable_def = SerializableGraphDefinition( + nodes=nodes, + edges=edges, + conditional_edges=conditional_edges, + state_channels=state_channels, + state_type_name=( + state_graph.state_schema.__name__ if hasattr(state_graph, "state_schema") else "UnknownState" + ), + ) + + # Create state schema from the StateGraph's state schema + state_schema = self._extract_state_schema(state_graph) + + # Create complete stored definition + stored_definition = StoredGraphDefinition( + metadata=GraphMetadata( + name=name, + description=description, + created_by=created_by, + ), + graph_definition=serializable_def, + state_schema=state_schema, + execution_config=ExecutionConfig(), + ) + + return stored_definition.model_dump_json(indent=2) + + def _extract_state_schema(self, state_graph: StateGraph) -> StateSchema: + """ + Extract Pydantic state schema information from StateGraph builder. + + Args: + state_graph: StateGraph builder instance + + Returns: + StateSchema with complete Pydantic model information + + Raises: + ValueError: If state schema is not a Pydantic BaseModel + """ + # Get the state schema class from the StateGraph + if not hasattr(state_graph, "schemas") or not state_graph.schemas: + raise ValueError("StateGraph must have a state schema defined") + + # Get the first (and should be only) schema class + schema_class = next(iter(state_graph.schemas.keys())) + + # Verify it's a Pydantic BaseModel + try: + from pydantic import BaseModel + + if not issubclass(schema_class, BaseModel): + raise ValueError(f"State schema must be a Pydantic BaseModel, got {type(schema_class)}") + except (TypeError, ImportError) as e: + raise ValueError(f"Invalid state schema class: {e}") + + # Extract module and class information + if not hasattr(schema_class, "__module__") or not hasattr(schema_class, "__name__"): + raise ValueError("State schema class missing __module__ or __name__ attributes") + + class_reference = f"{schema_class.__module__}:{schema_class.__name__}" + + return StateSchema( + name=schema_class.__name__, + class_reference=class_reference, + description=f"State schema for {schema_class.__name__}", + ) + + def store_to_ipfs(self, graph_json: str) -> str: + """ + Store serialized graph on IPFS via Pinata and return hash. + + Args: + graph_json: JSON string representation of the graph + + Returns: + IPFS hash of the stored content + """ + # Get credentials from environment when needed + api_key = os.getenv("PINATA_API_KEY") + secret_key = os.getenv("PINATA_SECRET_API_KEY") + + if not api_key or not secret_key: + raise ValueError("PINATA_API_KEY and PINATA_SECRET_API_KEY environment variables required for IPFS storage") + + # Initialize Pinata client + pinata = Pinning(PINATA_API_KEY=api_key, PINATA_API_SECRET=secret_key) + + # Pin JSON content to IPFS via Pinata + # First save to temp file since pinata-python expects a file path + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + f.write(graph_json) + temp_path: str = f.name + + try: + response: dict[str, Any] = pinata.pin_file_to_ipfs(temp_path) + return response["IpfsHash"] + finally: + os.unlink(temp_path) + + def retrieve_from_ipfs(self, ipfs_hash: str) -> StoredGraphDefinition: + """ + Retrieve and deserialize graph from IPFS via Pinata gateway. + + Args: + ipfs_hash: IPFS hash of the stored graph definition + + Returns: + StoredGraphDefinition object with validated data + """ + gateway_url = f"https://gateway.pinata.cloud/ipfs/{ipfs_hash}" + response = requests.get(gateway_url) + response.raise_for_status() + + return StoredGraphDefinition.model_validate(response.json()) + + def recreate_graph_from_definition(self, stored_definition: StoredGraphDefinition) -> StateGraph: + """ + Recreate a StateGraph from its serializable definition. + + Args: + stored_definition: StoredGraphDefinition with SerializableGraphDefinition + + Returns: + Recreated StateGraph ready for compilation + + Note: + This implementation dynamically imports and loads functions based on + the serialized function references (module + name). + """ + + # Dynamically load the Pydantic model class + state_class = self._load_class_from_reference(stored_definition.state_schema.class_reference) + + # Verify it's a Pydantic BaseModel + try: + from pydantic import BaseModel + + if not issubclass(state_class, BaseModel): + raise ValueError(f"Loaded state class must be a Pydantic BaseModel, got {type(state_class)}") + except (TypeError, ImportError) as e: + raise ValueError(f"Invalid loaded state class: {e}") + + # Create StateGraph with the proper Pydantic model + builder = StateGraph(state_class) + + for node_def in stored_definition.graph_definition.nodes: + try: + func = self._load_function_from_reference(node_def.function_reference) + builder.add_node(node_def.name, func) + except Exception as e: + raise ImportError(f"Failed to load node {node_def.name}: {e}") + + for edge_def in stored_definition.graph_definition.edges: + source = START if edge_def.source == "__start__" else edge_def.source + target = END if edge_def.target == "__end__" else edge_def.target + builder.add_edge(source, target) + for cond_edge_def in stored_definition.graph_definition.conditional_edges: + try: + condition_func = self._load_function_from_reference(cond_edge_def.condition_function_reference) + target_mapping: dict[Hashable, str] = {} + for key, value in cond_edge_def.target_mapping.items(): + target_value = END if value == "__end__" else value + target_mapping[key] = target_value + builder.add_conditional_edges(cond_edge_def.source_node, condition_func, target_mapping) + except Exception as e: + raise ImportError(f"Failed to load conditional edge from {cond_edge_def.source_node}: {e}") + + return builder + + def _load_function_from_reference( + self, function_reference: str + ) -> Union[Callable[..., Any], Callable[..., Awaitable[Any]]]: + """ + Dynamically load a function from a module:function reference string. + + Args: + function_reference: Function reference in format "module.path:function_name" + + Returns: + The loaded function object (can be sync or async) + """ + try: + module_name, function_name = function_reference.split(":", 1) + module = importlib.import_module(module_name) + return getattr(module, function_name) + except (ImportError, AttributeError, ValueError) as e: + raise ImportError(f"Could not load function from reference '{function_reference}': {e}") + + def _load_class_from_reference(self, class_reference: str) -> type: + """ + Dynamically load a class from a module:class reference string. + + Args: + class_reference: Class reference in format "module.path:ClassName" + + Returns: + The loaded class object + + Raises: + ImportError: If the class cannot be loaded + """ + try: + module_name, class_name = class_reference.split(":", 1) + module = importlib.import_module(module_name) + cls = getattr(module, class_name) + if not isinstance(cls, type): + raise ValueError(f"'{class_name}' is not a class") + return cls + except (ImportError, AttributeError, ValueError) as e: + raise ImportError(f"Could not load class from reference '{class_reference}': {e}") + + def recreate_graph(self, stored_definition: StoredGraphDefinition) -> CompiledGraph: + """ + Recreate and compile a LangGraph from its stored definition. + + Args: + stored_definition: StoredGraphDefinition object from IPFS + + Returns: + Recreated CompiledGraph ready for execution + """ + state_graph = self.recreate_graph_from_definition(stored_definition) + return state_graph.compile() + + def save_graph_from_builder( + self, + state_graph: StateGraph, + name: str, + description: str, + created_by: str = "experimental_poc", + ) -> str: + """ + High-level method: serialize StateGraph builder and store to IPFS in one operation. + + Args: + state_graph: The StateGraph builder (before compilation) + name: Name for the workflow + description: Description of what the workflow does + created_by: Who created this workflow + + Returns: + IPFS hash of the stored graph definition + """ + graph_json = self.serialize_graph_from_builder(state_graph, name, description, created_by) + return self.store_to_ipfs(graph_json) + + def load_graph(self, ipfs_hash: str) -> CompiledGraph: + """ + High-level method: retrieve from IPFS and recreate graph in one operation. + + Args: + ipfs_hash: IPFS hash of the stored graph definition + + Returns: + Recreated CompiledGraph ready for execution + + Note: + The returned graph can contain both sync and async functions. + Use `await graph.ainvoke(input)` for execution - this works + universally for pure sync, pure async, and mixed graphs. + """ + stored_definition = self.retrieve_from_ipfs(ipfs_hash) + return self.recreate_graph(stored_definition) + + def get_graph_info(self, ipfs_hash: str) -> GraphMetadata: + """ + Get metadata about a stored graph without recreating it. + + Args: + ipfs_hash: IPFS hash of the stored graph definition + + Returns: + GraphMetadata with information about the stored graph + """ + stored_definition = self.retrieve_from_ipfs(ipfs_hash) + return stored_definition.metadata diff --git a/experimental/archon/src/graph_models.py b/experimental/archon/src/graph_models.py new file mode 100644 index 00000000..bd172515 --- /dev/null +++ b/experimental/archon/src/graph_models.py @@ -0,0 +1,229 @@ +""" +Pydantic models for LangGraph workflow storage and serialization. + +This module defines the data structures used for storing LangGraph workflows +on IPFS with proper validation and type safety. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any + +from langchain_core.runnables.graph import Edge, Node +from pydantic import BaseModel, ConfigDict, Field, field_validator + + +class Immutable: + """Mixin class that makes Pydantic models immutable using frozen=True.""" + + model_config = ConfigDict(frozen=True) + + +class GraphModel(BaseModel, Immutable): + """ + Base class for all graph-related models. + """ + + @staticmethod + def _validate_function_reference(value: str) -> str: + """Validate function reference format and security constraints.""" + if not isinstance(value, str): + raise ValueError("Function reference must be a string") + + if ":" not in value: + raise ValueError("Function reference must be in format 'module.path:function_name'") + + if value.count(":") != 1: + raise ValueError("Function reference must have exactly one ':' separator") + + module_path, function_name = value.split(":", 1) + + if not module_path or not function_name: + raise ValueError("Both module path and function name must be non-empty") + + # Security: No relative path components + if ".." in module_path: + raise ValueError("Relative path components (..) are not allowed in module path") + + # Security: Must start with 'experimental' for now + if not module_path.startswith("experimental"): + raise ValueError("Module path must start with 'experimental' for security") + + # Validate function name is a valid Python identifier + if not function_name.isidentifier(): + raise ValueError(f"Function name '{function_name}' is not a valid Python identifier") + + return value + + +class GraphMetadata(GraphModel): + """Metadata for a stored graph definition.""" + + name: str + version: str = "1.0.0" + description: str + created_by: str + created_at: datetime = Field(default_factory=datetime.now) + langgraph_version: str = "0.2.60" + + +# Note: We create Pydantic-compatible wrappers for LangGraph's Node and Edge types +# to ensure compatibility with our immutable model architecture while maintaining +# compatibility with LangGraph's structure. + + +class LangGraphNode(GraphModel): + """Pydantic-compatible wrapper for LangGraph's Node type.""" + + id: str + name: str + data: Any + metadata: dict[str, Any] | None = None + + @classmethod + def from_langgraph_node(cls, node: Node) -> "LangGraphNode": + """Create from a LangGraph Node.""" + return cls(id=node.id, name=node.name, data=node.data, metadata=node.metadata) + + def to_langgraph_node(self) -> Node: + """Convert to a LangGraph Node.""" + return Node(id=self.id, name=self.name, data=self.data, metadata=self.metadata) + + +class LangGraphEdge(GraphModel): + """Pydantic-compatible wrapper for LangGraph's Edge type.""" + + source: str + target: str + data: Any | None = None + conditional: bool = False + + @classmethod + def from_langgraph_edge(cls, edge: Edge) -> "LangGraphEdge": + """Create from a LangGraph Edge.""" + return cls( + source=edge.source, + target=edge.target, + data=edge.data, + conditional=edge.conditional, + ) + + def to_langgraph_edge(self) -> Edge: + """Convert to a LangGraph Edge.""" + return Edge( + source=self.source, + target=self.target, + data=self.data, + conditional=self.conditional, + ) + + +class GraphNodeDefinition(GraphModel): + """Serializable representation of a graph node before compilation.""" + + name: str + function_reference: str = Field(description="Format: 'module.path:function_name'") + metadata: dict[str, Any] | None = None + input_type: str | None = None + + @field_validator("function_reference") + @classmethod + def validate_function_reference(cls, v: str) -> str: + return cls._validate_function_reference(v) + + +class GraphEdgeDefinition(GraphModel): + """Serializable representation of a simple graph edge.""" + + source: str + target: str + + +class ConditionalEdgeDefinition(GraphModel): + """Serializable representation of a conditional edge.""" + + source_node: str + condition_function_reference: str = Field(description="Format: 'module.path:function_name'") + target_mapping: dict[str, str] + + @field_validator("condition_function_reference") + @classmethod + def validate_condition_function_reference(cls, v: str) -> str: + return cls._validate_function_reference(v) + + +class StateChannelDefinition(GraphModel): + """Serializable representation of a state channel.""" + + name: str + channel_type: str + default_value: Any | None = None + + +class SerializableGraphDefinition(GraphModel): + """Complete serializable graph definition from StateGraph builder (before compilation).""" + + nodes: list[GraphNodeDefinition] + edges: list[GraphEdgeDefinition] + conditional_edges: list[ConditionalEdgeDefinition] + state_channels: list[StateChannelDefinition] + state_type_name: str + + +class LangGraphDefinition(GraphModel): + """Structured representation of LangGraph's native to_json() output using wrapper types.""" + + nodes: list[LangGraphNode] + edges: list[LangGraphEdge] + + +class StateSchema(GraphModel): + """Pydantic-only state schema representation.""" + + name: str + class_reference: str = Field(description="Format: 'module.path:ClassName'") + description: str | None = None + + @field_validator("class_reference") + @classmethod + def validate_class_reference(cls, v: str) -> str: + return cls._validate_function_reference(v) + + +class ExecutionConfig(GraphModel): + """Configuration options for graph execution.""" + + checkpointer: str | None = None + debug: bool = False + stream_mode: str = "values" + max_iterations: int | None = None + recursion_limit: int | None = None + + +class StoredGraphDefinition(GraphModel): + """Complete stored graph definition using serializable pre-compilation data.""" + + type: str = "serializable_graph_definition" + version: str = "2.0.0" # Updated version for new approach + metadata: GraphMetadata + + # Serializable graph definition from StateGraph builder (before compilation) + graph_definition: SerializableGraphDefinition = Field( + description="Complete serializable graph structure from StateGraph builder" + ) + + # LangGraph's native representation (optional, for reference) + langgraph_definition: LangGraphDefinition | None = Field( + default=None, + description="Optional: LangGraph's compiled representation for reference", + ) + + # Additional info that LangGraph doesn't capture + state_schema: StateSchema = Field(description="Structured information about the state schema") + + # Execution configuration + execution_config: ExecutionConfig = Field( + default_factory=ExecutionConfig, + description="Structured execution configuration options", + ) diff --git a/experimental/archon/tests/__init__.py b/experimental/archon/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/experimental/archon/tests/conftest.py b/experimental/archon/tests/conftest.py new file mode 100644 index 00000000..46e017b1 --- /dev/null +++ b/experimental/archon/tests/conftest.py @@ -0,0 +1,81 @@ +""" +Shared test fixtures for Archon graph storage tests. +""" + +from __future__ import annotations + +import pytest +from langgraph.graph import END, START, StateGraph +from pydantic import BaseModel, Field + + +class SentimentState(BaseModel): + """Pydantic state model for sentiment analysis workflow.""" + + input_text: str = Field(description="Input text to analyze") + sentiment_score: float = Field(default=0.0, description="Sentiment score between 0 and 1") + decision: str = Field(default="", description="Decision based on sentiment analysis") + final_action: str = Field(default="", description="Final action taken based on decision") + + +def analyze_sentiment(state: SentimentState) -> SentimentState: + """Mock sentiment analysis node.""" + text = state.input_text + mock_score = 0.8 if "good" in text.lower() else 0.3 + return state.model_copy(update={"sentiment_score": mock_score}) + + +async def make_decision(state: SentimentState) -> SentimentState: + """Decision node based on sentiment score.""" + decision = "positive" if state.sentiment_score > 0.6 else "negative" + return state.model_copy(update={"decision": decision}) + + +def take_action(state: SentimentState) -> SentimentState: + """Action node for positive sentiment.""" + action = f"Taking positive action based on score {state.sentiment_score}" + return state.model_copy(update={"final_action": action}) + + +def no_action(state: SentimentState) -> SentimentState: + """No action node for negative sentiment.""" + action = f"No action needed, score too low: {state.sentiment_score}" + return state.model_copy(update={"final_action": action}) + + +def should_take_action(state: SentimentState) -> str: + """Conditional edge function to decide next node.""" + return "action" if state.decision == "positive" else "no_action" + + +@pytest.fixture +def sentiment_graph_builder() -> StateGraph: + """ + Create the sentiment analysis graph from the POC. + Returns the StateGraph builder (before compilation). + """ + builder = StateGraph(SentimentState) + + # Add nodes + builder.add_node("analyze", analyze_sentiment) + builder.add_node("decide", make_decision) + builder.add_node("action", take_action) + builder.add_node("no_action", no_action) + + # Add edges + builder.add_edge(START, "analyze") + builder.add_edge("analyze", "decide") + builder.add_conditional_edges("decide", should_take_action, {"action": "action", "no_action": "no_action"}) + builder.add_edge("action", END) + builder.add_edge("no_action", END) + + return builder + + +@pytest.fixture +def mock_ipfs_storage(): + """ + Create an in-memory storage to mock IPFS. + Returns a dictionary that will store graph definitions by hash. + """ + return {} diff --git a/experimental/archon/tests/test_graph_executor.py b/experimental/archon/tests/test_graph_executor.py new file mode 100644 index 00000000..a05c8f92 --- /dev/null +++ b/experimental/archon/tests/test_graph_executor.py @@ -0,0 +1,133 @@ +""" +Tests for GraphExecutor - Graph execution from IPFS-stored definitions. +""" + +from __future__ import annotations + +import json + +import pytest + +from experimental.archon.src.graph_executor import GraphExecutor +from experimental.archon.src.graph_models import StoredGraphDefinition + + +class TestGraphExecutorWithMockedIPFS: + """Test GraphExecutor with mocked IPFS storage.""" + + @pytest.mark.asyncio + async def test_execute_stored_graph(self, sentiment_graph_builder, mock_ipfs_storage, monkeypatch): + """Test that we can execute a stored graph through the executor.""" + + def mock_store(graph_json: str) -> str: + fake_hash = f"Qm{hash(graph_json)}" + mock_ipfs_storage[fake_hash] = graph_json + return fake_hash + + def mock_retrieve(ipfs_hash: str) -> StoredGraphDefinition: + json_data = mock_ipfs_storage[ipfs_hash] + return StoredGraphDefinition.model_validate(json.loads(json_data)) + + executor = GraphExecutor() + + # Use monkeypatch to replace methods - this is type-safe + monkeypatch.setattr(executor.loader, "store_to_ipfs", mock_store) + monkeypatch.setattr(executor.loader, "retrieve_from_ipfs", mock_retrieve) + + # Step 1: Store the graph using the underlying loader + ipfs_hash = executor.loader.save_graph_from_builder( + sentiment_graph_builder, + name="executor_test_workflow", + description="Test workflow for executor", + ) + + # Step 2: Execute the stored graph with positive sentiment input + positive_input = { + "input_text": "This is good news!", + "sentiment_score": 0.0, + "decision": "", + "final_action": "", + } + + positive_result = await executor.execute_graph(ipfs_hash, positive_input) + + # Verify positive sentiment path + assert positive_result["input_text"] == "This is good news!" + assert positive_result["sentiment_score"] == 0.8 # "good" triggers positive score + assert positive_result["decision"] == "positive" + assert "Taking positive action" in positive_result["final_action"] + assert "0.8" in positive_result["final_action"] + + # Step 3: Execute with negative sentiment input + negative_input = { + "input_text": "This is terrible news!", + "sentiment_score": 0.0, + "decision": "", + "final_action": "", + } + + negative_result = await executor.execute_graph(ipfs_hash, negative_input) + + # Verify negative sentiment path + assert negative_result["input_text"] == "This is terrible news!" + assert negative_result["sentiment_score"] == 0.3 # "terrible" doesn't contain "good" + assert negative_result["decision"] == "negative" + assert "No action needed" in negative_result["final_action"] + assert "0.3" in negative_result["final_action"] + + @pytest.mark.asyncio + async def test_type_aware_execution(self, sentiment_graph_builder, mock_ipfs_storage, monkeypatch): + """Test the new type-aware LoadedGraph functionality.""" + + def mock_store(graph_json: str) -> str: + fake_hash = f"Qm{hash(graph_json)}" + mock_ipfs_storage[fake_hash] = graph_json + return fake_hash + + def mock_retrieve(ipfs_hash: str) -> StoredGraphDefinition: + json_data = mock_ipfs_storage[ipfs_hash] + return StoredGraphDefinition.model_validate(json.loads(json_data)) + + executor = GraphExecutor() + monkeypatch.setattr(executor.loader, "store_to_ipfs", mock_store) + monkeypatch.setattr(executor.loader, "retrieve_from_ipfs", mock_retrieve) + + # Store the graph + ipfs_hash = executor.loader.save_graph_from_builder( + sentiment_graph_builder, + name="type_aware_test", + description="Test type-aware execution", + ) + + # Load the graph with type information + loaded_graph = executor.load_graph(ipfs_hash) + + # Verify we have access to the state class + assert hasattr(loaded_graph, "state_class") + assert loaded_graph.state_class.__name__ == "SentimentState" + + # Create typed input using the state class + typed_input = loaded_graph.create_state( + input_text="This is good news!", + sentiment_score=0.0, + decision="", + final_action="", + ) + + # Verify the created state is properly typed + from experimental.archon.tests.conftest import SentimentState + + assert isinstance(typed_input, SentimentState) + assert typed_input.input_text == "This is good news!" + + # Execute with the Pydantic model directly + result = await loaded_graph.execute(typed_input) + + # Validate and convert result back to typed state + typed_result = loaded_graph.validate_result(result) + + assert isinstance(typed_result, SentimentState) + assert typed_result.input_text == "This is good news!" + assert typed_result.sentiment_score == 0.8 # "good" triggers positive score + assert typed_result.decision == "positive" + assert "Taking positive action" in typed_result.final_action diff --git a/experimental/archon/tests/test_graph_loader.py b/experimental/archon/tests/test_graph_loader.py new file mode 100644 index 00000000..41a3ce16 --- /dev/null +++ b/experimental/archon/tests/test_graph_loader.py @@ -0,0 +1,188 @@ +""" +Tests for GraphLoader - Graph storage and retrieval from IPFS. +""" + +from __future__ import annotations + +import json + +import pytest + +from experimental.archon.src.graph_loader import GraphLoader +from experimental.archon.src.graph_models import StoredGraphDefinition + +from .conftest import analyze_sentiment, make_decision, no_action, take_action + + +class TestGraphLoaderWithMockedIPFS: + """Test GraphLoader with mocked IPFS storage.""" + + def test_end_to_end_save_load(self, sentiment_graph_builder, mock_ipfs_storage, monkeypatch): + """Test complete workflow: save graph to 'IPFS', retrieve it, and execute.""" + + def mock_store(graph_json: str) -> str: + fake_hash = f"Qm{hash(graph_json)}" + mock_ipfs_storage[fake_hash] = graph_json + return fake_hash + + def mock_retrieve(ipfs_hash: str) -> StoredGraphDefinition: + json_data = mock_ipfs_storage[ipfs_hash] + return StoredGraphDefinition.model_validate(json.loads(json_data)) + + # Create loader - no credentials needed since we mock the IPFS methods + loader = GraphLoader() + monkeypatch.setattr(loader, "store_to_ipfs", mock_store) + monkeypatch.setattr(loader, "retrieve_from_ipfs", mock_retrieve) + + # Step 1: Save the graph + ipfs_hash = loader.save_graph_from_builder( + sentiment_graph_builder, + name="sentiment_workflow", + description="Analyzes sentiment and takes conditional actions", + ) + + assert ipfs_hash in mock_ipfs_storage + + # Step 2: Retrieve the graph definition + retrieved_def = loader.retrieve_from_ipfs(ipfs_hash) + + assert retrieved_def.metadata.name == "sentiment_workflow" + assert retrieved_def.metadata.description == "Analyzes sentiment and takes conditional actions" + assert len(retrieved_def.graph_definition.nodes) == 4 + + # Step 3: Verify the serialized structure + nodes = retrieved_def.graph_definition.nodes + node_names = {node.name for node in nodes} + assert node_names == {"analyze", "decide", "action", "no_action"} + + # Verify edges + edges = retrieved_def.graph_definition.edges + assert len(edges) >= 3 + + # Verify conditional edges + cond_edges = retrieved_def.graph_definition.conditional_edges + assert len(cond_edges) == 1 + assert cond_edges[0].source_node == "decide" + assert cond_edges[0].target_mapping == { + "action": "action", + "no_action": "no_action", + } + + # Step 4: Verify the function references have correct structure + # Function references should be in format "module.path:function_name" + + # Create a mapping of expected function references + # The module will be __main__ when running tests or test_graph_loader + expected_references = { + "analyze": ("analyze_sentiment", analyze_sentiment), + "decide": ("make_decision", make_decision), + "action": ("take_action", take_action), + "no_action": ("no_action", no_action), + } + + # Verify each node's function reference + for node in nodes: + assert ":" in node.function_reference, ( + f"Invalid reference format for {node.name}: {node.function_reference}" + ) + + module_path, func_name = node.function_reference.split(":", 1) + expected_func_name, expected_func = expected_references[node.name] + + # Verify function name matches + assert func_name == expected_func_name, ( + f"Function name mismatch for {node.name}: expected {expected_func_name}, got {func_name}" + ) + + # Verify module path is reasonable (should be __main__ or test module) + assert module_path in [ + "__main__", + "test_graph_loader", + "experimental.archon.tests.test_graph_loader", + "experimental.archon.tests.conftest", + ], f"Unexpected module path for {node.name}: {module_path}" + + # Verify conditional edge function reference + assert len(cond_edges) == 1 + cond_edge = cond_edges[0] + assert ":" in cond_edge.condition_function_reference + + cond_module, cond_func = cond_edge.condition_function_reference.split(":", 1) + assert cond_func == "should_take_action", f"Conditional function name mismatch: {cond_func}" + assert cond_module in [ + "__main__", + "test_graph_loader", + "experimental.archon.tests.test_graph_loader", + "experimental.archon.tests.conftest", + ], f"Unexpected module for conditional function: {cond_module}" + + # Verify that all function references are complete and valid + all_refs = [node.function_reference for node in nodes] + [cond_edge.condition_function_reference] + for ref in all_refs: + # Each reference should have exactly one colon separator + assert ref.count(":") == 1, f"Invalid reference format: {ref}" + # Neither part should be empty + module_part, func_part = ref.split(":") + assert module_part, f"Empty module in reference: {ref}" + assert func_part, f"Empty function name in reference: {ref}" + + @pytest.mark.asyncio + async def test_save_load_and_execute_graph(self, sentiment_graph_builder, mock_ipfs_storage, monkeypatch): + """Test that we can save a graph, load it back, and execute it successfully.""" + + def mock_store(graph_json: str) -> str: + fake_hash = f"Qm{hash(graph_json)}" + mock_ipfs_storage[fake_hash] = graph_json + return fake_hash + + def mock_retrieve(ipfs_hash: str) -> StoredGraphDefinition: + json_data = mock_ipfs_storage[ipfs_hash] + return StoredGraphDefinition.model_validate(json.loads(json_data)) + + loader = GraphLoader() + monkeypatch.setattr(loader, "store_to_ipfs", mock_store) + monkeypatch.setattr(loader, "retrieve_from_ipfs", mock_retrieve) + + # Step 1: Save the graph to IPFS + ipfs_hash = loader.save_graph_from_builder( + sentiment_graph_builder, + name="executable_workflow", + description="Test executable workflow", + ) + + # Step 2: Load the graph back from IPFS + recreated_graph = loader.load_graph(ipfs_hash) + + # Step 3: Execute the recreated graph with positive sentiment input + positive_input = { + "input_text": "This is good news!", + "sentiment_score": 0.0, + "decision": "", + "final_action": "", + } + + positive_result = await recreated_graph.ainvoke(positive_input) + + # Verify positive sentiment path + assert positive_result["input_text"] == "This is good news!" + assert positive_result["sentiment_score"] == 0.8 # "good" triggers positive score + assert positive_result["decision"] == "positive" + assert "Taking positive action" in positive_result["final_action"] + assert "0.8" in positive_result["final_action"] + + # Step 4: Execute with negative sentiment input + negative_input = { + "input_text": "This is terrible news!", + "sentiment_score": 0.0, + "decision": "", + "final_action": "", + } + + negative_result = await recreated_graph.ainvoke(negative_input) + + # Verify negative sentiment path + assert negative_result["input_text"] == "This is terrible news!" + assert negative_result["sentiment_score"] == 0.3 # "terrible" doesn't contain "good" + assert negative_result["decision"] == "negative" + assert "No action needed" in negative_result["final_action"] + assert "0.3" in negative_result["final_action"] diff --git a/experimental/archon/tests/test_type_preservation.py b/experimental/archon/tests/test_type_preservation.py new file mode 100644 index 00000000..249f1e54 --- /dev/null +++ b/experimental/archon/tests/test_type_preservation.py @@ -0,0 +1,112 @@ +""" +Tests specifically for Pydantic type preservation in graph serialization/deserialization. +""" + +from __future__ import annotations + +import json + +import pytest + +from experimental.archon.src.graph_loader import GraphLoader +from experimental.archon.tests.conftest import SentimentState + + +class TestTypePreservation: + """Test that Pydantic types are properly preserved during serialization.""" + + def test_state_schema_extraction_preserves_types(self, sentiment_graph_builder, mock_ipfs_storage, monkeypatch): + """Test that state schema extraction captures complete Pydantic information.""" + + def mock_store(graph_json: str) -> str: + fake_hash = f"Qm{hash(graph_json)}" + mock_ipfs_storage[fake_hash] = graph_json + return fake_hash + + loader = GraphLoader() + monkeypatch.setattr(loader, "store_to_ipfs", mock_store) + + # Serialize the graph + graph_json = loader.serialize_graph_from_builder( + sentiment_graph_builder, + name="type_test_workflow", + description="Test workflow for type preservation", + ) + + # Parse and verify the JSON contains proper type information + stored_data = json.loads(graph_json) + state_schema = stored_data["state_schema"] + + # Verify it's recognized as a Pydantic model with proper class reference + assert "class_reference" in state_schema + assert state_schema["class_reference"] == "experimental.archon.tests.conftest:SentimentState" + assert state_schema["name"] == "SentimentState" + + # Verify we can load the class and generate the schema on demand + from experimental.archon.tests.conftest import SentimentState + + json_schema = SentimentState.model_json_schema() + + # Check that field types can be retrieved from the loaded class + properties = json_schema["properties"] + + # input_text should be string type + assert properties["input_text"]["type"] == "string" + assert properties["input_text"]["description"] == "Input text to analyze" + + # sentiment_score should be number type with default + assert properties["sentiment_score"]["type"] == "number" + assert properties["sentiment_score"]["default"] == 0.0 + assert properties["sentiment_score"]["description"] == "Sentiment score between 0 and 1" + + # decision should be string with default + assert properties["decision"]["type"] == "string" + assert properties["decision"]["default"] == "" + + # final_action should be string with default + assert properties["final_action"]["type"] == "string" + assert properties["final_action"]["default"] == "" + + @pytest.mark.asyncio + async def test_type_preservation_round_trip(self, sentiment_graph_builder, mock_ipfs_storage, monkeypatch): + """Test that types are preserved through complete serialization round trip.""" + + def mock_store(graph_json: str) -> str: + fake_hash = f"Qm{hash(graph_json)}" + mock_ipfs_storage[fake_hash] = graph_json + return fake_hash + + def mock_retrieve(ipfs_hash: str): + from experimental.archon.src.graph_models import StoredGraphDefinition + + json_data = mock_ipfs_storage[ipfs_hash] + return StoredGraphDefinition.model_validate(json.loads(json_data)) + + loader = GraphLoader() + monkeypatch.setattr(loader, "store_to_ipfs", mock_store) + monkeypatch.setattr(loader, "retrieve_from_ipfs", mock_retrieve) + + # Save the graph + ipfs_hash = loader.save_graph_from_builder( + sentiment_graph_builder, + name="round_trip_test", + description="Test round trip type preservation", + ) + + # Load the graph back + recreated_graph = loader.load_graph(ipfs_hash) + + # Create test input with proper Pydantic model + test_input = SentimentState( + input_text="This is good news!", + sentiment_score=0.0, + decision="", + final_action="", + ) + + # Execute and verify types are preserved (use ainvoke for async compatibility) + result = await recreated_graph.ainvoke(test_input) + + # We can reconstruct the Pydantic model from the result + pydantic_result = SentimentState.model_validate(result) + assert isinstance(pydantic_result, SentimentState) diff --git a/experimental/archon/tests/test_validation.py b/experimental/archon/tests/test_validation.py new file mode 100644 index 00000000..f02b0444 --- /dev/null +++ b/experimental/archon/tests/test_validation.py @@ -0,0 +1,75 @@ +""" +Test Pydantic field validators for function references. +""" + +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +from experimental.archon.src.graph_models import ConditionalEdgeDefinition, GraphNodeDefinition + + +class TestFunctionReferenceValidation: + """Test validation of function references.""" + + def test_valid_function_reference(self): + """Test that valid function references are accepted.""" + # Valid reference + node = GraphNodeDefinition( + name="test_node", function_reference="experimental.archon.tests.test_validation:test_function" + ) + assert node.function_reference == "experimental.archon.tests.test_validation:test_function" + + def test_invalid_format_no_colon(self): + """Test that function references without colons are rejected.""" + with pytest.raises(ValidationError, match="must be in format"): + GraphNodeDefinition(name="test_node", function_reference="experimental.archon.tests.test_validation") + + def test_invalid_format_multiple_colons(self): + """Test that function references with multiple colons are rejected.""" + with pytest.raises(ValidationError, match="exactly one ':' separator"): + GraphNodeDefinition(name="test_node", function_reference="experimental.archon:test:function") + + def test_empty_module_path(self): + """Test that empty module paths are rejected.""" + with pytest.raises(ValidationError, match="must be non-empty"): + GraphNodeDefinition(name="test_node", function_reference=":test_function") + + def test_empty_function_name(self): + """Test that empty function names are rejected.""" + with pytest.raises(ValidationError, match="must be non-empty"): + GraphNodeDefinition(name="test_node", function_reference="experimental.archon.tests:") + + def test_relative_paths_rejected(self): + """Test that relative path components are rejected.""" + with pytest.raises(ValidationError, match="Relative path components"): + GraphNodeDefinition(name="test_node", function_reference="experimental.archon.../malicious:bad_function") + + def test_non_experimental_module_rejected(self): + """Test that non-experimental modules are rejected.""" + with pytest.raises(ValidationError, match="must start with 'experimental'"): + GraphNodeDefinition(name="test_node", function_reference="malicious.module:bad_function") + + def test_invalid_python_identifier(self): + """Test that invalid Python identifiers for function names are rejected.""" + with pytest.raises(ValidationError, match="not a valid Python identifier"): + GraphNodeDefinition(name="test_node", function_reference="experimental.archon.tests:123invalid-name") + + def test_conditional_edge_validation(self): + """Test that conditional edges also validate function references.""" + # Valid conditional edge + cond_edge = ConditionalEdgeDefinition( + source_node="test_node", + condition_function_reference="experimental.archon.tests.test_validation:condition_func", + target_mapping={"true": "next_node", "false": "end"}, + ) + assert cond_edge.condition_function_reference == "experimental.archon.tests.test_validation:condition_func" + + # Invalid conditional edge + with pytest.raises(ValidationError, match="must start with 'experimental'"): + ConditionalEdgeDefinition( + source_node="test_node", + condition_function_reference="malicious.module:bad_condition", + target_mapping={"true": "next_node", "false": "end"}, + ) diff --git a/pyproject.toml b/pyproject.toml index 74503645..cddb7610 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "langchain==0.3.26", "langchain-community==0.3.27", "langchain-openai==0.3.28", + "langgraph==0.2.60", "duckduckgo-search==8.1.1", "faiss-cpu==1.11.0.post1", "tiktoken==0.9.0", @@ -53,3 +54,8 @@ line-length = 120 [tool.mypy] strict = true + +[dependency-groups] +dev = [ + "pytest-asyncio>=1.1.0", +] diff --git a/requirements.txt b/requirements.txt index 3cb60ec4..61edb229 100644 --- a/requirements.txt +++ b/requirements.txt @@ -46,6 +46,7 @@ langchain-community==0.3.27 langchain-core==0.3.69 langchain-openai==0.3.28 langchain-text-splitters==0.3.8 +langgraph==0.2.60 langsmith==0.4.8 lxml==6.0.0 markdown-it-py==3.0.0 @@ -100,7 +101,6 @@ six==1.17.0 sniffio==1.3.1 soupsieve==2.7 sqlalchemy==2.0.41 --e file:///app tenacity==9.1.2 textblob==0.19.0 tiktoken==0.9.0 From d3c5ab844dc13c9778dae7cf15f2e78cc28cdfb8 Mon Sep 17 00:00:00 2001 From: Derek Date: Sat, 9 Aug 2025 16:11:12 +0100 Subject: [PATCH 2/3] simplify loader --- experimental/archon/src/graph_executor.py | 24 ---- experimental/archon/src/graph_loader.py | 120 +++--------------- experimental/archon/src/graph_models.py | 37 +----- experimental/archon/tests/conftest.py | 87 +++++++++++++ .../archon/tests/test_graph_executor.py | 36 +----- .../archon/tests/test_graph_loader.py | 84 +----------- .../archon/tests/test_type_preservation.py | 27 +--- experimental/archon/tests/test_validation.py | 1 - 8 files changed, 126 insertions(+), 290 deletions(-) diff --git a/experimental/archon/src/graph_executor.py b/experimental/archon/src/graph_executor.py index 67a4603c..365bf171 100644 --- a/experimental/archon/src/graph_executor.py +++ b/experimental/archon/src/graph_executor.py @@ -65,24 +65,9 @@ def __init__(self) -> None: self.loader = GraphLoader() def load_graph(self, ipfs_hash: str) -> LoadedGraph: - """ - Load a graph and expose its state class for type-safe usage. - - Args: - ipfs_hash: IPFS hash of the stored graph definition - - Returns: - LoadedGraph with both compiled graph and state class - - Raises: - ImportError: If graph functions or state class cannot be loaded - ValueError: If graph structure is invalid - """ - # Load the stored definition first to get state class stored_definition = self.loader.retrieve_from_ipfs(ipfs_hash) state_class = self.loader._load_class_from_reference(stored_definition.state_schema.class_reference) - # Load the compiled graph compiled_graph = self.loader.load_graph(ipfs_hash) return LoadedGraph(compiled_graph, state_class) @@ -97,15 +82,6 @@ async def execute_graph(self, ipfs_hash: str, input_state: dict[str, Any]) -> di Returns: Final state dictionary after graph execution - - Raises: - ImportError: If graph functions cannot be loaded - ValueError: If graph structure is invalid - Exception: If graph execution fails - - Note: - This method always uses async execution (ainvoke) which works - universally for sync, async, and mixed graphs. """ loaded_graph = self.load_graph(ipfs_hash) return await loaded_graph.execute(input_state) diff --git a/experimental/archon/src/graph_loader.py b/experimental/archon/src/graph_loader.py index c97040bd..f607c1b6 100644 --- a/experimental/archon/src/graph_loader.py +++ b/experimental/archon/src/graph_loader.py @@ -20,7 +20,9 @@ import requests from langgraph.graph import END, START, StateGraph from langgraph.graph.graph import CompiledGraph +from langgraph.utils.runnable import RunnableCallable from pinata_python.pinning import Pinning +from pydantic import BaseModel from .graph_models import ( ConditionalEdgeDefinition, @@ -48,10 +50,6 @@ class GraphLoader: >>> recreated_graph = loader.load_graph(ipfs_hash) """ - def __init__(self) -> None: - """Initialize the GraphLoader.""" - pass - def serialize_graph_from_builder( self, state_graph: StateGraph, @@ -77,17 +75,12 @@ def serialize_graph_from_builder( # Extract serializable node definitions nodes = [] for node_name, node_spec in state_graph.nodes.items(): - # Extract function information from StateNodeSpec - if not hasattr(node_spec, "runnable"): - raise ValueError(f"Node '{node_name}' missing expected 'runnable' attribute") - runnable = node_spec.runnable + assert isinstance(runnable, RunnableCallable), "Only RunnableCallables are currently supported." - # For sync functions: use runnable.func - # For async functions: runnable.func is None, use runnable.afunc - if hasattr(runnable, "func") and runnable.func is not None: + if runnable.func is not None: func = runnable.func - elif hasattr(runnable, "afunc") and runnable.afunc is not None: + elif runnable.afunc is not None: func = runnable.afunc else: raise ValueError( @@ -95,32 +88,24 @@ def serialize_graph_from_builder( f"Expected runnable.func or runnable.afunc to be set" ) - if not hasattr(func, "__module__") or not hasattr(func, "__name__"): - raise ValueError(f"Function for node '{node_name}' missing __module__ or __name__ attributes") - function_reference = f"{func.__module__}:{func.__name__}" - nodes.append(GraphNodeDefinition(name=node_name, function_reference=function_reference)) # Extract simple edges edges = [] - for edge_tuple in getattr(state_graph, "edges", set()): + for edge_tuple in state_graph.edges: source, target = edge_tuple edges.append(GraphEdgeDefinition(source=source, target=target)) # Extract conditional edges from branches conditional_edges = [] - branches = getattr(state_graph, "branches", {}) + branches = state_graph.branches for source_node, branch_dict in branches.items(): - # The branch dict has function names as keys for func_name, branch_obj in branch_dict.items(): - if not hasattr(branch_obj, "path"): - raise ValueError(f"Conditional edge from '{source_node}' missing expected 'path' attribute") - - # Extract condition function similar to node functions - if hasattr(branch_obj.path, "func") and branch_obj.path.func is not None: + assert isinstance(branch_obj.path, RunnableCallable), "Only RunnableCallables are currently supported." + if branch_obj.path.func is not None: condition_func = branch_obj.path.func - elif hasattr(branch_obj.path, "afunc") and branch_obj.path.afunc is not None: + elif branch_obj.path.afunc is not None: condition_func = branch_obj.path.afunc else: raise ValueError( @@ -128,20 +113,12 @@ def serialize_graph_from_builder( f"Expected branch_obj.path.func or branch_obj.path.afunc to be set" ) - if not hasattr(condition_func, "__module__") or not hasattr(condition_func, "__name__"): - raise ValueError( - f"Condition function for edge from '{source_node}' missing __module__ or __name__ attributes" - ) - - if not hasattr(branch_obj, "ends"): - raise ValueError(f"Conditional edge from '{source_node}' missing expected 'ends' attribute") - condition_function_reference = f"{condition_func.__module__}:{condition_func.__name__}" conditional_edges.append( ConditionalEdgeDefinition( source_node=source_node, condition_function_reference=condition_function_reference, - target_mapping=branch_obj.ends, + target_mapping=branch_obj.ends or {}, ) ) @@ -157,9 +134,7 @@ def serialize_graph_from_builder( edges=edges, conditional_edges=conditional_edges, state_channels=state_channels, - state_type_name=( - state_graph.state_schema.__name__ if hasattr(state_graph, "state_schema") else "UnknownState" - ), + state_type_name=state_graph.schema.__name__, ) # Create state schema from the StateGraph's state schema @@ -192,32 +167,18 @@ def _extract_state_schema(self, state_graph: StateGraph) -> StateSchema: Raises: ValueError: If state schema is not a Pydantic BaseModel """ - # Get the state schema class from the StateGraph - if not hasattr(state_graph, "schemas") or not state_graph.schemas: - raise ValueError("StateGraph must have a state schema defined") - # Get the first (and should be only) schema class schema_class = next(iter(state_graph.schemas.keys())) - # Verify it's a Pydantic BaseModel - try: - from pydantic import BaseModel - - if not issubclass(schema_class, BaseModel): - raise ValueError(f"State schema must be a Pydantic BaseModel, got {type(schema_class)}") - except (TypeError, ImportError) as e: - raise ValueError(f"Invalid state schema class: {e}") - - # Extract module and class information - if not hasattr(schema_class, "__module__") or not hasattr(schema_class, "__name__"): - raise ValueError("State schema class missing __module__ or __name__ attributes") + assert issubclass(schema_class, BaseModel), ( + f"State schema must be a Pydantic BaseModel, got {type(schema_class)}" + ) class_reference = f"{schema_class.__module__}:{schema_class.__name__}" return StateSchema( name=schema_class.__name__, class_reference=class_reference, - description=f"State schema for {schema_class.__name__}", ) def store_to_ipfs(self, graph_json: str) -> str: @@ -230,18 +191,15 @@ def store_to_ipfs(self, graph_json: str) -> str: Returns: IPFS hash of the stored content """ - # Get credentials from environment when needed api_key = os.getenv("PINATA_API_KEY") secret_key = os.getenv("PINATA_SECRET_API_KEY") if not api_key or not secret_key: raise ValueError("PINATA_API_KEY and PINATA_SECRET_API_KEY environment variables required for IPFS storage") - # Initialize Pinata client pinata = Pinning(PINATA_API_KEY=api_key, PINATA_API_SECRET=secret_key) # Pin JSON content to IPFS via Pinata - # First save to temp file since pinata-python expects a file path with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: f.write(graph_json) temp_path: str = f.name @@ -283,19 +241,11 @@ def recreate_graph_from_definition(self, stored_definition: StoredGraphDefinitio the serialized function references (module + name). """ - # Dynamically load the Pydantic model class state_class = self._load_class_from_reference(stored_definition.state_schema.class_reference) + assert issubclass(state_class, BaseModel), ( + f"Loaded state class must be a Pydantic BaseModel, got {type(state_class)}" + ) - # Verify it's a Pydantic BaseModel - try: - from pydantic import BaseModel - - if not issubclass(state_class, BaseModel): - raise ValueError(f"Loaded state class must be a Pydantic BaseModel, got {type(state_class)}") - except (TypeError, ImportError) as e: - raise ValueError(f"Invalid loaded state class: {e}") - - # Create StateGraph with the proper Pydantic model builder = StateGraph(state_class) for node_def in stored_definition.graph_definition.nodes: @@ -312,10 +262,9 @@ def recreate_graph_from_definition(self, stored_definition: StoredGraphDefinitio for cond_edge_def in stored_definition.graph_definition.conditional_edges: try: condition_func = self._load_function_from_reference(cond_edge_def.condition_function_reference) - target_mapping: dict[Hashable, str] = {} - for key, value in cond_edge_def.target_mapping.items(): - target_value = END if value == "__end__" else value - target_mapping[key] = target_value + target_mapping: dict[Hashable, str] = { + k: END if v == "__end__" else v for k, v in cond_edge_def.target_mapping.items() + } builder.add_conditional_edges(cond_edge_def.source_node, condition_func, target_mapping) except Exception as e: raise ImportError(f"Failed to load conditional edge from {cond_edge_def.source_node}: {e}") @@ -327,12 +276,6 @@ def _load_function_from_reference( ) -> Union[Callable[..., Any], Callable[..., Awaitable[Any]]]: """ Dynamically load a function from a module:function reference string. - - Args: - function_reference: Function reference in format "module.path:function_name" - - Returns: - The loaded function object (can be sync or async) """ try: module_name, function_name = function_reference.split(":", 1) @@ -344,15 +287,6 @@ def _load_function_from_reference( def _load_class_from_reference(self, class_reference: str) -> type: """ Dynamically load a class from a module:class reference string. - - Args: - class_reference: Class reference in format "module.path:ClassName" - - Returns: - The loaded class object - - Raises: - ImportError: If the class cannot be loaded """ try: module_name, class_name = class_reference.split(":", 1) @@ -367,12 +301,6 @@ def _load_class_from_reference(self, class_reference: str) -> type: def recreate_graph(self, stored_definition: StoredGraphDefinition) -> CompiledGraph: """ Recreate and compile a LangGraph from its stored definition. - - Args: - stored_definition: StoredGraphDefinition object from IPFS - - Returns: - Recreated CompiledGraph ready for execution """ state_graph = self.recreate_graph_from_definition(stored_definition) return state_graph.compile() @@ -420,12 +348,6 @@ def load_graph(self, ipfs_hash: str) -> CompiledGraph: def get_graph_info(self, ipfs_hash: str) -> GraphMetadata: """ Get metadata about a stored graph without recreating it. - - Args: - ipfs_hash: IPFS hash of the stored graph definition - - Returns: - GraphMetadata with information about the stored graph """ stored_definition = self.retrieve_from_ipfs(ipfs_hash) return stored_definition.metadata diff --git a/experimental/archon/src/graph_models.py b/experimental/archon/src/graph_models.py index bd172515..0b355b86 100644 --- a/experimental/archon/src/graph_models.py +++ b/experimental/archon/src/graph_models.py @@ -8,7 +8,7 @@ from __future__ import annotations from datetime import datetime -from typing import Any +from typing import Any, Hashable from langchain_core.runnables.graph import Edge, Node from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -28,14 +28,10 @@ class GraphModel(BaseModel, Immutable): @staticmethod def _validate_function_reference(value: str) -> str: """Validate function reference format and security constraints.""" - if not isinstance(value, str): - raise ValueError("Function reference must be a string") - - if ":" not in value: - raise ValueError("Function reference must be in format 'module.path:function_name'") - if value.count(":") != 1: - raise ValueError("Function reference must have exactly one ':' separator") + raise ValueError( + "Function reference must be in format 'module.path:function_name' with exactly one ':' separator" + ) module_path, function_name = value.split(":", 1) @@ -68,9 +64,7 @@ class GraphMetadata(GraphModel): langgraph_version: str = "0.2.60" -# Note: We create Pydantic-compatible wrappers for LangGraph's Node and Edge types -# to ensure compatibility with our immutable model architecture while maintaining -# compatibility with LangGraph's structure. +# Note: It's a bit duplicative, but I prefer to fully separate our internal classes from LangGraph's classes. class LangGraphNode(GraphModel): @@ -145,7 +139,7 @@ class ConditionalEdgeDefinition(GraphModel): source_node: str condition_function_reference: str = Field(description="Format: 'module.path:function_name'") - target_mapping: dict[str, str] + target_mapping: dict[Hashable, str] @field_validator("condition_function_reference") @classmethod @@ -171,19 +165,11 @@ class SerializableGraphDefinition(GraphModel): state_type_name: str -class LangGraphDefinition(GraphModel): - """Structured representation of LangGraph's native to_json() output using wrapper types.""" - - nodes: list[LangGraphNode] - edges: list[LangGraphEdge] - - class StateSchema(GraphModel): """Pydantic-only state schema representation.""" name: str class_reference: str = Field(description="Format: 'module.path:ClassName'") - description: str | None = None @field_validator("class_reference") @classmethod @@ -204,8 +190,7 @@ class ExecutionConfig(GraphModel): class StoredGraphDefinition(GraphModel): """Complete stored graph definition using serializable pre-compilation data.""" - type: str = "serializable_graph_definition" - version: str = "2.0.0" # Updated version for new approach + version: str = "0.0.1" metadata: GraphMetadata # Serializable graph definition from StateGraph builder (before compilation) @@ -213,16 +198,8 @@ class StoredGraphDefinition(GraphModel): description="Complete serializable graph structure from StateGraph builder" ) - # LangGraph's native representation (optional, for reference) - langgraph_definition: LangGraphDefinition | None = Field( - default=None, - description="Optional: LangGraph's compiled representation for reference", - ) - - # Additional info that LangGraph doesn't capture state_schema: StateSchema = Field(description="Structured information about the state schema") - # Execution configuration execution_config: ExecutionConfig = Field( default_factory=ExecutionConfig, description="Structured execution configuration options", diff --git a/experimental/archon/tests/conftest.py b/experimental/archon/tests/conftest.py index 46e017b1..ebf4b192 100644 --- a/experimental/archon/tests/conftest.py +++ b/experimental/archon/tests/conftest.py @@ -4,10 +4,16 @@ from __future__ import annotations +import json +from typing import Any, Callable + import pytest from langgraph.graph import END, START, StateGraph from pydantic import BaseModel, Field +from experimental.archon.src.graph_loader import GraphLoader +from experimental.archon.src.graph_models import StoredGraphDefinition + class SentimentState(BaseModel): """Pydantic state model for sentiment analysis workflow.""" @@ -79,3 +85,84 @@ def mock_ipfs_storage(): Returns a dictionary that will store graph definitions by hash. """ return {} + + +class IPFSMockHelpers: + """Helper functions for mocking IPFS operations.""" + + @staticmethod + def create_mock_store(storage: dict[str, str]) -> Callable[[str], str]: + """Create a mock store function that saves to the provided storage.""" + + def mock_store(graph_json: str) -> str: + fake_hash = f"Qm{hash(graph_json)}" + storage[fake_hash] = graph_json + return fake_hash + + return mock_store + + @staticmethod + def create_mock_retrieve(storage: dict[str, str]) -> Callable[[str], StoredGraphDefinition]: + """Create a mock retrieve function that loads from the provided storage.""" + + def mock_retrieve(ipfs_hash: str) -> StoredGraphDefinition: + json_data = storage[ipfs_hash] + return StoredGraphDefinition.model_validate(json.loads(json_data)) + + return mock_retrieve + + +@pytest.fixture +def mock_ipfs_functions(mock_ipfs_storage): + """ + Provide mock IPFS store and retrieve functions. + Returns a tuple of (mock_store, mock_retrieve) functions. + """ + mock_store = IPFSMockHelpers.create_mock_store(mock_ipfs_storage) + mock_retrieve = IPFSMockHelpers.create_mock_retrieve(mock_ipfs_storage) + return mock_store, mock_retrieve + + +@pytest.fixture +def setup_ipfs_mocks(monkeypatch, mock_ipfs_functions): + """ + Fixture that sets up IPFS mocks for a GraphLoader or GraphExecutor. + Returns a function that applies the mocks to a given loader/executor. + """ + mock_store, mock_retrieve = mock_ipfs_functions + + def apply_mocks(obj: Any) -> None: + """Apply IPFS mocks to a GraphLoader or GraphExecutor instance.""" + if hasattr(obj, "loader"): # GraphExecutor + target = obj.loader + else: # GraphLoader + target = obj + + monkeypatch.setattr(target, "store_to_ipfs", mock_store) + monkeypatch.setattr(target, "retrieve_from_ipfs", mock_retrieve) + + return apply_mocks + + +@pytest.fixture +def loader_with_ipfs_mocks(setup_ipfs_mocks): + """ + Create a GraphLoader with IPFS mocks already applied. + This fixture can be used directly instead of creating and mocking separately. + """ + loader = GraphLoader() + setup_ipfs_mocks(loader) + return loader + + +@pytest.fixture +def executor_with_ipfs_mocks(setup_ipfs_mocks): + """ + Create a GraphExecutor with IPFS mocks already applied. + This fixture can be used directly instead of creating and mocking separately. + """ + from experimental.archon.src.graph_executor import GraphExecutor + + executor = GraphExecutor() + setup_ipfs_mocks(executor) + return executor diff --git a/experimental/archon/tests/test_graph_executor.py b/experimental/archon/tests/test_graph_executor.py index a05c8f92..7bbe286c 100644 --- a/experimental/archon/tests/test_graph_executor.py +++ b/experimental/archon/tests/test_graph_executor.py @@ -4,35 +4,21 @@ from __future__ import annotations -import json - import pytest from experimental.archon.src.graph_executor import GraphExecutor -from experimental.archon.src.graph_models import StoredGraphDefinition +from experimental.archon.tests.conftest import SentimentState class TestGraphExecutorWithMockedIPFS: """Test GraphExecutor with mocked IPFS storage.""" @pytest.mark.asyncio - async def test_execute_stored_graph(self, sentiment_graph_builder, mock_ipfs_storage, monkeypatch): + async def test_execute_stored_graph(self, sentiment_graph_builder, setup_ipfs_mocks): """Test that we can execute a stored graph through the executor.""" - def mock_store(graph_json: str) -> str: - fake_hash = f"Qm{hash(graph_json)}" - mock_ipfs_storage[fake_hash] = graph_json - return fake_hash - - def mock_retrieve(ipfs_hash: str) -> StoredGraphDefinition: - json_data = mock_ipfs_storage[ipfs_hash] - return StoredGraphDefinition.model_validate(json.loads(json_data)) - executor = GraphExecutor() - - # Use monkeypatch to replace methods - this is type-safe - monkeypatch.setattr(executor.loader, "store_to_ipfs", mock_store) - monkeypatch.setattr(executor.loader, "retrieve_from_ipfs", mock_retrieve) + setup_ipfs_mocks(executor) # Step 1: Store the graph using the underlying loader ipfs_hash = executor.loader.save_graph_from_builder( @@ -76,21 +62,11 @@ def mock_retrieve(ipfs_hash: str) -> StoredGraphDefinition: assert "0.3" in negative_result["final_action"] @pytest.mark.asyncio - async def test_type_aware_execution(self, sentiment_graph_builder, mock_ipfs_storage, monkeypatch): + async def test_type_aware_execution(self, sentiment_graph_builder, setup_ipfs_mocks): """Test the new type-aware LoadedGraph functionality.""" - def mock_store(graph_json: str) -> str: - fake_hash = f"Qm{hash(graph_json)}" - mock_ipfs_storage[fake_hash] = graph_json - return fake_hash - - def mock_retrieve(ipfs_hash: str) -> StoredGraphDefinition: - json_data = mock_ipfs_storage[ipfs_hash] - return StoredGraphDefinition.model_validate(json.loads(json_data)) - executor = GraphExecutor() - monkeypatch.setattr(executor.loader, "store_to_ipfs", mock_store) - monkeypatch.setattr(executor.loader, "retrieve_from_ipfs", mock_retrieve) + setup_ipfs_mocks(executor) # Store the graph ipfs_hash = executor.loader.save_graph_from_builder( @@ -115,8 +91,6 @@ def mock_retrieve(ipfs_hash: str) -> StoredGraphDefinition: ) # Verify the created state is properly typed - from experimental.archon.tests.conftest import SentimentState - assert isinstance(typed_input, SentimentState) assert typed_input.input_text == "This is good news!" diff --git a/experimental/archon/tests/test_graph_loader.py b/experimental/archon/tests/test_graph_loader.py index 41a3ce16..cc8bf89a 100644 --- a/experimental/archon/tests/test_graph_loader.py +++ b/experimental/archon/tests/test_graph_loader.py @@ -4,35 +4,16 @@ from __future__ import annotations -import json - -import pytest - -from experimental.archon.src.graph_loader import GraphLoader -from experimental.archon.src.graph_models import StoredGraphDefinition - from .conftest import analyze_sentiment, make_decision, no_action, take_action class TestGraphLoaderWithMockedIPFS: """Test GraphLoader with mocked IPFS storage.""" - def test_end_to_end_save_load(self, sentiment_graph_builder, mock_ipfs_storage, monkeypatch): + def test_end_to_end_save_load(self, sentiment_graph_builder, mock_ipfs_storage, loader_with_ipfs_mocks): """Test complete workflow: save graph to 'IPFS', retrieve it, and execute.""" - def mock_store(graph_json: str) -> str: - fake_hash = f"Qm{hash(graph_json)}" - mock_ipfs_storage[fake_hash] = graph_json - return fake_hash - - def mock_retrieve(ipfs_hash: str) -> StoredGraphDefinition: - json_data = mock_ipfs_storage[ipfs_hash] - return StoredGraphDefinition.model_validate(json.loads(json_data)) - - # Create loader - no credentials needed since we mock the IPFS methods - loader = GraphLoader() - monkeypatch.setattr(loader, "store_to_ipfs", mock_store) - monkeypatch.setattr(loader, "retrieve_from_ipfs", mock_retrieve) + loader = loader_with_ipfs_mocks # Step 1: Save the graph ipfs_hash = loader.save_graph_from_builder( @@ -125,64 +106,3 @@ def mock_retrieve(ipfs_hash: str) -> StoredGraphDefinition: module_part, func_part = ref.split(":") assert module_part, f"Empty module in reference: {ref}" assert func_part, f"Empty function name in reference: {ref}" - - @pytest.mark.asyncio - async def test_save_load_and_execute_graph(self, sentiment_graph_builder, mock_ipfs_storage, monkeypatch): - """Test that we can save a graph, load it back, and execute it successfully.""" - - def mock_store(graph_json: str) -> str: - fake_hash = f"Qm{hash(graph_json)}" - mock_ipfs_storage[fake_hash] = graph_json - return fake_hash - - def mock_retrieve(ipfs_hash: str) -> StoredGraphDefinition: - json_data = mock_ipfs_storage[ipfs_hash] - return StoredGraphDefinition.model_validate(json.loads(json_data)) - - loader = GraphLoader() - monkeypatch.setattr(loader, "store_to_ipfs", mock_store) - monkeypatch.setattr(loader, "retrieve_from_ipfs", mock_retrieve) - - # Step 1: Save the graph to IPFS - ipfs_hash = loader.save_graph_from_builder( - sentiment_graph_builder, - name="executable_workflow", - description="Test executable workflow", - ) - - # Step 2: Load the graph back from IPFS - recreated_graph = loader.load_graph(ipfs_hash) - - # Step 3: Execute the recreated graph with positive sentiment input - positive_input = { - "input_text": "This is good news!", - "sentiment_score": 0.0, - "decision": "", - "final_action": "", - } - - positive_result = await recreated_graph.ainvoke(positive_input) - - # Verify positive sentiment path - assert positive_result["input_text"] == "This is good news!" - assert positive_result["sentiment_score"] == 0.8 # "good" triggers positive score - assert positive_result["decision"] == "positive" - assert "Taking positive action" in positive_result["final_action"] - assert "0.8" in positive_result["final_action"] - - # Step 4: Execute with negative sentiment input - negative_input = { - "input_text": "This is terrible news!", - "sentiment_score": 0.0, - "decision": "", - "final_action": "", - } - - negative_result = await recreated_graph.ainvoke(negative_input) - - # Verify negative sentiment path - assert negative_result["input_text"] == "This is terrible news!" - assert negative_result["sentiment_score"] == 0.3 # "terrible" doesn't contain "good" - assert negative_result["decision"] == "negative" - assert "No action needed" in negative_result["final_action"] - assert "0.3" in negative_result["final_action"] diff --git a/experimental/archon/tests/test_type_preservation.py b/experimental/archon/tests/test_type_preservation.py index 249f1e54..4288f3f4 100644 --- a/experimental/archon/tests/test_type_preservation.py +++ b/experimental/archon/tests/test_type_preservation.py @@ -15,16 +15,11 @@ class TestTypePreservation: """Test that Pydantic types are properly preserved during serialization.""" - def test_state_schema_extraction_preserves_types(self, sentiment_graph_builder, mock_ipfs_storage, monkeypatch): + def test_state_schema_extraction_preserves_types(self, sentiment_graph_builder, setup_ipfs_mocks): """Test that state schema extraction captures complete Pydantic information.""" - def mock_store(graph_json: str) -> str: - fake_hash = f"Qm{hash(graph_json)}" - mock_ipfs_storage[fake_hash] = graph_json - return fake_hash - loader = GraphLoader() - monkeypatch.setattr(loader, "store_to_ipfs", mock_store) + setup_ipfs_mocks(loader) # Serialize the graph graph_json = loader.serialize_graph_from_builder( @@ -43,8 +38,6 @@ def mock_store(graph_json: str) -> str: assert state_schema["name"] == "SentimentState" # Verify we can load the class and generate the schema on demand - from experimental.archon.tests.conftest import SentimentState - json_schema = SentimentState.model_json_schema() # Check that field types can be retrieved from the loaded class @@ -68,23 +61,11 @@ def mock_store(graph_json: str) -> str: assert properties["final_action"]["default"] == "" @pytest.mark.asyncio - async def test_type_preservation_round_trip(self, sentiment_graph_builder, mock_ipfs_storage, monkeypatch): + async def test_type_preservation_round_trip(self, sentiment_graph_builder, setup_ipfs_mocks): """Test that types are preserved through complete serialization round trip.""" - def mock_store(graph_json: str) -> str: - fake_hash = f"Qm{hash(graph_json)}" - mock_ipfs_storage[fake_hash] = graph_json - return fake_hash - - def mock_retrieve(ipfs_hash: str): - from experimental.archon.src.graph_models import StoredGraphDefinition - - json_data = mock_ipfs_storage[ipfs_hash] - return StoredGraphDefinition.model_validate(json.loads(json_data)) - loader = GraphLoader() - monkeypatch.setattr(loader, "store_to_ipfs", mock_store) - monkeypatch.setattr(loader, "retrieve_from_ipfs", mock_retrieve) + setup_ipfs_mocks(loader) # Save the graph ipfs_hash = loader.save_graph_from_builder( diff --git a/experimental/archon/tests/test_validation.py b/experimental/archon/tests/test_validation.py index f02b0444..7446c8bc 100644 --- a/experimental/archon/tests/test_validation.py +++ b/experimental/archon/tests/test_validation.py @@ -15,7 +15,6 @@ class TestFunctionReferenceValidation: def test_valid_function_reference(self): """Test that valid function references are accepted.""" - # Valid reference node = GraphNodeDefinition( name="test_node", function_reference="experimental.archon.tests.test_validation:test_function" ) From b3ceabdcd77c3ba26b66645db0706bd29cb292ba Mon Sep 17 00:00:00 2001 From: Derek Date: Sat, 30 Aug 2025 13:56:07 +0100 Subject: [PATCH 3/3] undo formatting --- src/talos/core/agent.py | 8 +- src/talos/core/job_scheduler.py | 80 +++++----- src/talos/core/main_agent.py | 40 ++--- src/talos/core/scheduled_job.py | 30 ++-- src/talos/data/dataset_manager.py | 56 +++---- src/talos/jobs/example_jobs.py | 60 +++---- src/talos/models/twitter.py | 9 +- .../services/implementations/yield_manager.py | 86 +++++----- src/talos/skills/twitter_influence.py | 6 +- src/talos/tools/document_loader.py | 24 +-- .../tools/general_influence_evaluator.py | 123 +++++++-------- src/talos/tools/twitter_client.py | 69 +++++--- tests/test_document_loader.py | 18 +-- tests/test_scheduled_jobs.py | 148 +++++++++--------- tests/test_yield_manager.py | 16 +- 15 files changed, 394 insertions(+), 379 deletions(-) diff --git a/src/talos/core/agent.py b/src/talos/core/agent.py index 4afa5951..ffb9443d 100644 --- a/src/talos/core/agent.py +++ b/src/talos/core/agent.py @@ -49,11 +49,11 @@ def model_post_init(self, __context: Any) -> None: def set_prompt(self, name: str | list[str]): if not self.prompt_manager: raise ValueError("Prompt manager not initialized.") - + prompt_names = name if isinstance(name, list) else [name] if self.dataset_manager: prompt_names.append("relevant_documents_prompt") - + prompt = self.prompt_manager.get_prompt(prompt_names) if not prompt: raise ValueError(f"The prompt '{prompt_names}' is not defined.") @@ -92,11 +92,11 @@ def _build_context(self, query: str, **kwargs) -> dict: A base method for adding context to the query. """ context = {} - + if self.dataset_manager: relevant_documents = self.dataset_manager.search(query, k=5) context["relevant_documents"] = relevant_documents - + return context def run(self, message: str, history: list[BaseMessage] | None = None, **kwargs) -> BaseModel: diff --git a/src/talos/core/job_scheduler.py b/src/talos/core/job_scheduler.py index dc5f47a6..ade631e5 100644 --- a/src/talos/core/job_scheduler.py +++ b/src/talos/core/job_scheduler.py @@ -17,44 +17,44 @@ class JobScheduler(BaseModel): """ Manages scheduled jobs for the MainAgent using APScheduler. - + Provides functionality to: - Register and manage scheduled jobs - Execute jobs with supervision - Handle job lifecycle (start, stop, pause, resume) """ - + model_config = ConfigDict(arbitrary_types_allowed=True) - + supervisor: Optional[Supervisor] = Field(None, description="Supervisor for approving job executions") timezone: str = Field("UTC", description="Timezone for job scheduling") - + _scheduler: AsyncIOScheduler = PrivateAttr() _jobs: Dict[str, ScheduledJob] = PrivateAttr(default_factory=dict) _running: bool = PrivateAttr(default=False) - + def model_post_init(self, __context: Any) -> None: self._scheduler = AsyncIOScheduler(timezone=self.timezone) self._jobs = {} self._running = False - + def register_job(self, job: ScheduledJob) -> None: """ Register a scheduled job with the scheduler. - + Args: job: The ScheduledJob instance to register """ if job.name in self._jobs: logger.warning(f"Job '{job.name}' already registered, replacing existing job") self.unregister_job(job.name) - + self._jobs[job.name] = job - + if not job.enabled: logger.info(f"Job '{job.name}' registered but disabled") return - + if job.is_recurring() and job.cron_expression: trigger = CronTrigger.from_crontab(job.cron_expression, timezone=self.timezone) self._scheduler.add_job( @@ -63,10 +63,10 @@ def register_job(self, job: ScheduledJob) -> None: args=[job.name], id=job.name, max_instances=job.max_instances, - replace_existing=True + replace_existing=True, ) logger.info(f"Registered recurring job '{job.name}' with cron: {job.cron_expression}") - + elif job.is_one_time() and job.execute_at: trigger = DateTrigger(run_date=job.execute_at, timezone=self.timezone) self._scheduler.add_job( @@ -75,68 +75,68 @@ def register_job(self, job: ScheduledJob) -> None: args=[job.name], id=job.name, max_instances=job.max_instances, - replace_existing=True + replace_existing=True, ) logger.info(f"Registered one-time job '{job.name}' for: {job.execute_at}") - + def unregister_job(self, job_name: str) -> bool: """ Unregister a scheduled job. - + Args: job_name: Name of the job to unregister - + Returns: True if job was found and removed, False otherwise """ if job_name not in self._jobs: logger.warning(f"Job '{job_name}' not found for unregistration") return False - + try: self._scheduler.remove_job(job_name) except Exception as e: logger.warning(f"Failed to remove job '{job_name}' from scheduler: {e}") - + del self._jobs[job_name] logger.info(f"Unregistered job '{job_name}'") return True - + def get_job(self, job_name: str) -> Optional[ScheduledJob]: """Get a registered job by name.""" return self._jobs.get(job_name) - + def list_jobs(self) -> List[ScheduledJob]: """Get all registered jobs.""" return list(self._jobs.values()) - + def start(self) -> None: """Start the job scheduler.""" if self._running: logger.warning("Job scheduler is already running") return - + self._scheduler.start() self._running = True logger.info("Job scheduler started") - + def stop(self) -> None: """Stop the job scheduler.""" if not self._running: logger.warning("Job scheduler is not running") return - + self._scheduler.shutdown() self._running = False logger.info("Job scheduler stopped") - + def pause_job(self, job_name: str) -> bool: """ Pause a specific job. - + Args: job_name: Name of the job to pause - + Returns: True if job was found and paused, False otherwise """ @@ -147,14 +147,14 @@ def pause_job(self, job_name: str) -> bool: except Exception as e: logger.error(f"Failed to pause job '{job_name}': {e}") return False - + def resume_job(self, job_name: str) -> bool: """ Resume a specific job. - + Args: job_name: Name of the job to resume - + Returns: True if job was found and resumed, False otherwise """ @@ -165,15 +165,15 @@ def resume_job(self, job_name: str) -> bool: except Exception as e: logger.error(f"Failed to resume job '{job_name}': {e}") return False - + def is_running(self) -> bool: """Check if the scheduler is running.""" return self._running - + async def _execute_job_with_supervision(self, job_name: str) -> None: """ Execute a job with optional supervision. - + Args: job_name: Name of the job to execute """ @@ -181,27 +181,27 @@ async def _execute_job_with_supervision(self, job_name: str) -> None: if not job: logger.error(f"Job '{job_name}' not found for execution") return - + if not job.enabled: logger.info(f"Job '{job_name}' is disabled, skipping execution") return - + logger.info(f"Executing scheduled job: {job_name}") - + try: if self.supervisor: logger.info(f"Requesting supervision approval for job: {job_name}") - + await job.run() logger.info(f"Job '{job_name}' completed successfully") - + if job.is_one_time(): self.unregister_job(job_name) logger.info(f"One-time job '{job_name}' removed after execution") - + except Exception as e: logger.error(f"Job '{job_name}' failed with error: {e}") - + if job.is_one_time(): self.unregister_job(job_name) logger.info(f"Failed one-time job '{job_name}' removed") diff --git a/src/talos/core/main_agent.py b/src/talos/core/main_agent.py index 5f8795f5..8137b272 100644 --- a/src/talos/core/main_agent.py +++ b/src/talos/core/main_agent.py @@ -2,15 +2,16 @@ import os from datetime import datetime -from typing import Any, Optional, List +from typing import Any, List, Optional from langchain_core.language_models import BaseChatModel from langchain_core.tools import BaseTool, tool from talos.core.agent import Agent +from talos.core.job_scheduler import JobScheduler from talos.core.router import Router from talos.core.scheduled_job import ScheduledJob -from talos.core.job_scheduler import JobScheduler +from talos.data.dataset_manager import DatasetManager from talos.hypervisor.hypervisor import Hypervisor from talos.models.services import Ticket from talos.prompts.prompt_manager import PromptManager @@ -19,11 +20,10 @@ from talos.skills.base import Skill from talos.skills.cryptography import CryptographySkill from talos.skills.proposals import ProposalsSkill -from talos.skills.twitter_sentiment import TwitterSentimentSkill from talos.skills.twitter_influence import TwitterInfluenceSkill +from talos.skills.twitter_sentiment import TwitterSentimentSkill +from talos.tools.document_loader import DatasetSearchTool, DocumentLoaderTool from talos.tools.tool_manager import ToolManager -from talos.tools.document_loader import DocumentLoaderTool, DatasetSearchTool -from talos.data.dataset_manager import DatasetManager class MainAgent(Agent): @@ -91,53 +91,53 @@ def _setup_tool_manager(self) -> None: tool_manager.register_tool(skill.create_ticket_tool()) tool_manager.register_tool(self._get_ticket_status_tool()) tool_manager.register_tool(self._add_memory_tool()) - + if self.dataset_manager: tool_manager.register_tool(DocumentLoaderTool(self.dataset_manager)) tool_manager.register_tool(DatasetSearchTool(self.dataset_manager)) - + self.tool_manager = tool_manager def _setup_job_scheduler(self) -> None: """Initialize the job scheduler and register any predefined scheduled jobs.""" if not self.job_scheduler: self.job_scheduler = JobScheduler(supervisor=self.supervisor, timezone="UTC") - + for job in self.scheduled_jobs: self.job_scheduler.register_job(job) - + self.job_scheduler.start() def add_scheduled_job(self, job: ScheduledJob) -> None: """ Add a scheduled job to the agent. - + Args: job: The ScheduledJob instance to add """ if not self.job_scheduler: raise ValueError("Job scheduler not initialized") - + self.scheduled_jobs.append(job) self.job_scheduler.register_job(job) def remove_scheduled_job(self, job_name: str) -> bool: """ Remove a scheduled job from the agent. - + Args: job_name: Name of the job to remove - + Returns: True if job was found and removed, False otherwise """ if not self.job_scheduler: return False - + success = self.job_scheduler.unregister_job(job_name) - + self.scheduled_jobs = [job for job in self.scheduled_jobs if job.name != job_name] - + return success def list_scheduled_jobs(self) -> List[ScheduledJob]: @@ -208,16 +208,16 @@ def get_ticket_status(service_name: str, ticket_id: str) -> Ticket: def _build_context(self, query: str, **kwargs) -> dict: assert self.router is not None - + base_context = super()._build_context(query, **kwargs) - + active_tickets = self.router.get_all_tickets() ticket_info = [f"- {ticket.ticket_id}: last updated at {ticket.updated_at}" for ticket in active_tickets] - + main_agent_context = { "time": datetime.now().isoformat(), "available_services": ", ".join([service.name for service in self.router.services]), "active_tickets": " ".join(ticket_info), } - + return {**base_context, **main_agent_context} diff --git a/src/talos/core/scheduled_job.py b/src/talos/core/scheduled_job.py index 1ec99f63..6ec6ffa6 100644 --- a/src/talos/core/scheduled_job.py +++ b/src/talos/core/scheduled_job.py @@ -5,7 +5,7 @@ from datetime import datetime from typing import Any, Optional -from pydantic import BaseModel, Field, ConfigDict +from pydantic import BaseModel, ConfigDict, Field logger = logging.getLogger(__name__) @@ -13,51 +13,53 @@ class ScheduledJob(BaseModel, ABC): """ Abstract base class for scheduled jobs that can be executed by the MainAgent. - + Jobs can be scheduled using either: - A cron expression for recurring execution - A specific datetime for one-time execution """ - + model_config = ConfigDict(arbitrary_types_allowed=True) - + name: str = Field(..., description="Unique name for this scheduled job") description: str = Field(..., description="Human-readable description of what this job does") - cron_expression: Optional[str] = Field(None, description="Cron expression for recurring jobs (e.g., '0 9 * * *' for daily at 9 AM)") + cron_expression: Optional[str] = Field( + None, description="Cron expression for recurring jobs (e.g., '0 9 * * *' for daily at 9 AM)" + ) execute_at: Optional[datetime] = Field(None, description="Specific datetime for one-time execution") enabled: bool = Field(True, description="Whether this job is enabled for execution") max_instances: int = Field(1, description="Maximum number of concurrent instances of this job") - + def model_post_init(self, __context: Any) -> None: if not self.cron_expression and not self.execute_at: raise ValueError("Either cron_expression or execute_at must be provided") if self.cron_expression and self.execute_at: raise ValueError("Only one of cron_expression or execute_at should be provided") - + @abstractmethod async def run(self, **kwargs: Any) -> Any: """ Execute the scheduled job. - + This method should contain the actual logic for the job. It will be called by the scheduler when the job is triggered. - + Args: **kwargs: Additional arguments that may be passed to the job - + Returns: Any result from the job execution """ pass - + def is_recurring(self) -> bool: """Check if this is a recurring job (has cron expression).""" return self.cron_expression is not None - + def is_one_time(self) -> bool: """Check if this is a one-time job (has execute_at datetime).""" return self.execute_at is not None - + def should_execute_now(self) -> bool: """ Check if this job should execute now (for one-time jobs). @@ -66,7 +68,7 @@ def should_execute_now(self) -> bool: if not self.is_one_time() or not self.execute_at: return False return datetime.now() >= self.execute_at - + def __str__(self) -> str: schedule_info = self.cron_expression if self.cron_expression else f"at {self.execute_at}" return f"ScheduledJob(name='{self.name}', schedule='{schedule_info}', enabled={self.enabled})" diff --git a/src/talos/data/dataset_manager.py b/src/talos/data/dataset_manager.py index fec422e0..d875d61c 100644 --- a/src/talos/data/dataset_manager.py +++ b/src/talos/data/dataset_manager.py @@ -80,10 +80,12 @@ def search(self, query: str, k: int = 5) -> list[str]: results = self.vector_store.similarity_search(query, k=k) return [doc.page_content for doc in results] - def add_document_from_ipfs(self, name: str, ipfs_hash: str, chunk_size: int = 1000, chunk_overlap: int = 200) -> None: + def add_document_from_ipfs( + self, name: str, ipfs_hash: str, chunk_size: int = 1000, chunk_overlap: int = 200 + ) -> None: """ Loads a document from IPFS hash and adds it to the dataset with intelligent chunking. - + Args: name: Name for the dataset ipfs_hash: IPFS hash of the document @@ -92,14 +94,14 @@ def add_document_from_ipfs(self, name: str, ipfs_hash: str, chunk_size: int = 10 """ ipfs_tool = IpfsTool() content = ipfs_tool.get_content(ipfs_hash) - + chunks = self._process_and_chunk_content(content, chunk_size, chunk_overlap) self.add_dataset(name, chunks) - + def add_document_from_url(self, name: str, url: str, chunk_size: int = 1000, chunk_overlap: int = 200) -> None: """ Loads a document from URL and adds it to the dataset with intelligent chunking. - + Args: name: Name for the dataset url: URL of the document @@ -109,69 +111,69 @@ def add_document_from_url(self, name: str, url: str, chunk_size: int = 1000, chu content = self._fetch_content_from_url(url) chunks = self._process_and_chunk_content(content, chunk_size, chunk_overlap) self.add_dataset(name, chunks) - + def _fetch_content_from_url(self, url: str) -> str: """Fetch content from URL, handling different content types.""" response = requests.get(url, timeout=30) response.raise_for_status() - - content_type = response.headers.get('content-type', '').lower() - - if 'application/pdf' in content_type: + + content_type = response.headers.get("content-type", "").lower() + + if "application/pdf" in content_type: pdf_reader = PdfReader(BytesIO(response.content)) text = "" for page in pdf_reader.pages: text += page.extract_text() + "\n" return text else: - if 'text/html' in content_type: - soup = BeautifulSoup(response.text, 'html.parser') + if "text/html" in content_type: + soup = BeautifulSoup(response.text, "html.parser") for script in soup(["script", "style"]): script.decompose() return soup.get_text() else: return response.text - + def _process_and_chunk_content(self, content: str, chunk_size: int, chunk_overlap: int) -> list[str]: """Process content and split into intelligent chunks.""" content = self._clean_text(content) - + chunks = [] start = 0 - + while start < len(content): end = start + chunk_size - + if end < len(content): search_start = max(start + chunk_size - 200, start) sentence_end = self._find_sentence_boundary(content, search_start, end) if sentence_end > start: end = sentence_end - + chunk = content[start:end].strip() if chunk: chunks.append(chunk) - + start = max(start + chunk_size - chunk_overlap, end) - + if start >= len(content): break - + return chunks - + def _clean_text(self, text: str) -> str: """Clean and normalize text content.""" - text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text) - text = re.sub(r'[ \t]+', ' ', text) + text = re.sub(r"\n\s*\n\s*\n+", "\n\n", text) + text = re.sub(r"[ \t]+", " ", text) return text.strip() - + def _find_sentence_boundary(self, text: str, start: int, end: int) -> int: """Find the best sentence boundary within the given range.""" - sentence_pattern = r'[.!?]\s+' - + sentence_pattern = r"[.!?]\s+" + for match in re.finditer(sentence_pattern, text[start:end]): boundary = start + match.end() if boundary > start: return boundary - + return end diff --git a/src/talos/jobs/example_jobs.py b/src/talos/jobs/example_jobs.py index eb776d89..02e63b99 100644 --- a/src/talos/jobs/example_jobs.py +++ b/src/talos/jobs/example_jobs.py @@ -13,29 +13,29 @@ class HealthCheckJob(ScheduledJob): """ Example scheduled job that performs a health check every hour. """ - + def __init__(self, **kwargs): super().__init__( name="health_check", description="Performs a health check of the agent system", cron_expression="0 * * * *", # Every hour at minute 0 - **kwargs + **kwargs, ) - + async def run(self, **kwargs: Any) -> str: """ Perform a health check of the agent system. """ logger.info("Running health check job") - + current_time = datetime.now() health_status = { "timestamp": current_time.isoformat(), "status": "healthy", "uptime": "running", - "memory_usage": "normal" + "memory_usage": "normal", } - + logger.info(f"Health check completed: {health_status}") return f"Health check completed at {current_time}: System is healthy" @@ -44,29 +44,24 @@ class DailyReportJob(ScheduledJob): """ Example scheduled job that generates a daily report at 9 AM. """ - + def __init__(self, **kwargs): super().__init__( name="daily_report", description="Generates a daily activity report", cron_expression="0 9 * * *", # Daily at 9 AM - **kwargs + **kwargs, ) - + async def run(self, **kwargs: Any) -> str: """ Generate a daily activity report. """ logger.info("Running daily report job") - + current_date = datetime.now().strftime("%Y-%m-%d") - report_data = { - "date": current_date, - "tasks_completed": 0, - "skills_used": [], - "memory_entries": 0 - } - + report_data = {"date": current_date, "tasks_completed": 0, "skills_used": [], "memory_entries": 0} + logger.info(f"Daily report generated: {report_data}") return f"Daily report for {current_date} generated successfully" @@ -75,30 +70,23 @@ class OneTimeMaintenanceJob(ScheduledJob): """ Example one-time scheduled job for maintenance tasks. """ - + def __init__(self, execute_at: datetime, **kwargs): super().__init__( - name="maintenance_task", - description="Performs one-time maintenance task", - execute_at=execute_at, - **kwargs + name="maintenance_task", description="Performs one-time maintenance task", execute_at=execute_at, **kwargs ) - + async def run(self, **kwargs: Any) -> str: """ Perform a one-time maintenance task. """ logger.info("Running one-time maintenance job") - - maintenance_tasks = [ - "Clean temporary files", - "Optimize memory usage", - "Update internal metrics" - ] - + + maintenance_tasks = ["Clean temporary files", "Optimize memory usage", "Update internal metrics"] + for task in maintenance_tasks: logger.info(f"Executing maintenance task: {task}") - + completion_time = datetime.now() logger.info(f"Maintenance completed at {completion_time}") return f"Maintenance tasks completed at {completion_time}" @@ -107,14 +95,10 @@ async def run(self, **kwargs: Any) -> str: def create_example_jobs() -> list[ScheduledJob]: """ Create a list of example scheduled jobs for demonstration. - + Returns: List of example ScheduledJob instances """ - jobs = [ - HealthCheckJob(), - DailyReportJob(), - OneTimeMaintenanceJob(execute_at=datetime.now() + timedelta(minutes=5)) - ] - + jobs = [HealthCheckJob(), DailyReportJob(), OneTimeMaintenanceJob(execute_at=datetime.now() + timedelta(minutes=5))] + return jobs diff --git a/src/talos/models/twitter.py b/src/talos/models/twitter.py index 6611470a..24455f36 100644 --- a/src/talos/models/twitter.py +++ b/src/talos/models/twitter.py @@ -40,16 +40,13 @@ class Tweet(BaseModel): referenced_tweets: Optional[list[ReferencedTweet]] = None in_reply_to_user_id: Optional[str] = None edit_history_tweet_ids: Optional[list[str]] = None - + def is_reply_to(self, tweet_id: str) -> bool: """Check if this tweet is a reply to the specified tweet ID.""" if not self.referenced_tweets: return False - return any( - ref.type == "replied_to" and ref.id == tweet_id - for ref in self.referenced_tweets - ) - + return any(ref.type == "replied_to" and ref.id == tweet_id for ref in self.referenced_tweets) + def get_replied_to_id(self) -> Optional[int]: """Get the ID of the tweet this is replying to, if any.""" if not self.referenced_tweets: diff --git a/src/talos/services/implementations/yield_manager.py b/src/talos/services/implementations/yield_manager.py index 69622edb..af350909 100644 --- a/src/talos/services/implementations/yield_manager.py +++ b/src/talos/services/implementations/yield_manager.py @@ -23,7 +23,7 @@ def __init__( raise ValueError("Min and max yield must be positive") if min_yield >= max_yield: raise ValueError("Min yield must be less than max yield") - + self.dexscreener_client = dexscreener_client self.gecko_terminal_client = gecko_terminal_client self.llm_client = llm_client @@ -51,7 +51,7 @@ def update_staking_apr(self, sentiment: float, sentiment_report: str) -> float: dexscreener_data, ohlcv_data, sentiment, staked_supply_percentage ) logging.info(f"Data source scores: {data_scores}") - + weighted_apr = self._calculate_weighted_apr_recommendation(data_scores) logging.info(f"Weighted APR recommendation: {weighted_apr}") @@ -66,9 +66,9 @@ def update_staking_apr(self, sentiment: float, sentiment_report: str) -> float: staked_supply_percentage=staked_supply_percentage, ohlcv_data=ohlcv_data.model_dump_json(), ) - + enhanced_prompt = f"{prompt}\n\nBased on weighted analysis of the data sources, the recommended APR is {weighted_apr:.4f}. Please consider this recommendation along with the raw data. The APR must be between {self.min_yield} and {self.max_yield}." - + response = self.llm_client.reasoning(enhanced_prompt, web_search=True) try: response_json = json.loads(response) @@ -81,84 +81,86 @@ def update_staking_apr(self, sentiment: float, sentiment_report: str) -> float: return max(self.min_yield, min(self.max_yield, weighted_apr)) final_apr = max(self.min_yield, min(self.max_yield, llm_apr)) - + if final_apr != llm_apr: logging.info(f"APR bounded from {llm_apr} to {final_apr} (min: {self.min_yield}, max: {self.max_yield})") - + return final_apr def get_staked_supply_percentage(self) -> float: return 0.45 - def _calculate_data_source_scores(self, dexscreener_data, ohlcv_data, sentiment: float, staked_supply_percentage: float) -> dict: + def _calculate_data_source_scores( + self, dexscreener_data, ohlcv_data, sentiment: float, staked_supply_percentage: float + ) -> dict: scores = {} - + price_change = dexscreener_data.price_change_h24 if price_change > 0.1: - scores['price_trend'] = 0.8 + scores["price_trend"] = 0.8 elif price_change > 0.05: - scores['price_trend'] = 0.6 + scores["price_trend"] = 0.6 elif price_change > -0.05: - scores['price_trend'] = 0.5 + scores["price_trend"] = 0.5 elif price_change > -0.1: - scores['price_trend'] = 0.3 + scores["price_trend"] = 0.3 else: - scores['price_trend'] = 0.1 - + scores["price_trend"] = 0.1 + volume = dexscreener_data.volume_h24 if volume > 1000000: - scores['volume_confidence'] = 0.8 + scores["volume_confidence"] = 0.8 elif volume > 500000: - scores['volume_confidence'] = 0.6 + scores["volume_confidence"] = 0.6 elif volume > 100000: - scores['volume_confidence'] = 0.4 + scores["volume_confidence"] = 0.4 else: - scores['volume_confidence'] = 0.2 - - scores['sentiment'] = max(0.0, min(1.0, sentiment / 100.0)) - + scores["volume_confidence"] = 0.2 + + scores["sentiment"] = max(0.0, min(1.0, sentiment / 100.0)) + if staked_supply_percentage > 0.8: - scores['supply_pressure'] = 0.2 + scores["supply_pressure"] = 0.2 elif staked_supply_percentage > 0.6: - scores['supply_pressure'] = 0.4 + scores["supply_pressure"] = 0.4 elif staked_supply_percentage > 0.4: - scores['supply_pressure'] = 0.6 + scores["supply_pressure"] = 0.6 elif staked_supply_percentage > 0.2: - scores['supply_pressure'] = 0.8 + scores["supply_pressure"] = 0.8 else: - scores['supply_pressure'] = 1.0 - + scores["supply_pressure"] = 1.0 + if ohlcv_data.ohlcv_list: recent_ohlcv = ohlcv_data.ohlcv_list[-5:] if len(recent_ohlcv) >= 2: price_range = max(item.high for item in recent_ohlcv) - min(item.low for item in recent_ohlcv) avg_price = sum(item.close for item in recent_ohlcv) / len(recent_ohlcv) volatility = price_range / avg_price if avg_price > 0 else 0 - + if volatility > 0.2: - scores['volatility'] = 0.3 + scores["volatility"] = 0.3 elif volatility > 0.1: - scores['volatility'] = 0.5 + scores["volatility"] = 0.5 else: - scores['volatility'] = 0.7 + scores["volatility"] = 0.7 else: - scores['volatility'] = 0.5 + scores["volatility"] = 0.5 else: - scores['volatility'] = 0.5 - + scores["volatility"] = 0.5 + return scores def _calculate_weighted_apr_recommendation(self, scores: dict) -> float: weights = { - 'price_trend': 0.25, - 'volume_confidence': 0.15, - 'sentiment': 0.20, - 'supply_pressure': 0.25, - 'volatility': 0.15 + "price_trend": 0.25, + "volume_confidence": 0.15, + "sentiment": 0.20, + "supply_pressure": 0.25, + "volatility": 0.15, } - + weighted_score = sum(scores[factor] * weights[factor] for factor in weights.keys()) - + apr_recommendation = self.min_yield + (weighted_score * (self.max_yield - self.min_yield)) - + return apr_recommendation diff --git a/src/talos/skills/twitter_influence.py b/src/talos/skills/twitter_influence.py index f1c15988..cadd11e7 100644 --- a/src/talos/skills/twitter_influence.py +++ b/src/talos/skills/twitter_influence.py @@ -34,7 +34,11 @@ def model_post_init(self, __context: Any) -> None: self.memory = Memory(file_path=memory_path, embeddings_model=embeddings, auto_save=True) if self.evaluator is None: - file_prompt_manager = self.prompt_manager if isinstance(self.prompt_manager, FilePromptManager) else FilePromptManager("src/talos/prompts") + file_prompt_manager = ( + self.prompt_manager + if isinstance(self.prompt_manager, FilePromptManager) + else FilePromptManager("src/talos/prompts") + ) self.evaluator = GeneralInfluenceEvaluator(self.twitter_client, self.llm, file_prompt_manager) @property diff --git a/src/talos/tools/document_loader.py b/src/talos/tools/document_loader.py index 5a196b5a..3cf3f286 100644 --- a/src/talos/tools/document_loader.py +++ b/src/talos/tools/document_loader.py @@ -17,17 +17,19 @@ class DocumentLoaderArgs(BaseModel): class DocumentLoaderTool(SupervisedTool): """Tool for loading documents from IPFS or URLs into the DatasetManager.""" - + name: str = "document_loader" description: str = "Loads documents from IPFS hashes or URLs and adds them to the dataset manager with intelligent chunking for RAG" args_schema: type[BaseModel] = DocumentLoaderArgs _dataset_manager: DatasetManager = PrivateAttr() - + def __init__(self, dataset_manager: DatasetManager, **kwargs): super().__init__(**kwargs) self._dataset_manager = dataset_manager - - def _run_unsupervised(self, name: str, source: str, chunk_size: int = 1000, chunk_overlap: int = 200, **kwargs: Any) -> str: + + def _run_unsupervised( + self, name: str, source: str, chunk_size: int = 1000, chunk_overlap: int = 200, **kwargs: Any + ) -> str: """Load document from IPFS hash or URL.""" try: if self._is_ipfs_hash(source): @@ -38,14 +40,14 @@ def _run_unsupervised(self, name: str, source: str, chunk_size: int = 1000, chun return f"Successfully loaded document from URL {source} into dataset '{name}'" except Exception as e: return f"Failed to load document: {str(e)}" - + def _is_ipfs_hash(self, source: str) -> bool: """Check if source is an IPFS hash.""" - if source.startswith('Qm') and len(source) == 46: + if source.startswith("Qm") and len(source) == 46: return True - if source.startswith('b') and len(source) > 46: + if source.startswith("b") and len(source) > 46: return True - if source.startswith('ipfs://'): + if source.startswith("ipfs://"): return True return False @@ -57,16 +59,16 @@ class DatasetSearchArgs(BaseModel): class DatasetSearchTool(SupervisedTool): """Tool for searching datasets in the DatasetManager.""" - + name: str = "dataset_search" description: str = "Search for similar content in loaded datasets" args_schema: type[BaseModel] = DatasetSearchArgs _dataset_manager: DatasetManager = PrivateAttr() - + def __init__(self, dataset_manager: DatasetManager, **kwargs): super().__init__(**kwargs) self._dataset_manager = dataset_manager - + def _run_unsupervised(self, query: str, k: int = 5, **kwargs: Any) -> list[str]: """Search for similar content in the datasets.""" try: diff --git a/src/talos/tools/general_influence_evaluator.py b/src/talos/tools/general_influence_evaluator.py index 48d13c46..717f0d6d 100644 --- a/src/talos/tools/general_influence_evaluator.py +++ b/src/talos/tools/general_influence_evaluator.py @@ -132,7 +132,7 @@ def _fallback_content_analysis(self, tweets: List[Any]) -> int: total_length = sum(len(tweet.text) for tweet in tweets) avg_length = total_length / len(tweets) - + if avg_length >= 200: length_score = 80 elif avg_length >= 100: @@ -144,7 +144,7 @@ def _fallback_content_analysis(self, tweets: List[Any]) -> int: original_tweets = [tweet for tweet in tweets if not tweet.text.startswith("RT @")] originality_ratio = len(original_tweets) / len(tweets) if tweets else 0 - + if originality_ratio >= 0.8: originality_score = 80 elif originality_ratio >= 0.6: @@ -189,46 +189,43 @@ def _calculate_authenticity_score(self, user: Any, tweets: List[Any] | None = No """Calculate enhanced authenticity score with advanced bot detection (0-100)""" if tweets is None: tweets = [] - + base_score = self._calculate_base_authenticity(user) - + engagement_score = self._calculate_engagement_authenticity(user, tweets) - + content_score = self._calculate_content_authenticity(tweets) - + temporal_score = self._calculate_temporal_authenticity(tweets) - + composite_score = int( - base_score * 0.40 + - engagement_score * 0.25 + - content_score * 0.20 + - temporal_score * 0.15 + base_score * 0.40 + engagement_score * 0.25 + content_score * 0.20 + temporal_score * 0.15 ) - + return min(100, max(0, composite_score)) def _calculate_base_authenticity(self, user: Any) -> int: """Calculate base authenticity score from account indicators (0-100)""" score = 0 - + account_age_days = (datetime.now(timezone.utc) - user.created_at).days if account_age_days > 1825: # 5+ years score += 35 - elif account_age_days > 1095: # 3+ years + elif account_age_days > 1095: # 3+ years score += 30 - elif account_age_days > 730: # 2+ years + elif account_age_days > 730: # 2+ years score += 25 - elif account_age_days > 365: # 1+ year + elif account_age_days > 365: # 1+ year score += 20 - elif account_age_days > 180: # 6+ months + elif account_age_days > 180: # 6+ months score += 10 - elif account_age_days < 30: # Suspicious new accounts + elif account_age_days < 30: # Suspicious new accounts score -= 10 - + if user.verified: score += 25 - - if user.profile_image_url and not user.profile_image_url.endswith('default_profile_images/'): + + if user.profile_image_url and not user.profile_image_url.endswith("default_profile_images/"): score += 15 if user.description and len(user.description) > 20: score += 10 @@ -236,77 +233,77 @@ def _calculate_base_authenticity(self, user: Any) -> int: score += 5 if user.url: score += 5 - + following = user.public_metrics.get("following_count", 0) - + if following > 50000: score -= 15 elif following > 10000: score -= 5 - + return min(100, max(0, score)) def _calculate_engagement_authenticity(self, user: Any, tweets: List[Any]) -> int: """Analyze engagement patterns for authenticity indicators (0-100)""" if not tweets: return 50 # Neutral score when no data available - + score = 50 # Start with neutral followers = user.public_metrics.get("followers_count", 0) - + if followers == 0: return 20 # Very suspicious - + engagement_rates = [] for tweet in tweets[:20]: # Analyze recent tweets engagement = ( - tweet.public_metrics.get("like_count", 0) + - tweet.public_metrics.get("retweet_count", 0) + - tweet.public_metrics.get("reply_count", 0) + tweet.public_metrics.get("like_count", 0) + + tweet.public_metrics.get("retweet_count", 0) + + tweet.public_metrics.get("reply_count", 0) ) rate = (engagement / followers) * 100 engagement_rates.append(rate) - + if engagement_rates: avg_rate = sum(engagement_rates) / len(engagement_rates) rate_variance = sum((r - avg_rate) ** 2 for r in engagement_rates) / len(engagement_rates) - + if rate_variance < 0.1: # Very consistent score += 20 elif rate_variance < 1.0: # Reasonably consistent score += 10 elif rate_variance > 10.0: # Highly inconsistent (suspicious) score -= 15 - + if avg_rate > 10: # >10% engagement rate is unusual score -= 20 elif avg_rate > 5: score -= 10 elif avg_rate < 0.1: # Very low engagement also suspicious score -= 10 - + like_counts = [t.public_metrics.get("like_count", 0) for t in tweets[:10]] retweet_counts = [t.public_metrics.get("retweet_count", 0) for t in tweets[:10]] - + if sum(like_counts) > 0 and sum(retweet_counts) > 0: like_rt_ratio = sum(like_counts) / sum(retweet_counts) if 2 <= like_rt_ratio <= 20: # Normal range score += 15 else: # Unusual ratios score -= 10 - + return min(100, max(0, score)) def _calculate_content_authenticity(self, tweets: List[Any]) -> int: """Analyze content patterns for authenticity indicators (0-100)""" if not tweets: return 50 # Neutral score when no data available - + score = 50 # Start with neutral - + tweet_texts = [tweet.text for tweet in tweets[:20]] unique_texts = set(tweet_texts) - + if len(tweet_texts) > 0: uniqueness_ratio = len(unique_texts) / len(tweet_texts) if uniqueness_ratio > 0.9: # High uniqueness @@ -315,97 +312,97 @@ def _calculate_content_authenticity(self, tweets: List[Any]) -> int: score += 15 elif uniqueness_ratio < 0.5: # Low uniqueness (suspicious) score -= 20 - + original_tweets = [t for t in tweets if not t.text.startswith("RT @")] - + if len(tweets) > 0: original_ratio = len(original_tweets) / len(tweets) if original_ratio > 0.7: # Mostly original content score += 20 elif original_ratio < 0.3: # Mostly retweets (suspicious) score -= 15 - + hashtag_counts = [] for tweet in tweets[:10]: - hashtag_count = tweet.text.count('#') + hashtag_count = tweet.text.count("#") hashtag_counts.append(hashtag_count) - + if hashtag_counts: avg_hashtags = sum(hashtag_counts) / len(hashtag_counts) if avg_hashtags > 5: # Excessive hashtag use score -= 15 elif 1 <= avg_hashtags <= 3: # Normal hashtag use score += 10 - + if original_tweets: avg_length = sum(len(t.text) for t in original_tweets) / len(original_tweets) if avg_length > 100: # Longer, more thoughtful tweets score += 15 elif avg_length < 30: # Very short tweets (suspicious) score -= 10 - + return min(100, max(0, score)) def _calculate_temporal_authenticity(self, tweets: List[Any]) -> int: """Analyze temporal posting patterns for authenticity indicators (0-100)""" if not tweets: return 50 # Neutral score when no data available - + score = 50 # Start with neutral - + # Analyze posting frequency tweets_with_dates = [t for t in tweets if t.created_at] if len(tweets_with_dates) < 2: return score - + timestamps = [] for tweet in tweets_with_dates[:20]: try: if isinstance(tweet.created_at, str): - timestamp = datetime.fromisoformat(tweet.created_at.replace('Z', '+00:00')) + timestamp = datetime.fromisoformat(tweet.created_at.replace("Z", "+00:00")) else: timestamp = tweet.created_at timestamps.append(timestamp) except (ValueError, AttributeError, TypeError): continue - + if len(timestamps) < 2: return score - + timestamps.sort() intervals = [] for i in range(1, len(timestamps)): - interval = (timestamps[i] - timestamps[i-1]).total_seconds() + interval = (timestamps[i] - timestamps[i - 1]).total_seconds() intervals.append(interval) - + if intervals: avg_interval = sum(intervals) / len(intervals) interval_variance = sum((i - avg_interval) ** 2 for i in intervals) / len(intervals) - + if interval_variance < (avg_interval * 0.1) ** 2 and len(intervals) > 5: score -= 20 # Too regular elif interval_variance > (avg_interval * 2) ** 2: - score += 10 # Natural variance - + score += 10 # Natural variance + if avg_interval < 300: # Less than 5 minutes average score -= 25 elif avg_interval < 3600: # Less than 1 hour average score -= 10 - + return min(100, max(0, score)) def _calculate_influence_score(self, user: Any) -> int: """Calculate influence score based on follower metrics (0-100)""" followers = user.public_metrics.get("followers_count", 0) following = user.public_metrics.get("following_count", 0) - + if followers >= 1000000: # 1M+ follower_score = 100 elif followers >= 100000: # 100K+ follower_score = 80 - elif followers >= 10000: # 10K+ + elif followers >= 10000: # 10K+ follower_score = 60 - elif followers >= 1000: # 1K+ + elif followers >= 1000: # 1K+ follower_score = 40 else: follower_score = 20 @@ -444,7 +441,7 @@ def _calculate_credibility_score(self, user: Any, tweets: List[Any]) -> int: if tweets: tweet_count = user.public_metrics.get("tweet_count", 0) account_age_days = (datetime.now(timezone.utc) - user.created_at).days - + if account_age_days > 0: tweets_per_day = tweet_count / account_age_days if 0.5 <= tweets_per_day <= 10: # Reasonable posting frequency @@ -454,7 +451,7 @@ def _calculate_credibility_score(self, user: Any, tweets: List[Any]) -> int: else: # Too much or too little posting score += 5 - if user.url and any(domain in user.url for domain in ['.com', '.org', '.edu', '.gov']): + if user.url and any(domain in user.url for domain in [".com", ".org", ".edu", ".gov"]): score += 10 return min(100, score) diff --git a/src/talos/tools/twitter_client.py b/src/talos/tools/twitter_client.py index 5f1d6c30..3ead4efb 100644 --- a/src/talos/tools/twitter_client.py +++ b/src/talos/tools/twitter_client.py @@ -6,7 +6,7 @@ from pydantic_settings import BaseSettings from textblob import TextBlob -from talos.models.twitter import TwitterUser, Tweet, ReferencedTweet +from talos.models.twitter import ReferencedTweet, Tweet, TwitterUser class PaginatedTwitterResponse: @@ -91,6 +91,7 @@ def get_user(self, username: str) -> TwitterUser: ], ) from talos.models.twitter import TwitterPublicMetrics + user_data = response.data return TwitterUser( id=int(user_data.id), @@ -101,7 +102,7 @@ def get_user(self, username: str) -> TwitterUser: public_metrics=TwitterPublicMetrics(**user_data.public_metrics), description=user_data.description, url=user_data.url, - verified=getattr(user_data, 'verified', False) + verified=getattr(user_data, "verified", False), ) def search_tweets( @@ -148,7 +149,15 @@ def get_user_timeline(self, username: str) -> list[Tweet]: return [] response = self.client.get_users_tweets( id=user.id, - tweet_fields=["author_id", "in_reply_to_user_id", "public_metrics", "referenced_tweets", "conversation_id", "created_at", "edit_history_tweet_ids"], + tweet_fields=[ + "author_id", + "in_reply_to_user_id", + "public_metrics", + "referenced_tweets", + "conversation_id", + "created_at", + "edit_history_tweet_ids", + ], user_fields=[ "created_at", "public_metrics", @@ -167,7 +176,15 @@ def get_user_mentions(self, username: str) -> list[Tweet]: return [] response = self.client.get_users_mentions( id=user.id, - tweet_fields=["author_id", "in_reply_to_user_id", "public_metrics", "referenced_tweets", "conversation_id", "created_at", "edit_history_tweet_ids"], + tweet_fields=[ + "author_id", + "in_reply_to_user_id", + "public_metrics", + "referenced_tweets", + "conversation_id", + "created_at", + "edit_history_tweet_ids", + ], user_fields=[ "created_at", "public_metrics", @@ -183,7 +200,15 @@ def get_user_mentions(self, username: str) -> list[Tweet]: def get_tweet(self, tweet_id: int) -> Tweet: response = self.client.get_tweet( str(tweet_id), - tweet_fields=["author_id", "in_reply_to_user_id", "public_metrics", "referenced_tweets", "conversation_id", "created_at", "edit_history_tweet_ids"] + tweet_fields=[ + "author_id", + "in_reply_to_user_id", + "public_metrics", + "referenced_tweets", + "conversation_id", + "created_at", + "edit_history_tweet_ids", + ], ) return self._convert_to_tweet_model(response.data) @@ -209,27 +234,31 @@ def reply_to_tweet(self, tweet_id: str, tweet: str) -> Any: def _convert_to_tweet_model(self, tweet_data: Any) -> Tweet: """Convert raw tweepy tweet data to Tweet BaseModel""" referenced_tweets = [] - if hasattr(tweet_data, 'referenced_tweets') and tweet_data.referenced_tweets: + if hasattr(tweet_data, "referenced_tweets") and tweet_data.referenced_tweets: for ref in tweet_data.referenced_tweets: if isinstance(ref, dict): - referenced_tweets.append(ReferencedTweet( - type=ref.get('type', ''), - id=ref.get('id', 0) - )) + referenced_tweets.append(ReferencedTweet(type=ref.get("type", ""), id=ref.get("id", 0))) else: - referenced_tweets.append(ReferencedTweet( - type=getattr(ref, 'type', ''), - id=getattr(ref, 'id', 0) - )) - + referenced_tweets.append(ReferencedTweet(type=getattr(ref, "type", ""), id=getattr(ref, "id", 0))) + return Tweet( id=int(tweet_data.id), text=tweet_data.text, author_id=str(tweet_data.author_id), - created_at=str(tweet_data.created_at) if hasattr(tweet_data, 'created_at') and tweet_data.created_at else None, - conversation_id=str(tweet_data.conversation_id) if hasattr(tweet_data, 'conversation_id') and tweet_data.conversation_id else None, - public_metrics=dict(tweet_data.public_metrics) if hasattr(tweet_data, 'public_metrics') and tweet_data.public_metrics else {}, + created_at=str(tweet_data.created_at) + if hasattr(tweet_data, "created_at") and tweet_data.created_at + else None, + conversation_id=str(tweet_data.conversation_id) + if hasattr(tweet_data, "conversation_id") and tweet_data.conversation_id + else None, + public_metrics=dict(tweet_data.public_metrics) + if hasattr(tweet_data, "public_metrics") and tweet_data.public_metrics + else {}, referenced_tweets=referenced_tweets if referenced_tweets else None, - in_reply_to_user_id=str(tweet_data.in_reply_to_user_id) if hasattr(tweet_data, 'in_reply_to_user_id') and tweet_data.in_reply_to_user_id else None, - edit_history_tweet_ids=[str(id) for id in tweet_data.edit_history_tweet_ids] if hasattr(tweet_data, 'edit_history_tweet_ids') and tweet_data.edit_history_tweet_ids else None + in_reply_to_user_id=str(tweet_data.in_reply_to_user_id) + if hasattr(tweet_data, "in_reply_to_user_id") and tweet_data.in_reply_to_user_id + else None, + edit_history_tweet_ids=[str(id) for id in tweet_data.edit_history_tweet_ids] + if hasattr(tweet_data, "edit_history_tweet_ids") and tweet_data.edit_history_tweet_ids + else None, ) diff --git a/tests/test_document_loader.py b/tests/test_document_loader.py index b016c8a5..9969e2c8 100644 --- a/tests/test_document_loader.py +++ b/tests/test_document_loader.py @@ -2,7 +2,7 @@ from unittest.mock import Mock, patch from talos.data.dataset_manager import DatasetManager -from talos.tools.document_loader import DocumentLoaderTool, DatasetSearchTool +from talos.tools.document_loader import DatasetSearchTool, DocumentLoaderTool class TestDocumentLoader(unittest.TestCase): @@ -11,28 +11,28 @@ def setUp(self): self.dataset_manager = DatasetManager() self.document_loader = DocumentLoaderTool(self.dataset_manager) self.dataset_search = DatasetSearchTool(self.dataset_manager) - + def test_is_ipfs_hash(self): self.assertTrue(self.document_loader._is_ipfs_hash("QmYwAPJzv5CZsnA625s3Xf2nemtYgPpHdWEz79ojWnPbdG")) self.assertTrue(self.document_loader._is_ipfs_hash("ipfs://QmYwAPJzv5CZsnA625s3Xf2nemtYgPpHdWEz79ojWnPbdG")) self.assertFalse(self.document_loader._is_ipfs_hash("https://example.com/document.pdf")) - - @patch('talos.data.dataset_manager.requests.get') + + @patch("talos.data.dataset_manager.requests.get") def test_fetch_content_from_url_text(self, mock_get): mock_response = Mock() - mock_response.headers = {'content-type': 'text/plain'} + mock_response.headers = {"content-type": "text/plain"} mock_response.text = "This is a test document." mock_response.raise_for_status.return_value = None mock_get.return_value = mock_response - + content = self.dataset_manager._fetch_content_from_url("https://example.com/test.txt") self.assertEqual(content, "This is a test document.") - + def test_clean_text(self): dirty_text = "This is a\n\n\n\ntest document." clean_text = self.dataset_manager._clean_text(dirty_text) self.assertEqual(clean_text, "This is a\n\ntest document.") - + def test_chunk_content(self): content = "This is sentence one. This is sentence two. This is sentence three." chunks = self.dataset_manager._process_and_chunk_content(content, chunk_size=30, chunk_overlap=10) @@ -40,5 +40,5 @@ def test_chunk_content(self): self.assertTrue(all(len(chunk) <= 40 for chunk in chunks)) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_scheduled_jobs.py b/tests/test_scheduled_jobs.py index 153cd059..8118e8fb 100644 --- a/tests/test_scheduled_jobs.py +++ b/tests/test_scheduled_jobs.py @@ -8,30 +8,27 @@ import pytest from langchain_openai import ChatOpenAI +from talos.core.job_scheduler import JobScheduler from talos.core.main_agent import MainAgent from talos.core.scheduled_job import ScheduledJob -from talos.core.job_scheduler import JobScheduler class MockScheduledJob(ScheduledJob): """Test implementation of ScheduledJob for testing purposes.""" - + execution_count: int = 0 last_execution: Optional[datetime] = None - + def __init__(self, name: str = "test_job", **kwargs): - if 'cron_expression' not in kwargs and 'execute_at' not in kwargs: - kwargs['cron_expression'] = "0 * * * *" # Every hour - - kwargs.setdefault('description', "Test scheduled job") - kwargs.setdefault('execution_count', 0) - kwargs.setdefault('last_execution', None) - - super().__init__( - name=name, - **kwargs - ) - + if "cron_expression" not in kwargs and "execute_at" not in kwargs: + kwargs["cron_expression"] = "0 * * * *" # Every hour + + kwargs.setdefault("description", "Test scheduled job") + kwargs.setdefault("execution_count", 0) + kwargs.setdefault("last_execution", None) + + super().__init__(name=name, **kwargs) + async def run(self, **kwargs) -> str: self.execution_count += 1 self.last_execution = datetime.now() @@ -40,18 +37,14 @@ async def run(self, **kwargs) -> str: class MockOneTimeJob(ScheduledJob): """Test implementation of one-time ScheduledJob.""" - + executed: bool = False - + def __init__(self, execute_at: datetime, **kwargs): super().__init__( - name="one_time_test", - description="One-time test job", - execute_at=execute_at, - executed=False, - **kwargs + name="one_time_test", description="One-time test job", execute_at=execute_at, executed=False, **kwargs ) - + async def run(self, **kwargs) -> str: self.executed = True return "One-time job executed" @@ -59,7 +52,7 @@ async def run(self, **kwargs) -> str: class TestScheduledJobValidation: """Test ScheduledJob validation and configuration.""" - + def test_cron_job_creation(self): """Test creating a job with cron expression.""" job = MockScheduledJob(name="cron_test", cron_expression="0 9 * * *") @@ -68,7 +61,7 @@ def test_cron_job_creation(self): assert job.execute_at is None assert job.is_recurring() assert not job.is_one_time() - + def test_one_time_job_creation(self): """Test creating a one-time job with datetime.""" future_time = datetime.now() + timedelta(hours=1) @@ -78,7 +71,7 @@ def test_one_time_job_creation(self): assert job.cron_expression is None assert job.is_one_time() assert not job.is_recurring() - + def test_job_validation_requires_schedule(self): """Test that job validation requires either cron or datetime.""" with pytest.raises(ValueError, match="Either cron_expression or execute_at must be provided"): @@ -88,9 +81,9 @@ def test_job_validation_requires_schedule(self): cron_expression=None, execute_at=None, execution_count=0, - last_execution=None + last_execution=None, ) - + def test_job_validation_exclusive_schedule(self): """Test that job validation prevents both cron and datetime.""" future_time = datetime.now() + timedelta(hours=1) @@ -101,18 +94,18 @@ def test_job_validation_exclusive_schedule(self): cron_expression="0 * * * *", execute_at=future_time, execution_count=0, - last_execution=None + last_execution=None, ) - + def test_should_execute_now(self): """Test should_execute_now method for one-time jobs.""" past_time = datetime.now() - timedelta(minutes=1) future_time = datetime.now() + timedelta(minutes=1) - + past_job = MockOneTimeJob(execute_at=past_time) future_job = MockOneTimeJob(execute_at=future_time) cron_job = MockScheduledJob() - + assert past_job.should_execute_now() assert not future_job.should_execute_now() assert not cron_job.should_execute_now() @@ -120,147 +113,152 @@ def test_should_execute_now(self): class TestJobScheduler: """Test JobScheduler functionality.""" - + @pytest.fixture def scheduler(self): """Create a JobScheduler instance for testing.""" return JobScheduler() - + def test_scheduler_initialization(self, scheduler): """Test scheduler initialization.""" assert scheduler.timezone == "UTC" assert not scheduler.is_running() assert len(scheduler.list_jobs()) == 0 - + def test_register_job(self, scheduler): """Test job registration.""" job = MockScheduledJob(name="test_register") scheduler.register_job(job) - + assert len(scheduler.list_jobs()) == 1 assert scheduler.get_job("test_register") == job - + def test_unregister_job(self, scheduler): """Test job unregistration.""" job = MockScheduledJob(name="test_unregister") scheduler.register_job(job) - + assert scheduler.unregister_job("test_unregister") assert len(scheduler.list_jobs()) == 0 assert scheduler.get_job("test_unregister") is None - + def test_unregister_nonexistent_job(self, scheduler): """Test unregistering a job that doesn't exist.""" assert not scheduler.unregister_job("nonexistent") - + def test_register_disabled_job(self, scheduler): """Test registering a disabled job.""" job = MockScheduledJob(name="disabled_job", enabled=False) scheduler.register_job(job) - + assert len(scheduler.list_jobs()) == 1 assert scheduler.get_job("disabled_job") == job class TestMainAgentIntegration: """Test MainAgent integration with scheduled jobs.""" - + @pytest.fixture def main_agent(self): """Create a MainAgent instance for testing.""" - with patch.dict('os.environ', { - 'GITHUB_TOKEN': 'test_token', - 'TWITTER_BEARER_TOKEN': 'test_twitter_token', - 'OPENAI_API_KEY': 'test_openai_key' - }): - agent = MainAgent( - model=ChatOpenAI(model="gpt-4o", api_key="test_key"), - prompts_dir="src/talos/prompts" - ) + with patch.dict( + "os.environ", + { + "GITHUB_TOKEN": "test_token", + "TWITTER_BEARER_TOKEN": "test_twitter_token", + "OPENAI_API_KEY": "test_openai_key", + }, + ): + agent = MainAgent(model=ChatOpenAI(model="gpt-4o", api_key="test_key"), prompts_dir="src/talos/prompts") if agent.job_scheduler: agent.job_scheduler.stop() return agent - + def test_main_agent_scheduler_initialization(self, main_agent): """Test that MainAgent initializes with a job scheduler.""" assert main_agent.job_scheduler is not None assert isinstance(main_agent.job_scheduler, JobScheduler) - + def test_add_scheduled_job(self, main_agent): """Test adding a scheduled job to MainAgent.""" job = MockScheduledJob(name="main_agent_test") main_agent.add_scheduled_job(job) - + assert len(main_agent.list_scheduled_jobs()) == 1 assert main_agent.get_scheduled_job("main_agent_test") == job - + def test_remove_scheduled_job(self, main_agent): """Test removing a scheduled job from MainAgent.""" job = MockScheduledJob(name="remove_test") main_agent.add_scheduled_job(job) - + assert main_agent.remove_scheduled_job("remove_test") assert len(main_agent.list_scheduled_jobs()) == 0 assert main_agent.get_scheduled_job("remove_test") is None - + def test_pause_resume_job(self, main_agent): """Test pausing and resuming jobs.""" job = MockScheduledJob(name="pause_test") main_agent.add_scheduled_job(job) - + main_agent.pause_scheduled_job("pause_test") main_agent.resume_scheduled_job("pause_test") - + def test_predefined_jobs_registration(self): """Test that predefined jobs are registered during initialization.""" job = MockScheduledJob(name="predefined_job") - - with patch.dict('os.environ', { - 'GITHUB_TOKEN': 'test_token', - 'TWITTER_BEARER_TOKEN': 'test_twitter_token', - 'OPENAI_API_KEY': 'test_openai_key' - }): + + with patch.dict( + "os.environ", + { + "GITHUB_TOKEN": "test_token", + "TWITTER_BEARER_TOKEN": "test_twitter_token", + "OPENAI_API_KEY": "test_openai_key", + }, + ): agent = MainAgent( model=ChatOpenAI(model="gpt-4o", api_key="test_key"), prompts_dir="src/talos/prompts", - scheduled_jobs=[job] + scheduled_jobs=[job], ) if agent.job_scheduler: agent.job_scheduler.stop() - + assert len(agent.list_scheduled_jobs()) == 1 assert agent.get_scheduled_job("predefined_job") == job def test_job_execution(): """Test that jobs can be executed.""" + async def run_test(): job = MockScheduledJob(name="execution_test") - + result = await job.run() - + assert job.execution_count == 1 assert job.last_execution is not None assert result == "Test job executed 1 times" - + result2 = await job.run() assert job.execution_count == 2 assert result2 == "Test job executed 2 times" - + asyncio.run(run_test()) def test_one_time_job_execution(): """Test one-time job execution.""" + async def run_test(): future_time = datetime.now() + timedelta(seconds=1) job = MockOneTimeJob(execute_at=future_time) - + assert not job.executed - + result = await job.run() - + assert job.executed assert result == "One-time job executed" - + asyncio.run(run_test()) diff --git a/tests/test_yield_manager.py b/tests/test_yield_manager.py index 0d8ece82..9ebd2952 100644 --- a/tests/test_yield_manager.py +++ b/tests/test_yield_manager.py @@ -51,7 +51,7 @@ def test_min_max_yield_validation(self, mock_tweepy_client): with self.assertRaises(ValueError): YieldManagerService(dexscreener_client, gecko_terminal_client, llm_client, min_yield=-0.01) - + with self.assertRaises(ValueError): YieldManagerService(dexscreener_client, gecko_terminal_client, llm_client, min_yield=0.2, max_yield=0.1) @@ -67,20 +67,18 @@ def test_apr_bounds_enforcement(self, mock_tweepy_client): volume=1000000, ) gecko_terminal_client.get_ohlcv_data.return_value = GeckoTerminalOHLCVData(ohlcv_list=[]) - - llm_client.reasoning.return_value = json.dumps( - {"apr": 0.25, "explanation": "High APR recommendation"} - ) - yield_manager = YieldManagerService(dexscreener_client, gecko_terminal_client, llm_client, min_yield=0.05, max_yield=0.20) + llm_client.reasoning.return_value = json.dumps({"apr": 0.25, "explanation": "High APR recommendation"}) + + yield_manager = YieldManagerService( + dexscreener_client, gecko_terminal_client, llm_client, min_yield=0.05, max_yield=0.20 + ) yield_manager.get_staked_supply_percentage = MagicMock(return_value=0.5) new_apr = yield_manager.update_staking_apr(75.0, "A report") self.assertEqual(new_apr, 0.20) - llm_client.reasoning.return_value = json.dumps( - {"apr": 0.01, "explanation": "Low APR recommendation"} - ) + llm_client.reasoning.return_value = json.dumps({"apr": 0.01, "explanation": "Low APR recommendation"}) new_apr = yield_manager.update_staking_apr(75.0, "A report") self.assertEqual(new_apr, 0.05)