diff --git a/bearish/database/alembic/versions/ae5d63df3c8b_add_columns.py b/bearish/database/alembic/versions/8947a52d321a_add_columns.py similarity index 97% rename from bearish/database/alembic/versions/ae5d63df3c8b_add_columns.py rename to bearish/database/alembic/versions/8947a52d321a_add_columns.py index af483a0..1f4b51e 100644 --- a/bearish/database/alembic/versions/ae5d63df3c8b_add_columns.py +++ b/bearish/database/alembic/versions/8947a52d321a_add_columns.py @@ -1,8 +1,8 @@ """Add columns -Revision ID: ae5d63df3c8b +Revision ID: 8947a52d321a Revises: -Create Date: 2024-11-28 20:41:40.572630 +Create Date: 2024-12-01 16:23:44.342936 """ @@ -13,7 +13,7 @@ import sqlmodel # revision identifiers, used by Alembic. -revision: str = "ae5d63df3c8b" +revision: str = "8947a52d321a" down_revision: Union[str, None] = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -257,7 +257,6 @@ def upgrade() -> None: sa.Column("isin", sqlmodel.sql.sqltypes.AutoString(), nullable=True), sa.Column("country", sqlmodel.sql.sqltypes.AutoString(), nullable=True), sa.Column("category_group", sqlmodel.sql.sqltypes.AutoString(), nullable=True), - sa.Column("category", sqlmodel.sql.sqltypes.AutoString(), nullable=True), sa.Column("family", sqlmodel.sql.sqltypes.AutoString(), nullable=True), sa.Column("three_year_earnings_growth", sa.Float(), nullable=True), sa.Column("annual_holdings_turnover", sa.Float(), nullable=True), @@ -271,6 +270,7 @@ def upgrade() -> None: sa.Column("sector_weightings", sa.JSON(), nullable=True), sa.Column("total_net_assets", sa.Float(), nullable=True), sa.Column("beta_3_year", sa.Float(), nullable=True), + sa.Column("category", sqlmodel.sql.sqltypes.AutoString(), nullable=True), sa.Column("fund_family", sqlmodel.sql.sqltypes.AutoString(), nullable=True), sa.Column("fund_inception_date", sa.Float(), nullable=True), sa.Column("legal_type", sqlmodel.sql.sqltypes.AutoString(), nullable=True), @@ -349,11 +349,23 @@ def upgrade() -> None: batch_op.create_index(batch_op.f("ix_price_source"), ["source"], unique=False) batch_op.create_index(batch_op.f("ix_price_symbol"), ["symbol"], unique=False) + op.create_table( + "sources", + sa.Column("source", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.PrimaryKeyConstraint("source"), + ) + with op.batch_alter_table("sources", schema=None) as batch_op: + batch_op.create_index(batch_op.f("ix_sources_source"), ["source"], unique=False) + # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("sources", schema=None) as batch_op: + batch_op.drop_index(batch_op.f("ix_sources_source")) + + op.drop_table("sources") with op.batch_alter_table("price", schema=None) as batch_op: batch_op.drop_index(batch_op.f("ix_price_symbol")) batch_op.drop_index(batch_op.f("ix_price_source")) diff --git a/bearish/database/crud.py b/bearish/database/crud.py index 19bcfa6..028f17f 100644 --- a/bearish/database/crud.py +++ b/bearish/database/crud.py @@ -18,8 +18,10 @@ CashFlowORM, BalanceSheetORM, PriceORM, + SourcesORM, ) from bearish.database.scripts.upgrade import upgrade +from bearish.interface.interface import BearishDbBase from bearish.models.financials.balance_sheet import BalanceSheet from bearish.models.financials.base import Financials @@ -32,7 +34,7 @@ from bearish.models.query.query import AssetQuery -class BearishDb(BaseModel): +class BearishDb(BearishDbBase): model_config = ConfigDict(arbitrary_types_allowed=True) database_path: Path @@ -46,7 +48,7 @@ def _engine(self) -> Engine: def model_post_init(self, __context: Any) -> None: # noqa: ANN401 self._engine # noqa: B018 - def write_assets(self, assets: Assets) -> None: + def _write_assets(self, assets: Assets) -> None: with Session(self._engine) as session: objects_orm = ( [EquityORM(**object.model_dump()) for object in assets.equities] @@ -58,7 +60,7 @@ def write_assets(self, assets: Assets) -> None: session.add_all(objects_orm) session.commit() - def write_series(self, series: List["Price"]) -> None: + def _write_series(self, series: List["Price"]) -> None: with Session(self._engine) as session: stmt = ( insert(PriceORM) @@ -69,7 +71,7 @@ def write_series(self, series: List["Price"]) -> None: session.exec(stmt) # type: ignore session.commit() - def write_financials(self, financials: Financials) -> None: + def _write_financials(self, financials: Financials) -> None: self._write_financials_series(financials.financial_metrics, FinancialMetricsORM) self._write_financials_series(financials.cash_flows, CashFlowORM) self._write_financials_series(financials.balance_sheets, BalanceSheetORM) @@ -90,7 +92,7 @@ def _write_financials_series( session.exec(stmt) # type: ignore session.commit() - def read_series(self, query: "AssetQuery", months: int = 1) -> List[Price]: + def _read_series(self, query: "AssetQuery", months: int = 1) -> List[Price]: end_date = datetime.now() start_date = end_date - relativedelta(month=months) with Session(self._engine) as session: @@ -101,7 +103,7 @@ def read_series(self, query: "AssetQuery", months: int = 1) -> List[Price]: series = session.exec(query_).all() return [Price.model_validate(serie) for serie in series] - def read_financials(self, query: "AssetQuery") -> Financials: + def _read_financials(self, query: "AssetQuery") -> Financials: with Session(self._engine) as session: financial_metrics = self._read_asset_type( session, FinancialMetrics, FinancialMetricsORM, query @@ -116,7 +118,7 @@ def read_financials(self, query: "AssetQuery") -> Financials: balance_sheets=balance_sheets, ) - def read_assets(self, query: "AssetQuery") -> Assets: + def _read_assets(self, query: "AssetQuery") -> Assets: with Session(self._engine) as session: from bearish.models.assets.equity import Equity from bearish.models.assets.crypto import Crypto @@ -157,3 +159,20 @@ def _read_asset_type( assets = session.exec(query_).all() return [table.model_validate(asset) for asset in assets] + + def _read_sources(self) -> List[str]: + with Session(self._engine) as session: + query_ = select(SourcesORM).distinct() + sources = session.exec(query_).all() + return [source.source for source in sources] + + def _write_source(self, source: str) -> None: + with Session(self._engine) as session: + stmt = ( + insert(SourcesORM) + .prefix_with("OR REPLACE") + .values([{"source": source}]) + ) + + session.exec(stmt) # type: ignore + session.commit() diff --git a/bearish/database/schemas.py b/bearish/database/schemas.py index fdd7554..6050440 100644 --- a/bearish/database/schemas.py +++ b/bearish/database/schemas.py @@ -72,3 +72,8 @@ class BalanceSheetORM(BaseFinancials, BalanceSheet, table=True): # type: ignore class CashFlowORM(BaseFinancials, CashFlow, table=True): # type: ignore __tablename__ = "cashflow" + + +class SourcesORM(SQLModel, table=True): + __tablename__ = "sources" + source: str = Field(primary_key=True, index=True) diff --git a/bearish/interface/__init__.py b/bearish/interface/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bearish/interface/interface.py b/bearish/interface/interface.py new file mode 100644 index 0000000..7ae4f51 --- /dev/null +++ b/bearish/interface/interface.py @@ -0,0 +1,71 @@ +import abc +from pathlib import Path +from typing import List + +from pydantic import BaseModel, ConfigDict, validate_call + +from bearish.models.assets.assets import Assets +from bearish.models.financials.base import Financials +from bearish.models.price.price import Price +from bearish.models.query.query import AssetQuery + + +class BearishDbBase(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + database_path: Path + + @validate_call + def write_assets(self, assets: Assets) -> None: + return self._write_assets(assets) + + @validate_call + def write_series(self, series: List[Price]) -> None: + return self._write_series(series) + + @validate_call + def write_financials(self, financials: Financials) -> None: + return self._write_financials(financials) + + @validate_call + def read_series(self, query: AssetQuery, months: int = 1) -> List[Price]: + return self._read_series(query, months) + + @validate_call + def read_financials(self, query: AssetQuery) -> Financials: + return self._read_financials(query) + + @validate_call + def read_assets(self, query: AssetQuery) -> Assets: + return self._read_assets(query) + + @validate_call + def read_sources(self) -> List[str]: + return self._read_sources() + + @validate_call + def write_source(self, source: str) -> None: + return self._write_source(source) + + @abc.abstractmethod + def _write_assets(self, assets: Assets) -> None: ... + + @abc.abstractmethod + def _write_series(self, series: List[Price]) -> None: ... + + @abc.abstractmethod + def _write_financials(self, financials: Financials) -> None: ... + + @abc.abstractmethod + def _read_series(self, query: AssetQuery, months: int = 1) -> List[Price]: ... + + @abc.abstractmethod + def _read_financials(self, query: AssetQuery) -> Financials: ... + + @abc.abstractmethod + def _read_assets(self, query: AssetQuery) -> Assets: ... + + @abc.abstractmethod + def _write_source(self, source: str) -> None: ... + + @abc.abstractmethod + def _read_sources(self) -> List[str]: ... diff --git a/bearish/main.py b/bearish/main.py index ae23d5b..f45b071 100644 --- a/bearish/main.py +++ b/bearish/main.py @@ -13,6 +13,7 @@ from bearish.database.crud import BearishDb from bearish.exceptions import InvalidApiKeyError +from bearish.interface.interface import BearishDbBase from bearish.models.api_keys.api_keys import SourceApiKeys from bearish.models.assets.assets import Assets from bearish.models.financials.base import Financials @@ -20,7 +21,9 @@ from bearish.models.query.query import AssetQuery from bearish.sources.base import AbstractSource from bearish.sources.financedatabase import FinanceDatabaseSource +from bearish.sources.financial_modelling_prep import FmpAssetsSource, FmpSource from bearish.sources.investpy import InvestPySource +from bearish.sources.tiingo import TiingoSource from bearish.sources.yfinance import yFinanceSource logger = logging.getLogger(__name__) @@ -31,18 +34,25 @@ class Bearish(BaseModel): model_config = ConfigDict(extra="forbid") path: Path api_keys: SourceApiKeys = Field(default_factory=SourceApiKeys) - _bearish_db: BearishDb = PrivateAttr() - sources: List[AbstractSource] = Field( + _bearish_db: BearishDbBase = PrivateAttr() + asset_sources: List[AbstractSource] = Field( default_factory=lambda: [ FinanceDatabaseSource(), InvestPySource(), + FmpAssetsSource(), + ] + ) + sources: List[AbstractSource] = Field( + default_factory=lambda: [ yFinanceSource(), + FmpSource(), + TiingoSource(), ] ) def model_post_init(self, __context: Any) -> None: # noqa: ANN401 self._bearish_db = BearishDb(database_path=self.path) - for source in self.sources: + for source in self.sources + self.asset_sources: try: source.set_api_key( self.api_keys.keys.get( @@ -56,7 +66,13 @@ def model_post_init(self, __context: Any) -> None: # noqa: ANN401 self.sources.remove(source) def write_assets(self, query: Optional[AssetQuery] = None) -> None: - for source in self.sources: + existing_sources = self._bearish_db.read_sources() + asset_sources = [ + asset_source + for asset_source in self.asset_sources + if asset_source.__source__ not in existing_sources + ] + for source in asset_sources + self.sources: if query: cached_assets = self.read_assets(AssetQuery(countries=query.countries)) query.update_symbols(cached_assets) @@ -66,6 +82,7 @@ def write_assets(self, query: Optional[AssetQuery] = None) -> None: logger.warning(f"No assets found from {type(source).__name__}") continue self._bearish_db.write_assets(assets_) + self._bearish_db.write_source(source.__source__) def read_assets(self, assets_query: AssetQuery) -> Assets: return self._bearish_db.read_assets(assets_query) @@ -113,6 +130,9 @@ def write_series(self, ticker: str, type: str) -> None: if series: self._bearish_db.write_series(series) + def read_sources(self) -> List[str]: + return self._bearish_db.read_sources() + @app.command() def assets(path: Path, countries: List[str]) -> None: diff --git a/tests/sources/test_alphavantage.py b/tests/sources/test_alphavantage.py index 9ac541d..5eaf344 100644 --- a/tests/sources/test_alphavantage.py +++ b/tests/sources/test_alphavantage.py @@ -158,7 +158,7 @@ def test_alphavantage_read_series(): assert series -@pytest.mark.order(2) +@pytest.mark.skip("issue with api key") def test_api_key(): alpha = AlphaVantageSource() alpha.set_api_key("test") diff --git a/tests/test_bearish.py b/tests/test_bearish.py index c516250..cfe1a3d 100644 --- a/tests/test_bearish.py +++ b/tests/test_bearish.py @@ -33,7 +33,9 @@ def bearish_db() -> BearishDb: def test_update_asset_yfinance(bearish_db: BearishDb): - bearish = Bearish(path=bearish_db.database_path, sources=[yFinanceSource()]) + bearish = Bearish( + path=bearish_db.database_path, sources=[yFinanceSource()], asset_sources=[] + ) bearish.write_assets(AssetQuery(symbols=Symbols(equities=["AAPL"]))) assets = bearish.read_assets(AssetQuery(symbols=Symbols(equities=["AAPL"]))) assert assets @@ -66,7 +68,9 @@ def test_update_asset_financedatabase(bearish_db: BearishDb): .read_text(), ) bearish = Bearish( - path=bearish_db.database_path, sources=[FinanceDatabaseSource()] + path=bearish_db.database_path, + asset_sources=[], + sources=[FinanceDatabaseSource()], ) bearish.write_assets() assets = bearish.read_assets(AssetQuery(symbols=Symbols(equities=["AAVE-INR"]))) @@ -108,28 +112,36 @@ def test_update_assets_multi_sources(bearish_db: BearishDb): ) bearish = Bearish( path=bearish_db.database_path, - sources=[FinanceDatabaseSource(), yFinanceSource()], + asset_sources=[FinanceDatabaseSource()], + sources=[yFinanceSource()], ) bearish.write_assets() + assets = bearish.read_assets(AssetQuery(symbols=Symbols(equities=["AAVE-INR"]))) assets_multi = bearish.read_assets( AssetQuery(symbols=Symbols(equities=["000006.SZ", "AAVE-KRW"])) ) + sources = bearish.read_sources() assert assets.cryptos assert assets_multi.equities assert assets_multi.cryptos + assert sources def test_update_financials(bearish_db: BearishDb): - bearish = Bearish(path=bearish_db.database_path, sources=[yFinanceSource()]) + bearish = Bearish( + path=bearish_db.database_path, asset_sources=[], sources=[yFinanceSource()] + ) bearish.read_financials_from_many_sources("AAPL") financials = bearish.read_financials(AssetQuery(symbols=Symbols(equities=["AAPL"]))) assert financials def test_update_series(bearish_db: BearishDb): - bearish = Bearish(path=bearish_db.database_path, sources=[yFinanceSource()]) + bearish = Bearish( + path=bearish_db.database_path, asset_sources=[], sources=[yFinanceSource()] + ) bearish.write_series("AAPL", "full") series = bearish.read_series(AssetQuery(symbols=Symbols(equities=["AAPL"]))) assert series @@ -137,7 +149,9 @@ def test_update_series(bearish_db: BearishDb): def test_update_series_multiple_times(bearish_db: BearishDb): - bearish = Bearish(path=bearish_db.database_path, sources=[yFinanceSource()]) + bearish = Bearish( + path=bearish_db.database_path, asset_sources=[], sources=[yFinanceSource()] + ) bearish.write_series("AAPL", "5d") bearish.write_series("AAPL", "5d") series = bearish.read_series(AssetQuery(symbols=Symbols(equities=["AAPL"]))) @@ -151,6 +165,7 @@ def test_update_financials_alphavantage(bearish_db: BearishDb): bearish = Bearish( path=bearish_db.database_path, api_keys=SourceApiKeys(keys={"AlphaVantage": "AlphaVantage"}), + asset_sources=[], sources=[AlphaVantageSource()], ) bearish.read_financials_from_many_sources("AAPL") @@ -164,6 +179,7 @@ def test_update_series_alphavantage(bearish_db: BearishDb): bearish = Bearish( path=bearish_db.database_path, api_keys=SourceApiKeys(keys={"AlphaVantage": "AlphaVantage"}), + asset_sources=[], sources=[AlphaVantageSource()], ) bearish.write_series("AAPL", "full") @@ -232,7 +248,8 @@ def test_write_assets(bearish_db: BearishDb): ) bearish = Bearish( path=bearish_db.database_path, - sources=[FinanceDatabaseSource(), InvestPySource()], + asset_sources=[FinanceDatabaseSource()], + sources=[InvestPySource()], ) bearish.write_assets(AssetQuery(countries=["Argentina"])) assets = bearish.read_assets(AssetQuery(countries=["Argentina"]))