Skip to content

Commit b8a1722

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 5cefbcf commit b8a1722

File tree

3 files changed

+65
-49
lines changed

3 files changed

+65
-49
lines changed

tests/pytorch/nvfp4/test_nvfp4_group_quantize.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True)
2727

28+
2829
def generate_random_multiples_sum(total=8192, n=4, multiple=64):
2930
if total % multiple != 0:
3031
raise ValueError(f"Total ({total}) must be a multiple of {multiple}")
@@ -38,13 +39,14 @@ def generate_random_multiples_sum(total=8192, n=4, multiple=64):
3839
cuts = sorted(random.sample(range(1, total_units), n - 1))
3940

4041
# convert to segment lengths
41-
parts = [cuts[0]] + \
42-
[cuts[i] - cuts[i-1] for i in range(1, len(cuts))] + \
43-
[total_units - cuts[-1]]
42+
parts = (
43+
[cuts[0]] + [cuts[i] - cuts[i - 1] for i in range(1, len(cuts))] + [total_units - cuts[-1]]
44+
)
4445

4546
# convert back to multiples
4647
return [p * multiple for p in parts]
4748

49+
4850
def generate_split_sections(M: int, N: int, edge_cases: str) -> list[int]:
4951
least_multiple = 64
5052
num_chunks = 4
@@ -53,7 +55,7 @@ def generate_split_sections(M: int, N: int, edge_cases: str) -> list[int]:
5355
avg_split = M // num_chunks
5456

5557
if M == 0 or N == 0:
56-
# all zeros
58+
# all zeros
5759
return [0] * num_chunks
5860
if edge_cases == "regular":
5961
split_sections = [avg_split] * num_chunks
@@ -73,7 +75,9 @@ def generate_split_sections(M: int, N: int, edge_cases: str) -> list[int]:
7375

7476
# make sure every split_section is a multiple of least_multiple
7577
for split_section in split_sections:
76-
assert split_section % least_multiple == 0, "The split_sections are not multiples of least_multiple"
78+
assert (
79+
split_section % least_multiple == 0
80+
), "The split_sections are not multiples of least_multiple"
7781

7882
return split_sections
7983

@@ -175,8 +179,8 @@ def check_group_quantization_nvfp4_versus_reference(
175179
)
176180
for _ in range(len(split_sections))
177181
]
178-
x_qx_ref, x_sx_ref, x_amax_rowwise_ref, x_qx_t_ref, x_sx_t_ref, x_amax_colwise_ref = reference_group_quantize(
179-
x, quantizers, split_sections, return_transpose
182+
x_qx_ref, x_sx_ref, x_amax_rowwise_ref, x_qx_t_ref, x_sx_t_ref, x_amax_colwise_ref = (
183+
reference_group_quantize(x, quantizers, split_sections, return_transpose)
180184
)
181185

182186
split_quantize_outputs = tex.split_quantize(x, split_sections, quantizers)
@@ -195,27 +199,31 @@ def check_group_quantization_nvfp4_versus_reference(
195199
torch.testing.assert_close(x_amax_rowwise[i], x_amax_rowwise_ref[i], atol=0.0, rtol=0.0)
196200
torch.testing.assert_close(x_qx[i], x_qx_ref[i], atol=0.0, rtol=0.0)
197201
valid_scale_shape = get_nvfp4_scale_shape_no_padding(x_splits[i].shape, False)
198-
x_sx_valid = x_sx[i][:valid_scale_shape[0], :valid_scale_shape[1]]
199-
x_sx_ref_valid = x_sx_ref[i][:valid_scale_shape[0], :valid_scale_shape[1]]
202+
x_sx_valid = x_sx[i][: valid_scale_shape[0], : valid_scale_shape[1]]
203+
x_sx_ref_valid = x_sx_ref[i][: valid_scale_shape[0], : valid_scale_shape[1]]
200204
torch.testing.assert_close(x_sx_valid, x_sx_ref_valid, atol=0.0, rtol=0.0)
201-
205+
202206
if return_transpose:
203-
x_qx_t = [output._columnwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs]
207+
x_qx_t = [
208+
output._columnwise_data.view(dtype=torch.uint8) for output in split_quantize_outputs
209+
]
204210
x_sx_t = [output._columnwise_scale_inv for output in split_quantize_outputs]
205211
x_amax_colwise = [output._amax_columnwise for output in split_quantize_outputs]
206-
# assert with zero tolerance
212+
# assert with zero tolerance
207213
for i in range(len(x_qx_t)):
208214
if split_sections[i] == 0:
209215
# then just assert the same same and dtype because the buffer won't be zero out
210216
assert_same_shape_and_dtype(x_amax_colwise[i], x_amax_colwise_ref[i])
211217
assert_same_shape_and_dtype(x_qx_t[i], x_qx_t_ref[i])
212218
assert_same_shape_and_dtype(x_sx_t[i], x_sx_t_ref[i])
213-
else:
214-
torch.testing.assert_close(x_amax_colwise[i], x_amax_colwise_ref[i], atol=0.0, rtol=0.0)
219+
else:
220+
torch.testing.assert_close(
221+
x_amax_colwise[i], x_amax_colwise_ref[i], atol=0.0, rtol=0.0
222+
)
215223
torch.testing.assert_close(x_qx_t[i], x_qx_t_ref[i], atol=0.0, rtol=0.0)
216224
valid_scale_shape = get_nvfp4_scale_shape_no_padding(x_splits[i].shape, True)
217-
x_sx_t_valid = x_sx_t[i][:valid_scale_shape[0], :valid_scale_shape[1]]
218-
x_sx_t_ref_valid = x_sx_t_ref[i][:valid_scale_shape[0], :valid_scale_shape[1]]
225+
x_sx_t_valid = x_sx_t[i][: valid_scale_shape[0], : valid_scale_shape[1]]
226+
x_sx_t_ref_valid = x_sx_t_ref[i][: valid_scale_shape[0], : valid_scale_shape[1]]
219227
torch.testing.assert_close(x_sx_t_valid, x_sx_t_ref_valid, atol=0.0, rtol=0.0)
220228

221229

@@ -234,7 +242,14 @@ def check_group_quantization_nvfp4_versus_reference(
234242
)
235243
@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str)
236244
@pytest.mark.parametrize(
237-
"edge_cases", ["regular", "zero_tokens_front", "zero_tokens_end", "zero_tokens_middle", "random_uneven_split"]
245+
"edge_cases",
246+
[
247+
"regular",
248+
"zero_tokens_front",
249+
"zero_tokens_end",
250+
"zero_tokens_middle",
251+
"random_uneven_split",
252+
],
238253
)
239254
@pytest.mark.parametrize(
240255
"return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"]
@@ -263,5 +278,3 @@ def test_rht_with_quantization_block_tiling_versus_reference(
263278
with_post_rht_amax=True,
264279
with_random_sign_mask=with_random_sign_mask,
265280
)
266-
267-

tests/pytorch/test_numerics.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
import torch.nn as nn
1313
from torch.nn import Parameter
1414

15-
from transformer_engine.pytorch.quantization import FP8GlobalStateManager, get_align_size_for_quantization
15+
from transformer_engine.pytorch.quantization import (
16+
FP8GlobalStateManager,
17+
get_align_size_for_quantization,
18+
)
1619
from transformer_engine.pytorch.utils import (
1720
init_method_normal,
1821
scaled_init_method_normal,

transformer_engine/pytorch/csrc/extensions/cast.cpp

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ void multi_tensor_quantize_nvfp4_impl(const TensorWrapper &input,
146146
nvte_tensor_input_list.push_back(input_list[i].data());
147147
nvte_tensor_output_list.push_back(output_list[i].data());
148148
}
149-
149+
150150
// stochastic rounding support for multi tensor
151151
if (quantizer->stochastic_rounding) {
152152
// TODO: implement stochastic rounding support for multi tensor
@@ -160,29 +160,26 @@ void multi_tensor_quantize_nvfp4_impl(const TensorWrapper &input,
160160

161161
// with or without RHT, use nvte_multi_hadamard_transform_amax
162162
// out.amax is the rowwise amax, out.columnwise_amax is the columnwise amax
163-
// rowwise amax will be the amax of original amax(input)
163+
// rowwise amax will be the amax of original amax(input)
164164
// columnwise amax will be the amax of the amax(RHT(input.t))
165165
if (quantizer->with_rht) {
166166
// bf16 only for now
167-
NVTE_CHECK(input.dtype() == DType::kBFloat16, "NVFP4 multi_quantize: RHT is only supported for bfloat16 input");
167+
NVTE_CHECK(input.dtype() == DType::kBFloat16,
168+
"NVFP4 multi_quantize: RHT is only supported for bfloat16 input");
168169
if (quantizer->with_post_rht_amax) {
169170
// We need:
170171
// 1. Rowwise amax = amax for input
171172
// 2. Columnwise amax = amax for RHT(input.t)
172173
NVTE_SCOPED_GIL_RELEASE({
173174
nvte_multi_hadamard_transform_amax(
174-
input.data(),
175-
reinterpret_cast<NVTETensor*>(nvte_tensor_output_list.data()),
176-
split_sections.data(),
177-
num_tensors,
178-
0,
179-
quantizer->rht_matrix_random_sign_mask_t,
175+
input.data(), reinterpret_cast<NVTETensor *>(nvte_tensor_output_list.data()),
176+
split_sections.data(), num_tensors, 0, quantizer->rht_matrix_random_sign_mask_t,
180177
stream);
181178
});
182-
}else {
179+
} else {
183180
NVTE_CHECK(false, "NVFP4 multi_quantize: Pre-RHT amax is not supported yet");
184181
}
185-
}else {
182+
} else {
186183
// TODO: implement this too when we disable RHT
187184
NVTE_CHECK(false, "NVFP4 multi_quantize: RHT is not supported when RHT is disabled for now");
188185
}
@@ -191,7 +188,7 @@ void multi_tensor_quantize_nvfp4_impl(const TensorWrapper &input,
191188
if (quantizer->with_rht) {
192189
// check the availablibilty of RHT matrix definition for best perf
193190
NVTE_CHECK(quantizer->rht_matrix.defined() && quantizer->rht_matrix.numel() > 0,
194-
"NVFP4 multi_quantize: RHT matrix is not set");
191+
"NVFP4 multi_quantize: RHT matrix is not set");
195192
auto rht_matrix_nvte = makeTransformerEngineTensor(quantizer->rht_matrix);
196193

197194
NVTE_SCOPED_GIL_RELEASE({
@@ -211,12 +208,15 @@ void multi_tensor_quantize_nvfp4_impl(const TensorWrapper &input,
211208
out_identity.set_rowwise_scale_inv(out_identity_scale_inv.data_ptr,
212209
static_cast<DType>(out_identity_scale_inv.dtype),
213210
out_identity_scale_inv.shape);
214-
out_identity.set_amax(out_identity_amax.data_ptr, static_cast<DType>(out_identity_amax.dtype),
211+
out_identity.set_amax(out_identity_amax.data_ptr,
212+
static_cast<DType>(out_identity_amax.dtype),
215213
out_identity_amax.shape);
216-
217-
NVTE_SCOPED_GIL_RELEASE(
218-
{ nvte_quantize_v2(input_list[i].data(), out_identity.data(), quant_config_list[i], stream); });
219-
}
214+
215+
NVTE_SCOPED_GIL_RELEASE({
216+
nvte_quantize_v2(input_list[i].data(), out_identity.data(), quant_config_list[i],
217+
stream);
218+
});
219+
}
220220

221221
// already eligible for RHT columnwise cast fusion after the dimension check
222222
if (quantizer->columnwise_usage) {
@@ -240,16 +240,17 @@ void multi_tensor_quantize_nvfp4_impl(const TensorWrapper &input,
240240
colwise_data_shape_2d.push_back(last_dim);
241241

242242
out_transpose.set_rowwise_data(out_columnwise_data.data_ptr,
243-
static_cast<DType>(out_columnwise_data.dtype),
244-
colwise_data_shape_2d);
243+
static_cast<DType>(out_columnwise_data.dtype),
244+
colwise_data_shape_2d);
245245
out_transpose.set_rowwise_scale_inv(out_columnwise_scale_inv.data_ptr,
246246
static_cast<DType>(out_columnwise_scale_inv.dtype),
247247
out_columnwise_scale_inv.shape);
248248
out_transpose.set_amax(out_columnwise_amax.data_ptr,
249-
static_cast<DType>(out_columnwise_amax.dtype),
250-
out_columnwise_amax.shape);
251-
nvte_hadamard_transform_cast_fusion_columnwise(
252-
input_list[i].data(), out_transpose.data(), rht_matrix_nvte.data(), quant_config_list[i], stream);
249+
static_cast<DType>(out_columnwise_amax.dtype),
250+
out_columnwise_amax.shape);
251+
nvte_hadamard_transform_cast_fusion_columnwise(input_list[i].data(), out_transpose.data(),
252+
rht_matrix_nvte.data(),
253+
quant_config_list[i], stream);
253254
}
254255
}
255256
});
@@ -264,7 +265,6 @@ void multi_tensor_quantize_nvfp4_impl(const TensorWrapper &input,
264265
}
265266
});
266267
}
267-
268268
}
269269

270270
void multi_tensor_quantize_impl(const TensorWrapper &single_input,
@@ -290,7 +290,7 @@ void multi_tensor_quantize_impl(const TensorWrapper &single_input,
290290

291291
// check if split_sections is just a dummy input
292292
bool valid_split_sections = split_sections.size() == num_tensors;
293-
293+
294294
// Check scaling mode consistency across all tensors
295295
for (size_t i = 0; i < num_tensors; i++) {
296296
if (detail::IsFloat8Quantizers(quantizer_py_list[i].ptr())) {
@@ -300,7 +300,7 @@ void multi_tensor_quantize_impl(const TensorWrapper &single_input,
300300
with_fused_kernel = false;
301301
break;
302302
}
303-
// check if the scaling mode is fp8 delayed scaling for all quantizers
303+
// check if the scaling mode is fp8 delayed scaling for all quantizers
304304
if (scaling_mode != NVTE_DELAYED_TENSOR_SCALING) {
305305
with_fused_kernel = false;
306306
break;
@@ -317,12 +317,12 @@ void multi_tensor_quantize_impl(const TensorWrapper &single_input,
317317
if (split_sections[i] % 64 != 0) {
318318
with_fused_kernel = false;
319319
break;
320-
}
321-
}else {
320+
}
321+
} else {
322322
with_fused_kernel = false;
323323
break;
324324
}
325-
325+
326326
} else {
327327
with_fused_kernel = false;
328328
break;

0 commit comments

Comments
 (0)