diff --git a/src/agents/market_data.py b/src/agents/market_data.py index c8592e15..e4a78f85 100644 --- a/src/agents/market_data.py +++ b/src/agents/market_data.py @@ -1,74 +1,89 @@ +from datetime import datetime from langchain_openai.chat_models import ChatOpenAI from agents.state import AgentState -from tools.api import search_line_items, get_financial_metrics, get_insider_trades, get_market_cap, get_prices - -from datetime import datetime +from tools.api.financial_dataset import ( + FinancialDatasetAPI, +) llm = ChatOpenAI(model="gpt-4o") + def market_data_agent(state: AgentState): """Responsible for gathering and preprocessing market data""" messages = state["messages"] data = state["data"] # Set default dates - end_date = data["end_date"] or datetime.now().strftime('%Y-%m-%d') + end_date = data["end_date"] or datetime.now().strftime("%Y-%m-%d") if not data["start_date"]: # Calculate 3 months before end_date - end_date_obj = datetime.strptime(end_date, '%Y-%m-%d') - start_date = end_date_obj.replace(month=end_date_obj.month - 3) if end_date_obj.month > 3 else \ - end_date_obj.replace(year=end_date_obj.year - 1, month=end_date_obj.month + 9) - start_date = start_date.strftime('%Y-%m-%d') + end_date_obj = datetime.strptime(end_date, "%Y-%m-%d") + start_date = ( + end_date_obj.replace(month=end_date_obj.month - 3) + if end_date_obj.month > 3 + else end_date_obj.replace( + year=end_date_obj.year - 1, month=end_date_obj.month + 9 + ) + ) + start_date = start_date.strftime("%Y-%m-%d") else: start_date = data["start_date"] + financial_api = FinancialDatasetAPI() + # Get the historical price data - prices = get_prices( - ticker=data["ticker"], - start_date=start_date, + prices = financial_api.get_prices( + ticker=data["ticker"], + start_date=start_date, end_date=end_date, ) # Get the financial metrics - financial_metrics = get_financial_metrics( - ticker=data["ticker"], - report_period=end_date, - period='ttm', + financial_metrics = financial_api.get_financial_metrics( + ticker=data["ticker"], + report_period=end_date, + period="ttm", limit=1, ) # Get the insider trades - insider_trades = get_insider_trades( - ticker=data["ticker"], + insider_trades = financial_api.get_insider_trades( + ticker=data["ticker"], end_date=end_date, limit=5, ) # Get the market cap - market_cap = get_market_cap( + market_cap = financial_api.get_market_cap( ticker=data["ticker"], ) # Get the line_items - financial_line_items = search_line_items( - ticker=data["ticker"], - line_items=["free_cash_flow", "net_income", "depreciation_and_amortization", "capital_expenditure", "working_capital"], - period='ttm', + financial_line_items = financial_api.search_line_items( + ticker=data["ticker"], + line_items=[ + "free_cash_flow", + "net_income", + "depreciation_and_amortization", + "capital_expenditure", + "working_capital", + ], + period="ttm", limit=2, ) return { "messages": messages, "data": { - **data, - "prices": prices, - "start_date": start_date, + **data, + "prices": prices, + "start_date": start_date, "end_date": end_date, "financial_metrics": financial_metrics, "insider_trades": insider_trades, "market_cap": market_cap, "financial_line_items": financial_line_items, - } - } \ No newline at end of file + }, + } diff --git a/src/agents/risk_manager.py b/src/agents/risk_manager.py index ba20bb91..2326bd0e 100644 --- a/src/agents/risk_manager.py +++ b/src/agents/risk_manager.py @@ -1,12 +1,12 @@ +import ast +import json import math from langchain_core.messages import HumanMessage from agents.state import AgentState, show_agent_reasoning -from tools.api import prices_to_df +from tools.api.financial_dataset import FinancialDatasetAPI -import json -import ast ##### Risk Management Agent ##### def risk_management_agent(state: AgentState): @@ -15,7 +15,7 @@ def risk_management_agent(state: AgentState): portfolio = state["data"]["portfolio"] data = state["data"] - prices_df = prices_to_df(data["prices"]) + prices_df = FinancialDatasetAPI.prices_to_df(data["prices"]) # Fetch messages from other agents technical_message = next(msg for msg in state["messages"] if msg.name == "technical_analyst_agent") diff --git a/src/agents/technicals.py b/src/agents/technicals.py index 1e5fcd78..c847406e 100644 --- a/src/agents/technicals.py +++ b/src/agents/technicals.py @@ -9,7 +9,7 @@ import pandas as pd import numpy as np -from tools.api import prices_to_df +from tools.api.financial_dataset import FinancialDatasetAPI ##### Technical Analyst ##### @@ -25,7 +25,7 @@ def technical_analyst_agent(state: AgentState): show_reasoning = state["metadata"]["show_reasoning"] data = state["data"] prices = data["prices"] - prices_df = prices_to_df(prices) + prices_df = FinancialDatasetAPI.prices_to_df(prices) # Calculate indicators # 1. MACD (Moving Average Convergence Divergence) diff --git a/src/backtester.py b/src/backtester.py index d0a5d5cf..86185d29 100644 --- a/src/backtester.py +++ b/src/backtester.py @@ -4,7 +4,7 @@ import pandas as pd from main import run_hedge_fund -from tools.api import get_price_data +from tools.api.financial_dataset import FinancialDatasetAPI class Backtester: def __init__(self, agent, ticker, start_date, end_date, initial_capital): @@ -70,7 +70,7 @@ def run_backtest(self): ) action, quantity = self.parse_action(agent_output) - df = get_price_data(self.ticker, lookback_start, current_date_str) + df = FinancialDatasetAPI().get_prices(self.ticker, lookback_start, current_date_str) current_price = df.iloc[-1]['close'] # Execute the trade with validation diff --git a/src/tools/api.py b/src/tools/api.py deleted file mode 100644 index 5d241358..00000000 --- a/src/tools/api.py +++ /dev/null @@ -1,152 +0,0 @@ -import os -from typing import Dict, Any, List -import pandas as pd -import requests - -import requests - -def get_financial_metrics( - ticker: str, - report_period: str, - period: str = 'ttm', - limit: int = 1 -) -> List[Dict[str, Any]]: - """Fetch financial metrics from the API.""" - headers = {"X-API-KEY": os.environ.get("FINANCIAL_DATASETS_API_KEY")} - url = ( - f"https://api.financialdatasets.ai/financial-metrics/" - f"?ticker={ticker}" - f"&report_period_lte={report_period}" - f"&limit={limit}" - f"&period={period}" - ) - response = requests.get(url, headers=headers) - if response.status_code != 200: - raise Exception( - f"Error fetching data: {response.status_code} - {response.text}" - ) - data = response.json() - financial_metrics = data.get("financial_metrics") - if not financial_metrics: - raise ValueError("No financial metrics returned") - return financial_metrics - -def search_line_items( - ticker: str, - line_items: List[str], - period: str = 'ttm', - limit: int = 1 -) -> List[Dict[str, Any]]: - """Fetch cash flow statements from the API.""" - headers = {"X-API-KEY": os.environ.get("FINANCIAL_DATASETS_API_KEY")} - url = "https://api.financialdatasets.ai/financials/search/line-items" - - body = { - "tickers": [ticker], - "line_items": line_items, - "period": period, - "limit": limit - } - response = requests.post(url, headers=headers, json=body) - if response.status_code != 200: - raise Exception( - f"Error fetching data: {response.status_code} - {response.text}" - ) - data = response.json() - search_results = data.get("search_results") - if not search_results: - raise ValueError("No search results returned") - return search_results - -def get_insider_trades( - ticker: str, - end_date: str, - limit: int = 5, -) -> List[Dict[str, Any]]: - """ - Fetch insider trades for a given ticker and date range. - """ - headers = {"X-API-KEY": os.environ.get("FINANCIAL_DATASETS_API_KEY")} - url = ( - f"https://api.financialdatasets.ai/insider-trades/" - f"?ticker={ticker}" - f"&filing_date_lte={end_date}" - f"&limit={limit}" - ) - response = requests.get(url, headers=headers) - if response.status_code != 200: - raise Exception( - f"Error fetching data: {response.status_code} - {response.text}" - ) - data = response.json() - insider_trades = data.get("insider_trades") - if not insider_trades: - raise ValueError("No insider trades returned") - return insider_trades - -def get_market_cap( - ticker: str, -) -> List[Dict[str, Any]]: - """Fetch market cap from the API.""" - headers = {"X-API-KEY": os.environ.get("FINANCIAL_DATASETS_API_KEY")} - url = ( - f'https://api.financialdatasets.ai/company/facts' - f'?ticker={ticker}' - ) - - response = requests.get(url, headers=headers) - if response.status_code != 200: - raise Exception( - f"Error fetching data: {response.status_code} - {response.text}" - ) - data = response.json() - company_facts = data.get('company_facts') - if not company_facts: - raise ValueError("No company facts returned") - return company_facts.get('market_cap') - -def get_prices( - ticker: str, - start_date: str, - end_date: str -) -> List[Dict[str, Any]]: - """Fetch price data from the API.""" - headers = {"X-API-KEY": os.environ.get("FINANCIAL_DATASETS_API_KEY")} - url = ( - f"https://api.financialdatasets.ai/prices/" - f"?ticker={ticker}" - f"&interval=day" - f"&interval_multiplier=1" - f"&start_date={start_date}" - f"&end_date={end_date}" - ) - response = requests.get(url, headers=headers) - if response.status_code != 200: - raise Exception( - f"Error fetching data: {response.status_code} - {response.text}" - ) - data = response.json() - prices = data.get("prices") - if not prices: - raise ValueError("No price data returned") - return prices - -def prices_to_df(prices: List[Dict[str, Any]]) -> pd.DataFrame: - """Convert prices to a DataFrame.""" - df = pd.DataFrame(prices) - df["Date"] = pd.to_datetime(df["time"]) - df.set_index("Date", inplace=True) - numeric_cols = ["open", "close", "high", "low", "volume"] - for col in numeric_cols: - df[col] = pd.to_numeric(df[col], errors="coerce") - df.sort_index(inplace=True) - return df - -# Update the get_price_data function to use the new functions -def get_price_data( - ticker: str, - start_date: str, - end_date: str -) -> pd.DataFrame: - prices = get_prices(ticker, start_date, end_date) - return prices_to_df(prices) diff --git a/src/tools/api/__init__.py b/src/tools/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/tools/api/base.py b/src/tools/api/base.py new file mode 100644 index 00000000..67773186 --- /dev/null +++ b/src/tools/api/base.py @@ -0,0 +1,60 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional + +import requests +from requests.exceptions import RequestException + +from .config import BaseAPIConfig + + +class BaseAPIClient(ABC): + """Abstract base class for API clients.""" + + def __init__(self, config: BaseAPIConfig): + self.config = config + self.headers = self._get_headers() + + @abstractmethod + def _get_headers(self) -> Dict[str, str]: + """Return headers required for API requests.""" + pass + + @abstractmethod + def _handle_response(self, response: requests.Response) -> Dict[str, Any]: + """Handle API response and return parsed data.""" + pass + + def _get_base_url(self) -> str: + """Implementation of abstract method for base URL.""" + return self.config.base_url + + def _make_request( + self, + endpoint: str, + method: str = "GET", + params: Optional[Dict[str, Any]] = None, + json_data: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """Make HTTP request to the API with error handling.""" + url = f"{self._get_base_url()}/{endpoint.lstrip('/')}" + + try: + response = requests.request( + method=method, + url=url, + headers=self.headers, + params=params, + json=json_data, + ) + return self._handle_response(response) + except RequestException as e: + raise Exception(f"API request failed: {str(e)}") from e + + def _get_data_or_raise( + self, response_data: Dict[str, Any], key: str, error_message: str + ) -> Any: + """Extract data from response or raise if not found.""" + data = response_data.get(key) + if not data: + raise Exception(error_message) + return data diff --git a/src/tools/api/config.py b/src/tools/api/config.py new file mode 100644 index 00000000..27ac05f7 --- /dev/null +++ b/src/tools/api/config.py @@ -0,0 +1,31 @@ +import os +from dataclasses import dataclass + + +@dataclass +class BaseAPIConfig: + """Configuration for the Financial Datasets API.""" + + api_key: str + base_url: str + + @classmethod + def from_env(cls) -> "BaseAPIConfig": + """Create configuration from environment variables.""" + pass + + +@dataclass +class FinancialDatasetAPIConfig(BaseAPIConfig): + """Configuration for the Financial Datasets API.""" + + api_key: str + base_url: str = "https://api.financialdatasets.ai" + + @classmethod + def from_env(cls) -> "FinancialDatasetAPIConfig": + """Create configuration from environment variables.""" + api_key = os.environ.get("FINANCIAL_DATASETS_API_KEY") + if not api_key: + raise ValueError("FINANCIAL_DATASETS_API_KEY environment variable not set") + return cls(api_key=api_key) diff --git a/src/tools/api/financial_dataset.py b/src/tools/api/financial_dataset.py new file mode 100644 index 00000000..911f0a3f --- /dev/null +++ b/src/tools/api/financial_dataset.py @@ -0,0 +1,185 @@ +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Union + +import pandas as pd +import requests + +from .base import BaseAPIClient +from .config import FinancialDatasetAPIConfig + + +class Period(str, Enum): + """Enumeration of valid reporting periods.""" + + TTM = "ttm" + QUARTERLY = "quarterly" + ANNUAL = "annual" + + +class FinancialDatasetAPI(BaseAPIClient): + """Client for accessing financial dataset API endpoints.""" + + def __init__(self, config: FinancialDatasetAPIConfig = None): + """Initialize API client with optional custom config.""" + self.config = config if config else FinancialDatasetAPIConfig.from_env() + super().__init__(self.config) + + def _get_headers(self) -> Dict[str, str]: + return {"X-API-KEY": self.config.api_key} + + def _handle_response(self, response: requests.Response) -> Dict[str, Any]: + response.raise_for_status() + return response.json() + + @staticmethod + def prices_to_df(prices: List[Dict[str, Any]]) -> pd.DataFrame: + df = pd.DataFrame(prices) + df["Date"] = pd.to_datetime(df["time"]) + df.set_index("Date", inplace=True) + + numeric_cols = ["open", "close", "high", "low", "volume"] + for col in numeric_cols: + df[col] = pd.to_numeric(df[col], errors="coerce") + + return df.sort_index() + + def get_financial_metrics( + self, + ticker: str, + report_period: Union[str, datetime], + period: Period = Period.TTM, + limit: int = 1, + ) -> List[Dict[str, Any]]: + """ + Fetch financial metrics for a given ticker. + + Args: + ticker: Stock ticker symbol + report_period: End date for the report period + period: Reporting period type (TTM, quarterly, annual) + limit: Maximum number of records to return + + Returns: + List of financial metrics dictionaries + """ + params = { + "ticker": ticker, + "report_period_lte": report_period, + "limit": limit, + "period": period, + } + + response = self._make_request(endpoint="/financial-metrics/", params=params) + + return self._get_data_or_raise( + response, "financial_metrics", "No financial metrics found" + ) + + def search_line_items( + self, + ticker: str, + line_items: List[str], + period: Period = Period.TTM, + limit: int = 1, + ) -> List[Dict[str, Any]]: + """ + Search for specific line items in financial statements. + + Args: + ticker: Stock ticker symbol + line_items: List of line items to search for + period: Reporting period type + limit: Maximum number of records to return + Returns: + List of line items dictionaries + """ + payload = { + "tickers": [ticker], + "line_items": line_items, + "period": period, + "limit": limit, + } + + response = self._make_request( + endpoint="/financials/search/line-items", method="POST", json_data=payload + ) + + return self._get_data_or_raise( + response, "search_results", "No line items found" + ) + + def get_insider_trades( + self, ticker: str, end_date: Union[str, datetime], limit: int = 5 + ) -> List[Dict[str, Any]]: + """ + Fetch insider trading data. + + Args: + ticker: Stock ticker symbol + end_date: End date for the search period + limit: Maximum number of trades to return + Returns: + List of insider trades dictionaries + """ + params = {"ticker": ticker, "filing_date_lte": end_date, "limit": limit} + + response = self._make_request(endpoint="/insider-trades/", params=params) + + return self._get_data_or_raise( + response, "insider_trades", "No insider trades found" + ) + + def get_market_cap(self, ticker: str) -> float: + """ + Fetch market capitalization for a company. + + Args: + ticker: Stock ticker symbol + Returns: + Market capitalization as a float + """ + response = self._make_request( + endpoint="/company/facts", params={"ticker": ticker} + ) + + company_facts = self._get_data_or_raise( + response, "company_facts", "No company facts found" + ) + + market_cap = company_facts.get("market_cap") + if not market_cap: + raise Exception("Market cap not available") + + return market_cap + + def get_prices( + self, + ticker: str, + start_date: Union[str, datetime], + end_date: Union[str, datetime], + ) -> pd.DataFrame: + """ + Fetch and format price data as a DataFrame. + + Args: + ticker: Stock ticker symbol + start_date: Start date for price data + end_date: End date for price data + + Returns: + DataFrame with price data indexed by date + """ + params = { + "ticker": ticker, + "interval": "day", + "interval_multiplier": 1, + "start_date": start_date, + "end_date": end_date, + } + + response = self._make_request(endpoint="/prices/", params=params) + + prices = self._get_data_or_raise(response, "prices", "No price data found") + + return FinancialDatasetAPI.prices_to_df(prices)