We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent b862400 commit 5e1da44Copy full SHA for 5e1da44
modules/intel/ipex/__init__.py
@@ -149,8 +149,8 @@ def ipex_init(): # pylint: disable=too-many-statements
149
150
# AMP:
151
if legacy:
152
- torch.xpu.amp.custom_fwd = torch.amp.custom_fwd
153
- torch.xpu.amp.custom_bwd = torch.amp.custom_bwd
+ torch.xpu.amp.custom_fwd = torch.cuda.amp.custom_fwd
+ torch.xpu.amp.custom_bwd = torch.cuda.amp.custom_bwd
154
torch.cuda.amp = torch.xpu.amp
155
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
156
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
0 commit comments