Skip to content

Commit a02dcc1

Browse files
authored
fix device issue during calibration (#2100)
Signed-off-by: Xin He <xinhe3@habana.ai>
1 parent fa8ad83 commit a02dcc1

File tree

1 file changed

+6
-4
lines changed
  • examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only

1 file changed

+6
-4
lines changed

examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/run_clm_no_trainer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -403,11 +403,12 @@ def calib_func(prepared_model):
403403
max_seq_length=args.gptq_max_seq_length,
404404
)
405405
dataloader_for_calibration = dataloaderPreprocessor.get_prepared_dataloader()
406-
from neural_compressor.torch.algorithms.weight_only.utility import move_input_to_device
406+
from neural_compressor.torch.utils import get_model_device, move_input_device
407407
from tqdm import tqdm
408408
def run_fn_for_gptq(model, dataloader_for_calibration, *args):
409409
for batch in tqdm(dataloader_for_calibration):
410-
batch = move_input_to_device(batch, device=None)
410+
device = get_model_device(model)
411+
batch = move_input_device(batch, device=device)
411412
if isinstance(batch, tuple) or isinstance(batch, list):
412413
model(batch[0])
413414
elif isinstance(batch, dict):
@@ -525,11 +526,12 @@ def run_fn_for_autoround(model, dataloader):
525526
)
526527
dataloader = dataloaderPreprocessor.get_prepared_dataloader()
527528
custom_tune_config = TuningConfig(config_set=get_woq_tuning_config())
528-
from neural_compressor.torch.algorithms.weight_only.utility import move_input_to_device
529+
from neural_compressor.torch.utils import get_model_device, move_input_device
529530
from tqdm import tqdm
530531
def run_fn_for_gptq(model, dataloader_for_calibration, *args):
531532
for batch in tqdm(dataloader_for_calibration):
532-
batch = move_input_to_device(batch, device=None)
533+
device = get_model_device(model)
534+
batch = move_input_device(batch, device=device)
533535
if isinstance(batch, tuple) or isinstance(batch, list):
534536
model(batch[0])
535537
elif isinstance(batch, dict):

0 commit comments

Comments
 (0)