Skip to content

Commit

Permalink
feat; new sources
Browse files Browse the repository at this point in the history
  • Loading branch information
andoludo committed Dec 1, 2024
1 parent e34c9a9 commit bef016f
Show file tree
Hide file tree
Showing 8 changed files with 167 additions and 23 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Add columns
Revision ID: ae5d63df3c8b
Revision ID: 8947a52d321a
Revises:
Create Date: 2024-11-28 20:41:40.572630
Create Date: 2024-12-01 16:23:44.342936
"""

Expand All @@ -13,7 +13,7 @@
import sqlmodel

# revision identifiers, used by Alembic.
revision: str = "ae5d63df3c8b"
revision: str = "8947a52d321a"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
Expand Down Expand Up @@ -257,7 +257,6 @@ def upgrade() -> None:
sa.Column("isin", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("country", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("category_group", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("category", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("family", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("three_year_earnings_growth", sa.Float(), nullable=True),
sa.Column("annual_holdings_turnover", sa.Float(), nullable=True),
Expand All @@ -271,6 +270,7 @@ def upgrade() -> None:
sa.Column("sector_weightings", sa.JSON(), nullable=True),
sa.Column("total_net_assets", sa.Float(), nullable=True),
sa.Column("beta_3_year", sa.Float(), nullable=True),
sa.Column("category", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("fund_family", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column("fund_inception_date", sa.Float(), nullable=True),
sa.Column("legal_type", sqlmodel.sql.sqltypes.AutoString(), nullable=True),
Expand Down Expand Up @@ -349,11 +349,23 @@ def upgrade() -> None:
batch_op.create_index(batch_op.f("ix_price_source"), ["source"], unique=False)
batch_op.create_index(batch_op.f("ix_price_symbol"), ["symbol"], unique=False)

op.create_table(
"sources",
sa.Column("source", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.PrimaryKeyConstraint("source"),
)
with op.batch_alter_table("sources", schema=None) as batch_op:
batch_op.create_index(batch_op.f("ix_sources_source"), ["source"], unique=False)

# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("sources", schema=None) as batch_op:
batch_op.drop_index(batch_op.f("ix_sources_source"))

op.drop_table("sources")
with op.batch_alter_table("price", schema=None) as batch_op:
batch_op.drop_index(batch_op.f("ix_price_symbol"))
batch_op.drop_index(batch_op.f("ix_price_source"))
Expand Down
33 changes: 26 additions & 7 deletions bearish/database/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
CashFlowORM,
BalanceSheetORM,
PriceORM,
SourcesORM,
)
from bearish.database.scripts.upgrade import upgrade
from bearish.interface.interface import BearishDbBase
from bearish.models.financials.balance_sheet import BalanceSheet

from bearish.models.financials.base import Financials
Expand All @@ -32,7 +34,7 @@
from bearish.models.query.query import AssetQuery


class BearishDb(BaseModel):
class BearishDb(BearishDbBase):
model_config = ConfigDict(arbitrary_types_allowed=True)
database_path: Path

Expand All @@ -46,7 +48,7 @@ def _engine(self) -> Engine:
def model_post_init(self, __context: Any) -> None: # noqa: ANN401
self._engine # noqa: B018

def write_assets(self, assets: Assets) -> None:
def _write_assets(self, assets: Assets) -> None:
with Session(self._engine) as session:
objects_orm = (
[EquityORM(**object.model_dump()) for object in assets.equities]
Expand All @@ -58,7 +60,7 @@ def write_assets(self, assets: Assets) -> None:
session.add_all(objects_orm)
session.commit()

def write_series(self, series: List["Price"]) -> None:
def _write_series(self, series: List["Price"]) -> None:
with Session(self._engine) as session:
stmt = (
insert(PriceORM)
Expand All @@ -69,7 +71,7 @@ def write_series(self, series: List["Price"]) -> None:
session.exec(stmt) # type: ignore
session.commit()

def write_financials(self, financials: Financials) -> None:
def _write_financials(self, financials: Financials) -> None:
self._write_financials_series(financials.financial_metrics, FinancialMetricsORM)
self._write_financials_series(financials.cash_flows, CashFlowORM)
self._write_financials_series(financials.balance_sheets, BalanceSheetORM)
Expand All @@ -90,7 +92,7 @@ def _write_financials_series(
session.exec(stmt) # type: ignore
session.commit()

def read_series(self, query: "AssetQuery", months: int = 1) -> List[Price]:
def _read_series(self, query: "AssetQuery", months: int = 1) -> List[Price]:
end_date = datetime.now()
start_date = end_date - relativedelta(month=months)
with Session(self._engine) as session:
Expand All @@ -101,7 +103,7 @@ def read_series(self, query: "AssetQuery", months: int = 1) -> List[Price]:
series = session.exec(query_).all()
return [Price.model_validate(serie) for serie in series]

def read_financials(self, query: "AssetQuery") -> Financials:
def _read_financials(self, query: "AssetQuery") -> Financials:
with Session(self._engine) as session:
financial_metrics = self._read_asset_type(
session, FinancialMetrics, FinancialMetricsORM, query
Expand All @@ -116,7 +118,7 @@ def read_financials(self, query: "AssetQuery") -> Financials:
balance_sheets=balance_sheets,
)

def read_assets(self, query: "AssetQuery") -> Assets:
def _read_assets(self, query: "AssetQuery") -> Assets:
with Session(self._engine) as session:
from bearish.models.assets.equity import Equity
from bearish.models.assets.crypto import Crypto
Expand Down Expand Up @@ -157,3 +159,20 @@ def _read_asset_type(

assets = session.exec(query_).all()
return [table.model_validate(asset) for asset in assets]

def _read_sources(self) -> List[str]:
with Session(self._engine) as session:
query_ = select(SourcesORM).distinct()
sources = session.exec(query_).all()
return [source.source for source in sources]

def _write_source(self, source: str) -> None:
with Session(self._engine) as session:
stmt = (
insert(SourcesORM)
.prefix_with("OR REPLACE")
.values([{"source": source}])
)

session.exec(stmt) # type: ignore
session.commit()
5 changes: 5 additions & 0 deletions bearish/database/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,8 @@ class BalanceSheetORM(BaseFinancials, BalanceSheet, table=True): # type: ignore

class CashFlowORM(BaseFinancials, CashFlow, table=True): # type: ignore
__tablename__ = "cashflow"


class SourcesORM(SQLModel, table=True):
__tablename__ = "sources"
source: str = Field(primary_key=True, index=True)
Empty file added bearish/interface/__init__.py
Empty file.
71 changes: 71 additions & 0 deletions bearish/interface/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import abc
from pathlib import Path
from typing import List

from pydantic import BaseModel, ConfigDict, validate_call

from bearish.models.assets.assets import Assets
from bearish.models.financials.base import Financials
from bearish.models.price.price import Price
from bearish.models.query.query import AssetQuery


class BearishDbBase(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
database_path: Path

@validate_call
def write_assets(self, assets: Assets) -> None:
return self._write_assets(assets)

@validate_call
def write_series(self, series: List[Price]) -> None:
return self._write_series(series)

@validate_call
def write_financials(self, financials: Financials) -> None:
return self._write_financials(financials)

@validate_call
def read_series(self, query: AssetQuery, months: int = 1) -> List[Price]:
return self._read_series(query, months)

@validate_call
def read_financials(self, query: AssetQuery) -> Financials:
return self._read_financials(query)

@validate_call
def read_assets(self, query: AssetQuery) -> Assets:
return self._read_assets(query)

@validate_call
def read_sources(self) -> List[str]:
return self._read_sources()

@validate_call
def write_source(self, source: str) -> None:
return self._write_source(source)

@abc.abstractmethod
def _write_assets(self, assets: Assets) -> None: ...

@abc.abstractmethod
def _write_series(self, series: List[Price]) -> None: ...

@abc.abstractmethod
def _write_financials(self, financials: Financials) -> None: ...

@abc.abstractmethod
def _read_series(self, query: AssetQuery, months: int = 1) -> List[Price]: ...

@abc.abstractmethod
def _read_financials(self, query: AssetQuery) -> Financials: ...

@abc.abstractmethod
def _read_assets(self, query: AssetQuery) -> Assets: ...

@abc.abstractmethod
def _write_source(self, source: str) -> None: ...

@abc.abstractmethod
def _read_sources(self) -> List[str]: ...
28 changes: 24 additions & 4 deletions bearish/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@

from bearish.database.crud import BearishDb
from bearish.exceptions import InvalidApiKeyError
from bearish.interface.interface import BearishDbBase
from bearish.models.api_keys.api_keys import SourceApiKeys
from bearish.models.assets.assets import Assets
from bearish.models.financials.base import Financials
from bearish.models.price.price import Price
from bearish.models.query.query import AssetQuery
from bearish.sources.base import AbstractSource
from bearish.sources.financedatabase import FinanceDatabaseSource
from bearish.sources.financial_modelling_prep import FmpAssetsSource, FmpSource
from bearish.sources.investpy import InvestPySource
from bearish.sources.tiingo import TiingoSource
from bearish.sources.yfinance import yFinanceSource

logger = logging.getLogger(__name__)
Expand All @@ -31,18 +34,25 @@ class Bearish(BaseModel):
model_config = ConfigDict(extra="forbid")
path: Path
api_keys: SourceApiKeys = Field(default_factory=SourceApiKeys)
_bearish_db: BearishDb = PrivateAttr()
sources: List[AbstractSource] = Field(
_bearish_db: BearishDbBase = PrivateAttr()
asset_sources: List[AbstractSource] = Field(
default_factory=lambda: [
FinanceDatabaseSource(),
InvestPySource(),
FmpAssetsSource(),
]
)
sources: List[AbstractSource] = Field(
default_factory=lambda: [
yFinanceSource(),
FmpSource(),
TiingoSource(),
]
)

def model_post_init(self, __context: Any) -> None: # noqa: ANN401
self._bearish_db = BearishDb(database_path=self.path)
for source in self.sources:
for source in self.sources + self.asset_sources:
try:
source.set_api_key(
self.api_keys.keys.get(
Expand All @@ -56,7 +66,13 @@ def model_post_init(self, __context: Any) -> None: # noqa: ANN401
self.sources.remove(source)

def write_assets(self, query: Optional[AssetQuery] = None) -> None:
for source in self.sources:
existing_sources = self._bearish_db.read_sources()
asset_sources = [
asset_source
for asset_source in self.asset_sources
if asset_source.__source__ not in existing_sources
]
for source in asset_sources + self.sources:
if query:
cached_assets = self.read_assets(AssetQuery(countries=query.countries))
query.update_symbols(cached_assets)
Expand All @@ -66,6 +82,7 @@ def write_assets(self, query: Optional[AssetQuery] = None) -> None:
logger.warning(f"No assets found from {type(source).__name__}")
continue
self._bearish_db.write_assets(assets_)
self._bearish_db.write_source(source.__source__)

def read_assets(self, assets_query: AssetQuery) -> Assets:
return self._bearish_db.read_assets(assets_query)
Expand Down Expand Up @@ -113,6 +130,9 @@ def write_series(self, ticker: str, type: str) -> None:
if series:
self._bearish_db.write_series(series)

def read_sources(self) -> List[str]:
return self._bearish_db.read_sources()


@app.command()
def assets(path: Path, countries: List[str]) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/sources/test_alphavantage.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def test_alphavantage_read_series():
assert series


@pytest.mark.order(2)
@pytest.mark.skip("issue with api key")
def test_api_key():
alpha = AlphaVantageSource()
alpha.set_api_key("test")
Expand Down
Loading

0 comments on commit bef016f

Please sign in to comment.