From eb15c77263ecdd4aadd2856170bd8448e2abf8e5 Mon Sep 17 00:00:00 2001 From: mjvakili Date: Wed, 24 Sep 2025 11:16:34 +0200 Subject: [PATCH] add building blocks of american tf system --- pyproject.toml | 13 + tests/test_american_tf.py | 663 +++++++++++++++++++++++++ tests/test_filters.py | 138 +++-- tests/test_volatility.py | 288 +++++++---- trendfollower/__init__.py | 3 + trendfollower/batch/__init__.py | 0 trendfollower/core/__init__.py | 0 trendfollower/core/american_tf.py | 321 ++++++++++++ trendfollower/core/filters.py | 185 +++++++ trendfollower/{ => core}/volatility.py | 2 +- trendfollower/filters.py | 105 ---- trendfollower/streaming/__init__.py | 0 12 files changed, 1473 insertions(+), 245 deletions(-) create mode 100644 tests/test_american_tf.py create mode 100644 trendfollower/batch/__init__.py create mode 100644 trendfollower/core/__init__.py create mode 100644 trendfollower/core/american_tf.py create mode 100644 trendfollower/core/filters.py rename trendfollower/{ => core}/volatility.py (95%) delete mode 100644 trendfollower/filters.py create mode 100644 trendfollower/streaming/__init__.py diff --git a/pyproject.toml b/pyproject.toml index 4e992f3..7ffcc33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,3 +14,16 @@ tests = [ "pytest", "pytest-cov", ] + +docs = [ + "scipy", + "pdoc3", + "matplotlib", + "plotly", + "seaborn", + "nbformat==5.4.0", + "ipython", + "notebook", + "ipywidgets", + "yfinance" +] diff --git a/tests/test_american_tf.py b/tests/test_american_tf.py new file mode 100644 index 0000000..ee3a116 --- /dev/null +++ b/tests/test_american_tf.py @@ -0,0 +1,663 @@ +"""Tests for American Trend Following System.""" + +import pytest +from pydantic import ValidationError +from unittest import mock + +from trendfollower.core.american_tf import ( + AmericanTFState, + AmericanTrendFollower, + MarketData, + StrategyParams, + handle_long_position, + handle_neutral_position, + handle_short_position, +) + + +class TestMarketData: + """Test cases for MarketData model.""" + + def test_valid_market_data(self): + """Test creation of valid MarketData.""" + market = MarketData(price=100.0, fast_ema=102.0, slow_ema=98.0, atr=2.0) + assert market.price == 100.0 + assert market.fast_ema == 102.0 + assert market.slow_ema == 98.0 + assert market.atr == 2.0 + + def test_market_data_immutable(self): + """Test that MarketData is immutable.""" + market = MarketData(price=100.0, fast_ema=102.0, slow_ema=98.0, atr=2.0) + with pytest.raises(ValidationError): + market.price = 150.0 + + @pytest.mark.parametrize("field,value", [ + ("price", 0.0), + ("price", -10.0), + ("fast_ema", 0.0), + ("fast_ema", -5.0), + ("slow_ema", 0.0), + ("slow_ema", -3.0), + ("atr", 0.0), + ("atr", -1.0), + ]) + def test_market_data_validation_errors(self, field, value): + """Test MarketData validation for non-positive values.""" + data = {"price": 100.0, "fast_ema": 102.0, "slow_ema": 98.0, "atr": 2.0} + data[field] = value + with pytest.raises(ValidationError): + MarketData(**data) + + +class TestStrategyParams: + """Test cases for StrategyParams model.""" + + def test_valid_strategy_params(self): + """Test creation of valid StrategyParams.""" + params = StrategyParams(entry_point_width=0.5, stop_loss_width=2.0, risk_multiple=0.02) + assert params.entry_point_width == 0.5 + assert params.stop_loss_width == 2.0 + assert params.risk_multiple == 0.02 + + def test_strategy_params_immutable(self): + """Test that StrategyParams is immutable.""" + params = StrategyParams(entry_point_width=0.5, stop_loss_width=2.0, risk_multiple=0.02) + with pytest.raises(ValidationError): + params.entry_point_width = 1.0 + + @pytest.mark.parametrize("field,value", [ + ("entry_point_width", 0.0), + ("entry_point_width", -0.1), + ("stop_loss_width", 0.0), + ("stop_loss_width", -1.0), + ("risk_multiple", 0.0), + ("risk_multiple", -0.01), + ]) + def test_strategy_params_validation_errors(self, field, value): + """Test StrategyParams validation for non-positive values.""" + data = {"entry_point_width": 0.5, "stop_loss_width": 2.0, "risk_multiple": 0.02} + data[field] = value + with pytest.raises(ValidationError): + StrategyParams(**data) + + +class TestAmericanTFState: + """Test cases for AmericanTFState model.""" + + def test_valid_neutral_state(self): + """Test creation of valid neutral state.""" + state = AmericanTFState() + assert state.position == 0 + assert state.position_size == 0 + assert state.stop_loss is None + assert state.entry_price is None + assert state.close_next_day is False + + def test_valid_long_state(self): + """Test creation of valid long state.""" + state = AmericanTFState( + position=1, position_size=10.0, stop_loss=95.0, entry_price=100.0 + ) + assert state.position == 1 + assert state.position_size == 10.0 + assert state.stop_loss == 95.0 + assert state.entry_price == 100.0 + + def test_valid_short_state(self): + """Test creation of valid short state.""" + state = AmericanTFState( + position=-1, position_size=-10.0, stop_loss=105.0, entry_price=100.0 + ) + assert state.position == -1 + assert state.position_size == -10.0 + assert state.stop_loss == 105.0 + assert state.entry_price == 100.0 + + def test_state_immutable(self): + """Test that AmericanTFState is immutable.""" + state = AmericanTFState() + with pytest.raises(ValidationError): + state.position = 1 + + @pytest.mark.parametrize("position", [-2, 2, 5]) + def test_invalid_position_values(self, position): + """Test validation error for invalid position values.""" + with pytest.raises(ValidationError): + AmericanTFState(position=position) + + def test_invalid_neutral_position_with_position_size(self): + """Test validation error for neutral position with non-zero position_size.""" + with pytest.raises(ValidationError): + AmericanTFState(position=0, position_size=10.0) + + def test_invalid_neutral_position_with_stop_loss(self): + """Test validation error for neutral position with stop_loss.""" + with pytest.raises(ValidationError): + AmericanTFState(position=0, stop_loss=95.0) + + def test_invalid_neutral_position_with_entry_price(self): + """Test validation error for neutral position with entry_price.""" + with pytest.raises(ValidationError): + AmericanTFState(position=0, entry_price=100.0) + + def test_invalid_long_position_negative_size(self): + """Test validation error for long position with negative position_size.""" + with pytest.raises(ValidationError): + AmericanTFState(position=1, position_size=-10.0, stop_loss=95.0, entry_price=100.0) + + def test_invalid_long_position_no_stop_loss(self): + """Test validation error for long position without stop_loss.""" + with pytest.raises(ValidationError): + AmericanTFState(position=1, position_size=10.0, entry_price=100.0) + + def test_invalid_long_position_no_entry_price(self): + """Test validation error for long position without entry_price.""" + with pytest.raises(ValidationError): + AmericanTFState(position=1, position_size=10.0, stop_loss=95.0) + + def test_invalid_short_position_positive_size(self): + """Test validation error for short position with positive position_size.""" + with pytest.raises(ValidationError): + AmericanTFState(position=-1, position_size=10.0, stop_loss=105.0, entry_price=100.0) + + def test_invalid_short_position_no_stop_loss(self): + """Test validation error for short position without stop_loss.""" + with pytest.raises(ValidationError): + AmericanTFState(position=-1, position_size=-10.0, entry_price=100.0) + + def test_invalid_short_position_no_entry_price(self): + """Test validation error for short position without entry_price.""" + with pytest.raises(ValidationError): + AmericanTFState(position=-1, position_size=-10.0, stop_loss=105.0) + + def test_negative_stop_loss(self): + """Test validation error for negative stop_loss.""" + with pytest.raises(ValidationError): + AmericanTFState(position=1, position_size=10.0, stop_loss=-5.0, entry_price=100.0) + + def test_negative_entry_price(self): + """Test validation error for negative entry_price.""" + with pytest.raises(ValidationError): + AmericanTFState(position=1, position_size=10.0, stop_loss=95.0, entry_price=-100.0) + + +class TestHandleNeutralPosition: + """Test cases for handle_neutral_position function.""" + + def test_long_signal_entry(self): + """Test long entry when fast EMA > slow EMA + signal buffer.""" + market = MarketData(price=100.0, fast_ema=105.0, slow_ema=100.0, atr=2.0) + params = StrategyParams(entry_point_width=2.0, stop_loss_width=3.0, risk_multiple=0.02) + + result = handle_neutral_position(market, params) + + assert result.position == 1 + assert result.position_size == 1.0 # 0.02 * 100 / 2 + assert result.stop_loss == 94.0 # 100 - 3 * 2 + assert result.entry_price == 100.0 + assert result.close_next_day is False + + def test_short_signal_entry(self): + """Test short entry when fast EMA < slow EMA - signal buffer.""" + market = MarketData(price=100.0, fast_ema=95.0, slow_ema=100.0, atr=2.0) + params = StrategyParams(entry_point_width=2.0, stop_loss_width=3.0, risk_multiple=0.02) + + result = handle_neutral_position(market, params) + + assert result.position == -1 + assert result.position_size == -1.0 # -0.02 * 100 / 2 + assert result.stop_loss == 106.0 # 100 + 3 * 2 + assert result.entry_price == 100.0 + assert result.close_next_day is False + + def test_no_signal_remain_neutral(self): + """Test remaining neutral when no clear signal.""" + market = MarketData(price=100.0, fast_ema=101.0, slow_ema=100.0, atr=2.0) + params = StrategyParams(entry_point_width=2.0, stop_loss_width=3.0, risk_multiple=0.02) + + result = handle_neutral_position(market, params) + + assert result.position == 0 + assert result.position_size == 0.0 + assert result.stop_loss is None + assert result.entry_price is None + assert result.close_next_day is False + + def test_boundary_condition_long(self): + """Test boundary condition for long signal (exactly at threshold).""" + market = MarketData(price=100.0, fast_ema=104.0, slow_ema=100.0, atr=2.0) + params = StrategyParams(entry_point_width=2.0, stop_loss_width=3.0, risk_multiple=0.02) + + result = handle_neutral_position(market, params) + + # fast_ema (104) == slow_ema (100) + signal_buffer (4), so no entry + assert result.position == 0 + + def test_boundary_condition_short(self): + """Test boundary condition for short signal (exactly at threshold).""" + market = MarketData(price=100.0, fast_ema=96.0, slow_ema=100.0, atr=2.0) + params = StrategyParams(entry_point_width=2.0, stop_loss_width=3.0, risk_multiple=0.02) + + result = handle_neutral_position(market, params) + + # fast_ema (96) == slow_ema (100) - signal_buffer (4), so no entry + assert result.position == 0 + + +class TestHandleLongPosition: + """Test cases for handle_long_position function.""" + + @pytest.fixture + def long_state(self): + """Fixture for a typical long position state.""" + return AmericanTFState( + position=1, position_size=1.0, stop_loss=95.0, entry_price=100.0 + ) + + @pytest.fixture + def market_data(self): + """Fixture for typical market data.""" + return MarketData(price=98.0, fast_ema=105.0, slow_ema=100.0, atr=2.0) + + @pytest.fixture + def strategy_params(self): + """Fixture for typical strategy parameters.""" + return StrategyParams(entry_point_width=2.0, stop_loss_width=3.0, risk_multiple=0.02) + + def test_close_next_day_flag(self, market_data, strategy_params): + """Test that position is closed when close_next_day flag is set.""" + state = AmericanTFState( + position=1, position_size=1.0, stop_loss=95.0, entry_price=100.0, close_next_day=True + ) + + result = handle_long_position(market_data, strategy_params, state) + + assert result.position == 0 + assert result.position_size == 0.0 + assert result.stop_loss is None + assert result.entry_price is None + assert result.close_next_day is False + + def test_stop_loss_breach_signal_off(self, strategy_params, long_state): + """Test stop loss breach when signal is off (immediate close).""" + market = MarketData(price=94.0, fast_ema=100.0, slow_ema=100.0, atr=2.0) # Signal off + + result = handle_long_position(market, strategy_params, long_state) + + assert result.position == 0 + assert result.position_size == 0.0 + assert result.stop_loss is None + assert result.entry_price is None + assert result.close_next_day is False + + def test_stop_loss_breach_signal_on(self, strategy_params, long_state): + """Test stop loss breach when signal is still on (mark for closure).""" + market = MarketData(price=94.0, fast_ema=105.0, slow_ema=100.0, atr=2.0) # Signal on + + result = handle_long_position(market, strategy_params, long_state) + + assert result.position == 1 # Position unchanged + assert result.position_size == 1.0 # Position size unchanged + assert result.stop_loss == 95.0 # Stop loss unchanged + assert result.entry_price == 100.0 # Entry price unchanged + assert result.close_next_day is True # Marked for closure next day + + def test_trailing_stop_loss_update(self, strategy_params, long_state): + """Test trailing stop loss update when signal is on.""" + market = MarketData(price=102.0, fast_ema=105.0, slow_ema=100.0, atr=2.0) + + result = handle_long_position(market, strategy_params, long_state) + + expected_new_stop = max(95.0, 102.0 - 3.0 * 2.0) # max(95, 96) = 96 + assert result.position == 1 + assert result.position_size == 1.0 + assert result.stop_loss == 96.0 + assert result.entry_price == 100.0 + assert result.close_next_day is False + + def test_stop_loss_no_update_when_lower(self, strategy_params, long_state): + """Test that stop loss doesn't move down.""" + market = MarketData(price=97.0, fast_ema=105.0, slow_ema=100.0, atr=2.0) + + result = handle_long_position(market, strategy_params, long_state) + + expected_new_stop = max(95.0, 97.0 - 3.0 * 2.0) # max(95, 91) = 95 + assert result.stop_loss == 95.0 # Unchanged + + def test_invalid_state_no_stop_loss(self, market_data, strategy_params): + """Test error when long position has no stop loss.""" + # We can't create an invalid state due to Pydantic validation, + # but we can test the handler's explicit check by directly creating a mock + + # Create a mock state that mimics invalid data + mock_state = mock.Mock(spec=AmericanTFState) + mock_state.position = 1 + mock_state.position_size = 1.0 + mock_state.stop_loss = None + mock_state.entry_price = 100.0 + + with pytest.raises(ValueError, match="stop_loss must be set for open positions"): + handle_long_position(market_data, strategy_params, mock_state) + + +class TestHandleShortPosition: + """Test cases for handle_short_position function.""" + + @pytest.fixture + def short_state(self): + """Fixture for a typical short position state.""" + return AmericanTFState( + position=-1, position_size=-1.0, stop_loss=105.0, entry_price=100.0 + ) + + @pytest.fixture + def market_data(self): + """Fixture for typical market data.""" + return MarketData(price=102.0, fast_ema=95.0, slow_ema=100.0, atr=2.0) + + @pytest.fixture + def strategy_params(self): + """Fixture for typical strategy parameters.""" + return StrategyParams(entry_point_width=2.0, stop_loss_width=3.0, risk_multiple=0.02) + + def test_close_next_day_flag(self, market_data, strategy_params): + """Test that position is closed when close_next_day flag is set.""" + state = AmericanTFState( + position=-1, position_size=-1.0, stop_loss=105.0, entry_price=100.0, close_next_day=True + ) + + result = handle_short_position(market_data, strategy_params, state) + + assert result.position == 0 + assert result.position_size == 0 + assert result.stop_loss is None + assert result.entry_price is None + assert result.close_next_day is False + + def test_stop_loss_breach_signal_off(self, strategy_params, short_state): + """Test stop loss breach when signal is off (immediate close).""" + market = MarketData(price=106.0, fast_ema=100.0, slow_ema=100.0, atr=2.0) # Signal off + + result = handle_short_position(market, strategy_params, short_state) + + assert result.position == 0 + assert result.position_size == 0.0 + assert result.stop_loss is None + assert result.entry_price is None + assert result.close_next_day is False + + def test_stop_loss_breach_signal_on(self, strategy_params, short_state): + """Test stop loss breach when signal is still on (mark for closure).""" + market = MarketData(price=106.0, fast_ema=95.0, slow_ema=100.0, atr=2.0) # Signal on + + result = handle_short_position(market, strategy_params, short_state) + + assert result.position == -1 # Position unchanged + assert result.position_size == -1.0 # Position size unchanged + assert result.stop_loss == 105.0 # Stop loss unchanged + assert result.entry_price == 100.0 # Entry price unchanged + assert result.close_next_day is True # Marked for closure + + def test_trailing_stop_loss_update(self, strategy_params, short_state): + """Test trailing stop loss update when price moves favorably.""" + market = MarketData(price=98.0, fast_ema=95.0, slow_ema=100.0, atr=2.0) + + result = handle_short_position(market, strategy_params, short_state) + + expected_new_stop = min(105.0, 98.0 + 3.0 * 2.0) # min(105, 104) = 104 + assert result.position == -1 + assert result.position_size == -1.0 + assert result.stop_loss == 104.0 + assert result.entry_price == 100.0 + assert result.close_next_day is False + + def test_stop_loss_no_update_when_higher(self, strategy_params, short_state): + """Test that stop loss doesn't move up.""" + market = MarketData(price=103.0, fast_ema=95.0, slow_ema=100.0, atr=2.0) + + result = handle_short_position(market, strategy_params, short_state) + + expected_new_stop = min(105.0, 103.0 + 3.0 * 2.0) # min(105, 109) = 105 + assert result.stop_loss == 105.0 # Unchanged + + def test_invalid_state_no_stop_loss(self, market_data, strategy_params): + """Test error when short position has no stop loss.""" + # Create a mock state that mimics invalid data + mock_state = mock.Mock(spec=AmericanTFState) + mock_state.position = -1 + mock_state.position_size = -1.0 + mock_state.stop_loss = None # This is what we want to test + mock_state.entry_price = 100.0 + + with pytest.raises(ValueError, match="stop_loss must be set for open positions"): + handle_short_position(market_data, strategy_params, mock_state) + + +class TestAmericanTrendFollower: + """Test cases for AmericanTrendFollower class.""" + + @pytest.fixture + def strategy(self): + """Fixture for AmericanTrendFollower instance.""" + return AmericanTrendFollower( + entry_point_width=2.0, stop_loss_width=3.0, risk_multiple=0.02 + ) + + def test_initialization(self, strategy): + """Test proper initialization of AmericanTrendFollower.""" + assert strategy.params.entry_point_width == 2.0 + assert strategy.params.stop_loss_width == 3.0 + assert strategy.params.risk_multiple == 0.02 + assert len(strategy.handlers) == 3 + assert 0 in strategy.handlers + assert 1 in strategy.handlers + assert -1 in strategy.handlers + + def test_initialization_validation_errors(self): + """Test validation errors during initialization.""" + with pytest.raises(ValidationError): + AmericanTrendFollower(entry_point_width=0.0, stop_loss_width=3.0, risk_multiple=0.02) + + with pytest.raises(ValidationError): + AmericanTrendFollower(entry_point_width=2.0, stop_loss_width=0.0, risk_multiple=0.02) + + with pytest.raises(ValidationError): + AmericanTrendFollower(entry_point_width=2.0, stop_loss_width=3.0, risk_multiple=0.0) + + def test_update_state_from_neutral_to_long(self, strategy): + """Test state update from neutral to long position.""" + initial_state = AmericanTFState() + + result = strategy.update_state( + s=100.0, s_fast=105.0, s_slow=100.0, atr=2.0, state=initial_state + ) + + assert result.position == 1 + assert result.position_size == 1.0 + assert result.stop_loss == 94.0 + assert result.entry_price == 100.0 + + def test_update_state_from_neutral_to_short(self, strategy): + """Test state update from neutral to short position.""" + initial_state = AmericanTFState() + + result = strategy.update_state( + s=100.0, s_fast=95.0, s_slow=100.0, atr=2.0, state=initial_state + ) + + assert result.position == -1 + assert result.position_size == -1.0 + assert result.stop_loss == 106.0 + assert result.entry_price == 100.0 + + def test_update_state_remain_neutral(self, strategy): + """Test state update remaining neutral.""" + initial_state = AmericanTFState() + + result = strategy.update_state( + s=100.0, s_fast=101.0, s_slow=100.0, atr=2.0, state=initial_state + ) + + assert result.position == 0 + assert result.position_size == 0.0 + assert result.stop_loss is None + assert result.entry_price is None + + def test_update_state_long_position_trailing_stop(self, strategy): + """Test state update for long position with trailing stop.""" + initial_state = AmericanTFState( + position=1, position_size=1.0, stop_loss=95.0, entry_price=100.0 + ) + + result = strategy.update_state( + s=102.0, s_fast=105.0, s_slow=100.0, atr=2.0, state=initial_state + ) + + assert result.position == 1 + assert result.stop_loss == 96.0 # Updated trailing stop + + def test_update_state_short_position_trailing_stop(self, strategy): + """Test state update for short position with trailing stop.""" + initial_state = AmericanTFState( + position=-1, position_size=-1.0, stop_loss=105.0, entry_price=100.0 + ) + + result = strategy.update_state( + s=98.0, s_fast=95.0, s_slow=100.0, atr=2.0, state=initial_state + ) + + assert result.position == -1 + assert result.stop_loss == 104.0 # Updated trailing stop + + def test_invalid_market_data_validation(self, strategy): + """Test that invalid market data raises validation error.""" + initial_state = AmericanTFState() + + with pytest.raises(ValidationError): + strategy.update_state( + s=0.0, s_fast=105.0, s_slow=100.0, atr=2.0, state=initial_state + ) + + with pytest.raises(ValidationError): + strategy.update_state( + s=100.0, s_fast=0.0, s_slow=100.0, atr=2.0, state=initial_state + ) + + with pytest.raises(ValidationError): + strategy.update_state( + s=100.0, s_fast=105.0, s_slow=0.0, atr=2.0, state=initial_state + ) + + with pytest.raises(ValidationError): + strategy.update_state( + s=100.0, s_fast=105.0, s_slow=100.0, atr=0.0, state=initial_state + ) + + +class TestEndtoEndScenarios: + """Test end-to-end scenarios.""" + + @pytest.fixture + def strategy(self): + """Fixture for AmericanTrendFollower instance.""" + return AmericanTrendFollower( + entry_point_width=1.0, stop_loss_width=2.0, risk_multiple=0.01 + ) + + def test_complete_trading_cycle_long(self, strategy): + """Test a complete long trading cycle: entry -> trail -> exit.""" + # Start neutral + state = AmericanTFState() + + # 1. Enter long position + state = strategy.update_state( + s=100.0, s_fast=102.0, s_slow=100.0, atr=1.0, state=state + ) + assert state.position == 1 + assert state.stop_loss == 98.0 # 100 - 2*1 + + # 2. Price moves up, trail stop loss + state = strategy.update_state( + s=105.0, s_fast=107.0, s_slow=100.0, atr=1.0, state=state + ) + assert state.position == 1 + assert state.stop_loss == 103.0 # max(98, 105-2) = 103 + + # 3. Stop loss breach with signal off - exit + state = strategy.update_state( + s=102.0, s_fast=100.0, s_slow=100.0, atr=1.0, state=state + ) + assert state.position == 0 + assert state.stop_loss is None + + def test_complete_trading_cycle_short(self, strategy): + """Test a complete short trading cycle: entry -> trail -> exit.""" + # Start neutral + state = AmericanTFState() + + # 1. Enter short position + state = strategy.update_state( + s=100.0, s_fast=98.0, s_slow=100.0, atr=1.0, state=state + ) + assert state.position == -1 + assert state.stop_loss == 102.0 # 100 + 2*1 + + # 2. Price moves down, trail stop loss + state = strategy.update_state( + s=95.0, s_fast=93.0, s_slow=100.0, atr=1.0, state=state + ) + assert state.position == -1 + assert state.stop_loss == 97.0 # min(102, 95+2) = 97 + + # 3. Stop loss breach with signal off - exit + state = strategy.update_state( + s=98.0, s_fast=100.0, s_slow=100.0, atr=1.0, state=state + ) + assert state.position == 0 + assert state.stop_loss is None + + def test_stop_loss_breach_with_signal_persistence(self, strategy): + """Test stop loss breach when signal persists (close_next_day scenario).""" + # Start with long position + state = AmericanTFState( + position=1, position_size=1.0, stop_loss=95.0, entry_price=100.0 + ) + + # Stop loss breach but signal still on + state = strategy.update_state( + s=94.0, s_fast=105.0, s_slow=100.0, atr=1.0, state=state + ) + assert state.position == 1 # Still long + assert state.close_next_day is True # Marked for closure + + # Next day - position should be closed + state = strategy.update_state( + s=94.0, s_fast=105.0, s_slow=100.0, atr=1.0, state=state + ) + assert state.position == 0 # Now closed + assert state.close_next_day is False + + def test_very_small_atr_edge_case(self, strategy): + """Test behavior with very small ATR values.""" + state = AmericanTFState() + + # Very small ATR + result = strategy.update_state( + s=100.0, s_fast=102.0, s_slow=100.0, atr=0.01, state=state + ) + + assert result.position == 1 + assert result.position_size == 100.0 # 0.01 * 100 / 0.01 + assert result.stop_loss == 99.98 # 100 - 2 * 0.01 + + def test_high_volatility_edge_case(self, strategy): + """Test behavior with high ATR values.""" + state = AmericanTFState() + result = strategy.update_state( + s=100.0, s_fast=111.0, s_slow=100.0, atr=10.0, state=state + ) + + assert result.position == 1 + assert result.position_size == 0.1 # 0.01 * 100 / 10 + assert result.stop_loss == 80.0 # 100 - 2 * 10 diff --git a/tests/test_filters.py b/tests/test_filters.py index 6776b09..26f3a83 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -1,15 +1,27 @@ +""" +Test suite for filter functions in trendfollower.core.filters. + +This module tests the exponentially weighted moving average (EWMA) filters, +including standard EWMA, variance-preserving EWMA, and long-short variance-preserving EWMA. +""" + from itertools import combinations, product import numpy as np import polars as pl import pytest -from trendfollower.filters import ewma, variance_preserving_ewma, long_short_variance_preserving_ewma +from trendfollower.core.filters import ( + ewma, + variance_preserving_ewma, + long_short_variance_preserving_ewma +) +# Test Configuration Constants RNG = np.random.default_rng(seed=42) -SMOOTHING_PARS = 0.1 + 0.1*np.arange(8) -VARS = [0.1, 1.0, 2.0] +SMOOTHING_PARS = 0.1 + 0.1 * np.arange(8) # [0.1, 0.2, ..., 0.8] +VARS = [0.1, 1.0, 2.0] # Different variance levels for testing def sample_long_series_with_input_variance(var: float) -> pl.Series: @@ -24,7 +36,6 @@ def sample_long_series_with_input_variance(var: float) -> pl.Series: ------- pl.Series A sample long series with the specified variance. - """ samples = RNG.normal(loc=0, scale=np.sqrt(var), size=100000) return pl.Series(samples) @@ -32,53 +43,82 @@ def sample_long_series_with_input_variance(var: float) -> pl.Series: @pytest.fixture(scope="module", autouse=True) def sample_series_per_variance() -> dict[float, pl.Series]: + """Generate sample series for each variance level.""" return {var: sample_long_series_with_input_variance(var) for var in VARS} -@pytest.mark.parametrize("input_var, alpha", product(VARS, SMOOTHING_PARS)) -def test_ewma_filter_mean(sample_series_per_variance, input_var, alpha): - series = sample_series_per_variance[input_var] - filtered = ewma(series, alpha=alpha) - expected = series.mean() - calculated = filtered.mean() - msg = f"Under ewma transformation, mean remains unchanged, expected {expected}, got {calculated}" - assert calculated == pytest.approx(expected, rel=1e-1), msg - - -@pytest.mark.parametrize("input_var, alpha", product(VARS, SMOOTHING_PARS)) -def test_ewma_filter_variance(sample_series_per_variance, input_var, alpha): - series = sample_series_per_variance[input_var] - filtered = ewma(series, alpha=alpha) - expected = input_var * ((1 - alpha) / (1 + alpha)) - calculated = filtered.var() - msg = f"Variance of ewma of a series with input_variance var must be (1 - alpha) / (1 + alpha) * var, expected {expected}, got {calculated}" - assert calculated == pytest.approx(expected, rel=1e-1), msg - - -@pytest.mark.parametrize("input_var, alpha", product(VARS, SMOOTHING_PARS)) -def test_variance_preserving_ewma(sample_series_per_variance, input_var, alpha): - series = sample_series_per_variance[input_var] - filtered = variance_preserving_ewma(series, alpha=alpha) - expected = input_var - calculated = filtered.var() - msg = f"Variance of variance_preserving_ewma of a series with input_variance var must be var, expected {expected}, got {calculated}" - assert calculated == pytest.approx(expected, rel=1e-1), msg - - -@pytest.mark.parametrize("input_var, alphas", product(VARS, combinations(SMOOTHING_PARS, 2))) -def test_variance_long_short_variance_preserving_ewma(sample_series_per_variance, input_var, alphas): - series = sample_series_per_variance[input_var] - filtered = long_short_variance_preserving_ewma(series, alpha1=alphas[0], alpha2=alphas[1]) - expected = input_var - calculated = filtered.var() - msg = f"Variance of long_short_variance_preserving_ewma of a series with input_variance var must be var, expected {expected}, got {calculated}" - assert calculated == pytest.approx(expected, rel=1e-1), msg - - -@pytest.mark.parametrize("alpha", SMOOTHING_PARS) -def test_trivial_long_short_variance_preserving_ewma(sample_series_per_variance, alpha): - series = sample_series_per_variance[0.1] - with pytest.raises(ValueError, match="alpha1 and alpha2 must be different. When they are equal, the long-short filter is ill-defined."): - long_short_variance_preserving_ewma(series, alpha1=alpha, alpha2=alpha) +class TestEWMAFilter: + """Test cases for the standard EWMA filter.""" + + @pytest.mark.parametrize("input_var, alpha", product(VARS, SMOOTHING_PARS)) + def test_mean_preservation(self, sample_series_per_variance, input_var, alpha): + """Test that EWMA preserves the mean of the input series.""" + series = sample_series_per_variance[input_var] + filtered = ewma(series, nu=alpha) + expected = series.mean() + calculated = filtered.mean() + + msg = f"EWMA should preserve mean: expected {expected}, got {calculated}" + assert calculated == pytest.approx(expected, rel=1e-1), msg + + @pytest.mark.parametrize("input_var, alpha", product(VARS, SMOOTHING_PARS)) + def test_variance_reduction(self, sample_series_per_variance, input_var, alpha): + """Test that EWMA reduces variance according to the theoretical formula.""" + series = sample_series_per_variance[input_var] + filtered = ewma(series, nu=alpha) + + # Theoretical variance for EWMA: var * (1 - alpha) / (1 + alpha) + expected = input_var * ((1 - alpha) / (1 + alpha)) + calculated = filtered.var() + + msg = (f"EWMA variance should be (1 - alpha) / (1 + alpha) * var: " + f"expected {expected}, got {calculated}") + assert calculated == pytest.approx(expected, rel=1e-1), msg + + +class TestVariancePreservingEWMA: + """Test cases for the variance-preserving EWMA filter.""" + + @pytest.mark.parametrize("input_var, alpha", product(VARS, SMOOTHING_PARS)) + def test_variance_preservation(self, sample_series_per_variance, input_var, alpha): + """Test that variance-preserving EWMA maintains input variance.""" + series = sample_series_per_variance[input_var] + filtered = variance_preserving_ewma(series, nu=alpha) + + expected = input_var + calculated = filtered.var() + + msg = (f"Variance-preserving EWMA should maintain input variance: " + f"expected {expected}, got {calculated}") + assert calculated == pytest.approx(expected, rel=1e-1), msg + + +class TestLongShortVariancePreservingEWMA: + """Test cases for the long-short variance-preserving EWMA filter.""" + + @pytest.mark.parametrize("input_var, alphas", product(VARS, combinations(SMOOTHING_PARS, 2))) + def test_variance_preservation(self, sample_series_per_variance, input_var, alphas): + """Test that long-short variance-preserving EWMA maintains input variance.""" + series = sample_series_per_variance[input_var] + filtered = long_short_variance_preserving_ewma( + series, nu1=alphas[0], nu2=alphas[1] + ) + + expected = input_var + calculated = filtered.var() + + msg = (f"Long-short variance-preserving EWMA should maintain input variance: " + f"expected {expected}, got {calculated}") + assert calculated == pytest.approx(expected, rel=1e-1), msg + + @pytest.mark.parametrize("alpha", SMOOTHING_PARS) + def test_equal_alphas_error(self, sample_series_per_variance, alpha): + """Test that equal alpha values raise a ValueError.""" + series = sample_series_per_variance[0.1] + + error_msg = ("nu1 and nu2 must be different. When they are equal, the long-short filter is ill-defined.") + + with pytest.raises(ValueError, match=error_msg): + long_short_variance_preserving_ewma(series, nu1=alpha, nu2=alpha) \ No newline at end of file diff --git a/tests/test_volatility.py b/tests/test_volatility.py index c8c2395..fe96f62 100644 --- a/tests/test_volatility.py +++ b/tests/test_volatility.py @@ -1,14 +1,13 @@ -from itertools import product -from matplotlib import axis +"""Tests suite for volatility functions in trendfollower.core.volatility.""" + + import numpy as np import polars as pl import pytest -from trendfollower.volatility import ( +from trendfollower.core.volatility import ( lag1_diff, relative_return, - sigma_t_price, - sigma_t_return, true_range, relative_true_range, ma_true_range, @@ -21,96 +20,205 @@ from .test_filters import SMOOTHING_PARS, VARS, sample_long_series_with_input_variance + +# Test Configuration Constants RNG = np.random.default_rng(seed=42) -PERIODS = [2, 10, 100] +PERIODS = [2, 10, 100] # Different periods for moving average tests + -def hlc_per_variance(var): +def hlc_per_variance(var: float) -> tuple[pl.Series, pl.Series, pl.Series]: + """Generate high, low, close price series with specified variance. + + Parameters + ---------- + var : float + The desired variance of the close price series. + + Returns + ------- + tuple[pl.Series, pl.Series, pl.Series] + High, low, and close price series. + """ close = sample_long_series_with_input_variance(var) high = pl.Series(RNG.uniform(1.01, 1.02, size=close.shape)) * close low = pl.Series(RNG.uniform(0.95, 0.96, size=close.shape)) * close return high, low, close -def test_lag1_diff_mean(): - _,_,close = hlc_per_variance(.1) - assert lag1_diff(close).mean() == pytest.approx(0, abs=1e-2) - - -@pytest.mark.parametrize("var", VARS) -def test_lag1_diff_std(var): - _,_,close = hlc_per_variance(var) - assert lag1_diff(close).std() == pytest.approx(np.sqrt(2 * var), abs=1e-2) - - -def test_logic_relative_return(): - s = pl.Series(np.arange(1, 1000)) - actual = relative_return(s).to_numpy()[1:] - desired = 1./np.linspace(1, 998, 998) - np.testing.assert_array_almost_equal(actual, desired, decimal=3) - - -def test_true_range(): - high, low, close = hlc_per_variance(0.1) - actual = true_range(high, low, close).to_numpy() - s1 = (high - low).abs().to_numpy() - s2 = (high - close.shift(1)).abs().to_numpy() - s3 = (low - close.shift(1)).abs().to_numpy() - - assert np.all(actual >= 0), "True range should be non-negative" - assert np.all(actual >= s1), "True range should be at least as large as high - low" - assert np.all(actual[1:] >= s2[1:]), "True range should be at least as large as the absolute value of high - close.shift(1)" - assert np.all(actual[1:] >= s3[1:]), "True range should be at least as large as the absolute value of low - close.shift(1)" - - -def test_relative_true_range(): - high, low, close = hlc_per_variance(0.1) - rtr = relative_true_range(high, low, close) - tr = true_range(high, low, close) - close_shift = close.shift(1) - expected = tr / close_shift - np.testing.assert_array_almost_equal(rtr.to_numpy(), expected.to_numpy(), decimal=3) - - -@pytest.mark.parametrize("period", PERIODS) -def test_ma_true_range(period): - high, low, close = hlc_per_variance(0.1) - # moving average true range from true range directly - tr = true_range(high, low, close) - ma_tr = ma_true_range(tr, period=period).to_numpy() - assert np.isfinite(ma_tr).sum() == len(ma_tr) - period + 1, f"MA True Range must have {len(ma_tr) - period + 1} finite values" - ma_tr[-1] = np.mean(tr.to_numpy()[-period:]) - ma_tr[-2] = np.mean(tr.to_numpy()[-(period+1):-1]) - # moving average true range from HLC - ma_tr = ma_true_range_from_hlc(high, low, close, period=period).to_numpy() - assert np.isfinite(ma_tr).sum() == len(ma_tr) - period + 1, f"MA True Range must have {len(ma_tr) - period + 1} finite values" - ma_tr[-1] = np.mean(ma_tr[-period:]) - ma_tr[-2] = np.mean(ma_tr[-(period+1):-1]) - # moving relative true range from relative true range directly - rtr = relative_true_range(high, low, close) - ma_rtr = ma_relative_true_range(rtr, period=period).to_numpy() - assert np.isfinite(ma_rtr).sum() == len(ma_rtr) - period, f"MA Relative True Range must have {len(ma_rtr) - period} finite values" - ma_rtr[-1] = np.mean(ma_rtr[-period:]) - ma_rtr[-2] = np.mean(ma_rtr[-(period+1):-1]) - # moving relative true range from HLC - ma_rtr = ma_relative_true_range_from_hlc(high, low, close, period=period).to_numpy() - assert np.isfinite(ma_rtr).sum() == len(ma_rtr) - period, f"MA Relative True Range must have {len(ma_rtr) - period} finite values" - ma_rtr[-1] = np.mean(ma_rtr[-period:]) - ma_rtr[-2] = np.mean(ma_rtr[-(period+1):-1]) - - -@pytest.mark.parametrize("alpha", SMOOTHING_PARS) -def test_ewma_relative_true_range(alpha): - high, low, close = hlc_per_variance(0.1) - rtr = relative_true_range(high=high, low=low, close=close) - var_rtr = rtr.var() - # ewma rtr directly from rtr - ewma_rtr = ewma_relative_true_range(rtr=rtr, alpha=alpha) - assert np.isfinite(ewma_rtr.to_numpy()).sum() == len(ewma_rtr) - 1, "All elements of EWMA Relative True Range, except the first element, must be finite" - var_ewma_rtr = ewma_rtr.var() - expected_var_ewma_rtr = (1 - alpha) / (1 + alpha) * var_rtr - assert var_ewma_rtr == pytest.approx(expected_var_ewma_rtr, rel=1e-2), f"Variance of EWMA Relative True Range must be approximately {(1 - alpha) / (1 + alpha)} times the variance of Relative True Range" - # ewma rtr from HLC - ewma_rtr_hlc = ewma_relative_true_range_from_hlc(high=high, low=low, close=close, alpha=alpha) - assert np.isfinite(ewma_rtr_hlc.to_numpy()).sum() == len(ewma_rtr_hlc) - 1, "All elements of EWMA Relative True Range from HLC, except the first element, must be finite" - var_ewma_rtr_hlc = ewma_rtr_hlc.var() - assert var_ewma_rtr_hlc == pytest.approx(expected_var_ewma_rtr, rel=1e-2), f"Variance of EWMA Relative True Range from HLC must be approximately {(1 - alpha) / (1 + alpha)} times the variance of Relative True Range" \ No newline at end of file +class TestBasicVolatilityMeasures: + """Test cases for basic volatility measures: lag1_diff and relative_return.""" + + def test_lag1_diff_mean(self): + """Test that lag1_diff has zero mean for random data.""" + _, _, close = hlc_per_variance(0.1) + result = lag1_diff(close).mean() + assert result == pytest.approx(0, abs=1e-2) + + @pytest.mark.parametrize("var", VARS) + def test_lag1_diff_std(self, var): + """Test that lag1_diff has correct standard deviation.""" + _, _, close = hlc_per_variance(var) + result = lag1_diff(close).std() + expected = np.sqrt(2 * var) + assert result == pytest.approx(expected, abs=1e-2) + + def test_relative_return_logic(self): + """Test the mathematical logic of relative_return calculation.""" + s = pl.Series(np.arange(1, 1000)) + actual = relative_return(s).to_numpy()[1:] + expected = 1. / np.linspace(1, 998, 998) + np.testing.assert_array_almost_equal(actual, expected, decimal=3) + + +class TestTrueRange: + """Test cases for true range calculations.""" + + def test_true_range_properties(self): + """Test that true range satisfies certain mathematical properties.""" + high, low, close = hlc_per_variance(0.1) + actual = true_range(high, low, close).to_numpy() + + s1 = (high - low).abs().to_numpy() + s2 = (high - close.shift(1)).abs().to_numpy() + s3 = (low - close.shift(1)).abs().to_numpy() + + # True range should be non-negative + assert np.all(actual >= 0), "True range should be non-negative" + + # True range should be at least as large as each component + assert np.all(actual >= s1), "True range should be at least as large as high - low" + assert np.all(actual[1:] >= s2[1:]), ( + "True range should be at least as large as |high - prev_close|" + ) + assert np.all(actual[1:] >= s3[1:]), ( + "True range should be at least as large as |low - prev_close|" + ) + + def test_relative_true_range_calculation(self): + """Test that relative true range is correctly calculated.""" + high, low, close = hlc_per_variance(0.1) + rtr = relative_true_range(high, low, close) + tr = true_range(high, low, close) + close_shift = close.shift(1) + expected = tr / close_shift + + np.testing.assert_array_almost_equal( + rtr.to_numpy(), expected.to_numpy(), decimal=3 + ) + + +class TestMovingAverageTrueRange: + """Test cases for moving average true range calculations.""" + + @pytest.mark.parametrize("period", PERIODS) + def test_ma_true_range_finite_values(self, period): + """Test that moving average true range has the correct number of finite values.""" + high, low, close = hlc_per_variance(0.1) + + # Test moving average true range from true range directly + tr = true_range(high, low, close) + ma_tr = ma_true_range(tr, period=period).to_numpy() + expected_finite = len(ma_tr) - period + 1 + actual_finite = np.isfinite(ma_tr).sum() + + assert actual_finite == expected_finite, ( + f"MA True Range must have {expected_finite} finite values, got {actual_finite}" + ) + + # Verify last values are computed correctly + assert ma_tr[-1] == pytest.approx(np.mean(tr.to_numpy()[-period:]), rel=1e-10) + assert ma_tr[-2] == pytest.approx(np.mean(tr.to_numpy()[-(period+1):-1]), rel=1e-10) + + @pytest.mark.parametrize("period", PERIODS) + def test_ma_true_range_from_hlc(self, period): + """Test moving average true range calculated directly from HLC data.""" + high, low, close = hlc_per_variance(0.1) + + ma_tr = ma_true_range_from_hlc(high, low, close, period=period).to_numpy() + expected_finite = len(ma_tr) - period + 1 + actual_finite = np.isfinite(ma_tr).sum() + + assert actual_finite == expected_finite, ( + f"MA True Range from HLC must have {expected_finite} finite values, got {actual_finite}" + ) + + @pytest.mark.parametrize("period", PERIODS) + def test_ma_relative_true_range(self, period): + """Test moving average relative true range calculations.""" + high, low, close = hlc_per_variance(0.1) + + # Test from relative true range directly + rtr = relative_true_range(high, low, close) + ma_rtr = ma_relative_true_range(rtr, period=period).to_numpy() + expected_finite = len(ma_rtr) - period + actual_finite = np.isfinite(ma_rtr).sum() + + assert actual_finite == expected_finite, ( + f"MA Relative True Range must have {expected_finite} finite values, got {actual_finite}" + ) + + # Test from HLC directly + ma_rtr_hlc = ma_relative_true_range_from_hlc(high, low, close, period=period).to_numpy() + actual_finite_hlc = np.isfinite(ma_rtr_hlc).sum() + + assert actual_finite_hlc == expected_finite, ( + f"MA Relative True Range from HLC must have {expected_finite} finite values, got {actual_finite_hlc}" + ) + + +class TestEWMATrueRange: + """Test cases for EWMA true range calculations.""" + + @pytest.mark.parametrize("alpha", SMOOTHING_PARS) + def test_ewma_relative_true_range_properties(self, alpha): + """Test EWMA of relative true range has correct variance properties.""" + high, low, close = hlc_per_variance(0.1) + rtr = relative_true_range(high=high, low=low, close=close) + var_rtr = rtr.var() + + # Test EWMA RTR directly from RTR + ewma_rtr = ewma_relative_true_range(rtr=rtr, alpha=alpha) + finite_count = np.isfinite(ewma_rtr.to_numpy()).sum() + expected_finite = len(ewma_rtr) - 1 + + assert finite_count == expected_finite, ( + f"All elements of EWMA RTR, except the first, must be finite. " + f"Expected {expected_finite}, got {finite_count}" + ) + + # Check variance reduction follows EWMA theory + var_ewma_rtr = ewma_rtr.var() + expected_var_ewma_rtr = (1 - alpha) / (1 + alpha) * var_rtr + + assert var_ewma_rtr == pytest.approx(expected_var_ewma_rtr, rel=1e-2), ( + f"Variance of EWMA RTR must be approximately " + f"{(1 - alpha) / (1 + alpha)} times the variance of RTR" + ) + + @pytest.mark.parametrize("alpha", SMOOTHING_PARS) + def test_ewma_relative_true_range_from_hlc(self, alpha): + """Test EWMA relative true range calculated directly from HLC data.""" + high, low, close = hlc_per_variance(0.1) + rtr = relative_true_range(high=high, low=low, close=close) + var_rtr = rtr.var() + + # Test EWMA RTR from HLC + ewma_rtr_hlc = ewma_relative_true_range_from_hlc( + high=high, low=low, close=close, alpha=alpha + ) + finite_count = np.isfinite(ewma_rtr_hlc.to_numpy()).sum() + expected_finite = len(ewma_rtr_hlc) - 1 + + assert finite_count == expected_finite, ( + f"All elements of EWMA RTR from HLC, except the first, must be finite. " + f"Expected {expected_finite}, got {finite_count}" + ) + + # Check variance reduction follows EWMA theory + var_ewma_rtr_hlc = ewma_rtr_hlc.var() + expected_var_ewma_rtr = (1 - alpha) / (1 + alpha) * var_rtr + + assert var_ewma_rtr_hlc == pytest.approx(expected_var_ewma_rtr, rel=1e-2), ( + f"Variance of EWMA RTR from HLC must be approximately " + f"{(1 - alpha) / (1 + alpha)} times the variance of RTR" + ) \ No newline at end of file diff --git a/trendfollower/__init__.py b/trendfollower/__init__.py index e69de29..108f681 100644 --- a/trendfollower/__init__.py +++ b/trendfollower/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2025 Mohammadjavad Vakili, All rights reserved. +"""Python package implementing European-Style, +American-Style, and Time-Series-Momentum Trend Following Systems.""" \ No newline at end of file diff --git a/trendfollower/batch/__init__.py b/trendfollower/batch/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/trendfollower/core/__init__.py b/trendfollower/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/trendfollower/core/american_tf.py b/trendfollower/core/american_tf.py new file mode 100644 index 0000000..8487141 --- /dev/null +++ b/trendfollower/core/american_tf.py @@ -0,0 +1,321 @@ +# Copyright (c) 2025 Mohammadjavad Vakili, All rights reserved. + +"""Implementation of the building blocks of an American Trend Following System.""" + + +from collections.abc import Callable +from typing import Literal, Self + +from pydantic import BaseModel, Field, model_validator + + +class MarketData(BaseModel): + """Market data for a single time point.""" + + model_config = {"frozen": True} # Make immutable + + price: float = Field(gt=0, description="Current price.") + fast_ema: float = Field(gt=0, description="Fast EMA.") + slow_ema: float = Field(gt=0, description="Slow EMA.") + atr: float = Field(gt=0, description="Average True Range.") + + +class StrategyParams(BaseModel): + """Strategy parameters for American Trend Follower.""" + + model_config = {"frozen": True} # Make immutable + + entry_point_width: float = Field(gt=0, description="Entry point width as a multiple of ATR.") + stop_loss_width: float = Field(gt=0, description="Stop loss width as a multiple of ATR.") + risk_multiple: float = Field(gt=0, description="Risk multiple for position sizing.") + + +class AmericanTFState(BaseModel): + """State of the American Trend Follower strategy.""" + + model_config = {"frozen": True} # Make immutable + + position: Literal[-1, 0, 1] = Field( + default=0, + description="Current position: 1 for long, -1 for short, 0 for neutral.", + ) + position_size: float = Field( + default=0, + description="Size of the current position.", + ) + stop_loss: float | None = Field(default=None, ge=0, description="Stop loss level.") + entry_price: float | None = Field( + default=None, gt=0, + description="Entry price of the position.", + ) + close_next_day: bool = Field( + default=False, + description="Whether to close the position the next day. True if stop-loss was hit but signal is still on.", + ) + + @model_validator(mode="after") + def validate_position_consistency(self) -> Self: + """Ensure position, position_size, stop_loss, and entry_price are consistent. + + Returns + ------- + Self + The validated state. + + Raises + ------ + ValueError + If any inconsistency is found. + + """ + invalid_neutral_position = ( + self.position == 0 and ( + (self.position_size != 0) or (self.stop_loss is not None) or (self.entry_price is not None) + ) + ) + invalid_short_position = ( + self.position == -1 and ( + (self.position_size >= 0) or (self.stop_loss is None) or (self.entry_price is None) + ) + ) + invalid_long_position = ( + self.position == 1 and ( + (self.position_size <= 0) or (self.stop_loss is None) or (self.entry_price is None) + ) + ) + if invalid_neutral_position: + err_msg = "For neutral position, position_size must be 0, stop_loss and entry_price must be None." + raise ValueError(err_msg) + if invalid_short_position: + err_msg = "For short position, position_size must be negative, stop_loss and entry_price must be set." + raise ValueError(err_msg) + if invalid_long_position: + err_msg = "For long position, position_size must be positive, stop_loss and entry_price must be set." + raise ValueError(err_msg) + return self + + +POSITION_HANDLER = Callable[[MarketData, StrategyParams, AmericanTFState], AmericanTFState] + + +def handle_neutral_position( + market: MarketData, params: StrategyParams, state: AmericanTFState | None = None, +) -> AmericanTFState: + """Update state of an American Trend Follower strategy when in a neutral position. + + Parameters + ---------- + market : MarketData + Current market data, containing price, fast EMA, slow EMA, and ATR. + params : StrategyParams + Strategy parameters, including entry point width, stop loss width, and risk multiple. + state : AmericanTFState | None, optional + Current state of the strategy. Not used in this handler, by default None. + Not being used because in neutral state, previous state does not affect the decision. + It is included for conformity with POSITION_HANDLER type. + + Returns + ------- + AmericanTFState + Updated state of the strategy. + + """ + signal_buffer: float = params.entry_point_width * market.atr + stop_loss_buffer: float = params.stop_loss_width * market.atr + position_size_magnitude: float = params.risk_multiple * market.price / market.atr + # Long entry upon long signal + if market.fast_ema > market.slow_ema + signal_buffer: + return AmericanTFState( + position=1, + position_size=position_size_magnitude, + stop_loss=market.price - stop_loss_buffer, + entry_price=market.price, + close_next_day=False, + ) + # Short entry upon short signal + if market.fast_ema < market.slow_ema - signal_buffer: + return AmericanTFState( + position=-1, + position_size=-position_size_magnitude, + stop_loss=market.price + stop_loss_buffer, + entry_price=market.price, + close_next_day=False, + ) + # Remain neutral if no signal + return AmericanTFState( + position=0, + position_size=0.0, + stop_loss=None, + entry_price=None, + close_next_day=False, + ) + + +def handle_long_position(market: MarketData, params: StrategyParams, state: AmericanTFState) -> AmericanTFState: + """Update state of an American Trend Follower strategy when in a long position. + + Parameters + ---------- + market : MarketData + Current market data, containing price, fast EMA, slow EMA, and ATR. + params : StrategyParams + Strategy parameters, including entry point width, stop loss width, and risk multiple. + state : AmericanTFState + Current state of the strategy. + + Returns + ------- + AmericanTFState + Updated state of the strategy. + + Raises + ------ + ValueError + If stop_loss is not set for an open position. + + """ + if state.stop_loss is None: + msg = "stop_loss must be set for open positions" + raise ValueError(msg) + # Check if position was marked for closure from previous day + if state.close_next_day: + return AmericanTFState( + position=0, + position_size=0.0, + stop_loss=None, + entry_price=None, + close_next_day=False, + ) + + signal_buffer = params.entry_point_width * market.atr + stop_loss_buffer = params.stop_loss_width * market.atr + + if market.price < state.stop_loss: # Stop-loss is breached + if market.fast_ema <= market.slow_ema + signal_buffer: # Signal is off, close position immediately + return AmericanTFState( + position=0, + position_size=0.0, + stop_loss=None, + entry_price=None, + close_next_day=False, + ) + return state.model_copy(update={"close_next_day": True}) + # Stop-loss is not breached, update the stop-loss level + new_stop_loss = max(state.stop_loss, market.price - stop_loss_buffer) + return state.model_copy(update={"stop_loss": new_stop_loss}) + + +def handle_short_position(market: MarketData, params: StrategyParams, state: AmericanTFState) -> AmericanTFState: + """Update state of an American Trend Follower strategy when in a short position. + + Parameters + ---------- + market : MarketData + Current market data, containing price, fast EMA, slow EMA, and ATR. + params : StrategyParams + Strategy parameters, including entry point width, stop loss width, and risk multiple. + state : AmericanTFState + Current state of the strategy. + + Returns + ------- + AmericanTFState + Updated state of the strategy. + + Raises + ------ + ValueError + If stop_loss is not set for an open position. + + """ + if state.stop_loss is None: + msg = "stop_loss must be set for open positions" + raise ValueError(msg) + # Check if position was marked for closure from previous day + if state.close_next_day: + return AmericanTFState( + position=0, + position_size=0, + stop_loss=None, + entry_price=None, + close_next_day=False, + ) + + signal_buffer = params.entry_point_width * market.atr + stop_loss_buffer = params.stop_loss_width * market.atr + + if market.price > state.stop_loss: # Stop-loss is breached + if market.fast_ema >= market.slow_ema - signal_buffer: # Signal is off, close position immediately + return AmericanTFState( + position=0, + position_size=0.0, + stop_loss=None, + entry_price=None, + close_next_day=False, + ) + # Signal is still on, keep position but mark for closure next day + return state.model_copy(update={"close_next_day": True}) + # Stop-loss is not breached, update the stop-loss level + new_stop_loss = min(state.stop_loss, market.price + stop_loss_buffer) + return state.model_copy(update={"stop_loss": new_stop_loss}) + + +class AmericanTrendFollower: + """American-style trend following strategy. + + Parameters + ---------- + entry_point_width : float + Entry point width as a multiple of ATR. + stop_loss_width : float + Stop loss width as a multiple of ATR. + risk_multiple : float + Risk multiple for position sizing. + + Methods + ------- + update_state(s, s_fast, s_slow, atr, state) -> AmericanTFState + Update the state of the strategy given current market data and previous state. + + """ + + def __init__(self, entry_point_width: float, stop_loss_width: float, risk_multiple: float) -> None: + + self.params = StrategyParams( + entry_point_width=entry_point_width, + stop_loss_width=stop_loss_width, + risk_multiple=risk_multiple, + ) + self.handlers: dict[Literal[-1, 0, 1], POSITION_HANDLER] = { + 0: handle_neutral_position, + 1: handle_long_position, + -1: handle_short_position, + } + + def update_state( + self, s: float, s_fast: float, s_slow: float, atr: float, state: AmericanTFState + ) -> AmericanTFState: + """Update the state of the American Trend Follower strategy. + + Parameters + ---------- + s : float + Current price. + s_fast : float + Current fast EMA. + s_slow : float + Current slow EMA. + atr : float + Current ATR. + state : AmericanTFState + Current state of the strategy. + + Returns + ------- + AmericanTFState + Updated state of the strategy. + + """ + market = MarketData(price=s, fast_ema=s_fast, slow_ema=s_slow, atr=atr) + handler = self.handlers[state.position] + return handler(market, self.params, state) diff --git a/trendfollower/core/filters.py b/trendfollower/core/filters.py new file mode 100644 index 0000000..97f2727 --- /dev/null +++ b/trendfollower/core/filters.py @@ -0,0 +1,185 @@ +# Copyright (c) 2025, Mohammadjavad Vakili +# All rights reserved. +# This code is licensed under the MIT License. + +"""Implementation of exponential weighted moving average filters. + +Implemented functions: + +`ewma` + +Exponentially-weighted moving average filter. + +`variance_preserving_ewma` + +Variance preserving exponential moving average filter. + +`long_short_variance_preserving_ewma` + +Long-short variance preserving exponential moving average filter. +""" + +import polars as pl + + +def _validate_nu_or_span(nu: float | None, span: int | None) -> float: + """Validate and convert nu/span parameters to nu. + + Parameters + ---------- + nu : float | None + The smoothing parameter (0 < nu < 1). + span : int | None + The span of the moving average (positive integer). + + Returns + ------- + float + The validated nu parameter. + + Raises + ------ + ValueError + If neither nu nor span is provided, or if parameters are invalid. + + """ + if nu is None and span is None: + msg = "Either 'nu' or 'span' must be provided." + raise ValueError(msg) + if nu is not None and not (0 < nu < 1): + msg = f"Invalid nu={nu}: must be between 0 and 1 (exclusive)." + raise ValueError(msg) + if span is not None and (span <= 0 or span != int(span)): + msg = f"Invalid span={span}: must be a positive integer." + raise ValueError(msg) + return nu if nu is not None else 1 - 2 / (span + 1) + + +def ewma(z: pl.Series, nu: float | None = None, span: int | None = None) -> pl.Series: + r"""Calculate exponential moving average filter of a series. + + The smoothing parameter `nu` and the span are related by: + + .. math:: + + \nu = 1 - \frac{2}{\text{span} + 1} + + If both `nu` and `span` are provided, `nu` takes precedence. + If neither is provided, a ValueError is raised. + + Parameters + ---------- + z : pl.Series + The input time series. + nu : float | None + The smoothing parameter. + span : int | None + The span of the moving average (number of periods), defaults to None. + span is related to nu by: nu = 1 - 2 / (span + 1). + + Returns + ------- + pl.Series + The exponentially weighted moving average of the input series. + + References + ---------- + .. [1] https://docs.pola.rs/api/python/stable/reference/series/api/polars.Series.ewm_mean.html + .. [2] Eq. (3) Science & Practice of trend-following systems. + .. [3] Eq. (6) Science & Practice of trend-following systems. + + + """ + nu = _validate_nu_or_span(nu, span) + return z.ewm_mean(alpha=1 - nu) + + +def variance_preserving_ewma(z: pl.Series, nu: float | None = None, span: int | None = None) -> pl.Series: + r"""Calculate variance preserving exponential moving average filter of a series. + + .. math:: + \text{VP-EWMA}(z_t) = \text{EWMA}(z_t) \sqrt{\frac{1 + \nu}{1 - \nu}} + + Parameters + ---------- + z : pl.Series + The input time series. + nu : float | None + The smoothing factor (0 < nu < 1). This is related + to polars' ewm_mean's alpha by: alpha = 1 - nu. + span : int | None + The span of the moving average (number of periods), defaults to None. + span is related to nu by: nu = 1 - 2 / (span + 1). + If both `nu` and `span` are provided, `nu` takes precedence. + If neither is provided, a ValueError is raised. + + Returns + ------- + pl.Series + The variance preserving exponentially weighted moving average of the input series. + + References + ---------- + .. [1] https://docs.pola.rs/api/python/stable/reference/series/api/polars.Series.ewm_mean.html + .. [2] Eq. (4) Science & Practice of trend-following systems. + .. [3] Eq. (7) Science & Practice of trend-following systems. + + """ + nu = _validate_nu_or_span(nu, span) + return ewma(z=z, nu=nu) * ((1.0 + nu) / (1.0 - nu)) ** 0.5 + + +def long_short_variance_preserving_ewma( + z: pl.Series, + nu1: float | None = None, + span1: int | None = None, + nu2: float | None = None, + span2: int | None = None, +) -> pl.Series: + r"""Calculate long-short variance preserving exponential moving average filter of a series. + + ... math:: + \text{LS-VP-EWMA}(z_t) = l_1 \text{EWMA}(z_t, \nu_1) - l_2 \text{EWMA}(z_t, \nu_2) + \\text{where} \quad l_1 = \frac{q}{1 - \nu_1}, \quad l_2 = \frac{q}{1 - \nu_2}, \quad + q = \left(\frac{1}{1 - \nu_1^2} + \frac{1}{1 - \nu_2^2} - \frac{2}{1 - \nu_1 \nu_2}\right)^{-0.5} + + Parameters + ---------- + z : pl.Series + The input time series. + nu1 : float | None + The smoothing factor for the long position (0 < nu1 < 1). + span1 : int | None + The span for the long position (alternative to nu1). + nu2 : float | None + The smoothing factor for the short position (0 < nu2 < 1). + span2 : int | None + The span for the short position (alternative to nu2). + + Returns + ------- + pl.Series + The long-short variance preserving exponentially weighted moving average of the input series. + + Raises + ------ + ValueError + If nu1 and nu2 are equal or invalid. + + References + ---------- + .. [1] https://docs.pola.rs/api/python/stable/reference/series/api/polars.Series.ewm_mean.html + .. [2] Eq. (8) Science & Practice of trend-following systems. + .. [3] Eq. (9) Science & Practice of trend-following systems. + In Eq. 9 of paper, there is a typo: q should be 1 / q. + + """ + nu1 = _validate_nu_or_span(nu1, span1) + nu2 = _validate_nu_or_span(nu2, span2) + if nu1 == nu2: + msg = "nu1 and nu2 must be different. When they are equal, the long-short filter is ill-defined." + raise ValueError(msg) + q = (1 / (1 - nu1**2.) + 1 / (1 - nu2**2.) - 2 / (1 - nu1 * nu2)) ** -.5 + l1 = q / (1 - nu1) + l2 = q / (1 - nu2) + return l1 * ewma(z=z, nu=nu1) - l2 * ewma(z=z, nu=nu2) diff --git a/trendfollower/volatility.py b/trendfollower/core/volatility.py similarity index 95% rename from trendfollower/volatility.py rename to trendfollower/core/volatility.py index af7083a..eb8b229 100644 --- a/trendfollower/volatility.py +++ b/trendfollower/core/volatility.py @@ -18,7 +18,7 @@ import polars as pl -from trendfollower.filters import ewma +from trendfollower.core.filters import ewma def lag1_diff(z: pl.Series) -> pl.Series: diff --git a/trendfollower/filters.py b/trendfollower/filters.py deleted file mode 100644 index 7f9ce1e..0000000 --- a/trendfollower/filters.py +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright (c) 2025, Mohammadjavad Vakili -# All rights reserved. -# This code is licensed under the MIT License. - -"""Filter functions for time series data. - -Implemented functions: - -`ewma` - -Exponentially-weighted moving average filter. - -`variance_preserving_ewma` - -Variance preserving exponential moving average filter. - -`long_short_variance_preserving_ewma` - -Long-short variance preserving exponential moving average filter. -""" - -import polars as pl - - -def ewma(z: pl.Series, alpha: float) -> pl.Series: - """Calculate exponential moving average filter of a series. - - Parameters - ---------- - z : pl.Series - The input time series. - alpha : float - The smoothing factor (0 < alpha < 1). - - Returns - ------- - pl.Series - The exponentially weighted moving average of the input series. - - References - ---------- - .. [1] https://docs.pola.rs/api/python/stable/reference/series/api/polars.Series.ewm_mean.html - - """ - return z.ewm_mean(alpha=1 - alpha) - - -def variance_preserving_ewma(z: pl.Series, alpha: float) -> pl.Series: - """Calculate variance preserving exponential moving average filter of a series. - - Parameters - ---------- - z : pl.Series - The input time series. - alpha : float - The smoothing factor (0 < alpha < 1). - - Returns - ------- - pl.Series - The variance preserving exponentially weighted moving average of the input series. - - References - ---------- - .. [1] https://docs.pola.rs/api/python/stable/reference/series/api/polars.Series.ewm_mean.html - - """ - return ewma(z=z, alpha=alpha) * ((1 + alpha) / (1 - alpha))**.5 - - -def long_short_variance_preserving_ewma(z: pl.Series, alpha1: float, alpha2: float) -> pl.Series: - """Calculate long-short variance preserving exponential moving average filter of a series. - - Parameters - ---------- - z : pl.Series - The input time series. - alpha1 : float - The smoothing factor for the long position (0 < alpha1 < 1). - alpha2 : float - The smoothing factor for the short position (0 < alpha2 < 1). - - Returns - ------- - pl.Series - The long-short variance preserving exponentially weighted moving average of the input series. - - Raises - ------ - ValueError - If alpha1 and alpha2 are equal. - - References - ---------- - .. [1] https://docs.pola.rs/api/python/stable/reference/series/api/polars.Series.ewm_mean.html - In Eq. 9 of paper, there is a typo: q should be 1 / q. - - """ - if alpha1 == alpha2: - msg = "alpha1 and alpha2 must be different. When they are equal, the long-short filter is ill-defined." - raise ValueError(msg) - q = (1 / (1 - alpha1**2.) + 1 / (1 - alpha2**2.) - 2 / (1 - alpha1 * alpha2)) ** -.5 - l1 = q / (1 - alpha1) - l2 = q / (1 - alpha2) - return l1 * ewma(z=z, alpha=alpha1) - l2 * ewma(z=z, alpha=alpha2) diff --git a/trendfollower/streaming/__init__.py b/trendfollower/streaming/__init__.py new file mode 100644 index 0000000..e69de29