Skip to content

Commit

Permalink
more typing improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
bsdz committed Jan 12, 2024
1 parent 058d26b commit 8381d11
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 30 deletions.
6 changes: 3 additions & 3 deletions notebooks/Delta_Hedging.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
"import pyfeng as pf\n",
"\n",
"from yabte.backtest import (\n",
" ADFI_AVAILABLE_AT_CLOSE,\n",
" ADFI_AVAILABLE_AT_OPEN,\n",
" Asset,\n",
" AssetDataFieldInfo,\n",
" Book,\n",
" CashTransaction,\n",
" Order,\n",
Expand Down Expand Up @@ -73,8 +74,7 @@
" dfs.append(\n",
" (\n",
" \"IVol\",\n",
" AssetDataFieldInfo.AVAILABLE_AT_CLOSE\n",
" | AssetDataFieldInfo.AVAILABLE_AT_OPEN,\n",
" ADFI_AVAILABLE_AT_CLOSE | ADFI_AVAILABLE_AT_OPEN,\n",
" )\n",
" )\n",
" return dfs\n",
Expand Down
9 changes: 8 additions & 1 deletion yabte/backtest/asset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from decimal import Decimal
from typing import TypeAlias
from typing import TypeAlias, TypeVar, Union, cast

import pandas as pd
from mypy_extensions import mypyc_attr
Expand All @@ -14,6 +14,7 @@
ADFI_AVAILABLE_AT_OPEN: int = 2
ADFI_REQUIRED: int = 4

T = TypeVar("T", bound=Union[pd.Series, pd.DataFrame])

AssetName: TypeAlias = str
"""Asset name string."""
Expand Down Expand Up @@ -84,6 +85,12 @@ def _get_fields(self, field_info: AssetDataFieldInfo) -> list[str]:
"""Internal method to get fields from `data_fields` with `field_info`."""
return [f for f, fi in self.data_fields() if fi & field_info]

def _filter_data(self, data: T) -> T:
"""Internal method to filter `data` columns and return only those relevant to
pricing."""
assert isinstance(self.data_label, str)
return cast(T, data[self.data_label])


@mypyc_attr(allow_interpreted_subclasses=True)
@dataclass(kw_only=True)
Expand Down
4 changes: 2 additions & 2 deletions yabte/backtest/book.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def add_transactions(self, transactions: Sequence[Transaction]):
self.transactions.append(tran)

def eod_tasks(
self, ts: pd.Timestamp, day_data: pd.DataFrame, asset_map: Dict[str, Asset]
self, ts: pd.Timestamp, day_data: pd.Series, asset_map: Dict[str, Asset]
):
"""Run end of day tasks such as book keeping."""
# accumulate continously compounded interest
Expand All @@ -119,7 +119,7 @@ def eod_tasks(
cash = float(self.cash)
mtm = float(
sum(
asset.end_of_day_price(day_data[asset.data_label]) * q
asset.end_of_day_price(asset._filter_data(day_data)) * q
for an, q in self.positions.items()
if (asset := asset_map.get(an))
)
Expand Down
30 changes: 11 additions & 19 deletions yabte/backtest/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,7 @@ def post_complete(self, trades: List[Trade]):
"""
pass

def apply(
self, ts: pd.Timestamp, day_data: pd.DataFrame, asset_map: Dict[str, Asset]
):
def apply(self, ts: pd.Timestamp, day_data: pd.Series, asset_map: Dict[str, Asset]):
"""Applies order to `self.book` for time `ts` using provided `day_data` and
dictionary of asset information `asset_map`."""
raise NotImplementedError("The apply methods needs to be implemented.")
Expand Down Expand Up @@ -128,9 +126,11 @@ def __post_init__(self):
self.size = ensure_decimal(self.size)
self.size_type = ensure_enum(self.size_type, OrderSizeType)

def _calc_quantity_price(self, day_data, asset_map) -> Tuple[Decimal, Decimal]:
def _calc_quantity_price(
self, day_data: pd.Series, asset_map: Dict[str, Asset]
) -> Tuple[Decimal, Decimal]:
asset = asset_map[self.asset_name]
asset_day_data = day_data[asset.data_label]
asset_day_data = asset._filter_data(day_data)
trade_price = asset.intraday_traded_price(asset_day_data, size=self.size)

if self.size_type == OrderSizeType.QUANTITY:
Expand Down Expand Up @@ -159,9 +159,7 @@ def pre_execute_check(
"""
return None

def apply(
self, ts: pd.Timestamp, day_data: pd.DataFrame, asset_map: Dict[str, Asset]
):
def apply(self, ts: pd.Timestamp, day_data: pd.Series, asset_map: Dict[str, Asset]):
if not self.book or not isinstance(self.book, Book):
raise RuntimeError("Cannot apply order without book instance")

Expand Down Expand Up @@ -202,9 +200,7 @@ def __post_init__(self):
super().__post_init__()
self.check_type = ensure_enum(self.check_type, PositionalOrderCheckType)

def apply(
self, ts: pd.Timestamp, day_data: pd.DataFrame, asset_map: Dict[str, Asset]
):
def apply(self, ts: pd.Timestamp, day_data: pd.Series, asset_map: Dict[str, Asset]):
if not self.book or not isinstance(self.book, Book):
raise RuntimeError("Cannot apply order without book instance")

Expand Down Expand Up @@ -275,10 +271,10 @@ def __post_init__(self):
self.size_type = ensure_enum(self.size_type, OrderSizeType)

def _calc_quantity_price(
self, day_data, asset_map
self, day_data: pd.Series, asset_map: Dict[str, Asset]
) -> List[Tuple[Decimal, Decimal]]:
assets = [asset_map[an] for an in self.asset_names]
assets_day_data = [day_data[a.data_label] for a in assets]
assets_day_data = [a._filter_data(day_data) for a in assets]
trade_prices = [
asset.intraday_traded_price(add, size=self.size)
for asset, add in zip(assets, assets_day_data)
Expand Down Expand Up @@ -313,9 +309,7 @@ def _calc_quantity_price(
for a, q, tp in zip(assets, quantities, trade_prices)
]

def apply(
self, ts: pd.Timestamp, day_data: pd.DataFrame, asset_map: Dict[str, Asset]
):
def apply(self, ts: pd.Timestamp, day_data: pd.Series, asset_map: Dict[str, Asset]):
if not self.book or not isinstance(self.book, Book):
raise RuntimeError("Cannot apply order without book instance")

Expand Down Expand Up @@ -343,9 +337,7 @@ class PositionalBasketOrder(BasketOrder):

check_type: PositionalOrderCheckType = PositionalOrderCheckType.POS_TQ_DIFFER

def apply(
self, ts: pd.Timestamp, day_data: pd.DataFrame, asset_map: Dict[str, Asset]
):
def apply(self, ts: pd.Timestamp, day_data: pd.Series, asset_map: Dict[str, Asset]):
if not self.book or not isinstance(self.book, Book):
raise RuntimeError("Cannot apply order without book instance")

Expand Down
4 changes: 2 additions & 2 deletions yabte/tests/test_strategy_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def init(self):

def on_close(self):
p = self.params
s = self.data["SPREAD"].Close[-1]
s = self.data["SPREAD"].Close.iloc[-1]
if s < self.mu - 0.5 * self.sigma:
self.orders.append(PositionalOrder(asset_name=p.s1, size=100))
self.orders.append(PositionalOrder(asset_name=p.s2, size=p.factor * 100))
Expand Down Expand Up @@ -159,7 +159,7 @@ def test_sma_crossover(self):
th.pivot_table(index="ts", columns="book", values="nc", aggfunc="sum")
.cumsum()
.reindex(sr.data.index)
.fillna(method="ffill")
.ffill()
.fillna(0)
)
self.assertTrue(
Expand Down
2 changes: 1 addition & 1 deletion yabte/utilities/pandas_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def log_returns(self):

@property
def returns(self):
return self._obj.pct_change()[1:]
return self._obj.pct_change(fill_method=None)[1:]

@property
def frequency(self):
Expand Down
2 changes: 1 addition & 1 deletion yabte/utilities/portopt/hierarchical_risk_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _getClusterVar(cov, cItems):

def _getRecBipart(cov, sortIx):
# Compute HRP alloc
w = pd.Series(1, index=sortIx)
w = pd.Series(1., index=sortIx)
cItems = [sortIx] # initialize all items in one cluster
while len(cItems) > 0:
cItems = [
Expand Down
4 changes: 3 additions & 1 deletion yabte/utilities/strategy_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ def crossover(series1: pd.Series, series2: pd.Series) -> bool:
True
"""
try:
return series1[-2] < series2[-2] and series1[-1] > series2[-1]
return (
series1.iloc[-2] < series2.iloc[-2] and series1.iloc[-1] > series2.iloc[-1]
)
except IndexError:
return False

0 comments on commit 8381d11

Please sign in to comment.