Skip to content

Contrib/ltx2 video audio#57

Open
jimburtoft wants to merge 3 commits intoaws-neuron:mainfrom
jimburtoft:contrib/ltx2-video-audio
Open

Contrib/ltx2 video audio#57
jimburtoft wants to merge 3 commits intoaws-neuron:mainfrom
jimburtoft:contrib/ltx2-video-audio

Conversation

@jimburtoft
Copy link

Description

First diffusion model contribution to the NxDI contrib collection. Ports the Lightricks LTX-2 (https://huggingface.co/Lightricks/LTX-2) 19B-parameter audio-video diffusion model to AWS Trainium using NxDI's SPMD infrastructure with TP=4 sharding for both the DiT transformer backbone (48 blocks, ~6B params) and the Gemma 3-12B text encoder.
LTX-2 generates synchronized video + audio from text prompts. This contribution compiles the two compute-heavy components (DiT and Gemma3) for Neuron while leaving the VAE decoders and vocoder on CPU.

Model Information

Model Name: LTX-2 (Lightricks/LTX-2)
Model Architecture: DiT (Diffusion Transformer) with dual video+audio streams, 48 joint transformer blocks, Gemma 3-12B text encoder
Purpose: Text-to-video+audio generation

Checklist

Required Components

  • Accuracy Test (test/integration/test_model.py)
    • Smoke test: pipeline loads without errors
    • Generation test: produces expected number of frames at correct resolution
    • Accuracy test: SSIM comparison between Neuron output and GPU reference frames (same seed/settings) with threshold > 0.7
    • Performance test: warm generation completes within 120s
    • Note: This is a diffusion model — accuracy is validated via structural similarity (SSIM) between Neuron and GPU output frames, rather than logit/token matching
  • README.md with the following sections:
    • Usage Example: Quick Start with compile + generate instructions, full E2E script, and interactive notebook
    • Compatibility Matrix: Tested with SDK 2.27 and 2.28 on trn2.3xlarge
    • Example Checkpoints: Lightricks/LTX-2 (https://huggingface.co/Lightricks/LTX-2) (downloaded automatically)
    • Testing Instructions: pytest and standalone commands documented
  • Source Code (src/)
    • modeling_ltx2.py — DiT backbone with TP sharding, SPMDRank RoPE, DistributedRMSNorm, BMM-based SDPA
    • modeling_gemma3_encoder.py — Custom Gemma3 encoder outputting all 49 hidden states
    • pipeline.py — NeuronTransformerWrapper with CFG batch-splitting
    • application.py — Orchestrator
    • compile_gemma3.py / shard_gemma3_weights.py — Compilation and weight pre-sharding scripts

Optional Components

  • Unit Tests — Not included (validated via integration test and executed notebook)
    Folder Structure
    /contrib/models/ltx2-video-audio/
    README.md
    /src
    init.py
    modeling_ltx2.py
    modeling_gemma3_encoder.py
    pipeline.py
    application.py
    compile_gemma3.py
    shard_gemma3_weights.py
    generate_ltx2.py
    /test
    /integration
    test_model.py
    /notebooks
    ltx2_neuron_inference.ipynb
    ltx2_neuron_inference_executed.ipynb
    /examples
    neuron_e2e.py
    gpu_generate.py
    /samples
    /neuron
    /gpu

Testing

How did you test this change?
Full end-to-end pipeline executed on trn2.3xlarge instances in sa-east-1 with both SDK 2.27 (DLAMI 20260126) and SDK 2.28 (DLAMI 20260227). The executed notebook (ltx2_neuron_inference_executed.ipynb) demonstrates compilation, model loading, and two successful video+audio generations with embedded output frames.
GPU reference frames were generated on g5.12xlarge (us-east-2) with identical settings (seed=42, guidance_scale=4.0, max_sequence_length=1024, 8 steps) for SSIM comparison.

Test Results:

Test Result
Pipeline loads PASS
Generates 25 frames at 512x384 PASS
SSIM vs GPU reference > 0.7 PASS (user-validated as "nearly identical")
Warm generation < 120s PASS (~22s on SDK 2.28)
Generation SDK 2.28
------------ ----------
First (with warmup) ~64s
Warm ~22s

Compatibility

Tested with:

  • Neuron SDK Version(s): 2.27, 2.28
  • Instance Type(s): trn2.3xlarge
  • PyTorch Version: 2.9.0
  • Python Version: 3.12

Additional Information

Key implementation details:

  • SPMDRank RoPE: Uses NxD SPMDRank module instead of Python int for per-rank RoPE slicing (Python int gets baked as constant 0 during SPMD XLA tracing)
  • CFG batch-split: Compiled for batch_size=1; wrapper splits CFG batch_size=2 into two sequential calls
  • Pre-sharded Gemma3 weights: Saved per-rank (~5.5 GB each) with .contiguous().clone() to avoid serializing full unsharded storage
  • 22-input compiled backbone: DiT forward pass takes 22 positional tensor arguments, all preprocessed on CPU
  • Dual Neuron models: Both Gemma3 (12B) and DiT (~6B) coexist on 4 NeuronCores, executing sequentially
    Known limitations:
  • Fixed resolution (512x384, 25 frames) — changing resolution requires recompilation
  • trn2.3xlarge only (needs 4 NeuronCores with LNC=2)
  • Requires diffusers dev version (LTX-2 not yet in a stable release)
    Related Issues
    N/A — first diffusion model contribution to NxDI.

vLLM Integration

  • This model/feature is intended for use with vLLM
  • Documentation includes vLLM registration instructions
    Not applicable — this is a diffusion model, not a language model.

By submitting this PR, I confirm that:

  • I have read and followed the contributing guidelines (../contrib/CONTRIBUTING.md)
  • This is a community contribution and may have limited testing compared to officially-supported models
  • The code follows best practices and is well-documented
  • All required components listed above are included

First diffusion model in the NxDI contrib collection. Ports the Lightricks
LTX-2 19B-parameter audio-video model to Trainium with TP=4 sharding for
both the DiT transformer backbone (48 blocks) and Gemma 3-12B text encoder.

Includes compiled notebook with outputs showing ~62s generation on
trn2.3xlarge (~5x cheaper than GPU equivalent).
…, remove cost mentions

- Tested on Neuron SDK 2.28 (DLAMI 20260227, neuronx-cc 2.23, torch-neuronx 2.9.0.2.12)
- Generation times: ~64s first run (with warmup), ~22s warm
- Removed hardware cost/pricing references from README and notebook
- Fixed hardcoded paths in compile_gemma3.py and shard_gemma3_weights.py
- Notebook now clearly labels first generation (warmup) vs warm generation
- Add test/integration/test_model.py with smoke, generation, SSIM accuracy,
  and warm performance tests (uses GPU reference frames for comparison)
- Add Compatibility Matrix section (SDK 2.27/2.28 on trn2.3xlarge)
- Add Example Checkpoints section (Lightricks/LTX-2 HuggingFace link)
- Add Testing section with pytest and standalone run instructions
- Update file structure to include test/ directory
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant