Skip to content

Commit bf6f01a

Browse files
authored
fix bug in deep fp8 zero (#141)
**Description** Fix 2 bugs in fp8 zero optimizer. MS-AMP will crash when: - there are param groups which don't have high precision parameter - partition 0 has no parameters.
1 parent ca2d8d5 commit bf6f01a

File tree

1 file changed

+26
-5
lines changed

1 file changed

+26
-5
lines changed

msamp/deepspeed/runtime/zero/fp8_stage_1_and_2.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,15 @@ def __init__(self, init_optimizer, *args, **kwargs): # noqa: C901
3434
**kwargs: Arbitrary keyword arguments.
3535
"""
3636
self.fp8_param_groups = []
37+
dtype = torch.float16
38+
for pg in init_optimizer.param_groups:
39+
for p in pg['params']:
40+
if p.requires_grad and not isinstance(p, ScalingTensor):
41+
dtype = p.dtype
42+
break
43+
44+
fake_param = torch.nn.parameter.Parameter(torch.zeros((), dtype=dtype))
45+
fake_index = 0
3746
for pg in init_optimizer.param_groups:
3847
fp8_params = []
3948
hp_params = []
@@ -45,6 +54,13 @@ def __init__(self, init_optimizer, *args, **kwargs): # noqa: C901
4554
else:
4655
hp_params.append(p)
4756
self.fp8_param_groups.append(fp8_params)
57+
# DeepSpeedZeroOptimizer will crash if there is no parameters in any parameter group,
58+
# so add a fake parameter.
59+
if len(hp_params) == 0:
60+
param_names = args[0]
61+
param_names[fake_param] = 'fake_' + str(fake_index)
62+
fake_index += 1
63+
hp_params.append(fake_param)
4864
pg['params'] = hp_params
4965

5066
assert len(self.fp8_param_groups) == len(init_optimizer.param_groups)
@@ -139,9 +155,14 @@ def _pad_and_flat(self, values_partitions, group_fp8_mems, group_id):
139155
torch.Tensor: flat fp8 groups.
140156
"""
141157
partition_size = dist.get_world_size(group=self.dp_process_group)
142-
ref_value = values_partitions[0][0]
143-
dtype = ref_value.dtype
144-
assert all(v.dtype == dtype for v in chain(*values_partitions))
158+
ref_value = None
159+
for partition in values_partitions:
160+
if len(partition) > 0:
161+
ref_value = partition[0]
162+
break
163+
if ref_value is not None:
164+
dtype = ref_value.dtype
165+
assert all(v.dtype == dtype for v in chain(*values_partitions))
145166

146167
align = self.fp8_nccl_start_alignment_factor
147168
max_flat_numels = max(group_fp8_mems)
@@ -777,12 +798,12 @@ def all_gather_fp8_metas(self):
777798
continue
778799
partition_size = len(params_partitions)
779800
scale_invs_partitions = [[p.meta.scale_inv for p in ps] for ps in params_partitions]
780-
ref_scale = scale_invs_partitions[0][0]
781801
align = self.fp8_nccl_start_alignment_factor
782802
max_flat_numels = (max_flat_numels + align - 1) // align * align
783803
for pi in range(partition_size):
784804
pad = max_flat_numels - numels[pi]
785-
scale_invs_partitions[pi].append(ref_scale.new_empty((pad, )))
805+
scale_invs_partitions[pi].append(torch.empty((pad, ), dtype=torch.float32, device='cuda'))
806+
786807
scales = list(chain(*scale_invs_partitions))
787808
scale_invs_groups.append(scales)
788809
flat = _flatten_dense_tensors(scales)

0 commit comments

Comments
 (0)