From d30ec6bd36760dbb02cabb4b434dd1fc89edbd03 Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Wed, 9 Aug 2017 15:51:00 -0700 Subject: [PATCH 1/7] Add TransformerEncoder and TransformerDecoder models PiperOrigin-RevId: 164785525 --- tensor2tensor/bin/t2t-datagen | 10 ++ tensor2tensor/bin/t2t-trainer | 0 tensor2tensor/data_generators/all_problems.py | 1 - .../data_generators/generator_utils.py | 0 tensor2tensor/data_generators/ice_parsing.py | 117 ------------------ .../data_generators/problem_hparams.py | 37 ++++++ tensor2tensor/data_generators/wmt.py | 22 ++++ tensor2tensor/models/transformer.py | 102 ++++++++++----- tensor2tensor/utils/decoding.py | 7 +- tensor2tensor/utils/registry.py | 6 +- 10 files changed, 143 insertions(+), 159 deletions(-) mode change 100755 => 100644 tensor2tensor/bin/t2t-datagen mode change 100755 => 100644 tensor2tensor/bin/t2t-trainer mode change 100755 => 100644 tensor2tensor/data_generators/all_problems.py mode change 100755 => 100644 tensor2tensor/data_generators/generator_utils.py delete mode 100755 tensor2tensor/data_generators/ice_parsing.py mode change 100755 => 100644 tensor2tensor/data_generators/problem_hparams.py mode change 100755 => 100644 tensor2tensor/data_generators/wmt.py mode change 100755 => 100644 tensor2tensor/models/transformer.py mode change 100755 => 100644 tensor2tensor/utils/decoding.py mode change 100755 => 100644 tensor2tensor/utils/registry.py diff --git a/tensor2tensor/bin/t2t-datagen b/tensor2tensor/bin/t2t-datagen old mode 100755 new mode 100644 index 97bbd1241..39453dbee --- a/tensor2tensor/bin/t2t-datagen +++ b/tensor2tensor/bin/t2t-datagen @@ -82,6 +82,16 @@ _SUPPORTED_PROBLEM_GENERATORS = { "algorithmic_algebra_inverse": ( lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000), lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)), + "ice_parsing_tokens": ( + lambda: wmt.tabbed_parsing_token_generator( + FLAGS.data_dir, FLAGS.tmp_dir, True, "ice", 2**13, 2**8), + lambda: wmt.tabbed_parsing_token_generator( + FLAGS.data_dir, FLAGS.tmp_dir, False, "ice", 2**13, 2**8)), + "ice_parsing_characters": ( + lambda: wmt.tabbed_parsing_character_generator( + FLAGS.data_dir, FLAGS.tmp_dir, True), + lambda: wmt.tabbed_parsing_character_generator( + FLAGS.data_dir, FLAGS.tmp_dir, False)), "wmt_parsing_tokens_8k": ( lambda: wmt.parsing_token_generator( FLAGS.data_dir, FLAGS.tmp_dir, True, 2**13), diff --git a/tensor2tensor/bin/t2t-trainer b/tensor2tensor/bin/t2t-trainer old mode 100755 new mode 100644 diff --git a/tensor2tensor/data_generators/all_problems.py b/tensor2tensor/data_generators/all_problems.py old mode 100755 new mode 100644 index 10a4764f5..ca6dccfda --- a/tensor2tensor/data_generators/all_problems.py +++ b/tensor2tensor/data_generators/all_problems.py @@ -31,7 +31,6 @@ from tensor2tensor.data_generators import wiki from tensor2tensor.data_generators import wmt from tensor2tensor.data_generators import wsj_parsing -from tensor2tensor.data_generators import ice_parsing # Problem modules that require optional dependencies diff --git a/tensor2tensor/data_generators/generator_utils.py b/tensor2tensor/data_generators/generator_utils.py old mode 100755 new mode 100644 diff --git a/tensor2tensor/data_generators/ice_parsing.py b/tensor2tensor/data_generators/ice_parsing.py deleted file mode 100755 index 7a90fec45..000000000 --- a/tensor2tensor/data_generators/ice_parsing.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright 2017 The Tensor2Tensor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This module implements the ice_parsing_* problems, which -# parse plain text into flattened parse trees and POS tags. -# The training data is stored in files named `parsing_train.pairs` -# and `parsing_dev.pairs`. These files are UTF-8 text files where -# each line contains an input sentence and a target parse tree, -# separated by a tab character. - -import os - -# Dependency imports - -from tensor2tensor.data_generators import generator_utils -from tensor2tensor.data_generators import problem -from tensor2tensor.data_generators import text_encoder -from tensor2tensor.data_generators.wmt import tabbed_generator -from tensor2tensor.utils import registry - -import tensorflow as tf - - -# End-of-sentence marker. -EOS = text_encoder.EOS_ID - - -def tabbed_parsing_token_generator(data_dir, tmp_dir, train, prefix, - source_vocab_size, target_vocab_size): - """Generate source and target data from a single file.""" - filename = "parsing_{0}.pairs".format("train" if train else "dev") - source_vocab = generator_utils.get_or_generate_tabbed_vocab( - data_dir, tmp_dir, filename, 0, - prefix + "_source.tokens.vocab.%d" % source_vocab_size, source_vocab_size) - target_vocab = generator_utils.get_or_generate_tabbed_vocab( - data_dir, tmp_dir, filename, 1, - prefix + "_target.tokens.vocab.%d" % target_vocab_size, target_vocab_size) - pair_filepath = os.path.join(tmp_dir, filename) - return tabbed_generator(pair_filepath, source_vocab, target_vocab, EOS) - - -def tabbed_parsing_character_generator(tmp_dir, train): - """Generate source and target data from a single file.""" - character_vocab = text_encoder.ByteTextEncoder() - filename = "parsing_{0}.pairs".format("train" if train else "dev") - pair_filepath = os.path.join(tmp_dir, filename) - return tabbed_generator(pair_filepath, character_vocab, character_vocab, EOS) - - -@registry.register_problem("ice_parsing_tokens") -class IceParsingTokens(problem.Problem): - """Problem spec for parsing tokenized Icelandic text to - constituency trees, also tokenized but to a smaller vocabulary.""" - - @property - def source_vocab_size(self): - return 2**14 # 16384 - - @property - def targeted_vocab_size(self): - return 2**8 # 256 - - @property - def input_space_id(self): - return problem.SpaceID.ICE_TOK - - @property - def target_space_id(self): - return problem.SpaceID.ICE_PARSE_TOK - - @property - def num_shards(self): - return 10 - - def feature_encoders(self, data_dir): - source_vocab_filename = os.path.join( - data_dir, "ice_source.tokens.vocab.%d" % self.source_vocab_size) - target_vocab_filename = os.path.join( - data_dir, "ice_target.tokens.vocab.%d" % self.targeted_vocab_size) - source_subtokenizer = text_encoder.SubwordTextEncoder(source_vocab_filename) - target_subtokenizer = text_encoder.SubwordTextEncoder(target_vocab_filename) - return { - "inputs": source_subtokenizer, - "targets": target_subtokenizer, - } - - def generate_data(self, data_dir, tmp_dir, task_id=-1): - generator_utils.generate_dataset_and_shuffle( - tabbed_parsing_token_generator(data_dir, tmp_dir, True, "ice", - self.source_vocab_size, - self.targeted_vocab_size), - self.training_filepaths(data_dir, self.num_shards, shuffled=False), - tabbed_parsing_token_generator(data_dir, tmp_dir, False, "ice", - self.source_vocab_size, - self.targeted_vocab_size), - self.dev_filepaths(data_dir, 1, shuffled=False)) - - def hparams(self, defaults, model_hparams): - p = defaults - source_vocab_size = self._encoders["inputs"].vocab_size - p.input_modality = {"inputs": (registry.Modalities.SYMBOL, source_vocab_size)} - p.target_modality = (registry.Modalities.SYMBOL, self.targeted_vocab_size) - p.input_space_id = self.input_space_id - p.target_space_id = self.target_space_id - p.loss_multiplier = 2.5 # Rough estimate of avg number of tokens per word - diff --git a/tensor2tensor/data_generators/problem_hparams.py b/tensor2tensor/data_generators/problem_hparams.py old mode 100755 new mode 100644 index b0ed44f5b..d0577db52 --- a/tensor2tensor/data_generators/problem_hparams.py +++ b/tensor2tensor/data_generators/problem_hparams.py @@ -462,6 +462,39 @@ def wsj_parsing_tokens(model_hparams, prefix, wrong_source_vocab_size, return p +def ice_parsing_tokens(model_hparams, wrong_source_vocab_size): + """Icelandic to parse tree translation benchmark. + + Args: + model_hparams: a tf.contrib.training.HParams + wrong_source_vocab_size: a number used in the filename indicating the + approximate vocabulary size. This is not to be confused with the actual + vocabulary size. + + Returns: + A tf.contrib.training.HParams object. + """ + p = default_problem_hparams() + # This vocab file must be present within the data directory. + source_vocab_filename = os.path.join( + model_hparams.data_dir, "ice_source.vocab.%d" % wrong_source_vocab_size) + target_vocab_filename = os.path.join(model_hparams.data_dir, + "ice_target.vocab.256") + source_subtokenizer = text_encoder.SubwordTextEncoder(source_vocab_filename) + target_subtokenizer = text_encoder.SubwordTextEncoder(target_vocab_filename) + p.input_modality = { + "inputs": (registry.Modalities.SYMBOL, source_subtokenizer.vocab_size) + } + p.target_modality = (registry.Modalities.SYMBOL, 256) + p.vocabulary = { + "inputs": source_subtokenizer, + "targets": target_subtokenizer, + } + p.input_space_id = 18 # Icelandic tokens + p.target_space_id = 19 # Icelandic parse tokens + return p + + def img2img_imagenet(unused_model_hparams): """Image 2 Image for imagenet dataset.""" p = default_problem_hparams() @@ -511,6 +544,10 @@ def image_celeba(unused_model_hparams): lm1b_32k, "wiki_32k": wiki_32k, + "ice_parsing_characters": + wmt_parsing_characters, + "ice_parsing_tokens": + lambda p: ice_parsing_tokens(p, 2**13), "wmt_parsing_tokens_8k": lambda p: wmt_parsing_tokens(p, 2**13), "wsj_parsing_tokens_16k": diff --git a/tensor2tensor/data_generators/wmt.py b/tensor2tensor/data_generators/wmt.py old mode 100755 new mode 100644 index 35d1b5fca..0a47e9989 --- a/tensor2tensor/data_generators/wmt.py +++ b/tensor2tensor/data_generators/wmt.py @@ -648,6 +648,28 @@ def target_space_id(self): return problem.SpaceID.CS_CHR +def tabbed_parsing_token_generator(data_dir, tmp_dir, train, prefix, + source_vocab_size, target_vocab_size): + """Generate source and target data from a single file.""" + source_vocab = generator_utils.get_or_generate_tabbed_vocab( + data_dir, tmp_dir, "parsing_train.pairs", 0, + prefix + "_source.vocab.%d" % source_vocab_size, source_vocab_size) + target_vocab = generator_utils.get_or_generate_tabbed_vocab( + data_dir, tmp_dir, "parsing_train.pairs", 1, + prefix + "_target.vocab.%d" % target_vocab_size, target_vocab_size) + filename = "parsing_%s" % ("train" if train else "dev") + pair_filepath = os.path.join(tmp_dir, filename + ".pairs") + return tabbed_generator(pair_filepath, source_vocab, target_vocab, EOS) + + +def tabbed_parsing_character_generator(tmp_dir, train): + """Generate source and target data from a single file.""" + character_vocab = text_encoder.ByteTextEncoder() + filename = "parsing_%s" % ("train" if train else "dev") + pair_filepath = os.path.join(tmp_dir, filename + ".pairs") + return tabbed_generator(pair_filepath, character_vocab, character_vocab, EOS) + + def parsing_token_generator(data_dir, tmp_dir, train, vocab_size): symbolizer_vocab = generator_utils.get_or_generate_vocab( data_dir, tmp_dir, "vocab.endefr.%d" % vocab_size, vocab_size) diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py old mode 100755 new mode 100644 index fa7ecdf81..37c1206bd --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -55,22 +55,66 @@ def model_fn_body(self, features): (decoder_input, decoder_self_attention_bias) = transformer_prepare_decoder( targets, hparams) - encoder_input = tf.nn.dropout( - encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) - decoder_input = tf.nn.dropout( - decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) - encoder_output = transformer_encoder( - encoder_input, encoder_self_attention_bias, hparams) + encoder_input = tf.nn.dropout(encoder_input, + 1.0 - hparams.layer_prepostprocess_dropout) + decoder_input = tf.nn.dropout(decoder_input, + 1.0 - hparams.layer_prepostprocess_dropout) + encoder_output = transformer_encoder(encoder_input, + encoder_self_attention_bias, hparams) decoder_output = transformer_decoder( - decoder_input, encoder_output, - decoder_self_attention_bias, + decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams) decoder_output = tf.expand_dims(decoder_output, 2) return decoder_output +@registry.register_model +class TransformerEncoder(t2t_model.T2TModel): + """Transformer, encoder only.""" + + def model_fn_body(self, features): + hparams = self._hparams + inputs = features["inputs"] + target_space = features["target_space_id"] + + inputs = common_layers.flatten4d3d(inputs) + + (encoder_input, encoder_self_attention_bias, + _) = (transformer_prepare_encoder(inputs, target_space, hparams)) + + encoder_input = tf.nn.dropout(encoder_input, + 1.0 - hparams.layer_prepostprocess_dropout) + encoder_output = transformer_encoder(encoder_input, + encoder_self_attention_bias, hparams) + + return encoder_output + + +@registry.register_model +class TransformerDecoder(t2t_model.T2TModel): + """Transformer, decoder only.""" + + def model_fn_body(self, features): + hparams = self._hparams + targets = features["targets"] + + targets = common_layers.flatten4d3d(targets) + + (decoder_input, decoder_self_attention_bias) = transformer_prepare_decoder( + targets, hparams) + + decoder_input = tf.nn.dropout(decoder_input, + 1.0 - hparams.layer_prepostprocess_dropout) + + decoder_output = transformer_decoder( + decoder_input, None, decoder_self_attention_bias, None, hparams) + decoder_output = tf.expand_dims(decoder_output, 2) + + return decoder_output + + def transformer_prepare_encoder(inputs, target_space, hparams): """Prepare one shard of the model for the encoder. @@ -150,14 +194,11 @@ def transformer_encoder(encoder_input, with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( - common_layers.layer_preprocess(x, hparams), - None, - encoder_self_attention_bias, + common_layers.layer_preprocess( + x, hparams), None, encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, - hparams.hidden_size, - hparams.num_heads, - hparams.attention_dropout) + hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer_ffn_layer( @@ -196,26 +237,23 @@ def transformer_decoder(decoder_input, with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("self_attention"): y = common_attention.multihead_attention( - common_layers.layer_preprocess(x, hparams), - None, - decoder_self_attention_bias, - hparams.attention_key_channels or hparams.hidden_size, - hparams.attention_value_channels or hparams.hidden_size, - hparams.hidden_size, - hparams.num_heads, - hparams.attention_dropout) - x = common_layers.layer_postprocess(x, y, hparams) - with tf.variable_scope("encdec_attention"): - y = common_attention.multihead_attention( - common_layers.layer_preprocess(x, hparams), - encoder_output, - encoder_decoder_attention_bias, + common_layers.layer_preprocess( + x, hparams), None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, - hparams.hidden_size, - hparams.num_heads, - hparams.attention_dropout) + hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) x = common_layers.layer_postprocess(x, y, hparams) + if encoder_output is not None: + assert encoder_decoder_attention_bias is not None + with tf.variable_scope("encdec_attention"): + y = common_attention.multihead_attention( + common_layers.layer_preprocess( + x, hparams), encoder_output, encoder_decoder_attention_bias, + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, hparams.num_heads, + hparams.attention_dropout) + x = common_layers.layer_postprocess(x, y, hparams) with tf.variable_scope("ffn"): y = transformer_ffn_layer( common_layers.layer_preprocess(x, hparams), hparams) @@ -393,7 +431,7 @@ def transformer_parsing_big(): @registry.register_hparams def transformer_parsing_ice(): - """Hparams for parsing and tagging Icelandic text.""" + """Hparams for parsing Icelandic text.""" hparams = transformer_base_single_gpu() hparams.batch_size = 4096 hparams.shared_embedding_and_softmax_weights = int(False) diff --git a/tensor2tensor/utils/decoding.py b/tensor2tensor/utils/decoding.py old mode 100755 new mode 100644 index fc9eb566f..5e8f4d482 --- a/tensor2tensor/utils/decoding.py +++ b/tensor2tensor/utils/decoding.py @@ -259,11 +259,6 @@ def _interactive_input_fn(hparams): vocabulary = p_hparams.vocabulary["inputs" if has_input else "targets"] # This should be longer than the longest input. const_array_size = 10000 - # Import readline if available for command line editing and recall - try: - import readline - except ImportError: - pass while True: prompt = ("INTERACTIVE MODE num_samples=%d decode_length=%d \n" " it= ('text' or 'image' or 'label')\n" @@ -271,7 +266,7 @@ def _interactive_input_fn(hparams): " in= (set the input problem number)\n" " ou= (set the output problem number)\n" " ns= (changes number of samples)\n" - " dl= (changes decode length)\n" + " dl= (changes decode legnth)\n" " <%s> (decode)\n" " q (quit)\n" ">" % (num_samples, decode_length, "source_string" diff --git a/tensor2tensor/utils/registry.py b/tensor2tensor/utils/registry.py old mode 100755 new mode 100644 index d79eef484..fea647b2b --- a/tensor2tensor/utils/registry.py +++ b/tensor2tensor/utils/registry.py @@ -225,10 +225,10 @@ def parse_problem_name(problem_name): was_copy: A boolean. """ # Recursively strip tags until we reach a base name. - if problem_name.endswith("_rev"): + if len(problem_name) > 4 and problem_name[-4:] == "_rev": base, _, was_copy = parse_problem_name(problem_name[:-4]) return base, True, was_copy - elif problem_name.endswith("_copy"): + elif len(problem_name) > 5 and problem_name[-5:] == "_copy": base, was_reversed, _ = parse_problem_name(problem_name[:-5]) return base, was_reversed, True else: @@ -352,7 +352,7 @@ def list_modalities(): def parse_modality_name(name): - name_parts = name.split(":", maxsplit=1) + name_parts = name.split(":") if len(name_parts) < 2: name_parts.append("default") modality_type, modality_name = name_parts From 12c59a7d3fa452af0de7b792126d32c35d60d37f Mon Sep 17 00:00:00 2001 From: Noam Shazeer Date: Thu, 10 Aug 2017 10:20:53 -0700 Subject: [PATCH 2/7] Massively simplify expert_utils. Breaks checkpoints for models that use experts. Fixed bug in Parallelism, where caching devices were always used, even when none. Fixed bug in attention_lm, attention_lm_moe by setting the default norm_type to "layer" instead of "none". PiperOrigin-RevId: 164869403 --- tensor2tensor/layers/common_hparams.py | 7 +- tensor2tensor/layers/common_layers.py | 67 +- tensor2tensor/layers/modalities.py | 2 +- tensor2tensor/models/attention_lm_moe.py | 48 +- tensor2tensor/models/long_answer.py | 276 ------ tensor2tensor/models/models.py | 1 - tensor2tensor/models/multimodel.py | 38 +- tensor2tensor/models/slicenet.py | 4 - tensor2tensor/models/transformer_moe.py | 37 +- tensor2tensor/utils/expert_utils.py | 1010 +++++----------------- 10 files changed, 313 insertions(+), 1177 deletions(-) delete mode 100644 tensor2tensor/models/long_answer.py diff --git a/tensor2tensor/layers/common_hparams.py b/tensor2tensor/layers/common_hparams.py index 10b5e7e59..0ed62685f 100644 --- a/tensor2tensor/layers/common_hparams.py +++ b/tensor2tensor/layers/common_hparams.py @@ -69,6 +69,11 @@ def basic_params1(): sampling_method="argmax", # "argmax" or "random" problem_choice="adaptive", # "uniform", "adaptive", "distributed" multiply_embedding_mode="sqrt_depth", + # Parameters related to mixtures of experts. + moe_hidden_sizes="2048", # hidden layer sizes (comma-separated) + moe_num_experts=64, # number of experts per layer + moe_k=2, # how many experts to use for each batch element + moe_loss_coef=1e-2, # Sequences of operations to perform on layer input and layer output. # Used by common_layers.layer_preprocess, common_layers.layer_postprocess # Each character repsesnts an operation: @@ -83,7 +88,7 @@ def basic_params1(): # dropout rate to use during layer_preprocess and layer_postprocess layer_prepostprocess_dropout=0.1, # What type of normalization to use - norm_type="none", # "batch", layer", "noam", "none". + norm_type="layer", # "batch", layer", "noam", "none". # epsilon parameter to normalization function norm_epsilon=1e-6, symbol_modality_num_shards=16, diff --git a/tensor2tensor/layers/common_layers.py b/tensor2tensor/layers/common_layers.py index a85430c1c..e9b195195 100644 --- a/tensor2tensor/layers/common_layers.py +++ b/tensor2tensor/layers/common_layers.py @@ -193,7 +193,7 @@ def embedding(x, vocab_size, dense_size, name=None, reuse=None, multiplier=1.0): # On the backwards pass, we want to convert the gradient from # an indexed-slices to a regular tensor before sending it back to the # parameter server. This avoids excess computation on the parameter server. - embedding_var = eu.ConvertGradientToTensor(embedding_var) + embedding_var = eu.convert_gradient_to_tensor(embedding_var) emb_x = tf.gather(embedding_var, x) if multiplier != 1.0: emb_x *= multiplier @@ -823,71 +823,6 @@ def decompress_seqcnn(x, return tf.layers.dense(outputs, targets_vocab_size) -def moe_layer(data_parallelism, - ps_devices, - xs, - train, - model_hidden_size, - expert_hidden_size, - n1, - n2, - loss_coef, - autoscale=True, - name=None): - """A mixture of experts layer. - - Args: - data_parallelism: a expert_utils.Parallelism object. - ps_devices: a list of strings - xs: a list of input tensors. - train: a boolean scalar. - model_hidden_size: an integer (input/output size for this layer) - expert_hidden_size: an integer (size of each expert's hidden layer) - n1: an integer - number of experts (or # of groups for hierarchical MoE) - n2: optional integer - size of each group of experts for hierarchical MoE - loss_coef: a scalar - multiplier on load-balancing losses - autoscale: a boolean - name: a string - - Returns: - ys: a list of tensors: - extra_training_loss: a scalar - """ - dp = data_parallelism - with tf.variable_scope(name, default_name="moe"): - # Set up the hyperparameters for the gating networks. - primary_gating_hp = eu.NoisyTopKGatingParams() - primary_gating_hp.num_experts = n1 - if n2: - # hierarchical MoE containing moe_n1 groups of moe_n2 experts. - assert n2 > 1 - secondary_gating_hp = eu.NoisyTopKGatingParams() - secondary_gating_hp.num_experts = n2 - else: - # flat mixture of moe_n1 experts. - secondary_gating_hp = None - # Set up the hyperparameters for the expert networks. - # Each expert contains a hidden RELU layer of size filter_size - expert_hp = eu.FeedForwardExpertParams() - expert_hp.autoscale = autoscale - expert_hp.hidden_layer_sizes = [expert_hidden_size] - # Create the mixture of experts. - moe = eu.DistributedMixtureOfExperts(primary_gating_hp, secondary_gating_hp, - expert_hp, model_hidden_size, - model_hidden_size, ps_devices, "moe") - # MoE expects input tensors to be 2d. - # Flatten out spatial dimensions. - xs_2d = dp(tf.reshape, xs, [[-1, model_hidden_size]] * dp.n) - # Call the MoE - moe_out_2d, importance, load, _, _ = moe.Eval( - dp.devices, xs_2d, train, identifiers=None) - # Reshape the output to the original shape. - moe_out = dp(tf.reshape, moe_out_2d, dp(tf.shape, xs)) - # These losses encourage equal load on the different experts. - loss = loss_coef * (eu.CVSquared(importance) + eu.CVSquared(load)) - return moe_out, loss - - def simple_attention(target, source, bias=None): """A simple attention function. diff --git a/tensor2tensor/layers/modalities.py b/tensor2tensor/layers/modalities.py index e44729041..acaacbf99 100644 --- a/tensor2tensor/layers/modalities.py +++ b/tensor2tensor/layers/modalities.py @@ -70,7 +70,7 @@ def _get_weights(self): ret = shards[0] else: ret = tf.concat(shards, 0) - ret = eu.ConvertGradientToTensor(ret) + ret = eu.convert_gradient_to_tensor(ret) return ret def bottom_simple(self, x, name, reuse): diff --git a/tensor2tensor/models/attention_lm_moe.py b/tensor2tensor/models/attention_lm_moe.py index 1869eef66..268e93f7b 100644 --- a/tensor2tensor/models/attention_lm_moe.py +++ b/tensor2tensor/models/attention_lm_moe.py @@ -32,6 +32,7 @@ from tensor2tensor.layers import common_attention from tensor2tensor.layers import common_hparams from tensor2tensor.layers import common_layers +from tensor2tensor.utils import expert_utils from tensor2tensor.utils import registry from tensor2tensor.utils import t2t_model @@ -61,6 +62,7 @@ def postprocess(x, y): x = dp(tf.nn.dropout, decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) extra_loss = 0.0 + moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")] for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("attention"): @@ -78,11 +80,18 @@ def postprocess(x, y): x = postprocess(x, y) with tf.variable_scope("ffn"): if str(layer) in hparams.moe_layers.split(","): - y, loss = common_layers.moe_layer( - dp, self._ps_devices, preprocess(x), + y, loss = expert_utils.distributed_moe( + dp, + self._ps_devices, + preprocess(x), hparams.mode == tf.contrib.learn.ModeKeys.TRAIN, - hparams.hidden_size, hparams.moe_hidden_size, hparams.moe_n1, - hparams.moe_n2, hparams.moe_loss_coef) + input_size=hparams.hidden_size, + expert_fn=expert_utils.ffn_expert_fn( + hparams.hidden_size, moe_hidden_sizes, + hparams.hidden_size), + num_experts=hparams.moe_num_experts, + k=hparams.moe_k, + loss_coef=hparams.moe_loss_coef) extra_loss += loss else: y = dp( @@ -149,16 +158,7 @@ def attention_lm_moe_base(): hparams.label_smoothing = 0.0 hparams.shared_embedding_and_softmax_weights = int(False) hparams.add_hparam("filter_size", 2048) # Add new ones like this. - # comma-separated list of layer numbers. - # At each of these layers, we replace the ffn with a mixture of experts. - hparams.add_hparam("moe_layers", "2") - # If moe_n2 is None, then use a flat MoE with moe_n1 experts. - # If moe_n2 is an integer, then use a hierarchical MoE - # consisting of moe_n1 groups of moe_n2 experts each. - hparams.add_hparam("moe_n1", 32) - hparams.add_hparam("moe_n2", 0) - hparams.add_hparam("moe_hidden_size", 2048) - hparams.add_hparam("moe_loss_coef", 1e-2) + hparams.moe_num_experts = 32 # attention-related flags hparams.add_hparam("num_heads", 8) hparams.add_hparam("attention_key_channels", 0) @@ -168,6 +168,7 @@ def attention_lm_moe_base(): hparams.add_hparam("attention_dropout", 0.0) hparams.add_hparam("relu_dropout", 0.0) hparams.add_hparam("pos", "timing") # timing, none + hparams.add_hparam("moe_layers", "2") # comma separated list of layer numbers return hparams @@ -188,9 +189,20 @@ def attention_lm_moe_small(): hparams.num_hidden_layers = 4 hparams.hidden_size = 512 hparams.filter_size = 2048 - hparams.moe_n1 = 128 + hparams.moe_num_experts = 128 hparams.moe_layers = "2" - hparams.moe_hidden_size = 2048 + return hparams + + +@registry.register_hparams +def attention_lm_moe_tiny(): + """Cheap model for debugging. + + Returns: + an hparams object. + """ + hparams = attention_lm_moe_small() + hparams.moe_num_experts = 32 return hparams @@ -233,7 +245,7 @@ def attention_lm_moe_large(): hparams.hidden_size = 1024 hparams.num_heads = 16 hparams.filter_size = 4096 - hparams.moe_hidden_size = 4096 - hparams.moe_n1 = 128 + hparams.moe_hidden_sizes = "4096" + hparams.moe_num_experts = 128 hparams.layer_prepostprocess_dropout = 0.2 return hparams diff --git a/tensor2tensor/models/long_answer.py b/tensor2tensor/models/long_answer.py deleted file mode 100644 index a9fb45e4a..000000000 --- a/tensor2tensor/models/long_answer.py +++ /dev/null @@ -1,276 +0,0 @@ -# coding=utf-8 -# Copyright 2017 The Tensor2Tensor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Model to generate long answers to short questions. - -E.g. wiki_32k title->article dataset. - -Variant on attention_lm_moe.py - - prepend the inputs to the targets. - - use masked local attention to avoid quadratic space and time blowup for - long sequences. - -This model is still highly experimental and under rapid iteration. - -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# Dependency imports - -from six.moves import xrange # pylint: disable=redefined-builtin - -from tensor2tensor.layers import common_attention -from tensor2tensor.layers import common_hparams -from tensor2tensor.layers import common_layers -from tensor2tensor.utils import registry -from tensor2tensor.utils import t2t_model - -import tensorflow as tf - - -@registry.register_model -class LongAnswer(t2t_model.T2TModel): - """Attention net. See file docstring.""" - - def model_fn_body_sharded(self, sharded_features): - # Remove dropout if not training - hparams = self._hparams - dp = self._data_parallelism - targets = sharded_features["targets"] - targets = dp(tf.squeeze, targets, 2) - inputs = sharded_features["inputs"] - inputs = dp(tf.squeeze, inputs, 2) - - decoder_input = dp(long_answer_prepare_decoder, inputs, targets, hparams) - - def residual_fn(x, y): - return common_layers.layer_norm(x + tf.nn.dropout( - y, 1.0 - hparams.residual_dropout)) - - x = dp(tf.nn.dropout, decoder_input, 1.0 - hparams.residual_dropout) - extra_loss = 0.0 - for layer in xrange(hparams.num_hidden_layers): - with tf.variable_scope("layer_%d" % layer): - with tf.variable_scope("attention"): - y = dp( - common_attention.multihead_attention, - x, - None, - None, - hparams.attention_key_channels or hparams.hidden_size, - hparams.attention_value_channels or hparams.hidden_size, - hparams.hidden_size, - hparams.num_heads, - hparams.attention_dropout, - attention_type="local_mask_right", - block_length=hparams.block_length, - name="decoder_self_attention") - x = dp(residual_fn, x, y) - with tf.variable_scope("ffn"): - if str(layer) in hparams.moe_layers.split(","): - y, loss = common_layers.moe_layer( - dp, self._ps_devices, x, - hparams.mode == tf.contrib.learn.ModeKeys.TRAIN, - hparams.hidden_size, hparams.moe_hidden_size, hparams.moe_n1, - hparams.moe_n2, hparams.moe_loss_coef) - extra_loss += loss - else: - y = dp( - common_layers.conv_hidden_relu, - x, - hparams.filter_size, - hparams.hidden_size, - dropout=hparams.relu_dropout) - x = dp(residual_fn, x, y) - x = dp(long_answer_output, x, inputs) - return x, extra_loss - - -def long_answer_prepare_decoder(inputs, targets, hparams): - """Prepare one shard of the model for the decoder. - - Args: - inputs: a Tensor. - targets: a Tensor. - hparams: run hyperparameters - - Returns: - decoder_input: a Tensor, bottom of decoder stack - """ - decoder_input = tf.concat([ - length_embedding(targets, hparams), inputs, - common_layers.shift_left_3d(targets) - ], 1) - if hparams.pos == "timing": - decoder_input = common_attention.add_timing_signal_1d(decoder_input) - return decoder_input - - -def length_embedding(targets, hparams): - """An embedding indicating approximate target length. - - This is a bit of a hack, where we want to be able to request a particular - target length during inference. - During training, we sometimes provide a target length. - During eval, we never provide a target length. - - Args: - targets: a Tensor. - hparams: run hyperparameters - - Returns: - a Tensor with shape [batch, 1, hparams.hidden_size] - """ - # encode the approx target length in case we want to specify it - # during inference. - batch = tf.shape(targets)[0] - padded_target_length = tf.shape(targets)[1] - if hparams.mode == tf.contrib.learn.ModeKeys.TRAIN: - lengths = padded_target_length * tf.to_int32( - tf.less(tf.random_uniform([batch]), hparams.answer_length_prob_train)) - elif hparams.mode == tf.contrib.learn.ModeKeys.EVAL: - lengths = 0 - else: - assert hparams.mode == tf.contrib.learn.ModeKeys.INFER - lengths = hparams.answer_length_infer - lengths = tf.to_int32(tf.log(tf.to_float(lengths + 1))) - lengths = tf.zeros([batch], dtype=tf.int32) + lengths - ret = tf.gather( - tf.get_variable("answer_length", [100, hparams.hidden_size]), lengths) - return tf.expand_dims(ret, 1) - - -def long_answer_output(x, inputs): - """Strip initial part corresponding to the inputs and the length embedding.""" - x = tf.slice(x, [0, tf.shape(inputs)[1] + 1, 0], [-1, -1, -1]) - x = tf.expand_dims(x, 2) - return x - - -@registry.register_hparams -def long_answer_base(): - """Set of hyperparameters. - - Returns: - a hparams object - """ - hparams = common_hparams.basic_params1() - hparams.hidden_size = 1024 - hparams.batch_size = 8192 - hparams.max_length = 8192 - hparams.dropout = 0.0 - hparams.batching_mantissa_bits = 3 - hparams.clip_grad_norm = 0. # i.e. no gradient clipping - hparams.optimizer_adam_epsilon = 1e-9 - hparams.learning_rate_decay_scheme = "noam" - hparams.learning_rate = 0.1 - hparams.learning_rate_warmup_steps = 1000 - hparams.initializer_gain = 1.0 - hparams.num_hidden_layers = 4 - hparams.initializer = "uniform_unit_scaling" - hparams.weight_decay = 0.0 - hparams.optimizer_adam_beta1 = 0.9 - hparams.optimizer_adam_beta2 = 0.98 - hparams.num_sampled_classes = 0 - hparams.label_smoothing = 0.0 - hparams.shared_embedding_and_softmax_weights = int(True) - hparams.sampling_method = "random" - hparams.add_hparam("filter_size", 2048) # Add new ones like this. - # comma-separated list of layer numbers. - # At each of these layers, we replace the ffn with a mixture of experts. - hparams.add_hparam("moe_layers", "2") - # If moe_n2 is None, then use a flat MoE with moe_n1 experts. - # If moe_n2 is an integer, then use a hierarchical MoE - # consisting of moe_n1 groups of moe_n2 experts each. - hparams.add_hparam("moe_n1", 64) - hparams.add_hparam("moe_n2", 0) - hparams.add_hparam("moe_hidden_size", 2048) - hparams.add_hparam("moe_loss_coef", 1e-2) - # attention-related flags - hparams.add_hparam("num_heads", 8) - hparams.add_hparam("attention_key_channels", 0) - hparams.add_hparam("attention_value_channels", 0) - # All hyperparameters ending in "dropout" are automatically set to 0.0 - # when not in training mode. - hparams.add_hparam("attention_dropout", 0.0) - hparams.add_hparam("relu_dropout", 0.0) - hparams.add_hparam("residual_dropout", 0.0) - hparams.add_hparam("pos", "timing") # timing, none - hparams.add_hparam("block_length", 512) - hparams.add_hparam("answer_length_prob_train", 0.5) - hparams.add_hparam("answer_length_infer", 1000) - # We cannot handle long sequence at this point, so drop them, during eval. - # This affects evaluation metrics. - # TODO(noam): find a different workaround - hparams.eval_drop_long_sequences = int(True) - return hparams - - -@registry.register_hparams -def long_answer_tiny(): - """Cheap model for validation. - - Returns: - an hparams object. - """ - hparams = long_answer_base() - hparams.num_hidden_layers = 3 - hparams.hidden_size = 512 - hparams.filter_size = 1024 - hparams.moe_layers = "2" - hparams.moe_hidden_size = 1024 - hparams.block_length = 128 - hparams.moe_n1 = 8 - hparams.batch_size = 2048 - hparams.max_length = 2048 - return hparams - - -@registry.register_hparams -def long_answer_small(): - """Cheap model for single-gpu training. - - Returns: - an hparams object. - """ - hparams = long_answer_base() - hparams.num_hidden_layers = 4 - hparams.hidden_size = 512 - hparams.filter_size = 2048 - hparams.moe_n1 = 128 - hparams.moe_layers = "2" - hparams.moe_hidden_size = 2048 - return hparams - - -@registry.register_hparams -def long_answer_large(): - """Large model for distributed training. - - Returns: - an hparams object. - """ - hparams = long_answer_base() - hparams.num_hidden_layers = 5 - hparams.moe_layers = "3" - hparams.hidden_size = 1024 - hparams.filter_size = 4096 - hparams.moe_hidden_size = 4096 - hparams.moe_n1 = 128 - hparams.block_length = 1024 - return hparams diff --git a/tensor2tensor/models/models.py b/tensor2tensor/models/models.py index cba779fc9..d4514408d 100644 --- a/tensor2tensor/models/models.py +++ b/tensor2tensor/models/models.py @@ -30,7 +30,6 @@ from tensor2tensor.models import bytenet from tensor2tensor.models import cycle_gan from tensor2tensor.models import gene_expression -from tensor2tensor.models import long_answer from tensor2tensor.models import lstm from tensor2tensor.models import multimodel from tensor2tensor.models import neural_gpu diff --git a/tensor2tensor/models/multimodel.py b/tensor2tensor/models/multimodel.py index 290c78732..c8d515c8d 100644 --- a/tensor2tensor/models/multimodel.py +++ b/tensor2tensor/models/multimodel.py @@ -27,6 +27,7 @@ from tensor2tensor.layers import common_layers from tensor2tensor.layers import modalities from tensor2tensor.models import slicenet +from tensor2tensor.utils import expert_utils from tensor2tensor.utils import registry from tensor2tensor.utils import t2t_model @@ -76,9 +77,19 @@ def conv_experts(xs, hparams, dp, ps, padding, mask, layer_id): train = hparams.mode == tf.contrib.learn.ModeKeys.TRAIN, conv_out = dp(conv_res_step, xs, hparams, padding, mask) loss = 0.0 - moe_out, loss = common_layers.moe_layer( - dp, ps, xs, train, hparams.hidden_size, hparams.filter_size, - hparams.moe_n1, hparams.moe_n2, 1.0) + moe_hidden_sizes = [hparams.filter_size] + expert_fn = expert_utils.ffn_expert_fn( + hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) + moe_out, loss = expert_utils.distributed_moe( + dp, + ps, + xs, + train, + input_size=hparams.hidden_size, + expert_fn=expert_fn, + num_experts=hparams.moe_num_experts, + k=hparams.moe_k, + loss_coef=1.0) return dp(residual_fn3, xs, moe_out, conv_out, hparams), loss @@ -136,6 +147,9 @@ def flatten(inputs): (decoder_input, decoder_self_attention_bias) = dp(prepare_decoder, targets, target_space_emb) + moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")] + expert_fn = expert_utils.ffn_expert_fn( + hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) x = dp(tf.nn.dropout, decoder_input, 1.0 - hparams.dropout) for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("dec_layer_%d" % layer): @@ -165,10 +179,16 @@ def flatten(inputs): x = dp(residual_fn3, x, y, z, hparams) with tf.variable_scope("ffn"): if str(layer) in hparams.moe_layers.split(","): - y, moe_loss = common_layers.moe_layer( - dp, self._ps_devices, x, train, hparams.hidden_size, - hparams.filter_size, hparams.moe_n1, hparams.moe_n2, - hparams.moe_loss_coef) + y, moe_loss = expert_utils.distributed_moe( + dp, + self._ps_devices, + x, + train, + input_size=hparams.hidden_size, + expert_fn=expert_fn, + num_experts=hparams.moe_num_experts, + k=hparams.moe_k, + loss_coef=hparams.moe_loss_coef) expert_loss += tf.reduce_mean(moe_loss) else: y = dp( @@ -199,10 +219,8 @@ def multimodel_base(): hparams.add_hparam("large_kernel_size", 15) hparams.add_hparam("attention_dropout", 0.1) hparams.add_hparam("num_heads", 8) - hparams.add_hparam("moe_n1", 30) - hparams.add_hparam("moe_n2", 0) hparams.add_hparam("moe_layers", "2") - hparams.add_hparam("moe_loss_coef", 1e-2) + hparams.moe_num_experts = 30 return hparams diff --git a/tensor2tensor/models/slicenet.py b/tensor2tensor/models/slicenet.py index 1079659b5..6b07dc640 100644 --- a/tensor2tensor/models/slicenet.py +++ b/tensor2tensor/models/slicenet.py @@ -322,9 +322,6 @@ def slicenet_params1(): # A kernel scheme, one of _KERNEL_SCHEMES; overrides large_kernel_size. hparams.add_hparam("kernel_scheme", "3.7.15.31") hparams.add_hparam("audio_compression", 8) - hparams.add_hparam("moe_n1", 32) - hparams.add_hparam("moe_n2", 0) - hparams.add_hparam("moe_loss_coef", 1e-2) # attention-related flags hparams.add_hparam("attention_type", "simple") hparams.add_hparam("num_heads", 8) @@ -358,7 +355,6 @@ def slicenet_params1_tiny(): hparams.separability = 0 hparams.hidden_size = 128 hparams.num_hidden_layers = 2 - hparams.moe_n1 = 2 hparams.batch_size = 512 hparams.learning_rate_warmup_steps = 200 return hparams diff --git a/tensor2tensor/models/transformer_moe.py b/tensor2tensor/models/transformer_moe.py index 6f01667d8..669b1842b 100644 --- a/tensor2tensor/models/transformer_moe.py +++ b/tensor2tensor/models/transformer_moe.py @@ -29,6 +29,7 @@ from tensor2tensor.layers import common_hparams from tensor2tensor.layers import common_layers from tensor2tensor.models import transformer +from tensor2tensor.utils import expert_utils from tensor2tensor.utils import registry from tensor2tensor.utils import t2t_model @@ -66,6 +67,9 @@ def postprocess(x, y): decoder_input = dp(tf.nn.dropout, decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) extra_loss = 0 + moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")] + expert_fn = expert_utils.ffn_expert_fn( + hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) x = encoder_input for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("encoder_layer_%d" % layer): @@ -83,11 +87,16 @@ def postprocess(x, y): x = postprocess(x, y) with tf.variable_scope("ffn"): if str(layer) in hparams.moe_layers_encoder.split(","): - y, loss = common_layers.moe_layer( - dp, self._ps_devices, preprocess(x), + y, loss = expert_utils.distributed_moe( + dp, + self._ps_devices, + preprocess(x), hparams.mode == tf.contrib.learn.ModeKeys.TRAIN, - hparams.hidden_size, hparams.moe_hidden_size, hparams.moe_n1, - hparams.moe_n2, hparams.moe_loss_coef) + input_size=hparams.hidden_size, + expert_fn=expert_fn, + num_experts=hparams.moe_num_experts, + k=hparams.moe_k, + loss_coef=hparams.moe_loss_coef) extra_loss += loss else: y = dp( @@ -127,11 +136,16 @@ def postprocess(x, y): x = postprocess(x, y) with tf.variable_scope("ffn"): if str(layer) in hparams.moe_layers_decoder.split(","): - y, loss = common_layers.moe_layer( - dp, self._ps_devices, preprocess(x), + y, loss = expert_utils.distributed_moe( + dp, + self._ps_devices, + preprocess(x), hparams.mode == tf.contrib.learn.ModeKeys.TRAIN, - hparams.hidden_size, hparams.moe_hidden_size, hparams.moe_n1, - hparams.moe_n2, hparams.moe_loss_coef) + input_size=hparams.hidden_size, + expert_fn=expert_fn, + num_experts=hparams.moe_num_experts, + k=hparams.moe_k, + loss_coef=hparams.moe_loss_coef) extra_loss += loss else: y = dp( @@ -192,13 +206,6 @@ def transformer_moe_base(): # At each of these layers, we replace the ffn with a mixture of experts. hparams.add_hparam("moe_layers_encoder", "2") hparams.add_hparam("moe_layers_decoder", "2") - # If moe_n2 is None, then use a flat MoE with moe_n1 experts. - # If moe_n2 is an integer, then use a hierarchical MoE - # consisting of moe_n1 groups of moe_n2 experts each. - hparams.add_hparam("moe_n1", 32) - hparams.add_hparam("moe_n2", 0) - hparams.add_hparam("moe_hidden_size", 2048) - hparams.add_hparam("moe_loss_coef", 1e-2) return hparams diff --git a/tensor2tensor/utils/expert_utils.py b/tensor2tensor/utils/expert_utils.py index e21f2453a..ac58ef3cd 100644 --- a/tensor2tensor/utils/expert_utils.py +++ b/tensor2tensor/utils/expert_utils.py @@ -15,8 +15,8 @@ """Utilities for creating Sparsely-Gated Mixture-of-Experts Layers. -See the most recent draft of our ICLR paper: -https://openreview.net/pdf?id=B1ckMDqlg +See "Outrageously Large Neural Networks" +https://arxiv.org/abs/1701.06538 """ from __future__ import absolute_import @@ -35,122 +35,10 @@ from tensorflow.python.framework import function -def NoisyTopKGatingParams(): - """Hyperparams defining NoisyTopK Gating Network. - - Returns: - a tf.contrib.training.HParams object - """ - return tf.contrib.training.HParams( - gating_class=NoisyTopKGating, - num_experts=16, # The number of experts - k=2, # 'The number of experts to use per example - input_size=None, # size of input to MoE. Set by MoE class - dtype=tf.float32, # floating point data type - initializer=tf.zeros_initializer(), # initializer for weight matrices - noisy_gating=True, # Add tunable noise (necessary for load-balancing) - noise_epsilon=1e-2, # Added to noise stddev for numerical stability - ) - - -def FeedForwardExpertParams(): - """Hyperparameters defining feed-forward expert networks. - - Returns: - a tf.contrib.training.HParams object - """ - return tf.contrib.training.HParams( - # The class that implements the expert network - expert_class=FeedForwardExpert, - input_size=None, # Size of input to MoE. Set by MoE class. - # List of hidden layer sizes, or None for no hidden layers. - # The length of this list determines the number of hidden layers - hidden_layer_sizes=None, - output_size=None, # Size of output from MoE. Set by MoE class. - dtype=tf.float32, # Floating point data type) - # Activation function applied at each hidden layer) - hidden_activation=tf.nn.relu, - initializer=None, # Optional initializer for weight matrices.) - # If autoscale=True, At each hidden/output layer, multiply by - # rsqrt(prev_layer_size / input_size). This scaling happens - # before application of hidden_activation) - autoscale=True,) - - -def _SetInputOutputSizes(hp, input_size, output_size): - """Fill in the input_size and output_size hyperparameters. - - This is used by LocalMixtureOfExperts and DistributedMixtureOfExperts to - fill in the input_size and output_size on the gating parameters and expert - parameters so that the user does not have to set them in multiple places. - - Args: - hp: a hyperparameters - input_size: an integer - output_size: an integer - """ - if hp.input_size is None: - hp.input_size = input_size - else: - assert hp.input_size == input_size - if output_size is not None: - if hp.output_size is None: - hp.output_size = output_size - else: - assert hp.output_size == output_size - - -class FeedForwardExpert(object): - """An object representing a feed forward network (used as an expert). - """ - - def __init__(self, hp, name): - """Creates a FeedForwardExpert. - - Args: - hp: hyperparameters. Call FeedForwardExpertParams() to create these. - name: a string. - """ - self._hp = hp - hidden_layer_sizes = hp.hidden_layer_sizes or [] - num_layers = 1 + len(hidden_layer_sizes) - layer_sizes = [hp.input_size] + hidden_layer_sizes + [hp.output_size] - self._layer_sizes = layer_sizes - self._w = [] - for layer in range(num_layers): - shape = layer_sizes[layer:layer + 2] - self._w.append( - tf.get_variable('%s_layer_%d' % (name, layer), shape, hp.dtype, - hp.initializer)) - - def Eval(self, x): - """Evaluate the FeedForwardExpert on the given input. - - Args: - x: a `Tensor` of shape `[batch_size, hp.input_size]` - - Returns: - a `Tensor` of shape `[batch_size, hp.output_size]` - """ - hp = self._hp - num_layers = len(self._w) - for i in xrange(num_layers): - x = tf.matmul(x, self._w[i]) - if hp.autoscale and self._layer_sizes[i] != hp.input_size: - x *= (self._layer_sizes[i] / hp.input_size)**-0.5 - if i + 1 < num_layers and hp.hidden_activation: - x = hp.hidden_activation(x) - return x - - @property - def vars(self): - return self._w - - @function.Defun( python_grad_func=lambda x, dy: tf.convert_to_tensor(dy), shape_func=lambda op: [op.inputs[0].get_shape()]) -def ConvertGradientToTensor(x): +def convert_gradient_to_tensor(x): """Identity operation whose gradient is converted to a `Tensor`. Currently, the gradient to `tf.concat` is particularly expensive to @@ -159,7 +47,7 @@ def ConvertGradientToTensor(x): the output of the `tf.concat` is eventually passed to `tf.gather`. It is sometimes faster to convert the gradient to a `Tensor`, so as to get the cheaper gradient for `tf.concat`. To do this, replace - `tf.concat(x)` with `ConvertGradientToTensor(tf.concat(x))`. + `tf.concat(x)` with `convert_gradient_to_tensor(tf.concat(x))`. Args: x: A `Tensor`. @@ -196,7 +84,7 @@ def __init__(self, """Create a Parallelism. Args: - device_names_or_functions: A list of of length n, containing device names + device_names_or_functions: A list of length n, containing device names or device functions (see `tf.device`) reuse: True or None. Whether to reuse variables created in the first replica in the subsequent replicas. @@ -212,7 +100,7 @@ def __init__(self, self._devices = device_names_or_functions self._n = len(device_names_or_functions) self._reuse = reuse - self._caching_devices = self._MaybeRepeat(caching_devices) + self._caching_devices = self._maybe_repeat(caching_devices) self._daisy_chain_variables = daisy_chain_variables def __call__(self, fn, *args, **kwargs): @@ -231,24 +119,25 @@ def __call__(self, fn, *args, **kwargs): """ # Construct lists or args and kwargs for each function. if args: - my_args = TransposeListOfLists([self._MaybeRepeat(arg) for arg in args]) + my_args = transpose_list_of_lists( + [self._maybe_repeat(arg) for arg in args]) else: my_args = [[] for _ in xrange(self.n)] my_kwargs = [{} for _ in xrange(self.n)] for k, v in six.iteritems(kwargs): - vals = self._MaybeRepeat(v) + vals = self._maybe_repeat(v) for i in xrange(self.n): my_kwargs[i][k] = vals[i] # Construct lists of functions. - fns = self._MaybeRepeat(fn) + fns = self._maybe_repeat(fn) # Now make the parallel call. outputs = [] cache = {} for i in xrange(self.n): - def DaisyChainGetter(getter, name, *args, **kwargs): + def daisy_chain_getter(getter, name, *args, **kwargs): """Get a variable and cache in a daisy chain.""" device_var_key = (self._devices[i], name) if device_var_key in cache: @@ -268,7 +157,7 @@ def DaisyChainGetter(getter, name, *args, **kwargs): # Variable scope will not reset caching_device on reused variables, # so we make a custom getter that uses identity to cache the variable. # pylint: disable=cell-var-from-loop - def CachingGetter(getter, name, *args, **kwargs): + def caching_getter(getter, name, *args, **kwargs): v = getter(name, *args, **kwargs) key = (self._caching_devices[i], name) if key in cache: @@ -279,15 +168,15 @@ def CachingGetter(getter, name, *args, **kwargs): return ret if self._daisy_chain_variables: - custom_getter = DaisyChainGetter - elif self._caching_devices: - custom_getter = CachingGetter + custom_getter = daisy_chain_getter + elif self._caching_devices[i]: + custom_getter = caching_getter else: custom_getter = None # pylint: enable=cell-var-from-loop - with tf.name_scope('parallel_%d' % i): + with tf.name_scope("parallel_%d" % i): with tf.variable_scope( - tf.get_variable_scope(), + tf.get_variable_scope() if self._reuse else "parallel_%d" % i, reuse=True if i > 0 and self._reuse else None, caching_device=self._caching_devices[i], custom_getter=custom_getter): @@ -306,7 +195,7 @@ def n(self): def devices(self): return self._devices - def _MaybeRepeat(self, x): + def _maybe_repeat(self, x): """Utility function for processing arguments that are singletons or lists. Args: @@ -322,25 +211,7 @@ def _MaybeRepeat(self, x): return [x] * self.n -def Parallel(device_names_or_functions, fn, *args): - """Deprecated interface. - - Use `Parallelism(device_names_or_functions)(fn, *args)` instead. - - Args: - device_names_or_functions: A list of length n. - fn: a function or a list of n functions. - *args: additional args. Each arg should either be not a list, or a list - of length n. - - Returns: - either a single list of length n (if fn does not return a tuple), or a - tuple of lists of length n (if fn returns a tuple). - """ - return Parallelism(device_names_or_functions)(fn, *args) - - -def _RowwiseUnsortedSegmentSum(values, indices, n): +def _rowwise_unsorted_segment_sum(values, indices, n): """UnsortedSegmentSum on each row. Args: @@ -357,7 +228,7 @@ def _RowwiseUnsortedSegmentSum(values, indices, n): return tf.reshape(ret_flat, [batch, n]) -def _NormalDistributionCDF(x, stddev): +def _normal_distribution_cdf(x, stddev): """Evaluates the CDF of the normal distribution. Normal distribution with mean 0 and standard deviation stddev, @@ -376,7 +247,8 @@ def _NormalDistributionCDF(x, stddev): return 0.5 * (1.0 + tf.erf(x / (math.sqrt(2) * stddev + 1e-20))) -def _ProbInTopK(clean_values, noisy_values, noise_stddev, noisy_top_values, k): +def _prob_in_top_k( + clean_values, noisy_values, noise_stddev, noisy_top_values, k): """Helper function to NoisyTopKGating. Computes the probability that value is in top k, given different random noise. @@ -393,7 +265,7 @@ def _ProbInTopK(clean_values, noisy_values, noise_stddev, noisy_top_values, k): normally distributed noise with standard deviation noise_stddev. noise_stddev: a `Tensor` of shape [batch, n], or None noisy_top_values: a `Tensor` of shape [batch, m]. - 'values' Output of tf.top_k(noisy_top_values, m). m >= k+1 + "values" Output of tf.top_k(noisy_top_values, m). m >= k+1 k: an integer. Returns: @@ -415,15 +287,15 @@ def _ProbInTopK(clean_values, noisy_values, noise_stddev, noisy_top_values, k): threshold_if_out = tf.expand_dims( tf.gather(top_values_flat, threshold_positions_if_out), 1) # is each value currently in the top k. - prob_if_in = _NormalDistributionCDF(clean_values - threshold_if_in, - noise_stddev) - prob_if_out = _NormalDistributionCDF(clean_values - threshold_if_out, - noise_stddev) + prob_if_in = _normal_distribution_cdf(clean_values - threshold_if_in, + noise_stddev) + prob_if_out = _normal_distribution_cdf(clean_values - threshold_if_out, + noise_stddev) prob = tf.where(is_in, prob_if_in, prob_if_out) return prob -def CVSquared(x): +def cv_squared(x): """The squared coefficient of variation of a sample. Useful as a loss to encourage a positive distribution to be more uniform. @@ -443,33 +315,7 @@ def CVSquared(x): return variance / (tf.square(mean) + epsilon) -def MaxOverload(load): - """The load of the hardest-hit device relative to average. - - This is useful for monitoring the performance of MoEs. - - The load of an expert is the number of examples assigned to that expert. - The load of a device is the sum of the loads of all experts on that device. - - The input to this function is generally the 'load' output of - DistributedMixtureOfExperts.Eval(), which is either a 1d or 2d `Tensor` of - per-expert loads. In either case, the fist dimension corresponds to devices. - - This function sums over all dimensions other than dimension zero, then - computes the ratio of the maxmium value to the mean value. - - Args: - load: a 1d or 2d `Tensor`. - - Returns: - a `Scalar`. - """ - per_device_load = tf.reduce_sum(tf.reshape(load, [tf.shape(load)[0], -1]), 1) - return (tf.reduce_max(per_device_load) / - (tf.reduce_mean(per_device_load) + 1e-10)) - - -def _GatesToLoad(gates): +def _gates_to_load(gates): """Compute the true load per expert, given the gates. The load is the number of examples for which the corresponding gate is >0. @@ -482,11 +328,16 @@ def _GatesToLoad(gates): return tf.reduce_sum(tf.to_float(gates > 0), 0) -def _MyTopK(x, k): +def _my_top_k(x, k): """GPU-compatible version of top-k that works for very small constant k. Calls argmax repeatedly. + tf.nn.top_k is implemented for GPU, but the gradient, sparse_to_dense, + seems not to be, so if we use tf.nn.top_k, then both the top_k and its + gradient go on cpu. Once this is not an issue, this function becomes + obselete and should be replaced by tf.nn.top_k. + Args: x: a 2d Tensor. k: a small integer. @@ -509,374 +360,72 @@ def _MyTopK(x, k): return tf.stack(values, axis=1), tf.to_int32(tf.stack(indices, axis=1)) -class NoisyTopKGating(object): - """Noisy top-k gating network. +def noisy_top_k_gating(x, + input_size, + num_experts, + train, + k=2, + initializer=tf.zeros_initializer(), + noisy_gating=True, + noise_epsilon=1e-2, + name=None): + """Noisy top-k gating. See paper: https://arxiv.org/abs/1701.06538. - """ - - def __init__(self, hp, name): - """Create a NoisyTopKGating network. - - Args: - hp: a hyperparameters created by NoisyTopKGatingParams() - name: a string - """ - self._vars = [] - self._hp = hp - self._w_gate = tf.get_variable('%s_gate' % name, - [hp.input_size, - hp.num_experts], hp.dtype, hp.initializer) - self._vars.append(self._w_gate) - if hp.noisy_gating: - self._w_noise = tf.get_variable('%s_noise' % name, - [hp.input_size, hp.num_experts], hp.dtype, - hp.initializer) - self._vars.append(self._w_noise) - - def Eval(self, x, train=True, summaries=False): - """Compute noisy top-k gating. - - Args: - x: a `Tensor` of shape `[batch_size, input_size]`. - train: a boolean `Scalar`. Setting this to false turns off noise. - summaries: a boolean. Whether to add summaries. - Returns: - gates: a `Tensor` of shape `[batch_size, n]` - load: a `Tensor` of shape `[n]`. - If we are using noise, this is a smooth approximation of the load, - and you can define a loss in terms of it to help with load-balancing. - """ - with tf.variable_scope('NoisyTopKGating'): - hp = self._hp - clean_logits = tf.matmul(x, self._w_gate) - if hp.noisy_gating: - raw_noise_stddev = tf.matmul(x, self._w_noise) - noise_stddev = ((tf.nn.softplus(raw_noise_stddev) + hp.noise_epsilon) * - (tf.to_float(train))) - noisy_logits = clean_logits + ( - tf.random_normal(tf.shape(clean_logits)) * noise_stddev) - logits = noisy_logits - if summaries: - tf.summary.histogram('noisy_logits', noisy_logits) - tf.summary.histogram('noise_stddev', noise_stddev) - else: - logits = clean_logits - top_logits, top_indices = _MyTopK(logits, min(hp.k + 1, hp.num_experts)) - top_k_logits = tf.slice(top_logits, [0, 0], [-1, hp.k]) - top_k_indices = tf.slice(top_indices, [0, 0], [-1, hp.k]) - top_k_gates = tf.nn.softmax(top_k_logits) - # This will be a `Tensor` of shape `[batch_size, n]`, with zeros in the - # positions corresponding to all but the top k experts per example. - gates = _RowwiseUnsortedSegmentSum(top_k_gates, top_k_indices, - hp.num_experts) - if hp.noisy_gating and hp.k < hp.num_experts: - load = tf.reduce_sum( - _ProbInTopK(clean_logits, noisy_logits, noise_stddev, top_logits, - hp.k), 0) - else: - load = _GatesToLoad(gates) - if summaries: - tf.summary.histogram('importance', tf.reduce_sum(gates, 0)) - tf.summary.histogram('load', load) - return gates, load - - @property - def vars(self): - return self._vars - - -class LocalMixtureOfExperts(object): - """A MoE on a single device. - """ - - def __init__(self, gating_hp, expert_hp, input_size, output_size, name): - """Create a LocalMixtureOfExperts. - - Args: - gating_hp: hyperparameters for the gating network. - e.g. NoisyTopKGatingParams() - expert_hp: hyperparameters for the expert networks. - e.g. FeedForwardExpertParams() - input_size: an integer. - output_size: an integer. - name: a string. - """ - self._name = name - _SetInputOutputSizes(gating_hp, input_size, None) - _SetInputOutputSizes(expert_hp, input_size, output_size) - self._gating_hp = gating_hp - self._gating = gating_hp.gating_class(gating_hp, name + '_gating') - self._expert_hp = expert_hp - self._experts = [ - expert_hp.expert_class(expert_hp, name + '_%d' % i) - for i in xrange(gating_hp.num_experts) - ] - - def Eval(self, - x, - train=True, - per_example_multiplier=None, - summaries=False, - identifiers=None): - """Evaluate mixture of experts. - - We provide a convenient debugging tool for determining the set of examples - that we passed to each expert. The caller may provide a `Tensor` of - "identifiers", of any type whose first dimension matches the number of - input examples. The function will then return a list - "expert_to_identifiers", with one `Tensor` for each expert containing the - identifiers for all examples assigned to that expert. A parallel list of - `Tensor`s, "expert_to_gates", is also returned, containing the - corresponding gate values. - - Args: - x: a `Tensor` of shape `[batch_size, input_size]` - train: a boolean Scalar. Are we in training mode? - per_example_multiplier: an optional `Tensor` of shape `[batch_size]` which - gets multiplied into the gate values. If this LocalMixtureOfExperts - represents one secondary MoE in a hierarchical MoE, then we pass in - in the gate values from the primary gating function here. This causes - the computed values (`y`, `importance` and `expert_to_gates`) to also - reflect the primary gate values. - summaries: an boolean. Enable summaries. - identifiers: an optional `Tensor` whose first dimension is equal to - batch_size. - - Returns: - y: a `Tensor` of shape `[batch_size, output_size]`. Output of the MoE. - importance: a `Tensor` of shape `[n]`. Batchwise sum of gates. - load: a `Tensor` of shape `[n]`. Smooth estimator of the number of - examples passed to each expert. This is useful for load-balancing, - as any gradient on this `Tensor` will back-propagate to the gating - network. - expert_to_identifiers: if `identifiers` was passed in, a list of - length `num_experts`. Each element is a `Tensor` whose shape matches - that of `identifiers` in all but the first dimension. Contains the - slices of `identifiers` corresponding to the batch elements that were - dispatched to that expert. - expert_to_gates: A list of length `num_experts`. Each element contains - a 1-dimensional tensor - """ - gating_hp = self._gating_hp - gates, load = self._gating.Eval(x, train, summaries) - if per_example_multiplier is not None: - gates *= tf.expand_dims(per_example_multiplier, 1) - dispatcher = SparseDispatcher(gating_hp.num_experts, gates) - expert_input = dispatcher.Dispatch(x) - expert_output = [ - self._experts[i].Eval(expert_input[i]) - for i in xrange(gating_hp.num_experts) - ] - y = dispatcher.Combine(expert_output) - if identifiers is not None: - expert_to_identifiers = dispatcher.Dispatch(identifiers) - else: - expert_to_identifiers = None - return (y, tf.reduce_sum(gates, 0), load, expert_to_identifiers, - dispatcher.ExpertToGates()) - - @property - def vars(self): - ret = [] - for x in self._experts: - ret.extend(x.vars) - ret.extend(self._gating.vars) - return ret - - -class DistributedMixtureOfExperts(object): - """Distributed (optionally Hierarchical) Mixture of Experts. - - This class implements the scheme described in our paper. - See link at the top of this file. - - The model is trained synchronously using one large TF graph using - multiple devices. - The conventional (non-MoE) layers use data-parallelism, with each device - processing a subset of the training batch. We call these datashards. + Args: + x: input Tensor with shape [batch_size, input_size] + input_size: an integer + num_experts: an integer + train: a boolean - we only add noise at training time. + k: an integer - number of experts per example + initializer: an initializer + noisy_gating: a boolean + noise_epsilon: a float + name: an optional string - The MoE layer (this object) uses model parallelism. Each expert is assigned - to a particular device, which hosts the expert parameters and performs the - expert computation for all examples assigned to that expert. In the case - of a hierarchical MoE, each second-level MoE is assigned to a device. + Returns: + gates: a Tensor with shape [batch_size, num_experts] + load: a Tensor with shape [num_experts] """ - - def __init__(self, primary_gating_hp, secondary_gating_hp, expert_hp, - input_size, output_size, expert_devices, name): - """Create a DistributedMixtureOfExperts. - - If `secondary_gating_hp` is `None`, then this is a flat MoE with - `primary_gating_hp.num_experts` experts. Otherwise, this is a hierarchical - MoE with `primary_gating_hp.num_experts` groups of - `secondary_gating_hp.num_experts` experts. - - The assignemnt of experts (or groups of experts) to devices is by - round-robin. So to make equal use of all the devices, one should set - `primary_gating_hp.num_experts` to the number of devices or a multiple - thereof. - - Args: - primary_gating_hp: hyperparameters for the primary gating network. - e.g. NoisyTopKGatingParams(). - secondary_gating_hp: hyperparameters for the secondary gating network. - e.g. NoisyTopKGatingParams(). None indicates a flat MoE. - expert_hp: hyperparameters for the expert networks. - e.g. FeedForwardExpertParams() - input_size: an integer. - output_size: an integer. - expert_devices: a list of device strings. The devices to be used for - the experts. - name: a string. - """ - self._name = name - # fill in the missing values in the hyperparameters - _SetInputOutputSizes(primary_gating_hp, input_size, None) - _SetInputOutputSizes(expert_hp, input_size, output_size) - self._is_hierarchical = secondary_gating_hp is not None - self._primary_gating_hp = primary_gating_hp - self._primary_gating = primary_gating_hp.gating_class( - primary_gating_hp, name + '_primary_gating') - n1 = self._primary_gating_hp.num_experts - # round robin assignment of experts to devices. - expert_devices = [ - expert_devices[i % len(expert_devices)] for i in xrange(n1) - ] - self._expert_devices = expert_devices - self._all_vars = [] - self._all_vars.extend(self._primary_gating.vars) - if self._is_hierarchical: - # hierarchical MoE - self._secondary_moe = [] - for i in xrange(n1): - with tf.device(expert_devices[i]): - secondary_moe = LocalMixtureOfExperts(secondary_gating_hp, expert_hp, - input_size, output_size, - '%s_secondary_%d' % (name, i)) - self._secondary_moe.append(secondary_moe) - self._all_vars.extend(secondary_moe.vars) + with tf.variable_scope(name, default_name="noisy_top_k_gating"): + w_gate = tf.get_variable( + "w_gate", [input_size, num_experts], tf.float32, initializer) + if noisy_gating: + w_noise = tf.get_variable("w_noise", + [input_size, num_experts], tf.float32, + initializer) + clean_logits = tf.matmul(x, w_gate) + if noisy_gating: + raw_noise_stddev = tf.matmul(x, w_noise) + noise_stddev = ((tf.nn.softplus(raw_noise_stddev) + noise_epsilon) * + (tf.to_float(train))) + noisy_logits = clean_logits + ( + tf.random_normal(tf.shape(clean_logits)) * noise_stddev) + logits = noisy_logits + if not tf.get_variable_scope().reuse: + tf.summary.histogram("noisy_logits", noisy_logits) + tf.summary.histogram("noise_stddev", noise_stddev) else: - # flat MoE - self._experts = [] - for i in xrange(n1): - with tf.device(expert_devices[i]): - expert = expert_hp.expert_class(expert_hp, name + '_%d' % i) - self._experts.append(expert) - self._all_vars.extend(expert.vars) - - def Eval(self, - datashard_devices, - xs, - train=True, - summaries=False, - identifiers=None, - shadow_xs=None): - """Evaluate MoE on given inputs. - - This class is designed for the case where the rest of the model is using - data parallelism. We receive an array of input `Tensor`s, one per - datashard, and we produce a list of output Tensors, one per datashard. - - We provide a convenient debugging tool for determining the set of examples - that we passed to each expert. The caller may provide a `Tensor` of - "identifiers", of any type whose first dimension matches the number of - input examples. The function will then return a list - "expert_to_identifiers", with one `Tensor` for each expert containing the - identifiers for all examples assigned to that expert. A parallel list of - `Tensor`s, "expert_to_gates", is also returned, containing the - corresponding gate values. - - Args: - datashard_devices: a `list` of device strings of length `num_datashards`. - Which devices to use for the output tensors. - xs: A `list` of `Tensor`s of length `num_datashards`. Each has shape - `[batch_size[d], input_size]. - train: a boolean `Scalar`. When train=`True`, noise is added to the - gating function. - summaries: a boolean. Whether to write summaries. - identifiers: an optional list of tensors. - Each tensor has shape [, extra_dims] - shadow_xs: Optional `list` of `Tensor`s of length `num_datashards`. Each - has shape `[batch_size[d], input_size]. Shadow_xs is useful if you want - to dispatch a transformed version of xs to the experts, but you want - untransformed xs for the gating network. - - Returns: - ys: the output (a list of one tensor per datashard). Each has shape - `[batch_size[d], output_size]. - importance: a `Tensor` of shape `[n]` for a flat MoE or `[n1, n2]` for a - hierarchical MoE. Batchwise sum of gates. - load: a `Tensor` of shape `[n]` for a flat MoE or `[n1, n2]` for a - hierarchical MoE. Smooth estimator of the number of - examples passed to each expert. This is useful for load-balancing, - as any gradient on this `Tensor` will back-propagate to the gating - network. - expert_to_identifiers: if `identifiers` was passed in, a list of - length `num_experts`. Each element is a `Tensor` whose shape matches - that of `identifiers` in all but the first dimension. Contains the - slices of `identifiers` corresponding to the batch elements that were - dispatched to that expert. - expert_to_gates: a list of one tensor per expert. - Each tensor has shape [] - - """ - n1 = self._primary_gating_hp.num_experts - epsilon = 1e-10 - assert len(datashard_devices) == len(xs) - num_datashards = len(xs) - expert_devices = self._expert_devices - has_identifiers = identifiers is not None - # pylint: disable=unbalanced-tuple-unpacking - primary_gates, primary_smooth_load = Parallel( - datashard_devices, self._primary_gating.Eval, xs, train, - [summaries] + [False] * (num_datashards - 1)) - primary_importance = tf.add_n( - Parallel(datashard_devices, tf.reduce_sum, primary_gates, 0)) - primary_smooth_load = tf.add_n(primary_smooth_load) - primary_true_load = tf.add_n( - Parallel(datashard_devices, _GatesToLoad, primary_gates)) - primary_dispatcher = DistributedSparseDispatcher( - datashard_devices, expert_devices, primary_gates) - - if shadow_xs is None: - secondary_input = primary_dispatcher.Dispatch(xs) + logits = clean_logits + top_logits, top_indices = _my_top_k(logits, min(k + 1, num_experts)) + top_k_logits = tf.slice(top_logits, [0, 0], [-1, k]) + top_k_indices = tf.slice(top_indices, [0, 0], [-1, k]) + top_k_gates = tf.nn.softmax(top_k_logits) + # This will be a `Tensor` of shape `[batch_size, n]`, with zeros in the + # positions corresponding to all but the top k experts per example. + gates = _rowwise_unsorted_segment_sum(top_k_gates, top_k_indices, + num_experts) + if noisy_gating and k < num_experts: + load = tf.reduce_sum( + _prob_in_top_k(clean_logits, noisy_logits, noise_stddev, top_logits, + k), 0) else: - secondary_input = primary_dispatcher.Dispatch(shadow_xs) - - primary_expert_to_identifiers = (primary_dispatcher.Dispatch(identifiers) - if has_identifiers else None) - primary_expert_to_gates = primary_dispatcher.ExpertToGates() - if not self._is_hierarchical: - # one-level distributed mixture of experts - secondary_output = Parallel(expert_devices, lambda a, b: a.Eval(b), - self._experts, secondary_input) - ys = primary_dispatcher.Combine(secondary_output) - return (ys, primary_importance, primary_smooth_load, - primary_expert_to_identifiers, primary_expert_to_gates) - # two-level hierarchical MoE - (secondary_output, secondary_importance, secondary_load, - secondary_expert_to_identifiers, secondary_expert_to_gates) = (Parallel( - expert_devices, [m.Eval for m in self._secondary_moe], secondary_input, - train, primary_expert_to_gates, [summaries] + [False] * (n1 - 1), - primary_expert_to_identifiers)) - # pylint: enable=unbalanced-tuple-unpacking - ys = primary_dispatcher.Combine(secondary_output, multiply_by_gates=False) - importance = tf.stack(secondary_importance) - load = tf.stack(secondary_load) * tf.expand_dims(primary_smooth_load / ( - primary_true_load + epsilon), 1) - expert_to_identifiers = [] - if identifiers is not None: - for el in secondary_expert_to_identifiers: - expert_to_identifiers.extend(el) - expert_to_gates = [] - for el in secondary_expert_to_gates: - expert_to_gates.extend(el) - return (ys, importance, load, expert_to_identifiers, expert_to_gates) - - @property - def vars(self): - return self._all_vars + load = _gates_to_load(gates) + if not tf.get_variable_scope().reuse: + tf.summary.histogram("importance", tf.reduce_sum(gates, 0)) + tf.summary.histogram("load", load) + return gates, load class SparseDispatcher(object): @@ -889,9 +438,9 @@ class SparseDispatcher(object): experts: a list of length `num_experts` containing sub-networks. dispatcher = SparseDispatcher(num_experts, gates) - expert_inputs = dispatcher.Dispatch(inputs) + expert_inputs = dispatcher.dispatch(inputs) expert_outputs = [experts[i](expert_inputs[i]) for i in range(num_experts)] - outputs = dispatcher.Combine(expert_outputs) + outputs = dispatcher.combine(expert_outputs) The preceding code sets the output for a particular example b to: output[b] = Sum_i(gates[b, i] * experts[i](inputs[b])) @@ -920,14 +469,14 @@ def __init__(self, num_experts, gates): tf.reshape(self._gates, [-1]), self._batch_index * num_experts + self._expert_index) - def Dispatch(self, inp): + def dispatch(self, inp): """Create one input Tensor for each expert. The `Tensor` for a expert `i` contains the slices of `inp` corresponding to the batch elements `b` where `gates[b, i] > 0`. Args: - inp: a `Tensor` of shape '[batch_size, ]` + inp: a `Tensor` of shape "[batch_size, ]` Returns: a list of `num_experts` `Tensor`s with shapes `[expert_batch_size_i, ]`. @@ -935,7 +484,7 @@ def Dispatch(self, inp): inp = tf.gather(inp, self._batch_index) return tf.split(inp, self._part_sizes_tensor, 0) - def Combine(self, expert_out, multiply_by_gates=True): + def combine(self, expert_out, multiply_by_gates=True): """Sum together the expert output, weighted by the gates. The slice corresponding to a particular batch element `b` is computed @@ -951,15 +500,15 @@ def Combine(self, expert_out, multiply_by_gates=True): Returns: a `Tensor` with shape `[batch_size, ]`. """ - # see comments on ConvertGradientToTensor - stitched = ConvertGradientToTensor(tf.concat(expert_out, 0)) + # see comments on convert_gradient_to_tensor + stitched = convert_gradient_to_tensor(tf.concat(expert_out, 0)) if multiply_by_gates: stitched *= tf.expand_dims(self._nonzero_gates, 1) combined = tf.unsorted_segment_sum(stitched, self._batch_index, tf.shape(self._gates)[0]) return combined - def ExpertToGates(self): + def expert_to_gates(self): """Gate values corresponding to the examples in the per-expert `Tensor`s. Returns: @@ -985,28 +534,25 @@ class DistributedSparseDispatcher(object): `Tensor`s are created on those devices. There is no single-device bottleneck. """ - def __init__(self, datashard_devices, expert_devices, gates): + def __init__(self, data_parallelism, expert_parallelism, gates): """Create a DistributedSparseDispatcher. Args: - datashard_devices: a list of num_datashards device strings. - expert_devices: a list of num_experts device strings. - gates: a list of num_datashards `Tensor`s of shapes + data_parallelism: a Parallelism object. + expert_parallelism: a Parallelism object. + gates: a list of datashard_parallelism.n `Tensor`s of shapes `[batch_size[d], num_experts]`. Returns: a DistributedSparseDispatcher """ self._gates = gates - self._num_experts = len(expert_devices) - assert len(gates) == len(datashard_devices) - self._num_datashards = len(gates) - self._datashard_devices = datashard_devices - self._expert_devices = expert_devices - self._dispatchers = Parallel(self._datashard_devices, SparseDispatcher, - self._num_experts, gates) - - def Dispatch(self, inp): + self._dp = data_parallelism + self._ep = expert_parallelism + assert len(gates) == self._dp.n + self._dispatchers = self._dp(SparseDispatcher, self._ep.n, gates) + + def dispatch(self, inp): """Create one input Tensor for each expert. Args: @@ -1016,16 +562,14 @@ def Dispatch(self, inp): a list of `num_experts` `Tensor`s with shapes `[num_examples[i], ]`. """ - dispatched = Parallel(self._datashard_devices, lambda a, b: a.Dispatch(b), - self._dispatchers, inp) - ret = Parallel(self._expert_devices, tf.concat, - TransposeListOfLists(dispatched), 0) + dispatched = self._dp(lambda a, b: a.dispatch(b), self._dispatchers, inp) + ret = self._ep(tf.concat, transpose_list_of_lists(dispatched), 0) if ret[0].dtype == tf.float32: - # see comments on ConvertGradientToTensor - ret = Parallel(self._expert_devices, ConvertGradientToTensor, ret) + # see comments on convert_gradient_to_tensor + ret = self._ep(convert_gradient_to_tensor, ret) return ret - def Combine(self, expert_out, multiply_by_gates=True): + def combine(self, expert_out, multiply_by_gates=True): """Sum together the expert output, multiplied by the corresponding gates. Args: @@ -1038,40 +582,31 @@ def Combine(self, expert_out, multiply_by_gates=True): `[batch_size[d], ]`. """ expert_part_sizes = tf.unstack( - tf.stack([ - self._dispatchers[d].part_sizes - for d in xrange(self._num_datashards) - ]), - num=self._num_experts, + tf.stack([d.part_sizes for d in self._dispatchers]), + num=self._ep.n, axis=1) # list of lists of shape [num_experts][num_datashards] - expert_output_parts = Parallel(self._expert_devices, tf.split, expert_out, - expert_part_sizes) - expert_output_parts_t = TransposeListOfLists(expert_output_parts) - ret = [] - for d in xrange(self._num_datashards): - with tf.device(self._datashard_devices[d]): - ret.append(self._dispatchers[d].Combine( - # see comments on ConvertGradientToTensor - ConvertGradientToTensor(tf.concat(expert_output_parts_t[d], 0)), - multiply_by_gates=multiply_by_gates)) - return ret - - def ExpertToGates(self): + expert_output_parts = self._ep(tf.split, expert_out, expert_part_sizes) + expert_output_parts_t = transpose_list_of_lists(expert_output_parts) + def my_combine(dispatcher, parts): + return dispatcher.combine( + convert_gradient_to_tensor(tf.concat(parts, 0)), + multiply_by_gates=multiply_by_gates) + return self._dp(my_combine, self._dispatchers, expert_output_parts_t) + + def expert_to_gates(self): """Gate values corresponding to the examples in the per-expert `Tensor`s. Returns: a list of `num_experts` one-dimensional `Tensor`s of type `tf.float32`. """ - return Parallel(self._expert_devices, tf.concat, - TransposeListOfLists( - Parallel(self._datashard_devices, [ - self._dispatchers[d].ExpertToGates - for d in xrange(self._num_datashards) - ])), 0) + return self._ep( + tf.concat, + transpose_list_of_lists( + self._dp(lambda d: d.expert_to_gates(), self._dispatchers)), 0) -def TransposeListOfLists(lol): +def transpose_list_of_lists(lol): """Transpose a list of equally-sized python lists. Args: @@ -1079,205 +614,110 @@ def TransposeListOfLists(lol): Returns: a list of lists """ - assert lol, 'cannot pass the empty list' + assert lol, "cannot pass the empty list" return [list(x) for x in zip(*lol)] -class DistributedSingleDispatcher(object): - """Dispatches to experts according to gates. - - Each example goes to one expert. - - Unlike SparseDispatcher, the gates are one-dimensional `Tensor`s of integer - expert ids. There are no weights. - """ +def ffn_expert_fn(input_size, + hidden_sizes, + output_size, + hidden_activation=tf.nn.relu): + """Returns a function that creates a feed-forward network. - def __init__(self, data_parallelism, model_parallelism, gates): - """Constructs a Dispatcher. - - Args: - data_parallelism: a Parallelism object. - model_parallelism: a Parallelism object. - gates: a list of 1d integer `Tensor`s, one per datashard. - Says which expert to use for each batch element. - - Returns: - a DistributedSingleDispatcher - """ - gates = data_parallelism(tf.to_int32, gates) - self._gates = gates - self._data_parallelism = data_parallelism - self._model_parallelism = model_parallelism - - # Compute the sizes number of examples going from each datashard to each - # expert. - def _PartSizes(gates): - return tf.unsorted_segment_sum( - tf.ones_like(gates), gates, model_parallelism.n) - - part_sizes_by_datashard = data_parallelism(_PartSizes, gates) - self._part_sizes_by_expert = tf.unstack( - tf.stack(part_sizes_by_datashard), num=model_parallelism.n, axis=1) - - # These indices will be used to combine the output on the datashards. - def _StitchIndices(gates): - return tf.dynamic_partition( - tf.range(tf.size(gates)), gates, model_parallelism.n) - - self._stitch_indices = data_parallelism(_StitchIndices, gates) - - def Dispatch(self, d_tensors): - """Reshuffles input `Tensor`s to produce output `Tensor`s. - - The dimensions of all input and output `Tensor`s match, except for - dimension 0. In dimension 0, the input `Tensor`s match the corresponding - `gates` `Tensor`s which were passed to the constructor. - - Args: - d_tensors: a list of `Tensor`s, one per datashard. - - Returns: - a list of `Tensor`s, one per expert. - - """ - parts = self._data_parallelism(tf.dynamic_partition, d_tensors, self._gates, - self._model_parallelism.n) - parts_by_expert = TransposeListOfLists(parts) - x_tensors = self._model_parallelism(tf.concat, parts_by_expert, 0) - return x_tensors - - def Combine(self, x_tensors): - """Reshuffles per-expert `Tensor`s to produce per-datashard `Tensor`s. - - Dispatch must have been called at least once first. - - The dimensions of all input and output `Tensor`s match, except for - dimension 0. In dimension 0, the input `Tensor`s match the corresponding - outputs of `Dispatch`, and the output `Tensor`s match the corresponding - `gates` `Tensor`s which were passed to the constructor. - - Args: - x_tensors: a list of `Tensor`s, one per expert. - - Returns: - a list of `Tensor`s, one per datashard. - """ - parts = self._model_parallelism(tf.split, x_tensors, - self._part_sizes_by_expert) - d_tensors = self._data_parallelism(tf.dynamic_stitch, self._stitch_indices, - TransposeListOfLists(parts)) - return d_tensors - - -def ParallelEmbeddingLookup(params, ids, data_parallelism): - """Mod-sharded embedding lookup with multiple datashards. - - TODO(noam): does this work when vocab_size is not a multiple of `num_shards`? + Use this function to create the expert_fn argument to distributed_moe. Args: - params: A list of `num_shards` `Tensors`, each with shapes - `[vocab_size / num_params, depth]`. - ids: A list of `num_datashards` one-dimensional ineger `Tensors`, - with shapes `[batch_size[i]]` - data_parallelism: A Parallelism object. + input_size: an integer + hidden_sizes: a list of integers + output_size: an integer + hidden_activation: a unary function. Returns: - a list of `num_datashards` `Tensors`, each with shape - `[batch_size[i], depth]`. + a unary function """ - param_devices = [x.device for x in params] - model_parallelism = Parallelism(param_devices) - num_shards = len(param_devices) - # pylint: disable=unbalanced-tuple-unpacking - ids, unique_idx = data_parallelism(tf.unique, ids) - # pylint: enable=unbalanced-tuple-unpacking - gates = data_parallelism(tf.mod, ids, num_shards) - ids_div = data_parallelism(tf.div, ids, num_shards) - dispatcher = DistributedSingleDispatcher(data_parallelism, model_parallelism, - gates) - x_ids_div = dispatcher.Dispatch(ids_div) - params = model_parallelism(ConvertGradientToTensor, params) - x_emb = model_parallelism(tf.gather, params, x_ids_div) - r_emb = dispatcher.Combine(x_emb) - r_emb = data_parallelism(tf.gather, r_emb, unique_idx) - return r_emb - - -def SampledSoftmaxLoss(features, sampler, num_classes, target_classes, - target_params, sampled_classes, sampled_params): - """Loss for training softmax classifiers on large label vocabulary. - - This function assumes that we have already chosen the sampled classes and - fetched the parameters for the target classes and the sampled classes. + def my_fn(x): + layer_sizes = [input_size] + hidden_sizes + [output_size] + for i in xrange(1 + len(hidden_sizes)): + w = tf.get_variable("w_%d" % i, layer_sizes[i:i+2], tf.float32) + x = tf.matmul(x, w) + if i < len(hidden_sizes): + x = hidden_activation(x) + if layer_sizes[i] != input_size: + x *= (layer_sizes[i] / float(input_size))**-0.5 + return x + return my_fn - Args: - features: a Tensor with shape [batch_size, hidden_size] - sampler: a candidate sampler object - num_classes: an integer - target_classes: an integer Tensor with shape [batch_size] - target_params: a Tensor with shape [batch_size, hidden_size] - The parameters corresponding to the target classes. - sampled_classes: an integer tensor with shape [num_sampled_classes] - sampled_params: a Tensor with shape [num_sampled_classes, hidden_size] - The parameters corresponding to the sampled classes. - Returns: - a Tensor with shape [batch_size] - """ - sampled_logits = (tf.matmul(features, sampled_params, transpose_b=True) - - sampler.log_expected_count(sampled_classes)) - target_logits = (tf.reduce_sum(target_params * features, 1) - - sampler.log_expected_count(target_classes)) - sampled_log_denominator = tf.reduce_logsumexp( - sampled_logits, [1], name='SampledLogDenominator') - sampled_classes_mask = tf.unsorted_segment_sum( - tf.fill(tf.shape(sampled_classes), float('-inf')), sampled_classes, - num_classes) - target_log_denominator = ( - target_logits + tf.gather(sampled_classes_mask, target_classes)) - combined_log_denominator = tf.reduce_logsumexp( - tf.stack([sampled_log_denominator, target_log_denominator]), [0]) - loss = combined_log_denominator - target_logits - return loss - - -def ParallelSampledSoftmaxLoss(params, - features, - target_classes, - sampler, - num_classes, - data_parallelism, - target_weights=None): - """Computes sampled softmax loss across many datashards. - - This is used during training to efficiently train a softmax classifier layer. +def reshape_like(a, b): + """Reshapes a to match the shape of b in all but the last dimension.""" + ret = tf.reshape(a, tf.concat([tf.shape(b)[:-1], tf.shape(a)[-1:]], 0)) + ret.set_shape(b.get_shape().as_list()[:-1] + a.get_shape().as_list()[-1:]) + return ret + + +def distributed_moe(data_parallelism, + expert_devices, + xs, + train, + input_size, + expert_fn, + num_experts, + k=2, + loss_coef=1e-2, + name=None): + """Call a distributed mixture of experts. Args: - params: A list of num_param_shards Tensors, each with shape - [num_classes / num_param_shards, num_features]. - The parameters are assumed to be mod-sharded by class. - features: a list of num_datashards Tensors, each with shape - [batch_size_i, num_features] - target_classes: A list of num_datashards integer Tensors each with shape - [batch_size_i] - sampler: a candidate sampler object - num_classes: an Integer - data_parallelism: a Parallelism object - target_weights: an optional list of num_datashards Tensors each with - shape [batch_size_i] + data_parallelism: a expert_utils.Parallelism object. + expert_devices: a list of strings. We round-robin the experts across these + devices. + xs: a list of input tensors, each with shape [... , input_size] + train: a boolean scalar. + input_size: an integer (input size for this layer) + expert_fn: a unary function for each expert to run + It should take a Tensor with shape [batch_size, input_size] + and return a Tensor with shape [batch_size, output_size] + num_experts: an integer - number of experts + k: an integer - how many experts to use for each batch element + loss_coef: a scalar - multiplier on load-balancing losses + name: a string + Returns: - a Scalar. + ys: a list of tensors. Each Tensor has the same shape as the corresponding + Tensor in xs, except for the last dimension, which is output_size. + extra_training_loss: a scalar. This should be added into the overall + training loss of the model. The backpropagation of this loss + encourages all experts to be approximately equally used across a batch. """ - sampled_classes = data_parallelism(sampler.sample) - sampled_params = ParallelEmbeddingLookup(params, sampled_classes, - data_parallelism) - target_params = ParallelEmbeddingLookup(params, target_classes, - data_parallelism) - ret = data_parallelism(SampledSoftmaxLoss, features, sampler, num_classes, - target_classes, target_params, sampled_classes, - sampled_params) - if target_weights is not None: - ret = data_parallelism(tf.multiply, ret, target_weights) - ret = data_parallelism(tf.reduce_sum, ret) - ret = tf.add_n(ret) - return ret + dp = data_parallelism + # create a parallelism object for running the experts. + # We use the default of reuse=False. Otherwise, the experts would all + # use the same variables. + ep = Parallelism( + [expert_devices[i % len(expert_devices)] for i in xrange(num_experts)]) + # Experts expect 2d input tensors, so flatten the batch dimension and all + # spatial dimensions together. + xs_flat = dp(tf.reshape, xs, [[-1, input_size]] * dp.n) + with tf.variable_scope(name, default_name="moe"): + # The gates indicate which batch elements go to which tensors. + # load is a measure of approximately how many examples go to each expert + gates, load = dp(noisy_top_k_gating, + xs_flat, + input_size, + num_experts, + train, + k, + initializer=tf.zeros_initializer(), + noisy_gating=True, + noise_epsilon=1e-2) + # This magic object helps us shuffle data between datashards and experts. + dispatcher = DistributedSparseDispatcher(dp, ep, gates) + expert_in = dispatcher.dispatch(xs_flat) + expert_out = ep(expert_fn, expert_in) + ys_flat = dispatcher.combine(expert_out) + ys = dp(reshape_like, ys_flat, xs) + # compute some load-balancing losses. + load = tf.add_n(load) + importance = tf.add_n(dp(tf.reduce_sum, gates, 0)) + loss = loss_coef * (cv_squared(importance) + cv_squared(load)) + return ys, loss From 35416daf4af61361113b51218c4960f25f38bfb7 Mon Sep 17 00:00:00 2001 From: Alexander Ku Date: Thu, 10 Aug 2017 10:23:50 -0700 Subject: [PATCH 3/7] adding function for local_attention_2d PiperOrigin-RevId: 164869818 --- tensor2tensor/layers/common_attention.py | 105 ++++++++++++++++++ tensor2tensor/layers/common_attention_test.py | 46 ++++++-- 2 files changed, 142 insertions(+), 9 deletions(-) diff --git a/tensor2tensor/layers/common_attention.py b/tensor2tensor/layers/common_attention.py index 2b1bd124f..4f1273163 100644 --- a/tensor2tensor/layers/common_attention.py +++ b/tensor2tensor/layers/common_attention.py @@ -541,6 +541,111 @@ def pad_l_and_r(x, pad_length): return output +def local_attention_2d(q, + k, + v, + block_length=128, + filter_flange=100, + name=None): + """strided block local self-attention. + + Args: + q: a Tensor with shape [batch, heads, h, w, depth_k] + k: a Tensor with shape [batch, heads, h, w, depth_k] + v: a Tensor with shape [batch, heads, h, w, depth_v] + block_length: an integer indicating the side length of each square block. + filter_flange: an integer indicating how much to look around each block. + name: an optional string + + Returns: + a Tensor of shape [batch, heads, h, w, depth_v] + """ + with tf.variable_scope( + name, default_name="local_self_attention_2d", values=[q, k, v]): + v_shape = tf.shape(v) + depth_v = tf.shape(v)[4] + batch_size = tf.shape(q)[0] + num_heads = tf.shape(q)[1] + original_length = tf.shape(q)[2] * tf.shape(q)[3] + + def reshape_range(tensor, i, j, shape): + """Reshapes a tensor between dimensions i and j.""" + target_shape = tf.concat( + [tf.shape(tensor)[:i], shape, tf.shape(tensor)[j:]], + axis=0) + return tf.reshape(tensor, target_shape) + + def pad_to_multiple(x, d): + """Making sure x is a multiple of d.""" + height_padding = -tf.shape(x)[1] % d + width_padding = -tf.shape(x)[2] % d + paddings = [[0, 0], [0, 0], [0, height_padding], + [0, width_padding], [0, 0]] + return tf.pad(x, paddings) + + def gather_indices(x, block_length, stride): + """Getting gather indices.""" + # making an identity matrix kernel + kernel = tf.eye(block_length ** 2) + kernel = reshape_range(kernel, 0, 1, [block_length, block_length, 1]) + # making indices [1, h, w, 1] to appy convs + indices = tf.range(0, tf.shape(x)[2] * tf.shape(x)[3], delta=1) + indices = tf.reshape(indices, [1, tf.shape(x)[2], tf.shape(x)[3], 1]) + indices = tf.nn.conv2d( + tf.cast(indices, tf.float32), + kernel, + strides=[1, stride, stride, 1], + padding="VALID") + # making indices [num_blocks, dim] to gather + num_blocks = tf.reduce_prod(tf.shape(indices)[:2]) + indices = tf.reshape(indices, [num_blocks, -1]) + return tf.cast(indices, tf.int32) + + def gather_blocks(x, indices): + """Gathers flattened blocks from x.""" + x_shape = tf.shape(x) + x = reshape_range(x, 2, 4, [tf.reduce_prod(x_shape[2:4])]) + # [length, batch, heads, dim] + x_t = tf.transpose(x, [2, 0, 1, 3]) + x_new = tf.gather(x_t, indices) + # returns [batch, heads, num_blocks, block_length ** 2, dim] + return tf.transpose(x_new, [2, 3, 0, 1, 4]) + + q = pad_to_multiple(q, block_length) + k = pad_to_multiple(k, block_length) + v = pad_to_multiple(v, block_length) + + # Setting up k and v values + paddings = [[0, 0], [0, 0], [filter_flange, filter_flange], + [filter_flange, filter_flange], [0, 0]] + k = tf.pad(k, paddings) + v = tf.pad(v, paddings) + + # Setting up q blocks + q_indices = gather_indices(q, block_length, block_length) + q_new = gather_blocks(q, q_indices) + + # Setting up k and v blocks + full_filter_width = block_length + 2 * filter_flange + k_and_v_indices = gather_indices(k, full_filter_width, block_length) + k_new = gather_blocks(k, k_and_v_indices) + v_new = gather_blocks(v, k_and_v_indices) + + attention_bias = tf.expand_dims( + tf.to_float(embedding_to_padding(k_new)) * -1e9, axis=-2) + + logits = tf.matmul(q_new, k_new, transpose_b=True) + + attention = tf.nn.softmax(logits + attention_bias) + output = tf.matmul(attention, v_new) + + output = tf.reshape(output, [batch_size, num_heads, -1, depth_v]) + # Remove the padding if introduced + output = tf.slice(output, [0, 0, 0, 0], [-1, -1, original_length, -1]) + # [batch, heads, h, w, depth_v] + return tf.reshape(output, v_shape) + + def multihead_attention(query_antecedent, memory_antecedent, bias, diff --git a/tensor2tensor/layers/common_attention_test.py b/tensor2tensor/layers/common_attention_test.py index e846c2002..e49999fbb 100644 --- a/tensor2tensor/layers/common_attention_test.py +++ b/tensor2tensor/layers/common_attention_test.py @@ -41,14 +41,14 @@ def testDotProductAttention(self): res = session.run(a) self.assertEqual(res.shape, (5, 7, 12, 32)) - def testMaskedLocalAttention(self): - q = np.array([[[[1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [ - 1.0, 0.0, 0.0, 0.0 - ], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], + def testMaskedLocalAttention1D(self): + q = np.array([[[[1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]]]]) - k = np.array([[[[1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [ - 1.0, 0.0, 0.0, 0.0 - ], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], + k = np.array([[[[1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]]]]) v = np.ones((1, 1, 8, 1)) with self.test_session() as session: @@ -61,7 +61,7 @@ def testMaskedLocalAttention(self): self.assertEqual(res.shape, (1, 1, 8, 1)) - def testLocalUnmaskedAttention(self): + def testLocalUnmaskedAttention1D(self): x = np.random.rand(5, 4, 25, 16) y = np.random.rand(5, 4, 25, 16) with self.test_session() as session: @@ -75,7 +75,7 @@ def testLocalUnmaskedAttention(self): res = session.run(a) self.assertEqual(res.shape, (5, 4, 25, 16)) - def testLocalUnmaskedAttentionMatchingBlockLength(self): + def testLocalUnmaskedAttention1DMatchingBlockLength(self): x = np.random.rand(5, 4, 25, 16) y = np.random.rand(5, 4, 25, 16) with self.test_session() as session: @@ -89,6 +89,34 @@ def testLocalUnmaskedAttentionMatchingBlockLength(self): res = session.run(a) self.assertEqual(res.shape, (5, 4, 25, 16)) + def testLocalUnmaskedAttention2D(self): + x = np.random.rand(5, 4, 25, 25, 16) + y = np.random.rand(5, 4, 25, 25, 16) + with self.test_session() as session: + a = common_attention.local_attention_2d( + tf.constant(x, dtype=tf.float32), + tf.constant(y, dtype=tf.float32), + tf.constant(y, dtype=tf.float32), + block_length=4, + filter_flange=3) + session.run(tf.global_variables_initializer()) + res = session.run(a) + self.assertEqual(res.shape, (5, 4, 25, 25, 16)) + + def testLocalUnmaskedAttention2DMatchingBlockLength(self): + x = np.random.rand(5, 4, 25, 25, 16) + y = np.random.rand(5, 4, 25, 25, 16) + with self.test_session() as session: + a = common_attention.local_attention_2d( + tf.constant(x, dtype=tf.float32), + tf.constant(y, dtype=tf.float32), + tf.constant(y, dtype=tf.float32), + block_length=5, + filter_flange=3) + session.run(tf.global_variables_initializer()) + res = session.run(a) + self.assertEqual(res.shape, (5, 4, 25, 25, 16)) + if __name__ == "__main__": tf.test.main() From 94eca0c50e8c32d30d262fc249c03e3019ac03f7 Mon Sep 17 00:00:00 2001 From: Lukasz Kaiser Date: Fri, 11 Aug 2017 13:14:39 -0700 Subject: [PATCH 4/7] Rename train_generator to just generator and port wiki_32k to Problem. Also cleaning and speeding up vocab generation, algorithmic problems, wmt_zhen and BPE download. PiperOrigin-RevId: 165015579 --- setup.py | 1 + tensor2tensor/bin/t2t-datagen | 61 +--- tensor2tensor/data_generators/algorithmic.py | 326 ++++++++---------- .../data_generators/algorithmic_test.py | 12 +- tensor2tensor/data_generators/all_problems.py | 1 + tensor2tensor/data_generators/cipher.py | 48 +-- tensor2tensor/data_generators/desc2code.py | 2 +- .../data_generators/generator_utils.py | 1 + tensor2tensor/data_generators/ice_parsing.py | 120 +++++++ tensor2tensor/data_generators/problem.py | 30 +- .../data_generators/problem_hparams.py | 52 --- tensor2tensor/data_generators/ptb.py | 2 +- tensor2tensor/data_generators/text_encoder.py | 6 +- tensor2tensor/data_generators/wiki.py | 102 +++--- tensor2tensor/data_generators/wmt.py | 75 ++-- tensor2tensor/models/transformer.py | 2 +- tensor2tensor/utils/decoding.py | 5 + tensor2tensor/utils/registry.py | 4 +- tensor2tensor/utils/trainer_utils_test.py | 5 +- 19 files changed, 418 insertions(+), 437 deletions(-) create mode 100644 tensor2tensor/data_generators/ice_parsing.py diff --git a/setup.py b/setup.py index 5beeb1b3e..4ada714b6 100644 --- a/setup.py +++ b/setup.py @@ -19,6 +19,7 @@ 'tensor2tensor/bin/t2t-make-tf-configs', ], install_requires=[ + 'bz2file', 'numpy', 'requests', 'sympy', diff --git a/tensor2tensor/bin/t2t-datagen b/tensor2tensor/bin/t2t-datagen index 39453dbee..30784fa60 100644 --- a/tensor2tensor/bin/t2t-datagen +++ b/tensor2tensor/bin/t2t-datagen @@ -45,7 +45,6 @@ from tensor2tensor.data_generators import generator_utils from tensor2tensor.data_generators import image from tensor2tensor.data_generators import lm1b from tensor2tensor.data_generators import snli -from tensor2tensor.data_generators import wiki from tensor2tensor.data_generators import wmt from tensor2tensor.data_generators import wsj_parsing from tensor2tensor.utils import registry @@ -82,16 +81,6 @@ _SUPPORTED_PROBLEM_GENERATORS = { "algorithmic_algebra_inverse": ( lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000), lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)), - "ice_parsing_tokens": ( - lambda: wmt.tabbed_parsing_token_generator( - FLAGS.data_dir, FLAGS.tmp_dir, True, "ice", 2**13, 2**8), - lambda: wmt.tabbed_parsing_token_generator( - FLAGS.data_dir, FLAGS.tmp_dir, False, "ice", 2**13, 2**8)), - "ice_parsing_characters": ( - lambda: wmt.tabbed_parsing_character_generator( - FLAGS.data_dir, FLAGS.tmp_dir, True), - lambda: wmt.tabbed_parsing_character_generator( - FLAGS.data_dir, FLAGS.tmp_dir, False)), "wmt_parsing_tokens_8k": ( lambda: wmt.parsing_token_generator( FLAGS.data_dir, FLAGS.tmp_dir, True, 2**13), @@ -115,10 +104,6 @@ _SUPPORTED_PROBLEM_GENERATORS = { lambda: lm1b.generator(FLAGS.tmp_dir, True, characters=True), lambda: lm1b.generator(FLAGS.tmp_dir, False, characters=True) ), - "wiki_32k": ( - lambda: wiki.generator(FLAGS.tmp_dir, True), - 1000 - ), "image_celeba_tune": ( lambda: image.celeba_generator(FLAGS.tmp_dir, 162770), lambda: image.celeba_generator(FLAGS.tmp_dir, 19867, 162770)), @@ -180,17 +165,14 @@ def main(_): # Remove parsing if paths are not given. if not FLAGS.parsing_path: problems = [p for p in problems if "parsing" not in p] - # Remove en-de BPE if paths are not given. - if not FLAGS.ende_bpe_path: - problems = [p for p in problems if "ende_bpe" not in p] if not problems: problems_str = "\n * ".join( sorted(list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_problems())) error_msg = ("You must specify one of the supported problems to " "generate data for:\n * " + problems_str + "\n") - error_msg += ("TIMIT, ende_bpe and parsing need data_sets specified with " - "--timit_paths, --ende_bpe_path and --parsing_path.") + error_msg += ("TIMIT and parsing need data_sets specified with " + "--timit_paths and --parsing_path.") raise ValueError(error_msg) if not FLAGS.data_dir: @@ -213,34 +195,17 @@ def generate_data_for_problem(problem): """Generate data for a problem in _SUPPORTED_PROBLEM_GENERATORS.""" training_gen, dev_gen = _SUPPORTED_PROBLEM_GENERATORS[problem] - if isinstance(dev_gen, int): - # The dev set and test sets are generated as extra shards using the - # training generator. The integer specifies the number of training - # shards. FLAGS.num_shards is ignored. - num_training_shards = dev_gen - tf.logging.info("Generating data for %s.", problem) - all_output_files = generator_utils.combined_data_filenames( - problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, - num_training_shards) - generator_utils.generate_files(training_gen(), all_output_files, - FLAGS.max_cases) - else: - # usual case - train data and dev data are generated using separate - # generators. - num_shards = FLAGS.num_shards or 10 - tf.logging.info("Generating training data for %s.", problem) - train_output_files = generator_utils.train_data_filenames( - problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, num_shards) - generator_utils.generate_files(training_gen(), train_output_files, - FLAGS.max_cases) - tf.logging.info("Generating development data for %s.", problem) - dev_shards = 10 if "coco" in problem else 1 - dev_output_files = generator_utils.dev_data_filenames( - problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, dev_shards) - generator_utils.generate_files(dev_gen(), dev_output_files) - all_output_files = train_output_files + dev_output_files - - tf.logging.info("Shuffling data...") + num_shards = FLAGS.num_shards or 10 + tf.logging.info("Generating training data for %s.", problem) + train_output_files = generator_utils.train_data_filenames( + problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, num_shards) + generator_utils.generate_files(training_gen(), train_output_files, + FLAGS.max_cases) + tf.logging.info("Generating development data for %s.", problem) + dev_output_files = generator_utils.dev_data_filenames( + problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir, 1) + generator_utils.generate_files(dev_gen(), dev_output_files) + all_output_files = train_output_files + dev_output_files generator_utils.shuffle_dataset(all_output_files) diff --git a/tensor2tensor/data_generators/algorithmic.py b/tensor2tensor/data_generators/algorithmic.py index c115a1ebe..c44ce65d8 100644 --- a/tensor2tensor/data_generators/algorithmic.py +++ b/tensor2tensor/data_generators/algorithmic.py @@ -37,15 +37,10 @@ class AlgorithmicProblem(problem.Problem): def num_symbols(self): raise NotImplementedError() - @property - def train_generator(self): - """Generator; takes 3 args: nbr_symbols, max_length, nbr_cases.""" + def generator(self, nbr_symbols, max_length, nbr_cases): + """Generates the data.""" raise NotImplementedError() - @property - def dev_generator(self): - return self.train_generator - @property def train_length(self): return 40 @@ -67,25 +62,19 @@ def num_shards(self): return 10 def generate_data(self, data_dir, _, task_id=-1): - def generator_eos(generator): + def generator_eos(nbr_symbols, max_length, nbr_cases): """Shift by NUM_RESERVED_IDS and append EOS token.""" - for case in generator: + for case in self.generator(nbr_symbols, max_length, nbr_cases): new_case = {} for feature in case: new_case[feature] = [i + text_encoder.NUM_RESERVED_TOKENS for i in case[feature]] + [text_encoder.EOS_ID] yield new_case - train_generator_eos = lambda: generator_eos( # pylint: disable=g-long-lambda - self.train_generator(self.num_symbols, - self.train_length, self.train_size)) - dev_generator_eos = lambda: generator_eos( # pylint: disable=g-long-lambda - self.dev_generator(self.num_symbols, self.dev_length, self.dev_size)) - utils.generate_dataset_and_shuffle( - train_generator_eos(), + generator_eos(self.num_symbols, self.train_length, self.train_size), self.training_filepaths(data_dir, self.num_shards, shuffled=True), - dev_generator_eos(), + generator_eos(self.num_symbols, self.dev_length, self.dev_size), self.dev_filepaths(data_dir, 1, shuffled=True), shuffle=False) @@ -98,28 +87,6 @@ def hparams(self, defaults, unused_model_hparams): p.target_space_id = problem.SpaceID.DIGIT_1 -def identity_generator(nbr_symbols, max_length, nbr_cases): - """Generator for the identity (copy) task on sequences of symbols. - - The length of the sequence is drawn uniformly at random from [1, max_length] - and then symbols are drawn uniformly at random from [0, nbr_symbols) until - nbr_cases sequences have been produced. - - Args: - nbr_symbols: number of symbols to use in each sequence. - max_length: integer, maximum length of sequences to generate. - nbr_cases: the number of cases to generate. - - Yields: - A dictionary {"inputs": input-list, "targets": target-list} where - input-list and target-list are the same. - """ - for _ in xrange(nbr_cases): - l = np.random.randint(max_length) + 1 - inputs = [np.random.randint(nbr_symbols) for _ in xrange(l)] - yield {"inputs": inputs, "targets": inputs} - - @registry.register_problem class AlgorithmicIdentityBinary40(AlgorithmicProblem): """Problem spec for algorithmic binary identity task.""" @@ -128,9 +95,26 @@ class AlgorithmicIdentityBinary40(AlgorithmicProblem): def num_symbols(self): return 2 - @property - def train_generator(self): - return identity_generator + def generator(self, nbr_symbols, max_length, nbr_cases): + """Generator for the identity (copy) task on sequences of symbols. + + The length of the sequence is drawn uniformly at random from [1, max_length] + and then symbols are drawn uniformly at random from [0, nbr_symbols) until + nbr_cases sequences have been produced. + + Args: + nbr_symbols: number of symbols to use in each sequence. + max_length: integer, maximum length of sequences to generate. + nbr_cases: the number of cases to generate. + + Yields: + A dictionary {"inputs": input-list, "targets": target-list} where + input-list and target-list are the same. + """ + for _ in xrange(nbr_cases): + l = np.random.randint(max_length) + 1 + inputs = [np.random.randint(nbr_symbols) for _ in xrange(l)] + yield {"inputs": inputs, "targets": inputs} @registry.register_problem @@ -142,32 +126,6 @@ def num_symbols(self): return 10 -def shift_generator(nbr_symbols, shift, max_length, nbr_cases): - """Generator for the shift task on sequences of symbols. - - The length of the sequence is drawn uniformly at random from [1, max_length] - and then symbols are drawn uniformly at random from [0, nbr_symbols - shift] - until nbr_cases sequences have been produced (output[i] = input[i] + shift). - - Args: - nbr_symbols: number of symbols to use in each sequence (input + output). - shift: by how much to shift the input. - max_length: integer, maximum length of sequences to generate. - nbr_cases: the number of cases to generate. - - Yields: - A dictionary {"inputs": input-list, "targets": target-list} where - target-list[i] = input-list[i] + shift. - """ - for _ in xrange(nbr_cases): - l = np.random.randint(max_length) + 1 - inputs = [np.random.randint(nbr_symbols - shift) for _ in xrange(l)] - yield { - "inputs": inputs, - "targets": [i + shift for i in inputs] - } - - @registry.register_problem class AlgorithmicShiftDecimal40(AlgorithmicProblem): """Problem spec for algorithmic decimal shift task.""" @@ -176,40 +134,36 @@ class AlgorithmicShiftDecimal40(AlgorithmicProblem): def num_symbols(self): return 20 - @property - def train_generator(self): - return lambda nbr_sym, l, size: shift_generator(nbr_sym, 10, l, size) + def generator(self, nbr_symbols, max_length, nbr_cases): + """Generator for the shift task on sequences of symbols. + + The length of the sequence is drawn uniformly at random from [1, max_length] + and then symbols are drawn uniformly at random from [0, nbr_symbols - shift] + until nbr_cases sequences have been produced (output[i] = input[i] + shift). + + Args: + nbr_symbols: number of symbols to use in each sequence (input + output). + max_length: integer, maximum length of sequences to generate. + nbr_cases: the number of cases to generate. + + Yields: + A dictionary {"inputs": input-list, "targets": target-list} where + target-list[i] = input-list[i] + shift. + """ + shift = 10 + for _ in xrange(nbr_cases): + l = np.random.randint(max_length) + 1 + inputs = [np.random.randint(nbr_symbols - shift) for _ in xrange(l)] + yield { + "inputs": inputs, + "targets": [i + shift for i in inputs] + } @property def dev_length(self): return 80 -def reverse_generator(nbr_symbols, max_length, nbr_cases): - """Generator for the reversing task on sequences of symbols. - - The length of the sequence is drawn uniformly at random from [1, max_length] - and then symbols are drawn uniformly at random from [0, nbr_symbols) until - nbr_cases sequences have been produced. - - Args: - nbr_symbols: number of symbols to use in each sequence. - max_length: integer, maximum length of sequences to generate. - nbr_cases: the number of cases to generate. - - Yields: - A dictionary {"inputs": input-list, "targets": target-list} where - target-list is input-list reversed. - """ - for _ in xrange(nbr_cases): - l = np.random.randint(max_length) + 1 - inputs = [np.random.randint(nbr_symbols) for _ in xrange(l)] - yield { - "inputs": inputs, - "targets": list(reversed(inputs)) - } - - @registry.register_problem class AlgorithmicReverseBinary40(AlgorithmicProblem): """Problem spec for algorithmic binary reversing task.""" @@ -218,9 +172,29 @@ class AlgorithmicReverseBinary40(AlgorithmicProblem): def num_symbols(self): return 2 - @property - def train_generator(self): - return reverse_generator + def generator(self, nbr_symbols, max_length, nbr_cases): + """Generator for the reversing task on sequences of symbols. + + The length of the sequence is drawn uniformly at random from [1, max_length] + and then symbols are drawn uniformly at random from [0, nbr_symbols) until + nbr_cases sequences have been produced. + + Args: + nbr_symbols: number of symbols to use in each sequence. + max_length: integer, maximum length of sequences to generate. + nbr_cases: the number of cases to generate. + + Yields: + A dictionary {"inputs": input-list, "targets": target-list} where + target-list is input-list reversed. + """ + for _ in xrange(nbr_cases): + l = np.random.randint(max_length) + 1 + inputs = [np.random.randint(nbr_symbols) for _ in xrange(l)] + yield { + "inputs": inputs, + "targets": list(reversed(inputs)) + } @registry.register_problem @@ -305,17 +279,16 @@ def reverse_generator_nlplike(nbr_symbols, @registry.register_problem -class AlgorithmicReverseNlplike8K(AlgorithmicProblem): +class AlgorithmicReverseNlplike8k(AlgorithmicProblem): """Problem spec for algorithmic nlp-like reversing task.""" @property def num_symbols(self): return 8000 - @property - def train_generator(self): - return lambda nbr_sym, length, size: reverse_generator_nlplike( # pylint: disable=g-long-lambda - nbr_sym, length, size, 10, 1.300) + def generator(self, nbr_symbols, max_length, nbr_cases): + return reverse_generator_nlplike( + nbr_symbols, max_length, nbr_cases, 10, 1.300) @property def train_length(self): @@ -327,17 +300,16 @@ def dev_length(self): @registry.register_problem -class AlgorithmicReverseNlplike32K(AlgorithmicReverseNlplike8K): - """Problem spec for algorithmic nlp-like reversing task, 32K vocab.""" +class AlgorithmicReverseNlplike32k(AlgorithmicReverseNlplike8k): + """Problem spec for algorithmic nlp-like reversing task, 32k vocab.""" @property def num_symbols(self): return 32000 - @property - def train_generator(self): - return lambda nbr_sym, length, size: reverse_generator_nlplike( # pylint: disable=g-long-lambda - nbr_sym, length, size, 10, 1.050) + def generator(self, nbr_symbols, max_length, nbr_cases): + return reverse_generator_nlplike( + nbr_symbols, max_length, nbr_cases, 10, 1.050) def lower_endian_to_number(l, base): @@ -360,38 +332,6 @@ def random_number_lower_endian(length, base): return prefix + [np.random.randint(base - 1) + 1] # Last digit is not 0. -def addition_generator(base, max_length, nbr_cases): - """Generator for the addition task. - - The length of each number is drawn uniformly at random from [1, max_length/2] - and then digits are drawn uniformly at random. The numbers are added and - separated by [base] in the input. Stops at nbr_cases. - - Args: - base: in which base are the numbers. - max_length: integer, maximum length of sequences to generate. - nbr_cases: the number of cases to generate. - - Yields: - A dictionary {"inputs": input-list, "targets": target-list} where - input-list are the 2 numbers and target-list is the result of adding them. - - Raises: - ValueError: if max_length is lower than 3. - """ - if max_length < 3: - raise ValueError("Maximum length must be at least 3.") - for _ in xrange(nbr_cases): - l1 = np.random.randint(max_length // 2) + 1 - l2 = np.random.randint(max_length - l1 - 1) + 1 - n1 = random_number_lower_endian(l1, base) - n2 = random_number_lower_endian(l2, base) - result = lower_endian_to_number(n1, base) + lower_endian_to_number(n2, base) - inputs = n1 + [base] + n2 - targets = number_to_lower_endian(result, base) - yield {"inputs": inputs, "targets": targets} - - @registry.register_problem class AlgorithmicAdditionBinary40(AlgorithmicProblem): """Problem spec for algorithmic binary addition task.""" @@ -400,9 +340,37 @@ class AlgorithmicAdditionBinary40(AlgorithmicProblem): def num_symbols(self): return 2 - @property - def train_generator(self): - return addition_generator + def generator(self, base, max_length, nbr_cases): + """Generator for the addition task. + + The length of each number is drawn uniformly at random in [1, max_length/2] + and then digits are drawn uniformly at random. The numbers are added and + separated by [base] in the input. Stops at nbr_cases. + + Args: + base: in which base are the numbers. + max_length: integer, maximum length of sequences to generate. + nbr_cases: the number of cases to generate. + + Yields: + A dictionary {"inputs": input-list, "targets": target-list} where + input-list are the 2 numbers and target-list is the result of adding them. + + Raises: + ValueError: if max_length is lower than 3. + """ + if max_length < 3: + raise ValueError("Maximum length must be at least 3.") + for _ in xrange(nbr_cases): + l1 = np.random.randint(max_length // 2) + 1 + l2 = np.random.randint(max_length - l1 - 1) + 1 + n1 = random_number_lower_endian(l1, base) + n2 = random_number_lower_endian(l2, base) + result = lower_endian_to_number(n1, base) + lower_endian_to_number( + n2, base) + inputs = n1 + [base] + n2 + targets = number_to_lower_endian(result, base) + yield {"inputs": inputs, "targets": targets} @registry.register_problem @@ -414,39 +382,6 @@ def num_symbols(self): return 10 -def multiplication_generator(base, max_length, nbr_cases): - """Generator for the multiplication task. - - The length of each number is drawn uniformly at random from [1, max_length/2] - and then digits are drawn uniformly at random. The numbers are multiplied - and separated by [base] in the input. Stops at nbr_cases. - - Args: - base: in which base are the numbers. - max_length: integer, maximum length of sequences to generate. - nbr_cases: the number of cases to generate. - - Yields: - A dictionary {"inputs": input-list, "targets": target-list} where - input-list are the 2 numbers and target-list is the result of multiplying - them. - - Raises: - ValueError: if max_length is lower than 3. - """ - if max_length < 3: - raise ValueError("Maximum length must be at least 3.") - for _ in xrange(nbr_cases): - l1 = np.random.randint(max_length // 2) + 1 - l2 = np.random.randint(max_length - l1 - 1) + 1 - n1 = random_number_lower_endian(l1, base) - n2 = random_number_lower_endian(l2, base) - result = lower_endian_to_number(n1, base) * lower_endian_to_number(n2, base) - inputs = n1 + [base] + n2 - targets = number_to_lower_endian(result, base) - yield {"inputs": inputs, "targets": targets} - - @registry.register_problem class AlgorithmicMultiplicationBinary40(AlgorithmicProblem): """Problem spec for algorithmic binary multiplication task.""" @@ -455,9 +390,38 @@ class AlgorithmicMultiplicationBinary40(AlgorithmicProblem): def num_symbols(self): return 2 - @property - def train_generator(self): - return multiplication_generator + def generator(self, base, max_length, nbr_cases): + """Generator for the multiplication task. + + The length of each number is drawn uniformly at random in [1, max_length/2] + and then digits are drawn uniformly at random. The numbers are multiplied + and separated by [base] in the input. Stops at nbr_cases. + + Args: + base: in which base are the numbers. + max_length: integer, maximum length of sequences to generate. + nbr_cases: the number of cases to generate. + + Yields: + A dictionary {"inputs": input-list, "targets": target-list} where + input-list are the 2 numbers and target-list is the result of multiplying + them. + + Raises: + ValueError: if max_length is lower than 3. + """ + if max_length < 3: + raise ValueError("Maximum length must be at least 3.") + for _ in xrange(nbr_cases): + l1 = np.random.randint(max_length // 2) + 1 + l2 = np.random.randint(max_length - l1 - 1) + 1 + n1 = random_number_lower_endian(l1, base) + n2 = random_number_lower_endian(l2, base) + result = lower_endian_to_number(n1, base) * lower_endian_to_number( + n2, base) + inputs = n1 + [base] + n2 + targets = number_to_lower_endian(result, base) + yield {"inputs": inputs, "targets": targets} @registry.register_problem diff --git a/tensor2tensor/data_generators/algorithmic_test.py b/tensor2tensor/data_generators/algorithmic_test.py index 57faaa80b..4ac6d3123 100644 --- a/tensor2tensor/data_generators/algorithmic_test.py +++ b/tensor2tensor/data_generators/algorithmic_test.py @@ -29,15 +29,17 @@ class AlgorithmicTest(tf.test.TestCase): def testIdentityGenerator(self): + identity_problem = algorithmic.AlgorithmicIdentityBinary40() counter = 0 - for d in algorithmic.identity_generator(3, 8, 10): + for d in identity_problem.generator(3, 8, 10): counter += 1 self.assertEqual(d["inputs"], d["targets"]) self.assertEqual(counter, 10) def testReverseGenerator(self): + reversing_problem = algorithmic.AlgorithmicReverseBinary40() counter = 0 - for d in algorithmic.reverse_generator(3, 8, 10): + for d in reversing_problem.generator(3, 8, 10): counter += 1 self.assertEqual(list(reversed(d["inputs"])), d["targets"]) self.assertEqual(counter, 10) @@ -76,8 +78,9 @@ def testNumberToLowerEndian(self): self.assertEqual(algorithmic.number_to_lower_endian(2137, 10), [7, 3, 1, 2]) def testAdditionGenerator(self): + addition_problem = algorithmic.AlgorithmicAdditionBinary40() counter = 0 - for d in algorithmic.addition_generator(4, 8, 10): + for d in addition_problem.generator(4, 8, 10): counter += 1 self.assertEqual(d["inputs"].count(4), 1) self.assertEqual(d["inputs"].count(5), 0) @@ -86,8 +89,9 @@ def testAdditionGenerator(self): self.assertEqual(counter, 10) def testMultiplicationGenerator(self): + multiplication_problem = algorithmic.AlgorithmicMultiplicationBinary40() counter = 0 - for d in algorithmic.multiplication_generator(4, 8, 10): + for d in multiplication_problem.generator(4, 8, 10): counter += 1 self.assertEqual(d["inputs"].count(4), 1) self.assertEqual(d["inputs"].count(5), 0) diff --git a/tensor2tensor/data_generators/all_problems.py b/tensor2tensor/data_generators/all_problems.py index ca6dccfda..0078eb3f9 100644 --- a/tensor2tensor/data_generators/all_problems.py +++ b/tensor2tensor/data_generators/all_problems.py @@ -24,6 +24,7 @@ from tensor2tensor.data_generators import audio from tensor2tensor.data_generators import cipher from tensor2tensor.data_generators import desc2code +from tensor2tensor.data_generators import ice_parsing from tensor2tensor.data_generators import image from tensor2tensor.data_generators import lm1b from tensor2tensor.data_generators import ptb diff --git a/tensor2tensor/data_generators/cipher.py b/tensor2tensor/data_generators/cipher.py index 41dcbd80e..a11776b84 100644 --- a/tensor2tensor/data_generators/cipher.py +++ b/tensor2tensor/data_generators/cipher.py @@ -44,23 +44,13 @@ def distribution(self): def shift(self): return 1 - @property - def train_generator(self): - """Generator; takes 3 args: nbr_symbols, max_length, nbr_cases.""" - - def _gen(nbr_symbols, max_length, nbr_cases): - plain_vocab = range(nbr_symbols) - indices = generate_plaintext_random(plain_vocab, self.distribution, - nbr_cases, max_length) - codes = encipher_shift(indices, plain_vocab, self.shift) - - for plain, code in zip(indices, codes): - yield { - "inputs": plain, - "targets": code, - } - - return _gen + def generator(self, nbr_symbols, max_length, nbr_cases): + plain_vocab = range(nbr_symbols) + indices = generate_plaintext_random( + plain_vocab, self.distribution, nbr_cases, max_length) + codes = encipher_shift(indices, plain_vocab, self.shift) + for plain, code in zip(indices, codes): + yield {"inputs": plain, "targets": code} @property def train_length(self): @@ -87,23 +77,13 @@ def distribution(self): def key(self): return [1, 3] - @property - def train_generator(self): - """Generator; takes 3 args: nbr_symbols, max_length, nbr_cases.""" - - def _gen(nbr_symbols, max_length, nbr_cases): - plain_vocab = range(nbr_symbols) - indices = generate_plaintext_random(plain_vocab, self.distribution, - nbr_cases, max_length) - codes = encipher_vigenere(indices, plain_vocab, self.key) - - for plain, code in zip(indices, codes): - yield { - "inputs": plain, - "targets": code, - } - - return _gen + def generator(self, nbr_symbols, max_length, nbr_cases): + plain_vocab = range(nbr_symbols) + indices = generate_plaintext_random(plain_vocab, self.distribution, + nbr_cases, max_length) + codes = encipher_vigenere(indices, plain_vocab, self.key) + for plain, code in zip(indices, codes): + yield {"inputs": plain, "targets": code} @property def train_length(self): diff --git a/tensor2tensor/data_generators/desc2code.py b/tensor2tensor/data_generators/desc2code.py index 6cef6db63..438c116c8 100644 --- a/tensor2tensor/data_generators/desc2code.py +++ b/tensor2tensor/data_generators/desc2code.py @@ -138,7 +138,7 @@ def feature_encoders(self, data_dir): "targets": target_token, } - def train_generator(self, data_dir, tmp_dir, train): + def generator(self, data_dir, tmp_dir, train): # Called twice: for train and test # Get the list of the training samples (coding challenge samples) diff --git a/tensor2tensor/data_generators/generator_utils.py b/tensor2tensor/data_generators/generator_utils.py index b38531c1a..eadca9bd6 100644 --- a/tensor2tensor/data_generators/generator_utils.py +++ b/tensor2tensor/data_generators/generator_utils.py @@ -308,6 +308,7 @@ def get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size, vocab = text_encoder.SubwordTextEncoder(vocab_filepath) return vocab + tf.logging.info("Generating vocab file: %s", vocab_filepath) token_counts = defaultdict(int) for item in generator_fn(): for tok in tokenizer.encode(text_encoder.native_to_unicode(item)): diff --git a/tensor2tensor/data_generators/ice_parsing.py b/tensor2tensor/data_generators/ice_parsing.py new file mode 100644 index 000000000..591b205da --- /dev/null +++ b/tensor2tensor/data_generators/ice_parsing.py @@ -0,0 +1,120 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module implements the ice_parsing_* problems.""" + +# These parse plain text into flattened parse trees and POS tags. +# The training data is stored in files named `parsing_train.pairs` +# and `parsing_dev.pairs`. These files are UTF-8 text files where +# each line contains an input sentence and a target parse tree, +# separated by a tab character. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +# Dependency imports + +from tensor2tensor.data_generators import generator_utils +from tensor2tensor.data_generators import problem +from tensor2tensor.data_generators import text_encoder +from tensor2tensor.data_generators.wmt import tabbed_generator +from tensor2tensor.utils import registry + + +# End-of-sentence marker. +EOS = text_encoder.EOS_ID + + +def tabbed_parsing_token_generator(data_dir, tmp_dir, train, prefix, + source_vocab_size, target_vocab_size): + """Generate source and target data from a single file.""" + filename = "parsing_{0}.pairs".format("train" if train else "dev") + source_vocab = generator_utils.get_or_generate_tabbed_vocab( + data_dir, tmp_dir, filename, 0, + prefix + "_source.tokens.vocab.%d" % source_vocab_size, source_vocab_size) + target_vocab = generator_utils.get_or_generate_tabbed_vocab( + data_dir, tmp_dir, filename, 1, + prefix + "_target.tokens.vocab.%d" % target_vocab_size, target_vocab_size) + pair_filepath = os.path.join(tmp_dir, filename) + return tabbed_generator(pair_filepath, source_vocab, target_vocab, EOS) + + +def tabbed_parsing_character_generator(tmp_dir, train): + """Generate source and target data from a single file.""" + character_vocab = text_encoder.ByteTextEncoder() + filename = "parsing_{0}.pairs".format("train" if train else "dev") + pair_filepath = os.path.join(tmp_dir, filename) + return tabbed_generator(pair_filepath, character_vocab, character_vocab, EOS) + + +@registry.register_problem("ice_parsing_tokens") +class IceParsingTokens(problem.Problem): + """Problem spec for parsing tokenized Icelandic text to constituency trees.""" + + @property + def source_vocab_size(self): + return 2**14 # 16384 + + @property + def targeted_vocab_size(self): + return 2**8 # 256 + + @property + def input_space_id(self): + return problem.SpaceID.ICE_TOK + + @property + def target_space_id(self): + return problem.SpaceID.ICE_PARSE_TOK + + @property + def num_shards(self): + return 10 + + def feature_encoders(self, data_dir): + source_vocab_filename = os.path.join( + data_dir, "ice_source.tokens.vocab.%d" % self.source_vocab_size) + target_vocab_filename = os.path.join( + data_dir, "ice_target.tokens.vocab.%d" % self.targeted_vocab_size) + source_subtokenizer = text_encoder.SubwordTextEncoder(source_vocab_filename) + target_subtokenizer = text_encoder.SubwordTextEncoder(target_vocab_filename) + return { + "inputs": source_subtokenizer, + "targets": target_subtokenizer, + } + + def generate_data(self, data_dir, tmp_dir, task_id=-1): + generator_utils.generate_dataset_and_shuffle( + tabbed_parsing_token_generator(data_dir, tmp_dir, True, "ice", + self.source_vocab_size, + self.targeted_vocab_size), + self.training_filepaths(data_dir, self.num_shards, shuffled=False), + tabbed_parsing_token_generator(data_dir, tmp_dir, False, "ice", + self.source_vocab_size, + self.targeted_vocab_size), + self.dev_filepaths(data_dir, 1, shuffled=False)) + + def hparams(self, defaults, model_hparams): + p = defaults + source_vocab_size = self._encoders["inputs"].vocab_size + p.input_modality = {"inputs": (registry.Modalities.SYMBOL, + source_vocab_size)} + p.target_modality = (registry.Modalities.SYMBOL, self.targeted_vocab_size) + p.input_space_id = self.input_space_id + p.target_space_id = self.target_space_id + p.loss_multiplier = 2.5 # Rough estimate of avg number of tokens per word diff --git a/tensor2tensor/data_generators/problem.py b/tensor2tensor/data_generators/problem.py index 07fafb492..7a84aac93 100644 --- a/tensor2tensor/data_generators/problem.py +++ b/tensor2tensor/data_generators/problem.py @@ -359,13 +359,14 @@ def is_character_level(self): def targeted_vocab_size(self): raise NotImplementedError() # Not needed if self.is_character_level. - def train_generator(self, data_dir, tmp_dir, is_training): - """Generator of the training data.""" + def generator(self, data_dir, tmp_dir, is_training): + """Generator for the training and evaluation data.""" raise NotImplementedError() - def dev_generator(self, data_dir, tmp_dir): - """Generator of the development data.""" - return self.train_generator(data_dir, tmp_dir, False) + @property + def use_train_shards_for_dev(self): + """If true, we only generate training data and hold out shards for dev.""" + return False @property def input_space_id(self): @@ -379,6 +380,10 @@ def target_space_id(self): def num_shards(self): raise NotImplementedError() + @property + def num_dev_shards(self): + return 1 + @property def vocab_name(self): raise NotImplementedError() @@ -396,11 +401,20 @@ def has_inputs(self): return True # Set to False for language models. def generate_data(self, data_dir, tmp_dir, task_id=-1): + train_paths = self.training_filepaths( + data_dir, self.num_shards, shuffled=False) + dev_paths = self.dev_filepaths( + data_dir, self.num_dev_shards, shuffled=False) + if self.use_train_shards_for_dev: + all_paths = train_paths + dev_paths + generator_utils.generate_files( + self.generator(data_dir, tmp_dir, True), all_paths) + generator_utils.shuffle_dataset(all_paths) generator_utils.generate_dataset_and_shuffle( - self.train_generator(data_dir, tmp_dir, True), + self.generator(data_dir, tmp_dir, True), self.training_filepaths(data_dir, self.num_shards, shuffled=False), - self.dev_generator(data_dir, tmp_dir), - self.dev_filepaths(data_dir, 1, shuffled=False)) + self.generator(data_dir, tmp_dir, False), + self.dev_filepaths(data_dir, self.num_dev_shards, shuffled=False)) def feature_encoders(self, data_dir): if self.is_character_level: diff --git a/tensor2tensor/data_generators/problem_hparams.py b/tensor2tensor/data_generators/problem_hparams.py index d0577db52..b33438d6d 100644 --- a/tensor2tensor/data_generators/problem_hparams.py +++ b/tensor2tensor/data_generators/problem_hparams.py @@ -345,19 +345,6 @@ def lm1b_characters(unused_model_hparams): return p -def wiki_32k(model_hparams): - """Wikipedia title to article. 32k subtoken vocabulary.""" - p = default_problem_hparams() - encoder = text_encoder.SubwordTextEncoder( - os.path.join(model_hparams.data_dir, "wiki_32k.subword_text_encoder")) - modality_spec = (registry.Modalities.SYMBOL, encoder.vocab_size) - p.input_modality = {"inputs": modality_spec} - p.target_modality = modality_spec - p.vocabulary = {"inputs": encoder, "targets": encoder} - p.target_space_id = 3 - return p - - def wmt_ende_bpe32k(model_hparams): """English to German translation benchmark.""" p = default_problem_hparams() @@ -462,39 +449,6 @@ def wsj_parsing_tokens(model_hparams, prefix, wrong_source_vocab_size, return p -def ice_parsing_tokens(model_hparams, wrong_source_vocab_size): - """Icelandic to parse tree translation benchmark. - - Args: - model_hparams: a tf.contrib.training.HParams - wrong_source_vocab_size: a number used in the filename indicating the - approximate vocabulary size. This is not to be confused with the actual - vocabulary size. - - Returns: - A tf.contrib.training.HParams object. - """ - p = default_problem_hparams() - # This vocab file must be present within the data directory. - source_vocab_filename = os.path.join( - model_hparams.data_dir, "ice_source.vocab.%d" % wrong_source_vocab_size) - target_vocab_filename = os.path.join(model_hparams.data_dir, - "ice_target.vocab.256") - source_subtokenizer = text_encoder.SubwordTextEncoder(source_vocab_filename) - target_subtokenizer = text_encoder.SubwordTextEncoder(target_vocab_filename) - p.input_modality = { - "inputs": (registry.Modalities.SYMBOL, source_subtokenizer.vocab_size) - } - p.target_modality = (registry.Modalities.SYMBOL, 256) - p.vocabulary = { - "inputs": source_subtokenizer, - "targets": target_subtokenizer, - } - p.input_space_id = 18 # Icelandic tokens - p.target_space_id = 19 # Icelandic parse tokens - return p - - def img2img_imagenet(unused_model_hparams): """Image 2 Image for imagenet dataset.""" p = default_problem_hparams() @@ -542,12 +496,6 @@ def image_celeba(unused_model_hparams): lm1b_characters, "lm1b_32k": lm1b_32k, - "wiki_32k": - wiki_32k, - "ice_parsing_characters": - wmt_parsing_characters, - "ice_parsing_tokens": - lambda p: ice_parsing_tokens(p, 2**13), "wmt_parsing_tokens_8k": lambda p: wmt_parsing_tokens(p, 2**13), "wsj_parsing_tokens_16k": diff --git a/tensor2tensor/data_generators/ptb.py b/tensor2tensor/data_generators/ptb.py index 18aedd640..b9014bcd6 100644 --- a/tensor2tensor/data_generators/ptb.py +++ b/tensor2tensor/data_generators/ptb.py @@ -105,7 +105,7 @@ def use_subword_tokenizer(self): def targeted_vocab_size(self): return 10000 - def train_generator(self, data_dir, tmp_dir, train): + def generator(self, data_dir, tmp_dir, train): filename = os.path.basename(PTB_URL) compressed_filepath = generator_utils.maybe_download( tmp_dir, filename, PTB_URL) diff --git a/tensor2tensor/data_generators/text_encoder.py b/tensor2tensor/data_generators/text_encoder.py index ad9c04c96..b628a538f 100644 --- a/tensor2tensor/data_generators/text_encoder.py +++ b/tensor2tensor/data_generators/text_encoder.py @@ -441,6 +441,8 @@ def build_to_target_size(cls, if min_val > max_val: raise ValueError("Lower bound for the minimum token count " "is greater than the upper bound.") + if target_size < 1: + raise ValueError("Target size must be positive.") def bisect(min_val, max_val): """Bisection to find the right size.""" @@ -450,8 +452,10 @@ def bisect(min_val, max_val): subtokenizer.build_from_token_counts(token_counts, present_count, num_iterations) + # Being within 1% of the target size is ok. + is_ok = abs(subtokenizer.vocab_size - target_size) * 100 < target_size # If min_val == max_val, we can't do any better than this. - if subtokenizer.vocab_size == target_size or min_val >= max_val: + if is_ok or min_val >= max_val or present_count < 2: return subtokenizer if subtokenizer.vocab_size > target_size: diff --git a/tensor2tensor/data_generators/wiki.py b/tensor2tensor/data_generators/wiki.py index 49147962a..1e427dbe8 100644 --- a/tensor2tensor/data_generators/wiki.py +++ b/tensor2tensor/data_generators/wiki.py @@ -19,23 +19,21 @@ from __future__ import division from __future__ import print_function -import bz2 -from collections import defaultdict import os # Dependency imports +import bz2file + import six from tensor2tensor.data_generators import generator_utils +from tensor2tensor.data_generators import problem from tensor2tensor.data_generators import text_encoder -from tensor2tensor.data_generators import tokenizer - -import tensorflow as tf +from tensor2tensor.utils import registry -# End-of-sentence marker (should correspond to the position of EOS in the -# RESERVED_TOKENS list in text_encoder.py) -EOS = 1 +# End-of-sentence marker. +EOS = text_encoder.EOS_ID def _maybe_download_corpus(tmp_dir): @@ -60,7 +58,7 @@ def page_generator(tmp_dir, max_docs=None): doc = u"" count = 0 corpus_filepath = _maybe_download_corpus(tmp_dir) - for line in bz2.BZ2File(corpus_filepath, "r"): + for line in bz2file.BZ2File(corpus_filepath, "r", buffering=1000000): line = unicode(line, "utf-8") if six.PY2 else line.decode("utf-8") if not doc and line != u" \n": continue @@ -82,48 +80,52 @@ def _page_title(page): return page[start_pos:end_pos] -def _get_or_build_subword_text_encoder(tmp_dir): - """Builds a SubwordTextEncoder based on the corpus. +@registry.register_problem +class Wiki32k(problem.Text2TextProblem): + """A class for generating PTB data.""" - Args: - tmp_dir: a string + @property + def is_character_level(self): + return False - Returns: - a SubwordTextEncoder. - """ - filename = os.path.join(tmp_dir, "wiki_32k.subword_text_encoder") - if tf.gfile.Exists(filename): - return text_encoder.SubwordTextEncoder(filename) - token_counts = defaultdict(int) - for page in page_generator(tmp_dir, max_docs=1000): - tokens = tokenizer.encode(page) - tokens = set(tokens) - for tok in tokens: - token_counts[tok] += 1 - new_token_counts = defaultdict(int) - for token, count in six.iteritems(token_counts): - if count >= 3: - new_token_counts[token] = count - ret = text_encoder.SubwordTextEncoder() - ret.build_from_token_counts(new_token_counts, min_count=10) - ret.store_to_file(filename) - return ret - - -def generator(tmp_dir, train): - """Generator for lm1b sentences. + @property + def has_inputs(self): + return True - Args: - tmp_dir: a string. - train: a boolean. + @property + def input_space_id(self): + return problem.SpaceID.EN_TOK - Yields: - A dictionary {"inputs": [], "targets": []} - """ - assert train - encoder = _get_or_build_subword_text_encoder(tmp_dir) - for page in page_generator(tmp_dir): - title = _page_title(page) - encoded = encoder.encode(page) + [EOS] - encoded_title = encoder.encode(title) + [EOS] - yield {"inputs": encoded_title, "targets": encoded} + @property + def target_space_id(self): + return problem.SpaceID.EN_TOK + + @property + def num_shards(self): + return 1000 + + @property + def vocab_name(self): + return "vocab.wiki" + + @property + def use_subword_tokenizer(self): + return True + + @property + def targeted_vocab_size(self): + return 2**15 # 32768 + + @property + def use_train_shards_for_dev(self): + return True + + def generator(self, data_dir, tmp_dir, _): + encoder = generator_utils.get_or_generate_vocab_inner( + data_dir, self.vocab_file, self.targeted_vocab_size, + lambda: page_generator(tmp_dir, max_docs=10000)) + for page in page_generator(tmp_dir): + title = _page_title(page) + encoded = encoder.encode(page) + [EOS] + encoded_title = encoder.encode(title) + [EOS] + yield {"inputs": encoded_title, "targets": encoded} diff --git a/tensor2tensor/data_generators/wmt.py b/tensor2tensor/data_generators/wmt.py index 0a47e9989..52990eb5f 100644 --- a/tensor2tensor/data_generators/wmt.py +++ b/tensor2tensor/data_generators/wmt.py @@ -32,10 +32,6 @@ import tensorflow as tf -tf.flags.DEFINE_string("ende_bpe_path", "", "Path to BPE files in tmp_dir." - "Download from https://drive.google.com/open?" - "id=0B_bZck-ksdkpM25jRUN2X2UxMm8") - FLAGS = tf.flags.FLAGS @@ -295,15 +291,15 @@ def bi_vocabs_token_generator(source_path, # Generators. -def _get_wmt_ende_dataset(directory, filename): +def _get_wmt_ende_bpe_dataset(directory, filename): """Extract the WMT en-de corpus `filename` to directory unless it's there.""" train_path = os.path.join(directory, filename) if not (tf.gfile.Exists(train_path + ".de") and tf.gfile.Exists(train_path + ".en")): - # We expect that this file has been downloaded from: - # https://drive.google.com/open?id=0B_bZck-ksdkpM25jRUN2X2UxMm8 and placed - # in `directory`. - corpus_file = os.path.join(directory, FLAGS.ende_bpe_path) + url = ("https://drive.google.com/uc?export=download&id=" + "0B_bZck-ksdkpM25jRUN2X2UxMm8") + corpus_file = generator_utils.maybe_download_from_drive( + directory, "wmt16_en_de.tar.gz", url) with tarfile.open(corpus_file, "r:gz") as corpus_tar: corpus_tar.extractall(directory) return train_path @@ -313,7 +309,7 @@ def ende_bpe_token_generator(data_dir, tmp_dir, train): """Instance of token generator for the WMT en->de task, training set.""" dataset_path = ("train.tok.clean.bpe.32000" if train else "newstest2013.tok.bpe.32000") - train_path = _get_wmt_ende_dataset(tmp_dir, dataset_path) + train_path = _get_wmt_ende_bpe_dataset(tmp_dir, dataset_path) token_tmp_path = os.path.join(tmp_dir, "vocab.bpe.32000") token_path = os.path.join(data_dir, "vocab.bpe.32000") tf.gfile.Copy(token_tmp_path, token_path, overwrite=True) @@ -334,6 +330,7 @@ def _preprocess_sgm(line, is_sgm): if line.startswith("

") or line.startswith("

"): return "" # Strip tags. + line = line.strip() if line.startswith(""): i = line.index(">") return line[i+1:-6] # Strip first and last . @@ -392,7 +389,7 @@ class WMTEnDeTokens8k(WMTProblem): def targeted_vocab_size(self): return 2**13 # 8192 - def train_generator(self, data_dir, tmp_dir, train): + def generator(self, data_dir, tmp_dir, train): symbolizer_vocab = generator_utils.get_or_generate_vocab( data_dir, tmp_dir, self.vocab_file, self.targeted_vocab_size) datasets = _ENDE_TRAIN_DATASETS if train else _ENDE_TEST_DATASETS @@ -426,7 +423,7 @@ class WMTEnDeCharacters(WMTProblem): def is_character_level(self): return True - def train_generator(self, _, tmp_dir, train): + def generator(self, _, tmp_dir, train): character_vocab = text_encoder.ByteTextEncoder() datasets = _ENDE_TRAIN_DATASETS if train else _ENDE_TEST_DATASETS tag = "train" if train else "dev" @@ -451,18 +448,22 @@ class WMTZhEnTokens8k(WMTProblem): def targeted_vocab_size(self): return 2**13 # 8192 - def train_generator(self, data_dir, tmp_dir, train): + @property + def num_shards(self): + return 10 # This is a small dataset. + + def generator(self, data_dir, tmp_dir, train): source_vocab_size = self.targeted_vocab_size target_vocab_size = self.targeted_vocab_size datasets = _ZHEN_TRAIN_DATASETS if train else _ZHEN_TEST_DATASETS source_datasets = [[item[0], [item[1][0]]] for item in _ZHEN_TRAIN_DATASETS] target_datasets = [[item[0], [item[1][1]]] for item in _ZHEN_TRAIN_DATASETS] source_vocab = generator_utils.get_or_generate_vocab( - data_dir, tmp_dir, "vocab.zh.%d" % source_vocab_size, source_vocab_size, - source_datasets) + data_dir, tmp_dir, "vocab.zhen-zh.%d" % source_vocab_size, + source_vocab_size, source_datasets) target_vocab = generator_utils.get_or_generate_vocab( - data_dir, tmp_dir, "vocab.en.%d" % target_vocab_size, target_vocab_size, - target_datasets) + data_dir, tmp_dir, "vocab.zhen-en.%d" % target_vocab_size, + target_vocab_size, target_datasets) tag = "train" if train else "dev" data_path = _compile_data(tmp_dir, datasets, "wmt_zhen_tok_%s" % tag) return bi_vocabs_token_generator(data_path + ".lang1", data_path + ".lang2", @@ -490,14 +491,6 @@ def feature_encoders(self, data_dir): } -@registry.register_problem("wmt_zhen_tokens_32k") -class WMTZhEnTokens32k(WMTZhEnTokens8k): - - @property - def targeted_vocab_size(self): - return 2**15 # 32768 - - @registry.register_problem("wmt_enfr_tokens_8k") class WMTEnFrTokens8k(WMTProblem): """Problem spec for WMT En-Fr translation.""" @@ -506,7 +499,7 @@ class WMTEnFrTokens8k(WMTProblem): def targeted_vocab_size(self): return 2**13 # 8192 - def train_generator(self, data_dir, tmp_dir, train): + def generator(self, data_dir, tmp_dir, train): symbolizer_vocab = generator_utils.get_or_generate_vocab( data_dir, tmp_dir, self.vocab_file, self.targeted_vocab_size) datasets = _ENFR_TRAIN_DATASETS if train else _ENFR_TEST_DATASETS @@ -540,7 +533,7 @@ class WMTEnFrCharacters(WMTProblem): def is_character_level(self): return True - def train_generator(self, data_dir, tmp_dir, train): + def generator(self, data_dir, tmp_dir, train): character_vocab = text_encoder.ByteTextEncoder() datasets = _ENFR_TRAIN_DATASETS if train else _ENFR_TEST_DATASETS tag = "train" if train else "dev" @@ -569,7 +562,7 @@ def targeted_vocab_size(self): def vocab_name(self): return "vocab.mken" - def train_generator(self, data_dir, tmp_dir, train): + def generator(self, data_dir, tmp_dir, train): datasets = _MKEN_TRAIN_DATASETS if train else _MKEN_TEST_DATASETS source_datasets = [[item[0], [item[1][0]]] for item in datasets] target_datasets = [[item[0], [item[1][1]]] for item in datasets] @@ -602,7 +595,7 @@ def targeted_vocab_size(self): def vocab_name(self): return "vocab.encs" - def train_generator(self, data_dir, tmp_dir, train): + def generator(self, data_dir, tmp_dir, train): datasets = _ENCS_TRAIN_DATASETS if train else _ENCS_TEST_DATASETS source_datasets = [[item[0], [item[1][0]]] for item in datasets] target_datasets = [[item[0], [item[1][1]]] for item in datasets] @@ -631,7 +624,7 @@ class WMTEnCsCharacters(WMTProblem): def is_character_level(self): return True - def train_generator(self, data_dir, tmp_dir, train): + def generator(self, data_dir, tmp_dir, train): character_vocab = text_encoder.ByteTextEncoder() datasets = _ENCS_TRAIN_DATASETS if train else _ENCS_TEST_DATASETS tag = "train" if train else "dev" @@ -648,28 +641,6 @@ def target_space_id(self): return problem.SpaceID.CS_CHR -def tabbed_parsing_token_generator(data_dir, tmp_dir, train, prefix, - source_vocab_size, target_vocab_size): - """Generate source and target data from a single file.""" - source_vocab = generator_utils.get_or_generate_tabbed_vocab( - data_dir, tmp_dir, "parsing_train.pairs", 0, - prefix + "_source.vocab.%d" % source_vocab_size, source_vocab_size) - target_vocab = generator_utils.get_or_generate_tabbed_vocab( - data_dir, tmp_dir, "parsing_train.pairs", 1, - prefix + "_target.vocab.%d" % target_vocab_size, target_vocab_size) - filename = "parsing_%s" % ("train" if train else "dev") - pair_filepath = os.path.join(tmp_dir, filename + ".pairs") - return tabbed_generator(pair_filepath, source_vocab, target_vocab, EOS) - - -def tabbed_parsing_character_generator(tmp_dir, train): - """Generate source and target data from a single file.""" - character_vocab = text_encoder.ByteTextEncoder() - filename = "parsing_%s" % ("train" if train else "dev") - pair_filepath = os.path.join(tmp_dir, filename + ".pairs") - return tabbed_generator(pair_filepath, character_vocab, character_vocab, EOS) - - def parsing_token_generator(data_dir, tmp_dir, train, vocab_size): symbolizer_vocab = generator_utils.get_or_generate_vocab( data_dir, tmp_dir, "vocab.endefr.%d" % vocab_size, vocab_size) diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index 37c1206bd..06f49b231 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -431,7 +431,7 @@ def transformer_parsing_big(): @registry.register_hparams def transformer_parsing_ice(): - """Hparams for parsing Icelandic text.""" + """Hparams for parsing and tagging Icelandic text.""" hparams = transformer_base_single_gpu() hparams.batch_size = 4096 hparams.shared_embedding_and_softmax_weights = int(False) diff --git a/tensor2tensor/utils/decoding.py b/tensor2tensor/utils/decoding.py index 5e8f4d482..da33cf90e 100644 --- a/tensor2tensor/utils/decoding.py +++ b/tensor2tensor/utils/decoding.py @@ -259,6 +259,11 @@ def _interactive_input_fn(hparams): vocabulary = p_hparams.vocabulary["inputs" if has_input else "targets"] # This should be longer than the longest input. const_array_size = 10000 + # Import readline if available for command line editing and recall. + try: + import readline # pylint: disable=g-import-not-at-top,unused-variable + except ImportError: + pass while True: prompt = ("INTERACTIVE MODE num_samples=%d decode_length=%d \n" " it= ('text' or 'image' or 'label')\n" diff --git a/tensor2tensor/utils/registry.py b/tensor2tensor/utils/registry.py index fea647b2b..6ce650ac3 100644 --- a/tensor2tensor/utils/registry.py +++ b/tensor2tensor/utils/registry.py @@ -225,10 +225,10 @@ def parse_problem_name(problem_name): was_copy: A boolean. """ # Recursively strip tags until we reach a base name. - if len(problem_name) > 4 and problem_name[-4:] == "_rev": + if problem_name.endswith("_rev"): base, _, was_copy = parse_problem_name(problem_name[:-4]) return base, True, was_copy - elif len(problem_name) > 5 and problem_name[-5:] == "_copy": + elif problem_name.endswith("_copy"): base, was_reversed, _ = parse_problem_name(problem_name[:-5]) return base, was_reversed, True else: diff --git a/tensor2tensor/utils/trainer_utils_test.py b/tensor2tensor/utils/trainer_utils_test.py index 8a71afe68..61156f227 100644 --- a/tensor2tensor/utils/trainer_utils_test.py +++ b/tensor2tensor/utils/trainer_utils_test.py @@ -36,11 +36,12 @@ class TinyAlgo(algorithmic.AlgorithmicIdentityBinary40): def generate_data(self, data_dir, _): + identity_problem = algorithmic.AlgorithmicIdentityBinary40() generator_utils.generate_files( - algorithmic.identity_generator(self.num_symbols, 40, 100000), + identity_problem.generator(self.num_symbols, 40, 100000), self.training_filepaths(data_dir, 1, shuffled=True), 100) generator_utils.generate_files( - algorithmic.identity_generator(self.num_symbols, 400, 10000), + identity_problem.generator(self.num_symbols, 400, 10000), self.dev_filepaths(data_dir, 1, shuffled=True), 100) From d1f9bb26d3ebaaa65d1b26069ad6253b628aefd4 Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Fri, 11 Aug 2017 14:01:02 -0700 Subject: [PATCH 5/7] Fix memory usage of rev_block PiperOrigin-RevId: 165021509 --- tensor2tensor/layers/rev_block.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/tensor2tensor/layers/rev_block.py b/tensor2tensor/layers/rev_block.py index 1e1a7b848..d6fb95cf3 100644 --- a/tensor2tensor/layers/rev_block.py +++ b/tensor2tensor/layers/rev_block.py @@ -41,7 +41,7 @@ def _rev_layer_forward(xs, f, g): y1 = x1 + f(x2) with tf.variable_scope("g"): y2 = x2 + g(y1) - return (y1, y2) + return tf.tuple([y1, y2]) def _rev_layer_backward(ys, grad_ys, f, g, f_vars, g_vars): @@ -65,17 +65,26 @@ def _rev_layer_backward(ys, grad_ys, f, g, f_vars, g_vars): # Compute gradients wrt to inputs # dL/dy2 * dG(y1)/y1 - grad_gy1_y2 = tf.gradients(gy1, y1_stop, grad_y2)[0] + grad_gy1_y2 = tf.gradients(gy1, y1_stop, grad_y2, gate_gradients=True)[0] grad_x1 = grad_y1 + grad_gy1_y2 - grad_x2 = (tf.gradients(fx2, x2_stop, grad_y1)[0] + grad_y2 + tf.gradients( - fx2, x2_stop, grad_gy1_y2)[0]) + grad_x2 = ( + tf.gradients(fx2, x2_stop, grad_y1, gate_gradients=True)[0] + grad_y2 + + tf.gradients(fx2, x2_stop, grad_gy1_y2, gate_gradients=True)[0]) # Compute gradients wrt to vars in f and g - grad_g_vars = tf.gradients(gy1, g_vars, grad_y2) - grad_f_y1 = tf.gradients(fx2, f_vars, grad_y1) - grad_f_y2 = tf.gradients(fx2, f_vars, grad_gy1_y2) + grad_g_vars = tf.gradients(gy1, g_vars, grad_y2, gate_gradients=True) + grad_f_y1 = tf.gradients(fx2, f_vars, grad_y1, gate_gradients=True) + grad_f_y2 = tf.gradients(fx2, f_vars, grad_gy1_y2, gate_gradients=True) grad_f_vars = [tf.add_n(grads) for grads in zip(grad_f_y1, grad_f_y2)] + # Put returns in a tuple to ensure a constant memory budget (i.e. don't want + # the subsequent layer to start computing and consuming memory based on a + # subset of these values). + outs = tf.tuple([x1, x2, grad_x1, grad_x2] + grad_f_vars + grad_g_vars) + x1, x2, grad_x1, grad_x2 = outs[:4] + grad_f_vars = outs[4:4 + len(grad_f_vars)] + grad_g_vars = outs[4 + len(grad_f_vars):] + return (x1, x2), (grad_x1, grad_x2), grad_f_vars, grad_g_vars From b31b3ae341407139ea0c52e8e813896db866f56e Mon Sep 17 00:00:00 2001 From: Lukasz Kaiser Date: Fri, 11 Aug 2017 15:10:45 -0700 Subject: [PATCH 6/7] Play more with VAE, small corrections elsewhere. PiperOrigin-RevId: 165031077 --- tensor2tensor/layers/modalities.py | 5 +- tensor2tensor/models/cycle_gan.py | 14 +-- tensor2tensor/models/shake_shake.py | 2 - tensor2tensor/models/transformer.py | 1 - tensor2tensor/models/transformer_vae.py | 129 ++++++++++++++++-------- 5 files changed, 98 insertions(+), 53 deletions(-) diff --git a/tensor2tensor/layers/modalities.py b/tensor2tensor/layers/modalities.py index acaacbf99..84f9adbe7 100644 --- a/tensor2tensor/layers/modalities.py +++ b/tensor2tensor/layers/modalities.py @@ -406,10 +406,11 @@ def top(self, body_output, _): # Assume input is a square with self._body_input_depth channels. if self._is_2d: length_float = tf.to_float(tf.shape(x)[1]) + length_float *= tf.to_float(tf.shape(x)[2]) spatial_dim_float = tf.sqrt(length_float) spatial_dim = tf.to_int32(spatial_dim_float) - x = tf.reshape(x, - [-1, spatial_dim, spatial_dim, self._body_input_depth]) + x_depth = int(x.get_shape()[3]) + x = tf.reshape(x, [-1, spatial_dim, spatial_dim, x_depth]) x = common_layers.conv_block_downsample(x, self._kernel, self._strides, self._padding) x = tf.nn.relu(x) diff --git a/tensor2tensor/models/cycle_gan.py b/tensor2tensor/models/cycle_gan.py index 5fcf96266..c17becbbe 100644 --- a/tensor2tensor/models/cycle_gan.py +++ b/tensor2tensor/models/cycle_gan.py @@ -39,7 +39,7 @@ def discriminator(x, compress, hparams, name, reuse=None): with tf.variable_scope(name, reuse=reuse): x = tf.stop_gradient(2 * x) - x # Reverse gradient. if compress: - x = transformer_vae.compress(x, hparams, "compress") + x = transformer_vae.compress(x, None, hparams, "compress") else: x = transformer_vae.residual_conv(x, 1, hparams, "compress_rc") y = tf.reduce_mean(x, axis=1) @@ -144,12 +144,12 @@ def cycle_vae_gan_internal(inputs, targets, _, hparams): # Input-input part. inp1_back, kl_loss1, inp1_mu, inp1_log_sigma = transformer_vae.vae_compress( - inputs1, hparams, "inp2hyp", "hyp2inp") + inputs1, None, hparams, "inp2hyp", "hyp2inp") inp1_hyp = tf.concat([inp1_mu, inp1_log_sigma], axis=3) # Target-target part. tgt2_back, kl_loss2, tgt2_mu, tgt2_log_sigma = transformer_vae.vae_compress( - targets2, hparams, "tgt2hyp", "hyp2tgt") + targets2, None, hparams, "tgt2hyp", "hyp2tgt") tgt2_hyp = tf.concat([tgt2_mu, tgt2_log_sigma], axis=3) # Reconstruction losses. @@ -165,7 +165,7 @@ def cycle_vae_gan_internal(inputs, targets, _, hparams): # Reconstruct targets from inputs. tgt, _, _, _ = transformer_vae.vae_compress( - inputs, hparams, "inp2hyp", "hyp2tgt", reuse=True) + inputs, None, hparams, "inp2hyp", "hyp2tgt", reuse=True) tgt = tf.layers.dense(tgt, hparams.vocab_size, name="softmax", reuse=True) # We use the reconstruction only for tracking progress, no gradients here! tgt = tf.stop_gradient(tf.expand_dims(tgt, axis=2)) @@ -173,8 +173,8 @@ def cycle_vae_gan_internal(inputs, targets, _, hparams): kl_rev_decay = common_layers.inverse_exp_decay(hparams.kl_warmup_steps) losses = {"input_input": hparams.cycle_loss_multiplier * inp1_loss, "target_target": hparams.cycle_loss_multiplier * tgt2_loss, - "input_kl": kl_loss1 * kl_rev_decay, - "target_kl": kl_loss2 * kl_rev_decay, + "input_kl": kl_loss1 * kl_rev_decay * 15.0, + "target_kl": kl_loss2 * kl_rev_decay * 15.0, "discriminator": dloss} return tgt, losses @@ -196,7 +196,7 @@ def cycle_gan_small(): hparams.input_modalities = "inputs:symbol:identity" hparams.target_modality = "symbol:identity" hparams.weight_decay = 3.0 - hparams.learning_rate = 0.005 + hparams.learning_rate = 0.05 hparams.kl_warmup_steps = 5000 hparams.learning_rate_warmup_steps = 3000 hparams.add_hparam("vocab_size", 32) # Vocabulary size, need to set here. diff --git a/tensor2tensor/models/shake_shake.py b/tensor2tensor/models/shake_shake.py index aa91654a3..a7b379e11 100644 --- a/tensor2tensor/models/shake_shake.py +++ b/tensor2tensor/models/shake_shake.py @@ -100,8 +100,6 @@ class ShakeShake(t2t_model.T2TModel): def model_fn_body(self, features): hparams = self._hparams - print(hparams.learning_rate) - inputs = features["inputs"] assert (hparams.num_hidden_layers - 2) % 6 == 0 blocks_per_stage = (hparams.num_hidden_layers - 2) // 6 diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index 06f49b231..0eed2dbdb 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -244,7 +244,6 @@ def transformer_decoder(decoder_input, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) x = common_layers.layer_postprocess(x, y, hparams) if encoder_output is not None: - assert encoder_decoder_attention_bias is not None with tf.variable_scope("encdec_attention"): y = common_attention.multihead_attention( common_layers.layer_preprocess( diff --git a/tensor2tensor/models/transformer_vae.py b/tensor2tensor/models/transformer_vae.py index 74f1e4c8f..ffd791a04 100644 --- a/tensor2tensor/models/transformer_vae.py +++ b/tensor2tensor/models/transformer_vae.py @@ -23,6 +23,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin +from tensor2tensor.layers import common_attention from tensor2tensor.layers import common_layers from tensor2tensor.models import transformer from tensor2tensor.utils import registry @@ -49,13 +50,43 @@ def residual_conv(x, repeat, hparams, name, reuse=None): return x -def decompress_step(source, hparams, first_relu, name): +def attend(x, source, hparams, name): + with tf.variable_scope(name): + x = tf.squeeze(x, axis=2) + if len(source.get_shape()) > 3: + source = tf.squeeze(source, axis=2) + source = common_attention.add_timing_signal_1d(source) + y = common_attention.multihead_attention( + common_layers.layer_preprocess(x, hparams), source, None, + hparams.attention_key_channels or hparams.hidden_size, + hparams.attention_value_channels or hparams.hidden_size, + hparams.hidden_size, hparams.num_heads, + hparams.attention_dropout) + res = common_layers.layer_postprocess(x, y, hparams) + return tf.expand_dims(res, axis=2) + + +def interleave(x, y, axis=1): + x = tf.expand_dims(x, axis=axis+1) + y = tf.expand_dims(y, axis=axis+1) + return tf.concat([x, y], axis=axis+1) + + +def decompress_step(source, c, hparams, first_relu, name): """Decompression function.""" with tf.variable_scope(name): shape = tf.shape(source) - thicker = common_layers.conv_block( - source, hparams.hidden_size * 2, [((1, 1), (1, 1))], - first_relu=first_relu, name="decompress_conv") + if c is not None: + source = attend(source, c, hparams, "decompress_attend") + first = common_layers.conv_block( + source, + hparams.hidden_size, [((1, 1), (3, 1)), ((1, 1), (3, 1))], + first_relu=first_relu, padding="SAME", name="decompress_conv1") + second = common_layers.conv_block( + tf.concat([source, first], axis=3), + hparams.hidden_size, [((1, 1), (3, 1)), ((1, 1), (3, 1))], + first_relu=first_relu, padding="SAME", name="decompress_conv2") + thicker = interleave(first, second) return tf.reshape(thicker, [shape[0], shape[1] * 2, 1, hparams.hidden_size]) @@ -71,12 +102,14 @@ def vae(x, hparams, name): return z, tf.reduce_mean(kl), mu, log_sigma -def compress(inputs, hparams, name): +def compress(x, c, hparams, name): """Compress.""" with tf.variable_scope(name): # Run compression by strided convs. - cur = inputs + cur = x for i in xrange(hparams.num_compress_steps): + if c is not None: + cur = attend(cur, c, hparams, "compress_attend_%d" % i) cur = residual_conv(cur, 1, hparams, "compress_rc_%d" % i) cur = common_layers.conv_block( cur, hparams.hidden_size, [((1, 1), (2, 1))], @@ -84,10 +117,10 @@ def compress(inputs, hparams, name): return cur -def vae_compress(inputs, hparams, compress_name, decompress_name, reuse=None): +def vae_compress(x, c, hparams, compress_name, decompress_name, reuse=None): """Compress, then VAE.""" with tf.variable_scope(compress_name, reuse=reuse): - cur = compress(inputs, hparams, "compress") + cur = compress(x, c, hparams, "compress") # Convolve and ReLu to get state. cur = common_layers.conv_block( cur, hparams.hidden_size, [((1, 1), (1, 1))], name="mid_conv") @@ -100,7 +133,7 @@ def vae_compress(inputs, hparams, compress_name, decompress_name, reuse=None): for i in xrange(hparams.num_compress_steps): j = hparams.num_compress_steps - i - 1 z = residual_conv(z, 1, hparams, "decompress_rc_%d" % j) - z = decompress_step(z, hparams, i > 0, "decompress__step_%d" % j) + z = decompress_step(z, c, hparams, i > 0, "decompress__step_%d" % j) return z, kl_loss, mu, log_sigma @@ -124,6 +157,13 @@ def dropmask(targets, targets_dropout_max, is_training): return targets * keep_mask +def ffn(x, hparams, name): + with tf.variable_scope(name): + y = transformer.transformer_ffn_layer( + common_layers.layer_preprocess(x, hparams), hparams) + return common_layers.layer_postprocess(x, y, hparams) + + def vae_transformer_internal(inputs, targets, target_space, hparams): """VAE Transformer, main step used for training.""" with tf.variable_scope("vae_transformer"): @@ -140,36 +180,40 @@ def vae_transformer_internal(inputs, targets, target_space, hparams): inputs = encode(inputs, target_space, hparams, "input_enc") # Dropout targets or swap for zeros 5% of the time. + targets_nodrop = targets max_prestep = hparams.kl_warmup_steps prob_targets = 0.95 if is_training else 1.0 targets_dropout_max = common_layers.inverse_lin_decay(max_prestep) - 0.01 targets = dropmask(targets, targets_dropout_max * 0.7, is_training) targets = tf.cond(tf.less(tf.random_uniform([]), prob_targets), lambda: targets, lambda: tf.zeros_like(targets)) - - # Join targets with inputs, run encoder. - # to_encode = common_layers.conv_block( - # tf.expand_dims(tf.concat([targets, inputs], axis=2), axis=2), - # hparams.hidden_size, [((1, 1), (1, 1))], - # first_relu=False, name="join_targets") - # to_compress = encode(tf.squeeze(to_encode, axis=2), - # target_space, hparams, "enc") + targets = targets_nodrop # Compress and vae. - z, kl_loss, _, _ = vae_compress(tf.expand_dims(targets, axis=2), hparams, - "vae_compress", "vae_decompress") + z = tf.get_variable("z", [hparams.hidden_size]) + z = tf.reshape(z, [1, 1, 1, -1]) + z = tf.tile(z, [tf.shape(inputs)[0], 1, 1, 1]) + + z = attend(z, inputs, hparams, "z_attendsi") + z = ffn(z, hparams, "zff2") + z = attend(z, targets, hparams, "z_attendst2") + z = ffn(z, hparams, "zff3") + z, kl_loss, _, _ = vae(z, hparams, name="vae") + z = tf.layers.dense(z, hparams.hidden_size, name="z_to_dense") + + # z, kl_loss, _, _ = vae_compress( + # tf.expand_dims(targets, axis=2), tf.expand_dims(inputs, axis=2), + # hparams, "vae_compress", "vae_decompress") - # Join z with inputs, run decoder. - to_decode = common_layers.conv_block( - tf.concat([z, tf.expand_dims(inputs, axis=2)], axis=3), - hparams.hidden_size, [((1, 1), (1, 1))], name="join_z") - ret = encode(tf.squeeze(to_decode, axis=2), target_space, hparams, "dec") - # to_decode = residual_conv(to_decode, 2, hparams, "dec_conv") - # ret = tf.squeeze(to_decode, axis=2) + decoder_in = tf.squeeze(z, axis=2) + tf.zeros_like(targets) + (decoder_input, decoder_self_attention_bias) = ( + transformer.transformer_prepare_decoder(decoder_in, hparams)) + ret = transformer.transformer_decoder( + decoder_input, inputs, decoder_self_attention_bias, None, hparams) - # Randomize decoder inputs.. - kl_loss *= common_layers.inverse_exp_decay(max_prestep) * 10.0 - return tf.expand_dims(ret, axis=2), kl_loss + kl_loss *= common_layers.inverse_exp_decay(int(max_prestep * 1.5)) * 5.0 + losses = {"kl": kl_loss} + return tf.expand_dims(ret, axis=2), losses @registry.register_model @@ -203,13 +247,15 @@ def infer(self, features=None, decode_length=50, beam_size=1, top_beams=1, sharded_samples = self._data_parallelism(tf.argmax, sharded_logits, 4) samples = tf.concat(sharded_samples, 0) - # 2nd step. - with tf.variable_scope(tf.get_variable_scope(), reuse=True): - features["targets"] = samples - sharded_logits, _ = self.model_fn( - features, False, last_position_only=last_position_only) - sharded_samples = self._data_parallelism(tf.argmax, sharded_logits, 4) - samples = tf.concat(sharded_samples, 0) + # More steps. + how_many_more_steps = 20 + for _ in xrange(how_many_more_steps): + with tf.variable_scope(tf.get_variable_scope(), reuse=True): + features["targets"] = samples + sharded_logits, _ = self.model_fn( + features, False, last_position_only=last_position_only) + sharded_samples = self._data_parallelism(tf.argmax, sharded_logits, 4) + samples = tf.concat(sharded_samples, 0) if inputs_old is not None: # Restore to not confuse Estimator. features["inputs"] = inputs_old @@ -221,9 +267,10 @@ def transformer_vae_small(): """Set of hyperparameters.""" hparams = transformer.transformer_small() hparams.batch_size = 2048 + hparams.learning_rate_warmup_steps = 16000 hparams.add_hparam("z_size", 128) hparams.add_hparam("num_compress_steps", 4) - hparams.add_hparam("kl_warmup_steps", 50000) + hparams.add_hparam("kl_warmup_steps", 60000) return hparams @@ -233,9 +280,9 @@ def transformer_vae_base(): hparams = transformer_vae_small() hparams.hidden_size = 512 hparams.filter_size = 2048 - hparams.attention_dropout = 0.1 - hparams.relu_dropout = 0.1 - hparams.dropout = 0.1 - hparams.num_hidden_layers = 4 + hparams.attention_dropout = 0.0 + hparams.relu_dropout = 0.0 + hparams.dropout = 0.0 + hparams.num_hidden_layers = 3 hparams.z_size = 256 return hparams From 8abc5d29b4b22a93c4aaa9ea17aa3b3302d1da86 Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Fri, 11 Aug 2017 16:16:35 -0700 Subject: [PATCH 7/7] v1.1.8 PiperOrigin-RevId: 165038950 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 4ada714b6..ff1503990 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='tensor2tensor', - version='1.1.7', + version='1.1.8', description='Tensor2Tensor', author='Google Inc.', author_email='no-reply@google.com',