Skip to content

Commit

Permalink
Option to force tiled VAE decode #1650
Browse files Browse the repository at this point in the history
  • Loading branch information
Acly committed Mar 11, 2025
1 parent b96b9ff commit 6a80e3b
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 7 deletions.
1 change: 1 addition & 0 deletions ai_diffusion/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class CheckpointInput:
rescale_cfg: float = 0.7
self_attention_guidance: bool = False
dynamic_caching: bool = False
tiled_vae: bool = False


@dataclass
Expand Down
1 change: 1 addition & 0 deletions ai_diffusion/comfy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ def performance_settings(self):
batch_size=settings.batch_size,
resolution_multiplier=settings.resolution_multiplier,
max_pixel_count=settings.max_pixel_count,
tiled_vae=settings.tiled_vae,
dynamic_caching=settings.dynamic_caching and self.features.wave_speed,
)

Expand Down
5 changes: 5 additions & 0 deletions ai_diffusion/comfy_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,11 @@ def vae_encode_inpaint(self, vae: Output, image: Output, mask: Output):
def vae_decode(self, vae: Output, latent_image: Output):
return self.add("VAEDecode", 1, vae=vae, samples=latent_image)

def vae_decode_tiled(self, vae: Output, latent_image: Output):
return self.add(
"VAEDecodeTiled", 1, vae=vae, samples=latent_image, tile_size=512, overlap=64
)

def set_latent_noise_mask(self, latent: Output, mask: Output):
return self.add("SetLatentNoiseMask", 1, samples=latent, mask=mask)

Expand Down
10 changes: 10 additions & 0 deletions ai_diffusion/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class PerformancePresetSettings(NamedTuple):
batch_size: int = 4
resolution_multiplier: float = 1.0
max_pixel_count: int = 6
tiled_vae: bool = False


@dataclass
Expand All @@ -67,6 +68,7 @@ class PerformanceSettings:
resolution_multiplier: float = 1.0
max_pixel_count: int = 6
dynamic_caching: bool = False
tiled_vae: bool = False


class Setting:
Expand Down Expand Up @@ -267,6 +269,13 @@ class Settings(QObject):
_("Re-use outputs of previous steps (First Block Cache) to speed up generation."),
)

tiled_vae: bool
_tiled_vae = Setting(
_("Tiled VAE"),
False,
_("Conserve memory by processing output images in smaller tiles."),
)

_performance_presets = {
PerformancePreset.cpu: PerformancePresetSettings(
batch_size=1,
Expand All @@ -277,6 +286,7 @@ class Settings(QObject):
batch_size=2,
resolution_multiplier=1.0,
max_pixel_count=2,
tiled_vae=True,
),
PerformancePreset.medium: PerformancePresetSettings(
batch_size=4,
Expand Down
8 changes: 8 additions & 0 deletions ai_diffusion/ui/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,12 @@ def __init__(self):
self._max_pixel_count.value_changed.connect(self.write)
advanced_layout.addWidget(self._max_pixel_count)

self._tiled_vae = SwitchSetting(
Settings._tiled_vae, text=(_("Always"), _("Automatic")), parent=self._advanced
)
self._tiled_vae.value_changed.connect(self.write)
advanced_layout.addWidget(self._tiled_vae)

self._dynamic_caching = SwitchSetting(Settings._dynamic_caching, parent=self)
self._dynamic_caching.value_changed.connect(self.write)
self._layout.addWidget(self._dynamic_caching)
Expand Down Expand Up @@ -635,6 +641,7 @@ def _read(self):
)
self._resolution_multiplier.value = settings.resolution_multiplier
self._max_pixel_count.value = settings.max_pixel_count
self._tiled_vae.value = settings.tiled_vae
self._dynamic_caching.value = settings.dynamic_caching
self.update_client_info()

Expand All @@ -644,6 +651,7 @@ def _write(self):
settings.batch_size = int(self._batch_size.value)
settings.resolution_multiplier = self._resolution_multiplier.value
settings.max_pixel_count = self._max_pixel_count.value
settings.tiled_vae = self._tiled_vae.value
settings.performance_preset = list(PerformancePreset)[
self._performance_preset.currentIndex()
]
Expand Down
20 changes: 13 additions & 7 deletions ai_diffusion/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,11 @@ def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, mod
if arch.supports_attention_guidance and checkpoint.self_attention_guidance:
model = w.apply_self_attention_guidance(model)

return model, clip, vae

def vae_decode(w: ComfyWorkflow, vae: Output, latent: Output, tiled: bool):
if tiled:
return w.vae_decode_tiled(vae, latent)
return w.vae_decode(vae, latent)


class ImageReshape(NamedTuple):
Expand Down Expand Up @@ -642,6 +646,7 @@ def scale_refine_and_decode(
clip: Output,
vae: Output,
models: ModelDict,
tiled_vae: bool,
):
"""Handles scaling images from `initial` to `desired` resolution.
If it is a substantial upscale, runs a high-res SD refinement pass.
Expand Down Expand Up @@ -673,7 +678,7 @@ def scale_refine_and_decode(
w, model, positive, negative, cond.all_control, extent.desired, vae, models
)
result = w.sampler_custom_advanced(model, positive, negative, latent, models.arch, **params)
image = w.vae_decode(vae, result)
image = vae_decode(w, vae, result, tiled_vae)
return image


Expand Down Expand Up @@ -712,7 +717,7 @@ def generate(
model, positive, negative, latent, models.arch, **_sampler_params(sampling)
)
out_image = scale_refine_and_decode(
extent, w, cond, sampling, out_latent, model_orig, clip, vae, models
extent, w, cond, sampling, out_latent, model_orig, clip, vae, models, checkpoint.tiled_vae
)
out_image = w.nsfw_filter(out_image, sensitivity=misc.nsfw_filter)
out_image = scale_to_target(extent, w, out_image, models)
Expand Down Expand Up @@ -905,15 +910,15 @@ def inpaint(
out_latent = w.sampler_custom_advanced(
model, positive_up, negative_up, latent, models.arch, **sampler_params
)
out_image = w.vae_decode(vae, out_latent)
out_image = vae_decode(w, vae, out_latent, checkpoint.tiled_vae)
out_image = scale_to_target(upscale_extent, w, out_image, models)
else:
desired_bounds = extent.convert(target_bounds, "target", "desired")
desired_extent = desired_bounds.extent
cropped_extent = ScaledExtent(
desired_extent, desired_extent, desired_extent, target_bounds.extent
)
out_image = w.vae_decode(vae, out_latent)
out_image = vae_decode(w, vae, out_latent, checkpoint.tiled_vae)
out_image = scale(
extent.initial, extent.desired, extent.refinement_scaling, w, out_image, models
)
Expand Down Expand Up @@ -952,7 +957,7 @@ def refine(
sampler = w.sampler_custom_advanced(
model, positive, negative, latent, models.arch, **_sampler_params(sampling)
)
out_image = w.vae_decode(vae, sampler)
out_image = vae_decode(w, vae, sampler, checkpoint.tiled_vae)
out_image = w.nsfw_filter(out_image, sensitivity=misc.nsfw_filter)
out_image = scale_to_target(extent, w, out_image, models)
w.send_image(out_image)
Expand Down Expand Up @@ -1006,7 +1011,7 @@ def refine_region(
inpaint_model, positive, negative, latent, models.arch, **_sampler_params(sampling)
)
out_image = scale_refine_and_decode(
extent, w, cond, sampling, out_latent, model_orig, clip, vae, models
extent, w, cond, sampling, out_latent, model_orig, clip, vae, models, checkpoint.tiled_vae
)
out_image = w.nsfw_filter(out_image, sensitivity=misc.nsfw_filter)
out_image = scale_to_target(extent, w, out_image, models)
Expand Down Expand Up @@ -1290,6 +1295,7 @@ def prepare(
i.conditioning.positive += _collect_lora_triggers(i.models.loras, files)
i.models.loras = unique(i.models.loras + extra_loras, key=lambda l: l.name)
i.models.dynamic_caching = perf.dynamic_caching
i.models.tiled_vae = perf.tiled_vae
arch = i.models.version = resolve_arch(style, models)

_check_server_has_models(i.models, i.conditioning.regions, models, files, style.name)
Expand Down

0 comments on commit 6a80e3b

Please sign in to comment.