@@ -146,7 +146,7 @@ void multi_tensor_quantize_nvfp4_impl(const TensorWrapper &input,
146146 nvte_tensor_input_list.push_back (input_list[i].data ());
147147 nvte_tensor_output_list.push_back (output_list[i].data ());
148148 }
149-
149+
150150 // stochastic rounding support for multi tensor
151151 if (quantizer->stochastic_rounding ) {
152152 // TODO: implement stochastic rounding support for multi tensor
@@ -160,29 +160,26 @@ void multi_tensor_quantize_nvfp4_impl(const TensorWrapper &input,
160160
161161 // with or without RHT, use nvte_multi_hadamard_transform_amax
162162 // out.amax is the rowwise amax, out.columnwise_amax is the columnwise amax
163- // rowwise amax will be the amax of original amax(input)
163+ // rowwise amax will be the amax of original amax(input)
164164 // columnwise amax will be the amax of the amax(RHT(input.t))
165165 if (quantizer->with_rht ) {
166166 // bf16 only for now
167- NVTE_CHECK (input.dtype () == DType::kBFloat16 , " NVFP4 multi_quantize: RHT is only supported for bfloat16 input" );
167+ NVTE_CHECK (input.dtype () == DType::kBFloat16 ,
168+ " NVFP4 multi_quantize: RHT is only supported for bfloat16 input" );
168169 if (quantizer->with_post_rht_amax ) {
169170 // We need:
170171 // 1. Rowwise amax = amax for input
171172 // 2. Columnwise amax = amax for RHT(input.t)
172173 NVTE_SCOPED_GIL_RELEASE ({
173174 nvte_multi_hadamard_transform_amax (
174- input.data (),
175- reinterpret_cast <NVTETensor*>(nvte_tensor_output_list.data ()),
176- split_sections.data (),
177- num_tensors,
178- 0 ,
179- quantizer->rht_matrix_random_sign_mask_t ,
175+ input.data (), reinterpret_cast <NVTETensor *>(nvte_tensor_output_list.data ()),
176+ split_sections.data (), num_tensors, 0 , quantizer->rht_matrix_random_sign_mask_t ,
180177 stream);
181178 });
182- }else {
179+ } else {
183180 NVTE_CHECK (false , " NVFP4 multi_quantize: Pre-RHT amax is not supported yet" );
184181 }
185- }else {
182+ } else {
186183 // TODO: implement this too when we disable RHT
187184 NVTE_CHECK (false , " NVFP4 multi_quantize: RHT is not supported when RHT is disabled for now" );
188185 }
@@ -191,7 +188,7 @@ void multi_tensor_quantize_nvfp4_impl(const TensorWrapper &input,
191188 if (quantizer->with_rht ) {
192189 // check the availablibilty of RHT matrix definition for best perf
193190 NVTE_CHECK (quantizer->rht_matrix .defined () && quantizer->rht_matrix .numel () > 0 ,
194- " NVFP4 multi_quantize: RHT matrix is not set" );
191+ " NVFP4 multi_quantize: RHT matrix is not set" );
195192 auto rht_matrix_nvte = makeTransformerEngineTensor (quantizer->rht_matrix );
196193
197194 NVTE_SCOPED_GIL_RELEASE ({
@@ -211,12 +208,15 @@ void multi_tensor_quantize_nvfp4_impl(const TensorWrapper &input,
211208 out_identity.set_rowwise_scale_inv (out_identity_scale_inv.data_ptr ,
212209 static_cast <DType>(out_identity_scale_inv.dtype ),
213210 out_identity_scale_inv.shape );
214- out_identity.set_amax (out_identity_amax.data_ptr , static_cast <DType>(out_identity_amax.dtype ),
211+ out_identity.set_amax (out_identity_amax.data_ptr ,
212+ static_cast <DType>(out_identity_amax.dtype ),
215213 out_identity_amax.shape );
216-
217- NVTE_SCOPED_GIL_RELEASE (
218- { nvte_quantize_v2 (input_list[i].data (), out_identity.data (), quant_config_list[i], stream); });
219- }
214+
215+ NVTE_SCOPED_GIL_RELEASE ({
216+ nvte_quantize_v2 (input_list[i].data (), out_identity.data (), quant_config_list[i],
217+ stream);
218+ });
219+ }
220220
221221 // already eligible for RHT columnwise cast fusion after the dimension check
222222 if (quantizer->columnwise_usage ) {
@@ -240,16 +240,17 @@ void multi_tensor_quantize_nvfp4_impl(const TensorWrapper &input,
240240 colwise_data_shape_2d.push_back (last_dim);
241241
242242 out_transpose.set_rowwise_data (out_columnwise_data.data_ptr ,
243- static_cast <DType>(out_columnwise_data.dtype ),
244- colwise_data_shape_2d);
243+ static_cast <DType>(out_columnwise_data.dtype ),
244+ colwise_data_shape_2d);
245245 out_transpose.set_rowwise_scale_inv (out_columnwise_scale_inv.data_ptr ,
246246 static_cast <DType>(out_columnwise_scale_inv.dtype ),
247247 out_columnwise_scale_inv.shape );
248248 out_transpose.set_amax (out_columnwise_amax.data_ptr ,
249- static_cast <DType>(out_columnwise_amax.dtype ),
250- out_columnwise_amax.shape );
251- nvte_hadamard_transform_cast_fusion_columnwise (
252- input_list[i].data (), out_transpose.data (), rht_matrix_nvte.data (), quant_config_list[i], stream);
249+ static_cast <DType>(out_columnwise_amax.dtype ),
250+ out_columnwise_amax.shape );
251+ nvte_hadamard_transform_cast_fusion_columnwise (input_list[i].data (), out_transpose.data (),
252+ rht_matrix_nvte.data (),
253+ quant_config_list[i], stream);
253254 }
254255 }
255256 });
@@ -264,7 +265,6 @@ void multi_tensor_quantize_nvfp4_impl(const TensorWrapper &input,
264265 }
265266 });
266267 }
267-
268268}
269269
270270void multi_tensor_quantize_impl (const TensorWrapper &single_input,
@@ -290,7 +290,7 @@ void multi_tensor_quantize_impl(const TensorWrapper &single_input,
290290
291291 // check if split_sections is just a dummy input
292292 bool valid_split_sections = split_sections.size () == num_tensors;
293-
293+
294294 // Check scaling mode consistency across all tensors
295295 for (size_t i = 0 ; i < num_tensors; i++) {
296296 if (detail::IsFloat8Quantizers (quantizer_py_list[i].ptr ())) {
@@ -300,7 +300,7 @@ void multi_tensor_quantize_impl(const TensorWrapper &single_input,
300300 with_fused_kernel = false ;
301301 break ;
302302 }
303- // check if the scaling mode is fp8 delayed scaling for all quantizers
303+ // check if the scaling mode is fp8 delayed scaling for all quantizers
304304 if (scaling_mode != NVTE_DELAYED_TENSOR_SCALING) {
305305 with_fused_kernel = false ;
306306 break ;
@@ -317,12 +317,12 @@ void multi_tensor_quantize_impl(const TensorWrapper &single_input,
317317 if (split_sections[i] % 64 != 0 ) {
318318 with_fused_kernel = false ;
319319 break ;
320- }
321- }else {
320+ }
321+ } else {
322322 with_fused_kernel = false ;
323323 break ;
324324 }
325-
325+
326326 } else {
327327 with_fused_kernel = false ;
328328 break ;
0 commit comments