Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultiBacktest multi-dataset backtesting wrapper #1223

Merged
merged 3 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ jobs:
timeout-minutes: 3
strategy:
matrix:
python-version: [3.11, 3.12, 3.13]
python-version: [3.12, 3.13]
experimental: [false]
include:
- python-version: '3.*'
experimental: true
continue-on-error: ${{ matrix.experimental }}
steps:
- uses: actions/setup-python@v5
with:
Expand Down
99 changes: 79 additions & 20 deletions backtesting/_util.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

import os
import sys
import warnings
from contextlib import contextmanager
from functools import partial
from itertools import chain
from multiprocessing import resource_tracker as _mprt
from multiprocessing import shared_memory as _mpshm
from numbers import Number
Expand All @@ -12,6 +15,13 @@
import numpy as np
import pandas as pd

try:
from tqdm.auto import tqdm as _tqdm
_tqdm = partial(_tqdm, leave=False)
except ImportError:
def _tqdm(seq, **_):
return seq


def try_(lazy_func, default=None, exception=Exception):
try:
Expand Down Expand Up @@ -55,6 +65,13 @@ def _as_list(value) -> List:
return [value]


def _batch(seq):
# XXX: Replace with itertools.batched
n = np.clip(int(len(seq) // (os.cpu_count() or 1)), 1, 300)
for i in range(0, len(seq), n):
yield seq[i:i + n]


def _data_period(index) -> Union[pd.Timedelta, Number]:
"""Return data index period as pd.Timedelta"""
values = pd.Series(index[-100:])
Expand Down Expand Up @@ -233,7 +250,6 @@ def __setstate__(self, state):

if sys.version_info >= (3, 13):
SharedMemory = _mpshm.SharedMemory
from multiprocessing.managers import SharedMemoryManager # noqa: F401
else:
class SharedMemory(_mpshm.SharedMemory):
# From https://github.com/python/cpython/issues/82300#issuecomment-2169035092
Expand All @@ -244,7 +260,7 @@ def __init__(self, *args, track: bool = True, **kwargs):
if track:
return super().__init__(*args, **kwargs)
with self.__lock:
with patch(_mprt, 'register', lambda *a, **kw: None): # TODO lambda
with patch(_mprt, 'register', lambda *a, **kw: None):
super().__init__(*args, **kwargs)

def unlink(self):
Expand All @@ -253,23 +269,66 @@ def unlink(self):
if self._track:
_mprt.unregister(self._name, "shared_memory")

class SharedMemoryManager:
def __init__(self) -> None:
self._shms: list[SharedMemory] = []

def SharedMemory(self, size):
shm = SharedMemory(create=True, size=size, track=True)
self._shms.append(shm)
return shm

def __enter__(self):
return self

def __exit__(self, *args, **kwargs):
for shm in self._shms:
try:
shm.close()
class SharedMemoryManager:
"""
A simple shared memory contextmanager based on
https://docs.python.org/3/library/multiprocessing.shared_memory.html#multiprocessing.shared_memory.SharedMemory
"""
def __init__(self, create=False) -> None:
self._shms: list[SharedMemory] = []
self.__create = create

def SharedMemory(self, *, name=None, create=False, size=0, track=True):
shm = SharedMemory(name=name, create=create, size=size, track=track)
shm._create = create
# Essential to keep refs on Windows
# https://stackoverflow.com/questions/74193377/filenotfounderror-when-passing-a-shared-memory-to-a-new-process#comment130999060_74194875 # noqa: E501
self._shms.append(shm)
return shm

def __enter__(self):
return self

def __exit__(self, *args, **kwargs):
for shm in self._shms:
try:
shm.close()
if shm._create:
shm.unlink()
except Exception:
warnings.warn(f'Failed to unlink shared memory {shm.name!r}',
category=ResourceWarning, stacklevel=2)
except Exception:
warnings.warn(f'Failed to unlink shared memory {shm.name!r}',
category=ResourceWarning, stacklevel=2)
raise

def arr2shm(self, vals):
"""Array to shared memory. Returns (shm_name, shape, dtype) used for restore."""
assert vals.ndim == 1, (vals.ndim, vals.shape, vals)
shm = self.SharedMemory(size=vals.nbytes, create=True)
buf = np.ndarray(vals.shape, dtype=vals.dtype, buffer=shm.buf)
buf[:] = vals[:] # Copy into shared memory
return shm.name, vals.shape, vals.dtype

def df2shm(self, df):
return tuple((
(column, *self.arr2shm(values))
for column, values in chain([(self._DF_INDEX_COL, df.index)], df.items())
))

@staticmethod
def shm2arr(shm, shape, dtype):
arr = np.ndarray(shape, dtype=dtype, buffer=shm.buf)
arr.setflags(write=False)
return arr

_DF_INDEX_COL = '__bt_index'

@staticmethod
def shm2df(data_shm):
shm = [SharedMemory(name=name, create=False, track=False) for _, name, _, _ in data_shm]
df = pd.DataFrame({
col: SharedMemoryManager.shm2arr(shm, shape, dtype)
for shm, (col, _, shape, dtype) in zip(shm, data_shm)})
df.set_index(SharedMemoryManager._DF_INDEX_COL, drop=True, inplace=True)
df.index.name = None
return df, shm
50 changes: 4 additions & 46 deletions backtesting/backtesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from __future__ import annotations

import multiprocessing as mp
import os
import sys
import warnings
from abc import ABCMeta, abstractmethod
Expand All @@ -24,18 +23,11 @@
import pandas as pd
from numpy.random import default_rng

try:
from tqdm.auto import tqdm as _tqdm
_tqdm = partial(_tqdm, leave=False)
except ImportError:
def _tqdm(seq, **_):
return seq

from ._plotting import plot # noqa: I001
from ._stats import compute_stats
from ._util import (
SharedMemory, SharedMemoryManager, _as_str, _Indicator, _Data, _indicator_warmup_nbars,
_strategy_indicators, patch, try_,
SharedMemoryManager, _as_str, _Indicator, _Data, _batch, _indicator_warmup_nbars,
_strategy_indicators, patch, try_, _tqdm,
)

__pdoc__ = {
Expand Down Expand Up @@ -1507,36 +1499,14 @@ def _optimize_grid() -> Union[pd.Series, Tuple[pd.Series, pd.Series]]:
[p.values() for p in param_combos],
names=next(iter(param_combos)).keys()))

def _batch(seq):
# XXX: Replace with itertools.batched
n = np.clip(int(len(seq) // (os.cpu_count() or 1)), 1, 300)
for i in range(0, len(seq), n):
yield seq[i:i + n]

with mp.Pool() as pool, \
SharedMemoryManager() as smm:

shm_refs = [] # https://stackoverflow.com/questions/74193377/filenotfounderror-when-passing-a-shared-memory-to-a-new-process#comment130999060_74194875 # noqa: E501

def arr2shm(vals):
nonlocal smm
shm = smm.SharedMemory(size=vals.nbytes)
buf = np.ndarray(vals.shape, dtype=vals.dtype, buffer=shm.buf)
buf[:] = vals[:] # Copy into shared memory
assert vals.ndim == 1, (vals.ndim, vals.shape, vals)
shm_refs.append(shm)
return shm.name, vals.shape, vals.dtype

data_shm = tuple((
(column, *arr2shm(values))
for column, values in chain([(Backtest._mp_task_INDEX_COL, self._data.index)],
self._data.items())
))
with patch(self, '_data', None):
bt = copy(self) # bt._data will be reassigned in _mp_task worker
results = _tqdm(
pool.imap(Backtest._mp_task,
((bt, data_shm, params_batch)
((bt, smm.df2shm(self._data), params_batch)
for params_batch in _batch(param_combos))),
total=len(param_combos),
desc='Backtest.optimize'
Expand Down Expand Up @@ -1640,27 +1610,15 @@ def cons(x):
@staticmethod
def _mp_task(arg):
bt, data_shm, params_batch = arg
shm = [SharedMemory(name=shm_name, create=False, track=False)
for _, shm_name, *_ in data_shm]
bt._data, shm = SharedMemoryManager.shm2df(data_shm)
try:
def shm2arr(shm, shape, dtype):
arr = np.ndarray(shape, dtype=dtype, buffer=shm.buf)
arr.setflags(write=False)
return arr

bt._data = df = pd.DataFrame({
col: shm2arr(shm, shape, dtype)
for shm, (col, _, shape, dtype) in zip(shm, data_shm)})
df.set_index(Backtest._mp_task_INDEX_COL, drop=True, inplace=True)
return [stats.filter(regex='^[^_]') if stats['# Trades'] else None
for stats in (bt.run(**params)
for params in params_batch)]
finally:
for shmem in shm:
shmem.close()

_mp_task_INDEX_COL = '__bt_index'

def plot(self, *, results: pd.Series = None, filename=None, plot_width=None,
plot_equity=True, plot_return=False, plot_pl=True,
plot_volume=True, plot_drawdown=False, plot_trades=True,
Expand Down
88 changes: 85 additions & 3 deletions backtesting/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@

from __future__ import annotations

import multiprocessing as mp
from collections import OrderedDict
from inspect import currentframe
from itertools import compress
from itertools import chain, compress, count
from numbers import Number
from typing import Callable, Generator, Optional, Sequence, Union

Expand All @@ -24,7 +25,7 @@

from ._plotting import plot_heatmaps as _plot_heatmaps
from ._stats import compute_stats as _compute_stats
from ._util import _Array, _as_str
from ._util import SharedMemoryManager, _Array, _as_str, _batch, _tqdm
from .backtesting import Backtest, Strategy

__pdoc__ = {}
Expand Down Expand Up @@ -474,11 +475,24 @@ def set_atr_periods(self, periods: int = 100):

def set_trailing_sl(self, n_atr: float = 6):
"""
Sets the future trailing stop-loss as some multiple (`n_atr`)
Set the future trailing stop-loss as some multiple (`n_atr`)
average true bar ranges away from the current price.
"""
self.__n_atr = n_atr

def set_trailing_pct(self, pct: float = .05):
"""
Set the future trailing stop-loss as some percent (`0 < pct < 1`)
below the current price (default 5% below).

.. note:: Stop-loss set by `pct` is inexact
Stop-loss set by `set_trailing_pct` is converted to units of ATR
with `mean(Close * pct / atr)` and set with `set_trailing_sl`.
"""
assert 0 < pct < 1, 'Need pct= as rate, i.e. 5% == 0.05'
pct_in_atr = np.mean(self.data.Close * pct / self.__atr) # type: ignore
self.set_trailing_sl(pct_in_atr)

def next(self):
super().next()
# Can't use index=-1 because self.__atr is not an Indicator type
Expand Down Expand Up @@ -522,6 +536,74 @@ def __init__(self,
__pdoc__[f'{cls.__name__}.__init__'] = False


class MultiBacktest:
"""
Multi-dataset `backtesting.backtesting.Backtest` wrapper.

Run supplied `backtesting.backtesting.Strategy` on several instruments,
in parallel. Used for comparing strategy runs across many instruments
or classes of instruments. Example:

from backtesting.test import EURUSD, BTCUSD, SmaCross
btm = MultiBacktest([EURUSD, BTCUSD], SmaCross)
stats_per_ticker: pd.DataFrame = btm.run(fast=10, slow=20)
heatmap_per_ticker: pd.DataFrame = btm.optimize(...)
"""
def __init__(self, df_list, strategy_cls, **kwargs):
self._dfs = df_list
self._strategy = strategy_cls
self._bt_kwargs = kwargs

def run(self, **kwargs):
"""
Wraps `backtesting.backtesting.Backtest.run`. Returns `pd.DataFrame` with
currency indexes in columns.
"""
with mp.Pool() as pool, \
SharedMemoryManager() as smm:
shm = [smm.df2shm(df) for df in self._dfs]
results = _tqdm(
pool.imap(self._mp_task_run,
((df_batch, self._strategy, self._bt_kwargs, kwargs)
for df_batch in _batch(shm))),
total=len(shm),
desc=self.__class__.__name__,
)
df = pd.DataFrame(list(chain(*results))).transpose()
return df

@staticmethod
def _mp_task_run(args):
data_shm, strategy, bt_kwargs, run_kwargs = args
dfs, shms = zip(*(SharedMemoryManager.shm2df(i) for i in data_shm))
try:
return [stats.filter(regex='^[^_]') if stats['# Trades'] else None
for stats in (Backtest(df, strategy, **bt_kwargs).run(**run_kwargs)
for df in dfs)]
finally:
for shmem in chain(*shms):
shmem.close()

def optimize(self, **kwargs) -> pd.DataFrame:
"""
Wraps `backtesting.backtesting.Backtest.optimize`, but returns `pd.DataFrame` with
currency indexes in columns.

heamap: pd.DataFrame = btm.optimize(...)
from backtesting.plot import plot_heatmaps
plot_heatmaps(heatmap.mean(axis=1))
"""
heatmaps = []
# Simple loop since bt.optimize already does its own multiprocessing
for df in _tqdm(self._dfs, desc=self.__class__.__name__):
bt = Backtest(df, self._strategy, **self._bt_kwargs)
_best_stats, heatmap = bt.optimize( # type: ignore
return_heatmap=True, return_optimization=False, **kwargs)
heatmaps.append(heatmap)
heatmap = pd.DataFrame(dict(zip(count(), heatmaps)))
return heatmap


# NOTE: Don't put anything below this __all__ list

__all__ = [getattr(v, '__name__', k)
Expand Down
Loading