Skip to content

LAION-AI/scaled-echo-tts

Repository files navigation

Echo TTS — Scaled Training

DiT-based diffusion transformer for text-to-speech, conditioned on frozen T5Gemma2 text encoders and learned speaker embeddings from DAC-VAE latents.

Two training backends share the same model architecture:

  • TorchTitan — DDP / FSDP2 with optional torch.compile
  • Megatron — Tensor parallelism (TP) via Megatron-Core + TransformerEngine

Repository structure

model/
  config.py          Model config dataclass + 800M/3B/8B scale presets
  components.py      Shared layers (RMSNorm, QKNorm, RoPE, AdaLN, MLP)
  dit.py             Standard PyTorch model (DDP / FSDP2 / compile)
  dit_megatron.py    Megatron variant (tensor parallelism + TE kernels)

data/
  dataset.py         LocalTarDataset — streams from tar archives
  tokenizer.py       Extended tokenizer with conditioning tokens

encoder/
  frozen.py          FrozenTextEncoder wrapper (T5 / T5Gemma2)

train/
  torchtitan.py      DDP/FSDP2 training entry point
  megatron.py        Megatron TP+DDP training entry point
  utils.py           MFU calculation, distributed helpers

configs/             YAML configs per scale and framework
scripts/             SLURM launcher

Setup

git clone git@github.com:GLJS/scaled_echo_tts.git
cd scaled_echo_tts
pip install -e .

# For Megatron support:
pip install -e ".[megatron]"

Quick start (single machine, no SLURM)

TorchTitan — DDP on 4 GPUs

torchrun --nproc_per_node=4 \
    -m train.torchtitan --config configs/tt_800M.yaml

TorchTitan — FSDP2 + compile on 4 GPUs

torchrun --nproc_per_node=4 \
    -m train.torchtitan --config configs/tt_3B.yaml

Megatron — TP=2 on 4 GPUs

torchrun --nproc_per_node=4 \
    -m train.megatron --config configs/mg_3B.yaml

Megatron — TP=4 on 4 GPUs

torchrun --nproc_per_node=4 \
    -m train.megatron --config configs/mg_8B.yaml

Multi-node with SLURM

The launcher script auto-detects the framework from the YAML config:

# 1 node (4 GPUs)
sbatch --nodes=1 scripts/run.sh configs/tt_3B.yaml

# 4 nodes (16 GPUs), 2 hour wall time
sbatch --nodes=4 --time=02:00:00 scripts/run.sh configs/mg_8B.yaml

# 16 nodes (64 GPUs)
sbatch --nodes=16 --time=02:00:00 scripts/run.sh configs/tt_800M.yaml

Multi-node without SLURM

Set the distributed environment variables manually:

# On each node:
export MASTER_ADDR=<first-node-hostname>
export MASTER_PORT=29500
export NCCL_IB_HCA=mlx5            # InfiniBand adapter
export NCCL_SOCKET_IFNAME=ib0      # IB network interface

# Node 0:
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 \
    --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT \
    -m train.torchtitan --config configs/tt_3B.yaml

# Node 1:
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=1 \
    --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT \
    -m train.torchtitan --config configs/tt_3B.yaml

Configuration

All settings live in a single YAML file. Copy and modify:

cp configs/tt_3B.yaml configs/my_experiment.yaml

Key config fields

Field Values Description
framework torchtitan / megatron Training backend
scale 800M / 3B / 8B Model scale (sets architecture dims)
parallelism ddp / fsdp2 Data parallelism strategy (TorchTitan only)
compile false / default / max-autotune / max-autotune-no-cudagraphs torch.compile mode (TorchTitan only)
tensor_parallel_size 1 / 2 / 4 TP degree (Megatron only)
batch_size_per_gpu int Per-GPU micro batch size
remat true / false Activation checkpointing
frozen_encoder true / false Freeze T5Gemma2 text encoder
encoder_model HF model name e.g. google/t5gemma-2-1b-1b
use_te_attn true / false TransformerEngine fused attention (Megatron)
fused_optimizer true / false PyTorch fused AdamW
duration_seconds int Benchmark measurement duration
max_latent_length int Max audio latent frames (default 768)
max_text_length int Max text token length (default 192)
max_speaker_length int Max speaker latent frames (default 384)

Data format

Training data is stored as tar archives containing triplets:

sample_001.npy       # Audio latent (T, 128) float16 — DAC-VAE encoded
sample_001.ref.npy   # Speaker reference latent (T, 128) float16
sample_001.json      # {"text": "...", "target_duration": 5.2, ...}

Set the data directory in the training script or config. The LocalTarDataset streams from tars with threaded prefetching.

Model architecture

Text ──→ [Frozen T5Gemma2] ──→ Dense Projection ──→ text_embeddings
                                                          │
Audio ──→ [DAC-VAE latent] ──→ noise + timestep t         │
                │                                          │
Speaker ──→ [Patch + 8L Transformer] ──→ speaker_emb       │
                                              │            │
                                    ┌─────────┴────────────┘
                                    ▼
                            DiT Decoder (24–32 layers)
                            - AdaLN timestep conditioning
                            - Joint self + cross attention
                              (text stream + speaker stream)
                            - QKNorm + RoPE
                            - Gated SiLU MLP
                                    │
                                    ▼
                            noise prediction (T, 128)

Scales:

Scale Decoder Speaker Enc Text Encoder Total Params
800M 1280d, 24L, 10H 768d, 8L, 6H T5Gemma2-270m ~780M
3B 2560d, 32L, 20H 1024d, 8L, 8H T5Gemma2-1B ~3.5B
8B 4096d, 32L, 32H 1024d, 8L, 8H T5Gemma2-4B ~8.2B

Swapping components

Different text encoder

Change encoder_model in the YAML config:

encoder_model: google/t5gemma-2-4b-4b   # larger encoder
# or
encoder_model: google-t5/t5-base         # standard T5

Update t5_hidden_dim to match the encoder's output dimension.

Unfrozen encoder

frozen_encoder: false

The encoder gradients will flow and the encoder parameters will be included in the optimizer.

Different parallelism

# DDP (replicate model, shard data)
parallelism: ddp

# FSDP2 (shard model + optimizer across GPUs)
parallelism: fsdp2

# Megatron TP=2 (split attention heads across 2 GPUs, DDP across rest)
framework: megatron
tensor_parallel_size: 2

Results

Benchmark results on NVIDIA GH200 (990 TFLOPS bf16 peak):

Config 1N (4 GPU) 4N (16 GPU) 16N (64 GPU)
TorchTitan 800M DDP + max-autotune-no-cg 35.6% MFU 35.8% MFU 35.7% MFU
Megatron 3B TP=2 30.0% MFU 29.4% MFU 29.1% MFU
Megatron 8B TP=4 37.5% MFU 36.7% MFU 36.0% MFU

Full results dashboard: https://share.gijs.me/model_flops_utilization.html

About

Scaled diffusion transformer for text-to-speech synthesis (DiT + T5Gemma2 conditioning, TorchTitan & Megatron backends, tested up to 1024 GPUs)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors