@@ -59,14 +59,16 @@ def init():
5959 yield
6060
6161
62- @partial (jax .jit , static_argnums = (5 , 6 , 7 , 9 ))
62+ @partial (jax .jit , static_argnums = (6 , 7 , 8 , 9 , 11 ))
6363def general_dot_product_attention (
6464 query : ArrayLike ,
6565 key : ArrayLike ,
6666 value : ArrayLike ,
67+ softmax_offset : Optional [ArrayLike ],
6768 bias : ArrayLike ,
6869 mask : ArrayLike ,
6970 deterministic : bool ,
71+ softmax_type : AttnSoftmaxType ,
7072 scale_factor : float ,
7173 dropout_rate : float ,
7274 dropout_rng : ArrayLike ,
@@ -99,7 +101,25 @@ def general_dot_product_attention(
99101 mask = jnp .expand_dims (mask , axis = - 3 )
100102 logits = jnp .where (mask , jnp .finfo (dtype ).min , logits )
101103
102- softmax_out = jax .nn .softmax (logits ).astype (dtype )
104+ match softmax_type :
105+ case AttnSoftmaxType .VANILLA_SOFTMAX :
106+ softmax_out = jax .nn .softmax (logits ).astype (dtype )
107+ case AttnSoftmaxType .OFF_BY_ONE_SOFTMAX :
108+ # Softmax with +1 in denominator: exp(x_i) / (sum(exp(x_j)) + 1)
109+ exp_logits = jnp .exp (logits - jnp .max (logits , axis = - 1 , keepdims = True ))
110+ softmax_out = (exp_logits / (jnp .sum (exp_logits , axis = - 1 , keepdims = True ) + 1.0 )).astype (dtype )
111+ case AttnSoftmaxType .LEARNABLE_SOFTMAX :
112+ # Reshape softmax_offset from (1, h_q, 1, 1) to (1, h_kv, num_groups, 1, 1) to match logits
113+ # logits shape: (b, h_kv, num_groups, s_q, s_kv)
114+ if softmax_offset is not None and softmax_offset .size > 0 :
115+ softmax_offset_reshaped = softmax_offset .reshape (1 , h_kv , num_groups , 1 , 1 )
116+ else :
117+ softmax_offset_reshaped = jnp .zeros ((1 , h_kv , num_groups , 1 , 1 ), dtype = jnp .float32 )
118+ exp_logits = jnp .exp (logits - jnp .max (logits , axis = - 1 , keepdims = True ))
119+ softmax_out = (exp_logits / (jnp .sum (exp_logits , axis = - 1 , keepdims = True ) + jnp .exp (softmax_offset_reshaped ))).astype (dtype )
120+ case _:
121+ raise NotImplementedError (f"Unknown { softmax_type = } " )
122+
103123
104124 if not deterministic and dropout_rate > 0.0 :
105125 keep_prob = 1.0 - dropout_rate
@@ -219,19 +239,21 @@ def _split_valid_and_invalid(primitive, reference, pad):
219239 return primitive_valid , primitive_invalid , reference_valid , reference_invalid
220240
221241
222- def jax_dpa (query , key , value , bias , mask , dropout_rng , ** kwargs ):
242+ def jax_dpa (query , key , value , bias , softmax_offset , mask , dropout_rng , ** kwargs ):
223243 """
224244 JAX native dot product attention implementation
225245 """
226246 output = general_dot_product_attention (
227247 query ,
228248 key ,
229249 value ,
250+ softmax_offset ,
230251 bias ,
231252 mask ,
232253 deterministic = not kwargs ["is_training" ],
233254 scale_factor = kwargs ["scaling_factor" ],
234255 dropout_rate = kwargs ["dropout_probability" ],
256+ softmax_type = kwargs ["softmax_type" ],
235257 dropout_rng = dropout_rng ,
236258 dtype = jnp .float32 ,
237259 )
@@ -243,6 +265,7 @@ def customcall_fused_dpa(
243265 key ,
244266 value ,
245267 bias ,
268+ softmax_offset ,
246269 sequence_descriptor ,
247270 dropout_rng ,
248271 ** kwargs ,
@@ -264,7 +287,7 @@ def customcall_fused_dpa(
264287 qkv_args = (query , key , value )
265288 case _:
266289 raise ValueError (f"Unsupported { qkv_layout = } " )
267- return fused_attn (qkv_args , bias , sequence_descriptor , dropout_rng , ** kwargs ).astype (
290+ return fused_attn (qkv_args , bias , sequence_descriptor , dropout_rng , softmax_offset = softmax_offset , ** kwargs ).astype (
268291 query .dtype
269292 )
270293
@@ -412,7 +435,7 @@ def _setup_inputs(self):
412435 self .tp_size = self .mesh .shape .get (self .mesh_resource .tpsp_resource , 1 )
413436
414437 key = jax .random .PRNGKey (0 )
415- q_key , k_key , v_key , bias_key , dropout_key = jax .random .split (key , 5 )
438+ q_key , k_key , v_key , bias_key , dropout_key , softmax_key = jax .random .split (key , 6 )
416439
417440 q_shape = (self .batch_size , self .max_seqlen_q , self .num_heads_q , self .head_dim_qk )
418441 k_shape = (self .batch_size , self .max_seqlen_kv , self .num_heads_kv , self .head_dim_qk )
@@ -462,6 +485,11 @@ def _setup_inputs(self):
462485 pad_ratio = 0.3
463486 else :
464487 pad_ratio = 0.0
488+
489+ if self .softmax_type == AttnSoftmaxType .LEARNABLE_SOFTMAX :
490+ self .softmax_offset = jax .random .uniform (softmax_key , (1 , self .num_heads_q , 1 , 1 ), self .dtype , - 1.0 )
491+ else :
492+ self .softmax_offset = None
465493
466494 def gen_valid (bs , max_seqlen , pad_ratio ):
467495 pad_len = int (max_seqlen * pad_ratio )
@@ -682,6 +710,10 @@ def to_dp_shardings(x):
682710 self .bias_pspec = PartitionSpec ()
683711 self .bias_sharding = NamedSharding (self .mesh , self .bias_pspec )
684712
713+ # Softmax offset sharding (1, num_heads, 1, 1)
714+ self .softmax_offset_pspec = PartitionSpec (None , self .mesh_resource .tpsp_resource , None , None )
715+ self .softmax_offset_sharding = NamedSharding (self .mesh , self .softmax_offset_pspec )
716+
685717 self .dropout_rng_pspec = PartitionSpec (
686718 None ,
687719 )
@@ -701,7 +733,7 @@ def test_forward(self):
701733 """
702734 self ._setup_inputs ()
703735
704- args = [self .q , self .k , self .v , self .bias , self .mask , self .dropout_rng ]
736+ args = [self .q , self .k , self .v , self .bias , self .softmax_offset , self . mask , self .dropout_rng ]
705737
706738 customcall_args = [
707739 # Put test data onto each GPU for distributed.
@@ -711,6 +743,7 @@ def test_forward(self):
711743 jax .device_put (self .cp_reorder_fn (self .k ), self .qkvo_sharding ),
712744 jax .device_put (self .cp_reorder_fn (self .v ), self .qkvo_sharding ),
713745 jax .device_put (self .bias , self .bias_sharding ),
746+ jax .device_put (self .softmax_offset , self .softmax_offset_sharding ),
714747 jax .device_put (self .sequence_desciptor , self .seq_desc_sharding ),
715748 jax .device_put (self .dropout_rng , self .dropout_rng_sharding ),
716749 ]
@@ -736,6 +769,7 @@ def test_forward(self):
736769 self .qkvo_sharding ,
737770 self .qkvo_sharding ,
738771 self .bias_sharding ,
772+ self .softmax_offset_sharding ,
739773 self .seq_desc_sharding ,
740774 self .dropout_rng_sharding ,
741775 ],
@@ -796,14 +830,15 @@ def grad_func(func, *args, cp_reverse_out=False, **kwargs):
796830 jnp .mean (ret_valid .astype (jnp .float32 ), dtype = jnp .float32 ) * gradient_multiplier
797831 ).astype (self .dtype )
798832
799- args = [self .q , self .k , self .v , self .bias , self .mask , self .dropout_rng ]
833+ args = [self .q , self .k , self .v , self .bias , self .softmax_offset , self . mask , self .dropout_rng ]
800834 customcall_args = [
801835 # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and
802836 # THD params once we support those features on CP.
803837 jax .device_put (self .cp_reorder_fn (self .q ), self .qkvo_sharding ),
804838 jax .device_put (self .cp_reorder_fn (self .k ), self .qkvo_sharding ),
805839 jax .device_put (self .cp_reorder_fn (self .v ), self .qkvo_sharding ),
806840 jax .device_put (self .bias , self .bias_sharding ),
841+ jax .device_put (self .softmax_offset , self .softmax_offset_sharding ),
807842 jax .device_put (self .sequence_desciptor , self .seq_desc_sharding ),
808843 jax .device_put (self .dropout_rng , self .dropout_rng_sharding ),
809844 ]
@@ -822,6 +857,7 @@ def grad_func(func, *args, cp_reverse_out=False, **kwargs):
822857 }
823858
824859 # We can compute dBias only for the [1, h, s, s] layout
860+ # arg positions: q=0, k=1, v=2, bias=3, softmax_offset=4
825861 if self .bias_shape == BiasShape ._1HSS :
826862 arg_nums = (0 , 1 , 2 , 3 )
827863 grad_shardings = (
@@ -837,8 +873,8 @@ def grad_func(func, *args, cp_reverse_out=False, **kwargs):
837873 # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
838874 jitted_primitive = jit (
839875 value_and_grad (
840- lambda q , k , v , bias , * args : grad_func (
841- customcall_fused_dpa , q , k , v , bias , * args , cp_reverse_out = True , ** kwargs
876+ lambda q , k , v , bias , softmax_offset , * args : grad_func (
877+ customcall_fused_dpa , q , k , v , bias , softmax_offset , * args , cp_reverse_out = True , ** kwargs
842878 ),
843879 arg_nums ,
844880 ),
@@ -847,14 +883,15 @@ def grad_func(func, *args, cp_reverse_out=False, **kwargs):
847883 self .qkvo_sharding ,
848884 self .qkvo_sharding ,
849885 self .bias_sharding ,
886+ self .softmax_offset_sharding ,
850887 self .seq_desc_sharding ,
851888 self .dropout_rng_sharding ,
852889 ),
853890 out_shardings = (None , grad_shardings ),
854891 )
855892 jitted_reference = jit (
856893 value_and_grad (
857- lambda q , k , v , bias , * args : grad_func (jax_dpa , q , k , v , bias , * args , ** kwargs ),
894+ lambda q , k , v , bias , softmax_offset , * args : grad_func (jax_dpa , q , k , v , bias , softmax_offset , * args , ** kwargs ),
858895 arg_nums ,
859896 )
860897 )
@@ -1097,6 +1134,7 @@ def _test_forward(
10971134 seq_desc_format ,
10981135 )
10991136 runner .test_forward ()
1137+
11001138
11011139 @staticmethod
11021140 @pytest .mark .parametrize (
@@ -1150,3 +1188,4 @@ def test_backward(
11501188 seq_desc_format ,
11511189 )
11521190 runner .test_backward ()
1191+
0 commit comments