Skip to content

Commit

Permalink
unpatch target model's generation
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng committed Sep 9, 2024
1 parent 0748782 commit 9837e46
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,17 +682,23 @@ def generate(self, *args, **kwargs):
f"Assisted decoding is not supported for patched models if ipex < 2.4, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
)
# Patch functions to support IAKV cache
if self._is_ipex_exported:
_patch_crop_past_key_values()

result = super().generate(*args, **kwargs)
if self._is_ipex_exported and kwargs.get("assistant_model", None):
transformers.generation.utils._crop_past_key_values = _ipex_crop_past_key_values
elif self._is_ipex_exported:
transformers.generation.candidate_generator._crop_past_key_values = _ipex_crop_past_key_values

return result
try:
result = super().generate(*args, **kwargs)
except Exception as e:
transformers.generation.utils._crop_past_key_values = _crop_past_key_values
transformers.generation.candidate_generator._crop_past_key_values = _crop_past_key_values
raise e

if self._is_ipex_exported and kwargs.get("assistant_model", None):
transformers.generation.utils._crop_past_key_values = _crop_past_key_values
transformers.generation.candidate_generator._crop_past_key_values = _crop_past_key_values

def _patch_crop_past_key_values():
transformers.generation.candidate_generator._crop_past_key_values = _ipex_crop_past_key_values
transformers.generation.utils._crop_past_key_values = _ipex_crop_past_key_values
return result


def _ipex_prepare_inputs_for_generation(
Expand Down

0 comments on commit 9837e46

Please sign in to comment.