Skip to content

Commit 61144ed

Browse files
committed
BRAIN-15566 - Access model state with get_state() function
### Changes * Implemented get_state() function in BaseMab abstract class
1 parent 5c48947 commit 61144ed

File tree

5 files changed

+242
-10
lines changed

5 files changed

+242
-10
lines changed

pybandits/base.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,9 @@ def _get_valid_actions(self, forbidden_actions: Optional[Set[ActionId]]) -> Set[
129129
return valid_actions
130130

131131
def _check_update_params(
132-
self,
133-
actions: List[ActionId],
134-
rewards: List[Union[NonNegativeInt, List[NonNegativeInt]]],
132+
self,
133+
actions: List[ActionId],
134+
rewards: List[Union[NonNegativeInt, List[NonNegativeInt]]],
135135
):
136136
"""
137137
Verify that the given list of action IDs is a subset of the currently defined actions.
@@ -152,11 +152,11 @@ def _check_update_params(
152152
@abstractmethod
153153
@validate_arguments
154154
def update(
155-
self,
156-
actions: List[ActionId],
157-
rewards: List[Union[BinaryReward, List[BinaryReward]]],
158-
*args,
159-
**kwargs,
155+
self,
156+
actions: List[ActionId],
157+
rewards: List[Union[BinaryReward, List[BinaryReward]]],
158+
*args,
159+
**kwargs,
160160
):
161161
"""
162162
Update the stochastic multi-armed bandit model.
@@ -187,3 +187,17 @@ def predict(self, forbidden_actions: Optional[Set[ActionId]] = None):
187187
probs: List[Dict[ActionId, float]] of shape (n_samples,)
188188
The probabilities of getting a positive reward for each action.
189189
"""
190+
191+
def get_state(self) -> (str, dict):
192+
"""
193+
Access the complete model internal state, enough to create an exact copy of the same model from it.
194+
Returns
195+
-------
196+
model_class_name: str
197+
The name of the class of the model.
198+
model_state: dict
199+
The internal state of the model (actions, scores, etc.).
200+
"""
201+
model_name = self.__class__.__name__
202+
state: dict = self.dict()
203+
return model_name, state

tests/test_base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ def update(
4444
def predict():
4545
pass
4646

47+
def get_state(self) -> (str, dict):
48+
model_name = self.__class__.__name__
49+
state: dict = {"actions": self.actions}
50+
return model_name, state
51+
4752

4853
def test_base_mab_raise_on_less_than_2_actions():
4954
with pytest.raises(ValidationError):

tests/test_cmab.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,16 @@
2020
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121
# SOFTWARE.
2222

23+
import json
24+
from typing import Optional
2325
import numpy as np
2426
import pandas as pd
2527
import pytest
2628
from hypothesis import given, settings
2729
from hypothesis import strategies as st
2830
from pydantic import ValidationError
2931

32+
from pybandits.base import Float01
3033
from pybandits.cmab import (
3134
CmabBernoulli,
3235
CmabBernoulliBAI,
@@ -40,12 +43,14 @@
4043
StudentT,
4144
create_bayesian_logistic_regression_cc_cold_start,
4245
create_bayesian_logistic_regression_cold_start,
46+
BayesianLogisticRegressionCC
4347
)
4448
from pybandits.strategy import (
4549
BestActionIdentification,
4650
ClassicBandit,
4751
CostControlBandit,
4852
)
53+
from tests.test_utils import is_serializable
4954

5055
########################################################################################################################
5156

@@ -311,6 +316,33 @@ def run_predict(mab):
311316
assert mab != create_cmab_bernoulli_cold_start(action_ids=["a1", "a2", "a3", "a4", "a5"], n_features=n_features)
312317
run_predict(mab=mab)
313318

319+
@pytest.mark.parametrize("action_dict, action_ids", [
320+
({"a1": BayesianLogisticRegression(alpha=StudentT(mu=1, sigma=2), betas=[StudentT(), StudentT(), StudentT()]),
321+
"a2": create_bayesian_logistic_regression_cold_start(n_betas=3)}, None),
322+
(None, {"a0", "a1", "a2"})])
323+
def test_cmab_get_state(action_dict: Optional[dict], action_ids: Optional[set]):
324+
if action_dict:
325+
cmab = CmabBernoulli(actions=action_dict)
326+
expected_state = json.loads(json.dumps({
327+
"actions": action_dict,
328+
"strategy": {},
329+
'predict_with_proba': False,
330+
'predict_actions_randomly': False}, default=dict))
331+
else:
332+
cmab = create_cmab_bernoulli_cold_start(action_ids=action_ids, n_features=3)
333+
expected_state = json.loads(json.dumps({
334+
"actions": {action_id: create_bayesian_logistic_regression_cold_start(n_betas=3)
335+
for action_id in action_ids},
336+
"strategy": {},
337+
'predict_with_proba': False,
338+
'predict_actions_randomly': True}, default=dict))
339+
340+
cmab_state = cmab.get_state()
341+
assert cmab_state[0] == "CmabBernoulli"
342+
assert cmab_state[1] == expected_state
343+
344+
assert is_serializable(cmab_state[1]), "Internal state is not serializable"
345+
314346

315347
########################################################################################################################
316348

@@ -450,6 +482,33 @@ def test_cmab_bai_update(n_samples=100, n_features=3):
450482
)
451483
assert not mab.predict_actions_randomly
452484

485+
@pytest.mark.parametrize("action_dict, action_ids, exploit_p", [
486+
({"a1": BayesianLogisticRegression(alpha=StudentT(mu=1, sigma=2), betas=[StudentT(), StudentT(), StudentT()]),
487+
"a2": create_bayesian_logistic_regression_cold_start(n_betas=3)}, None, 0.5),
488+
(None, {"a0", "a1", "a2"}, 0.8)])
489+
def test_cmab_bai_get_state(action_dict: Optional[dict], action_ids: Optional[set], exploit_p: Float01):
490+
if action_dict:
491+
cmab = CmabBernoulliBAI(actions=action_dict, exploit_p=exploit_p)
492+
expected_state = json.loads(json.dumps({
493+
"actions": action_dict,
494+
"strategy": {"exploit_p": exploit_p},
495+
'predict_with_proba': False,
496+
'predict_actions_randomly': False}, default=dict))
497+
else:
498+
cmab = create_cmab_bernoulli_bai_cold_start(action_ids=action_ids, n_features=3, exploit_p=exploit_p)
499+
expected_state = json.loads(json.dumps({
500+
"actions": {action_id: create_bayesian_logistic_regression_cold_start(n_betas=3)
501+
for action_id in action_ids},
502+
"strategy": {"exploit_p": exploit_p},
503+
'predict_with_proba': False,
504+
'predict_actions_randomly': True}, default=dict))
505+
506+
cmab_state = cmab.get_state()
507+
assert cmab_state[0] == "CmabBernoulliBAI"
508+
assert cmab_state[1] == expected_state
509+
510+
assert is_serializable(cmab_state[1]), "Internal state is not serializable"
511+
453512

454513
########################################################################################################################
455514

@@ -597,3 +656,32 @@ def test_cmab_cc_update(n_samples=100, n_features=3):
597656
]
598657
)
599658
assert not mab.predict_actions_randomly
659+
660+
661+
@pytest.mark.parametrize("action_dict, action_ids_cost, subsidy_factor", [
662+
({"a1": BayesianLogisticRegressionCC(alpha=StudentT(mu=1, sigma=2), betas=[StudentT(), StudentT()], cost=0.1),
663+
"a2": create_bayesian_logistic_regression_cc_cold_start(n_betas=2, cost=0.2)}, None, 0.3),
664+
(None, {"a0": 0.1, "a1": 0.2, "a2": 0.3}, 0.5)])
665+
def test_cmab_cc_get_state(action_dict: Optional[dict], action_ids_cost: Optional[dict], subsidy_factor: Float01):
666+
if action_dict:
667+
cmab = CmabBernoulliCC(actions=action_dict, subsidy_factor=subsidy_factor)
668+
expected_state = json.loads(json.dumps({
669+
"actions": action_dict,
670+
"strategy": {"subsidy_factor": subsidy_factor},
671+
'predict_with_proba': True,
672+
'predict_actions_randomly': False}, default=dict))
673+
else:
674+
cmab = create_cmab_bernoulli_cc_cold_start(action_ids_cost=action_ids_cost, n_features=2,
675+
subsidy_factor=subsidy_factor)
676+
expected_state = json.loads(json.dumps({
677+
"actions": {action_id: create_bayesian_logistic_regression_cc_cold_start(n_betas=2, cost=action_cost)
678+
for action_id, action_cost in action_ids_cost.items()},
679+
"strategy": {"subsidy_factor": subsidy_factor},
680+
'predict_with_proba': True,
681+
'predict_actions_randomly': True}, default=dict))
682+
683+
cmab_state = cmab.get_state()
684+
assert cmab_state[0] == "CmabBernoulliCC"
685+
assert cmab_state[1] == expected_state
686+
687+
assert is_serializable(cmab_state[1]), "Internal state is not serializable"

tests/test_smab.py

Lines changed: 118 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
# SOFTWARE.
2222

2323
from copy import deepcopy
24-
from typing import List
24+
from typing import List, Optional, Tuple
2525

2626
import pytest
2727
from hypothesis import given
2828
from hypothesis import strategies as st
2929
from pydantic import ValidationError
3030

31-
from pybandits.base import BinaryReward
31+
from pybandits.base import BinaryReward, Float01
3232
from pybandits.model import Beta, BetaCC, BetaMO, BetaMOCC
3333
from pybandits.smab import (
3434
SmabBernoulli,
@@ -48,6 +48,7 @@
4848
MultiObjectiveBandit,
4949
MultiObjectiveCostControlBandit,
5050
)
51+
from tests.test_utils import is_serializable
5152

5253
########################################################################################################################
5354

@@ -200,6 +201,26 @@ def test_smab_accepts_only_valid_actions(s):
200201
SmabBernoulli(actions={s: Beta(), s + "_": Beta()})
201202

202203

204+
@pytest.mark.parametrize("action_dict, action_ids", [
205+
({"a0": Beta(), "a1": Beta(), "a2": Beta()}, None),
206+
({"a0": Beta(), "a1": Beta(n_successes=5, n_failures=5), "a2": Beta(n_successes=10, n_failures=1),
207+
"a3": Beta(n_successes=10, n_failures=5), "a4": Beta(n_successes=100, n_failures=4), "a5": Beta()}, None),
208+
(None, {"a0", "a1", "a2"})])
209+
def test_smab_get_state(action_dict: Optional[dict], action_ids: Optional[set]):
210+
if action_dict:
211+
smab = SmabBernoulli(actions=action_dict)
212+
expected_state = {"actions": action_dict, "strategy": {}}
213+
else:
214+
smab = create_smab_bernoulli_cold_start(action_ids=action_ids)
215+
expected_state = {"actions": {action_id: Beta() for action_id in action_ids}, "strategy": {}}
216+
217+
smab_state = smab.get_state()
218+
assert smab_state[0] == "SmabBernoulli"
219+
assert smab_state[1] == expected_state
220+
221+
assert is_serializable(smab_state[1]), "Internal state is not serializable"
222+
223+
203224
########################################################################################################################
204225

205226

@@ -265,6 +286,27 @@ def test_smabbai_with_betacc():
265286
)
266287

267288

289+
@pytest.mark.parametrize("action_dict, action_ids, exploit_p", [
290+
({"a0": Beta(), "a1": Beta(), "a2": Beta()}, None, 0.3),
291+
({"a0": Beta(), "a1": Beta(n_successes=5, n_failures=5), "a2": Beta(n_successes=10, n_failures=1),
292+
"a3": Beta(n_successes=10, n_failures=5), "a4": Beta(n_successes=100, n_failures=4), "a5": Beta()}, None, 0.8),
293+
(None, {"a0", "a1", "a2"}, 0.5)])
294+
def test_smab_bai_get_state(action_dict: Optional[dict], action_ids: Optional[set], exploit_p: Float01):
295+
if action_dict:
296+
smab = SmabBernoulliBAI(actions=action_dict, exploit_p=exploit_p)
297+
expected_state = {"actions": action_dict, "strategy": {"exploit_p": exploit_p}}
298+
else:
299+
smab = create_smab_bernoulli_bai_cold_start(action_ids=action_ids, exploit_p=exploit_p)
300+
expected_state = {"actions": {action_id: Beta() for action_id in action_ids},
301+
"strategy": {"exploit_p": exploit_p}}
302+
303+
smab_state = smab.get_state()
304+
assert smab_state[0] == "SmabBernoulliBAI"
305+
assert smab_state[1] == expected_state
306+
307+
assert is_serializable(smab_state[1]), "Internal state is not serializable"
308+
309+
268310
########################################################################################################################
269311

270312

@@ -327,6 +369,28 @@ def test_smabcc_update():
327369
s.update(actions=["a1", "a1"], rewards=[1, 0])
328370

329371

372+
@pytest.mark.parametrize("action_dict, action_ids_cost, subsidy_factor", [
373+
({"a0": BetaCC(cost=0.1), "a1": BetaCC(cost=0.2), "a2": BetaCC(cost=0.3)}, None, 0.3),
374+
({"a0": BetaCC(cost=0.1), "a1": BetaCC(n_successes=5, n_failures=5, cost=0.2), "a2":
375+
BetaCC(n_successes=10, n_failures=1, cost=0.3), "a3": BetaCC(n_successes=10, n_failures=5, cost=0.4),
376+
"a4": BetaCC(n_successes=100, n_failures=4, cost=0.5), "a5": BetaCC(cost=0.6)}, None, 0.8),
377+
(None, {"a0": 0.1, "a1": 0.2, "a2": 0.3}, 0.5)])
378+
def test_smab_cc_get_state(action_dict: Optional[dict], action_ids_cost: Optional[dict], subsidy_factor: Float01):
379+
if action_dict:
380+
smab = SmabBernoulliCC(actions=action_dict, subsidy_factor=subsidy_factor)
381+
expected_state = {"actions": action_dict, "strategy": {"subsidy_factor": subsidy_factor}}
382+
else:
383+
smab = create_smab_bernoulli_cc_cold_start(action_ids_cost=action_ids_cost, subsidy_factor=subsidy_factor)
384+
expected_state = {"actions": {k: BetaCC(cost=v) for k, v in action_ids_cost.items()},
385+
"strategy": {"subsidy_factor": subsidy_factor}}
386+
387+
smab_state = smab.get_state()
388+
assert smab_state[0] == "SmabBernoulliCC"
389+
assert smab_state[1] == expected_state
390+
391+
assert is_serializable(smab_state[1]), "Internal state is not serializable"
392+
393+
330394
########################################################################################################################
331395

332396

@@ -414,6 +478,31 @@ def test_smab_mo_update():
414478
mab.update(actions=["a1", "a1"], rewards=[[1, 0, 1], [1, 1, 0]])
415479

416480

481+
@pytest.mark.parametrize("action_dict, action_ids_with_obj", [
482+
({"a0": BetaMO(counters=[Beta(), Beta()]), "a1": BetaMO(counters=[Beta(), Beta()]),
483+
"a2": BetaMO(counters=[Beta(), Beta()])}, None),
484+
({"a0": BetaMO(counters=[Beta(), Beta()]),
485+
"a1": BetaMO(counters=[Beta(n_successes=2, n_failures=3), Beta(n_successes=4, n_failures=5)]),
486+
"a2": BetaMO(counters=[Beta(n_successes=6, n_failures=7), Beta(n_successes=8, n_failures=9)])}, None),
487+
(None, ({"a0", "a1", "a2"}, 2))
488+
])
489+
def test_smab_mo_get_state(action_dict: Optional[dict], action_ids_with_obj: Optional[Tuple[set, int]]):
490+
if action_dict:
491+
smab = SmabBernoulliMO(actions=action_dict)
492+
expected_state = {"actions": action_dict, "strategy": {}}
493+
else:
494+
smab = create_smab_bernoulli_mo_cold_start(action_ids=action_ids_with_obj[0],
495+
n_objectives=action_ids_with_obj[1])
496+
expected_state = {"actions": {action_id: BetaMO(counters=[Beta()] * action_ids_with_obj[1]) for action_id in
497+
action_ids_with_obj[0]}, "strategy": {}}
498+
499+
smab_state = smab.get_state()
500+
assert smab_state[0] == "SmabBernoulliMO"
501+
assert smab_state[1] == expected_state
502+
503+
assert is_serializable(smab_state[1]), "Internal state is not serializable"
504+
505+
417506
########################################################################################################################
418507

419508

@@ -498,3 +587,30 @@ def test_smab_mo_cc_predict():
498587
forbidden = ["a1", "a3"]
499588
with pytest.raises(ValueError):
500589
s.predict(n_samples=n_samples, forbidden_actions=forbidden)
590+
591+
592+
@pytest.mark.parametrize("action_dict, action_ids_cost_with_obj", [
593+
({"a0": BetaMOCC(counters=[Beta(), Beta()], cost=0.1), "a1": BetaMOCC(counters=[Beta(), Beta()], cost=0.2),
594+
"a2": BetaMOCC(counters=[Beta(), Beta()], cost=0.3)}, None),
595+
({"a0": BetaMOCC(counters=[Beta(), Beta()], cost=0.1),
596+
"a1": BetaMOCC(counters=[Beta(n_successes=2, n_failures=3), Beta(n_successes=4, n_failures=5)], cost=0.2),
597+
"a2": BetaMOCC(counters=[Beta(n_successes=6, n_failures=7), Beta(n_successes=8, n_failures=9)], cost=0.3)}, None),
598+
(None, ({"a0": 0.1, "a1": 0.2, "a2": 0.3}, 2))
599+
])
600+
def test_smab_mo_cc_get_state(action_dict: Optional[dict],
601+
action_ids_cost_with_obj: Optional[Tuple[dict, int]]):
602+
if action_dict:
603+
smab = SmabBernoulliMOCC(actions=action_dict)
604+
expected_state = {"actions": action_dict, "strategy": {}}
605+
else:
606+
smab = create_smab_bernoulli_mo_cc_cold_start(action_ids_cost=action_ids_cost_with_obj[0],
607+
n_objectives=action_ids_cost_with_obj[1])
608+
expected_state = {"actions": {action_id: BetaMOCC(counters=[Beta()] * action_ids_cost_with_obj[1],
609+
cost=action_ids_cost_with_obj[0][action_id]) for action_id in
610+
action_ids_cost_with_obj[0]}, "strategy": {}}
611+
612+
smab_state = smab.get_state()
613+
assert smab_state[0] == "SmabBernoulliMOCC"
614+
assert smab_state[1] == expected_state
615+
616+
assert is_serializable(smab_state[1]), "Internal state is not serializable"

tests/test_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import json
2+
3+
4+
def is_serializable(something) -> bool:
5+
try:
6+
json.dumps(something)
7+
return True
8+
except Exception:
9+
return False

0 commit comments

Comments
 (0)