Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #290 from rsepassi/push
Browse files Browse the repository at this point in the history
v1.2.2
  • Loading branch information
lukaszkaiser authored Sep 8, 2017
2 parents 8f83adf + b8e59e7 commit 56cb37f
Show file tree
Hide file tree
Showing 53 changed files with 1,822 additions and 730 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ on the task (e.g. fed through a final linear transform to produce logits for a
softmax over classes). All models are imported in
[`models.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/models/models.py),
inherit from `T2TModel` - defined in
[`t2t_model.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/t2t_model.py) - and are registered with
[`t2t_model.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/t2t_model.py) -
and are registered with
[`@registry.register_model`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/utils/registry.py).

### Hyperparameter Sets
Expand Down
40 changes: 32 additions & 8 deletions docs/new_problem.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,17 @@ Let's add a new dataset together and train the transformer model. We'll be learn

For each problem we want to tackle we create a new problem class and register it. Let's call our problem `Word2def`.

Since many text2text problems share similar methods, there's already a class called [`Text2TextProblem`](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L354) that extends the base problem class, `Problem` (both found in [`problem.py`](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py)).

For our problem, we can go ahead and create the file `word2def.py` in the [`data_generators`](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/) folder and add our new problem, `Word2def`, which extends [`Text2TextProblem`](https://github.com/tensorflow/tensor2tensor/blob/24071ba07d5a14c170044c5e60a24bda8179fb7a/tensor2tensor/data_generators/problem.py#L354). Let's also register it while we're at it so we can specify the problem through flags.
Since many text2text problems share similar methods, there's already a class
called `Text2TextProblem` that extends the base problem class, `Problem`
(both found in
[`problem.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem.py)).

For our problem, we can go ahead and create the file `word2def.py` in the
[`data_generators`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/)
folder and add our new problem, `Word2def`, which extends
[`Text2TextProblem`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem.py).
Let's also register it while we're at it so we can specify the problem through
flags.

```python
@registry.register_problem
Expand All @@ -28,7 +36,9 @@ class Word2def(problem.Text2TextProblem):
...
```

We need to implement the following methods from [`Text2TextProblem`](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py#L354) in our new class:
We need to implement the following methods from
[`Text2TextProblem`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem.py).
in our new class:
* is_character_level
* targeted_vocab_size
* generator
Expand All @@ -42,7 +52,12 @@ Let's tackle them one by one:

**input_space_id, target_space_id, is_character_level, targeted_vocab_size, use_subword_tokenizer**:

SpaceIDs tell Tensor2Tensor what sort of space the input and target tensors are in. These are things like, EN_CHR (English character), EN_TOK (English token), AUDIO_WAV (audio waveform), IMAGE, DNA (genetic bases). The complete list can be found at [`data_generators/problem.py`](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/problem.py) in the class `SpaceID`.
SpaceIDs tell Tensor2Tensor what sort of space the input and target tensors are
in. These are things like, EN_CHR (English character), EN_TOK (English token),
AUDIO_WAV (audio waveform), IMAGE, DNA (genetic bases). The complete list can be
found at
[`data_generators/problem.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/problem.py).
in the class `SpaceID`.

Since we're generating definitions and feeding in words at the character level, we set `is_character_level` to true, and use the same SpaceID, EN_CHR, for both input and target. Additionally, since we aren't using tokens, we don't need to give a `targeted_vocab_size` or define `use_subword_tokenizer`.

Expand All @@ -58,7 +73,7 @@ The number of shards to break data files into.
@registry.register_problem()
class Word2def(problem.Text2TextProblem):
"""Problem spec for English word to dictionary definition."""

@property
def is_character_level(self):
return True
Expand Down Expand Up @@ -86,7 +101,15 @@ class Word2def(problem.Text2TextProblem):

**generator**:

We're almost done. `generator` generates the training and evaluation data and stores them in files like "word2def_train.lang1" in your DATA_DIR. Thankfully several commonly used methods like `character_generator`, and `token_generator` are already written in the file [`wmt.py`](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/wmt.py). We will import `character_generator` and [`text_encoder`](https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/text_encoder.py) to write:
We're almost done. `generator` generates the training and evaluation data and
stores them in files like "word2def_train.lang1" in your DATA_DIR. Thankfully
several commonly used methods like `character_generator`, and `token_generator`
are already written in the file
[`wmt.py`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/wmt.py).
We will import `character_generator` and
[`text_encoder`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/data_generators/text_encoder.py)
to write:

```python
def generator(self, data_dir, tmp_dir, train):
character_vocab = text_encoder.ByteTextEncoder()
Expand Down Expand Up @@ -152,6 +175,7 @@ _WORD2DEF_TEST_DATASETS = [
## Putting it all together

Now our `word2def.py` file looks like:

```python
""" Problem definition for word to dictionary definition.
"""
Expand Down Expand Up @@ -210,7 +234,7 @@ class Word2def(problem.Text2TextProblem):
```

# Hyperparameters
All hyperparamters inherit from `_default_hparams()` in `problem.py.` If you would like to customize your hyperparameters, register a new hyperparameter set in `word2def.py` like the example provided in the walkthrough. For example:
All hyperparamters inherit from `_default_hparams()` in `problem.py.` If you would like to customize your hyperparameters, register a new hyperparameter set in `word2def.py` like the example provided in the walkthrough. For example:

```python
from tensor2tensor.models import transformer
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='tensor2tensor',
version='1.2.1',
version='1.2.2',
description='Tensor2Tensor',
author='Google Inc.',
author_email='no-reply@google.com',
Expand Down
1 change: 0 additions & 1 deletion tensor2tensor/data_generators/all_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,3 @@
pass
# pylint: enable=g-import-not-at-top
# pylint: enable=unused-import

4 changes: 2 additions & 2 deletions tensor2tensor/data_generators/cnn_dailymail.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def _maybe_download_corpora(tmp_dir):
filepath of the downloaded corpus file.
"""
cnn_filename = "cnn_stories.tgz"
dailymail_filename = "dailymail_stories.tgz"
cnn_finalpath = os.path.join(tmp_dir, "cnn/stories/")
dailymail_filename = "dailymail_stories.tgz"
dailymail_finalpath = os.path.join(tmp_dir, "dailymail/stories/")
if not tf.gfile.Exists(cnn_finalpath):
cnn_file = generator_utils.maybe_download_from_drive(
Expand All @@ -63,7 +63,7 @@ def _maybe_download_corpora(tmp_dir):
cnn_tar.extractall(tmp_dir)
if not tf.gfile.Exists(dailymail_finalpath):
dailymail_file = generator_utils.maybe_download_from_drive(
tmp_dir, dailymail_filename, _CNN_STORIES_DRIVE_URL)
tmp_dir, dailymail_filename, _DAILYMAIL_STORIES_DRIVE_URL)
with tarfile.open(dailymail_file, "r:gz") as dailymail_tar:
dailymail_tar.extractall(tmp_dir)
return [cnn_finalpath, dailymail_finalpath]
Expand Down
5 changes: 2 additions & 3 deletions tensor2tensor/data_generators/gene_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
# Shuffle
generator_utils.shuffle_dataset(all_filepaths)

def hparams(self, defaults, model_hparams):
def hparams(self, defaults, unused_model_hparams):
p = defaults
vocab_size = self._encoders["inputs"].vocab_size
p.input_modality = {"inputs": (registry.Modalities.SYMBOL, vocab_size)}
Expand All @@ -159,9 +159,8 @@ def example_reading_spec(self):
data_items_to_decoders = None
return (data_fields, data_items_to_decoders)

def preprocess_examples(self, examples, mode, hparams):
def preprocess_examples(self, examples, mode, unused_hparams):
del mode
del hparams

# Reshape targets to contain num_output_predictions per output timestep
examples["targets"] = tf.reshape(examples["targets"],
Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/data_generators/ice_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
self.targeted_vocab_size),
self.dev_filepaths(data_dir, 1, shuffled=False))

def hparams(self, defaults, model_hparams):
def hparams(self, defaults, unused_model_hparams):
p = defaults
source_vocab_size = self._encoders["inputs"].vocab_size
p.input_modality = {"inputs": (registry.Modalities.SYMBOL,
Expand Down
41 changes: 32 additions & 9 deletions tensor2tensor/data_generators/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def resize(img, size):
examples["targets"] = resize(inputs, 32)
return examples

def hparams(self, defaults, model_hparams):
def hparams(self, defaults, unused_model_hparams):
p = defaults
p.input_modality = {"inputs": ("image:identity_no_pad", None)}
p.target_modality = ("image:identity_no_pad", None)
Expand Down Expand Up @@ -229,7 +229,7 @@ def feature_encoders(self, data_dir):
"targets": text_encoder.SubwordTextEncoder(vocab_filename)
}

def hparams(self, defaults, model_hparams):
def hparams(self, defaults, unused_model_hparams):
p = defaults
p.input_modality = {"inputs": (registry.Modalities.IMAGE, None)}
vocab_size = self._encoders["targets"].vocab_size
Expand Down Expand Up @@ -264,10 +264,21 @@ def train_shards(self):
def dev_shards(self):
return 1

@property
def class_labels(self):
return ["ID_%d" % i for i in range(self.num_classes)]

def feature_encoders(self, data_dir):
del data_dir
return {
"inputs": text_encoder.TextEncoder(),
"targets": text_encoder.ClassLabelEncoder(self.class_labels)
}

def generator(self, data_dir, tmp_dir, is_training):
raise NotImplementedError()

def hparams(self, defaults, model_hparams):
def hparams(self, defaults, unused_model_hparams):
p = defaults
small_modality = "%s:small_image_modality" % registry.Modalities.IMAGE
modality = small_modality if self.is_small else registry.Modalities.IMAGE
Expand Down Expand Up @@ -302,7 +313,7 @@ def resize(img):
return tf.to_int64(tf.image.resize_images(img, [299, 299]))

inputs = tf.cast(examples["inputs"], tf.int64)
if mode == tf.contrib.learn.ModeKeys.TRAIN:
if mode == tf.estimator.ModeKeys.TRAIN:
examples["inputs"] = tf.cond( # Preprocess 90% of the time.
tf.less(tf.random_uniform([]), 0.9),
lambda img=inputs: preprocess(img),
Expand Down Expand Up @@ -349,7 +360,7 @@ def is_small(self):
def num_classes(self):
return 1000

def preprocess_examples(self, examples, mode, hparams):
def preprocess_examples(self, examples, mode, unused_hparams):
# Just resize with area.
if self._was_reversed:
examples["inputs"] = tf.to_int64(
Expand Down Expand Up @@ -491,6 +502,10 @@ def is_small(self):
def num_classes(self):
return 10

@property
def class_labels(self):
return [str(c) for c in range(self.num_classes)]

@property
def train_shards(self):
return 10
Expand Down Expand Up @@ -564,9 +579,17 @@ def cifar10_generator(tmp_dir, training, how_many, start_from=0):

@registry.register_problem
class ImageCifar10Tune(ImageMnistTune):
"""Cifar-10 Tune."""

@property
def class_labels(self):
return [
"airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse",
"ship", "truck"
]

def preprocess_examples(self, examples, mode, hparams):
if mode == tf.contrib.learn.ModeKeys.TRAIN:
def preprocess_examples(self, examples, mode, unused_hparams):
if mode == tf.estimator.ModeKeys.TRAIN:
examples["inputs"] = common_layers.cifar_image_augmentation(
examples["inputs"])
return examples
Expand All @@ -591,7 +614,7 @@ def generator(self, data_dir, tmp_dir, is_training):
@registry.register_problem
class ImageCifar10Plain(ImageCifar10):

def preprocess_examples(self, examples, mode, hparams):
def preprocess_examples(self, examples, mode, unused_hparams):
return examples


Expand Down Expand Up @@ -730,7 +753,7 @@ def feature_encoders(self, data_dir):
encoder = text_encoder.SubwordTextEncoder(vocab_filename)
return {"targets": encoder}

def hparams(self, defaults, model_hparams):
def hparams(self, defaults, unused_model_hparams):
p = defaults
p.input_modality = {"inputs": (registry.Modalities.IMAGE, None)}
encoder = self._encoders["targets"]
Expand Down
4 changes: 2 additions & 2 deletions tensor2tensor/data_generators/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1):
self.generator(data_dir, tmp_dir, True), train_paths,
self.generator(data_dir, tmp_dir, False), dev_paths)

def hparams(self, defaults, model_hparams):
def hparams(self, defaults, unused_model_hparams):
p = defaults
source_vocab_size = self._encoders["inputs"].vocab_size
p.input_modality = {
Expand All @@ -112,7 +112,7 @@ def feature_encoders(self, data_dir):
encoder = text_encoder.SubwordTextEncoder(vocab_filename)
return {
"inputs": encoder,
"targets": text_encoder.TextEncoder(),
"targets": text_encoder.ClassLabelEncoder(["neg", "pos"]),
}

def example_reading_spec(self):
Expand Down
Loading

0 comments on commit 56cb37f

Please sign in to comment.