From 4f76ecb20f9444184f5a80714997ff119a466a49 Mon Sep 17 00:00:00 2001 From: s4rduk4r Date: Wed, 27 Sep 2023 01:18:12 +0300 Subject: [PATCH 1/4] Offload to cpu --- awq/models/auto.py | 8 +++++--- awq/models/base.py | 31 +++++++++++++++++++++++++------ 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/awq/models/auto.py b/awq/models/auto.py index 91a60b7e..f37607a0 100644 --- a/awq/models/auto.py +++ b/awq/models/auto.py @@ -40,11 +40,13 @@ def from_pretrained(self, model_path, trust_remote_code=True, safetensors=False, @classmethod def from_quantized(self, quant_path, quant_filename='', max_new_tokens=None, trust_remote_code=True, fuse_layers=True, - batch_size=1, safetensors=False) -> BaseAWQForCausalLM: + batch_size=1, safetensors=False, + max_memory=None, offload_folder=None) -> BaseAWQForCausalLM: os.environ["AWQ_BATCH_SIZE"] = str(batch_size) model_type = check_and_get_model_type(quant_path, trust_remote_code) return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized( quant_path, model_type, quant_filename, max_new_tokens, trust_remote_code=trust_remote_code, - fuse_layers=fuse_layers, safetensors=safetensors - ) \ No newline at end of file + fuse_layers=fuse_layers, safetensors=safetensors, + max_memory=max_memory, offload_folder=offload_folder + ) diff --git a/awq/models/base.py b/awq/models/base.py index 03b3eb1e..065e366f 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -133,7 +133,8 @@ def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = tor def from_quantized(self, model_path, model_type, model_filename='', max_new_tokens=None, torch_dtype=torch.float16, trust_remote_code=True, safetensors=False, is_quantized=True, - fuse_layers=False, version='GEMM'): + fuse_layers=False, version='GEMM', + max_memory=None, offload_folder=None): # [STEP 1-2] Load weights path and configs model_weights_path, config, quant_config = self._load_config( self, model_path, model_filename, safetensors, version, @@ -153,21 +154,39 @@ def from_quantized(self, model_path, model_type, model_filename='', device_map = infer_auto_device_map( model, no_split_module_classes=[self.layer_type], + max_memory=max_memory, dtype=torch_dtype - ) + ) # Load checkpoint load_checkpoint_in_model( model, checkpoint=model_weights_path, - device_map=device_map + device_map=device_map, + offload_folder=offload_folder, + dtype=torch_dtype ) # Dispath to devices - model = simple_dispatch_model(model, device_map) + if max_memory is None: + # VRAM only + model = simple_dispatch_model(model, device_map) + + if fuse_layers: + self.fuse_layers(model, quant_config) + else: + if fuse_layers: + self.fuse_layers(model, quant_config) + + # Offloading dispatch + from accelerate import dispatch_model + model = dispatch_model( + model, + device_map=device_map, + # offload_buffers=offload_folder is not None, + offload_dir=offload_folder + ) - if fuse_layers: - self.fuse_layers(model, quant_config) return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config) From 91de3bfac5efd9771bd449b4739ac495f2c564ca Mon Sep 17 00:00:00 2001 From: s4rduk4r Date: Wed, 27 Sep 2023 01:25:05 +0300 Subject: [PATCH 2/4] Clean --- awq/models/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/awq/models/base.py b/awq/models/base.py index 065e366f..a0f0dfbf 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -156,7 +156,7 @@ def from_quantized(self, model_path, model_type, model_filename='', no_split_module_classes=[self.layer_type], max_memory=max_memory, dtype=torch_dtype - ) + ) # Load checkpoint load_checkpoint_in_model( @@ -183,7 +183,6 @@ def from_quantized(self, model_path, model_type, model_filename='', model = dispatch_model( model, device_map=device_map, - # offload_buffers=offload_folder is not None, offload_dir=offload_folder ) From ffaaa2595b789e5516a9bcacb3f9f0dc9cb5a295 Mon Sep 17 00:00:00 2001 From: s4rduk4r Date: Wed, 27 Sep 2023 20:49:09 +0300 Subject: [PATCH 3/4] Rely on accelerate.dispatch_model() only --- awq/models/base.py | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/awq/models/base.py b/awq/models/base.py index a0f0dfbf..3e93b1a0 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -168,23 +168,16 @@ def from_quantized(self, model_path, model_type, model_filename='', ) # Dispath to devices - if max_memory is None: - # VRAM only - model = simple_dispatch_model(model, device_map) + if fuse_layers: + self.fuse_layers(model, quant_config) - if fuse_layers: - self.fuse_layers(model, quant_config) - else: - if fuse_layers: - self.fuse_layers(model, quant_config) - - # Offloading dispatch - from accelerate import dispatch_model - model = dispatch_model( - model, - device_map=device_map, - offload_dir=offload_folder - ) + # Offloading dispatch + from accelerate import dispatch_model + model = dispatch_model( + model, + device_map=device_map, + offload_dir=offload_folder + ) return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config) From 841a231301b6d7698ebe0ce5a5d448cd906a642a Mon Sep 17 00:00:00 2001 From: s4rduk4r Date: Wed, 27 Sep 2023 20:49:52 +0300 Subject: [PATCH 4/4] Fix apply_rotary_emb() to have both tensors on the same device --- awq/modules/fused/attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/awq/modules/fused/attn.py b/awq/modules/fused/attn.py index 73bedd26..5bdb6c1d 100644 --- a/awq/modules/fused/attn.py +++ b/awq/modules/fused/attn.py @@ -36,7 +36,7 @@ def apply_rotary_emb( xk_ = torch.view_as_complex( xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous() ) - freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_).to(xq_.device) xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk)