Skip to content

Commit

Permalink
improve type hinting and fix use_cache (#680)
Browse files Browse the repository at this point in the history
  • Loading branch information
casper-hansen authored Dec 13, 2024
1 parent cbd6a75 commit f1abb8e
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
7 changes: 2 additions & 5 deletions awq/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,15 @@ def from_pretrained(
model_path, trust_remote_code, **model_init_kwargs
)

if model_init_kwargs.get("low_cpu_mem_usage") is None:
model_init_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
if model_init_kwargs.get("use_cache") is None:
model_init_kwargs["use_cache"] = use_cache

return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
model_path,
model_type,
trust_remote_code=trust_remote_code,
safetensors=safetensors,
device_map=device_map,
download_kwargs=download_kwargs,
low_cpu_mem_usage=low_cpu_mem_usage,
use_cache=use_cache,
**model_init_kwargs,
)

Expand Down
16 changes: 15 additions & 1 deletion awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
PretrainedConfig,
AutoProcessor,
BaseImageProcessor,
ProcessorMixin,
PreTrainedTokenizer,
)
from accelerate.big_modeling import (
Expand Down Expand Up @@ -112,7 +113,7 @@ def __init__(
self.search_result = None
self.config: PretrainedConfig = config
self.quant_config: AwqConfig = quant_config
self.processor: BaseImageProcessor = processor
self.processor: ProcessorMixin = processor

def to(self, device: Annotated[str, Doc("The device to move your model to.")]):
"""A utility function for moving the model to a device."""
Expand Down Expand Up @@ -342,6 +343,14 @@ def from_pretrained(
Dict,
Doc("Used for configure download model"),
] = None,
low_cpu_mem_usage: Annotated[
bool,
Doc("Use low_cpu_mem_usage when loading from transformers.")
] = True,
use_cache: Annotated[
bool,
Doc("Use use_cache argument in transformers")
] = False,
**model_init_kwargs: Annotated[
Dict,
Doc(
Expand All @@ -367,6 +376,11 @@ def from_pretrained(
if target_cls_name == "AutoModelForVision2Seq":
processor = AutoProcessor.from_pretrained(model_weights_path)

if model_init_kwargs.get("low_cpu_mem_usage") is None:
model_init_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
if model_init_kwargs.get("use_cache") is None and target_cls_name != "AutoModelForVision2Seq":
model_init_kwargs["use_cache"] = use_cache

# If not quantized, must load with AutoModelForCausalLM
model = target_cls.from_pretrained(
model_weights_path,
Expand Down
3 changes: 2 additions & 1 deletion awq/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def move_embed(model: OldLlavaForConditionalGeneration, device: str):
model.language_model.model.embed_tokens = model.get_input_embeddings().to(
device
)
model.language_model.model.rotary_emb = model.language_model.model.rotary_emb.to(device)
if hasattr(model.language_model.model, "rotary_emb"):
model.language_model.model.rotary_emb = model.language_model.model.rotary_emb.to(device)

@staticmethod
def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs):
Expand Down

0 comments on commit f1abb8e

Please sign in to comment.