Skip to content
This repository has been archived by the owner on Oct 7, 2024. It is now read-only.

Commit

Permalink
Make number of bandit arms configurable.
Browse files Browse the repository at this point in the history
Default kwargs keep bsuite the same.

PiperOrigin-RevId: 355511900
Change-Id: I0d674e9650ec24f5b9f9932219e66451df3e0fca
  • Loading branch information
iosband authored and copybara-github committed Feb 4, 2021
1 parent b4c9fed commit 5116216
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
12 changes: 6 additions & 6 deletions bsuite/environments/bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,19 @@
class SimpleBandit(base.Environment):
"""SimpleBandit environment."""

def __init__(self, mapping_seed: int = None):
def __init__(self, mapping_seed: int = None, num_actions: int = 11):
"""Builds a simple bandit environment.
Args:
mapping_seed: Optional integer. Seed for action mapping.
num_actions: number of actions available, defaults to 11.
"""
super(SimpleBandit, self).__init__()
self._rng = np.random.RandomState(mapping_seed)

self._n_actions = 11
self._num_actions = num_actions
action_mask = self._rng.choice(
range(self._n_actions), size=self._n_actions, replace=False)
self._rewards = np.linspace(0, 1, self._n_actions)[action_mask]
range(self._num_actions), size=self._num_actions, replace=False)
self._rewards = np.linspace(0, 1, self._num_actions)[action_mask]

self._total_regret = 0.
self._optimal_return = 1.
Expand All @@ -66,7 +66,7 @@ def observation_spec(self):
return specs.Array(shape=(1, 1), dtype=np.float32)

def action_spec(self):
return specs.DiscreteArray(self._n_actions, name='action')
return specs.DiscreteArray(self._num_actions, name='action')

def bsuite_info(self):
return dict(total_regret=self._total_regret)
4 changes: 2 additions & 2 deletions bsuite/experiments/bandit_noise/bandit_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
from bsuite.utils import wrappers


def load(noise_scale, seed, mapping_seed):
def load(noise_scale, seed, mapping_seed, num_actions=11):
"""Load a bandit_noise experiment with the prescribed settings."""
env = wrappers.RewardNoise(
env=bandit.SimpleBandit(mapping_seed=mapping_seed),
env=bandit.SimpleBandit(mapping_seed, num_actions=num_actions),
noise_scale=noise_scale,
seed=seed)
env.bsuite_num_episodes = sweep.NUM_EPISODES
Expand Down

0 comments on commit 5116216

Please sign in to comment.