From f1abb8ef8e261db78eb6c603f691801797fbb293 Mon Sep 17 00:00:00 2001 From: Casper Date: Fri, 13 Dec 2024 21:17:09 +0000 Subject: [PATCH] improve type hinting and fix use_cache (#680) --- awq/models/auto.py | 7 ++----- awq/models/base.py | 16 +++++++++++++++- awq/models/llava.py | 3 ++- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/awq/models/auto.py b/awq/models/auto.py index 9c9d6f22..f711fef5 100644 --- a/awq/models/auto.py +++ b/awq/models/auto.py @@ -74,11 +74,6 @@ def from_pretrained( model_path, trust_remote_code, **model_init_kwargs ) - if model_init_kwargs.get("low_cpu_mem_usage") is None: - model_init_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage - if model_init_kwargs.get("use_cache") is None: - model_init_kwargs["use_cache"] = use_cache - return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained( model_path, model_type, @@ -86,6 +81,8 @@ def from_pretrained( safetensors=safetensors, device_map=device_map, download_kwargs=download_kwargs, + low_cpu_mem_usage=low_cpu_mem_usage, + use_cache=use_cache, **model_init_kwargs, ) diff --git a/awq/models/base.py b/awq/models/base.py index 16a63b67..201bdeed 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -36,6 +36,7 @@ PretrainedConfig, AutoProcessor, BaseImageProcessor, + ProcessorMixin, PreTrainedTokenizer, ) from accelerate.big_modeling import ( @@ -112,7 +113,7 @@ def __init__( self.search_result = None self.config: PretrainedConfig = config self.quant_config: AwqConfig = quant_config - self.processor: BaseImageProcessor = processor + self.processor: ProcessorMixin = processor def to(self, device: Annotated[str, Doc("The device to move your model to.")]): """A utility function for moving the model to a device.""" @@ -342,6 +343,14 @@ def from_pretrained( Dict, Doc("Used for configure download model"), ] = None, + low_cpu_mem_usage: Annotated[ + bool, + Doc("Use low_cpu_mem_usage when loading from transformers.") + ] = True, + use_cache: Annotated[ + bool, + Doc("Use use_cache argument in transformers") + ] = False, **model_init_kwargs: Annotated[ Dict, Doc( @@ -367,6 +376,11 @@ def from_pretrained( if target_cls_name == "AutoModelForVision2Seq": processor = AutoProcessor.from_pretrained(model_weights_path) + if model_init_kwargs.get("low_cpu_mem_usage") is None: + model_init_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + if model_init_kwargs.get("use_cache") is None and target_cls_name != "AutoModelForVision2Seq": + model_init_kwargs["use_cache"] = use_cache + # If not quantized, must load with AutoModelForCausalLM model = target_cls.from_pretrained( model_weights_path, diff --git a/awq/models/llava.py b/awq/models/llava.py index 418a4fc4..126074c3 100644 --- a/awq/models/llava.py +++ b/awq/models/llava.py @@ -35,7 +35,8 @@ def move_embed(model: OldLlavaForConditionalGeneration, device: str): model.language_model.model.embed_tokens = model.get_input_embeddings().to( device ) - model.language_model.model.rotary_emb = model.language_model.model.rotary_emb.to(device) + if hasattr(model.language_model.model, "rotary_emb"): + model.language_model.model.rotary_emb = model.language_model.model.rotary_emb.to(device) @staticmethod def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs):