Skip to content

Commit f35e5dd

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 59e7dcd commit f35e5dd

File tree

7 files changed

+47
-30
lines changed

7 files changed

+47
-30
lines changed

tests/pytorch/test_numerics.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2733,7 +2733,7 @@ def _run_module(m, inp):
27332733

27342734
def test_fp8_weight_on_demand_transpose():
27352735
if not fp8_block_scaling_available:
2736-
pytest.skip("blockwise fp8 not available.")
2736+
pytest.skip("blockwise fp8 not available.")
27372737

27382738
dtype = torch.bfloat16
27392739
num_gemms = 4
@@ -2758,7 +2758,9 @@ def test_fp8_weight_on_demand_transpose():
27582758

27592759
# Share params
27602760
with torch.no_grad():
2761-
weights_cache = [Parameter(getattr(grouped_linear, f"weight{i}").clone()) for i in range(num_gemms)]
2761+
weights_cache = [
2762+
Parameter(getattr(grouped_linear, f"weight{i}").clone()) for i in range(num_gemms)
2763+
]
27622764

27632765
for i in range(num_gemms):
27642766
assert getattr(grouped_linear, f"weight{i}")._columnwise_data is not None
@@ -2811,7 +2813,6 @@ def test_fp8_weight_on_demand_transpose():
28112813
for i, (o, o_ref) in enumerate(zip(outputs1, outputs2)):
28122814
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
28132815

2814-
28152816
# 2. layernorm linear module test
28162817
FP8GlobalStateManager.FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE = False
28172818
with fp8_model_init(enabled=True, recipe=fp8_recipe):
@@ -2831,9 +2832,7 @@ def test_fp8_weight_on_demand_transpose():
28312832
# Share params
28322833
weights_cache = []
28332834
with torch.no_grad():
2834-
weights_cache.append(
2835-
te_ln_linear.te_module.layer_norm_weight.clone()
2836-
)
2835+
weights_cache.append(te_ln_linear.te_module.layer_norm_weight.clone())
28372836
weights_cache.append(te_ln_linear.te_module.weight.clone())
28382837
outputs1 = _test_granular_accuracy(te_ln_linear, bs, dtype, config, recipe=fp8_recipe)
28392838

@@ -2861,7 +2860,6 @@ def test_fp8_weight_on_demand_transpose():
28612860
for i, (o, o_ref) in enumerate(zip(outputs1, outputs2)):
28622861
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
28632862

2864-
28652863
# 3. linear module test
28662864
FP8GlobalStateManager.FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE = False
28672865
with fp8_model_init(enabled=True, recipe=fp8_recipe):
@@ -2880,7 +2878,9 @@ def test_fp8_weight_on_demand_transpose():
28802878
with torch.no_grad():
28812879
weights_cache = te_linear.weight.clone()
28822880

2883-
te_outputs1 = _test_granular_accuracy(te_linear, bs, dtype, config, delay_wgrad_compute=True, recipe=fp8_recipe)
2881+
te_outputs1 = _test_granular_accuracy(
2882+
te_linear, bs, dtype, config, delay_wgrad_compute=True, recipe=fp8_recipe
2883+
)
28842884

28852885
FP8GlobalStateManager.FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE = True
28862886
with fp8_model_init(enabled=True, recipe=fp8_recipe):
@@ -2896,10 +2896,12 @@ def test_fp8_weight_on_demand_transpose():
28962896
assert te_linear.weight._columnwise_data is None
28972897

28982898
te_linear.weight = Parameter(weights_cache)
2899-
te_outputs2 = _test_granular_accuracy(te_linear, bs, dtype, config, delay_wgrad_compute=True, recipe=fp8_recipe)
2899+
te_outputs2 = _test_granular_accuracy(
2900+
te_linear, bs, dtype, config, delay_wgrad_compute=True, recipe=fp8_recipe
2901+
)
29002902

29012903
# should be bit-wise match
29022904
for i, (o, o_ref) in enumerate(zip(outputs1, outputs2)):
29032905
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
29042906

2905-
FP8GlobalStateManager.FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE = old_value
2907+
FP8GlobalStateManager.FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE = old_value

transformer_engine/pytorch/fp8.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,13 @@ def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType:
113113
return Format.E4M3.value.max_fwd
114114
return Format.E5M2.value.max_fwd
115115

116+
116117
def _get_fp8_blockwise_weight_on_demand_transpose():
117118
# fp8 blockwise is not supported when sm >= 10.0
118-
return (
119-
int(os.getenv("NVTE_ON_DEMAND_FP8_WEIGHT_TRANSPOSE", "0")) > 0
120-
and get_device_compute_capability() < (10, 0)
121-
)
119+
return int(
120+
os.getenv("NVTE_ON_DEMAND_FP8_WEIGHT_TRANSPOSE", "0")
121+
) > 0 and get_device_compute_capability() < (10, 0)
122+
122123

123124
class FP8GlobalStateManager:
124125
"""Class to keep track of and manipulate the global

transformer_engine/pytorch/module/base.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1265,7 +1265,9 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
12651265
# Wrap parameters in QuantizedTensor if needed
12661266
fp8_meta_index = self.param_init_meta[name].fp8_meta_index
12671267
high_precision_init_val = None
1268-
fp8_weight_on_demand_transpose = FP8GlobalStateManager.is_blockwise_fp8_weight_on_demand_transpose()
1268+
fp8_weight_on_demand_transpose = (
1269+
FP8GlobalStateManager.is_blockwise_fp8_weight_on_demand_transpose()
1270+
)
12691271
if self.primary_weights_in_fp8 and fp8_meta_index is not None:
12701272

12711273
# Keep high-precision values on CPU if needed
@@ -1276,7 +1278,10 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
12761278
quantizer = self.quantizers["scaling_fwd"][fp8_meta_index]
12771279
if quantizer is None:
12781280
raise RuntimeError("Weight quantizer has not been initialized")
1279-
quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled() and not fp8_weight_on_demand_transpose)
1281+
quantizer.set_usage(
1282+
rowwise=True,
1283+
columnwise=torch.is_grad_enabled() and not fp8_weight_on_demand_transpose,
1284+
)
12801285
quantizer.internal = False
12811286

12821287
# Quantize parameter

transformer_engine/pytorch/module/grouped_linear.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ def forward(
9090
biases = weights_and_biases[num_gemms:]
9191
device = inp.device
9292
weight_requires_grad = weights[0].requires_grad
93-
fp8_weight_on_demand_transpose = FP8GlobalStateManager.is_blockwise_fp8_weight_on_demand_transpose()
93+
fp8_weight_on_demand_transpose = (
94+
FP8GlobalStateManager.is_blockwise_fp8_weight_on_demand_transpose()
95+
)
9496

9597
# Configure quantizers
9698
if save_original_input and isinstance(input_quantizers[0], Float8Quantizer):
@@ -113,7 +115,7 @@ def forward(
113115
for weight_quantizer in weight_quantizers:
114116
weight_quantizer.set_usage(
115117
rowwise=True,
116-
columnwise=columnwise_usage and not fp8_weight_on_demand_transpose
118+
columnwise=columnwise_usage and not fp8_weight_on_demand_transpose,
117119
)
118120
if output_quantizers[0] is not None:
119121
for output_quantizer in output_quantizers:
@@ -212,7 +214,10 @@ def forward(
212214
inputmats = [None] * num_gemms
213215
if inp.requires_grad:
214216
for weight in weights_fp8:
215-
if isinstance(weight, QuantizedTensorBase) and not fp8_weight_on_demand_transpose:
217+
if (
218+
isinstance(weight, QuantizedTensorBase)
219+
and not fp8_weight_on_demand_transpose
220+
):
216221
weight.update_usage(columnwise_usage=True)
217222

218223
tensors_to_save, tensor_objects = prepare_for_saving(

transformer_engine/pytorch/module/layernorm_linear.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,9 @@ def forward(
136136
if ub_name is not None:
137137
nvtx_label = f"{nvtx_label}.{ub_name}"
138138

139-
fp8_weight_on_demand_transpose = FP8GlobalStateManager.is_blockwise_fp8_weight_on_demand_transpose()
139+
fp8_weight_on_demand_transpose = (
140+
FP8GlobalStateManager.is_blockwise_fp8_weight_on_demand_transpose()
141+
)
140142

141143
# Make sure input dimensions are compatible
142144
out_features, in_features = weight.shape
@@ -279,8 +281,7 @@ def forward(
279281
# Configure quantizer
280282
if weight_quantizer is not None:
281283
weight_quantizer.set_usage(
282-
rowwise=True,
283-
columnwise=is_grad_enabled and not fp8_weight_on_demand_transpose
284+
rowwise=True, columnwise=is_grad_enabled and not fp8_weight_on_demand_transpose
284285
)
285286

286287
# Get quantized weight

transformer_engine/pytorch/module/linear.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,9 @@ def forward(
129129
out_features, in_features = weight.shape
130130
assert inp.shape[-1] == in_features, "GEMM not possible"
131131

132-
fp8_weight_on_demand_transpose = FP8GlobalStateManager.is_blockwise_fp8_weight_on_demand_transpose()
132+
fp8_weight_on_demand_transpose = (
133+
FP8GlobalStateManager.is_blockwise_fp8_weight_on_demand_transpose()
134+
)
133135

134136
# Configure tensor-parallel communication
135137
tp_world_size = get_distributed_world_size(tp_group)
@@ -245,8 +247,7 @@ def forward(
245247
and not in_fp8_activation_recompute_phase()
246248
)
247249
weight_quantizer.set_usage(
248-
rowwise=True,
249-
columnwise=columnwise_usage and not fp8_weight_on_demand_transpose
250+
rowwise=True, columnwise=columnwise_usage and not fp8_weight_on_demand_transpose
250251
)
251252

252253
# Get quantized weight
@@ -384,7 +385,10 @@ def forward(
384385

385386
# Weight with column-wise usage is needed for dgrad GEMM.
386387
if inp.requires_grad:
387-
if isinstance(weightmat, QuantizedTensorBase) and not fp8_weight_on_demand_transpose:
388+
if (
389+
isinstance(weightmat, QuantizedTensorBase)
390+
and not fp8_weight_on_demand_transpose
391+
):
388392
weightmat.update_usage(columnwise_usage=True)
389393

390394
if cpu_offloading and saved_inputmat is not None:

transformer_engine/pytorch/tensor/float8_blockwise_tensor.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -803,9 +803,8 @@ def backward(
803803
def get_columnwise_fp8_tensor(rowwise_tensor, requires_grad=False):
804804
columnwise_scale_inv = rowwise_tensor._rowwise_scale_inv.transpose(-2, -1).contiguous()
805805
M, N = rowwise_tensor.shape
806-
columnwise_data = torch.empty((N, M),
807-
device=rowwise_tensor.device,
808-
dtype=rowwise_tensor._rowwise_data.dtype
806+
columnwise_data = torch.empty(
807+
(N, M), device=rowwise_tensor.device, dtype=rowwise_tensor._rowwise_data.dtype
809808
)
810809
tex.fp8_transpose(rowwise_tensor._rowwise_data, rowwise_tensor._fp8_dtype, out=columnwise_data)
811810

@@ -820,4 +819,4 @@ def get_columnwise_fp8_tensor(rowwise_tensor, requires_grad=False):
820819
quantizer=rowwise_tensor._quantizer,
821820
is_2D_scaled=True,
822821
requires_grad=requires_grad,
823-
)
822+
)

0 commit comments

Comments
 (0)