Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Add models for output and improve typing #63

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ readme = "README.md"
packages = [
{ include = "src", from = "." }
]

[tool.poetry.dependencies]
python = "^3.9"
langchain = "0.3.0"
Expand All @@ -19,13 +20,18 @@ matplotlib = "^3.9.2"
tabulate = "^0.9.0"
colorama = "^0.4.6"
questionary = "^2.1.0"
pydantic = "^2.10.5"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.0"
black = "^23.7.0"
isort = "^5.12.0"
flake8 = "^6.1.0"

[tool.black]
line-length = 100
target-version = ['py310']

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
1 change: 1 addition & 0 deletions src/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = "0.1.0"
27 changes: 17 additions & 10 deletions src/backtester.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from datetime import datetime, timedelta
from typing import Protocol

from dateutil.relativedelta import relativedelta
import questionary

import matplotlib.pyplot as plt
import pandas as pd
from tabulate import tabulate
from colorama import Fore, Back, Style, init
from colorama import Fore, Style, init

from utils.analysts import ANALYST_ORDER
from main import run_hedge_fund
from models.outputs import Analysts, RootResultModel
from tools.api import (
get_price_data,
get_prices,
Expand All @@ -21,8 +23,13 @@
init(autoreset=True)


class AgentFn(Protocol):
def __call__(self, ticker: str, start_date: str, end_date: str, portfolio: dict, show_reasoning: bool = False, selected_analysts: list[Analysts] | None = None) -> RootResultModel:
...


class Backtester:
def __init__(self, agent, ticker, start_date, end_date, initial_capital, selected_analysts=None):
def __init__(self, agent: AgentFn, ticker: str, start_date: str, end_date: str, initial_capital: float, selected_analysts: list[Analysts] | None = None):
self.agent = agent
self.ticker = ticker
self.start_date = start_date
Expand Down Expand Up @@ -124,8 +131,8 @@ def run_backtest(self):
selected_analysts=self.selected_analysts,
)

agent_decision = output["decision"]
action, quantity = agent_decision["action"], agent_decision["quantity"]
agent_decision = output.decision
action, quantity = agent_decision.action, agent_decision.quantity
df = get_price_data(self.ticker, lookback_start, current_date_str)
current_price = df.iloc[-1]["close"]

Expand All @@ -137,13 +144,13 @@ def run_backtest(self):
self.portfolio["portfolio_value"] = total_value

# Count signals from selected analysts only
analyst_signals = output["analyst_signals"]
analyst_signals = output.analyst_signals.signals

# Count signals
bullish_count = len([s for s in analyst_signals.values() if s.get("signal", "").lower() == "bullish"])
bearish_count = len([s for s in analyst_signals.values() if s.get("signal", "").lower() == "bearish"])
neutral_count = len([s for s in analyst_signals.values() if s.get("signal", "").lower() == "neutral"])

bullish_count = len([s for s in analyst_signals if s.signal == "bullish"])
bearish_count = len([s for s in analyst_signals if s.signal == "bearish"])
neutral_count = len([s for s in analyst_signals if s.signal == "neutral"])
print(f"Signal counts - Bullish: {bullish_count}, Bearish: {bearish_count}, Neutral: {neutral_count}")

# Format and add row
Expand Down
26 changes: 13 additions & 13 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
import argparse
from datetime import datetime

import questionary
from colorama import Fore, Style, init
from dateutil.relativedelta import relativedelta
from dotenv import load_dotenv
from langchain_core.messages import HumanMessage
from langgraph.graph import END, StateGraph
from colorama import Fore, Back, Style, init
import questionary

from agents.fundamentals import fundamentals_agent
from agents.portfolio_manager import portfolio_management_agent
from agents.technicals import technical_analyst_agent
from agents.risk_manager import risk_management_agent
from agents.sentiment import sentiment_agent
from graph.state import AgentState
from agents.technicals import technical_analyst_agent
from agents.valuation import valuation_agent
from graph.state import AgentState
from models.outputs import RootResultModel, Analysts
from utils.display import print_trading_output
from utils.analysts import ANALYST_ORDER

import argparse
from datetime import datetime
from dateutil.relativedelta import relativedelta
from tabulate import tabulate

# Load environment variables from .env file
load_dotenv()

Expand All @@ -39,8 +39,8 @@ def run_hedge_fund(
end_date: str,
portfolio: dict,
show_reasoning: bool = False,
selected_analysts: list = None,
):
selected_analysts: list[Analysts] = None,
) -> RootResultModel:
# Create a new workflow if analysts are customized
if selected_analysts is not None:
workflow = create_workflow(selected_analysts)
Expand All @@ -67,10 +67,10 @@ def run_hedge_fund(
},
},
)
return {
return RootResultModel.model_validate({
"decision": parse_hedge_fund_response(final_state["messages"][-1].content),
"analyst_signals": final_state["data"]["analyst_signals"],
}
})


def start(state: AgentState):
Expand Down
Empty file added src/models/__init__.py
Empty file.
167 changes: 167 additions & 0 deletions src/models/outputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
from typing import Literal

from pydantic import BaseModel

Signal_Type = Literal["neutral", "bullish", "bearish"]
Analysts = Literal[
"fundamentals_agent",
"sentiment_agent",
"technical_analyst_agent",
"valuation_agent",
]


class Signal(BaseModel):
confidence: float
signal: Signal_Type


class DetailsSignal(BaseModel):
details: str
signal: Signal_Type


class ReasoningAgent(BaseModel):
financial_health_signal: DetailsSignal | None = None
growth_signal: DetailsSignal | None = None
price_ratios_signal: DetailsSignal | None = None
profitability_signal: DetailsSignal | None = None


class MeanReversionMetrics(BaseModel):
price_vs_bb: float
rsi_14: float
rsi_28: float
z_score: float


class MeanReversionTechnicalSignal(Signal):
metrics: MeanReversionMetrics


class MomentumMetrics(BaseModel):
momentum_1m: float
momentum_3m: float
momentum_6m: float
volume_momentum: float


class MomentumTechnicalSignal(Signal):
metrics: MomentumMetrics


class StatisticalArbitrageMetrics(BaseModel):
hurst_exponent: float
kurtosis: float
skewness: float


class StatisticalArbitrageTechnicalSignal(Signal):
metrics: StatisticalArbitrageMetrics


class TrendFollowingMetrics(BaseModel):
adx: float
trend_strength: float


class TrendFollowingTechnicalSignal(Signal):
metrics: TrendFollowingMetrics


class VolatilityMetrics(BaseModel):
atr_ratio: float
historical_volatility: float
volatility_regime: float
volatility_z_score: float


class VolatilityTechnicalSignal(Signal):
metrics: VolatilityMetrics


class ReasoningTechnical(BaseModel):
mean_reversion: MeanReversionTechnicalSignal | None = None
momentum: MomentumTechnicalSignal | None = None
statistical_arbitrage: StatisticalArbitrageTechnicalSignal | None = None
trend_following: TrendFollowingTechnicalSignal | None = None
volatility: VolatilityTechnicalSignal | None = None


class ValuationDetails(BaseModel):
details: str
signal: Signal_Type


class ReasoningValuation(BaseModel):
dcf_analysis: ValuationDetails | None = None
owner_earnings_analysis: ValuationDetails | None = None


class FundamentalsAgent(Signal):
reasoning: ReasoningAgent # Specialized reasoning for fundamentals

def __str__(self):
return "Fundamentals"


class RiskManagementAgent(BaseModel):
max_position_size: float
reasoning: str

def __str__(self):
return "Risk Management"


class SentimentAgent(Signal):
pass

def __str__(self):
return "Sentiment"


class TechnicalAnalystAgent(Signal):
reasoning: ReasoningTechnical

def __str__(self):
return "Technical Analyst"


class ValuationAgent(Signal):
reasoning: ReasoningValuation

def __str__(self):
return "Valuation"


class AnalystSignals(BaseModel):
fundamentals_agent: FundamentalsAgent | None = None
risk_management_agent: RiskManagementAgent
sentiment_agent: SentimentAgent | None = None
technical_analyst_agent: TechnicalAnalystAgent | None = None
valuation_agent: ValuationAgent | None = None

@property
def signals(self) -> list[Signal]:
return [
agent
for agent in [
self.fundamentals_agent,
self.sentiment_agent,
self.technical_analyst_agent,
self.valuation_agent,
]
if agent
]


class Decision(BaseModel):
action: str
confidence: float | None
quantity: int
reasoning: str


class RootResultModel(BaseModel):
analyst_signals: AnalystSignals
decision: Decision | None
Loading