Skip to content

Commit dcb23b3

Browse files
committed
fix
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
1 parent a1e0c51 commit dcb23b3

File tree

4 files changed

+114
-28
lines changed

4 files changed

+114
-28
lines changed

tests/jax/test_fused_attn.py

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,16 @@ def init():
5959
yield
6060

6161

62-
@partial(jax.jit, static_argnums=(5, 6, 7, 9))
62+
@partial(jax.jit, static_argnums=(6, 7, 8, 9, 11))
6363
def general_dot_product_attention(
6464
query: ArrayLike,
6565
key: ArrayLike,
6666
value: ArrayLike,
67+
softmax_offset: Optional[ArrayLike],
6768
bias: ArrayLike,
6869
mask: ArrayLike,
6970
deterministic: bool,
71+
softmax_type: AttnSoftmaxType,
7072
scale_factor: float,
7173
dropout_rate: float,
7274
dropout_rng: ArrayLike,
@@ -99,7 +101,25 @@ def general_dot_product_attention(
99101
mask = jnp.expand_dims(mask, axis=-3)
100102
logits = jnp.where(mask, jnp.finfo(dtype).min, logits)
101103

102-
softmax_out = jax.nn.softmax(logits).astype(dtype)
104+
match softmax_type:
105+
case AttnSoftmaxType.VANILLA_SOFTMAX:
106+
softmax_out = jax.nn.softmax(logits).astype(dtype)
107+
case AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
108+
# Softmax with +1 in denominator: exp(x_i) / (sum(exp(x_j)) + 1)
109+
exp_logits = jnp.exp(logits - jnp.max(logits, axis=-1, keepdims=True))
110+
softmax_out = (exp_logits / (jnp.sum(exp_logits, axis=-1, keepdims=True) + 1.0)).astype(dtype)
111+
case AttnSoftmaxType.LEARNABLE_SOFTMAX:
112+
# Reshape softmax_offset from (1, h_q, 1, 1) to (1, h_kv, num_groups, 1, 1) to match logits
113+
# logits shape: (b, h_kv, num_groups, s_q, s_kv)
114+
if softmax_offset is not None and softmax_offset.size > 0:
115+
softmax_offset_reshaped = softmax_offset.reshape(1, h_kv, num_groups, 1, 1)
116+
else:
117+
softmax_offset_reshaped = jnp.zeros((1, h_kv, num_groups, 1, 1), dtype=jnp.float32)
118+
exp_logits = jnp.exp(logits - jnp.max(logits, axis=-1, keepdims=True))
119+
softmax_out = (exp_logits / (jnp.sum(exp_logits, axis=-1, keepdims=True) + jnp.exp(softmax_offset_reshaped))).astype(dtype)
120+
case _:
121+
raise NotImplementedError(f"Unknown {softmax_type=}")
122+
103123

104124
if not deterministic and dropout_rate > 0.0:
105125
keep_prob = 1.0 - dropout_rate
@@ -219,19 +239,21 @@ def _split_valid_and_invalid(primitive, reference, pad):
219239
return primitive_valid, primitive_invalid, reference_valid, reference_invalid
220240

221241

222-
def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs):
242+
def jax_dpa(query, key, value, bias, softmax_offset, mask, dropout_rng, **kwargs):
223243
"""
224244
JAX native dot product attention implementation
225245
"""
226246
output = general_dot_product_attention(
227247
query,
228248
key,
229249
value,
250+
softmax_offset,
230251
bias,
231252
mask,
232253
deterministic=not kwargs["is_training"],
233254
scale_factor=kwargs["scaling_factor"],
234255
dropout_rate=kwargs["dropout_probability"],
256+
softmax_type=kwargs["softmax_type"],
235257
dropout_rng=dropout_rng,
236258
dtype=jnp.float32,
237259
)
@@ -243,6 +265,7 @@ def customcall_fused_dpa(
243265
key,
244266
value,
245267
bias,
268+
softmax_offset,
246269
sequence_descriptor,
247270
dropout_rng,
248271
**kwargs,
@@ -264,7 +287,7 @@ def customcall_fused_dpa(
264287
qkv_args = (query, key, value)
265288
case _:
266289
raise ValueError(f"Unsupported {qkv_layout=}")
267-
return fused_attn(qkv_args, bias, sequence_descriptor, dropout_rng, **kwargs).astype(
290+
return fused_attn(qkv_args, bias, sequence_descriptor, dropout_rng, softmax_offset=softmax_offset, **kwargs).astype(
268291
query.dtype
269292
)
270293

@@ -412,7 +435,7 @@ def _setup_inputs(self):
412435
self.tp_size = self.mesh.shape.get(self.mesh_resource.tpsp_resource, 1)
413436

414437
key = jax.random.PRNGKey(0)
415-
q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5)
438+
q_key, k_key, v_key, bias_key, dropout_key, softmax_key = jax.random.split(key, 6)
416439

417440
q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim_qk)
418441
k_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim_qk)
@@ -462,6 +485,11 @@ def _setup_inputs(self):
462485
pad_ratio = 0.3
463486
else:
464487
pad_ratio = 0.0
488+
489+
if self.softmax_type == AttnSoftmaxType.LEARNABLE_SOFTMAX:
490+
self.softmax_offset = jax.random.uniform(softmax_key, (1, self.num_heads_q, 1, 1), self.dtype, -1.0)
491+
else:
492+
self.softmax_offset = None
465493

466494
def gen_valid(bs, max_seqlen, pad_ratio):
467495
pad_len = int(max_seqlen * pad_ratio)
@@ -682,6 +710,10 @@ def to_dp_shardings(x):
682710
self.bias_pspec = PartitionSpec()
683711
self.bias_sharding = NamedSharding(self.mesh, self.bias_pspec)
684712

713+
# Softmax offset sharding (1, num_heads, 1, 1)
714+
self.softmax_offset_pspec = PartitionSpec(None, self.mesh_resource.tpsp_resource, None, None)
715+
self.softmax_offset_sharding = NamedSharding(self.mesh, self.softmax_offset_pspec)
716+
685717
self.dropout_rng_pspec = PartitionSpec(
686718
None,
687719
)
@@ -701,7 +733,7 @@ def test_forward(self):
701733
"""
702734
self._setup_inputs()
703735

704-
args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
736+
args = [self.q, self.k, self.v, self.bias, self.softmax_offset, self.mask, self.dropout_rng]
705737

706738
customcall_args = [
707739
# Put test data onto each GPU for distributed.
@@ -711,6 +743,7 @@ def test_forward(self):
711743
jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
712744
jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
713745
jax.device_put(self.bias, self.bias_sharding),
746+
jax.device_put(self.softmax_offset, self.softmax_offset_sharding),
714747
jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
715748
jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
716749
]
@@ -736,6 +769,7 @@ def test_forward(self):
736769
self.qkvo_sharding,
737770
self.qkvo_sharding,
738771
self.bias_sharding,
772+
self.softmax_offset_sharding,
739773
self.seq_desc_sharding,
740774
self.dropout_rng_sharding,
741775
],
@@ -796,14 +830,15 @@ def grad_func(func, *args, cp_reverse_out=False, **kwargs):
796830
jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier
797831
).astype(self.dtype)
798832

799-
args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng]
833+
args = [self.q, self.k, self.v, self.bias, self.softmax_offset, self.mask, self.dropout_rng]
800834
customcall_args = [
801835
# TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
802836
# THD params once we support those features on CP.
803837
jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding),
804838
jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding),
805839
jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding),
806840
jax.device_put(self.bias, self.bias_sharding),
841+
jax.device_put(self.softmax_offset, self.softmax_offset_sharding),
807842
jax.device_put(self.sequence_desciptor, self.seq_desc_sharding),
808843
jax.device_put(self.dropout_rng, self.dropout_rng_sharding),
809844
]
@@ -822,6 +857,7 @@ def grad_func(func, *args, cp_reverse_out=False, **kwargs):
822857
}
823858

824859
# We can compute dBias only for the [1, h, s, s] layout
860+
# arg positions: q=0, k=1, v=2, bias=3, softmax_offset=4
825861
if self.bias_shape == BiasShape._1HSS:
826862
arg_nums = (0, 1, 2, 3)
827863
grad_shardings = (
@@ -837,8 +873,8 @@ def grad_func(func, *args, cp_reverse_out=False, **kwargs):
837873
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
838874
jitted_primitive = jit(
839875
value_and_grad(
840-
lambda q, k, v, bias, *args: grad_func(
841-
customcall_fused_dpa, q, k, v, bias, *args, cp_reverse_out=True, **kwargs
876+
lambda q, k, v, bias, softmax_offset, *args: grad_func(
877+
customcall_fused_dpa, q, k, v, bias, softmax_offset, *args, cp_reverse_out=True, **kwargs
842878
),
843879
arg_nums,
844880
),
@@ -847,14 +883,15 @@ def grad_func(func, *args, cp_reverse_out=False, **kwargs):
847883
self.qkvo_sharding,
848884
self.qkvo_sharding,
849885
self.bias_sharding,
886+
self.softmax_offset_sharding,
850887
self.seq_desc_sharding,
851888
self.dropout_rng_sharding,
852889
),
853890
out_shardings=(None, grad_shardings),
854891
)
855892
jitted_reference = jit(
856893
value_and_grad(
857-
lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
894+
lambda q, k, v, bias, softmax_offset, *args: grad_func(jax_dpa, q, k, v, bias, softmax_offset, *args, **kwargs),
858895
arg_nums,
859896
)
860897
)
@@ -1097,6 +1134,7 @@ def _test_forward(
10971134
seq_desc_format,
10981135
)
10991136
runner.test_forward()
1137+
11001138

11011139
@staticmethod
11021140
@pytest.mark.parametrize(
@@ -1150,3 +1188,4 @@ def test_backward(
11501188
seq_desc_format,
11511189
)
11521190
runner.test_backward()
1191+

transformer_engine/jax/attention.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,7 @@ def _legacy_fused_attn(
706706
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
707707
context_parallel_causal_load_balanced: bool = False,
708708
context_parallel_axis: str = "",
709+
softmax_offset: Optional[jnp.ndarray] = None,
709710
):
710711
"""
711712
Perform non-THD (non-packed) cuDNN fused attention.
@@ -777,7 +778,7 @@ def _legacy_fused_attn(
777778
output = _fused_attn(
778779
qkv,
779780
bias,
780-
None,
781+
softmax_offset,
781782
SequenceDescriptor.from_seqlens((q_seq_lens, kv_seq_lens)),
782783
seed,
783784
attn_bias_type=attn_bias_type,
@@ -816,6 +817,7 @@ def fused_attn_thd(
816817
context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT,
817818
context_parallel_causal_load_balanced: bool = False,
818819
context_parallel_axis: str = "",
820+
softmax_offset: Optional[jnp.ndarray] = None,
819821
):
820822
"""
821823
Deprecated THD fused attn, please use fusd_attn with SequenceDescriptor
@@ -853,7 +855,7 @@ def fused_attn_thd(
853855
output = _fused_attn(
854856
qkv,
855857
bias,
856-
None,
858+
softmax_offset,
857859
SequenceDescriptor.from_seqlens_and_offsets(
858860
(q_seq_lens, kv_seq_lens), (q_seq_offsets, kv_seq_offsets)
859861
),
@@ -1023,6 +1025,8 @@ def _fused_attn_bwd_rule(
10231025
)
10241026
if attn_bias_type == AttnBiasType.NO_BIAS:
10251027
grad_bias = None
1028+
if softmax_type != AttnSoftmaxType.LEARNABLE_SOFTMAX:
1029+
grad_softmax_offset = None
10261030
return (
10271031
grad_qkv,
10281032
grad_bias,
@@ -1053,6 +1057,7 @@ def fused_attn(
10531057
context_parallel_causal_load_balanced: bool = False,
10541058
context_parallel_axis: str = "",
10551059
context_checkpoint_name: str = "context",
1060+
softmax_offset: Optional[jnp.ndarray] = None,
10561061
):
10571062
"""
10581063
Perform cuDNN fused attention.
@@ -1087,6 +1092,9 @@ def fused_attn(
10871092
Indicates the sequences are ordered for causal mask load balancing when running context parallelism.
10881093
context_parallel_axis (str): The name of the context parallel axis.
10891094
context_checkpoint_name (str): The name of the context checkpoint for the custom VJP forward pass.
1095+
softmax_offset (Optional[jnp.ndarray]): An optional learnable softmax offset tensor with shape
1096+
[1, num_heads, 1, 1]. Used when softmax_type is AttnSoftmaxType.LEARNABLE_SOFTMAX.
1097+
If provided, this parameter will receive gradients during backpropagation.
10901098
Returns:
10911099
(jnp.ndarray): The output tensor from the fused attention.
10921100
@@ -1143,15 +1151,18 @@ def fused_attn(
11431151
context_parallel_strategy=context_parallel_strategy,
11441152
context_parallel_causal_load_balanced=context_parallel_causal_load_balanced,
11451153
context_parallel_axis=context_parallel_axis,
1154+
softmax_offset=softmax_offset,
11461155
)
11471156
output = _fused_attn(
11481157
qkv,
11491158
bias,
1159+
softmax_offset,
11501160
sequence_descriptor,
11511161
seed,
11521162
attn_bias_type=attn_bias_type,
11531163
attn_mask_type=attn_mask_type,
11541164
qkv_layout=qkv_layout,
1165+
softmax_type=softmax_type,
11551166
scaling_factor=scaling_factor,
11561167
dropout_probability=dropout_probability,
11571168
is_training=is_training,

transformer_engine/jax/cpp_extensions/attention.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -813,9 +813,16 @@ def abstract(
813813
shape=wkspace_shape, dtype=te_dtype_to_jax_dtype(wkspace_dtype)
814814
)
815815

816-
dsoftmax_offset_aval = q_aval.update(
817-
shape=softmax_offset_aval.shape, dtype=softmax_offset_aval.dtype
818-
)
816+
# dsoftmax_offset should always have shape [1, attn_heads, 1, 1] when softmax_type is not VANILLA_SOFTMAX
817+
# This matches the cuDNN graph requirements and PyTorch implementation
818+
if config.softmax_type == AttnSoftmaxType.VANILLA_SOFTMAX:
819+
dsoftmax_offset_aval = q_aval.update(
820+
shape=softmax_offset_aval.shape, dtype=softmax_offset_aval.dtype
821+
)
822+
else:
823+
dsoftmax_offset_aval = q_aval.update(
824+
shape=(1, attn_heads, 1, 1), dtype=jnp.float32
825+
)
819826

820827
return dq_aval, dk_aval, dv_aval, dbias_aval, dsoftmax_offset_aval, wkspace_aval
821828

@@ -2664,9 +2671,15 @@ def fused_attn_fwd(
26642671
if softmax_offset is None:
26652672
assert softmax_type != AttnSoftmaxType.LEARNABLE_SOFTMAX, f"Unknown {softmax_type=}"
26662673
if softmax_type == AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
2667-
raise NotImplementedError(
2668-
"Off-by-one softmax is not supported when softmax_offset is None"
2669-
)
2674+
# Extract number of heads from qkv shape
2675+
# For qkvpacked (BS3HD): shape is (..., seq, 3, heads, dim) → index -2
2676+
# For separate/kvpacked (BSHD): shape is (..., seq, heads, dim) → index -2
2677+
if qkv_layout.is_qkvpacked():
2678+
num_heads = qkv[0].shape[-2] # heads is at index -2 for BS3HD
2679+
else:
2680+
num_heads = qkv[0].shape[-2] # heads is at index -2 for BSHD
2681+
# Create properly-sized tensor [1, h, 1, 1] filled with zeros
2682+
softmax_offset = jnp.zeros((1, num_heads, 1, 1), dtype=jnp.float32)
26702683
elif softmax_type == AttnSoftmaxType.VANILLA_SOFTMAX:
26712684
softmax_offset = jnp.zeros(0, dtype=qkv[0].dtype)
26722685
else:
@@ -2803,9 +2816,15 @@ def fused_attn_bwd(
28032816
if softmax_offset is None:
28042817
assert softmax_type != AttnSoftmaxType.LEARNABLE_SOFTMAX, f"Unknown {softmax_type=}"
28052818
if softmax_type == AttnSoftmaxType.OFF_BY_ONE_SOFTMAX:
2806-
raise NotImplementedError(
2807-
"Off-by-one softmax is not supported when softmax_offset is None"
2808-
)
2819+
# Extract number of heads from qkv shape
2820+
# For qkvpacked (BS3HD): shape is (..., seq, 3, heads, dim) → index -2
2821+
# For separate/kvpacked (BSHD): shape is (..., seq, heads, dim) → index -2
2822+
if qkv_layout.is_qkvpacked():
2823+
num_heads = qkv[0].shape[-2] # heads is at index -2 for BS3HD
2824+
else:
2825+
num_heads = qkv[0].shape[-2] # heads is at index -2 for BSHD
2826+
# Create properly-sized tensor [1, h, 1, 1] filled with zeros
2827+
softmax_offset = jnp.zeros((1, num_heads, 1, 1), dtype=jnp.float32)
28092828
elif softmax_type == AttnSoftmaxType.VANILLA_SOFTMAX:
28102829
softmax_offset = jnp.zeros(0, dtype=qkv[0].dtype)
28112830
else:

0 commit comments

Comments
 (0)