From a2d4d07887dd621633ee8916fa748f74e3273d02 Mon Sep 17 00:00:00 2001 From: adarmiento Date: Tue, 24 Oct 2023 15:06:56 +0200 Subject: [PATCH] BRAIN-15566 - Access model state with get_state() function ### Changes * Implemented get_state() function in BaseMab abstract class --- pybandits/base.py | 14 +++++ tests/test_base.py | 5 ++ tests/test_cmab.py | 102 +++++++++++++++++++++++++++++++++++- tests/test_smab.py | 123 +++++++++++++++++++++++++++++++++++++++++++- tests/test_utils.py | 9 ++++ 5 files changed, 250 insertions(+), 3 deletions(-) create mode 100644 tests/test_utils.py diff --git a/pybandits/base.py b/pybandits/base.py index adbb575..818599f 100644 --- a/pybandits/base.py +++ b/pybandits/base.py @@ -187,3 +187,17 @@ def predict(self, forbidden_actions: Optional[Set[ActionId]] = None): probs: List[Dict[ActionId, float]] of shape (n_samples,) The probabilities of getting a positive reward for each action. """ + + def get_state(self) -> (str, dict): + """ + Access the complete model internal state, enough to create an exact copy of the same model from it. + Returns + ------- + model_class_name: str + The name of the class of the model. + model_state: dict + The internal state of the model (actions, scores, etc.). + """ + model_name = self.__class__.__name__ + state: dict = self.dict() + return model_name, state diff --git a/tests/test_base.py b/tests/test_base.py index a9c14a9..ecb6372 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -44,6 +44,11 @@ def update( def predict(): pass + def get_state(self) -> (str, dict): + model_name = self.__class__.__name__ + state: dict = {"actions": self.actions} + return model_name, state + def test_base_mab_raise_on_less_than_2_actions(): with pytest.raises(ValidationError): diff --git a/tests/test_cmab.py b/tests/test_cmab.py index 38d0648..4620901 100644 --- a/tests/test_cmab.py +++ b/tests/test_cmab.py @@ -20,13 +20,16 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import json + import numpy as np import pandas as pd import pytest from hypothesis import given, settings from hypothesis import strategies as st -from pydantic import ValidationError +from pydantic import NonNegativeFloat, ValidationError +from pybandits.base import Float01 from pybandits.cmab import ( CmabBernoulli, CmabBernoulliBAI, @@ -37,6 +40,7 @@ ) from pybandits.model import ( BayesianLogisticRegression, + BayesianLogisticRegressionCC, StudentT, create_bayesian_logistic_regression_cc_cold_start, create_bayesian_logistic_regression_cold_start, @@ -46,6 +50,7 @@ ClassicBandit, CostControlBandit, ) +from tests.test_utils import is_serializable ######################################################################################################################## @@ -312,6 +317,29 @@ def run_predict(mab): run_predict(mab=mab) +@settings(deadline=500) +@given(st.integers(min_value=1), st.integers(min_value=1), st.integers(min_value=2, max_value=100)) +def test_cmab_get_state(mu, sigma, n_features): + actions: dict = { + "a1": BayesianLogisticRegression(alpha=StudentT(mu=mu, sigma=sigma), betas=n_features * [StudentT()]), + "a2": create_bayesian_logistic_regression_cold_start(n_betas=n_features), + } + + cmab = CmabBernoulli(actions=actions) + expected_state = json.loads( + json.dumps( + {"actions": actions, "strategy": {}, "predict_with_proba": False, "predict_actions_randomly": False}, + default=dict, + ) + ) + + class_name, cmab_state = cmab.get_state() + assert class_name == "CmabBernoulli" + assert cmab_state == expected_state + + assert is_serializable(cmab_state), "Internal state is not serializable" + + ######################################################################################################################## @@ -451,6 +479,39 @@ def test_cmab_bai_update(n_samples=100, n_features=3): assert not mab.predict_actions_randomly +@settings(deadline=500) +@given( + st.integers(min_value=1), + st.integers(min_value=1), + st.integers(min_value=2, max_value=100), + st.floats(min_value=0, max_value=1), +) +def test_cmab_bai_get_state(mu, sigma, n_features, exploit_p: Float01): + actions: dict = { + "a1": BayesianLogisticRegression(alpha=StudentT(mu=mu, sigma=sigma), betas=n_features * [StudentT()]), + "a2": create_bayesian_logistic_regression_cold_start(n_betas=n_features), + } + + cmab = CmabBernoulliBAI(actions=actions, exploit_p=exploit_p) + expected_state = json.loads( + json.dumps( + { + "actions": actions, + "strategy": {"exploit_p": exploit_p}, + "predict_with_proba": False, + "predict_actions_randomly": False, + }, + default=dict, + ) + ) + + class_name, cmab_state = cmab.get_state() + assert class_name == "CmabBernoulliBAI" + assert cmab_state == expected_state + + assert is_serializable(cmab_state), "Internal state is not serializable" + + ######################################################################################################################## @@ -597,3 +658,42 @@ def test_cmab_cc_update(n_samples=100, n_features=3): ] ) assert not mab.predict_actions_randomly + + +@settings(deadline=500) +@given( + st.integers(min_value=1), + st.integers(min_value=1), + st.integers(min_value=2, max_value=100), + st.floats(min_value=0), + st.floats(min_value=0), + st.floats(min_value=0, max_value=1), +) +def test_cmab_cc_get_state( + mu, sigma, n_features, cost_1: NonNegativeFloat, cost_2: NonNegativeFloat, subsidy_factor: Float01 +): + actions: dict = { + "a1": BayesianLogisticRegressionCC( + alpha=StudentT(mu=mu, sigma=sigma), betas=n_features * [StudentT()], cost=cost_1 + ), + "a2": create_bayesian_logistic_regression_cc_cold_start(n_betas=n_features, cost=cost_2), + } + + cmab = CmabBernoulliCC(actions=actions, subsidy_factor=subsidy_factor) + expected_state = json.loads( + json.dumps( + { + "actions": actions, + "strategy": {"subsidy_factor": subsidy_factor}, + "predict_with_proba": True, + "predict_actions_randomly": False, + }, + default=dict, + ) + ) + + class_name, cmab_state = cmab.get_state() + assert class_name == "CmabBernoulliCC" + assert cmab_state == expected_state + + assert is_serializable(cmab_state), "Internal state is not serializable" diff --git a/tests/test_smab.py b/tests/test_smab.py index 015f44f..b136790 100644 --- a/tests/test_smab.py +++ b/tests/test_smab.py @@ -26,9 +26,9 @@ import pytest from hypothesis import given from hypothesis import strategies as st -from pydantic import ValidationError +from pydantic import NonNegativeFloat, ValidationError -from pybandits.base import BinaryReward +from pybandits.base import BinaryReward, Float01 from pybandits.model import Beta, BetaCC, BetaMO, BetaMOCC from pybandits.smab import ( SmabBernoulli, @@ -48,6 +48,7 @@ MultiObjectiveBandit, MultiObjectiveCostControlBandit, ) +from tests.test_utils import is_serializable ######################################################################################################################## @@ -200,6 +201,19 @@ def test_smab_accepts_only_valid_actions(s): SmabBernoulli(actions={s: Beta(), s + "_": Beta()}) +@given(st.integers(min_value=1), st.integers(min_value=1), st.integers(min_value=1), st.integers(min_value=1)) +def test_smab_get_state(a, b, c, d): + actions = {"action1": Beta(n_successes=a, n_failures=b), "action2": Beta(n_successes=c, n_failures=d)} + smab = SmabBernoulli(actions=actions) + + expected_state = {"actions": actions, "strategy": {}} + smab_state = smab.get_state() + + class_name, smab_state = smab.get_state() + assert class_name == "SmabBernoulli" + assert smab_state == expected_state + + ######################################################################################################################## @@ -265,6 +279,25 @@ def test_smabbai_with_betacc(): ) +@given( + st.integers(min_value=1), + st.integers(min_value=1), + st.integers(min_value=1), + st.integers(min_value=1), + st.floats(min_value=0, max_value=1), +) +def test_smab_bai_get_state(a, b, c, d, exploit_p: Float01): + actions = {"action1": Beta(n_successes=a, n_failures=b), "action2": Beta(n_successes=c, n_failures=d)} + smab = SmabBernoulliBAI(actions=actions, exploit_p=exploit_p) + expected_state = {"actions": actions, "strategy": {"exploit_p": exploit_p}} + + class_name, smab_state = smab.get_state() + assert class_name == "SmabBernoulliBAI" + assert smab_state == expected_state + + assert is_serializable(smab_state), "Internal state is not serializable" + + ######################################################################################################################## @@ -327,6 +360,30 @@ def test_smabcc_update(): s.update(actions=["a1", "a1"], rewards=[1, 0]) +@given( + st.integers(min_value=1), + st.integers(min_value=1), + st.integers(min_value=1), + st.integers(min_value=1), + st.floats(min_value=0), + st.floats(min_value=0), + st.floats(min_value=0, max_value=1), +) +def test_smab_cc_get_state(a, b, c, d, cost1: NonNegativeFloat, cost2: NonNegativeFloat, subsidy_factor: Float01): + actions = { + "action1": BetaCC(n_successes=a, n_failures=b, cost=cost1), + "action2": BetaCC(n_successes=c, n_failures=d, cost=cost2), + } + smab = SmabBernoulliCC(actions=actions, subsidy_factor=subsidy_factor) + expected_state = {"actions": actions, "strategy": {"subsidy_factor": subsidy_factor}} + + class_name, smab_state = smab.get_state() + assert class_name == "SmabBernoulliCC" + assert smab_state == expected_state + + assert is_serializable(smab_state), "Internal state is not serializable" + + ######################################################################################################################## @@ -414,6 +471,36 @@ def test_smab_mo_update(): mab.update(actions=["a1", "a1"], rewards=[[1, 0, 1], [1, 1, 0]]) +@given(st.lists(st.integers(min_value=1), min_size=6, max_size=6)) +def test_smab_mo_get_state(a_list): + a, b, c, d, e, f = a_list + + actions = { + "a1": BetaMO( + counters=[ + Beta(n_successes=a, n_failures=b), + Beta(n_successes=c, n_failures=d), + Beta(n_successes=e, n_failures=f), + ] + ), + "a2": BetaMO( + counters=[ + Beta(n_successes=d, n_failures=a), + Beta(n_successes=e, n_failures=b), + Beta(n_successes=f, n_failures=c), + ] + ), + } + smab = SmabBernoulliMO(actions=actions) + expected_state = {"actions": actions, "strategy": {}} + + class_name, smab_state = smab.get_state() + assert class_name == "SmabBernoulliMO" + assert smab_state == expected_state + + assert is_serializable(smab_state), "Internal state is not serializable" + + ######################################################################################################################## @@ -498,3 +585,35 @@ def test_smab_mo_cc_predict(): forbidden = ["a1", "a3"] with pytest.raises(ValueError): s.predict(n_samples=n_samples, forbidden_actions=forbidden) + + +@given(st.lists(st.integers(min_value=1), min_size=8, max_size=8)) +def test_smab_mocc_get_state(a_list): + a, b, c, d, e, f, g, h = a_list + + actions = { + "a1": BetaMOCC( + counters=[ + Beta(n_successes=a, n_failures=b), + Beta(n_successes=c, n_failures=d), + Beta(n_successes=e, n_failures=f), + ], + cost=g, + ), + "a2": BetaMOCC( + counters=[ + Beta(n_successes=d, n_failures=a), + Beta(n_successes=e, n_failures=b), + Beta(n_successes=f, n_failures=c), + ], + cost=h, + ), + } + smab = SmabBernoulliMOCC(actions=actions) + expected_state = {"actions": actions, "strategy": {}} + + class_name, smab_state = smab.get_state() + assert class_name == "SmabBernoulliMOCC" + assert smab_state == expected_state + + assert is_serializable(smab_state), "Internal state is not serializable" diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..b015cdb --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,9 @@ +import json + + +def is_serializable(something) -> bool: + try: + json.dumps(something) + return True + except Exception: + return False