Skip to content

Commit 88a2f95

Browse files
committed
Fix IPEX 2.3 x2
1 parent 04f61d8 commit 88a2f95

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

modules/intel/ipex/hijacks.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -254,13 +254,13 @@ def UntypedStorage_to(self, *args, device=None, **kwargs):
254254
else:
255255
return original_UntypedStorage_to(self, *args, device=device, **kwargs)
256256

257-
original_UntypedStorage_cuda = torch.UntypedStorage.cuda
258-
@wraps(torch.UntypedStorage.cuda)
259-
def UntypedStorage_cuda(self, device=None, non_blocking=False, **kwargs):
260-
if device is None or check_cuda(device):
261-
return self.to(device=return_xpu(device), non_blocking=non_blocking, **kwargs)
262-
else:
263-
return original_UntypedStorage_cuda(self, device=device, non_blocking=non_blocking, **kwargs)
257+
original_UntypedStorage_cuda = torch.UntypedStorage.cuda
258+
@wraps(torch.UntypedStorage.cuda)
259+
def UntypedStorage_cuda(self, device=None, non_blocking=False, **kwargs):
260+
if device is None or check_cuda(device):
261+
return self.to(device=return_xpu(device), non_blocking=non_blocking, **kwargs)
262+
else:
263+
return original_UntypedStorage_cuda(self, device=device, non_blocking=non_blocking, **kwargs)
264264

265265
original_torch_empty = torch.empty
266266
@wraps(torch.empty)
@@ -347,16 +347,16 @@ def torch_cuda_synchronize(device=None):
347347
# Hijack Functions:
348348
def ipex_hijacks(legacy=True):
349349
global device_supports_fp64, can_allocate_plus_4gb
350-
if legacy and float(torch.__version__[:3]) < 2.5:
350+
if float(torch.__version__[:3]) >= 2.4:
351+
torch.UntypedStorage.cuda = UntypedStorage_cuda
352+
torch.UntypedStorage.to = UntypedStorage_to
353+
else: # ipex 2.3 and below
351354
torch.nn.functional.interpolate = interpolate
352355
torch.tensor = torch_tensor
353356
torch.Tensor.to = Tensor_to
354357
torch.Tensor.cuda = Tensor_cuda
355358
torch.Tensor.pin_memory = Tensor_pin_memory
356359
torch.UntypedStorage.__init__ = UntypedStorage_init
357-
torch.UntypedStorage.cuda = UntypedStorage_cuda
358-
if float(torch.__version__[:3]) >= 2.4:
359-
torch.UntypedStorage.to = UntypedStorage_to
360360
torch.empty = torch_empty
361361
torch.randn = torch_randn
362362
torch.ones = torch_ones

0 commit comments

Comments
 (0)