Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #633 from rsepassi/push
Browse files Browse the repository at this point in the history
v1.5.4
  • Loading branch information
lukaszkaiser authored Mar 2, 2018
2 parents 11f1ae4 + 4dd189e commit fd9b315
Show file tree
Hide file tree
Showing 30 changed files with 417 additions and 184 deletions.
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ env:
matrix:
- TF_VERSION="1.4.*"
- TF_VERSION="1.5.*"
- TF_VERSION="1.6.0rc1"
- TF_VERSION="1.6.*"
matrix:
exclude:
- python: "3.6"
env: TF_VERSION="1.4.*"
- python: "3.6"
env: TF_VERSION="1.6.0rc1"
env: TF_VERSION="1.6.*"
before_install:
- echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list
- curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add -
Expand Down
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CO

[Tensor2Tensor](https://github.com/tensorflow/tensor2tensor), or
[T2T](https://github.com/tensorflow/tensor2tensor) for short, is a library
of deep learning models and datasets designed to [accelerate deep learning
research](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html) and make it more accessible.

T2T is actively used and maintained by researchers and engineers within the
of deep learning models and datasets designed to make deep learning more
accessible and [accelerate ML
research](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html).
is actively used and maintained by researchers and engineers within the
[Google Brain team](https://research.google.com/teams/brain/) and a community
of users. We're eager to collaborate with you too, so feel free to
[open an issue on GitHub](https://github.com/tensorflow/tensor2tensor/issues)
Expand Down Expand Up @@ -368,6 +368,7 @@ T2T](https://research.googleblog.com/2017/06/accelerating-deep-learning-research
* [Discrete Autoencoders for Sequence Models](https://arxiv.org/abs/1801.09797)
* [Generating Wikipedia by Summarizing Long
Sequences](https://arxiv.org/abs/1801.10198)
* [Image Transformer](https://openreview.net/forum?id=r16Vyf-0-)
* [Image Transformer](https://arxiv.org/abs/1802.05751)
* [Training Tips for the Transformer Model](http://ufallab.ms.mff.cuni.cz/~popel/training-tips-transformer.pdf)

*Note: This is not an official Google product.*
5 changes: 3 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CO

[Tensor2Tensor](https://github.com/tensorflow/tensor2tensor), or
[T2T](https://github.com/tensorflow/tensor2tensor) for short, is a library
of deep learning models and datasets designed to [accelerate deep learning
research](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html) and make it more accessible.
of deep learning models and datasets designed to make deep learning more
accessible and [accelerate ML
research](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html).


## Basics
Expand Down
4 changes: 4 additions & 0 deletions docs/new_problem.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CO
[![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby)
[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0)

Another good overview of this part together with training is given in
[The Cloud ML Poetry Blog
Post](https://cloud.google.com/blog/big-data/2018/02/cloud-poetry-training-and-hyperparameter-tuning-custom-text-models-on-cloud-ml-engine)

Let's add a new dataset together and train the
[Transformer](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/models/transformer.py)
model on it. We'll give the model a line of poetry, and it will learn to
Expand Down
11 changes: 6 additions & 5 deletions docs/walkthrough.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CO

[Tensor2Tensor](https://github.com/tensorflow/tensor2tensor), or
[T2T](https://github.com/tensorflow/tensor2tensor) for short, is a library
of deep learning models and datasets designed to [accelerate deep learning
research](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html) and make it more accessible.

T2T is actively used and maintained by researchers and engineers within the
of deep learning models and datasets designed to make deep learning more
accessible and [accelerate ML
research](https://research.googleblog.com/2017/06/accelerating-deep-learning-research.html).
is actively used and maintained by researchers and engineers within the
[Google Brain team](https://research.google.com/teams/brain/) and a community
of users. We're eager to collaborate with you too, so feel free to
[open an issue on GitHub](https://github.com/tensorflow/tensor2tensor/issues)
Expand Down Expand Up @@ -368,6 +368,7 @@ T2T](https://research.googleblog.com/2017/06/accelerating-deep-learning-research
* [Discrete Autoencoders for Sequence Models](https://arxiv.org/abs/1801.09797)
* [Generating Wikipedia by Summarizing Long
Sequences](https://arxiv.org/abs/1801.10198)
* [Image Transformer](https://openreview.net/forum?id=r16Vyf-0-)
* [Image Transformer](https://arxiv.org/abs/1802.05751)
* [Training Tips for the Transformer Model](http://ufallab.ms.mff.cuni.cz/~popel/training-tips-transformer.pdf)

*Note: This is not an official Google product.*
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='tensor2tensor',
version='1.5.3',
version='1.5.4',
description='Tensor2Tensor',
author='Google Inc.',
author_email='no-reply@google.com',
Expand Down
6 changes: 5 additions & 1 deletion tensor2tensor/data_generators/generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,9 @@ def generate_files(generator, output_filenames, max_cases=None):
if outputs_exist(output_filenames):
tf.logging.info("Skipping generator because outputs files exist")
return
tmp_filenames = [fname + ".incomplete" for fname in output_filenames]
num_shards = len(output_filenames)
writers = [tf.python_io.TFRecordWriter(fname) for fname in output_filenames]
writers = [tf.python_io.TFRecordWriter(fname) for fname in tmp_filenames]
counter, shard = 0, 0
for case in generator:
if case is None:
Expand All @@ -165,6 +166,9 @@ def generate_files(generator, output_filenames, max_cases=None):
for writer in writers:
writer.close()

for tmp_name, final_name in zip(tmp_filenames, output_filenames):
tf.gfile.Rename(tmp_name, final_name)

tf.logging.info("Generated %s Examples", counter)


Expand Down
50 changes: 23 additions & 27 deletions tensor2tensor/data_generators/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,39 +19,30 @@
from __future__ import division
from __future__ import print_function

import os
import functools

# Dependency imports

import numpy as np
import functools
import gym
import numpy as np

from tensor2tensor.rl import rl_trainer_lib
from tensor2tensor.rl.envs import atari_wrappers
from tensor2tensor.models.research import rl
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.models.research import rl
from tensor2tensor.rl.envs import atari_wrappers
from tensor2tensor.utils import registry

import tensorflow as tf




flags = tf.flags
FLAGS = flags.FLAGS

flags.DEFINE_string("model_path", "", "File with model for pong")


def gym_lib():
"""Access to gym to allow for import of this file without a gym install."""
try:
import gym # pylint: disable=g-import-not-at-top
except ImportError:
raise ImportError("pip install gym to use gym-based Problems")
return gym


class GymDiscreteProblem(problem.Problem):
"""Gym environment with discrete actions and rewards."""

Expand All @@ -67,7 +58,7 @@ def env_name(self):
@property
def env(self):
if self._env is None:
self._env = gym_lib().make(self.env_name)
self._env = gym.make(self.env_name)
return self._env

@property
Expand Down Expand Up @@ -157,8 +148,6 @@ def num_steps(self):
return 5000




@registry.register_problem
class GymPongTrajectoriesFromPolicy(GymDiscreteProblem):
"""Pong game, loaded actions."""
Expand All @@ -167,28 +156,34 @@ def __init__(self, event_dir, *args, **kwargs):
super(GymPongTrajectoriesFromPolicy, self).__init__(*args, **kwargs)
self._env = None
self._event_dir = event_dir
env_spec = lambda: atari_wrappers.wrap_atari(
gym.make("PongNoFrameskip-v4"), warp=False, frame_skip=4, frame_stack=False)
env_spec = lambda: atari_wrappers.wrap_atari( # pylint: disable=g-long-lambda
gym.make("PongNoFrameskip-v4"),
warp=False,
frame_skip=4,
frame_stack=False)
hparams = rl.atari_base()
with tf.variable_scope("train"):
policy_lambda = hparams.network
policy_factory = tf.make_template(
"network",
functools.partial(policy_lambda, env_spec().action_space, hparams))
self._max_frame_pl = tf.placeholder(tf.float32, self.env.observation_space.shape)
actor_critic = policy_factory(tf.expand_dims(tf.expand_dims(self._max_frame_pl, 0), 0))
"network",
functools.partial(policy_lambda, env_spec().action_space, hparams))
self._max_frame_pl = tf.placeholder(
tf.float32, self.env.observation_space.shape)
actor_critic = policy_factory(tf.expand_dims(tf.expand_dims(
self._max_frame_pl, 0), 0))
policy = actor_critic.policy
self._last_policy_op = policy.mode()
self._last_action = self.env.action_space.sample()
self._skip = 4
self._skip_step = 0
self._obs_buffer = np.zeros((2,) + self.env.observation_space.shape, dtype=np.uint8)
self._obs_buffer = np.zeros((2,) + self.env.observation_space.shape,
dtype=np.uint8)
self._sess = tf.Session()
model_saver = tf.train.Saver(tf.global_variables(".*network_parameters.*"))
model_saver.restore(self._sess, FLAGS.model_path)

# TODO(blazej0): For training of atari agents wrappers are usually used.
# Below we have a hacky solution which is a temporary workaround to be used together
# Below we have a hacky solution which is a workaround to be used together
# with atari_wrappers.MaxAndSkipEnv.
def get_action(self, observation=None):
if self._skip_step == self._skip - 2: self._obs_buffer[0] = observation
Expand All @@ -197,7 +192,8 @@ def get_action(self, observation=None):
if self._skip_step == 0:
max_frame = self._obs_buffer.max(axis=0)
self._last_action = int(self._sess.run(
self._last_policy_op, feed_dict={self._max_frame_pl: max_frame})[0, 0])
self._last_policy_op,
feed_dict={self._max_frame_pl: max_frame})[0, 0])
return self._last_action

@property
Expand Down
99 changes: 99 additions & 0 deletions tensor2tensor/data_generators/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,84 @@
from __future__ import division
from __future__ import print_function

import os
# Dependency imports

from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import image_utils
from tensor2tensor.utils import registry

import tensorflow as tf

# URLs and filenames for IMAGENET 32x32 data from
# https://arxiv.org/abs/1601.06759.
_IMAGENET_SMALL_ROOT_URL = "http://image-net.org/small/"
_IMAGENET_SMALL_URLS = [
"train_32x32.tar", "valid_32x32.tar"]
_IMAGENET_SMALL_TRAIN_PREFIX = "train_32x32"
_IMAGENET_SMALL_EVAL_PREFIX = "valid_32x32"
_IMAGENET_SMALL_IMAGE_SIZE = 32


# URLs and filenames for IMAGENET 64x64 data.
_IMAGENET_MEDIUM_ROOT_URL = "http://image-net.org/small/"
_IMAGENET_MEDIUM_URLS = [
"train_64x64.tar", "valid_64x64.tar"]
_IMAGENET_MEDIUM_TRAIN_PREFIX = "train_64x64"
_IMAGENET_MEDIUM_EVAL_PREFIX = "valid_64x64"
_IMAGENET_MEDIUM_IMAGE_SIZE = 64


# Derived from ImageNet data
MEAN_RGB = [0.485, 0.456, 0.406]
STDDEV_RGB = [0.229, 0.224, 0.225]


def imagenet_pixelrnn_generator(tmp_dir,
training,
size=_IMAGENET_SMALL_IMAGE_SIZE):
"""Image generator for Imagenet 64x64 downsampled images.
It assumes that the data has been downloaded from
http://image-net.org/small/*_32x32.tar or
http://image-net.org/small/*_64x64.tar into tmp_dir.
Args:
tmp_dir: path to temporary storage directory.
training: a Boolean; if true, we use the train set, otherwise the test set.
size: image size (assumes height and width are same)
Yields:
A dictionary representing the images with the following fields:
* image/encoded: the string encoding the image as JPEG,
* image/format: the string "jpeg" representing image format,
* image/height: an integer representing the height,
* image/width: an integer representing the width.
Every field is actually a list of the corresponding type.
"""
if size == _IMAGENET_SMALL_IMAGE_SIZE:
train_prefix = _IMAGENET_SMALL_TRAIN_PREFIX
eval_prefix = _IMAGENET_SMALL_EVAL_PREFIX
else:
train_prefix = _IMAGENET_MEDIUM_TRAIN_PREFIX
eval_prefix = _IMAGENET_MEDIUM_EVAL_PREFIX
prefix = train_prefix if training else eval_prefix
images_filepath = os.path.join(tmp_dir, prefix)
image_files = tf.gfile.Glob(images_filepath + "/*")
height = size
width = size
const_label = 0
for filename in image_files:
with tf.gfile.Open(filename, "r") as f:
encoded_image = f.read()
yield {
"image/encoded": [encoded_image],
"image/format": ["png"],
"image/class/label": [const_label],
"image/height": [height],
"image/width": [width]
}


def imagenet_preprocess_example(example, mode, resize_size=None):
"""Preprocessing used for Imagenet and similar problems."""
resize_size = resize_size or [299, 299]
Expand Down Expand Up @@ -123,6 +188,40 @@ def preprocess_example(self, example, mode, _):
return example


@registry.register_problem
class ImageImagenet64Gen(ImageImagenet):
"""Cifar-10 Tune."""

@property
def train_shards(self):
return 1024

@property
def dev_shards(self):
return 10

def generate_data(self, data_dir, tmp_dir, task_id=-1):
generator_utils.generate_dataset_and_shuffle(
self.generator(data_dir, tmp_dir, True),
self.training_filepaths(data_dir, self.train_shards, shuffled=True),
self.generator(data_dir, tmp_dir, False),
self.dev_filepaths(data_dir, self.dev_shards, shuffled=True))

def generator(self, data_dir, tmp_dir, is_training):
if is_training:
return imagenet_pixelrnn_generator(
tmp_dir, int(True), size=_IMAGENET_MEDIUM_IMAGE_SIZE)
else:
return imagenet_pixelrnn_generator(
tmp_dir, int(False), size=_IMAGENET_MEDIUM_IMAGE_SIZE)

def preprocess_example(self, example, mode, unused_hparams):
example["inputs"].set_shape([_IMAGENET_MEDIUM_IMAGE_SIZE,
_IMAGENET_MEDIUM_IMAGE_SIZE, 3])
example["inputs"] = tf.to_int64(example["inputs"])
return example


@registry.register_problem
class ImageImagenet64(ImageImagenet32):
"""Imagenet rescaled to 64x64."""
Expand Down
6 changes: 6 additions & 0 deletions tensor2tensor/data_generators/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def main(_):
total_sequences = 0
total_input_tokens = 0
total_target_tokens = 0
nonpadding_input_tokens = 0
nonpadding_target_tokens = 0
max_input_length = 0
max_target_length = 0
for record in reader:
Expand All @@ -71,6 +73,8 @@ def main(_):
print("INPUTS:\n" + encoder.decode(inputs) if encoder else inputs)
if FLAGS.print_targets:
print("TARGETS:\n" + encoder.decode(targets) if encoder else targets)
nonpadding_input_tokens += len(inputs) - inputs.count(0)
nonpadding_target_tokens += len(targets) - targets.count(0)
total_input_tokens += len(inputs)
total_target_tokens += len(targets)
total_sequences += 1
Expand All @@ -83,6 +87,8 @@ def main(_):
print("total_sequences: %d" % total_sequences)
print("total_input_tokens: %d" % total_input_tokens)
print("total_target_tokens: %d" % total_target_tokens)
print("nonpadding_input_tokens: %d" % nonpadding_input_tokens)
print("nonpadding_target_tokens: %d" % nonpadding_target_tokens)
print("max_input_length: %d" % max_input_length)
print("max_target_length: %d" % max_target_length)

Expand Down
Loading

0 comments on commit fd9b315

Please sign in to comment.