Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchtitan/config/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def _merge_configs(base, custom) -> Type:
- Otherwise, the field from `custom` overrides the one in `base` (type, default, etc.).
- Fields only present in `base` or `custom` are preserved as-is.
"""
result = []
result: list[str | tuple[str, Any] | tuple[str, Any, Any]] = []
b_map = {f.name: f for f in fields(base)}
c_map = {f.name: f for f in fields(custom)}

Expand Down
5 changes: 4 additions & 1 deletion torchtitan/distributed/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,10 @@ def unflatten_mesh(
backend_override[name] = "fake"

return world_mesh._unflatten(
0, dim_degrees, dim_names, backend_override=backend_override
0,
dim_degrees,
dim_names,
backend_override=backend_override,
)

logger.info(
Expand Down
17 changes: 9 additions & 8 deletions torchtitan/experiments/rl/unified/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ The integration consists of two main components:
## Quick Start
### Prerequisites

1. Install PyTorch nightly for torchtitan:
```
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --force-reinstall
1. Install PyTorch nightly & Monarch for torchtitan:
```bash
uv pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --force-reinstall
uv pip install torchmonarch
```


Expand All @@ -33,7 +34,7 @@ uv pip install --no-build-isolation -e .

NOTE: If `flash_attn_varlen_func` hits error "torch.AcceleratorError: CUDA error: the provided PTX was compiled with an unsupported toolchain" during forward path, this is due to GPU driver version is not compatible with vLLM/PyTorch compiled version. Use the following command to recompile vLLM.

```
```bash
# Set CUDA version environment variable
export CUDA_HOME=/usr/local/cuda-12.4
export PATH=/usr/local/cuda-12.4/bin:$PATH
Expand All @@ -49,23 +50,23 @@ uv pip install -e .
```

3. Download Qwen/Qwen3-0.6B checkpoint from HuggingFace and put into `torchtitan/experiments/rl/example_checkpoint` folder.
```
```bash
python scripts/download_hf_assets.py --repo_id Qwen/Qwen3-0.6B --local_dir torchtitan/experiments/rl/example_checkpoint --all --hf_token=...
```

4. Run inference:
```
```bash
python torchtitan/experiments/rl/unified/infer.py --model-ckpt-path <path_to_model_checkpoint>
```

Run with TP: (work in progress)
```
```bash
python torchtitan/experiments/rl/unified/infer.py --model-ckpt-path <path_to_model_checkpoint> --tensor-parallel-size 2

```

5. Run simple rl loop
```
```bash
VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_rl_multiprocess.py
```
Right now we only support VLLM_COMPAT mode, which could achieve trainer and generator bitwise identical. We are working on support UNIFIED mode,
Expand Down
5 changes: 5 additions & 0 deletions torchtitan/experiments/rl/unified/actors/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
)
from torchtitan.experiments.rl.vllm_compat.weights.converter import torchtitan_to_vllm
from vllm import LLM, SamplingParams
from vllm.config import AttentionConfig
from vllm.v1.attention.backends.registry import AttentionBackendEnum

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -210,6 +212,9 @@ def update_weights(self, vllm_compat_state: dict) -> None:
seed=42, # Fixed seed for determinism
enforce_eager=True,
tensor_parallel_size=self.tp_size, # Explicitly single GPU
attention_config=AttentionConfig(
backend=AttentionBackendEnum.FLASH_ATTN,
),
)
logger.info("Created new vLLM engine")
else:
Expand Down
1 change: 0 additions & 1 deletion torchtitan/experiments/rl/unified/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

import torch

from vllm.model_executor.layers.attention import Attention


Expand Down
3 changes: 2 additions & 1 deletion torchtitan/experiments/rl/unified/simple_rl_multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
init_batch_invariance,
vllm_is_batch_invariant,
)
from vllm.v1.attention.backends.registry import AttentionBackendEnum

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -63,7 +64,7 @@ async def main():
trainer_tp_size = 1
generator_tp_size = 1

init_batch_invariance()
init_batch_invariance(AttentionBackendEnum.FLASH_ATTN)
batch_invariant = vllm_is_batch_invariant()
mode = ModelMode.UNIFIED

Expand Down
3 changes: 2 additions & 1 deletion torchtitan/experiments/rl/vllm_compat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ Initialize vLLM's batch-invariant mode before training:

```python
from vllm.model_executor.layers.batch_invariant import init_batch_invariance
init_batch_invariance()
from vllm.v1.attention.backends.registry import AttentionBackendEnum
init_batch_invariance(AttentionBackendEnum.FLASH_ATTN)
```

## Usage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from batch_invariant_backward import enable_batch_invariant_backward_mode

# Initialize vLLM's deterministic mode first
init_batch_invariance()
init_batch_invariance(AttentionBackendEnum.FLASH_ATTN)

# Then enable gradient support
enable_batch_invariant_backward_mode()
Expand Down Expand Up @@ -308,7 +308,7 @@ def enable_batch_invariant_backward_mode():
):
raise RuntimeError(
"vLLM's batch_invariant mode is not initialized. "
"Call init_batch_invariance() first."
"Call init_batch_invariance(AttentionBackendEnum.FLASH_ATTN) first."
)

# Use vLLM's existing library - don't destroy it!
Expand Down
92 changes: 73 additions & 19 deletions torchtitan/experiments/rl/vllm_compat/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


import math
from collections.abc import Callable

import torch
from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func
Expand All @@ -30,6 +31,7 @@ def forward(
v: torch.Tensor,
*,
scale: float | None = None,
enable_gqa: bool = False,
) -> torch.Tensor:
# Flash Attention varlen expects: (batch, seqlen, nheads, headdim)
# The input from TorchTitan is always (batch, num_heads, seq_len, head_dim)
Expand All @@ -42,12 +44,13 @@ def forward(

# Get dimensions
batch_size, seq_len, num_heads, head_dim = q.shape
num_kv_heads = k.shape[2]

# Convert to varlen format: flatten batch and sequence dimensions
# (batch, seqlen, nheads, headdim) -> (total_tokens, nheads, headdim)
q_varlen = q.reshape(-1, num_heads, head_dim)
k_varlen = k.reshape(-1, k.shape[2], head_dim)
v_varlen = v.reshape(-1, v.shape[2], head_dim)
k_varlen = k.reshape(-1, num_kv_heads, head_dim)
v_varlen = v.reshape(-1, num_kv_heads, head_dim)

# Create cumulative sequence lengths
# cu_seqlens: [0, seq_len, 2*seq_len, ..., batch_size*seq_len]
Expand All @@ -59,21 +62,29 @@ def forward(
if scale is None:
scale = 1.0 / math.sqrt(q.size(-1))

# Pre-allocate output tensor with correct shape (num_heads from Q, not K/V)
# This ensures flash attention writes to a tensor with the correct GQA output shape
total_tokens = batch_size * seq_len
out_varlen = torch.empty(
(total_tokens, num_heads, head_dim), dtype=q.dtype, device=q.device
)

# Wrap Flash Attention with manual backward pass
class FlashAttnWithBackward(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
k,
v,
cu_seqlens,
seq_len,
scale,
num_splits,
flash_fn,
fa_version,
):
ctx: torch.autograd.function.FunctionCtx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
cu_seqlens: torch.Tensor,
seq_len: int,
scale: float,
num_splits: int,
flash_fn: Callable[..., torch.Tensor],
fa_version: int,
) -> torch.Tensor:
# Call flash attention for forward (fast)
output = flash_fn(
q,
Expand All @@ -87,6 +98,7 @@ def forward(
causal=True,
num_splits=num_splits,
fa_version=fa_version,
out=out,
)
# Save for backward
ctx.save_for_backward(q, k, v, output)
Expand All @@ -95,7 +107,20 @@ def forward(
return output

@staticmethod
def backward(ctx, grad_output):
def backward(
ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
None,
None,
None,
None,
None,
None,
None,
]:
q, k, v, output = ctx.saved_tensors
scale = ctx.scale
seq_len = ctx.seq_len
Expand All @@ -104,12 +129,13 @@ def backward(ctx, grad_output):
# Assume uniform sequence lengths (batch_size = total_tokens / seq_len)
total_tokens = q.shape[0]
num_heads = q.shape[1]
num_kv_heads = k.shape[1]
head_dim = q.shape[2]
batch_size = total_tokens // seq_len

q_batch = q.reshape(batch_size, seq_len, num_heads, head_dim)
k_batch = k.reshape(batch_size, seq_len, num_heads, head_dim)
v_batch = v.reshape(batch_size, seq_len, num_heads, head_dim)
k_batch = k.reshape(batch_size, seq_len, num_kv_heads, head_dim)
v_batch = v.reshape(batch_size, seq_len, num_kv_heads, head_dim)
out_batch = output.reshape(batch_size, seq_len, num_heads, head_dim)
grad_out_batch = grad_output.reshape(
batch_size, seq_len, num_heads, head_dim
Expand All @@ -122,6 +148,17 @@ def backward(ctx, grad_output):
out_t = out_batch.transpose(1, 2)
grad_out_t = grad_out_batch.transpose(1, 2)

# For GQA, we need to expand K/V to match Q's num_heads
# Each KV head serves (num_heads // num_kv_heads) Q heads
if num_kv_heads < num_heads:
assert enable_gqa, "GQA requires enable_gqa=True"
assert (
num_heads % num_kv_heads == 0
), "num_heads must be a multiple of num_kv_heads"
n_rep = num_heads // num_kv_heads
k_t = k_t.repeat_interleave(n_rep, dim=1)
v_t = v_t.repeat_interleave(n_rep, dim=1)

# Compute attention scores: QK^T
# q_t: (B, H, N, D), k_t: (B, H, N, D) -> scores: (B, H, N, N)
scores = torch.matmul(q_t, k_t.transpose(-2, -1)) * scale
Expand Down Expand Up @@ -167,20 +204,37 @@ def backward(ctx, grad_output):
grad_q = grad_q_t.transpose(1, 2).reshape(
total_tokens, num_heads, head_dim
)

# For GQA, we need to reduce grad_k and grad_v back to num_kv_heads
if num_kv_heads < num_heads:
assert enable_gqa, "GQA requires enable_gqa=True"
assert (
num_heads % num_kv_heads == 0
), "num_heads must be a multiple of num_kv_heads"
n_rep = num_heads // num_kv_heads
# Reshape and sum over the repeated dimension
grad_k_t = grad_k_t.reshape(
batch_size, num_kv_heads, n_rep, seq_len, head_dim
).sum(dim=2)
grad_v_t = grad_v_t.reshape(
batch_size, num_kv_heads, n_rep, seq_len, head_dim
).sum(dim=2)

grad_k = grad_k_t.transpose(1, 2).reshape(
total_tokens, num_heads, head_dim
total_tokens, num_kv_heads, head_dim
)
grad_v = grad_v_t.transpose(1, 2).reshape(
total_tokens, num_heads, head_dim
total_tokens, num_kv_heads, head_dim
)

return grad_q, grad_k, grad_v, None, None, None, None, None, None
return grad_q, grad_k, grad_v, None, None, None, None, None, None, None

# Call Flash Attention varlen with custom backward
output_varlen = FlashAttnWithBackward.apply(
q_varlen,
k_varlen,
v_varlen,
out_varlen,
cu_seqlens,
seq_len,
scale,
Expand Down
Loading