From 3ca52db5a260b79bf08565071b759bea476672b7 Mon Sep 17 00:00:00 2001 From: Ian Wang Date: Fri, 18 Oct 2024 13:11:26 -0500 Subject: [PATCH] Support FP8 for dot_product_attention --- jax/_src/cudnn/fused_attention_stablehlo.py | 933 +++++++++++++++++--- tests/fused_attention_stablehlo_test.py | 198 ++++- 2 files changed, 1011 insertions(+), 120 deletions(-) diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index e20271f66301..561aa8f825a0 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -99,18 +99,10 @@ def element_type_to_backend_config_type_mapping(dtype): def default_layouts(*shapes): return [range(len(shape) - 1, -1, -1) for shape in shapes] -def create_dot_product_attention_backend_config(batch, - num_heads, - seq_q, - seq_kv, - dtype, - fmha_scale, - seed, - dropout_rate, - mask_type, - layout, - sliding_window_length, - is_bwd): + +def create_dot_product_attention_backend_config_base( + batch, num_heads, seq_q, seq_kv, dtype,fmha_scale, mask_type, layout, is_bwd +): # Q, K, V: query, key, value in shape of BT(S)NH or BNT(S)H # P: BMM1 output in shape of BNTS # O: BMM2 output in the same shape with Q @@ -120,8 +112,6 @@ def create_dot_product_attention_backend_config(batch, # BMM1Grad2: dP @ K -> dQ # BMM2Grad1: P @ dO -> dV # BMM2Grad2: dO @ V -> dP - if sliding_window_length is None: - sliding_window_length = 0 cudnn_fmha_backend_config = { "algorithm": { "algo_id": "0", @@ -131,7 +121,6 @@ def create_dot_product_attention_backend_config(batch, "workspace_size": "0", }, "fmha_scale": fmha_scale, - "dropout_rate": dropout_rate, "intermediate_tensor_shape": { "element_type": element_type_to_backend_config_type_mapping(dtype), "dimensions": [str(batch), str(num_heads), str(seq_q), str(seq_kv)], @@ -150,10 +139,8 @@ def create_dot_product_attention_backend_config(batch, }, "is_dynamic_dimension": [False, False, False, False], }, - "seed": seed, "is_flash_attention": True, "mask_type": convert_mask_type_to_string(mask_type), - "sliding_window_length": sliding_window_length, } # We define the contracting and batch dims in the format of @@ -200,31 +187,80 @@ def create_dot_product_attention_backend_config(batch, cudnn_fmha_backend_config = {**cudnn_fmha_backend_config, **bwd_dot_number} else: cudnn_fmha_backend_config = {**cudnn_fmha_backend_config, **fwd_dot_number} - backend_config = { "operation_queue_id":"0", "wait_on_operation_queues":[], "cudnn_fmha_backend_config": cudnn_fmha_backend_config } - backend_config = json.dumps(backend_config) return backend_config +def create_dot_product_attention_backend_config( + batch, + num_heads, + seq_q, + seq_kv, + dtype, + fmha_scale, + seed, + dropout_rate, + mask_type, + layout, + sliding_window_length, + is_bwd +): + backend_config = create_dot_product_attention_backend_config_base( + batch, num_heads, seq_q, seq_kv, dtype, + fmha_scale, mask_type, layout, is_bwd + ) + if sliding_window_length is None: + sliding_window_length = 0 + backend_config['cudnn_fmha_backend_config']["dropout_rate"] = dropout_rate + backend_config['cudnn_fmha_backend_config']["seed"] = seed + backend_config['cudnn_fmha_backend_config']["sliding_window_length"] = sliding_window_length + return json.dumps(backend_config) + +def create_dot_product_attention_fp8_backend_config(batch, + num_heads, + seq_q, + seq_kv, + dtype, + fmha_scale, + mask_type, + layout, + is_bwd): + backend_config = create_dot_product_attention_backend_config_base(batch, + num_heads, + seq_q, + seq_kv, + dtype, + fmha_scale, + mask_type, + layout, + is_bwd) + return json.dumps(backend_config) + # mapping from (is_bwd, has_dropout, has_bias) to custom call name _custom_name_maps = { # fMHA forward call targets. - (False, False, False): "__cudnn$fmhaSoftmax", - (False, False, True): "__cudnn$fmhaScaleBiasSoftmax", - (False, True, False): "__cudnn$fmhaSoftmaxDropout", - (False, True, True): "__cudnn$fmhaScaleBiasSoftmaxDropout", + (False, False, False, False): "__cudnn$fmhaSoftmax", + (False, False, True, False): "__cudnn$fmhaScaleBiasSoftmax", + (False, True, False, False): "__cudnn$fmhaSoftmaxDropout", + (False, True, True, False): "__cudnn$fmhaScaleBiasSoftmaxDropout", + (False, False, False, True): "__cudnn$fmhaSoftmaxF8", # fMHA backward call targets. - (True, False, False): "__cudnn$fmhaSoftmaxBackward", - (True, False, True): "__cudnn$fmhaScaleBiasSoftmaxBackward", - (True, True, False): "__cudnn$fmhaSoftmaxDropoutBackward", - (True, True, True): "__cudnn$fmhaScaleBiasSoftmaxDropoutBackward", + (True, False, False, False): "__cudnn$fmhaSoftmaxBackward", + (True, False, True, False): "__cudnn$fmhaScaleBiasSoftmaxBackward", + (True, True, False, False): "__cudnn$fmhaSoftmaxDropoutBackward", + (True, True, True, False): "__cudnn$fmhaScaleBiasSoftmaxDropoutBackward", + (True, False, False, True): "__cudnn$fmhaSoftmaxBackwardF8", } -def get_custom_call_name(has_bias, has_dropout, is_bwd): - return _custom_name_maps[(is_bwd, has_dropout, has_bias)] +def get_custom_call_name(has_bias, has_dropout, is_bwd, is_fp8=False): + return _custom_name_maps[(is_bwd, has_dropout, has_bias, is_fp8)] + +get_fp8_custom_call_name = functools.partial( + get_custom_call_name, has_bias=False, has_dropout=False, is_fp8=True +) def check_layout(query, key, value, bias, q_seqlen, kv_seqlen, layout): def check_eq(a, b, c, msg): @@ -237,8 +273,8 @@ def check_eq(a, b, c, msg): check_eq(q_rank, k_rank, v_rank, "QKV rank") q_dtype, k_dtype, v_dtype = query.dtype, key.dtype, value.dtype - if q_dtype not in [jnp.bfloat16, jnp.float16]: - raise NotImplementedError(f"Q must be fp16 or bf16, got {q_dtype}") + if q_dtype not in [jnp.bfloat16, jnp.float16, jnp.float8_e4m3fn, jnp.float8_e5m2]: + raise NotImplementedError(f"Q must be fp16/bf16/fp8_e4m3fn/fp8_e5m2, got {q_dtype}") check_eq(q_dtype, k_dtype, v_dtype, "QKV dtype") if layout == AttentionLayout.BNTH: @@ -287,24 +323,44 @@ def check_eq(a, b, c, msg): raise ValueError(f"kv_seqlen must have same batch as Q, got {kv_seq_b}") def check_is_flash_attention( - query, key, layout: int, cudnn_version, has_bias, is_training): - if layout == AttentionLayout.BNTH.value: - _, _, T, H = query.shape - _, _, S, _ = key.shape - else: - _, T, _, H = query.shape - _, S, _, _ = key.shape - - if not ((H <= 128 and H % 8 == 0) and - (not is_training or not has_bias or T % 2 == 0 and S % 2 == 0)): - # check if flash attention is supported - # for training, for patterns with bias, seqlen should be divisible by 2 - raise NotImplementedError( - f"Unsupported sequence length Q {T}, KV {S} and head dim {H}.") - # check if minimum cudnn version requirement is satisfied - if cudnn_version < 8904: - raise RuntimeError( - "JAX requires cuDNN >= 8.9.4 to use flash cross attention.") + query, key, layout: int, cudnn_version, has_bias, is_training, is_fp8=False): + # Check minimum cuDNN version requirement based on FP8 flag + if is_fp8: + if cudnn_version < 9010: + raise RuntimeError( + "JAX requires cuDNN >= 9.1.0 to use fp8 flash cross attention." + ) + if has_bias: + raise ValueError("fp8 flash attention doesn't support bias yet") + else: + if cudnn_version < 8904: + raise RuntimeError( + "JAX requires cuDNN >= 8.9.4 to use flash cross attention." + ) + + # Extract sequence length (T) and head dim (H) based on layout + if layout == AttentionLayout.BNTH.value: + _, _, T, H = query.shape + _, _, S, _ = key.shape + else: + _, T, _, H = query.shape + _, S, _, _ = key.shape + + # Flash attention conditions + if is_fp8: + # FP8 specific conditions + if not ((is_training and H == 128 and T % 128 == 0 and S % 128 == 0) or + (not is_training and H <= 256 and H % 16 == 0)): + raise NotImplementedError( + f"Unsupported sequence length Q {T}, KV {S} and head dim {H} for FP8." + ) + else: + # Regular attention conditions + if not ((H <= 128 and H % 8 == 0) and + (not is_training or not has_bias or T % 2 == 0 and S % 2 == 0)): + raise NotImplementedError( + f"Unsupported sequence length Q {T}, KV {S} and head dim {H}." + ) def check_cudnn_version(): # check if cuDNN is installed @@ -730,11 +786,14 @@ def _get_padded_spec(arg_info): return spec + (None,) * (ndim - len(spec)) def _check_qkv_bias_mask_spec( - query_spec, key_spec, value_spec, bias_spec): + query_spec, key_spec, value_spec, bias_spec, layout): # check qkv spec if not query_spec == key_spec == value_spec: raise ValueError("Query, key and value should have same sharding.") - *batch_spec, q_seq_spec, num_head_spec, head_spec = query_spec + if layout == AttentionLayout.BNTH.value: + *batch_spec, num_head_spec, q_seq_spec, head_spec = query_spec + else: + *batch_spec, q_seq_spec, num_head_spec, head_spec = query_spec if q_seq_spec is not None: raise ValueError("Sharding on sequence dim is not allowed.") if head_spec is not None: @@ -749,8 +808,9 @@ def _check_qkv_bias_mask_spec( if bias_q_seq_spec is not None or bias_kv_seq_spec is not None: raise ValueError("Sharding on bias sequence dim is not allowed.") + # fwd custom partition -def _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args, is_training): +def _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args,is_training, layout): # only sharding on batch and num_head dim is allowed # (*batch, q_seq, num_head, head) query_spec = _get_padded_spec(arg_shapes[0]) @@ -761,7 +821,7 @@ def _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args, is_training): bias_spec = _get_padded_spec(arg_shapes[3]) if has_bias else None _check_qkv_bias_mask_spec( - query_spec, key_spec, value_spec, bias_spec) + query_spec, key_spec, value_spec, bias_spec, layout) # keep out sharding same as query sharding since they have same shape out_sharding = NamedSharding(mesh, PartitionSpec(*query_spec)) if is_training: @@ -778,7 +838,7 @@ def _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args, is_training): def _dot_product_attention_fwd_infer_sharding_from_operands( scale, seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length, is_training, mesh, arg_shapes, result_shape): - return _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args, is_training) + return _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args, is_training, layout) def _dot_product_attention_fwd_partition( scale, seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length, @@ -786,7 +846,7 @@ def _dot_product_attention_fwd_partition( # args sharding arg_shardings = tuple(arg_i.sharding for arg_i in arg_shapes) out_shardings = _infer_fwd_output_sharding( - mesh, arg_shapes, variadic_args, is_training) + mesh, arg_shapes, variadic_args, is_training, layout) impl = functools.partial( _dot_product_attention_fwd_impl, scale=scale, @@ -801,7 +861,7 @@ def _dot_product_attention_fwd_partition( return mesh, impl, out_shardings, arg_shardings # bwd custom partition -def _infer_bwd_output_sharding(mesh, arg_shapes, variadic_args): +def _infer_bwd_output_sharding(mesh, arg_shapes, layout, variadic_args): # (*batch, q_seq, num_head, head) query_spec = _get_padded_spec(arg_shapes[0]) # (*batch, kv_seq, num_head, head) @@ -810,7 +870,7 @@ def _infer_bwd_output_sharding(mesh, arg_shapes, variadic_args): has_bias, has_dbias = variadic_args bias_spec = _get_padded_spec(arg_shapes[3]) if has_bias else None _check_qkv_bias_mask_spec( - query_spec, key_spec, value_spec, bias_spec) + query_spec, key_spec, value_spec, bias_spec, layout) # keep grad query sharding same as query sharding grad_query_sharding = NamedSharding(mesh, PartitionSpec(*query_spec)) grad_key_sharding = NamedSharding(mesh, PartitionSpec(*key_spec)) @@ -828,12 +888,12 @@ def _infer_bwd_output_sharding(mesh, arg_shapes, variadic_args): def _dot_product_attention_bwd_infer_sharding_from_operands( scale, seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length, mesh, arg_shapes, result_shape): - return _infer_bwd_output_sharding(mesh, arg_shapes, variadic_args) + return _infer_bwd_output_sharding(mesh, arg_shapes, layout, variadic_args) def _dot_product_attention_bwd_partition( scale, seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length, mesh, arg_shapes, result_shape): - out_shardings = _infer_bwd_output_sharding(mesh, arg_shapes, variadic_args) + out_shardings = _infer_bwd_output_sharding(mesh, arg_shapes, layout, variadic_args) # args sharding arg_shardings = tuple(arg_i.sharding for arg_i in arg_shapes) def sharded_impl(*args): @@ -942,7 +1002,6 @@ def sharded_impl(*args): _dot_product_attention_bwd_p_wrapper ) - @functools.partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11, 12, 13)) def _dot_product_attention(query: Array, key: Array, @@ -965,24 +1024,624 @@ def _dot_product_attention(query: Array, cudnn_version=cudnn_version) return output -# _dot_product_attention_fwd must have the same func signature as _dot_product_attention -_dot_product_attention.defvjp(_dot_product_attention_fwd_rule, _dot_product_attention_bwd_rule) +fp8_params_keys = [ + 'amax_dQ', 'amax_dK', 'amax_dV', 'amax_dP', # place holder for bwd output + 'descale_q', 'descale_k', 'descale_v', 'descale_s', + 'scale_s', 'scale_o', 'descale_o', 'descale_dO', + 'descale_dP', 'scale_dQ', 'scale_dK', 'scale_dV', + 'scale_dP' +] + +fp8_params_keys_fwd = [ + 'descale_q', 'descale_k', 'descale_v', 'descale_s', 'scale_s', 'scale_o' +] +fp8_params_keys_bwd = [ + 'descale_q', 'descale_k', 'descale_v', 'descale_o', 'descale_dO', 'descale_s', + 'descale_dP', 'scale_s', 'scale_dQ', 'scale_dK', 'scale_dV', 'scale_dP', +] +params_from_keys = lambda params, keys: [params[key] for key in keys] + +def check_fp8_params(params): + # Check if all required keys are present + missing_keys = [key for key in fp8_params_keys if key not in params] + if missing_keys: + raise ValueError(f"The following keys are missing from fp8_params: {', '.join(missing_keys)}") + +check_is_flash_attention_fp8 = functools.partial( + check_is_flash_attention, + has_bias=False, + is_fp8=True +) + +def _dot_product_attention_fp8_fwd( + query, key, value, + fp8_params_fwd, + scale, use_causal_mask, layout, cudnn_version): + # check if flash attention is supported for this attention pattern + check_is_flash_attention_fp8( + query, key, layout, cudnn_version, is_training=False) + descale_q, descale_k, descale_v, descale_s, scale_s, scale_o = fp8_params_fwd + outputs = _dot_product_attention_fp8_fwd_p_wrapper.bind( + query, key, value, + descale_q, descale_k, descale_v, descale_s, + scale_s, scale_o, + scale=scale, use_causal_mask=use_causal_mask, layout=layout, is_training=False) + return outputs + +def _dot_product_attention_fp8_fwd_rule( + query, key, value, + fp8_params, + scale, use_causal_mask, layout, cudnn_version): + # check if flash attention is supported for this attention pattern + check_is_flash_attention_fp8( + query, key, layout, cudnn_version, is_training=True) + + outputs = _dot_product_attention_fp8_fwd_p_wrapper.bind( + query, key, value, *params_from_keys(fp8_params, fp8_params_keys_fwd), + scale=scale, use_causal_mask=use_causal_mask, layout=layout, is_training=True) + res = (query, key, value, outputs[3], outputs[0], params_from_keys(fp8_params, fp8_params_keys_bwd)) + return (outputs[0], outputs[1], outputs[2]), res + +def _dot_product_attention_fp8_bwd_rule( + scale, use_causal_mask, layout, cudnn_version, res, g): + (query, key, value, activation, fwd_output, aux_params) = res + grad_output = g[0] + # dQ, dK, dV, amax_dq, amax_dk ,amax_dv, amax_dp + grads = _dot_product_attention_fp8_bwd_p_wrapper.bind( + query, + key, + value, + fwd_output, + grad_output, + activation, + *aux_params, + scale=scale, + use_causal_mask=use_causal_mask, + layout=layout, + ) + + fp8_params_grads = dict.fromkeys(fp8_params_keys) + keys_to_grad_indices = ['amax_dQ', 'amax_dK', 'amax_dV', 'amax_dP'] + for i, key in enumerate(keys_to_grad_indices, start=3): + fp8_params_grads[key] = grads[i] + + return (grads[0], grads[1], grads[2], fp8_params_grads) + +def _dot_product_attention_fp8_fwd_impl( + query, key, value, + descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, + scale, use_causal_mask, layout, is_training): + # args: {Q, K, V, mask*} + outputs = _dot_product_attention_fp8_fwd_p.bind( + query, + key, + value, + descale_q, + descale_k, + descale_v, + descale_s, + scale_s, + scale_o, + scale=scale, + use_causal_mask=use_causal_mask, + layout=layout, + is_training=is_training, + ) + return outputs + +def _dot_product_attention_fp8_bwd_impl( + query, key, value, fwd_output, grad_output, activation, + descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, + descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, + scale, use_causal_mask, layout): + grads = _dot_product_attention_fp8_bwd_p.bind( + query, key, value, fwd_output, grad_output, activation, + descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, + descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, + scale=scale, use_causal_mask=use_causal_mask, layout=layout) + return grads # dQ, dK, dV, amax_dq, amax_dk ,amax_dv, amax_dp + + +def _dot_product_attention_fp8_fwd_abstract( + query, key, value, + descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, + scale, use_causal_mask, layout, is_training): + query_dtype = dtypes.canonicalize_dtype(query.dtype) + if layout == AttentionLayout.BNTH.value: + B, N, T, _ = query.shape + _, _, S, _ = key.shape + else: + B, T, N, _ = query.shape + _, S, _, _ = key.shape + output_shape = query.shape + softmax_stat_shape = (B, N, T) + + if is_training: + return ( + core.ShapedArray(output_shape, query_dtype), # output + core.ShapedArray((1,1,1,1), jnp.float32), # amax_s + core.ShapedArray((1,1,1,1), jnp.float32), # amax_o + core.ShapedArray(softmax_stat_shape, jnp.float32), # M: softmax_stat + ) + else: + return ( + core.ShapedArray(output_shape, query_dtype), # output + core.ShapedArray((1,1,1,1), jnp.float32), # amax_s + core.ShapedArray((1,1,1,1), jnp.float32), # amax_o + ) + +def _dot_product_attention_fp8_bwd_abstract( + query, key, value, fwd_output, grad_output, activation, + descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, + descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, + scale, use_causal_mask, layout): + query_dtype = dtypes.canonicalize_dtype(query.dtype) + key_dtype = dtypes.canonicalize_dtype(key.dtype) + value_dtype = dtypes.canonicalize_dtype(value.dtype) + + amax_shape = (1,1,1,1) + + return ( + core.ShapedArray( + query.shape, query_dtype + ), # grad query + core.ShapedArray( + key.shape, key_dtype + ), # grad key + core.ShapedArray( + value.shape, value_dtype + ), # grad value + core.ShapedArray( + amax_shape, jnp.float32 + ), # amax of grad of query + core.ShapedArray( + amax_shape, jnp.float32 + ), # amax of grad key + core.ShapedArray( + amax_shape, jnp.float32 + ), # amax of grad value + core.ShapedArray( + amax_shape, jnp.float32 + ), # amax of grad of P + ) + +def _dot_product_attention_fp8_fwd_cuda_lowering( + ctx, query, key, value, + descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, + scale, use_causal_mask, layout, is_training): + query_type = ir.RankedTensorType(query.type) + query_shape = query_type.shape + key_type = ir.RankedTensorType(key.type) + key_shape = key_type.shape + + if layout == AttentionLayout.BNTH.value: + B, N, T, H = query_shape + _, _, S, _ = key_shape + output_layout = (3, 2, 1, 0) + output_transpose_perm = mlir.dense_int_array((0, 1, 2, 3)) + else: + B, T, N, H = query_shape + _, S, _, _ = key_shape + output_layout = (3, 1, 2, 0) + output_transpose_perm = mlir.dense_int_array((0, 2, 1, 3)) + + output_shape = (B, N, T, H) + softmax_stat_shape = (B, N, T) + workspace_shape = (0,) + amax_shape = (1,1,1,1) + workspace_type = ir.IntegerType.get_unsigned(8) + mask_type = MaskType.CAUSAL if use_causal_mask else MaskType.NO_MASK + backend_config = create_dot_product_attention_fp8_backend_config( + B, N, T, S, ir.BF16Type.get(),#query_type.element_type, + scale, mask_type, layout, is_bwd=False, + ) + # {Q, K, V, mask*, q_seqlen*, kv_seqlen*} + # {output, activation*, workspace} + operands = [query, key, value, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o] + + custom_call_name = get_fp8_custom_call_name(is_bwd=False) + + # create output types and layouts + if is_training: + result_types = [ + ir.RankedTensorType.get(output_shape, query_type.element_type), + ir.RankedTensorType.get((1,1,1,1), ir.F32Type.get()), + ir.RankedTensorType.get((1,1,1,1), ir.F32Type.get()), + ir.RankedTensorType.get(softmax_stat_shape, ir.F32Type.get()), + ir.RankedTensorType.get(workspace_shape, workspace_type), + ] + result_layouts = [output_layout] + default_layouts(amax_shape, amax_shape, softmax_stat_shape, workspace_shape) + else: + result_types = [ + ir.RankedTensorType.get(output_shape, query_type.element_type), + ir.RankedTensorType.get((1,1,1,1), ir.F32Type.get()), + ir.RankedTensorType.get((1,1,1,1), ir.F32Type.get()), + ir.RankedTensorType.get(workspace_shape, workspace_type) + ] + result_layouts = [output_layout] + default_layouts(amax_shape, amax_shape, workspace_shape) + + tmp_shapes = [ir.RankedTensorType(operand.type).shape for operand in operands[:3]] + tmp_shapes += [[1, 1, 1, 1]] * 6 + operand_layouts = default_layouts(*tmp_shapes) + out = mlir.custom_call( + custom_call_name, + result_types=result_types, + operands=operands, + backend_config=backend_config, + operand_layouts=operand_layouts, + result_layouts=result_layouts, + ) + # drop workspace memory + # output should be (B, T, N, H) instead of (B, N, T, H) + if is_training: + return [hlo.transpose(out.results[0], output_transpose_perm), out.results[1], out.results[2], out.results[3]] + else: + return [hlo.transpose(out.results[0], output_transpose_perm), out.results[1], out.results[2]] + + + +def _dot_product_attention_fp8_bwd_cuda_lowering( + ctx, query, key, value, fwd_output, grad_output, activation, + descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, + descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, scale, + use_causal_mask, layout): + query_type = ir.RankedTensorType(query.type) + query_shape = query_type.shape + key_type = ir.RankedTensorType(key.type) + key_shape = key_type.shape + value_type = ir.RankedTensorType(value.type) + + if layout == AttentionLayout.BNTH.value: + B, q_N, T, H = query_shape + _, k_N, S, _ = key_shape + grad_layout = (3, 2, 1, 0) + grad_transpose_perm = mlir.dense_int_array((0, 1, 2, 3)) + else: + B, T, q_N, H = query_shape + _, S, k_N, _ = key_shape + grad_layout = (3, 1, 2, 0) + grad_transpose_perm = mlir.dense_int_array((0, 2, 1, 3)) + + workspace_shape = (0,) + workspace_type = ir.IntegerType.get_unsigned(8) + amax_shape = (1,1,1,1) + + grad_query_shape = (B, q_N, T, H) + grad_key_shape = (B, k_N, S, H) + grad_value_shape = (B, k_N, S, H) + mask_type = MaskType.CAUSAL if use_causal_mask else MaskType.NO_MASK + + backend_config = create_dot_product_attention_fp8_backend_config( + B, q_N, T, S, ir.BF16Type.get(), + scale, mask_type, layout, is_bwd=True, + ) + + # create operands + operands = [ + query, + key, + value, + fwd_output, + grad_output, + activation, + descale_q, + descale_k, + descale_v, + descale_o, + descale_dO, + descale_s, + descale_dP, + scale_s, + scale_dQ, + scale_dK, + scale_dV, + scale_dP, + ] + + # get custom call name + custom_call_name = get_fp8_custom_call_name(is_bwd=True) + + # create output types and layouts + # grad_query, grad_key, grad_value, amax_dQ, amax_dK, amax_dV, amax_dP + result_types = [ + ir.RankedTensorType.get(grad_query_shape, query_type.element_type), + ir.RankedTensorType.get(grad_key_shape, key_type.element_type), + ir.RankedTensorType.get(grad_value_shape, value_type.element_type), + ir.RankedTensorType.get(amax_shape, ir.F32Type.get()), + ir.RankedTensorType.get(amax_shape, ir.F32Type.get()), + ir.RankedTensorType.get(amax_shape, ir.F32Type.get()), + ir.RankedTensorType.get(amax_shape, ir.F32Type.get()), + ] + result_layouts = [grad_layout, grad_layout, grad_layout] + default_layouts(amax_shape, amax_shape, amax_shape, amax_shape) + + # workspace + result_types.append(ir.RankedTensorType.get(workspace_shape, workspace_type)) + result_layouts = result_layouts + default_layouts(workspace_shape) + out = mlir.custom_call( + custom_call_name, + result_types=result_types, + operands=operands, + backend_config=backend_config, + operand_layouts=default_layouts( + *[ir.RankedTensorType(operand.type).shape for operand in operands]), + result_layouts=result_layouts, + ) + dqkv_amaxs = (hlo.transpose(out.results[0], grad_transpose_perm), + hlo.transpose(out.results[1], grad_transpose_perm), + hlo.transpose(out.results[2], grad_transpose_perm), + out.results[3], out.results[4], out.results[5], out.results[6]) + # Only keep dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP here + return dqkv_amaxs + +def _dot_product_attention_fp8_fwd_batcher( + batched_args, batch_dims, *, scale, use_causal_mask, layout, is_training): + _check_valid_batch_dims(batch_dims) + query, key, value,\ + descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, = batched_args + query_bdim = batch_dims[0] + if is_training: + out_bdims = query_bdim, query_bdim + else: + out_bdims = (query_bdim,) + + if layout == AttentionLayout.BNTH.value: + *Bs, N, T, _ = query.shape + *_, _, S, _ = key.shape + else: + *Bs, T, N, _ = query.shape + *_, S, _, _ = key.shape + B = math.prod(Bs) + + # reshape to 4D shape + query = jnp.reshape(query, (B,) + query.shape[-3:]) + key = jnp.reshape(key, (B,) + key.shape[-3:]) + value = jnp.reshape(value, (B,) + key.shape[-3:]) + + outputs = _dot_product_attention_fp8_fwd_p_wrapper.bind( + query, key, value, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, + scale=scale, use_causal_mask=use_causal_mask, layout=layout, is_training=is_training) + + # reshape to original shape + output, amax_s, amax_o = outputs[0], outputs[1], outputs[2] + output = jnp.reshape(output, query.shape) + if is_training: + activation = outputs[3] + activation = jnp.reshape(activation, (*Bs, N, T)) + return (output, amax_s, amax_o, activation), out_bdims + else: + return (output, amax_s, amax_o), out_bdims + +def _dot_product_attention_fp8_bwd_batcher( + batched_args, batch_dims, *, scale, use_causal_mask, layout): + _check_valid_batch_dims(batch_dims) + query, key, value, fwd_output, grad_output, activation,\ + descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, descale_dP,\ + scale_s, scale_dQ, scale_dK, scale_dV, scale_dP = batched_args + query_bdim = batch_dims[0] + out_bdims = query_bdim, query_bdim, query_bdim + + if layout == AttentionLayout.BNTH.value: + *Bs, N, T, _ = query.shape + *_, _, S, _ = key.shape + else: + *Bs, T, N, _ = query.shape + *_, S, _, _ = key.shape + B = math.prod(Bs) + + # reshape to 4D shape + query = jnp.reshape(query, (B,) + query.shape[-3:]) + key = jnp.reshape(key, (B,) + key.shape[-3:]) + value = jnp.reshape(value, (B,) + key.shape[-3:]) + + activation = jnp.reshape(activation, (B, N, T)) + fwd_output = jnp.reshape(fwd_output, (B,) + query.shape[-3:]) + grad_output = jnp.reshape(grad_output, (B,) + query.shape[-3:]) + + grads = _dot_product_attention_fp8_bwd_p_wrapper.bind( + query, key, value, fwd_output, grad_output, activation, + descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, descale_dP, scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, + scale=scale, use_causal_mask=use_causal_mask, layout=layout, + ) + + grad_query, grad_key, grad_value = grads[:3] + # reshape to original shape + grad_query = jnp.reshape(grad_query, query.shape) + grad_key = jnp.reshape(grad_key, key.shape) + grad_value = jnp.reshape(grad_value, value.shape) + + return grads, out_bdims + +def _infer_fp8_fwd_output_sharding(mesh, arg_shapes, is_training, layout): + # Prepare variadic_args for the original function + has_bias = False # Adjust as needed + variadic_args = (has_bias, None) # Dummy value, adjust as necessary + + # Call the original function with the required parameters + output_sharding = _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args, is_training, layout) + amax_sharding = NamedSharding(mesh, PartitionSpec()) + if is_training: + out_sharding, activation_sharding = output_sharding[0], output_sharding[1] + return [out_sharding, amax_sharding, amax_sharding, activation_sharding] + return output_sharding + [amax_sharding, amax_sharding] + +_dot_product_attention_fp8_fwd_lower = custom_partitioning( + _dot_product_attention_fp8_fwd_impl, static_argnums=(9, 10, 11, 12)) + +def _dot_product_attention_fp8_fwd_infer_sharding_from_operands( + scale, use_causal_mask, layout, is_training, + mesh, arg_shapes, result_shape): + return _infer_fp8_fwd_output_sharding(mesh, arg_shapes, is_training, layout) + +def _dot_product_attention_fp8_fwd_partition( + scale, use_causal_mask, layout, is_training, + mesh, arg_shapes, result_shape): + # args sharding + arg_shardings = tuple(arg_i.sharding for arg_i in arg_shapes) + out_shardings = _infer_fp8_fwd_output_sharding( + mesh, arg_shapes, is_training, layout) + impl = functools.partial( + _dot_product_attention_fp8_fwd_impl, scale=scale, use_causal_mask=use_causal_mask, + layout=layout, is_training=is_training) + return mesh, impl, out_shardings, arg_shardings + +def _infer_fp8_bwd_output_sharding(mesh, arg_shapes, layout): + # Prepare variadic_args for the original function + has_bias = False # Adjust as needed + has_dbias = False # Adjust as needed + variadic_args = (has_bias, has_dbias) # Dummy value, adjust as necessary + + # Call the original function with the required parameters + output_shardings = _infer_bwd_output_sharding(mesh, arg_shapes, layout, variadic_args) + + # Prepare amax_sharding + amax_sharding = NamedSharding(mesh, PartitionSpec()) # Use a default spec or adjust as needed + + # Append amax_sharding for each output sharding + out_shardings_with_amax = output_shardings + [amax_sharding] * 4 + + return out_shardings_with_amax + +_dot_product_attention_fp8_bwd_lower = custom_partitioning( + _dot_product_attention_fp8_bwd_impl, static_argnums=(18,19,20) +) + +def _dot_product_attention_fp8_bwd_infer_sharding_from_operands( + scale, use_causal_mask, layout, mesh, + arg_shapes, result_shape): + return _infer_fp8_bwd_output_sharding(mesh, arg_shapes, layout) + +def _dot_product_attention_fp8_bwd_partition( + scale, use_causal_mask, layout, mesh, + arg_shapes, result_shape): + out_shardings = _infer_fp8_bwd_output_sharding(mesh, arg_shapes, layout) + # args sharding + arg_shardings = tuple(arg_i.sharding for arg_i in arg_shapes) + impl = functools.partial( + _dot_product_attention_fp8_bwd_impl, scale=scale, + use_causal_mask=use_causal_mask, layout=layout + ) + return mesh, impl, out_shardings, arg_shardings + +# Create dot_product_attention_fp8_fwd_p for forward operation. +_dot_product_attention_fp8_fwd_p = core.Primitive("dot_product_attention_fp8_fwd") +_dot_product_attention_fp8_fwd_p.multiple_results = True +_dot_product_attention_fp8_fwd_p.def_impl( + functools.partial(xla.apply_primitive, _dot_product_attention_fp8_fwd_p) +) +_dot_product_attention_fp8_fwd_p.def_abstract_eval( + _dot_product_attention_fp8_fwd_abstract +) + +mlir.register_lowering( + _dot_product_attention_fp8_fwd_p, + _dot_product_attention_fp8_fwd_cuda_lowering, + platform="cuda", +) + +_dot_product_attention_fp8_fwd_p_wrapper = core.Primitive( + "dot_product_attention_fp8_fwd_wrapper" +) +_dot_product_attention_fp8_fwd_p_wrapper.multiple_results = True +_dot_product_attention_fp8_fwd_p_wrapper.def_impl(_dot_product_attention_fp8_fwd_impl) +_dot_product_attention_fp8_fwd_p_wrapper.def_abstract_eval( + _dot_product_attention_fp8_fwd_abstract +) + +# Create dot_product_attention_bwd_p for backward operation. +_dot_product_attention_fp8_bwd_p = core.Primitive("dot_product_attention_fp8_bwd") +_dot_product_attention_fp8_bwd_p.multiple_results = True +_dot_product_attention_fp8_bwd_p.def_impl( + functools.partial(xla.apply_primitive, _dot_product_attention_fp8_bwd_p) +) +_dot_product_attention_fp8_bwd_p.def_abstract_eval( + _dot_product_attention_fp8_bwd_abstract +) + +mlir.register_lowering( + _dot_product_attention_fp8_bwd_p, + _dot_product_attention_fp8_bwd_cuda_lowering, + platform="cuda", +) + +_dot_product_attention_fp8_bwd_p_wrapper = core.Primitive( + "dot_product_attention_fp8_bwd_wrapper" +) +_dot_product_attention_fp8_bwd_p_wrapper.multiple_results = True +_dot_product_attention_fp8_bwd_p_wrapper.def_impl(_dot_product_attention_fp8_bwd_impl) +_dot_product_attention_fp8_bwd_p_wrapper.def_abstract_eval( + _dot_product_attention_fp8_bwd_abstract +) + +batching.primitive_batchers[ + _dot_product_attention_fp8_fwd_p_wrapper +] = _dot_product_attention_fp8_fwd_batcher +batching.primitive_batchers[ + _dot_product_attention_fp8_bwd_p_wrapper +] = _dot_product_attention_fp8_bwd_batcher + +_dot_product_attention_fp8_fwd_lower.def_partition( + infer_sharding_from_operands=_dot_product_attention_fp8_fwd_infer_sharding_from_operands, + partition=_dot_product_attention_fp8_fwd_partition) + +mlir.register_lowering(_dot_product_attention_fp8_fwd_p_wrapper, + mlir.lower_fun(_dot_product_attention_fp8_fwd_lower, multiple_results=True)) + +_dot_product_attention_fp8_bwd_lower.def_partition( + infer_sharding_from_operands=_dot_product_attention_fp8_bwd_infer_sharding_from_operands, + partition=_dot_product_attention_fp8_bwd_partition) + +mlir.register_lowering(_dot_product_attention_fp8_bwd_p_wrapper, + mlir.lower_fun(_dot_product_attention_fp8_bwd_lower, multiple_results=True)) + +dispatch.prim_requires_devices_during_lowering.add( + _dot_product_attention_fp8_fwd_p +) +dispatch.prim_requires_devices_during_lowering.add( + _dot_product_attention_fp8_fwd_p_wrapper +) +dispatch.prim_requires_devices_during_lowering.add( + _dot_product_attention_fp8_bwd_p +) +dispatch.prim_requires_devices_during_lowering.add( + _dot_product_attention_fp8_bwd_p_wrapper +) + +@functools.partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7)) +def _dot_product_attention_fp8(query: Array, + key: Array, + value: Array, + fp8_params: dict[str, Array], + scale: float, + use_causal_mask: bool, + layout: int, + cudnn_version: int): + output, amax_s, amax_o = _dot_product_attention_fp8_fwd( + query, key, value, params_from_keys(fp8_params, fp8_params_keys_fwd), + scale, use_causal_mask, layout, cudnn_version + ) + return output, amax_s, amax_o + +_dot_product_attention_fp8.defvjp(_dot_product_attention_fp8_fwd_rule, _dot_product_attention_fp8_bwd_rule) # User interface -def dot_product_attention(query: Array, - key: Array, - value: Array, - bias: Array | None = None, - mask: Array | None = None, - q_seqlen: Array | None = None, - kv_seqlen: Array | None = None, - *, - scale: float = 1.0, - mask_type: MaskType = MaskType.NO_MASK, - seed: int = 42, - dropout_rate: float = 0., - qkv_layout: str = "BTNH", - sliding_window_length: int | None = None): +def dot_product_attention( + query: Array, + key: Array, + value: Array, + bias: Array | None = None, + mask: Array | None = None, + q_seqlen: Array | None = None, + kv_seqlen: Array | None = None, + fp8_params: dict[str, Array] | None = None, + *, + scale: float = 1.0, + mask_type: MaskType = MaskType.NO_MASK, + seed: int = 42, + dropout_rate: float = 0., + qkv_layout: str = "BTNH", + sliding_window_length: int | None = None, + use_fp8: bool = False +): """Computes dot-product attention given query (Q), key (K), and value (V). This function serves as the core operation for applying attention @@ -1018,56 +1677,92 @@ def dot_product_attention(query: Array, token's left local window (pos - sliding_window_length, pos] where `pos` is the index of each token. E.g., if sliding_window_length == 3 and the sequence is [0, 1, 2, 3, c, 4, 5], token `c` can attend to [4, 5, c]. - + use_fp8: Whether to use FP8 attention mechanism. + fp8_params: Dictionary containing FP8-specific parameters. That includes: + amax_dQ: amax of gradient of query. + amax_dK: amax of gradient of key. + amax_dV: amax of gradient of value. + amax_dP: amax of gradient of state. + descale_q: Descaling factor of query. + descale_k: Descaling factor of key. + descale_v: Descaling factor of value. + descale_s: Descaling factor of attention score. + scale_s: Scale factor for S tensor. + scale_o: Scale factor for output. + descale_o (bwd): Descale factor for output. + descale_dO (bwd): Descale factor for output gradient. + descale_dP (bwd): Descale factor for P gradient tensor. + scale_dQ (bwd): Scale factor for query gradient. + scale_dK (bwd): Scale factor for key gradient. + scale_dV (bwd): Scale factor for value gradient. + scale_dP (bwd): Scale factor for state gradient. Returns: Output of the same shape as the query. + amax_s: amax of state. (fp8 only) + amax_o: amax of output. (fp8 only) """ # check if cuDNN is installed cudnn_version = check_cudnn_version() - # only support at least Ampere if not check_compute_capability("8.0"): raise RuntimeError("Require at least Ampere arch to run") layout = _normalize_layout(qkv_layout) - if has_padding(mask_type) and (q_seqlen is None or kv_seqlen is None): - raise ValueError("Require q_seqlen and kv_seqlen to generate padding mask") - if sliding_window_length is not None and sliding_window_length <= 0: - raise ValueError( - f"Require sliding_window_length > 0, got {sliding_window_length}") - if bias is not None: - # reshape bias to have 4D shape - bias = bias.reshape((1,) * (4 - len(bias.shape)) + bias.shape) - - if mask is not None: - if mask.dtype == jnp.bool: - large_negative_number = get_large_negative_number(query.dtype) - mask = jnp.where(mask, jnp.asarray(0, query.dtype), large_negative_number) - # reshape mask to have 4D shape - mask = mask.reshape((1,) * (4 - len(mask.shape)) + mask.shape) # type: ignore[union-attr] - - # combine bias and mask - if bias is None: - bias = mask + if use_fp8: + assert fp8_params is not None + assert mask_type in (MaskType.NO_MASK, MaskType.CAUSAL) + assert all(x is None for x in [bias, mask, q_seqlen, kv_seqlen]), \ + ( + f"Expected 'None' for bias, mask, q_seqlen, and kv_seqlen, " + f"but got: bias={bias}, mask={mask}, q_seqlen={q_seqlen}, kv_seqlen={kv_seqlen}" + ) + check_fp8_params(fp8_params) + check_layout(query, key, value, bias, q_seqlen, kv_seqlen, layout) + output, amax_s, amax_o = _dot_product_attention_fp8( + query, key, value, fp8_params, + scale, mask_type == MaskType.CAUSAL, layout.value, cudnn_version + ) + return output, amax_s, amax_o else: - if mask is not None: - # should be broadcast to same shape - bias = bias + mask - - # check if input shape and data type is compatiable - check_layout(query, key, value, bias, q_seqlen, kv_seqlen, layout) - has_bias = bias is not None - has_dbias = has_bias and \ - should_export_dbias(bias.shape, query.shape, layout) # type: ignore[union-attr] - variadic_args = (has_bias, has_dbias) - - if bias is None: - bias = jnp.zeros(0, dtype=query.dtype) - if q_seqlen is None: - q_seqlen = jnp.zeros(0, dtype=query.dtype) - if kv_seqlen is None: - kv_seqlen = jnp.zeros(0, dtype=query.dtype) - output = _dot_product_attention( - query, key, value, bias, q_seqlen, kv_seqlen, scale, seed, - dropout_rate, variadic_args, mask_type, layout.value, sliding_window_length, - cudnn_version) - return output + if has_padding(mask_type) and (q_seqlen is None or kv_seqlen is None): + raise ValueError("Require q_seqlen and kv_seqlen to generate padding mask") + if sliding_window_length is not None and sliding_window_length <= 0: + raise ValueError(f"Require sliding_window_length > 0, got {sliding_window_length}") + if bias is not None: + # reshape bias to have 4D shape + bias = bias.reshape((1,) * (4 - len(bias.shape)) + bias.shape) + + if mask is not None: + if mask.dtype == jnp.bool: + large_negative_number = get_large_negative_number(query.dtype) + mask = jnp.where(mask, jnp.asarray(0, query.dtype), large_negative_number) + # reshape mask to have 4D shape + mask = mask.reshape((1,) * (4 - len(mask.shape)) + mask.shape) # type: ignore[union-attr] + + # combine bias and mask + if bias is None: + bias = mask + else: + if mask is not None: + # should be broadcast to same shape + bias = bias + mask + + # check if input shape and data type is compatiable + check_layout(query, key, value, bias, q_seqlen, kv_seqlen, layout) + has_bias = bias is not None + has_dbias = has_bias and \ + should_export_dbias(bias.shape, query.shape, layout) # type: ignore[union-attr] + variadic_args = (has_bias, has_dbias) + + if bias is None: + bias = jnp.zeros(0, dtype=query.dtype) + if q_seqlen is None: + q_seqlen = jnp.zeros(0, dtype=query.dtype) + if kv_seqlen is None: + kv_seqlen = jnp.zeros(0, dtype=query.dtype) + + output = _dot_product_attention( + query, key, value, bias, q_seqlen, kv_seqlen, scale, seed, + dropout_rate, variadic_args, mask_type, layout.value, sliding_window_length, + cudnn_version + ) + return output diff --git a/tests/fused_attention_stablehlo_test.py b/tests/fused_attention_stablehlo_test.py index 2cfcfa7c5ec6..903c3e5ec276 100644 --- a/tests/fused_attention_stablehlo_test.py +++ b/tests/fused_attention_stablehlo_test.py @@ -34,10 +34,22 @@ MaskType, AttentionLayout, ) +from flax.linen.fp8_ops import qdq, quantize config.parse_flags_with_absl() Array = jnp.ndarray +fp8_meta_names = [ + 'amax_dQ', 'amax_dK', 'amax_dV', 'amax_dP', + 'descale_q', 'descale_k', 'descale_v', 'descale_s', + 'scale_s', 'scale_o', 'descale_o', 'descale_dO', + 'descale_dP', 'scale_dQ', 'scale_dK', 'scale_dV', 'scale_dP', +] + +fp8_metas = {name: jnp.ones((1, 1, 1, 1), dtype=jnp.float32) for name in fp8_meta_names} + +cast_to_representable = partial(qdq, scale=jnp.ones((1,)), compute_dtype=jnp.bfloat16) + def sdpa_train(query: Array, key: Array, value: Array, @@ -171,6 +183,25 @@ def sdpa_train_ref(query: Array, return out_ref, (query_grad_ref, key_grad_ref, value_grad_ref, bias_grad_ref) return out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) +def sdpa_train_fp8(query: Array, + key: Array, + value: Array, + grad: Array, + fp8_metas: dict[Array], + scale: float = 0.5, + mask_type: MaskType = MaskType.NO_MASK): + def dot_product_attention_fp8(query, key, value, fp8_metas): + f_p = partial( + dot_product_attention, scale=scale, mask_type=mask_type, use_fp8=True) + return f_p(query, key, value, None, None, None, None, fp8_metas) + out, sdpa_vjp = jax.vjp( + dot_product_attention_fp8, query, key, value, fp8_metas) + + grad_amax_s = jnp.ones((1,1,1,1), dtype=jnp.float32) + grad_amax_o = jnp.ones((1,1,1,1), dtype=jnp.float32) + query_grad, key_grad, value_grad, *_ = sdpa_vjp((grad, grad_amax_s, grad_amax_o)) + return out[0], (query_grad, key_grad, value_grad) + class DotProductAttentionTest(jtu.JaxTestCase): def setUp(self): super().setUp() @@ -202,7 +233,7 @@ def setUp(self): def test_sdpa(self, batch_size: int, seq_len: int, num_heads: int, head_dim: int, use_mask: bool, use_bias: bool, mask_type: MaskType, dropout_rate: float, scale: float, dtype: jnp.dtype): - if len(jax.local_devices()) <= 4: + if len(jax.local_devices()) < 4: self.skipTest("Require at least 4 devices to run sharding tests.") if use_mask and mask_type != MaskType.NO_MASK: self.skipTest("Either pass in mask or generate mask directly in cuDNN.") @@ -543,5 +574,170 @@ def test_sdpa_utils(self): query, key, AttentionLayout.BNTH.value, cudnn_version, has_bias, is_training) + +@jtu.with_config(jax_numpy_dtype_promotion='standard') +class DotProductAttentionF8Test(jtu.JaxTestCase): + def setUp(self): + super().setUp() + if jax.device_count() < 4: + self.skipTest("Requires more than 4 devices.") + try: + cudnn_version = check_cudnn_version() + except RuntimeError as e: + self.skipTest(str(e)) + return + if cudnn_version < 9010: + self.skipTest("Requires >= cuDNN 9.1.0") + if not jtu.is_cuda_compute_capability_at_least("9.0"): + self.skipTest("Requires at least Hopper arch") + + @jtu.sample_product( + batch_size=[2, 4], + seq_len=[128, 256], + num_heads=[4, 8], + head_dim=[128], + mask_type=[MaskType.NO_MASK], + scale=[1.0, 0.75], + dtype=[jnp.bfloat16, jnp.float16] + ) + @jtu.run_on_devices("cuda") + def test_sdpa_fp8(self, batch_size: int, seq_len: int, num_heads: int, + head_dim: int, mask_type: MaskType, + scale: float, dtype: jnp.dtype): + k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4) + input_shape = (batch_size, seq_len, num_heads, head_dim) # only test the default BTNH + query_h = jax.random.normal( + k1, input_shape, dtype=dtype) + key_h = jax.random.normal( + k2, input_shape, dtype=dtype) + value_h = jax.random.normal( + k3, input_shape, dtype=dtype) + grad_h = jax.random.normal( + k4, input_shape, dtype=dtype) + query = cast_to_representable(query_h, jnp.float8_e4m3fn) + key = cast_to_representable(key_h, jnp.float8_e4m3fn) + value = cast_to_representable(value_h, jnp.float8_e4m3fn) + grad = cast_to_representable(grad_h, jnp.float8_e4m3fn) + + query_quantized = quantize(query, jnp.float8_e4m3fn, jnp.ones((1,)), jnp.float32) + key_quantized = quantize(key, jnp.float8_e4m3fn, jnp.ones((1,)), jnp.float32) + value_quantized = quantize(value, jnp.float8_e4m3fn, jnp.ones((1,)), jnp.float32) + grad_quantized = quantize(grad, jnp.float8_e4m3fn, jnp.ones((1,)), jnp.float32) + + devices = np.array(jax.local_devices()[:4]) + devices = devices.reshape((2, 2)) + + with Mesh(devices, ("dp", "tp")) as mesh: + qkv_spec = PartitionSpec("dp", None, "tp", None) + qkv_sharding = NamedSharding(mesh, qkv_spec) + query = jax.device_put(query, qkv_sharding) + key = jax.device_put(key, qkv_sharding) + value = jax.device_put(value, qkv_sharding) + grad = jax.device_put(grad, qkv_sharding) + + query_quantized = jax.device_put(query_quantized , qkv_sharding) + key_quantized = jax.device_put(key_quantized , qkv_sharding) + value_quantized = jax.device_put(value_quantized , qkv_sharding) + grad_quantized = jax.device_put(grad_quantized , qkv_sharding) + + fp8_meta_shardings = {name: None for name in fp8_meta_names} + in_shardings = (qkv_sharding, qkv_sharding, qkv_sharding, qkv_sharding, fp8_meta_shardings) + out_shardings = (qkv_sharding, (qkv_sharding, qkv_sharding, qkv_sharding)) + + in_shardings_ref = (qkv_sharding, qkv_sharding, qkv_sharding, qkv_sharding, None, None) + out_shardings_ref = out_shardings + sdpa_train_fp8_p = partial(sdpa_train_fp8, scale=scale, mask_type=mask_type) + jitted_sdpa_train_fp8 = jax.jit(sdpa_train_fp8_p, in_shardings=in_shardings, out_shardings=out_shardings) + jitted_sdpa_train_ref = jax.jit( + partial( + sdpa_train_ref, scale=scale, mask_type=mask_type, dropout_rate=0.0), + in_shardings=in_shardings_ref, + out_shardings=out_shardings_ref + ) + + out, (query_grad, key_grad, value_grad) = \ + jitted_sdpa_train_fp8(query_quantized, key_quantized, value_quantized, grad_quantized, fp8_metas) + out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = \ + jitted_sdpa_train_ref(query, key, value, grad, None, None) + + self.assertArraysAllClose(out_ref, out.astype(dtype), rtol=5e-1, atol=5e-1) + self.assertArraysAllClose(query_grad_ref, query_grad.astype(dtype), rtol=5e-1, atol=3e0) + self.assertArraysAllClose(key_grad_ref, key_grad.astype(dtype), rtol=5e-1, atol=3e0) + self.assertArraysAllClose(value_grad_ref, value_grad.astype(dtype), rtol=5e-1, atol=5e-1) + + @jtu.sample_product( + batch_size=[4, 2], + seq_len=[4, 16], + num_heads=[4, 16], + head_dim=[16, 32], + mask_type=[MaskType.NO_MASK], + qkv_layout=["BNTH", "BTNH"], + scale=[1.0, 0.75], + dtype=[jnp.bfloat16, jnp.float16] + ) + @jtu.run_on_devices("cuda") + def test_sdpa_fp8_inference(self, batch_size: int, seq_len: int, num_heads: int, + head_dim: int, mask_type: MaskType, qkv_layout: str, + scale: float, dtype: jnp.dtype): + k1, k2, k3 = jax.random.split(jax.random.key(0), 3) + if qkv_layout == "BNTH": + input_shape = (batch_size, num_heads, seq_len, head_dim) + else: + input_shape = (batch_size, seq_len, num_heads, head_dim) + query_h = jax.random.normal(k1, input_shape, dtype=dtype) + key_h = jax.random.normal(k2, input_shape, dtype=dtype) + value_h = jax.random.normal(k3, input_shape, dtype=dtype) + + query = cast_to_representable(query_h, jnp.float8_e4m3fn) + key = cast_to_representable(key_h, jnp.float8_e4m3fn) + value = cast_to_representable(value_h, jnp.float8_e4m3fn) + + query_quantized = quantize(query, jnp.float8_e4m3fn, jnp.ones((1,)), jnp.float32) + key_quantized = quantize(key, jnp.float8_e4m3fn, jnp.ones((1,)), jnp.float32) + value_quantized = quantize(value, jnp.float8_e4m3fn, jnp.ones((1,)), jnp.float32) + + devices = np.array(jax.local_devices()[:4]) + devices = devices.reshape((2, 2)) + + with Mesh(devices, ("dp", "tp")) as mesh: + if qkv_layout == "BNTH": + qkv_spec = PartitionSpec("dp", "tp", None, None) + else: + qkv_spec = PartitionSpec("dp", None, "tp", None) + + qkv_sharding = NamedSharding(mesh, qkv_spec) + fp8_meta_shardings = {name: None for name in fp8_meta_names} + in_shardings = ( + qkv_sharding, qkv_sharding, qkv_sharding, fp8_meta_shardings) + out_shardings = (qkv_sharding, None, None) + + in_shardings_ref = ( + qkv_sharding, qkv_sharding, qkv_sharding) + out_shardings_ref = qkv_sharding + + query = jax.device_put(query, qkv_sharding) + key = jax.device_put(key, qkv_sharding) + value = jax.device_put(value, qkv_sharding) + def dot_product_attention_fp8(query, key, value, fp8_metas): + f_p = partial( + dot_product_attention, scale=scale, mask_type=mask_type, qkv_layout=qkv_layout, use_fp8=True) + return f_p(query, key, value, None, None, None, None, fp8_metas) + + jitted_sdpa_inference = jax.jit( + dot_product_attention_fp8, + in_shardings=in_shardings, + out_shardings=out_shardings + ) + + jitted_sdpa_inference_ref = jax.jit( + partial( + dot_product_attention, scale=scale, mask_type=mask_type, qkv_layout=qkv_layout), + in_shardings=in_shardings_ref, + out_shardings=out_shardings_ref + ) + out, _, _ = jitted_sdpa_inference(query_quantized, key_quantized, value_quantized, fp8_metas) + out_ref = jitted_sdpa_inference_ref(query, key, value) + self.assertArraysAllClose(out_ref, out.astype(dtype), rtol=5e-2, atol=5e-2) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())