diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 2c407ecfd919..ba3c59f36532 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2170,7 +2170,7 @@ def _valid_auto_compile_criteria(self, model_kwargs: dict[str, Any], generation_ return False # Base logic - valid_hardware = self.device.type == "cuda" or bool( + valid_hardware = self.device.type in ["cuda", "xpu"] or bool( generation_config.compile_config is not None and generation_config.compile_config._compile_all_devices ) using_compilable_cache = (