@@ -1994,78 +1994,127 @@ def _allgather_params(self, param_list, hierarchy=0):
19941994 if len (param_list ) == 0 :
19951995 return
19961996
1997- partition_size = sum ([param .ds_tensor .ds_numel for param in param_list ])
1997+ if self .allgather_single_param :
1998+ for param in param_list :
1999+ partition_size = param .ds_tensor .ds_numel
2000+ tensor_size = partition_size * self .num_partitions
19982001
1999- tensor_size = partition_size * self .num_partitions
2000- flat_tensor = torch .empty (tensor_size , dtype = param_list [0 ].ds_tensor .dtype , device = self .local_device )
2001- partitions = []
2002- for i in range (self .num_partitions ):
2003- start = partition_size * i
2002+ flat_tensor = torch .empty (tensor_size , dtype = param .ds_tensor .dtype , device = self .local_device )
20042003
2005- partitions . append ( flat_tensor .narrow ( 0 , start , partition_size ))
2004+ flat_tensor .requires_grad = False
20062005
2007- if i == self .get_partition_rank ():
2008- offset = 0
2009- for param in param_list :
2010- param_numel = param .ds_tensor .ds_numel
2006+ partitions = []
2007+ for i in range (self .num_partitions ):
2008+ start = partition_size * i
2009+ partitioned_tensor = narrow_buffer (flat_tensor , start , partition_size , sbp_id = param .ds_sbp )
2010+ partitions .append (partitioned_tensor )
2011+
2012+ if i == self .get_partition_rank ():
2013+ partitioned_tensor .copy_ (param .ds_tensor .data )
2014+
2015+ if hasattr (param , 'ds_quant_scale' ):
2016+ scale_size = param .ds_tensor .ds_quant_scale .numel ()
2017+ scale_tensor_size = scale_size * self .num_partitions
2018+ flat_scale_tensor = torch .empty (scale_tensor_size ,
2019+ dtype = param .ds_tensor .ds_quant_scale .dtype ,
2020+ device = self .local_device )
2021+ flat_scale_tensor .requires_grad = False
2022+
2023+ scale_partitions = []
2024+ for i in range (self .num_partitions ):
2025+ start = scale_size * i
2026+ scale_partitions .append (flat_scale_tensor .narrow (0 , start , scale_size ))
2027+ if i == self .get_partition_rank ():
2028+ scale_partitions [i ].copy_ (param .ds_tensor .ds_quant_scale .data )
2029+
2030+ dist .all_gather_into_tensor (flat_tensor ,
2031+ partitions [self .get_partition_rank ()],
2032+ group = self .get_partition_dp_group (param ),
2033+ async_op = False )
2034+
2035+ if hasattr (param , 'ds_quant_scale' ):
2036+ dist .all_gather (flat_scale_tensor ,
2037+ param .ds_tensor .ds_quant_scale ,
2038+ group = self .get_partition_dp_group (param ),
2039+ async_op = False )
20112040
2012- partitions [ i ] .narrow (0 , offset , param_numel ). copy_ (param .ds_tensor .data )
2041+ param . data = flat_tensor .narrow (0 , 0 , param . ds_numel ). view (param .ds_shape ) .data
20132042
2014- offset += param_numel
2043+ if hasattr (param , 'ds_quant_scale' ):
2044+ param .data = self .quantizer_module .dequantize (param .data , flat_scale_tensor )
2045+ else :
2046+ partition_size = sum ([param .ds_tensor .ds_numel for param in param_list ])
20152047
2016- if hasattr (param_list [0 ], 'ds_quant_scale' ):
2017- scale_size = sum ([param .ds_tensor .ds_quant_scale .numel () for param in param_list ])
2018- scale_tensor_size = scale_size * self .world_size
2019- flat_scale_tensor = torch .empty (scale_tensor_size ,
2020- dtype = param_list [0 ].ds_tensor .ds_quant_scale .dtype ,
2021- device = self .local_device )
2022- scale_partitions = []
2023- for i in range (self .world_size ):
2024- start = scale_tensor_size * i
2025- scale_partitions .append (flat_scale_tensor .narrow (0 , start , scale_tensor_size ))
2026- if i == self .rank :
2048+ tensor_size = partition_size * self .num_partitions
2049+ flat_tensor = torch .empty (tensor_size , dtype = param_list [0 ].ds_tensor .dtype , device = self .local_device )
2050+ partitions = []
2051+ for i in range (self .num_partitions ):
2052+ start = partition_size * i
2053+
2054+ partitions .append (flat_tensor .narrow (0 , start , partition_size ))
2055+
2056+ if i == self .get_partition_rank ():
20272057 offset = 0
20282058 for param in param_list :
2029- param_scale_numel = param .ds_tensor . ds_quant_scale .ds_numel
2059+ param_numel = param .ds_tensor .ds_numel
20302060
2031- scale_partitions [i ].narrow (0 , offset ,
2032- param_scale_numel ).copy_ (param .ds_tensor .ds_quant_scale .data )
2061+ partitions [i ].narrow (0 , offset , param_numel ).copy_ (param .ds_tensor .data )
20332062
2034- offset += param_scale_numel
2063+ offset += param_numel
20352064
2036- dist .all_gather_into_tensor (flat_tensor ,
2037- partitions [self .get_partition_rank ()],
2038- group = self .get_partition_dp_group (param ),
2039- async_op = False )
2040- if hasattr (param_list [0 ], 'ds_quant_scale' ):
2041- dist .all_gather (flat_scale_tensor ,
2042- param_list [0 ].ds_quant_scale ,
2043- group = self .get_partition_dp_group (param ),
2044- async_op = False )
2045- param_offset = 0
2065+ if hasattr (param_list [0 ], 'ds_quant_scale' ):
2066+ scale_size = sum ([param .ds_tensor .ds_quant_scale .numel () for param in param_list ])
2067+ scale_tensor_size = scale_size * self .world_size
2068+ flat_scale_tensor = torch .empty (scale_tensor_size ,
2069+ dtype = param_list [0 ].ds_tensor .ds_quant_scale .dtype ,
2070+ device = self .local_device )
2071+ scale_partitions = []
2072+ for i in range (self .world_size ):
2073+ start = scale_tensor_size * i
2074+ scale_partitions .append (flat_scale_tensor .narrow (0 , start , scale_tensor_size ))
2075+ if i == self .rank :
2076+ offset = 0
2077+ for param in param_list :
2078+ param_scale_numel = param .ds_tensor .ds_quant_scale .ds_numel
2079+
2080+ scale_partitions [i ].narrow (0 , offset ,
2081+ param_scale_numel ).copy_ (param .ds_tensor .ds_quant_scale .data )
2082+
2083+ offset += param_scale_numel
2084+
2085+ dist .all_gather_into_tensor (flat_tensor ,
2086+ partitions [self .get_partition_rank ()],
2087+ group = self .get_partition_dp_group (param ),
2088+ async_op = False )
2089+ if hasattr (param_list [0 ], 'ds_quant_scale' ):
2090+ dist .all_gather (flat_scale_tensor ,
2091+ param_list [0 ].ds_quant_scale ,
2092+ group = self .get_partition_dp_group (param ),
2093+ async_op = False )
2094+ param_offset = 0
20462095
2047- for param in param_list :
2048- param_partition_size = param .ds_tensor .ds_numel
2049- param_size = param .ds_numel
2050- replicated_tensor = torch .empty (param .ds_shape , dtype = param .ds_tensor .dtype , device = self .local_device )
2096+ for param in param_list :
2097+ param_partition_size = param .ds_tensor .ds_numel
2098+ param_size = param .ds_numel
2099+ replicated_tensor = torch .empty (param .ds_shape , dtype = param .ds_tensor .dtype , device = self .local_device )
20512100
2052- for i in range (self .num_partitions ):
2101+ for i in range (self .num_partitions ):
20532102
2054- start = i * partition_size
2103+ start = i * partition_size
20552104
2056- param_start = i * param_partition_size
2105+ param_start = i * param_partition_size
20572106
2058- if param_start < param_size :
2059- numel_to_copy = min (param_size - param_start , param_partition_size )
2107+ if param_start < param_size :
2108+ numel_to_copy = min (param_size - param_start , param_partition_size )
20602109
2061- part_to_copy = partitions [i ].narrow (0 , param_offset , numel_to_copy )
2110+ part_to_copy = partitions [i ].narrow (0 , param_offset , numel_to_copy )
20622111
2063- replicated_tensor .view (- 1 ).narrow (0 , param_start , numel_to_copy ).copy_ (part_to_copy )
2064- #param_offset += param.data.numel()
2065- param_offset += param .ds_tensor .ds_numel
2066- if hasattr (param_list [0 ], 'ds_quant_scale' ):
2067- replicated_tensor = self .quantizer_module .dequantize (replicated_tensor , flat_scale_tensor )
2068- param .data = replicated_tensor .data
2112+ replicated_tensor .view (- 1 ).narrow (0 , param_start , numel_to_copy ).copy_ (part_to_copy )
2113+ #param_offset += param.data.numel()
2114+ param_offset += param .ds_tensor .ds_numel
2115+ if hasattr (param_list [0 ], 'ds_quant_scale' ):
2116+ replicated_tensor = self .quantizer_module .dequantize (replicated_tensor , flat_scale_tensor )
2117+ param .data = replicated_tensor .data
20692118
20702119 return None
20712120
0 commit comments