Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
internal
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 187661295
  • Loading branch information
Lukasz Kaiser authored and Ryan Sepassi committed Mar 2, 2018
1 parent 0584c15 commit 4dd189e
Show file tree
Hide file tree
Showing 6 changed files with 292 additions and 44 deletions.
86 changes: 77 additions & 9 deletions tensor2tensor/data_generators/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,28 @@
from __future__ import division
from __future__ import print_function

import os
import functools

# Dependency imports

import gym
import numpy as np

from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.models.research import rl
from tensor2tensor.rl.envs import atari_wrappers
from tensor2tensor.utils import registry

import tensorflow as tf



def gym_lib():
"""Access to gym to allow for import of this file without a gym install."""
try:
import gym # pylint: disable=g-import-not-at-top
except ImportError:
raise ImportError("pip install gym to use gym-based Problems")
return gym

flags = tf.flags
FLAGS = flags.FLAGS

flags.DEFINE_string("model_path", "", "File with model for pong")


class GymDiscreteProblem(problem.Problem):
Expand All @@ -55,7 +58,7 @@ def env_name(self):
@property
def env(self):
if self._env is None:
self._env = gym_lib().make(self.env_name)
self._env = gym.make(self.env_name)
return self._env

@property
Expand Down Expand Up @@ -143,3 +146,68 @@ def num_rewards(self):
@property
def num_steps(self):
return 5000


@registry.register_problem
class GymPongTrajectoriesFromPolicy(GymDiscreteProblem):
"""Pong game, loaded actions."""

def __init__(self, event_dir, *args, **kwargs):
super(GymPongTrajectoriesFromPolicy, self).__init__(*args, **kwargs)
self._env = None
self._event_dir = event_dir
env_spec = lambda: atari_wrappers.wrap_atari( # pylint: disable=g-long-lambda
gym.make("PongNoFrameskip-v4"),
warp=False,
frame_skip=4,
frame_stack=False)
hparams = rl.atari_base()
with tf.variable_scope("train"):
policy_lambda = hparams.network
policy_factory = tf.make_template(
"network",
functools.partial(policy_lambda, env_spec().action_space, hparams))
self._max_frame_pl = tf.placeholder(
tf.float32, self.env.observation_space.shape)
actor_critic = policy_factory(tf.expand_dims(tf.expand_dims(
self._max_frame_pl, 0), 0))
policy = actor_critic.policy
self._last_policy_op = policy.mode()
self._last_action = self.env.action_space.sample()
self._skip = 4
self._skip_step = 0
self._obs_buffer = np.zeros((2,) + self.env.observation_space.shape,
dtype=np.uint8)
self._sess = tf.Session()
model_saver = tf.train.Saver(tf.global_variables(".*network_parameters.*"))
model_saver.restore(self._sess, FLAGS.model_path)

# TODO(blazej0): For training of atari agents wrappers are usually used.
# Below we have a hacky solution which is a workaround to be used together
# with atari_wrappers.MaxAndSkipEnv.
def get_action(self, observation=None):
if self._skip_step == self._skip - 2: self._obs_buffer[0] = observation
if self._skip_step == self._skip - 1: self._obs_buffer[1] = observation
self._skip_step = (self._skip_step + 1) % self._skip
if self._skip_step == 0:
max_frame = self._obs_buffer.max(axis=0)
self._last_action = int(self._sess.run(
self._last_policy_op,
feed_dict={self._max_frame_pl: max_frame})[0, 0])
return self._last_action

@property
def env_name(self):
return "PongNoFrameskip-v4"

@property
def num_actions(self):
return 4

@property
def num_rewards(self):
return 2

@property
def num_steps(self):
return 5000
1 change: 1 addition & 0 deletions tensor2tensor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from tensor2tensor.models.research import cycle_gan
from tensor2tensor.models.research import gene_expression
from tensor2tensor.models.research import multimodel
from tensor2tensor.models.research import rl
from tensor2tensor.models.research import super_lm
from tensor2tensor.models.research import transformer_moe
from tensor2tensor.models.research import transformer_revnet
Expand Down
79 changes: 49 additions & 30 deletions tensor2tensor/models/research/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def ppo_base_v1():
hparams.add_hparam("eval_every_epochs", 10)
hparams.add_hparam("num_eval_agents", 3)
hparams.add_hparam("video_during_eval", True)
hparams.add_hparam("save_models_every_epochs", 30)
return hparams


Expand All @@ -66,7 +67,23 @@ def discrete_action_base():
return hparams


# Neural networks for actor-critic algorithms
@registry.register_hparams
def atari_base():
"""Atari base parameters."""
hparams = discrete_action_base()
hparams.learning_rate = 16e-5
hparams.num_agents = 5
hparams.epoch_length = 200
hparams.gae_gamma = 0.985
hparams.gae_lambda = 0.985
hparams.entropy_loss_coef = 0.002
hparams.value_loss_coef = 0.025
hparams.optimization_epochs = 10
hparams.epochs_num = 10000
hparams.num_eval_agents = 1
hparams.network = feed_forward_cnn_small_categorical_fun
return hparams


NetworkOutput = collections.namedtuple(
"NetworkOutput", "policy, value, action_postprocessing")
Expand All @@ -85,23 +102,24 @@ def feed_forward_gaussian_fun(action_space, config, observations):
tf.shape(observations)[0], tf.shape(observations)[1],
functools.reduce(operator.mul, observations.shape.as_list()[2:], 1)])

with tf.variable_scope("policy"):
x = flat_observations
for size in config.policy_layers:
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
mean = tf.contrib.layers.fully_connected(
x, action_space.shape[0], tf.tanh,
weights_initializer=mean_weights_initializer)
logstd = tf.get_variable(
"logstd", mean.shape[2:], tf.float32, logstd_initializer)
logstd = tf.tile(
logstd[None, None],
[tf.shape(mean)[0], tf.shape(mean)[1]] + [1] * (mean.shape.ndims - 2))
with tf.variable_scope("value"):
x = flat_observations
for size in config.value_layers:
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
value = tf.contrib.layers.fully_connected(x, 1, None)[..., 0]
with tf.variable_scope("network_parameters"):
with tf.variable_scope("policy"):
x = flat_observations
for size in config.policy_layers:
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
mean = tf.contrib.layers.fully_connected(
x, action_space.shape[0], tf.tanh,
weights_initializer=mean_weights_initializer)
logstd = tf.get_variable(
"logstd", mean.shape[2:], tf.float32, logstd_initializer)
logstd = tf.tile(
logstd[None, None],
[tf.shape(mean)[0], tf.shape(mean)[1]] + [1] * (mean.shape.ndims - 2))
with tf.variable_scope("value"):
x = flat_observations
for size in config.value_layers:
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
value = tf.contrib.layers.fully_connected(x, 1, None)[..., 0]
mean = tf.check_numerics(mean, "mean")
logstd = tf.check_numerics(logstd, "logstd")
value = tf.check_numerics(value, "value")
Expand All @@ -119,17 +137,18 @@ def feed_forward_categorical_fun(action_space, config, observations):
flat_observations = tf.reshape(observations, [
tf.shape(observations)[0], tf.shape(observations)[1],
functools.reduce(operator.mul, observations.shape.as_list()[2:], 1)])
with tf.variable_scope("policy"):
x = flat_observations
for size in config.policy_layers:
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
logits = tf.contrib.layers.fully_connected(x, action_space.n,
activation_fn=None)
with tf.variable_scope("value"):
x = flat_observations
for size in config.value_layers:
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
value = tf.contrib.layers.fully_connected(x, 1, None)[..., 0]
with tf.variable_scope("network_parameters"):
with tf.variable_scope("policy"):
x = flat_observations
for size in config.policy_layers:
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
logits = tf.contrib.layers.fully_connected(x, action_space.n,
activation_fn=None)
with tf.variable_scope("value"):
x = flat_observations
for size in config.value_layers:
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
value = tf.contrib.layers.fully_connected(x, 1, None)[..., 0]
policy = tf.contrib.distributions.Categorical(logits=logits)
return NetworkOutput(policy, value, lambda a: a)

Expand All @@ -141,7 +160,7 @@ def feed_forward_cnn_small_categorical_fun(action_space, config, observations):
obs_shape = observations.shape.as_list()
x = tf.reshape(observations, [-1] + obs_shape[2:])

with tf.variable_scope("policy"):
with tf.variable_scope("network_parameters"):
x = tf.to_float(x) / 255.0
x = tf.contrib.layers.conv2d(x, 32, [5, 5], [2, 2],
activation_fn=tf.nn.relu, padding="SAME")
Expand Down
12 changes: 11 additions & 1 deletion tensor2tensor/rl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@ for now and under heavy development.

Currently the only supported algorithm is Proximy Policy Optimization - PPO.

## Sample usage - training in Pendulum-v0 environment.
## Sample usage - training in the Pendulum-v0 environment.

```python rl/t2t_rl_trainer.py --problems=Pendulum-v0 --hparams_set continuous_action_base [--output_dir dir_location]```

## Sample usage - training in the PongNoFrameskip-v0 environment.

```python tensor2tensor/rl/t2t_rl_trainer.py --problem stacked_pong --hparams_set atari_base --hparams num_agents=5 --output_dir /tmp/pong`date +%Y%m%d_%H%M%S````

## Sample usage - generation of a model

```python tensor2tensor/bin/t2t-trainer --generate_data --data_dir=~/t2t_data --problems=gym_pong_trajectories_from_policy --hparams_set=base_atari --model_path [model]```

```python tensor2tensor/bin/t2t-datagen --data_dir=~/t2t_data --tmp_dir=~/t2t_data/tmp --problem=gym_pong_trajectories_from_policy --model_path [model]```
139 changes: 139 additions & 0 deletions tensor2tensor/rl/envs/atari_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# coding=utf-8
# Copyright 2018 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Various wrappers copied for Gym Baselines."""

from collections import deque
import gym
import numpy as np


# Adapted from the link below.
# https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py


class WarpFrame(gym.ObservationWrapper):
"""Wrap a frame."""

def __init__(self, env):
"""Warp frames to 84x84 as done in the Nature paper and later work."""
gym.ObservationWrapper.__init__(self, env)
self.width = 84
self.height = 84
self.observation_space = gym.spaces.Box(
low=0, high=255,
shape=(self.height, self.width, 1), dtype=np.uint8)

def observation(self, frame):
import cv2 # pylint: disable=g-import-not-at-top
cv2.ocl.setUseOpenCL(False)
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
frame = cv2.resize(frame, (self.width, self.height),
interpolation=cv2.INTER_AREA)
return frame[:, :, None]


class LazyFrames(object):
"""Lazy frame storage."""

def __init__(self, frames):
"""Lazy frame storage.
This object ensures that common frames between the observations
are only stored once. It exists purely to optimize memory usage
which can be huge for DQN's 1M frames replay buffers.
This object should only be converted to numpy array before being passed
to the model.
Args:
frames: the frames.
"""
self._frames = frames

def __array__(self, dtype=None):
out = np.concatenate(self._frames, axis=2)
if dtype is not None:
out = out.astype(dtype)
return out


class FrameStack(gym.Wrapper):
"""Stack frames."""

def __init__(self, env, k):
"""Stack k last frames. Returns lazy array, memory efficient."""
gym.Wrapper.__init__(self, env)
self.k = k
self.frames = deque([], maxlen=k)
shp = env.observation_space.shape
self.observation_space = gym.spaces.Box(
low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=np.uint8)

def reset(self):
ob = self.env.reset()
for _ in range(self.k):
self.frames.append(ob)
return self._get_ob()

def step(self, action):
ob, reward, done, info = self.env.step(action)
self.frames.append(ob)
return self._get_ob(), reward, done, info

def _get_ob(self):
assert len(self.frames) == self.k
return LazyFrames(list(self.frames))


class MaxAndSkipEnv(gym.Wrapper):
"""Max and skip env."""

def __init__(self, env, skip=4):
"""Return only every `skip`-th frame."""
gym.Wrapper.__init__(self, env)
# Most recent raw observations (for max pooling across time steps).
self._obs_buffer = np.zeros((2,) + env.observation_space.shape,
dtype=np.uint8)
self._skip = skip

def reset(self, **kwargs):
return self.env.reset(**kwargs)

def step(self, action):
"""Repeat action, sum reward, and max over last observations."""
total_reward = 0.0
done = None
for i in range(self._skip):
obs, reward, done, info = self.env.step(action)
if i == self._skip - 2: self._obs_buffer[0] = obs
if i == self._skip - 1: self._obs_buffer[1] = obs
total_reward += reward
if done:
break
# Note that the observation on the done=True frame
# doesn't matter
max_frame = self._obs_buffer.max(axis=0)

return max_frame, total_reward, done, info


def wrap_atari(env, warp=False, frame_skip=False, frame_stack=False):
if warp:
env = WarpFrame(env)
if frame_skip:
env = MaxAndSkipEnv(env, frame_skip)
if frame_stack:
env = FrameStack(env, frame_stack)
return env
Loading

0 comments on commit 4dd189e

Please sign in to comment.