Skip to content

Commit

Permalink
Disable tpp for un-verified models (#822)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng authored and IlyasMoutawwakil committed Aug 6, 2024
1 parent b016fa3 commit 82627b2
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def ipex_jit_trace(model, task, use_cache):
if not use_cache:
sample_inputs.pop("past_key_values")

# Use Tensor Processing Primitives to accelerate linear, see https://arxiv.org/abs/2104.05755.
# Only ipex >= 2.3.0 supports tpp. The tpp is only verified for llm in generation tasks.
if is_ipex_version(">=", "2.3.0") and task in _IPEX_EXPORTED_GENERATION_TASKS:
_enable_tpp()
model = ipex.optimize(model.eval(), dtype=model.dtype, inplace=True)
# Disable repack while jit tracing to reduce the memory
ipex._C.disable_jit_linear_repack()
Expand Down

0 comments on commit 82627b2

Please sign in to comment.