Skip to content

Commit 8381d11

Browse files
committed
more typing improvements
1 parent 058d26b commit 8381d11

File tree

8 files changed

+31
-30
lines changed

8 files changed

+31
-30
lines changed

notebooks/Delta_Hedging.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@
2727
"import pyfeng as pf\n",
2828
"\n",
2929
"from yabte.backtest import (\n",
30+
" ADFI_AVAILABLE_AT_CLOSE,\n",
31+
" ADFI_AVAILABLE_AT_OPEN,\n",
3032
" Asset,\n",
31-
" AssetDataFieldInfo,\n",
3233
" Book,\n",
3334
" CashTransaction,\n",
3435
" Order,\n",
@@ -73,8 +74,7 @@
7374
" dfs.append(\n",
7475
" (\n",
7576
" \"IVol\",\n",
76-
" AssetDataFieldInfo.AVAILABLE_AT_CLOSE\n",
77-
" | AssetDataFieldInfo.AVAILABLE_AT_OPEN,\n",
77+
" ADFI_AVAILABLE_AT_CLOSE | ADFI_AVAILABLE_AT_OPEN,\n",
7878
" )\n",
7979
" )\n",
8080
" return dfs\n",

yabte/backtest/asset.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataclasses import dataclass
22
from decimal import Decimal
3-
from typing import TypeAlias
3+
from typing import TypeAlias, TypeVar, Union, cast
44

55
import pandas as pd
66
from mypy_extensions import mypyc_attr
@@ -14,6 +14,7 @@
1414
ADFI_AVAILABLE_AT_OPEN: int = 2
1515
ADFI_REQUIRED: int = 4
1616

17+
T = TypeVar("T", bound=Union[pd.Series, pd.DataFrame])
1718

1819
AssetName: TypeAlias = str
1920
"""Asset name string."""
@@ -84,6 +85,12 @@ def _get_fields(self, field_info: AssetDataFieldInfo) -> list[str]:
8485
"""Internal method to get fields from `data_fields` with `field_info`."""
8586
return [f for f, fi in self.data_fields() if fi & field_info]
8687

88+
def _filter_data(self, data: T) -> T:
89+
"""Internal method to filter `data` columns and return only those relevant to
90+
pricing."""
91+
assert isinstance(self.data_label, str)
92+
return cast(T, data[self.data_label])
93+
8794

8895
@mypyc_attr(allow_interpreted_subclasses=True)
8996
@dataclass(kw_only=True)

yabte/backtest/book.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def add_transactions(self, transactions: Sequence[Transaction]):
101101
self.transactions.append(tran)
102102

103103
def eod_tasks(
104-
self, ts: pd.Timestamp, day_data: pd.DataFrame, asset_map: Dict[str, Asset]
104+
self, ts: pd.Timestamp, day_data: pd.Series, asset_map: Dict[str, Asset]
105105
):
106106
"""Run end of day tasks such as book keeping."""
107107
# accumulate continously compounded interest
@@ -119,7 +119,7 @@ def eod_tasks(
119119
cash = float(self.cash)
120120
mtm = float(
121121
sum(
122-
asset.end_of_day_price(day_data[asset.data_label]) * q
122+
asset.end_of_day_price(asset._filter_data(day_data)) * q
123123
for an, q in self.positions.items()
124124
if (asset := asset_map.get(an))
125125
)

yabte/backtest/order.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,7 @@ def post_complete(self, trades: List[Trade]):
9898
"""
9999
pass
100100

101-
def apply(
102-
self, ts: pd.Timestamp, day_data: pd.DataFrame, asset_map: Dict[str, Asset]
103-
):
101+
def apply(self, ts: pd.Timestamp, day_data: pd.Series, asset_map: Dict[str, Asset]):
104102
"""Applies order to `self.book` for time `ts` using provided `day_data` and
105103
dictionary of asset information `asset_map`."""
106104
raise NotImplementedError("The apply methods needs to be implemented.")
@@ -128,9 +126,11 @@ def __post_init__(self):
128126
self.size = ensure_decimal(self.size)
129127
self.size_type = ensure_enum(self.size_type, OrderSizeType)
130128

131-
def _calc_quantity_price(self, day_data, asset_map) -> Tuple[Decimal, Decimal]:
129+
def _calc_quantity_price(
130+
self, day_data: pd.Series, asset_map: Dict[str, Asset]
131+
) -> Tuple[Decimal, Decimal]:
132132
asset = asset_map[self.asset_name]
133-
asset_day_data = day_data[asset.data_label]
133+
asset_day_data = asset._filter_data(day_data)
134134
trade_price = asset.intraday_traded_price(asset_day_data, size=self.size)
135135

136136
if self.size_type == OrderSizeType.QUANTITY:
@@ -159,9 +159,7 @@ def pre_execute_check(
159159
"""
160160
return None
161161

162-
def apply(
163-
self, ts: pd.Timestamp, day_data: pd.DataFrame, asset_map: Dict[str, Asset]
164-
):
162+
def apply(self, ts: pd.Timestamp, day_data: pd.Series, asset_map: Dict[str, Asset]):
165163
if not self.book or not isinstance(self.book, Book):
166164
raise RuntimeError("Cannot apply order without book instance")
167165

@@ -202,9 +200,7 @@ def __post_init__(self):
202200
super().__post_init__()
203201
self.check_type = ensure_enum(self.check_type, PositionalOrderCheckType)
204202

205-
def apply(
206-
self, ts: pd.Timestamp, day_data: pd.DataFrame, asset_map: Dict[str, Asset]
207-
):
203+
def apply(self, ts: pd.Timestamp, day_data: pd.Series, asset_map: Dict[str, Asset]):
208204
if not self.book or not isinstance(self.book, Book):
209205
raise RuntimeError("Cannot apply order without book instance")
210206

@@ -275,10 +271,10 @@ def __post_init__(self):
275271
self.size_type = ensure_enum(self.size_type, OrderSizeType)
276272

277273
def _calc_quantity_price(
278-
self, day_data, asset_map
274+
self, day_data: pd.Series, asset_map: Dict[str, Asset]
279275
) -> List[Tuple[Decimal, Decimal]]:
280276
assets = [asset_map[an] for an in self.asset_names]
281-
assets_day_data = [day_data[a.data_label] for a in assets]
277+
assets_day_data = [a._filter_data(day_data) for a in assets]
282278
trade_prices = [
283279
asset.intraday_traded_price(add, size=self.size)
284280
for asset, add in zip(assets, assets_day_data)
@@ -313,9 +309,7 @@ def _calc_quantity_price(
313309
for a, q, tp in zip(assets, quantities, trade_prices)
314310
]
315311

316-
def apply(
317-
self, ts: pd.Timestamp, day_data: pd.DataFrame, asset_map: Dict[str, Asset]
318-
):
312+
def apply(self, ts: pd.Timestamp, day_data: pd.Series, asset_map: Dict[str, Asset]):
319313
if not self.book or not isinstance(self.book, Book):
320314
raise RuntimeError("Cannot apply order without book instance")
321315

@@ -343,9 +337,7 @@ class PositionalBasketOrder(BasketOrder):
343337

344338
check_type: PositionalOrderCheckType = PositionalOrderCheckType.POS_TQ_DIFFER
345339

346-
def apply(
347-
self, ts: pd.Timestamp, day_data: pd.DataFrame, asset_map: Dict[str, Asset]
348-
):
340+
def apply(self, ts: pd.Timestamp, day_data: pd.Series, asset_map: Dict[str, Asset]):
349341
if not self.book or not isinstance(self.book, Book):
350342
raise RuntimeError("Cannot apply order without book instance")
351343

yabte/tests/test_strategy_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def init(self):
106106

107107
def on_close(self):
108108
p = self.params
109-
s = self.data["SPREAD"].Close[-1]
109+
s = self.data["SPREAD"].Close.iloc[-1]
110110
if s < self.mu - 0.5 * self.sigma:
111111
self.orders.append(PositionalOrder(asset_name=p.s1, size=100))
112112
self.orders.append(PositionalOrder(asset_name=p.s2, size=p.factor * 100))
@@ -159,7 +159,7 @@ def test_sma_crossover(self):
159159
th.pivot_table(index="ts", columns="book", values="nc", aggfunc="sum")
160160
.cumsum()
161161
.reindex(sr.data.index)
162-
.fillna(method="ffill")
162+
.ffill()
163163
.fillna(0)
164164
)
165165
self.assertTrue(

yabte/utilities/pandas_extension.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def log_returns(self):
3333

3434
@property
3535
def returns(self):
36-
return self._obj.pct_change()[1:]
36+
return self._obj.pct_change(fill_method=None)[1:]
3737

3838
@property
3939
def frequency(self):

yabte/utilities/portopt/hierarchical_risk_parity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _getClusterVar(cov, cItems):
3636

3737
def _getRecBipart(cov, sortIx):
3838
# Compute HRP alloc
39-
w = pd.Series(1, index=sortIx)
39+
w = pd.Series(1., index=sortIx)
4040
cItems = [sortIx] # initialize all items in one cluster
4141
while len(cItems) > 0:
4242
cItems = [

yabte/utilities/strategy_helpers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ def crossover(series1: pd.Series, series2: pd.Series) -> bool:
99
True
1010
"""
1111
try:
12-
return series1[-2] < series2[-2] and series1[-1] > series2[-1]
12+
return (
13+
series1.iloc[-2] < series2.iloc[-2] and series1.iloc[-1] > series2.iloc[-1]
14+
)
1315
except IndexError:
1416
return False

0 commit comments

Comments
 (0)