From e05557ab61cd97e46b6a47851a39f3be85d8c3ad Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Sun, 18 Feb 2024 11:21:30 -0500 Subject: [PATCH 01/16] fix jit model --- optimum/intel/ipex/modeling_base.py | 38 ++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 67810ae06..1869efc76 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -90,13 +90,13 @@ def _from_transformers( cls, model_id: str, config: PretrainedConfig, + use_cache: bool = True, use_auth_token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, force_download: bool = False, cache_dir: Optional[str] = None, subfolder: str = "", local_files_only: bool = False, - use_cache: bool = True, torch_dtype: Optional[Union[str, "torch.dtype"]] = None, trust_remote_code: bool = False, ): @@ -134,6 +134,7 @@ def _from_transformers( cache_dir=cache_dir, local_files_only=local_files_only, use_cache=use_cache, + model_dtype=torch_dtype, ) @classmethod @@ -325,9 +326,11 @@ def __init__( ): # Perform the initial warmup at the end of __init__ super().__init__(model, config, model_save_dir=model_save_dir, warmup=False) + GenerationMixin.__init__(self) self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) self.model_dtype = kwargs.get("model_dtype", self.dtype) + self._dtype = self.model_dtype self.use_cache = "past_key_values" in self.input_names if use_cache ^ self.use_cache: @@ -346,8 +349,6 @@ def __init__( ) except AttributeError: self.model_cls = get_model_class(self.config, AutoModelForCausalLM._model_mapping) - self._reorder_cache = self.model_cls._reorder_cache.__get__(self) - self.prepare_inputs_for_generation = self.model_cls.prepare_inputs_for_generation.__get__(self) if hasattr(self.model_cls, "_convert_to_standard_cache"): self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache if hasattr(self.model_cls, "_convert_to_bloom_cache"): @@ -355,6 +356,37 @@ def __init__( if warmup: self._init_warmup() + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + past_key_values = past_key_values or kwargs.get("past", None) + + if self.use_cache and past_key_values is not None: + input_ids = input_ids[:, -1:] + + # `past_key_values` may be in the stardard format (e.g. in contrastive search), converts to bloom's format if needed + if past_key_values is not None and self.config.model_type == "bloom": + if past_key_values[0][0].shape[0] == input_ids.shape[0]: + past_key_values = self._convert_to_bloom_cache(past_key_values) + + position_ids = kwargs.get("position_ids", None) + + attention_mask = kwargs.get("attention_mask", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": self.use_cache, + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": None, + } + def _prepare_past_key_values(self, input_ids): model_type = self.config.model_type.replace("_", "-") nb_pkv = 2 From 151712d565516ca9137af82c202cc5e3984d663e Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Sun, 18 Feb 2024 11:59:26 -0500 Subject: [PATCH 02/16] rm autocast in model --- optimum/intel/ipex/modeling_base.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 1869efc76..a6936bc88 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -193,7 +193,7 @@ def forward( if "token_type_ids" in self.input_names: inputs["token_type_ids"] = token_type_ids - outputs = self._call_model(**inputs) + outputs = self.model(**inputs) return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0]) def eval(self): @@ -216,14 +216,6 @@ def to(self, device: Union[torch.device, str]): def can_generate(self): return isinstance(self, GenerationMixin) - def _call_model(self, *args, **kwargs): - try: - with torch.autocast(self.device.type, self.dtype): - out = self.model(*args, **kwargs) - except RuntimeError: - out = self.model(*args, **kwargs) - return out - def _init_warmup(self): # warmup, the first 2 forwards of an IPEX model include some preprocessing steps and # the results of the compute are unpredictable @@ -261,7 +253,7 @@ def forward( "pixel_values": pixel_values, } - outputs = self._call_model(**inputs) + outputs = self.model(**inputs) return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0]) @@ -282,7 +274,7 @@ def forward( if "attention_mask" in self.input_names: inputs["attention_mask"] = attention_mask - outputs = self._call_model(**inputs) + outputs = self.model(**inputs) return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0]) @@ -305,7 +297,7 @@ def forward( if "token_type_ids" in self.input_names: inputs["token_type_ids"] = token_type_ids - outputs = self._call_model(**inputs) + outputs = self.model(**inputs) start_logits = outputs["start_logits"] if isinstance(outputs, dict) else outputs[0] end_logits = outputs["end_logits"] if isinstance(outputs, dict) else outputs[1] return ModelOutput(start_logits=start_logits, end_logits=end_logits) @@ -451,7 +443,7 @@ def forward( inputs["past_key_values"] = past_key_values # 2. Model forward - outputs = self._call_model(**inputs) + outputs = self.model(**inputs) # 3. Process model outputs if isinstance(outputs, (list, tuple)): From 1782a5008b9770c5cf0ad3d9a7ecd386aebdb72d Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 21 Feb 2024 10:15:03 -0500 Subject: [PATCH 03/16] support assisted decoding and add reorder cache function --- optimum/intel/ipex/modeling_base.py | 99 ++++++++++++++++++++++++++++- 1 file changed, 96 insertions(+), 3 deletions(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index a6936bc88..acf27c431 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -348,11 +348,95 @@ def __init__( if warmup: self._init_warmup() + def _reorder_cache( + self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. + This is required to match `past_key_values` with the correct beam_idx at every generation step. + """ + if self.config.model_type == "bloom": + return self._reorder_cache_bloom(past_key_values, beam_idx) + + # from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) + + # Copied from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache + def _reorder_cache_bloom( + self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called for bloom architecture. + This is required to match `past_key_values` with the correct beam_idx at every generation step. + """ + standardized_past = self._convert_to_standard_cache(past_key_values, batch_size=len(beam_idx)) + + # Get a copy of `beam_idx` on all the devices where we need those indices. + device_to_beam_idx = { + past_state.device: beam_idx.to(past_state.device) + for layer_past in past_key_values + for past_state in layer_past + } + reordered_past = tuple( + ( + layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]), + layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]), + ) + for layer_past in standardized_past + ) + return self._convert_to_bloom_cache(reordered_past) + + # Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._convert_to_bloom_cache + @staticmethod + def _convert_to_bloom_cache(past_key_value: Tuple[Tuple[torch.Tensor]]) -> Tuple[Tuple[torch.Tensor]]: + """ + Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...])) + """ + batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape + batch_size_times_num_heads = batch_size * num_heads + # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length] + # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim] + return tuple( + ( + layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length), + layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim), + ) + for layer_past in past_key_value + ) + + # Adapted from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._convert_to_standard_cache + def _convert_to_standard_cache( + self, past_key_value: Tuple[Tuple[torch.Tensor]], batch_size: int + ) -> Tuple[Tuple[torch.Tensor]]: + """ + Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size, num_heads, ...])) + """ + if self.config.model_type != "bloom": + return past_key_value + + batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape + num_heads = batch_size_times_num_heads // batch_size + # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length] + # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim] + return tuple( + ( + layer_past[0].view(batch_size, num_heads, head_dim, seq_length), + layer_past[1].view(batch_size, num_heads, seq_length, head_dim), + ) + for layer_past in past_key_value + ) + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): past_key_values = past_key_values or kwargs.get("past", None) if self.use_cache and past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = self.get_past_length(past_key_values) + input_ids = input_ids[:, past_length:] # `past_key_values` may be in the stardard format (e.g. in contrastive search), converts to bloom's format if needed if past_key_values is not None and self.config.model_type == "bloom": @@ -368,7 +452,7 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[-1] :] return { "input_ids": input_ids, @@ -379,6 +463,15 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg "token_type_ids": None, } + def get_past_length(self, past_key_values): + model_type = self.config.model_type.replace("_", "-") + if model_type == "bloom": + return past_key_values[0][0].shape[-1] + elif model_type.replace("-", "_") in MULTI_QUERY_ATTN_MODELS: + return past_key_values[0].shape[1] + else: + return past_key_values[0][0].shape[-2] + def _prepare_past_key_values(self, input_ids): model_type = self.config.model_type.replace("_", "-") nb_pkv = 2 @@ -431,7 +524,7 @@ def forward( position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[-1] :] if "position_ids" in self.input_names or not self.input_names: inputs["position_ids"] = position_ids From 41bf0f5bd8b9e9ba1ad22cfe36e1716140dfcea0 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 26 Feb 2024 08:23:16 -0500 Subject: [PATCH 04/16] add comment for _prepare_past_key_values --- optimum/intel/ipex/modeling_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index acf27c431..b1ac5b009 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -472,6 +472,7 @@ def get_past_length(self, past_key_values): else: return past_key_values[0][0].shape[-2] + # Rewrite it to avoid jit failed, original function may call attributes which jit model don't have def _prepare_past_key_values(self, input_ids): model_type = self.config.model_type.replace("_", "-") nb_pkv = 2 From dd63ee7a98fda2fb943b70b02d63c58dbd6b5064 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 26 Feb 2024 09:08:46 -0500 Subject: [PATCH 05/16] rebase main --- optimum/intel/ipex/modeling_base.py | 128 +--------------------------- 1 file changed, 2 insertions(+), 126 deletions(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index c66d1cae9..b8f5bc465 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -344,12 +344,12 @@ def __init__( except AttributeError: self.model_cls = get_model_class(self.config, AutoModelForCausalLM._model_mapping) - self._reorder_cache = self.model_cls._reorder_cache.__get__(self) + self._reorder_cache = self.model_cls._reorder_cache if is_transformers_version(">=", "4.38.0") and model_type in {"llama", "phi", "persimmon"}: self.prepare_inputs_for_generation = _prepare_inputs_for_generation_for_llama else: - self.prepare_inputs_for_generation = self.model_cls.prepare_inputs_for_generation.__get__(self) + self.prepare_inputs_for_generation = self.model_cls.prepare_inputs_for_generation if hasattr(self.model_cls, "_convert_to_standard_cache"): self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache @@ -358,130 +358,6 @@ def __init__( if warmup: self._init_warmup() - def _reorder_cache( - self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> Tuple[Tuple[torch.Tensor]]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. - This is required to match `past_key_values` with the correct beam_idx at every generation step. - """ - if self.config.model_type == "bloom": - return self._reorder_cache_bloom(past_key_values, beam_idx) - - # from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past_key_values - ) - - # Copied from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache - def _reorder_cache_bloom( - self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> Tuple[Tuple[torch.Tensor]]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called for bloom architecture. - This is required to match `past_key_values` with the correct beam_idx at every generation step. - """ - standardized_past = self._convert_to_standard_cache(past_key_values, batch_size=len(beam_idx)) - - # Get a copy of `beam_idx` on all the devices where we need those indices. - device_to_beam_idx = { - past_state.device: beam_idx.to(past_state.device) - for layer_past in past_key_values - for past_state in layer_past - } - reordered_past = tuple( - ( - layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]), - layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]), - ) - for layer_past in standardized_past - ) - return self._convert_to_bloom_cache(reordered_past) - - # Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._convert_to_bloom_cache - @staticmethod - def _convert_to_bloom_cache(past_key_value: Tuple[Tuple[torch.Tensor]]) -> Tuple[Tuple[torch.Tensor]]: - """ - Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...])) - """ - batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape - batch_size_times_num_heads = batch_size * num_heads - # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length] - # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim] - return tuple( - ( - layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length), - layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim), - ) - for layer_past in past_key_value - ) - - # Adapted from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._convert_to_standard_cache - def _convert_to_standard_cache( - self, past_key_value: Tuple[Tuple[torch.Tensor]], batch_size: int - ) -> Tuple[Tuple[torch.Tensor]]: - """ - Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size, num_heads, ...])) - """ - if self.config.model_type != "bloom": - return past_key_value - - batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape - num_heads = batch_size_times_num_heads // batch_size - # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length] - # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim] - return tuple( - ( - layer_past[0].view(batch_size, num_heads, head_dim, seq_length), - layer_past[1].view(batch_size, num_heads, seq_length, head_dim), - ) - for layer_past in past_key_value - ) - - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): - past_key_values = past_key_values or kwargs.get("past", None) - - if self.use_cache and past_key_values is not None: - past_length = self.get_past_length(past_key_values) - input_ids = input_ids[:, past_length:] - - # `past_key_values` may be in the stardard format (e.g. in contrastive search), converts to bloom's format if needed - if past_key_values is not None and self.config.model_type == "bloom": - if past_key_values[0][0].shape[0] == input_ids.shape[0]: - past_key_values = self._convert_to_bloom_cache(past_key_values) - - position_ids = kwargs.get("position_ids", None) - - attention_mask = kwargs.get("attention_mask", None) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[-1] :] - - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "use_cache": self.use_cache, - "position_ids": position_ids, - "attention_mask": attention_mask, - "token_type_ids": None, - } - - def get_past_length(self, past_key_values): - model_type = self.config.model_type.replace("_", "-") - if model_type == "bloom": - return past_key_values[0][0].shape[-1] - elif model_type.replace("-", "_") in MULTI_QUERY_ATTN_MODELS: - return past_key_values[0].shape[1] - else: - return past_key_values[0][0].shape[-2] - # Rewrite it to avoid jit failed, original function may call attributes which jit model don't have def _prepare_past_key_values(self, input_ids): model_type = self.config.model_type.replace("_", "-") From 16706d324d86671521229f6e81ba48b13f396d28 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 26 Feb 2024 09:25:37 -0500 Subject: [PATCH 06/16] fix model_dtype --- optimum/intel/ipex/modeling_base.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index b8f5bc465..8424eed5a 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -124,6 +124,7 @@ def _from_transformers( save_dir_path = Path(save_dir.name) torch.jit.save(traced_model, save_dir_path / WEIGHTS_NAME) config.torchscript = True + config.torch_dtype = torch_dtype return cls._from_pretrained( model_id=save_dir_path, @@ -134,7 +135,6 @@ def _from_transformers( cache_dir=cache_dir, local_files_only=local_files_only, use_cache=use_cache, - model_dtype=torch_dtype, ) @classmethod @@ -208,6 +208,11 @@ def device(self) -> torch.device: def dtype(self) -> torch.dtype: return self._dtype + @property + def model_dtype(self): + logger.warning("model_dtype will be removed after v1.18.0") + return self._dtype + def to(self, device: Union[torch.device, str]): self._device = device if isinstance(device, torch.device) else torch.device(device) self.model.to(self._device) @@ -322,8 +327,6 @@ def __init__( model_type = config.model_type.replace("_", "-") self.normalized_config = NormalizedConfigManager.get_normalized_config_class(model_type)(config) - self.model_dtype = kwargs.get("model_dtype", self.dtype) - self._dtype = self.model_dtype self.use_cache = "past_key_values" in self.input_names if use_cache ^ self.use_cache: From 124477222571ad3627389c667e8c6487c29ee7b5 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 26 Feb 2024 09:29:31 -0500 Subject: [PATCH 07/16] rm useless comments --- optimum/intel/ipex/modeling_base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 8424eed5a..002142db8 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -361,7 +361,6 @@ def __init__( if warmup: self._init_warmup() - # Rewrite it to avoid jit failed, original function may call attributes which jit model don't have def _prepare_past_key_values(self, input_ids): model_type = self.config.model_type.replace("_", "-") nb_pkv = 2 From ccad4b5c1cffcf9baf0cdfdd1aaa70803e26741c Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 26 Feb 2024 10:21:37 -0500 Subject: [PATCH 08/16] fix class name --- optimum/intel/ipex/modeling_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 002142db8..16845aeb5 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -347,12 +347,12 @@ def __init__( except AttributeError: self.model_cls = get_model_class(self.config, AutoModelForCausalLM._model_mapping) - self._reorder_cache = self.model_cls._reorder_cache + self._reorder_cache = self.model_cls._reorder_cache.__get__(self) if is_transformers_version(">=", "4.38.0") and model_type in {"llama", "phi", "persimmon"}: self.prepare_inputs_for_generation = _prepare_inputs_for_generation_for_llama else: - self.prepare_inputs_for_generation = self.model_cls.prepare_inputs_for_generation + self.prepare_inputs_for_generation = self.model_cls.prepare_inputs_for_generation.__get__(self) if hasattr(self.model_cls, "_convert_to_standard_cache"): self._convert_to_standard_cache = self.model_cls._convert_to_standard_cache From 740af9420a85195b5d2af639acd0f21075ded7f9 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 1 Mar 2024 08:20:25 -0500 Subject: [PATCH 09/16] revert _call_model --- optimum/intel/ipex/modeling_base.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 16845aeb5..68f225f9d 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -193,7 +193,7 @@ def forward( if "token_type_ids" in self.input_names: inputs["token_type_ids"] = token_type_ids - outputs = self.model(**inputs) + outputs = self._call_model(**inputs) return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0]) def eval(self): @@ -213,6 +213,14 @@ def model_dtype(self): logger.warning("model_dtype will be removed after v1.18.0") return self._dtype + def _call_model(self, *args, **kwargs): + try: + with torch.autocast(self.device.type, self.dtype): + out = self.model(*args, **kwargs) + except RuntimeError: + out = self.model(*args, **kwargs) + return out + def to(self, device: Union[torch.device, str]): self._device = device if isinstance(device, torch.device) else torch.device(device) self.model.to(self._device) @@ -258,7 +266,7 @@ def forward( "pixel_values": pixel_values, } - outputs = self.model(**inputs) + outputs = self._call_model(**inputs) return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0]) @@ -279,7 +287,7 @@ def forward( if "attention_mask" in self.input_names: inputs["attention_mask"] = attention_mask - outputs = self.model(**inputs) + outputs = self._call_model(**inputs) return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0]) @@ -302,7 +310,7 @@ def forward( if "token_type_ids" in self.input_names: inputs["token_type_ids"] = token_type_ids - outputs = self.model(**inputs) + outputs = self._call_model(**inputs) start_logits = outputs["start_logits"] if isinstance(outputs, dict) else outputs[0] end_logits = outputs["end_logits"] if isinstance(outputs, dict) else outputs[1] return ModelOutput(start_logits=start_logits, end_logits=end_logits) @@ -425,7 +433,7 @@ def forward( inputs["past_key_values"] = past_key_values # 2. Model forward - outputs = self.model(**inputs) + outputs = self._call_model(**inputs) # 3. Process model outputs if isinstance(outputs, (list, tuple)): From 26ebb310db97854051c1c70377176c5181462189 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 1 Mar 2024 08:32:30 -0500 Subject: [PATCH 10/16] fix model_dtype warning liog --- optimum/intel/ipex/modeling_base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 68f225f9d..6dea3cac2 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -210,7 +210,9 @@ def dtype(self) -> torch.dtype: @property def model_dtype(self): - logger.warning("model_dtype will be removed after v1.18.0") + logger.warning( + "access to the `model_dtype` attribute is deprecated and will be removed after v1.18.0, please use `_dtype` instead." + ) return self._dtype def _call_model(self, *args, **kwargs): From 3a966c5c905e712c5b83baed82ba4c8caa25ef1a Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 8 Mar 2024 06:07:58 -0500 Subject: [PATCH 11/16] testiong low precision ipex model --- tests/ipex/test_modeling.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index ffc2ca6a8..a38ba86f9 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -144,6 +144,26 @@ def test_pipeline(self, model_arch): _ = pipe(text) self.assertEqual(pipe.device, model.device) + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_low_precision(self, model_arch): + model_id = MODEL_NAMES[model_arch] + ipex_model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, export=True, torch_dtype=torch.bfloat16) + self.assertEqual(ipex_model._dtype, torch.bfloat16) + transformers_model = self.IPEX_MODEL_CLASS.auto_model_class.from_pretrained( + model_id, torch_dtype=torch.bfloat16 + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = "This is a sample input" + tokens = tokenizer(inputs, return_tensors="pt") + with torch.no_grad(): + transformers_outputs = transformers_model(**tokens) + outputs = ipex_model(**tokens) + # Compare tensor outputs + for output_name in {"logits", "last_hidden_state"}: + if output_name in transformers_outputs: + self.assertEqual(outputs[output_name].dtype, torch.bfloat16) + self.assertTrue(torch.allclose(outputs[output_name], transformers_outputs[output_name], atol=1e-1)) + class IPEXModelForSequenceClassificationTest(IPEXModelTest): IPEX_MODEL_CLASS = IPEXModelForTokenClassification From 55a59e39ef6d2dfe47641a3e7f166af81b07b1f4 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 8 Mar 2024 06:23:56 -0500 Subject: [PATCH 12/16] add assisted decoding --- tests/ipex/test_modeling.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index a38ba86f9..087b1e415 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -275,6 +275,17 @@ def test_pipeline(self, model_arch): self.assertEqual(pipe.device, model.device) self.assertTrue(all("This is a sample" in item["generated_text"] for item in outputs)) + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_assisted_decoding(self, model_arch): + model_id = MODEL_NAMES[model_arch] + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = IPEXModelForCausalLM.from_pretrained(model_id, export=True) + assistant_model = AutoModelForCausalLM.from_pretrained(model_id) + tokens = tokenizer("This is a sample input", return_tensors="pt") + output = model.generate(**tokens, do_sample=False) + output_assisted = model.generate(**tokens, do_sample=False, assistant_model=assistant_model) + self.assertTrue(torch.equal(output, output_assisted)) + def test_compare_with_and_without_past_key_values(self): model_id = "echarlaix/tiny-random-gpt2-torchscript" tokenizer = AutoTokenizer.from_pretrained(model_id) From 248f0d2805838dab558133bb9ea3c44a007dcec1 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 8 Mar 2024 06:33:27 -0500 Subject: [PATCH 13/16] remove low-precision testing as CI node does not support bf16 --- tests/ipex/test_modeling.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 087b1e415..50293bfad 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -144,26 +144,6 @@ def test_pipeline(self, model_arch): _ = pipe(text) self.assertEqual(pipe.device, model.device) - @parameterized.expand(SUPPORTED_ARCHITECTURES) - def test_low_precision(self, model_arch): - model_id = MODEL_NAMES[model_arch] - ipex_model = self.IPEX_MODEL_CLASS.from_pretrained(model_id, export=True, torch_dtype=torch.bfloat16) - self.assertEqual(ipex_model._dtype, torch.bfloat16) - transformers_model = self.IPEX_MODEL_CLASS.auto_model_class.from_pretrained( - model_id, torch_dtype=torch.bfloat16 - ) - tokenizer = AutoTokenizer.from_pretrained(model_id) - inputs = "This is a sample input" - tokens = tokenizer(inputs, return_tensors="pt") - with torch.no_grad(): - transformers_outputs = transformers_model(**tokens) - outputs = ipex_model(**tokens) - # Compare tensor outputs - for output_name in {"logits", "last_hidden_state"}: - if output_name in transformers_outputs: - self.assertEqual(outputs[output_name].dtype, torch.bfloat16) - self.assertTrue(torch.allclose(outputs[output_name], transformers_outputs[output_name], atol=1e-1)) - class IPEXModelForSequenceClassificationTest(IPEXModelTest): IPEX_MODEL_CLASS = IPEXModelForTokenClassification From 32232bb02d2b7ad94c4cc198099af36ebc030195 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 11 Mar 2024 05:33:02 -0400 Subject: [PATCH 14/16] fix conflict --- optimum/intel/ipex/modeling_base.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 1d13b65ac..0ba42fb5e 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -272,14 +272,6 @@ def model_dtype(self): ) return self._dtype - def _call_model(self, *args, **kwargs): - try: - with torch.autocast(self.device.type, self.dtype), torch.no_grad(): - out = self.model(*args, **kwargs) - except RuntimeError: - out = self.model(*args, **kwargs) - return out - def to(self, device: Union[torch.device, str]): self._device = device if isinstance(device, torch.device) else torch.device(device) self.model.to(self._device) @@ -288,6 +280,14 @@ def to(self, device: Union[torch.device, str]): def can_generate(self): return isinstance(self, GenerationMixin) + def _call_model(self, *args, **kwargs): + try: + with torch.autocast(self.device.type, self.dtype), torch.no_grad(): + out = self.model(*args, **kwargs) + except RuntimeError: + out = self.model(*args, **kwargs) + return out + def _init_warmup(self): # warmup, the first 2 forwards of an IPEX model include some preprocessing steps and # the results of the compute are unpredictable From 05510648405b3ef6e858e869c6264b474dbe55e8 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 14 Mar 2024 05:43:11 -0400 Subject: [PATCH 15/16] remove prepare position_ids in forward --- optimum/intel/ipex/modeling_base.py | 6 ------ tests/ipex/test_modeling.py | 21 ++++++++++----------- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index 0ba42fb5e..a12519425 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -506,12 +506,6 @@ def forward( "attention_mask": attention_mask, } - if "position_ids" in self.input_names and position_ids is None: - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[-1] :] - if "position_ids" in self.input_names or not self.input_names: inputs["position_ids"] = position_ids diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index cdf6dffff..7866836cc 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -32,7 +32,6 @@ set_seed, ) -from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS from optimum.intel import ( IPEXModel, IPEXModelForAudioClassification, @@ -236,11 +235,8 @@ def test_compare_to_transformers(self, model_arch): return_tensors="pt", return_token_type_ids=False if model_arch in ("llama", "llama2") else None, ) - position_ids = None - if model_arch.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS: - input_shape = tokens["input_ids"].shape - position_ids = torch.arange(0, input_shape[-1], dtype=torch.long).unsqueeze(0).view(-1, input_shape[-1]) - outputs = ipex_model(**tokens, position_ids=position_ids) + inputs = ipex_model.prepare_inputs_for_generation(**tokens) + outputs = ipex_model(**inputs) self.assertIsInstance(outputs.logits, torch.Tensor) self.assertIsInstance(outputs.past_key_values, (tuple, list)) @@ -267,12 +263,15 @@ def test_pipeline(self, model_arch): def test_assisted_decoding(self, model_arch): model_id = MODEL_NAMES[model_arch] tokenizer = AutoTokenizer.from_pretrained(model_id) - model = IPEXModelForCausalLM.from_pretrained(model_id, export=True) - assistant_model = AutoModelForCausalLM.from_pretrained(model_id) + ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, export=True) + transformers_model = AutoModelForCausalLM.from_pretrained(model_id) tokens = tokenizer("This is a sample input", return_tensors="pt") - output = model.generate(**tokens, do_sample=False) - output_assisted = model.generate(**tokens, do_sample=False, assistant_model=assistant_model) - self.assertTrue(torch.equal(output, output_assisted)) + ipex_output = ipex_model.generate(**tokens, do_sample=False) + ipex_output_assisted = ipex_model.generate(**tokens, do_sample=False, assistant_model=transformers_model) + transformers_output = transformers_model.generate(**tokens, do_sample=False) + transformers_output_assisted = transformers_model.generate(**tokens, do_sample=False, assistant_model=ipex_model) + self.assertTrue(torch.equal(ipex_output, ipex_output_assisted)) + self.assertTrue(torch.equal(transformers_output, transformers_output_assisted)) @parameterized.expand( grid_parameters( From a80b073d5b726d8e6b9b6cffa5699b8ce810803e Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 14 Mar 2024 05:48:42 -0400 Subject: [PATCH 16/16] fix code style --- tests/ipex/test_modeling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py index 7866836cc..c46ce1cdc 100644 --- a/tests/ipex/test_modeling.py +++ b/tests/ipex/test_modeling.py @@ -269,7 +269,9 @@ def test_assisted_decoding(self, model_arch): ipex_output = ipex_model.generate(**tokens, do_sample=False) ipex_output_assisted = ipex_model.generate(**tokens, do_sample=False, assistant_model=transformers_model) transformers_output = transformers_model.generate(**tokens, do_sample=False) - transformers_output_assisted = transformers_model.generate(**tokens, do_sample=False, assistant_model=ipex_model) + transformers_output_assisted = transformers_model.generate( + **tokens, do_sample=False, assistant_model=ipex_model + ) self.assertTrue(torch.equal(ipex_output, ipex_output_assisted)) self.assertTrue(torch.equal(transformers_output, transformers_output_assisted))