diff --git a/tensor2tensor/bin/t2t-datagen b/tensor2tensor/bin/t2t-datagen old mode 100644 new mode 100755 index 39453dbee..97bbd1241 --- a/tensor2tensor/bin/t2t-datagen +++ b/tensor2tensor/bin/t2t-datagen @@ -82,16 +82,6 @@ _SUPPORTED_PROBLEM_GENERATORS = { "algorithmic_algebra_inverse": ( lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000), lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)), - "ice_parsing_tokens": ( - lambda: wmt.tabbed_parsing_token_generator( - FLAGS.data_dir, FLAGS.tmp_dir, True, "ice", 2**13, 2**8), - lambda: wmt.tabbed_parsing_token_generator( - FLAGS.data_dir, FLAGS.tmp_dir, False, "ice", 2**13, 2**8)), - "ice_parsing_characters": ( - lambda: wmt.tabbed_parsing_character_generator( - FLAGS.data_dir, FLAGS.tmp_dir, True), - lambda: wmt.tabbed_parsing_character_generator( - FLAGS.data_dir, FLAGS.tmp_dir, False)), "wmt_parsing_tokens_8k": ( lambda: wmt.parsing_token_generator( FLAGS.data_dir, FLAGS.tmp_dir, True, 2**13), diff --git a/tensor2tensor/bin/t2t-trainer b/tensor2tensor/bin/t2t-trainer old mode 100644 new mode 100755 diff --git a/tensor2tensor/data_generators/all_problems.py b/tensor2tensor/data_generators/all_problems.py old mode 100644 new mode 100755 index ca6dccfda..10a4764f5 --- a/tensor2tensor/data_generators/all_problems.py +++ b/tensor2tensor/data_generators/all_problems.py @@ -31,6 +31,7 @@ from tensor2tensor.data_generators import wiki from tensor2tensor.data_generators import wmt from tensor2tensor.data_generators import wsj_parsing +from tensor2tensor.data_generators import ice_parsing # Problem modules that require optional dependencies diff --git a/tensor2tensor/data_generators/generator_utils.py b/tensor2tensor/data_generators/generator_utils.py old mode 100644 new mode 100755 diff --git a/tensor2tensor/data_generators/ice_parsing.py b/tensor2tensor/data_generators/ice_parsing.py new file mode 100755 index 000000000..7a90fec45 --- /dev/null +++ b/tensor2tensor/data_generators/ice_parsing.py @@ -0,0 +1,117 @@ +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This module implements the ice_parsing_* problems, which +# parse plain text into flattened parse trees and POS tags. +# The training data is stored in files named `parsing_train.pairs` +# and `parsing_dev.pairs`. These files are UTF-8 text files where +# each line contains an input sentence and a target parse tree, +# separated by a tab character. + +import os + +# Dependency imports + +from tensor2tensor.data_generators import generator_utils +from tensor2tensor.data_generators import problem +from tensor2tensor.data_generators import text_encoder +from tensor2tensor.data_generators.wmt import tabbed_generator +from tensor2tensor.utils import registry + +import tensorflow as tf + + +# End-of-sentence marker. +EOS = text_encoder.EOS_ID + + +def tabbed_parsing_token_generator(data_dir, tmp_dir, train, prefix, + source_vocab_size, target_vocab_size): + """Generate source and target data from a single file.""" + filename = "parsing_{0}.pairs".format("train" if train else "dev") + source_vocab = generator_utils.get_or_generate_tabbed_vocab( + data_dir, tmp_dir, filename, 0, + prefix + "_source.tokens.vocab.%d" % source_vocab_size, source_vocab_size) + target_vocab = generator_utils.get_or_generate_tabbed_vocab( + data_dir, tmp_dir, filename, 1, + prefix + "_target.tokens.vocab.%d" % target_vocab_size, target_vocab_size) + pair_filepath = os.path.join(tmp_dir, filename) + return tabbed_generator(pair_filepath, source_vocab, target_vocab, EOS) + + +def tabbed_parsing_character_generator(tmp_dir, train): + """Generate source and target data from a single file.""" + character_vocab = text_encoder.ByteTextEncoder() + filename = "parsing_{0}.pairs".format("train" if train else "dev") + pair_filepath = os.path.join(tmp_dir, filename) + return tabbed_generator(pair_filepath, character_vocab, character_vocab, EOS) + + +@registry.register_problem("ice_parsing_tokens") +class IceParsingTokens(problem.Problem): + """Problem spec for parsing tokenized Icelandic text to + constituency trees, also tokenized but to a smaller vocabulary.""" + + @property + def source_vocab_size(self): + return 2**14 # 16384 + + @property + def targeted_vocab_size(self): + return 2**8 # 256 + + @property + def input_space_id(self): + return problem.SpaceID.ICE_TOK + + @property + def target_space_id(self): + return problem.SpaceID.ICE_PARSE_TOK + + @property + def num_shards(self): + return 10 + + def feature_encoders(self, data_dir): + source_vocab_filename = os.path.join( + data_dir, "ice_source.tokens.vocab.%d" % self.source_vocab_size) + target_vocab_filename = os.path.join( + data_dir, "ice_target.tokens.vocab.%d" % self.targeted_vocab_size) + source_subtokenizer = text_encoder.SubwordTextEncoder(source_vocab_filename) + target_subtokenizer = text_encoder.SubwordTextEncoder(target_vocab_filename) + return { + "inputs": source_subtokenizer, + "targets": target_subtokenizer, + } + + def generate_data(self, data_dir, tmp_dir, task_id=-1): + generator_utils.generate_dataset_and_shuffle( + tabbed_parsing_token_generator(data_dir, tmp_dir, True, "ice", + self.source_vocab_size, + self.targeted_vocab_size), + self.training_filepaths(data_dir, self.num_shards, shuffled=False), + tabbed_parsing_token_generator(data_dir, tmp_dir, False, "ice", + self.source_vocab_size, + self.targeted_vocab_size), + self.dev_filepaths(data_dir, 1, shuffled=False)) + + def hparams(self, defaults, model_hparams): + p = defaults + source_vocab_size = self._encoders["inputs"].vocab_size + p.input_modality = {"inputs": (registry.Modalities.SYMBOL, source_vocab_size)} + p.target_modality = (registry.Modalities.SYMBOL, self.targeted_vocab_size) + p.input_space_id = self.input_space_id + p.target_space_id = self.target_space_id + p.loss_multiplier = 2.5 # Rough estimate of avg number of tokens per word + diff --git a/tensor2tensor/data_generators/problem_hparams.py b/tensor2tensor/data_generators/problem_hparams.py old mode 100644 new mode 100755 index d0577db52..b0ed44f5b --- a/tensor2tensor/data_generators/problem_hparams.py +++ b/tensor2tensor/data_generators/problem_hparams.py @@ -462,39 +462,6 @@ def wsj_parsing_tokens(model_hparams, prefix, wrong_source_vocab_size, return p -def ice_parsing_tokens(model_hparams, wrong_source_vocab_size): - """Icelandic to parse tree translation benchmark. - - Args: - model_hparams: a tf.contrib.training.HParams - wrong_source_vocab_size: a number used in the filename indicating the - approximate vocabulary size. This is not to be confused with the actual - vocabulary size. - - Returns: - A tf.contrib.training.HParams object. - """ - p = default_problem_hparams() - # This vocab file must be present within the data directory. - source_vocab_filename = os.path.join( - model_hparams.data_dir, "ice_source.vocab.%d" % wrong_source_vocab_size) - target_vocab_filename = os.path.join(model_hparams.data_dir, - "ice_target.vocab.256") - source_subtokenizer = text_encoder.SubwordTextEncoder(source_vocab_filename) - target_subtokenizer = text_encoder.SubwordTextEncoder(target_vocab_filename) - p.input_modality = { - "inputs": (registry.Modalities.SYMBOL, source_subtokenizer.vocab_size) - } - p.target_modality = (registry.Modalities.SYMBOL, 256) - p.vocabulary = { - "inputs": source_subtokenizer, - "targets": target_subtokenizer, - } - p.input_space_id = 18 # Icelandic tokens - p.target_space_id = 19 # Icelandic parse tokens - return p - - def img2img_imagenet(unused_model_hparams): """Image 2 Image for imagenet dataset.""" p = default_problem_hparams() @@ -544,10 +511,6 @@ def image_celeba(unused_model_hparams): lm1b_32k, "wiki_32k": wiki_32k, - "ice_parsing_characters": - wmt_parsing_characters, - "ice_parsing_tokens": - lambda p: ice_parsing_tokens(p, 2**13), "wmt_parsing_tokens_8k": lambda p: wmt_parsing_tokens(p, 2**13), "wsj_parsing_tokens_16k": diff --git a/tensor2tensor/data_generators/wmt.py b/tensor2tensor/data_generators/wmt.py old mode 100644 new mode 100755 index 0a47e9989..35d1b5fca --- a/tensor2tensor/data_generators/wmt.py +++ b/tensor2tensor/data_generators/wmt.py @@ -648,28 +648,6 @@ def target_space_id(self): return problem.SpaceID.CS_CHR -def tabbed_parsing_token_generator(data_dir, tmp_dir, train, prefix, - source_vocab_size, target_vocab_size): - """Generate source and target data from a single file.""" - source_vocab = generator_utils.get_or_generate_tabbed_vocab( - data_dir, tmp_dir, "parsing_train.pairs", 0, - prefix + "_source.vocab.%d" % source_vocab_size, source_vocab_size) - target_vocab = generator_utils.get_or_generate_tabbed_vocab( - data_dir, tmp_dir, "parsing_train.pairs", 1, - prefix + "_target.vocab.%d" % target_vocab_size, target_vocab_size) - filename = "parsing_%s" % ("train" if train else "dev") - pair_filepath = os.path.join(tmp_dir, filename + ".pairs") - return tabbed_generator(pair_filepath, source_vocab, target_vocab, EOS) - - -def tabbed_parsing_character_generator(tmp_dir, train): - """Generate source and target data from a single file.""" - character_vocab = text_encoder.ByteTextEncoder() - filename = "parsing_%s" % ("train" if train else "dev") - pair_filepath = os.path.join(tmp_dir, filename + ".pairs") - return tabbed_generator(pair_filepath, character_vocab, character_vocab, EOS) - - def parsing_token_generator(data_dir, tmp_dir, train, vocab_size): symbolizer_vocab = generator_utils.get_or_generate_vocab( data_dir, tmp_dir, "vocab.endefr.%d" % vocab_size, vocab_size) diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py old mode 100644 new mode 100755 index caf8ab198..fa7ecdf81 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -393,7 +393,7 @@ def transformer_parsing_big(): @registry.register_hparams def transformer_parsing_ice(): - """Hparams for parsing Icelandic text.""" + """Hparams for parsing and tagging Icelandic text.""" hparams = transformer_base_single_gpu() hparams.batch_size = 4096 hparams.shared_embedding_and_softmax_weights = int(False) diff --git a/tensor2tensor/utils/decoding.py b/tensor2tensor/utils/decoding.py old mode 100644 new mode 100755 index 5e8f4d482..fc9eb566f --- a/tensor2tensor/utils/decoding.py +++ b/tensor2tensor/utils/decoding.py @@ -259,6 +259,11 @@ def _interactive_input_fn(hparams): vocabulary = p_hparams.vocabulary["inputs" if has_input else "targets"] # This should be longer than the longest input. const_array_size = 10000 + # Import readline if available for command line editing and recall + try: + import readline + except ImportError: + pass while True: prompt = ("INTERACTIVE MODE num_samples=%d decode_length=%d \n" " it= ('text' or 'image' or 'label')\n" @@ -266,7 +271,7 @@ def _interactive_input_fn(hparams): " in= (set the input problem number)\n" " ou= (set the output problem number)\n" " ns= (changes number of samples)\n" - " dl= (changes decode legnth)\n" + " dl= (changes decode length)\n" " <%s> (decode)\n" " q (quit)\n" ">" % (num_samples, decode_length, "source_string" diff --git a/tensor2tensor/utils/registry.py b/tensor2tensor/utils/registry.py old mode 100644 new mode 100755 index fea647b2b..d79eef484 --- a/tensor2tensor/utils/registry.py +++ b/tensor2tensor/utils/registry.py @@ -225,10 +225,10 @@ def parse_problem_name(problem_name): was_copy: A boolean. """ # Recursively strip tags until we reach a base name. - if len(problem_name) > 4 and problem_name[-4:] == "_rev": + if problem_name.endswith("_rev"): base, _, was_copy = parse_problem_name(problem_name[:-4]) return base, True, was_copy - elif len(problem_name) > 5 and problem_name[-5:] == "_copy": + elif problem_name.endswith("_copy"): base, was_reversed, _ = parse_problem_name(problem_name[:-5]) return base, was_reversed, True else: @@ -352,7 +352,7 @@ def list_modalities(): def parse_modality_name(name): - name_parts = name.split(":") + name_parts = name.split(":", maxsplit=1) if len(name_parts) < 2: name_parts.append("default") modality_type, modality_name = name_parts