diff --git a/torchtitan/config/manager.py b/torchtitan/config/manager.py index 2a3c766e31..67f0e61d4a 100644 --- a/torchtitan/config/manager.py +++ b/torchtitan/config/manager.py @@ -129,7 +129,7 @@ def _merge_configs(base, custom) -> Type: - Otherwise, the field from `custom` overrides the one in `base` (type, default, etc.). - Fields only present in `base` or `custom` are preserved as-is. """ - result = [] + result: list[str | tuple[str, Any] | tuple[str, Any, Any]] = [] b_map = {f.name: f for f in fields(base)} c_map = {f.name: f for f in fields(custom)} diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 3e3256dfec..a7230a9ea5 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -121,7 +121,10 @@ def unflatten_mesh( backend_override[name] = "fake" return world_mesh._unflatten( - 0, dim_degrees, dim_names, backend_override=backend_override + 0, + dim_degrees, + dim_names, + backend_override=backend_override, ) logger.info( diff --git a/torchtitan/experiments/rl/unified/README.md b/torchtitan/experiments/rl/unified/README.md index 27550e977c..9e0dbf8ccb 100644 --- a/torchtitan/experiments/rl/unified/README.md +++ b/torchtitan/experiments/rl/unified/README.md @@ -14,9 +14,10 @@ The integration consists of two main components: ## Quick Start ### Prerequisites -1. Install PyTorch nightly for torchtitan: -``` -pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --force-reinstall +1. Install PyTorch nightly & Monarch for torchtitan: +```bash +uv pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --force-reinstall +uv pip install torchmonarch ``` @@ -33,7 +34,7 @@ uv pip install --no-build-isolation -e . NOTE: If `flash_attn_varlen_func` hits error "torch.AcceleratorError: CUDA error: the provided PTX was compiled with an unsupported toolchain" during forward path, this is due to GPU driver version is not compatible with vLLM/PyTorch compiled version. Use the following command to recompile vLLM. -``` +```bash # Set CUDA version environment variable export CUDA_HOME=/usr/local/cuda-12.4 export PATH=/usr/local/cuda-12.4/bin:$PATH @@ -49,23 +50,23 @@ uv pip install -e . ``` 3. Download Qwen/Qwen3-0.6B checkpoint from HuggingFace and put into `torchtitan/experiments/rl/example_checkpoint` folder. -``` +```bash python scripts/download_hf_assets.py --repo_id Qwen/Qwen3-0.6B --local_dir torchtitan/experiments/rl/example_checkpoint --all --hf_token=... ``` 4. Run inference: -``` +```bash python torchtitan/experiments/rl/unified/infer.py --model-ckpt-path ``` Run with TP: (work in progress) -``` +```bash python torchtitan/experiments/rl/unified/infer.py --model-ckpt-path --tensor-parallel-size 2 ``` 5. Run simple rl loop -``` +```bash VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_rl_multiprocess.py ``` Right now we only support VLLM_COMPAT mode, which could achieve trainer and generator bitwise identical. We are working on support UNIFIED mode, diff --git a/torchtitan/experiments/rl/unified/actors/generator.py b/torchtitan/experiments/rl/unified/actors/generator.py index 45d89095b7..61842f5bd8 100644 --- a/torchtitan/experiments/rl/unified/actors/generator.py +++ b/torchtitan/experiments/rl/unified/actors/generator.py @@ -28,6 +28,8 @@ ) from torchtitan.experiments.rl.vllm_compat.weights.converter import torchtitan_to_vllm from vllm import LLM, SamplingParams +from vllm.config import AttentionConfig +from vllm.v1.attention.backends.registry import AttentionBackendEnum logger = logging.getLogger(__name__) @@ -210,6 +212,9 @@ def update_weights(self, vllm_compat_state: dict) -> None: seed=42, # Fixed seed for determinism enforce_eager=True, tensor_parallel_size=self.tp_size, # Explicitly single GPU + attention_config=AttentionConfig( + backend=AttentionBackendEnum.FLASH_ATTN, + ), ) logger.info("Created new vLLM engine") else: diff --git a/torchtitan/experiments/rl/unified/models/attention.py b/torchtitan/experiments/rl/unified/models/attention.py index 3f11b3a294..8f317ed2ad 100644 --- a/torchtitan/experiments/rl/unified/models/attention.py +++ b/torchtitan/experiments/rl/unified/models/attention.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import torch - from vllm.model_executor.layers.attention import Attention diff --git a/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py b/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py index 3e914f3778..b221dad8d7 100644 --- a/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py +++ b/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py @@ -33,6 +33,7 @@ init_batch_invariance, vllm_is_batch_invariant, ) +from vllm.v1.attention.backends.registry import AttentionBackendEnum logger = logging.getLogger(__name__) @@ -63,7 +64,7 @@ async def main(): trainer_tp_size = 1 generator_tp_size = 1 - init_batch_invariance() + init_batch_invariance(AttentionBackendEnum.FLASH_ATTN) batch_invariant = vllm_is_batch_invariant() mode = ModelMode.UNIFIED diff --git a/torchtitan/experiments/rl/vllm_compat/README.md b/torchtitan/experiments/rl/vllm_compat/README.md index 84df62d3ed..3cdb35d232 100644 --- a/torchtitan/experiments/rl/vllm_compat/README.md +++ b/torchtitan/experiments/rl/vllm_compat/README.md @@ -71,7 +71,8 @@ Initialize vLLM's batch-invariant mode before training: ```python from vllm.model_executor.layers.batch_invariant import init_batch_invariance -init_batch_invariance() +from vllm.v1.attention.backends.registry import AttentionBackendEnum +init_batch_invariance(AttentionBackendEnum.FLASH_ATTN) ``` ## Usage diff --git a/torchtitan/experiments/rl/vllm_compat/batch_invariant_backward.py b/torchtitan/experiments/rl/vllm_compat/batch_invariant_backward.py index b67244478e..b7cf56efa7 100644 --- a/torchtitan/experiments/rl/vllm_compat/batch_invariant_backward.py +++ b/torchtitan/experiments/rl/vllm_compat/batch_invariant_backward.py @@ -22,7 +22,7 @@ from batch_invariant_backward import enable_batch_invariant_backward_mode # Initialize vLLM's deterministic mode first - init_batch_invariance() + init_batch_invariance(AttentionBackendEnum.FLASH_ATTN) # Then enable gradient support enable_batch_invariant_backward_mode() @@ -308,7 +308,7 @@ def enable_batch_invariant_backward_mode(): ): raise RuntimeError( "vLLM's batch_invariant mode is not initialized. " - "Call init_batch_invariance() first." + "Call init_batch_invariance(AttentionBackendEnum.FLASH_ATTN) first." ) # Use vLLM's existing library - don't destroy it! diff --git a/torchtitan/experiments/rl/vllm_compat/models/attention.py b/torchtitan/experiments/rl/vllm_compat/models/attention.py index 752b416922..5c54f14968 100644 --- a/torchtitan/experiments/rl/vllm_compat/models/attention.py +++ b/torchtitan/experiments/rl/vllm_compat/models/attention.py @@ -6,6 +6,7 @@ import math +from collections.abc import Callable import torch from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func @@ -30,6 +31,7 @@ def forward( v: torch.Tensor, *, scale: float | None = None, + enable_gqa: bool = False, ) -> torch.Tensor: # Flash Attention varlen expects: (batch, seqlen, nheads, headdim) # The input from TorchTitan is always (batch, num_heads, seq_len, head_dim) @@ -42,12 +44,13 @@ def forward( # Get dimensions batch_size, seq_len, num_heads, head_dim = q.shape + num_kv_heads = k.shape[2] # Convert to varlen format: flatten batch and sequence dimensions # (batch, seqlen, nheads, headdim) -> (total_tokens, nheads, headdim) q_varlen = q.reshape(-1, num_heads, head_dim) - k_varlen = k.reshape(-1, k.shape[2], head_dim) - v_varlen = v.reshape(-1, v.shape[2], head_dim) + k_varlen = k.reshape(-1, num_kv_heads, head_dim) + v_varlen = v.reshape(-1, num_kv_heads, head_dim) # Create cumulative sequence lengths # cu_seqlens: [0, seq_len, 2*seq_len, ..., batch_size*seq_len] @@ -59,21 +62,29 @@ def forward( if scale is None: scale = 1.0 / math.sqrt(q.size(-1)) + # Pre-allocate output tensor with correct shape (num_heads from Q, not K/V) + # This ensures flash attention writes to a tensor with the correct GQA output shape + total_tokens = batch_size * seq_len + out_varlen = torch.empty( + (total_tokens, num_heads, head_dim), dtype=q.dtype, device=q.device + ) + # Wrap Flash Attention with manual backward pass class FlashAttnWithBackward(torch.autograd.Function): @staticmethod def forward( - ctx, - q, - k, - v, - cu_seqlens, - seq_len, - scale, - num_splits, - flash_fn, - fa_version, - ): + ctx: torch.autograd.function.FunctionCtx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + cu_seqlens: torch.Tensor, + seq_len: int, + scale: float, + num_splits: int, + flash_fn: Callable[..., torch.Tensor], + fa_version: int, + ) -> torch.Tensor: # Call flash attention for forward (fast) output = flash_fn( q, @@ -87,6 +98,7 @@ def forward( causal=True, num_splits=num_splits, fa_version=fa_version, + out=out, ) # Save for backward ctx.save_for_backward(q, k, v, output) @@ -95,7 +107,20 @@ def forward( return output @staticmethod - def backward(ctx, grad_output): + def backward( + ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + None, + None, + None, + None, + None, + None, + None, + ]: q, k, v, output = ctx.saved_tensors scale = ctx.scale seq_len = ctx.seq_len @@ -104,12 +129,13 @@ def backward(ctx, grad_output): # Assume uniform sequence lengths (batch_size = total_tokens / seq_len) total_tokens = q.shape[0] num_heads = q.shape[1] + num_kv_heads = k.shape[1] head_dim = q.shape[2] batch_size = total_tokens // seq_len q_batch = q.reshape(batch_size, seq_len, num_heads, head_dim) - k_batch = k.reshape(batch_size, seq_len, num_heads, head_dim) - v_batch = v.reshape(batch_size, seq_len, num_heads, head_dim) + k_batch = k.reshape(batch_size, seq_len, num_kv_heads, head_dim) + v_batch = v.reshape(batch_size, seq_len, num_kv_heads, head_dim) out_batch = output.reshape(batch_size, seq_len, num_heads, head_dim) grad_out_batch = grad_output.reshape( batch_size, seq_len, num_heads, head_dim @@ -122,6 +148,17 @@ def backward(ctx, grad_output): out_t = out_batch.transpose(1, 2) grad_out_t = grad_out_batch.transpose(1, 2) + # For GQA, we need to expand K/V to match Q's num_heads + # Each KV head serves (num_heads // num_kv_heads) Q heads + if num_kv_heads < num_heads: + assert enable_gqa, "GQA requires enable_gqa=True" + assert ( + num_heads % num_kv_heads == 0 + ), "num_heads must be a multiple of num_kv_heads" + n_rep = num_heads // num_kv_heads + k_t = k_t.repeat_interleave(n_rep, dim=1) + v_t = v_t.repeat_interleave(n_rep, dim=1) + # Compute attention scores: QK^T # q_t: (B, H, N, D), k_t: (B, H, N, D) -> scores: (B, H, N, N) scores = torch.matmul(q_t, k_t.transpose(-2, -1)) * scale @@ -167,20 +204,37 @@ def backward(ctx, grad_output): grad_q = grad_q_t.transpose(1, 2).reshape( total_tokens, num_heads, head_dim ) + + # For GQA, we need to reduce grad_k and grad_v back to num_kv_heads + if num_kv_heads < num_heads: + assert enable_gqa, "GQA requires enable_gqa=True" + assert ( + num_heads % num_kv_heads == 0 + ), "num_heads must be a multiple of num_kv_heads" + n_rep = num_heads // num_kv_heads + # Reshape and sum over the repeated dimension + grad_k_t = grad_k_t.reshape( + batch_size, num_kv_heads, n_rep, seq_len, head_dim + ).sum(dim=2) + grad_v_t = grad_v_t.reshape( + batch_size, num_kv_heads, n_rep, seq_len, head_dim + ).sum(dim=2) + grad_k = grad_k_t.transpose(1, 2).reshape( - total_tokens, num_heads, head_dim + total_tokens, num_kv_heads, head_dim ) grad_v = grad_v_t.transpose(1, 2).reshape( - total_tokens, num_heads, head_dim + total_tokens, num_kv_heads, head_dim ) - return grad_q, grad_k, grad_v, None, None, None, None, None, None + return grad_q, grad_k, grad_v, None, None, None, None, None, None, None # Call Flash Attention varlen with custom backward output_varlen = FlashAttnWithBackward.apply( q_varlen, k_varlen, v_varlen, + out_varlen, cu_seqlens, seq_len, scale,