@@ -342,13 +342,11 @@ def set_lm_head(module):
342
342
module .lm_head , "weight" ) and module .lm_head .weight .is_meta :
343
343
module .lm_head .weight = embedding_weight
344
344
# enable tensor parallel for the last linear
345
- if hasattr (module , "lm_head" ) and hasattr (module .lm_head ,
346
- "weight" ) and not module .lm_head .weight .is_meta and isinstance (
347
- module .lm_head , torch .nn .Linear ):
345
+ if hasattr (module , "lm_head" ) and hasattr (module .lm_head , "weight" ) and isinstance (
346
+ module .lm_head , torch .nn .Linear ):
348
347
module = replace_wo_policy (module , ("lm_head" , ), 0 , "lm_head" )
349
- elif hasattr (module , "embed_out" ) and hasattr (module .embed_out ,
350
- "weight" ) and not module .embed_out .weight .is_meta and isinstance (
351
- module .embed_out , torch .nn .Linear ):
348
+ elif hasattr (module , "embed_out" ) and hasattr (module .embed_out , "weight" ) and isinstance (
349
+ module .embed_out , torch .nn .Linear ):
352
350
module = replace_wo_policy (module , ("embed_out" , ), 0 , "embed_out" )
353
351
elif hasattr (module , "language_model" ) and hasattr (module .language_model , "lm_head" ):
354
352
module = replace_wo_policy (module .language_model , ("lm_head" , ), 0 , "lm_head" )
@@ -389,7 +387,6 @@ def conv2d_parallel_shard_weights(model, rank, world_size):
389
387
checkpoint = checkpoint_file )
390
388
pbar .update (1 )
391
389
gc .collect ()
392
- replaced_module = set_lm_head (replaced_module )
393
390
# conv2d tp module replace
394
391
# Now is for yuan model. Add model list and conv policy to decide whether to replace conv.
395
392
if 'Yuan' in str (replaced_module ):
@@ -399,6 +396,9 @@ def conv2d_parallel_shard_weights(model, rank, world_size):
399
396
orig_class = orig_layer_impl ,
400
397
replace_fn = replace_fn ,
401
398
_replace_policy = config .injection_policy_tuple )
399
+ # AutoTP default set lm_head tp
400
+ if not config .replace_with_kernel_inject :
401
+ replaced_module = set_lm_head (replaced_module )
402
402
403
403
quantizer = GroupQuantizer (q_int8 = quantize )
404
404
world_size = dist .get_world_size () if dist .is_initialized () else 1
0 commit comments