Skip to content

Commit

Permalink
multi-gpu fix (#668)
Browse files Browse the repository at this point in the history
  • Loading branch information
casper-hansen authored Dec 3, 2024
1 parent 9f13358 commit f2171f3
Show file tree
Hide file tree
Showing 24 changed files with 86 additions and 85 deletions.
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,7 @@ quant_path = 'mistral-instruct-v0.2-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, use_cache=False
)
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Quantize
Expand Down
1 change: 1 addition & 0 deletions awq/models/aquila.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def get_act_for_scaling(module: OldAquilaDecoderLayer):
@staticmethod
def move_embed(model: OldAquilaForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
model.model.rotary_emb = model.model.rotary_emb.to(device)

@staticmethod
def get_layers_for_scaling(
Expand Down
9 changes: 8 additions & 1 deletion awq/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,21 @@ def from_pretrained(
model_path,
trust_remote_code=True,
safetensors=True,
device_map="auto",
device_map=None,
download_kwargs=None,
low_cpu_mem_usage=True,
use_cache=False,
**model_init_kwargs,
) -> BaseAWQForCausalLM:
model_type = check_and_get_model_type(
model_path, trust_remote_code, **model_init_kwargs
)

if model_init_kwargs.get("low_cpu_mem_usage") is None:
model_init_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
if model_init_kwargs.get("use_cache") is None:
model_init_kwargs["use_cache"] = use_cache

return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
model_path,
model_type,
Expand Down
1 change: 1 addition & 0 deletions awq/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def get_act_for_scaling(module):
@staticmethod
def move_embed(model, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
model.model.rotary_emb = model.model.rotary_emb.to(device)

@staticmethod
# def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs):
Expand Down
1 change: 1 addition & 0 deletions awq/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def get_act_for_scaling(module: OldCohereDecoderLayer):
@staticmethod
def move_embed(model: OldCohereForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
model.model.rotary_emb = model.model.rotary_emb.to(device)

@staticmethod
def get_layers_for_scaling(
Expand Down
1 change: 1 addition & 0 deletions awq/models/exaone.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def get_act_for_scaling(module: OldExaoneBlock):
@staticmethod
def move_embed(model: OldExaoneForCausalLM, device: str):
model.transformer.wte = model.transformer.wte.to(device)
model.transformer.rotary = model.transformer.rotary.to(device)

@staticmethod
def get_layers_for_scaling(module: OldExaoneBlock, input_feat, module_kwargs):
Expand Down
1 change: 1 addition & 0 deletions awq/models/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def get_act_for_scaling(module: OldFalconDecoderLayer):
@staticmethod
def move_embed(model: FalconForCausalLM, device):
model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
model.transformer.rotary_emb = model.transformer.rotary_emb.to(device)

@staticmethod
def get_layers_for_scaling(
Expand Down
1 change: 1 addition & 0 deletions awq/models/gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def get_act_for_scaling(module: GPTNeoXLayer):
@staticmethod
def move_embed(model: GPTNeoXForCausalLM, device: str):
model.gpt_neox.embed_in = model.gpt_neox.embed_in.to(device)
model.gpt_neox.rotary_emb = model.gpt_neox.rotary_emb.to(device)

@staticmethod
def get_layers_for_scaling(module: GPTNeoXLayer, input_feat, module_kwargs):
Expand Down
1 change: 1 addition & 0 deletions awq/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def get_act_for_scaling(module: OldLlamaDecoderLayer):
@staticmethod
def move_embed(model: OldLlamaForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
model.model.rotary_emb = model.model.rotary_emb.to(device)

@staticmethod
def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs):
Expand Down
1 change: 1 addition & 0 deletions awq/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def move_embed(model: OldLlavaForConditionalGeneration, device: str):
model.language_model.model.embed_tokens = model.get_input_embeddings().to(
device
)
model.language_model.model.rotary_emb = model.language_model.model.rotary_emb.to(device)

@staticmethod
def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs):
Expand Down
1 change: 1 addition & 0 deletions awq/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def move_embed(model: LlavaNextForConditionalGeneration, device: str):
model.language_model.model.embed_tokens = model.get_input_embeddings().to(
device
)
model.language_model.model.rotary_emb = model.language_model.model.rotary_emb.to(device)

@staticmethod
def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs):
Expand Down
1 change: 1 addition & 0 deletions awq/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def get_act_for_scaling(module: OldQwen2DecoderLayer):
@staticmethod
def move_embed(model: OldQwen2ForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
model.model.rotary_emb = model.model.rotary_emb.to(device)

@staticmethod
def get_layers_for_scaling(module: OldQwen2DecoderLayer, input_feat, module_kwargs):
Expand Down
1 change: 1 addition & 0 deletions awq/models/qwen2vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def get_act_for_scaling(module: "Qwen2VLForConditionalGeneration"):
def move_embed(model: "Qwen2VLForConditionalGeneration", device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
model.visual = model.visual.to(device)
model.model.rotary_emb = model.model.rotary_emb.to(device)

@staticmethod
def get_layers_for_scaling(module: "Qwen2VLDecoderLayer", input_feat, module_kwargs):
Expand Down
1 change: 1 addition & 0 deletions awq/models/stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def get_act_for_scaling(module: OldStableLmForCausalLM):
@staticmethod
def move_embed(model: OldStableLmForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
model.model.rotary_emb = model.model.rotary_emb.to(device)

@staticmethod
def get_layers_for_scaling(
Expand Down
1 change: 1 addition & 0 deletions awq/models/starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def get_act_for_scaling(module: OldStarcoder2DecoderLayer):
@staticmethod
def move_embed(model: OldStarcoder2ForCausalLM, device):
model.model.embed_tokens = model.model.embed_tokens.to(device)
model.model.rotary_emb = model.model.rotary_emb.to(device)

@staticmethod
def get_layers_for_scaling(module: OldStarcoder2DecoderLayer, input_feat, module_kwargs):
Expand Down
1 change: 1 addition & 0 deletions awq/models/yi.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def get_act_for_scaling(module):
@staticmethod
def move_embed(model, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
model.model.rotary_emb = model.model.rotary_emb.to(device)

@staticmethod
def get_layers_for_scaling(module, input_feat, module_kwargs):
Expand Down
54 changes: 28 additions & 26 deletions awq/modules/triton/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,17 +280,18 @@ def awq_dequantize_triton(
triton.cdiv(X, META["BLOCK_SIZE_X"]),
triton.cdiv(Y, META["BLOCK_SIZE_Y"]),
)
awq_dequantize_kernel[grid](
qweight,
scales,
zeros,
group_size,
result,
X,
Y,
BLOCK_SIZE_X=block_size_x,
BLOCK_SIZE_Y=block_size_y,
)
with torch.cuda.device(qweight.device.index):
awq_dequantize_kernel[grid](
qweight,
scales,
zeros,
group_size,
result,
X,
Y,
BLOCK_SIZE_X=block_size_x,
BLOCK_SIZE_Y=block_size_y,
)

return result

Expand Down Expand Up @@ -332,20 +333,21 @@ def awq_gemm_triton(

# A = input, B = qweight, C = result
# A = M x K, B = K x N, C = M x N
awq_gemm_kernel[grid](
input,
qweight,
result,
qzeros,
scales,
M,
N,
K,
group_size,
BLOCK_SIZE_M=block_size_m,
BLOCK_SIZE_N=block_size_n,
BLOCK_SIZE_K=block_size_k,
SPLIT_K=split_k_iters,
)
with torch.cuda.device(qweight.device.index):
awq_gemm_kernel[grid](
input,
qweight,
result,
qzeros,
scales,
M,
N,
K,
group_size,
BLOCK_SIZE_M=block_size_m,
BLOCK_SIZE_N=block_size_n,
BLOCK_SIZE_K=block_size_k,
SPLIT_K=split_k_iters,
)

return result
13 changes: 13 additions & 0 deletions awq/quantize/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,19 @@ def quantize(self):

self.inps = self.inps.to(common_device)

# We need to move the rotary embedding every time we move to a new module.
# Transformers 4.45.0 moved rotary embedding to model definition as of this PR:
# https://github.com/huggingface/transformers/pull/32617
self.awq_model.move_embed(self.model, common_device)

for k, v in self.module_kwargs.items():
# position embeddings found in tuple
if isinstance(v, tuple):
self.module_kwargs[k] = tuple(
item.to(common_device) if isinstance(item, (torch.Tensor, nn.Module))
else item for item in v
)

# [STEP 1]: Get layer, extract linear modules, extract input features
named_linears = get_named_linears(self.modules[i])

Expand Down
29 changes: 8 additions & 21 deletions docs/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ quant_path = 'mistral-instruct-v0.2-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, use_cache=False
)
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Quantize
Expand Down Expand Up @@ -50,9 +48,7 @@ quant_path = 'vicuna-7b-v1.5-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, use_cache=False, device_map="cuda",
)
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Define data loading methods
Expand Down Expand Up @@ -107,9 +103,7 @@ quant_path = 'qwen2-7b-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, use_cache=False, device_map="cuda",
)
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

def load_cosmopedia():
Expand Down Expand Up @@ -150,9 +144,7 @@ quant_path = 'deepseek-coder-v2-lite-instruct-awq'
quant_config = { "zero_point": True, "q_group_size": 64, "w_bit": 4, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, use_cache=False, device_map="cuda",
)
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

def load_openhermes_coding():
Expand Down Expand Up @@ -197,7 +189,7 @@ quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version":

# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, device_map="cuda",
model_path, low_cpu_mem_usage=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

Expand Down Expand Up @@ -236,9 +228,7 @@ llama_cpp_path = '/workspace/llama.cpp'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 6, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, use_cache=False, device_map="cuda",
)
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Quantize
Expand Down Expand Up @@ -293,7 +283,7 @@ quant_path = "qwen2-vl-7b-instruct"
quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM"}

model = AutoAWQForCausalLM.from_pretrained(
model_path, use_cache=False, attn_implementation="flash_attention_2"
model_path, attn_implementation="flash_attention_2"
)

# We define our own quantizer by extending the AwqQuantizer.
Expand Down Expand Up @@ -505,9 +495,7 @@ quant_path = 'minicpm3-4b-awq'
quant_config = { "zero_point": True, "q_group_size": 64, "w_bit": 4, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, use_cache=False, safetensors=False
)
model = AutoAWQForCausalLM.from_pretrained(model_path, safetensors=False)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Quantize
Expand Down Expand Up @@ -591,7 +579,6 @@ model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="cuda:0"
)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

Expand Down
16 changes: 1 addition & 15 deletions examples/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,7 @@ def main():
parser.add_argument("--version", type=str, default="GEMM", help="Quantization version")

# Model config arguments
parser.add_argument("--low_cpu_mem_usage", action="store_true", help="Use low CPU memory")
parser.add_argument("--no-low_cpu_mem_usage", action="store_false", dest="low_cpu_mem_usage", help="Don't use low CPU memory")
parser.add_argument("--use_cache", action="store_true", help="Use cache")
parser.add_argument("--no-use_cache", action="store_false", dest="use_cache", help="Don't use cache")
parser.add_argument("--device_map", type=str, default="auto", help="Device map for loading the pretrained model")

parser.set_defaults(zero_point=True, low_cpu_mem_usage=True, use_cache=None)
parser.add_argument("--device_map", type=str, default=None, help="Device map for loading the pretrained model")

args = parser.parse_args()

Expand All @@ -33,18 +27,10 @@ def main():
"version": args.version
}

model_config = {
"low_cpu_mem_usage": args.low_cpu_mem_usage,
}

if args.use_cache is not None:
model_config["use_cache"] = args.use_cache

print(f"Loading model from: {args.hf_model_path}")
model = AutoAWQForCausalLM.from_pretrained(
args.hf_model_path,
device_map=args.device_map,
**model_config
)
tokenizer = AutoTokenizer.from_pretrained(args.hf_model_path, trust_remote_code=True)

Expand Down
1 change: 0 additions & 1 deletion examples/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,4 @@
do_sample=True,
max_new_tokens=256,
streamer=streamer,
eos_token_id=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]
)
8 changes: 3 additions & 5 deletions examples/quantize.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer

model_path = 'mistralai/Mistral-7B-Instruct-v0.2'
quant_path = 'mistral-instruct-v0.2-awq'
model_path = 'Qwen/Qwen2.5-14B-Instruct'
quant_path = 'Qwen2.5-14B-Instruct-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, use_cache=False
)
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Quantize
Expand Down
Loading

0 comments on commit f2171f3

Please sign in to comment.