Skip to content

Commit 340177e

Browse files
Disable non blocking on mps.
1 parent 614b7e7 commit 340177e

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

comfy/model_management.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -553,15 +553,19 @@ def cast_to_device(tensor, device, dtype, copy=False):
553553
elif is_intel_xpu():
554554
device_supports_cast = True
555555

556+
non_blocking = True
557+
if is_device_mps(device):
558+
non_blocking = False #pytorch bug? mps doesn't support non blocking
559+
556560
if device_supports_cast:
557561
if copy:
558562
if tensor.device == device:
559-
return tensor.to(dtype, copy=copy, non_blocking=True)
560-
return tensor.to(device, copy=copy, non_blocking=True).to(dtype, non_blocking=True)
563+
return tensor.to(dtype, copy=copy, non_blocking=non_blocking)
564+
return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
561565
else:
562-
return tensor.to(device, non_blocking=True).to(dtype, non_blocking=True)
566+
return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking)
563567
else:
564-
return tensor.to(device, dtype, copy=copy, non_blocking=True)
568+
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
565569

566570
def xformers_enabled():
567571
global directml_enabled

0 commit comments

Comments
 (0)