-
Notifications
You must be signed in to change notification settings - Fork 120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix jit model #566
fix jit model #566
Changes from 9 commits
e05557a
151712d
1782a50
41bf0f5
6509035
dd63ee7
16706d3
1244772
ccad4b5
740af94
e77e2c7
26ebb31
08717f2
3a966c5
55a59e3
248f0d2
de581ea
32232bb
0551064
a80b073
c144c5e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -90,13 +90,13 @@ def _from_transformers( | |
cls, | ||
model_id: str, | ||
config: PretrainedConfig, | ||
use_cache: bool = True, | ||
use_auth_token: Optional[Union[bool, str]] = None, | ||
revision: Optional[str] = None, | ||
force_download: bool = False, | ||
cache_dir: Optional[str] = None, | ||
subfolder: str = "", | ||
local_files_only: bool = False, | ||
use_cache: bool = True, | ||
torch_dtype: Optional[Union[str, "torch.dtype"]] = None, | ||
trust_remote_code: bool = False, | ||
): | ||
|
@@ -124,6 +124,7 @@ def _from_transformers( | |
save_dir_path = Path(save_dir.name) | ||
torch.jit.save(traced_model, save_dir_path / WEIGHTS_NAME) | ||
config.torchscript = True | ||
config.torch_dtype = torch_dtype | ||
|
||
return cls._from_pretrained( | ||
model_id=save_dir_path, | ||
|
@@ -192,7 +193,7 @@ def forward( | |
if "token_type_ids" in self.input_names: | ||
inputs["token_type_ids"] = token_type_ids | ||
|
||
outputs = self._call_model(**inputs) | ||
outputs = self.model(**inputs) | ||
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0]) | ||
|
||
def eval(self): | ||
|
@@ -207,6 +208,11 @@ def device(self) -> torch.device: | |
def dtype(self) -> torch.dtype: | ||
return self._dtype | ||
|
||
@property | ||
def model_dtype(self): | ||
logger.warning("model_dtype will be removed after v1.18.0") | ||
return self._dtype | ||
|
||
def to(self, device: Union[torch.device, str]): | ||
self._device = device if isinstance(device, torch.device) else torch.device(device) | ||
self.model.to(self._device) | ||
|
@@ -215,14 +221,6 @@ def to(self, device: Union[torch.device, str]): | |
def can_generate(self): | ||
return isinstance(self, GenerationMixin) | ||
|
||
def _call_model(self, *args, **kwargs): | ||
try: | ||
with torch.autocast(self.device.type, self.dtype): | ||
out = self.model(*args, **kwargs) | ||
except RuntimeError: | ||
out = self.model(*args, **kwargs) | ||
return out | ||
|
||
def _init_warmup(self): | ||
# warmup, the first 2 forwards of an IPEX model include some preprocessing steps and | ||
# the results of the compute are unpredictable | ||
|
@@ -260,7 +258,7 @@ def forward( | |
"pixel_values": pixel_values, | ||
} | ||
|
||
outputs = self._call_model(**inputs) | ||
outputs = self.model(**inputs) | ||
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0]) | ||
|
||
|
||
|
@@ -281,7 +279,7 @@ def forward( | |
if "attention_mask" in self.input_names: | ||
inputs["attention_mask"] = attention_mask | ||
|
||
outputs = self._call_model(**inputs) | ||
outputs = self.model(**inputs) | ||
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0]) | ||
|
||
|
||
|
@@ -304,7 +302,7 @@ def forward( | |
if "token_type_ids" in self.input_names: | ||
inputs["token_type_ids"] = token_type_ids | ||
|
||
outputs = self._call_model(**inputs) | ||
outputs = self.model(**inputs) | ||
start_logits = outputs["start_logits"] if isinstance(outputs, dict) else outputs[0] | ||
end_logits = outputs["end_logits"] if isinstance(outputs, dict) else outputs[1] | ||
return ModelOutput(start_logits=start_logits, end_logits=end_logits) | ||
|
@@ -325,10 +323,10 @@ def __init__( | |
): | ||
# Perform the initial warmup at the end of __init__ | ||
super().__init__(model, config, model_save_dir=model_save_dir, warmup=False) | ||
GenerationMixin.__init__(self) | ||
|
||
model_type = config.model_type.replace("_", "-") | ||
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(model_type)(config) | ||
self.model_dtype = kwargs.get("model_dtype", self.dtype) | ||
self.use_cache = "past_key_values" in self.input_names | ||
|
||
if use_cache ^ self.use_cache: | ||
|
@@ -348,6 +346,7 @@ def __init__( | |
) | ||
except AttributeError: | ||
self.model_cls = get_model_class(self.config, AutoModelForCausalLM._model_mapping) | ||
|
||
self._reorder_cache = self.model_cls._reorder_cache.__get__(self) | ||
|
||
if is_transformers_version(">=", "4.38.0") and model_type in {"llama", "phi", "persimmon"}: | ||
|
@@ -414,7 +413,7 @@ def forward( | |
position_ids = attention_mask.long().cumsum(-1) - 1 | ||
position_ids.masked_fill_(attention_mask == 0, 1) | ||
if past_key_values: | ||
position_ids = position_ids[:, -1].unsqueeze(-1) | ||
position_ids = position_ids[:, -input_ids.shape[-1] :] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this modification needed ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you add a test that would fail without this fix / pass with it then ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this be taken care by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's also an option, WDYT @echarlaix ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's already integrated in |
||
|
||
if "position_ids" in self.input_names or not self.input_names: | ||
inputs["position_ids"] = position_ids | ||
|
@@ -426,7 +425,7 @@ def forward( | |
inputs["past_key_values"] = past_key_values | ||
|
||
# 2. Model forward | ||
outputs = self._call_model(**inputs) | ||
outputs = self.model(**inputs) | ||
|
||
# 3. Process model outputs | ||
if isinstance(outputs, (list, tuple)): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also might make sense to add a test to verify this doesn't break support for model traced with with autocasting enabled cc @ofirzaf do you know if there any tiny model on thus hub we can use for this (https://huggingface.co/Intel/q8_tiny_starcoder_py/ ?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use this one, yes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we actually should have a test for this to verify no modifications will break support, would you mind adding it @jiqing-feng in https://github.com/huggingface/optimum-intel/blob/v1.15.2/tests/ipex/test_modeling.py#L201 ? Also given @ofirzaf above explanations removing
_call_model
will likely broke support forq8_tiny_starcoder_py