Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #220 from vthorsteinsson/ice
Browse files Browse the repository at this point in the history
Move Icelandic parsing problem to separate module
  • Loading branch information
lukaszkaiser committed Aug 11, 2017
2 parents 73f0be2 + ab9b004 commit b669110
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 74 deletions.
10 changes: 0 additions & 10 deletions tensor2tensor/bin/t2t-datagen
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,6 @@ _SUPPORTED_PROBLEM_GENERATORS = {
"algorithmic_algebra_inverse": (
lambda: algorithmic_math.algebra_inverse(26, 0, 2, 100000),
lambda: algorithmic_math.algebra_inverse(26, 3, 3, 10000)),
"ice_parsing_tokens": (
lambda: wmt.tabbed_parsing_token_generator(
FLAGS.data_dir, FLAGS.tmp_dir, True, "ice", 2**13, 2**8),
lambda: wmt.tabbed_parsing_token_generator(
FLAGS.data_dir, FLAGS.tmp_dir, False, "ice", 2**13, 2**8)),
"ice_parsing_characters": (
lambda: wmt.tabbed_parsing_character_generator(
FLAGS.data_dir, FLAGS.tmp_dir, True),
lambda: wmt.tabbed_parsing_character_generator(
FLAGS.data_dir, FLAGS.tmp_dir, False)),
"wmt_parsing_tokens_8k": (
lambda: wmt.parsing_token_generator(
FLAGS.data_dir, FLAGS.tmp_dir, True, 2**13),
Expand Down
Empty file modified tensor2tensor/bin/t2t-trainer
100644 → 100755
Empty file.
1 change: 1 addition & 0 deletions tensor2tensor/data_generators/all_problems.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from tensor2tensor.data_generators import wiki
from tensor2tensor.data_generators import wmt
from tensor2tensor.data_generators import wsj_parsing
from tensor2tensor.data_generators import ice_parsing


# Problem modules that require optional dependencies
Expand Down
Empty file modified tensor2tensor/data_generators/generator_utils.py
100644 → 100755
Empty file.
117 changes: 117 additions & 0 deletions tensor2tensor/data_generators/ice_parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright 2017 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This module implements the ice_parsing_* problems, which
# parse plain text into flattened parse trees and POS tags.
# The training data is stored in files named `parsing_train.pairs`
# and `parsing_dev.pairs`. These files are UTF-8 text files where
# each line contains an input sentence and a target parse tree,
# separated by a tab character.

import os

# Dependency imports

from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.data_generators.wmt import tabbed_generator
from tensor2tensor.utils import registry

import tensorflow as tf


# End-of-sentence marker.
EOS = text_encoder.EOS_ID


def tabbed_parsing_token_generator(data_dir, tmp_dir, train, prefix,
source_vocab_size, target_vocab_size):
"""Generate source and target data from a single file."""
filename = "parsing_{0}.pairs".format("train" if train else "dev")
source_vocab = generator_utils.get_or_generate_tabbed_vocab(
data_dir, tmp_dir, filename, 0,
prefix + "_source.tokens.vocab.%d" % source_vocab_size, source_vocab_size)
target_vocab = generator_utils.get_or_generate_tabbed_vocab(
data_dir, tmp_dir, filename, 1,
prefix + "_target.tokens.vocab.%d" % target_vocab_size, target_vocab_size)
pair_filepath = os.path.join(tmp_dir, filename)
return tabbed_generator(pair_filepath, source_vocab, target_vocab, EOS)


def tabbed_parsing_character_generator(tmp_dir, train):
"""Generate source and target data from a single file."""
character_vocab = text_encoder.ByteTextEncoder()
filename = "parsing_{0}.pairs".format("train" if train else "dev")
pair_filepath = os.path.join(tmp_dir, filename)
return tabbed_generator(pair_filepath, character_vocab, character_vocab, EOS)


@registry.register_problem("ice_parsing_tokens")
class IceParsingTokens(problem.Problem):
"""Problem spec for parsing tokenized Icelandic text to
constituency trees, also tokenized but to a smaller vocabulary."""

@property
def source_vocab_size(self):
return 2**14 # 16384

@property
def targeted_vocab_size(self):
return 2**8 # 256

@property
def input_space_id(self):
return problem.SpaceID.ICE_TOK

@property
def target_space_id(self):
return problem.SpaceID.ICE_PARSE_TOK

@property
def num_shards(self):
return 10

def feature_encoders(self, data_dir):
source_vocab_filename = os.path.join(
data_dir, "ice_source.tokens.vocab.%d" % self.source_vocab_size)
target_vocab_filename = os.path.join(
data_dir, "ice_target.tokens.vocab.%d" % self.targeted_vocab_size)
source_subtokenizer = text_encoder.SubwordTextEncoder(source_vocab_filename)
target_subtokenizer = text_encoder.SubwordTextEncoder(target_vocab_filename)
return {
"inputs": source_subtokenizer,
"targets": target_subtokenizer,
}

def generate_data(self, data_dir, tmp_dir, task_id=-1):
generator_utils.generate_dataset_and_shuffle(
tabbed_parsing_token_generator(data_dir, tmp_dir, True, "ice",
self.source_vocab_size,
self.targeted_vocab_size),
self.training_filepaths(data_dir, self.num_shards, shuffled=False),
tabbed_parsing_token_generator(data_dir, tmp_dir, False, "ice",
self.source_vocab_size,
self.targeted_vocab_size),
self.dev_filepaths(data_dir, 1, shuffled=False))

def hparams(self, defaults, model_hparams):
p = defaults
source_vocab_size = self._encoders["inputs"].vocab_size
p.input_modality = {"inputs": (registry.Modalities.SYMBOL, source_vocab_size)}
p.target_modality = (registry.Modalities.SYMBOL, self.targeted_vocab_size)
p.input_space_id = self.input_space_id
p.target_space_id = self.target_space_id
p.loss_multiplier = 2.5 # Rough estimate of avg number of tokens per word

37 changes: 0 additions & 37 deletions tensor2tensor/data_generators/problem_hparams.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -462,39 +462,6 @@ def wsj_parsing_tokens(model_hparams, prefix, wrong_source_vocab_size,
return p


def ice_parsing_tokens(model_hparams, wrong_source_vocab_size):
"""Icelandic to parse tree translation benchmark.
Args:
model_hparams: a tf.contrib.training.HParams
wrong_source_vocab_size: a number used in the filename indicating the
approximate vocabulary size. This is not to be confused with the actual
vocabulary size.
Returns:
A tf.contrib.training.HParams object.
"""
p = default_problem_hparams()
# This vocab file must be present within the data directory.
source_vocab_filename = os.path.join(
model_hparams.data_dir, "ice_source.vocab.%d" % wrong_source_vocab_size)
target_vocab_filename = os.path.join(model_hparams.data_dir,
"ice_target.vocab.256")
source_subtokenizer = text_encoder.SubwordTextEncoder(source_vocab_filename)
target_subtokenizer = text_encoder.SubwordTextEncoder(target_vocab_filename)
p.input_modality = {
"inputs": (registry.Modalities.SYMBOL, source_subtokenizer.vocab_size)
}
p.target_modality = (registry.Modalities.SYMBOL, 256)
p.vocabulary = {
"inputs": source_subtokenizer,
"targets": target_subtokenizer,
}
p.input_space_id = 18 # Icelandic tokens
p.target_space_id = 19 # Icelandic parse tokens
return p


def img2img_imagenet(unused_model_hparams):
"""Image 2 Image for imagenet dataset."""
p = default_problem_hparams()
Expand Down Expand Up @@ -544,10 +511,6 @@ def image_celeba(unused_model_hparams):
lm1b_32k,
"wiki_32k":
wiki_32k,
"ice_parsing_characters":
wmt_parsing_characters,
"ice_parsing_tokens":
lambda p: ice_parsing_tokens(p, 2**13),
"wmt_parsing_tokens_8k":
lambda p: wmt_parsing_tokens(p, 2**13),
"wsj_parsing_tokens_16k":
Expand Down
22 changes: 0 additions & 22 deletions tensor2tensor/data_generators/wmt.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -648,28 +648,6 @@ def target_space_id(self):
return problem.SpaceID.CS_CHR


def tabbed_parsing_token_generator(data_dir, tmp_dir, train, prefix,
source_vocab_size, target_vocab_size):
"""Generate source and target data from a single file."""
source_vocab = generator_utils.get_or_generate_tabbed_vocab(
data_dir, tmp_dir, "parsing_train.pairs", 0,
prefix + "_source.vocab.%d" % source_vocab_size, source_vocab_size)
target_vocab = generator_utils.get_or_generate_tabbed_vocab(
data_dir, tmp_dir, "parsing_train.pairs", 1,
prefix + "_target.vocab.%d" % target_vocab_size, target_vocab_size)
filename = "parsing_%s" % ("train" if train else "dev")
pair_filepath = os.path.join(tmp_dir, filename + ".pairs")
return tabbed_generator(pair_filepath, source_vocab, target_vocab, EOS)


def tabbed_parsing_character_generator(tmp_dir, train):
"""Generate source and target data from a single file."""
character_vocab = text_encoder.ByteTextEncoder()
filename = "parsing_%s" % ("train" if train else "dev")
pair_filepath = os.path.join(tmp_dir, filename + ".pairs")
return tabbed_generator(pair_filepath, character_vocab, character_vocab, EOS)


def parsing_token_generator(data_dir, tmp_dir, train, vocab_size):
symbolizer_vocab = generator_utils.get_or_generate_vocab(
data_dir, tmp_dir, "vocab.endefr.%d" % vocab_size, vocab_size)
Expand Down
2 changes: 1 addition & 1 deletion tensor2tensor/models/transformer.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def transformer_parsing_big():

@registry.register_hparams
def transformer_parsing_ice():
"""Hparams for parsing Icelandic text."""
"""Hparams for parsing and tagging Icelandic text."""
hparams = transformer_base_single_gpu()
hparams.batch_size = 4096
hparams.shared_embedding_and_softmax_weights = int(False)
Expand Down
7 changes: 6 additions & 1 deletion tensor2tensor/utils/decoding.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -259,14 +259,19 @@ def _interactive_input_fn(hparams):
vocabulary = p_hparams.vocabulary["inputs" if has_input else "targets"]
# This should be longer than the longest input.
const_array_size = 10000
# Import readline if available for command line editing and recall
try:
import readline
except ImportError:
pass
while True:
prompt = ("INTERACTIVE MODE num_samples=%d decode_length=%d \n"
" it=<input_type> ('text' or 'image' or 'label')\n"
" pr=<problem_num> (set the problem number)\n"
" in=<input_problem> (set the input problem number)\n"
" ou=<output_problem> (set the output problem number)\n"
" ns=<num_samples> (changes number of samples)\n"
" dl=<decode_length> (changes decode legnth)\n"
" dl=<decode_length> (changes decode length)\n"
" <%s> (decode)\n"
" q (quit)\n"
">" % (num_samples, decode_length, "source_string"
Expand Down
6 changes: 3 additions & 3 deletions tensor2tensor/utils/registry.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,10 @@ def parse_problem_name(problem_name):
was_copy: A boolean.
"""
# Recursively strip tags until we reach a base name.
if len(problem_name) > 4 and problem_name[-4:] == "_rev":
if problem_name.endswith("_rev"):
base, _, was_copy = parse_problem_name(problem_name[:-4])
return base, True, was_copy
elif len(problem_name) > 5 and problem_name[-5:] == "_copy":
elif problem_name.endswith("_copy"):
base, was_reversed, _ = parse_problem_name(problem_name[:-5])
return base, was_reversed, True
else:
Expand Down Expand Up @@ -352,7 +352,7 @@ def list_modalities():


def parse_modality_name(name):
name_parts = name.split(":")
name_parts = name.split(":", maxsplit=1)
if len(name_parts) < 2:
name_parts.append("default")
modality_type, modality_name = name_parts
Expand Down

0 comments on commit b669110

Please sign in to comment.