Skip to content

Commit

Permalink
Add cache
Browse files Browse the repository at this point in the history
  • Loading branch information
evangriffiths committed Oct 23, 2024
1 parent 011529c commit bddd65e
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 42 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from pydantic import BaseModel, Field

from prediction_market_agent_tooling.tools.tavily.tavily_models import TavilyResult


class RelevantNewsAnalysis(BaseModel):
reasoning: str = Field(
...,
description="The reason why the news contains information relevant to the given question. Or if no news is relevant, why not.",
)
contains_relevant_news: bool = Field(
...,
description="A boolean flag for whether the news contains information relevant to the given question.",
)


class RelevantNews(BaseModel):
question: str
url: str
summary: str
relevance_reasoning: str
days_ago: int

@staticmethod
def from_tavily_result_and_analysis(
question: str,
days_ago: int,
taviy_result: TavilyResult,
relevant_news_analysis: RelevantNewsAnalysis,
) -> "RelevantNews":
return RelevantNews(
question=question,
url=taviy_result.url,
summary=taviy_result.content,
relevance_reasoning=relevant_news_analysis.reasoning,
days_ago=days_ago,
)


class NoRelevantNews(BaseModel):
"""
A placeholder model for when no relevant news is found. Enables ability to
distinguish between 'a cache hit with no news' and 'a cache miss'.
"""

# TODO not nice, but not sure how else to distinguish between no news and no
pass
Original file line number Diff line number Diff line change
Expand Up @@ -3,49 +3,26 @@
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field

from prediction_market_agent_tooling.config import APIKeys
from prediction_market_agent_tooling.tools.langfuse_ import (
get_langfuse_langchain_config,
observe,
)
from prediction_market_agent_tooling.tools.tavily.tavily_models import TavilyResult
from prediction_market_agent_tooling.tools.relevant_news_analysis.data_models import (
NoRelevantNews,
RelevantNews,
RelevantNewsAnalysis,
)
from prediction_market_agent_tooling.tools.relevant_news_analysis.relevant_news_cache import (
RelevantNewsResponseCache,
)
from prediction_market_agent_tooling.tools.tavily.tavily_search import (
get_relevant_news_since,
)
from prediction_market_agent_tooling.tools.tavily.tavily_storage import TavilyStorage
from prediction_market_agent_tooling.tools.utils import check_not_none, utcnow


class RelevantNewsAnalysis(BaseModel):
reasoning: str = Field(
...,
description="The reason why the news contains information relevant to the given question. Or if no news is relevant, why not.",
)
contains_relevant_news: bool = Field(
...,
description="A boolean flag for whether the news contains information relevant to the given question.",
)


class RelevantNews(BaseModel):
url: str
summary: str
relevance_reasoning: str

@staticmethod
def from_tavily_result_and_analysis(
taviy_result: TavilyResult,
relevant_news_analysis: RelevantNewsAnalysis,
) -> "RelevantNews":
return RelevantNews(
url=taviy_result.url,
summary=taviy_result.content,
relevance_reasoning=relevant_news_analysis.reasoning,
)


SUMMARISE_RELEVANT_NEWS_PROMPT_TEMPLATE = """
You are an expert news analyst, tracking stories that may affect your prediction to the outcome of a particular QUESTION.
Expand Down Expand Up @@ -114,24 +91,18 @@ def analyse_news_relevance(
def get_certified_relevant_news_since(
question: str,
days_ago: int,
model: str = "gpt-4o",
temperature: float = 0.0,
max_search_results: int = 3,
tavily_storage: TavilyStorage | None = None,
) -> RelevantNews | None:
"""
Get relevant news since a given date for a given question. Retrieves
possibly relevant news from tavily, then checks that it is relevant via
an LLM call.
TODO save/restore from a cache
TODO generate subquestions and get relevant news for each
"""
results = get_relevant_news_since(
question=question,
days_ago=days_ago,
score_threshold=0.0, # Be conservative to avoid missing relevant information
max_results=max_search_results,
max_results=3,
tavily_storage=tavily_storage,
)

Expand All @@ -148,16 +119,44 @@ def get_certified_relevant_news_since(
raw_content=check_not_none(result.raw_content),
question=question,
date_of_interest=utcnow() - timedelta(days=days_ago),
model=model,
temperature=temperature,
model="gpt-4o",
temperature=0.0,
)

# Return first relevant news found
if relevant_news_analysis.contains_relevant_news:
return RelevantNews.from_tavily_result_and_analysis(
question=question,
days_ago=days_ago,
taviy_result=result,
relevant_news_analysis=relevant_news_analysis,
)

# No relevant news found
return None


def get_certified_relevant_news_since_cached(
question: str,
days_ago: int,
cache: RelevantNewsResponseCache,
tavily_storage: TavilyStorage | None = None,
) -> RelevantNews | None:
cached = cache.find(question=question, days_ago=days_ago)

if isinstance(cached, NoRelevantNews):
return None
elif cached is None:
relevant_news = get_certified_relevant_news_since(
question=question,
days_ago=days_ago,
tavily_storage=tavily_storage,
)
cache.save(
question=question,
days_ago=days_ago,
relevant_news=relevant_news,
)
return relevant_news
else:
return cached
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from datetime import datetime, timedelta

from sqlmodel import Field, Session, SQLModel, create_engine, desc, select

from prediction_market_agent_tooling.config import APIKeys
from prediction_market_agent_tooling.loggers import logger
from prediction_market_agent_tooling.tools.relevant_news_analysis.data_models import (
NoRelevantNews,
RelevantNews,
)
from prediction_market_agent_tooling.tools.utils import check_not_none, utcnow


class RelevantNewsCacheModel(SQLModel, table=True):
__tablename__ = "relevant_news_response_cache"
__table_args__ = {"extend_existing": True}
id: int | None = Field(default=None, primary_key=True)
question: str = Field(index=True)
datetime_: datetime = Field(index=True)
days_ago: int
json_dump: str | None


class RelevantNewsResponseCache:
def __init__(self, sqlalchemy_db_url: str | None = None):
self.engine = create_engine(
sqlalchemy_db_url
if sqlalchemy_db_url
else APIKeys().sqlalchemy_db_url.get_secret_value()
)
self._initialize_db()

def _initialize_db(self) -> None:
"""
Creates the tables if they don't exist
"""

# trick for making models import mandatory - models must be imported for metadata.create_all to work
logger.debug(f"tables being added {RelevantNewsCacheModel}")
SQLModel.metadata.create_all(self.engine)

def find(
self,
question: str,
days_ago: int,
) -> RelevantNews | NoRelevantNews | None:
with Session(self.engine) as session:
query = (
select(RelevantNewsCacheModel)
.where(RelevantNewsCacheModel.question == question)
.where(RelevantNewsCacheModel.days_ago <= days_ago)
.where(RelevantNewsCacheModel.datetime_ >= utcnow() - timedelta(days=1))
)
item = session.exec(
query.order_by(desc(RelevantNewsCacheModel.datetime_))
).first()

if item == None:
return None
else:
item = check_not_none(item)
if item.json_dump is None:
return NoRelevantNews()
else:
try:
return RelevantNews.model_validate_json(item.json_dump)
except ValueError as e:
logger.error(
f"Error deserializing RelevantNews from cache for {question=}, {days_ago=} and {item=}: {e}"
)
return None

def save(
self,
question: str,
days_ago: int,
relevant_news: RelevantNews | None,
) -> None:
with Session(self.engine) as session:
cached = RelevantNewsCacheModel(
question=question,
days_ago=days_ago,
datetime_=utcnow(), # Assumes that the cache is being updated at the time the news is found
json_dump=relevant_news.model_dump_json() if relevant_news else None,
)
session.add(cached)
session.commit()
9 changes: 7 additions & 2 deletions tests_integration/tools/test_relevant_news_analysis.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import pytest
from langchain_community.callbacks import get_openai_callback

from prediction_market_agent_tooling.tools.relevant_news_analysis import (
from prediction_market_agent_tooling.tools.relevant_news_analysis.relevant_news_analysis import (
get_certified_relevant_news_since,
)
from tests.utils import RUN_PAID_TESTS


@pytest.mark.skipif(not RUN_PAID_TESTS, reason="This test costs money to run.")
def test_get_certified_relevant_news_since() -> None:
questions_days_ago_expected_results = [
(
Expand Down Expand Up @@ -36,7 +39,6 @@ def test_get_certified_relevant_news_since() -> None:
news = get_certified_relevant_news_since(
question=question,
days_ago=days_ago,
model="gpt-4o",
)
running_cost += cb.total_cost
iterations += 1
Expand All @@ -48,3 +50,6 @@ def test_get_certified_relevant_news_since() -> None:

average_cost = running_cost / iterations
assert average_cost < 0.03, f"Expected average: {average_cost}. Expected < 0.03"


# TODO test cache and get_certified_relevant_news_since_cached

0 comments on commit bddd65e

Please sign in to comment.