Skip to content

Commit

Permalink
Merge pull request NVIDIA#96 from NVIDIA/clean_slate
Browse files Browse the repository at this point in the history
Clean slate
  • Loading branch information
rafaelvalle authored Nov 27, 2018
2 parents fc0cf6a + ba8cf36 commit f02704f
Show file tree
Hide file tree
Showing 13 changed files with 232 additions and 180 deletions.
33 changes: 19 additions & 14 deletions README.md
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Tacotron 2 (without wavenet)

Tacotron 2 PyTorch implementation of [Natural TTS Synthesis By Conditioning
PyTorch implementation of [Natural TTS Synthesis By Conditioning
Wavenet On Mel Spectrogram Predictions](https://arxiv.org/pdf/1712.05884.pdf).

This implementation includes **distributed** and **fp16** support
Expand All @@ -11,9 +11,7 @@ Distributed and FP16 support relies on work by Christian Sarofeen and NVIDIA's

![Alignment, Predicted Mel Spectrogram, Target Mel Spectrogram](tensorboard.png)

[Download demo audio](https://github.com/NVIDIA/tacotron2/blob/master/demo.wav) trained on LJS and using Ryuchi Yamamoto's [pre-trained Mixture of Logistics
wavenet](https://github.com/r9y9/wavenet_vocoder/)
"Scientists at the CERN laboratory say they have discovered a new particle."
Visit our [website] for audio samples.

## Pre-requisites
1. NVIDIA GPU + CUDA cuDNN
Expand All @@ -24,11 +22,9 @@ wavenet](https://github.com/r9y9/wavenet_vocoder/)
3. CD into this repo: `cd tacotron2`
4. Update .wav paths: `sed -i -- 's,DUMMY,ljs_dataset_folder/wavs,g' filelists/*.txt`
- Alternatively, set `load_mel_from_disk=True` in `hparams.py` and update mel-spectrogram paths
5. Install [pytorch 0.4](https://github.com/pytorch/pytorch)
5. Install [PyTorch 1.0]
6. Install python requirements or build docker image
- Install python requirements: `pip install -r requirements.txt`
- **OR**
- Build docker image: `docker build --tag tacotron2 .`

## Training
1. `python train.py --output_directory=outdir --log_directory=logdir`
Expand All @@ -37,17 +33,22 @@ wavenet](https://github.com/r9y9/wavenet_vocoder/)
## Multi-GPU (distributed) and FP16 Training
1. `python -m multiproc train.py --output_directory=outdir --log_directory=logdir --hparams=distributed_run=True,fp16_run=True`

## Inference
When performing Mel-Spectrogram to Audio synthesis with a WaveNet model, make sure Tacotron 2 and WaveNet were trained on the same mel-spectrogram representation. Follow these steps to use a a simple inference pipeline using griffin-lim:

1. `jupyter notebook --ip=127.0.0.1 --port=31337`
2. load inference.ipynb
## Inference demo
1. Download our published [Tacotron 2] model
2. Download our published [WaveGlow] model
3. `jupyter notebook --ip=127.0.0.1 --port=31337`
4. Load inference.ipynb

N.b. When performing Mel-Spectrogram to Audio synthesis, make sure Tacotron 2
and the Mel decoder were trained on the same mel-spectrogram representation.


## Related repos
[nv-wavenet](https://github.com/NVIDIA/nv-wavenet/): Faster than real-time
wavenet inference
[WaveGlow](https://github.com/NVIDIA/WaveGlow) Faster than real time Flow-based
Generative Network for Speech Synthesis

[nv-wavenet](https://github.com/NVIDIA/nv-wavenet/) Faster than real time
WaveNet.

## Acknowledgements
This implementation uses code from the following repos: [Keith
Expand All @@ -61,3 +62,7 @@ We are thankful to the Tacotron 2 paper authors, specially Jonathan Shen, Yuxuan
Wang and Zongheng Yang.


[WaveGlow]: https://drive.google.com/file/d/1cjKPHbtAMh_4HTHmuIGNkbOkPBD9qwhj/view?usp=sharing
[Tacotron 2]: https://drive.google.com/file/d/1c5ZTuT7J08wLUoVZ2KkUs_VdZuJ86ZqA/view?usp=sharing
[pytorch 1.0]: https://github.com/pytorch/pytorch#installation
[website]: https://nv-adlr.github.io/WaveGlow
19 changes: 10 additions & 9 deletions data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@ class TextMelLoader(torch.utils.data.Dataset):
2) normalizes text and converts them to sequences of one-hot vectors
3) computes mel-spectrograms from audio files.
"""
def __init__(self, audiopaths_and_text, hparams, shuffle=True):
self.audiopaths_and_text = load_filepaths_and_text(
audiopaths_and_text, hparams.sort_by_length)
def __init__(self, audiopaths_and_text, hparams):
self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text)
self.text_cleaners = hparams.text_cleaners
self.max_wav_value = hparams.max_wav_value
self.sampling_rate = hparams.sampling_rate
Expand All @@ -26,8 +25,7 @@ def __init__(self, audiopaths_and_text, hparams, shuffle=True):
hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin,
hparams.mel_fmax)
random.seed(1234)
if shuffle:
random.shuffle(self.audiopaths_and_text)
random.shuffle(self.audiopaths_and_text)

def get_mel_text_pair(self, audiopath_and_text):
# separate filename and text
Expand All @@ -38,7 +36,10 @@ def get_mel_text_pair(self, audiopath_and_text):

def get_mel(self, filename):
if not self.load_mel_from_disk:
audio = load_wav_to_torch(filename, self.sampling_rate)
audio, sampling_rate = load_wav_to_torch(filename)
if sampling_rate != self.stft.sampling_rate:
raise ValueError("{} {} SR doesn't match target {} SR".format(
sampling_rate, self.stft.sampling_rate))
audio_norm = audio / self.max_wav_value
audio_norm = audio_norm.unsqueeze(0)
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
Expand Down Expand Up @@ -87,9 +88,9 @@ def __call__(self, batch):
text = batch[ids_sorted_decreasing[i]][0]
text_padded[i, :text.size(0)] = text

# Right zero-pad mel-spec with extra single zero vector to mark the end
# Right zero-pad mel-spec
num_mels = batch[0][1].size(0)
max_target_len = max([x[1].size(1) for x in batch]) + 1
max_target_len = max([x[1].size(1) for x in batch])
if max_target_len % self.n_frames_per_step != 0:
max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step
assert max_target_len % self.n_frames_per_step == 0
Expand All @@ -103,7 +104,7 @@ def __call__(self, batch):
for i in range(len(ids_sorted_decreasing)):
mel = batch[ids_sorted_decreasing[i]][1]
mel_padded[i, :, :mel.size(1)] = mel
gate_padded[i, mel.size(1):] = 1
gate_padded[i, mel.size(1)-1:] = 1
output_lengths[i] = mel.size(1)

return text_padded, input_lengths, mel_padded, gate_padded, \
Expand Down
52 changes: 52 additions & 0 deletions distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,55 @@ def train(self, mode=True):
super(DistributedDataParallel, self).train(mode)
self.module.train(mode)
'''
'''
Modifies existing model to do gradient allreduce, but doesn't change class
so you don't need "module"
'''
def apply_gradient_allreduce(module):
if not hasattr(dist, '_backend'):
module.warn_on_half = True
else:
module.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False

for p in module.state_dict().values():
if not torch.is_tensor(p):
continue
dist.broadcast(p, 0)

def allreduce_params():
if(module.needs_reduction):
module.needs_reduction = False
buckets = {}
for param in module.parameters():
if param.requires_grad and param.grad is not None:
tp = type(param.data)
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(param)
if module.warn_on_half:
if torch.cuda.HalfTensor in buckets:
print("WARNING: gloo dist backend for half parameters may be extremely slow." +
" It is recommended to use the NCCL backend in this case. This currently requires" +
"PyTorch built from top of tree master.")
module.warn_on_half = False

for tp in buckets:
bucket = buckets[tp]
grads = [param.grad.data for param in bucket]
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced)
coalesced /= dist.get_world_size()
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)

for param in list(module.parameters()):
def allreduce_hook(*unused):
param._execution_engine.queue_callback(allreduce_params)
if param.requires_grad:
param.register_hook(allreduce_hook)

def set_needs_reduction(self, input, output):
self.needs_reduction = True

module.register_forward_hook(set_needs_reduction)
return module
19 changes: 10 additions & 9 deletions hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def create_hparams(hparams_string=None, verbose=False):
# Experiment Parameters #
################################
epochs=500,
iters_per_checkpoint=500,
iters_per_checkpoint=1000,
seed=1234,
dynamic_loss_scaling=True,
fp16_run=False,
Expand All @@ -24,10 +24,9 @@ def create_hparams(hparams_string=None, verbose=False):
# Data Parameters #
################################
load_mel_from_disk=False,
training_files='filelists/ljs_audio_text_train_filelist.txt',
validation_files='filelists/ljs_audio_text_val_filelist.txt',
training_files='filelists/ljs_audio22khz_text_train_filelist.txt',
validation_files='filelists/ljs_audio22khz_text_val_filelist.txt',
text_cleaners=['english_cleaners'],
sort_by_length=False,

################################
# Audio Parameters #
Expand All @@ -39,7 +38,7 @@ def create_hparams(hparams_string=None, verbose=False):
win_length=1024,
n_mel_channels=80,
mel_fmin=0.0,
mel_fmax=None, # if None, half the sampling rate
mel_fmax=8000.0,

################################
# Model Parameters #
Expand All @@ -57,7 +56,9 @@ def create_hparams(hparams_string=None, verbose=False):
decoder_rnn_dim=1024,
prenet_dim=256,
max_decoder_steps=1000,
gate_threshold=0.6,
gate_threshold=0.5,
p_attention_dropout=0.1,
p_decoder_dropout=0.1,

# Attention parameters
attention_rnn_dim=1024,
Expand All @@ -78,9 +79,9 @@ def create_hparams(hparams_string=None, verbose=False):
use_saved_learning_rate=False,
learning_rate=1e-3,
weight_decay=1e-6,
grad_clip_thresh=1,
batch_size=48,
mask_padding=False # set model's padded outputs to padded values
grad_clip_thresh=1.0,
batch_size=64,
mask_padding=True # set model's padded outputs to padded values
)

if hparams_string:
Expand Down
117 changes: 60 additions & 57 deletions inference.ipynb

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
super(LinearNorm, self).__init__()
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)

torch.nn.init.xavier_uniform(
torch.nn.init.xavier_uniform_(
self.linear_layer.weight,
gain=torch.nn.init.calculate_gain(w_init_gain))

Expand All @@ -31,7 +31,7 @@ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
padding=padding, dilation=dilation,
bias=bias)

torch.nn.init.xavier_uniform(
torch.nn.init.xavier_uniform_(
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))

def forward(self, signal):
Expand All @@ -42,7 +42,7 @@ def forward(self, signal):
class TacotronSTFT(torch.nn.Module):
def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
mel_fmax=None):
mel_fmax=8000.0):
super(TacotronSTFT, self).__init__()
self.n_mel_channels = n_mel_channels
self.sampling_rate = sampling_rate
Expand Down
Loading

0 comments on commit f02704f

Please sign in to comment.