Skip to content

Commit

Permalink
Add GPTQ support for Gemma (vllm-project#3200)
Browse files Browse the repository at this point in the history
  • Loading branch information
TechxGenus authored Mar 7, 2024
1 parent 4cb3b92 commit d3c04b6
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions vllm/model_executor/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,17 @@ def load_weights(self,
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# GemmaRMSNorm is different from Llama's in that it multiplies
# (1 + weight) to the output, instead of just weight.
if "norm.weight" in name:
Expand Down

0 comments on commit d3c04b6

Please sign in to comment.