diff --git a/setup.py b/setup.py index 6c3c3c885..850f0a679 100644 --- a/setup.py +++ b/setup.py @@ -175,15 +175,16 @@ def get_required_packages(): """Returns list of required packages.""" required_packages = [ - 'absl-py >= 0.6.1', - 'cloudpickle >= 1.3', - 'gin-config >= 0.4.0', - 'gym >= 0.17.0, <=0.23.0', - 'numpy >= 1.19.0', - 'pillow', - 'six >= 1.10.0', - 'protobuf >= 3.11.3', - 'wrapt >= 1.11.1', + 'absl-py >= 2.0.0', + 'cloudpickle >= 3.0.0', + 'gin-config >= 0.5.0', + 'gym >= 0.17.0, <= 0.23.1', + 'gymnasium >= 0.29.0', + 'numpy >= 1.26.2', + 'pillow >= 10.1.0', + 'six >= 1.16.0', + 'protobuf >= 3.11.3, <= 4.23.4', + 'wrapt >= 1.16.0', # Using an older version to avoid this bug # https://github.com/tensorflow/tensorflow/issues/62217 # while using tf 2.15.0 @@ -191,7 +192,7 @@ def get_required_packages(): # Used by gym >= 0.22.0. Only installed as a dependency when gym[all] is # installed or if gym[*] (where * is an environment which lists pygame as # a dependency). - 'pygame == 2.1.3', + 'pygame == 2.5.2', ] add_additional_packages(required_packages) return required_packages diff --git a/tf_agents/agents/dqn/dqn_agent.py b/tf_agents/agents/dqn/dqn_agent.py index 253ae10a4..c256ba34f 100644 --- a/tf_agents/agents/dqn/dqn_agent.py +++ b/tf_agents/agents/dqn/dqn_agent.py @@ -328,11 +328,14 @@ def _check_network_output(self, net, label): net: A `Network`. label: A label to print in case of a mismatch. """ - network_utils.check_single_floating_network_output( - net.create_variables(), - expected_output_shape=(self._num_actions,), - label=label, - ) + outputs = net.create_variables() + iterable = list(outputs) if isinstance(outputs, tuple) else [outputs] + for output in iterable: + network_utils.check_single_floating_network_output( + output, + expected_output_shape=(self._num_actions,), + label=label, + ) def _setup_policy( self, @@ -590,8 +593,10 @@ def _compute_q_values(self, time_steps, actions, training=False): # param. Note: assumes len(tf.nest.flatten(action_spec)) == 1. action_spec = cast(tensor_spec.BoundedTensorSpec, self._action_spec) multi_dim_actions = action_spec.shape.rank > 0 + # support for dueling networks + a_values = q_values[0] if isinstance(q_values, tuple) else q_values return common.index_with_actions( - q_values, + a_values, tf.cast(actions, dtype=tf.int32), multi_dim_actions=multi_dim_actions, ) @@ -614,9 +619,12 @@ def _compute_next_q_values(self, next_time_steps, info): network_observation ) - next_target_q_values, _ = self._target_q_network( + q_next_target, _ = self._target_q_network( network_observation, step_type=next_time_steps.step_type ) + next_target_q_values = ( + q_next_target[0] if isinstance(q_next_target, tuple) else q_next_target + ) batch_size = ( next_target_q_values.shape[0] or tf.shape(next_target_q_values)[0] ) @@ -668,9 +676,12 @@ def _compute_next_q_values(self, next_time_steps, info): network_observation ) - next_target_q_values, _ = self._target_q_network( + q_next_target, _ = self._target_q_network( network_observation, step_type=next_time_steps.step_type ) + next_target_q_values = ( + q_next_target[0] if isinstance(q_next_target, tuple) else q_next_target + ) batch_size = ( next_target_q_values.shape[0] or tf.shape(next_target_q_values)[0] ) @@ -687,3 +698,56 @@ def _compute_next_q_values(self, next_time_steps, info): best_next_actions, multi_dim_actions=multi_dim_actions, ) + + +@gin.configurable +class D3qnAgent(DqnAgent): + """A Dueling DQN Agent. + + Implements the Double Dueling DQN algorithm from + + "Dueling Network Architectures for Deep Reinforcement Learning" + Wang et al., 2016 + https://arxiv.org/abs/1511.06581 + """ + + def _compute_next_q_values(self, next_time_steps, info): + """Compute the q value of the next state for TD error computation. + + Args: + next_time_steps: A batch of next timesteps + info: PolicyStep.info that may be used by other agents inherited from + dqn_agent. + + Returns: + A tensor of Q values for the given next state. + """ + del info + # TODO(b/117175589): Add binary tests for DDQN. + network_observation = next_time_steps.observation + + if self._observation_and_action_constraint_splitter is not None: + network_observation, _ = self._observation_and_action_constraint_splitter( + network_observation + ) + + q_next_target, _ = self._target_q_network( + network_observation, step_type=next_time_steps.step_type + ) + next_target_q_values = ( + q_next_target[0] if isinstance(q_next_target, tuple) else q_next_target + ) + q_next, _ = self._q_network( + network_observation, step_type=next_time_steps.step_type + ) + next_q_values = q_next[1] if isinstance(q_next, tuple) else q_next + best_next_actions = tf.math.argmax(next_q_values, axis=1) + + # Handle action_spec.shape=(), and shape=(1,) by using the multi_dim_actions + # param. Note: assumes len(tf.nest.flatten(action_spec)) == 1. + multi_dim_actions = tf.nest.flatten(self._action_spec)[0].shape.rank > 0 + return common.index_with_actions( + next_target_q_values, + best_next_actions, + multi_dim_actions=multi_dim_actions, + ) diff --git a/tf_agents/agents/dqn/dqn_agent_test.py b/tf_agents/agents/dqn/dqn_agent_test.py index 06d7de679..a66e2b047 100644 --- a/tf_agents/agents/dqn/dqn_agent_test.py +++ b/tf_agents/agents/dqn/dqn_agent_test.py @@ -82,7 +82,9 @@ def testComputeTDTargets(self): @parameterized.named_parameters( - ('DqnAgent', dqn_agent.DqnAgent), ('DdqnAgent', dqn_agent.DdqnAgent) + ('DqnAgent', dqn_agent.DqnAgent), + ('DdqnAgent', dqn_agent.DdqnAgent), + ('D3qnAgent', dqn_agent.D3qnAgent), ) class DqnAgentTest(test_utils.TestCase): @@ -216,6 +218,10 @@ def testLoss(self, agent_class): self.assertAllClose(self.evaluate(loss), expected_loss) def testLossWithChangedOptimalActions(self, agent_class): + + # if 'D3qnAgent' in agent_class.__name__: + # self.skipTest('invalid for dueling networks') + q_net = DummyNet(self._observation_spec, self._action_spec) agent = agent_class( self._time_step_spec, self._action_spec, q_network=q_net, optimizer=None @@ -475,6 +481,10 @@ def testLossNStepMidMidLastFirst(self, agent_class): self.assertAllClose(self.evaluate(loss), expected_loss) def testLossWithMaskedActions(self, agent_class): + + # if 'D3qnAgent' in agent_class.__name__: + # self.skipTest('invalid for dueling networks') + # Observations are now a tuple of the usual observation and an action mask. observation_spec_with_mask = ( self._observation_spec, @@ -529,10 +539,22 @@ def testLossWithMaskedActions(self, agent_class): # Target Q-value for second next_observation (only action 0 is valid): # 2 * 7 + 1 * 8 + 1 = 23 # TD targets: 10 + 0.9 * 12 = 20.8 and 20 + 0.9 * 23 = 40.7 - # TD errors: 20.8 - 5 = 15.8 and 40.7 - 8 = 32.7 - # TD loss: 15.3 and 32.2 (Huber loss subtracts 0.5) + # TD errors: 20.8 - 5 = 20.3 and 40.7 - 8 = 32.7 + # TD loss: 19.8 and 32.2 (Huber loss subtracts 0.5) # Overall loss: (15.3 + 32.2) / 2 = 23.75 expected_loss = 23.75 + if 'D3qnAgent' in agent_class.__name__: + # Using Q-values for next_observations only for D3qnAgent. + # Q-value for first next_observation/action pair: + # 2 * 5 + 1 * 6 + 1 = 17 + # Q-value for second next_observation/action pair: + # 1 * 7 + 1 * 8 + 1 = 16 + # TD targets: 10 + 0.9 * 17 = 25.3 and 20 + 0.9 * 23 = 40.7 + # TD errors: 25.3 - 5 = 20.3 and 40.7 - 8 = 32.7 + # TD loss: 19.8 and 32.2 (Huber loss subtracts 0.5) + # Overall loss: (19.8 + 32.2) / 2 = 26.0 + expected_loss = 26.0 + loss, _ = agent._loss(experience) self.evaluate(tf.compat.v1.global_variables_initializer()) diff --git a/tf_agents/environments/__init__.py b/tf_agents/environments/__init__.py index 3ecb6866a..ff5b95d7a 100644 --- a/tf_agents/environments/__init__.py +++ b/tf_agents/environments/__init__.py @@ -29,7 +29,9 @@ # pylint: disable=g-import-not-at-top try: from tf_agents.environments import gym_wrapper + from tf_agents.environments import gymnasium_wrapper from tf_agents.environments import suite_gym + from tf_agents.environments import suite_gymnasium from tf_agents.environments import suite_atari from tf_agents.environments import suite_dm_control from tf_agents.environments import suite_mujoco diff --git a/tf_agents/environments/configs/suite_gymnasium.gin b/tf_agents/environments/configs/suite_gymnasium.gin new file mode 100644 index 000000000..0503a5363 --- /dev/null +++ b/tf_agents/environments/configs/suite_gymnasium.gin @@ -0,0 +1,9 @@ +#-*-Python-*- +import tf_agents.environments.suite_gymnasium + +## Configure Environment +ENVIRONMENT = @suite_gymnasium.load() +suite_gymnasium.load.environment_name = %ENVIRONMENT_NAME +# Note: The ENVIRONMENT_NAME can be overridden by passing the command line flag: +# --params="ENVIRONMENT_NAME='CartPole-v1'" +ENVIRONMENT_NAME = 'CartPole-v1' diff --git a/tf_agents/environments/gymnasium_wrapper.py b/tf_agents/environments/gymnasium_wrapper.py new file mode 100644 index 000000000..2d2e3e85d --- /dev/null +++ b/tf_agents/environments/gymnasium_wrapper.py @@ -0,0 +1,254 @@ +# coding=utf-8 +# Copyright 2020 The TF-Agents Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wrapper providing a PyEnvironmentBase adapter for Gymnasium environments.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +from typing import Any, Optional, Text, Union + +import gymnasium as gym +from gymnasium.utils import seeding +import numpy as np +import tensorflow as tf +from tf_agents import specs +from tf_agents.environments import py_environment +from tf_agents.trajectories import time_step as ts +from tf_agents.typing import types + + +def spec_from_gym_space( + space: gym.Space, + simplify_box_bounds: bool = True, + name: Optional[Text] = None, +) -> Union[ + specs.BoundedArraySpec, + specs.ArraySpec, + tuple[specs.ArraySpec, ...], + list[specs.ArraySpec], + collections.OrderedDict[str, specs.ArraySpec], +]: + """Converts gymnasium spaces into array specs, or a collection thereof. + + Please note: + Unlike OpenAI's gym, Farama's gymnasium provides a dtype for + each current implementation of spaces. dtype should be defined + in all specific subclasses of gymnasium.Space even if it is still + optional in the superclass. + + Args: + space: gymnasium.Space to turn into a spec. + simplify_box_bounds: Whether to replace bounds of Box space that are arrays + with identical values with one number and rely on broadcasting. + name: Name of the spec. + + Returns: + A BoundedArraySpec or an ArraySpec nest mirroring the given space structure. + The result can be a tuple, sequence or dict of specs for specific Spaces. + Raises: + ValueError: If there is an unknown space type. + """ + + # We try to simplify redundant arrays to make logging and debugging less + # verbose and easier to read since the printed spec bounds may be large. + def try_simplify_array_to_value(np_array): + """If given numpy array has all the same values, returns that value.""" + first_value = np_array.item(0) + if np.all(np_array == first_value): + return np.array(first_value, dtype=np_array.dtype) + else: + return np_array + + def nested_spec(spec, child_name): + """Returns the nested spec with a unique name.""" + nested_name = name + '/' + child_name if name else child_name + return spec_from_gym_space(spec, simplify_box_bounds, nested_name) + + if isinstance(space, gym.spaces.Discrete): + # Discrete spaces span the set {0, 1, ... , n-1} while Bounded Array specs + # are inclusive on their bounds. + maximum = space.n - 1 + return specs.BoundedArraySpec( + shape=(), dtype=np.int64, minimum=0, maximum=maximum, name=name + ) + elif isinstance(space, gym.spaces.MultiDiscrete): + dtype = np.integer + maximum = try_simplify_array_to_value( + np.asarray(space.nvec - 1, dtype=dtype) + ) + return specs.BoundedArraySpec( + shape=space.shape, dtype=dtype, minimum=0, maximum=maximum, name=name + ) + elif isinstance(space, gym.spaces.MultiBinary): + return specs.BoundedArraySpec( + shape=space.shape, dtype=np.int8, minimum=0, maximum=1, name=name + ) + elif isinstance(space, gym.spaces.Box): + dtype = space.dtype + minimum = np.asarray(space.low, dtype=dtype) + maximum = np.asarray(space.high, dtype=dtype) + if simplify_box_bounds: + simple_minimum = try_simplify_array_to_value(minimum) + simple_maximum = try_simplify_array_to_value(maximum) + # Can only simplify if both bounds are simplified. Otherwise + # broadcasting doesn't work from non-simplified to simplified. + if simple_minimum.shape == simple_maximum.shape: + minimum = simple_minimum + maximum = simple_maximum + return specs.BoundedArraySpec( + shape=space.shape, + dtype=dtype, + minimum=minimum, + maximum=maximum, + name=name, + ) + elif isinstance(space, gym.spaces.Tuple): + return tuple( + [nested_spec(s, 'tuple_%d' % i) for i, s in enumerate(space.spaces)] + ) + elif isinstance(space, gym.spaces.Dict): + return collections.OrderedDict( + [(key, nested_spec(s, key)) for key, s in space.spaces.items()] + ) + elif isinstance(space, gym.spaces.Sequence): + return list([nested_spec(space.feature_space, 'nested_space')]) + elif isinstance(space, gym.spaces.Graph): + return ( + nested_spec(space.node_space, 'node_space'), + nested_spec(space.edge_space, 'edge_space'), + ) + elif isinstance(space, gym.spaces.Text): + return specs.ArraySpec(shape=space.shape, dtype=tf.string, name=name) + else: + raise ValueError( + 'The gymnasium space {} is currently not supported.'.format(space) + ) + + +class GymnasiumWrapper(py_environment.PyEnvironment): + """Base wrapper implementing PyEnvironmentBaseWrapper interface for Gymnasium envs. + + Action and observation specs are automatically generated from the action and + observation spaces. See base class for py_environment.Base details. + """ + + def __init__( + self, + gym_env: gym.Env, + discount: types.Float = 1.0, + auto_reset: bool = True, + simplify_box_bounds: bool = True, + ): + super(GymnasiumWrapper, self).__init__(auto_reset) + + self._gym_env = gym_env + self._discount = discount + self._action_is_discrete = isinstance( + self._gym_env.action_space, gym.spaces.Discrete + ) + self._observation_spec = spec_from_gym_space( + self._gym_env.observation_space, + simplify_box_bounds, + 'observation', + ) + self._action_spec = spec_from_gym_space( + self._gym_env.action_space, + simplify_box_bounds, + 'action', + ) + self._flat_obs_spec = tf.nest.flatten(self._observation_spec) + self._info = None + self._truncated = True + + @property + def gym(self) -> gym.Env: + return self._gym_env + + def __getattr__(self, name: Text) -> Any: + """Forward all other calls to the base environment.""" + gym_env = super(GymnasiumWrapper, self).__getattribute__('_gym_env') + return getattr(gym_env, name) + + def get_info(self) -> Any: + """Returns the gym environment info returned on the last step.""" + return self._info + + def _reset(self): + # Upcoming update on gym adds **kwargs on reset. Update this to + # support that. + observation, self._info = self._gym_env.reset() + self._terminated = False + self._truncated = False + return ts.restart(observation) + + @property + def terminated(self) -> bool: + return self._terminated + + @property + def truncated(self) -> bool: + return self._truncated + + def _step(self, action): + # Some environments (e.g. FrozenLake) use the action as a key to the + # transition probability, so it has to be hashable. In the case of discrete + # actions we have a numpy scalar (e.g array(2)) which is not hashable + # in this case, we simply pull out the scalar value which will be hashable. + if self._action_is_discrete and isinstance(action, np.ndarray): + action = action.item() + + # Figure out how tuple or dict actions will be generated by the + # agents and if we can pass them through directly to gym. + observation, reward, self._terminated, self._truncated, self._info = ( + self._gym_env.step(action) + ) + + if self._terminated: + return ts.termination(observation, reward) + elif self._truncated: + return ts.truncation(observation, reward, self._discount) + else: + return ts.transition(observation, reward, self._discount) + + def observation_spec(self) -> types.NestedArraySpec: + return self._observation_spec + + def action_spec(self) -> types.NestedArraySpec: + return self._action_spec + + def close(self) -> None: + return self._gym_env.close() + + def seed(self, seed: types.Seed) -> types.Seed: + np_random, seed = seeding.np_random(seed) + self._gym_env.np_random = np_random + return seed + + def render(self, mode: Text = 'rgb_array') -> Any: + return ( + self._gym_env.render() + ) # mode should be set for key "render_mode" in make() + + # pytype: disable=attribute-error + def set_state(self, state: Any) -> None: + return self._gym_env.set_state(state) + + def get_state(self) -> Any: + return self._gym.get_state() + + # pytype: enable=attribute-error diff --git a/tf_agents/environments/gymnasium_wrapper_test.py b/tf_agents/environments/gymnasium_wrapper_test.py new file mode 100644 index 000000000..e5c078fea --- /dev/null +++ b/tf_agents/environments/gymnasium_wrapper_test.py @@ -0,0 +1,404 @@ +# coding=utf-8 +# Copyright 2020 The TF-Agents Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for environments.gymnasium_wrapper.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +from absl.testing.absltest import mock +import gymnasium as gym +import numpy as np +from tf_agents.environments import gymnasium_wrapper as gym_wrapper +from tf_agents.utils import test_utils + + +class GymnasiumWrapperSpecTest(test_utils.TestCase): + + def test_spec_from_gym_space_discrete(self): + discrete_space = gym.spaces.Discrete(3) + spec = gym_wrapper.spec_from_gym_space(discrete_space) + + self.assertEqual((), spec.shape) + self.assertEqual(np.int64, spec.dtype) + self.assertEqual(0, spec.minimum) + self.assertEqual(2, spec.maximum) + + def test_spec_from_gym_space_multi_discrete(self): + multi_discrete_space = gym.spaces.MultiDiscrete([1, 2, 3, 4]) + spec = gym_wrapper.spec_from_gym_space(multi_discrete_space) + + self.assertEqual((4,), spec.shape) + self.assertEqual(np.int64, spec.dtype) + np.testing.assert_array_equal(np.array([0], dtype=int), spec.minimum) + np.testing.assert_array_equal( + np.array([0, 1, 2, 3], dtype=int), spec.maximum + ) + + def test_spec_from_gym_space_multi_binary(self): + multi_binary_space = gym.spaces.MultiBinary(4) + spec = gym_wrapper.spec_from_gym_space(multi_binary_space) + + self.assertEqual((4,), spec.shape) + self.assertEqual(np.int8, spec.dtype) + np.testing.assert_array_equal(np.array([0], dtype=int), spec.minimum) + np.testing.assert_array_equal(np.array([1], dtype=int), spec.maximum) + + def test_spec_from_gym_space_multi_binary_2d(self): + multi_binary_space = gym.spaces.MultiBinary((8, 8)) + spec = gym_wrapper.spec_from_gym_space(multi_binary_space) + + self.assertEqual((8, 8), spec.shape) + self.assertEqual(np.int8, spec.dtype) + np.testing.assert_array_equal(np.array([0], dtype=int), spec.minimum) + np.testing.assert_array_equal(np.array([1], dtype=int), spec.maximum) + + def test_spec_from_gym_space_box_scalars(self): + for dtype in (np.float32, np.float64): + box_space = gym.spaces.Box(-1.0, 1.0, (3, 4), dtype=dtype) + spec = gym_wrapper.spec_from_gym_space(box_space) + + self.assertEqual((3, 4), spec.shape) + self.assertEqual(dtype, spec.dtype) + np.testing.assert_array_equal(-np.ones((3, 4)), spec.minimum) + np.testing.assert_array_equal(np.ones((3, 4)), spec.maximum) + + def test_spec_from_gym_space_box_scalars_simplify_bounds(self): + box_space = gym.spaces.Box(-1.0, 1.0, (3, 4)) + spec = gym_wrapper.spec_from_gym_space(box_space, simplify_box_bounds=True) + + self.assertEqual((3, 4), spec.shape) + self.assertEqual(np.float32, spec.dtype) + np.testing.assert_array_equal(np.array([-1], dtype=int), spec.minimum) + np.testing.assert_array_equal(np.array([1], dtype=int), spec.maximum) + + def test_spec_from_gym_space_when_simplify_box_bounds_false(self): + # testing on gym.spaces.Dict which makes recursive calls to + # _spec_from_gym_space + box_space = gym.spaces.Box(-1.0, 1.0, (2,)) + dict_space = gym.spaces.Dict({'box1': box_space, 'box2': box_space}) + spec = gym_wrapper.spec_from_gym_space( + dict_space, simplify_box_bounds=False + ) + + self.assertEqual((2,), spec['box1'].shape) + self.assertEqual((2,), spec['box2'].shape) + self.assertEqual(np.float32, spec['box1'].dtype) + self.assertEqual(np.float32, spec['box2'].dtype) + self.assertEqual('box1', spec['box1'].name) + self.assertEqual('box2', spec['box2'].name) + np.testing.assert_array_equal( + np.array([-1, -1], dtype=int), spec['box1'].minimum + ) + np.testing.assert_array_equal( + np.array([1, 1], dtype=int), spec['box1'].maximum + ) + np.testing.assert_array_equal( + np.array([-1, -1], dtype=int), spec['box2'].minimum + ) + np.testing.assert_array_equal( + np.array([1, 1], dtype=int), spec['box2'].maximum + ) + + def test_spec_from_gym_space_box_array(self): + for dtype in (np.float32, np.float64): + box_space = gym.spaces.Box( + np.array([-1.0, -2.0]), np.array([2.0, 4.0]), dtype=dtype + ) + spec = gym_wrapper.spec_from_gym_space(box_space) + + self.assertEqual((2,), spec.shape) + self.assertEqual(dtype, spec.dtype) + np.testing.assert_array_equal(np.array([-1.0, -2.0]), spec.minimum) + np.testing.assert_array_equal(np.array([2.0, 4.0]), spec.maximum) + + def test_spec_from_gym_space_box_array_constant_bounds(self): + for dtype in (np.float32, np.float64): + box_space = gym.spaces.Box( + np.array([-1.0, -1.0]), np.array([2.0, 2.0]), dtype=dtype + ) + spec = gym_wrapper.spec_from_gym_space(box_space) + + self.assertEqual((2,), spec.shape) + self.assertEqual(dtype, spec.dtype) + self.assertAllEqual(-1.0, spec.minimum) + self.assertAllEqual(2.0, spec.maximum) + + def test_spec_from_gym_space_box_array_constant_min(self): + for dtype in (np.float32, np.float64): + box_space = gym.spaces.Box( + np.array([-1.0, -1.0]), np.array([2.0, 4.0]), dtype=dtype + ) + spec = gym_wrapper.spec_from_gym_space(box_space) + + self.assertEqual((2,), spec.shape) + self.assertEqual(dtype, spec.dtype) + self.assertAllEqual([-1.0, -1.0], spec.minimum) + self.assertAllEqual([2.0, 4.0], spec.maximum) + + def test_spec_from_gym_space_tuple(self): + tuple_space = gym.spaces.Tuple( + (gym.spaces.Discrete(2), gym.spaces.Discrete(3)) + ) + spec = gym_wrapper.spec_from_gym_space(tuple_space) + + self.assertEqual(2, len(spec)) + self.assertEqual((), spec[0].shape) + self.assertEqual(np.int64, spec[0].dtype) + self.assertEqual(0, spec[0].minimum) + self.assertEqual(1, spec[0].maximum) + + self.assertEqual((), spec[1].shape) + self.assertEqual(np.int64, spec[1].dtype) + self.assertEqual(0, spec[1].minimum) + self.assertEqual(2, spec[1].maximum) + + def test_spec_from_gym_space_tuple_mixed(self): + tuple_space = gym.spaces.Tuple(( + gym.spaces.Discrete(2), + gym.spaces.Box(-1.0, 1.0, (3, 4)), + gym.spaces.Tuple((gym.spaces.Discrete(2), gym.spaces.Discrete(3))), + gym.spaces.Dict({ + 'spec_1': gym.spaces.Discrete(2), + 'spec_2': gym.spaces.Tuple( + (gym.spaces.Discrete(2), gym.spaces.Discrete(3)) + ), + }), + )) + spec = gym_wrapper.spec_from_gym_space(tuple_space) + + self.assertEqual(4, len(spec)) + # Test Discrete + self.assertEqual((), spec[0].shape) + self.assertEqual(np.int64, spec[0].dtype) + self.assertEqual(0, spec[0].minimum) + self.assertEqual(1, spec[0].maximum) + + # Test Box + self.assertEqual((3, 4), spec[1].shape) + self.assertEqual(np.float32, spec[1].dtype) + np.testing.assert_array_almost_equal(-np.ones((3, 4)), spec[1].minimum) + np.testing.assert_array_almost_equal(np.ones((3, 4)), spec[1].maximum) + + # Test Tuple + self.assertEqual(2, len(spec[2])) + self.assertEqual((), spec[2][0].shape) + self.assertEqual(np.int64, spec[2][0].dtype) + self.assertEqual(0, spec[2][0].minimum) + self.assertEqual(1, spec[2][0].maximum) + self.assertEqual((), spec[2][1].shape) + self.assertEqual(np.int64, spec[2][1].dtype) + self.assertEqual(0, spec[2][1].minimum) + self.assertEqual(2, spec[2][1].maximum) + + # Test Dict + # Test Discrete in Dict + discrete_in_dict = spec[3]['spec_1'] + self.assertEqual((), discrete_in_dict.shape) + self.assertEqual(np.int64, discrete_in_dict.dtype) + self.assertEqual(0, discrete_in_dict.minimum) + self.assertEqual(1, discrete_in_dict.maximum) + + # Test Tuple in Dict + tuple_in_dict = spec[3]['spec_2'] + self.assertEqual(2, len(tuple_in_dict)) + self.assertEqual((), tuple_in_dict[0].shape) + self.assertEqual(np.int64, tuple_in_dict[0].dtype) + self.assertEqual(0, tuple_in_dict[0].minimum) + self.assertEqual(1, tuple_in_dict[0].maximum) + self.assertEqual((), tuple_in_dict[1].shape) + self.assertEqual(np.int64, tuple_in_dict[1].dtype) + self.assertEqual(0, tuple_in_dict[1].minimum) + self.assertEqual(2, tuple_in_dict[1].maximum) + + def test_spec_from_gym_space_dict(self): + dict_space = gym.spaces.Dict([ + ('spec_2', gym.spaces.Box(-1.0, 1.0, (3, 4))), + ('spec_1', gym.spaces.Discrete(2)), + ]) + + spec = gym_wrapper.spec_from_gym_space(dict_space) + + keys = list(spec.keys()) + self.assertEqual('spec_1', keys[1]) + self.assertEqual(2, len(spec)) + self.assertEqual((), spec['spec_1'].shape) + self.assertEqual(np.int64, spec['spec_1'].dtype) + self.assertEqual(0, spec['spec_1'].minimum) + self.assertEqual(1, spec['spec_1'].maximum) + + self.assertEqual('spec_2', keys[0]) + self.assertEqual((3, 4), spec['spec_2'].shape) + self.assertEqual(np.float32, spec['spec_2'].dtype) + np.testing.assert_array_almost_equal( + -np.ones((3, 4)), + spec['spec_2'].minimum, + ) + np.testing.assert_array_almost_equal( + np.ones((3, 4)), + spec['spec_2'].maximum, + ) + + def test_spec_name(self): + box_space = gym.spaces.Box( + np.array([-1.0, -2.0]), np.array([2.0, 4.0]), dtype=np.float32 + ) + spec = gym_wrapper.spec_from_gym_space(box_space, name='observation') + self.assertEqual('observation', spec.name) + + def test_spec_name_nested(self): + dict_space = gym.spaces.Tuple(( + gym.spaces.Dict({ + 'spec_0': gym.spaces.Dict({ + 'spec_1': gym.spaces.Discrete(2), + 'spec_2': gym.spaces.Discrete(2), + }), + }), + gym.spaces.Discrete(2), + )) + spec = gym_wrapper.spec_from_gym_space(dict_space, name='observation') + self.assertEqual( + 'observation/tuple_0/spec_0/spec_1', spec[0]['spec_0']['spec_1'].name + ) + self.assertEqual( + 'observation/tuple_0/spec_0/spec_2', spec[0]['spec_0']['spec_2'].name + ) + self.assertEqual('observation/tuple_1', spec[1].name) + + +class GymnasiumWrapperOnCartpoleTest(test_utils.TestCase): + + def test_wrapped_cartpole_specs(self): + # Note we use spec.make on gym envs to avoid getting a TimeLimit wrapper on + # the environment. + cartpole_env = gym.spec('CartPole-v1').make() + env = gym_wrapper.GymnasiumWrapper(cartpole_env) + + action_spec = env.action_spec() + self.assertEqual((), action_spec.shape) + self.assertEqual(0, action_spec.minimum) + self.assertEqual(1, action_spec.maximum) + + observation_spec = env.observation_spec() + self.assertEqual((4,), observation_spec.shape) + self.assertEqual(np.float32, observation_spec.dtype) + high = np.array([ + 4.8, + np.finfo(np.float32).max, + 2 / 15.0 * math.pi, + np.finfo(np.float32).max, + ]) + np.testing.assert_array_almost_equal(-high, observation_spec.minimum) + np.testing.assert_array_almost_equal(high, observation_spec.maximum) + + def test_wrapped_cartpole_reset(self): + cartpole_env = gym.spec('CartPole-v1').make() + env = gym_wrapper.GymnasiumWrapper(cartpole_env) + + first_time_step = env.reset() + self.assertTrue(first_time_step.is_first()) + self.assertEqual(0.0, first_time_step.reward) + self.assertEqual(1.0, first_time_step.discount) + self.assertEqual((4,), first_time_step.observation.shape) + self.assertEqual(np.float32, first_time_step.observation.dtype) + + def test_wrapped_cartpole_transition(self): + cartpole_env = gym.spec('CartPole-v1').make() + env = gym_wrapper.GymnasiumWrapper(cartpole_env) + env.reset() + transition_time_step = env.step(np.array(0, dtype=np.int32)) + + self.assertTrue(transition_time_step.is_mid()) + self.assertNotEqual(None, transition_time_step.reward) + self.assertEqual(1.0, transition_time_step.discount) + self.assertEqual((4,), transition_time_step.observation.shape) + + def test_wrapped_cartpole_final(self): + cartpole_env = gym.spec('CartPole-v1').make() + env = gym_wrapper.GymnasiumWrapper(cartpole_env) + time_step = env.reset() + + while not time_step.is_last(): + time_step = env.step(np.array(1, dtype=np.int32)) + + self.assertTrue(time_step.is_last()) + self.assertNotEqual(None, time_step.reward) + self.assertEqual(0.0, time_step.discount) + self.assertEqual((4,), time_step.observation.shape) + + def test_get_info(self): + cartpole_env = gym.spec('CartPole-v1').make() + env = gym_wrapper.GymnasiumWrapper(cartpole_env) + self.assertIsNone(env.get_info()) + env.reset() + self.assertIsNone(None, env.get_info()) + env.step(np.array(0, dtype=np.int32)) + self.assertEqual({}, env.get_info()) + + def test_automatic_reset_after_create(self): + cartpole_env = gym.spec('CartPole-v1').make() + env = gym_wrapper.GymnasiumWrapper(cartpole_env) + + first_time_step = env.step(0) # pytype: disable=wrong-arg-types + self.assertTrue(first_time_step.is_first()) + + def test_automatic_reset_after_done(self): + cartpole_env = gym.spec('CartPole-v1').make() + env = gym_wrapper.GymnasiumWrapper(cartpole_env) + time_step = env.reset() + + while not time_step.is_last(): + time_step = env.step(np.array(1, dtype=np.int32)) + + self.assertTrue(time_step.is_last()) + first_time_step = env.step(0) # pytype: disable=wrong-arg-types + self.assertTrue(first_time_step.is_first()) + + def test_automatic_reset_after_done_not_using_reset_directly(self): + cartpole_env = gym.spec('CartPole-v1').make() + env = gym_wrapper.GymnasiumWrapper(cartpole_env) + time_step = env.step(1) # pytype: disable=wrong-arg-types + + while not time_step.is_last(): + time_step = env.step(np.array(1, dtype=np.int32)) + + self.assertTrue(time_step.is_last()) + first_time_step = env.step(0) # pytype: disable=wrong-arg-types + self.assertTrue(first_time_step.is_first()) + + def test_method_propagation(self): + cartpole_env = gym.spec('CartPole-v1').make(render_mode='rgb_array') + for method_name in ('render', 'close'): + setattr(cartpole_env, method_name, mock.MagicMock()) + env = gym_wrapper.GymnasiumWrapper(cartpole_env) + env.render() + self.assertEqual(1, cartpole_env.render.call_count) + cartpole_env.render.assert_called_with() + env.close() + self.assertEqual(1, cartpole_env.close.call_count) + + def test_obs_dtype(self): + cartpole_env = gym.spec('CartPole-v1').make() + cartpole_env.render = mock.MagicMock() + env = gym_wrapper.GymnasiumWrapper(cartpole_env) + time_step = env.reset() + self.assertEqual(env.observation_spec().dtype, time_step.observation.dtype) + + +if __name__ == '__main__': + test_utils.main() diff --git a/tf_agents/environments/suite_gymnasium.py b/tf_agents/environments/suite_gymnasium.py new file mode 100644 index 000000000..51159902c --- /dev/null +++ b/tf_agents/environments/suite_gymnasium.py @@ -0,0 +1,142 @@ +# coding=utf-8 +# Copyright 2020 The TF-Agents Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Suite for loading Gym Environments. + +Note we use gym.spec(env_id).make() on gym envs to avoid getting a TimeLimit +wrapper on the environment. OpenAI's TimeLimit wrappers terminate episodes +without indicating if the failure is due to the time limit, or due to negative +agent behaviour. This prevents us from setting the appropriate discount value +for the final step of an episode. To prevent that we extract the step limit +from the environment specs and utilize our TimeLimit wrapper. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from typing import Any, Callable, Dict, Optional, Sequence, Text + +import gin +import gymnasium as gym +from tf_agents.environments import gymnasium_wrapper as gym_wrapper +from tf_agents.environments import py_environment +from tf_agents.environments import wrappers +from tf_agents.typing import types + +TimeLimitWrapperType = Callable[ + [py_environment.PyEnvironment, int], py_environment.PyEnvironment +] + + +@gin.configurable +def load( + environment_name: Text, + discount: types.Float = 1.0, + max_episode_steps: Optional[types.Int] = None, + gym_env_wrappers: Sequence[types.GymnasiumEnvWrapper] = (), + env_wrappers: Sequence[types.PyEnvWrapper] = (), + gym_kwargs: Optional[Dict[str, Any]] = None, + render_kwargs: Optional[Dict[str, Any]] = None, +) -> py_environment.PyEnvironment: + """Loads the selected environment and wraps it with the specified wrappers. + + Note that by default a TimeLimit wrapper is used to limit episode lengths + to the default benchmarks defined by the registered environments. + + Args: + environment_name: Name for the environment to load. + discount: Discount to use for the environment. + max_episode_steps: If None the max_episode_steps will be set to the default + step limit defined in the environment's spec. No limit is applied if set + to 0 or if there is no max_episode_steps set in the environment's spec. + gym_env_wrappers: Iterable with references to wrapper classes to use + directly on the gym environment. + env_wrappers: Iterable with references to wrapper classes to use on the + gym_wrapped environment. + gym_kwargs: Optional kwargs to pass to the Gym environment class. + render_kwargs: Optional kwargs for rendering to pass to `make()` of the + gymnasium_wrapped environment. + + Returns: + A PyEnvironment instance. + """ + gym_kwargs = gym_kwargs if gym_kwargs else {} + render_kwargs = render_kwargs if render_kwargs else {} + make_args = {**gym_kwargs, **render_kwargs} + gym_spec = gym.spec(environment_name) + gym_env = gym_spec.make(**make_args) + + if max_episode_steps is None and gym_spec.max_episode_steps is not None: + max_episode_steps = gym_spec.max_episode_steps + + return wrap_env( + gym_env, + discount=discount, + max_episode_steps=max_episode_steps, + gym_env_wrappers=gym_env_wrappers, + env_wrappers=env_wrappers, + ) + + +@gin.configurable +def wrap_env( + gym_env: gym.Env, + discount: types.Float = 1.0, + max_episode_steps: Optional[types.Int] = None, + gym_env_wrappers: Sequence[types.GymnasiumEnvWrapper] = (), + time_limit_wrapper: TimeLimitWrapperType = wrappers.TimeLimit, + env_wrappers: Sequence[types.PyEnvWrapper] = (), + auto_reset: bool = True, +) -> py_environment.PyEnvironment: + """Wraps given gymnasium environment with TF Agent's GymnasiumWrapper. + + Note that by default a TimeLimit wrapper is used to limit episode lengths + to the default benchmarks defined by the registered environments. + + Args: + gym_env: An instance of OpenAI gym environment. + discount: Discount to use for the environment. + max_episode_steps: Used to create a TimeLimitWrapper. No limit is applied if + set to None or 0. Usually set to `gym_spec.max_episode_steps` in `load`. + gym_env_wrappers: Iterable with references to wrapper classes to use + directly on the gym environment. + time_limit_wrapper: Wrapper that accepts (env, max_episode_steps) params to + enforce a TimeLimit. Usuaully this should be left as the default, + wrappers.TimeLimit. + env_wrappers: Iterable with references to wrapper classes to use on the + gym_wrapped environment. + auto_reset: If True (default), reset the environment automatically after a + terminal state is reached. + + Returns: + A PyEnvironment instance. + """ + + for wrapper in gym_env_wrappers: + gym_env = wrapper(gym_env) + env = gym_wrapper.GymnasiumWrapper( + gym_env, + discount=discount, + auto_reset=auto_reset, + ) + + if max_episode_steps is not None and max_episode_steps > 0: + env = time_limit_wrapper(env, max_episode_steps) + + for wrapper in env_wrappers: + env = wrapper(env) + + return env diff --git a/tf_agents/environments/suite_gymnasium_test.py b/tf_agents/environments/suite_gymnasium_test.py new file mode 100644 index 000000000..e7aff184d --- /dev/null +++ b/tf_agents/environments/suite_gymnasium_test.py @@ -0,0 +1,89 @@ +# coding=utf-8 +# Copyright 2020 The TF-Agents Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test for tf_agents.environments.suite_gym.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +import gin +from tf_agents.environments import py_environment +from tf_agents.environments import suite_gymnasium as suite_gym +from tf_agents.environments import wrappers +from tf_agents.utils import test_utils + + +class SuiteGymnasiumTest(test_utils.TestCase): + + def tearDown(self): + gin.clear_config() + super(SuiteGymnasiumTest, self).tearDown() + + def test_load_adds_time_limit_steps(self): + env = suite_gym.load('CartPole-v1') + self.assertIsInstance(env, py_environment.PyEnvironment) + self.assertIsInstance(env, wrappers.TimeLimit) + + def test_load_disable_step_limit(self): + env = suite_gym.load('CartPole-v1', max_episode_steps=0) + self.assertIsInstance(env, py_environment.PyEnvironment) + self.assertNotIsInstance(env, wrappers.TimeLimit) + + def test_load_disable_wrappers_applied(self): + duration_wrapper = functools.partial(wrappers.TimeLimit, duration=10) + env = suite_gym.load( + 'CartPole-v1', max_episode_steps=0, env_wrappers=(duration_wrapper,) + ) + self.assertIsInstance(env, py_environment.PyEnvironment) + self.assertIsInstance(env, wrappers.TimeLimit) + + def test_custom_max_steps(self): + env = suite_gym.load('CartPole-v1', max_episode_steps=5) + self.assertIsInstance(env, py_environment.PyEnvironment) + self.assertIsInstance(env, wrappers.TimeLimit) + self.assertEqual(5, env._duration) + + def testGinConfig(self): + gin.parse_config_file( + test_utils.test_src_dir_path('environments/configs/suite_gymnasium.gin') + ) + env = suite_gym.load() + self.assertIsInstance(env, py_environment.PyEnvironment) + self.assertIsInstance(env, wrappers.TimeLimit) + + def test_gym_kwargs_argument(self): + env = suite_gym.load('MountainCar-v0', gym_kwargs={'goal_velocity': 21}) + self.assertEqual(env.unwrapped.goal_velocity, 21) + + env = suite_gym.load('MountainCar-v0', gym_kwargs={'goal_velocity': 50}) + self.assertEqual(env.unwrapped.goal_velocity, 50) + + def test_render_kwargs_argument(self): + env = suite_gym.load( + 'MountainCar-v0', render_kwargs={'render_mode': 'human'} + ) + self.assertEqual(env.unwrapped.render_mode, 'human') + + env = suite_gym.load( + 'MountainCar-v0', render_kwargs={'render_mode': 'rgb_array'} + ) + self.assertEqual(env.unwrapped.render_mode, 'rgb_array') + + +if __name__ == '__main__': + test_utils.main() diff --git a/tf_agents/examples/dqn/gymnasium/d3qn_train_eval.py b/tf_agents/examples/dqn/gymnasium/d3qn_train_eval.py new file mode 100644 index 000000000..6543be8ff --- /dev/null +++ b/tf_agents/examples/dqn/gymnasium/d3qn_train_eval.py @@ -0,0 +1,286 @@ +# coding=utf-8 +# Copyright 2020 The TF-Agents Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Example training Double Dueling DQN (D3QN) using actor/learner in a gymnasium environment. + +To run D3QN on LunarLander: + +```bash +tensorboard --logdir $HOME/tmp/d3qn_lunar_lander --port 2223 & +python tf_agents/examples/dqn/gymnasium/d3qn_train_eval.py +--root_dir=$HOME/tmp/d3qn_lunar_lander +``` +""" + +import functools +import os + +from absl import app +from absl import flags +from absl import logging +import gin +import reverb +import tensorflow.compat.v2 as tf +from tf_agents.agents.dqn import dqn_agent +from tf_agents.environments import suite_gymnasium as suite_gym +from tf_agents.metrics import py_metrics +from tf_agents.networks import dueling_q_network +from tf_agents.policies import py_tf_eager_policy +from tf_agents.policies import random_py_policy +from tf_agents.replay_buffers import reverb_replay_buffer +from tf_agents.replay_buffers import reverb_utils +from tf_agents.system import system_multiprocessing as multiprocessing +from tf_agents.train import actor +from tf_agents.train import learner +from tf_agents.train import triggers +from tf_agents.train.utils import spec_utils +from tf_agents.train.utils import train_utils +from tf_agents.utils import common + +FLAGS = flags.FLAGS + +flags.DEFINE_string( + 'root_dir', + os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'), + 'Root directory for writing logs/summaries/checkpoints.', +) +flags.DEFINE_integer( + 'reverb_port', + None, + 'Port for reverb server, if None, use a randomly chosen unused port.', +) +flags.DEFINE_integer( + 'num_iterations', 200000, 'Total number train/eval iterations to perform.' +) +flags.DEFINE_integer( + 'eval_interval', + 1000, + 'Number of train steps between evaluations. Set to 0 to skip.', +) +flags.DEFINE_boolean('dueling', False, 'Set to True for dueling DQN') +flags.DEFINE_multi_string('gin_file', None, 'Paths to the gin-config files.') +flags.DEFINE_multi_string('gin_bindings', None, 'Gin binding parameters.') + + +@gin.configurable +def train_eval( + root_dir, + env_name='LunarLander-v2', + # Training params + initial_collect_steps=1000, + num_iterations=200000, + fc_layer_params=(128, 128), + # Agent params + epsilon_greedy=0.1, + min_epsilon=0.0001, + num_decay_steps=40000, + batch_size=64, + learning_rate=1e-3, + n_step_update=1, + gamma=0.99, + target_update_tau=1.0, + target_update_period=120, + reward_scale_factor=1.0, + # Replay params + reverb_port=None, + replay_capacity=100000, + # Others + policy_save_interval=1000, + eval_interval=1000, + eval_episodes=10, +): + """Trains and evaluates D3QN.""" + collect_env = suite_gym.load(env_name, max_episode_steps=400) + eval_env = suite_gym.load( + env_name, + max_episode_steps=400, + render_kwargs={'render_mode': 'rgb_array'}, + ) + + observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( + spec_utils.get_tensor_specs(collect_env) + ) + + train_step = train_utils.create_train_step() + + # D3QN requires Dueling Q Networks + q_net = dueling_q_network.DuelingQNetwork( + input_tensor_spec=observation_tensor_spec, + action_spec=action_tensor_spec, + fc_layer_params=fc_layer_params, + ) + + target_q_net = dueling_q_network.DuelingQNetwork( + input_tensor_spec=observation_tensor_spec, + action_spec=action_tensor_spec, + fc_layer_params=fc_layer_params, + ) + + epsilon_decay = tf.compat.v1.train.polynomial_decay( + learning_rate=epsilon_greedy, + global_step=train_step, + decay_steps=num_decay_steps, + end_learning_rate=min_epsilon, + power=0.5, + ) + + # D3QN requires a Dueling Double DQN agent + agent = dqn_agent.D3qnAgent( + time_step_tensor_spec, + action_tensor_spec, + q_network=q_net, + epsilon_greedy=epsilon_decay, + n_step_update=n_step_update, + target_q_network=target_q_net, + target_update_tau=target_update_tau, + target_update_period=target_update_period, + optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), + td_errors_loss_fn=common.element_wise_squared_loss, + gamma=gamma, + reward_scale_factor=reward_scale_factor, + train_step_counter=train_step, + ) + + table_name = 'uniform_table' + sequence_length = n_step_update + 1 + table = reverb.Table( + table_name, + max_size=replay_capacity, + sampler=reverb.selectors.Uniform(), + # sampler=reverb.selectors.Prioritized(priority_exponent=0.5), + remover=reverb.selectors.Fifo(), + rate_limiter=reverb.rate_limiters.MinSize(1), + ) + reverb_server = reverb.Server([table], port=reverb_port) + reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( + agent.collect_data_spec, + sequence_length=sequence_length, + table_name=table_name, + local_server=reverb_server, + ) + rb_observer = reverb_utils.ReverbAddTrajectoryObserver( + reverb_replay.py_client, + table_name, + sequence_length=sequence_length, + stride_length=1, + ) + + dataset = reverb_replay.as_dataset( + num_parallel_calls=3, sample_batch_size=batch_size, num_steps=2 + ).prefetch(3) + experience_dataset_fn = lambda: dataset + + saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) + env_step_metric = py_metrics.EnvironmentSteps() + + learning_triggers = [ + triggers.PolicySavedModelTrigger( + saved_model_dir, + agent, + train_step, + interval=policy_save_interval, + metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}, + ), + triggers.StepPerSecondLogTrigger(train_step, interval=100), + ] + + dqn_learner = learner.Learner( + root_dir, + train_step, + agent, + experience_dataset_fn, + triggers=learning_triggers, + ) + + # If we haven't trained yet make sure we collect some random samples first to + # fill up the Replay Buffer with some experience. + random_policy = random_py_policy.RandomPyPolicy( + collect_env.time_step_spec(), collect_env.action_spec() + ) + initial_collect_actor = actor.Actor( + collect_env, + random_policy, + train_step, + steps_per_run=initial_collect_steps, + observers=[rb_observer], + ) + logging.info('Doing initial collect.') + initial_collect_actor.run() + + tf_collect_policy = agent.collect_policy + collect_policy = py_tf_eager_policy.PyTFEagerPolicy( + tf_collect_policy, use_tf_function=True + ) + + collect_actor = actor.Actor( + collect_env, + collect_policy, + train_step, + steps_per_run=1, + observers=[rb_observer, env_step_metric], + metrics=actor.collect_metrics(10), + summary_dir=os.path.join(root_dir, learner.TRAIN_DIR), + ) + + tf_greedy_policy = agent.policy + greedy_policy = py_tf_eager_policy.PyTFEagerPolicy( + tf_greedy_policy, use_tf_function=True + ) + + eval_actor = actor.Actor( + eval_env, + greedy_policy, + train_step, + episodes_per_run=eval_episodes, + metrics=actor.eval_metrics(eval_episodes), + summary_dir=os.path.join(root_dir, 'eval'), + ) + + if eval_interval: + logging.info('Evaluating.') + eval_actor.run_and_log() + + logging.info('Training.') + for _ in range(num_iterations): + collect_actor.run() + dqn_learner.run(iterations=1) + + if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0: + logging.info('Evaluating (epsilon: %s).', epsilon_decay()) + eval_actor.run_and_log() + average_return = eval_actor.metrics[0].result() + # LunarLander-v2 goal is 200.0 + if average_return > 200.0: + break + + rb_observer.close() + reverb_server.stop() + + +def main(_): + logging.set_verbosity(logging.INFO) + gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings) + + train_eval( + root_dir=FLAGS.root_dir, + num_iterations=FLAGS.num_iterations, + reverb_port=FLAGS.reverb_port, + eval_interval=FLAGS.eval_interval, + ) + + +if __name__ == '__main__': + flags.mark_flag_as_required('root_dir') + multiprocessing.handle_main(functools.partial(app.run, main)) diff --git a/tf_agents/networks/__init__.py b/tf_agents/networks/__init__.py index 896c19c69..b75873eb5 100644 --- a/tf_agents/networks/__init__.py +++ b/tf_agents/networks/__init__.py @@ -19,6 +19,7 @@ from tf_agents.networks import actor_distribution_rnn_network from tf_agents.networks import categorical_projection_network from tf_agents.networks import categorical_q_network +from tf_agents.networks import dueling_q_network from tf_agents.networks import encoding_network from tf_agents.networks import expand_dims_layer from tf_agents.networks import lstm_encoding_network diff --git a/tf_agents/networks/dueling_q_network.py b/tf_agents/networks/dueling_q_network.py new file mode 100644 index 000000000..b266590e4 --- /dev/null +++ b/tf_agents/networks/dueling_q_network.py @@ -0,0 +1,149 @@ +# coding=utf-8 +# Copyright 2020 The TF-Agents Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sample Keras networks for Dueling (D)DQN. + +Implements a TF-Agents Network from + +"Dueling Network Architectures for Deep Reinforcement Learning" + Wang et al., 2016 + https://arxiv.org/abs/1511.06581 +""" + +import gin +import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import +from tf_agents.networks import q_network + + +@gin.configurable +class DuelingQNetwork(q_network.QNetwork): + """Extensions to the classic (D)DQN networks.""" + + def __init__( + self, + input_tensor_spec, + action_spec, + preprocessing_layers=None, + preprocessing_combiner=None, + conv_layer_params=None, + fc_layer_params=(75, 40), + dropout_layer_params=None, + activation_fn=tf.keras.activations.relu, + kernel_initializer=None, + batch_squash=True, + dtype=tf.float32, + q_layer_activation_fn=None, + name='DuelingQNetwork', + ): + """Creates an instance of `DuelingQNetwork` as a subclass of QNetwork. + + Args: + input_tensor_spec: A nest of `tensor_spec.TensorSpec` representing the + input observations. + action_spec: A nest of `tensor_spec.BoundedTensorSpec` representing the + actions. + preprocessing_layers: (Optional.) A nest of `tf.keras.layers.Layer` + representing preprocessing for the different observations. All of these + layers must not be already built. For more details see the documentation + of `networks.EncodingNetwork`. + preprocessing_combiner: (Optional.) A keras layer that takes a flat list + of tensors and combines them. Good options include `tf.keras.layers.Add` + and `tf.keras.layers.Concatenate(axis=-1)`. This layer must not be + already built. For more details see the documentation of + `networks.EncodingNetwork`. + conv_layer_params: Optional list of convolution layers parameters, where + each item is a length-three tuple indicating (filters, kernel_size, + stride). + fc_layer_params: Optional list of fully_connected parameters, where each + item is the number of units in the layer. + dropout_layer_params: Optional list of dropout layer parameters, where + each item is the fraction of input units to drop. The dropout layers are + interleaved with the fully connected layers; there is a dropout layer + after each fully connected layer, except if the entry in the list is + None. This list must have the same length of fc_layer_params, or be + None. + activation_fn: Activation function, e.g. tf.keras.activations.relu. + kernel_initializer: Initializer to use for the kernels of the conv and + dense layers. If none is provided a default variance_scaling_initializer + batch_squash: If True the outer_ranks of the observation are squashed into + the batch dimension. This allows encoding networks to be used with + observations with shape [BxTx...]. + dtype: The dtype to use by the convolution and fully connected layers. + q_layer_activation_fn: Activation function for the Q layer. + name: A string representing the name of the network. + + Raises: + ValueError: If `input_tensor_spec` contains more than one observation. Or + if `action_spec` contains more than one action. + """ + super(DuelingQNetwork, self).__init__( + input_tensor_spec=input_tensor_spec, + action_spec=action_spec, + preprocessing_layers=preprocessing_layers, + preprocessing_combiner=preprocessing_combiner, + conv_layer_params=conv_layer_params, + fc_layer_params=fc_layer_params, + dropout_layer_params=dropout_layer_params, + activation_fn=activation_fn, + kernel_initializer=kernel_initializer, + batch_squash=batch_squash, + dtype=dtype, + q_layer_activation_fn=q_layer_activation_fn, + name=name, + ) + + # Add a dense layer to the encoding network, in parallel to the + # q_value_layer. This 'dueling' layer estimates the state value for the + # input. + dueling_layer = tf.keras.layers.Dense( + 1, + activation=None, + kernel_initializer=tf.random_uniform_initializer( + minval=-0.03, maxval=0.03 + ), + bias_initializer=tf.constant_initializer(-0.2), + dtype=dtype, + ) + + self._dueling_layer = dueling_layer + self.layers.append(self._dueling_layer) # state value + self.layers.append(self._q_value_layer) # action advantage + + def call(self, observation, step_type=None, network_state=(), training=False): + """Runs the given observation through the network. + + Args: + observation: The observation to provide to the network. + step_type: The step type for the given observation. See `StepType` in + time_step.py. + network_state: A state tuple to pass to the network, mainly used by RNNs. + training: Whether the output is being used for training. + + Returns: + A tuple `(logits, network_state)`. + """ + state, network_state = self._encoder( + observation, + step_type=step_type, + network_state=network_state, + training=training, + ) + + q_values = self._q_value_layer(state) + state_value = self._dueling_layer(state) + advantage = state_value + ( + q_values - tf.reduce_mean(q_values, axis=1, keepdims=True) + ) + return (advantage, q_values), network_state diff --git a/tf_agents/policies/q_policy.py b/tf_agents/policies/q_policy.py index 92c639abb..b558e3e49 100644 --- a/tf_agents/policies/q_policy.py +++ b/tf_agents/policies/q_policy.py @@ -120,9 +120,13 @@ def observation_and_action_constraint_splitter(observation): return ) num_actions = spec.maximum - spec.minimum + 1 - network_utils.check_single_floating_network_output( - q_network.create_variables(), (num_actions,), str(q_network) - ) + # enable checking of dueling Q networks + outputs = q_network.create_variables() + iterable = list(outputs) if isinstance(outputs, tuple) else [outputs] + for output in iterable: + network_utils.check_single_floating_network_output( + output, (num_actions,), str(q_network) + ) # We need to maintain the flat action spec for dtype, shape and range. self._flat_action_spec = flat_action_spec[0] @@ -165,7 +169,8 @@ def _distribution(self, time_step, policy_state): step_type=time_step.step_type, ) - logits = q_values + # use action values + logits = q_values[1] if isinstance(q_values, tuple) else q_values if observation_and_action_constraint_splitter is not None: # Overwrite the logits for invalid actions to logits.dtype.min. diff --git a/tf_agents/typing/types.py b/tf_agents/typing/types.py index 919eca095..e764339ab 100644 --- a/tf_agents/typing/types.py +++ b/tf_agents/typing/types.py @@ -102,6 +102,8 @@ GymEnv = ForwardRef('gym.Env') # pylint: disable=invalid-name GymEnvWrapper = Callable[[GymEnv], GymEnv] +GymnasiumEnv = ForwardRef('gymnasium.Env') # pylint: disable=invalid-name +GymnasiumEnvWrapper = Callable[[GymnasiumEnv], GymnasiumEnv] PyEnv = ForwardRef('tf_agents.environments.py_environment.PyEnvironment') # pylint: disable=invalid-name PyEnvWrapper = Callable[[PyEnv], PyEnv]