Skip to content

Commit

Permalink
fix traced model patch check
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng committed Jul 5, 2024
1 parent e1ef5c8 commit 66f3365
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions optimum/intel/ipex/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def _is_patched_with_ipex(model, task):

if isinstance(model, torch.jit.ScriptModule):
for node in model.graph.nodes():
# Only patched model enabled tpp linear by _enable_tpp().
if "torch_ipex::tpp_linear" in node.__str__():
# Only patched model enabled fusion linear.
if "/fusions/" in node.__str__():
return True
return False
elif task in _IPEX_EXPORTED_GENERATION_TASKS and model.config.hidden_size < 64:
Expand All @@ -99,8 +99,6 @@ def ipex_jit_trace(model, task, use_cache):

if _is_patched_with_ipex(model, task):
model = _patch_model(model)
# Use Tensor Processing Primitives to accelerate linear, see https://arxiv.org/abs/2104.05755.
_enable_tpp()

sample_inputs = _prepare_inputs_for_ipex_model(model, task, use_cache)

Expand All @@ -111,6 +109,8 @@ def ipex_jit_trace(model, task, use_cache):
if not use_cache:
sample_inputs.pop("past_key_values")

# Use Tensor Processing Primitives to accelerate linear, see https://arxiv.org/abs/2104.05755.
_enable_tpp()
model = ipex.optimize(model.eval(), dtype=model.dtype, inplace=True)
# Disable repack while jit tracing to reduce the memory
ipex._C.disable_jit_linear_repack()
Expand Down

0 comments on commit 66f3365

Please sign in to comment.