-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
011529c
commit bddd65e
Showing
4 changed files
with
180 additions
and
42 deletions.
There are no files selected for viewing
47 changes: 47 additions & 0 deletions
47
prediction_market_agent_tooling/tools/relevant_news_analysis/data_models.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from pydantic import BaseModel, Field | ||
|
||
from prediction_market_agent_tooling.tools.tavily.tavily_models import TavilyResult | ||
|
||
|
||
class RelevantNewsAnalysis(BaseModel): | ||
reasoning: str = Field( | ||
..., | ||
description="The reason why the news contains information relevant to the given question. Or if no news is relevant, why not.", | ||
) | ||
contains_relevant_news: bool = Field( | ||
..., | ||
description="A boolean flag for whether the news contains information relevant to the given question.", | ||
) | ||
|
||
|
||
class RelevantNews(BaseModel): | ||
question: str | ||
url: str | ||
summary: str | ||
relevance_reasoning: str | ||
days_ago: int | ||
|
||
@staticmethod | ||
def from_tavily_result_and_analysis( | ||
question: str, | ||
days_ago: int, | ||
taviy_result: TavilyResult, | ||
relevant_news_analysis: RelevantNewsAnalysis, | ||
) -> "RelevantNews": | ||
return RelevantNews( | ||
question=question, | ||
url=taviy_result.url, | ||
summary=taviy_result.content, | ||
relevance_reasoning=relevant_news_analysis.reasoning, | ||
days_ago=days_ago, | ||
) | ||
|
||
|
||
class NoRelevantNews(BaseModel): | ||
""" | ||
A placeholder model for when no relevant news is found. Enables ability to | ||
distinguish between 'a cache hit with no news' and 'a cache miss'. | ||
""" | ||
|
||
# TODO not nice, but not sure how else to distinguish between no news and no | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
87 changes: 87 additions & 0 deletions
87
prediction_market_agent_tooling/tools/relevant_news_analysis/relevant_news_cache.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from datetime import datetime, timedelta | ||
|
||
from sqlmodel import Field, Session, SQLModel, create_engine, desc, select | ||
|
||
from prediction_market_agent_tooling.config import APIKeys | ||
from prediction_market_agent_tooling.loggers import logger | ||
from prediction_market_agent_tooling.tools.relevant_news_analysis.data_models import ( | ||
NoRelevantNews, | ||
RelevantNews, | ||
) | ||
from prediction_market_agent_tooling.tools.utils import check_not_none, utcnow | ||
|
||
|
||
class RelevantNewsCacheModel(SQLModel, table=True): | ||
__tablename__ = "relevant_news_response_cache" | ||
__table_args__ = {"extend_existing": True} | ||
id: int | None = Field(default=None, primary_key=True) | ||
question: str = Field(index=True) | ||
datetime_: datetime = Field(index=True) | ||
days_ago: int | ||
json_dump: str | None | ||
|
||
|
||
class RelevantNewsResponseCache: | ||
def __init__(self, sqlalchemy_db_url: str | None = None): | ||
self.engine = create_engine( | ||
sqlalchemy_db_url | ||
if sqlalchemy_db_url | ||
else APIKeys().sqlalchemy_db_url.get_secret_value() | ||
) | ||
self._initialize_db() | ||
|
||
def _initialize_db(self) -> None: | ||
""" | ||
Creates the tables if they don't exist | ||
""" | ||
|
||
# trick for making models import mandatory - models must be imported for metadata.create_all to work | ||
logger.debug(f"tables being added {RelevantNewsCacheModel}") | ||
SQLModel.metadata.create_all(self.engine) | ||
|
||
def find( | ||
self, | ||
question: str, | ||
days_ago: int, | ||
) -> RelevantNews | NoRelevantNews | None: | ||
with Session(self.engine) as session: | ||
query = ( | ||
select(RelevantNewsCacheModel) | ||
.where(RelevantNewsCacheModel.question == question) | ||
.where(RelevantNewsCacheModel.days_ago <= days_ago) | ||
.where(RelevantNewsCacheModel.datetime_ >= utcnow() - timedelta(days=1)) | ||
) | ||
item = session.exec( | ||
query.order_by(desc(RelevantNewsCacheModel.datetime_)) | ||
).first() | ||
|
||
if item == None: | ||
return None | ||
else: | ||
item = check_not_none(item) | ||
if item.json_dump is None: | ||
return NoRelevantNews() | ||
else: | ||
try: | ||
return RelevantNews.model_validate_json(item.json_dump) | ||
except ValueError as e: | ||
logger.error( | ||
f"Error deserializing RelevantNews from cache for {question=}, {days_ago=} and {item=}: {e}" | ||
) | ||
return None | ||
|
||
def save( | ||
self, | ||
question: str, | ||
days_ago: int, | ||
relevant_news: RelevantNews | None, | ||
) -> None: | ||
with Session(self.engine) as session: | ||
cached = RelevantNewsCacheModel( | ||
question=question, | ||
days_ago=days_ago, | ||
datetime_=utcnow(), # Assumes that the cache is being updated at the time the news is found | ||
json_dump=relevant_news.model_dump_json() if relevant_news else None, | ||
) | ||
session.add(cached) | ||
session.commit() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters