|
21 | 21 | # SOFTWARE.
|
22 | 22 |
|
23 | 23 | from copy import deepcopy
|
24 |
| -from typing import List |
| 24 | +from typing import List, Optional, Tuple |
25 | 25 |
|
26 | 26 | import pytest
|
27 | 27 | from hypothesis import given
|
28 | 28 | from hypothesis import strategies as st
|
29 | 29 | from pydantic import ValidationError
|
30 | 30 |
|
31 |
| -from pybandits.base import BinaryReward |
| 31 | +from pybandits.base import BinaryReward, Float01 |
32 | 32 | from pybandits.model import Beta, BetaCC, BetaMO, BetaMOCC
|
33 | 33 | from pybandits.smab import (
|
34 | 34 | SmabBernoulli,
|
|
48 | 48 | MultiObjectiveBandit,
|
49 | 49 | MultiObjectiveCostControlBandit,
|
50 | 50 | )
|
| 51 | +from tests.test_utils import is_serializable |
51 | 52 |
|
52 | 53 | ########################################################################################################################
|
53 | 54 |
|
@@ -200,6 +201,26 @@ def test_smab_accepts_only_valid_actions(s):
|
200 | 201 | SmabBernoulli(actions={s: Beta(), s + "_": Beta()})
|
201 | 202 |
|
202 | 203 |
|
| 204 | +@pytest.mark.parametrize("action_dict, action_ids", [ |
| 205 | + ({"a0": Beta(), "a1": Beta(), "a2": Beta()}, None), |
| 206 | + ({"a0": Beta(), "a1": Beta(n_successes=5, n_failures=5), "a2": Beta(n_successes=10, n_failures=1), |
| 207 | + "a3": Beta(n_successes=10, n_failures=5), "a4": Beta(n_successes=100, n_failures=4), "a5": Beta()}, None), |
| 208 | + (None, {"a0", "a1", "a2"})]) |
| 209 | +def test_smab_get_state(action_dict: Optional[dict], action_ids: Optional[set]): |
| 210 | + if action_dict: |
| 211 | + smab = SmabBernoulli(actions=action_dict) |
| 212 | + expected_state = {"actions": action_dict, "strategy": {}} |
| 213 | + else: |
| 214 | + smab = create_smab_bernoulli_cold_start(action_ids=action_ids) |
| 215 | + expected_state = {"actions": {action_id: Beta() for action_id in action_ids}, "strategy": {}} |
| 216 | + |
| 217 | + smab_state = smab.get_state() |
| 218 | + assert smab_state[0] == "SmabBernoulli" |
| 219 | + assert smab_state[1] == expected_state |
| 220 | + |
| 221 | + assert is_serializable(smab_state[1]), "Internal state is not serializable" |
| 222 | + |
| 223 | + |
203 | 224 | ########################################################################################################################
|
204 | 225 |
|
205 | 226 |
|
@@ -265,6 +286,27 @@ def test_smabbai_with_betacc():
|
265 | 286 | )
|
266 | 287 |
|
267 | 288 |
|
| 289 | +@pytest.mark.parametrize("action_dict, action_ids, exploit_p", [ |
| 290 | + ({"a0": Beta(), "a1": Beta(), "a2": Beta()}, None, 0.3), |
| 291 | + ({"a0": Beta(), "a1": Beta(n_successes=5, n_failures=5), "a2": Beta(n_successes=10, n_failures=1), |
| 292 | + "a3": Beta(n_successes=10, n_failures=5), "a4": Beta(n_successes=100, n_failures=4), "a5": Beta()}, None, 0.8), |
| 293 | + (None, {"a0", "a1", "a2"}, 0.5)]) |
| 294 | +def test_smab_bai_get_state(action_dict: Optional[dict], action_ids: Optional[set], exploit_p: Float01): |
| 295 | + if action_dict: |
| 296 | + smab = SmabBernoulliBAI(actions=action_dict, exploit_p=exploit_p) |
| 297 | + expected_state = {"actions": action_dict, "strategy": {"exploit_p": exploit_p}} |
| 298 | + else: |
| 299 | + smab = create_smab_bernoulli_bai_cold_start(action_ids=action_ids, exploit_p=exploit_p) |
| 300 | + expected_state = {"actions": {action_id: Beta() for action_id in action_ids}, |
| 301 | + "strategy": {"exploit_p": exploit_p}} |
| 302 | + |
| 303 | + smab_state = smab.get_state() |
| 304 | + assert smab_state[0] == "SmabBernoulliBAI" |
| 305 | + assert smab_state[1] == expected_state |
| 306 | + |
| 307 | + assert is_serializable(smab_state[1]), "Internal state is not serializable" |
| 308 | + |
| 309 | + |
268 | 310 | ########################################################################################################################
|
269 | 311 |
|
270 | 312 |
|
@@ -327,6 +369,28 @@ def test_smabcc_update():
|
327 | 369 | s.update(actions=["a1", "a1"], rewards=[1, 0])
|
328 | 370 |
|
329 | 371 |
|
| 372 | +@pytest.mark.parametrize("action_dict, action_ids_cost, subsidy_factor", [ |
| 373 | + ({"a0": BetaCC(cost=0.1), "a1": BetaCC(cost=0.2), "a2": BetaCC(cost=0.3)}, None, 0.3), |
| 374 | + ({"a0": BetaCC(cost=0.1), "a1": BetaCC(n_successes=5, n_failures=5, cost=0.2), "a2": |
| 375 | + BetaCC(n_successes=10, n_failures=1, cost=0.3), "a3": BetaCC(n_successes=10, n_failures=5, cost=0.4), |
| 376 | + "a4": BetaCC(n_successes=100, n_failures=4, cost=0.5), "a5": BetaCC(cost=0.6)}, None, 0.8), |
| 377 | + (None, {"a0": 0.1, "a1": 0.2, "a2": 0.3}, 0.5)]) |
| 378 | +def test_smab_cc_get_state(action_dict: Optional[dict], action_ids_cost: Optional[dict], subsidy_factor: Float01): |
| 379 | + if action_dict: |
| 380 | + smab = SmabBernoulliCC(actions=action_dict, subsidy_factor=subsidy_factor) |
| 381 | + expected_state = {"actions": action_dict, "strategy": {"subsidy_factor": subsidy_factor}} |
| 382 | + else: |
| 383 | + smab = create_smab_bernoulli_cc_cold_start(action_ids_cost=action_ids_cost, subsidy_factor=subsidy_factor) |
| 384 | + expected_state = {"actions": {k: BetaCC(cost=v) for k, v in action_ids_cost.items()}, |
| 385 | + "strategy": {"subsidy_factor": subsidy_factor}} |
| 386 | + |
| 387 | + smab_state = smab.get_state() |
| 388 | + assert smab_state[0] == "SmabBernoulliCC" |
| 389 | + assert smab_state[1] == expected_state |
| 390 | + |
| 391 | + assert is_serializable(smab_state[1]), "Internal state is not serializable" |
| 392 | + |
| 393 | + |
330 | 394 | ########################################################################################################################
|
331 | 395 |
|
332 | 396 |
|
@@ -414,6 +478,31 @@ def test_smab_mo_update():
|
414 | 478 | mab.update(actions=["a1", "a1"], rewards=[[1, 0, 1], [1, 1, 0]])
|
415 | 479 |
|
416 | 480 |
|
| 481 | +@pytest.mark.parametrize("action_dict, action_ids_with_obj", [ |
| 482 | + ({"a0": BetaMO(counters=[Beta(), Beta()]), "a1": BetaMO(counters=[Beta(), Beta()]), |
| 483 | + "a2": BetaMO(counters=[Beta(), Beta()])}, None), |
| 484 | + ({"a0": BetaMO(counters=[Beta(), Beta()]), |
| 485 | + "a1": BetaMO(counters=[Beta(n_successes=2, n_failures=3), Beta(n_successes=4, n_failures=5)]), |
| 486 | + "a2": BetaMO(counters=[Beta(n_successes=6, n_failures=7), Beta(n_successes=8, n_failures=9)])}, None), |
| 487 | + (None, ({"a0", "a1", "a2"}, 2)) |
| 488 | +]) |
| 489 | +def test_smab_mo_get_state(action_dict: Optional[dict], action_ids_with_obj: Optional[Tuple[set, int]]): |
| 490 | + if action_dict: |
| 491 | + smab = SmabBernoulliMO(actions=action_dict) |
| 492 | + expected_state = {"actions": action_dict, "strategy": {}} |
| 493 | + else: |
| 494 | + smab = create_smab_bernoulli_mo_cold_start(action_ids=action_ids_with_obj[0], |
| 495 | + n_objectives=action_ids_with_obj[1]) |
| 496 | + expected_state = {"actions": {action_id: BetaMO(counters=[Beta()] * action_ids_with_obj[1]) for action_id in |
| 497 | + action_ids_with_obj[0]}, "strategy": {}} |
| 498 | + |
| 499 | + smab_state = smab.get_state() |
| 500 | + assert smab_state[0] == "SmabBernoulliMO" |
| 501 | + assert smab_state[1] == expected_state |
| 502 | + |
| 503 | + assert is_serializable(smab_state[1]), "Internal state is not serializable" |
| 504 | + |
| 505 | + |
417 | 506 | ########################################################################################################################
|
418 | 507 |
|
419 | 508 |
|
@@ -498,3 +587,30 @@ def test_smab_mo_cc_predict():
|
498 | 587 | forbidden = ["a1", "a3"]
|
499 | 588 | with pytest.raises(ValueError):
|
500 | 589 | s.predict(n_samples=n_samples, forbidden_actions=forbidden)
|
| 590 | + |
| 591 | + |
| 592 | +@pytest.mark.parametrize("action_dict, action_ids_cost_with_obj", [ |
| 593 | + ({"a0": BetaMOCC(counters=[Beta(), Beta()], cost=0.1), "a1": BetaMOCC(counters=[Beta(), Beta()], cost=0.2), |
| 594 | + "a2": BetaMOCC(counters=[Beta(), Beta()], cost=0.3)}, None), |
| 595 | + ({"a0": BetaMOCC(counters=[Beta(), Beta()], cost=0.1), |
| 596 | + "a1": BetaMOCC(counters=[Beta(n_successes=2, n_failures=3), Beta(n_successes=4, n_failures=5)], cost=0.2), |
| 597 | + "a2": BetaMOCC(counters=[Beta(n_successes=6, n_failures=7), Beta(n_successes=8, n_failures=9)], cost=0.3)}, None), |
| 598 | + (None, ({"a0": 0.1, "a1": 0.2, "a2": 0.3}, 2)) |
| 599 | +]) |
| 600 | +def test_smab_mo_cc_get_state(action_dict: Optional[dict], |
| 601 | + action_ids_cost_with_obj: Optional[Tuple[dict, int]]): |
| 602 | + if action_dict: |
| 603 | + smab = SmabBernoulliMOCC(actions=action_dict) |
| 604 | + expected_state = {"actions": action_dict, "strategy": {}} |
| 605 | + else: |
| 606 | + smab = create_smab_bernoulli_mo_cc_cold_start(action_ids_cost=action_ids_cost_with_obj[0], |
| 607 | + n_objectives=action_ids_cost_with_obj[1]) |
| 608 | + expected_state = {"actions": {action_id: BetaMOCC(counters=[Beta()] * action_ids_cost_with_obj[1], |
| 609 | + cost=action_ids_cost_with_obj[0][action_id]) for action_id in |
| 610 | + action_ids_cost_with_obj[0]}, "strategy": {}} |
| 611 | + |
| 612 | + smab_state = smab.get_state() |
| 613 | + assert smab_state[0] == "SmabBernoulliMOCC" |
| 614 | + assert smab_state[1] == expected_state |
| 615 | + |
| 616 | + assert is_serializable(smab_state[1]), "Internal state is not serializable" |
0 commit comments