@@ -254,13 +254,13 @@ def UntypedStorage_to(self, *args, device=None, **kwargs):
254
254
else :
255
255
return original_UntypedStorage_to (self , * args , device = device , ** kwargs )
256
256
257
- original_UntypedStorage_cuda = torch .UntypedStorage .cuda
258
- @wraps (torch .UntypedStorage .cuda )
259
- def UntypedStorage_cuda (self , device = None , non_blocking = False , ** kwargs ):
260
- if device is None or check_cuda (device ):
261
- return self .to (device = return_xpu (device ), non_blocking = non_blocking , ** kwargs )
262
- else :
263
- return original_UntypedStorage_cuda (self , device = device , non_blocking = non_blocking , ** kwargs )
257
+ original_UntypedStorage_cuda = torch .UntypedStorage .cuda
258
+ @wraps (torch .UntypedStorage .cuda )
259
+ def UntypedStorage_cuda (self , device = None , non_blocking = False , ** kwargs ):
260
+ if device is None or check_cuda (device ):
261
+ return self .to (device = return_xpu (device ), non_blocking = non_blocking , ** kwargs )
262
+ else :
263
+ return original_UntypedStorage_cuda (self , device = device , non_blocking = non_blocking , ** kwargs )
264
264
265
265
original_torch_empty = torch .empty
266
266
@wraps (torch .empty )
@@ -347,16 +347,16 @@ def torch_cuda_synchronize(device=None):
347
347
# Hijack Functions:
348
348
def ipex_hijacks (legacy = True ):
349
349
global device_supports_fp64 , can_allocate_plus_4gb
350
- if legacy and float (torch .__version__ [:3 ]) < 2.5 :
350
+ if float (torch .__version__ [:3 ]) >= 2.4 :
351
+ torch .UntypedStorage .cuda = UntypedStorage_cuda
352
+ torch .UntypedStorage .to = UntypedStorage_to
353
+ else : # ipex 2.3 and below
351
354
torch .nn .functional .interpolate = interpolate
352
355
torch .tensor = torch_tensor
353
356
torch .Tensor .to = Tensor_to
354
357
torch .Tensor .cuda = Tensor_cuda
355
358
torch .Tensor .pin_memory = Tensor_pin_memory
356
359
torch .UntypedStorage .__init__ = UntypedStorage_init
357
- torch .UntypedStorage .cuda = UntypedStorage_cuda
358
- if float (torch .__version__ [:3 ]) >= 2.4 :
359
- torch .UntypedStorage .to = UntypedStorage_to
360
360
torch .empty = torch_empty
361
361
torch .randn = torch_randn
362
362
torch .ones = torch_ones
0 commit comments