20
20
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
21
# SOFTWARE.
22
22
23
- from typing import Dict , List , Optional , Set , Union
23
+ from abc import ABC
24
+ from typing import List , Optional , Set , Union
24
25
25
26
from numpy import array
26
27
from numpy .random import choice
27
28
from numpy .typing import ArrayLike
28
29
30
+ from pybandits .actions_manager import CmabActionsManager
29
31
from pybandits .base import ActionId , BinaryReward , CmabPredictions
30
32
from pybandits .mab import BaseMab
31
33
from pybandits .model import BayesianLogisticRegression , BayesianLogisticRegressionCC
32
- from pybandits .pydantic_version_compatibility import field_validator , validate_call
34
+ from pybandits .pydantic_version_compatibility import validate_call
33
35
from pybandits .strategy import (
34
36
BestActionIdentificationBandit ,
35
37
ClassicBandit ,
36
38
CostControlBandit ,
37
39
)
38
40
39
41
40
- class BaseCmabBernoulli (BaseMab ):
42
+ class BaseCmabBernoulli (BaseMab , ABC ):
41
43
"""
42
44
Base model for a Contextual Multi-Armed Bandit for Bernoulli bandits with Thompson Sampling.
43
45
@@ -54,27 +56,10 @@ class BaseCmabBernoulli(BaseMab):
54
56
bandit strategy.
55
57
"""
56
58
57
- actions : Dict [ ActionId , BayesianLogisticRegression ]
59
+ actions_manager : CmabActionsManager [ BayesianLogisticRegression ]
58
60
predict_with_proba : bool
59
61
predict_actions_randomly : bool
60
62
61
- @field_validator ("actions" , mode = "after" )
62
- @classmethod
63
- def check_bayesian_logistic_regression_models (cls , v ):
64
- action_models = list (v .values ())
65
- first_action = action_models [0 ]
66
- first_action_type = type (first_action )
67
- for action in action_models [1 :]:
68
- if not isinstance (action , first_action_type ):
69
- raise AttributeError ("All actions should follow the same type." )
70
- if not len (action .betas ) == len (first_action .betas ):
71
- raise AttributeError ("All actions should have the same number of betas." )
72
- if not action .update_method == first_action .update_method :
73
- raise AttributeError ("All actions should have the same update method." )
74
- if not action .update_kwargs == first_action .update_kwargs :
75
- raise AttributeError ("All actions should have the same update kwargs." )
76
- return v
77
-
78
63
@validate_call (config = dict (arbitrary_types_allowed = True ))
79
64
def predict (
80
65
self ,
@@ -169,20 +154,7 @@ def update(
169
154
If strategy is MultiObjectiveBandit, rewards should be a list of list, e.g. (with n_objectives=2):
170
155
rewards = [[1, 1], [1, 0], [1, 1], [1, 0], [1, 1], ...]
171
156
"""
172
- self ._validate_update_params (actions = actions , rewards = rewards )
173
- if len (context ) != len (rewards ):
174
- raise AttributeError (f"Shape mismatch: actions and rewards should have the same length { len (actions )} ." )
175
-
176
- # cast inputs to numpy arrays to facilitate their manipulation
177
- context , actions , rewards = array (context ), array (actions ), array (rewards )
178
-
179
- for a in set (actions ):
180
- # get context and rewards of the samples associated to action a
181
- context_of_a = context [actions == a ]
182
- rewards_of_a = rewards [actions == a ].tolist ()
183
-
184
- # update model associated to action a
185
- self .actions [a ].update (context = context_of_a , rewards = rewards_of_a )
157
+ super ().update (actions = actions , rewards = rewards , context = context )
186
158
187
159
# always set predict_actions_randomly after update
188
160
self .predict_actions_randomly = False
@@ -208,7 +180,7 @@ class CmabBernoulli(BaseCmabBernoulli):
208
180
bandit strategy.
209
181
"""
210
182
211
- actions : Dict [ ActionId , BayesianLogisticRegression ]
183
+ actions_manager : CmabActionsManager [ BayesianLogisticRegression ]
212
184
strategy : ClassicBandit
213
185
predict_with_proba : bool = False
214
186
predict_actions_randomly : bool = False
@@ -234,7 +206,7 @@ class CmabBernoulliBAI(BaseCmabBernoulli):
234
206
bandit strategy.
235
207
"""
236
208
237
- actions : Dict [ ActionId , BayesianLogisticRegression ]
209
+ actions_manager : CmabActionsManager [ BayesianLogisticRegression ]
238
210
strategy : BestActionIdentificationBandit
239
211
predict_with_proba : bool = False
240
212
predict_actions_randomly : bool = False
@@ -268,7 +240,7 @@ class CmabBernoulliCC(BaseCmabBernoulli):
268
240
bandit strategy.
269
241
"""
270
242
271
- actions : Dict [ ActionId , BayesianLogisticRegressionCC ]
243
+ actions_manager : CmabActionsManager [ BayesianLogisticRegressionCC ]
272
244
strategy : CostControlBandit
273
245
predict_with_proba : bool = True
274
246
predict_actions_randomly : bool = False
0 commit comments