Skip to content

Commit e6efaf9

Browse files
authored
Merge branch 'master' into linear
2 parents 5ee9a87 + fa8db5c commit e6efaf9

File tree

4 files changed

+13
-9
lines changed

4 files changed

+13
-9
lines changed

.github/workflows/nv-ds-chat.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343
4444
- name: Install deepspeed
4545
run: |
46-
pip install transformers==4.45.2
46+
pip install transformers
4747
pip install .[dev]
4848
ds_report
4949

SECURITY.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,7 @@ We prefer all communications to be in English.
3939
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
4040

4141
<!-- END MICROSOFT SECURITY.MD BLOCK -->
42+
43+
---
44+
45+
Please see [PyTorch's Security Policy](https://github.com/pytorch/pytorch/blob/main/SECURITY.md) for more information and recommendations on how to securely interact with models.

deepspeed/module_inject/layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def __init__(self, weight_shape=None, weight=None, bias=None):
191191
self.offset = 2
192192
super().__init__(weight_shape, weight=weight)
193193

194-
def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
194+
def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0, position_ids: int = 0):
195195
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
196196
attention_mask = attention_mask.long()
197197

deepspeed/module_inject/replace_module.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -342,13 +342,11 @@ def set_lm_head(module):
342342
module.lm_head, "weight") and module.lm_head.weight.is_meta:
343343
module.lm_head.weight = embedding_weight
344344
# enable tensor parallel for the last linear
345-
if hasattr(module, "lm_head") and hasattr(module.lm_head,
346-
"weight") and not module.lm_head.weight.is_meta and isinstance(
347-
module.lm_head, torch.nn.Linear):
345+
if hasattr(module, "lm_head") and hasattr(module.lm_head, "weight") and isinstance(
346+
module.lm_head, torch.nn.Linear):
348347
module = replace_wo_policy(module, ("lm_head", ), 0, "lm_head")
349-
elif hasattr(module, "embed_out") and hasattr(module.embed_out,
350-
"weight") and not module.embed_out.weight.is_meta and isinstance(
351-
module.embed_out, torch.nn.Linear):
348+
elif hasattr(module, "embed_out") and hasattr(module.embed_out, "weight") and isinstance(
349+
module.embed_out, torch.nn.Linear):
352350
module = replace_wo_policy(module, ("embed_out", ), 0, "embed_out")
353351
elif hasattr(module, "language_model") and hasattr(module.language_model, "lm_head"):
354352
module = replace_wo_policy(module.language_model, ("lm_head", ), 0, "lm_head")
@@ -389,7 +387,6 @@ def conv2d_parallel_shard_weights(model, rank, world_size):
389387
checkpoint=checkpoint_file)
390388
pbar.update(1)
391389
gc.collect()
392-
replaced_module = set_lm_head(replaced_module)
393390
# conv2d tp module replace
394391
# Now is for yuan model. Add model list and conv policy to decide whether to replace conv.
395392
if 'Yuan' in str(replaced_module):
@@ -399,6 +396,9 @@ def conv2d_parallel_shard_weights(model, rank, world_size):
399396
orig_class=orig_layer_impl,
400397
replace_fn=replace_fn,
401398
_replace_policy=config.injection_policy_tuple)
399+
# AutoTP default set lm_head tp
400+
if not config.replace_with_kernel_inject:
401+
replaced_module = set_lm_head(replaced_module)
402402

403403
quantizer = GroupQuantizer(q_int8=quantize)
404404
world_size = dist.get_world_size() if dist.is_initialized() else 1

0 commit comments

Comments
 (0)