From 4dd35e47cce1cddc68ab68bc33f02f663f4059f2 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 28 Oct 2025 10:15:23 +0000 Subject: [PATCH 1/6] [WIP] --- autoparallel/_testing/models/llama3.py | 92 +++++++++++++++++++++++++- examples/example_llama3.py | 6 +- 2 files changed, 95 insertions(+), 3 deletions(-) diff --git a/autoparallel/_testing/models/llama3.py b/autoparallel/_testing/models/llama3.py index 9d349e1a..b4bcaa09 100644 --- a/autoparallel/_testing/models/llama3.py +++ b/autoparallel/_testing/models/llama3.py @@ -19,6 +19,93 @@ def has_cuda_capability(major: int, minor: int) -> bool: ) +from torch.distributed.tensor.experimental._attention import ( + _scaled_dot_product_ring_flash_attention, + _scaled_dot_product_ring_flash_attention_backward, +) + + +class _ContextParallelAttention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, dropout_p, is_causal, scale, mesh): + ctx.scale = scale + ctx.is_causal = is_causal + ctx.dropout_p = dropout_p + ctx.mesh = mesh + # Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor rng_state, Tensor unused, Tensor debug_attn_mask + ( + out, + lse, + cum_seq_q, + cum_seq_k, + max_q, + max_k, + philox_seed, + philox_offset, + debug_attn_mask, + ) = _scaled_dot_product_ring_flash_attention( + mesh, q, k, v, dropout_p, is_causal, return_debug_mask=False, scale=scale + ) + ctx.max_q = max_q + ctx.max_k = max_k + ctx.save_for_backward( + q, k, v, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset + ) + return out + + @staticmethod + def backward(ctx, grad_out): + ( + q, + k, + v, + out, + lse, + cum_seq_q, + cum_seq_k, + philox_seed, + philox_offset, + ) = ctx.saved_tensors + return _scaled_dot_product_ring_flash_attention_backward( + ctx.mesh, + grad_out, + q, + k, + v, + out, + lse, + cum_seq_q, + cum_seq_k, + ctx.max_q, + ctx.max_k, + ctx.dropout_p, + ctx.is_causal, + philox_seed, + philox_offset, + scale=ctx.scale, + ) + + +from torch.distributed.tensor.placement_types import Replicate, Shard + +from autoparallel.collectives import get_mesh_from_global, local_map + + +def context_parallel_attention(q, k, v, *, dropout_p=0.0, is_causal=False, scale=None): + mesh = get_mesh_from_global() + plc = (Shard(0), Shard(2)) + out_placements = (plc,) + in_placements = (plc, plc, plc, None, None, None, None) + return local_map( + _ContextParallelAttention.apply, + out_placements=out_placements, + in_placements=in_placements, + redistribute_inputs=True, + in_grad_placements=None, + device_mesh=mesh, + )(q, k, v, dropout_p, is_causal, scale, mesh["tp"]) + + class ScaledDotProductAttention(torch.nn.Module): backends: ClassVar[list[SDPBackend]] = [] @@ -49,8 +136,9 @@ def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor ) -> torch.Tensor: assert self.backends, "SDPA Backends should not be empty." - with sdpa_kernel(self.backends, set_priority=True): - return F.scaled_dot_product_attention(q, k, v, is_causal=True) + # with sdpa_kernel(self.backends, set_priority=True): + # return F.scaled_dot_product_attention(q, k, v, is_causal=True) + return context_parallel_attention(q, k, v, is_causal=True) def build_attention( diff --git a/examples/example_llama3.py b/examples/example_llama3.py index bc41e96c..29cf5bac 100644 --- a/examples/example_llama3.py +++ b/examples/example_llama3.py @@ -52,12 +52,16 @@ model_type = "8b" enable_asynctp = False +import autoparallel.collectives + +autoparallel.collectives._local_map_device_mesh = mesh + def model_fn(): if model_type == "8b": model_args = TransformerModelArgs( dim=4096, - n_layers=32, + n_layers=1, # 32, n_heads=32, n_kv_heads=8, ffn_dim_multiplier=1.3, From 06f17c709e28d44a0a77ab8bb855f68084751591 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 13 Jan 2026 09:24:29 +0000 Subject: [PATCH 2/6] Use Claude to generalize to all sdpa backends --- autoparallel/_testing/models/llama3.py | 335 ++++++++++++++++++++----- 1 file changed, 277 insertions(+), 58 deletions(-) diff --git a/autoparallel/_testing/models/llama3.py b/autoparallel/_testing/models/llama3.py index b4bcaa09..c2171ac7 100644 --- a/autoparallel/_testing/models/llama3.py +++ b/autoparallel/_testing/models/llama3.py @@ -9,7 +9,18 @@ import torch import torch.nn.functional as F from torch import nn -from torch.nn.attention import SDPBackend, sdpa_kernel +from torch.distributed.tensor.experimental._context_parallel._attention import ( + _scaled_dot_product_ring_cudnn_attention, + _scaled_dot_product_ring_cudnn_attention_backward, + _scaled_dot_product_ring_efficient_attention, + _scaled_dot_product_ring_efficient_attention_backward, + _scaled_dot_product_ring_flash_attention, + _scaled_dot_product_ring_flash_attention_backward, +) +from torch.distributed.tensor.placement_types import Shard +from torch.nn.attention import SDPBackend # , sdpa_kernel + +from autoparallel.collectives import get_mesh_from_global, local_map def has_cuda_capability(major: int, minor: int) -> bool: @@ -19,83 +30,267 @@ def has_cuda_capability(major: int, minor: int) -> bool: ) -from torch.distributed.tensor.experimental._attention import ( - _scaled_dot_product_ring_flash_attention, - _scaled_dot_product_ring_flash_attention_backward, -) +# Backend-specific backward wrappers to handle signature differences +def _flash_backward_wrapper(mesh, grad_out, q, k, v, out, forward_outputs, kwargs): + """Handle flash attention backward with correct argument order.""" + # Forward outputs: lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask + # Backward expects: mesh, grad_out, query, key, value, out, logsumexp, cum_seq_q, cum_seq_k, + # max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, *, scale + ( + lse, + cum_seq_q, + cum_seq_k, + max_q, + max_k, + philox_seed, + philox_offset, + ) = forward_outputs[:7] + return _scaled_dot_product_ring_flash_attention_backward( + mesh, + grad_out, + q, + k, + v, + out, + lse, + cum_seq_q, + cum_seq_k, + max_q, + max_k, + kwargs.get("dropout_p", 0.0), + kwargs.get("is_causal", False), + philox_seed, + philox_offset, + scale=kwargs.get("scale", None), + ) + + +def _efficient_backward_wrapper(mesh, grad_out, q, k, v, out, forward_outputs, kwargs): + """Handle efficient attention backward with correct argument order.""" + # Forward outputs: lse, philox_seed, philox_offset + # Backward expects: mesh, grad_out, query, key, value, bias, out, logsumexp, + # philox_seed, philox_offset, dropout_p, grad_input_mask, is_causal, *, scale + lse, philox_seed, philox_offset = forward_outputs[:3] + # Build grad_input_mask based on which inputs require gradients + attn_bias = kwargs.get("attn_bias", None) + grad_input_mask = ( + q.requires_grad, + k.requires_grad, + v.requires_grad, + attn_bias.requires_grad if attn_bias is not None else False, + ) + return _scaled_dot_product_ring_efficient_attention_backward( + mesh, + grad_out, + q, + k, + v, + attn_bias, + out, + lse, + philox_seed, + philox_offset, + kwargs.get("dropout_p", 0.0), + grad_input_mask, + kwargs.get("is_causal", False), + scale=kwargs.get("scale", None), + ) + + +def _cudnn_backward_wrapper(mesh, grad_out, q, k, v, out, forward_outputs, kwargs): + """Handle cudnn attention backward with correct argument order.""" + # Forward outputs: lse, philox_seed, philox_offset, softmax_stats(?), bias(?), cum_seq_q, cum_seq_k, max_q, max_k, debug_attn_mask + # Backward expects: mesh, grad_out, query, key, value, out, logsumexp, + # philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, + # max_q, max_k, dropout_p, is_causal, *, scale + lse, philox_seed, philox_offset = forward_outputs[:3] + # CuDNN may have additional outputs; extract what we need + if len(forward_outputs) >= 9: + cum_seq_q, cum_seq_k, max_q, max_k = forward_outputs[5:9] + else: + # Fallback if structure is different + cum_seq_q, cum_seq_k, max_q, max_k = forward_outputs[3:7] + + return _scaled_dot_product_ring_cudnn_attention_backward( + mesh, + grad_out, + q, + k, + v, + out, + lse, + philox_seed, + philox_offset, + kwargs.get("attn_bias", None), + cum_seq_q, + cum_seq_k, + max_q, + max_k, + kwargs.get("dropout_p", 0.0), + kwargs.get("is_causal", False), + scale=kwargs.get("scale", None), + ) + + +# Mapping of backward functions to their wrappers +_CP_BACKWARD_WRAPPERS = { + _scaled_dot_product_ring_flash_attention_backward: _flash_backward_wrapper, + _scaled_dot_product_ring_efficient_attention_backward: _efficient_backward_wrapper, + _scaled_dot_product_ring_cudnn_attention_backward: _cudnn_backward_wrapper, +} class _ContextParallelAttention(torch.autograd.Function): + """ + Generic context parallel attention that supports multiple backends. + Uses **kwargs to be future-proof against signature changes. + """ + @staticmethod - def forward(ctx, q, k, v, dropout_p, is_causal, scale, mesh): - ctx.scale = scale - ctx.is_causal = is_causal - ctx.dropout_p = dropout_p + def forward(ctx, op_forward, op_backward, q, k, v, kwargs_keys_str, *kwargs_values): + """ + Args: + op_forward: Forward operation (e.g., _scaled_dot_product_ring_flash_attention) + op_backward: Backward operation (e.g., _scaled_dot_product_ring_flash_attention_backward) + q, k, v: Query, key, value tensors + kwargs_keys_str: Comma-separated string of kwarg names (e.g., 'dropout_p,is_causal,scale') + *kwargs_values: Values corresponding to kwargs_keys + """ + # Get mesh from global context (avoids passing it through local_map which would flatten it) + mesh = get_mesh_from_global()["tp"] + + ctx.op_backward = op_backward ctx.mesh = mesh - # Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor rng_state, Tensor unused, Tensor debug_attn_mask - ( - out, - lse, - cum_seq_q, - cum_seq_k, - max_q, - max_k, - philox_seed, - philox_offset, - debug_attn_mask, - ) = _scaled_dot_product_ring_flash_attention( - mesh, q, k, v, dropout_p, is_causal, return_debug_mask=False, scale=scale - ) - ctx.max_q = max_q - ctx.max_k = max_k - ctx.save_for_backward( - q, k, v, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset - ) + + # Reconstruct kwargs dict from keys string and values + kwargs_keys = kwargs_keys_str.split(",") if kwargs_keys_str else [] + kwargs_dict = dict(zip(kwargs_keys, kwargs_values)) + ctx.kwargs = kwargs_dict + + # Call the forward operation with all kwargs + outputs = op_forward(mesh, q, k, v, **kwargs_dict) + + # outputs is a tuple: (out, lse, ...) where the rest varies by backend + out = outputs[0] + forward_outputs = outputs[1:] + + # Separate tensors from non-tensors for proper saving + # Tensors must be saved via save_for_backward for proper memory management + tensors_to_save = [q, k, v, out] + non_tensor_outputs = [] + + for i, item in enumerate(forward_outputs): + if isinstance(item, torch.Tensor): + tensors_to_save.append(item) + non_tensor_outputs.append(("tensor", len(tensors_to_save) - 1)) + else: + non_tensor_outputs.append(("value", item)) + + ctx.save_for_backward(*tensors_to_save) + ctx.non_tensor_outputs = non_tensor_outputs + ctx.num_forward_outputs = len(forward_outputs) + return out @staticmethod def backward(ctx, grad_out): - ( - q, - k, - v, - out, - lse, - cum_seq_q, - cum_seq_k, - philox_seed, - philox_offset, - ) = ctx.saved_tensors - return _scaled_dot_product_ring_flash_attention_backward( + # Retrieve saved tensors + saved_tensors = ctx.saved_tensors + q, k, v, out = saved_tensors[:4] + saved_forward_tensors = saved_tensors[4:] + + # Reconstruct forward_outputs from saved tensors and non-tensor values + forward_outputs = [] + tensor_idx = 0 + for output_type, output_value in ctx.non_tensor_outputs: + if output_type == "tensor": + forward_outputs.append(saved_forward_tensors[tensor_idx]) + tensor_idx += 1 + else: + forward_outputs.append(output_value) + forward_outputs = tuple(forward_outputs) + + # Use the backend-specific wrapper to handle argument ordering + wrapper_fn = _CP_BACKWARD_WRAPPERS.get(ctx.op_backward) + if wrapper_fn is None: + raise RuntimeError( + f"No backward wrapper found for {ctx.op_backward}. " + "This backend may not be supported yet." + ) + + grads = wrapper_fn( ctx.mesh, grad_out, q, k, v, out, - lse, - cum_seq_q, - cum_seq_k, - ctx.max_q, - ctx.max_k, - ctx.dropout_p, - ctx.is_causal, - philox_seed, - philox_offset, - scale=ctx.scale, + forward_outputs, + ctx.kwargs, ) + # Return gradients: + # (None for op_forward, None for op_backward, grad_q, grad_k, grad_v, None for kwargs_keys_str, None for each kwargs_value) + num_kwargs = len(ctx.kwargs) + return (None, None) + grads[:3] + (None,) + (None,) * num_kwargs + + +# Backend registry for context parallel attention +_CP_ATTENTION_BACKENDS = { + SDPBackend.FLASH_ATTENTION: ( + _scaled_dot_product_ring_flash_attention, + _scaled_dot_product_ring_flash_attention_backward, + ), + SDPBackend.EFFICIENT_ATTENTION: ( + _scaled_dot_product_ring_efficient_attention, + _scaled_dot_product_ring_efficient_attention_backward, + ), + SDPBackend.CUDNN_ATTENTION: ( + _scaled_dot_product_ring_cudnn_attention, + _scaled_dot_product_ring_cudnn_attention_backward, + ), +} + + +def context_parallel_attention( + q, k, v, *, backend=SDPBackend.FLASH_ATTENTION, **kwargs +): + """ + Generic context parallel attention supporting multiple backends. -from torch.distributed.tensor.placement_types import Replicate, Shard + Args: + q, k, v: Query, key, value tensors + backend: SDPBackend to use (FLASH_ATTENTION, EFFICIENT_ATTENTION, or CUDNN_ATTENTION) + **kwargs: Additional arguments passed to the attention operation (e.g., dropout_p, is_causal, scale, attn_bias) -from autoparallel.collectives import get_mesh_from_global, local_map + Returns: + Attention output tensor + This function is future-proof as it uses **kwargs to pass arguments, so changes + to backend signatures won't require updating this function. + """ + if backend not in _CP_ATTENTION_BACKENDS: + raise ValueError( + f"Unsupported backend: {backend}. Supported backends: {list(_CP_ATTENTION_BACKENDS.keys())}" + ) + + op_forward, op_backward = _CP_ATTENTION_BACKENDS[backend] -def context_parallel_attention(q, k, v, *, dropout_p=0.0, is_causal=False, scale=None): mesh = get_mesh_from_global() plc = (Shard(0), Shard(2)) out_placements = (plc,) - in_placements = (plc, plc, plc, None, None, None, None) + + # Convert kwargs to a comma-separated string of keys and a tuple of values + # Using a string prevents pytree from flattening it + kwargs_keys_str = ",".join(kwargs.keys()) if kwargs else "" + kwargs_values = tuple(kwargs.values()) + + # in_placements for: op_forward, op_backward, q, k, v, kwargs_keys_str, *kwargs_values + # Note: mesh is NOT passed through local_map (it would be flattened by pytree) + # Instead, we retrieve it inside the autograd function using get_mesh_from_global() + num_kwargs = len(kwargs) + in_placements = (None, None, plc, plc, plc, None) + (None,) * num_kwargs + return local_map( _ContextParallelAttention.apply, out_placements=out_placements, @@ -103,7 +298,7 @@ def context_parallel_attention(q, k, v, *, dropout_p=0.0, is_causal=False, scale redistribute_inputs=True, in_grad_placements=None, device_mesh=mesh, - )(q, k, v, dropout_p, is_causal, scale, mesh["tp"]) + )(op_forward, op_backward, q, k, v, kwargs_keys_str, *kwargs_values) class ScaledDotProductAttention(torch.nn.Module): @@ -132,13 +327,37 @@ def _init_backend(cls) -> None: if has_cuda_capability(10, 0): cls.backends.insert(0, SDPBackend.CUDNN_ATTENTION) + def _select_backend(self) -> SDPBackend: + """ + Select the best available backend for context parallel attention. + Only considers backends that are supported by context parallel. + """ + supported_cp_backends = { + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.CUDNN_ATTENTION, + } + + for backend in self.backends: + if backend in supported_cp_backends: + return backend + + # Fallback to flash attention if no supported backend is found + return SDPBackend.FLASH_ATTENTION + def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor ) -> torch.Tensor: assert self.backends, "SDPA Backends should not be empty." - # with sdpa_kernel(self.backends, set_priority=True): - # return F.scaled_dot_product_attention(q, k, v, is_causal=True) - return context_parallel_attention(q, k, v, is_causal=True) + + # Select the best available backend + backend = self._select_backend() + + # Use context parallel attention with the selected backend + # All backend-specific arguments (is_causal, dropout_p, scale, etc.) are passed via kwargs + return context_parallel_attention( + q, k, v, backend=backend, is_causal=True, dropout_p=0.0 + ) def build_attention( From 36a9fd985573371dd7f9284f135e250b48dd8ab0 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 13 Jan 2026 09:30:32 +0000 Subject: [PATCH 3/6] Move context_parallel_attention to new file --- autoparallel/_testing/models/llama3.py | 282 +----------------------- autoparallel/ops.py | 289 +++++++++++++++++++++++++ 2 files changed, 290 insertions(+), 281 deletions(-) create mode 100644 autoparallel/ops.py diff --git a/autoparallel/_testing/models/llama3.py b/autoparallel/_testing/models/llama3.py index c2171ac7..685add65 100644 --- a/autoparallel/_testing/models/llama3.py +++ b/autoparallel/_testing/models/llama3.py @@ -9,18 +9,9 @@ import torch import torch.nn.functional as F from torch import nn -from torch.distributed.tensor.experimental._context_parallel._attention import ( - _scaled_dot_product_ring_cudnn_attention, - _scaled_dot_product_ring_cudnn_attention_backward, - _scaled_dot_product_ring_efficient_attention, - _scaled_dot_product_ring_efficient_attention_backward, - _scaled_dot_product_ring_flash_attention, - _scaled_dot_product_ring_flash_attention_backward, -) -from torch.distributed.tensor.placement_types import Shard from torch.nn.attention import SDPBackend # , sdpa_kernel -from autoparallel.collectives import get_mesh_from_global, local_map +from autoparallel.ops import context_parallel_attention def has_cuda_capability(major: int, minor: int) -> bool: @@ -30,277 +21,6 @@ def has_cuda_capability(major: int, minor: int) -> bool: ) -# Backend-specific backward wrappers to handle signature differences -def _flash_backward_wrapper(mesh, grad_out, q, k, v, out, forward_outputs, kwargs): - """Handle flash attention backward with correct argument order.""" - # Forward outputs: lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask - # Backward expects: mesh, grad_out, query, key, value, out, logsumexp, cum_seq_q, cum_seq_k, - # max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, *, scale - ( - lse, - cum_seq_q, - cum_seq_k, - max_q, - max_k, - philox_seed, - philox_offset, - ) = forward_outputs[:7] - return _scaled_dot_product_ring_flash_attention_backward( - mesh, - grad_out, - q, - k, - v, - out, - lse, - cum_seq_q, - cum_seq_k, - max_q, - max_k, - kwargs.get("dropout_p", 0.0), - kwargs.get("is_causal", False), - philox_seed, - philox_offset, - scale=kwargs.get("scale", None), - ) - - -def _efficient_backward_wrapper(mesh, grad_out, q, k, v, out, forward_outputs, kwargs): - """Handle efficient attention backward with correct argument order.""" - # Forward outputs: lse, philox_seed, philox_offset - # Backward expects: mesh, grad_out, query, key, value, bias, out, logsumexp, - # philox_seed, philox_offset, dropout_p, grad_input_mask, is_causal, *, scale - lse, philox_seed, philox_offset = forward_outputs[:3] - # Build grad_input_mask based on which inputs require gradients - attn_bias = kwargs.get("attn_bias", None) - grad_input_mask = ( - q.requires_grad, - k.requires_grad, - v.requires_grad, - attn_bias.requires_grad if attn_bias is not None else False, - ) - return _scaled_dot_product_ring_efficient_attention_backward( - mesh, - grad_out, - q, - k, - v, - attn_bias, - out, - lse, - philox_seed, - philox_offset, - kwargs.get("dropout_p", 0.0), - grad_input_mask, - kwargs.get("is_causal", False), - scale=kwargs.get("scale", None), - ) - - -def _cudnn_backward_wrapper(mesh, grad_out, q, k, v, out, forward_outputs, kwargs): - """Handle cudnn attention backward with correct argument order.""" - # Forward outputs: lse, philox_seed, philox_offset, softmax_stats(?), bias(?), cum_seq_q, cum_seq_k, max_q, max_k, debug_attn_mask - # Backward expects: mesh, grad_out, query, key, value, out, logsumexp, - # philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, - # max_q, max_k, dropout_p, is_causal, *, scale - lse, philox_seed, philox_offset = forward_outputs[:3] - # CuDNN may have additional outputs; extract what we need - if len(forward_outputs) >= 9: - cum_seq_q, cum_seq_k, max_q, max_k = forward_outputs[5:9] - else: - # Fallback if structure is different - cum_seq_q, cum_seq_k, max_q, max_k = forward_outputs[3:7] - - return _scaled_dot_product_ring_cudnn_attention_backward( - mesh, - grad_out, - q, - k, - v, - out, - lse, - philox_seed, - philox_offset, - kwargs.get("attn_bias", None), - cum_seq_q, - cum_seq_k, - max_q, - max_k, - kwargs.get("dropout_p", 0.0), - kwargs.get("is_causal", False), - scale=kwargs.get("scale", None), - ) - - -# Mapping of backward functions to their wrappers -_CP_BACKWARD_WRAPPERS = { - _scaled_dot_product_ring_flash_attention_backward: _flash_backward_wrapper, - _scaled_dot_product_ring_efficient_attention_backward: _efficient_backward_wrapper, - _scaled_dot_product_ring_cudnn_attention_backward: _cudnn_backward_wrapper, -} - - -class _ContextParallelAttention(torch.autograd.Function): - """ - Generic context parallel attention that supports multiple backends. - Uses **kwargs to be future-proof against signature changes. - """ - - @staticmethod - def forward(ctx, op_forward, op_backward, q, k, v, kwargs_keys_str, *kwargs_values): - """ - Args: - op_forward: Forward operation (e.g., _scaled_dot_product_ring_flash_attention) - op_backward: Backward operation (e.g., _scaled_dot_product_ring_flash_attention_backward) - q, k, v: Query, key, value tensors - kwargs_keys_str: Comma-separated string of kwarg names (e.g., 'dropout_p,is_causal,scale') - *kwargs_values: Values corresponding to kwargs_keys - """ - # Get mesh from global context (avoids passing it through local_map which would flatten it) - mesh = get_mesh_from_global()["tp"] - - ctx.op_backward = op_backward - ctx.mesh = mesh - - # Reconstruct kwargs dict from keys string and values - kwargs_keys = kwargs_keys_str.split(",") if kwargs_keys_str else [] - kwargs_dict = dict(zip(kwargs_keys, kwargs_values)) - ctx.kwargs = kwargs_dict - - # Call the forward operation with all kwargs - outputs = op_forward(mesh, q, k, v, **kwargs_dict) - - # outputs is a tuple: (out, lse, ...) where the rest varies by backend - out = outputs[0] - forward_outputs = outputs[1:] - - # Separate tensors from non-tensors for proper saving - # Tensors must be saved via save_for_backward for proper memory management - tensors_to_save = [q, k, v, out] - non_tensor_outputs = [] - - for i, item in enumerate(forward_outputs): - if isinstance(item, torch.Tensor): - tensors_to_save.append(item) - non_tensor_outputs.append(("tensor", len(tensors_to_save) - 1)) - else: - non_tensor_outputs.append(("value", item)) - - ctx.save_for_backward(*tensors_to_save) - ctx.non_tensor_outputs = non_tensor_outputs - ctx.num_forward_outputs = len(forward_outputs) - - return out - - @staticmethod - def backward(ctx, grad_out): - # Retrieve saved tensors - saved_tensors = ctx.saved_tensors - q, k, v, out = saved_tensors[:4] - saved_forward_tensors = saved_tensors[4:] - - # Reconstruct forward_outputs from saved tensors and non-tensor values - forward_outputs = [] - tensor_idx = 0 - for output_type, output_value in ctx.non_tensor_outputs: - if output_type == "tensor": - forward_outputs.append(saved_forward_tensors[tensor_idx]) - tensor_idx += 1 - else: - forward_outputs.append(output_value) - forward_outputs = tuple(forward_outputs) - - # Use the backend-specific wrapper to handle argument ordering - wrapper_fn = _CP_BACKWARD_WRAPPERS.get(ctx.op_backward) - if wrapper_fn is None: - raise RuntimeError( - f"No backward wrapper found for {ctx.op_backward}. " - "This backend may not be supported yet." - ) - - grads = wrapper_fn( - ctx.mesh, - grad_out, - q, - k, - v, - out, - forward_outputs, - ctx.kwargs, - ) - - # Return gradients: - # (None for op_forward, None for op_backward, grad_q, grad_k, grad_v, None for kwargs_keys_str, None for each kwargs_value) - num_kwargs = len(ctx.kwargs) - return (None, None) + grads[:3] + (None,) + (None,) * num_kwargs - - -# Backend registry for context parallel attention -_CP_ATTENTION_BACKENDS = { - SDPBackend.FLASH_ATTENTION: ( - _scaled_dot_product_ring_flash_attention, - _scaled_dot_product_ring_flash_attention_backward, - ), - SDPBackend.EFFICIENT_ATTENTION: ( - _scaled_dot_product_ring_efficient_attention, - _scaled_dot_product_ring_efficient_attention_backward, - ), - SDPBackend.CUDNN_ATTENTION: ( - _scaled_dot_product_ring_cudnn_attention, - _scaled_dot_product_ring_cudnn_attention_backward, - ), -} - - -def context_parallel_attention( - q, k, v, *, backend=SDPBackend.FLASH_ATTENTION, **kwargs -): - """ - Generic context parallel attention supporting multiple backends. - - Args: - q, k, v: Query, key, value tensors - backend: SDPBackend to use (FLASH_ATTENTION, EFFICIENT_ATTENTION, or CUDNN_ATTENTION) - **kwargs: Additional arguments passed to the attention operation (e.g., dropout_p, is_causal, scale, attn_bias) - - Returns: - Attention output tensor - - This function is future-proof as it uses **kwargs to pass arguments, so changes - to backend signatures won't require updating this function. - """ - if backend not in _CP_ATTENTION_BACKENDS: - raise ValueError( - f"Unsupported backend: {backend}. Supported backends: {list(_CP_ATTENTION_BACKENDS.keys())}" - ) - - op_forward, op_backward = _CP_ATTENTION_BACKENDS[backend] - - mesh = get_mesh_from_global() - plc = (Shard(0), Shard(2)) - out_placements = (plc,) - - # Convert kwargs to a comma-separated string of keys and a tuple of values - # Using a string prevents pytree from flattening it - kwargs_keys_str = ",".join(kwargs.keys()) if kwargs else "" - kwargs_values = tuple(kwargs.values()) - - # in_placements for: op_forward, op_backward, q, k, v, kwargs_keys_str, *kwargs_values - # Note: mesh is NOT passed through local_map (it would be flattened by pytree) - # Instead, we retrieve it inside the autograd function using get_mesh_from_global() - num_kwargs = len(kwargs) - in_placements = (None, None, plc, plc, plc, None) + (None,) * num_kwargs - - return local_map( - _ContextParallelAttention.apply, - out_placements=out_placements, - in_placements=in_placements, - redistribute_inputs=True, - in_grad_placements=None, - device_mesh=mesh, - )(op_forward, op_backward, q, k, v, kwargs_keys_str, *kwargs_values) - - class ScaledDotProductAttention(torch.nn.Module): backends: ClassVar[list[SDPBackend]] = [] diff --git a/autoparallel/ops.py b/autoparallel/ops.py new file mode 100644 index 00000000..1853f242 --- /dev/null +++ b/autoparallel/ops.py @@ -0,0 +1,289 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch.distributed.tensor.experimental._context_parallel._attention import ( + _scaled_dot_product_ring_cudnn_attention, + _scaled_dot_product_ring_cudnn_attention_backward, + _scaled_dot_product_ring_efficient_attention, + _scaled_dot_product_ring_efficient_attention_backward, + _scaled_dot_product_ring_flash_attention, + _scaled_dot_product_ring_flash_attention_backward, +) +from torch.distributed.tensor.placement_types import Shard +from torch.nn.attention import SDPBackend + +from autoparallel.collectives import get_mesh_from_global, local_map + + +# Backend-specific backward wrappers to handle signature differences +def _flash_backward_wrapper(mesh, grad_out, q, k, v, out, forward_outputs, kwargs): + """Handle flash attention backward with correct argument order.""" + # Forward outputs: lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask + # Backward expects: mesh, grad_out, query, key, value, out, logsumexp, cum_seq_q, cum_seq_k, + # max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, *, scale + ( + lse, + cum_seq_q, + cum_seq_k, + max_q, + max_k, + philox_seed, + philox_offset, + ) = forward_outputs[:7] + return _scaled_dot_product_ring_flash_attention_backward( + mesh, + grad_out, + q, + k, + v, + out, + lse, + cum_seq_q, + cum_seq_k, + max_q, + max_k, + kwargs.get("dropout_p", 0.0), + kwargs.get("is_causal", False), + philox_seed, + philox_offset, + scale=kwargs.get("scale", None), + ) + + +def _efficient_backward_wrapper(mesh, grad_out, q, k, v, out, forward_outputs, kwargs): + """Handle efficient attention backward with correct argument order.""" + # Forward outputs: lse, philox_seed, philox_offset + # Backward expects: mesh, grad_out, query, key, value, bias, out, logsumexp, + # philox_seed, philox_offset, dropout_p, grad_input_mask, is_causal, *, scale + lse, philox_seed, philox_offset = forward_outputs[:3] + # Build grad_input_mask based on which inputs require gradients + attn_bias = kwargs.get("attn_bias", None) + grad_input_mask = ( + q.requires_grad, + k.requires_grad, + v.requires_grad, + attn_bias.requires_grad if attn_bias is not None else False, + ) + return _scaled_dot_product_ring_efficient_attention_backward( + mesh, + grad_out, + q, + k, + v, + attn_bias, + out, + lse, + philox_seed, + philox_offset, + kwargs.get("dropout_p", 0.0), + grad_input_mask, + kwargs.get("is_causal", False), + scale=kwargs.get("scale", None), + ) + + +def _cudnn_backward_wrapper(mesh, grad_out, q, k, v, out, forward_outputs, kwargs): + """Handle cudnn attention backward with correct argument order.""" + # Forward outputs: lse, philox_seed, philox_offset, softmax_stats(?), bias(?), cum_seq_q, cum_seq_k, max_q, max_k, debug_attn_mask + # Backward expects: mesh, grad_out, query, key, value, out, logsumexp, + # philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, + # max_q, max_k, dropout_p, is_causal, *, scale + lse, philox_seed, philox_offset = forward_outputs[:3] + # CuDNN may have additional outputs; extract what we need + if len(forward_outputs) >= 9: + cum_seq_q, cum_seq_k, max_q, max_k = forward_outputs[5:9] + else: + # Fallback if structure is different + cum_seq_q, cum_seq_k, max_q, max_k = forward_outputs[3:7] + + return _scaled_dot_product_ring_cudnn_attention_backward( + mesh, + grad_out, + q, + k, + v, + out, + lse, + philox_seed, + philox_offset, + kwargs.get("attn_bias", None), + cum_seq_q, + cum_seq_k, + max_q, + max_k, + kwargs.get("dropout_p", 0.0), + kwargs.get("is_causal", False), + scale=kwargs.get("scale", None), + ) + + +# Mapping of backward functions to their wrappers +_CP_BACKWARD_WRAPPERS = { + _scaled_dot_product_ring_flash_attention_backward: _flash_backward_wrapper, + _scaled_dot_product_ring_efficient_attention_backward: _efficient_backward_wrapper, + _scaled_dot_product_ring_cudnn_attention_backward: _cudnn_backward_wrapper, +} + + +class _ContextParallelAttention(torch.autograd.Function): + """ + Generic context parallel attention that supports multiple backends. + Uses **kwargs to be future-proof against signature changes. + """ + + @staticmethod + def forward(ctx, op_forward, op_backward, q, k, v, kwargs_keys_str, *kwargs_values): + """ + Args: + op_forward: Forward operation (e.g., _scaled_dot_product_ring_flash_attention) + op_backward: Backward operation (e.g., _scaled_dot_product_ring_flash_attention_backward) + q, k, v: Query, key, value tensors + kwargs_keys_str: Comma-separated string of kwarg names (e.g., 'dropout_p,is_causal,scale') + *kwargs_values: Values corresponding to kwargs_keys + """ + # Get mesh from global context (avoids passing it through local_map which would flatten it) + mesh = get_mesh_from_global()["tp"] + + ctx.op_backward = op_backward + ctx.mesh = mesh + + # Reconstruct kwargs dict from keys string and values + kwargs_keys = kwargs_keys_str.split(",") if kwargs_keys_str else [] + kwargs_dict = dict(zip(kwargs_keys, kwargs_values)) + ctx.kwargs = kwargs_dict + + # Call the forward operation with all kwargs + outputs = op_forward(mesh, q, k, v, **kwargs_dict) + + # outputs is a tuple: (out, lse, ...) where the rest varies by backend + out = outputs[0] + forward_outputs = outputs[1:] + + # Separate tensors from non-tensors for proper saving + # Tensors must be saved via save_for_backward for proper memory management + tensors_to_save = [q, k, v, out] + non_tensor_outputs = [] + + for i, item in enumerate(forward_outputs): + if isinstance(item, torch.Tensor): + tensors_to_save.append(item) + non_tensor_outputs.append(("tensor", len(tensors_to_save) - 1)) + else: + non_tensor_outputs.append(("value", item)) + + ctx.save_for_backward(*tensors_to_save) + ctx.non_tensor_outputs = non_tensor_outputs + ctx.num_forward_outputs = len(forward_outputs) + + return out + + @staticmethod + def backward(ctx, grad_out): + # Retrieve saved tensors + saved_tensors = ctx.saved_tensors + q, k, v, out = saved_tensors[:4] + saved_forward_tensors = saved_tensors[4:] + + # Reconstruct forward_outputs from saved tensors and non-tensor values + forward_outputs = [] + tensor_idx = 0 + for output_type, output_value in ctx.non_tensor_outputs: + if output_type == "tensor": + forward_outputs.append(saved_forward_tensors[tensor_idx]) + tensor_idx += 1 + else: + forward_outputs.append(output_value) + forward_outputs = tuple(forward_outputs) + + # Use the backend-specific wrapper to handle argument ordering + wrapper_fn = _CP_BACKWARD_WRAPPERS.get(ctx.op_backward) + if wrapper_fn is None: + raise RuntimeError( + f"No backward wrapper found for {ctx.op_backward}. " + "This backend may not be supported yet." + ) + + grads = wrapper_fn( + ctx.mesh, + grad_out, + q, + k, + v, + out, + forward_outputs, + ctx.kwargs, + ) + + # Return gradients: + # (None for op_forward, None for op_backward, grad_q, grad_k, grad_v, None for kwargs_keys_str, None for each kwargs_value) + num_kwargs = len(ctx.kwargs) + return (None, None) + grads[:3] + (None,) + (None,) * num_kwargs + + +# Backend registry for context parallel attention +_CP_ATTENTION_BACKENDS = { + SDPBackend.FLASH_ATTENTION: ( + _scaled_dot_product_ring_flash_attention, + _scaled_dot_product_ring_flash_attention_backward, + ), + SDPBackend.EFFICIENT_ATTENTION: ( + _scaled_dot_product_ring_efficient_attention, + _scaled_dot_product_ring_efficient_attention_backward, + ), + SDPBackend.CUDNN_ATTENTION: ( + _scaled_dot_product_ring_cudnn_attention, + _scaled_dot_product_ring_cudnn_attention_backward, + ), +} + + +def context_parallel_attention( + q, k, v, *, backend=SDPBackend.FLASH_ATTENTION, **kwargs +): + """ + Generic context parallel attention supporting multiple backends. + + Args: + q, k, v: Query, key, value tensors + backend: SDPBackend to use (FLASH_ATTENTION, EFFICIENT_ATTENTION, or CUDNN_ATTENTION) + **kwargs: Additional arguments passed to the attention operation (e.g., dropout_p, is_causal, scale, attn_bias) + + Returns: + Attention output tensor + + This function is future-proof as it uses **kwargs to pass arguments, so changes + to backend signatures won't require updating this function. + """ + if backend not in _CP_ATTENTION_BACKENDS: + raise ValueError( + f"Unsupported backend: {backend}. Supported backends: {list(_CP_ATTENTION_BACKENDS.keys())}" + ) + + op_forward, op_backward = _CP_ATTENTION_BACKENDS[backend] + + mesh = get_mesh_from_global() + plc = (Shard(0), Shard(2)) + out_placements = (plc,) + + # Convert kwargs to a comma-separated string of keys and a tuple of values + # Using a string prevents pytree from flattening it + kwargs_keys_str = ",".join(kwargs.keys()) if kwargs else "" + kwargs_values = tuple(kwargs.values()) + + # in_placements for: op_forward, op_backward, q, k, v, kwargs_keys_str, *kwargs_values + # Note: mesh is NOT passed through local_map (it would be flattened by pytree) + # Instead, we retrieve it inside the autograd function using get_mesh_from_global() + num_kwargs = len(kwargs) + in_placements = (None, None, plc, plc, plc, None) + (None,) * num_kwargs + + return local_map( + _ContextParallelAttention.apply, + out_placements=out_placements, + in_placements=in_placements, + redistribute_inputs=True, + in_grad_placements=None, + device_mesh=mesh, + )(op_forward, op_backward, q, k, v, kwargs_keys_str, *kwargs_values) From 661dad3849edee71f1a6c3d48fcb73bf5704ee45 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 13 Jan 2026 12:23:31 +0000 Subject: [PATCH 4/6] Add auto-backend selection Thanks Claude! --- autoparallel/ops.py | 72 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 68 insertions(+), 4 deletions(-) diff --git a/autoparallel/ops.py b/autoparallel/ops.py index 1853f242..1def6182 100644 --- a/autoparallel/ops.py +++ b/autoparallel/ops.py @@ -240,15 +240,68 @@ def backward(ctx, grad_out): } -def context_parallel_attention( - q, k, v, *, backend=SDPBackend.FLASH_ATTENTION, **kwargs -): +def _select_cp_backend(q, k, v, dropout_p=0.0, is_causal=False, attn_mask=None): + """ + Select the best available backend for context parallel attention. + + Uses PyTorch's internal backend selection logic to determine which SDPA + backend can handle the given inputs, respecting the priority order. + + Args: + q, k, v: Query, key, value tensors + dropout_p: Dropout probability + is_causal: Whether to use causal attention + attn_mask: Optional attention mask + + Returns: + SDPBackend enum value for the selected backend + + Raises: + RuntimeError: If no suitable backend is available + """ + from torch.backends.cuda import ( + SDPAParams, + can_use_cudnn_attention, + can_use_efficient_attention, + can_use_flash_attention, + ) + + # Create params object for backend selection + # SDPAParams signature: query, key, value, attn_mask, dropout, is_causal, enable_gqa + params = SDPAParams(q, k, v, attn_mask, dropout_p, is_causal, False) + + # Map backend enum values to their can_use functions + # Only include backends we support for context parallel + backend_checks = { + SDPBackend.FLASH_ATTENTION: can_use_flash_attention, + SDPBackend.EFFICIENT_ATTENTION: can_use_efficient_attention, + SDPBackend.CUDNN_ATTENTION: can_use_cudnn_attention, + } + + # Get priority order from PyTorch + priority_order = torch._C._get_sdp_priority_order() + + for backend_id in priority_order: + for backend_enum, can_use_fn in backend_checks.items(): + if backend_enum.value == backend_id and can_use_fn(params): + return backend_enum + + raise RuntimeError( + "No suitable SDPA backend available for context parallel attention. " + "Supported backends are: FLASH_ATTENTION, EFFICIENT_ATTENTION, CUDNN_ATTENTION. " + "Check that your inputs are compatible with at least one of these backends." + ) + + +def context_parallel_attention(q, k, v, *, backend=None, **kwargs): """ Generic context parallel attention supporting multiple backends. Args: q, k, v: Query, key, value tensors - backend: SDPBackend to use (FLASH_ATTENTION, EFFICIENT_ATTENTION, or CUDNN_ATTENTION) + backend: SDPBackend to use (FLASH_ATTENTION, EFFICIENT_ATTENTION, or CUDNN_ATTENTION). + If None (default), automatically selects the best available backend using + PyTorch's internal selection logic. **kwargs: Additional arguments passed to the attention operation (e.g., dropout_p, is_causal, scale, attn_bias) Returns: @@ -257,6 +310,17 @@ def context_parallel_attention( This function is future-proof as it uses **kwargs to pass arguments, so changes to backend signatures won't require updating this function. """ + # Auto-select backend if not specified + if backend is None: + backend = _select_cp_backend( + q, + k, + v, + dropout_p=kwargs.get("dropout_p", 0.0), + is_causal=kwargs.get("is_causal", False), + attn_mask=kwargs.get("attn_bias", None), + ) + if backend not in _CP_ATTENTION_BACKENDS: raise ValueError( f"Unsupported backend: {backend}. Supported backends: {list(_CP_ATTENTION_BACKENDS.keys())}" From f92e7dbffabac308ede34931b47bbf05ec18e618 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 13 Jan 2026 14:10:24 +0000 Subject: [PATCH 5/6] More fixes for compile --- autoparallel/_testing/models/llama3.py | 33 ++++---------------------- autoparallel/ops.py | 2 ++ 2 files changed, 7 insertions(+), 28 deletions(-) diff --git a/autoparallel/_testing/models/llama3.py b/autoparallel/_testing/models/llama3.py index 685add65..418682df 100644 --- a/autoparallel/_testing/models/llama3.py +++ b/autoparallel/_testing/models/llama3.py @@ -9,7 +9,7 @@ import torch import torch.nn.functional as F from torch import nn -from torch.nn.attention import SDPBackend # , sdpa_kernel +from torch.nn.attention import SDPBackend, sdpa_kernel from autoparallel.ops import context_parallel_attention @@ -47,37 +47,14 @@ def _init_backend(cls) -> None: if has_cuda_capability(10, 0): cls.backends.insert(0, SDPBackend.CUDNN_ATTENTION) - def _select_backend(self) -> SDPBackend: - """ - Select the best available backend for context parallel attention. - Only considers backends that are supported by context parallel. - """ - supported_cp_backends = { - SDPBackend.FLASH_ATTENTION, - SDPBackend.EFFICIENT_ATTENTION, - SDPBackend.CUDNN_ATTENTION, - } - - for backend in self.backends: - if backend in supported_cp_backends: - return backend - - # Fallback to flash attention if no supported backend is found - return SDPBackend.FLASH_ATTENTION - def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor ) -> torch.Tensor: assert self.backends, "SDPA Backends should not be empty." - - # Select the best available backend - backend = self._select_backend() - - # Use context parallel attention with the selected backend - # All backend-specific arguments (is_causal, dropout_p, scale, etc.) are passed via kwargs - return context_parallel_attention( - q, k, v, backend=backend, is_causal=True, dropout_p=0.0 - ) + with sdpa_kernel(self.backends, set_priority=True): + # Use context parallel attention with the selected backend + # All backend-specific arguments (is_causal, dropout_p, scale, etc.) are passed via kwargs + return context_parallel_attention(q, k, v, is_causal=True, dropout_p=0.0) def build_attention( diff --git a/autoparallel/ops.py b/autoparallel/ops.py index 1def6182..d6f5664d 100644 --- a/autoparallel/ops.py +++ b/autoparallel/ops.py @@ -240,6 +240,8 @@ def backward(ctx, grad_out): } +# TODO: using assume_constant_result isn't strictly correct, but gets the job done for now +@torch.compiler.assume_constant_result def _select_cp_backend(q, k, v, dropout_p=0.0, is_causal=False, attn_mask=None): """ Select the best available backend for context parallel attention. From 6215779c6f6ea4554dfcd7793ef2baca830d3ca8 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 13 Jan 2026 14:13:10 +0000 Subject: [PATCH 6/6] Fix for AssertionError: Tracing local_map is only currently supported with None placements last. --- autoparallel/ops.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/autoparallel/ops.py b/autoparallel/ops.py index d6f5664d..f84af127 100644 --- a/autoparallel/ops.py +++ b/autoparallel/ops.py @@ -135,7 +135,7 @@ class _ContextParallelAttention(torch.autograd.Function): """ @staticmethod - def forward(ctx, op_forward, op_backward, q, k, v, kwargs_keys_str, *kwargs_values): + def forward(ctx, q, k, v, op_forward, op_backward, kwargs_keys_str, *kwargs_values): """ Args: op_forward: Forward operation (e.g., _scaled_dot_product_ring_flash_attention) @@ -220,7 +220,7 @@ def backward(ctx, grad_out): # Return gradients: # (None for op_forward, None for op_backward, grad_q, grad_k, grad_v, None for kwargs_keys_str, None for each kwargs_value) num_kwargs = len(ctx.kwargs) - return (None, None) + grads[:3] + (None,) + (None,) * num_kwargs + return grads[:3] + (None, None, None) + (None,) * num_kwargs # Backend registry for context parallel attention @@ -343,7 +343,7 @@ def context_parallel_attention(q, k, v, *, backend=None, **kwargs): # Note: mesh is NOT passed through local_map (it would be flattened by pytree) # Instead, we retrieve it inside the autograd function using get_mesh_from_global() num_kwargs = len(kwargs) - in_placements = (None, None, plc, plc, plc, None) + (None,) * num_kwargs + in_placements = (plc, plc, plc, None, None, None) + (None,) * num_kwargs return local_map( _ContextParallelAttention.apply, @@ -352,4 +352,4 @@ def context_parallel_attention(q, k, v, *, backend=None, **kwargs): redistribute_inputs=True, in_grad_placements=None, device_mesh=mesh, - )(op_forward, op_backward, q, k, v, kwargs_keys_str, *kwargs_values) + )(q, k, v, op_forward, op_backward, kwargs_keys_str, *kwargs_values)