Skip to content

Commit

Permalink
fix "Expected all tensors to be on the same device" (#664)
Browse files Browse the repository at this point in the history
  • Loading branch information
casper-hansen authored Dec 2, 2024
1 parent 0d1906f commit 9f13358
Show file tree
Hide file tree
Showing 6 changed files with 0 additions and 6 deletions.
1 change: 0 additions & 1 deletion awq/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def get_act_for_scaling(module: OldCohereDecoderLayer):

@staticmethod
def move_embed(model: OldCohereForCausalLM, device: str):
model.model.rotary_emb = model.model.rotary_emb.to(device)
model.model.embed_tokens = model.model.embed_tokens.to(device)

@staticmethod
Expand Down
1 change: 0 additions & 1 deletion awq/models/gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def get_act_for_scaling(module: GPTNeoXLayer):

@staticmethod
def move_embed(model: GPTNeoXForCausalLM, device: str):
model.gpt_neox.rotary_emb = model.gpt_neox.rotary_emb.to(device)
model.gpt_neox.embed_in = model.gpt_neox.embed_in.to(device)

@staticmethod
Expand Down
1 change: 0 additions & 1 deletion awq/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def get_act_for_scaling(module: OldLlamaDecoderLayer):

@staticmethod
def move_embed(model: OldLlamaForCausalLM, device: str):
model.model.rotary_emb = model.model.rotary_emb.to(device)
model.model.embed_tokens = model.model.embed_tokens.to(device)

@staticmethod
Expand Down
1 change: 0 additions & 1 deletion awq/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def get_act_for_scaling(module: OldQwen2DecoderLayer):

@staticmethod
def move_embed(model: OldQwen2ForCausalLM, device: str):
model.model.rotary_emb = model.model.rotary_emb.to(device)
model.model.embed_tokens = model.model.embed_tokens.to(device)

@staticmethod
Expand Down
1 change: 0 additions & 1 deletion awq/models/stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def get_act_for_scaling(module: OldStableLmForCausalLM):

@staticmethod
def move_embed(model: OldStableLmForCausalLM, device: str):
model.model.rotary_emb = model.model.rotary_emb.to(device)
model.model.embed_tokens = model.model.embed_tokens.to(device)

@staticmethod
Expand Down
1 change: 0 additions & 1 deletion awq/models/starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def get_act_for_scaling(module: OldStarcoder2DecoderLayer):

@staticmethod
def move_embed(model: OldStarcoder2ForCausalLM, device):
model.model.rotary_emb = model.model.rotary_emb.to(device)
model.model.embed_tokens = model.model.embed_tokens.to(device)

@staticmethod
Expand Down

0 comments on commit 9f13358

Please sign in to comment.