Skip to content

Commit 7b67c1f

Browse files
committed
use torch.no_grad in jit trace
1 parent 41ca8c4 commit 7b67c1f

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

optimum/intel/ipex/modeling_base.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,16 @@ def ipex_jit_trace(model, task, use_cache):
8989
model.config.return_dict = False
9090

9191
model = ipex.optimize(model.eval(), dtype=model.dtype, inplace=True)
92-
trace_model = torch.jit.trace(
93-
model,
94-
example_kwarg_inputs=sample_inputs,
95-
strict=False,
96-
check_trace=False,
97-
)
98-
trace_model = torch.jit.freeze(trace_model)
99-
trace_model(**sample_inputs)
100-
trace_model(**sample_inputs)
92+
with torch.no_grad():
93+
trace_model = torch.jit.trace(
94+
model,
95+
example_kwarg_inputs=sample_inputs,
96+
strict=False,
97+
check_trace=False,
98+
)
99+
trace_model = torch.jit.freeze(trace_model)
100+
trace_model(**sample_inputs)
101+
trace_model(**sample_inputs)
101102

102103
return trace_model
103104

0 commit comments

Comments
 (0)