Skip to content

coqylight/boltz_dap

Repository files navigation

Boltz-DAP: Distributed Axial Parallelism for Boltz-2

Run Boltz-2 protein structure prediction on large complexes (>2,000 amino acid residues) across multiple GPUs without OOM.

DAP (Dynamic Axial Parallelism) shards the pair representation z [B, N, N, D] across multiple GPUs along the row dimension, so no single GPU ever holds the full N×N tensor. This reduces peak memory proportionally to the number of GPUs — 4 GPUs → ~4× less memory per GPU.

Why?

Original Boltz-2 holds the full pair representation tensor on 1 GPU. For large complexes (>2,000 residues), this leads to CUDA out-of-memory (OOM) errors in consumer grade GPUs (VRAM < 48 GB). DAP enables Boltz-2 to run on multiple GPUs without OOM, even for large complexes like adeno-associated virus (AAV) hexamers.

Complex N (tokens) Original Boltz-2 DAP (4 × RTX 5880 Ada 48 GB VRAM)
AAV2 VP3 Trimer (3 × 519 aa) ~1,557 ⚠️ Tight ✅ ~12 GB/GPU
AAV2 VP3 Pentamer (5 × 519 aa) ~2,595 ❌ OOM ✅ ~36 GB/GPU
AAV2 VP3 Hexamer (6 × 519 aa) ~3,114 ❌ OOM ✅ ~45 GB/GPU

How It Works

┌─────────────────────────────────────────────────────┐
│  ALL GPUs: Input embedding → z_init [B, N, N, 128]  │
└──────────────────────┬──────────────────────────────┘
                       ▼
              scatter(z, dim=1)
         ┌─────────────┼──────────────┐
         ▼             ▼              ▼
   GPU 0: z₀       GPU 1: z₁     GPU 2: z₂    ...
   [B,N/P,N,D]    [B,N/P,N,D]   [B,N/P,N,D]
         │             │              │
         ▼             ▼              ▼
   ┌──────────────────────────────────────────┐
   │  Trunk Loop (48 Pairformer layers):      │
   │    • TriMulOut  (broadcast-chunked)      │
   │    • TriMulIn   (row↔col + broadcast)    │
   │    • TriAttStart (gather only H-bias)    │
   │    • TriAttEnd   (row↔col + attention)   │
   │    • Transition  (pointwise, no comm)    │
   │    • SeqAttn     (gather only pair bias) │
   └──────────────────────────────────────────┘
         │             │              │
         ▼             ▼              ▼
              gather(z, dim=1)
                       ▼
        z_full [B, N, N, 128]  (GPU 0 only)
                       ▼
         Distogram → Diffusion → Confidence

The full z is only materialized at scatter/gather boundaries. The entire trunk loop operates on smaller shards.

Quick Start

First time here? See docs/GETTING_STARTED.md for a step-by-step guide: install Boltz-2, clone this repo, prepare input YAML, and run DAP.

Prerequisites

  • 2+ GPUs on the same node (NVLink recommended)
  • Python 3.10+, PyTorch 2.x with CUDA
  • Boltz-2 installed (pip install boltz)

Tested environment

Item Used in development
GPU NVIDIA RTX 5880 Ada (48 GB VRAM)
NVIDIA H800 (80 GB VRAM)
CUDA Compatible with PyTorch 2.x
GPU counts tested 2, 4, 8 GPUs
Settings tested Boltz-2 default: recycling_steps=3, sampling_steps=200, diffusion_samples=1
AF3 default: recycling_steps=10, sampling_steps=200, diffusion_samples=25
Workloads AAV2 VP3 Trimer/Pentamer (e.g. 3×519 aa, 5×519 aa, 4 GPUs)
AAV2 VP3 Hexamer (6×519 aa, 25 samples with --use_flex_attention_chunked, 4 GPUs)
9MME (4642 tokens, 8 GPUs)

Other GPU models (A100, V100, etc.) should work with 2+ GPUs; memory per GPU scales with shard size.

Example log file: example_hexamer_25cif_full.log — full run that produced 25 CIF files (AAV2 VP3 Hexamer, 4 GPUs, --use_flex_attention_chunked, AF3 defaults). Large (~8.8 MB) but useful as a reference.

Running

# 4 GPUs
torchrun --nproc_per_node=4 boltz_dap_v2/run_boltz_dap_v2.py \
    input.yaml \
    --out_dir ./output \
    --cache ~/.boltz

# 2 GPUs
torchrun --nproc_per_node=2 boltz_dap_v2/run_boltz_dap_v2.py \
    input.yaml \
    --out_dir ./output \
    --cache ~/.boltz

Options

Flag Default Description
--out_dir (required) Output directory
--cache ~/.boltz Model weights cache
--recycling_steps 3 Number of recycling iterations (AF3-style default)
--sampling_steps 200 Diffusion sampling steps
--diffusion_samples 1 Number of diffusion samples
--use_msa_server off Use MSA server (e.g. ColabFold) for MSA generation
--no_kernels off Disable cuequivariance CUDA kernels (PyTorch-native triangle attention)
--use_flex_attention off Use FlexAttention for triangle attention (memory/throughput; may need chunked on large N)
--use_flex_attention_chunked off Chunked FlexAttention for DAP (avoids OOM on 25-sample hexamer; numerically matches original)
--use_potentials off Enable FK steering + physical guidance potentials
--seed None Random seed for reproducibility

Confidence (pLDDT, pTM, iPTM, PAE, PDE) is always computed when the model supports it; no flag required.

For a full prediction guide (entrypoint, launch, input data, CLI options, pipeline stages), see docs/boltz2_dap_prediction.md.

SLURM Example

#!/bin/bash
#SBATCH --job-name=boltz-dap
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=4
#SBATCH --gpus-per-task=1
#SBATCH --mem=128G

srun torchrun --nproc_per_node=4 \
    boltz_dap_v2/run_boltz_dap_v2.py \
    input.yaml \
    --out_dir ./output \
    --cache ~/.boltz \
    --recycling_steps 3

Project Structure

boltz_dap/
├── boltz_dap_v2/                    # DAP-aware layer wrappers
│   ├── run_boltz_dap_v2.py          # Entry point (replaces `boltz predict`)
│   ├── dap_trunk.py                 # Main forward: scatter → trunk → gather
│   ├── dap_pairformer.py            # PairformerLayer wrapper (with seq attention)
│   ├── dap_pairformer_noseq.py      # PairformerLayer wrapper (for templates)
│   ├── dap_trimul.py                # Triangle multiplication (broadcast-chunked)
│   ├── dap_tri_att.py               # Triangle attention (gather only bias)
│   ├── dap_msa.py                   # MSA module wrapper
│   └── dap_confidence.py            # Confidence module wrapper
├── boltz_distributed/               # Communication primitives
│   ├── core.py                      # init_dap(), get_dap_rank(), get_dap_size()
│   ├── comm.py                      # scatter, gather, row_to_col, col_to_row
│   └── wrappers.py                  # Helper wrappers
├── docs/                            # Getting started, prediction guide
├── scripts/                         # Auxiliary Python scripts (compare, analyze, test, etc.)
├── slurm/                           # SLURM job scripts (.sbatch, .sh) for HPC runs
└── README.md

Key Design Decisions

Zero Boltz-2 Modifications

DAP does not modify any original Boltz-2 source code. Instead, it monkey-patches the model at runtime:

# dap_trunk.py
inject_dap_into_model(model)  # Wraps each layer with DAP-aware version

The original boltz/ package remains untouched. All weights are identical.

Broadcast-Chunked Triangle Multiplication

The hardest operation to distribute. Instead of all-gathering the full tensor (which would defeat the purpose), each GPU broadcasts its shard one at a time:

# Each GPU broadcasts b_chunk, others compute partial output
for src in range(dap_size):
    dist.broadcast(b_chunk, src=src)       # One shard at a time
    out[:, :, j_start:j_end, :] = einsum(  # Fill j-columns
        "bikd,bjkd->bijd", a, b_chunk
    )

Peak memory stays at ~2× shard size vs full N×N.

Bias-Only Gathering

For triangle attention and sequence attention, only the small bias tensor [B, H, N, N] (H ≈ 4–16) is gathered, not the full z [B, N, N, 128]. This reduces communication by ~8–32×.

Numerical Accuracy

DAP produces results with minor floating-point differences from single-GPU Boltz-2, due to different operation ordering in distributed reductions. Structure predictions (LDDT, TM-score) are statistically equivalent.

References

If you found this project useful, please cite:

We would also appreciate it if you could cite this repository in any work that uses or builds upon it. A formal citation will be provided in a preprint describing our implementation, benchmarks, and results on our AAV multimer structure prediction with this approach.

License

This DAP wrapper follows the same MIT license as Boltz-2.

Further Advancement

For any inquiries, please email {gleeai, wjkimab}@connect.ust.hk, we would be happy to help with anything we could.

Acknowledgements

We sincerely thank:

  • the original Boltz-2 team for fully open-sourcing their state-of-the-art biomolecular structure prediction models,
  • the FastFold team for their open-source distributed communication utilities,
  • the AlphaFold 3 team for open-sourcing their inference code and model weights,
  • the deep learning for biomolecular interaction modeling and the broader AI for Science communities for their ongoing contributions in this exciting field, and
  • the developers and maintainers of all the packages used in this project!

This project was developed with generous compute support in HKUST HPC4 and SuperPOD from The Hong Kong University of Science and Technology (HKUST). This work was conducted at the lab of Prof. Bonnie Danqing Zhu in the Department of Chemical and Biological Engineering (CBE).

We note the parallel development of Fold-CP by the team at NVIDIA Digital Bio, which also enables multi-GPU Boltz-2 inference (and also training) with a different approach. We look forward to comparing and learning from each other's implementations!

Differences with Fold-CP

Adapted from boltz2_cp_prediction. Most of the original serial prediction's features are supported by Boltz-DAP.

Aspect Boltz-DAP (boltz_dap_v2/run_boltz_dap_v2.py) Fold-CP (src/boltz/distributed/main.py)
Multi-GPU strategy SingleDeviceStrategy + DTensor CP mesh
Device management DistributedManager via --size_dp, --size_cp
Launch method torchrun or srun torchrun or srun
Input formats config_files (YAML/FASTA), preprocessed preprocessed only
num_workers Configurable Fixed at 0 (DTensor CP requires main-process collation)
Precision Lightning --precision string Top-level --precision enum
Attention backends --triattn_backend, --sdpa_with_bias_backend, --sdpa_with_bias_shardwise_backend
CUDA memory profiling --cuda_memory_profile flag
Confidence prediction Supported Not yet supported (write_confidence_summary=False)
Steering potentials Supported Not yet supported
Affinity prediction Supported Not yet supported
Template features Supported Weights loaded but distributed TemplateModule not yet implemented
Constraint features Supported Not yet supported
Checkpoint loading Reads checkpoint hparams, merges v2 flags, loads with strict=True
Output writing All ranks write Only CP rank 0 per DP group writes output

About

Uses DAP (Distributed Axial Parallelism) to prevent OOM when running Boltz-2 protein structure inference. Optional FlexAttention for triangle attention.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors