Skip to content

Commit e2e11ee

Browse files
committed
cMAB Fast Update via Variational Inference
### Changes * Edited BaseBayesianLogisticRegression and inheritors on model.py to support variational inference by adding fast_inference control parameter on class attributes and adding control arguments on update method. * Edited BaseBayesianLogisticRegression to allow faster update via vectorization of PyMC operations. * Edited "update" UTs on test_cmab.py to support new inference mode. * Edited cMABs cold start function tto support new inference mode. * Removed redundant test_execution_time.py. * Edited version on pyproject.toml.
1 parent fcd0896 commit e2e11ee

File tree

5 files changed

+221
-406
lines changed

5 files changed

+221
-406
lines changed

pybandits/cmab.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ class BaseCmabBernoulli(BaseMab):
6363
predict_with_proba: bool
6464
predict_actions_randomly: bool
6565

66-
@field_validator("actions")
66+
@field_validator("actions", mode="after")
67+
@classmethod
6768
def check_bayesian_logistic_regression_models_len(cls, v):
6869
blr_betas_len = [len(b.betas) for b in v.values()]
6970
if not all(blr_betas_len[0] == x for x in blr_betas_len):
@@ -329,6 +330,7 @@ def create_cmab_bernoulli_cold_start(
329330
n_features: PositiveInt,
330331
epsilon: Optional[Float01] = None,
331332
default_action: Optional[ActionId] = None,
333+
fast_inference: bool = False,
332334
) -> CmabBernoulli:
333335
"""
334336
Utility function to create a Contextual Bernoulli Multi-Armed Bandit with Thompson Sampling, with default
@@ -347,15 +349,19 @@ def create_cmab_bernoulli_cold_start(
347349
default_action: Optional[ActionId]
348350
The default action to select with a probability of epsilon when using the epsilon-greedy approach.
349351
If `default_action` is None, a random action from the action set will be selected with a probability of epsilon.
352+
fast_inference: bool, defaults to False
353+
Whether to utilize MCMC (False) or variational inference (True) for the Bayesian inference on update
350354
351355
Returns
352356
-------
353357
cmab: CmabBernoulli
354358
Contextual Multi-Armed Bandit with strategy = ClassicBandit
355359
"""
356360
actions = {}
357-
for a in set(action_ids):
358-
actions[a] = create_bayesian_logistic_regression_cold_start(n_betas=n_features)
361+
for action_id in set(action_ids):
362+
actions[action_id] = create_bayesian_logistic_regression_cold_start(
363+
n_betas=n_features, fast_inference=fast_inference
364+
)
359365
mab = CmabBernoulli(actions=actions, epsilon=epsilon, default_action=default_action)
360366
mab.predict_actions_randomly = True
361367
return mab
@@ -368,6 +374,7 @@ def create_cmab_bernoulli_bai_cold_start(
368374
exploit_p: Optional[Float01] = None,
369375
epsilon: Optional[Float01] = None,
370376
default_action: Optional[ActionId] = None,
377+
fast_inference: bool = False,
371378
) -> CmabBernoulliBAI:
372379
"""
373380
Utility function to create a Contextual Bernoulli Multi-Armed Bandit with Thompson Sampling, and Best Action
@@ -395,6 +402,9 @@ def create_cmab_bernoulli_bai_cold_start(
395402
default_action: Optional[ActionId]
396403
The default action to select with a probability of epsilon when using the epsilon-greedy approach.
397404
If `default_action` is None, a random action from the action set will be selected with a probability of epsilon.
405+
fast_inference: bool, defaults to False
406+
Whether to utilize standard MCMC (False) or faster variational inference (True)
407+
for the Bayesian inference on update steps.
398408
399409
Returns
400410
-------
@@ -403,7 +413,7 @@ def create_cmab_bernoulli_bai_cold_start(
403413
"""
404414
actions = {}
405415
for a in set(action_ids):
406-
actions[a] = create_bayesian_logistic_regression_cold_start(n_betas=n_features)
416+
actions[a] = create_bayesian_logistic_regression_cold_start(n_betas=n_features, fast_inference=fast_inference)
407417
mab = CmabBernoulliBAI(actions=actions, exploit_p=exploit_p, epsilon=epsilon, default_action=default_action)
408418
mab.predict_actions_randomly = True
409419
return mab
@@ -416,6 +426,7 @@ def create_cmab_bernoulli_cc_cold_start(
416426
subsidy_factor: Optional[Float01] = None,
417427
epsilon: Optional[Float01] = None,
418428
default_action: Optional[ActionId] = None,
429+
fast_inference: bool = False,
419430
) -> CmabBernoulliCC:
420431
"""
421432
Utility function to create a Stochastic Bernoulli Multi-Armed Bandit with Thompson Sampling, and Cost Control
@@ -449,6 +460,9 @@ def create_cmab_bernoulli_cc_cold_start(
449460
default_action: Optional[ActionId]
450461
The default action to select with a probability of epsilon when using the epsilon-greedy approach.
451462
If `default_action` is None, a random action from the action set will be selected with a probability of epsilon.
463+
fast_inference: bool, defaults to False
464+
Whether to utilize standard MCMC (False) or faster variational inference (True)
465+
for the Bayesian inference on update steps.
452466
453467
Returns
454468
-------
@@ -457,7 +471,9 @@ def create_cmab_bernoulli_cc_cold_start(
457471
"""
458472
actions = {}
459473
for a, cost in action_ids_cost.items():
460-
actions[a] = create_bayesian_logistic_regression_cc_cold_start(n_betas=n_features, cost=cost)
474+
actions[a] = create_bayesian_logistic_regression_cc_cold_start(
475+
n_betas=n_features, cost=cost, fast_inference=fast_inference
476+
)
461477
mab = CmabBernoulliCC(
462478
actions=actions, subsidy_factor=subsidy_factor, epsilon=epsilon, default_action=default_action
463479
)

pybandits/model.py

Lines changed: 119 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@
2222

2323

2424
from random import betavariate
25-
from typing import List, Tuple
25+
from typing import List, Optional, Tuple, Union
2626

27-
from numpy import array, c_, exp, insert, mean, multiply, ones, sqrt, std
27+
import numpy as np
28+
import pymc.math as pmath
29+
from numpy import array, c_, insert, mean, multiply, ones, sqrt, std
2830
from numpy.typing import ArrayLike
2931
from pydantic import (
3032
Field,
@@ -34,11 +36,10 @@
3436
model_validator,
3537
validate_call,
3638
)
37-
from pymc import Bernoulli, Data, Deterministic, sample
39+
from pymc import Bernoulli, Data, Deterministic, fit, sample
3840
from pymc import Model as PymcModel
3941
from pymc import StudentT as PymcStudentT
40-
from pymc.math import sigmoid
41-
from pytensor.tensor import dot
42+
from pytensor.tensor import TensorVariable, dot
4243
from scipy.stats import t
4344

4445
from pybandits.base import BinaryReward, Model, Probability, PyBanditsBaseModel
@@ -231,16 +232,62 @@ class BaseBayesianLogisticRegression(Model):
231232
232233
Parameters
233234
----------
234-
alpha: StudentT
235+
alpha : StudentT
235236
Student's t-distribution of the alpha coefficient.
236-
betas: StudentT
237+
betas : StudentT
237238
Student's t-distributions of the betas coefficients.
238-
params_sample: Dict
239-
Parameters for the function pymc.sample()
239+
fast_inference : bool, defaults to False
240+
Whether to utilize standard MCMC (False) or faster variational inference (True)
241+
for the Bayesian inference on update steps.
242+
update_kwargs : Optional[dict], uses default values if not specified
243+
Additional arguments to pass to the update method.
240244
"""
241245

242246
alpha: StudentT
243247
betas: List[StudentT] = Field(..., min_items=1)
248+
fast_inference: bool = False
249+
update_kwargs: Optional[dict] = None
250+
_default_update_kwargs = dict(draws=1000, progressbar=False, return_inferencedata=False)
251+
_default_mcmc_kwargs = dict(
252+
tune=500,
253+
draws=1000,
254+
chains=2,
255+
init="adapt_diag",
256+
cores=1,
257+
target_accept=0.95,
258+
progressbar=False,
259+
return_inferencedata=False,
260+
)
261+
_default_variational_inference_kwargs = dict(method="advi")
262+
263+
@model_validator(mode="after")
264+
def arrange_update_kwargs(self):
265+
if self.update_kwargs is None:
266+
self.update_kwargs = self._default_update_kwargs
267+
if self.fast_inference:
268+
self.update_kwargs = {**self._default_variational_inference_kwargs, **self.update_kwargs}
269+
else:
270+
self.update_kwargs = {**self._default_mcmc_kwargs, **self.update_kwargs}
271+
return self
272+
273+
@classmethod
274+
def _stable_sigmoid(cls, x: Union[np.ndarray, TensorVariable]) -> Union[np.ndarray, TensorVariable]:
275+
"""
276+
Vectorized sigmoid function that avoids overflow and underflow.
277+
Compatible with both numpy and PyMC3 tensors.
278+
Parameters
279+
----------
280+
x : Union[np.ndarray, TensorVariable]
281+
Input values.
282+
283+
Returns
284+
-------
285+
prob : Union[np.ndarray, TensorVariable]
286+
Sigmoid function applied to the input values.
287+
"""
288+
backend = np if isinstance(x, np.ndarray) else pmath
289+
prob = backend.where(x >= 0, 1 / (1 + backend.exp(-x)), backend.exp(x) / (1 + backend.exp(x)))
290+
return prob
244291

245292
@validate_call(config=dict(arbitrary_types_allowed=True))
246293
def check_context_matrix(self, context: ArrayLike):
@@ -249,12 +296,12 @@ def check_context_matrix(self, context: ArrayLike):
249296
250297
Parameters
251298
----------
252-
context: ArrayLike of shape (n_samples, n_features)
299+
context : ArrayLike of shape (n_samples, n_features)
253300
Matrix of contextual features.
254301
255302
Returns
256303
-------
257-
context: pandas DataFrame of shape (n_samples, n_features)
304+
context : pandas DataFrame of shape (n_samples, n_features)
258305
Matrix of contextual features.
259306
"""
260307
try:
@@ -304,25 +351,12 @@ def sample_proba(self, context: ArrayLike) -> Tuple[Probability, float]:
304351
weighted_sum = multiply(context_ext, coeff.T).sum(axis=1)
305352

306353
# compute the probability with the sigmoid function
307-
prob = 1.0 / (1.0 + exp(-weighted_sum))
354+
prob = self._stable_sigmoid(weighted_sum)
308355

309356
return prob, weighted_sum
310357

311358
@validate_call(config=dict(arbitrary_types_allowed=True))
312-
def update(
313-
self,
314-
context: ArrayLike,
315-
rewards: List[BinaryReward],
316-
tune=500,
317-
draws=1000,
318-
chains=2,
319-
init="adapt_diag",
320-
cores=2,
321-
target_accept=0.95,
322-
progressbar=False,
323-
return_inferencedata=False,
324-
**kwargs,
325-
):
359+
def update(self, context: ArrayLike, rewards: List[BinaryReward]):
326360
"""
327361
Update the model parameters.
328362
@@ -344,40 +378,39 @@ def update(
344378
# if model was never updated priors_parameters = default arguments
345379
# else priors_parameters are calculated from traces of the previous update
346380
alpha = PymcStudentT("alpha", mu=self.alpha.mu, sigma=self.alpha.sigma, nu=self.alpha.nu)
347-
betas = [
348-
PymcStudentT("beta" + str(i), mu=self.betas[i].mu, sigma=self.betas[i].sigma, nu=self.betas[i].nu)
349-
for i in range(len(self.betas))
350-
]
381+
beta_mu = [b.mu for b in self.betas]
382+
beta_sigma = [b.sigma for b in self.betas]
383+
beta_nu = [b.nu for b in self.betas]
384+
betas = PymcStudentT("betas", mu=beta_mu, sigma=beta_sigma, nu=beta_nu, shape=len(self.betas))
351385

352-
context = Data("context", context)
353-
rewards = Data("rewards", rewards)
386+
context = Data("context", context, mutable=False)
387+
rewards = Data("rewards", rewards, mutable=False)
354388

355389
# Likelihood (sampling distribution) of observations
356390
weighted_sum = Deterministic("weighted_sum", alpha + dot(betas, context.T))
357-
p = Deterministic("p", sigmoid(weighted_sum))
391+
p = Deterministic("p", self._stable_sigmoid(weighted_sum))
358392

359393
# Bernoulli random vector with probability of success given by sigmoid function and actual data as observed
360394
_ = Bernoulli("likelihood", p=p, observed=rewards)
361395

362396
# update traces object by sampling from posterior distribution
363-
trace = sample(
364-
tune=tune,
365-
draws=draws,
366-
chains=chains,
367-
init=init,
368-
cores=cores,
369-
target_accept=target_accept,
370-
progressbar=progressbar,
371-
return_inferencedata=return_inferencedata,
372-
**kwargs,
373-
)
397+
if self.fast_inference:
398+
# variational inference
399+
update_kwargs = self.update_kwargs.copy()
400+
approx = fit(method=update_kwargs.pop("method"))
401+
trace = approx.sample(**update_kwargs)
402+
else:
403+
# MCMC
404+
trace = sample(**self.update_kwargs)
374405

375406
# compute mean and std of the coefficients distributions
376407
self.alpha.mu = mean(trace["alpha"])
377408
self.alpha.sigma = std(trace["alpha"], ddof=1)
378-
for i in range(len(self.betas)):
379-
self.betas[i].mu = mean(trace["beta" + str(i)])
380-
self.betas[i].sigma = std(trace["beta" + str(i)], ddof=1)
409+
betas_mu = mean(trace["betas"], axis=0)
410+
betas_std = std(trace["betas"], axis=0, ddof=1)
411+
self.betas = [
412+
StudentT(mu=mu, sigma=sigma, nu=beta.nu) for mu, sigma, beta in zip(betas_mu, betas_std, self.betas)
413+
]
381414

382415

383416
class BayesianLogisticRegression(BaseBayesianLogisticRegression):
@@ -392,12 +425,15 @@ class BayesianLogisticRegression(BaseBayesianLogisticRegression):
392425
393426
Parameters
394427
----------
395-
alpha: StudentT
428+
alpha : StudentT
396429
Student's t-distribution of the alpha coefficient.
397-
betas: StudentT
430+
betas : StudentT
398431
Student's t-distributions of the betas coefficients.
399-
params_sample: Dict
400-
Parameters for the function pymc.sample()
432+
fast_inference : bool, defaults to False
433+
Whether to utilize standard MCMC (False) or faster variational inference (True)
434+
for the Bayesian inference on update steps.
435+
update_kwargs: Optional[dict], uses default values if not specified
436+
Additional arguments to pass to the update method.
401437
"""
402438

403439

@@ -417,16 +453,21 @@ class BayesianLogisticRegressionCC(BaseBayesianLogisticRegression):
417453
Student's t-distribution of the alpha coefficient.
418454
betas: StudentT
419455
Student's t-distributions of the betas coefficients.
420-
params_sample: Dict
421-
Parameters for the function pymc.sample()
456+
fast_inference : bool, defaults to False
457+
Whether to utilize standard MCMC (False) or faster variational inference (True)
458+
for the Bayesian inference on update steps.
459+
update_kwargs : Optional[dict], uses default values if not specified
460+
Additional arguments to pass to the update method.
422461
cost: NonNegativeFloat
423462
Cost associated to the Bayesian Logistic Regression model.
424463
"""
425464

426465
cost: NonNegativeFloat
427466

428467

429-
def create_bayesian_logistic_regression_cold_start(n_betas: PositiveInt) -> BayesianLogisticRegression:
468+
def create_bayesian_logistic_regression_cold_start(
469+
n_betas: PositiveInt, fast_inference: bool = False, update_kwargs: Optional[dict] = None
470+
) -> BayesianLogisticRegression:
430471
"""
431472
Utility function to create a Bayesian Logistic Regression model, with default parameters.
432473
@@ -441,17 +482,27 @@ def create_bayesian_logistic_regression_cold_start(n_betas: PositiveInt) -> Baye
441482
n_betas : PositiveInt
442483
The number of betas of the Bayesian Logistic Regression model. This is also the number of features expected
443484
after in the context matrix.
485+
fast_inference : bool, defaults to False
486+
Whether to utilize standard MCMC (False) or faster variational inference (True)
487+
for the Bayesian inference on update steps.
488+
update_kwargs : Optional[dict], uses default values if not specified
489+
Additional arguments to pass to the update method.
444490
445491
Returns
446492
-------
447493
blr: BayesianLogisticRegression
448494
The Bayesian Logistic Regression model.
449495
"""
450-
return BayesianLogisticRegression(alpha=StudentT(), betas=[StudentT() for _ in range(n_betas)])
496+
return BayesianLogisticRegression(
497+
alpha=StudentT(),
498+
betas=[StudentT() for _ in range(n_betas)],
499+
fast_inference=fast_inference,
500+
update_kwargs=update_kwargs,
501+
)
451502

452503

453504
def create_bayesian_logistic_regression_cc_cold_start(
454-
n_betas: PositiveInt, cost: NonNegativeFloat
505+
n_betas: PositiveInt, cost: NonNegativeFloat, fast_inference: bool = False, update_kwargs: Optional[dict] = None
455506
) -> BayesianLogisticRegressionCC:
456507
"""
457508
Utility function to create a Bayesian Logistic Regression model with cost control, with default parameters.
@@ -469,10 +520,21 @@ def create_bayesian_logistic_regression_cc_cold_start(
469520
after in the context matrix.
470521
cost: NonNegativeFloat
471522
Cost associated to the Bayesian Logistic Regression model.
523+
fast_inference : bool, defaults to False
524+
Whether to utilize standard MCMC (False) or faster variational inference (True)
525+
for the Bayesian inference on update steps.
526+
update_kwargs : Optional[dict], uses default values if not specified
527+
Additional arguments to pass to the update method.
472528
473529
Returns
474530
-------
475531
blr: BayesianLogisticRegressionCC
476532
The Bayesian Logistic Regression model.
477533
"""
478-
return BayesianLogisticRegressionCC(alpha=StudentT(), betas=[StudentT() for _ in range(n_betas)], cost=cost)
534+
return BayesianLogisticRegressionCC(
535+
alpha=StudentT(),
536+
betas=[StudentT() for _ in range(n_betas)],
537+
cost=cost,
538+
fast_inference=fast_inference,
539+
update_kwargs=update_kwargs,
540+
)

0 commit comments

Comments
 (0)