diff --git a/autoparallel/_testing/models/llama3.py b/autoparallel/_testing/models/llama3.py index 9d349e1a..418682df 100644 --- a/autoparallel/_testing/models/llama3.py +++ b/autoparallel/_testing/models/llama3.py @@ -11,6 +11,8 @@ from torch import nn from torch.nn.attention import SDPBackend, sdpa_kernel +from autoparallel.ops import context_parallel_attention + def has_cuda_capability(major: int, minor: int) -> bool: return torch.cuda.is_available() and torch.cuda.get_device_capability() >= ( @@ -50,7 +52,9 @@ def forward( ) -> torch.Tensor: assert self.backends, "SDPA Backends should not be empty." with sdpa_kernel(self.backends, set_priority=True): - return F.scaled_dot_product_attention(q, k, v, is_causal=True) + # Use context parallel attention with the selected backend + # All backend-specific arguments (is_causal, dropout_p, scale, etc.) are passed via kwargs + return context_parallel_attention(q, k, v, is_causal=True, dropout_p=0.0) def build_attention( diff --git a/autoparallel/ops.py b/autoparallel/ops.py new file mode 100644 index 00000000..f84af127 --- /dev/null +++ b/autoparallel/ops.py @@ -0,0 +1,355 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch.distributed.tensor.experimental._context_parallel._attention import ( + _scaled_dot_product_ring_cudnn_attention, + _scaled_dot_product_ring_cudnn_attention_backward, + _scaled_dot_product_ring_efficient_attention, + _scaled_dot_product_ring_efficient_attention_backward, + _scaled_dot_product_ring_flash_attention, + _scaled_dot_product_ring_flash_attention_backward, +) +from torch.distributed.tensor.placement_types import Shard +from torch.nn.attention import SDPBackend + +from autoparallel.collectives import get_mesh_from_global, local_map + + +# Backend-specific backward wrappers to handle signature differences +def _flash_backward_wrapper(mesh, grad_out, q, k, v, out, forward_outputs, kwargs): + """Handle flash attention backward with correct argument order.""" + # Forward outputs: lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask + # Backward expects: mesh, grad_out, query, key, value, out, logsumexp, cum_seq_q, cum_seq_k, + # max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, *, scale + ( + lse, + cum_seq_q, + cum_seq_k, + max_q, + max_k, + philox_seed, + philox_offset, + ) = forward_outputs[:7] + return _scaled_dot_product_ring_flash_attention_backward( + mesh, + grad_out, + q, + k, + v, + out, + lse, + cum_seq_q, + cum_seq_k, + max_q, + max_k, + kwargs.get("dropout_p", 0.0), + kwargs.get("is_causal", False), + philox_seed, + philox_offset, + scale=kwargs.get("scale", None), + ) + + +def _efficient_backward_wrapper(mesh, grad_out, q, k, v, out, forward_outputs, kwargs): + """Handle efficient attention backward with correct argument order.""" + # Forward outputs: lse, philox_seed, philox_offset + # Backward expects: mesh, grad_out, query, key, value, bias, out, logsumexp, + # philox_seed, philox_offset, dropout_p, grad_input_mask, is_causal, *, scale + lse, philox_seed, philox_offset = forward_outputs[:3] + # Build grad_input_mask based on which inputs require gradients + attn_bias = kwargs.get("attn_bias", None) + grad_input_mask = ( + q.requires_grad, + k.requires_grad, + v.requires_grad, + attn_bias.requires_grad if attn_bias is not None else False, + ) + return _scaled_dot_product_ring_efficient_attention_backward( + mesh, + grad_out, + q, + k, + v, + attn_bias, + out, + lse, + philox_seed, + philox_offset, + kwargs.get("dropout_p", 0.0), + grad_input_mask, + kwargs.get("is_causal", False), + scale=kwargs.get("scale", None), + ) + + +def _cudnn_backward_wrapper(mesh, grad_out, q, k, v, out, forward_outputs, kwargs): + """Handle cudnn attention backward with correct argument order.""" + # Forward outputs: lse, philox_seed, philox_offset, softmax_stats(?), bias(?), cum_seq_q, cum_seq_k, max_q, max_k, debug_attn_mask + # Backward expects: mesh, grad_out, query, key, value, out, logsumexp, + # philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, + # max_q, max_k, dropout_p, is_causal, *, scale + lse, philox_seed, philox_offset = forward_outputs[:3] + # CuDNN may have additional outputs; extract what we need + if len(forward_outputs) >= 9: + cum_seq_q, cum_seq_k, max_q, max_k = forward_outputs[5:9] + else: + # Fallback if structure is different + cum_seq_q, cum_seq_k, max_q, max_k = forward_outputs[3:7] + + return _scaled_dot_product_ring_cudnn_attention_backward( + mesh, + grad_out, + q, + k, + v, + out, + lse, + philox_seed, + philox_offset, + kwargs.get("attn_bias", None), + cum_seq_q, + cum_seq_k, + max_q, + max_k, + kwargs.get("dropout_p", 0.0), + kwargs.get("is_causal", False), + scale=kwargs.get("scale", None), + ) + + +# Mapping of backward functions to their wrappers +_CP_BACKWARD_WRAPPERS = { + _scaled_dot_product_ring_flash_attention_backward: _flash_backward_wrapper, + _scaled_dot_product_ring_efficient_attention_backward: _efficient_backward_wrapper, + _scaled_dot_product_ring_cudnn_attention_backward: _cudnn_backward_wrapper, +} + + +class _ContextParallelAttention(torch.autograd.Function): + """ + Generic context parallel attention that supports multiple backends. + Uses **kwargs to be future-proof against signature changes. + """ + + @staticmethod + def forward(ctx, q, k, v, op_forward, op_backward, kwargs_keys_str, *kwargs_values): + """ + Args: + op_forward: Forward operation (e.g., _scaled_dot_product_ring_flash_attention) + op_backward: Backward operation (e.g., _scaled_dot_product_ring_flash_attention_backward) + q, k, v: Query, key, value tensors + kwargs_keys_str: Comma-separated string of kwarg names (e.g., 'dropout_p,is_causal,scale') + *kwargs_values: Values corresponding to kwargs_keys + """ + # Get mesh from global context (avoids passing it through local_map which would flatten it) + mesh = get_mesh_from_global()["tp"] + + ctx.op_backward = op_backward + ctx.mesh = mesh + + # Reconstruct kwargs dict from keys string and values + kwargs_keys = kwargs_keys_str.split(",") if kwargs_keys_str else [] + kwargs_dict = dict(zip(kwargs_keys, kwargs_values)) + ctx.kwargs = kwargs_dict + + # Call the forward operation with all kwargs + outputs = op_forward(mesh, q, k, v, **kwargs_dict) + + # outputs is a tuple: (out, lse, ...) where the rest varies by backend + out = outputs[0] + forward_outputs = outputs[1:] + + # Separate tensors from non-tensors for proper saving + # Tensors must be saved via save_for_backward for proper memory management + tensors_to_save = [q, k, v, out] + non_tensor_outputs = [] + + for i, item in enumerate(forward_outputs): + if isinstance(item, torch.Tensor): + tensors_to_save.append(item) + non_tensor_outputs.append(("tensor", len(tensors_to_save) - 1)) + else: + non_tensor_outputs.append(("value", item)) + + ctx.save_for_backward(*tensors_to_save) + ctx.non_tensor_outputs = non_tensor_outputs + ctx.num_forward_outputs = len(forward_outputs) + + return out + + @staticmethod + def backward(ctx, grad_out): + # Retrieve saved tensors + saved_tensors = ctx.saved_tensors + q, k, v, out = saved_tensors[:4] + saved_forward_tensors = saved_tensors[4:] + + # Reconstruct forward_outputs from saved tensors and non-tensor values + forward_outputs = [] + tensor_idx = 0 + for output_type, output_value in ctx.non_tensor_outputs: + if output_type == "tensor": + forward_outputs.append(saved_forward_tensors[tensor_idx]) + tensor_idx += 1 + else: + forward_outputs.append(output_value) + forward_outputs = tuple(forward_outputs) + + # Use the backend-specific wrapper to handle argument ordering + wrapper_fn = _CP_BACKWARD_WRAPPERS.get(ctx.op_backward) + if wrapper_fn is None: + raise RuntimeError( + f"No backward wrapper found for {ctx.op_backward}. " + "This backend may not be supported yet." + ) + + grads = wrapper_fn( + ctx.mesh, + grad_out, + q, + k, + v, + out, + forward_outputs, + ctx.kwargs, + ) + + # Return gradients: + # (None for op_forward, None for op_backward, grad_q, grad_k, grad_v, None for kwargs_keys_str, None for each kwargs_value) + num_kwargs = len(ctx.kwargs) + return grads[:3] + (None, None, None) + (None,) * num_kwargs + + +# Backend registry for context parallel attention +_CP_ATTENTION_BACKENDS = { + SDPBackend.FLASH_ATTENTION: ( + _scaled_dot_product_ring_flash_attention, + _scaled_dot_product_ring_flash_attention_backward, + ), + SDPBackend.EFFICIENT_ATTENTION: ( + _scaled_dot_product_ring_efficient_attention, + _scaled_dot_product_ring_efficient_attention_backward, + ), + SDPBackend.CUDNN_ATTENTION: ( + _scaled_dot_product_ring_cudnn_attention, + _scaled_dot_product_ring_cudnn_attention_backward, + ), +} + + +# TODO: using assume_constant_result isn't strictly correct, but gets the job done for now +@torch.compiler.assume_constant_result +def _select_cp_backend(q, k, v, dropout_p=0.0, is_causal=False, attn_mask=None): + """ + Select the best available backend for context parallel attention. + + Uses PyTorch's internal backend selection logic to determine which SDPA + backend can handle the given inputs, respecting the priority order. + + Args: + q, k, v: Query, key, value tensors + dropout_p: Dropout probability + is_causal: Whether to use causal attention + attn_mask: Optional attention mask + + Returns: + SDPBackend enum value for the selected backend + + Raises: + RuntimeError: If no suitable backend is available + """ + from torch.backends.cuda import ( + SDPAParams, + can_use_cudnn_attention, + can_use_efficient_attention, + can_use_flash_attention, + ) + + # Create params object for backend selection + # SDPAParams signature: query, key, value, attn_mask, dropout, is_causal, enable_gqa + params = SDPAParams(q, k, v, attn_mask, dropout_p, is_causal, False) + + # Map backend enum values to their can_use functions + # Only include backends we support for context parallel + backend_checks = { + SDPBackend.FLASH_ATTENTION: can_use_flash_attention, + SDPBackend.EFFICIENT_ATTENTION: can_use_efficient_attention, + SDPBackend.CUDNN_ATTENTION: can_use_cudnn_attention, + } + + # Get priority order from PyTorch + priority_order = torch._C._get_sdp_priority_order() + + for backend_id in priority_order: + for backend_enum, can_use_fn in backend_checks.items(): + if backend_enum.value == backend_id and can_use_fn(params): + return backend_enum + + raise RuntimeError( + "No suitable SDPA backend available for context parallel attention. " + "Supported backends are: FLASH_ATTENTION, EFFICIENT_ATTENTION, CUDNN_ATTENTION. " + "Check that your inputs are compatible with at least one of these backends." + ) + + +def context_parallel_attention(q, k, v, *, backend=None, **kwargs): + """ + Generic context parallel attention supporting multiple backends. + + Args: + q, k, v: Query, key, value tensors + backend: SDPBackend to use (FLASH_ATTENTION, EFFICIENT_ATTENTION, or CUDNN_ATTENTION). + If None (default), automatically selects the best available backend using + PyTorch's internal selection logic. + **kwargs: Additional arguments passed to the attention operation (e.g., dropout_p, is_causal, scale, attn_bias) + + Returns: + Attention output tensor + + This function is future-proof as it uses **kwargs to pass arguments, so changes + to backend signatures won't require updating this function. + """ + # Auto-select backend if not specified + if backend is None: + backend = _select_cp_backend( + q, + k, + v, + dropout_p=kwargs.get("dropout_p", 0.0), + is_causal=kwargs.get("is_causal", False), + attn_mask=kwargs.get("attn_bias", None), + ) + + if backend not in _CP_ATTENTION_BACKENDS: + raise ValueError( + f"Unsupported backend: {backend}. Supported backends: {list(_CP_ATTENTION_BACKENDS.keys())}" + ) + + op_forward, op_backward = _CP_ATTENTION_BACKENDS[backend] + + mesh = get_mesh_from_global() + plc = (Shard(0), Shard(2)) + out_placements = (plc,) + + # Convert kwargs to a comma-separated string of keys and a tuple of values + # Using a string prevents pytree from flattening it + kwargs_keys_str = ",".join(kwargs.keys()) if kwargs else "" + kwargs_values = tuple(kwargs.values()) + + # in_placements for: op_forward, op_backward, q, k, v, kwargs_keys_str, *kwargs_values + # Note: mesh is NOT passed through local_map (it would be flattened by pytree) + # Instead, we retrieve it inside the autograd function using get_mesh_from_global() + num_kwargs = len(kwargs) + in_placements = (plc, plc, plc, None, None, None) + (None,) * num_kwargs + + return local_map( + _ContextParallelAttention.apply, + out_placements=out_placements, + in_placements=in_placements, + redistribute_inputs=True, + in_grad_placements=None, + device_mesh=mesh, + )(q, k, v, op_forward, op_backward, kwargs_keys_str, *kwargs_values) diff --git a/examples/example_llama3.py b/examples/example_llama3.py index 7211e668..cd114085 100644 --- a/examples/example_llama3.py +++ b/examples/example_llama3.py @@ -52,12 +52,16 @@ model_type = "8b" enable_asynctp = False +import autoparallel.collectives + +autoparallel.collectives._local_map_device_mesh = mesh + def model_fn(): if model_type == "8b": model_args = TransformerModelArgs( dim=4096, - n_layers=32, + n_layers=1, # 32, n_heads=32, n_kv_heads=8, ffn_dim_multiplier=1.3,