Skip to content

Commit

Permalink
Add get_certified_relevant_news_since util function, and correspondin…
Browse files Browse the repository at this point in the history
…g cache (#529)
  • Loading branch information
evangriffiths authored Oct 24, 2024
1 parent f34e049 commit e9bce61
Show file tree
Hide file tree
Showing 5 changed files with 394 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from pydantic import BaseModel, Field

from prediction_market_agent_tooling.tools.tavily.tavily_models import TavilyResult


class RelevantNewsAnalysis(BaseModel):
reasoning: str = Field(
...,
description="The reason why the news contains information relevant to the given question. Or if no news is relevant, why not.",
)
contains_relevant_news: bool = Field(
...,
description="A boolean flag for whether the news contains information relevant to the given question.",
)


class RelevantNews(BaseModel):
question: str
url: str
summary: str
relevance_reasoning: str
days_ago: int

@staticmethod
def from_tavily_result_and_analysis(
question: str,
days_ago: int,
tavily_result: TavilyResult,
relevant_news_analysis: RelevantNewsAnalysis,
) -> "RelevantNews":
return RelevantNews(
question=question,
url=tavily_result.url,
summary=tavily_result.content,
relevance_reasoning=relevant_news_analysis.reasoning,
days_ago=days_ago,
)


class NoRelevantNews(BaseModel):
"""
A placeholder model for when no relevant news is found. Enables ability to
distinguish between 'a cache hit with no news' and 'a cache miss'.
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
from datetime import datetime, timedelta

from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI

from prediction_market_agent_tooling.config import APIKeys
from prediction_market_agent_tooling.tools.langfuse_ import (
get_langfuse_langchain_config,
observe,
)
from prediction_market_agent_tooling.tools.relevant_news_analysis.data_models import (
NoRelevantNews,
RelevantNews,
RelevantNewsAnalysis,
)
from prediction_market_agent_tooling.tools.relevant_news_analysis.relevant_news_cache import (
RelevantNewsResponseCache,
)
from prediction_market_agent_tooling.tools.tavily.tavily_search import (
get_relevant_news_since,
)
from prediction_market_agent_tooling.tools.tavily.tavily_storage import TavilyStorage
from prediction_market_agent_tooling.tools.utils import check_not_none, utcnow

SUMMARISE_RELEVANT_NEWS_PROMPT_TEMPLATE = """
You are an expert news analyst, tracking stories that may affect your prediction to the outcome of a particular QUESTION.
Your role is to identify only the relevant information from a scraped news site (RAW_CONTENT), analyse it, and determine whether it contains developments or announcements occurring **after** the DATE_OF_INTEREST that could affect the outcome of the QUESTION.
Note that the news article may be published after the DATE_OF_INTEREST, but reference information that is older than the DATE_OF_INTEREST.
[QUESTION]
{question}
[DATE_OF_INTEREST]
{date_of_interest}
[RAW_CONTENT]
{raw_content}
For your analysis, you should:
- Discard the 'noise' from the raw content (e.g. ads, irrelevant content)
- Consider ONLY information that would have a notable impact on the outcome of the question.
- Consider ONLY information relating to an announcement or development that occurred **after** the DATE_OF_INTEREST.
- Present this information concisely in your reasoning.
- In your reasoning, do not use the term 'DATE_OF_INTEREST' directly. Use the actual date you are referring to instead.
- In your reasoning, do not use the term 'RAW_CONTENT' directly. Refer to it as 'the article', or quote the content you are referring to.
{format_instructions}
"""


@observe()
def analyse_news_relevance(
raw_content: str,
question: str,
date_of_interest: datetime,
model: str,
temperature: float,
) -> RelevantNewsAnalysis:
"""
Analyse whether the news contains new (relative to the given date)
information relevant to the given question.
"""
parser = PydanticOutputParser(pydantic_object=RelevantNewsAnalysis)
prompt = PromptTemplate(
template=SUMMARISE_RELEVANT_NEWS_PROMPT_TEMPLATE,
input_variables=["question", "date_of_interest", "raw_content"],
partial_variables={"format_instructions": parser.get_format_instructions()},
)
llm = ChatOpenAI(
temperature=temperature,
model=model,
api_key=APIKeys().openai_api_key_secretstr_v1,
)
chain = prompt | llm | parser

relevant_news_analysis: RelevantNewsAnalysis = chain.invoke(
{
"raw_content": raw_content,
"question": question,
"date_of_interest": str(date_of_interest),
},
config=get_langfuse_langchain_config(),
)
return relevant_news_analysis


@observe()
def get_certified_relevant_news_since(
question: str,
days_ago: int,
tavily_storage: TavilyStorage | None = None,
) -> RelevantNews | None:
"""
Get relevant news since a given date for a given question. Retrieves
possibly relevant news from tavily, then checks that it is relevant via
an LLM call.
"""
results = get_relevant_news_since(
question=question,
days_ago=days_ago,
score_threshold=0.0, # Be conservative to avoid missing relevant information
max_results=3, # A tradeoff between cost and quality. 3 seems to be a good balance.
tavily_storage=tavily_storage,
)

# Sort results by descending 'relevance score' to maximise the chance of
# finding relevant news early
results = sorted(
results,
key=lambda result: result.score,
reverse=True,
)

for result in results:
relevant_news_analysis = analyse_news_relevance(
raw_content=check_not_none(result.raw_content),
question=question,
date_of_interest=utcnow() - timedelta(days=days_ago),
model="gpt-4o", # 4o-mini isn't good enough, 1o and 1o-mini are too expensive
temperature=0.0,
)

# Return first relevant news found
if relevant_news_analysis.contains_relevant_news:
return RelevantNews.from_tavily_result_and_analysis(
question=question,
days_ago=days_ago,
tavily_result=result,
relevant_news_analysis=relevant_news_analysis,
)

# No relevant news found
return None


def get_certified_relevant_news_since_cached(
question: str,
days_ago: int,
cache: RelevantNewsResponseCache,
tavily_storage: TavilyStorage | None = None,
) -> RelevantNews | None:
cached = cache.find(question=question, days_ago=days_ago)

if isinstance(cached, NoRelevantNews):
return None
elif cached is None:
relevant_news = get_certified_relevant_news_since(
question=question,
days_ago=days_ago,
tavily_storage=tavily_storage,
)
cache.save(
question=question,
days_ago=days_ago,
relevant_news=relevant_news,
)
return relevant_news
else:
return cached
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from datetime import datetime, timedelta

from pydantic import ValidationError
from sqlmodel import Field, Session, SQLModel, create_engine, desc, select

from prediction_market_agent_tooling.config import APIKeys
from prediction_market_agent_tooling.loggers import logger
from prediction_market_agent_tooling.tools.relevant_news_analysis.data_models import (
NoRelevantNews,
RelevantNews,
)
from prediction_market_agent_tooling.tools.utils import utcnow


class RelevantNewsCacheModel(SQLModel, table=True):
__tablename__ = "relevant_news_response_cache"
__table_args__ = {"extend_existing": True}
id: int | None = Field(default=None, primary_key=True)
question: str = Field(index=True)
datetime_: datetime = Field(index=True)
days_ago: int
json_dump: str | None


class RelevantNewsResponseCache:
def __init__(self, sqlalchemy_db_url: str | None = None):
self.engine = create_engine(
sqlalchemy_db_url
if sqlalchemy_db_url
else APIKeys().sqlalchemy_db_url.get_secret_value()
)
self._initialize_db()

def _initialize_db(self) -> None:
"""
Creates the tables if they don't exist
"""
with self.engine.connect() as conn:
SQLModel.metadata.create_all(
conn,
tables=[SQLModel.metadata.tables[RelevantNewsCacheModel.__tablename__]],
)

def find(
self,
question: str,
days_ago: int,
) -> RelevantNews | NoRelevantNews | None:
with Session(self.engine) as session:
query = (
select(RelevantNewsCacheModel)
.where(RelevantNewsCacheModel.question == question)
.where(RelevantNewsCacheModel.days_ago <= days_ago)
.where(
RelevantNewsCacheModel.datetime_ >= utcnow() - timedelta(days=1)
) # Cache entries expire after 1 day
)
item = session.exec(
query.order_by(desc(RelevantNewsCacheModel.datetime_))
).first()

if item is None:
return None
else:
if item.json_dump is None:
return NoRelevantNews()
else:
try:
return RelevantNews.model_validate_json(item.json_dump)
except ValidationError as e:
logger.error(
f"Error deserializing RelevantNews from cache for {question=}, {days_ago=} and {item=}: {e}"
)
return None

def save(
self,
question: str,
days_ago: int,
relevant_news: RelevantNews | None,
) -> None:
with Session(self.engine) as session:
cached = RelevantNewsCacheModel(
question=question,
days_ago=days_ago,
datetime_=utcnow(), # Assumes that the cache is being updated at the time the news is found
json_dump=relevant_news.model_dump_json() if relevant_news else None,
)
session.add(cached)
session.commit()
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def _tavily_search(
return response


def get_related_news_since(
def get_relevant_news_since(
question: str,
days_ago: int,
score_threshold: float = DEFAULT_SCORE_THRESHOLD,
Expand Down
Loading

0 comments on commit e9bce61

Please sign in to comment.