diff --git a/comfy/sd.py b/comfy/sd.py index b4842b260eb..fb9a104a8a4 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -247,9 +247,10 @@ def encode(self, pixel_samples): if hasattr(self.device, 'type') and (self.device.type != 'cpu' and self.device.type != 'mps'): devices.append(torch.device("cpu")) + pixel_samples = pixel_samples.clone().movedim(-1,1) + for device in devices: self.first_stage_model = self.first_stage_model.to(device) - pixel_samples = pixel_samples.movedim(-1,1) try: memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.7 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change. model_management.free_memory(memory_used, device)