From 73dd1ae3e4f3a09ba581d58aaa8d9e862f710029 Mon Sep 17 00:00:00 2001 From: Rich <24254625+richklee@users.noreply.github.com> Date: Sun, 9 Jun 2024 17:54:09 +0100 Subject: [PATCH] refactor for correct type hints --- alphavec/backtest.py | 59 ++++++++++++++++++++++-------------------- example.ipynb | 42 +++++++++++++++--------------- tests/test_backtest.py | 1 - 3 files changed, 52 insertions(+), 50 deletions(-) diff --git a/alphavec/backtest.py b/alphavec/backtest.py index fa69b60..4d1e075 100644 --- a/alphavec/backtest.py +++ b/alphavec/backtest.py @@ -7,7 +7,7 @@ DEFAULT_RISK_FREE_RATE = 0.02 -def zero_commission(weights: pd.DataFrame, prices: pd.DataFrame) -> float: +def zero_commission(weights: pd.DataFrame, prices: pd.DataFrame) -> pd.DataFrame: """Zero trading commission. Args: @@ -17,10 +17,12 @@ def zero_commission(weights: pd.DataFrame, prices: pd.DataFrame) -> float: Returns: Always returns 0. """ - return 0 + return pd.DataFrame(0, index=weights.index, columns=weights.columns) -def flat_commission(weights: pd.DataFrame, prices: pd.DataFrame, fee: float) -> float: +def flat_commission( + weights: pd.DataFrame, prices: pd.DataFrame, fee: float +) -> pd.DataFrame: """Flat commission applies a fixed fee per trade. Args: @@ -37,7 +39,9 @@ def flat_commission(weights: pd.DataFrame, prices: pd.DataFrame, fee: float) -> return commissions -def pct_commission(weights: pd.DataFrame, prices: pd.DataFrame, fee: float) -> float: +def pct_commission( + weights: pd.DataFrame, prices: pd.DataFrame, fee: float +) -> pd.DataFrame: """Percentage commission applies a percentage fee per trade. Args: @@ -60,7 +64,9 @@ def backtest( freq_day: int = 1, trading_days_year: int = DEFAULT_TRADING_DAYS_YEAR, shift_periods: int = 1, - commission_func: Callable[[pd.DataFrame, pd.DataFrame], float] = zero_commission, + commission_func: Callable[ + [pd.DataFrame, pd.DataFrame], pd.DataFrame + ] = zero_commission, ann_borrow_rate: float = 0, spread_pct: float = 0, ann_risk_free_rate: float = DEFAULT_RISK_FREE_RATE, @@ -69,7 +75,7 @@ def backtest( ]: """Backtest a trading strategy. - Strategy is simulated using the given weights, returns, and cost parameters. + Strategy is simulated using the given weights, prices, and cost parameters. Zero costs are calculated by default: no commission, no borrowing, no spread. To prevent look-ahead bias the returns will be shifted 1 interval by default relative to the weights during backtest. @@ -89,7 +95,7 @@ def backtest( Index should be a DatetimeIndex. Shape must match returns. prices: - Prices of the assets at each interval used to calculate returns ans costs. + Prices of the assets at each interval used to calculate returns and costs. Each column should be the mark prices for a specific asset, with the column name being the asset name. Column names should match weights. Index should be a DatetimeIndex. @@ -129,16 +135,16 @@ def backtest( asset_cum = (1 + asset_rets).cumprod() - 1 asset_perf = pd.concat( [ - asset_rets.apply( - _ann_sharpe, periods=freq_year, risk_free_rate=ann_risk_free_rate + _ann_sharpe( + asset_rets, periods=freq_year, risk_free_rate=ann_risk_free_rate ), - asset_rets.apply(_ann_vol, periods=freq_year), - asset_rets.apply(_cagr, periods=freq_year), - asset_rets.apply(_max_drawdown), + _ann_vol(asset_rets, periods=freq_year), + _cagr(asset_rets, periods=freq_year), + _max_drawdown(asset_rets), ], keys=["annual_sharpe", "annual_volatility", "cagr", "max_drawdown"], axis=1, - ) + ) # type: ignore # Backtest a cost-aware strategy as defined by the given weights: # 1. Calc costs @@ -173,12 +179,12 @@ def backtest( # Evaluate the strategy asset-wise performance strat_perf = pd.concat( [ - strat_rets.apply( - _ann_sharpe, periods=freq_year, risk_free_rate=ann_risk_free_rate + _ann_sharpe( + strat_rets, periods=freq_year, risk_free_rate=ann_risk_free_rate ), - strat_rets.apply(_ann_vol, periods=freq_year), - strat_rets.apply(_cagr, periods=freq_year), - strat_rets.apply(_max_drawdown), + _ann_vol(strat_rets, periods=freq_year), + _cagr(strat_rets, periods=freq_year), + _max_drawdown(strat_rets), strat_ann_turnover, _trade_count(weights) / strat_total_days, ], @@ -191,7 +197,7 @@ def backtest( "trades_per_day", ], axis=1, - ) + ) # type: ignore # Evaluate the strategy portfolio performance port_rets = strat_rets.sum(axis=1) @@ -213,8 +219,7 @@ def backtest( index=["portfolio"], ) - # Combine the asset and strategy performance metrics - # into a single dataframe for comparison + # Combine the asset and strategy performance metrics into a single dataframe for comparison perf = pd.concat( [asset_perf, strat_perf], keys=["asset", "strategy"], @@ -264,7 +269,7 @@ def _ann_sharpe( rets: pd.DataFrame | pd.Series, risk_free_rate: float = DEFAULT_RISK_FREE_RATE, periods: int = DEFAULT_TRADING_DAYS_YEAR, -) -> float: +) -> pd.DataFrame | pd.Series: ann_rfr = (1 + risk_free_rate) ** (1 / periods) - 1 mu = rets.mean() sigma = rets.std() @@ -290,11 +295,9 @@ def _cagr( ) -> pd.DataFrame | pd.Series: cumprod = (1 + rets).cumprod().dropna() if len(cumprod) == 0: - return 0 + return rets * 0 final = cumprod.iloc[-1] - if final <= 0: - return 0 n = len(cumprod) / periods cagr = final ** (1 / n) - 1 @@ -327,8 +330,8 @@ def _turnover( diff = weights.fillna(0).diff() # Capital is fixed (uncompounded) for each interval so we can calculate the trade volume # Sum the volume of the buy and sell trades - buy_volume = (diff.where(diff > 0, 0).abs() * capital).sum() - sell_volume = (diff.where(diff < 0, 0).abs() * capital).sum() + buy_volume = (diff.where(lambda x: x.gt(0), 0).abs() * capital).sum() + sell_volume = (diff.where(lambda x: x.lt(0), 0).abs() * capital).sum() # Trade volume is the minimum of the buy and sell volumes # Wrap in Series in case of scalar volume sum (when weights is a Series) trade_volume = pd.concat( @@ -341,7 +344,7 @@ def _turnover( return turnover -def _trade_count(weights: pd.DataFrame | pd.Series) -> pd.DataFrame | pd.Series: +def _trade_count(weights: pd.DataFrame | pd.Series) -> int | pd.Series: diff = weights.fillna(0).diff().abs() != 0 tx = diff.astype(int) return tx.sum() diff --git a/example.ipynb b/example.ipynb index d20ee79..32748b6 100644 --- a/example.ipynb +++ b/example.ipynb @@ -703,11 +703,11 @@ "