@@ -403,11 +403,12 @@ def calib_func(prepared_model):
403
403
max_seq_length = args .gptq_max_seq_length ,
404
404
)
405
405
dataloader_for_calibration = dataloaderPreprocessor .get_prepared_dataloader ()
406
- from neural_compressor .torch .algorithms . weight_only . utility import move_input_to_device
406
+ from neural_compressor .torch .utils import get_model_device , move_input_device
407
407
from tqdm import tqdm
408
408
def run_fn_for_gptq (model , dataloader_for_calibration , * args ):
409
409
for batch in tqdm (dataloader_for_calibration ):
410
- batch = move_input_to_device (batch , device = None )
410
+ device = get_model_device (model )
411
+ batch = move_input_device (batch , device = device )
411
412
if isinstance (batch , tuple ) or isinstance (batch , list ):
412
413
model (batch [0 ])
413
414
elif isinstance (batch , dict ):
@@ -525,11 +526,12 @@ def run_fn_for_autoround(model, dataloader):
525
526
)
526
527
dataloader = dataloaderPreprocessor .get_prepared_dataloader ()
527
528
custom_tune_config = TuningConfig (config_set = get_woq_tuning_config ())
528
- from neural_compressor .torch .algorithms . weight_only . utility import move_input_to_device
529
+ from neural_compressor .torch .utils import get_model_device , move_input_device
529
530
from tqdm import tqdm
530
531
def run_fn_for_gptq (model , dataloader_for_calibration , * args ):
531
532
for batch in tqdm (dataloader_for_calibration ):
532
- batch = move_input_to_device (batch , device = None )
533
+ device = get_model_device (model )
534
+ batch = move_input_device (batch , device = device )
533
535
if isinstance (batch , tuple ) or isinstance (batch , list ):
534
536
model (batch [0 ])
535
537
elif isinstance (batch , dict ):
0 commit comments