Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix jit model #566

Merged
merged 21 commits into from
Mar 19, 2024
Merged
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,13 @@ def _from_transformers(
cls,
model_id: str,
config: PretrainedConfig,
use_cache: bool = True,
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
force_download: bool = False,
cache_dir: Optional[str] = None,
subfolder: str = "",
local_files_only: bool = False,
use_cache: bool = True,
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
trust_remote_code: bool = False,
):
Expand Down Expand Up @@ -124,6 +124,7 @@ def _from_transformers(
save_dir_path = Path(save_dir.name)
torch.jit.save(traced_model, save_dir_path / WEIGHTS_NAME)
config.torchscript = True
config.torch_dtype = torch_dtype

return cls._from_pretrained(
model_id=save_dir_path,
Expand Down Expand Up @@ -192,7 +193,7 @@ def forward(
if "token_type_ids" in self.input_names:
inputs["token_type_ids"] = token_type_ids

outputs = self._call_model(**inputs)
outputs = self.model(**inputs)
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])

def eval(self):
Expand All @@ -207,6 +208,11 @@ def device(self) -> torch.device:
def dtype(self) -> torch.dtype:
return self._dtype

@property
def model_dtype(self):
logger.warning("model_dtype will be removed after v1.18.0")
return self._dtype

def to(self, device: Union[torch.device, str]):
self._device = device if isinstance(device, torch.device) else torch.device(device)
self.model.to(self._device)
Expand All @@ -215,14 +221,6 @@ def to(self, device: Union[torch.device, str]):
def can_generate(self):
return isinstance(self, GenerationMixin)

def _call_model(self, *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also might make sense to add a test to verify this doesn't break support for model traced with with autocasting enabled cc @ofirzaf do you know if there any tiny model on thus hub we can use for this (https://huggingface.co/Intel/q8_tiny_starcoder_py/ ?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use this one, yes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we actually should have a test for this to verify no modifications will break support, would you mind adding it @jiqing-feng in https://github.com/huggingface/optimum-intel/blob/v1.15.2/tests/ipex/test_modeling.py#L201 ? Also given @ofirzaf above explanations removing _call_model will likely broke support for q8_tiny_starcoder_py

try:
with torch.autocast(self.device.type, self.dtype):
out = self.model(*args, **kwargs)
except RuntimeError:
out = self.model(*args, **kwargs)
return out

def _init_warmup(self):
# warmup, the first 2 forwards of an IPEX model include some preprocessing steps and
# the results of the compute are unpredictable
Expand Down Expand Up @@ -260,7 +258,7 @@ def forward(
"pixel_values": pixel_values,
}

outputs = self._call_model(**inputs)
outputs = self.model(**inputs)
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])


Expand All @@ -281,7 +279,7 @@ def forward(
if "attention_mask" in self.input_names:
inputs["attention_mask"] = attention_mask

outputs = self._call_model(**inputs)
outputs = self.model(**inputs)
return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])


Expand All @@ -304,7 +302,7 @@ def forward(
if "token_type_ids" in self.input_names:
inputs["token_type_ids"] = token_type_ids

outputs = self._call_model(**inputs)
outputs = self.model(**inputs)
start_logits = outputs["start_logits"] if isinstance(outputs, dict) else outputs[0]
end_logits = outputs["end_logits"] if isinstance(outputs, dict) else outputs[1]
return ModelOutput(start_logits=start_logits, end_logits=end_logits)
Expand All @@ -325,10 +323,10 @@ def __init__(
):
# Perform the initial warmup at the end of __init__
super().__init__(model, config, model_save_dir=model_save_dir, warmup=False)
GenerationMixin.__init__(self)

model_type = config.model_type.replace("_", "-")
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(model_type)(config)
self.model_dtype = kwargs.get("model_dtype", self.dtype)
self.use_cache = "past_key_values" in self.input_names

if use_cache ^ self.use_cache:
Expand All @@ -348,6 +346,7 @@ def __init__(
)
except AttributeError:
self.model_cls = get_model_class(self.config, AutoModelForCausalLM._model_mapping)

self._reorder_cache = self.model_cls._reorder_cache.__get__(self)

if is_transformers_version(">=", "4.38.0") and model_type in {"llama", "phi", "persimmon"}:
Expand Down Expand Up @@ -414,7 +413,7 @@ def forward(
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
position_ids = position_ids[:, -input_ids.shape[-1] :]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this modification needed ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

position_ids should always have the same size as input_ids, we cannot assume the length is 1 while pkv exists (for example, assisted decoding).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a test that would fail without this fix / pass with it then ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be taken care by prepare_inputs_for_generation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's also an option, WDYT @echarlaix ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's already integrated in prepare_inputs_for_generation for each modeling
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1237, could make sense to remove it all actually


if "position_ids" in self.input_names or not self.input_names:
inputs["position_ids"] = position_ids
Expand All @@ -426,7 +425,7 @@ def forward(
inputs["past_key_values"] = past_key_values

# 2. Model forward
outputs = self._call_model(**inputs)
outputs = self.model(**inputs)

# 3. Process model outputs
if isinstance(outputs, (list, tuple)):
Expand Down
Loading