Skip to content

Commit 3613b25

Browse files
committed
adaptor func _allgather_params
Signed-off-by: aeeeeeep <aeeeeeep@proton.me>
1 parent 77a51f7 commit 3613b25

File tree

1 file changed

+103
-54
lines changed

1 file changed

+103
-54
lines changed

deepspeed/runtime/zero/partition_parameters.py

Lines changed: 103 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1994,78 +1994,127 @@ def _allgather_params(self, param_list, hierarchy=0):
19941994
if len(param_list) == 0:
19951995
return
19961996

1997-
partition_size = sum([param.ds_tensor.ds_numel for param in param_list])
1997+
if self.allgather_single_param:
1998+
for param in param_list:
1999+
partition_size = param.ds_tensor.ds_numel
2000+
tensor_size = partition_size * self.num_partitions
19982001

1999-
tensor_size = partition_size * self.num_partitions
2000-
flat_tensor = torch.empty(tensor_size, dtype=param_list[0].ds_tensor.dtype, device=self.local_device)
2001-
partitions = []
2002-
for i in range(self.num_partitions):
2003-
start = partition_size * i
2002+
flat_tensor = torch.empty(tensor_size, dtype=param.ds_tensor.dtype, device=self.local_device)
20042003

2005-
partitions.append(flat_tensor.narrow(0, start, partition_size))
2004+
flat_tensor.requires_grad = False
20062005

2007-
if i == self.get_partition_rank():
2008-
offset = 0
2009-
for param in param_list:
2010-
param_numel = param.ds_tensor.ds_numel
2006+
partitions = []
2007+
for i in range(self.num_partitions):
2008+
start = partition_size * i
2009+
partitioned_tensor = narrow_buffer(flat_tensor, start, partition_size, sbp_id=param.ds_sbp)
2010+
partitions.append(partitioned_tensor)
2011+
2012+
if i == self.get_partition_rank():
2013+
partitioned_tensor.copy_(param.ds_tensor.data)
2014+
2015+
if hasattr(param, 'ds_quant_scale'):
2016+
scale_size = param.ds_tensor.ds_quant_scale.numel()
2017+
scale_tensor_size = scale_size * self.num_partitions
2018+
flat_scale_tensor = torch.empty(scale_tensor_size,
2019+
dtype=param.ds_tensor.ds_quant_scale.dtype,
2020+
device=self.local_device)
2021+
flat_scale_tensor.requires_grad = False
2022+
2023+
scale_partitions = []
2024+
for i in range(self.num_partitions):
2025+
start = scale_size * i
2026+
scale_partitions.append(flat_scale_tensor.narrow(0, start, scale_size))
2027+
if i == self.get_partition_rank():
2028+
scale_partitions[i].copy_(param.ds_tensor.ds_quant_scale.data)
2029+
2030+
dist.all_gather_into_tensor(flat_tensor,
2031+
partitions[self.get_partition_rank()],
2032+
group=self.get_partition_dp_group(param),
2033+
async_op=False)
2034+
2035+
if hasattr(param, 'ds_quant_scale'):
2036+
dist.all_gather(flat_scale_tensor,
2037+
param.ds_tensor.ds_quant_scale,
2038+
group=self.get_partition_dp_group(param),
2039+
async_op=False)
20112040

2012-
partitions[i].narrow(0, offset, param_numel).copy_(param.ds_tensor.data)
2041+
param.data = flat_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape).data
20132042

2014-
offset += param_numel
2043+
if hasattr(param, 'ds_quant_scale'):
2044+
param.data = self.quantizer_module.dequantize(param.data, flat_scale_tensor)
2045+
else:
2046+
partition_size = sum([param.ds_tensor.ds_numel for param in param_list])
20152047

2016-
if hasattr(param_list[0], 'ds_quant_scale'):
2017-
scale_size = sum([param.ds_tensor.ds_quant_scale.numel() for param in param_list])
2018-
scale_tensor_size = scale_size * self.world_size
2019-
flat_scale_tensor = torch.empty(scale_tensor_size,
2020-
dtype=param_list[0].ds_tensor.ds_quant_scale.dtype,
2021-
device=self.local_device)
2022-
scale_partitions = []
2023-
for i in range(self.world_size):
2024-
start = scale_tensor_size * i
2025-
scale_partitions.append(flat_scale_tensor.narrow(0, start, scale_tensor_size))
2026-
if i == self.rank:
2048+
tensor_size = partition_size * self.num_partitions
2049+
flat_tensor = torch.empty(tensor_size, dtype=param_list[0].ds_tensor.dtype, device=self.local_device)
2050+
partitions = []
2051+
for i in range(self.num_partitions):
2052+
start = partition_size * i
2053+
2054+
partitions.append(flat_tensor.narrow(0, start, partition_size))
2055+
2056+
if i == self.get_partition_rank():
20272057
offset = 0
20282058
for param in param_list:
2029-
param_scale_numel = param.ds_tensor.ds_quant_scale.ds_numel
2059+
param_numel = param.ds_tensor.ds_numel
20302060

2031-
scale_partitions[i].narrow(0, offset,
2032-
param_scale_numel).copy_(param.ds_tensor.ds_quant_scale.data)
2061+
partitions[i].narrow(0, offset, param_numel).copy_(param.ds_tensor.data)
20332062

2034-
offset += param_scale_numel
2063+
offset += param_numel
20352064

2036-
dist.all_gather_into_tensor(flat_tensor,
2037-
partitions[self.get_partition_rank()],
2038-
group=self.get_partition_dp_group(param),
2039-
async_op=False)
2040-
if hasattr(param_list[0], 'ds_quant_scale'):
2041-
dist.all_gather(flat_scale_tensor,
2042-
param_list[0].ds_quant_scale,
2043-
group=self.get_partition_dp_group(param),
2044-
async_op=False)
2045-
param_offset = 0
2065+
if hasattr(param_list[0], 'ds_quant_scale'):
2066+
scale_size = sum([param.ds_tensor.ds_quant_scale.numel() for param in param_list])
2067+
scale_tensor_size = scale_size * self.world_size
2068+
flat_scale_tensor = torch.empty(scale_tensor_size,
2069+
dtype=param_list[0].ds_tensor.ds_quant_scale.dtype,
2070+
device=self.local_device)
2071+
scale_partitions = []
2072+
for i in range(self.world_size):
2073+
start = scale_tensor_size * i
2074+
scale_partitions.append(flat_scale_tensor.narrow(0, start, scale_tensor_size))
2075+
if i == self.rank:
2076+
offset = 0
2077+
for param in param_list:
2078+
param_scale_numel = param.ds_tensor.ds_quant_scale.ds_numel
2079+
2080+
scale_partitions[i].narrow(0, offset,
2081+
param_scale_numel).copy_(param.ds_tensor.ds_quant_scale.data)
2082+
2083+
offset += param_scale_numel
2084+
2085+
dist.all_gather_into_tensor(flat_tensor,
2086+
partitions[self.get_partition_rank()],
2087+
group=self.get_partition_dp_group(param),
2088+
async_op=False)
2089+
if hasattr(param_list[0], 'ds_quant_scale'):
2090+
dist.all_gather(flat_scale_tensor,
2091+
param_list[0].ds_quant_scale,
2092+
group=self.get_partition_dp_group(param),
2093+
async_op=False)
2094+
param_offset = 0
20462095

2047-
for param in param_list:
2048-
param_partition_size = param.ds_tensor.ds_numel
2049-
param_size = param.ds_numel
2050-
replicated_tensor = torch.empty(param.ds_shape, dtype=param.ds_tensor.dtype, device=self.local_device)
2096+
for param in param_list:
2097+
param_partition_size = param.ds_tensor.ds_numel
2098+
param_size = param.ds_numel
2099+
replicated_tensor = torch.empty(param.ds_shape, dtype=param.ds_tensor.dtype, device=self.local_device)
20512100

2052-
for i in range(self.num_partitions):
2101+
for i in range(self.num_partitions):
20532102

2054-
start = i * partition_size
2103+
start = i * partition_size
20552104

2056-
param_start = i * param_partition_size
2105+
param_start = i * param_partition_size
20572106

2058-
if param_start < param_size:
2059-
numel_to_copy = min(param_size - param_start, param_partition_size)
2107+
if param_start < param_size:
2108+
numel_to_copy = min(param_size - param_start, param_partition_size)
20602109

2061-
part_to_copy = partitions[i].narrow(0, param_offset, numel_to_copy)
2110+
part_to_copy = partitions[i].narrow(0, param_offset, numel_to_copy)
20622111

2063-
replicated_tensor.view(-1).narrow(0, param_start, numel_to_copy).copy_(part_to_copy)
2064-
#param_offset += param.data.numel()
2065-
param_offset += param.ds_tensor.ds_numel
2066-
if hasattr(param_list[0], 'ds_quant_scale'):
2067-
replicated_tensor = self.quantizer_module.dequantize(replicated_tensor, flat_scale_tensor)
2068-
param.data = replicated_tensor.data
2112+
replicated_tensor.view(-1).narrow(0, param_start, numel_to_copy).copy_(part_to_copy)
2113+
#param_offset += param.data.numel()
2114+
param_offset += param.ds_tensor.ds_numel
2115+
if hasattr(param_list[0], 'ds_quant_scale'):
2116+
replicated_tensor = self.quantizer_module.dequantize(replicated_tensor, flat_scale_tensor)
2117+
param.data = replicated_tensor.data
20692118

20702119
return None
20712120

0 commit comments

Comments
 (0)