diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py index b9f9a708da..aa57ec20fa 100644 --- a/optimum/intel/ipex/modeling_base.py +++ b/optimum/intel/ipex/modeling_base.py @@ -105,6 +105,11 @@ def ipex_jit_trace(model, task, use_cache): if not use_cache: sample_inputs.pop("past_key_values") + # Use Tensor Processing Primitives to accelerate linear, see https://arxiv.org/abs/2104.05755. + # Only ipex >= 2.3.0 supports tpp. The tpp is only verified for llm in generation tasks. + if is_ipex_version(">=", "2.3.0"): + _enable_tpp() + model = ipex.optimize(model.eval(), dtype=model.dtype, inplace=True) # Disable repack while jit tracing to reduce the memory ipex._C.disable_jit_linear_repack()