-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add get_certified_relevant_news_since util function, and correspondin…
…g cache (#529)
- Loading branch information
1 parent
f34e049
commit e9bce61
Showing
5 changed files
with
394 additions
and
1 deletion.
There are no files selected for viewing
44 changes: 44 additions & 0 deletions
44
prediction_market_agent_tooling/tools/relevant_news_analysis/data_models.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from pydantic import BaseModel, Field | ||
|
||
from prediction_market_agent_tooling.tools.tavily.tavily_models import TavilyResult | ||
|
||
|
||
class RelevantNewsAnalysis(BaseModel): | ||
reasoning: str = Field( | ||
..., | ||
description="The reason why the news contains information relevant to the given question. Or if no news is relevant, why not.", | ||
) | ||
contains_relevant_news: bool = Field( | ||
..., | ||
description="A boolean flag for whether the news contains information relevant to the given question.", | ||
) | ||
|
||
|
||
class RelevantNews(BaseModel): | ||
question: str | ||
url: str | ||
summary: str | ||
relevance_reasoning: str | ||
days_ago: int | ||
|
||
@staticmethod | ||
def from_tavily_result_and_analysis( | ||
question: str, | ||
days_ago: int, | ||
tavily_result: TavilyResult, | ||
relevant_news_analysis: RelevantNewsAnalysis, | ||
) -> "RelevantNews": | ||
return RelevantNews( | ||
question=question, | ||
url=tavily_result.url, | ||
summary=tavily_result.content, | ||
relevance_reasoning=relevant_news_analysis.reasoning, | ||
days_ago=days_ago, | ||
) | ||
|
||
|
||
class NoRelevantNews(BaseModel): | ||
""" | ||
A placeholder model for when no relevant news is found. Enables ability to | ||
distinguish between 'a cache hit with no news' and 'a cache miss'. | ||
""" |
162 changes: 162 additions & 0 deletions
162
prediction_market_agent_tooling/tools/relevant_news_analysis/relevant_news_analysis.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
from datetime import datetime, timedelta | ||
|
||
from langchain_core.output_parsers import PydanticOutputParser | ||
from langchain_core.prompts import PromptTemplate | ||
from langchain_openai import ChatOpenAI | ||
|
||
from prediction_market_agent_tooling.config import APIKeys | ||
from prediction_market_agent_tooling.tools.langfuse_ import ( | ||
get_langfuse_langchain_config, | ||
observe, | ||
) | ||
from prediction_market_agent_tooling.tools.relevant_news_analysis.data_models import ( | ||
NoRelevantNews, | ||
RelevantNews, | ||
RelevantNewsAnalysis, | ||
) | ||
from prediction_market_agent_tooling.tools.relevant_news_analysis.relevant_news_cache import ( | ||
RelevantNewsResponseCache, | ||
) | ||
from prediction_market_agent_tooling.tools.tavily.tavily_search import ( | ||
get_relevant_news_since, | ||
) | ||
from prediction_market_agent_tooling.tools.tavily.tavily_storage import TavilyStorage | ||
from prediction_market_agent_tooling.tools.utils import check_not_none, utcnow | ||
|
||
SUMMARISE_RELEVANT_NEWS_PROMPT_TEMPLATE = """ | ||
You are an expert news analyst, tracking stories that may affect your prediction to the outcome of a particular QUESTION. | ||
Your role is to identify only the relevant information from a scraped news site (RAW_CONTENT), analyse it, and determine whether it contains developments or announcements occurring **after** the DATE_OF_INTEREST that could affect the outcome of the QUESTION. | ||
Note that the news article may be published after the DATE_OF_INTEREST, but reference information that is older than the DATE_OF_INTEREST. | ||
[QUESTION] | ||
{question} | ||
[DATE_OF_INTEREST] | ||
{date_of_interest} | ||
[RAW_CONTENT] | ||
{raw_content} | ||
For your analysis, you should: | ||
- Discard the 'noise' from the raw content (e.g. ads, irrelevant content) | ||
- Consider ONLY information that would have a notable impact on the outcome of the question. | ||
- Consider ONLY information relating to an announcement or development that occurred **after** the DATE_OF_INTEREST. | ||
- Present this information concisely in your reasoning. | ||
- In your reasoning, do not use the term 'DATE_OF_INTEREST' directly. Use the actual date you are referring to instead. | ||
- In your reasoning, do not use the term 'RAW_CONTENT' directly. Refer to it as 'the article', or quote the content you are referring to. | ||
{format_instructions} | ||
""" | ||
|
||
|
||
@observe() | ||
def analyse_news_relevance( | ||
raw_content: str, | ||
question: str, | ||
date_of_interest: datetime, | ||
model: str, | ||
temperature: float, | ||
) -> RelevantNewsAnalysis: | ||
""" | ||
Analyse whether the news contains new (relative to the given date) | ||
information relevant to the given question. | ||
""" | ||
parser = PydanticOutputParser(pydantic_object=RelevantNewsAnalysis) | ||
prompt = PromptTemplate( | ||
template=SUMMARISE_RELEVANT_NEWS_PROMPT_TEMPLATE, | ||
input_variables=["question", "date_of_interest", "raw_content"], | ||
partial_variables={"format_instructions": parser.get_format_instructions()}, | ||
) | ||
llm = ChatOpenAI( | ||
temperature=temperature, | ||
model=model, | ||
api_key=APIKeys().openai_api_key_secretstr_v1, | ||
) | ||
chain = prompt | llm | parser | ||
|
||
relevant_news_analysis: RelevantNewsAnalysis = chain.invoke( | ||
{ | ||
"raw_content": raw_content, | ||
"question": question, | ||
"date_of_interest": str(date_of_interest), | ||
}, | ||
config=get_langfuse_langchain_config(), | ||
) | ||
return relevant_news_analysis | ||
|
||
|
||
@observe() | ||
def get_certified_relevant_news_since( | ||
question: str, | ||
days_ago: int, | ||
tavily_storage: TavilyStorage | None = None, | ||
) -> RelevantNews | None: | ||
""" | ||
Get relevant news since a given date for a given question. Retrieves | ||
possibly relevant news from tavily, then checks that it is relevant via | ||
an LLM call. | ||
""" | ||
results = get_relevant_news_since( | ||
question=question, | ||
days_ago=days_ago, | ||
score_threshold=0.0, # Be conservative to avoid missing relevant information | ||
max_results=3, # A tradeoff between cost and quality. 3 seems to be a good balance. | ||
tavily_storage=tavily_storage, | ||
) | ||
|
||
# Sort results by descending 'relevance score' to maximise the chance of | ||
# finding relevant news early | ||
results = sorted( | ||
results, | ||
key=lambda result: result.score, | ||
reverse=True, | ||
) | ||
|
||
for result in results: | ||
relevant_news_analysis = analyse_news_relevance( | ||
raw_content=check_not_none(result.raw_content), | ||
question=question, | ||
date_of_interest=utcnow() - timedelta(days=days_ago), | ||
model="gpt-4o", # 4o-mini isn't good enough, 1o and 1o-mini are too expensive | ||
temperature=0.0, | ||
) | ||
|
||
# Return first relevant news found | ||
if relevant_news_analysis.contains_relevant_news: | ||
return RelevantNews.from_tavily_result_and_analysis( | ||
question=question, | ||
days_ago=days_ago, | ||
tavily_result=result, | ||
relevant_news_analysis=relevant_news_analysis, | ||
) | ||
|
||
# No relevant news found | ||
return None | ||
|
||
|
||
def get_certified_relevant_news_since_cached( | ||
question: str, | ||
days_ago: int, | ||
cache: RelevantNewsResponseCache, | ||
tavily_storage: TavilyStorage | None = None, | ||
) -> RelevantNews | None: | ||
cached = cache.find(question=question, days_ago=days_ago) | ||
|
||
if isinstance(cached, NoRelevantNews): | ||
return None | ||
elif cached is None: | ||
relevant_news = get_certified_relevant_news_since( | ||
question=question, | ||
days_ago=days_ago, | ||
tavily_storage=tavily_storage, | ||
) | ||
cache.save( | ||
question=question, | ||
days_ago=days_ago, | ||
relevant_news=relevant_news, | ||
) | ||
return relevant_news | ||
else: | ||
return cached |
90 changes: 90 additions & 0 deletions
90
prediction_market_agent_tooling/tools/relevant_news_analysis/relevant_news_cache.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from datetime import datetime, timedelta | ||
|
||
from pydantic import ValidationError | ||
from sqlmodel import Field, Session, SQLModel, create_engine, desc, select | ||
|
||
from prediction_market_agent_tooling.config import APIKeys | ||
from prediction_market_agent_tooling.loggers import logger | ||
from prediction_market_agent_tooling.tools.relevant_news_analysis.data_models import ( | ||
NoRelevantNews, | ||
RelevantNews, | ||
) | ||
from prediction_market_agent_tooling.tools.utils import utcnow | ||
|
||
|
||
class RelevantNewsCacheModel(SQLModel, table=True): | ||
__tablename__ = "relevant_news_response_cache" | ||
__table_args__ = {"extend_existing": True} | ||
id: int | None = Field(default=None, primary_key=True) | ||
question: str = Field(index=True) | ||
datetime_: datetime = Field(index=True) | ||
days_ago: int | ||
json_dump: str | None | ||
|
||
|
||
class RelevantNewsResponseCache: | ||
def __init__(self, sqlalchemy_db_url: str | None = None): | ||
self.engine = create_engine( | ||
sqlalchemy_db_url | ||
if sqlalchemy_db_url | ||
else APIKeys().sqlalchemy_db_url.get_secret_value() | ||
) | ||
self._initialize_db() | ||
|
||
def _initialize_db(self) -> None: | ||
""" | ||
Creates the tables if they don't exist | ||
""" | ||
with self.engine.connect() as conn: | ||
SQLModel.metadata.create_all( | ||
conn, | ||
tables=[SQLModel.metadata.tables[RelevantNewsCacheModel.__tablename__]], | ||
) | ||
|
||
def find( | ||
self, | ||
question: str, | ||
days_ago: int, | ||
) -> RelevantNews | NoRelevantNews | None: | ||
with Session(self.engine) as session: | ||
query = ( | ||
select(RelevantNewsCacheModel) | ||
.where(RelevantNewsCacheModel.question == question) | ||
.where(RelevantNewsCacheModel.days_ago <= days_ago) | ||
.where( | ||
RelevantNewsCacheModel.datetime_ >= utcnow() - timedelta(days=1) | ||
) # Cache entries expire after 1 day | ||
) | ||
item = session.exec( | ||
query.order_by(desc(RelevantNewsCacheModel.datetime_)) | ||
).first() | ||
|
||
if item is None: | ||
return None | ||
else: | ||
if item.json_dump is None: | ||
return NoRelevantNews() | ||
else: | ||
try: | ||
return RelevantNews.model_validate_json(item.json_dump) | ||
except ValidationError as e: | ||
logger.error( | ||
f"Error deserializing RelevantNews from cache for {question=}, {days_ago=} and {item=}: {e}" | ||
) | ||
return None | ||
|
||
def save( | ||
self, | ||
question: str, | ||
days_ago: int, | ||
relevant_news: RelevantNews | None, | ||
) -> None: | ||
with Session(self.engine) as session: | ||
cached = RelevantNewsCacheModel( | ||
question=question, | ||
days_ago=days_ago, | ||
datetime_=utcnow(), # Assumes that the cache is being updated at the time the news is found | ||
json_dump=relevant_news.model_dump_json() if relevant_news else None, | ||
) | ||
session.add(cached) | ||
session.commit() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.