Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 76 additions & 84 deletions transformer_engine/jax/cpp_extensions/activation.py

Large diffs are not rendered by default.

8 changes: 6 additions & 2 deletions transformer_engine/jax/cpp_extensions/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,9 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant
break
# _quantize_dbias_impl forcing 1x quantization for tensor scaling switches q_layout to ROWWISE,
# but this fails when bias fusion is turned on with arch < 100.
force_1x_quantization = quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x()
force_1x_quantization = (
quantizer.scaling_mode.is_tensor_scaling() and quantizer.q_layout.is_rowwise_colwise
)
return (
(force_1x_quantization or quantizer.q_layout == QuantizeLayout.ROWWISE)
and arch_l_100
Expand All @@ -229,7 +231,9 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1,
@return: the output of 'f' with the colwise output calculated
"""
should_apply_war = (
quantizer is not None and quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x()
quantizer is not None
and quantizer.scaling_mode.is_tensor_scaling()
and quantizer.q_layout.is_rowwise_colwise
)
if not should_apply_war:
return None
Expand Down
102 changes: 53 additions & 49 deletions transformer_engine/jax/cpp_extensions/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import jax
import jax.numpy as jnp
from jax import dtypes, ffi
from jax.experimental.custom_partitioning import SdyShardingRule
from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec

Expand Down Expand Up @@ -112,7 +112,7 @@ def abstract(
epsilon,
out_dtype,
scaling_mode,
is_2x,
quantize_layout,
scale_dtype,
amax_scope,
transpose_batch_sequence,
Expand Down Expand Up @@ -165,15 +165,15 @@ def abstract(

updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)

colwise_out_shape = x_aval.shape if is_2x else (1,)
colwise_out_shape = x_aval.shape if quantize_layout.has_colwise else (1,)
colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype)

rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
scaling_mode
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer)

scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype)
colwise_scale_inv_shape = colwise_scale_inv_shape if is_2x else (1,)
colwise_scale_inv_shape = colwise_scale_inv_shape if quantize_layout.has_colwise else (1,)
colwise_scale_inv_aval = jax.core.ShapedArray(
shape=colwise_scale_inv_shape, dtype=scale_dtype
)
Expand All @@ -189,7 +189,7 @@ def abstract(
zero_centered_gamma,
epsilon,
get_forward_sm_margin(),
is_2x,
True, # is_training
)
wkspace_aval = jax.core.ShapedArray(
shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1])
Expand Down Expand Up @@ -245,7 +245,7 @@ def lowering(
epsilon,
out_dtype,
scaling_mode,
is_2x,
quantize_layout,
scale_dtype,
amax_scope,
transpose_batch_sequence,
Expand Down Expand Up @@ -287,7 +287,7 @@ def lowering(
epsilon=epsilon,
sm_margin=sm_margin,
scaling_mode=scaling_mode.value,
is_2x=is_2x,
quantize_layout=quantize_layout.value.value,
output_amax_when_no_scaling=output_amax_when_no_scaling,
)

Expand All @@ -303,7 +303,7 @@ def impl(
epsilon,
out_dtype,
scaling_mode,
is_2x,
quantize_layout,
scale_dtype,
amax_scope,
transpose_batch_sequence,
Expand Down Expand Up @@ -335,7 +335,7 @@ def impl(
epsilon=epsilon,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
is_2x=is_2x,
quantize_layout=quantize_layout,
scale_dtype=scale_dtype,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
Expand All @@ -349,7 +349,7 @@ def impl(
scale_inv = scale_inv.flatten()[: reduce(operator.mul, rowwise_scale_inv_shape, 1)].reshape(
rowwise_scale_inv_shape
)
if is_2x:
if quantize_layout.has_colwise:
colwise_scale_inv = colwise_scale_inv.flatten()[
: reduce(operator.mul, colwise_scale_inv_shape, 1)
].reshape(colwise_scale_inv_shape)
Expand All @@ -373,7 +373,7 @@ def batcher(
epsilon,
out_dtype,
scaling_mode,
is_2x,
quantize_layout,
scale_dtype,
amax_scope,
transpose_batch_sequence,
Expand Down Expand Up @@ -409,7 +409,7 @@ def batcher(
epsilon=epsilon,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
is_2x=is_2x,
quantize_layout=quantize_layout,
scale_dtype=scale_dtype,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
Expand All @@ -426,7 +426,7 @@ def infer_sharding_from_operands(
epsilon,
out_dtype,
scaling_mode,
is_2x,
quantize_layout,
scale_dtype,
amax_scope,
transpose_batch_sequence,
Expand All @@ -450,7 +450,7 @@ def infer_sharding_from_operands(
)

out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out")
colwise_out_spec = out_spec if is_2x else (None,)
colwise_out_spec = out_spec if quantize_layout.has_colwise else (None,)
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out"
)
Expand Down Expand Up @@ -488,7 +488,7 @@ def partition(
epsilon,
out_dtype,
scaling_mode,
is_2x,
quantize_layout,
scale_dtype,
amax_scope,
transpose_batch_sequence,
Expand Down Expand Up @@ -524,7 +524,7 @@ def partition(
)

out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out")
colwise_out_spec = out_spec if is_2x else (None,)
colwise_out_spec = out_spec if quantize_layout.has_colwise else (None,)
colwise_out_sharding = NamedSharding(
mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out"
)
Expand Down Expand Up @@ -586,7 +586,7 @@ def sharded_impl(x, scale, amax, gamma, beta):
epsilon=epsilon,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
is_2x=is_2x,
quantize_layout=quantize_layout,
scale_dtype=scale_dtype,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
Expand Down Expand Up @@ -623,7 +623,7 @@ def shardy_sharding_rule(
epsilon,
out_dtype,
scaling_mode,
is_2x,
quantize_layout,
scale_dtype,
amax_scope,
transpose_batch_sequence,
Expand All @@ -646,25 +646,29 @@ def shardy_sharding_rule(
result_types,
)

prefix = "NormFwd_"
prefix = "NormFwd"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
value_types[0].shape, unique_var=prefix + "x", flatten_axis=-1
value_types[0].shape,
unique_var=prefix,
flatten_axis=-1,
q_layout=quantize_layout,
)
x_axes = scale_rules.input_spec
input_spec = scale_rules.input_spec

out = x_axes
colwise_out = out if is_2x else (prefix + "out_colwise",)
rsigma = x_axes[:-1]
mu = (prefix + "mu",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma
amax = (prefix + "amax",)
rsigma = input_spec[:-1]
mu = (BATCHING + prefix + "_mu",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma
amax = (BATCHING + prefix + "_amax",)
scale = (BATCHING + prefix + "_scale",)
gamma = (BATCHING + prefix + "_gamma",)
beta = (BATCHING + prefix + "_beta",)

return SdyShardingRule(
(x_axes, ("…1",), amax, ("…2",), ("…3",)),
(input_spec, scale, amax, gamma, beta),
(
out,
colwise_out,
scale_rules.rowwise_rule,
scale_rules.colwise_rule,
scale_rules.rowwise_out_spec,
scale_rules.colwise_out_spec,
scale_rules.rowwise_scale_spec,
scale_rules.colwise_scale_spec,
amax,
mu,
rsigma,
Expand Down Expand Up @@ -987,7 +991,7 @@ def layernorm_fwd(
return (output, mu, rsigma)

# TE/common does not support normalization with colwise only quantization yet
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
if quantizer is not None and quantizer.q_layout.is_colwise_only:
return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer)

scale = (
Expand All @@ -1008,7 +1012,7 @@ def layernorm_fwd(
epsilon=epsilon,
out_dtype=x.dtype,
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False,
quantize_layout=QuantizeLayout.ROWWISE,
scale_dtype=jnp.float32,
amax_scope=amax_scope,
transpose_batch_sequence=False,
Expand Down Expand Up @@ -1067,10 +1071,11 @@ def layernorm_fwd(
)
return out, mu, rsigma

is_2x2x = quantizer.is_2x2x()
# TE/common normalization doesn't support 2x delayed scaling
if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling():
is_2x2x = False
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
q_layout = quantizer.q_layout
if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling():
q_layout = QuantizeLayout.ROWWISE

(
rowwise_casted_output,
colwise_casted_output,
Expand All @@ -1090,7 +1095,7 @@ def layernorm_fwd(
epsilon=epsilon,
out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value,
is_2x=is_2x2x,
quantize_layout=q_layout,
scale_dtype=quantizer.get_scale_dtype(),
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
Expand All @@ -1099,8 +1104,7 @@ def layernorm_fwd(
)
quantizer.update(updated_amax)

# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling():
if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling():
colwise_casted_output = jnp.transpose(
rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
)
Expand Down Expand Up @@ -1238,7 +1242,7 @@ def rmsnorm_fwd(
return (output, rsigma)

# TE/common does not support normalization with colwise only quantization yet
if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE:
if quantizer is not None and quantizer.q_layout.is_colwise_only:
return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer)

scale = (
Expand All @@ -1261,7 +1265,7 @@ def rmsnorm_fwd(
epsilon=epsilon,
out_dtype=x.dtype,
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False,
quantize_layout=QuantizeLayout.ROWWISE,
scale_dtype=jnp.float32,
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
Expand Down Expand Up @@ -1321,10 +1325,11 @@ def rmsnorm_fwd(
)
return out, rsigma

is_2x2x = quantizer.is_2x2x()
# TE/common normalization doesn't support 2x delayed scaling
if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling():
is_2x2x = False
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
q_layout = quantizer.q_layout
if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling():
q_layout = QuantizeLayout.ROWWISE

(
rowwise_casted_output,
colwise_casted_output,
Expand All @@ -1344,7 +1349,7 @@ def rmsnorm_fwd(
epsilon=epsilon,
out_dtype=quantizer.q_dtype,
scaling_mode=quantizer.scaling_mode.value,
is_2x=is_2x2x,
quantize_layout=q_layout,
scale_dtype=quantizer.get_scale_dtype(),
amax_scope=amax_scope,
transpose_batch_sequence=transpose_batch_sequence,
Expand All @@ -1353,8 +1358,7 @@ def rmsnorm_fwd(
)
quantizer.update(updated_amax)

# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling():
if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling():
colwise_casted_output = jnp.transpose(
rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
)
Expand Down
Loading