Skip to content

Commit

Permalink
generalize backtesting asset fields.
Browse files Browse the repository at this point in the history
support size effects on pricing.
improve typing.
bump a couple of packages.
  • Loading branch information
bsdz committed Jan 10, 2024
1 parent f528476 commit f685c5a
Show file tree
Hide file tree
Showing 10 changed files with 185 additions and 109 deletions.
3 changes: 3 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[flake8]
max-line-length = 88
extend-ignore = E203, E704
109 changes: 69 additions & 40 deletions notebooks/Delta_Hedging.ipynb

Large diffs are not rendered by default.

65 changes: 33 additions & 32 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ generate-setup-file = true
python = "^3.10,<3.12"
pandas = ">1.5,<3"
scipy = "^1.10.0"
pandas-stubs = "^2.0.0.230412"
mypy = "^1.3.0"
pandas-stubs = "^2.1.4.231227"
mypy = "^1.8.0"

[tool.poetry.group.dev]
optional = true
Expand Down Expand Up @@ -51,5 +51,5 @@ nbconvert = "^7.2.9"
profile = "black"

[build-system]
requires = ["poetry-core>=1.0.0", "setuptools==67.8.0", "mypy==1.3.0", "pandas-stubs==2.0.0.230412"]
requires = ["poetry-core>=1.0.0", "setuptools==69.0.2", "mypy==1.7.1", "pandas-stubs==2.1.4.231227"]
build-backend = "poetry.core.masonry.api"
3 changes: 2 additions & 1 deletion yabte/backtest/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# TODO: use absolute imports until mypyc fixes relative imports in __init__.py
# (https://github.com/mypyc/mypyc/issues/996)
from yabte.backtest.asset import Asset, AssetName
from yabte.backtest.asset import Asset, AssetDataFieldInfo, AssetName
from yabte.backtest.book import Book, BookMandate, BookName
from yabte.backtest.order import (
BasketOrder,
Expand All @@ -16,6 +16,7 @@

__all__ = [
"Asset",
"AssetDataFieldInfo",
"AssetName",
"Book",
"BookName",
Expand Down
8 changes: 5 additions & 3 deletions yabte/backtest/_helpers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import logging
from decimal import Decimal
from enum import Enum
from typing import Any, Type
from typing import Any, Type, TypeVar

logger = logging.getLogger(__name__)

T = TypeVar("T", bound=Enum)

def ensure_enum(value: Any, enum_type: Type[Enum]):

def ensure_enum(value: Any, enum_type: Type[T]) -> T:
if isinstance(value, enum_type):
return value
if isinstance(value, str):
Expand All @@ -16,7 +18,7 @@ def ensure_enum(value: Any, enum_type: Type[Enum]):
raise ValueError(f"Unexpected enum type {value} for {enum_type}")


def ensure_decimal(value: Any):
def ensure_decimal(value: Any) -> Decimal:
if isinstance(value, Decimal):
return value
if isinstance(value, (str, float, int)):
Expand Down
78 changes: 56 additions & 22 deletions yabte/backtest/asset.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
from dataclasses import dataclass
from decimal import Decimal
from typing import Sequence, TypeAlias
from enum import Flag, auto
from typing import TypeAlias

import pandas as pd
from mypy_extensions import mypyc_attr

__all__ = ["Asset"]


class AssetDataFieldInfo(Flag):
AVAILABLE_AT_CLOSE = auto()
AVAILABLE_AT_OPEN = auto()
REQUIRED = auto()


_FI = AssetDataFieldInfo


AssetName: TypeAlias = str
"""Asset name string."""

Expand Down Expand Up @@ -43,23 +53,36 @@ def round_quantity(self, quantity) -> Decimal:
"""Round `quantity`."""
return round(quantity, self.quantity_round_dp)

@property
def fields_available_at_open(self) -> Sequence[str]:
"""A sequence of field names available at open.
Any fields not in this sequence will be masked out.
"""
return []

def intraday_traded_price(self, asset_day_data) -> Decimal:
def intraday_traded_price(
self, asset_day_data: pd.Series, size: Decimal | None = None
) -> Decimal:
"""Calculate price during market hours with given row of
`asset_day_data`."""
raise NotImplementedError("The apply methods needs to be implemented.")
`asset_day_data` and the order `size`. The `size` can be used to
determine a price from say, bid / ask spreads."""
raise NotImplementedError(
"The intraday_traded_price method needs to be implemented."
)

def end_of_day_price(self, asset_day_data: pd.Series) -> Decimal:
"""Calculate price at end of day with given row of `asset_day_data`."""
raise NotImplementedError(
"The end_of_day_price method needs to be implemented."
)

def check_and_fix_data(self, data: pd.DataFrame) -> pd.DataFrame:
"""Checks dataframe `data` has correct fields and fixes columns where
necessary."""
raise NotImplementedError("The apply methods needs to be implemented.")
raise NotImplementedError(
"The check_and_fix_data method needs to be implemented."
)

def data_fields(self) -> list[tuple[str, AssetDataFieldInfo]]:
"""List of data fields and their availability."""
raise NotImplementedError("The data_fields method needs to be implemented.")

def _get_fields(self, field_info: AssetDataFieldInfo) -> list[str]:
"""Internal method to get fields from `data_fields` with `field_info`."""
return [f for f, fi in self.data_fields() if fi & field_info]


@mypyc_attr(allow_interpreted_subclasses=True)
Expand All @@ -68,30 +91,41 @@ class Asset(AssetBase):
"""Assets whose price history is represented by High, Low, Open, Close and Volume
fields."""

@property
def fields_available_at_open(self) -> Sequence[str]:
return ["Open"]

def intraday_traded_price(self, asset_day_data) -> Decimal:
def data_fields(self) -> list[tuple[str, AssetDataFieldInfo]]:
return [
("High", _FI.AVAILABLE_AT_CLOSE),
("Low", _FI.AVAILABLE_AT_CLOSE),
("Open", _FI.AVAILABLE_AT_CLOSE | _FI.AVAILABLE_AT_OPEN),
("Close", _FI.AVAILABLE_AT_CLOSE | _FI.REQUIRED),
("Volume", _FI.AVAILABLE_AT_CLOSE),
]

def intraday_traded_price(
self, asset_day_data: pd.Series, size: Decimal | None = None
) -> Decimal:
if pd.notnull(asset_day_data.Low) and pd.notnull(asset_day_data.High):
p = Decimal((asset_day_data.Low + asset_day_data.High) / 2)
else:
p = Decimal(asset_day_data.Close)
return round(p, self.price_round_dp)

def end_of_day_price(self, asset_day_data: pd.Series) -> Decimal:
return round(Decimal(asset_day_data.Close), self.price_round_dp)

def check_and_fix_data(self, data: pd.DataFrame) -> pd.DataFrame:
# TODO: check low <= open, high, close & high >= open, low, close
# TODO: check volume >= 0

# check each asset has required fields
required_fields = {"Close"}
missing_req_fields = required_fields - set(data.columns)
required_fields = self._get_fields(_FI.REQUIRED)
missing_req_fields = set(required_fields) - set(data.columns)
if len(missing_req_fields):
raise ValueError(
f"data columns index requires fields {required_fields} and missing {missing_req_fields}"
f"data columns index requires fields {required_fields}"
f" and missing {missing_req_fields}"
)

# reindex columns with expected fields + additional fields
expected_fields = ["High", "Low", "Open", "Close", "Volume"]
expected_fields = self._get_fields(_FI.AVAILABLE_AT_CLOSE)
other_fields = list(set(data.columns) - set(expected_fields))
return data.reindex(expected_fields + other_fields, axis=1)
9 changes: 6 additions & 3 deletions yabte/backtest/book.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,11 @@ def eod_tasks(
]
)
cash = float(self.cash)
mtm = sum(
day_data[asset_map[an].data_label].Close * float(q)
for an, q in self.positions.items()
mtm = float(
sum(
asset.end_of_day_price(day_data[asset.data_label]) * q
for an, q in self.positions.items()
if (asset := asset_map.get(an))
)
)
self._history.append([ts, cash, mtm, cash + mtm])
4 changes: 2 additions & 2 deletions yabte/backtest/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __post_init__(self):
def _calc_quantity_price(self, day_data, asset_map) -> Tuple[Decimal, Decimal]:
asset = asset_map[self.asset_name]
asset_day_data = day_data[asset.data_label]
trade_price = asset.intraday_traded_price(asset_day_data)
trade_price = asset.intraday_traded_price(asset_day_data, size=self.size)

if self.size_type == OrderSizeType.QUANTITY:
quantity = self.size
Expand Down Expand Up @@ -280,7 +280,7 @@ def _calc_quantity_price(
assets = [asset_map[an] for an in self.asset_names]
assets_day_data = [day_data[a.data_label] for a in assets]
trade_prices = [
asset.intraday_traded_price(add)
asset.intraday_traded_price(add, size=self.size)
for asset, add in zip(assets, assets_day_data)
]

Expand Down
9 changes: 6 additions & 3 deletions yabte/backtest/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# (https://github.com/mypyc/mypyc/issues/1000)
from pandas import DataFrame, Series, Timestamp # type: ignore

from .asset import Asset, AssetName
from .asset import Asset, AssetDataFieldInfo, AssetName
from .book import Book, BookMandate, BookName
from .order import Order, OrderBase, OrderStatus

Expand Down Expand Up @@ -109,7 +109,10 @@ def _get_col_indexer(self):
mix = pd.MultiIndex.from_tuples(
chain(
*[
product([asset.data_label], asset.fields_available_at_open)
product(
[asset.data_label],
asset._get_fields(AssetDataFieldInfo.AVAILABLE_AT_OPEN),
)
for asset_name, asset in self.assets.items()
]
)
Expand Down Expand Up @@ -174,7 +177,7 @@ def _check_data(df, asset_map):
if not df.index.is_unique:
raise ValueError("data index must be unique")

# colum level 1 = asset, level 2 = field
# column level 1 = asset, level 2 = field
if not isinstance(df.columns, pd.MultiIndex):
raise ValueError("data columns must be multindex asset/field")
if len(df.columns.levels) != 2:
Expand Down

0 comments on commit f685c5a

Please sign in to comment.