diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index e720673675..729cf71170 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2729,3 +2729,179 @@ def _run_module(m, inp): out = _run_module(g2, b) assert_allclose(out, outT, 1e-7) + + +def test_fp8_weight_on_demand_transpose(): + if not fp8_block_scaling_available: + pytest.skip("blockwise fp8 not available.") + + dtype = torch.bfloat16 + num_gemms = 4 + bs = 4 + fp8_recipe = recipe.Float8BlockScaling() + config = model_configs["126m"] + + old_value = FP8GlobalStateManager.FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE + + # 1. grouped linear module test + FP8GlobalStateManager.FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE = False + + with fp8_model_init(enabled=True, recipe=fp8_recipe): + grouped_linear = GroupedLinear( + num_gemms, + config.hidden_size, + 4 * config.hidden_size, + bias=False, + params_dtype=dtype, + device="cuda", + ).eval() + + # Share params + with torch.no_grad(): + weights_cache = [ + Parameter(getattr(grouped_linear, f"weight{i}").clone()) for i in range(num_gemms) + ] + + for i in range(num_gemms): + assert getattr(grouped_linear, f"weight{i}")._columnwise_data is not None + + outputs1 = _test_grouped_linear_accuracy( + grouped_linear, + num_gemms, + bs, + dtype, + config, + fp8_recipe, + True, + False, + False, + ) + + FP8GlobalStateManager.FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE = True + + with fp8_model_init(enabled=True, recipe=fp8_recipe): + grouped_linear = GroupedLinear( + num_gemms, + config.hidden_size, + 4 * config.hidden_size, + bias=False, + params_dtype=dtype, + device="cuda", + ).eval() + + # Share params + with torch.no_grad(): + for i in range(num_gemms): + w = getattr(grouped_linear, f"weight{i}") + assert w._columnwise_data is None + w._rowwise_data = weights_cache[i]._rowwise_data + w._rowwise_scale_inv = weights_cache[i]._rowwise_scale_inv + + outputs2 = _test_grouped_linear_accuracy( + grouped_linear, + num_gemms, + bs, + dtype, + config, + fp8_recipe, + True, + False, + False, + ) + + # should be bit-wise match + for i, (o, o_ref) in enumerate(zip(outputs1, outputs2)): + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + + # 2. layernorm linear module test + FP8GlobalStateManager.FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE = False + with fp8_model_init(enabled=True, recipe=fp8_recipe): + te_ln_linear = TestReturnBiasModule( + LayerNormLinear, + in_features=config.hidden_size, + out_features=4 * config.hidden_size, + eps=config.eps, + normalization="RMSNorm", + params_dtype=dtype, + return_bias=False, + bias=False, + device="cuda", + ) + assert te_ln_linear.te_module.weight._columnwise_data is not None + + # Share params + weights_cache = [] + with torch.no_grad(): + weights_cache.append(te_ln_linear.te_module.layer_norm_weight.clone()) + weights_cache.append(te_ln_linear.te_module.weight.clone()) + outputs1 = _test_granular_accuracy(te_ln_linear, bs, dtype, config, recipe=fp8_recipe) + + FP8GlobalStateManager.FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE = True + with fp8_model_init(enabled=True, recipe=fp8_recipe): + te_ln_linear = TestReturnBiasModule( + LayerNormLinear, + in_features=config.hidden_size, + out_features=4 * config.hidden_size, + eps=config.eps, + normalization="RMSNorm", + params_dtype=dtype, + return_bias=False, + bias=False, + device="cuda", + ) + assert te_ln_linear.te_module.weight._columnwise_data is None + # update params + te_ln_linear.te_module.layer_norm_weight = Parameter(weights_cache[0]) + te_ln_linear.te_module.weight = Parameter(weights_cache[1]) + + outputs2 = _test_granular_accuracy(te_ln_linear, bs, dtype, config, recipe=fp8_recipe) + + # should be bit-wise match + for i, (o, o_ref) in enumerate(zip(outputs1, outputs2)): + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + + # 3. linear module test + FP8GlobalStateManager.FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE = False + with fp8_model_init(enabled=True, recipe=fp8_recipe): + te_linear = Linear( + config.hidden_size, + 4 * config.hidden_size, + bias=False, + params_dtype=dtype, + device="cuda", + delay_wgrad_compute=False, + fuse_wgrad_accumulation=False, + ).eval() + assert te_linear.weight._columnwise_data is not None + + # Share params + with torch.no_grad(): + weights_cache = te_linear.weight.clone() + + te_outputs1 = _test_granular_accuracy( + te_linear, bs, dtype, config, delay_wgrad_compute=True, recipe=fp8_recipe + ) + + FP8GlobalStateManager.FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE = True + with fp8_model_init(enabled=True, recipe=fp8_recipe): + te_linear = Linear( + config.hidden_size, + 4 * config.hidden_size, + bias=False, + params_dtype=dtype, + device="cuda", + delay_wgrad_compute=False, + fuse_wgrad_accumulation=False, + ).eval() + assert te_linear.weight._columnwise_data is None + + te_linear.weight = Parameter(weights_cache) + te_outputs2 = _test_granular_accuracy( + te_linear, bs, dtype, config, delay_wgrad_compute=True, recipe=fp8_recipe + ) + + # should be bit-wise match + for i, (o, o_ref) in enumerate(zip(outputs1, outputs2)): + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + + FP8GlobalStateManager.FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE = old_value diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 8f9dbd88d0..03febb6f6e 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -114,6 +114,13 @@ def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: return Format.E5M2.value.max_fwd +def _get_fp8_blockwise_weight_on_demand_transpose(): + # fp8 blockwise is not supported when sm >= 10.0 + return int( + os.getenv("NVTE_ON_DEMAND_FP8_WEIGHT_TRANSPOSE", "0") + ) > 0 and get_device_compute_capability() < (10, 0) + + class FP8GlobalStateManager: """Class to keep track of and manipulate the global FP8 state at different stages of execution. @@ -128,6 +135,7 @@ class FP8GlobalStateManager: IS_FIRST_FP8_MODULE = False FP8_GRAPH_CAPTURING = False FP8_AUTOCAST_DEPTH = 0 + FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE = _get_fp8_blockwise_weight_on_demand_transpose() global_amax_buffer = {} global_amax_history_buffer = {} global_scale_buffer = {} @@ -313,6 +321,15 @@ def with_fp8_parameters(cls) -> bool: """Should the parameters be stored as FP8""" return cls.FP8_PARAMETERS + @classmethod + def is_blockwise_fp8_weight_on_demand_transpose(cls) -> bool: + """Should the blockwwise fp8 weight on-demand transpose""" + return ( + cls.FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE + and cls.FP8_RECIPE is not None + and cls.FP8_RECIPE.float8_block_scaling() + ) + @classmethod def with_high_precision_init_val(cls) -> bool: """Should the high precision initial values be stored with FP8 parameters""" diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 0f2e3c4de1..6c66d76545 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1265,6 +1265,9 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: # Wrap parameters in QuantizedTensor if needed fp8_meta_index = self.param_init_meta[name].fp8_meta_index high_precision_init_val = None + fp8_weight_on_demand_transpose = ( + FP8GlobalStateManager.is_blockwise_fp8_weight_on_demand_transpose() + ) if self.primary_weights_in_fp8 and fp8_meta_index is not None: # Keep high-precision values on CPU if needed @@ -1275,7 +1278,10 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: quantizer = self.quantizers["scaling_fwd"][fp8_meta_index] if quantizer is None: raise RuntimeError("Weight quantizer has not been initialized") - quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) + quantizer.set_usage( + rowwise=True, + columnwise=torch.is_grad_enabled() and not fp8_weight_on_demand_transpose, + ) quantizer.internal = False # Quantize parameter diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index e9189ccc59..d9a7f153d3 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -49,6 +49,7 @@ prepare_for_saving, restore_from_saved, ) +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import get_columnwise_fp8_tensor __all__ = ["GroupedLinear"] @@ -89,6 +90,9 @@ def forward( biases = weights_and_biases[num_gemms:] device = inp.device weight_requires_grad = weights[0].requires_grad + fp8_weight_on_demand_transpose = ( + FP8GlobalStateManager.is_blockwise_fp8_weight_on_demand_transpose() + ) # Configure quantizers if save_original_input and isinstance(input_quantizers[0], Float8Quantizer): @@ -109,7 +113,10 @@ def forward( ) if weight_quantizers[0] is not None: for weight_quantizer in weight_quantizers: - weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + weight_quantizer.set_usage( + rowwise=True, + columnwise=columnwise_usage and not fp8_weight_on_demand_transpose, + ) if output_quantizers[0] is not None: for output_quantizer in output_quantizers: output_quantizer.set_usage(rowwise=True, columnwise=False) @@ -192,6 +199,7 @@ def forward( if is_grad_enabled: ctx.weight_quantizers = weight_quantizers ctx.weights_shape_1 = weights[0].shape[1] + ctx.fp8_weight_on_demand_transpose = fp8_weight_on_demand_transpose # TODO: update after #1638 is merged. # pylint: disable=fixme if weight_requires_grad: @@ -206,7 +214,10 @@ def forward( inputmats = [None] * num_gemms if inp.requires_grad: for weight in weights_fp8: - if isinstance(weight, QuantizedTensorBase): + if ( + isinstance(weight, QuantizedTensorBase) + and not fp8_weight_on_demand_transpose + ): weight.update_usage(columnwise_usage=True) tensors_to_save, tensor_objects = prepare_for_saving( @@ -337,14 +348,19 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], device=ctx.device, ) + columnwise_weights = [] for weight, quantizer in zip(weights, ctx.weight_quantizers): if quantizer is not None and isinstance(weight, QuantizedTensorBase): - weight.update_usage( - rowwise_usage=quantizer.rowwise_usage, - columnwise_usage=quantizer.columnwise_usage, - ) + if ctx.fp8_weight_on_demand_transpose: + columnwise_weight = get_columnwise_fp8_tensor(weight) + columnwise_weights.append(columnwise_weight) + else: + weight.update_usage( + rowwise_usage=quantizer.rowwise_usage, + columnwise_usage=quantizer.columnwise_usage, + ) general_grouped_gemm( - weights, + columnwise_weights if ctx.fp8_weight_on_demand_transpose else weights, grad_output, [dgrad], ctx.activation_dtype, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index cd02f31132..de29e10a7c 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -72,6 +72,7 @@ from ..cpp_extensions import ( general_gemm, ) +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import get_columnwise_fp8_tensor __all__ = ["LayerNormLinear"] @@ -135,6 +136,10 @@ def forward( if ub_name is not None: nvtx_label = f"{nvtx_label}.{ub_name}" + fp8_weight_on_demand_transpose = ( + FP8GlobalStateManager.is_blockwise_fp8_weight_on_demand_transpose() + ) + # Make sure input dimensions are compatible out_features, in_features = weight.shape inp_shape = inp.shape @@ -275,7 +280,9 @@ def forward( # Configure quantizer if weight_quantizer is not None: - weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) + weight_quantizer.set_usage( + rowwise=True, columnwise=is_grad_enabled and not fp8_weight_on_demand_transpose + ) # Get quantized weight update_workspace = is_first_microbatch is None or is_first_microbatch @@ -404,7 +411,7 @@ def forward( ln_out.update_usage(rowwise_usage=False) # Weight with column-wise usage is needed for dgrad GEMM. - if isinstance(weightmat, QuantizedTensorBase): + if isinstance(weightmat, QuantizedTensorBase) and not fp8_weight_on_demand_transpose: weightmat.update_usage(columnwise_usage=True) if cpu_offloading: @@ -678,7 +685,10 @@ def backward( if isinstance(grad_output, QuantizedTensorBase): grad_output.update_usage(rowwise_usage=True) if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorBase): - weight.update_usage(columnwise_usage=True) + if FP8GlobalStateManager.is_blockwise_fp8_weight_on_demand_transpose(): + weight = get_columnwise_fp8_tensor(weight) + else: + weight.update_usage(columnwise_usage=True) # Choose whether to use GEMM kernel with split accumulator use_split_accumulator = _2X_ACC_DGRAD diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 2ce6fb4c1d..7b98104a99 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -68,6 +68,7 @@ from ..export import is_in_onnx_export_mode, assert_warmed_up from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...debug.pytorch.debug_state import TEDebugState +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import get_columnwise_fp8_tensor __all__ = ["Linear"] @@ -128,6 +129,10 @@ def forward( out_features, in_features = weight.shape assert inp.shape[-1] == in_features, "GEMM not possible" + fp8_weight_on_demand_transpose = ( + FP8GlobalStateManager.is_blockwise_fp8_weight_on_demand_transpose() + ) + # Configure tensor-parallel communication tp_world_size = get_distributed_world_size(tp_group) backward_needs_input = is_grad_enabled and weight.requires_grad @@ -241,7 +246,9 @@ def forward( is_fp8_activation_recompute_enabled() and not in_fp8_activation_recompute_phase() ) - weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + weight_quantizer.set_usage( + rowwise=True, columnwise=columnwise_usage and not fp8_weight_on_demand_transpose + ) # Get quantized weight update_workspace = is_first_microbatch is None or is_first_microbatch @@ -378,7 +385,10 @@ def forward( # Weight with column-wise usage is needed for dgrad GEMM. if inp.requires_grad: - if isinstance(weightmat, QuantizedTensorBase): + if ( + isinstance(weightmat, QuantizedTensorBase) + and not fp8_weight_on_demand_transpose + ): weightmat.update_usage(columnwise_usage=True) if cpu_offloading and saved_inputmat is not None: @@ -666,7 +676,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if isinstance(grad_output, QuantizedTensorBase): grad_output.update_usage(rowwise_usage=True) if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensorBase): - weight_fp8.update_usage(columnwise_usage=True) + if FP8GlobalStateManager.is_blockwise_fp8_weight_on_demand_transpose(): + weight_fp8 = get_columnwise_fp8_tensor(weight_fp8) + else: + weight_fp8.update_usage(columnwise_usage=True) # Choose whether to use GEMM kernel with split accumulator use_split_accumulator = _2X_ACC_DGRAD diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 0e41fc9c51..8b92cf0050 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -798,3 +798,25 @@ def backward( ) return dgrad, None return grad.view(ctx.shape), None + + +def get_columnwise_fp8_tensor(rowwise_tensor, requires_grad=False): + columnwise_scale_inv = rowwise_tensor._rowwise_scale_inv.transpose(-2, -1).contiguous() + M, N = rowwise_tensor.shape + columnwise_data = torch.empty( + (N, M), device=rowwise_tensor.device, dtype=rowwise_tensor._rowwise_data.dtype + ) + tex.fp8_transpose(rowwise_tensor._rowwise_data, rowwise_tensor._fp8_dtype, out=columnwise_data) + + return Float8BlockwiseQTensor( + shape=rowwise_tensor.shape, + dtype=rowwise_tensor.dtype, + fp8_dtype=rowwise_tensor._fp8_dtype, + rowwise_data=None, + rowwise_scale_inv=None, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + quantizer=rowwise_tensor._quantizer, + is_2D_scaled=True, + requires_grad=requires_grad, + )