From 3d807cb4a6cf078e8d51f178e5721b9af90c7c18 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Sun, 22 Dec 2024 18:48:18 -0800 Subject: [PATCH] fix crash in warmup for xpu Signed-off-by: Wang, Yi A --- optimum/intel/generation/modeling.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/optimum/intel/generation/modeling.py b/optimum/intel/generation/modeling.py index a6e8a76f4f..d3b53f9a3c 100644 --- a/optimum/intel/generation/modeling.py +++ b/optimum/intel/generation/modeling.py @@ -63,7 +63,11 @@ def prepare_jit_inputs(model: PreTrainedModel, task: str, use_cache: bool = Fals dummy_inputs = onnx_config.generate_dummy_inputs(framework="pt") - return {key: dummy_inputs[key] for key in signature.parameters if dummy_inputs.get(key, None) is not None} + return { + key: dummy_inputs[key].to(model.device) + for key in signature.parameters + if dummy_inputs.get(key, None) is not None + } def jit_trace(model: PreTrainedModel, task: str, use_cache: bool = False):