@@ -34,6 +34,15 @@ def __init__(self, init_optimizer, *args, **kwargs): # noqa: C901
34
34
**kwargs: Arbitrary keyword arguments.
35
35
"""
36
36
self .fp8_param_groups = []
37
+ dtype = torch .float16
38
+ for pg in init_optimizer .param_groups :
39
+ for p in pg ['params' ]:
40
+ if p .requires_grad and not isinstance (p , ScalingTensor ):
41
+ dtype = p .dtype
42
+ break
43
+
44
+ fake_param = torch .nn .parameter .Parameter (torch .zeros ((), dtype = dtype ))
45
+ fake_index = 0
37
46
for pg in init_optimizer .param_groups :
38
47
fp8_params = []
39
48
hp_params = []
@@ -45,6 +54,13 @@ def __init__(self, init_optimizer, *args, **kwargs): # noqa: C901
45
54
else :
46
55
hp_params .append (p )
47
56
self .fp8_param_groups .append (fp8_params )
57
+ # DeepSpeedZeroOptimizer will crash if there is no parameters in any parameter group,
58
+ # so add a fake parameter.
59
+ if len (hp_params ) == 0 :
60
+ param_names = args [0 ]
61
+ param_names [fake_param ] = 'fake_' + str (fake_index )
62
+ fake_index += 1
63
+ hp_params .append (fake_param )
48
64
pg ['params' ] = hp_params
49
65
50
66
assert len (self .fp8_param_groups ) == len (init_optimizer .param_groups )
@@ -139,9 +155,14 @@ def _pad_and_flat(self, values_partitions, group_fp8_mems, group_id):
139
155
torch.Tensor: flat fp8 groups.
140
156
"""
141
157
partition_size = dist .get_world_size (group = self .dp_process_group )
142
- ref_value = values_partitions [0 ][0 ]
143
- dtype = ref_value .dtype
144
- assert all (v .dtype == dtype for v in chain (* values_partitions ))
158
+ ref_value = None
159
+ for partition in values_partitions :
160
+ if len (partition ) > 0 :
161
+ ref_value = partition [0 ]
162
+ break
163
+ if ref_value is not None :
164
+ dtype = ref_value .dtype
165
+ assert all (v .dtype == dtype for v in chain (* values_partitions ))
145
166
146
167
align = self .fp8_nccl_start_alignment_factor
147
168
max_flat_numels = max (group_fp8_mems )
@@ -777,12 +798,12 @@ def all_gather_fp8_metas(self):
777
798
continue
778
799
partition_size = len (params_partitions )
779
800
scale_invs_partitions = [[p .meta .scale_inv for p in ps ] for ps in params_partitions ]
780
- ref_scale = scale_invs_partitions [0 ][0 ]
781
801
align = self .fp8_nccl_start_alignment_factor
782
802
max_flat_numels = (max_flat_numels + align - 1 ) // align * align
783
803
for pi in range (partition_size ):
784
804
pad = max_flat_numels - numels [pi ]
785
- scale_invs_partitions [pi ].append (ref_scale .new_empty ((pad , )))
805
+ scale_invs_partitions [pi ].append (torch .empty ((pad , ), dtype = torch .float32 , device = 'cuda' ))
806
+
786
807
scales = list (chain (* scale_invs_partitions ))
787
808
scale_invs_groups .append (scales )
788
809
flat = _flatten_dense_tensors (scales )
0 commit comments