Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Merge pull request #484 from rsepassi/push
Browse files Browse the repository at this point in the history
v1.4.1
  • Loading branch information
lukaszkaiser authored Dec 22, 2017
2 parents 758991d + 02da1be commit c1cd875
Show file tree
Hide file tree
Showing 17 changed files with 973 additions and 252 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ env:
- T2T_DATA_DIR=/tmp/t2t-data
- T2T_TRAIN_DIR=/tmp/t2t-train
script:
- pytest --ignore=tensor2tensor/utils/registry_test.py --ignore=tensor2tensor/problems_test.py --ignore=tensor2tensor/tpu/tpu_trainer_lib_test.py
- pytest --ignore=tensor2tensor/utils/registry_test.py --ignore=tensor2tensor/problems_test.py --ignore=tensor2tensor/tpu/tpu_trainer_lib_test.py --ignore=tensor2tensor/data_generators/algorithmic_math_test.py
- pytest tensor2tensor/utils/registry_test.py
- pytest tensor2tensor/tpu/tpu_trainer_lib_test.py
- t2t-datagen 2>&1 | grep translate && echo passed
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='tensor2tensor',
version='1.4.0',
version='1.4.1',
description='Tensor2Tensor',
author='Google Inc.',
author_email='no-reply@google.com',
Expand All @@ -30,6 +30,7 @@
'gym',
'numpy',
'requests',
'scipy',
'sympy',
'six',
],
Expand Down
14 changes: 8 additions & 6 deletions tensor2tensor/bin/t2t-trainer
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ flags.DEFINE_string("t2t_usr_dir", "",
"The imported files should contain registrations, "
"e.g. @registry.register_model calls, that will then be "
"available to the t2t-trainer.")
flags.DEFINE_integer("random_seed", 1234, "Random seed.")
flags.DEFINE_integer("tpu_num_shards", 8, "Number of tpu shards.")
flags.DEFINE_integer("iterations_per_loop", 1000,
"Number of iterations in a TPU training loop.")
Expand All @@ -61,7 +62,11 @@ try:
flags.DEFINE_string("output_dir", "", "Base output directory for run.")
flags.DEFINE_string("schedule", "continuous_train_and_eval",
"Method of Experiment to run.")
flags.DEFINE_integer("eval_steps", 200, "Number of steps in evaluation.")
flags.DEFINE_integer("eval_steps", 10000,
"Number of steps in evaluation. By default, eval will "
"stop after eval_steps or when it runs through the eval "
"dataset once in full, whichever comes first, so this "
"can be a very large number.")
except: # pylint: disable=bare-except
pass

Expand All @@ -77,9 +82,6 @@ def create_hparams():


def create_experiment_fn():
use_validation_monitor = (FLAGS.schedule in
["train_and_evaluate", "continuous_train_and_eval"]
and FLAGS.local_eval_frequency)
return tpu_trainer_lib.create_experiment_fn(
model_name=FLAGS.model,
problem_name=get_problem_name(),
Expand All @@ -92,9 +94,9 @@ def create_experiment_fn():
decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
use_tfdbg=FLAGS.tfdbg,
use_dbgprofile=FLAGS.dbgprofile,
use_validation_monitor=use_validation_monitor,
eval_early_stopping_steps=FLAGS.eval_early_stopping_steps,
eval_early_stopping_metric=FLAGS.eval_early_stopping_metric,
eval_early_stopping_metric_delta=FLAGS.eval_early_stopping_metric_delta,
eval_early_stopping_metric_minimize=FLAGS.
eval_early_stopping_metric_minimize,
use_tpu=FLAGS.use_tpu)
Expand Down Expand Up @@ -170,7 +172,7 @@ def execute_schedule(exp):

def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
tf.set_random_seed(123)
tpu_trainer_lib.set_random_seed(FLAGS.random_seed)
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
log_registry()

Expand Down
14 changes: 8 additions & 6 deletions tensor2tensor/bin/t2t_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"The imported files should contain registrations, "
"e.g. @registry.register_model calls, that will then be "
"available to the t2t-trainer.")
flags.DEFINE_integer("random_seed", 1234, "Random seed.")
flags.DEFINE_integer("tpu_num_shards", 8, "Number of tpu shards.")
flags.DEFINE_integer("iterations_per_loop", 1000,
"Number of iterations in a TPU training loop.")
Expand All @@ -60,7 +61,11 @@
flags.DEFINE_string("output_dir", "", "Base output directory for run.")
flags.DEFINE_string("schedule", "continuous_train_and_eval",
"Method of Experiment to run.")
flags.DEFINE_integer("eval_steps", 200, "Number of steps in evaluation.")
flags.DEFINE_integer("eval_steps", 10000,
"Number of steps in evaluation. By default, eval will "
"stop after eval_steps or when it runs through the eval "
"dataset once in full, whichever comes first, so this "
"can be a very large number.")
except: # pylint: disable=bare-except
pass

Expand All @@ -76,9 +81,6 @@ def create_hparams():


def create_experiment_fn():
use_validation_monitor = (FLAGS.schedule in
["train_and_evaluate", "continuous_train_and_eval"]
and FLAGS.local_eval_frequency)
return tpu_trainer_lib.create_experiment_fn(
model_name=FLAGS.model,
problem_name=get_problem_name(),
Expand All @@ -91,9 +93,9 @@ def create_experiment_fn():
decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
use_tfdbg=FLAGS.tfdbg,
use_dbgprofile=FLAGS.dbgprofile,
use_validation_monitor=use_validation_monitor,
eval_early_stopping_steps=FLAGS.eval_early_stopping_steps,
eval_early_stopping_metric=FLAGS.eval_early_stopping_metric,
eval_early_stopping_metric_delta=FLAGS.eval_early_stopping_metric_delta,
eval_early_stopping_metric_minimize=FLAGS.
eval_early_stopping_metric_minimize,
use_tpu=FLAGS.use_tpu)
Expand Down Expand Up @@ -169,7 +171,7 @@ def execute_schedule(exp):

def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
tf.set_random_seed(123)
tpu_trainer_lib.set_random_seed(FLAGS.random_seed)
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
log_registry()

Expand Down
1 change: 1 addition & 0 deletions tensor2tensor/data_generators/algorithmic_math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

"""Tests for tensor2tensor.data_generators.algorithmic_math."""
# TODO(rsepassi): This test is flaky. Disable, remove, or update.

from __future__ import absolute_import
from __future__ import division
Expand Down
203 changes: 34 additions & 169 deletions tensor2tensor/data_generators/librispeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,14 @@
"""Librispeech dataset."""

import os
from subprocess import call
import tarfile
import wave

# Dependency imports

import numpy as np

from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.layers import common_layers
from tensor2tensor.utils import modality
from tensor2tensor.data_generators import speech_recognition
from tensor2tensor.utils import registry

import tensorflow as tf


_LIBRISPEECH_TRAIN_DATASETS = [
[
Expand Down Expand Up @@ -86,130 +77,13 @@ def _collect_data(directory, input_ext, transcription_ext):
return data_files


def _get_audio_data(filepath):
# Construct a true .wav file.
out_filepath = filepath.strip(".flac") + ".wav"
# Assumes sox is installed on system. Sox converts from FLAC to WAV.
call(["sox", filepath, out_filepath])
wav_file = wave.open(open(out_filepath))
frame_count = wav_file.getnframes()
byte_array = wav_file.readframes(frame_count)

data = np.fromstring(byte_array, np.uint8).tolist()
return data, frame_count, wav_file.getsampwidth(), wav_file.getnchannels()


class LibrispeechTextEncoder(text_encoder.TextEncoder):

def encode(self, s):
return [self._num_reserved_ids + ord(c) for c in s]

def decode(self, ids):
"""Transform a sequence of int ids into a human-readable string.
EOS is not expected in ids.
Args:
ids: list of integers to be converted.
Returns:
s: human-readable string.
"""
decoded_ids = []
for id_ in ids:
if 0 <= id_ < self._num_reserved_ids:
decoded_ids.append(text_encoder.RESERVED_TOKENS[int(id_)])
else:
decoded_ids.append(id_ - self._num_reserved_ids)
return "".join([chr(d) for d in decoded_ids])


@registry.register_audio_modality
class LibrispeechModality(modality.Modality):
"""Performs strided conv compressions for audio spectral data."""

def bottom(self, inputs):
"""Transform input from data space to model space.
Args:
inputs: A Tensor with shape [batch, ...]
Returns:
body_input: A Tensor with shape [batch, ?, ?, body_input_depth].
"""
with tf.variable_scope(self.name):
# TODO(aidangomez): Will need to sort out a better audio pipeline
def xnet_resblock(x, filters, res_relu, name):
with tf.variable_scope(name):
# We only stride along the length dimension to preserve the spectral
# bins (which are tiny in dimensionality relative to length)
y = common_layers.separable_conv_block(
x,
filters, [((1, 1), (3, 3)), ((1, 1), (3, 3))],
first_relu=True,
padding="SAME",
force2d=True,
name="sep_conv_block")
y = common_layers.pool(y, (3, 3), "MAX", "SAME", strides=(2, 1))
return y + common_layers.conv_block(
x,
filters, [((1, 1), (1, 1))],
padding="SAME",
strides=(2, 1),
first_relu=res_relu,
force2d=True,
name="res_conv0")

# Rescale from UINT8 to floats in [-1,-1]
signals = (tf.to_float(inputs)-127)/128.
signals = tf.squeeze(signals, [2, 3])

# `stfts` is a complex64 Tensor representing the short-time Fourier
# Transform of each signal in `signals`. Its shape is
# [batch_size, ?, fft_unique_bins]
# where fft_unique_bins = fft_length // 2 + 1 = 513.
stfts = tf.contrib.signal.stft(signals, frame_length=1024, frame_step=512,
fft_length=1024)

# An energy spectrogram is the magnitude of the complex-valued STFT.
# A float32 Tensor of shape [batch_size, ?, 513].
magnitude_spectrograms = tf.abs(stfts)

# Warp the linear-scale, magnitude spectrograms into the mel-scale.
num_spectrogram_bins = magnitude_spectrograms.shape[-1].value
lower_edge_hertz, upper_edge_hertz, num_mel_bins = 80.0, 7600.0, 64
sample_rate = 16000
linear_to_mel_weight_matrix = (
tf.contrib.signal.linear_to_mel_weight_matrix(
num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz,
upper_edge_hertz))
mel_spectrograms = tf.tensordot(
magnitude_spectrograms, linear_to_mel_weight_matrix, 1)
# Note: Shape inference for tensordot does not currently handle this case.
mel_spectrograms.set_shape(magnitude_spectrograms.shape[:-1].concatenate(
linear_to_mel_weight_matrix.shape[-1:]))

x = tf.expand_dims(mel_spectrograms, 2)
x.set_shape([None, None, None, num_mel_bins])
for i in xrange(self._model_hparams.audio_compression):
x = xnet_resblock(x, 2**(i + 1), True, "compress_block_%d" % i)
return xnet_resblock(x, self._body_input_depth, False,
"compress_block_final")


@registry.register_problem()
class Librispeech(problem.Problem):
"""Problem spec for English word to dictionary definition."""
class Librispeech(speech_recognition.SpeechRecognitionProblem):
"""Problem spec for Librispeech using clean and noisy data."""

@property
def is_character_level(self):
return True

@property
def input_space_id(self):
return problem.SpaceID.AUDIO_SPECTRAL

@property
def target_space_id(self):
return problem.SpaceID.EN_CHR
# Select only the clean data
TRAIN_DATASETS = _LIBRISPEECH_TRAIN_DATASETS
DEV_DATASETS = _LIBRISPEECH_TEST_DATASETS

@property
def num_shards(self):
Expand All @@ -228,26 +102,8 @@ def use_train_shards_for_dev(self):
"""If true, we only generate training data and hold out shards for dev."""
return False

def feature_encoders(self, _):
return {
"inputs": text_encoder.TextEncoder(),
"targets": LibrispeechTextEncoder(),
}

def example_reading_spec(self):
data_fields = {
"inputs": tf.VarLenFeature(tf.int64),
"targets": tf.VarLenFeature(tf.int64),
}
data_items_to_decoders = None
return (data_fields, data_items_to_decoders)

def generator(self, data_dir, tmp_dir, training,
def generator(self, data_dir, tmp_dir, datasets,
eos_list=None, start_from=0, how_many=0):
eos_list = [1] if eos_list is None else eos_list
datasets = (_LIBRISPEECH_TRAIN_DATASETS if training
else _LIBRISPEECH_TEST_DATASETS)
num_reserved_ids = self.feature_encoders(None)["targets"].num_reserved_ids
i = 0
for url, subdir in datasets:
filename = os.path.basename(url)
Expand All @@ -267,44 +123,53 @@ def generator(self, data_dir, tmp_dir, training,
data_dir = os.path.join(tmp_dir, "LibriSpeech", subdir)
data_files = _collect_data(data_dir, "flac", "txt")
data_pairs = data_files.values()

encoders = self.feature_encoders(None)
audio_encoder = encoders["waveforms"]
text_encoder = encoders["targets"]

for media_file, text_data in sorted(data_pairs)[start_from:]:
if how_many > 0 and i == how_many:
return
i += 1
audio_data, sample_count, sample_width, num_channels = _get_audio_data(
media_file)
label = [num_reserved_ids + ord(c) for c in text_data] + eos_list
yield {
"inputs": audio_data,
"audio/channel_count": [num_channels],
"audio/sample_count": [sample_count],
"audio/sample_width": [sample_width],
"targets": label
"waveforms": audio_encoder.encode(media_file),
"targets": text_encoder.encode(text_data)
}

def generate_data(self, data_dir, tmp_dir, task_id=-1):
train_paths = self.training_filepaths(
data_dir, self.num_shards, shuffled=False)
dev_paths = self.dev_filepaths(
data_dir, self.num_dev_shards, shuffled=False)

if self.use_train_shards_for_dev:
all_paths = train_paths + dev_paths
generator_utils.generate_files(
self.generator(data_dir, tmp_dir, True), all_paths)
self.generator(data_dir, tmp_dir, self.TRAIN_DATASETS), all_paths)
generator_utils.shuffle_dataset(all_paths)
else:
generator_utils.generate_dataset_and_shuffle(
self.generator(data_dir, tmp_dir, True), train_paths,
self.generator(data_dir, tmp_dir, False), dev_paths)
self.generator(data_dir, tmp_dir, self.TRAIN_DATASETS), train_paths,
self.generator(data_dir, tmp_dir, self.DEV_DATASETS), dev_paths)

def hparams(self, defaults, unused_model_hparams):
p = defaults
p.stop_at_eos = int(False)
p.input_modality = {"inputs": ("audio:librispeech_modality", None)}
p.target_modality = (registry.Modalities.SYMBOL, 256)

def preprocess_example(self, example, mode, hparams):
return example
@registry.register_problem()
class LibrispeechCleanSmall(Librispeech):
"""Problem spec for Librispeech using 100h clean train data."""

# Select only the clean data
TRAIN_DATASETS = _LIBRISPEECH_TRAIN_DATASETS[:1]
DEV_DATASETS = _LIBRISPEECH_TEST_DATASETS[:1]


@registry.register_problem()
class LibrispeechClean(Librispeech):
"""Problem spec for Librispeech using 460h clean train data."""

# Select only the clean data
TRAIN_DATASETS = _LIBRISPEECH_TRAIN_DATASETS[:2]
DEV_DATASETS = _LIBRISPEECH_TEST_DATASETS[:1]


# TODO(lukaszkaiser): clean up hparams or remove from here.
Expand Down
Loading

0 comments on commit c1cd875

Please sign in to comment.