@@ -374,22 +374,33 @@ def _weight_only_quantization(
374
374
}
375
375
376
376
low_cpu_mem_usage = True
377
+
377
378
if use_xpu :
378
- try :
379
- # TODO: if low_cpu_mem_uasge is True, gptj will have accuracy issue on CPU device.
380
- model = model_class .from_pretrained (
381
- model_id , low_cpu_mem_usage = low_cpu_mem_usage , device_map = "cpu" , ** loading_kwargs
382
- )
383
- except NotImplementedError :
384
- logger .info (
385
- "Failed to load models with `low_cpu_mem_usage=True`, will fall to traditional load method resulting in higher memory consumption."
386
- )
387
- low_cpu_mem_usage = False
388
- model = model_class .from_pretrained (model_id , low_cpu_mem_usage = low_cpu_mem_usage , ** loading_kwargs )
379
+ if hasattr (quantization_config , "use_layer_wise" ) and quantization_config .use_layer_wise :
380
+ from neural_compressor .torch import load_empty_model
381
+
382
+ model = load_empty_model (model_id , ** loading_kwargs )
383
+ else :
384
+ try :
385
+ # TODO: if low_cpu_mem_uasge is True, gptj will have accuracy issue on CPU device.
386
+ model = model_class .from_pretrained (
387
+ model_id , low_cpu_mem_usage = low_cpu_mem_usage , device_map = "cpu" , ** loading_kwargs
388
+ )
389
+ except NotImplementedError :
390
+ logger .info (
391
+ "Failed to load models with `low_cpu_mem_usage=True`, will fall to traditional load method resulting in higher memory consumption."
392
+ )
393
+ low_cpu_mem_usage = False
394
+ model = model_class .from_pretrained (model_id , low_cpu_mem_usage = low_cpu_mem_usage , ** loading_kwargs )
389
395
quantization_config .update (** {"device" : "xpu" })
390
396
quantization_config .post_init_xpu ()
391
397
else :
392
- model = model_class .from_pretrained (model_id , low_cpu_mem_usage = low_cpu_mem_usage , ** loading_kwargs )
398
+ if hasattr (quantization_config , "use_layer_wise" ) and quantization_config .use_layer_wise :
399
+ from neural_compressor .torch import load_empty_model
400
+
401
+ model = load_empty_model (model_id , ** loading_kwargs )
402
+ else :
403
+ model = model_class .from_pretrained (model_id , low_cpu_mem_usage = low_cpu_mem_usage , ** loading_kwargs )
393
404
quantization_config .post_init_cpu ()
394
405
395
406
model .config .update ({"low_cpu_mem_usage" : low_cpu_mem_usage })
@@ -398,12 +409,6 @@ def _weight_only_quantization(
398
409
if (not torch .cuda .is_available () or device_map == "cpu" ) and model .config .model_type == "chatglm" :
399
410
model = model .float ()
400
411
401
- from neural_compressor .torch import load_empty_model
402
-
403
- model = load_empty_model (
404
- model_id ,
405
- trust_remote_code = trust_remote_code ,
406
- )
407
412
model = convert_to_quantized_model (model , quantization_config , device = device_map )
408
413
quantization_config .remove_redundant_parameters ()
409
414
model .config .quantization_config = quantization_config
0 commit comments