22
22
23
23
24
24
from random import betavariate
25
- from typing import List , Tuple
25
+ from typing import List , Optional , Tuple , Union
26
26
27
- from numpy import array , c_ , exp , insert , mean , multiply , ones , sqrt , std
27
+ import numpy as np
28
+ import pymc .math as pmath
29
+ from numpy import array , c_ , insert , mean , multiply , ones , sqrt , std
28
30
from numpy .typing import ArrayLike
29
31
from pydantic import (
30
32
Field ,
34
36
model_validator ,
35
37
validate_call ,
36
38
)
37
- from pymc import Bernoulli , Data , Deterministic , sample
39
+ from pymc import Bernoulli , Data , Deterministic , fit , sample
38
40
from pymc import Model as PymcModel
39
41
from pymc import StudentT as PymcStudentT
40
- from pymc .math import sigmoid
41
- from pytensor .tensor import dot
42
+ from pytensor .tensor import TensorVariable , dot
42
43
from scipy .stats import t
43
44
44
45
from pybandits .base import BinaryReward , Model , Probability , PyBanditsBaseModel
@@ -231,16 +232,62 @@ class BaseBayesianLogisticRegression(Model):
231
232
232
233
Parameters
233
234
----------
234
- alpha: StudentT
235
+ alpha : StudentT
235
236
Student's t-distribution of the alpha coefficient.
236
- betas: StudentT
237
+ betas : StudentT
237
238
Student's t-distributions of the betas coefficients.
238
- params_sample: Dict
239
- Parameters for the function pymc.sample()
239
+ fast_inference : bool, defaults to False
240
+ Whether to utilize standard MCMC (False) or faster variational inference (True)
241
+ for the Bayesian inference on update steps.
242
+ update_kwargs : Optional[dict], uses default values if not specified
243
+ Additional arguments to pass to the update method.
240
244
"""
241
245
242
246
alpha : StudentT
243
247
betas : List [StudentT ] = Field (..., min_items = 1 )
248
+ fast_inference : bool = False
249
+ update_kwargs : Optional [dict ] = None
250
+ _default_update_kwargs = dict (draws = 1000 , progressbar = False , return_inferencedata = False )
251
+ _default_mcmc_kwargs = dict (
252
+ tune = 500 ,
253
+ draws = 1000 ,
254
+ chains = 2 ,
255
+ init = "adapt_diag" ,
256
+ cores = 1 ,
257
+ target_accept = 0.95 ,
258
+ progressbar = False ,
259
+ return_inferencedata = False ,
260
+ )
261
+ _default_variational_inference_kwargs = dict (method = "advi" )
262
+
263
+ @model_validator (mode = "after" )
264
+ def arrange_update_kwargs (self ):
265
+ if self .update_kwargs is None :
266
+ self .update_kwargs = self ._default_update_kwargs
267
+ if self .fast_inference :
268
+ self .update_kwargs = {** self ._default_variational_inference_kwargs , ** self .update_kwargs }
269
+ else :
270
+ self .update_kwargs = {** self ._default_mcmc_kwargs , ** self .update_kwargs }
271
+ return self
272
+
273
+ @classmethod
274
+ def _stable_sigmoid (cls , x : Union [np .ndarray , TensorVariable ]) -> Union [np .ndarray , TensorVariable ]:
275
+ """
276
+ Vectorized sigmoid function that avoids overflow and underflow.
277
+ Compatible with both numpy and PyMC3 tensors.
278
+ Parameters
279
+ ----------
280
+ x : Union[np.ndarray, TensorVariable]
281
+ Input values.
282
+
283
+ Returns
284
+ -------
285
+ prob : Union[np.ndarray, TensorVariable]
286
+ Sigmoid function applied to the input values.
287
+ """
288
+ backend = np if isinstance (x , np .ndarray ) else pmath
289
+ prob = backend .where (x >= 0 , 1 / (1 + backend .exp (- x )), backend .exp (x ) / (1 + backend .exp (x )))
290
+ return prob
244
291
245
292
@validate_call (config = dict (arbitrary_types_allowed = True ))
246
293
def check_context_matrix (self , context : ArrayLike ):
@@ -249,12 +296,12 @@ def check_context_matrix(self, context: ArrayLike):
249
296
250
297
Parameters
251
298
----------
252
- context: ArrayLike of shape (n_samples, n_features)
299
+ context : ArrayLike of shape (n_samples, n_features)
253
300
Matrix of contextual features.
254
301
255
302
Returns
256
303
-------
257
- context: pandas DataFrame of shape (n_samples, n_features)
304
+ context : pandas DataFrame of shape (n_samples, n_features)
258
305
Matrix of contextual features.
259
306
"""
260
307
try :
@@ -304,25 +351,12 @@ def sample_proba(self, context: ArrayLike) -> Tuple[Probability, float]:
304
351
weighted_sum = multiply (context_ext , coeff .T ).sum (axis = 1 )
305
352
306
353
# compute the probability with the sigmoid function
307
- prob = 1.0 / ( 1.0 + exp ( - weighted_sum ) )
354
+ prob = self . _stable_sigmoid ( weighted_sum )
308
355
309
356
return prob , weighted_sum
310
357
311
358
@validate_call (config = dict (arbitrary_types_allowed = True ))
312
- def update (
313
- self ,
314
- context : ArrayLike ,
315
- rewards : List [BinaryReward ],
316
- tune = 500 ,
317
- draws = 1000 ,
318
- chains = 2 ,
319
- init = "adapt_diag" ,
320
- cores = 2 ,
321
- target_accept = 0.95 ,
322
- progressbar = False ,
323
- return_inferencedata = False ,
324
- ** kwargs ,
325
- ):
359
+ def update (self , context : ArrayLike , rewards : List [BinaryReward ]):
326
360
"""
327
361
Update the model parameters.
328
362
@@ -344,40 +378,39 @@ def update(
344
378
# if model was never updated priors_parameters = default arguments
345
379
# else priors_parameters are calculated from traces of the previous update
346
380
alpha = PymcStudentT ("alpha" , mu = self .alpha .mu , sigma = self .alpha .sigma , nu = self .alpha .nu )
347
- betas = [
348
- PymcStudentT ( "beta" + str ( i ), mu = self . betas [ i ]. mu , sigma = self . betas [ i ] .sigma , nu = self .betas [ i ]. nu )
349
- for i in range ( len ( self .betas ))
350
- ]
381
+ beta_mu = [b . mu for b in self . betas ]
382
+ beta_sigma = [ b .sigma for b in self .betas ]
383
+ beta_nu = [ b . nu for b in self .betas ]
384
+ betas = PymcStudentT ( "betas" , mu = beta_mu , sigma = beta_sigma , nu = beta_nu , shape = len ( self . betas ))
351
385
352
- context = Data ("context" , context )
353
- rewards = Data ("rewards" , rewards )
386
+ context = Data ("context" , context , mutable = False )
387
+ rewards = Data ("rewards" , rewards , mutable = False )
354
388
355
389
# Likelihood (sampling distribution) of observations
356
390
weighted_sum = Deterministic ("weighted_sum" , alpha + dot (betas , context .T ))
357
- p = Deterministic ("p" , sigmoid (weighted_sum ))
391
+ p = Deterministic ("p" , self . _stable_sigmoid (weighted_sum ))
358
392
359
393
# Bernoulli random vector with probability of success given by sigmoid function and actual data as observed
360
394
_ = Bernoulli ("likelihood" , p = p , observed = rewards )
361
395
362
396
# update traces object by sampling from posterior distribution
363
- trace = sample (
364
- tune = tune ,
365
- draws = draws ,
366
- chains = chains ,
367
- init = init ,
368
- cores = cores ,
369
- target_accept = target_accept ,
370
- progressbar = progressbar ,
371
- return_inferencedata = return_inferencedata ,
372
- ** kwargs ,
373
- )
397
+ if self .fast_inference :
398
+ # variational inference
399
+ update_kwargs = self .update_kwargs .copy ()
400
+ approx = fit (method = update_kwargs .pop ("method" ))
401
+ trace = approx .sample (** update_kwargs )
402
+ else :
403
+ # MCMC
404
+ trace = sample (** self .update_kwargs )
374
405
375
406
# compute mean and std of the coefficients distributions
376
407
self .alpha .mu = mean (trace ["alpha" ])
377
408
self .alpha .sigma = std (trace ["alpha" ], ddof = 1 )
378
- for i in range (len (self .betas )):
379
- self .betas [i ].mu = mean (trace ["beta" + str (i )])
380
- self .betas [i ].sigma = std (trace ["beta" + str (i )], ddof = 1 )
409
+ betas_mu = mean (trace ["betas" ], axis = 0 )
410
+ betas_std = std (trace ["betas" ], axis = 0 , ddof = 1 )
411
+ self .betas = [
412
+ StudentT (mu = mu , sigma = sigma , nu = beta .nu ) for mu , sigma , beta in zip (betas_mu , betas_std , self .betas )
413
+ ]
381
414
382
415
383
416
class BayesianLogisticRegression (BaseBayesianLogisticRegression ):
@@ -392,12 +425,15 @@ class BayesianLogisticRegression(BaseBayesianLogisticRegression):
392
425
393
426
Parameters
394
427
----------
395
- alpha: StudentT
428
+ alpha : StudentT
396
429
Student's t-distribution of the alpha coefficient.
397
- betas: StudentT
430
+ betas : StudentT
398
431
Student's t-distributions of the betas coefficients.
399
- params_sample: Dict
400
- Parameters for the function pymc.sample()
432
+ fast_inference : bool, defaults to False
433
+ Whether to utilize standard MCMC (False) or faster variational inference (True)
434
+ for the Bayesian inference on update steps.
435
+ update_kwargs: Optional[dict], uses default values if not specified
436
+ Additional arguments to pass to the update method.
401
437
"""
402
438
403
439
@@ -417,16 +453,21 @@ class BayesianLogisticRegressionCC(BaseBayesianLogisticRegression):
417
453
Student's t-distribution of the alpha coefficient.
418
454
betas: StudentT
419
455
Student's t-distributions of the betas coefficients.
420
- params_sample: Dict
421
- Parameters for the function pymc.sample()
456
+ fast_inference : bool, defaults to False
457
+ Whether to utilize standard MCMC (False) or faster variational inference (True)
458
+ for the Bayesian inference on update steps.
459
+ update_kwargs : Optional[dict], uses default values if not specified
460
+ Additional arguments to pass to the update method.
422
461
cost: NonNegativeFloat
423
462
Cost associated to the Bayesian Logistic Regression model.
424
463
"""
425
464
426
465
cost : NonNegativeFloat
427
466
428
467
429
- def create_bayesian_logistic_regression_cold_start (n_betas : PositiveInt ) -> BayesianLogisticRegression :
468
+ def create_bayesian_logistic_regression_cold_start (
469
+ n_betas : PositiveInt , fast_inference : bool = False , update_kwargs : Optional [dict ] = None
470
+ ) -> BayesianLogisticRegression :
430
471
"""
431
472
Utility function to create a Bayesian Logistic Regression model, with default parameters.
432
473
@@ -441,17 +482,27 @@ def create_bayesian_logistic_regression_cold_start(n_betas: PositiveInt) -> Baye
441
482
n_betas : PositiveInt
442
483
The number of betas of the Bayesian Logistic Regression model. This is also the number of features expected
443
484
after in the context matrix.
485
+ fast_inference : bool, defaults to False
486
+ Whether to utilize standard MCMC (False) or faster variational inference (True)
487
+ for the Bayesian inference on update steps.
488
+ update_kwargs : Optional[dict], uses default values if not specified
489
+ Additional arguments to pass to the update method.
444
490
445
491
Returns
446
492
-------
447
493
blr: BayesianLogisticRegression
448
494
The Bayesian Logistic Regression model.
449
495
"""
450
- return BayesianLogisticRegression (alpha = StudentT (), betas = [StudentT () for _ in range (n_betas )])
496
+ return BayesianLogisticRegression (
497
+ alpha = StudentT (),
498
+ betas = [StudentT () for _ in range (n_betas )],
499
+ fast_inference = fast_inference ,
500
+ update_kwargs = update_kwargs ,
501
+ )
451
502
452
503
453
504
def create_bayesian_logistic_regression_cc_cold_start (
454
- n_betas : PositiveInt , cost : NonNegativeFloat
505
+ n_betas : PositiveInt , cost : NonNegativeFloat , fast_inference : bool = False , update_kwargs : Optional [ dict ] = None
455
506
) -> BayesianLogisticRegressionCC :
456
507
"""
457
508
Utility function to create a Bayesian Logistic Regression model with cost control, with default parameters.
@@ -469,10 +520,21 @@ def create_bayesian_logistic_regression_cc_cold_start(
469
520
after in the context matrix.
470
521
cost: NonNegativeFloat
471
522
Cost associated to the Bayesian Logistic Regression model.
523
+ fast_inference : bool, defaults to False
524
+ Whether to utilize standard MCMC (False) or faster variational inference (True)
525
+ for the Bayesian inference on update steps.
526
+ update_kwargs : Optional[dict], uses default values if not specified
527
+ Additional arguments to pass to the update method.
472
528
473
529
Returns
474
530
-------
475
531
blr: BayesianLogisticRegressionCC
476
532
The Bayesian Logistic Regression model.
477
533
"""
478
- return BayesianLogisticRegressionCC (alpha = StudentT (), betas = [StudentT () for _ in range (n_betas )], cost = cost )
534
+ return BayesianLogisticRegressionCC (
535
+ alpha = StudentT (),
536
+ betas = [StudentT () for _ in range (n_betas )],
537
+ cost = cost ,
538
+ fast_inference = fast_inference ,
539
+ update_kwargs = update_kwargs ,
540
+ )
0 commit comments