diff --git a/prediction_market_agent_tooling/tools/relevant_news_analysis/data_models.py b/prediction_market_agent_tooling/tools/relevant_news_analysis/data_models.py new file mode 100644 index 00000000..ec93f84e --- /dev/null +++ b/prediction_market_agent_tooling/tools/relevant_news_analysis/data_models.py @@ -0,0 +1,44 @@ +from pydantic import BaseModel, Field + +from prediction_market_agent_tooling.tools.tavily.tavily_models import TavilyResult + + +class RelevantNewsAnalysis(BaseModel): + reasoning: str = Field( + ..., + description="The reason why the news contains information relevant to the given question. Or if no news is relevant, why not.", + ) + contains_relevant_news: bool = Field( + ..., + description="A boolean flag for whether the news contains information relevant to the given question.", + ) + + +class RelevantNews(BaseModel): + question: str + url: str + summary: str + relevance_reasoning: str + days_ago: int + + @staticmethod + def from_tavily_result_and_analysis( + question: str, + days_ago: int, + tavily_result: TavilyResult, + relevant_news_analysis: RelevantNewsAnalysis, + ) -> "RelevantNews": + return RelevantNews( + question=question, + url=tavily_result.url, + summary=tavily_result.content, + relevance_reasoning=relevant_news_analysis.reasoning, + days_ago=days_ago, + ) + + +class NoRelevantNews(BaseModel): + """ + A placeholder model for when no relevant news is found. Enables ability to + distinguish between 'a cache hit with no news' and 'a cache miss'. + """ diff --git a/prediction_market_agent_tooling/tools/relevant_news_analysis/relevant_news_analysis.py b/prediction_market_agent_tooling/tools/relevant_news_analysis/relevant_news_analysis.py new file mode 100644 index 00000000..2a469853 --- /dev/null +++ b/prediction_market_agent_tooling/tools/relevant_news_analysis/relevant_news_analysis.py @@ -0,0 +1,162 @@ +from datetime import datetime, timedelta + +from langchain_core.output_parsers import PydanticOutputParser +from langchain_core.prompts import PromptTemplate +from langchain_openai import ChatOpenAI + +from prediction_market_agent_tooling.config import APIKeys +from prediction_market_agent_tooling.tools.langfuse_ import ( + get_langfuse_langchain_config, + observe, +) +from prediction_market_agent_tooling.tools.relevant_news_analysis.data_models import ( + NoRelevantNews, + RelevantNews, + RelevantNewsAnalysis, +) +from prediction_market_agent_tooling.tools.relevant_news_analysis.relevant_news_cache import ( + RelevantNewsResponseCache, +) +from prediction_market_agent_tooling.tools.tavily.tavily_search import ( + get_relevant_news_since, +) +from prediction_market_agent_tooling.tools.tavily.tavily_storage import TavilyStorage +from prediction_market_agent_tooling.tools.utils import check_not_none, utcnow + +SUMMARISE_RELEVANT_NEWS_PROMPT_TEMPLATE = """ +You are an expert news analyst, tracking stories that may affect your prediction to the outcome of a particular QUESTION. + +Your role is to identify only the relevant information from a scraped news site (RAW_CONTENT), analyse it, and determine whether it contains developments or announcements occurring **after** the DATE_OF_INTEREST that could affect the outcome of the QUESTION. + +Note that the news article may be published after the DATE_OF_INTEREST, but reference information that is older than the DATE_OF_INTEREST. + +[QUESTION] +{question} + +[DATE_OF_INTEREST] +{date_of_interest} + +[RAW_CONTENT] +{raw_content} + +For your analysis, you should: +- Discard the 'noise' from the raw content (e.g. ads, irrelevant content) +- Consider ONLY information that would have a notable impact on the outcome of the question. +- Consider ONLY information relating to an announcement or development that occurred **after** the DATE_OF_INTEREST. +- Present this information concisely in your reasoning. +- In your reasoning, do not use the term 'DATE_OF_INTEREST' directly. Use the actual date you are referring to instead. +- In your reasoning, do not use the term 'RAW_CONTENT' directly. Refer to it as 'the article', or quote the content you are referring to. + +{format_instructions} +""" + + +@observe() +def analyse_news_relevance( + raw_content: str, + question: str, + date_of_interest: datetime, + model: str, + temperature: float, +) -> RelevantNewsAnalysis: + """ + Analyse whether the news contains new (relative to the given date) + information relevant to the given question. + """ + parser = PydanticOutputParser(pydantic_object=RelevantNewsAnalysis) + prompt = PromptTemplate( + template=SUMMARISE_RELEVANT_NEWS_PROMPT_TEMPLATE, + input_variables=["question", "date_of_interest", "raw_content"], + partial_variables={"format_instructions": parser.get_format_instructions()}, + ) + llm = ChatOpenAI( + temperature=temperature, + model=model, + api_key=APIKeys().openai_api_key_secretstr_v1, + ) + chain = prompt | llm | parser + + relevant_news_analysis: RelevantNewsAnalysis = chain.invoke( + { + "raw_content": raw_content, + "question": question, + "date_of_interest": str(date_of_interest), + }, + config=get_langfuse_langchain_config(), + ) + return relevant_news_analysis + + +@observe() +def get_certified_relevant_news_since( + question: str, + days_ago: int, + tavily_storage: TavilyStorage | None = None, +) -> RelevantNews | None: + """ + Get relevant news since a given date for a given question. Retrieves + possibly relevant news from tavily, then checks that it is relevant via + an LLM call. + """ + results = get_relevant_news_since( + question=question, + days_ago=days_ago, + score_threshold=0.0, # Be conservative to avoid missing relevant information + max_results=3, # A tradeoff between cost and quality. 3 seems to be a good balance. + tavily_storage=tavily_storage, + ) + + # Sort results by descending 'relevance score' to maximise the chance of + # finding relevant news early + results = sorted( + results, + key=lambda result: result.score, + reverse=True, + ) + + for result in results: + relevant_news_analysis = analyse_news_relevance( + raw_content=check_not_none(result.raw_content), + question=question, + date_of_interest=utcnow() - timedelta(days=days_ago), + model="gpt-4o", # 4o-mini isn't good enough, 1o and 1o-mini are too expensive + temperature=0.0, + ) + + # Return first relevant news found + if relevant_news_analysis.contains_relevant_news: + return RelevantNews.from_tavily_result_and_analysis( + question=question, + days_ago=days_ago, + tavily_result=result, + relevant_news_analysis=relevant_news_analysis, + ) + + # No relevant news found + return None + + +def get_certified_relevant_news_since_cached( + question: str, + days_ago: int, + cache: RelevantNewsResponseCache, + tavily_storage: TavilyStorage | None = None, +) -> RelevantNews | None: + cached = cache.find(question=question, days_ago=days_ago) + + if isinstance(cached, NoRelevantNews): + return None + elif cached is None: + relevant_news = get_certified_relevant_news_since( + question=question, + days_ago=days_ago, + tavily_storage=tavily_storage, + ) + cache.save( + question=question, + days_ago=days_ago, + relevant_news=relevant_news, + ) + return relevant_news + else: + return cached diff --git a/prediction_market_agent_tooling/tools/relevant_news_analysis/relevant_news_cache.py b/prediction_market_agent_tooling/tools/relevant_news_analysis/relevant_news_cache.py new file mode 100644 index 00000000..d1257ec5 --- /dev/null +++ b/prediction_market_agent_tooling/tools/relevant_news_analysis/relevant_news_cache.py @@ -0,0 +1,90 @@ +from datetime import datetime, timedelta + +from pydantic import ValidationError +from sqlmodel import Field, Session, SQLModel, create_engine, desc, select + +from prediction_market_agent_tooling.config import APIKeys +from prediction_market_agent_tooling.loggers import logger +from prediction_market_agent_tooling.tools.relevant_news_analysis.data_models import ( + NoRelevantNews, + RelevantNews, +) +from prediction_market_agent_tooling.tools.utils import utcnow + + +class RelevantNewsCacheModel(SQLModel, table=True): + __tablename__ = "relevant_news_response_cache" + __table_args__ = {"extend_existing": True} + id: int | None = Field(default=None, primary_key=True) + question: str = Field(index=True) + datetime_: datetime = Field(index=True) + days_ago: int + json_dump: str | None + + +class RelevantNewsResponseCache: + def __init__(self, sqlalchemy_db_url: str | None = None): + self.engine = create_engine( + sqlalchemy_db_url + if sqlalchemy_db_url + else APIKeys().sqlalchemy_db_url.get_secret_value() + ) + self._initialize_db() + + def _initialize_db(self) -> None: + """ + Creates the tables if they don't exist + """ + with self.engine.connect() as conn: + SQLModel.metadata.create_all( + conn, + tables=[SQLModel.metadata.tables[RelevantNewsCacheModel.__tablename__]], + ) + + def find( + self, + question: str, + days_ago: int, + ) -> RelevantNews | NoRelevantNews | None: + with Session(self.engine) as session: + query = ( + select(RelevantNewsCacheModel) + .where(RelevantNewsCacheModel.question == question) + .where(RelevantNewsCacheModel.days_ago <= days_ago) + .where( + RelevantNewsCacheModel.datetime_ >= utcnow() - timedelta(days=1) + ) # Cache entries expire after 1 day + ) + item = session.exec( + query.order_by(desc(RelevantNewsCacheModel.datetime_)) + ).first() + + if item is None: + return None + else: + if item.json_dump is None: + return NoRelevantNews() + else: + try: + return RelevantNews.model_validate_json(item.json_dump) + except ValidationError as e: + logger.error( + f"Error deserializing RelevantNews from cache for {question=}, {days_ago=} and {item=}: {e}" + ) + return None + + def save( + self, + question: str, + days_ago: int, + relevant_news: RelevantNews | None, + ) -> None: + with Session(self.engine) as session: + cached = RelevantNewsCacheModel( + question=question, + days_ago=days_ago, + datetime_=utcnow(), # Assumes that the cache is being updated at the time the news is found + json_dump=relevant_news.model_dump_json() if relevant_news else None, + ) + session.add(cached) + session.commit() diff --git a/prediction_market_agent_tooling/tools/tavily/tavily_search.py b/prediction_market_agent_tooling/tools/tavily/tavily_search.py index b3559dd3..4f27e11f 100644 --- a/prediction_market_agent_tooling/tools/tavily/tavily_search.py +++ b/prediction_market_agent_tooling/tools/tavily/tavily_search.py @@ -129,7 +129,7 @@ def _tavily_search( return response -def get_related_news_since( +def get_relevant_news_since( question: str, days_ago: int, score_threshold: float = DEFAULT_SCORE_THRESHOLD, diff --git a/tests_integration/tools/test_relevant_news_analysis.py b/tests_integration/tools/test_relevant_news_analysis.py new file mode 100644 index 00000000..bf260c3b --- /dev/null +++ b/tests_integration/tools/test_relevant_news_analysis.py @@ -0,0 +1,97 @@ +from unittest.mock import patch + +import pytest +from langchain_community.callbacks import get_openai_callback + +from prediction_market_agent_tooling.tools.relevant_news_analysis.data_models import ( + RelevantNews, +) +from prediction_market_agent_tooling.tools.relevant_news_analysis.relevant_news_analysis import ( + get_certified_relevant_news_since, + get_certified_relevant_news_since_cached, +) +from prediction_market_agent_tooling.tools.relevant_news_analysis.relevant_news_cache import ( + RelevantNewsResponseCache, +) +from tests.utils import RUN_PAID_TESTS + + +@pytest.mark.skipif(not RUN_PAID_TESTS, reason="This test costs money to run.") +def test_get_certified_relevant_news_since() -> None: + questions_days_ago_expected_results = [ + ( + "Will the price of Bitcoin be higher than $100,000 by the end of the year?", + True, + 5, + ), + ( + "Will the strength of the Earth's gravitational field change by more than 3% any time before the end of the calendar year?", + False, + 2, + ), + ( + "Will the number of Chinese-made electric cars sold worldwide this year be higher than in the previous calendar year?", + True, + 90, + ), + ( + "Will total UK cinema box office sales this month be higher than in the previous calendar month?", + True, + 14, + ), + ] + + running_cost = 0.0 + iterations = 0 + for question, expected_result, days_ago in questions_days_ago_expected_results: + with get_openai_callback() as cb: + news = get_certified_relevant_news_since( + question=question, + days_ago=days_ago, + ) + running_cost += cb.total_cost + iterations += 1 + + has_related_news = news is not None + assert ( + has_related_news == expected_result + ), f"Was relevant news found for question '{question}'?: {has_related_news}. Expected result {expected_result}" + + average_cost = running_cost / iterations # $0.01289 when run on 2022-10-24 + assert average_cost < 0.02, f"Expected average: {average_cost}. Expected < 0.02" + + +def test_get_certified_relevant_news_since_cached() -> None: + cache = RelevantNewsResponseCache(sqlalchemy_db_url="sqlite:///:memory:") + + question = ( + "Will the price of Bitcoin be higher than $100,000 by the end of the year?" + ) + days_ago = 5 + assert ( + cache.find(question=question, days_ago=days_ago) is None + ), "Cache should be empty" + + mock_news = RelevantNews( + question=question, + url="https://www.example.com", + summary="This is a summary", + relevance_reasoning="some reasoning", + days_ago=days_ago, + ) + with patch( + "prediction_market_agent_tooling.tools.relevant_news_analysis.relevant_news_analysis.get_certified_relevant_news_since" + ) as get_certified_relevant_news_since: + # Mock the response + get_certified_relevant_news_since.return_value = mock_news + + news = get_certified_relevant_news_since_cached( + question=question, + days_ago=1, + cache=cache, + ) + + assert news == mock_news + assert ( + cache.find(question=question, days_ago=days_ago) == mock_news + ), "Cache should contain the news"