diff --git a/.flake8 b/.flake8 index 69331a0..ff16ab7 100644 --- a/.flake8 +++ b/.flake8 @@ -1,3 +1,3 @@ [flake8] max-line-length = 88 -extend-ignore = E203, E704, E741 \ No newline at end of file +extend-ignore = E203, E704, E741, F401 \ No newline at end of file diff --git a/yabte/backtest/asset.py b/yabte/backtest/asset.py index 6e31db6..5f3c3ac 100644 --- a/yabte/backtest/asset.py +++ b/yabte/backtest/asset.py @@ -1,6 +1,8 @@ +from __future__ import annotations + from dataclasses import dataclass from decimal import Decimal -from typing import TypeAlias, TypeVar, Union, cast +from typing import TYPE_CHECKING, TypeAlias, TypeVar, Union, cast import pandas as pd from mypy_extensions import mypyc_attr @@ -9,6 +11,7 @@ # use ints until mypyc supports IntFlag +# https://github.com/mypyc/mypyc/issues/1022 AssetDataFieldInfo = int ADFI_AVAILABLE_AT_CLOSE: int = 1 ADFI_AVAILABLE_AT_OPEN: int = 2 diff --git a/yabte/backtest/strategy.py b/yabte/backtest/strategy.py index 4b94f26..dcda760 100644 --- a/yabte/backtest/strategy.py +++ b/yabte/backtest/strategy.py @@ -193,7 +193,7 @@ def _check_data(df, asset_map): # check and fix data for each asset dfs = { - asset.data_label: asset.check_and_fix_data(df[asset.data_label]) + asset.data_label: asset.check_and_fix_data(asset._filter_data(df)) for asset_name, asset in asset_map.items() } diff --git a/yabte/backtest/transaction.py b/yabte/backtest/transaction.py index 478b176..a8d8164 100644 --- a/yabte/backtest/transaction.py +++ b/yabte/backtest/transaction.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import logging from dataclasses import dataclass from decimal import Decimal +from typing import TYPE_CHECKING import pandas as pd @@ -8,7 +11,8 @@ # (https://github.com/mypyc/mypyc/issues/1000) from pandas import Timestamp # type: ignore -from .asset import AssetName +if TYPE_CHECKING: + from .asset import AssetName logger = logging.getLogger(__name__) diff --git a/yabte/utilities/plot/matplotlib/strategy_runner.py b/yabte/utilities/plot/matplotlib/strategy_runner.py index 70acf8d..e29a14a 100644 --- a/yabte/utilities/plot/matplotlib/strategy_runner.py +++ b/yabte/utilities/plot/matplotlib/strategy_runner.py @@ -55,7 +55,7 @@ def plot_strategy_runner(sr: StrategyRunner, settings: dict[str, Any] | None = N for book, axs in zip(sr.books, axss.T): for i, asset in enumerate(traded_assets): - prices = sr.data[asset.data_label] + prices = asset._filter_data(sr.data) up = prices[prices.Close >= prices.Open] down = prices[prices.Close < prices.Open] diff --git a/yabte/utilities/plot/plotly/strategy_runner.py b/yabte/utilities/plot/plotly/strategy_runner.py index 2bd6b84..35e7a0b 100644 --- a/yabte/utilities/plot/plotly/strategy_runner.py +++ b/yabte/utilities/plot/plotly/strategy_runner.py @@ -48,7 +48,7 @@ def plot_strategy_runner(sr: StrategyRunner, settings: dict[str, Any] | None = N for col, book in enumerate(sr.books, start=1): for row, asset in enumerate(traded_assets, start=1): - prices = sr.data[asset.data_label] + prices = asset._filter_data(sr.data) fig.add_trace( go.Candlestick(