Skip to content

Commit 9e0bb7c

Browse files
committed
improve model init
Signed-off-by: changwa1 <chang1.wang@intel.com>
1 parent 5f14658 commit 9e0bb7c

File tree

1 file changed

+23
-18
lines changed

1 file changed

+23
-18
lines changed

optimum/intel/neural_compressor/quantization.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -374,22 +374,33 @@ def _weight_only_quantization(
374374
}
375375

376376
low_cpu_mem_usage = True
377+
377378
if use_xpu:
378-
try:
379-
# TODO: if low_cpu_mem_uasge is True, gptj will have accuracy issue on CPU device.
380-
model = model_class.from_pretrained(
381-
model_id, low_cpu_mem_usage=low_cpu_mem_usage, device_map="cpu", **loading_kwargs
382-
)
383-
except NotImplementedError:
384-
logger.info(
385-
"Failed to load models with `low_cpu_mem_usage=True`, will fall to traditional load method resulting in higher memory consumption."
386-
)
387-
low_cpu_mem_usage = False
388-
model = model_class.from_pretrained(model_id, low_cpu_mem_usage=low_cpu_mem_usage, **loading_kwargs)
379+
if hasattr(quantization_config, "use_layer_wise") and quantization_config.use_layer_wise:
380+
from neural_compressor.torch import load_empty_model
381+
382+
model = load_empty_model(model_id, **loading_kwargs)
383+
else:
384+
try:
385+
# TODO: if low_cpu_mem_uasge is True, gptj will have accuracy issue on CPU device.
386+
model = model_class.from_pretrained(
387+
model_id, low_cpu_mem_usage=low_cpu_mem_usage, device_map="cpu", **loading_kwargs
388+
)
389+
except NotImplementedError:
390+
logger.info(
391+
"Failed to load models with `low_cpu_mem_usage=True`, will fall to traditional load method resulting in higher memory consumption."
392+
)
393+
low_cpu_mem_usage = False
394+
model = model_class.from_pretrained(model_id, low_cpu_mem_usage=low_cpu_mem_usage, **loading_kwargs)
389395
quantization_config.update(**{"device": "xpu"})
390396
quantization_config.post_init_xpu()
391397
else:
392-
model = model_class.from_pretrained(model_id, low_cpu_mem_usage=low_cpu_mem_usage, **loading_kwargs)
398+
if hasattr(quantization_config, "use_layer_wise") and quantization_config.use_layer_wise:
399+
from neural_compressor.torch import load_empty_model
400+
401+
model = load_empty_model(model_id, **loading_kwargs)
402+
else:
403+
model = model_class.from_pretrained(model_id, low_cpu_mem_usage=low_cpu_mem_usage, **loading_kwargs)
393404
quantization_config.post_init_cpu()
394405

395406
model.config.update({"low_cpu_mem_usage": low_cpu_mem_usage})
@@ -398,12 +409,6 @@ def _weight_only_quantization(
398409
if (not torch.cuda.is_available() or device_map == "cpu") and model.config.model_type == "chatglm":
399410
model = model.float()
400411

401-
from neural_compressor.torch import load_empty_model
402-
403-
model = load_empty_model(
404-
model_id,
405-
trust_remote_code=trust_remote_code,
406-
)
407412
model = convert_to_quantized_model(model, quantization_config, device=device_map)
408413
quantization_config.remove_redundant_parameters()
409414
model.config.quantization_config = quantization_config

0 commit comments

Comments
 (0)