@@ -99,13 +99,13 @@ def short_str(self):
9999
100100 def __post_init__ (self ):
101101 if self .scaling_type is ScalingType .STATIC :
102- assert (
103- self . static_scale is not None
104- ), "static_scale must be specified for static scaling"
102+ assert self . static_scale is not None , (
103+ " static_scale must be specified for static scaling"
104+ )
105105 if self .scaling_granularity is ScalingGranularity .AXISWISE :
106- assert (
107- self . scaling_type is ScalingType . DYNAMIC
108- ), "only dynamic scaling type is supported for axiswise scaling granularity"
106+ assert self . scaling_type is ScalingType . DYNAMIC , (
107+ "only dynamic scaling type is supported for axiswise scaling granularity"
108+ )
109109 assert self .target_dtype is None or (
110110 self .target_dtype .is_floating_point and self .target_dtype .itemsize == 1
111111 ), "must specify a 8-bit floating-point dtype"
@@ -130,9 +130,9 @@ class DelayedScalingConfig:
130130 scale_fn_name : str = "max"
131131
132132 def __post_init__ (self ):
133- assert (
134- self .scale_fn_name == " max"
135- ), f" { self . scale_fn_name } is not implemented yet. Only max is supported for now."
133+ assert self . scale_fn_name == "max" , (
134+ f" { self .scale_fn_name } is not implemented yet. Only max is supported for now. "
135+ )
136136
137137
138138@dataclass (frozen = True )
@@ -148,7 +148,6 @@ class Float8GemmConfig:
148148
149149# Pre-made recipes for common configurations
150150class Float8LinearRecipeName (enum .Enum ):
151-
152151 # Default, dynamic per-tensor scaling with the cuBLAS tensorwise kernel
153152 TENSORWISE = "tensorwise"
154153
@@ -291,7 +290,9 @@ def __post_init__(self):
291290
292291 # float8 all-gather only supports tensorwise, in the future may support blockwise
293292 if self .cast_config_weight .scaling_granularity != ScalingGranularity .TENSORWISE :
294- assert not self .enable_fsdp_float8_all_gather , f"enable_fsdp_float8_all_gather only supports tensorwise scaling granularity, got { self .cast_config_weight .scaling_granularity } "
293+ assert not self .enable_fsdp_float8_all_gather , (
294+ f"enable_fsdp_float8_all_gather only supports tensorwise scaling granularity, got { self .cast_config_weight .scaling_granularity } "
295+ )
295296
296297 # save some characters in the compatibility checks below
297298 cc_i = self .cast_config_input
@@ -310,9 +311,9 @@ def __post_init__(self):
310311 ):
311312 is_disabled_1 = cc1 .scaling_type is ScalingType .DISABLED
312313 is_disabled_2 = cc1 .scaling_type is ScalingType .DISABLED
313- assert (
314- is_disabled_1 == is_disabled_2
315- ), f"incompatible operand precision for { gemm_name } "
314+ assert is_disabled_1 == is_disabled_2 , (
315+ f"incompatible operand precision for { gemm_name } "
316+ )
316317
317318 for cc1 , cc2 , operand_name , default_dtype in [
318319 (cc_i , cc_i_gw , "input" , e4m3_dtype ),
@@ -324,9 +325,9 @@ def __post_init__(self):
324325 object .__setattr__ (cc1 , "target_dtype" , default_dtype )
325326 if cc2 .target_dtype is None :
326327 object .__setattr__ (cc2 , "target_dtype" , default_dtype )
327- assert (
328- cc1 . target_dtype == cc2 . target_dtype
329- ), f" { operand_name } must be cast to the same dtype in both matmuls it's used in"
328+ assert cc1 . target_dtype == cc2 . target_dtype , (
329+ f" { operand_name } must be cast to the same dtype in both matmuls it's used in"
330+ )
330331
331332 # See the comments around `force_recompute_fp8_weight_in_bwd` for more details of this warning.
332333 if (
@@ -357,9 +358,9 @@ def from_recipe_name(
357358 """
358359 if type (recipe_name ) == str :
359360 valid_names = [n .value for n in Float8LinearRecipeName ]
360- assert (
361- recipe_name in valid_names
362- ), f"recipe_name { recipe_name } not in valid names { valid_names } "
361+ assert recipe_name in valid_names , (
362+ f" recipe_name { recipe_name } not in valid names { valid_names } "
363+ )
363364 recipe_name = Float8LinearRecipeName (recipe_name )
364365
365366 if recipe_name is Float8LinearRecipeName .TENSORWISE :
@@ -385,7 +386,6 @@ def from_recipe_name(
385386 )
386387
387388 elif recipe_name is Float8LinearRecipeName .ROWWISE_WITH_GW_HP :
388-
389389 # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
390390 cc_i = CastConfig (scaling_granularity = ScalingGranularity .AXISWISE )
391391 cc_w = CastConfig (scaling_granularity = ScalingGranularity .AXISWISE )
0 commit comments