diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 72fce10872f..0a607abeab1 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -80,6 +80,7 @@ class LatentPreviewMethod(enum.Enum): TAESD = "taesd" parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction) +parser.add_argument("--preview-cpu", action="store_true", help="To use the CPU for preview (slow).") attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") @@ -99,6 +100,7 @@ class LatentPreviewMethod(enum.Enum): parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.") +parser.add_argument("--memory-estimation-multiplier", type=float, default=-1, help="Multiplier for the memory estimation.") parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.") parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.") diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 016795a5974..8efa4b75593 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -217,16 +217,21 @@ def attention_split(q, k, v, heads, mask=None): gb = 1024 ** 3 tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size modifier = 3 + if args.memory_estimation_multiplier >= 0: + modifier = args.memory_estimation_multiplier mem_required = tensor_size * modifier steps = 1 + max_steps = q.shape[1] - 1 + while (q.shape[1] % max_steps) != 0: + max_steps -= 1 if mem_required > mem_free_total: steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") - if steps > 64: + if steps > max_steps: max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') @@ -259,8 +264,10 @@ def attention_split(q, k, v, heads, mask=None): cleared_cache = True print("out of memory error, emptying cache and trying again") continue - steps *= 2 - if steps > 64: + steps += 1 + while (q.shape[1] % steps) != 0 and steps < max_steps: + steps += 1 + if steps > max_steps: raise e print("out of memory error, increasing steps and trying again", steps) else: diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index f23417fd216..cdeb43edfc6 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -7,6 +7,7 @@ from typing import Optional, Any from comfy import model_management +from comfy.cli_args import args import comfy.ops if model_management.xformers_enabled_vae(): @@ -165,9 +166,15 @@ def slice_attention(q, k, v): gb = 1024 ** 3 tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() modifier = 3 if q.element_size() == 2 else 2.5 + if args.memory_estimation_multiplier >= 0: + modifier = args.memory_estimation_multiplier mem_required = tensor_size * modifier steps = 1 + max_steps = q.shape[1] - 1 + while (q.shape[1] % max_steps) != 0: + max_steps -= 1 + if mem_required > mem_free_total: steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) @@ -186,8 +193,10 @@ def slice_attention(q, k, v): break except model_management.OOM_EXCEPTION as e: model_management.soft_empty_cache(True) - steps *= 2 - if steps > 128: + steps += 1 + while (q.shape[1] % steps) != 0 and steps < max_steps: + steps += 1 + if steps > max_steps: raise e print("out of memory error, increasing steps and trying again", steps) diff --git a/comfy/sd.py b/comfy/sd.py index c3cc8e72080..1e205d26412 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -217,10 +217,21 @@ def decode(self, samples_in): samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).cpu().float() + 1.0) / 2.0, min=0.0, max=1.0) except model_management.OOM_EXCEPTION as e: - print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") - pixel_samples = self.decode_tiled_(samples_in) - - self.first_stage_model = self.first_stage_model.to(self.offload_device) + tile_size = 64 + while tile_size >= 8: + overlap = tile_size // 4 + print(f"Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding with tile size {tile_size} and overlap {overlap}.") + try: + pixel_samples = self.decode_tiled_(samples_in, tile_x=tile_size, tile_y=tile_size, overlap=overlap) + break + except model_management.OOM_EXCEPTION as e: + pass + tile_size -= 8 + + if pixel_samples is None: + raise e + finally: + self.first_stage_model = self.first_stage_model.to(self.offload_device) pixel_samples = pixel_samples.cpu().movedim(1,-1) return pixel_samples @@ -245,10 +256,21 @@ def encode(self, pixel_samples): samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float() except model_management.OOM_EXCEPTION as e: - print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") - samples = self.encode_tiled_(pixel_samples) - - self.first_stage_model = self.first_stage_model.to(self.offload_device) + tile_size = 512 + while tile_size >= 64: + overlap = tile_size // 8 + print(f"Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding with tile size {tile_size} and overlap {overlap}.") + try: + samples = self.encode_tiled_(pixel_samples, tile_x=tile_size, tile_y=tile_size, overlap=overlap) + break + except model_management.OOM_EXCEPTION as e: + pass + tile_size -= 64 + + if samples is None: + raise e + finally: + self.first_stage_model = self.first_stage_model.to(self.offload_device) return samples def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): diff --git a/latent_preview.py b/latent_preview.py index 6e758a1a9d1..7399f06568f 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -6,6 +6,7 @@ from comfy.taesd.taesd import TAESD import folder_paths import comfy.utils +from comfy import model_management MAX_PREVIEW_RESOLUTION = 512 @@ -18,11 +19,12 @@ def decode_latent_to_preview_image(self, preview_format, x0): return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION) class TAESDPreviewerImpl(LatentPreviewer): - def __init__(self, taesd): + def __init__(self, taesd, device): self.taesd = taesd + self.device = device def decode_latent_to_preview(self, x0): - x_sample = self.taesd.decoder(x0[:1])[0].detach() + x_sample = self.taesd.decoder(x0[:1].to(self.device))[0].detach() # x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2] x_sample = x_sample.sub(0.5).mul(2) @@ -52,6 +54,8 @@ def decode_latent_to_preview(self, x0): def get_previewer(device, latent_format): previewer = None method = args.preview_method + if args.preview_cpu: + device = torch.device("cpu") if method != LatentPreviewMethod.NoPreviews: # TODO previewer methods taesd_decoder_path = None @@ -71,7 +75,7 @@ def get_previewer(device, latent_format): if method == LatentPreviewMethod.TAESD: if taesd_decoder_path: taesd = TAESD(None, taesd_decoder_path).to(device) - previewer = TAESDPreviewerImpl(taesd) + previewer = TAESDPreviewerImpl(taesd, device) else: print("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name)) @@ -94,7 +98,10 @@ def callback(step, x0, x, total_steps): preview_bytes = None if previewer: - preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0) + try: + preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0) + except model_management.OOM_EXCEPTION as e: + pass pbar.update_absolute(step + 1, total_steps, preview_bytes) return callback