Skip to content

Commit 5e1da44

Browse files
committed
IPEX fix custom_fwd x2
1 parent b862400 commit 5e1da44

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

modules/intel/ipex/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,8 @@ def ipex_init(): # pylint: disable=too-many-statements
149149

150150
# AMP:
151151
if legacy:
152-
torch.xpu.amp.custom_fwd = torch.amp.custom_fwd
153-
torch.xpu.amp.custom_bwd = torch.amp.custom_bwd
152+
torch.xpu.amp.custom_fwd = torch.cuda.amp.custom_fwd
153+
torch.xpu.amp.custom_bwd = torch.cuda.amp.custom_bwd
154154
torch.cuda.amp = torch.xpu.amp
155155
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
156156
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype

0 commit comments

Comments
 (0)