Skip to content

Commit

Permalink
Fix IPEXModel input names for nn.module
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Jun 20, 2024
1 parent eb50967 commit d4abbd1
Showing 1 changed file with 31 additions and 34 deletions.
65 changes: 31 additions & 34 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def ipex_jit_trace(model, task, use_cache):

if _is_patched_with_ipex(model, task):
model = _patch_model(model)
# Todo: integerate in prepare_jit_inputs.
# TODO: integerate in prepare_jit_inputs.
sample_inputs = get_dummy_input(model, return_dict=True)
# Use Tensor Processing Primitives to accelerate linear, see https://arxiv.org/abs/2104.05755.
_enable_tpp()
Expand Down Expand Up @@ -151,7 +151,7 @@ def __init__(
logger.warning("The model has been exported already.")
else:
config = model.config if config is None else config
use_cache = kwargs.get("use_cache", None)
use_cache = kwargs.get("use_cache", True)
model = ipex_jit_trace(model, self.export_feature, use_cache)
config.torchscript = True

Expand All @@ -162,11 +162,13 @@ def __init__(
self.model_save_dir = model_save_dir
self._is_ipex_exported = _is_patched_with_ipex(model, self.export_feature)

self.input_names = (
{inputs.debugName().split(".")[0] for inputs in model.graph.inputs() if inputs.debugName() != "self"}
if isinstance(model, torch.jit.RecursiveScriptModule)
else inspect.signature(model.forward).parameters
)
if isinstance(model, torch.jit.RecursiveScriptModule):
self.input_names = {
inputs.debugName().split(".")[0] for inputs in model.graph.inputs() if inputs.debugName() != "self"
}
else:
self.input_names = set(inspect.signature(model.forward).parameters)

# Registers the IPEXModelForXXX classes into the transformers AutoModel classes to avoid warnings when creating
# a pipeline https://github.com/huggingface/transformers/blob/cad61b68396a1a387287a8e2e2fef78a25b79383/src/transformers/pipelines/base.py#L863
AutoConfig.register(self.base_model_prefix, AutoConfig)
Expand All @@ -184,7 +186,6 @@ def _from_pretrained(
cls,
model_id: Union[str, Path],
config: PretrainedConfig,
use_cache: bool = True,
use_auth_token: Optional[Union[bool, str]] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
Expand All @@ -194,7 +195,6 @@ def _from_pretrained(
local_files_only: bool = False,
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
trust_remote_code: bool = False,
_commit_hash: str = None,
file_name: Optional[str] = WEIGHTS_NAME,
**kwargs,
):
Expand All @@ -209,6 +209,17 @@ def _from_pretrained(
)
token = use_auth_token

commit_hash = kwargs.pop("_commit_hash", None)

model_kwargs = {
"revision": revision,
"token": token,
"cache_dir": cache_dir,
"subfolder": subfolder,
"local_files_only": local_files_only,
"force_download": force_download,
}

if not getattr(config, "torchscript", False):
logger.warning("Detect torchscript is false. Convert to torchscript model!")

Expand All @@ -217,44 +228,30 @@ def _from_pretrained(

task = cls.export_feature
config.torch_dtype = torch_dtype
model_kwargs = {
"revision": revision,
"token": token,
"cache_dir": cache_dir,
"subfolder": subfolder,
"local_files_only": local_files_only,
"force_download": force_download,
"torch_dtype": torch_dtype,
"trust_remote_code": trust_remote_code,
"_commit_hash": _commit_hash,
}

model = TasksManager.get_model_from_task(task, model_id, **model_kwargs)
model = TasksManager.get_model_from_task(
task,
model_id,
trust_remote_code=trust_remote_code,
torch_dtype=torch_dtype,
_commit_hash=commit_hash,
**model_kwargs,
)

return cls(model, config=config, export=True, use_cache=use_cache, **kwargs)
return cls(model, config=config, export=True, **kwargs)

# Load the model from local directory
if os.path.isdir(model_id):
model_cache_path = os.path.join(model_id, file_name)
model_save_dir = model_id
# Download the model from the hub
else:
model_cache_path = hf_hub_download(
repo_id=model_id,
filename=file_name,
token=token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
subfolder=subfolder,
)
model_cache_path = hf_hub_download(repo_id=model_id, filename=file_name, **model_kwargs)
model_save_dir = Path(model_cache_path).parent

model = torch.jit.load(model_cache_path)
torch.jit.freeze(model.eval())

return cls(model, config=config, model_save_dir=model_save_dir, use_cache=use_cache, **kwargs)
return cls(model, config=config, model_save_dir=model_save_dir, **kwargs)

def _save_pretrained(self, save_directory: Union[str, Path]):
output_path = os.path.join(save_directory, WEIGHTS_NAME)
Expand Down

0 comments on commit d4abbd1

Please sign in to comment.