DiT-based diffusion transformer for text-to-speech, conditioned on frozen T5Gemma2 text encoders and learned speaker embeddings from DAC-VAE latents.
Two training backends share the same model architecture:
- TorchTitan — DDP / FSDP2 with optional
torch.compile - Megatron — Tensor parallelism (TP) via Megatron-Core + TransformerEngine
model/
config.py Model config dataclass + 800M/3B/8B scale presets
components.py Shared layers (RMSNorm, QKNorm, RoPE, AdaLN, MLP)
dit.py Standard PyTorch model (DDP / FSDP2 / compile)
dit_megatron.py Megatron variant (tensor parallelism + TE kernels)
data/
dataset.py LocalTarDataset — streams from tar archives
tokenizer.py Extended tokenizer with conditioning tokens
encoder/
frozen.py FrozenTextEncoder wrapper (T5 / T5Gemma2)
train/
torchtitan.py DDP/FSDP2 training entry point
megatron.py Megatron TP+DDP training entry point
utils.py MFU calculation, distributed helpers
configs/ YAML configs per scale and framework
scripts/ SLURM launcher
git clone git@github.com:GLJS/scaled_echo_tts.git
cd scaled_echo_tts
pip install -e .
# For Megatron support:
pip install -e ".[megatron]"torchrun --nproc_per_node=4 \
-m train.torchtitan --config configs/tt_800M.yamltorchrun --nproc_per_node=4 \
-m train.torchtitan --config configs/tt_3B.yamltorchrun --nproc_per_node=4 \
-m train.megatron --config configs/mg_3B.yamltorchrun --nproc_per_node=4 \
-m train.megatron --config configs/mg_8B.yamlThe launcher script auto-detects the framework from the YAML config:
# 1 node (4 GPUs)
sbatch --nodes=1 scripts/run.sh configs/tt_3B.yaml
# 4 nodes (16 GPUs), 2 hour wall time
sbatch --nodes=4 --time=02:00:00 scripts/run.sh configs/mg_8B.yaml
# 16 nodes (64 GPUs)
sbatch --nodes=16 --time=02:00:00 scripts/run.sh configs/tt_800M.yamlSet the distributed environment variables manually:
# On each node:
export MASTER_ADDR=<first-node-hostname>
export MASTER_PORT=29500
export NCCL_IB_HCA=mlx5 # InfiniBand adapter
export NCCL_SOCKET_IFNAME=ib0 # IB network interface
# Node 0:
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 \
--master_addr=$MASTER_ADDR --master_port=$MASTER_PORT \
-m train.torchtitan --config configs/tt_3B.yaml
# Node 1:
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=1 \
--master_addr=$MASTER_ADDR --master_port=$MASTER_PORT \
-m train.torchtitan --config configs/tt_3B.yamlAll settings live in a single YAML file. Copy and modify:
cp configs/tt_3B.yaml configs/my_experiment.yaml| Field | Values | Description |
|---|---|---|
framework |
torchtitan / megatron |
Training backend |
scale |
800M / 3B / 8B |
Model scale (sets architecture dims) |
parallelism |
ddp / fsdp2 |
Data parallelism strategy (TorchTitan only) |
compile |
false / default / max-autotune / max-autotune-no-cudagraphs |
torch.compile mode (TorchTitan only) |
tensor_parallel_size |
1 / 2 / 4 |
TP degree (Megatron only) |
batch_size_per_gpu |
int | Per-GPU micro batch size |
remat |
true / false |
Activation checkpointing |
frozen_encoder |
true / false |
Freeze T5Gemma2 text encoder |
encoder_model |
HF model name | e.g. google/t5gemma-2-1b-1b |
use_te_attn |
true / false |
TransformerEngine fused attention (Megatron) |
fused_optimizer |
true / false |
PyTorch fused AdamW |
duration_seconds |
int | Benchmark measurement duration |
max_latent_length |
int | Max audio latent frames (default 768) |
max_text_length |
int | Max text token length (default 192) |
max_speaker_length |
int | Max speaker latent frames (default 384) |
Training data is stored as tar archives containing triplets:
sample_001.npy # Audio latent (T, 128) float16 — DAC-VAE encoded
sample_001.ref.npy # Speaker reference latent (T, 128) float16
sample_001.json # {"text": "...", "target_duration": 5.2, ...}
Set the data directory in the training script or config. The LocalTarDataset streams from tars with threaded prefetching.
Text ──→ [Frozen T5Gemma2] ──→ Dense Projection ──→ text_embeddings
│
Audio ──→ [DAC-VAE latent] ──→ noise + timestep t │
│ │
Speaker ──→ [Patch + 8L Transformer] ──→ speaker_emb │
│ │
┌─────────┴────────────┘
▼
DiT Decoder (24–32 layers)
- AdaLN timestep conditioning
- Joint self + cross attention
(text stream + speaker stream)
- QKNorm + RoPE
- Gated SiLU MLP
│
▼
noise prediction (T, 128)
Scales:
| Scale | Decoder | Speaker Enc | Text Encoder | Total Params |
|---|---|---|---|---|
| 800M | 1280d, 24L, 10H | 768d, 8L, 6H | T5Gemma2-270m | ~780M |
| 3B | 2560d, 32L, 20H | 1024d, 8L, 8H | T5Gemma2-1B | ~3.5B |
| 8B | 4096d, 32L, 32H | 1024d, 8L, 8H | T5Gemma2-4B | ~8.2B |
Change encoder_model in the YAML config:
encoder_model: google/t5gemma-2-4b-4b # larger encoder
# or
encoder_model: google-t5/t5-base # standard T5Update t5_hidden_dim to match the encoder's output dimension.
frozen_encoder: falseThe encoder gradients will flow and the encoder parameters will be included in the optimizer.
# DDP (replicate model, shard data)
parallelism: ddp
# FSDP2 (shard model + optimizer across GPUs)
parallelism: fsdp2
# Megatron TP=2 (split attention heads across 2 GPUs, DDP across rest)
framework: megatron
tensor_parallel_size: 2Benchmark results on NVIDIA GH200 (990 TFLOPS bf16 peak):
| Config | 1N (4 GPU) | 4N (16 GPU) | 16N (64 GPU) |
|---|---|---|---|
| TorchTitan 800M DDP + max-autotune-no-cg | 35.6% MFU | 35.8% MFU | 35.7% MFU |
| Megatron 3B TP=2 | 30.0% MFU | 29.4% MFU | 29.1% MFU |
| Megatron 8B TP=4 | 37.5% MFU | 36.7% MFU | 36.0% MFU |
Full results dashboard: https://share.gijs.me/model_flops_utilization.html