Skip to content

Commit 259f36a

Browse files
committed
Adaptive Windowing for Multi-Armed Bandits
### Changes: * Added adaptive windowing mechanism to detect and handle concept drift in MAB models. * Introduced ActionsManager class to handle action memory and updates with configurable window sizes. * Refactored Model class hierarchy to support model resetting and memory management. * Added support for infinite and fixed-size windows with change detection via delta parameter. * Enhanced test coverage for adaptive windowing functionality across MAB variants.
1 parent 64913ef commit 259f36a

15 files changed

+2182
-538
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,4 @@ MANIFEST
6565

6666
# poetry
6767
poetry.lock
68+
.qodo

pybandits/actions_manager.py

Lines changed: 626 additions & 0 deletions
Large diffs are not rendered by default.

pybandits/base.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
# SOFTWARE.
2222

2323

24-
from typing import Any, Dict, List, NewType, Tuple, Union
24+
from typing import Any, Dict, List, NewType, Tuple, Union, _GenericAlias, get_args, get_origin
2525

2626
from pybandits.pydantic_version_compatibility import (
2727
PYDANTIC_VERSION_1,
@@ -52,6 +52,7 @@
5252
Union[Dict[ActionId, float], Dict[ActionId, Probability], Dict[ActionId, List[Probability]]],
5353
)
5454
ACTION_IDS_PREFIX = "action_ids_"
55+
ACTIONS = "actions"
5556

5657

5758
class _classproperty(property):
@@ -96,6 +97,18 @@ def _apply_version_adjusted_method(self, v2_method_name: str, v1_method_name: st
9697
def _get_value_with_default(cls, key: str, values: Dict[str, Any]) -> Any:
9798
return values.get(key, cls.model_fields[key].default)
9899

100+
@classmethod
101+
def _get_field_type(cls, key: str) -> Any:
102+
if pydantic_version == PYDANTIC_VERSION_1:
103+
annotation = cls.model_fields[key].type_
104+
elif pydantic_version == PYDANTIC_VERSION_2:
105+
annotation = cls.model_fields[key].annotation
106+
if isinstance(annotation, _GenericAlias) and get_origin(annotation) is dict:
107+
annotation = get_args(annotation)[1] # refer to the type of the Dict values
108+
else:
109+
raise ValueError(f"Unsupported pydantic version: {pydantic_version}")
110+
return annotation
111+
99112
if pydantic_version == PYDANTIC_VERSION_1:
100113

101114
@_classproperty

pybandits/cmab.py

Lines changed: 10 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,26 @@
2020
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121
# SOFTWARE.
2222

23-
from typing import Dict, List, Optional, Set, Union
23+
from abc import ABC
24+
from typing import List, Optional, Set, Union
2425

2526
from numpy import array
2627
from numpy.random import choice
2728
from numpy.typing import ArrayLike
2829

30+
from pybandits.actions_manager import CmabActionsManager
2931
from pybandits.base import ActionId, BinaryReward, CmabPredictions
3032
from pybandits.mab import BaseMab
3133
from pybandits.model import BayesianLogisticRegression, BayesianLogisticRegressionCC
32-
from pybandits.pydantic_version_compatibility import field_validator, validate_call
34+
from pybandits.pydantic_version_compatibility import validate_call
3335
from pybandits.strategy import (
3436
BestActionIdentificationBandit,
3537
ClassicBandit,
3638
CostControlBandit,
3739
)
3840

3941

40-
class BaseCmabBernoulli(BaseMab):
42+
class BaseCmabBernoulli(BaseMab, ABC):
4143
"""
4244
Base model for a Contextual Multi-Armed Bandit for Bernoulli bandits with Thompson Sampling.
4345
@@ -54,27 +56,10 @@ class BaseCmabBernoulli(BaseMab):
5456
bandit strategy.
5557
"""
5658

57-
actions: Dict[ActionId, BayesianLogisticRegression]
59+
actions_manager: CmabActionsManager[BayesianLogisticRegression]
5860
predict_with_proba: bool
5961
predict_actions_randomly: bool
6062

61-
@field_validator("actions", mode="after")
62-
@classmethod
63-
def check_bayesian_logistic_regression_models(cls, v):
64-
action_models = list(v.values())
65-
first_action = action_models[0]
66-
first_action_type = type(first_action)
67-
for action in action_models[1:]:
68-
if not isinstance(action, first_action_type):
69-
raise AttributeError("All actions should follow the same type.")
70-
if not len(action.betas) == len(first_action.betas):
71-
raise AttributeError("All actions should have the same number of betas.")
72-
if not action.update_method == first_action.update_method:
73-
raise AttributeError("All actions should have the same update method.")
74-
if not action.update_kwargs == first_action.update_kwargs:
75-
raise AttributeError("All actions should have the same update kwargs.")
76-
return v
77-
7863
@validate_call(config=dict(arbitrary_types_allowed=True))
7964
def predict(
8065
self,
@@ -169,20 +154,7 @@ def update(
169154
If strategy is MultiObjectiveBandit, rewards should be a list of list, e.g. (with n_objectives=2):
170155
rewards = [[1, 1], [1, 0], [1, 1], [1, 0], [1, 1], ...]
171156
"""
172-
self._validate_update_params(actions=actions, rewards=rewards)
173-
if len(context) != len(rewards):
174-
raise AttributeError(f"Shape mismatch: actions and rewards should have the same length {len(actions)}.")
175-
176-
# cast inputs to numpy arrays to facilitate their manipulation
177-
context, actions, rewards = array(context), array(actions), array(rewards)
178-
179-
for a in set(actions):
180-
# get context and rewards of the samples associated to action a
181-
context_of_a = context[actions == a]
182-
rewards_of_a = rewards[actions == a].tolist()
183-
184-
# update model associated to action a
185-
self.actions[a].update(context=context_of_a, rewards=rewards_of_a)
157+
super().update(actions=actions, rewards=rewards, context=context)
186158

187159
# always set predict_actions_randomly after update
188160
self.predict_actions_randomly = False
@@ -208,7 +180,7 @@ class CmabBernoulli(BaseCmabBernoulli):
208180
bandit strategy.
209181
"""
210182

211-
actions: Dict[ActionId, BayesianLogisticRegression]
183+
actions_manager: CmabActionsManager[BayesianLogisticRegression]
212184
strategy: ClassicBandit
213185
predict_with_proba: bool = False
214186
predict_actions_randomly: bool = False
@@ -234,7 +206,7 @@ class CmabBernoulliBAI(BaseCmabBernoulli):
234206
bandit strategy.
235207
"""
236208

237-
actions: Dict[ActionId, BayesianLogisticRegression]
209+
actions_manager: CmabActionsManager[BayesianLogisticRegression]
238210
strategy: BestActionIdentificationBandit
239211
predict_with_proba: bool = False
240212
predict_actions_randomly: bool = False
@@ -268,7 +240,7 @@ class CmabBernoulliCC(BaseCmabBernoulli):
268240
bandit strategy.
269241
"""
270242

271-
actions: Dict[ActionId, BayesianLogisticRegressionCC]
243+
actions_manager: CmabActionsManager[BayesianLogisticRegressionCC]
272244
strategy: CostControlBandit
273245
predict_with_proba: bool = True
274246
predict_actions_randomly: bool = False

0 commit comments

Comments
 (0)