Skip to content

Commit e3b7fcf

Browse files
committed
Renew simulation framework for multi-armed bandits
### Changes * Introduced `Simulator` base abstract class for simulating multi-armed bandit environments. * Added `SmabSimulator` class for simulating stochastic multi-armed bandits (sMAB). * Added `CmabSimulator` class for simulating contextual multi-armed bandits (cMAB). * Added utility function for identifying running code environment under utils.py. * Updated pyproject.toml to include bokeh dependency for interactive visualization. * Added unit tests for the various simulators to ensure proper functionality. * Removed simulation_plots.py.
1 parent 9c15f78 commit e3b7fcf

13 files changed

+1118
-796
lines changed

pybandits/cmab_simulator.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# MIT License
2+
#
3+
# Copyright (c) 2022 Playtika Ltd.
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy
6+
# of this software and associated documentation files (the "Software"), to deal
7+
# in the Software without restriction, including without limitation the rights
8+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
# copies of the Software, and to permit persons to whom the Software is
10+
# furnished to do so, subject to the following conditions:
11+
#
12+
# The above copyright notice and this permission notice shall be included in all
13+
# copies or substantial portions of the Software.
14+
#
15+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
# SOFTWARE.
22+
23+
import random
24+
from typing import Dict, List, Optional, Tuple
25+
26+
import numpy as np
27+
import pandas as pd
28+
from pydantic import Field, model_validator
29+
30+
from pybandits.base import ActionId, BinaryReward
31+
from pybandits.cmab import BaseCmabBernoulli
32+
from pybandits.simulator import Simulator
33+
34+
35+
class CmabSimulator(Simulator):
36+
"""
37+
Simulate environment for contextual multi-armed bandit models.
38+
39+
This class simulates information required by the contextual bandit. Generated data are processed by the bandit with
40+
batches of size n>=1. For each batch of samples, actions are recommended by the bandit and corresponding simulated
41+
rewards collected. Bandit policy parameters are then updated based on returned rewards from recommended actions.
42+
43+
Parameters
44+
----------
45+
mab : BaseCmabBernoulli
46+
Contextual multi-armed bandit model
47+
context : np.ndarray of shape (n_samples, n_feature)
48+
Context matrix of samples features.
49+
group : Optional[List] with length=n_samples
50+
Group to which each sample belongs. Samples which belongs to the same group have features that come from the
51+
same distribution and they have the same probability to receive a positive/negative feedback from each action.
52+
If not supplied, all samples are assigned to the group.
53+
"""
54+
55+
mab: BaseCmabBernoulli = Field(validation_alias="cmab")
56+
context: np.ndarray
57+
group: Optional[List] = None
58+
_base_columns: List[str] = ["batch", "action", "reward", "group"]
59+
60+
@model_validator(mode="after")
61+
def replace_nulls_and_validate_sizes(self):
62+
if len(self.context) != self.batch_size * self.n_updates:
63+
raise ValueError("Context length must equal to batch_size x n_updates.")
64+
if self.group is None:
65+
self.group = len(self.context) * [0]
66+
else:
67+
if len(self.context) != len(self.group):
68+
raise ValueError("Mismatch between context length and group length")
69+
mab_action_ids = list(self.mab.actions.keys())
70+
index = list(set(self.group))
71+
if self.probs_reward is None:
72+
self.probs_reward = pd.DataFrame(0.5, index=index, columns=mab_action_ids)
73+
else:
74+
if self.probs_reward.shape[0] != len(index):
75+
raise ValueError("number of probs_reward rows must match the number of groups.")
76+
return self
77+
78+
def _initialize_results(self):
79+
"""
80+
Initialize the results DataFrame. The results DataFrame is used to store the raw simulation results.
81+
"""
82+
self._results = pd.DataFrame(
83+
columns=["action", "reward", "group", "selected_prob_reward", "max_prob_reward"],
84+
)
85+
86+
def _draw_rewards(self, actions: List[ActionId], metadata: Dict[str, List]) -> List[BinaryReward]:
87+
"""
88+
Draw rewards for the selected actions based on metadata according to probs_reward
89+
90+
Parameters
91+
----------
92+
actions : List[ActionId]
93+
The actions selected by the multi-armed bandit model.
94+
metadata : Dict[str, List]
95+
The metadata for the selected actions; should contain the batch groups association.
96+
97+
Returns
98+
-------
99+
reward : List[BinaryReward]
100+
A list of binary rewards.
101+
"""
102+
rewards = [int(random.random() < self.probs_reward.loc[g, a]) for g, a in zip(metadata["group"], actions)]
103+
return rewards
104+
105+
def _get_batch_step_kwargs_and_metadata(
106+
self, batch_index
107+
) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Dict[str, List]]:
108+
"""
109+
Extract context required for the cMAB's update and predict functionality,
110+
as well as metadata for sample group.
111+
112+
Parameters
113+
----------
114+
batch_index : int
115+
The index of the batch.
116+
117+
Returns
118+
-------
119+
predict_kwargs : Dict[str, np.ndarray]
120+
Dictionary containing the context for the batch.
121+
update_kwargs : Dict[str, np.ndarray]
122+
Dictionary containing the context for the batch.
123+
metadata : Dict[str, List]
124+
Dictionary containing the group information for the batch.
125+
"""
126+
idx_batch_min = batch_index * self.batch_size
127+
idx_batch_max = (batch_index + 1) * self.batch_size
128+
predict_and_update_kwargs = {"context": self.context[idx_batch_min:idx_batch_max]}
129+
metadata = {"group": self.group[idx_batch_min:idx_batch_max]}
130+
return predict_and_update_kwargs, predict_and_update_kwargs, metadata
131+
132+
def _finalize_step(self, batch_results: pd.DataFrame):
133+
"""
134+
Finalize the step by adding additional information to the batch results.
135+
136+
Parameters
137+
----------
138+
batch_results : pd.DataFrame
139+
raw batch results
140+
141+
Returns
142+
-------
143+
batch_results : pd.DataFrame
144+
batch results with added reward probability for selected a1nd most rewarding action
145+
"""
146+
group_id = batch_results.loc[:, "group"]
147+
action_id = batch_results.loc[:, "action"]
148+
selected_prob_reward = [self.probs_reward.loc[g, a] for g, a in zip(group_id, action_id)]
149+
batch_results.loc[:, "selected_prob_reward"] = selected_prob_reward
150+
max_prob_reward = self.probs_reward.loc[group_id].max(axis=1)
151+
batch_results.loc[:, "max_prob_reward"] = max_prob_reward.tolist()
152+
return batch_results
153+
154+
def _finalize_results(self):
155+
"""
156+
Finalize the simulation process. Used to add regret and cumulative regret
157+
158+
Returns
159+
-------
160+
None
161+
"""
162+
self._results["regret"] = self._results["max_prob_reward"] - self._results["selected_prob_reward"]
163+
self._results["cum_regret"] = self._results["regret"].cumsum()

0 commit comments

Comments
 (0)