diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..3130744d --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,97 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Development Commands + +### Setup and Dependencies +- `uv venv` - Create virtual environment +- `source .venv/bin/activate` - Activate virtual environment +- `./scripts/install_deps.sh` - Install dependencies +- `uv pip install -e .[dev]` - Install project in development mode with dev dependencies + +### Code Quality and Testing +- `./scripts/run_checks.sh` - Run complete test suite (ruff format, ruff check, mypy, pytest) +- `uv run ruff format .` - Format code +- `uv run ruff check .` - Lint code +- `uv run mypy src` - Type check source code +- `uv run pytest` - Run all tests +- `uv run pytest tests/test_specific.py` - Run specific test file +- `uv run isort .` - Sort imports + +### Running the Application +- `uv run talos` - Start interactive CLI (requires OPENAI_API_KEY, PINATA_API_KEY, PINATA_SECRET_API_KEY environment variables) +- `uv run talos proposals eval --file ` - Evaluate a proposal from file +- `uv run talos twitter get-user-prompt ` - Get Twitter user persona prompt + +## Architecture Overview + +Talos is an AI protocol owner system built with a sophisticated multi-agent architecture: + +### Core Components + +**MainAgent** (`src/talos/core/main_agent.py`): The top-level orchestrating agent that: +- Delegates tasks using a Router to different services and skills +- Manages scheduled jobs for autonomous execution +- Integrates with Hypervisor for action approval +- Handles dataset management for contextual information retrieval + +**Agent Base Class** (`src/talos/core/agent.py`): Foundation for all agents providing: +- LangChain integration with chat models +- Message history management and memory persistence +- Tool management and supervised tool execution +- Context building with dataset search integration + +**Hypervisor** (`src/talos/hypervisor/hypervisor.py`): Security and governance layer that: +- Monitors and approves/denies agent actions +- Uses dedicated prompts to evaluate action safety +- Maintains agent history for context-aware decisions + +### Key Architectural Patterns + +**Router Pattern** (`src/talos/core/router.py`): Delegates tasks to appropriate services and skills based on request classification. + +**Skill-based Architecture** (`src/talos/skills/`): Modular capabilities including: +- ProposalsSkill - Governance proposal evaluation +- TwitterSentimentSkill - Social media sentiment analysis +- CryptographySkill - Cryptographic operations +- TwitterInfluenceSkill - Influence analysis + +**Service Layer** (`src/talos/services/`): Abstract and concrete implementations for external integrations: +- GitHub service for repository management +- Twitter service for social media interaction +- Proposal service for governance workflows +- Yield management for DeFi operations + +**Tool Management** (`src/talos/tools/`): Extensible tool system with: +- SupervisedTool base class requiring hypervisor approval +- Platform-specific tools (Twitter, GitHub, DexScreener) +- Document loading and search capabilities +- Memory management tools + +### Data Flow + +1. User input → MainAgent → Router → appropriate Service/Skill +2. Actions requiring approval → Hypervisor → approval/denial +3. Context building → DatasetManager → relevant document retrieval +4. Tool execution → ToolManager → supervised execution + +## Code Style Guidelines + +- Use `from __future__ import annotations` for forward references +- Prefer `list` and `dict` over `List` and `Dict` in type hints +- Use `model_post_init` for Pydantic model initialization logic +- Set `arbitrary_types_allowed=True` in ConfigDict for LangChain integration +- Default LLM model is `gpt-4o` +- Line length limit is 120 characters (configured in ruff) +- All function signatures require type hints + +## Environment Variables + +Required for full functionality: +- `OPENAI_API_KEY` - OpenAI API access +- `GITHUB_TOKEN` - GitHub API access +- `PINATA_API_KEY` / `PINATA_SECRET_API_KEY` - IPFS storage +- Additional service-specific keys as needed +- Always run scripts/run_checks.sh after performing a large code change. +- ⏺ Code commenting principle: Only include comments that explain why something is done or provide context that isn't obvious from the code itself - avoid comments that simply restate what the code is doing, as these add noise without value. \ No newline at end of file diff --git a/experimental/__init__.py b/experimental/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/experimental/archon/__init__.py b/experimental/archon/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/experimental/archon/src/__init__.py b/experimental/archon/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/experimental/archon/src/graph_executor.py b/experimental/archon/src/graph_executor.py new file mode 100644 index 00000000..365bf171 --- /dev/null +++ b/experimental/archon/src/graph_executor.py @@ -0,0 +1,87 @@ +""" +LangGraph Executor Module + +Provides GraphExecutor class for loading and executing stored graphs from IPFS. +This encapsulates the process of retrieving a graph by its IPFS hash and running +it with provided input, handling both sync and async execution automatically. +""" + +from __future__ import annotations + +from typing import Any, Union + +from langgraph.graph.graph import CompiledGraph +from pydantic import BaseModel + +from .graph_loader import GraphLoader + + +class LoadedGraph: + """Container for a loaded graph with its state class.""" + + def __init__(self, compiled_graph: CompiledGraph, state_class: type[BaseModel]): + self.compiled_graph = compiled_graph + self.state_class = state_class + + async def execute(self, input_state: Union[BaseModel, dict[str, Any]]) -> dict[str, Any]: + """Execute the graph with type-aware input handling.""" + # Convert Pydantic model to dict if needed (LangGraph expects dict) + if isinstance(input_state, BaseModel): + state_dict = input_state.model_dump() + else: + state_dict = input_state + + return await self.compiled_graph.ainvoke(state_dict) + + def create_state(self, **kwargs) -> BaseModel: + """Create a properly typed state instance.""" + return self.state_class(**kwargs) + + def validate_result(self, result: dict[str, Any]) -> BaseModel: + """Validate and convert result dict back to typed state.""" + return self.state_class.model_validate(result) + + +class GraphExecutor: + """ + Executes stored LangGraph workflows loaded from IPFS with enhanced type safety. + + Provides both simple execution and advanced type-aware methods. + + Example: + >>> executor = GraphExecutor() + >>> # Simple execution + >>> result = await executor.execute_graph(ipfs_hash, {"input": "data"}) + >>> + >>> # Type-aware execution + >>> loaded = executor.load_graph(ipfs_hash) + >>> state = loaded.create_state(input_text="Hello world!") + >>> result_dict = await loaded.execute(state) + >>> typed_result = loaded.validate_result(result_dict) + """ + + def __init__(self) -> None: + """Initialize the GraphExecutor with a GraphLoader instance.""" + self.loader = GraphLoader() + + def load_graph(self, ipfs_hash: str) -> LoadedGraph: + stored_definition = self.loader.retrieve_from_ipfs(ipfs_hash) + state_class = self.loader._load_class_from_reference(stored_definition.state_schema.class_reference) + + compiled_graph = self.loader.load_graph(ipfs_hash) + + return LoadedGraph(compiled_graph, state_class) + + async def execute_graph(self, ipfs_hash: str, input_state: dict[str, Any]) -> dict[str, Any]: + """ + Load and execute a graph from IPFS with the given input state. + + Args: + ipfs_hash: IPFS hash of the stored graph definition + input_state: Input state dictionary matching the graph's state schema + + Returns: + Final state dictionary after graph execution + """ + loaded_graph = self.load_graph(ipfs_hash) + return await loaded_graph.execute(input_state) diff --git a/experimental/archon/src/graph_loader.py b/experimental/archon/src/graph_loader.py new file mode 100644 index 00000000..f607c1b6 --- /dev/null +++ b/experimental/archon/src/graph_loader.py @@ -0,0 +1,353 @@ +""" +LangGraph IPFS Storage Module + +Provides GraphLoader class for serializing LangGraph workflows and storing them +on IPFS via Pinata for decentralized, governance-ready AI agent management. + +This module combines: +- LangGraph's native graph serialization +- Pydantic models for validation and structure +- IPFS storage via Pinata for immutable, content-addressed storage +""" + +from __future__ import annotations + +import importlib +import os +import tempfile +from typing import Any, Awaitable, Callable, Hashable, Union + +import requests +from langgraph.graph import END, START, StateGraph +from langgraph.graph.graph import CompiledGraph +from langgraph.utils.runnable import RunnableCallable +from pinata_python.pinning import Pinning +from pydantic import BaseModel + +from .graph_models import ( + ConditionalEdgeDefinition, + ExecutionConfig, + GraphEdgeDefinition, + GraphMetadata, + GraphNodeDefinition, + SerializableGraphDefinition, + StateChannelDefinition, + StateSchema, + StoredGraphDefinition, +) + + +class GraphLoader: + """ + Manages serialization, IPFS storage, and deserialization of LangGraph workflows. + + Combines LangGraph's native serialization with Pydantic validation and + IPFS storage via Pinata for decentralized, governance-ready AI workflows. + + Example: + >>> loader = GraphLoader() + >>> ipfs_hash = loader.save_graph(compiled_graph, "my_workflow", "Description") + >>> recreated_graph = loader.load_graph(ipfs_hash) + """ + + def serialize_graph_from_builder( + self, + state_graph: StateGraph, + name: str, + description: str, + created_by: str = "experimental_poc", + ) -> str: + """ + Serialize graph from StateGraph builder (before compilation) to capture serializable references. + + Args: + state_graph: The StateGraph builder (before calling .compile()) + name: Name for the workflow + description: Description of what the workflow does + created_by: Who created this workflow + + Returns: + JSON string representation of the serializable graph definition + + This approach extracts function references and graph structure from the StateGraph + builder before compilation, allowing for proper deserialization. + """ + # Extract serializable node definitions + nodes = [] + for node_name, node_spec in state_graph.nodes.items(): + runnable = node_spec.runnable + assert isinstance(runnable, RunnableCallable), "Only RunnableCallables are currently supported." + + if runnable.func is not None: + func = runnable.func + elif runnable.afunc is not None: + func = runnable.afunc + else: + raise ValueError( + f"Node '{node_name}' has no valid function reference. " + f"Expected runnable.func or runnable.afunc to be set" + ) + + function_reference = f"{func.__module__}:{func.__name__}" + nodes.append(GraphNodeDefinition(name=node_name, function_reference=function_reference)) + + # Extract simple edges + edges = [] + for edge_tuple in state_graph.edges: + source, target = edge_tuple + edges.append(GraphEdgeDefinition(source=source, target=target)) + + # Extract conditional edges from branches + conditional_edges = [] + branches = state_graph.branches + for source_node, branch_dict in branches.items(): + for func_name, branch_obj in branch_dict.items(): + assert isinstance(branch_obj.path, RunnableCallable), "Only RunnableCallables are currently supported." + if branch_obj.path.func is not None: + condition_func = branch_obj.path.func + elif branch_obj.path.afunc is not None: + condition_func = branch_obj.path.afunc + else: + raise ValueError( + f"Conditional edge from '{source_node}' has no valid condition function. " + f"Expected branch_obj.path.func or branch_obj.path.afunc to be set" + ) + + condition_function_reference = f"{condition_func.__module__}:{condition_func.__name__}" + conditional_edges.append( + ConditionalEdgeDefinition( + source_node=source_node, + condition_function_reference=condition_function_reference, + target_mapping=branch_obj.ends or {}, + ) + ) + + # Extract state channel information + state_channels: list[StateChannelDefinition] = [] + if hasattr(state_graph, "state_schema"): + # This would need to be implemented based on StateGraph internals + pass + + # Create serializable graph definition + serializable_def = SerializableGraphDefinition( + nodes=nodes, + edges=edges, + conditional_edges=conditional_edges, + state_channels=state_channels, + state_type_name=state_graph.schema.__name__, + ) + + # Create state schema from the StateGraph's state schema + state_schema = self._extract_state_schema(state_graph) + + # Create complete stored definition + stored_definition = StoredGraphDefinition( + metadata=GraphMetadata( + name=name, + description=description, + created_by=created_by, + ), + graph_definition=serializable_def, + state_schema=state_schema, + execution_config=ExecutionConfig(), + ) + + return stored_definition.model_dump_json(indent=2) + + def _extract_state_schema(self, state_graph: StateGraph) -> StateSchema: + """ + Extract Pydantic state schema information from StateGraph builder. + + Args: + state_graph: StateGraph builder instance + + Returns: + StateSchema with complete Pydantic model information + + Raises: + ValueError: If state schema is not a Pydantic BaseModel + """ + # Get the first (and should be only) schema class + schema_class = next(iter(state_graph.schemas.keys())) + + assert issubclass(schema_class, BaseModel), ( + f"State schema must be a Pydantic BaseModel, got {type(schema_class)}" + ) + + class_reference = f"{schema_class.__module__}:{schema_class.__name__}" + + return StateSchema( + name=schema_class.__name__, + class_reference=class_reference, + ) + + def store_to_ipfs(self, graph_json: str) -> str: + """ + Store serialized graph on IPFS via Pinata and return hash. + + Args: + graph_json: JSON string representation of the graph + + Returns: + IPFS hash of the stored content + """ + api_key = os.getenv("PINATA_API_KEY") + secret_key = os.getenv("PINATA_SECRET_API_KEY") + + if not api_key or not secret_key: + raise ValueError("PINATA_API_KEY and PINATA_SECRET_API_KEY environment variables required for IPFS storage") + + pinata = Pinning(PINATA_API_KEY=api_key, PINATA_API_SECRET=secret_key) + + # Pin JSON content to IPFS via Pinata + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + f.write(graph_json) + temp_path: str = f.name + + try: + response: dict[str, Any] = pinata.pin_file_to_ipfs(temp_path) + return response["IpfsHash"] + finally: + os.unlink(temp_path) + + def retrieve_from_ipfs(self, ipfs_hash: str) -> StoredGraphDefinition: + """ + Retrieve and deserialize graph from IPFS via Pinata gateway. + + Args: + ipfs_hash: IPFS hash of the stored graph definition + + Returns: + StoredGraphDefinition object with validated data + """ + gateway_url = f"https://gateway.pinata.cloud/ipfs/{ipfs_hash}" + response = requests.get(gateway_url) + response.raise_for_status() + + return StoredGraphDefinition.model_validate(response.json()) + + def recreate_graph_from_definition(self, stored_definition: StoredGraphDefinition) -> StateGraph: + """ + Recreate a StateGraph from its serializable definition. + + Args: + stored_definition: StoredGraphDefinition with SerializableGraphDefinition + + Returns: + Recreated StateGraph ready for compilation + + Note: + This implementation dynamically imports and loads functions based on + the serialized function references (module + name). + """ + + state_class = self._load_class_from_reference(stored_definition.state_schema.class_reference) + assert issubclass(state_class, BaseModel), ( + f"Loaded state class must be a Pydantic BaseModel, got {type(state_class)}" + ) + + builder = StateGraph(state_class) + + for node_def in stored_definition.graph_definition.nodes: + try: + func = self._load_function_from_reference(node_def.function_reference) + builder.add_node(node_def.name, func) + except Exception as e: + raise ImportError(f"Failed to load node {node_def.name}: {e}") + + for edge_def in stored_definition.graph_definition.edges: + source = START if edge_def.source == "__start__" else edge_def.source + target = END if edge_def.target == "__end__" else edge_def.target + builder.add_edge(source, target) + for cond_edge_def in stored_definition.graph_definition.conditional_edges: + try: + condition_func = self._load_function_from_reference(cond_edge_def.condition_function_reference) + target_mapping: dict[Hashable, str] = { + k: END if v == "__end__" else v for k, v in cond_edge_def.target_mapping.items() + } + builder.add_conditional_edges(cond_edge_def.source_node, condition_func, target_mapping) + except Exception as e: + raise ImportError(f"Failed to load conditional edge from {cond_edge_def.source_node}: {e}") + + return builder + + def _load_function_from_reference( + self, function_reference: str + ) -> Union[Callable[..., Any], Callable[..., Awaitable[Any]]]: + """ + Dynamically load a function from a module:function reference string. + """ + try: + module_name, function_name = function_reference.split(":", 1) + module = importlib.import_module(module_name) + return getattr(module, function_name) + except (ImportError, AttributeError, ValueError) as e: + raise ImportError(f"Could not load function from reference '{function_reference}': {e}") + + def _load_class_from_reference(self, class_reference: str) -> type: + """ + Dynamically load a class from a module:class reference string. + """ + try: + module_name, class_name = class_reference.split(":", 1) + module = importlib.import_module(module_name) + cls = getattr(module, class_name) + if not isinstance(cls, type): + raise ValueError(f"'{class_name}' is not a class") + return cls + except (ImportError, AttributeError, ValueError) as e: + raise ImportError(f"Could not load class from reference '{class_reference}': {e}") + + def recreate_graph(self, stored_definition: StoredGraphDefinition) -> CompiledGraph: + """ + Recreate and compile a LangGraph from its stored definition. + """ + state_graph = self.recreate_graph_from_definition(stored_definition) + return state_graph.compile() + + def save_graph_from_builder( + self, + state_graph: StateGraph, + name: str, + description: str, + created_by: str = "experimental_poc", + ) -> str: + """ + High-level method: serialize StateGraph builder and store to IPFS in one operation. + + Args: + state_graph: The StateGraph builder (before compilation) + name: Name for the workflow + description: Description of what the workflow does + created_by: Who created this workflow + + Returns: + IPFS hash of the stored graph definition + """ + graph_json = self.serialize_graph_from_builder(state_graph, name, description, created_by) + return self.store_to_ipfs(graph_json) + + def load_graph(self, ipfs_hash: str) -> CompiledGraph: + """ + High-level method: retrieve from IPFS and recreate graph in one operation. + + Args: + ipfs_hash: IPFS hash of the stored graph definition + + Returns: + Recreated CompiledGraph ready for execution + + Note: + The returned graph can contain both sync and async functions. + Use `await graph.ainvoke(input)` for execution - this works + universally for pure sync, pure async, and mixed graphs. + """ + stored_definition = self.retrieve_from_ipfs(ipfs_hash) + return self.recreate_graph(stored_definition) + + def get_graph_info(self, ipfs_hash: str) -> GraphMetadata: + """ + Get metadata about a stored graph without recreating it. + """ + stored_definition = self.retrieve_from_ipfs(ipfs_hash) + return stored_definition.metadata diff --git a/experimental/archon/src/graph_models.py b/experimental/archon/src/graph_models.py new file mode 100644 index 00000000..0b355b86 --- /dev/null +++ b/experimental/archon/src/graph_models.py @@ -0,0 +1,206 @@ +""" +Pydantic models for LangGraph workflow storage and serialization. + +This module defines the data structures used for storing LangGraph workflows +on IPFS with proper validation and type safety. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any, Hashable + +from langchain_core.runnables.graph import Edge, Node +from pydantic import BaseModel, ConfigDict, Field, field_validator + + +class Immutable: + """Mixin class that makes Pydantic models immutable using frozen=True.""" + + model_config = ConfigDict(frozen=True) + + +class GraphModel(BaseModel, Immutable): + """ + Base class for all graph-related models. + """ + + @staticmethod + def _validate_function_reference(value: str) -> str: + """Validate function reference format and security constraints.""" + if value.count(":") != 1: + raise ValueError( + "Function reference must be in format 'module.path:function_name' with exactly one ':' separator" + ) + + module_path, function_name = value.split(":", 1) + + if not module_path or not function_name: + raise ValueError("Both module path and function name must be non-empty") + + # Security: No relative path components + if ".." in module_path: + raise ValueError("Relative path components (..) are not allowed in module path") + + # Security: Must start with 'experimental' for now + if not module_path.startswith("experimental"): + raise ValueError("Module path must start with 'experimental' for security") + + # Validate function name is a valid Python identifier + if not function_name.isidentifier(): + raise ValueError(f"Function name '{function_name}' is not a valid Python identifier") + + return value + + +class GraphMetadata(GraphModel): + """Metadata for a stored graph definition.""" + + name: str + version: str = "1.0.0" + description: str + created_by: str + created_at: datetime = Field(default_factory=datetime.now) + langgraph_version: str = "0.2.60" + + +# Note: It's a bit duplicative, but I prefer to fully separate our internal classes from LangGraph's classes. + + +class LangGraphNode(GraphModel): + """Pydantic-compatible wrapper for LangGraph's Node type.""" + + id: str + name: str + data: Any + metadata: dict[str, Any] | None = None + + @classmethod + def from_langgraph_node(cls, node: Node) -> "LangGraphNode": + """Create from a LangGraph Node.""" + return cls(id=node.id, name=node.name, data=node.data, metadata=node.metadata) + + def to_langgraph_node(self) -> Node: + """Convert to a LangGraph Node.""" + return Node(id=self.id, name=self.name, data=self.data, metadata=self.metadata) + + +class LangGraphEdge(GraphModel): + """Pydantic-compatible wrapper for LangGraph's Edge type.""" + + source: str + target: str + data: Any | None = None + conditional: bool = False + + @classmethod + def from_langgraph_edge(cls, edge: Edge) -> "LangGraphEdge": + """Create from a LangGraph Edge.""" + return cls( + source=edge.source, + target=edge.target, + data=edge.data, + conditional=edge.conditional, + ) + + def to_langgraph_edge(self) -> Edge: + """Convert to a LangGraph Edge.""" + return Edge( + source=self.source, + target=self.target, + data=self.data, + conditional=self.conditional, + ) + + +class GraphNodeDefinition(GraphModel): + """Serializable representation of a graph node before compilation.""" + + name: str + function_reference: str = Field(description="Format: 'module.path:function_name'") + metadata: dict[str, Any] | None = None + input_type: str | None = None + + @field_validator("function_reference") + @classmethod + def validate_function_reference(cls, v: str) -> str: + return cls._validate_function_reference(v) + + +class GraphEdgeDefinition(GraphModel): + """Serializable representation of a simple graph edge.""" + + source: str + target: str + + +class ConditionalEdgeDefinition(GraphModel): + """Serializable representation of a conditional edge.""" + + source_node: str + condition_function_reference: str = Field(description="Format: 'module.path:function_name'") + target_mapping: dict[Hashable, str] + + @field_validator("condition_function_reference") + @classmethod + def validate_condition_function_reference(cls, v: str) -> str: + return cls._validate_function_reference(v) + + +class StateChannelDefinition(GraphModel): + """Serializable representation of a state channel.""" + + name: str + channel_type: str + default_value: Any | None = None + + +class SerializableGraphDefinition(GraphModel): + """Complete serializable graph definition from StateGraph builder (before compilation).""" + + nodes: list[GraphNodeDefinition] + edges: list[GraphEdgeDefinition] + conditional_edges: list[ConditionalEdgeDefinition] + state_channels: list[StateChannelDefinition] + state_type_name: str + + +class StateSchema(GraphModel): + """Pydantic-only state schema representation.""" + + name: str + class_reference: str = Field(description="Format: 'module.path:ClassName'") + + @field_validator("class_reference") + @classmethod + def validate_class_reference(cls, v: str) -> str: + return cls._validate_function_reference(v) + + +class ExecutionConfig(GraphModel): + """Configuration options for graph execution.""" + + checkpointer: str | None = None + debug: bool = False + stream_mode: str = "values" + max_iterations: int | None = None + recursion_limit: int | None = None + + +class StoredGraphDefinition(GraphModel): + """Complete stored graph definition using serializable pre-compilation data.""" + + version: str = "0.0.1" + metadata: GraphMetadata + + # Serializable graph definition from StateGraph builder (before compilation) + graph_definition: SerializableGraphDefinition = Field( + description="Complete serializable graph structure from StateGraph builder" + ) + + state_schema: StateSchema = Field(description="Structured information about the state schema") + + execution_config: ExecutionConfig = Field( + default_factory=ExecutionConfig, + description="Structured execution configuration options", + ) diff --git a/experimental/archon/tests/__init__.py b/experimental/archon/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/experimental/archon/tests/conftest.py b/experimental/archon/tests/conftest.py new file mode 100644 index 00000000..ebf4b192 --- /dev/null +++ b/experimental/archon/tests/conftest.py @@ -0,0 +1,168 @@ +""" +Shared test fixtures for Archon graph storage tests. +""" + +from __future__ import annotations + +import json +from typing import Any, Callable + +import pytest +from langgraph.graph import END, START, StateGraph +from pydantic import BaseModel, Field + +from experimental.archon.src.graph_loader import GraphLoader +from experimental.archon.src.graph_models import StoredGraphDefinition + + +class SentimentState(BaseModel): + """Pydantic state model for sentiment analysis workflow.""" + + input_text: str = Field(description="Input text to analyze") + sentiment_score: float = Field(default=0.0, description="Sentiment score between 0 and 1") + decision: str = Field(default="", description="Decision based on sentiment analysis") + final_action: str = Field(default="", description="Final action taken based on decision") + + +def analyze_sentiment(state: SentimentState) -> SentimentState: + """Mock sentiment analysis node.""" + text = state.input_text + mock_score = 0.8 if "good" in text.lower() else 0.3 + return state.model_copy(update={"sentiment_score": mock_score}) + + +async def make_decision(state: SentimentState) -> SentimentState: + """Decision node based on sentiment score.""" + decision = "positive" if state.sentiment_score > 0.6 else "negative" + return state.model_copy(update={"decision": decision}) + + +def take_action(state: SentimentState) -> SentimentState: + """Action node for positive sentiment.""" + action = f"Taking positive action based on score {state.sentiment_score}" + return state.model_copy(update={"final_action": action}) + + +def no_action(state: SentimentState) -> SentimentState: + """No action node for negative sentiment.""" + action = f"No action needed, score too low: {state.sentiment_score}" + return state.model_copy(update={"final_action": action}) + + +def should_take_action(state: SentimentState) -> str: + """Conditional edge function to decide next node.""" + return "action" if state.decision == "positive" else "no_action" + + +@pytest.fixture +def sentiment_graph_builder() -> StateGraph: + """ + Create the sentiment analysis graph from the POC. + Returns the StateGraph builder (before compilation). + """ + builder = StateGraph(SentimentState) + + # Add nodes + builder.add_node("analyze", analyze_sentiment) + builder.add_node("decide", make_decision) + builder.add_node("action", take_action) + builder.add_node("no_action", no_action) + + # Add edges + builder.add_edge(START, "analyze") + builder.add_edge("analyze", "decide") + builder.add_conditional_edges("decide", should_take_action, {"action": "action", "no_action": "no_action"}) + builder.add_edge("action", END) + builder.add_edge("no_action", END) + + return builder + + +@pytest.fixture +def mock_ipfs_storage(): + """ + Create an in-memory storage to mock IPFS. + Returns a dictionary that will store graph definitions by hash. + """ + return {} + + +class IPFSMockHelpers: + """Helper functions for mocking IPFS operations.""" + + @staticmethod + def create_mock_store(storage: dict[str, str]) -> Callable[[str], str]: + """Create a mock store function that saves to the provided storage.""" + + def mock_store(graph_json: str) -> str: + fake_hash = f"Qm{hash(graph_json)}" + storage[fake_hash] = graph_json + return fake_hash + + return mock_store + + @staticmethod + def create_mock_retrieve(storage: dict[str, str]) -> Callable[[str], StoredGraphDefinition]: + """Create a mock retrieve function that loads from the provided storage.""" + + def mock_retrieve(ipfs_hash: str) -> StoredGraphDefinition: + json_data = storage[ipfs_hash] + return StoredGraphDefinition.model_validate(json.loads(json_data)) + + return mock_retrieve + + +@pytest.fixture +def mock_ipfs_functions(mock_ipfs_storage): + """ + Provide mock IPFS store and retrieve functions. + Returns a tuple of (mock_store, mock_retrieve) functions. + """ + mock_store = IPFSMockHelpers.create_mock_store(mock_ipfs_storage) + mock_retrieve = IPFSMockHelpers.create_mock_retrieve(mock_ipfs_storage) + return mock_store, mock_retrieve + + +@pytest.fixture +def setup_ipfs_mocks(monkeypatch, mock_ipfs_functions): + """ + Fixture that sets up IPFS mocks for a GraphLoader or GraphExecutor. + Returns a function that applies the mocks to a given loader/executor. + """ + mock_store, mock_retrieve = mock_ipfs_functions + + def apply_mocks(obj: Any) -> None: + """Apply IPFS mocks to a GraphLoader or GraphExecutor instance.""" + if hasattr(obj, "loader"): # GraphExecutor + target = obj.loader + else: # GraphLoader + target = obj + + monkeypatch.setattr(target, "store_to_ipfs", mock_store) + monkeypatch.setattr(target, "retrieve_from_ipfs", mock_retrieve) + + return apply_mocks + + +@pytest.fixture +def loader_with_ipfs_mocks(setup_ipfs_mocks): + """ + Create a GraphLoader with IPFS mocks already applied. + This fixture can be used directly instead of creating and mocking separately. + """ + loader = GraphLoader() + setup_ipfs_mocks(loader) + return loader + + +@pytest.fixture +def executor_with_ipfs_mocks(setup_ipfs_mocks): + """ + Create a GraphExecutor with IPFS mocks already applied. + This fixture can be used directly instead of creating and mocking separately. + """ + from experimental.archon.src.graph_executor import GraphExecutor + + executor = GraphExecutor() + setup_ipfs_mocks(executor) + return executor diff --git a/experimental/archon/tests/test_graph_executor.py b/experimental/archon/tests/test_graph_executor.py new file mode 100644 index 00000000..7bbe286c --- /dev/null +++ b/experimental/archon/tests/test_graph_executor.py @@ -0,0 +1,107 @@ +""" +Tests for GraphExecutor - Graph execution from IPFS-stored definitions. +""" + +from __future__ import annotations + +import pytest + +from experimental.archon.src.graph_executor import GraphExecutor +from experimental.archon.tests.conftest import SentimentState + + +class TestGraphExecutorWithMockedIPFS: + """Test GraphExecutor with mocked IPFS storage.""" + + @pytest.mark.asyncio + async def test_execute_stored_graph(self, sentiment_graph_builder, setup_ipfs_mocks): + """Test that we can execute a stored graph through the executor.""" + + executor = GraphExecutor() + setup_ipfs_mocks(executor) + + # Step 1: Store the graph using the underlying loader + ipfs_hash = executor.loader.save_graph_from_builder( + sentiment_graph_builder, + name="executor_test_workflow", + description="Test workflow for executor", + ) + + # Step 2: Execute the stored graph with positive sentiment input + positive_input = { + "input_text": "This is good news!", + "sentiment_score": 0.0, + "decision": "", + "final_action": "", + } + + positive_result = await executor.execute_graph(ipfs_hash, positive_input) + + # Verify positive sentiment path + assert positive_result["input_text"] == "This is good news!" + assert positive_result["sentiment_score"] == 0.8 # "good" triggers positive score + assert positive_result["decision"] == "positive" + assert "Taking positive action" in positive_result["final_action"] + assert "0.8" in positive_result["final_action"] + + # Step 3: Execute with negative sentiment input + negative_input = { + "input_text": "This is terrible news!", + "sentiment_score": 0.0, + "decision": "", + "final_action": "", + } + + negative_result = await executor.execute_graph(ipfs_hash, negative_input) + + # Verify negative sentiment path + assert negative_result["input_text"] == "This is terrible news!" + assert negative_result["sentiment_score"] == 0.3 # "terrible" doesn't contain "good" + assert negative_result["decision"] == "negative" + assert "No action needed" in negative_result["final_action"] + assert "0.3" in negative_result["final_action"] + + @pytest.mark.asyncio + async def test_type_aware_execution(self, sentiment_graph_builder, setup_ipfs_mocks): + """Test the new type-aware LoadedGraph functionality.""" + + executor = GraphExecutor() + setup_ipfs_mocks(executor) + + # Store the graph + ipfs_hash = executor.loader.save_graph_from_builder( + sentiment_graph_builder, + name="type_aware_test", + description="Test type-aware execution", + ) + + # Load the graph with type information + loaded_graph = executor.load_graph(ipfs_hash) + + # Verify we have access to the state class + assert hasattr(loaded_graph, "state_class") + assert loaded_graph.state_class.__name__ == "SentimentState" + + # Create typed input using the state class + typed_input = loaded_graph.create_state( + input_text="This is good news!", + sentiment_score=0.0, + decision="", + final_action="", + ) + + # Verify the created state is properly typed + assert isinstance(typed_input, SentimentState) + assert typed_input.input_text == "This is good news!" + + # Execute with the Pydantic model directly + result = await loaded_graph.execute(typed_input) + + # Validate and convert result back to typed state + typed_result = loaded_graph.validate_result(result) + + assert isinstance(typed_result, SentimentState) + assert typed_result.input_text == "This is good news!" + assert typed_result.sentiment_score == 0.8 # "good" triggers positive score + assert typed_result.decision == "positive" + assert "Taking positive action" in typed_result.final_action diff --git a/experimental/archon/tests/test_graph_loader.py b/experimental/archon/tests/test_graph_loader.py new file mode 100644 index 00000000..cc8bf89a --- /dev/null +++ b/experimental/archon/tests/test_graph_loader.py @@ -0,0 +1,108 @@ +""" +Tests for GraphLoader - Graph storage and retrieval from IPFS. +""" + +from __future__ import annotations + +from .conftest import analyze_sentiment, make_decision, no_action, take_action + + +class TestGraphLoaderWithMockedIPFS: + """Test GraphLoader with mocked IPFS storage.""" + + def test_end_to_end_save_load(self, sentiment_graph_builder, mock_ipfs_storage, loader_with_ipfs_mocks): + """Test complete workflow: save graph to 'IPFS', retrieve it, and execute.""" + + loader = loader_with_ipfs_mocks + + # Step 1: Save the graph + ipfs_hash = loader.save_graph_from_builder( + sentiment_graph_builder, + name="sentiment_workflow", + description="Analyzes sentiment and takes conditional actions", + ) + + assert ipfs_hash in mock_ipfs_storage + + # Step 2: Retrieve the graph definition + retrieved_def = loader.retrieve_from_ipfs(ipfs_hash) + + assert retrieved_def.metadata.name == "sentiment_workflow" + assert retrieved_def.metadata.description == "Analyzes sentiment and takes conditional actions" + assert len(retrieved_def.graph_definition.nodes) == 4 + + # Step 3: Verify the serialized structure + nodes = retrieved_def.graph_definition.nodes + node_names = {node.name for node in nodes} + assert node_names == {"analyze", "decide", "action", "no_action"} + + # Verify edges + edges = retrieved_def.graph_definition.edges + assert len(edges) >= 3 + + # Verify conditional edges + cond_edges = retrieved_def.graph_definition.conditional_edges + assert len(cond_edges) == 1 + assert cond_edges[0].source_node == "decide" + assert cond_edges[0].target_mapping == { + "action": "action", + "no_action": "no_action", + } + + # Step 4: Verify the function references have correct structure + # Function references should be in format "module.path:function_name" + + # Create a mapping of expected function references + # The module will be __main__ when running tests or test_graph_loader + expected_references = { + "analyze": ("analyze_sentiment", analyze_sentiment), + "decide": ("make_decision", make_decision), + "action": ("take_action", take_action), + "no_action": ("no_action", no_action), + } + + # Verify each node's function reference + for node in nodes: + assert ":" in node.function_reference, ( + f"Invalid reference format for {node.name}: {node.function_reference}" + ) + + module_path, func_name = node.function_reference.split(":", 1) + expected_func_name, expected_func = expected_references[node.name] + + # Verify function name matches + assert func_name == expected_func_name, ( + f"Function name mismatch for {node.name}: expected {expected_func_name}, got {func_name}" + ) + + # Verify module path is reasonable (should be __main__ or test module) + assert module_path in [ + "__main__", + "test_graph_loader", + "experimental.archon.tests.test_graph_loader", + "experimental.archon.tests.conftest", + ], f"Unexpected module path for {node.name}: {module_path}" + + # Verify conditional edge function reference + assert len(cond_edges) == 1 + cond_edge = cond_edges[0] + assert ":" in cond_edge.condition_function_reference + + cond_module, cond_func = cond_edge.condition_function_reference.split(":", 1) + assert cond_func == "should_take_action", f"Conditional function name mismatch: {cond_func}" + assert cond_module in [ + "__main__", + "test_graph_loader", + "experimental.archon.tests.test_graph_loader", + "experimental.archon.tests.conftest", + ], f"Unexpected module for conditional function: {cond_module}" + + # Verify that all function references are complete and valid + all_refs = [node.function_reference for node in nodes] + [cond_edge.condition_function_reference] + for ref in all_refs: + # Each reference should have exactly one colon separator + assert ref.count(":") == 1, f"Invalid reference format: {ref}" + # Neither part should be empty + module_part, func_part = ref.split(":") + assert module_part, f"Empty module in reference: {ref}" + assert func_part, f"Empty function name in reference: {ref}" diff --git a/experimental/archon/tests/test_type_preservation.py b/experimental/archon/tests/test_type_preservation.py new file mode 100644 index 00000000..4288f3f4 --- /dev/null +++ b/experimental/archon/tests/test_type_preservation.py @@ -0,0 +1,93 @@ +""" +Tests specifically for Pydantic type preservation in graph serialization/deserialization. +""" + +from __future__ import annotations + +import json + +import pytest + +from experimental.archon.src.graph_loader import GraphLoader +from experimental.archon.tests.conftest import SentimentState + + +class TestTypePreservation: + """Test that Pydantic types are properly preserved during serialization.""" + + def test_state_schema_extraction_preserves_types(self, sentiment_graph_builder, setup_ipfs_mocks): + """Test that state schema extraction captures complete Pydantic information.""" + + loader = GraphLoader() + setup_ipfs_mocks(loader) + + # Serialize the graph + graph_json = loader.serialize_graph_from_builder( + sentiment_graph_builder, + name="type_test_workflow", + description="Test workflow for type preservation", + ) + + # Parse and verify the JSON contains proper type information + stored_data = json.loads(graph_json) + state_schema = stored_data["state_schema"] + + # Verify it's recognized as a Pydantic model with proper class reference + assert "class_reference" in state_schema + assert state_schema["class_reference"] == "experimental.archon.tests.conftest:SentimentState" + assert state_schema["name"] == "SentimentState" + + # Verify we can load the class and generate the schema on demand + json_schema = SentimentState.model_json_schema() + + # Check that field types can be retrieved from the loaded class + properties = json_schema["properties"] + + # input_text should be string type + assert properties["input_text"]["type"] == "string" + assert properties["input_text"]["description"] == "Input text to analyze" + + # sentiment_score should be number type with default + assert properties["sentiment_score"]["type"] == "number" + assert properties["sentiment_score"]["default"] == 0.0 + assert properties["sentiment_score"]["description"] == "Sentiment score between 0 and 1" + + # decision should be string with default + assert properties["decision"]["type"] == "string" + assert properties["decision"]["default"] == "" + + # final_action should be string with default + assert properties["final_action"]["type"] == "string" + assert properties["final_action"]["default"] == "" + + @pytest.mark.asyncio + async def test_type_preservation_round_trip(self, sentiment_graph_builder, setup_ipfs_mocks): + """Test that types are preserved through complete serialization round trip.""" + + loader = GraphLoader() + setup_ipfs_mocks(loader) + + # Save the graph + ipfs_hash = loader.save_graph_from_builder( + sentiment_graph_builder, + name="round_trip_test", + description="Test round trip type preservation", + ) + + # Load the graph back + recreated_graph = loader.load_graph(ipfs_hash) + + # Create test input with proper Pydantic model + test_input = SentimentState( + input_text="This is good news!", + sentiment_score=0.0, + decision="", + final_action="", + ) + + # Execute and verify types are preserved (use ainvoke for async compatibility) + result = await recreated_graph.ainvoke(test_input) + + # We can reconstruct the Pydantic model from the result + pydantic_result = SentimentState.model_validate(result) + assert isinstance(pydantic_result, SentimentState) diff --git a/experimental/archon/tests/test_validation.py b/experimental/archon/tests/test_validation.py new file mode 100644 index 00000000..7446c8bc --- /dev/null +++ b/experimental/archon/tests/test_validation.py @@ -0,0 +1,74 @@ +""" +Test Pydantic field validators for function references. +""" + +from __future__ import annotations + +import pytest +from pydantic import ValidationError + +from experimental.archon.src.graph_models import ConditionalEdgeDefinition, GraphNodeDefinition + + +class TestFunctionReferenceValidation: + """Test validation of function references.""" + + def test_valid_function_reference(self): + """Test that valid function references are accepted.""" + node = GraphNodeDefinition( + name="test_node", function_reference="experimental.archon.tests.test_validation:test_function" + ) + assert node.function_reference == "experimental.archon.tests.test_validation:test_function" + + def test_invalid_format_no_colon(self): + """Test that function references without colons are rejected.""" + with pytest.raises(ValidationError, match="must be in format"): + GraphNodeDefinition(name="test_node", function_reference="experimental.archon.tests.test_validation") + + def test_invalid_format_multiple_colons(self): + """Test that function references with multiple colons are rejected.""" + with pytest.raises(ValidationError, match="exactly one ':' separator"): + GraphNodeDefinition(name="test_node", function_reference="experimental.archon:test:function") + + def test_empty_module_path(self): + """Test that empty module paths are rejected.""" + with pytest.raises(ValidationError, match="must be non-empty"): + GraphNodeDefinition(name="test_node", function_reference=":test_function") + + def test_empty_function_name(self): + """Test that empty function names are rejected.""" + with pytest.raises(ValidationError, match="must be non-empty"): + GraphNodeDefinition(name="test_node", function_reference="experimental.archon.tests:") + + def test_relative_paths_rejected(self): + """Test that relative path components are rejected.""" + with pytest.raises(ValidationError, match="Relative path components"): + GraphNodeDefinition(name="test_node", function_reference="experimental.archon.../malicious:bad_function") + + def test_non_experimental_module_rejected(self): + """Test that non-experimental modules are rejected.""" + with pytest.raises(ValidationError, match="must start with 'experimental'"): + GraphNodeDefinition(name="test_node", function_reference="malicious.module:bad_function") + + def test_invalid_python_identifier(self): + """Test that invalid Python identifiers for function names are rejected.""" + with pytest.raises(ValidationError, match="not a valid Python identifier"): + GraphNodeDefinition(name="test_node", function_reference="experimental.archon.tests:123invalid-name") + + def test_conditional_edge_validation(self): + """Test that conditional edges also validate function references.""" + # Valid conditional edge + cond_edge = ConditionalEdgeDefinition( + source_node="test_node", + condition_function_reference="experimental.archon.tests.test_validation:condition_func", + target_mapping={"true": "next_node", "false": "end"}, + ) + assert cond_edge.condition_function_reference == "experimental.archon.tests.test_validation:condition_func" + + # Invalid conditional edge + with pytest.raises(ValidationError, match="must start with 'experimental'"): + ConditionalEdgeDefinition( + source_node="test_node", + condition_function_reference="malicious.module:bad_condition", + target_mapping={"true": "next_node", "false": "end"}, + ) diff --git a/pyproject.toml b/pyproject.toml index 2bcc2ee0..bd8d951e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "langchain==0.3.26", "langchain-community==0.3.27", "langchain-openai==0.3.28", + "langgraph==0.2.60", "langgraph>=0.2.0", "langsmith>=0.1.0", "duckduckgo-search==8.1.1", @@ -56,6 +57,10 @@ line-length = 120 [tool.mypy] strict = true +[dependency-groups] +dev = [ + "pytest-asyncio>=1.1.0", +] [tool.pytest.ini_options] testpaths = ["tests"] addopts = "--ignore=integration_tests" diff --git a/requirements.txt b/requirements.txt index f05eb9fc..94d56ead 100644 --- a/requirements.txt +++ b/requirements.txt @@ -46,6 +46,7 @@ langchain-community==0.3.27 langchain-core==0.3.69 langchain-openai==0.3.28 langchain-text-splitters==0.3.8 +langgraph==0.2.60 langsmith==0.4.8 lxml==6.0.0 markdown-it-py==3.0.0 @@ -100,7 +101,6 @@ six==1.17.0 sniffio==1.3.1 soupsieve==2.7 sqlalchemy==2.0.41 --e file:///app tenacity==9.1.2 textblob==0.19.0 tiktoken==0.9.0 diff --git a/src/talos/core/agent.py b/src/talos/core/agent.py index 0f4c9a26..18fd7ddd 100644 --- a/src/talos/core/agent.py +++ b/src/talos/core/agent.py @@ -55,7 +55,7 @@ def model_post_init(self, __context: Any) -> None: def set_prompt(self, name: str | list[str]): if not self.prompt_manager: raise ValueError("Prompt manager not initialized.") - + prompt_names = name if isinstance(name, list) else [name] if self.dataset_manager: prompt_names.append("relevant_documents_prompt") diff --git a/src/talos/core/job_scheduler.py b/src/talos/core/job_scheduler.py index edf75079..112b2f18 100644 --- a/src/talos/core/job_scheduler.py +++ b/src/talos/core/job_scheduler.py @@ -17,44 +17,44 @@ class JobScheduler(BaseModel): """ Manages scheduled jobs for the MainAgent using APScheduler. - + Provides functionality to: - Register and manage scheduled jobs - Execute jobs with supervision - Handle job lifecycle (start, stop, pause, resume) """ - + model_config = ConfigDict(arbitrary_types_allowed=True) - + supervisor: Optional[Supervisor] = Field(None, description="Supervisor for approving job executions") timezone: str = Field("UTC", description="Timezone for job scheduling") - + _scheduler: AsyncIOScheduler = PrivateAttr() _jobs: Dict[str, ScheduledJob] = PrivateAttr(default_factory=dict) _running: bool = PrivateAttr(default=False) - + def model_post_init(self, __context: Any) -> None: self._scheduler = AsyncIOScheduler(timezone=self.timezone) self._jobs = {} self._running = False - + def register_job(self, job: ScheduledJob) -> None: """ Register a scheduled job with the scheduler. - + Args: job: The ScheduledJob instance to register """ if job.name in self._jobs: logger.warning(f"Job '{job.name}' already registered, replacing existing job") self.unregister_job(job.name) - + self._jobs[job.name] = job - + if not job.enabled: logger.info(f"Job '{job.name}' registered but disabled") return - + if job.is_recurring() and job.cron_expression: trigger = CronTrigger.from_crontab(job.cron_expression, timezone=self.timezone) self._scheduler.add_job( @@ -63,10 +63,10 @@ def register_job(self, job: ScheduledJob) -> None: args=[job.name], id=job.name, max_instances=job.max_instances, - replace_existing=True + replace_existing=True, ) logger.info(f"Registered recurring job '{job.name}' with cron: {job.cron_expression}") - + elif job.is_one_time() and job.execute_at: trigger = DateTrigger(run_date=job.execute_at, timezone=self.timezone) self._scheduler.add_job( @@ -75,41 +75,41 @@ def register_job(self, job: ScheduledJob) -> None: args=[job.name], id=job.name, max_instances=job.max_instances, - replace_existing=True + replace_existing=True, ) logger.info(f"Registered one-time job '{job.name}' for: {job.execute_at}") - + def unregister_job(self, job_name: str) -> bool: """ Unregister a scheduled job. - + Args: job_name: Name of the job to unregister - + Returns: True if job was found and removed, False otherwise """ if job_name not in self._jobs: logger.warning(f"Job '{job_name}' not found for unregistration") return False - + try: self._scheduler.remove_job(job_name) except Exception as e: logger.warning(f"Failed to remove job '{job_name}' from scheduler: {e}") - + del self._jobs[job_name] logger.info(f"Unregistered job '{job_name}'") return True - + def get_job(self, job_name: str) -> Optional[ScheduledJob]: """Get a registered job by name.""" return self._jobs.get(job_name) - + def list_jobs(self) -> List[ScheduledJob]: """Get all registered jobs.""" return list(self._jobs.values()) - + def start(self) -> None: """Start the job scheduler.""" if self._running: @@ -133,18 +133,18 @@ def stop(self) -> None: if not self._running: logger.warning("Job scheduler is not running") return - + self._scheduler.shutdown() self._running = False logger.info("Job scheduler stopped") - + def pause_job(self, job_name: str) -> bool: """ Pause a specific job. - + Args: job_name: Name of the job to pause - + Returns: True if job was found and paused, False otherwise """ @@ -159,14 +159,14 @@ def pause_job(self, job_name: str) -> bool: except Exception as e: logger.error(f"Failed to pause job '{job_name}': {e}") return False - + def resume_job(self, job_name: str) -> bool: """ Resume a specific job. - + Args: job_name: Name of the job to resume - + Returns: True if job was found and resumed, False otherwise """ @@ -181,15 +181,15 @@ def resume_job(self, job_name: str) -> bool: except Exception as e: logger.error(f"Failed to resume job '{job_name}': {e}") return False - + def is_running(self) -> bool: """Check if the scheduler is running.""" return self._running - + async def _execute_job_with_supervision(self, job_name: str) -> None: """ Execute a job with optional supervision. - + Args: job_name: Name of the job to execute """ @@ -197,27 +197,27 @@ async def _execute_job_with_supervision(self, job_name: str) -> None: if not job: logger.error(f"Job '{job_name}' not found for execution") return - + if not job.enabled: logger.info(f"Job '{job_name}' is disabled, skipping execution") return - + logger.info(f"Executing scheduled job: {job_name}") - + try: if self.supervisor: logger.info(f"Requesting supervision approval for job: {job_name}") - + await job.run() logger.info(f"Job '{job_name}' completed successfully") - + if job.is_one_time(): self.unregister_job(job_name) logger.info(f"One-time job '{job_name}' removed after execution") - + except Exception as e: logger.error(f"Job '{job_name}' failed with error: {e}") - + if job.is_one_time(): self.unregister_job(job_name) logger.info(f"Failed one-time job '{job_name}' removed") diff --git a/src/talos/core/scheduled_job.py b/src/talos/core/scheduled_job.py index 1ec99f63..6ec6ffa6 100644 --- a/src/talos/core/scheduled_job.py +++ b/src/talos/core/scheduled_job.py @@ -5,7 +5,7 @@ from datetime import datetime from typing import Any, Optional -from pydantic import BaseModel, Field, ConfigDict +from pydantic import BaseModel, ConfigDict, Field logger = logging.getLogger(__name__) @@ -13,51 +13,53 @@ class ScheduledJob(BaseModel, ABC): """ Abstract base class for scheduled jobs that can be executed by the MainAgent. - + Jobs can be scheduled using either: - A cron expression for recurring execution - A specific datetime for one-time execution """ - + model_config = ConfigDict(arbitrary_types_allowed=True) - + name: str = Field(..., description="Unique name for this scheduled job") description: str = Field(..., description="Human-readable description of what this job does") - cron_expression: Optional[str] = Field(None, description="Cron expression for recurring jobs (e.g., '0 9 * * *' for daily at 9 AM)") + cron_expression: Optional[str] = Field( + None, description="Cron expression for recurring jobs (e.g., '0 9 * * *' for daily at 9 AM)" + ) execute_at: Optional[datetime] = Field(None, description="Specific datetime for one-time execution") enabled: bool = Field(True, description="Whether this job is enabled for execution") max_instances: int = Field(1, description="Maximum number of concurrent instances of this job") - + def model_post_init(self, __context: Any) -> None: if not self.cron_expression and not self.execute_at: raise ValueError("Either cron_expression or execute_at must be provided") if self.cron_expression and self.execute_at: raise ValueError("Only one of cron_expression or execute_at should be provided") - + @abstractmethod async def run(self, **kwargs: Any) -> Any: """ Execute the scheduled job. - + This method should contain the actual logic for the job. It will be called by the scheduler when the job is triggered. - + Args: **kwargs: Additional arguments that may be passed to the job - + Returns: Any result from the job execution """ pass - + def is_recurring(self) -> bool: """Check if this is a recurring job (has cron expression).""" return self.cron_expression is not None - + def is_one_time(self) -> bool: """Check if this is a one-time job (has execute_at datetime).""" return self.execute_at is not None - + def should_execute_now(self) -> bool: """ Check if this job should execute now (for one-time jobs). @@ -66,7 +68,7 @@ def should_execute_now(self) -> bool: if not self.is_one_time() or not self.execute_at: return False return datetime.now() >= self.execute_at - + def __str__(self) -> str: schedule_info = self.cron_expression if self.cron_expression else f"at {self.execute_at}" return f"ScheduledJob(name='{self.name}', schedule='{schedule_info}', enabled={self.enabled})" diff --git a/src/talos/jobs/example_jobs.py b/src/talos/jobs/example_jobs.py index 5da0f89c..5ca7b53c 100644 --- a/src/talos/jobs/example_jobs.py +++ b/src/talos/jobs/example_jobs.py @@ -13,29 +13,29 @@ class HealthCheckJob(ScheduledJob): """ Example scheduled job that performs a health check every hour. """ - + def __init__(self, **kwargs): super().__init__( name="health_check", description="Performs a health check of the agent system", cron_expression="0 * * * *", # Every hour at minute 0 - **kwargs + **kwargs, ) - + async def run(self, **kwargs: Any) -> str: """ Perform a health check of the agent system. """ logger.info("Running health check job") - + current_time = datetime.now() health_status = { "timestamp": current_time.isoformat(), "status": "healthy", "uptime": "running", - "memory_usage": "normal" + "memory_usage": "normal", } - + logger.info(f"Health check completed: {health_status}") return f"Health check completed at {current_time}: System is healthy" @@ -44,29 +44,24 @@ class DailyReportJob(ScheduledJob): """ Example scheduled job that generates a daily report at 9 AM. """ - + def __init__(self, **kwargs): super().__init__( name="daily_report", description="Generates a daily activity report", cron_expression="0 9 * * *", # Daily at 9 AM - **kwargs + **kwargs, ) - + async def run(self, **kwargs: Any) -> str: """ Generate a daily activity report. """ logger.info("Running daily report job") - + current_date = datetime.now().strftime("%Y-%m-%d") - report_data = { - "date": current_date, - "tasks_completed": 0, - "skills_used": [], - "memory_entries": 0 - } - + report_data = {"date": current_date, "tasks_completed": 0, "skills_used": [], "memory_entries": 0} + logger.info(f"Daily report generated: {report_data}") return f"Daily report for {current_date} completed with {report_data['tasks_completed']} tasks and {report_data['memory_entries']} memory entries" @@ -75,30 +70,23 @@ class OneTimeMaintenanceJob(ScheduledJob): """ Example one-time scheduled job for maintenance tasks. """ - + def __init__(self, execute_at: datetime, **kwargs): super().__init__( - name="maintenance_task", - description="Performs one-time maintenance task", - execute_at=execute_at, - **kwargs + name="maintenance_task", description="Performs one-time maintenance task", execute_at=execute_at, **kwargs ) - + async def run(self, **kwargs: Any) -> str: """ Perform a one-time maintenance task. """ logger.info("Running one-time maintenance job") - - maintenance_tasks = [ - "Clean temporary files", - "Optimize memory usage", - "Update internal metrics" - ] - + + maintenance_tasks = ["Clean temporary files", "Optimize memory usage", "Update internal metrics"] + for task in maintenance_tasks: logger.info(f"Executing maintenance task: {task}") - + completion_time = datetime.now() logger.info(f"Maintenance completed at {completion_time}") return f"Maintenance tasks completed at {completion_time}" @@ -107,14 +95,10 @@ async def run(self, **kwargs: Any) -> str: def create_example_jobs() -> list[ScheduledJob]: """ Create a list of example scheduled jobs for demonstration. - + Returns: List of example ScheduledJob instances """ - jobs = [ - HealthCheckJob(), - DailyReportJob(), - OneTimeMaintenanceJob(execute_at=datetime.now() + timedelta(minutes=5)) - ] - + jobs = [HealthCheckJob(), DailyReportJob(), OneTimeMaintenanceJob(execute_at=datetime.now() + timedelta(minutes=5))] + return jobs diff --git a/src/talos/models/twitter.py b/src/talos/models/twitter.py index 3d3a4d07..0ceb8b11 100644 --- a/src/talos/models/twitter.py +++ b/src/talos/models/twitter.py @@ -40,16 +40,13 @@ class Tweet(BaseModel): referenced_tweets: Optional[list[ReferencedTweet]] = None in_reply_to_user_id: Optional[str] = None edit_history_tweet_ids: Optional[list[str]] = None - + def is_reply_to(self, tweet_id: str) -> bool: """Check if this tweet is a reply to the specified tweet ID.""" if not self.referenced_tweets: return False - return any( - ref.type == "replied_to" and ref.id == tweet_id - for ref in self.referenced_tweets - ) - + return any(ref.type == "replied_to" and ref.id == tweet_id for ref in self.referenced_tweets) + def get_replied_to_id(self) -> Optional[int]: """Get the ID of the tweet this is replying to, if any.""" if not self.referenced_tweets: diff --git a/src/talos/services/implementations/yield_manager.py b/src/talos/services/implementations/yield_manager.py index 69622edb..af350909 100644 --- a/src/talos/services/implementations/yield_manager.py +++ b/src/talos/services/implementations/yield_manager.py @@ -23,7 +23,7 @@ def __init__( raise ValueError("Min and max yield must be positive") if min_yield >= max_yield: raise ValueError("Min yield must be less than max yield") - + self.dexscreener_client = dexscreener_client self.gecko_terminal_client = gecko_terminal_client self.llm_client = llm_client @@ -51,7 +51,7 @@ def update_staking_apr(self, sentiment: float, sentiment_report: str) -> float: dexscreener_data, ohlcv_data, sentiment, staked_supply_percentage ) logging.info(f"Data source scores: {data_scores}") - + weighted_apr = self._calculate_weighted_apr_recommendation(data_scores) logging.info(f"Weighted APR recommendation: {weighted_apr}") @@ -66,9 +66,9 @@ def update_staking_apr(self, sentiment: float, sentiment_report: str) -> float: staked_supply_percentage=staked_supply_percentage, ohlcv_data=ohlcv_data.model_dump_json(), ) - + enhanced_prompt = f"{prompt}\n\nBased on weighted analysis of the data sources, the recommended APR is {weighted_apr:.4f}. Please consider this recommendation along with the raw data. The APR must be between {self.min_yield} and {self.max_yield}." - + response = self.llm_client.reasoning(enhanced_prompt, web_search=True) try: response_json = json.loads(response) @@ -81,84 +81,86 @@ def update_staking_apr(self, sentiment: float, sentiment_report: str) -> float: return max(self.min_yield, min(self.max_yield, weighted_apr)) final_apr = max(self.min_yield, min(self.max_yield, llm_apr)) - + if final_apr != llm_apr: logging.info(f"APR bounded from {llm_apr} to {final_apr} (min: {self.min_yield}, max: {self.max_yield})") - + return final_apr def get_staked_supply_percentage(self) -> float: return 0.45 - def _calculate_data_source_scores(self, dexscreener_data, ohlcv_data, sentiment: float, staked_supply_percentage: float) -> dict: + def _calculate_data_source_scores( + self, dexscreener_data, ohlcv_data, sentiment: float, staked_supply_percentage: float + ) -> dict: scores = {} - + price_change = dexscreener_data.price_change_h24 if price_change > 0.1: - scores['price_trend'] = 0.8 + scores["price_trend"] = 0.8 elif price_change > 0.05: - scores['price_trend'] = 0.6 + scores["price_trend"] = 0.6 elif price_change > -0.05: - scores['price_trend'] = 0.5 + scores["price_trend"] = 0.5 elif price_change > -0.1: - scores['price_trend'] = 0.3 + scores["price_trend"] = 0.3 else: - scores['price_trend'] = 0.1 - + scores["price_trend"] = 0.1 + volume = dexscreener_data.volume_h24 if volume > 1000000: - scores['volume_confidence'] = 0.8 + scores["volume_confidence"] = 0.8 elif volume > 500000: - scores['volume_confidence'] = 0.6 + scores["volume_confidence"] = 0.6 elif volume > 100000: - scores['volume_confidence'] = 0.4 + scores["volume_confidence"] = 0.4 else: - scores['volume_confidence'] = 0.2 - - scores['sentiment'] = max(0.0, min(1.0, sentiment / 100.0)) - + scores["volume_confidence"] = 0.2 + + scores["sentiment"] = max(0.0, min(1.0, sentiment / 100.0)) + if staked_supply_percentage > 0.8: - scores['supply_pressure'] = 0.2 + scores["supply_pressure"] = 0.2 elif staked_supply_percentage > 0.6: - scores['supply_pressure'] = 0.4 + scores["supply_pressure"] = 0.4 elif staked_supply_percentage > 0.4: - scores['supply_pressure'] = 0.6 + scores["supply_pressure"] = 0.6 elif staked_supply_percentage > 0.2: - scores['supply_pressure'] = 0.8 + scores["supply_pressure"] = 0.8 else: - scores['supply_pressure'] = 1.0 - + scores["supply_pressure"] = 1.0 + if ohlcv_data.ohlcv_list: recent_ohlcv = ohlcv_data.ohlcv_list[-5:] if len(recent_ohlcv) >= 2: price_range = max(item.high for item in recent_ohlcv) - min(item.low for item in recent_ohlcv) avg_price = sum(item.close for item in recent_ohlcv) / len(recent_ohlcv) volatility = price_range / avg_price if avg_price > 0 else 0 - + if volatility > 0.2: - scores['volatility'] = 0.3 + scores["volatility"] = 0.3 elif volatility > 0.1: - scores['volatility'] = 0.5 + scores["volatility"] = 0.5 else: - scores['volatility'] = 0.7 + scores["volatility"] = 0.7 else: - scores['volatility'] = 0.5 + scores["volatility"] = 0.5 else: - scores['volatility'] = 0.5 - + scores["volatility"] = 0.5 + return scores def _calculate_weighted_apr_recommendation(self, scores: dict) -> float: weights = { - 'price_trend': 0.25, - 'volume_confidence': 0.15, - 'sentiment': 0.20, - 'supply_pressure': 0.25, - 'volatility': 0.15 + "price_trend": 0.25, + "volume_confidence": 0.15, + "sentiment": 0.20, + "supply_pressure": 0.25, + "volatility": 0.15, } - + weighted_score = sum(scores[factor] * weights[factor] for factor in weights.keys()) - + apr_recommendation = self.min_yield + (weighted_score * (self.max_yield - self.min_yield)) - + return apr_recommendation diff --git a/src/talos/skills/twitter_influence.py b/src/talos/skills/twitter_influence.py index f1c15988..cadd11e7 100644 --- a/src/talos/skills/twitter_influence.py +++ b/src/talos/skills/twitter_influence.py @@ -34,7 +34,11 @@ def model_post_init(self, __context: Any) -> None: self.memory = Memory(file_path=memory_path, embeddings_model=embeddings, auto_save=True) if self.evaluator is None: - file_prompt_manager = self.prompt_manager if isinstance(self.prompt_manager, FilePromptManager) else FilePromptManager("src/talos/prompts") + file_prompt_manager = ( + self.prompt_manager + if isinstance(self.prompt_manager, FilePromptManager) + else FilePromptManager("src/talos/prompts") + ) self.evaluator = GeneralInfluenceEvaluator(self.twitter_client, self.llm, file_prompt_manager) @property diff --git a/src/talos/tools/general_influence_evaluator.py b/src/talos/tools/general_influence_evaluator.py index 48d13c46..717f0d6d 100644 --- a/src/talos/tools/general_influence_evaluator.py +++ b/src/talos/tools/general_influence_evaluator.py @@ -132,7 +132,7 @@ def _fallback_content_analysis(self, tweets: List[Any]) -> int: total_length = sum(len(tweet.text) for tweet in tweets) avg_length = total_length / len(tweets) - + if avg_length >= 200: length_score = 80 elif avg_length >= 100: @@ -144,7 +144,7 @@ def _fallback_content_analysis(self, tweets: List[Any]) -> int: original_tweets = [tweet for tweet in tweets if not tweet.text.startswith("RT @")] originality_ratio = len(original_tweets) / len(tweets) if tweets else 0 - + if originality_ratio >= 0.8: originality_score = 80 elif originality_ratio >= 0.6: @@ -189,46 +189,43 @@ def _calculate_authenticity_score(self, user: Any, tweets: List[Any] | None = No """Calculate enhanced authenticity score with advanced bot detection (0-100)""" if tweets is None: tweets = [] - + base_score = self._calculate_base_authenticity(user) - + engagement_score = self._calculate_engagement_authenticity(user, tweets) - + content_score = self._calculate_content_authenticity(tweets) - + temporal_score = self._calculate_temporal_authenticity(tweets) - + composite_score = int( - base_score * 0.40 + - engagement_score * 0.25 + - content_score * 0.20 + - temporal_score * 0.15 + base_score * 0.40 + engagement_score * 0.25 + content_score * 0.20 + temporal_score * 0.15 ) - + return min(100, max(0, composite_score)) def _calculate_base_authenticity(self, user: Any) -> int: """Calculate base authenticity score from account indicators (0-100)""" score = 0 - + account_age_days = (datetime.now(timezone.utc) - user.created_at).days if account_age_days > 1825: # 5+ years score += 35 - elif account_age_days > 1095: # 3+ years + elif account_age_days > 1095: # 3+ years score += 30 - elif account_age_days > 730: # 2+ years + elif account_age_days > 730: # 2+ years score += 25 - elif account_age_days > 365: # 1+ year + elif account_age_days > 365: # 1+ year score += 20 - elif account_age_days > 180: # 6+ months + elif account_age_days > 180: # 6+ months score += 10 - elif account_age_days < 30: # Suspicious new accounts + elif account_age_days < 30: # Suspicious new accounts score -= 10 - + if user.verified: score += 25 - - if user.profile_image_url and not user.profile_image_url.endswith('default_profile_images/'): + + if user.profile_image_url and not user.profile_image_url.endswith("default_profile_images/"): score += 15 if user.description and len(user.description) > 20: score += 10 @@ -236,77 +233,77 @@ def _calculate_base_authenticity(self, user: Any) -> int: score += 5 if user.url: score += 5 - + following = user.public_metrics.get("following_count", 0) - + if following > 50000: score -= 15 elif following > 10000: score -= 5 - + return min(100, max(0, score)) def _calculate_engagement_authenticity(self, user: Any, tweets: List[Any]) -> int: """Analyze engagement patterns for authenticity indicators (0-100)""" if not tweets: return 50 # Neutral score when no data available - + score = 50 # Start with neutral followers = user.public_metrics.get("followers_count", 0) - + if followers == 0: return 20 # Very suspicious - + engagement_rates = [] for tweet in tweets[:20]: # Analyze recent tweets engagement = ( - tweet.public_metrics.get("like_count", 0) + - tweet.public_metrics.get("retweet_count", 0) + - tweet.public_metrics.get("reply_count", 0) + tweet.public_metrics.get("like_count", 0) + + tweet.public_metrics.get("retweet_count", 0) + + tweet.public_metrics.get("reply_count", 0) ) rate = (engagement / followers) * 100 engagement_rates.append(rate) - + if engagement_rates: avg_rate = sum(engagement_rates) / len(engagement_rates) rate_variance = sum((r - avg_rate) ** 2 for r in engagement_rates) / len(engagement_rates) - + if rate_variance < 0.1: # Very consistent score += 20 elif rate_variance < 1.0: # Reasonably consistent score += 10 elif rate_variance > 10.0: # Highly inconsistent (suspicious) score -= 15 - + if avg_rate > 10: # >10% engagement rate is unusual score -= 20 elif avg_rate > 5: score -= 10 elif avg_rate < 0.1: # Very low engagement also suspicious score -= 10 - + like_counts = [t.public_metrics.get("like_count", 0) for t in tweets[:10]] retweet_counts = [t.public_metrics.get("retweet_count", 0) for t in tweets[:10]] - + if sum(like_counts) > 0 and sum(retweet_counts) > 0: like_rt_ratio = sum(like_counts) / sum(retweet_counts) if 2 <= like_rt_ratio <= 20: # Normal range score += 15 else: # Unusual ratios score -= 10 - + return min(100, max(0, score)) def _calculate_content_authenticity(self, tweets: List[Any]) -> int: """Analyze content patterns for authenticity indicators (0-100)""" if not tweets: return 50 # Neutral score when no data available - + score = 50 # Start with neutral - + tweet_texts = [tweet.text for tweet in tweets[:20]] unique_texts = set(tweet_texts) - + if len(tweet_texts) > 0: uniqueness_ratio = len(unique_texts) / len(tweet_texts) if uniqueness_ratio > 0.9: # High uniqueness @@ -315,97 +312,97 @@ def _calculate_content_authenticity(self, tweets: List[Any]) -> int: score += 15 elif uniqueness_ratio < 0.5: # Low uniqueness (suspicious) score -= 20 - + original_tweets = [t for t in tweets if not t.text.startswith("RT @")] - + if len(tweets) > 0: original_ratio = len(original_tweets) / len(tweets) if original_ratio > 0.7: # Mostly original content score += 20 elif original_ratio < 0.3: # Mostly retweets (suspicious) score -= 15 - + hashtag_counts = [] for tweet in tweets[:10]: - hashtag_count = tweet.text.count('#') + hashtag_count = tweet.text.count("#") hashtag_counts.append(hashtag_count) - + if hashtag_counts: avg_hashtags = sum(hashtag_counts) / len(hashtag_counts) if avg_hashtags > 5: # Excessive hashtag use score -= 15 elif 1 <= avg_hashtags <= 3: # Normal hashtag use score += 10 - + if original_tweets: avg_length = sum(len(t.text) for t in original_tweets) / len(original_tweets) if avg_length > 100: # Longer, more thoughtful tweets score += 15 elif avg_length < 30: # Very short tweets (suspicious) score -= 10 - + return min(100, max(0, score)) def _calculate_temporal_authenticity(self, tweets: List[Any]) -> int: """Analyze temporal posting patterns for authenticity indicators (0-100)""" if not tweets: return 50 # Neutral score when no data available - + score = 50 # Start with neutral - + # Analyze posting frequency tweets_with_dates = [t for t in tweets if t.created_at] if len(tweets_with_dates) < 2: return score - + timestamps = [] for tweet in tweets_with_dates[:20]: try: if isinstance(tweet.created_at, str): - timestamp = datetime.fromisoformat(tweet.created_at.replace('Z', '+00:00')) + timestamp = datetime.fromisoformat(tweet.created_at.replace("Z", "+00:00")) else: timestamp = tweet.created_at timestamps.append(timestamp) except (ValueError, AttributeError, TypeError): continue - + if len(timestamps) < 2: return score - + timestamps.sort() intervals = [] for i in range(1, len(timestamps)): - interval = (timestamps[i] - timestamps[i-1]).total_seconds() + interval = (timestamps[i] - timestamps[i - 1]).total_seconds() intervals.append(interval) - + if intervals: avg_interval = sum(intervals) / len(intervals) interval_variance = sum((i - avg_interval) ** 2 for i in intervals) / len(intervals) - + if interval_variance < (avg_interval * 0.1) ** 2 and len(intervals) > 5: score -= 20 # Too regular elif interval_variance > (avg_interval * 2) ** 2: - score += 10 # Natural variance - + score += 10 # Natural variance + if avg_interval < 300: # Less than 5 minutes average score -= 25 elif avg_interval < 3600: # Less than 1 hour average score -= 10 - + return min(100, max(0, score)) def _calculate_influence_score(self, user: Any) -> int: """Calculate influence score based on follower metrics (0-100)""" followers = user.public_metrics.get("followers_count", 0) following = user.public_metrics.get("following_count", 0) - + if followers >= 1000000: # 1M+ follower_score = 100 elif followers >= 100000: # 100K+ follower_score = 80 - elif followers >= 10000: # 10K+ + elif followers >= 10000: # 10K+ follower_score = 60 - elif followers >= 1000: # 1K+ + elif followers >= 1000: # 1K+ follower_score = 40 else: follower_score = 20 @@ -444,7 +441,7 @@ def _calculate_credibility_score(self, user: Any, tweets: List[Any]) -> int: if tweets: tweet_count = user.public_metrics.get("tweet_count", 0) account_age_days = (datetime.now(timezone.utc) - user.created_at).days - + if account_age_days > 0: tweets_per_day = tweet_count / account_age_days if 0.5 <= tweets_per_day <= 10: # Reasonable posting frequency @@ -454,7 +451,7 @@ def _calculate_credibility_score(self, user: Any, tweets: List[Any]) -> int: else: # Too much or too little posting score += 5 - if user.url and any(domain in user.url for domain in ['.com', '.org', '.edu', '.gov']): + if user.url and any(domain in user.url for domain in [".com", ".org", ".edu", ".gov"]): score += 10 return min(100, score) diff --git a/src/talos/tools/twitter_client.py b/src/talos/tools/twitter_client.py index 344826f5..3538fb99 100644 --- a/src/talos/tools/twitter_client.py +++ b/src/talos/tools/twitter_client.py @@ -7,7 +7,7 @@ from pydantic_settings import BaseSettings from textblob import TextBlob -from talos.models.twitter import TwitterUser, Tweet, ReferencedTweet +from talos.models.twitter import ReferencedTweet, Tweet, TwitterUser logger = logging.getLogger(__name__) @@ -106,6 +106,7 @@ def get_user(self, username: str) -> TwitterUser: ], ) from talos.models.twitter import TwitterPublicMetrics + user_data = response.data return TwitterUser( id=int(user_data.id), @@ -116,7 +117,7 @@ def get_user(self, username: str) -> TwitterUser: public_metrics=TwitterPublicMetrics(**user_data.public_metrics), description=user_data.description, url=user_data.url, - verified=getattr(user_data, 'verified', False) + verified=getattr(user_data, "verified", False), ) def search_tweets( @@ -173,7 +174,15 @@ def get_user_timeline(self, username: str) -> list[Tweet]: return [] response = self.client.get_users_tweets( id=user.id, - tweet_fields=["author_id", "in_reply_to_user_id", "public_metrics", "referenced_tweets", "conversation_id", "created_at", "edit_history_tweet_ids"], + tweet_fields=[ + "author_id", + "in_reply_to_user_id", + "public_metrics", + "referenced_tweets", + "conversation_id", + "created_at", + "edit_history_tweet_ids", + ], user_fields=[ "created_at", "public_metrics", @@ -196,7 +205,15 @@ def get_user_mentions(self, username: str) -> list[Tweet]: return [] response = self.client.get_users_mentions( id=user.id, - tweet_fields=["author_id", "in_reply_to_user_id", "public_metrics", "referenced_tweets", "conversation_id", "created_at", "edit_history_tweet_ids"], + tweet_fields=[ + "author_id", + "in_reply_to_user_id", + "public_metrics", + "referenced_tweets", + "conversation_id", + "created_at", + "edit_history_tweet_ids", + ], user_fields=[ "created_at", "public_metrics", @@ -215,7 +232,15 @@ def get_tweet(self, tweet_id: int) -> Tweet: response = self.client.get_tweet( str(tweet_id), - tweet_fields=["author_id", "in_reply_to_user_id", "public_metrics", "referenced_tweets", "conversation_id", "created_at", "edit_history_tweet_ids"] + tweet_fields=[ + "author_id", + "in_reply_to_user_id", + "public_metrics", + "referenced_tweets", + "conversation_id", + "created_at", + "edit_history_tweet_ids", + ], ) return self._convert_to_tweet_model(response.data) @@ -260,27 +285,31 @@ def reply_to_tweet(self, tweet_id: str, tweet: str) -> Any: def _convert_to_tweet_model(self, tweet_data: Any) -> Tweet: """Convert raw tweepy tweet data to Tweet BaseModel""" referenced_tweets = [] - if hasattr(tweet_data, 'referenced_tweets') and tweet_data.referenced_tweets: + if hasattr(tweet_data, "referenced_tweets") and tweet_data.referenced_tweets: for ref in tweet_data.referenced_tweets: if isinstance(ref, dict): - referenced_tweets.append(ReferencedTweet( - type=ref.get('type', ''), - id=ref.get('id', 0) - )) + referenced_tweets.append(ReferencedTweet(type=ref.get("type", ""), id=ref.get("id", 0))) else: - referenced_tweets.append(ReferencedTweet( - type=getattr(ref, 'type', ''), - id=getattr(ref, 'id', 0) - )) - + referenced_tweets.append(ReferencedTweet(type=getattr(ref, "type", ""), id=getattr(ref, "id", 0))) + return Tweet( id=int(tweet_data.id), text=tweet_data.text, author_id=str(tweet_data.author_id), - created_at=str(tweet_data.created_at) if hasattr(tweet_data, 'created_at') and tweet_data.created_at else None, - conversation_id=str(tweet_data.conversation_id) if hasattr(tweet_data, 'conversation_id') and tweet_data.conversation_id else None, - public_metrics=dict(tweet_data.public_metrics) if hasattr(tweet_data, 'public_metrics') and tweet_data.public_metrics else {}, + created_at=str(tweet_data.created_at) + if hasattr(tweet_data, "created_at") and tweet_data.created_at + else None, + conversation_id=str(tweet_data.conversation_id) + if hasattr(tweet_data, "conversation_id") and tweet_data.conversation_id + else None, + public_metrics=dict(tweet_data.public_metrics) + if hasattr(tweet_data, "public_metrics") and tweet_data.public_metrics + else {}, referenced_tweets=referenced_tweets if referenced_tweets else None, - in_reply_to_user_id=str(tweet_data.in_reply_to_user_id) if hasattr(tweet_data, 'in_reply_to_user_id') and tweet_data.in_reply_to_user_id else None, - edit_history_tweet_ids=[str(id) for id in tweet_data.edit_history_tweet_ids] if hasattr(tweet_data, 'edit_history_tweet_ids') and tweet_data.edit_history_tweet_ids else None + in_reply_to_user_id=str(tweet_data.in_reply_to_user_id) + if hasattr(tweet_data, "in_reply_to_user_id") and tweet_data.in_reply_to_user_id + else None, + edit_history_tweet_ids=[str(id) for id in tweet_data.edit_history_tweet_ids] + if hasattr(tweet_data, "edit_history_tweet_ids") and tweet_data.edit_history_tweet_ids + else None, ) diff --git a/tests/test_document_loader.py b/tests/test_document_loader.py index 0727517f..4dc9056e 100644 --- a/tests/test_document_loader.py +++ b/tests/test_document_loader.py @@ -2,7 +2,7 @@ from unittest.mock import Mock, patch from talos.data.dataset_manager import DatasetManager -from talos.tools.document_loader import DocumentLoaderTool, DatasetSearchTool +from talos.tools.document_loader import DatasetSearchTool, DocumentLoaderTool class TestDocumentLoader(unittest.TestCase): @@ -11,7 +11,7 @@ def setUp(self): self.dataset_manager = DatasetManager() self.document_loader = DocumentLoaderTool(self.dataset_manager) self.dataset_search = DatasetSearchTool(self.dataset_manager) - + def test_is_ipfs_hash(self): self.assertTrue(self.document_loader._is_ipfs_hash("QmYwAPJzv5CZsnA625s3Xf2nemtYgPpHdWEz79ojWnPbdG")) self.assertTrue(self.document_loader._is_ipfs_hash("ipfs://QmYwAPJzv5CZsnA625s3Xf2nemtYgPpHdWEz79ojWnPbdG")) @@ -20,19 +20,19 @@ def test_is_ipfs_hash(self): @patch('talos.utils.http_client.SecureHTTPClient.get') def test_fetch_content_from_url_text(self, mock_get): mock_response = Mock() - mock_response.headers = {'content-type': 'text/plain'} + mock_response.headers = {"content-type": "text/plain"} mock_response.text = "This is a test document." mock_response.raise_for_status.return_value = None mock_get.return_value = mock_response - + content = self.dataset_manager._fetch_content_from_url("https://example.com/test.txt") self.assertEqual(content, "This is a test document.") - + def test_clean_text(self): dirty_text = "This is a\n\n\n\ntest document." clean_text = self.dataset_manager._clean_text(dirty_text) self.assertEqual(clean_text, "This is a\n\ntest document.") - + def test_chunk_content(self): content = "This is sentence one. This is sentence two. This is sentence three." chunks = self.dataset_manager._process_and_chunk_content(content, chunk_size=30, chunk_overlap=10) @@ -40,5 +40,5 @@ def test_chunk_content(self): self.assertTrue(all(len(chunk) <= 40 for chunk in chunks)) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_scheduled_jobs.py b/tests/test_scheduled_jobs.py index 12f482f3..2b823d47 100644 --- a/tests/test_scheduled_jobs.py +++ b/tests/test_scheduled_jobs.py @@ -8,30 +8,27 @@ import pytest from langchain_openai import ChatOpenAI +from talos.core.job_scheduler import JobScheduler from talos.core.main_agent import MainAgent from talos.core.scheduled_job import ScheduledJob -from talos.core.job_scheduler import JobScheduler class MockScheduledJob(ScheduledJob): """Test implementation of ScheduledJob for testing purposes.""" - + execution_count: int = 0 last_execution: Optional[datetime] = None - + def __init__(self, name: str = "test_job", **kwargs): - if 'cron_expression' not in kwargs and 'execute_at' not in kwargs: - kwargs['cron_expression'] = "0 * * * *" # Every hour - - kwargs.setdefault('description', "Test scheduled job") - kwargs.setdefault('execution_count', 0) - kwargs.setdefault('last_execution', None) - - super().__init__( - name=name, - **kwargs - ) - + if "cron_expression" not in kwargs and "execute_at" not in kwargs: + kwargs["cron_expression"] = "0 * * * *" # Every hour + + kwargs.setdefault("description", "Test scheduled job") + kwargs.setdefault("execution_count", 0) + kwargs.setdefault("last_execution", None) + + super().__init__(name=name, **kwargs) + async def run(self, **kwargs) -> str: self.execution_count += 1 self.last_execution = datetime.now() @@ -40,18 +37,14 @@ async def run(self, **kwargs) -> str: class MockOneTimeJob(ScheduledJob): """Test implementation of one-time ScheduledJob.""" - + executed: bool = False - + def __init__(self, execute_at: datetime, **kwargs): super().__init__( - name="one_time_test", - description="One-time test job", - execute_at=execute_at, - executed=False, - **kwargs + name="one_time_test", description="One-time test job", execute_at=execute_at, executed=False, **kwargs ) - + async def run(self, **kwargs) -> str: self.executed = True return "One-time job executed" @@ -59,7 +52,7 @@ async def run(self, **kwargs) -> str: class TestScheduledJobValidation: """Test ScheduledJob validation and configuration.""" - + def test_cron_job_creation(self): """Test creating a job with cron expression.""" job = MockScheduledJob(name="cron_test", cron_expression="0 9 * * *") @@ -68,7 +61,7 @@ def test_cron_job_creation(self): assert job.execute_at is None assert job.is_recurring() assert not job.is_one_time() - + def test_one_time_job_creation(self): """Test creating a one-time job with datetime.""" future_time = datetime.now() + timedelta(hours=1) @@ -78,7 +71,7 @@ def test_one_time_job_creation(self): assert job.cron_expression is None assert job.is_one_time() assert not job.is_recurring() - + def test_job_validation_requires_schedule(self): """Test that job validation requires either cron or datetime.""" with pytest.raises(ValueError, match="Either cron_expression or execute_at must be provided"): @@ -88,9 +81,9 @@ def test_job_validation_requires_schedule(self): cron_expression=None, execute_at=None, execution_count=0, - last_execution=None + last_execution=None, ) - + def test_job_validation_exclusive_schedule(self): """Test that job validation prevents both cron and datetime.""" future_time = datetime.now() + timedelta(hours=1) @@ -101,18 +94,18 @@ def test_job_validation_exclusive_schedule(self): cron_expression="0 * * * *", execute_at=future_time, execution_count=0, - last_execution=None + last_execution=None, ) - + def test_should_execute_now(self): """Test should_execute_now method for one-time jobs.""" past_time = datetime.now() - timedelta(minutes=1) future_time = datetime.now() + timedelta(minutes=1) - + past_job = MockOneTimeJob(execute_at=past_time) future_job = MockOneTimeJob(execute_at=future_time) cron_job = MockScheduledJob() - + assert past_job.should_execute_now() assert not future_job.should_execute_now() assert not cron_job.should_execute_now() @@ -120,51 +113,51 @@ def test_should_execute_now(self): class TestJobScheduler: """Test JobScheduler functionality.""" - + @pytest.fixture def scheduler(self): """Create a JobScheduler instance for testing.""" return JobScheduler() - + def test_scheduler_initialization(self, scheduler): """Test scheduler initialization.""" assert scheduler.timezone == "UTC" assert not scheduler.is_running() assert len(scheduler.list_jobs()) == 0 - + def test_register_job(self, scheduler): """Test job registration.""" job = MockScheduledJob(name="test_register") scheduler.register_job(job) - + assert len(scheduler.list_jobs()) == 1 assert scheduler.get_job("test_register") == job - + def test_unregister_job(self, scheduler): """Test job unregistration.""" job = MockScheduledJob(name="test_unregister") scheduler.register_job(job) - + assert scheduler.unregister_job("test_unregister") assert len(scheduler.list_jobs()) == 0 assert scheduler.get_job("test_unregister") is None - + def test_unregister_nonexistent_job(self, scheduler): """Test unregistering a job that doesn't exist.""" assert not scheduler.unregister_job("nonexistent") - + def test_register_disabled_job(self, scheduler): """Test registering a disabled job.""" job = MockScheduledJob(name="disabled_job", enabled=False) scheduler.register_job(job) - + assert len(scheduler.list_jobs()) == 1 assert scheduler.get_job("disabled_job") == job class TestMainAgentIntegration: """Test MainAgent integration with scheduled jobs.""" - + @pytest.fixture def main_agent(self): """Create a MainAgent instance for testing.""" @@ -180,37 +173,37 @@ def main_agent(self): if agent.job_scheduler: agent.job_scheduler.stop() return agent - + def test_main_agent_scheduler_initialization(self, main_agent): """Test that MainAgent initializes with a job scheduler.""" assert main_agent.job_scheduler is not None assert isinstance(main_agent.job_scheduler, JobScheduler) - + def test_add_scheduled_job(self, main_agent): """Test adding a scheduled job to MainAgent.""" job = MockScheduledJob(name="main_agent_test") main_agent.add_scheduled_job(job) - + assert len(main_agent.list_scheduled_jobs()) == 1 assert main_agent.get_scheduled_job("main_agent_test") == job - + def test_remove_scheduled_job(self, main_agent): """Test removing a scheduled job from MainAgent.""" job = MockScheduledJob(name="remove_test") main_agent.add_scheduled_job(job) - + assert main_agent.remove_scheduled_job("remove_test") assert len(main_agent.list_scheduled_jobs()) == 0 assert main_agent.get_scheduled_job("remove_test") is None - + def test_pause_resume_job(self, main_agent): """Test pausing and resuming jobs.""" job = MockScheduledJob(name="pause_test") main_agent.add_scheduled_job(job) - + main_agent.pause_scheduled_job("pause_test") main_agent.resume_scheduled_job("pause_test") - + def test_predefined_jobs_registration(self): """Test that predefined jobs are registered during initialization.""" job = MockScheduledJob(name="predefined_job") @@ -223,44 +216,46 @@ def test_predefined_jobs_registration(self): agent = MainAgent( model=ChatOpenAI(model="gpt-5", api_key="test_key"), prompts_dir="src/talos/prompts", - scheduled_jobs=[job] + scheduled_jobs=[job], ) if agent.job_scheduler: agent.job_scheduler.stop() - + assert len(agent.list_scheduled_jobs()) == 1 assert agent.get_scheduled_job("predefined_job") == job def test_job_execution(): """Test that jobs can be executed.""" + async def run_test(): job = MockScheduledJob(name="execution_test") - + result = await job.run() - + assert job.execution_count == 1 assert job.last_execution is not None assert result == "Test job executed 1 times" - + result2 = await job.run() assert job.execution_count == 2 assert result2 == "Test job executed 2 times" - + asyncio.run(run_test()) def test_one_time_job_execution(): """Test one-time job execution.""" + async def run_test(): future_time = datetime.now() + timedelta(seconds=1) job = MockOneTimeJob(execute_at=future_time) - + assert not job.executed - + result = await job.run() - + assert job.executed assert result == "One-time job executed" - + asyncio.run(run_test()) diff --git a/tests/test_yield_manager.py b/tests/test_yield_manager.py index 0d8ece82..9ebd2952 100644 --- a/tests/test_yield_manager.py +++ b/tests/test_yield_manager.py @@ -51,7 +51,7 @@ def test_min_max_yield_validation(self, mock_tweepy_client): with self.assertRaises(ValueError): YieldManagerService(dexscreener_client, gecko_terminal_client, llm_client, min_yield=-0.01) - + with self.assertRaises(ValueError): YieldManagerService(dexscreener_client, gecko_terminal_client, llm_client, min_yield=0.2, max_yield=0.1) @@ -67,20 +67,18 @@ def test_apr_bounds_enforcement(self, mock_tweepy_client): volume=1000000, ) gecko_terminal_client.get_ohlcv_data.return_value = GeckoTerminalOHLCVData(ohlcv_list=[]) - - llm_client.reasoning.return_value = json.dumps( - {"apr": 0.25, "explanation": "High APR recommendation"} - ) - yield_manager = YieldManagerService(dexscreener_client, gecko_terminal_client, llm_client, min_yield=0.05, max_yield=0.20) + llm_client.reasoning.return_value = json.dumps({"apr": 0.25, "explanation": "High APR recommendation"}) + + yield_manager = YieldManagerService( + dexscreener_client, gecko_terminal_client, llm_client, min_yield=0.05, max_yield=0.20 + ) yield_manager.get_staked_supply_percentage = MagicMock(return_value=0.5) new_apr = yield_manager.update_staking_apr(75.0, "A report") self.assertEqual(new_apr, 0.20) - llm_client.reasoning.return_value = json.dumps( - {"apr": 0.01, "explanation": "Low APR recommendation"} - ) + llm_client.reasoning.return_value = json.dumps({"apr": 0.01, "explanation": "Low APR recommendation"}) new_apr = yield_manager.update_staking_apr(75.0, "A report") self.assertEqual(new_apr, 0.05)