-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add an initial warmup step to IPEXModel
s
#543
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
optimum/intel/ipex/modeling_base.py
Outdated
@wraps(IPEXModel.forward) | ||
def forward(self, *args, **kwargs): | ||
outputs = self.model(*args, **kwargs) | ||
outputs = super().forward(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this needed ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The prepare_jit_inputs
looks at the signature of the function and the wraps and super help avoid code copy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would prefer we avoid as it will fail in case outputs
is not a dict
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also not sure to see the link with prepare_jit_inputs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- In
_init_warmup
we callprepare_jit_inputs
which examines the passed model'sforward
signature to see which dummy inputs exists in the signature. If we don't use wraps we get the signature ofinstead ofself, *args, **kwargs
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor = None, **kwargs,
outputs
will always be a dict because this is the output ofIPEXModel.forward
, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- OK I understand, was thinking that
prepare_jit_inputs
was only used for the torchscript export but I see that it's also used in_init_warmup
, thanks for the clarification - here I'm talking about
outputs
https://github.com/huggingface/optimum-intel/blob/8ee487dc2ade5bd0023d1bbe0a0103d6af8821e0/optimum/intel/ipex/modeling_base.py#L192C9-L192C16
What does this PR do?
The first 2 forwards of an IPEXModel after trace/load includes background optimizations steps that make the output of these forwards unpredictable and non consistent with the model after the optimizations. To fix that, an initial warmup step was added to the
__init__
of IPEXModelsDepends on PR #542
@echarlaix
Before submitting