Skip to content
This repository has been archived by the owner on Oct 7, 2024. It is now read-only.

Commit

Permalink
Extract environments to their own package.
Browse files Browse the repository at this point in the history
This makes the distinction between an _environment_ and an _experiment_ clearer.

If users want to import individual environments for their own debugging/development:

✗ from bsuite.experiments.catch import catch
✓ from bsuite.environments import catch

This change also introduces some more formal typing of bsuite environments:
- Add a base class which includes the bsuite_* attributes/methods.

PiperOrigin-RevId: 307575828
Change-Id: Iba2303d64a397ccef8a3f3f154e414bf343f905b
  • Loading branch information
aslanides authored and copybara-github committed Apr 21, 2020
1 parent 6c12227 commit f9b74bf
Show file tree
Hide file tree
Showing 56 changed files with 1,251 additions and 948 deletions.
3 changes: 2 additions & 1 deletion bsuite/baselines/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def run(agent: base.Agent,
"""

if verbose:
environment = terminal_logging.wrap_environment(environment, log_every=True)
environment = terminal_logging.wrap_environment(
environment, log_every=True) # pytype: disable=wrong-arg-types

for _ in range(num_episodes):
# Run an episode.
Expand Down
2 changes: 1 addition & 1 deletion bsuite/baselines/third_party/dopamine_dqn/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def create_environment() -> gym.Env:
"""Factory method for environment initialization in Dopmamine."""
env = wrappers.ImageObservation(raw_env, OBSERVATION_SHAPE)
if FLAGS.verbose:
env = terminal_logging.wrap_environment(env, log_every=True)
env = terminal_logging.wrap_environment(env, log_every=True) # pytype: disable=wrong-arg-types
env = gym_wrapper.GymFromDMEnv(env)
env.game_over = False # Dopamine looks for this
return env
Expand Down
2 changes: 1 addition & 1 deletion bsuite/baselines/third_party/openai_dqn/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def run(bsuite_id: str) -> str:
overwrite=FLAGS.overwrite,
)
if FLAGS.verbose:
raw_env = terminal_logging.wrap_environment(raw_env, log_every=True)
raw_env = terminal_logging.wrap_environment(raw_env, log_every=True) # pytype: disable=wrong-arg-types
env = gym_wrapper.GymFromDMEnv(raw_env)

num_episodes = FLAGS.num_episodes or getattr(raw_env, 'bsuite_num_episodes')
Expand Down
2 changes: 1 addition & 1 deletion bsuite/baselines/third_party/openai_ppo/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _load_env():
overwrite=FLAGS.overwrite,
)
if FLAGS.verbose:
raw_env = terminal_logging.wrap_environment(raw_env, log_every=True)
raw_env = terminal_logging.wrap_environment(raw_env, log_every=True) # pytype: disable=wrong-arg-types
return gym_wrapper.GymFromDMEnv(raw_env)
env = dummy_vec_env.DummyVecEnv([_load_env])

Expand Down
21 changes: 11 additions & 10 deletions bsuite/bsuite.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Any, Mapping, Tuple

from bsuite import sweep
from bsuite.environments import base
from bsuite.experiments.bandit import bandit
from bsuite.experiments.bandit_noise import bandit_noise
from bsuite.experiments.bandit_scale import bandit_scale
Expand Down Expand Up @@ -54,29 +55,29 @@
# Each constructor or load function accepts keyword arguments as defined in
# each experiment's sweep.py file.
EXPERIMENT_NAME_TO_ENVIRONMENT = dict(
bandit=bandit.SimpleBandit,
bandit=bandit.load,
bandit_noise=bandit_noise.load,
bandit_scale=bandit_scale.load,
cartpole=cartpole.Cartpole,
cartpole=cartpole.load,
cartpole_noise=cartpole_noise.load,
cartpole_scale=cartpole_scale.load,
cartpole_swingup=cartpole_swingup.CartpoleSwingup,
catch=catch.Catch,
catch=catch.load,
catch_noise=catch_noise.load,
catch_scale=catch_scale.load,
deep_sea=deep_sea.DeepSea,
deep_sea=deep_sea.load,
deep_sea_stochastic=deep_sea_stochastic.load,
discounting_chain=discounting_chain.DiscountingChain,
discounting_chain=discounting_chain.load,
memory_len=memory_len.load,
memory_size=memory_size.load,
mnist=mnist.MNISTBandit,
mnist=mnist.load,
mnist_noise=mnist_noise.load,
mnist_scale=mnist_scale.load,
mountain_car=mountain_car.MountainCar,
mountain_car=mountain_car.load,
mountain_car_noise=mountain_car_noise.load,
mountain_car_scale=mountain_car_scale.load,
umbrella_distract=umbrella_distract.load,
umbrella_length=umbrella_length.UmbrellaChain,
umbrella_length=umbrella_length.load,
)


Expand All @@ -92,12 +93,12 @@ def unpack_bsuite_id(bsuite_id: str) -> Tuple[str, int]:
def load(
experiment_name: str,
kwargs: Mapping[str, Any],
) -> dm_env.Environment:
) -> base.Environment:
"""Returns a bsuite environment given an experiment name and settings."""
return EXPERIMENT_NAME_TO_ENVIRONMENT[experiment_name](**kwargs)


def load_from_id(bsuite_id: str) -> dm_env.Environment:
def load_from_id(bsuite_id: str) -> base.Environment:
"""Returns a bsuite environment given a bsuite_id."""
kwargs = sweep.SETTINGS[bsuite_id]
experiment_name, _ = unpack_bsuite_id(bsuite_id)
Expand Down
15 changes: 15 additions & 0 deletions bsuite/environments/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Environments

This folder contains the raw *environments* used in `bsuite` experiments; we
expose them here for debugging and development purposes;

Recall that in the context of bsuite, an *experiment* consists of three parts:
1. Environments: a fixed set of environments determined by some parameters. 2.
Interaction: a fixed regime of agent/environment interaction (e.g. 100
episodes). 3. Analysis: a fixed procedure that maps agent behaviour to results
and plots.

Note: If you load the environment from this folder you will miss out on the
interaction+analysis as specified by bsuite. In general, you should use the
`bsuite_id` to load the environment via `bsuite.load_from_id(bsuite_id)` rather
than the raw environment.
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tests for bsuite.experiments.memory_len."""
"""bsuite environments package."""

from absl.testing import absltest
from bsuite.experiments.memory_size import memory_size
from dm_env import test_utils
import numpy as np


class InterfaceTest(test_utils.EnvironmentTestMixin, absltest.TestCase):

def make_object_under_test(self):
return memory_size.load(10)

def make_action_sequence(self):
valid_actions = [0, 1]
rng = np.random.RandomState(42)

for _ in range(100):
yield rng.choice(valid_actions)

if __name__ == '__main__':
absltest.main()
from bsuite.environments.base import Environment
72 changes: 72 additions & 0 deletions bsuite/environments/bandit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# python3
# pylint: disable=g-bad-file-header
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Simple diagnostic bandit environment.
Observation is a single pixel of 0 - this is an independent arm bandit problem!
Rewards are [0, 0.1, .. 1] assigned randomly to 11 arms and deterministic
"""

from bsuite.environments import base
from bsuite.experiments.bandit import sweep

import dm_env
from dm_env import specs
import numpy as np


class SimpleBandit(base.Environment):
"""SimpleBandit environment."""

def __init__(self, seed=None):
"""Builds a simple bandit environment.
Args:
seed: Optional integer. Seed for numpy's random number generator (RNG).
"""
super(SimpleBandit, self).__init__()
self._rng = np.random.RandomState(seed)

self._n_actions = 11
action_mask = self._rng.choice(
range(self._n_actions), size=self._n_actions, replace=False)
self._rewards = np.linspace(0, 1, self._n_actions)[action_mask]

self._total_regret = 0.
self._optimal_return = 1.
self.bsuite_num_episodes = sweep.NUM_EPISODES

def _get_observation(self):
return np.ones(shape=(1, 1), dtype=np.float32)

def _reset(self) -> dm_env.TimeStep:
observation = self._get_observation()
return dm_env.restart(observation)

def _step(self, action: int) -> dm_env.TimeStep:
reward = self._rewards[action]
self._total_regret += self._optimal_return - reward
observation = self._get_observation()
return dm_env.termination(reward=reward, observation=observation)

def observation_spec(self):
return specs.Array(shape=(1, 1), dtype=np.float32)

def action_spec(self):
return specs.DiscreteArray(self._n_actions, name='action')

def bsuite_info(self):
return dict(total_regret=self._total_regret)
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""Tests for bsuite.experiments.bandit."""

from absl.testing import absltest
from bsuite.experiments.bandit import bandit
from bsuite.environments import bandit
from dm_env import test_utils
import numpy as np

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,41 +14,65 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Auto-resetting environment base class.
""""Base class for bsuite environments.
The environment API states that stepping an environment after a LAST timestep
should return the first timestep of a new episode.
This inherits from the dm_env base class, with two major differences:
- Includes bsuite-specific metadata:
- `bsuite_info` returns metadata for logging, e.g. for computing regret/score.
- `bsuite_num_episodes` specifies how long the experiment should run for.
- Implements the auto-reset behavior specified by the environment API.
That is, stepping an environment after a LAST timestep should return the
first timestep of a new episode.
"""

import abc
from typing import Any, Dict

import dm_env


class Base(dm_env.Environment):
"""This class implements the required `step()` and `reset()` methods.
class Environment(dm_env.Environment):
"""Base clas for bsuite environments.
A bsuite environment is a dm_env environment with extra metadata:
- bsuite_info method.
- bsuite_num_episodes attribute.
A bsuite environment also has auto-reset behavior.
This class implements the required `step()` and `reset()` methods.
It instead requires users to implement `_step()` and `_reset()`. This class
handles the reset behaviour automatically when it detects a LAST timestep.
"""

# Number of episodes that this environment should be run for.
bsuite_num_episodes: int

def __init__(self):
self._reset_next_step = True

@abc.abstractmethod
def _reset(self):
"""Returns a `timestep` namedtuple as per the regular `reset()` method."""

@abc.abstractmethod
def _step(self, action):
"""Returns a `timestep` namedtuple as per the regular `step()` method."""

def reset(self):
def reset(self) -> dm_env.TimeStep:
"""Resets the environment, calling the underlying _reset() method."""
self._reset_next_step = False
return self._reset()

def step(self, action):
def step(self, action: int) -> dm_env.TimeStep:
"""Steps the environment and implements the auto-reset behavior."""
if self._reset_next_step:
return self.reset()
timestep = self._step(action)
self._reset_next_step = timestep.last()
return timestep

@abc.abstractmethod
def _reset(self) -> dm_env.TimeStep:
"""Returns a `timestep` namedtuple as per the regular `reset()` method."""

@abc.abstractmethod
def _step(self, action: int) -> dm_env.TimeStep:
"""Returns a `timestep` namedtuple as per the regular `step()` method."""

@abc.abstractmethod
def bsuite_info(self) -> Dict[str, Any]:
"""Returns metadata specific to this environment for logging/scoring."""
Loading

0 comments on commit f9b74bf

Please sign in to comment.