Skip to content

Fix IPEXModel input names #775

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 31 additions & 34 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def ipex_jit_trace(model, task, use_cache):

if _is_patched_with_ipex(model, task):
model = _patch_model(model)
# Todo: integerate in prepare_jit_inputs.
# TODO: integerate in prepare_jit_inputs.
sample_inputs = get_dummy_input(model, return_dict=True)
# Use Tensor Processing Primitives to accelerate linear, see https://arxiv.org/abs/2104.05755.
_enable_tpp()
Expand Down Expand Up @@ -151,7 +151,7 @@ def __init__(
logger.warning("The model has been exported already.")
else:
config = model.config if config is None else config
use_cache = kwargs.get("use_cache", None)
use_cache = kwargs.get("use_cache", True)
model = ipex_jit_trace(model, self.export_feature, use_cache)
config.torchscript = True

Expand All @@ -162,11 +162,13 @@ def __init__(
self.model_save_dir = model_save_dir
self._is_ipex_exported = _is_patched_with_ipex(model, self.export_feature)

self.input_names = (
{inputs.debugName().split(".")[0] for inputs in model.graph.inputs() if inputs.debugName() != "self"}
if isinstance(model, torch.jit.RecursiveScriptModule)
else inspect.signature(model.forward).parameters
)
if isinstance(model, torch.jit.RecursiveScriptModule):
self.input_names = {
inputs.debugName().split(".")[0] for inputs in model.graph.inputs() if inputs.debugName() != "self"
}
else:
self.input_names = set(inspect.signature(model.forward).parameters)

# Registers the IPEXModelForXXX classes into the transformers AutoModel classes to avoid warnings when creating
# a pipeline https://github.com/huggingface/transformers/blob/cad61b68396a1a387287a8e2e2fef78a25b79383/src/transformers/pipelines/base.py#L863
AutoConfig.register(self.base_model_prefix, AutoConfig)
Expand All @@ -184,7 +186,6 @@ def _from_pretrained(
cls,
model_id: Union[str, Path],
config: PretrainedConfig,
use_cache: bool = True,
use_auth_token: Optional[Union[bool, str]] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
Expand All @@ -194,7 +195,6 @@ def _from_pretrained(
local_files_only: bool = False,
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
trust_remote_code: bool = False,
_commit_hash: str = None,
file_name: Optional[str] = WEIGHTS_NAME,
**kwargs,
):
Expand All @@ -209,6 +209,17 @@ def _from_pretrained(
)
token = use_auth_token

commit_hash = kwargs.pop("_commit_hash", None)

model_kwargs = {
"revision": revision,
"token": token,
"cache_dir": cache_dir,
"subfolder": subfolder,
"local_files_only": local_files_only,
"force_download": force_download,
}

if not getattr(config, "torchscript", False):
logger.warning("Detect torchscript is false. Convert to torchscript model!")

Expand All @@ -217,44 +228,30 @@ def _from_pretrained(

task = cls.export_feature
config.torch_dtype = torch_dtype
model_kwargs = {
"revision": revision,
"token": token,
"cache_dir": cache_dir,
"subfolder": subfolder,
"local_files_only": local_files_only,
"force_download": force_download,
"torch_dtype": torch_dtype,
"trust_remote_code": trust_remote_code,
"_commit_hash": _commit_hash,
}

model = TasksManager.get_model_from_task(task, model_id, **model_kwargs)
model = TasksManager.get_model_from_task(
task,
model_id,
trust_remote_code=trust_remote_code,
torch_dtype=torch_dtype,
_commit_hash=commit_hash,
**model_kwargs,
)

return cls(model, config=config, export=True, use_cache=use_cache, **kwargs)
return cls(model, config=config, export=True, **kwargs)

# Load the model from local directory
if os.path.isdir(model_id):
model_cache_path = os.path.join(model_id, file_name)
model_save_dir = model_id
# Download the model from the hub
else:
model_cache_path = hf_hub_download(
repo_id=model_id,
filename=file_name,
token=token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
subfolder=subfolder,
)
model_cache_path = hf_hub_download(repo_id=model_id, filename=file_name, **model_kwargs)
model_save_dir = Path(model_cache_path).parent

model = torch.jit.load(model_cache_path)
torch.jit.freeze(model.eval())

return cls(model, config=config, model_save_dir=model_save_dir, use_cache=use_cache, **kwargs)
return cls(model, config=config, model_save_dir=model_save_dir, **kwargs)

def _save_pretrained(self, save_directory: Union[str, Path]):
output_path = os.path.join(save_directory, WEIGHTS_NAME)
Expand Down
Loading