Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor FinancialDatasetAPI integration and add configuration classes for API setup. #42

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 42 additions & 27 deletions src/agents/market_data.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,89 @@
from datetime import datetime

from langchain_openai.chat_models import ChatOpenAI

from agents.state import AgentState
from tools.api import search_line_items, get_financial_metrics, get_insider_trades, get_market_cap, get_prices

from datetime import datetime
from tools.api.financial_dataset import (
FinancialDatasetAPI,
)

llm = ChatOpenAI(model="gpt-4o")


def market_data_agent(state: AgentState):
"""Responsible for gathering and preprocessing market data"""
messages = state["messages"]
data = state["data"]

# Set default dates
end_date = data["end_date"] or datetime.now().strftime('%Y-%m-%d')
end_date = data["end_date"] or datetime.now().strftime("%Y-%m-%d")
if not data["start_date"]:
# Calculate 3 months before end_date
end_date_obj = datetime.strptime(end_date, '%Y-%m-%d')
start_date = end_date_obj.replace(month=end_date_obj.month - 3) if end_date_obj.month > 3 else \
end_date_obj.replace(year=end_date_obj.year - 1, month=end_date_obj.month + 9)
start_date = start_date.strftime('%Y-%m-%d')
end_date_obj = datetime.strptime(end_date, "%Y-%m-%d")
start_date = (
end_date_obj.replace(month=end_date_obj.month - 3)
if end_date_obj.month > 3
else end_date_obj.replace(
year=end_date_obj.year - 1, month=end_date_obj.month + 9
)
)
start_date = start_date.strftime("%Y-%m-%d")
else:
start_date = data["start_date"]

financial_api = FinancialDatasetAPI()

# Get the historical price data
prices = get_prices(
ticker=data["ticker"],
start_date=start_date,
prices = financial_api.get_prices(
ticker=data["ticker"],
start_date=start_date,
end_date=end_date,
)

# Get the financial metrics
financial_metrics = get_financial_metrics(
ticker=data["ticker"],
report_period=end_date,
period='ttm',
financial_metrics = financial_api.get_financial_metrics(
ticker=data["ticker"],
report_period=end_date,
period="ttm",
limit=1,
)

# Get the insider trades
insider_trades = get_insider_trades(
ticker=data["ticker"],
insider_trades = financial_api.get_insider_trades(
ticker=data["ticker"],
end_date=end_date,
limit=5,
)

# Get the market cap
market_cap = get_market_cap(
market_cap = financial_api.get_market_cap(
ticker=data["ticker"],
)

# Get the line_items
financial_line_items = search_line_items(
ticker=data["ticker"],
line_items=["free_cash_flow", "net_income", "depreciation_and_amortization", "capital_expenditure", "working_capital"],
period='ttm',
financial_line_items = financial_api.search_line_items(
ticker=data["ticker"],
line_items=[
"free_cash_flow",
"net_income",
"depreciation_and_amortization",
"capital_expenditure",
"working_capital",
],
period="ttm",
limit=2,
)

return {
"messages": messages,
"data": {
**data,
"prices": prices,
"start_date": start_date,
**data,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3.9+ syntax is preferred, i.e. data | { ... }, rather than { **data, ... }

"prices": prices,
"start_date": start_date,
"end_date": end_date,
"financial_metrics": financial_metrics,
"insider_trades": insider_trades,
"market_cap": market_cap,
"financial_line_items": financial_line_items,
}
}
},
}
8 changes: 4 additions & 4 deletions src/agents/risk_manager.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import ast
import json
import math

from langchain_core.messages import HumanMessage

from agents.state import AgentState, show_agent_reasoning
from tools.api import prices_to_df
from tools.api.financial_dataset import FinancialDatasetAPI

import json
import ast

##### Risk Management Agent #####
def risk_management_agent(state: AgentState):
Expand All @@ -15,7 +15,7 @@ def risk_management_agent(state: AgentState):
portfolio = state["data"]["portfolio"]
data = state["data"]

prices_df = prices_to_df(data["prices"])
prices_df = FinancialDatasetAPI.prices_to_df(data["prices"])

# Fetch messages from other agents
technical_message = next(msg for msg in state["messages"] if msg.name == "technical_analyst_agent")
Expand Down
4 changes: 2 additions & 2 deletions src/agents/technicals.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pandas as pd
import numpy as np

from tools.api import prices_to_df
from tools.api.financial_dataset import FinancialDatasetAPI


##### Technical Analyst #####
Expand All @@ -25,7 +25,7 @@ def technical_analyst_agent(state: AgentState):
show_reasoning = state["metadata"]["show_reasoning"]
data = state["data"]
prices = data["prices"]
prices_df = prices_to_df(prices)
prices_df = FinancialDatasetAPI.prices_to_df(prices)

# Calculate indicators
# 1. MACD (Moving Average Convergence Divergence)
Expand Down
4 changes: 2 additions & 2 deletions src/backtester.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd

from main import run_hedge_fund
from tools.api import get_price_data
from tools.api.financial_dataset import FinancialDatasetAPI

class Backtester:
def __init__(self, agent, ticker, start_date, end_date, initial_capital):
Expand Down Expand Up @@ -70,7 +70,7 @@ def run_backtest(self):
)

action, quantity = self.parse_action(agent_output)
df = get_price_data(self.ticker, lookback_start, current_date_str)
df = FinancialDatasetAPI().get_prices(self.ticker, lookback_start, current_date_str)
current_price = df.iloc[-1]['close']

# Execute the trade with validation
Expand Down
152 changes: 0 additions & 152 deletions src/tools/api.py

This file was deleted.

Empty file added src/tools/api/__init__.py
Empty file.
60 changes: 60 additions & 0 deletions src/tools/api/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional

import requests
from requests.exceptions import RequestException

from .config import BaseAPIConfig


class BaseAPIClient(ABC):
"""Abstract base class for API clients."""

def __init__(self, config: BaseAPIConfig):
self.config = config
self.headers = self._get_headers()

@abstractmethod
def _get_headers(self) -> Dict[str, str]:
"""Return headers required for API requests."""
pass

@abstractmethod
def _handle_response(self, response: requests.Response) -> Dict[str, Any]:
"""Handle API response and return parsed data."""
pass

def _get_base_url(self) -> str:
"""Implementation of abstract method for base URL."""
return self.config.base_url

def _make_request(
self,
endpoint: str,
method: str = "GET",
params: Optional[Dict[str, Any]] = None,
json_data: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Make HTTP request to the API with error handling."""
url = f"{self._get_base_url()}/{endpoint.lstrip('/')}"

try:
response = requests.request(
method=method,
url=url,
headers=self.headers,
params=params,
json=json_data,
)
return self._handle_response(response)
except RequestException as e:
raise Exception(f"API request failed: {str(e)}") from e

def _get_data_or_raise(
self, response_data: Dict[str, Any], key: str, error_message: str
) -> Any:
"""Extract data from response or raise if not found."""
data = response_data.get(key)
if not data:
raise Exception(error_message)
return data
Loading