Skip to content

Commit

Permalink
Memory optimizations to allow bigger images
Browse files Browse the repository at this point in the history
  • Loading branch information
jn-jairo committed Nov 21, 2023
1 parent 6ff06fa commit 0de7950
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 17 deletions.
2 changes: 2 additions & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class LatentPreviewMethod(enum.Enum):
TAESD = "taesd"

parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
parser.add_argument("--preview-cpu", action="store_true", help="To use the CPU for preview (slow).")

attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
Expand All @@ -99,6 +100,7 @@ class LatentPreviewMethod(enum.Enum):

parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")

parser.add_argument("--memory-estimation-multiplier", type=float, default=-1, help="Multiplier for the memory estimation.")

parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
Expand Down
13 changes: 10 additions & 3 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,16 +217,21 @@ def attention_split(q, k, v, heads, mask=None):
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
modifier = 3
if args.memory_estimation_multiplier >= 0:
modifier = args.memory_estimation_multiplier
mem_required = tensor_size * modifier
steps = 1

max_steps = q.shape[1] - 1
while (q.shape[1] % max_steps) != 0:
max_steps -= 1

if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")

if steps > 64:
if steps > max_steps:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
Expand Down Expand Up @@ -259,8 +264,10 @@ def attention_split(q, k, v, heads, mask=None):
cleared_cache = True
print("out of memory error, emptying cache and trying again")
continue
steps *= 2
if steps > 64:
steps += 1
while (q.shape[1] % steps) != 0 and steps < max_steps:
steps += 1
if steps > max_steps:
raise e
print("out of memory error, increasing steps and trying again", steps)
else:
Expand Down
13 changes: 11 additions & 2 deletions comfy/ldm/modules/diffusionmodules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Optional, Any

from comfy import model_management
from comfy.cli_args import args
import comfy.ops

if model_management.xformers_enabled_vae():
Expand Down Expand Up @@ -165,9 +166,15 @@ def slice_attention(q, k, v):
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
if args.memory_estimation_multiplier >= 0:
modifier = args.memory_estimation_multiplier
mem_required = tensor_size * modifier
steps = 1

max_steps = q.shape[1] - 1
while (q.shape[1] % max_steps) != 0:
max_steps -= 1

if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))

Expand All @@ -186,8 +193,10 @@ def slice_attention(q, k, v):
break
except model_management.OOM_EXCEPTION as e:
model_management.soft_empty_cache(True)
steps *= 2
if steps > 128:
steps += 1
while (q.shape[1] % steps) != 0 and steps < max_steps:
steps += 1
if steps > max_steps:
raise e
print("out of memory error, increasing steps and trying again", steps)

Expand Down
38 changes: 30 additions & 8 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,21 @@ def decode(self, samples_in):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).cpu().float() + 1.0) / 2.0, min=0.0, max=1.0)
except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
pixel_samples = self.decode_tiled_(samples_in)

self.first_stage_model = self.first_stage_model.to(self.offload_device)
tile_size = 64
while tile_size >= 8:
overlap = tile_size // 4
print(f"Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding with tile size {tile_size} and overlap {overlap}.")
try:
pixel_samples = self.decode_tiled_(samples_in, tile_x=tile_size, tile_y=tile_size, overlap=overlap)
break
except model_management.OOM_EXCEPTION as e:
pass
tile_size -= 8

if pixel_samples is None:
raise e
finally:
self.first_stage_model = self.first_stage_model.to(self.offload_device)
pixel_samples = pixel_samples.cpu().movedim(1,-1)
return pixel_samples

Expand All @@ -245,10 +256,21 @@ def encode(self, pixel_samples):
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float()

except model_management.OOM_EXCEPTION as e:
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
samples = self.encode_tiled_(pixel_samples)

self.first_stage_model = self.first_stage_model.to(self.offload_device)
tile_size = 512
while tile_size >= 64:
overlap = tile_size // 8
print(f"Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding with tile size {tile_size} and overlap {overlap}.")
try:
samples = self.encode_tiled_(pixel_samples, tile_x=tile_size, tile_y=tile_size, overlap=overlap)
break
except model_management.OOM_EXCEPTION as e:
pass
tile_size -= 64

if samples is None:
raise e
finally:
self.first_stage_model = self.first_stage_model.to(self.offload_device)
return samples

def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
Expand Down
15 changes: 11 additions & 4 deletions latent_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from comfy.taesd.taesd import TAESD
import folder_paths
import comfy.utils
from comfy import model_management

MAX_PREVIEW_RESOLUTION = 512

Expand All @@ -18,11 +19,12 @@ def decode_latent_to_preview_image(self, preview_format, x0):
return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION)

class TAESDPreviewerImpl(LatentPreviewer):
def __init__(self, taesd):
def __init__(self, taesd, device):
self.taesd = taesd
self.device = device

def decode_latent_to_preview(self, x0):
x_sample = self.taesd.decoder(x0[:1])[0].detach()
x_sample = self.taesd.decoder(x0[:1].to(self.device))[0].detach()
# x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2]
x_sample = x_sample.sub(0.5).mul(2)

Expand Down Expand Up @@ -52,6 +54,8 @@ def decode_latent_to_preview(self, x0):
def get_previewer(device, latent_format):
previewer = None
method = args.preview_method
if args.preview_cpu:
device = torch.device("cpu")
if method != LatentPreviewMethod.NoPreviews:
# TODO previewer methods
taesd_decoder_path = None
Expand All @@ -71,7 +75,7 @@ def get_previewer(device, latent_format):
if method == LatentPreviewMethod.TAESD:
if taesd_decoder_path:
taesd = TAESD(None, taesd_decoder_path).to(device)
previewer = TAESDPreviewerImpl(taesd)
previewer = TAESDPreviewerImpl(taesd, device)
else:
print("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))

Expand All @@ -94,7 +98,10 @@ def callback(step, x0, x, total_steps):

preview_bytes = None
if previewer:
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
try:
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
except model_management.OOM_EXCEPTION as e:
pass
pbar.update_absolute(step + 1, total_steps, preview_bytes)
return callback

0 comments on commit 0de7950

Please sign in to comment.