Skip to content

Commit

Permalink
CR fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Shahar-Bar committed Aug 5, 2024
1 parent eaa17e2 commit 6addb32
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 76 deletions.
90 changes: 15 additions & 75 deletions pybandits/smab.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,8 @@ class SmabBernoulli(BaseSmabBernoulli):
actions: Dict[ActionId, Beta]
strategy: ClassicBandit

def __init__(
self,
actions: Dict[ActionId, Beta],
epsilon: Optional[Float01] = None,
default_action: Optional[ActionId] = None,
):
super().__init__(actions=actions, strategy=ClassicBandit(epsilon=epsilon, default_action=default_action))
def __init__(self, actions: Dict[ActionId, Beta]):
super().__init__(actions=actions, strategy=ClassicBandit())

@classmethod
def from_state(cls, state: dict) -> "SmabBernoulli":
Expand Down Expand Up @@ -179,18 +174,8 @@ class SmabBernoulliBAI(BaseSmabBernoulli):
actions: Dict[ActionId, Beta]
strategy: BestActionIdentification

def __init__(
self,
actions: Dict[ActionId, Beta],
epsilon: Optional[Float01] = None,
default_action: Optional[ActionId] = None,
exploit_p: Optional[Float01] = None,
):
strategy = (
BestActionIdentification(epsilon=epsilon, default_action=default_action)
if exploit_p is None
else BestActionIdentification(epsilon=epsilon, default_action=default_action, exploit_p=exploit_p)
)
def __init__(self, actions: Dict[ActionId, Beta], exploit_p: Optional[Float01] = None):
strategy = BestActionIdentification() if exploit_p is None else BestActionIdentification(exploit_p=exploit_p)
super().__init__(actions=actions, strategy=strategy)

@classmethod
Expand Down Expand Up @@ -231,15 +216,9 @@ class SmabBernoulliCC(BaseSmabBernoulli):
def __init__(
self,
actions: Dict[ActionId, BetaCC],
epsilon: Optional[Float01] = None,
default_action: Optional[ActionId] = None,
subsidy_factor: Optional[Float01] = None,
):
strategy = (
CostControlBandit(epsilon=epsilon, default_action=default_action)
if subsidy_factor is None
else CostControlBandit(epsilon=epsilon, default_action=default_action, subsidy_factor=subsidy_factor)
)
strategy = CostControlBandit() if subsidy_factor is None else CostControlBandit(subsidy_factor=subsidy_factor)
super().__init__(actions=actions, strategy=strategy)

@classmethod
Expand Down Expand Up @@ -306,10 +285,8 @@ class SmabBernoulliMO(BaseSmabBernoulliMO):
def __init__(
self,
actions: Dict[ActionId, Beta],
epsilon: Optional[Float01] = None,
default_action: Optional[ActionId] = None,
):
super().__init__(actions=actions, strategy=MultiObjectiveBandit(epsilon=epsilon, default_action=default_action))
super().__init__(actions=actions, strategy=MultiObjectiveBandit())

@classmethod
def from_state(cls, state: dict) -> "SmabBernoulliMO":
Expand Down Expand Up @@ -338,12 +315,8 @@ class SmabBernoulliMOCC(BaseSmabBernoulliMO):
def __init__(
self,
actions: Dict[ActionId, Beta],
epsilon: Optional[Float01] = None,
default_action: Optional[ActionId] = None,
):
super().__init__(
actions=actions, strategy=MultiObjectiveCostControlBandit(epsilon=epsilon, default_action=default_action)
)
super().__init__(actions=actions, strategy=MultiObjectiveCostControlBandit())

@classmethod
def from_state(cls, state: dict) -> "SmabBernoulliMOCC":
Expand All @@ -352,7 +325,7 @@ def from_state(cls, state: dict) -> "SmabBernoulliMOCC":

@validate_call
def create_smab_bernoulli_cold_start(
action_ids: Set[ActionId], epsilon: Optional[Float01] = None, default_action: Optional[ActionId] = None
action_ids: Set[ActionId],
) -> SmabBernoulli:
"""
Utility function to create a Stochastic Bernoulli Multi-Armed Bandit with Thompson Sampling, with default
Expand All @@ -362,10 +335,6 @@ def create_smab_bernoulli_cold_start(
----------
action_ids: Set[ActionId]
The list of possible actions.
epsilon: Optional[Float01]
epsilon for epsilon-greedy approach. If None, epsilon-greedy is not used.
default_action: Optional[ActionId]
Default action to select if the epsilon-greedy approach is used. None for random selection.
Returns
-------
Expand All @@ -375,15 +344,12 @@ def create_smab_bernoulli_cold_start(
actions = {}
for a in set(action_ids):
actions[a] = Beta()
return SmabBernoulli(actions=actions, epsilon=epsilon, default_action=default_action)
return SmabBernoulli(actions=actions)


@validate_call
def create_smab_bernoulli_bai_cold_start(
action_ids: Set[ActionId],
exploit_p: Optional[Float01] = None,
epsilon: Optional[Float01] = None,
default_action: Optional[ActionId] = None,
action_ids: Set[ActionId], exploit_p: Optional[Float01] = None
) -> SmabBernoulliBAI:
"""
Utility function to create a Stochastic Bernoulli Multi-Armed Bandit with Thompson Sampling, and Best Action
Expand All @@ -402,10 +368,6 @@ def create_smab_bernoulli_bai_cold_start(
(it behaves as a Greedy strategy).
If exploit_p is 0, the bandits always select the action with 2nd highest probability of getting a positive
reward.
epsilon: Optional[Float01]
epsilon for epsilon-greedy approach. If None, epsilon-greedy is not used.
default_action: Optional[ActionId]
Default action to select if the epsilon-greedy approach is used. None for random selection.
Returns
-------
Expand All @@ -415,15 +377,13 @@ def create_smab_bernoulli_bai_cold_start(
actions = {}
for a in set(action_ids):
actions[a] = Beta()
return SmabBernoulliBAI(actions=actions, epsilon=epsilon, default_action=default_action, exploit_p=exploit_p)
return SmabBernoulliBAI(actions=actions, exploit_p=exploit_p)


@validate_call
def create_smab_bernoulli_cc_cold_start(
action_ids_cost: Dict[ActionId, NonNegativeFloat],
subsidy_factor: Optional[Float01] = None,
epsilon: Optional[Float01] = None,
default_action: Optional[ActionId] = None,
) -> SmabBernoulliCC:
"""
Utility function to create a Stochastic Bernoulli Multi-Armed Bandit with Thompson Sampling, and Cost Control
Expand All @@ -449,10 +409,6 @@ def create_smab_bernoulli_cc_cold_start(
If subsidy_factor is 1, the bandits always selects the action with the minimum cost.
If subsidy_factor is 0, the bandits always selects the action with highest probability of getting a positive
reward (it behaves as a classic Bernoulli bandit).
epsilon: Optional[Float01]
epsilon for epsilon-greedy approach. If None, epsilon-greedy is not used.
default_action: Optional[ActionId]
Default action to select if the epsilon-greedy approach is used. None for random selection.
Returns
-------
Expand All @@ -462,17 +418,13 @@ def create_smab_bernoulli_cc_cold_start(
actions = {}
for a, cost in action_ids_cost.items():
actions[a] = BetaCC(cost=cost)
return SmabBernoulliCC(
actions=actions, epsilon=epsilon, default_action=default_action, subsidy_factor=subsidy_factor
)
return SmabBernoulliCC(actions=actions, subsidy_factor=subsidy_factor)


@validate_call
def create_smab_bernoulli_mo_cold_start(
action_ids: Set[ActionId],
n_objectives: PositiveInt,
epsilon: Optional[Float01] = None,
default_action: Optional[ActionId] = None,
) -> SmabBernoulliMO:
"""
Utility function to create a Stochastic Bernoulli Multi-Armed Bandit with Thompson Sampling, and Multi-Objectives
Expand All @@ -492,10 +444,6 @@ def create_smab_bernoulli_mo_cold_start(
The list of possible actions.
n_objectives: PositiveInt
The number of objectives to optimize. The bandit assumes the same number of objectives for all actions.
epsilon: Optional[Float01]
epsilon for epsilon-greedy approach. If None, epsilon-greedy is not used.
default_action: Optional[ActionId]
Default action to select if the epsilon-greedy approach is used. None for random selection.
Returns
-------
Expand All @@ -505,15 +453,12 @@ def create_smab_bernoulli_mo_cold_start(
actions = {}
for a in set(action_ids):
actions[a] = BetaMO(counters=n_objectives * [Beta()])
return SmabBernoulliMO(actions=actions, epsilon=epsilon, default_action=default_action)
return SmabBernoulliMO(actions=actions)


@validate_call
def create_smab_bernoulli_mo_cc_cold_start(
action_ids_cost: Dict[ActionId, NonNegativeFloat],
n_objectives: PositiveInt,
epsilon: Optional[Float01] = None,
default_action: Optional[ActionId] = None,
action_ids_cost: Dict[ActionId, NonNegativeFloat], n_objectives: PositiveInt
) -> SmabBernoulliMOCC:
"""
Utility function to create a Stochastic Bernoulli Multi-Armed Bandit with Thompson Sampling implementation for
Expand All @@ -528,11 +473,6 @@ def create_smab_bernoulli_mo_cc_cold_start(
The list of possible actions, and their cost.
n_objectives: PositiveInt
The number of objectives to optimize. The bandit assumes the same number of objectives for all actions.
epsilon: Optional[Float01]
epsilon for epsilon-greedy approach. If None, epsilon-greedy is not used.
default_action: Optional[ActionId]
Default action to select if the epsilon-greedy approach is used. None for random selection.
Returns
-------
Expand All @@ -542,4 +482,4 @@ def create_smab_bernoulli_mo_cc_cold_start(
actions = {}
for a, cost in action_ids_cost.items():
actions[a] = BetaMOCC(counters=n_objectives * [Beta()], cost=cost)
return SmabBernoulliMOCC(actions=actions, epsilon=epsilon, default_action=default_action)
return SmabBernoulliMOCC(actions=actions)
2 changes: 1 addition & 1 deletion tests/test_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def test_can_init_multiobjective():
@given(
st.dictionaries(
st.text(min_size=1, alphabet=st.characters(blacklist_characters=("\x00"))),
st.lists(st.floats(min_value=0, max_value=1), min_size=3, max_size=3),
st.lists(st.floats(min_value=0, max_value=2), min_size=3, max_size=3),
min_size=3,
)
)
Expand Down

0 comments on commit 6addb32

Please sign in to comment.