Skip to content

Commit bce9aa9

Browse files
committed
upgrad minimum torch version to 2.5
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
1 parent 2c54045 commit bce9aa9

File tree

5 files changed

+22
-9
lines changed

5 files changed

+22
-9
lines changed

.github/workflows/test_ipex.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
fail-fast: false
2020
matrix:
2121
transformers-version: ["4.46.0", "4.46.3"]
22-
torch-version: ["2.4.0", "2.5.*"]
22+
torch-version: ["2.5.*"]
2323

2424
runs-on: ubuntu-22.04
2525

optimum/exporters/ipex/modeling_utils.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@
3232

3333
logger = logging.getLogger(__name__)
3434

35-
_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.4.0"
36-
_IPEX_MINIMUM_VERSION_FOR_FLASH_VARLEN_ATTN = "2.5.0"
35+
_IPEX_MINIMUM_VERSION_FOR_PATCHING = "2.5.0"
3736

3837

3938
if is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_PATCHING):
@@ -213,6 +212,8 @@ def _llama_model_forward(
213212
position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1))
214213
else:
215214
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
215+
216+
if past_key_values is None:
216217
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
217218
attention_mask=attention_mask,
218219
input_shape=(input_ids.shape[0], input_ids.shape[-1]),
@@ -334,6 +335,8 @@ def _falcon_model_forward(
334335
position_embeddings = (cos.unsqueeze(1), sin.unsqueeze(1))
335336
else:
336337
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
338+
339+
if past_key_values is None:
337340
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
338341
attention_mask=attention_mask,
339342
input_shape=(input_ids.shape[0], input_ids.shape[-1]),
@@ -463,6 +466,8 @@ def _gpt2_model_forward(
463466
hidden_states = (hidden_states.view(-1, hidden_states.shape[-1]))[index]
464467
else:
465468
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
469+
470+
if past_key_values is None:
466471
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
467472
attention_mask=attention_mask,
468473
input_shape=(input_ids.shape[0], input_ids.shape[-1]),
@@ -660,11 +665,16 @@ def forward(
660665

661666
if past_len == 0:
662667
# prefill
663-
if past_key_value is None or is_ipex_version("<", _IPEX_MINIMUM_VERSION_FOR_FLASH_VARLEN_ATTN):
668+
if past_key_value is None:
669+
n_rep = query.shape[1] // key.shape[1]
664670
attn_output = torch.nn.functional.scaled_dot_product_attention(
665671
query.reshape(input_lens.shape[0], input_lens.max().item(), -1, query.shape[-1]).transpose(1, 2),
666-
key.reshape(input_lens.shape[0], input_lens.max().item(), -1, key.shape[-1]).transpose(1, 2),
667-
value.reshape(input_lens.shape[0], input_lens.max().item(), -1, value.shape[-1]).transpose(1, 2),
672+
key.reshape(input_lens.shape[0], input_lens.max().item(), -1, key.shape[-1])
673+
.transpose(1, 2)
674+
.repeat_interleave(n_rep, 1),
675+
value.reshape(input_lens.shape[0], input_lens.max().item(), -1, value.shape[-1])
676+
.transpose(1, 2)
677+
.repeat_interleave(n_rep, 1),
668678
attn_mask=attention_mask,
669679
dropout_p=0.0,
670680
is_causal=True,

optimum/intel/ipex/modeling_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,9 @@ def prepare_inputs_for_generation(self, *args, **kwargs):
299299
return self.model.prepare_inputs_for_generation(*args, **kwargs)
300300

301301
def generate(self, *args, **kwargs):
302-
if is_ipex_version("<", "2.4.0") and self._add_patch and kwargs.get("assistant_model", None):
302+
if self._add_patch and kwargs.get("assistant_model", None):
303303
raise ValueError(
304-
f"Assisted decoding is not supported for patched models if ipex < 2.4, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
304+
f"Assisted decoding is not supported for patched models for now, support methods are {_IPEX_EXPORTED_GENERATION_METHODS}"
305305
)
306306
# Patch functions to support ipex_paged cache
307307
if self._add_patch:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
"nncf": ["nncf>=2.14.0"],
6767
"openvino": ["nncf>=2.14.0", "openvino>=2024.5.0", "openvino-tokenizers>=2024.5.0"],
6868
"neural-compressor": ["neural-compressor[pt]>3.0", "accelerate", "transformers<4.46"],
69-
"ipex": ["intel-extension-for-pytorch>=2.4", "transformers>4.45,<4.47"],
69+
"ipex": ["intel-extension-for-pytorch>=2.5", "transformers>4.45,<4.47"],
7070
"diffusers": ["diffusers"],
7171
"quality": QUALITY_REQUIRE,
7272
"tests": TESTS_REQUIRE,

tests/ipex/test_modeling.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,9 @@ def test_compare_with_and_without_past_key_values(self):
377377
outputs_model_without_pkv = model_without_pkv.generate(
378378
**tokens, min_new_tokens=self.GENERATION_LENGTH, max_new_tokens=self.GENERATION_LENGTH, num_beams=1
379379
)
380+
import pdb
381+
382+
pdb.set_trace()
380383
self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv))
381384
self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH + tokens.input_ids.shape[1])
382385
self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH + tokens.input_ids.shape[1])

0 commit comments

Comments
 (0)