From 23d7be2401482154b00545638b95ff34160f0617 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 10 Nov 2025 08:05:46 -0800 Subject: [PATCH 01/13] shardy + quantize_layout rework Signed-off-by: Phuong Nguyen --- .../jax/cpp_extensions/activation.py | 160 +++++++++--------- transformer_engine/jax/cpp_extensions/misc.py | 8 +- .../jax/cpp_extensions/normalization.py | 102 +++++------ .../jax/cpp_extensions/quantization.py | 86 +++++----- transformer_engine/jax/csrc/extensions.h | 6 +- .../jax/csrc/extensions/activation.cpp | 45 ++--- transformer_engine/jax/csrc/extensions/misc.h | 14 +- .../jax/csrc/extensions/normalization.cpp | 29 ++-- .../jax/csrc/extensions/pybind.cpp | 9 +- .../jax/csrc/extensions/quantization.cpp | 40 ++--- transformer_engine/jax/quantize/__init__.py | 1 + transformer_engine/jax/quantize/misc.py | 64 +++++++ transformer_engine/jax/quantize/quantizer.py | 53 ++---- .../jax/quantize/scaling_modes.py | 133 +++++++++++---- transformer_engine/jax/quantize/tensor.py | 33 ++-- 15 files changed, 442 insertions(+), 341 deletions(-) create mode 100644 transformer_engine/jax/quantize/misc.py diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index bb3c56bcf1..2ec13f915e 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -10,7 +10,7 @@ import jax import jax.numpy as jnp from jax import dtypes, ffi -from jax.experimental.custom_partitioning import SdyShardingRule +from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING from jax.sharding import PartitionSpec import numpy as np @@ -159,7 +159,7 @@ class ActLuPrimitive(BasePrimitive): 11, 12, 13, - ) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling, is_outer + ) # out_dtype, act_enum, act_len, scaling_mode, quantize_layout, scale_dtype, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling, is_outer inner_primitive = None outer_primitive = None @@ -173,7 +173,7 @@ def abstract( act_enum, act_len, scaling_mode, - is_2x, + quantize_layout, scale_dtype, act_params, amax_scope, @@ -210,7 +210,7 @@ def abstract( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode ).get_scale_shape_2x(out_shape, is_padded=not is_outer, flatten_axis=-1) - if not is_2x: + if quantize_layout.is_rowwise_only: out_shape = (1,) colwise_scale_inv_shape = (1,) colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) @@ -232,7 +232,7 @@ def lowering( act_enum, act_len, scaling_mode, - is_2x, + quantize_layout, scale_dtype, act_params, amax_scope, @@ -259,7 +259,7 @@ def lowering( amax, act_enum=act_enum, scaling_mode=scaling_mode.value, - is_2x=is_2x, + quantize_layout=quantize_layout.value.value, act_params=act_params.to_ffi_lowering_dict(), output_amax_when_no_scaling=output_amax_when_no_scaling, ) @@ -274,7 +274,7 @@ def impl( act_enum, act_len, scaling_mode, - is_2x, + quantize_layout, scale_dtype, act_params, amax_scope, @@ -297,7 +297,7 @@ def impl( act_enum=act_enum, act_len=act_len, scaling_mode=scaling_mode, - is_2x=is_2x, + quantize_layout=quantize_layout, scale_dtype=scale_dtype, act_params=act_params, amax_scope=amax_scope, @@ -313,7 +313,7 @@ def impl( scale_inv = jax.lax.slice( scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape ) - if is_2x: + if quantize_layout.is_rowwise_colwise: colwise_scale_inv = jax.lax.slice( colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape ) @@ -329,7 +329,7 @@ def batcher( act_enum, act_len, scaling_mode, - is_2x, + quantize_layout, scale_dtype, act_params, amax_scope, @@ -356,7 +356,7 @@ def batcher( act_enum=act_enum, act_len=act_len, scaling_mode=scaling_mode, - is_2x=is_2x, + quantize_layout=quantize_layout, scale_dtype=scale_dtype, act_params=act_params, amax_scope=amax_scope, @@ -373,7 +373,7 @@ def infer_sharding_from_operands( act_enum, act_len, scaling_mode, - is_2x, + quantize_layout, scale_dtype, act_params, amax_scope, @@ -402,7 +402,7 @@ def infer_sharding_from_operands( out_spec = (*x_spec[:-2], x_spec[-1]) out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out") - if is_2x: + if quantize_layout.is_rowwise_colwise: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1) else: @@ -419,7 +419,7 @@ def infer_sharding_from_operands( elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = out_spec - if is_2x: + if quantize_layout.is_rowwise_colwise: colwise_scale_inv_spec = scale_inv_spec scale_inv_sharding = NamedSharding( @@ -444,7 +444,7 @@ def partition( act_enum, act_len, scaling_mode, - is_2x, + quantize_layout, scale_dtype, act_params, amax_scope, @@ -462,7 +462,7 @@ def partition( out_spec = (*x_spec[:-2], x_spec[-1]) out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out") - if is_2x: + if quantize_layout.is_rowwise_colwise: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1) else: @@ -479,7 +479,7 @@ def partition( elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = out_spec - if is_2x: + if quantize_layout.is_rowwise_colwise: colwise_scale_inv_spec = scale_inv_spec scale_inv_sharding = NamedSharding( @@ -514,7 +514,7 @@ def sharded_impl(x, scale, amax): act_enum=act_enum, act_len=act_len, scaling_mode=scaling_mode, - is_2x=is_2x, + quantize_layout=quantize_layout, scale_dtype=scale_dtype, act_params=act_params, amax_scope=amax_scope, @@ -550,7 +550,7 @@ def shardy_sharding_rule( act_enum, act_len, scaling_mode, - is_2x, + quantize_layout, scale_dtype, act_params, amax_scope, @@ -574,37 +574,28 @@ def shardy_sharding_rule( mesh, result_types, ) - prefix = "ActLu_" + prefix = "ActLu" input_shape = value_types[0].shape output_shape = input_shape[:-2] + input_shape[-1:] # Here we pass len of output so that the scales are propagated correctly scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - output_shape, unique_var=prefix + "x", flatten_axis=-1 + output_shape, unique_var=prefix, flatten_axis=-1, q_layout=quantize_layout ) - x_axes = scale_rules.input_spec - # Correct input spec with act dim - x_axes = x_axes[:-1] + (prefix + "_act_dim",) + x_axes[-1:] - out = scale_rules.input_spec - - colwise_out = (prefix + "out_colwise",) - colwise_scale_inv = (prefix + "scale_inv_colwise",) - if is_2x: - colwise_scale_inv = scale_rules.colwise_rule - if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - colwise_out = multidim_transpose(out, transpose_axis=-1) - else: - colwise_out = out - colwise_scale_inv = scale_rules.colwise_rule - - amax = (prefix + "amax",) + # Correct the input spec with act dim + input_spec = scale_rules.input_spec + input_spec = input_spec[:-1] + (prefix + "_act_dim",) + input_spec[-1:] + amax = (BATCHING + prefix + "_amax",) + scale = (BATCHING + prefix + "_scale",) return SdyShardingRule( + (tuple(input_spec), scale, amax), ( - x_axes, - ("…1",), + scale_rules.rowwise_out_spec, + scale_rules.colwise_out_spec, + scale_rules.rowwise_scale_spec, + scale_rules.colwise_scale_spec, amax, ), - (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax), **scale_rules.factor_sizes, ) @@ -612,7 +603,6 @@ def shardy_sharding_rule( register_primitive(ActLuPrimitive) -# TODO(Jeremy): replace is_2x with q_layout class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): """ DActLu DBias Cast Transpose Primitive @@ -620,7 +610,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): name = "te_dact_dbias_quantize_ffi" multiple_results = True - # out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling, is_outer + # out_dtype, scaling_mode, quantize_layout, scale_dtype, is_dbias, act_enum, act_len, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling, is_outer impl_static_args = (4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15) inner_primitive = None outer_primitive = None @@ -634,7 +624,7 @@ def abstract( *, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, is_dbias, act_enum, @@ -678,7 +668,7 @@ def abstract( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2) - if is_2x: + if quantize_layout.is_rowwise_colwise: if ScalingMode(scaling_mode).is_tensor_scaling(): colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2) else: @@ -700,7 +690,7 @@ def abstract( jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(out_dtype), scaling_mode, - is_2x, + quantize_layout.value, ) wkspace_shape = wkspace_info[0] wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1]) @@ -741,7 +731,7 @@ def lowering( *, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, is_dbias, act_enum, @@ -777,7 +767,7 @@ def lowering( scale, amax, scaling_mode=scaling_mode.value, - is_2x=is_2x, + quantize_layout=quantize_layout.value.value, is_dbias=is_dbias, act_enum=int(act_enum), act_params=act_params.to_ffi_lowering_dict(), @@ -792,7 +782,7 @@ def impl( amax, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, is_dbias, act_enum, @@ -816,7 +806,7 @@ def impl( amax, out_dtype=out_dtype, scaling_mode=scaling_mode, - is_2x=is_2x, + quantize_layout=quantize_layout, scale_dtype=scale_dtype, is_dbias=is_dbias, act_enum=act_enum, @@ -835,7 +825,7 @@ def impl( scale_inv = jax.lax.slice( scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape ) - if is_2x: + if quantize_layout.is_rowwise_colwise: colwise_scale_inv = jax.lax.slice( colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape ) @@ -848,7 +838,7 @@ def batcher( *, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, is_dbias, act_enum, @@ -883,7 +873,7 @@ def batcher( amax, out_dtype=out_dtype, scaling_mode=scaling_mode, - is_2x=is_2x, + quantize_layout=quantize_layout, scale_dtype=scale_dtype, is_dbias=is_dbias, act_enum=act_enum, @@ -901,7 +891,7 @@ def batcher( def infer_sharding_from_operands( out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, is_dbias, act_enum, @@ -928,7 +918,7 @@ def infer_sharding_from_operands( out_sharding = NamedSharding( mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out" ) - if is_2x: + if quantize_layout.is_rowwise_colwise: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2) else: @@ -954,7 +944,7 @@ def infer_sharding_from_operands( elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec - if is_2x: + if quantize_layout.is_rowwise_colwise: colwise_scale_inv_spec = scale_inv_spec scale_inv_sharding = NamedSharding( @@ -981,7 +971,7 @@ def infer_sharding_from_operands( def partition( out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, is_dbias, act_enum, @@ -1003,7 +993,7 @@ def partition( mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out" ) - if is_2x: + if quantize_layout.is_rowwise_colwise: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2) else: @@ -1029,7 +1019,7 @@ def partition( elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec - if is_2x: + if quantize_layout.is_rowwise_colwise: colwise_scale_inv_spec = scale_inv_spec scale_inv_sharding = NamedSharding( @@ -1066,7 +1056,7 @@ def sharded_impl(dz, x, scale, amax): amax, out_dtype=out_dtype, scaling_mode=scaling_mode, - is_2x=is_2x, + quantize_layout=quantize_layout, scale_dtype=scale_dtype, is_dbias=is_dbias, act_enum=act_enum, @@ -1102,7 +1092,7 @@ def sharded_impl(dz, x, scale, amax): def shardy_sharding_rule( out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, is_dbias, act_enum, @@ -1132,28 +1122,30 @@ def shardy_sharding_rule( ) prefix = "DActLuDBias_" + # get sharding rules base on the input shape scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - value_types[1].shape, unique_var=prefix + "x", flatten_axis=-2 + value_types[1].shape, + unique_var=prefix, + flatten_axis=-2, + q_layout=quantize_layout, ) - x_axes = scale_rules.input_spec - dz_axes = (*x_axes[:-2], x_axes[-1]) - out = x_axes - - colwise_out = (prefix + "out_colwise",) - colwise_scale_inv = (prefix + "scale_inv_colwise",) - if is_2x: - if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2)) - else: - colwise_out = out - colwise_scale_inv = scale_rules.colwise_rule - dbias = x_axes[-2:] if is_dbias else (prefix + "dbias",) - amax = (prefix + "amax",) + input_spec = scale_rules.input_spec + dz_spec = (*input_spec[:-2], input_spec[-1]) + dbias = input_spec[-2:] if is_dbias else (prefix + "_dbias",) + amax = (prefix + "_amax",) + scale = (prefix + "_scale",) return SdyShardingRule( - (dz_axes, x_axes, ("…2",), amax), - (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), + (tuple(dz_spec), tuple(input_spec), scale, amax), + ( + scale_rules.rowwise_out_spec, + scale_rules.colwise_out_spec, + scale_rules.rowwise_scale_spec, + scale_rules.colwise_scale_spec, + amax, + dbias, + ), **scale_rules.factor_sizes, ) @@ -1269,7 +1261,7 @@ def act_lu( return _jax_act_lu(x, activation_type, quantizer, act_params) # TE/common does not support colwise-only quantization yet - if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: + if quantizer is not None and quantizer.q_layout.is_colwise_only: return _jax_act_lu(x, activation_type, quantizer, act_params) # TE/common does not support 2x quantization for DelayedScaling yet war_output = try_apply_delayed_scaling_2x_war( @@ -1298,7 +1290,7 @@ def act_lu( act_enum=act_type_id, act_len=act_len, scaling_mode=ScalingMode.NO_SCALING.value, - is_2x=False, + quantize_layout=QuantizeLayout.ROWWISE, scale_dtype=jnp.float32, act_params=act_params, amax_scope=amax_scope, @@ -1354,7 +1346,7 @@ def act_lu( act_enum=act_type_id, act_len=act_len, scaling_mode=quantizer.scaling_mode.value, - is_2x=quantizer.is_2x2x(), + quantize_layout=quantizer.q_layout, scale_dtype=quantizer.get_scale_dtype(), act_params=act_params, amax_scope=amax_scope, @@ -1415,7 +1407,7 @@ def quantize_dact_dbias( act_type_id = ActivationEnum[activation_type] PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive if not PrimitiveClass.enabled() or ( - quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE + quantizer is not None and quantizer.q_layout.is_colwise_only ): return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer, act_params) if quantizer is None: @@ -1428,7 +1420,7 @@ def quantize_dact_dbias( out_dtype=(jnp.float32 if is_dbias else x.dtype), # default value for no scaling, TE/common ignore this value when scale is unset scaling_mode=ScalingMode.NO_SCALING.value, - is_2x=False, # unused + quantize_layout=QuantizeLayout.ROWWISE, # unused scale_dtype=jnp.float32, # unused is_dbias=False, act_enum=act_type_id, @@ -1555,7 +1547,7 @@ def quantize_dact_dbias( amax, out_dtype=quantizer.q_dtype, scaling_mode=quantizer.scaling_mode.value, - is_2x=quantizer.is_2x2x(), + quantize_layout=quantizer.q_layout, scale_dtype=quantizer.get_scale_dtype(), is_dbias=is_dbias, act_enum=act_type_id, @@ -1568,7 +1560,7 @@ def quantize_dact_dbias( ) # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise - if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x(): + if quantizer.scaling_mode.is_tensor_scaling() and quantizer.q_layout.is_rowwise_colwise: colwise_scale_inv = rowwise_scale_inv quantizer.update(updated_amax) diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 572d82f18d..f15fe72bad 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -207,7 +207,9 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant break # _quantize_dbias_impl forcing 1x quantization for tensor scaling switches q_layout to ROWWISE, # but this fails when bias fusion is turned on with arch < 100. - force_1x_quantization = quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x() + force_1x_quantization = ( + quantizer.scaling_mode.is_tensor_scaling() and quantizer.q_layout.is_rowwise_colwise + ) return ( (force_1x_quantization or quantizer.q_layout == QuantizeLayout.ROWWISE) and arch_l_100 @@ -229,7 +231,9 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1, @return: the output of 'f' with the colwise output calculated """ should_apply_war = ( - quantizer is not None and quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x() + quantizer is not None + and quantizer.scaling_mode.is_tensor_scaling() + and quantizer.q_layout.is_rowwise_colwise ) if not should_apply_war: return None diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index d09ce7ef74..c6bc053073 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -11,7 +11,7 @@ import jax import jax.numpy as jnp from jax import dtypes, ffi -from jax.experimental.custom_partitioning import SdyShardingRule +from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec @@ -112,7 +112,7 @@ def abstract( epsilon, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, amax_scope, transpose_batch_sequence, @@ -165,7 +165,7 @@ def abstract( updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) - colwise_out_shape = x_aval.shape if is_2x else (1,) + colwise_out_shape = x_aval.shape if quantize_layout.has_colwise else (1,) colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( @@ -173,7 +173,7 @@ def abstract( ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer) scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype) - colwise_scale_inv_shape = colwise_scale_inv_shape if is_2x else (1,) + colwise_scale_inv_shape = colwise_scale_inv_shape if quantize_layout.has_colwise else (1,) colwise_scale_inv_aval = jax.core.ShapedArray( shape=colwise_scale_inv_shape, dtype=scale_dtype ) @@ -189,7 +189,7 @@ def abstract( zero_centered_gamma, epsilon, get_forward_sm_margin(), - is_2x, + True, # is_training ) wkspace_aval = jax.core.ShapedArray( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) @@ -245,7 +245,7 @@ def lowering( epsilon, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, amax_scope, transpose_batch_sequence, @@ -287,7 +287,7 @@ def lowering( epsilon=epsilon, sm_margin=sm_margin, scaling_mode=scaling_mode.value, - is_2x=is_2x, + quantize_layout=quantize_layout.value.value, output_amax_when_no_scaling=output_amax_when_no_scaling, ) @@ -303,7 +303,7 @@ def impl( epsilon, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, amax_scope, transpose_batch_sequence, @@ -335,7 +335,7 @@ def impl( epsilon=epsilon, out_dtype=out_dtype, scaling_mode=scaling_mode, - is_2x=is_2x, + quantize_layout=quantize_layout, scale_dtype=scale_dtype, amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence, @@ -349,7 +349,7 @@ def impl( scale_inv = scale_inv.flatten()[: reduce(operator.mul, rowwise_scale_inv_shape, 1)].reshape( rowwise_scale_inv_shape ) - if is_2x: + if quantize_layout.has_colwise: colwise_scale_inv = colwise_scale_inv.flatten()[ : reduce(operator.mul, colwise_scale_inv_shape, 1) ].reshape(colwise_scale_inv_shape) @@ -373,7 +373,7 @@ def batcher( epsilon, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, amax_scope, transpose_batch_sequence, @@ -409,7 +409,7 @@ def batcher( epsilon=epsilon, out_dtype=out_dtype, scaling_mode=scaling_mode, - is_2x=is_2x, + quantize_layout=quantize_layout, scale_dtype=scale_dtype, amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence, @@ -426,7 +426,7 @@ def infer_sharding_from_operands( epsilon, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, amax_scope, transpose_batch_sequence, @@ -450,7 +450,7 @@ def infer_sharding_from_operands( ) out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out") - colwise_out_spec = out_spec if is_2x else (None,) + colwise_out_spec = out_spec if quantize_layout.has_colwise else (None,) colwise_out_sharding = NamedSharding( mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out" ) @@ -488,7 +488,7 @@ def partition( epsilon, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, amax_scope, transpose_batch_sequence, @@ -524,7 +524,7 @@ def partition( ) out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out") - colwise_out_spec = out_spec if is_2x else (None,) + colwise_out_spec = out_spec if quantize_layout.has_colwise else (None,) colwise_out_sharding = NamedSharding( mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out" ) @@ -586,7 +586,7 @@ def sharded_impl(x, scale, amax, gamma, beta): epsilon=epsilon, out_dtype=out_dtype, scaling_mode=scaling_mode, - is_2x=is_2x, + quantize_layout=quantize_layout, scale_dtype=scale_dtype, amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence, @@ -623,7 +623,7 @@ def shardy_sharding_rule( epsilon, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, amax_scope, transpose_batch_sequence, @@ -646,25 +646,29 @@ def shardy_sharding_rule( result_types, ) - prefix = "NormFwd_" + prefix = "NormFwd" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - value_types[0].shape, unique_var=prefix + "x", flatten_axis=-1 + value_types[0].shape, + unique_var=prefix, + flatten_axis=-1, + q_layout=quantize_layout, ) - x_axes = scale_rules.input_spec + input_spec = scale_rules.input_spec - out = x_axes - colwise_out = out if is_2x else (prefix + "out_colwise",) - rsigma = x_axes[:-1] - mu = (prefix + "mu",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma - amax = (prefix + "amax",) + rsigma = input_spec[:-1] + mu = (BATCHING + prefix + "_mu",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma + amax = (BATCHING + prefix + "_amax",) + scale = (BATCHING + prefix + "_scale",) + gamma = (BATCHING + prefix + "_gamma",) + beta = (BATCHING + prefix + "_beta",) return SdyShardingRule( - (x_axes, ("…1",), amax, ("…2",), ("…3",)), + (tuple(input_spec), scale, amax, gamma, beta), ( - out, - colwise_out, - scale_rules.rowwise_rule, - scale_rules.colwise_rule, + scale_rules.rowwise_out_spec, + scale_rules.colwise_out_spec, + scale_rules.rowwise_scale_spec, + scale_rules.colwise_scale_spec, amax, mu, rsigma, @@ -987,7 +991,7 @@ def layernorm_fwd( return (output, mu, rsigma) # TE/common does not support normalization with colwise only quantization yet - if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: + if quantizer is not None and quantizer.q_layout.is_colwise_only: return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer) scale = ( @@ -1008,7 +1012,7 @@ def layernorm_fwd( epsilon=epsilon, out_dtype=x.dtype, scaling_mode=ScalingMode.NO_SCALING.value, - is_2x=False, + quantize_layout=QuantizeLayout.ROWWISE, scale_dtype=jnp.float32, amax_scope=amax_scope, transpose_batch_sequence=False, @@ -1067,10 +1071,11 @@ def layernorm_fwd( ) return out, mu, rsigma - is_2x2x = quantizer.is_2x2x() - # TE/common normalization doesn't support 2x delayed scaling - if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling(): - is_2x2x = False + # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose + q_layout = quantizer.q_layout + if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling(): + q_layout = QuantizeLayout.ROWWISE + ( rowwise_casted_output, colwise_casted_output, @@ -1090,7 +1095,7 @@ def layernorm_fwd( epsilon=epsilon, out_dtype=quantizer.q_dtype, scaling_mode=quantizer.scaling_mode.value, - is_2x=is_2x2x, + quantize_layout=q_layout, scale_dtype=quantizer.get_scale_dtype(), amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence, @@ -1099,8 +1104,7 @@ def layernorm_fwd( ) quantizer.update(updated_amax) - # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose - if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling(): + if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling(): colwise_casted_output = jnp.transpose( rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1)) ) @@ -1238,7 +1242,7 @@ def rmsnorm_fwd( return (output, rsigma) # TE/common does not support normalization with colwise only quantization yet - if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: + if quantizer is not None and quantizer.q_layout.is_colwise_only: return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer) scale = ( @@ -1261,7 +1265,7 @@ def rmsnorm_fwd( epsilon=epsilon, out_dtype=x.dtype, scaling_mode=ScalingMode.NO_SCALING.value, - is_2x=False, + quantize_layout=QuantizeLayout.ROWWISE, scale_dtype=jnp.float32, amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence, @@ -1321,10 +1325,11 @@ def rmsnorm_fwd( ) return out, rsigma - is_2x2x = quantizer.is_2x2x() - # TE/common normalization doesn't support 2x delayed scaling - if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling(): - is_2x2x = False + # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose + q_layout = quantizer.q_layout + if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling(): + q_layout = QuantizeLayout.ROWWISE + ( rowwise_casted_output, colwise_casted_output, @@ -1344,7 +1349,7 @@ def rmsnorm_fwd( epsilon=epsilon, out_dtype=quantizer.q_dtype, scaling_mode=quantizer.scaling_mode.value, - is_2x=is_2x2x, + quantize_layout=q_layout, scale_dtype=quantizer.get_scale_dtype(), amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence, @@ -1353,8 +1358,7 @@ def rmsnorm_fwd( ) quantizer.update(updated_amax) - # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose - if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling(): + if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling(): colwise_casted_output = jnp.transpose( rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1)) ) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 67c505bc98..424917a30f 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -11,7 +11,7 @@ import jax import jax.numpy as jnp from jax import dtypes, ffi -from jax.experimental.custom_partitioning import SdyShardingRule +from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING from jax.sharding import PartitionSpec import transformer_engine_jax @@ -122,7 +122,7 @@ def abstract( f" stochastic_rounding is True but received {sr_rng_state_aval.shape}" ) - if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + if QuantizeLayout(q_layout).has_rowwise: rowwise_out_shape = out_shape else: rowwise_out_shape = (1,) @@ -170,7 +170,7 @@ def abstract( broadcast_2d_scale_shape_to_1d=True, ) - if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + if QuantizeLayout(q_layout).has_colwise: if ScalingMode(scaling_mode).is_colwise_transposed: colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis) else: @@ -194,9 +194,7 @@ def abstract( jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(scale_dtype), scaling_mode, - QuantizeLayout( - q_layout - ), # For now until we have auto-decoding for QuantizeLayout enum + q_layout.value, ) wkspace_shape = wkspace_info[0] wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1]) @@ -272,7 +270,7 @@ def lowering( post_rht_amax, rht_matrix, scaling_mode=scaling_mode.value, - q_layout=q_layout, + q_layout=q_layout.value.value, flatten_axis=flatten_axis, is_dbias=is_dbias, stochastic_rounding=stochastic_rounding, @@ -335,7 +333,7 @@ def impl( scale_inv = jax.lax.slice( scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape ) - if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + if q_layout.has_colwise: colwise_scale_inv = jax.lax.slice( colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape ) @@ -424,7 +422,7 @@ def infer_sharding_from_operands( PartitionSpec(*x_spec), desc="BaseDBiasQuantizePrimitive.out_sharding", ) - if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + if q_layout.has_colwise: if ScalingMode(scaling_mode).is_colwise_transposed: colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) else: @@ -448,7 +446,7 @@ def infer_sharding_from_operands( if ScalingMode(scaling_mode).is_block_scaling: scale_inv_spec = x_spec - if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + if q_layout.has_colwise: if ( ScalingMode(scaling_mode).is_block_scaling and ScalingMode(scaling_mode).is_colwise_transposed @@ -505,7 +503,7 @@ def partition( desc="BaseDBiasQuantizePrimitive.out_sharding", ) - if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + if q_layout.has_colwise: if ScalingMode(scaling_mode).is_colwise_transposed: colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) else: @@ -529,7 +527,7 @@ def partition( if ScalingMode(scaling_mode).is_block_scaling: scale_inv_spec = x_spec - if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + if q_layout.has_colwise: if ( ScalingMode(scaling_mode).is_block_scaling and ScalingMode(scaling_mode).is_colwise_transposed @@ -643,39 +641,37 @@ def shardy_sharding_rule( result_types, ) - prefix = "DBiasQuantize_" + prefix = "DBiasQuantize" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( value_types[0].shape, - unique_var=prefix + "x", + unique_var=prefix, flatten_axis=flatten_axis, + q_layout=q_layout, broadcast_2d_scale_shape_to_1d=True, ) - x_axes = scale_rules.input_spec - - out = x_axes - colwise_out = (prefix + "out_colwise",) - colwise_scale_inv = (prefix + "colwise_scale_inv",) - if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - colwise_scale_inv = scale_rules.colwise_rule - if ScalingMode(scaling_mode).is_colwise_transposed: - colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis)) - colwise_scale_inv = tuple( - multidim_transpose(colwise_scale_inv, transpose_axis=flatten_axis) - ) - else: - colwise_out = x_axes - - dbias = x_axes[flatten_axis:] if is_dbias else (prefix + "dbias",) - amax = (prefix + "amax",) - sr_rng_state = (prefix + "sr_rng_state_partition_axis", prefix + "sr_rng_state_data_axis") + input_spec = scale_rules.input_spec + dbias = input_spec[flatten_axis:] if is_dbias else (prefix + "_dbias",) + amax = (BATCHING + prefix + "_amax",) + scale = (BATCHING + prefix + "_scale",) + sr_rng_state = ( + BATCHING + prefix + "_sr_rng_state_partition_axis", + BATCHING + prefix + "sr_rng_state_data_axis", + ) - post_rht_amax = (prefix + "post_rht_amax",) - rht_matrix = (prefix + "rht_matrix_1", prefix + "rht_matrix_2") + post_rht_amax = (BATCHING + prefix + "_post_rht_amax",) + rht_matrix = (BATCHING + prefix + "_rht_matrix_1", BATCHING + prefix + "_rht_matrix_2") return SdyShardingRule( - (x_axes, ("…1",), amax, sr_rng_state, post_rht_amax, rht_matrix), - (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), + (tuple(input_spec), scale, amax, sr_rng_state, post_rht_amax, rht_matrix), + ( + scale_rules.rowwise_out_spec, + scale_rules.colwise_out_spec, + scale_rules.rowwise_scale_spec, + scale_rules.colwise_scale_spec, + amax, + dbias, + ), **scale_rules.factor_sizes, ) @@ -762,7 +758,7 @@ def _quantize_dbias_impl( # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE, # fall back on the native-JAX quantize implementation PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive - is_unsupported = quantizer.q_layout == QuantizeLayout.COLWISE and not ( + is_unsupported = quantizer.q_layout.is_colwise_only and not ( quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING and hasattr(quantizer, "use_rht") and quantizer.use_rht @@ -845,7 +841,7 @@ def _quantize_dbias_impl( is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100) force_1x_quantization = ( quantizer.scaling_mode.is_tensor_scaling() - and quantizer.is_2x2x() + and quantizer.q_layout.is_rowwise_colwise and is_1x_kernel_supported ) q_layout = quantizer.q_layout @@ -879,7 +875,7 @@ def _quantize_dbias_impl( rht_matrix, out_dtype=quantizer.q_dtype, scaling_mode=quantizer.scaling_mode.value, - q_layout=q_layout.value, + q_layout=q_layout, flatten_axis=flatten_axis, scale_dtype=quantizer.get_scale_dtype(), is_dbias=is_dbias if not quantizer.scaling_mode.is_nvfp4_scaling else False, @@ -888,10 +884,10 @@ def _quantize_dbias_impl( use_rht=use_rht, ) # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise - if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x(): + if quantizer.scaling_mode.is_tensor_scaling() and quantizer.q_layout.is_rowwise_colwise: colwise_scale_inv = rowwise_scale_inv - if q_layout == QuantizeLayout.ROWWISE: + if q_layout.is_rowwise_only: # Quantizer requires 2x quantization, but we are using 1x quantization # for performance reasons, so we need to generate the colwise data in JAX if flatten_axis < 0: @@ -1043,7 +1039,7 @@ def abstract( flatten_axis=flatten_axis, ) - if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + if q_layout.has_rowwise: rowwise_out_shape = out_shape else: rowwise_out_shape = (1,) @@ -1052,7 +1048,7 @@ def abstract( amax_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32) - if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + if q_layout.has_colwise: colwise_out_shape = out_shape else: colwise_out_shape = (1,) @@ -1240,7 +1236,7 @@ def grouped_quantize( ) # WAR for tensor_scaling as TE/Common does not support q_layout = COLWISE yet # So we performance ROWWISE_COLWISE and use the colwise_tensor_output - apply_colwise_war = is_tensor_scaling and quantizer.q_layout == QuantizeLayout.COLWISE + apply_colwise_war = is_tensor_scaling and quantizer.q_layout.is_colwise_only q_layout = QuantizeLayout.ROWWISE_COLWISE if apply_colwise_war else quantizer.q_layout ( rowwise_casted_output, @@ -1262,7 +1258,7 @@ def grouped_quantize( # For DelayedScaling2x and CurrentScaling2x, the scale buffer # is shared between rowwise and colwise - if is_tensor_scaling and quantizer.is_2x2x() or apply_colwise_war: + if is_tensor_scaling and quantizer.q_layout.is_rowwise_colwise or apply_colwise_war: colwise_scale_inv = rowwise_scale_inv # TODO(Phuong): store the whole updated_amax in the grouped_quantize instead? diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 87c6fa91cd..c1c7e0d665 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -57,7 +57,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler); pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, - JAXX_Scaling_Mode scaling_mode, bool is_2x); + JAXX_Scaling_Mode scaling_mode, + JAXX_Quantize_Layout quantize_layout); // Normalization XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardInitializeHandler); @@ -87,7 +88,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler); pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, DType scale_dtype, JAXX_Scaling_Mode scaling_mode, - QuantizeLayout q_layout); + JAXX_Quantize_Layout quantize_layout); // Softmax XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler); @@ -162,5 +163,6 @@ XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( // ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Collective_Op); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Quantize_Layout); #endif // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index f512321c38..34ce29ae13 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -18,7 +18,8 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode, - bool is_2x_int, ActivationConfig act_params, bool output_amax_when_no_scaling) { + JAXX_Quantize_Layout quantize_layout, ActivationConfig act_params, + bool output_amax_when_no_scaling) { // parameters for clamped swiglu used in GPT OSS auto swiglu_limit = act_params.clamped_swiglu.limit; auto swiglu_alpha = act_params.clamped_swiglu.alpha; @@ -40,7 +41,6 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal auto n = input_dims.back(); auto act_type = static_cast(act_enum); auto act_len = input_dims[input_dims.size() - 2]; - auto is_2x = static_cast(is_2x_int); auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis auto input_shape = std::vector{m, static_cast(act_len * n)}; @@ -77,7 +77,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal } } - if (is_2x) { + if (is_quantize_2x2x(quantize_layout)) { auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape; @@ -158,7 +158,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, .Ret() // updated_amax .Attr("act_enum") .Attr("scaling_mode") - .Attr("is_2x") + .Attr("quantize_layout") .Attr("act_params") .Attr("output_amax_when_no_scaling"), FFI_CudaGraph_Traits); @@ -167,11 +167,12 @@ Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, - int64_t act_enum, JAXX_Scaling_Mode scaling_mode, bool is_2x_int, - ActivationConfig act_params, bool output_amax_when_no_scaling) { + int64_t act_enum, JAXX_Scaling_Mode scaling_mode, + JAXX_Quantize_Layout quantize_layout, ActivationConfig act_params, + bool output_amax_when_no_scaling) { return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, amax_buf, output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, - updated_amax_buf, act_enum, scaling_mode, is_2x_int, act_params, + updated_amax_buf, act_enum, scaling_mode, quantize_layout, act_params, output_amax_when_no_scaling); } @@ -188,13 +189,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI, .Ret() // updated_amax .Attr("act_enum") .Attr("scaling_mode") - .Attr("is_2x") + .Attr("quantize_layout") .Attr("act_params") .Attr("output_amax_when_no_scaling")); pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, - JAXX_Scaling_Mode scaling_mode, bool is_2x) { + JAXX_Scaling_Mode scaling_mode, + JAXX_Quantize_Layout quantize_layout) { auto input_shape = std::vector{batch_size, hidden_size}; auto dact_input_shape = std::vector{batch_size, hidden_size}; auto output_shape = std::vector{batch_size, hidden_size}; @@ -226,7 +228,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid std::vector{1}); } - if (is_2x) { + if (is_quantize_2x2x(quantize_layout)) { auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape : output_shape; output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, tmp_shape); @@ -260,9 +262,9 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, Result_Type dbias_buf, Result_Type workspace_buf, - JAXX_Scaling_Mode scaling_mode, int64_t act_enum, bool is_2x, - bool is_dbias, ActivationConfig act_params, - bool output_amax_when_no_scaling) { + JAXX_Scaling_Mode scaling_mode, int64_t act_enum, + JAXX_Quantize_Layout quantize_layout, bool is_dbias, + ActivationConfig act_params, bool output_amax_when_no_scaling) { // parameters for clamped swiglu used in GPT OSS auto swiglu_limit = act_params.clamped_swiglu.limit; auto swiglu_alpha = act_params.clamped_swiglu.alpha; @@ -340,7 +342,7 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, } } - if (is_2x) { + if (is_quantize_2x2x(quantize_layout)) { auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape; @@ -370,7 +372,8 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, // fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead NVTE_CHECK(!(act_len == 2 && is_dbias), "Unsupported DGatedActedDBias Fusion!"); - NVTE_CHECK(!(scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_2x && act_len == 2), + NVTE_CHECK(!(scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && + is_quantize_2x2x(quantize_layout) && act_len == 2), "TE/common does not support delayed scaling for 2x with gated activations."); if (is_dbias) { @@ -465,7 +468,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI .Ret() // wkspace .Attr("scaling_mode") .Attr("act_enum") - .Attr("is_2x") + .Attr("quantize_layout") .Attr("is_dbias") .Attr("act_params") .Attr("output_amax_when_no_scaling"), @@ -476,13 +479,13 @@ Error_Type DActLuDBiasQuantizeInitializeFFI( Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, - int64_t act_enum, bool is_2x, bool is_dbias, ActivationConfig act_params, - bool output_amax_when_no_scaling) { + int64_t act_enum, JAXX_Quantize_Layout quantize_layout, bool is_dbias, + ActivationConfig act_params, bool output_amax_when_no_scaling) { return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf, act_input_buf, scale_buf, amax_buf, output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, updated_amax_buf, dbias_buf, - workspace_buf, scaling_mode, act_enum, is_2x, is_dbias, act_params, - output_amax_when_no_scaling); + workspace_buf, scaling_mode, act_enum, quantize_layout, is_dbias, + act_params, output_amax_when_no_scaling); } XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, @@ -502,7 +505,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, .Ret() // wkspace .Attr("scaling_mode") .Attr("act_enum") - .Attr("is_2x") + .Attr("quantize_layout") .Attr("is_dbias") .Attr("act_params") .Attr("output_amax_when_no_scaling")); diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index 07e9aec7e9..21b50c1af4 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -34,12 +34,24 @@ inline size_t product(const std::vector &shape) { return ret; } -enum class QuantizeLayout { +enum class JAXX_Quantize_Layout : int64_t { ROWWISE, COLWISE, ROWWISE_COLWISE, }; +inline bool is_quantize_rowwise(const JAXX_Quantize_Layout &layout) { + return layout == JAXX_Quantize_Layout::ROWWISE || layout == JAXX_Quantize_Layout::ROWWISE_COLWISE; +} + +inline bool is_quantize_colwise(const JAXX_Quantize_Layout &layout) { + return layout == JAXX_Quantize_Layout::COLWISE || layout == JAXX_Quantize_Layout::ROWWISE_COLWISE; +} + +inline bool is_quantize_2x2x(const JAXX_Quantize_Layout &layout) { + return layout == JAXX_Quantize_Layout::ROWWISE_COLWISE; +} + enum class JAXX_Scaling_Mode : int64_t { NO_SCALING = 0, DELAYED_TENSOR_SCALING = 1, diff --git a/transformer_engine/jax/csrc/extensions/normalization.cpp b/transformer_engine/jax/csrc/extensions/normalization.cpp index 378e009c83..b01e23c128 100644 --- a/transformer_engine/jax/csrc/extensions/normalization.cpp +++ b/transformer_engine/jax/csrc/extensions/normalization.cpp @@ -66,7 +66,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc Result_Type updated_amax_buf, Result_Type mu_buf, Result_Type rsigma_buf, Result_Type wkspace_buf, int norm_type, bool zero_centered_gamma, double epsilon, int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, - bool is_2x, bool output_amax_when_no_scaling) { + JAXX_Quantize_Layout quantize_layout, bool output_amax_when_no_scaling) { auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type()); @@ -86,7 +86,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased"); auto _norm_type = static_cast(norm_type); - auto _is_2x = static_cast(is_2x); auto x_size = product(x_buf.dimensions()); auto gamma_size = product(gamma_buf.dimensions()); @@ -134,7 +133,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); } - if (_is_2x) { + if (is_quantize_2x2x(quantize_layout)) { output_tensor.set_columnwise_data(colwise_output_buf->untyped_data(), static_cast(out_dtype), input_shape); output_tensor.set_columnwise_scale_inv( @@ -185,25 +184,23 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI, .Attr("epsilon") .Attr("sm_margin") .Attr("scaling_mode") - .Attr("is_2x") + .Attr("quantize_layout") .Attr("output_amax_when_no_scaling"), FFI_CudaGraph_Traits); -Error_Type NormForwardInitializeFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf, - Buffer_Type amax_buf, Buffer_Type gamma_buf, - Buffer_Type beta_buf, Result_Type output_buf, - Result_Type colwise_output_buf, Result_Type scale_inv_buf, - Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, - Result_Type mu_buf, Result_Type rsigma_buf, - Result_Type wkspace_buf, int norm_type, - bool zero_centered_gamma, double epsilon, int64_t sm_margin, - JAXX_Scaling_Mode scaling_mode, bool is_2x, - bool output_amax_when_no_scaling) { +Error_Type NormForwardInitializeFFI( + cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf, Buffer_Type amax_buf, + Buffer_Type gamma_buf, Buffer_Type beta_buf, Result_Type output_buf, + Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, + Result_Type updated_amax_buf, Result_Type mu_buf, Result_Type rsigma_buf, + Result_Type wkspace_buf, int norm_type, bool zero_centered_gamma, double epsilon, + int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, JAXX_Quantize_Layout quantize_layout, + bool output_amax_when_no_scaling) { return wrapInStreamCapture(std::function(NormForwardFFI), stream, x_buf, scale_buf, amax_buf, gamma_buf, beta_buf, output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, updated_amax_buf, mu_buf, rsigma_buf, wkspace_buf, norm_type, zero_centered_gamma, epsilon, sm_margin, - scaling_mode, is_2x, output_amax_when_no_scaling); + scaling_mode, quantize_layout, output_amax_when_no_scaling); } XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializeFFI, @@ -227,7 +224,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializ .Attr("epsilon") .Attr("sm_margin") .Attr("scaling_mode") - .Attr("is_2x") + .Attr("quantize_layout") .Attr("output_amax_when_no_scaling")); pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index d740df0e2a..e57d07872e 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -176,11 +176,10 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("NVFP4_2D_SCALING", JAXX_Scaling_Mode::NVFP4_2D_SCALING) .export_values(); - pybind11::enum_(m, "QuantizeLayout", - pybind11::module_local()) - .value("ROWWISE", transformer_engine::jax::QuantizeLayout::ROWWISE) - .value("COLWISE", transformer_engine::jax::QuantizeLayout::COLWISE) - .value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE) + pybind11::enum_(m, "JAXX_Quantize_Layout", pybind11::module_local()) + .value("ROWWISE", JAXX_Quantize_Layout::ROWWISE) + .value("COLWISE", JAXX_Quantize_Layout::COLWISE) + .value("ROWWISE_COLWISE", JAXX_Quantize_Layout::ROWWISE_COLWISE) .export_values(); pybind11::enum_(m, "JAXX_Collective_Op", pybind11::module_local()) diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index a45a698822..1f7db84383 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -20,7 +20,7 @@ namespace jax { pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, DType scale_dtype, JAXX_Scaling_Mode scaling_mode, - QuantizeLayout q_layout) { + JAXX_Quantize_Layout q_layout) { auto input_shape = std::vector{batch_size, hidden_size}; auto output_shape = std::vector{batch_size, hidden_size}; auto output_trans_shape = std::vector{hidden_size, batch_size}; @@ -42,7 +42,7 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); auto scale_shape = std::vector{1}; // Only the pointers will be checked for scale_inv, thus the shapes do not matter - if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::ROWWISE) { + if (is_quantize_rowwise(q_layout)) { output_tensor.set_rowwise_data(reinterpret_cast(&temp), out_dtype, output_shape); if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) { if (is_nvfp4) @@ -52,7 +52,7 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ } } - if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::COLWISE) { + if (is_quantize_colwise(q_layout)) { auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape : output_shape; output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, tmp_shape); @@ -90,8 +90,8 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, - int64_t quantize_layout_enum, bool is_dbias, int64_t flatten_axis, - bool stochastic_rounding, bool use_rht) { + JAXX_Quantize_Layout quantize_layout, bool is_dbias, + int64_t flatten_axis, bool stochastic_rounding, bool use_rht) { auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); @@ -101,8 +101,6 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T auto *input = input_buf.untyped_data(); - auto const quantize_layout = static_cast(quantize_layout_enum); - auto *output = output_buf->untyped_data(); auto *output_trans = output_trans_buf->untyped_data(); auto *dbias = dbias_buf->untyped_data(); @@ -127,15 +125,13 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING; - bool const is_mxfp8 = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; bool const is_nvfp4 = scaling_mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING || scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING; NVTE_CHECK(!stochastic_rounding || is_nvfp4, "Stochastic rounding is only supported for NVFP4."); NVTE_CHECK(!use_rht || is_nvfp4, "RHT is only supported for NVFP4 scaling"); - if (quantize_layout == QuantizeLayout::ROWWISE || - quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { + if (is_quantize_rowwise(quantize_layout)) { output_tensor.set_rowwise_data(output, out_dtype, output_shape); if (is_tensor_scaling) { @@ -180,10 +176,9 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T quant_config.set_rng_state(sr_rng_state_tensor.data()); } - if (quantize_layout == QuantizeLayout::COLWISE || - quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { + if (is_quantize_colwise(quantize_layout)) { if (is_nvfp4 && use_rht) { - if (quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { + if (is_quantize_2x2x(quantize_layout)) { // Do regular rowwise quantization without RHT nvte_quantize_v2(input_tensor.data(), output_tensor.data(), quant_config, stream); } @@ -281,7 +276,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI, .Ret() // dbias .Ret() // wkspace .Attr("scaling_mode") - .Attr("q_layout") + .Attr("q_layout") .Attr("is_dbias") .Attr("flatten_axis") .Attr("stochastic_rounding") @@ -323,7 +318,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty Buffer_Type group_sizes, Result_Type outputs, Result_Type colwise_outputs, Result_Type scale_invs, Result_Type colwise_scale_invs, Result_Type amaxs, - JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum, + JAXX_Scaling_Mode scaling_mode, JAXX_Quantize_Layout quantize_layout, int64_t flatten_axis) { NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::NO_SCALING, "Unsupported scaling mode: ", static_cast(scaling_mode)); @@ -336,7 +331,6 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty auto group_size_dtype = convert_ffi_datatype_to_te_dtype(group_sizes.element_type()); auto sinv_dtype = convert_ffi_datatype_to_te_dtype(scale_invs->element_type()); auto amax_dtype = convert_ffi_datatype_to_te_dtype(amaxs->element_type()); - auto const quantize_layout = static_cast(quantize_layout_enum); auto *input_ptr = reinterpret_cast(inputs.untyped_data()); auto *scale_ptr = reinterpret_cast(scales.untyped_data()); @@ -346,10 +340,6 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty auto *colwise_sinv_ptr = reinterpret_cast(colwise_scale_invs->untyped_data()); auto *amax_ptr = reinterpret_cast(amaxs->untyped_data()); - bool has_rowwise = quantize_layout == QuantizeLayout::ROWWISE || - quantize_layout == QuantizeLayout::ROWWISE_COLWISE; - bool has_colwise = quantize_layout == QuantizeLayout::COLWISE || - quantize_layout == QuantizeLayout::ROWWISE_COLWISE; bool is_delayed_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING; bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING; @@ -359,8 +349,8 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty size_t output_dtype_bytes = te_dtype_bytes(out_dtype); size_t sinv_dtype_bytes = te_dtype_bytes(sinv_dtype); size_t group_size_dtype_bytes = te_dtype_bytes(group_size_dtype); - size_t colwise_output_dtype_bytes = has_colwise ? output_dtype_bytes : 0; - size_t colwise_sinv_dtype_bytes = has_colwise ? sinv_dtype_bytes : 0; + size_t colwise_output_dtype_bytes = is_quantize_colwise(quantize_layout) ? output_dtype_bytes : 0; + size_t colwise_sinv_dtype_bytes = is_quantize_colwise(quantize_layout) ? sinv_dtype_bytes : 0; size_t scale_dtype_bytes = is_tensor_scaling ? te_dtype_bytes(scale_dtype) : 0; size_t amax_dtype_bytes = is_tensor_scaling ? te_dtype_bytes(amax_dtype) : 0; @@ -423,7 +413,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty auto inp_i = TensorWrapper(static_cast(input_ptr), shape_i, in_dtype); auto out_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - if (has_rowwise) { + if (is_quantize_rowwise(quantize_layout)) { out_i.set_rowwise_data(static_cast(output_ptr), out_dtype, shape_i); if (is_fp8_dtype(out_dtype)) { @@ -442,7 +432,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty } } - if (has_colwise) { + if (is_quantize_colwise(quantize_layout)) { auto &tmp_shape = is_tensor_scaling ? shape_trans_i : shape_i; out_i.set_columnwise_data(static_cast(colwise_output_ptr), out_dtype, tmp_shape); // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling @@ -501,7 +491,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI, .Ret() // scale_inv colwise .Ret() // amax .Attr("scaling_mode") - .Attr("q_layout") + .Attr("q_layout") .Attr("flatten_axis")); } // namespace jax diff --git a/transformer_engine/jax/quantize/__init__.py b/transformer_engine/jax/quantize/__init__.py index 9616965c75..878067a783 100644 --- a/transformer_engine/jax/quantize/__init__.py +++ b/transformer_engine/jax/quantize/__init__.py @@ -17,3 +17,4 @@ from .hadamard import * from .helper import * from .device_utils import * +from .misc import * diff --git a/transformer_engine/jax/quantize/misc.py b/transformer_engine/jax/quantize/misc.py new file mode 100644 index 0000000000..21af4813d8 --- /dev/null +++ b/transformer_engine/jax/quantize/misc.py @@ -0,0 +1,64 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +""" +This module provides additional enum and utilities for quantizing tensors in JAX. +""" +# from jax.tree_util import register_pytree_node_class +# from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum + +from transformer_engine_jax import JAXX_Quantize_Layout + +__all__ = [ + "QuantizeLayout", +] + + +# @register_pytree_node_class +@dataclass(frozen=True) +class QuantizeLayout(Enum): + "Wrapper for JAXX_Quantize_Layout" + + ROWWISE = JAXX_Quantize_Layout.ROWWISE + COLWISE = JAXX_Quantize_Layout.COLWISE + ROWWISE_COLWISE = JAXX_Quantize_Layout.ROWWISE_COLWISE + + @property + def has_rowwise(self) -> bool: + """If the layout has the rowwise component""" + return self.value in (JAXX_Quantize_Layout.ROWWISE, JAXX_Quantize_Layout.ROWWISE_COLWISE) + + @property + def has_colwise(self) -> bool: + """If the layout has the colwise component""" + return self.value in (JAXX_Quantize_Layout.COLWISE, JAXX_Quantize_Layout.ROWWISE_COLWISE) + + @property + def is_rowwise_colwise(self) -> bool: + """If layout is both rowwise and colwise""" + return self.value == JAXX_Quantize_Layout.ROWWISE_COLWISE + + @property + def is_rowwise_only(self) -> bool: + """If layout is rowwise only""" + return self.value == JAXX_Quantize_Layout.ROWWISE + + @property + def is_colwise_only(self) -> bool: + """If layout is colwise only""" + return self.value == JAXX_Quantize_Layout.COLWISE + + def __eq__(self, other): + """Compare this quantize layout with another. + + Args: + other: The other quantize layout to compare with + + Returns: + True if the modes are equal, False otherwise + """ + if not isinstance(other, QuantizeLayout): + return False + return self.value == other.value diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index eb2b7b5924..705c463de0 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -15,10 +15,10 @@ import jax import jax.numpy as jnp from jax.tree_util import register_pytree_node_class -from transformer_engine_jax import QuantizeLayout from transformer_engine.common import recipe from .scaling_modes import ScalingMode +from .misc import QuantizeLayout from .hadamard import apply_rht from .tensor import ( ScaledTensor, @@ -37,7 +37,6 @@ from ..sharding import get_num_devices_in_mesh __all__ = [ - "QuantizeLayout", "Quantizer", "QuantizerSet", "CurrentScaleQuantizer", @@ -118,14 +117,6 @@ def update(self, *args, **kwargs): """Update quantizer state (no-op in base class).""" del args, kwargs - def is_2x2x(self) -> bool: - """Check if quantizer uses both row-wise and column-wise quantization. - - Returns: - True if using both row-wise and column-wise quantization - """ - return self.q_layout == QuantizeLayout.ROWWISE_COLWISE - def get_data_layout(self) -> str: """Get the data data_layout string. @@ -135,11 +126,11 @@ def get_data_layout(self) -> str: Raises: ValueError: If quantization axis is invalid """ - if self.q_layout == QuantizeLayout.ROWWISE_COLWISE: + if self.q_layout.is_rowwise_colwise: return self.data_layout - if self.q_layout == QuantizeLayout.ROWWISE: + if self.q_layout.is_rowwise_only: return self.data_layout[0] - if self.q_layout == QuantizeLayout.COLWISE: + if self.q_layout.is_colwise_only: return self.data_layout[1] raise ValueError(f"Invalid q_layout: {self.q_layout}") @@ -174,16 +165,8 @@ def quantize( """ del kwargs - is_rowwise = ( - is_rowwise - if is_rowwise is not None - else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x()) - ) - is_colwise = ( - is_colwise - if is_colwise is not None - else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x()) - ) + is_rowwise = is_rowwise if is_rowwise is not None else self.q_layout.has_rowwise + is_colwise = is_colwise if is_colwise is not None else self.q_layout.has_colwise if (is_rowwise and is_colwise) or self.is_2x2x(): rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis) @@ -299,16 +282,8 @@ def quantize( flatten_axis += x.ndim assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!" - is_rowwise = ( - is_rowwise - if is_rowwise is not None - else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x()) - ) - is_colwise = ( - is_colwise - if is_colwise is not None - else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x()) - ) + is_rowwise = is_rowwise if is_rowwise is not None else self.q_layout.has_rowwise + is_colwise = is_colwise if is_colwise is not None else self.q_layout.has_colwise rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis) colwise_tensor = None @@ -974,16 +949,8 @@ def quantize( flatten_axis += x.ndim assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!" - is_rowwise = ( - is_rowwise - if is_rowwise is not None - else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x()) - ) - is_colwise = ( - is_colwise - if is_colwise is not None - else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x()) - ) + is_rowwise = is_rowwise if is_rowwise is not None else self.q_layout.has_rowwise + is_colwise = is_colwise if is_colwise is not None else self.q_layout.has_colwise assert is_rowwise or is_colwise, "No quantization layout is specified" original_shape = x.shape diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index d490e02752..5df7103782 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -21,7 +21,8 @@ from jax.tree_util import register_pytree_node_class import jax.numpy as jnp -from transformer_engine_jax import JAXX_Scaling_Mode, QuantizeLayout +from transformer_engine_jax import JAXX_Scaling_Mode +from .misc import QuantizeLayout from .device_utils import is_fp8_gemm_with_all_layouts_supported @@ -72,16 +73,18 @@ class QuantizeShardyRules: Attributes: input_spec: Specification for the input axes - rowwise_rule: Sharding rule for the row-wise scale tensor, depends on - the axes in `input_spec` - colwise_rule: Likewise for the column-wise scale tensor. - factor_sizes: For block scaling, contains the block size factor, which is - used in `input_spec`. + rowwise_out_spec: Sharding spec for the rowwise quantized data + rowwise_scale_spec: Sharding spec for the rowwise scale + colwise_out_spec: Sharding spec for the colwise quantized data + colwise_scale_spec: Sharding spec for the colwise scale + factor_sizes: For block scaling, contains the block size factor """ input_spec: Tuple[str] - rowwise_rule: Tuple[str] - colwise_rule: Tuple[str] + rowwise_out_spec: Tuple[str] + rowwise_scale_spec: Tuple[str] + colwise_out_spec: Tuple[str] + colwise_scale_spec: Tuple[str] factor_sizes: Dict[str, int] @@ -166,7 +169,9 @@ def get_shardy_sharding_rules( input_shape, unique_var, flatten_axis, + q_layout, broadcast_2d_scale_shape_to_1d, + is_colwise_transposed, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. @@ -174,7 +179,9 @@ def get_shardy_sharding_rules( input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix flatten_axis: Axis along which data can be flattened to 2D for quantization + q_layout: The layout of the quantized tensor broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. + is_colwise_transposed: Whether the column-wise tensors are transposed. Returns: The Shardy rules for the scaling mode @@ -268,7 +275,9 @@ def get_shardy_sharding_rules( input_shape, unique_var, flatten_axis, + q_layout, broadcast_2d_scale_shape_to_1d, + is_colwise_transposed, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. @@ -281,10 +290,17 @@ def get_shardy_sharding_rules( Returns: The Shardy rules for the scaling mode """ - del flatten_axis, broadcast_2d_scale_shape_to_1d - input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape))) - scale_var = BATCHING + unique_var + "_scale_inv" - return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) + del broadcast_2d_scale_shape_to_1d + input_spec = tuple(f"{unique_var}_x_{i}" for i in range(len(input_shape))) + output_spec = tuple(input_spec) + return QuantizeShardyRules( + input_spec, + output_spec, + (BATCHING + f"{unique_var}_scale",), + (BATCHING + f"{unique_var}_colwise_output",), + (BATCHING + f"{unique_var}_colwise_scale",), + {}, + ) class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl): @@ -376,7 +392,9 @@ def get_shardy_sharding_rules( input_shape, unique_var, flatten_axis, + q_layout, broadcast_2d_scale_shape_to_1d, + is_colwise_transposed, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. @@ -385,14 +403,39 @@ def get_shardy_sharding_rules( unique_var: An otherwise unused Shardy variable name prefix flatten_axis: Axis along which data can be flattened to 2D for quantization broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. - + q_layout: The layout of the quantized tensor + is_colwise_transposed: Whether the colwise scaling is transposed Returns: The Shardy rules for the scaling mode """ del flatten_axis, broadcast_2d_scale_shape_to_1d - input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape))) - scale_var = BATCHING + unique_var + "_scale_inv" - return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) + input_spec = tuple(f"{unique_var}x_{i}" for i in range(len(input_shape))) + output_spec = input_spec.copy() + colwise_output_spec = BATCHING + f"{unique_var}_colwise_output" + + if q_layout.has_colwise: + from ..cpp_extensions.misc import multidim_transpose + + colwise_output_spec = input_spec.copy() + if is_colwise_transposed: + colwise_output_spec = multidim_transpose( + colwise_output_spec, transpose_axis=flatten_axis + ) + scale = BATCHING + unique_var + "_scale_inv" + return QuantizeShardyRules( + tuple(input_spec), + tuple(output_spec), + tuple( + scale, + ), + tuple( + colwise_output_spec, + ), + tuple( + scale, + ), + {}, + ) class DelayedScalingModeMetadataImpl(CurrentScalingModeMetadataImpl): @@ -658,7 +701,9 @@ def get_shardy_sharding_rules( input_shape, unique_var, flatten_axis, + q_layout, broadcast_2d_scale_shape_to_1d, + is_colwise_transposed, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. @@ -666,15 +711,18 @@ def get_shardy_sharding_rules( input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix flatten_axis: Axis along which data can be flattened to 2D for quantization + q_layout: The layout of the quantized tensor broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. - + is_colwise_transposed: Whether the column-wise tensors are transposed. Returns: The Shardy rules for the scaling mode """ - # TODO(Phuong): to rework the shardy rule to handle transposes after NVFP4 is upstreamed + is_rowwise = q_layout.has_rowwise + is_colwise = q_layout.has_colwise + input_rank = len(input_shape) - input_spec = [f"{unique_var}_{i}" for i in range(input_rank)] flatten_axis = (flatten_axis + input_rank) % input_rank + input_spec = [f"{unique_var}_x_{i}" for i in range(input_rank)] assert ( self._block_dims[1] != 1 @@ -690,30 +738,48 @@ def get_shardy_sharding_rules( # We have to use two different factors in the two CompoundFactors because of Shardy # verifier requirements, even though they are the same. + # No CompoundFactor is needed if the dim has the same size as the blocksize blocksizes = {} - colwise_var = f"{unique_var}_None" rowwise_var = f"{unique_var}_None" - if not input_shape[-1] == block_size_1d: + colwise_var = f"{unique_var}_None" + if is_rowwise and not input_shape[-1] == block_size_1d: rowwise_var = input_spec[-1] + "_compound" input_spec[-1] = CompoundFactor(rowwise_var, "blocksize_x") blocksizes["blocksize_x"] = block_size_1d - if not input_shape[flatten_axis - 1] == block_size_1d: + if is_colwise and not input_shape[flatten_axis - 1] == block_size_1d: colwise_var = input_spec[flatten_axis - 1] + "_compound" input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "blocksize_y") blocksizes["blocksize_y"] = block_size_1d # The rowwise and colwise scale tensors should be sharded the same way as the input. # However, we need to adjust the dimensions where the block scaling factor applies. - rowwise = input_spec.copy() - rowwise[-1] = rowwise_var + if is_rowwise: + rowwise_out = input_spec.copy() + rowwise_scale = input_spec.copy() + rowwise_scale[-1] = rowwise_var + else: + rowwise_out = f"{unique_var}_rowwise_output" + rowwise_scale = f"{unique_var}_rowwise_scale_inv" - colwise = input_spec.copy() - colwise[flatten_axis - 1] = colwise_var + if is_colwise: + from ..cpp_extensions.misc import multidim_transpose + + colwise_out = input_spec.copy() + colwise_scale = input_spec.copy() + colwise_scale[flatten_axis - 1] = colwise_var + if is_colwise_transposed: + colwise_out = multidim_transpose(colwise_out, transpose_axis=flatten_axis) + colwise_scale = multidim_transpose(colwise_scale, transpose_axis=flatten_axis) + else: + colwise_out = f"{unique_var}_colwise_output" + colwise_scale = f"{unique_var}_colwise_scale_inv" return QuantizeShardyRules( tuple(input_spec), - tuple(rowwise), - tuple(colwise), + tuple(rowwise_out), + tuple(rowwise_scale), + tuple(colwise_out), + tuple(colwise_scale), blocksizes, ) @@ -850,7 +916,8 @@ def get_shardy_sharding_rules( self, input_shape, unique_var, - flatten_axis=-1, + flatten_axis, + q_layout, broadcast_2d_scale_shape_to_1d=False, ) -> Tuple[Tuple[str]]: """Sharding rules for the input and (row, col)wise scale tensors. @@ -859,13 +926,19 @@ def get_shardy_sharding_rules( input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix flatten_axis: Axis along which data can be flattened to 2D for quantization. + q_layout: The layout of the quantized tensor broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to False. Returns: The Shardy rules for the scaling mode """ return self._get_impl().get_shardy_sharding_rules( - input_shape, unique_var, flatten_axis, broadcast_2d_scale_shape_to_1d + input_shape, + unique_var, + flatten_axis, + q_layout, + broadcast_2d_scale_shape_to_1d, + self.is_colwise_transposed, ) def get_grouped_scale_shape_2x( diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 6c358a044e..72256944a7 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -15,10 +15,10 @@ import jax.numpy as jnp from jax.tree_util import register_pytree_node_class -from transformer_engine_jax import QuantizeLayout from .scaling_modes import ScalingMode, TensorUsage from .dequantizer import ScalingModeToDequantizerMap +from .misc import QuantizeLayout from ..sharding import ( with_sharding_constraint_by_logical_axes as original_with_sharding_constraint_by_logical_axes, ) @@ -128,9 +128,7 @@ def dequantize(self): def get_tensor(self, usage: TensorUsage): """Returns the tensor based on the tensor usage.""" q_layout = ScalingMode.NO_SCALING.get_quantize_layout(usage) - assert ( - q_layout == QuantizeLayout.ROWWISE - ), "Only ROWWISE layout is supported for NoScaleTensor" + assert q_layout.is_rowwise_only, "Only ROWWISE layout is supported for NoScaleTensor" return self def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): @@ -264,8 +262,8 @@ def dequantize(self): def get_tensor(self, usage: TensorUsage): """Returns the tensor based on the tensor usage.""" q_layout = self.scaling_mode.get_quantize_layout(usage) - colwise_usage_valid = q_layout == QuantizeLayout.COLWISE and self.is_colwise - rowwise_usage_valid = q_layout == QuantizeLayout.ROWWISE and not self.is_colwise + colwise_usage_valid = q_layout.is_colwise_only and self.is_colwise + rowwise_usage_valid = q_layout.is_rowwise_only and not self.is_colwise if colwise_usage_valid or rowwise_usage_valid: return self @@ -309,8 +307,8 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st return ScaledTensor1x( data=data, - scale_inv=scale_inv, amax=self.amax, + scale_inv=scale_inv, scaling_mode=self.scaling_mode, dq_dtype=self.dq_dtype, _dq_func=self._dq_func, @@ -467,10 +465,10 @@ def get_tensor(self, usage: TensorUsage): q_layout_rowwise = self.rowwise_tensor.scaling_mode.get_quantize_layout(usage) q_layout_colwise = self.colwise_tensor.scaling_mode.get_quantize_layout(usage) - if q_layout_rowwise == QuantizeLayout.ROWWISE: + if q_layout_rowwise.is_rowwise_only: return self.rowwise_tensor - if q_layout_colwise == QuantizeLayout.COLWISE: + if q_layout_colwise.is_colwise_only: return self.colwise_tensor raise ValueError( @@ -548,13 +546,13 @@ def create_1x( dequantizer = ScalingModeToDequantizerMap.get(scaling_mode) if group_sizes is not None: - flatten_axis = len(original_shape) + flatten_axis if flatten_axis < 0 else flatten_axis + flatten_axis = (len(original_shape) + flatten_axis) % len(original_shape) assert ( original_shape is not None ), "original_shape is not given for GroupedScaledTensor1x" # Handling attrs of transposed tensors - group_axis = len(original_shape) + group_axis if group_axis < 0 else group_axis + group_axis = (len(original_shape) + group_axis) % len(original_shape) if data_layout == "T": if original_shape[0] == group_sizes.size: original_shape = ( @@ -587,7 +585,7 @@ def create_1x( ) # Handling attrs of transposed tensors - flatten_axis = data.ndim + flatten_axis if flatten_axis < 0 else flatten_axis + flatten_axis = (data.ndim + flatten_axis) % data.ndim if data_layout == "T": flatten_axis = data.ndim - flatten_axis @@ -669,7 +667,7 @@ def create_2x( colwise_amax, scaling_mode, dq_dtype, - is_colwise=True, # TODO(Phuong): set this correctly + is_colwise=True, data_layout=data_layout[1], flatten_axis=flatten_axis, group_sizes=group_sizes, @@ -721,7 +719,7 @@ def create( """ assert not rowwise_has_rht_applied, "RHT is not supported for rowwise quantization yet" - if q_layout == QuantizeLayout.ROWWISE_COLWISE: + if q_layout.is_rowwise_colwise: return ScaledTensorFactory.create_2x( data, scale_inv, @@ -740,15 +738,14 @@ def create( colwise_has_rht_applied=colwise_has_rht_applied, ) - is_colwise = q_layout == QuantizeLayout.COLWISE - if is_colwise: + if q_layout.is_colwise_only: return ScaledTensorFactory.create_1x( colwise_data, colwise_scale_inv, colwise_amax if colwise_amax is not None else amax, scaling_mode, dq_dtype, - is_colwise=is_colwise, + is_colwise=True, data_layout=data_layout[0], flatten_axis=flatten_axis, group_sizes=group_sizes, @@ -763,7 +760,7 @@ def create( amax, scaling_mode, dq_dtype, - is_colwise=is_colwise, + is_colwise=False, data_layout=data_layout[0], flatten_axis=flatten_axis, group_sizes=group_sizes, From 6c166c662daae7105d369be8c1b94a683bc58b03 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 10 Nov 2025 13:12:14 -0800 Subject: [PATCH 02/13] minor fix Signed-off-by: Phuong Nguyen --- transformer_engine/jax/quantize/quantizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index 705c463de0..2da3ce9aaf 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -168,7 +168,7 @@ def quantize( is_rowwise = is_rowwise if is_rowwise is not None else self.q_layout.has_rowwise is_colwise = is_colwise if is_colwise is not None else self.q_layout.has_colwise - if (is_rowwise and is_colwise) or self.is_2x2x(): + if (is_rowwise and is_colwise) or self.q_layout.is_rowwise_colwise: rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis) colwise_tensor = self._quantize_func( x, is_colwise=True, dq_dtype=dq_dtype, flatten_axis=flatten_axis From a1aba7f1d665c178aee21881ff7df407504d53be Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 10 Nov 2025 14:31:46 -0800 Subject: [PATCH 03/13] fix rules for fp8 Signed-off-by: Phuong Nguyen --- .../jax/cpp_extensions/normalization.py | 2 +- .../jax/cpp_extensions/quantization.py | 2 +- .../jax/quantize/scaling_modes.py | 33 ++++++------------- 3 files changed, 12 insertions(+), 25 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index c6bc053073..8f4f5a47fc 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -663,7 +663,7 @@ def shardy_sharding_rule( beta = (BATCHING + prefix + "_beta",) return SdyShardingRule( - (tuple(input_spec), scale, amax, gamma, beta), + (input_spec, scale, amax, gamma, beta), ( scale_rules.rowwise_out_spec, scale_rules.colwise_out_spec, diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 424917a30f..35ea73ae9a 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -663,7 +663,7 @@ def shardy_sharding_rule( rht_matrix = (BATCHING + prefix + "_rht_matrix_1", BATCHING + prefix + "_rht_matrix_2") return SdyShardingRule( - (tuple(input_spec), scale, amax, sr_rng_state, post_rht_amax, rht_matrix), + (input_spec, scale, amax, sr_rng_state, post_rht_amax, rht_matrix), ( scale_rules.rowwise_out_spec, scale_rules.colwise_out_spec, diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index 5df7103782..16420e014b 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -408,34 +408,21 @@ def get_shardy_sharding_rules( Returns: The Shardy rules for the scaling mode """ - del flatten_axis, broadcast_2d_scale_shape_to_1d + del broadcast_2d_scale_shape_to_1d input_spec = tuple(f"{unique_var}x_{i}" for i in range(len(input_shape))) - output_spec = input_spec.copy() - colwise_output_spec = BATCHING + f"{unique_var}_colwise_output" + output_spec = input_spec + colwise_output_spec = (BATCHING + f"{unique_var}_colwise_output",) if q_layout.has_colwise: from ..cpp_extensions.misc import multidim_transpose - colwise_output_spec = input_spec.copy() + colwise_output_spec = input_spec if is_colwise_transposed: colwise_output_spec = multidim_transpose( colwise_output_spec, transpose_axis=flatten_axis ) - scale = BATCHING + unique_var + "_scale_inv" - return QuantizeShardyRules( - tuple(input_spec), - tuple(output_spec), - tuple( - scale, - ), - tuple( - colwise_output_spec, - ), - tuple( - scale, - ), - {}, - ) + scale = (BATCHING + unique_var + "_scale_inv",) + return QuantizeShardyRules(input_spec, output_spec, scale, colwise_output_spec, scale, {}) class DelayedScalingModeMetadataImpl(CurrentScalingModeMetadataImpl): @@ -758,8 +745,8 @@ def get_shardy_sharding_rules( rowwise_scale = input_spec.copy() rowwise_scale[-1] = rowwise_var else: - rowwise_out = f"{unique_var}_rowwise_output" - rowwise_scale = f"{unique_var}_rowwise_scale_inv" + rowwise_out = BATCHING + f"{unique_var}_rowwise_output" + rowwise_scale = BATCHING + f"{unique_var}_rowwise_scale_inv" if is_colwise: from ..cpp_extensions.misc import multidim_transpose @@ -771,8 +758,8 @@ def get_shardy_sharding_rules( colwise_out = multidim_transpose(colwise_out, transpose_axis=flatten_axis) colwise_scale = multidim_transpose(colwise_scale, transpose_axis=flatten_axis) else: - colwise_out = f"{unique_var}_colwise_output" - colwise_scale = f"{unique_var}_colwise_scale_inv" + colwise_out = BATCHING + f"{unique_var}_colwise_output" + colwise_scale = BATCHING + f"{unique_var}_colwise_scale_inv" return QuantizeShardyRules( tuple(input_spec), From 3864b87afb550565c79ee7e012fd1e0823df9d1f Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 10 Nov 2025 14:52:50 -0800 Subject: [PATCH 04/13] fix Signed-off-by: Phuong Nguyen --- transformer_engine/jax/quantize/scaling_modes.py | 11 +++++------ transformer_engine/jax/quantize/tensor.py | 3 +-- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index 16420e014b..b151efaf18 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -745,21 +745,20 @@ def get_shardy_sharding_rules( rowwise_scale = input_spec.copy() rowwise_scale[-1] = rowwise_var else: - rowwise_out = BATCHING + f"{unique_var}_rowwise_output" - rowwise_scale = BATCHING + f"{unique_var}_rowwise_scale_inv" + rowwise_out = [BATCHING + f"{unique_var}_rowwise_output",] + rowwise_scale = [BATCHING + f"{unique_var}_rowwise_scale_inv",] if is_colwise: - from ..cpp_extensions.misc import multidim_transpose - colwise_out = input_spec.copy() colwise_scale = input_spec.copy() colwise_scale[flatten_axis - 1] = colwise_var if is_colwise_transposed: + from ..cpp_extensions.misc import multidim_transpose colwise_out = multidim_transpose(colwise_out, transpose_axis=flatten_axis) colwise_scale = multidim_transpose(colwise_scale, transpose_axis=flatten_axis) else: - colwise_out = BATCHING + f"{unique_var}_colwise_output" - colwise_scale = BATCHING + f"{unique_var}_colwise_scale_inv" + colwise_out = [BATCHING + f"{unique_var}_colwise_output", ] + colwise_scale = [BATCHING + f"{unique_var}_colwise_scale_inv", ] return QuantizeShardyRules( tuple(input_spec), diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 72256944a7..c93ece771d 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -299,8 +299,7 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st data = with_sharding_constraint_by_logical_axes(self.data, axis_names) - if self.scaling_mode == ScalingMode.MXFP8_1D_SCALING: - # TODO(Phuong): Handle padding !? + if self.scaling_mode.is_1d_block_scaling: # Both MXFP8 and NVFP4 scale_inv = with_sharding_constraint_by_logical_axes(self.scale_inv, axis_names) else: scale_inv = self.scale_inv From fbbf5524737b489b117c925b61ee7b0295b3a9d6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 10 Nov 2025 22:53:47 +0000 Subject: [PATCH 05/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../jax/quantize/scaling_modes.py | 17 +++++++++++++---- transformer_engine/jax/quantize/tensor.py | 2 +- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index b151efaf18..eea27a35d8 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -745,8 +745,12 @@ def get_shardy_sharding_rules( rowwise_scale = input_spec.copy() rowwise_scale[-1] = rowwise_var else: - rowwise_out = [BATCHING + f"{unique_var}_rowwise_output",] - rowwise_scale = [BATCHING + f"{unique_var}_rowwise_scale_inv",] + rowwise_out = [ + BATCHING + f"{unique_var}_rowwise_output", + ] + rowwise_scale = [ + BATCHING + f"{unique_var}_rowwise_scale_inv", + ] if is_colwise: colwise_out = input_spec.copy() @@ -754,11 +758,16 @@ def get_shardy_sharding_rules( colwise_scale[flatten_axis - 1] = colwise_var if is_colwise_transposed: from ..cpp_extensions.misc import multidim_transpose + colwise_out = multidim_transpose(colwise_out, transpose_axis=flatten_axis) colwise_scale = multidim_transpose(colwise_scale, transpose_axis=flatten_axis) else: - colwise_out = [BATCHING + f"{unique_var}_colwise_output", ] - colwise_scale = [BATCHING + f"{unique_var}_colwise_scale_inv", ] + colwise_out = [ + BATCHING + f"{unique_var}_colwise_output", + ] + colwise_scale = [ + BATCHING + f"{unique_var}_colwise_scale_inv", + ] return QuantizeShardyRules( tuple(input_spec), diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index c93ece771d..4354941aca 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -299,7 +299,7 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st data = with_sharding_constraint_by_logical_axes(self.data, axis_names) - if self.scaling_mode.is_1d_block_scaling: # Both MXFP8 and NVFP4 + if self.scaling_mode.is_1d_block_scaling: # Both MXFP8 and NVFP4 scale_inv = with_sharding_constraint_by_logical_axes(self.scale_inv, axis_names) else: scale_inv = self.scale_inv From 20529693355cd7f69b298be6ae14ff7af0c463d4 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 12 Nov 2025 08:12:05 -0800 Subject: [PATCH 06/13] cleanup Signed-off-by: Phuong Nguyen --- transformer_engine/jax/quantize/misc.py | 5 ----- transformer_engine/jax/quantize/tensor.py | 2 +- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/transformer_engine/jax/quantize/misc.py b/transformer_engine/jax/quantize/misc.py index 21af4813d8..8d3442c09d 100644 --- a/transformer_engine/jax/quantize/misc.py +++ b/transformer_engine/jax/quantize/misc.py @@ -4,9 +4,6 @@ """ This module provides additional enum and utilities for quantizing tensors in JAX. """ -# from jax.tree_util import register_pytree_node_class -# from abc import ABC, abstractmethod -from dataclasses import dataclass from enum import Enum from transformer_engine_jax import JAXX_Quantize_Layout @@ -16,8 +13,6 @@ ] -# @register_pytree_node_class -@dataclass(frozen=True) class QuantizeLayout(Enum): "Wrapper for JAXX_Quantize_Layout" diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 4354941aca..25db844098 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -299,7 +299,7 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st data = with_sharding_constraint_by_logical_axes(self.data, axis_names) - if self.scaling_mode.is_1d_block_scaling: # Both MXFP8 and NVFP4 + if self.scaling_mode.is_block_scaling: # Both MXFP8 and NVFP4 scale_inv = with_sharding_constraint_by_logical_axes(self.scale_inv, axis_names) else: scale_inv = self.scale_inv From c918975ee3e44aef6aa26a6578fe0be5191760f6 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 12 Nov 2025 11:35:52 -0800 Subject: [PATCH 07/13] add dataclass back Signed-off-by: Phuong Nguyen --- transformer_engine/jax/quantize/misc.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_engine/jax/quantize/misc.py b/transformer_engine/jax/quantize/misc.py index 8d3442c09d..c1e169d005 100644 --- a/transformer_engine/jax/quantize/misc.py +++ b/transformer_engine/jax/quantize/misc.py @@ -4,6 +4,7 @@ """ This module provides additional enum and utilities for quantizing tensors in JAX. """ +from dataclasses import dataclass from enum import Enum from transformer_engine_jax import JAXX_Quantize_Layout @@ -13,6 +14,7 @@ ] +@dataclass(frozen=True) class QuantizeLayout(Enum): "Wrapper for JAXX_Quantize_Layout" From 9f97504983ea5fa06e201bb60cac794619cad1cc Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 12 Nov 2025 15:00:51 -0800 Subject: [PATCH 08/13] minor fix Signed-off-by: Phuong Nguyen --- transformer_engine/jax/quantize/quantizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index 2da3ce9aaf..8a54f0b1db 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -168,7 +168,7 @@ def quantize( is_rowwise = is_rowwise if is_rowwise is not None else self.q_layout.has_rowwise is_colwise = is_colwise if is_colwise is not None else self.q_layout.has_colwise - if (is_rowwise and is_colwise) or self.q_layout.is_rowwise_colwise: + if is_rowwise and is_colwise: rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis) colwise_tensor = self._quantize_func( x, is_colwise=True, dq_dtype=dq_dtype, flatten_axis=flatten_axis From 0b11309e3a481769d15f11a92a19b33abd32cc18 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 12 Nov 2025 15:07:04 -0800 Subject: [PATCH 09/13] fix layout.value in grouped_quantize Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/quantization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 35ea73ae9a..a0e1a6406f 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -1113,7 +1113,7 @@ def lowering( scale, group_sizes, scaling_mode=scaling_mode.value, - q_layout=q_layout, + q_layout=q_layout.value.value, flatten_axis=flatten_axis, ) @@ -1250,7 +1250,7 @@ def grouped_quantize( group_sizes, out_dtype=quantizer.q_dtype, scaling_mode=quantizer.scaling_mode.value, - q_layout=q_layout.value, + q_layout=q_layout, flatten_axis=flatten_axis, group_axis=group_axis, scale_dtype=quantizer.get_scale_dtype(), From 38d5be31702bbfcc846a0d0a2ed8020ecb675f30 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 10 Nov 2025 13:08:49 -0800 Subject: [PATCH 10/13] init Signed-off-by: Phuong Nguyen --- .../jax/cpp_extensions/quantization.py | 37 ++++++++++++++++++- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index a0e1a6406f..977b28d9fb 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -837,10 +837,10 @@ def _quantize_dbias_impl( if amax is None: amax = jnp.zeros((1,), jnp.float32) - # It is faster to use 1x quantization for tensor scaling + # It is faster to use 1x quantization for tensor scaling and 2D NVFP4_1D_SCALING is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100) force_1x_quantization = ( - quantizer.scaling_mode.is_tensor_scaling() + (quantizer.scaling_mode.is_tensor_scaling() or quantizer.scaling_mode == ScalingMode.NVFP4_2D_SCALING) and quantizer.q_layout.is_rowwise_colwise and is_1x_kernel_supported ) @@ -895,6 +895,39 @@ def _quantize_dbias_impl( colwise_casted_output = jnp.transpose( rowwise_casted_output, (*range(flatten_axis, x.ndim), *range(flatten_axis)) ) + + if quantizer.scaling_mode == ScalingMode.NVFP4_2D_SCALING and quantizer.q_layout.is_rowwise_colwise: + assert q_layout.is_rowwise_only + # Quantizer requires 2x quantization, but we are using 1x quantization + # for performance reasons, so we need to generate the colwise data in JAX + flatten_axis = (flatten_axis + x.ndim) % x.ndim + colwise_casted_output = jnp.transpose( + rowwise_casted_output, (*range(flatten_axis, x.ndim), *range(flatten_axis)) + ) + # Interleave + transpose the scale + scale_shape = rowwise_scale_inv.shape + flatten_axis = (flatten_axis + x.ndim) % x.ndim + ## Split the dim before the flatten_axis to (its size / block_size, block_size) + colwise_scale_inv = rowwise_scale_inv.reshape( + *scale_shape[:flatten_axis - 1], + int(scale_shape[flatten_axis - 1] / 16), + 16, # <-- block_dim + *scale_shape[flatten_axis:], + ) + # now flatten_axis = flatten_axis + 1 + colwise_scale_inv = jnp.transpose(colwise_scale_inv, + (*range(flatten_axis + 1, colwise_scale_inv.ndim), + flatten_axis, # <-- block_dim after transpose + *range(0, flatten_axis)), + ) + block_dim = colwise_scale_inv.ndim - flatten_axis - 1 + assert block_dim >= 1 + # Merge the block_dim back + colwise_scale_inv = colwise_scale_inv.reshape( + *colwise_scale_inv.shape[:block_dim - 1], + -1, + *colwise_scale_inv.shape[block_dim + 1:], + ) quantizer.update(updated_amax) if quantizer.scaling_mode.is_nvfp4_scaling and is_dbias: dbias = _jax_dbias(x, flatten_axis=flatten_axis) From b5dd94a5e977b812a46ad265881ce919d2cb233f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 10 Nov 2025 22:40:12 +0000 Subject: [PATCH 11/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../jax/cpp_extensions/quantization.py | 31 ++++++++++++------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 977b28d9fb..082e3516a9 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -840,7 +840,10 @@ def _quantize_dbias_impl( # It is faster to use 1x quantization for tensor scaling and 2D NVFP4_1D_SCALING is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100) force_1x_quantization = ( - (quantizer.scaling_mode.is_tensor_scaling() or quantizer.scaling_mode == ScalingMode.NVFP4_2D_SCALING) + ( + quantizer.scaling_mode.is_tensor_scaling() + or quantizer.scaling_mode == ScalingMode.NVFP4_2D_SCALING + ) and quantizer.q_layout.is_rowwise_colwise and is_1x_kernel_supported ) @@ -896,7 +899,10 @@ def _quantize_dbias_impl( rowwise_casted_output, (*range(flatten_axis, x.ndim), *range(flatten_axis)) ) - if quantizer.scaling_mode == ScalingMode.NVFP4_2D_SCALING and quantizer.q_layout.is_rowwise_colwise: + if ( + quantizer.scaling_mode == ScalingMode.NVFP4_2D_SCALING + and quantizer.q_layout.is_rowwise_colwise + ): assert q_layout.is_rowwise_only # Quantizer requires 2x quantization, but we are using 1x quantization # for performance reasons, so we need to generate the colwise data in JAX @@ -909,24 +915,27 @@ def _quantize_dbias_impl( flatten_axis = (flatten_axis + x.ndim) % x.ndim ## Split the dim before the flatten_axis to (its size / block_size, block_size) colwise_scale_inv = rowwise_scale_inv.reshape( - *scale_shape[:flatten_axis - 1], + *scale_shape[: flatten_axis - 1], int(scale_shape[flatten_axis - 1] / 16), - 16, # <-- block_dim + 16, # <-- block_dim *scale_shape[flatten_axis:], ) # now flatten_axis = flatten_axis + 1 - colwise_scale_inv = jnp.transpose(colwise_scale_inv, - (*range(flatten_axis + 1, colwise_scale_inv.ndim), - flatten_axis, # <-- block_dim after transpose - *range(0, flatten_axis)), - ) + colwise_scale_inv = jnp.transpose( + colwise_scale_inv, + ( + *range(flatten_axis + 1, colwise_scale_inv.ndim), + flatten_axis, # <-- block_dim after transpose + *range(0, flatten_axis), + ), + ) block_dim = colwise_scale_inv.ndim - flatten_axis - 1 assert block_dim >= 1 # Merge the block_dim back colwise_scale_inv = colwise_scale_inv.reshape( - *colwise_scale_inv.shape[:block_dim - 1], + *colwise_scale_inv.shape[: block_dim - 1], -1, - *colwise_scale_inv.shape[block_dim + 1:], + *colwise_scale_inv.shape[block_dim + 1 :], ) quantizer.update(updated_amax) if quantizer.scaling_mode.is_nvfp4_scaling and is_dbias: From fc648742ee365f05b8750202af47ba907cc311b0 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 12 Nov 2025 17:32:41 -0800 Subject: [PATCH 12/13] added AG weight before transpose Signed-off-by: Phuong Nguyen --- .../jax/cpp_extensions/quantization.py | 26 +++++++++++++++---- transformer_engine/jax/dense.py | 2 +- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 082e3516a9..07fb1c30b2 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -32,6 +32,7 @@ all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp, get_num_devices_in_mesh, + global_mesh_resource, ) from ..quantize import ( ScaledTensor2x, @@ -497,17 +498,27 @@ def partition( x_spec = get_padded_spec(arg_infos[0]) amax_spec = get_padded_spec(arg_infos[2]) + out_spec = x_spec + + # Optimization for NVFP4 2D 1x1x + FSDP + gsr = global_mesh_resource() + fsdp_all_gather_dim = None + # if ScalingMode(scaling_mode) == ScalingMode.NVFP4_2D_SCALING and q_layout.is_rowwise_only and gsr.fsdp_resource in out_spec: + if ScalingMode(scaling_mode) == ScalingMode.NVFP4_2D_SCALING and q_layout.is_rowwise_only and gsr.fsdp_resource == out_spec[0]: + fsdp_all_gather_dim = out_spec.index(gsr.fsdp_resource) + out_spec = tuple(s if s != gsr.fsdp_resource else None for s in out_spec) + out_sharding = NamedSharding( mesh, - PartitionSpec(*x_spec), + PartitionSpec(*out_spec), desc="BaseDBiasQuantizePrimitive.out_sharding", ) if q_layout.has_colwise: if ScalingMode(scaling_mode).is_colwise_transposed: - colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) + colwise_out_spec = multidim_transpose(out_spec, transpose_axis=flatten_axis) else: - colwise_out_spec = x_spec + colwise_out_spec = out_spec else: colwise_out_spec = (None,) colwise_out_sharding = NamedSharding( @@ -516,7 +527,7 @@ def partition( desc="BaseDBiasQuantizePrimitive.colwise_out_sharding", ) - dbias_spec = x_spec[flatten_axis:] if is_dbias else (None,) + dbias_spec = out_spec[flatten_axis:] if is_dbias else (None,) dbias_sharding = NamedSharding( mesh, PartitionSpec(*dbias_spec), @@ -525,7 +536,7 @@ def partition( scale_inv_spec = colwise_scale_inv_spec = (None,) if ScalingMode(scaling_mode).is_block_scaling: - scale_inv_spec = x_spec + scale_inv_spec = out_spec if q_layout.has_colwise: if ( @@ -605,6 +616,11 @@ def sharded_impl(x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix): else: global_dbias = local_dbias + if fsdp_all_gather_dim is not None: + local_x = jax.lax.all_gather(local_x, gsr.fsdp_resource, axis=fsdp_all_gather_dim, tiled=True) + local_scale_inv = jax.lax.all_gather(local_scale_inv, gsr.fsdp_resource, axis=fsdp_all_gather_dim, + tiled=True) + return ( local_x, local_colwise_x, diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 44c73a5b1e..750408f9c4 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -199,7 +199,7 @@ def _dense_fwd_rule( amax_scope=AmaxScope.TPSP, transpose_batch_sequence=transpose_batch_sequence, ) - casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) + # casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) casted_kernel = tex.quantize( kernel, From 006874a14c6a17e0c0bd0a6cc568ba7ea7932fef Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 13 Nov 2025 01:33:28 +0000 Subject: [PATCH 13/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../jax/cpp_extensions/quantization.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 07fb1c30b2..303c558b07 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -504,7 +504,11 @@ def partition( gsr = global_mesh_resource() fsdp_all_gather_dim = None # if ScalingMode(scaling_mode) == ScalingMode.NVFP4_2D_SCALING and q_layout.is_rowwise_only and gsr.fsdp_resource in out_spec: - if ScalingMode(scaling_mode) == ScalingMode.NVFP4_2D_SCALING and q_layout.is_rowwise_only and gsr.fsdp_resource == out_spec[0]: + if ( + ScalingMode(scaling_mode) == ScalingMode.NVFP4_2D_SCALING + and q_layout.is_rowwise_only + and gsr.fsdp_resource == out_spec[0] + ): fsdp_all_gather_dim = out_spec.index(gsr.fsdp_resource) out_spec = tuple(s if s != gsr.fsdp_resource else None for s in out_spec) @@ -617,9 +621,12 @@ def sharded_impl(x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix): global_dbias = local_dbias if fsdp_all_gather_dim is not None: - local_x = jax.lax.all_gather(local_x, gsr.fsdp_resource, axis=fsdp_all_gather_dim, tiled=True) - local_scale_inv = jax.lax.all_gather(local_scale_inv, gsr.fsdp_resource, axis=fsdp_all_gather_dim, - tiled=True) + local_x = jax.lax.all_gather( + local_x, gsr.fsdp_resource, axis=fsdp_all_gather_dim, tiled=True + ) + local_scale_inv = jax.lax.all_gather( + local_scale_inv, gsr.fsdp_resource, axis=fsdp_all_gather_dim, tiled=True + ) return ( local_x,