@@ -2733,7 +2733,7 @@ def _run_module(m, inp):
2733
2733
2734
2734
def test_fp8_weight_on_demand_transpose ():
2735
2735
if not fp8_block_scaling_available :
2736
- pytest .skip ("blockwise fp8 not available." )
2736
+ pytest .skip ("blockwise fp8 not available." )
2737
2737
2738
2738
dtype = torch .bfloat16
2739
2739
num_gemms = 4
@@ -2758,7 +2758,9 @@ def test_fp8_weight_on_demand_transpose():
2758
2758
2759
2759
# Share params
2760
2760
with torch .no_grad ():
2761
- weights_cache = [Parameter (getattr (grouped_linear , f"weight{ i } " ).clone ()) for i in range (num_gemms )]
2761
+ weights_cache = [
2762
+ Parameter (getattr (grouped_linear , f"weight{ i } " ).clone ()) for i in range (num_gemms )
2763
+ ]
2762
2764
2763
2765
for i in range (num_gemms ):
2764
2766
assert getattr (grouped_linear , f"weight{ i } " )._columnwise_data is not None
@@ -2811,7 +2813,6 @@ def test_fp8_weight_on_demand_transpose():
2811
2813
for i , (o , o_ref ) in enumerate (zip (outputs1 , outputs2 )):
2812
2814
torch .testing .assert_close (o , o_ref , rtol = 0 , atol = 0 )
2813
2815
2814
-
2815
2816
# 2. layernorm linear module test
2816
2817
FP8GlobalStateManager .FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE = False
2817
2818
with fp8_model_init (enabled = True , recipe = fp8_recipe ):
@@ -2831,9 +2832,7 @@ def test_fp8_weight_on_demand_transpose():
2831
2832
# Share params
2832
2833
weights_cache = []
2833
2834
with torch .no_grad ():
2834
- weights_cache .append (
2835
- te_ln_linear .te_module .layer_norm_weight .clone ()
2836
- )
2835
+ weights_cache .append (te_ln_linear .te_module .layer_norm_weight .clone ())
2837
2836
weights_cache .append (te_ln_linear .te_module .weight .clone ())
2838
2837
outputs1 = _test_granular_accuracy (te_ln_linear , bs , dtype , config , recipe = fp8_recipe )
2839
2838
@@ -2861,7 +2860,6 @@ def test_fp8_weight_on_demand_transpose():
2861
2860
for i , (o , o_ref ) in enumerate (zip (outputs1 , outputs2 )):
2862
2861
torch .testing .assert_close (o , o_ref , rtol = 0 , atol = 0 )
2863
2862
2864
-
2865
2863
# 3. linear module test
2866
2864
FP8GlobalStateManager .FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE = False
2867
2865
with fp8_model_init (enabled = True , recipe = fp8_recipe ):
@@ -2880,7 +2878,9 @@ def test_fp8_weight_on_demand_transpose():
2880
2878
with torch .no_grad ():
2881
2879
weights_cache = te_linear .weight .clone ()
2882
2880
2883
- te_outputs1 = _test_granular_accuracy (te_linear , bs , dtype , config , delay_wgrad_compute = True , recipe = fp8_recipe )
2881
+ te_outputs1 = _test_granular_accuracy (
2882
+ te_linear , bs , dtype , config , delay_wgrad_compute = True , recipe = fp8_recipe
2883
+ )
2884
2884
2885
2885
FP8GlobalStateManager .FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE = True
2886
2886
with fp8_model_init (enabled = True , recipe = fp8_recipe ):
@@ -2896,10 +2896,12 @@ def test_fp8_weight_on_demand_transpose():
2896
2896
assert te_linear .weight ._columnwise_data is None
2897
2897
2898
2898
te_linear .weight = Parameter (weights_cache )
2899
- te_outputs2 = _test_granular_accuracy (te_linear , bs , dtype , config , delay_wgrad_compute = True , recipe = fp8_recipe )
2899
+ te_outputs2 = _test_granular_accuracy (
2900
+ te_linear , bs , dtype , config , delay_wgrad_compute = True , recipe = fp8_recipe
2901
+ )
2900
2902
2901
2903
# should be bit-wise match
2902
2904
for i , (o , o_ref ) in enumerate (zip (outputs1 , outputs2 )):
2903
2905
torch .testing .assert_close (o , o_ref , rtol = 0 , atol = 0 )
2904
2906
2905
- FP8GlobalStateManager .FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE = old_value
2907
+ FP8GlobalStateManager .FP8_BLOCKWISE_WEIGHT_ON_DEMAND_TRANSPOSE = old_value
0 commit comments