Skip to content

Commit 3f1bb19

Browse files
committed
Refactor MAB and Strategy Classes with Cold Start Methods and Enhanced Validation
Change log: 1. Moved Strategy, Model, and MAB to strategy.py, model.py, and to the new mab.py. base.py is now only for definitions and abstract PyBanditsBaseModel. The abstract MAB now allows for all childs to either accept strategy instance as parameter, or to get the strategy parameters and instantiate correspondingly. 2. The from_state functionality is now directly inherited by all MABs from BaseMab. 3. Replaced all cold_start methods in cmab.py and smab.py with cold_start stemming from BaseMab. Correspondingly, updated test cases to use the new cold_start_instantiate methods. 4. Introduced numerize_field and get_expected_value_from_state methods in the Strategy class to handle default values and state extraction. Added field_validator for exploit_p in BestActionIdentification and subsidy_factor in CostControlBandit to ensure proper default handling and validation. 5. Merged common functionality into a new CostControlStrategy abstract class, which is now inherited by CostControlBandit and MultiObjectiveCostControlBandit. Simplified the select_action methods by using helper methods like _evaluate_and_select and _reduce. 6. Plugged get_pareto_front into a new MultiObjectiveStrategy abstract class, which is now inherited by MultiObjectiveBandit and MultiObjectiveCostControlBandit. 7. In model.py. Removed the redundant BaseBetaMO and BaseBayesianLogisticRegression. Added cold_start_instantiate method to BetaMO and BayesianLogisticRegression models. 8. Added extract_argument_names_from_function under utils.py to allow extract function parameter names by handle. 9. Changed test_base.py into test_mab.py. 10. Updated deprecated linter settings in pyproject.toml. 11. Added test_smab_mo_cc_update test on test_smab.py. 12. Changed version to 1.0.0 on pyproject.toml.
1 parent 9c15f78 commit 3f1bb19

File tree

17 files changed

+1152
-1319
lines changed

17 files changed

+1152
-1319
lines changed

.github/workflows/continuous_delivery.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
strategy:
1212
fail-fast: false
1313
matrix:
14-
python-version: [ 3.8, 3.9 ]
14+
python-version: [ "3.8", "3.9", "3.10" ]
1515

1616
steps:
1717
- name: Checkout repository

.github/workflows/continuous_integration.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
strategy:
2121
fail-fast: false
2222
matrix:
23-
python-version: [3.8, 3.9]
23+
python-version: [ "3.8", "3.9", "3.10" ]
2424

2525
steps:
2626
- name: Checkout repository

docs/tutorials/mab.ipynb

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
"from rich import print\n",
2121
"\n",
2222
"from pybandits.model import Beta\n",
23-
"from pybandits.smab import SmabBernoulli, create_smab_bernoulli_cold_start"
23+
"from pybandits.smab import SmabBernoulli"
2424
]
2525
},
2626
{
@@ -73,8 +73,6 @@
7373
"metadata": {},
7474
"outputs": [],
7575
"source": [
76-
"n_objectives = 2\n",
77-
"\n",
7876
"mab = SmabBernoulli(\n",
7977
" actions={\n",
8078
" \"a1\": Beta(n_successes=1, n_failures=1),\n",
@@ -137,7 +135,7 @@
137135
"id": "564914fd-73cc-4854-8ec7-548970f794a6",
138136
"metadata": {},
139137
"source": [
140-
"You can initialize the bandit via the utility function `create_smab_bernoulli_mo_cc_cold_start()`. This is particulary useful in a cold start setting when there is no prior knowledge on the Beta distruibutions. In this case for all Betas `n_successes` and `n_failures` are set to `1`."
138+
"You can initialize the bandit via the utility function `SmabBernoulliMOCC.cold_start()`. This is particulary useful in a cold start setting when there is no prior knowledge on the Beta distruibutions. In this case for all Betas `n_successes` and `n_failures` are set to `1`."
141139
]
142140
},
143141
{
@@ -148,7 +146,7 @@
148146
"outputs": [],
149147
"source": [
150148
"# generate a smab bernoulli in cold start settings\n",
151-
"mab = create_smab_bernoulli_cold_start(action_ids=[\"a1\", \"a2\", \"a3\"])"
149+
"mab = SmabBernoulli.cold_start(action_ids=[\"a1\", \"a2\", \"a3\"])"
152150
]
153151
},
154152
{

docs/tutorials/smab_mo_cc.ipynb

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
"from rich import print\n",
2121
"\n",
2222
"from pybandits.model import Beta, BetaMOCC\n",
23-
"from pybandits.smab import SmabBernoulliMOCC, create_smab_bernoulli_mo_cc_cold_start"
23+
"from pybandits.smab import SmabBernoulliMOCC"
2424
]
2525
},
2626
{
@@ -72,8 +72,6 @@
7272
"metadata": {},
7373
"outputs": [],
7474
"source": [
75-
"n_objectives = 2\n",
76-
"\n",
7775
"mab = SmabBernoulliMOCC(\n",
7876
" actions={\n",
7977
" \"a1\": BetaMOCC(counters=[Beta(n_successes=1, n_failures=1), Beta(n_successes=1, n_failures=1)], cost=30),\n",
@@ -153,7 +151,7 @@
153151
"id": "564914fd-73cc-4854-8ec7-548970f794a6",
154152
"metadata": {},
155153
"source": [
156-
"You can initialize the bandit via the utility function `create_smab_bernoulli_mo_cc_cold_start()`. This is particulary useful in a cold start setting when there is no prior knowledge on the Beta distruibutions. In this case for all Betas `n_successes` and `n_failures` are set to `1`."
154+
"You can initialize the bandit via the utility function `SmabBernoulliMOCC.cold_start()`. This is particulary useful in a cold start setting when there is no prior knowledge on the Beta distruibutions. In this case for all Betas `n_successes` and `n_failures` are set to `1`."
157155
]
158156
},
159157
{
@@ -165,10 +163,9 @@
165163
"source": [
166164
"# list of action IDs with their cost\n",
167165
"action_ids_cost = {\"a1\": 30, \"a2\": 10, \"a3\": 20}\n",
168-
"n_objectives = 2\n",
169166
"\n",
170167
"# generate a smab bernoulli in cold start settings\n",
171-
"mab = create_smab_bernoulli_mo_cc_cold_start(action_ids_cost=action_ids_cost, n_objectives=n_objectives)"
168+
"mab = SmabBernoulliMOCC.cold_start(action_ids_cost=action_ids_cost)"
172169
]
173170
},
174171
{

pybandits/base.py

Lines changed: 12 additions & 226 deletions
Original file line numberDiff line numberDiff line change
@@ -21,241 +21,27 @@
2121
# SOFTWARE.
2222

2323

24-
from abc import ABC, abstractmethod
25-
from typing import Any, Dict, List, NewType, Optional, Set, Tuple, Union
24+
from typing import Dict, List, NewType, Tuple, Union
2625

27-
import numpy as np
28-
from pydantic import (
29-
BaseModel,
30-
NonNegativeInt,
31-
confloat,
32-
conint,
33-
constr,
34-
field_validator,
35-
model_validator,
36-
validate_call,
37-
)
26+
from pydantic import BaseModel, confloat, conint, constr
3827

3928
ActionId = NewType("ActionId", constr(min_length=1))
4029
Float01 = NewType("Float_0_1", confloat(ge=0, le=1))
4130
Probability = NewType("Probability", Float01)
42-
Predictions = NewType("Predictions", Tuple[List[ActionId], List[Dict[ActionId, Probability]]])
31+
SmabPredictions = NewType("SmabPredictions", Tuple[List[ActionId], List[Dict[ActionId, Probability]]])
32+
CmabPredictions = NewType(
33+
"CmabPredictions", Tuple[List[ActionId], List[Dict[ActionId, Probability]], List[Dict[ActionId, float]]]
34+
)
35+
Predictions = NewType("Predictions", Union[SmabPredictions, CmabPredictions])
4336
BinaryReward = NewType("BinaryReward", conint(ge=0, le=1))
37+
ActionRewardLikelihood = NewType(
38+
"ActionRewardLikelihood",
39+
Union[Dict[ActionId, float], Dict[ActionId, Probability], Dict[ActionId, List[Probability]]],
40+
)
41+
ACTION_IDS_PREFIX = "action_ids_"
4442

4543

4644
class PyBanditsBaseModel(BaseModel, extra="forbid"):
4745
"""
4846
BaseModel of the PyBandits library.
4947
"""
50-
51-
52-
class Model(PyBanditsBaseModel, ABC):
53-
"""
54-
Class to model the prior distributions.
55-
"""
56-
57-
@abstractmethod
58-
def sample_proba(self) -> Probability:
59-
"""
60-
Sample the probability of getting a positive reward.
61-
"""
62-
63-
@abstractmethod
64-
def update(self, rewards: List[Any]):
65-
"""
66-
Update the model parameters.
67-
"""
68-
69-
70-
class Strategy(PyBanditsBaseModel, ABC):
71-
"""
72-
Strategy to select actions in multi-armed bandits.
73-
"""
74-
75-
@abstractmethod
76-
def select_action(self, p: Dict[ActionId, Probability], actions: Optional[Dict[ActionId, Model]]) -> ActionId:
77-
"""
78-
Select the action.
79-
"""
80-
81-
82-
class BaseMab(PyBanditsBaseModel, ABC):
83-
"""
84-
Multi-armed bandit superclass.
85-
86-
Parameters
87-
----------
88-
actions: Dict[ActionId, Model]
89-
The list of possible actions, and their associated Model.
90-
strategy: Strategy
91-
The strategy used to select actions.
92-
epsilon: Optional[Float01]
93-
The probability of selecting a random action.
94-
default_action: Optional[ActionId]
95-
The default action to select with a probability of epsilon when using the epsilon-greedy approach.
96-
If `default_action` is None, a random action from the action set will be selected with a probability of epsilon.
97-
"""
98-
99-
actions: Dict[ActionId, Model]
100-
strategy: Strategy
101-
epsilon: Optional[Float01]
102-
default_action: Optional[ActionId]
103-
104-
@field_validator("actions", mode="before")
105-
@classmethod
106-
def at_least_2_actions_are_defined(cls, v):
107-
# validate that at least 2 actions are defined
108-
if len(v) < 2:
109-
raise AttributeError("At least 2 actions should be defined.")
110-
# validate that all actions are of the same configuration
111-
action_models = list(v.values())
112-
first_action = action_models[0]
113-
first_action_type = type(first_action)
114-
if any(not isinstance(action, first_action_type) for action in action_models[1:]):
115-
raise AttributeError("All actions should follow the same type.")
116-
117-
return v
118-
119-
@model_validator(mode="after")
120-
def check_default_action(self):
121-
if not self.epsilon and self.default_action:
122-
raise AttributeError("A default action should only be defined when epsilon is defined.")
123-
if self.default_action and self.default_action not in self.actions:
124-
raise AttributeError("The default action should be defined in the actions.")
125-
return self
126-
127-
def _get_valid_actions(self, forbidden_actions: Optional[Set[ActionId]]) -> Set[ActionId]:
128-
"""
129-
Given a set of forbidden action IDs, return a set of valid action IDs.
130-
131-
Parameters
132-
----------
133-
forbidden_actions: Optional[Set[ActionId]]
134-
The set of forbidden action IDs.
135-
136-
Returns
137-
-------
138-
valid_actions: Set[ActionId]
139-
The list of valid (i.e. not forbidden) action IDs.
140-
"""
141-
if forbidden_actions is None:
142-
forbidden_actions = set()
143-
144-
if not all(a in self.actions.keys() for a in forbidden_actions):
145-
raise ValueError("forbidden_actions contains invalid action IDs.")
146-
valid_actions = set(self.actions.keys()) - forbidden_actions
147-
if len(valid_actions) == 0:
148-
raise ValueError("All actions are forbidden. You must allow at least 1 action.")
149-
if self.default_action and self.default_action not in valid_actions:
150-
raise ValueError("The default action is forbidden.")
151-
152-
return valid_actions
153-
154-
def _check_update_params(self, actions: List[ActionId], rewards: List[Union[NonNegativeInt, List[NonNegativeInt]]]):
155-
"""
156-
Verify that the given list of action IDs is a subset of the currently defined actions.
157-
158-
Parameters
159-
----------
160-
actions : List[ActionId]
161-
The selected action for each sample.
162-
rewards: List[Union[BinaryReward, List[BinaryReward]]]
163-
The reward for each sample.
164-
"""
165-
invalid = set(actions) - set(self.actions.keys())
166-
if invalid:
167-
raise AttributeError(f"The following invalid action(s) were specified: {invalid}.")
168-
if len(actions) != len(rewards):
169-
raise AttributeError(f"Shape mismatch: actions and rewards should have the same length {len(actions)}.")
170-
171-
@abstractmethod
172-
@validate_call
173-
def update(self, actions: List[ActionId], rewards: List[Union[BinaryReward, List[BinaryReward]]], *args, **kwargs):
174-
"""
175-
Update the stochastic multi-armed bandit model.
176-
177-
actions: List[ActionId]
178-
The selected action for each sample.
179-
rewards: List[Union[BinaryReward, List[BinaryReward]]]
180-
The reward for each sample.
181-
"""
182-
183-
@abstractmethod
184-
@validate_call
185-
def predict(self, forbidden_actions: Optional[Set[ActionId]] = None):
186-
"""
187-
Predict actions.
188-
189-
Parameters
190-
----------
191-
forbidden_actions : Optional[Set[ActionId]], default=None
192-
Set of forbidden actions. If specified, the model will discard the forbidden_actions and it will only
193-
consider the remaining allowed_actions. By default, the model considers all actions as allowed_actions.
194-
Note that: actions = allowed_actions U forbidden_actions.
195-
196-
Returns
197-
-------
198-
actions: List[ActionId] of shape (n_samples,)
199-
The actions selected by the multi-armed bandit model.
200-
probs: List[Dict[ActionId, float]] of shape (n_samples,)
201-
The probabilities of getting a positive reward for each action.
202-
"""
203-
204-
def get_state(self) -> (str, dict):
205-
"""
206-
Access the complete model internal state, enough to create an exact copy of the same model from it.
207-
Returns
208-
-------
209-
model_class_name: str
210-
The name of the class of the model.
211-
model_state: dict
212-
The internal state of the model (actions, scores, etc.).
213-
"""
214-
model_name = self.__class__.__name__
215-
state: dict = self.dict()
216-
return model_name, state
217-
218-
@validate_call
219-
def _select_epsilon_greedy_action(
220-
self,
221-
p: Union[Dict[ActionId, float], Dict[ActionId, Probability], Dict[ActionId, List[Probability]]],
222-
actions: Optional[Dict[ActionId, Model]] = None,
223-
) -> ActionId:
224-
"""
225-
Wraps self.strategy.select_action function with epsilon-greedy strategy,
226-
such that with probability epsilon a default_action is selected,
227-
and with probability 1-epsilon the select_action function is triggered to choose action.
228-
If no default_action is provided, a random action is selected.
229-
230-
Reference: Reinforcement Learning: An Introduction, Ch. 2 (Sutton and Burto, 2018)
231-
https://web.stanford.edu/class/psych209/Readings/SuttonBartoIPRLBook2ndEd.pdf&ved=2ahUKEwjMy8WV9N2HAxVe0gIHHVjjG5sQFnoECEMQAQ&usg=AOvVaw3bKK-Y_1kf6XQVwR-UYrBY
232-
233-
Parameters
234-
----------
235-
p: Union[Dict[ActionId, float], Dict[ActionId, Probability], Dict[ActionId, List[Probability]]]
236-
The dictionary or actions and their sampled probability of getting a positive reward.
237-
For MO strategy, the sampled probability is a list with elements corresponding to the objectives.
238-
actions: Optional[Dict[ActionId, Model]]
239-
The dictionary of actions and their associated Model.
240-
241-
Returns
242-
-------
243-
selected_action: ActionId
244-
The selected action.
245-
246-
Raises
247-
------
248-
KeyError
249-
If self.default_action is not present as a key in the probabilities dictionary.
250-
"""
251-
252-
if self.epsilon:
253-
if self.default_action and self.default_action not in p.keys():
254-
raise KeyError(f"Default action {self.default_action} not in actions.")
255-
if np.random.binomial(1, self.epsilon):
256-
selected_action = self.default_action if self.default_action else np.random.choice(list(p.keys()))
257-
else:
258-
selected_action = self.strategy.select_action(p=p, actions=actions)
259-
else:
260-
selected_action = self.strategy.select_action(p=p, actions=actions)
261-
return selected_action

0 commit comments

Comments
 (0)