diff --git a/awq/models/cohere.py b/awq/models/cohere.py index 9f669535..5db14561 100644 --- a/awq/models/cohere.py +++ b/awq/models/cohere.py @@ -29,6 +29,7 @@ def get_act_for_scaling(module: OldCohereDecoderLayer): @staticmethod def move_embed(model: OldCohereForCausalLM, device: str): + model.model.rotary_emb = model.model.rotary_emb.to(device) model.model.embed_tokens = model.model.embed_tokens.to(device) @staticmethod diff --git a/awq/models/gpt_neox.py b/awq/models/gpt_neox.py index 849dedb8..a0895d89 100644 --- a/awq/models/gpt_neox.py +++ b/awq/models/gpt_neox.py @@ -24,6 +24,7 @@ def get_act_for_scaling(module: GPTNeoXLayer): @staticmethod def move_embed(model: GPTNeoXForCausalLM, device: str): + model.gpt_neox.rotary_emb = model.gpt_neox.rotary_emb.to(device) model.gpt_neox.embed_in = model.gpt_neox.embed_in.to(device) @staticmethod diff --git a/awq/models/llama.py b/awq/models/llama.py index be6e8ecc..2664d16c 100644 --- a/awq/models/llama.py +++ b/awq/models/llama.py @@ -30,6 +30,7 @@ def get_act_for_scaling(module: OldLlamaDecoderLayer): @staticmethod def move_embed(model: OldLlamaForCausalLM, device: str): + model.model.rotary_emb = model.model.rotary_emb.to(device) model.model.embed_tokens = model.model.embed_tokens.to(device) @staticmethod diff --git a/awq/models/qwen2.py b/awq/models/qwen2.py index 1d7367b5..a4499480 100644 --- a/awq/models/qwen2.py +++ b/awq/models/qwen2.py @@ -30,6 +30,7 @@ def get_act_for_scaling(module: OldQwen2DecoderLayer): @staticmethod def move_embed(model: OldQwen2ForCausalLM, device: str): + model.model.rotary_emb = model.model.rotary_emb.to(device) model.model.embed_tokens = model.model.embed_tokens.to(device) @staticmethod diff --git a/awq/models/stablelm.py b/awq/models/stablelm.py index b4ad8bb8..38e15d2d 100644 --- a/awq/models/stablelm.py +++ b/awq/models/stablelm.py @@ -30,6 +30,7 @@ def get_act_for_scaling(module: OldStableLmForCausalLM): @staticmethod def move_embed(model: OldStableLmForCausalLM, device: str): + model.model.rotary_emb = model.model.rotary_emb.to(device) model.model.embed_tokens = model.model.embed_tokens.to(device) @staticmethod diff --git a/awq/models/starcoder2.py b/awq/models/starcoder2.py index ab6716c5..6949d5dd 100644 --- a/awq/models/starcoder2.py +++ b/awq/models/starcoder2.py @@ -36,6 +36,7 @@ def get_act_for_scaling(module: OldStarcoder2DecoderLayer): @staticmethod def move_embed(model: OldStarcoder2ForCausalLM, device): + model.model.rotary_emb = model.model.rotary_emb.to(device) model.model.embed_tokens = model.model.embed_tokens.to(device) @staticmethod