Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 176 additions & 0 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2729,3 +2729,179 @@ def _run_module(m, inp):
out = _run_module(g2, b)

assert_allclose(out, outT, 1e-7)


def test_fp8_weight_on_demand_transpose():
if not fp8_block_scaling_available:
pytest.skip("blockwise fp8 not available.")

dtype = torch.bfloat16
num_gemms = 4
bs = 4
fp8_recipe = recipe.Float8BlockScaling()
config = model_configs["126m"]

old_value = FP8GlobalStateManager.FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE

# 1. grouped linear module test
FP8GlobalStateManager.FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE = False

with fp8_model_init(enabled=True, recipe=fp8_recipe):
grouped_linear = GroupedLinear(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
device="cuda",
).eval()

# Share params
with torch.no_grad():
weights_cache = [
Parameter(getattr(grouped_linear, f"weight{i}").clone()) for i in range(num_gemms)
]

for i in range(num_gemms):
assert getattr(grouped_linear, f"weight{i}")._columnwise_data is not None

outputs1 = _test_grouped_linear_accuracy(
grouped_linear,
num_gemms,
bs,
dtype,
config,
fp8_recipe,
True,
False,
False,
)

FP8GlobalStateManager.FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE = True

with fp8_model_init(enabled=True, recipe=fp8_recipe):
grouped_linear = GroupedLinear(
num_gemms,
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
device="cuda",
).eval()

# Share params
with torch.no_grad():
for i in range(num_gemms):
w = getattr(grouped_linear, f"weight{i}")
assert w._columnwise_data is None
w._rowwise_data = weights_cache[i]._rowwise_data
w._rowwise_scale_inv = weights_cache[i]._rowwise_scale_inv

outputs2 = _test_grouped_linear_accuracy(
grouped_linear,
num_gemms,
bs,
dtype,
config,
fp8_recipe,
True,
False,
False,
)

# should be bit-wise match
for i, (o, o_ref) in enumerate(zip(outputs1, outputs2)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)

# 2. layernorm linear module test
FP8GlobalStateManager.FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE = False
with fp8_model_init(enabled=True, recipe=fp8_recipe):
te_ln_linear = TestReturnBiasModule(
LayerNormLinear,
in_features=config.hidden_size,
out_features=4 * config.hidden_size,
eps=config.eps,
normalization="RMSNorm",
params_dtype=dtype,
return_bias=False,
bias=False,
device="cuda",
)
assert te_ln_linear.te_module.weight._columnwise_data is not None

# Share params
weights_cache = []
with torch.no_grad():
weights_cache.append(te_ln_linear.te_module.layer_norm_weight.clone())
weights_cache.append(te_ln_linear.te_module.weight.clone())
outputs1 = _test_granular_accuracy(te_ln_linear, bs, dtype, config, recipe=fp8_recipe)

FP8GlobalStateManager.FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE = True
with fp8_model_init(enabled=True, recipe=fp8_recipe):
te_ln_linear = TestReturnBiasModule(
LayerNormLinear,
in_features=config.hidden_size,
out_features=4 * config.hidden_size,
eps=config.eps,
normalization="RMSNorm",
params_dtype=dtype,
return_bias=False,
bias=False,
device="cuda",
)
assert te_ln_linear.te_module.weight._columnwise_data is None
# update params
te_ln_linear.te_module.layer_norm_weight = Parameter(weights_cache[0])
te_ln_linear.te_module.weight = Parameter(weights_cache[1])

outputs2 = _test_granular_accuracy(te_ln_linear, bs, dtype, config, recipe=fp8_recipe)

# should be bit-wise match
for i, (o, o_ref) in enumerate(zip(outputs1, outputs2)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)

# 3. linear module test
FP8GlobalStateManager.FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE = False
with fp8_model_init(enabled=True, recipe=fp8_recipe):
te_linear = Linear(
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
device="cuda",
delay_wgrad_compute=False,
fuse_wgrad_accumulation=False,
).eval()
assert te_linear.weight._columnwise_data is not None

# Share params
with torch.no_grad():
weights_cache = te_linear.weight.clone()

te_outputs1 = _test_granular_accuracy(
te_linear, bs, dtype, config, delay_wgrad_compute=True, recipe=fp8_recipe
)

FP8GlobalStateManager.FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE = True
with fp8_model_init(enabled=True, recipe=fp8_recipe):
te_linear = Linear(
config.hidden_size,
4 * config.hidden_size,
bias=False,
params_dtype=dtype,
device="cuda",
delay_wgrad_compute=False,
fuse_wgrad_accumulation=False,
).eval()
assert te_linear.weight._columnwise_data is None

te_linear.weight = Parameter(weights_cache)
te_outputs2 = _test_granular_accuracy(
te_linear, bs, dtype, config, delay_wgrad_compute=True, recipe=fp8_recipe
)

# should be bit-wise match
for i, (o, o_ref) in enumerate(zip(outputs1, outputs2)):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)

FP8GlobalStateManager.FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE = old_value
17 changes: 17 additions & 0 deletions transformer_engine/pytorch/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType:
return Format.E5M2.value.max_fwd


def _get_fp8_blockwise_weight_on_demand_transpose():
# fp8 blockwise is not supported when sm >= 10.0
return int(
os.getenv("NVTE_ON_DEMAND_FP8_WEIGHT_TRANSPOSE", "0")
) > 0 and get_device_compute_capability() < (10, 0)


class FP8GlobalStateManager:
"""Class to keep track of and manipulate the global
FP8 state at different stages of execution.
Expand All @@ -128,6 +135,7 @@ class FP8GlobalStateManager:
IS_FIRST_FP8_MODULE = False
FP8_GRAPH_CAPTURING = False
FP8_AUTOCAST_DEPTH = 0
FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE = _get_fp8_blockwise_weight_on_demand_transpose()
global_amax_buffer = {}
global_amax_history_buffer = {}
global_scale_buffer = {}
Expand Down Expand Up @@ -313,6 +321,15 @@ def with_fp8_parameters(cls) -> bool:
"""Should the parameters be stored as FP8"""
return cls.FP8_PARAMETERS

@classmethod
def is_blockwise_fp8_weight_on_demand_transpose(cls) -> bool:
"""Should the blockwwise fp8 weight on-demand transpose"""
return (
cls.FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE
and cls.FP8_RECIPE is not None
and cls.FP8_RECIPE.float8_block_scaling()
)

@classmethod
def with_high_precision_init_val(cls) -> bool:
"""Should the high precision initial values be stored with FP8 parameters"""
Expand Down
8 changes: 7 additions & 1 deletion transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,6 +1265,9 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
# Wrap parameters in QuantizedTensor if needed
fp8_meta_index = self.param_init_meta[name].fp8_meta_index
high_precision_init_val = None
fp8_weight_on_demand_transpose = (
FP8GlobalStateManager.is_blockwise_fp8_weight_on_demand_transpose()
)
if self.primary_weights_in_fp8 and fp8_meta_index is not None:

# Keep high-precision values on CPU if needed
Expand All @@ -1275,7 +1278,10 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
quantizer = self.quantizers["scaling_fwd"][fp8_meta_index]
if quantizer is None:
raise RuntimeError("Weight quantizer has not been initialized")
quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled())
quantizer.set_usage(
rowwise=True,
columnwise=torch.is_grad_enabled() and not fp8_weight_on_demand_transpose,
)
quantizer.internal = False

# Quantize parameter
Expand Down
30 changes: 23 additions & 7 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
prepare_for_saving,
restore_from_saved,
)
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import get_columnwise_fp8_tensor

__all__ = ["GroupedLinear"]

Expand Down Expand Up @@ -89,6 +90,9 @@ def forward(
biases = weights_and_biases[num_gemms:]
device = inp.device
weight_requires_grad = weights[0].requires_grad
fp8_weight_on_demand_transpose = (
FP8GlobalStateManager.is_blockwise_fp8_weight_on_demand_transpose()
)

# Configure quantizers
if save_original_input and isinstance(input_quantizers[0], Float8Quantizer):
Expand All @@ -109,7 +113,10 @@ def forward(
)
if weight_quantizers[0] is not None:
for weight_quantizer in weight_quantizers:
weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
weight_quantizer.set_usage(
rowwise=True,
columnwise=columnwise_usage and not fp8_weight_on_demand_transpose,
)
if output_quantizers[0] is not None:
for output_quantizer in output_quantizers:
output_quantizer.set_usage(rowwise=True, columnwise=False)
Expand Down Expand Up @@ -192,6 +199,7 @@ def forward(
if is_grad_enabled:
ctx.weight_quantizers = weight_quantizers
ctx.weights_shape_1 = weights[0].shape[1]
ctx.fp8_weight_on_demand_transpose = fp8_weight_on_demand_transpose

# TODO: update after #1638 is merged. # pylint: disable=fixme
if weight_requires_grad:
Expand All @@ -206,7 +214,10 @@ def forward(
inputmats = [None] * num_gemms
if inp.requires_grad:
for weight in weights_fp8:
if isinstance(weight, QuantizedTensorBase):
if (
isinstance(weight, QuantizedTensorBase)
and not fp8_weight_on_demand_transpose
):
weight.update_usage(columnwise_usage=True)

tensors_to_save, tensor_objects = prepare_for_saving(
Expand Down Expand Up @@ -337,14 +348,19 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
device=ctx.device,
)

columnwise_weights = []
for weight, quantizer in zip(weights, ctx.weight_quantizers):
if quantizer is not None and isinstance(weight, QuantizedTensorBase):
weight.update_usage(
rowwise_usage=quantizer.rowwise_usage,
columnwise_usage=quantizer.columnwise_usage,
)
if ctx.fp8_weight_on_demand_transpose:
columnwise_weight = get_columnwise_fp8_tensor(weight)
columnwise_weights.append(columnwise_weight)
else:
weight.update_usage(
rowwise_usage=quantizer.rowwise_usage,
columnwise_usage=quantizer.columnwise_usage,
)
general_grouped_gemm(
weights,
columnwise_weights if ctx.fp8_weight_on_demand_transpose else weights,
grad_output,
[dgrad],
ctx.activation_dtype,
Expand Down
16 changes: 13 additions & 3 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from ..cpp_extensions import (
general_gemm,
)
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import get_columnwise_fp8_tensor

__all__ = ["LayerNormLinear"]

Expand Down Expand Up @@ -135,6 +136,10 @@ def forward(
if ub_name is not None:
nvtx_label = f"{nvtx_label}.{ub_name}"

fp8_weight_on_demand_transpose = (
FP8GlobalStateManager.is_blockwise_fp8_weight_on_demand_transpose()
)

# Make sure input dimensions are compatible
out_features, in_features = weight.shape
inp_shape = inp.shape
Expand Down Expand Up @@ -275,7 +280,9 @@ def forward(

# Configure quantizer
if weight_quantizer is not None:
weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled)
weight_quantizer.set_usage(
rowwise=True, columnwise=is_grad_enabled and not fp8_weight_on_demand_transpose
)

# Get quantized weight
update_workspace = is_first_microbatch is None or is_first_microbatch
Expand Down Expand Up @@ -404,7 +411,7 @@ def forward(
ln_out.update_usage(rowwise_usage=False)

# Weight with column-wise usage is needed for dgrad GEMM.
if isinstance(weightmat, QuantizedTensorBase):
if isinstance(weightmat, QuantizedTensorBase) and not fp8_weight_on_demand_transpose:
weightmat.update_usage(columnwise_usage=True)

if cpu_offloading:
Expand Down Expand Up @@ -678,7 +685,10 @@ def backward(
if isinstance(grad_output, QuantizedTensorBase):
grad_output.update_usage(rowwise_usage=True)
if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorBase):
weight.update_usage(columnwise_usage=True)
if FP8GlobalStateManager.is_blockwise_fp8_weight_on_demand_transpose():
weight = get_columnwise_fp8_tensor(weight)
else:
weight.update_usage(columnwise_usage=True)

# Choose whether to use GEMM kernel with split accumulator
use_split_accumulator = _2X_ACC_DGRAD
Expand Down
Loading