From 4b0c6470073a4b41cfd6acbfca79809b9295edb7 Mon Sep 17 00:00:00 2001 From: Bissmella Bahaduri Date: Thu, 20 Nov 2025 09:45:44 +0100 Subject: [PATCH 1/8] initial scheme of unified-sp --- src/diffusers/models/attention_dispatch.py | 58 ++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 0c247b76d039..c7c395d16920 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1025,6 +1025,14 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor: x = _wait_tensor(x) return x +def _all_to_all_double(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor: + pass + + +class SeqAllToAllDouble(torch.autograd.Function): + pass + + class TemplatedRingAttention(torch.autograd.Function): @staticmethod @@ -1244,6 +1252,56 @@ def backward( return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None +class TemplatedUnifiedAttention(torch.nn.Module): + @staticmethod + def forward(ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor], + dropout_p: float, + is_causal: bool, + scale: Optional[float], + enable_gqa: bool, + return_lse: bool, + forward_op, + backward_op, + _parallel_config: Optional["ParallelConfig"] = None, + ): + ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh + ulysses_group = ulysses_mesh.get_group() + ring_mesh = _parallel_config.context_parallel_config._ring_mesh + ring_group = ring_mesh.get_group() + scatter_idx = 2 + gather_idx = 1 + + query = SeqAllToAllDouble.apply(ulysses_group, query, scatter_idx, gather_idx) + key = SeqAllToAllDouble.apply(ulysses_group, key, scatter_idx, gather_idx) + value = SeqAllToAllDouble.apply(ulysses_group, value, scatter_idx, gather_idx) + out = TemplatedRingAttention.apply( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op, + backward_op, + _parallel_config, + ) + if return_lse: + context_layer, lse, *_ = out + else: + context_layer = out + output = SeqAllToAllDouble.apply( + ulysses_group, + context_layer, + gather_idx, + scatter_idx, + ) def _templated_context_parallel_attention( query: torch.Tensor, From 81494b8ef4b5896bfd3af9df27c65f5cf4427e17 Mon Sep 17 00:00:00 2001 From: Bissmella Bahaduri Date: Thu, 20 Nov 2025 10:57:55 +0100 Subject: [PATCH 2/8] initial all_to_all_double --- src/diffusers/models/attention_dispatch.py | 47 +++++++++++++++++++++- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index c7c395d16920..33084fe19f28 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1026,11 +1026,54 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor: return x def _all_to_all_double(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor: - pass + group_world_size = funcol.get_world_size(group) + #dist.get_world_size(group) + + if scatter_idx == 2 and gather_idx == 1: + B, S_LOCAL, H, D = x.shape + S = S_LOCAL * group_world_size + H_LOCAL = H // group_world_size + + x_temp = (x.reshape(B, S_LOCAL, group_world_size, H_LOCAL, D) + .permute(0, 2, 1, 3, 4).contiguous() + ) + + out = torch.empty_like(x_temp) + if group_world_size >1: + funcol.all_to_all_single(out, x_temp, None, None, group) + else: + out = x_temp + out = out.reshape(S, B, H_LOCAL, D).permute(1, 0, 2, 3).contiguous() + out = out.reshape(B, S, H_LOCAL, D) + return out + elif scatter_idx == 1 and gather_idx == 2: + B, S, H_LOCAL, D = x.shape + H = H_LOCAL * group_world_size + S_LOCAL = S // group_world_size + + # + x_temp = (x.reshape(B, group_world_size, S_LOCAL, H_LOCAL, D) + .permute(1, 3, 2, 0, 4).reshape(group_world_size, H_LOCAL, S_LOCAL, B, D)) + output = torch.empty_like(x_temp) + if group_world_size >1: + funcol.all_to_all_single(output, x_temp, None, None, group) + else: + output = x_temp + output = output.reshape(H, S_LOCAL, B, D).transpose(0, 2).contiguous() + output = output.reshape(B, S_LOCAL, H, D) + return output + else: + raise ValueError("Invalid scatter/gather indices for all_to_all_double.") class SeqAllToAllDouble(torch.autograd.Function): - pass + @staticmethod + def forward(): + pass + + @staticmethod + def backward(): + pass From 83fc6067b8429614a9151c5b0a55e5925180ac31 Mon Sep 17 00:00:00 2001 From: Bissmella Bahaduri Date: Thu, 20 Nov 2025 12:15:01 +0100 Subject: [PATCH 3/8] bug fixes, added cmnts --- src/diffusers/models/attention_dispatch.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 33084fe19f28..225e0776fe78 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1025,22 +1025,23 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor: x = _wait_tensor(x) return x -def _all_to_all_double(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor: +def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor: group_world_size = funcol.get_world_size(group) - #dist.get_world_size(group) if scatter_idx == 2 and gather_idx == 1: B, S_LOCAL, H, D = x.shape S = S_LOCAL * group_world_size H_LOCAL = H // group_world_size + # B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D x_temp = (x.reshape(B, S_LOCAL, group_world_size, H_LOCAL, D) - .permute(0, 2, 1, 3, 4).contiguous() + .transpose(0, 2).contiguous() ) - out = torch.empty_like(x_temp) if group_world_size >1: - funcol.all_to_all_single(out, x_temp, None, None, group) + #maybe here need to use the _all_to_all_single helper to avoid contiguity issues + out = funcol.all_to_all_single(x_temp, None, None, group=group) + out = _wait_tensor(out) else: out = x_temp out = out.reshape(S, B, H_LOCAL, D).permute(1, 0, 2, 3).contiguous() @@ -1054,19 +1055,20 @@ def _all_to_all_double(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = # x_temp = (x.reshape(B, group_world_size, S_LOCAL, H_LOCAL, D) .permute(1, 3, 2, 0, 4).reshape(group_world_size, H_LOCAL, S_LOCAL, B, D)) - output = torch.empty_like(x_temp) + if group_world_size >1: - funcol.all_to_all_single(output, x_temp, None, None, group) + output = funcol.all_to_all_single(x_temp, None, None, group) + output = _wait_tensor(output) else: output = x_temp output = output.reshape(H, S_LOCAL, B, D).transpose(0, 2).contiguous() output = output.reshape(B, S_LOCAL, H, D) return output else: - raise ValueError("Invalid scatter/gather indices for all_to_all_double.") + raise ValueError("Invalid scatter/gather indices for _all_to_all_dim_exchange.") -class SeqAllToAllDouble(torch.autograd.Function): +class SeqAllToAllDim(torch.autograd.Function): @staticmethod def forward(): pass From fcb06e52f5395d4c04b24a61dc86e0b27a513029 Mon Sep 17 00:00:00 2001 From: Bissmella Bahaduri Date: Thu, 20 Nov 2025 14:36:42 +0100 Subject: [PATCH 4/8] unified attention prototype done --- src/diffusers/models/attention_dispatch.py | 27 +++++++++++++++------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 225e0776fe78..cb64c3ae76da 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1044,6 +1044,7 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: out = _wait_tensor(out) else: out = x_temp + # group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D out = out.reshape(S, B, H_LOCAL, D).permute(1, 0, 2, 3).contiguous() out = out.reshape(B, S, H_LOCAL, D) return out @@ -1057,6 +1058,7 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: .permute(1, 3, 2, 0, 4).reshape(group_world_size, H_LOCAL, S_LOCAL, B, D)) if group_world_size >1: + #maybe here need to use the _all_to_all_single helper to avoid contiguity issues output = funcol.all_to_all_single(x_temp, None, None, group) output = _wait_tensor(output) else: @@ -1070,12 +1072,15 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: class SeqAllToAllDim(torch.autograd.Function): @staticmethod - def forward(): - pass + def forward(ctx, group, input, scatter_id=2, gather_id=1): + ctx.group = group + ctx.scatter_id = scatter_id + ctx.gather_id = gather_id + return _all_to_all_dim_exchange(input, scatter_id, gather_id, group) @staticmethod - def backward(): - pass + def backward(ctx, *grad_outputs): + return (None, _all_to_all_dim_exchange(grad_outputs[0], ctx.gather_id, ctx.scatter_id, ctx.group), None, None) @@ -1317,12 +1322,13 @@ def forward(ctx: torch.autograd.function.FunctionCtx, ulysses_group = ulysses_mesh.get_group() ring_mesh = _parallel_config.context_parallel_config._ring_mesh ring_group = ring_mesh.get_group() + #hardcoded for now scatter_idx = 2 gather_idx = 1 - query = SeqAllToAllDouble.apply(ulysses_group, query, scatter_idx, gather_idx) - key = SeqAllToAllDouble.apply(ulysses_group, key, scatter_idx, gather_idx) - value = SeqAllToAllDouble.apply(ulysses_group, value, scatter_idx, gather_idx) + query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx) + key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx) + value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx) out = TemplatedRingAttention.apply( query, key, @@ -1341,12 +1347,17 @@ def forward(ctx: torch.autograd.function.FunctionCtx, context_layer, lse, *_ = out else: context_layer = out - output = SeqAllToAllDouble.apply( + output = SeqAllToAllDim.apply( ulysses_group, context_layer, gather_idx, scatter_idx, ) + if return_lse: + # not sure if this is correct + lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx) + return (output, lse) + return output def _templated_context_parallel_attention( query: torch.Tensor, From 4b7177750146aa93098454260a9fc96d7877cf14 Mon Sep 17 00:00:00 2001 From: Bissmella Bahaduri Date: Thu, 20 Nov 2025 15:46:42 +0100 Subject: [PATCH 5/8] remove raising value error in contextParallelConfig to enable unified attention --- src/diffusers/models/_modeling_parallel.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index 2a4eb520c796..8d9e4193616c 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -90,10 +90,10 @@ def __post_init__(self): ) if self.ring_degree < 1 or self.ulysses_degree < 1: raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.") - if self.ring_degree > 1 and self.ulysses_degree > 1: - raise ValueError( - "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1." - ) + # if self.ring_degree > 1 and self.ulysses_degree > 1: + # raise ValueError( + # "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1." + # ) if self.rotate_method != "allgather": raise NotImplementedError( f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}." From e0ed41e36cce350534c9cca254a84338cbf2bb7c Mon Sep 17 00:00:00 2001 From: Bissmella Bahaduri Date: Fri, 21 Nov 2025 09:15:50 +0100 Subject: [PATCH 6/8] bug fix --- src/diffusers/models/attention_dispatch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index cb64c3ae76da..e73a0df6815a 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1026,7 +1026,7 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor: return x def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor: - group_world_size = funcol.get_world_size(group) + group_world_size = torch.distributed.get_world_size(group) if scatter_idx == 2 and gather_idx == 1: B, S_LOCAL, H, D = x.shape From 3a407d8dbc96086812278d5b385cb613ecb5a0ea Mon Sep 17 00:00:00 2001 From: KarthikSundar2002 Date: Fri, 21 Nov 2025 11:52:22 +0000 Subject: [PATCH 7/8] feat: Adds Test for Unified SP Attention and Fixes a bug in Template Ring Attention --- src/diffusers/models/attention_dispatch.py | 2 +- tests/others/test_unified_sp_attention.py | 131 +++++++++++++++++++++ 2 files changed, 132 insertions(+), 1 deletion(-) create mode 100644 tests/others/test_unified_sp_attention.py diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index e73a0df6815a..adcc673b573d 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1205,7 +1205,7 @@ def backward( grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value)) - return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None + return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None class TemplatedUlyssesAttention(torch.autograd.Function): diff --git a/tests/others/test_unified_sp_attention.py b/tests/others/test_unified_sp_attention.py new file mode 100644 index 000000000000..00c4403bf3d2 --- /dev/null +++ b/tests/others/test_unified_sp_attention.py @@ -0,0 +1,131 @@ +import math +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from diffusers.models.attention_dispatch import TemplatedUnifiedAttention +import os + +def run(rank, world_size): + dist.init_process_group( + backend="gloo", + rank=rank, + world_size=world_size + ) + + torch.manual_seed(0) + + B, S, H, D = 2, 8, 4, 16 # small toy + q = torch.randn(B, S, H, D) + k = torch.randn(B, S, H, D) + v = torch.randn(B, S, H, D) + + q.requires_grad_(True) + + from diffusers.models._modeling_parallel import ( + ParallelConfig, + ContextParallelConfig + ) + + pc = ParallelConfig( + context_parallel_config=ContextParallelConfig( + ring_degree=2, + ulysses_degree=2, + ) + ) + + pc.context_parallel_config.setup( + rank=rank, + world_size=world_size, + device=torch.device("cpu"), + mesh=dist.device_mesh.init_device_mesh("cpu", + (2,2), + mesh_dim_names=["ring", "ulysses"], + ) + ) + + def dummy_forward_op( + ctx, + q, + k, + v, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + *, + _save_ctx=True, + _parallel_config=None, + ): + head_scale = math.sqrt(D) + attn = (q @ k.transpose(-1, -2)) / head_scale + out = attn @ v + lse = torch.logsumexp(attn, dim=-1) + + if _save_ctx: + ctx.save_for_backward(q, k, v) + ctx._cached_qkv = [] + ctx._cached_iter = 0 + + if not hasattr(ctx, "_cached_qkv"): + ctx._cached_qkv = [] + + ctx._cached_qkv.append((q.detach(), k.detach(), v.detach())) + + return (out, lse) if return_lse else out + + def dummy_backward_op(ctx, grad_out, *args, **kwargs): + if not hasattr(ctx, "_cached_qkv"): + raise RuntimeError("No cached tensors for backward.") + + if not hasattr(ctx, "_cached_iter"): + ctx._cached_iter = 0 + + if ctx._cached_iter >= len(ctx._cached_qkv): + raise RuntimeError("Backward called more times than cached forwards.") + + q, k, v = ctx._cached_qkv[ctx._cached_iter] + ctx._cached_iter += 1 + + head_scale = math.sqrt(D) + attn = (q @ k.transpose(-1, -2)) / head_scale + + grad_v = attn.transpose(-1, -2) @ grad_out + grad_attn = grad_out @ v.transpose(-1, -2) + grad_q = (grad_attn @ k) / head_scale + grad_k = (grad_attn.transpose(-1, -2) @ q) / head_scale + + return ( + grad_q, + grad_k, + grad_v, + ) + + attn = TemplatedUnifiedAttention() + + out = attn( + None, + q, k, v, None, + dropout_p=0.0, + is_causal=False, + scale=None, + enable_gqa=False, + return_lse=False, + forward_op=dummy_forward_op, + backward_op=dummy_backward_op, + _parallel_config=pc, + ) + + print(f"[RANK {rank}] output:", out.shape) + + out.sum().backward() + print(f"[RANK {rank}] grad:", q.grad.shape) + + dist.destroy_process_group() + +if __name__ == "__main__": + world_size = 4 + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + mp.spawn(run, args=(world_size,), nprocs=world_size) \ No newline at end of file From 9ebcff56080b408e5199809b9ffff6dc5e506d93 Mon Sep 17 00:00:00 2001 From: Bissmella Bahaduri Date: Sun, 23 Nov 2025 12:34:44 +0100 Subject: [PATCH 8/8] bug fix, lse calculation, testing bug fixes, lse calculation - switched to _all_to_all_single helper in _all_to_all_dim_exchange due contiguity issues bug fix bug fix bug fix --- src/diffusers/models/attention_dispatch.py | 155 ++++++++++++--------- tests/others/test_unified_sp_attention.py | 4 +- 2 files changed, 89 insertions(+), 70 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index adcc673b573d..823d119c9753 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1034,14 +1034,13 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: H_LOCAL = H // group_world_size # B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D - x_temp = (x.reshape(B, S_LOCAL, group_world_size, H_LOCAL, D) - .transpose(0, 2).contiguous() - ) + x_temp = x.reshape(B, S_LOCAL, group_world_size, H_LOCAL, D).transpose(0, 2).contiguous() + if group_world_size >1: #maybe here need to use the _all_to_all_single helper to avoid contiguity issues - out = funcol.all_to_all_single(x_temp, None, None, group=group) - out = _wait_tensor(out) + out = _all_to_all_single(x_temp, group=group) + #out = _wait_tensor(out) else: out = x_temp # group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D @@ -1053,14 +1052,13 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: H = H_LOCAL * group_world_size S_LOCAL = S // group_world_size - # - x_temp = (x.reshape(B, group_world_size, S_LOCAL, H_LOCAL, D) - .permute(1, 3, 2, 0, 4).reshape(group_world_size, H_LOCAL, S_LOCAL, B, D)) + #B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D + x_temp = x.reshape(B, group_world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 2, 0, 4).reshape(group_world_size, H_LOCAL, S_LOCAL, B, D) if group_world_size >1: #maybe here need to use the _all_to_all_single helper to avoid contiguity issues - output = funcol.all_to_all_single(x_temp, None, None, group) - output = _wait_tensor(output) + output = _all_to_all_single(x_temp, group) + #output = _wait_tensor(output) else: output = x_temp output = output.reshape(H, S_LOCAL, B, D).transpose(0, 2).contiguous() @@ -1079,8 +1077,14 @@ def forward(ctx, group, input, scatter_id=2, gather_id=1): return _all_to_all_dim_exchange(input, scatter_id, gather_id, group) @staticmethod - def backward(ctx, *grad_outputs): - return (None, _all_to_all_dim_exchange(grad_outputs[0], ctx.gather_id, ctx.scatter_id, ctx.group), None, None) + def backward(ctx, grad_outputs): + grad_input = SeqAllToAllDim.apply( + ctx.group, + grad_outputs, + ctx.gather_id, # reversed + ctx.scatter_id, # reversed + ) + return (None, grad_input, None, None) @@ -1302,62 +1306,64 @@ def backward( return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None -class TemplatedUnifiedAttention(torch.nn.Module): - @staticmethod - def forward(ctx: torch.autograd.function.FunctionCtx, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_mask: Optional[torch.Tensor], - dropout_p: float, - is_causal: bool, - scale: Optional[float], - enable_gqa: bool, - return_lse: bool, +def TemplatedUnifiedAttention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor], + dropout_p: float, + is_causal: bool, + scale: Optional[float], + enable_gqa: bool, + return_lse: bool, + forward_op, + backward_op, + _parallel_config: Optional["ParallelConfig"] = None, + ): + ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh + ulysses_group = ulysses_mesh.get_group() + ring_mesh = _parallel_config.context_parallel_config._ring_mesh + ring_group = ring_mesh.get_group() + #hardcoded for now + scatter_idx = 2 + gather_idx = 1 + + query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx) + key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx) + value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx) + out = TemplatedRingAttention.apply( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, forward_op, backward_op, - _parallel_config: Optional["ParallelConfig"] = None, - ): - ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh - ulysses_group = ulysses_mesh.get_group() - ring_mesh = _parallel_config.context_parallel_config._ring_mesh - ring_group = ring_mesh.get_group() - #hardcoded for now - scatter_idx = 2 - gather_idx = 1 - - query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx) - key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx) - value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx) - out = TemplatedRingAttention.apply( - query, - key, - value, - attn_mask, - dropout_p, - is_causal, - scale, - enable_gqa, - return_lse, - forward_op, - backward_op, - _parallel_config, - ) - if return_lse: - context_layer, lse, *_ = out - else: - context_layer = out - output = SeqAllToAllDim.apply( - ulysses_group, - context_layer, - gather_idx, - scatter_idx, - ) - if return_lse: - # not sure if this is correct - lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx) - return (output, lse) - return output + _parallel_config, + ) + if return_lse: + context_layer, lse, *_ = out + else: + context_layer = out + # Assuming (based on forward ops implementations) context_layer is of shape (B, S, H_LOCAL, D) + output = SeqAllToAllDim.apply( + ulysses_group, + context_layer, + gather_idx, + scatter_idx, + ) + if return_lse: + # not sure if this is correct: Assuming (based on forward ops in ringAttention) + # the lse is of shape (B, S, H_LOCAL) + lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1) + lse = SeqAllToAllDim.apply(ulysses_group, lse, scatter_idx=2, gather_idx=1) + lse = lse.squeeze(-1) + return (output, lse) + return output def _templated_context_parallel_attention( query: torch.Tensor, @@ -1382,7 +1388,22 @@ def _templated_context_parallel_attention( raise ValueError("GQA is not yet supported for templated attention.") # TODO: add support for unified attention with ring/ulysses degree both being > 1 - if _parallel_config.context_parallel_config.ring_degree > 1: + if _parallel_config.context_parallel_config.ring_degree > 1 and _parallel_config.context_parallel_config.ulysses_degree > 1: + return TemplatedUnifiedAttention( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op, + backward_op, + _parallel_config, + ) + elif _parallel_config.context_parallel_config.ring_degree > 1: return TemplatedRingAttention.apply( query, key, diff --git a/tests/others/test_unified_sp_attention.py b/tests/others/test_unified_sp_attention.py index 00c4403bf3d2..4c0621999bd0 100644 --- a/tests/others/test_unified_sp_attention.py +++ b/tests/others/test_unified_sp_attention.py @@ -102,10 +102,8 @@ def dummy_backward_op(ctx, grad_out, *args, **kwargs): grad_v, ) - attn = TemplatedUnifiedAttention() - out = attn( - None, + out = TemplatedUnifiedAttention( q, k, v, None, dropout_p=0.0, is_causal=False,