From ea2f9311cbf2414ed1333c66683c12dd47fdb9cf Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Mon, 9 Feb 2026 10:59:26 -0800 Subject: [PATCH 1/3] Scaffolding changes for torch.compile support --- .gitignore | 4 ++ torchtitan/components/metrics.py | 1 + torchtitan/config/manager.py | 3 ++ torchtitan/distributed/parallel_dims.py | 6 ++- torchtitan/experiments/rl/unified/README.md | 3 +- .../rl/unified/actors/generator.py | 5 +++ .../rl/unified/models/attention.py | 4 +- .../rl/unified/simple_rl_multiprocess.py | 3 +- .../experiments/rl/vllm_compat/README.md | 3 +- .../vllm_compat/batch_invariant_backward.py | 4 +- .../rl/vllm_compat/models/attention.py | 42 +++++++++++++++---- 11 files changed, 63 insertions(+), 15 deletions(-) diff --git a/.gitignore b/.gitignore index 415631ff9c..b81b3275a9 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,7 @@ Sessionx.vim # Vibe coding .claude + +# Experiment rl +converted/ +models/ diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index bffe42946a..b6e4f2b6c9 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -136,6 +136,7 @@ class WandBLogger(BaseLogger): def __init__(self, log_dir: str, job_config: JobConfig, tag: str | None = None): # Import wandb here to avoid startup import + # pyrefly: ignore [missing-import] import wandb self.wandb = wandb diff --git a/torchtitan/config/manager.py b/torchtitan/config/manager.py index 2a3c766e31..02817556d3 100644 --- a/torchtitan/config/manager.py +++ b/torchtitan/config/manager.py @@ -144,15 +144,18 @@ def _merge_configs(base, custom) -> Type: # Custom field overrides base type elif name in c_map: + # pyrefly: ignore [bad-argument-type] result.append((name, c_map[name].type, c_map[name])) # Only in Base else: + # pyrefly: ignore [bad-argument-type] result.append((name, f.type, f)) # Only in Custom for name, f in c_map.items(): if name not in b_map: + # pyrefly: ignore [bad-argument-type] result.append((name, f.type, f)) return make_dataclass(f"Merged{base.__name__}", result, bases=(base,)) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 3e3256dfec..5fcf646a6b 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -121,7 +121,11 @@ def unflatten_mesh( backend_override[name] = "fake" return world_mesh._unflatten( - 0, dim_degrees, dim_names, backend_override=backend_override + 0, + dim_degrees, + dim_names, + # pyrefly: ignore [bad-argument-type] + backend_override=backend_override, ) logger.info( diff --git a/torchtitan/experiments/rl/unified/README.md b/torchtitan/experiments/rl/unified/README.md index 27550e977c..0861120dba 100644 --- a/torchtitan/experiments/rl/unified/README.md +++ b/torchtitan/experiments/rl/unified/README.md @@ -14,9 +14,10 @@ The integration consists of two main components: ## Quick Start ### Prerequisites -1. Install PyTorch nightly for torchtitan: +1. Install PyTorch nightly & Monarch for torchtitan: ``` pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --force-reinstall +pip3 install torchmonarch ``` diff --git a/torchtitan/experiments/rl/unified/actors/generator.py b/torchtitan/experiments/rl/unified/actors/generator.py index 45d89095b7..61842f5bd8 100644 --- a/torchtitan/experiments/rl/unified/actors/generator.py +++ b/torchtitan/experiments/rl/unified/actors/generator.py @@ -28,6 +28,8 @@ ) from torchtitan.experiments.rl.vllm_compat.weights.converter import torchtitan_to_vllm from vllm import LLM, SamplingParams +from vllm.config import AttentionConfig +from vllm.v1.attention.backends.registry import AttentionBackendEnum logger = logging.getLogger(__name__) @@ -210,6 +212,9 @@ def update_weights(self, vllm_compat_state: dict) -> None: seed=42, # Fixed seed for determinism enforce_eager=True, tensor_parallel_size=self.tp_size, # Explicitly single GPU + attention_config=AttentionConfig( + backend=AttentionBackendEnum.FLASH_ATTN, + ), ) logger.info("Created new vLLM engine") else: diff --git a/torchtitan/experiments/rl/unified/models/attention.py b/torchtitan/experiments/rl/unified/models/attention.py index 3f11b3a294..827782f9ac 100644 --- a/torchtitan/experiments/rl/unified/models/attention.py +++ b/torchtitan/experiments/rl/unified/models/attention.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import torch - from vllm.model_executor.layers.attention import Attention @@ -99,7 +98,8 @@ def forward( output_flat = self.vllm_attn(q, k, v) # Output is (batch * seq_len, num_heads * head_dim), reshape to (batch, seq_len, num_heads, head_dim) - output = output_flat.view(batch_size, seq_len, num_heads, head_dim) + # Use self.num_heads and self.head_dim since vLLM Attention outputs based on its configured dimensions + output = output_flat.view(batch_size, seq_len, self.num_heads, self.head_dim) # Transpose back to TorchTitan format: (batch, num_heads, seq_len, head_dim) output = output.transpose(1, 2) diff --git a/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py b/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py index 3e914f3778..b221dad8d7 100644 --- a/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py +++ b/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py @@ -33,6 +33,7 @@ init_batch_invariance, vllm_is_batch_invariant, ) +from vllm.v1.attention.backends.registry import AttentionBackendEnum logger = logging.getLogger(__name__) @@ -63,7 +64,7 @@ async def main(): trainer_tp_size = 1 generator_tp_size = 1 - init_batch_invariance() + init_batch_invariance(AttentionBackendEnum.FLASH_ATTN) batch_invariant = vllm_is_batch_invariant() mode = ModelMode.UNIFIED diff --git a/torchtitan/experiments/rl/vllm_compat/README.md b/torchtitan/experiments/rl/vllm_compat/README.md index 84df62d3ed..3cdb35d232 100644 --- a/torchtitan/experiments/rl/vllm_compat/README.md +++ b/torchtitan/experiments/rl/vllm_compat/README.md @@ -71,7 +71,8 @@ Initialize vLLM's batch-invariant mode before training: ```python from vllm.model_executor.layers.batch_invariant import init_batch_invariance -init_batch_invariance() +from vllm.v1.attention.backends.registry import AttentionBackendEnum +init_batch_invariance(AttentionBackendEnum.FLASH_ATTN) ``` ## Usage diff --git a/torchtitan/experiments/rl/vllm_compat/batch_invariant_backward.py b/torchtitan/experiments/rl/vllm_compat/batch_invariant_backward.py index b67244478e..b7cf56efa7 100644 --- a/torchtitan/experiments/rl/vllm_compat/batch_invariant_backward.py +++ b/torchtitan/experiments/rl/vllm_compat/batch_invariant_backward.py @@ -22,7 +22,7 @@ from batch_invariant_backward import enable_batch_invariant_backward_mode # Initialize vLLM's deterministic mode first - init_batch_invariance() + init_batch_invariance(AttentionBackendEnum.FLASH_ATTN) # Then enable gradient support enable_batch_invariant_backward_mode() @@ -308,7 +308,7 @@ def enable_batch_invariant_backward_mode(): ): raise RuntimeError( "vLLM's batch_invariant mode is not initialized. " - "Call init_batch_invariance() first." + "Call init_batch_invariance(AttentionBackendEnum.FLASH_ATTN) first." ) # Use vLLM's existing library - don't destroy it! diff --git a/torchtitan/experiments/rl/vllm_compat/models/attention.py b/torchtitan/experiments/rl/vllm_compat/models/attention.py index 752b416922..5474d77c27 100644 --- a/torchtitan/experiments/rl/vllm_compat/models/attention.py +++ b/torchtitan/experiments/rl/vllm_compat/models/attention.py @@ -30,6 +30,7 @@ def forward( v: torch.Tensor, *, scale: float | None = None, + enable_gqa: bool = False, ) -> torch.Tensor: # Flash Attention varlen expects: (batch, seqlen, nheads, headdim) # The input from TorchTitan is always (batch, num_heads, seq_len, head_dim) @@ -42,12 +43,13 @@ def forward( # Get dimensions batch_size, seq_len, num_heads, head_dim = q.shape + num_kv_heads = k.shape[2] # Convert to varlen format: flatten batch and sequence dimensions # (batch, seqlen, nheads, headdim) -> (total_tokens, nheads, headdim) q_varlen = q.reshape(-1, num_heads, head_dim) - k_varlen = k.reshape(-1, k.shape[2], head_dim) - v_varlen = v.reshape(-1, v.shape[2], head_dim) + k_varlen = k.reshape(-1, num_kv_heads, head_dim) + v_varlen = v.reshape(-1, num_kv_heads, head_dim) # Create cumulative sequence lengths # cu_seqlens: [0, seq_len, 2*seq_len, ..., batch_size*seq_len] @@ -59,6 +61,13 @@ def forward( if scale is None: scale = 1.0 / math.sqrt(q.size(-1)) + # Pre-allocate output tensor with correct shape (num_heads from Q, not K/V) + # This ensures flash attention writes to a tensor with the correct GQA output shape + total_tokens = batch_size * seq_len + out_varlen = torch.empty( + (total_tokens, num_heads, head_dim), dtype=q.dtype, device=q.device + ) + # Wrap Flash Attention with manual backward pass class FlashAttnWithBackward(torch.autograd.Function): @staticmethod @@ -67,6 +76,7 @@ def forward( q, k, v, + out, cu_seqlens, seq_len, scale, @@ -87,6 +97,7 @@ def forward( causal=True, num_splits=num_splits, fa_version=fa_version, + out=out, ) # Save for backward ctx.save_for_backward(q, k, v, output) @@ -104,12 +115,13 @@ def backward(ctx, grad_output): # Assume uniform sequence lengths (batch_size = total_tokens / seq_len) total_tokens = q.shape[0] num_heads = q.shape[1] + num_kv_heads = k.shape[1] head_dim = q.shape[2] batch_size = total_tokens // seq_len q_batch = q.reshape(batch_size, seq_len, num_heads, head_dim) - k_batch = k.reshape(batch_size, seq_len, num_heads, head_dim) - v_batch = v.reshape(batch_size, seq_len, num_heads, head_dim) + k_batch = k.reshape(batch_size, seq_len, num_kv_heads, head_dim) + v_batch = v.reshape(batch_size, seq_len, num_kv_heads, head_dim) out_batch = output.reshape(batch_size, seq_len, num_heads, head_dim) grad_out_batch = grad_output.reshape( batch_size, seq_len, num_heads, head_dim @@ -122,6 +134,13 @@ def backward(ctx, grad_output): out_t = out_batch.transpose(1, 2) grad_out_t = grad_out_batch.transpose(1, 2) + # For GQA, we need to expand K/V to match Q's num_heads + # Each KV head serves (num_heads // num_kv_heads) Q heads + if num_kv_heads != num_heads: + n_rep = num_heads // num_kv_heads + k_t = k_t.repeat_interleave(n_rep, dim=1) + v_t = v_t.repeat_interleave(n_rep, dim=1) + # Compute attention scores: QK^T # q_t: (B, H, N, D), k_t: (B, H, N, D) -> scores: (B, H, N, N) scores = torch.matmul(q_t, k_t.transpose(-2, -1)) * scale @@ -167,20 +186,29 @@ def backward(ctx, grad_output): grad_q = grad_q_t.transpose(1, 2).reshape( total_tokens, num_heads, head_dim ) + + # For GQA, we need to reduce grad_k and grad_v back to num_kv_heads + if num_kv_heads != num_heads: + n_rep = num_heads // num_kv_heads + # Reshape and sum over the repeated dimension + grad_k_t = grad_k_t.reshape(batch_size, num_kv_heads, n_rep, seq_len, head_dim).sum(dim=2) + grad_v_t = grad_v_t.reshape(batch_size, num_kv_heads, n_rep, seq_len, head_dim).sum(dim=2) + grad_k = grad_k_t.transpose(1, 2).reshape( - total_tokens, num_heads, head_dim + total_tokens, num_kv_heads, head_dim ) grad_v = grad_v_t.transpose(1, 2).reshape( - total_tokens, num_heads, head_dim + total_tokens, num_kv_heads, head_dim ) - return grad_q, grad_k, grad_v, None, None, None, None, None, None + return grad_q, grad_k, grad_v, None, None, None, None, None, None, None # Call Flash Attention varlen with custom backward output_varlen = FlashAttnWithBackward.apply( q_varlen, k_varlen, v_varlen, + out_varlen, cu_seqlens, seq_len, scale, From 28bb3638aab65cec2c4c1d36b51a1ec0e221fc76 Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Tue, 10 Feb 2026 08:24:30 -0800 Subject: [PATCH 2/3] Incorpororate feedback --- .gitignore | 6 +- torchtitan/experiments/rl/unified/README.md | 16 ++--- .../rl/unified/models/attention.py | 3 +- .../rl/vllm_compat/models/attention.py | 60 +++++++++++++------ 4 files changed, 55 insertions(+), 30 deletions(-) diff --git a/.gitignore b/.gitignore index b81b3275a9..39d6694f6b 100644 --- a/.gitignore +++ b/.gitignore @@ -46,6 +46,6 @@ Sessionx.vim # Vibe coding .claude -# Experiment rl -converted/ -models/ +# Experiments/rl artifacts +/converted/ +/models/ diff --git a/torchtitan/experiments/rl/unified/README.md b/torchtitan/experiments/rl/unified/README.md index 0861120dba..9e0dbf8ccb 100644 --- a/torchtitan/experiments/rl/unified/README.md +++ b/torchtitan/experiments/rl/unified/README.md @@ -15,9 +15,9 @@ The integration consists of two main components: ### Prerequisites 1. Install PyTorch nightly & Monarch for torchtitan: -``` -pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --force-reinstall -pip3 install torchmonarch +```bash +uv pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --force-reinstall +uv pip install torchmonarch ``` @@ -34,7 +34,7 @@ uv pip install --no-build-isolation -e . NOTE: If `flash_attn_varlen_func` hits error "torch.AcceleratorError: CUDA error: the provided PTX was compiled with an unsupported toolchain" during forward path, this is due to GPU driver version is not compatible with vLLM/PyTorch compiled version. Use the following command to recompile vLLM. -``` +```bash # Set CUDA version environment variable export CUDA_HOME=/usr/local/cuda-12.4 export PATH=/usr/local/cuda-12.4/bin:$PATH @@ -50,23 +50,23 @@ uv pip install -e . ``` 3. Download Qwen/Qwen3-0.6B checkpoint from HuggingFace and put into `torchtitan/experiments/rl/example_checkpoint` folder. -``` +```bash python scripts/download_hf_assets.py --repo_id Qwen/Qwen3-0.6B --local_dir torchtitan/experiments/rl/example_checkpoint --all --hf_token=... ``` 4. Run inference: -``` +```bash python torchtitan/experiments/rl/unified/infer.py --model-ckpt-path ``` Run with TP: (work in progress) -``` +```bash python torchtitan/experiments/rl/unified/infer.py --model-ckpt-path --tensor-parallel-size 2 ``` 5. Run simple rl loop -``` +```bash VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_rl_multiprocess.py ``` Right now we only support VLLM_COMPAT mode, which could achieve trainer and generator bitwise identical. We are working on support UNIFIED mode, diff --git a/torchtitan/experiments/rl/unified/models/attention.py b/torchtitan/experiments/rl/unified/models/attention.py index 827782f9ac..8f317ed2ad 100644 --- a/torchtitan/experiments/rl/unified/models/attention.py +++ b/torchtitan/experiments/rl/unified/models/attention.py @@ -98,8 +98,7 @@ def forward( output_flat = self.vllm_attn(q, k, v) # Output is (batch * seq_len, num_heads * head_dim), reshape to (batch, seq_len, num_heads, head_dim) - # Use self.num_heads and self.head_dim since vLLM Attention outputs based on its configured dimensions - output = output_flat.view(batch_size, seq_len, self.num_heads, self.head_dim) + output = output_flat.view(batch_size, seq_len, num_heads, head_dim) # Transpose back to TorchTitan format: (batch, num_heads, seq_len, head_dim) output = output.transpose(1, 2) diff --git a/torchtitan/experiments/rl/vllm_compat/models/attention.py b/torchtitan/experiments/rl/vllm_compat/models/attention.py index 5474d77c27..5c54f14968 100644 --- a/torchtitan/experiments/rl/vllm_compat/models/attention.py +++ b/torchtitan/experiments/rl/vllm_compat/models/attention.py @@ -6,6 +6,7 @@ import math +from collections.abc import Callable import torch from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func @@ -72,18 +73,18 @@ def forward( class FlashAttnWithBackward(torch.autograd.Function): @staticmethod def forward( - ctx, - q, - k, - v, - out, - cu_seqlens, - seq_len, - scale, - num_splits, - flash_fn, - fa_version, - ): + ctx: torch.autograd.function.FunctionCtx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + cu_seqlens: torch.Tensor, + seq_len: int, + scale: float, + num_splits: int, + flash_fn: Callable[..., torch.Tensor], + fa_version: int, + ) -> torch.Tensor: # Call flash attention for forward (fast) output = flash_fn( q, @@ -106,7 +107,20 @@ def forward( return output @staticmethod - def backward(ctx, grad_output): + def backward( + ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + None, + None, + None, + None, + None, + None, + None, + ]: q, k, v, output = ctx.saved_tensors scale = ctx.scale seq_len = ctx.seq_len @@ -136,7 +150,11 @@ def backward(ctx, grad_output): # For GQA, we need to expand K/V to match Q's num_heads # Each KV head serves (num_heads // num_kv_heads) Q heads - if num_kv_heads != num_heads: + if num_kv_heads < num_heads: + assert enable_gqa, "GQA requires enable_gqa=True" + assert ( + num_heads % num_kv_heads == 0 + ), "num_heads must be a multiple of num_kv_heads" n_rep = num_heads // num_kv_heads k_t = k_t.repeat_interleave(n_rep, dim=1) v_t = v_t.repeat_interleave(n_rep, dim=1) @@ -188,11 +206,19 @@ def backward(ctx, grad_output): ) # For GQA, we need to reduce grad_k and grad_v back to num_kv_heads - if num_kv_heads != num_heads: + if num_kv_heads < num_heads: + assert enable_gqa, "GQA requires enable_gqa=True" + assert ( + num_heads % num_kv_heads == 0 + ), "num_heads must be a multiple of num_kv_heads" n_rep = num_heads // num_kv_heads # Reshape and sum over the repeated dimension - grad_k_t = grad_k_t.reshape(batch_size, num_kv_heads, n_rep, seq_len, head_dim).sum(dim=2) - grad_v_t = grad_v_t.reshape(batch_size, num_kv_heads, n_rep, seq_len, head_dim).sum(dim=2) + grad_k_t = grad_k_t.reshape( + batch_size, num_kv_heads, n_rep, seq_len, head_dim + ).sum(dim=2) + grad_v_t = grad_v_t.reshape( + batch_size, num_kv_heads, n_rep, seq_len, head_dim + ).sum(dim=2) grad_k = grad_k_t.transpose(1, 2).reshape( total_tokens, num_kv_heads, head_dim From 4d41bf1df7ad3081a889029ddbb58dd409fbd31d Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Tue, 10 Feb 2026 14:41:21 -0800 Subject: [PATCH 3/3] Remove gitignore, pyrefly changes --- .gitignore | 4 ---- torchtitan/components/metrics.py | 1 - torchtitan/config/manager.py | 5 +---- torchtitan/distributed/parallel_dims.py | 1 - 4 files changed, 1 insertion(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index 39d6694f6b..415631ff9c 100644 --- a/.gitignore +++ b/.gitignore @@ -45,7 +45,3 @@ Sessionx.vim # Vibe coding .claude - -# Experiments/rl artifacts -/converted/ -/models/ diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index b6e4f2b6c9..bffe42946a 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -136,7 +136,6 @@ class WandBLogger(BaseLogger): def __init__(self, log_dir: str, job_config: JobConfig, tag: str | None = None): # Import wandb here to avoid startup import - # pyrefly: ignore [missing-import] import wandb self.wandb = wandb diff --git a/torchtitan/config/manager.py b/torchtitan/config/manager.py index 02817556d3..67f0e61d4a 100644 --- a/torchtitan/config/manager.py +++ b/torchtitan/config/manager.py @@ -129,7 +129,7 @@ def _merge_configs(base, custom) -> Type: - Otherwise, the field from `custom` overrides the one in `base` (type, default, etc.). - Fields only present in `base` or `custom` are preserved as-is. """ - result = [] + result: list[str | tuple[str, Any] | tuple[str, Any, Any]] = [] b_map = {f.name: f for f in fields(base)} c_map = {f.name: f for f in fields(custom)} @@ -144,18 +144,15 @@ def _merge_configs(base, custom) -> Type: # Custom field overrides base type elif name in c_map: - # pyrefly: ignore [bad-argument-type] result.append((name, c_map[name].type, c_map[name])) # Only in Base else: - # pyrefly: ignore [bad-argument-type] result.append((name, f.type, f)) # Only in Custom for name, f in c_map.items(): if name not in b_map: - # pyrefly: ignore [bad-argument-type] result.append((name, f.type, f)) return make_dataclass(f"Merged{base.__name__}", result, bases=(base,)) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 5fcf646a6b..a7230a9ea5 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -124,7 +124,6 @@ def unflatten_mesh( 0, dim_degrees, dim_names, - # pyrefly: ignore [bad-argument-type] backend_override=backend_override, )