Skip to content

Commit

Permalink
fix for "two devices" issue due to RoPE changes (#630)
Browse files Browse the repository at this point in the history
  • Loading branch information
davedgd authored Nov 14, 2024
1 parent 7954766 commit 12c91b7
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 0 deletions.
1 change: 1 addition & 0 deletions awq/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def get_act_for_scaling(module: OldCohereDecoderLayer):

@staticmethod
def move_embed(model: OldCohereForCausalLM, device: str):
model.model.rotary_emb = model.model.rotary_emb.to(device)
model.model.embed_tokens = model.model.embed_tokens.to(device)

@staticmethod
Expand Down
1 change: 1 addition & 0 deletions awq/models/gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def get_act_for_scaling(module: GPTNeoXLayer):

@staticmethod
def move_embed(model: GPTNeoXForCausalLM, device: str):
model.gpt_neox.rotary_emb = model.gpt_neox.rotary_emb.to(device)
model.gpt_neox.embed_in = model.gpt_neox.embed_in.to(device)

@staticmethod
Expand Down
1 change: 1 addition & 0 deletions awq/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def get_act_for_scaling(module: OldLlamaDecoderLayer):

@staticmethod
def move_embed(model: OldLlamaForCausalLM, device: str):
model.model.rotary_emb = model.model.rotary_emb.to(device)
model.model.embed_tokens = model.model.embed_tokens.to(device)

@staticmethod
Expand Down
1 change: 1 addition & 0 deletions awq/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def get_act_for_scaling(module: OldQwen2DecoderLayer):

@staticmethod
def move_embed(model: OldQwen2ForCausalLM, device: str):
model.model.rotary_emb = model.model.rotary_emb.to(device)
model.model.embed_tokens = model.model.embed_tokens.to(device)

@staticmethod
Expand Down
1 change: 1 addition & 0 deletions awq/models/stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def get_act_for_scaling(module: OldStableLmForCausalLM):

@staticmethod
def move_embed(model: OldStableLmForCausalLM, device: str):
model.model.rotary_emb = model.model.rotary_emb.to(device)
model.model.embed_tokens = model.model.embed_tokens.to(device)

@staticmethod
Expand Down
1 change: 1 addition & 0 deletions awq/models/starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def get_act_for_scaling(module: OldStarcoder2DecoderLayer):

@staticmethod
def move_embed(model: OldStarcoder2ForCausalLM, device):
model.model.rotary_emb = model.model.rotary_emb.to(device)
model.model.embed_tokens = model.model.embed_tokens.to(device)

@staticmethod
Expand Down

0 comments on commit 12c91b7

Please sign in to comment.