Skip to content

Commit

Permalink
cMAB Fast Update via Variational Inference
Browse files Browse the repository at this point in the history
 ### Changes
 * Edited BaseBayesianLogisticRegression and inheritors on model.py to support variational inference by adding fast_inference control parameter on class attributes and adding control arguments on update method.
 * Edited BaseBayesianLogisticRegression to allow faster update via vectorization of PyMC operations.
 * Edited "update" UTs on test_cmab.py to support new inference mode.
 * Edited cMABs cold start function tto support new inference mode.
 * Removed redundant test_execution_time.py.
 * Edited version on pyproject.toml.
  • Loading branch information
Shahar-Bar committed Sep 23, 2024
1 parent 70dde37 commit 35b7da0
Show file tree
Hide file tree
Showing 7 changed files with 325 additions and 433 deletions.
8 changes: 8 additions & 0 deletions pybandits/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,16 @@ class BaseMab(PyBanditsBaseModel, ABC):
@field_validator("actions", mode="before")
@classmethod
def at_least_2_actions_are_defined(cls, v):
# validate that at least 2 actions are defined
if len(v) < 2:
raise AttributeError("At least 2 actions should be defined.")
# validate that all actions are of the same configuration
action_models = list(v.values())
first_action = action_models[0]
first_action_type = type(first_action)
if any(not isinstance(action, first_action_type) for action in action_models[1:]):
raise AttributeError("All actions should follow the same type.")

return v

@model_validator(mode="after")
Expand Down
63 changes: 52 additions & 11 deletions pybandits/cmab.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
BaseBayesianLogisticRegression,
BayesianLogisticRegression,
BayesianLogisticRegressionCC,
UpdateMethods,
create_bayesian_logistic_regression_cc_cold_start,
create_bayesian_logistic_regression_cold_start,
)
Expand Down Expand Up @@ -63,13 +64,21 @@ class BaseCmabBernoulli(BaseMab):
predict_with_proba: bool
predict_actions_randomly: bool

@field_validator("actions")
def check_bayesian_logistic_regression_models_len(cls, v):
blr_betas_len = [len(b.betas) for b in v.values()]
if not all(blr_betas_len[0] == x for x in blr_betas_len):
raise AttributeError(
f"All bayesian logistic regression models must have the same n_betas. Models betas_len={blr_betas_len}."
)
@field_validator("actions", mode="after")
@classmethod
def check_bayesian_logistic_regression_models(cls, v):
action_models = list(v.values())
first_action = action_models[0]
first_action_type = type(first_action)
for action in action_models[1:]:
if not isinstance(action, first_action_type):
raise AttributeError("All actions should follow the same type.")
if not len(action.betas) == len(first_action.betas):
raise AttributeError("All actions should have the same number of betas.")
if not action.update_method == first_action.update_method:
raise AttributeError("All actions should have the same update method.")
if not action.update_kwargs == first_action.update_kwargs:
raise AttributeError("All actions should have the same update kwargs.")
return v

@validate_call(config=dict(arbitrary_types_allowed=True))
Expand Down Expand Up @@ -329,6 +338,8 @@ def create_cmab_bernoulli_cold_start(
n_features: PositiveInt,
epsilon: Optional[Float01] = None,
default_action: Optional[ActionId] = None,
update_method: UpdateMethods = "MCMC",
update_kwargs: Optional[dict] = None,
) -> CmabBernoulli:
"""
Utility function to create a Contextual Bernoulli Multi-Armed Bandit with Thompson Sampling, with default
Expand All @@ -347,15 +358,23 @@ def create_cmab_bernoulli_cold_start(
default_action: Optional[ActionId]
The default action to select with a probability of epsilon when using the epsilon-greedy approach.
If `default_action` is None, a random action from the action set will be selected with a probability of epsilon.
update_method: UpdateMethods, defaults to MCMC
The strategy for computing posterior quantities of the Bayesian models in the update function. Such as Markov
chain Monte Carlo ("MCMC") or Variational Inference ("VI"). Check UpdateMethods in pybandits.model for the
full list.
update_kwargs : Optional[dict], uses default values if not specified
Additional arguments to pass to the update method of each of the action models.
Returns
-------
cmab: CmabBernoulli
Contextual Multi-Armed Bandit with strategy = ClassicBandit
"""
actions = {}
for a in set(action_ids):
actions[a] = create_bayesian_logistic_regression_cold_start(n_betas=n_features)
for action_id in set(action_ids):
actions[action_id] = create_bayesian_logistic_regression_cold_start(
n_betas=n_features, update_method=update_method, update_kwargs=update_kwargs
)
mab = CmabBernoulli(actions=actions, epsilon=epsilon, default_action=default_action)
mab.predict_actions_randomly = True
return mab
Expand All @@ -368,6 +387,8 @@ def create_cmab_bernoulli_bai_cold_start(
exploit_p: Optional[Float01] = None,
epsilon: Optional[Float01] = None,
default_action: Optional[ActionId] = None,
update_method: UpdateMethods = "MCMC",
update_kwargs: Optional[dict] = None,
) -> CmabBernoulliBAI:
"""
Utility function to create a Contextual Bernoulli Multi-Armed Bandit with Thompson Sampling, and Best Action
Expand Down Expand Up @@ -395,6 +416,12 @@ def create_cmab_bernoulli_bai_cold_start(
default_action: Optional[ActionId]
The default action to select with a probability of epsilon when using the epsilon-greedy approach.
If `default_action` is None, a random action from the action set will be selected with a probability of epsilon.
update_method: UpdateMethods, defaults to MCMC
The strategy for computing posterior quantities of the Bayesian models in the update function. Such as Markov
chain Monte Carlo ("MCMC") or Variational Inference ("VI"). Check UpdateMethods in pybandits.model for the
full list.
update_kwargs : Optional[dict], uses default values if not specified
Additional arguments to pass to the update method of each of the action models.
Returns
-------
Expand All @@ -403,7 +430,11 @@ def create_cmab_bernoulli_bai_cold_start(
"""
actions = {}
for a in set(action_ids):
actions[a] = create_bayesian_logistic_regression_cold_start(n_betas=n_features)
actions[a] = create_bayesian_logistic_regression_cold_start(
n_betas=n_features,
update_method=update_method,
update_kwargs=update_kwargs,
)
mab = CmabBernoulliBAI(actions=actions, exploit_p=exploit_p, epsilon=epsilon, default_action=default_action)
mab.predict_actions_randomly = True
return mab
Expand All @@ -416,6 +447,8 @@ def create_cmab_bernoulli_cc_cold_start(
subsidy_factor: Optional[Float01] = None,
epsilon: Optional[Float01] = None,
default_action: Optional[ActionId] = None,
update_method: UpdateMethods = "MCMC",
update_kwargs: Optional[dict] = None,
) -> CmabBernoulliCC:
"""
Utility function to create a Stochastic Bernoulli Multi-Armed Bandit with Thompson Sampling, and Cost Control
Expand Down Expand Up @@ -449,6 +482,12 @@ def create_cmab_bernoulli_cc_cold_start(
default_action: Optional[ActionId]
The default action to select with a probability of epsilon when using the epsilon-greedy approach.
If `default_action` is None, a random action from the action set will be selected with a probability of epsilon.
update_method: UpdateMethods, defaults to MCMC
The strategy for computing posterior quantities of the Bayesian models in the update function. Such as Markov
chain Monte Carlo ("MCMC") or Variational Inference ("VI"). Check UpdateMethods in pybandits.model for the
full list.
update_kwargs : Optional[dict], uses default values if not specified
Additional arguments to pass to the update method.
Returns
-------
Expand All @@ -457,7 +496,9 @@ def create_cmab_bernoulli_cc_cold_start(
"""
actions = {}
for a, cost in action_ids_cost.items():
actions[a] = create_bayesian_logistic_regression_cc_cold_start(n_betas=n_features, cost=cost)
actions[a] = create_bayesian_logistic_regression_cc_cold_start(
n_betas=n_features, cost=cost, update_method=update_method, update_kwargs=update_kwargs
)
mab = CmabBernoulliCC(
actions=actions, subsidy_factor=subsidy_factor, epsilon=epsilon, default_action=default_action
)
Expand Down
Loading

0 comments on commit 35b7da0

Please sign in to comment.