diff --git a/ai_diffusion/api.py b/ai_diffusion/api.py index b72866ea8..f9da1ddd1 100644 --- a/ai_diffusion/api.py +++ b/ai_diffusion/api.py @@ -63,6 +63,7 @@ class CheckpointInput: rescale_cfg: float = 0.7 self_attention_guidance: bool = False dynamic_caching: bool = False + tiled_vae: bool = False @dataclass diff --git a/ai_diffusion/comfy_client.py b/ai_diffusion/comfy_client.py index e6540b018..2a5eb9be2 100644 --- a/ai_diffusion/comfy_client.py +++ b/ai_diffusion/comfy_client.py @@ -475,6 +475,7 @@ def performance_settings(self): batch_size=settings.batch_size, resolution_multiplier=settings.resolution_multiplier, max_pixel_count=settings.max_pixel_count, + tiled_vae=settings.tiled_vae, dynamic_caching=settings.dynamic_caching and self.features.wave_speed, ) diff --git a/ai_diffusion/comfy_workflow.py b/ai_diffusion/comfy_workflow.py index 215981dcd..2a261119a 100644 --- a/ai_diffusion/comfy_workflow.py +++ b/ai_diffusion/comfy_workflow.py @@ -752,6 +752,11 @@ def vae_encode_inpaint(self, vae: Output, image: Output, mask: Output): def vae_decode(self, vae: Output, latent_image: Output): return self.add("VAEDecode", 1, vae=vae, samples=latent_image) + def vae_decode_tiled(self, vae: Output, latent_image: Output): + return self.add( + "VAEDecodeTiled", 1, vae=vae, samples=latent_image, tile_size=512, overlap=64 + ) + def set_latent_noise_mask(self, latent: Output, mask: Output): return self.add("SetLatentNoiseMask", 1, samples=latent, mask=mask) diff --git a/ai_diffusion/settings.py b/ai_diffusion/settings.py index 1ec97247a..751697638 100644 --- a/ai_diffusion/settings.py +++ b/ai_diffusion/settings.py @@ -59,6 +59,7 @@ class PerformancePresetSettings(NamedTuple): batch_size: int = 4 resolution_multiplier: float = 1.0 max_pixel_count: int = 6 + tiled_vae: bool = False @dataclass @@ -67,6 +68,7 @@ class PerformanceSettings: resolution_multiplier: float = 1.0 max_pixel_count: int = 6 dynamic_caching: bool = False + tiled_vae: bool = False class Setting: @@ -267,6 +269,13 @@ class Settings(QObject): _("Re-use outputs of previous steps (First Block Cache) to speed up generation."), ) + tiled_vae: bool + _tiled_vae = Setting( + _("Tiled VAE"), + False, + _("Conserve memory by processing output images in smaller tiles."), + ) + _performance_presets = { PerformancePreset.cpu: PerformancePresetSettings( batch_size=1, @@ -277,6 +286,7 @@ class Settings(QObject): batch_size=2, resolution_multiplier=1.0, max_pixel_count=2, + tiled_vae=True, ), PerformancePreset.medium: PerformancePresetSettings( batch_size=4, diff --git a/ai_diffusion/ui/settings.py b/ai_diffusion/ui/settings.py index a8cca8031..0f49c6239 100644 --- a/ai_diffusion/ui/settings.py +++ b/ai_diffusion/ui/settings.py @@ -592,6 +592,12 @@ def __init__(self): self._max_pixel_count.value_changed.connect(self.write) advanced_layout.addWidget(self._max_pixel_count) + self._tiled_vae = SwitchSetting( + Settings._tiled_vae, text=(_("Always"), _("Automatic")), parent=self._advanced + ) + self._tiled_vae.value_changed.connect(self.write) + advanced_layout.addWidget(self._tiled_vae) + self._dynamic_caching = SwitchSetting(Settings._dynamic_caching, parent=self) self._dynamic_caching.value_changed.connect(self.write) self._layout.addWidget(self._dynamic_caching) @@ -635,6 +641,7 @@ def _read(self): ) self._resolution_multiplier.value = settings.resolution_multiplier self._max_pixel_count.value = settings.max_pixel_count + self._tiled_vae.value = settings.tiled_vae self._dynamic_caching.value = settings.dynamic_caching self.update_client_info() @@ -644,6 +651,7 @@ def _write(self): settings.batch_size = int(self._batch_size.value) settings.resolution_multiplier = self._resolution_multiplier.value settings.max_pixel_count = self._max_pixel_count.value + settings.tiled_vae = self._tiled_vae.value settings.performance_preset = list(PerformancePreset)[ self._performance_preset.currentIndex() ] diff --git a/ai_diffusion/workflow.py b/ai_diffusion/workflow.py index a4b6b6a46..38007323e 100644 --- a/ai_diffusion/workflow.py +++ b/ai_diffusion/workflow.py @@ -153,7 +153,11 @@ def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, mod if arch.supports_attention_guidance and checkpoint.self_attention_guidance: model = w.apply_self_attention_guidance(model) - return model, clip, vae + +def vae_decode(w: ComfyWorkflow, vae: Output, latent: Output, tiled: bool): + if tiled: + return w.vae_decode_tiled(vae, latent) + return w.vae_decode(vae, latent) class ImageReshape(NamedTuple): @@ -642,6 +646,7 @@ def scale_refine_and_decode( clip: Output, vae: Output, models: ModelDict, + tiled_vae: bool, ): """Handles scaling images from `initial` to `desired` resolution. If it is a substantial upscale, runs a high-res SD refinement pass. @@ -673,7 +678,7 @@ def scale_refine_and_decode( w, model, positive, negative, cond.all_control, extent.desired, vae, models ) result = w.sampler_custom_advanced(model, positive, negative, latent, models.arch, **params) - image = w.vae_decode(vae, result) + image = vae_decode(w, vae, result, tiled_vae) return image @@ -712,7 +717,7 @@ def generate( model, positive, negative, latent, models.arch, **_sampler_params(sampling) ) out_image = scale_refine_and_decode( - extent, w, cond, sampling, out_latent, model_orig, clip, vae, models + extent, w, cond, sampling, out_latent, model_orig, clip, vae, models, checkpoint.tiled_vae ) out_image = w.nsfw_filter(out_image, sensitivity=misc.nsfw_filter) out_image = scale_to_target(extent, w, out_image, models) @@ -905,7 +910,7 @@ def inpaint( out_latent = w.sampler_custom_advanced( model, positive_up, negative_up, latent, models.arch, **sampler_params ) - out_image = w.vae_decode(vae, out_latent) + out_image = vae_decode(w, vae, out_latent, checkpoint.tiled_vae) out_image = scale_to_target(upscale_extent, w, out_image, models) else: desired_bounds = extent.convert(target_bounds, "target", "desired") @@ -913,7 +918,7 @@ def inpaint( cropped_extent = ScaledExtent( desired_extent, desired_extent, desired_extent, target_bounds.extent ) - out_image = w.vae_decode(vae, out_latent) + out_image = vae_decode(w, vae, out_latent, checkpoint.tiled_vae) out_image = scale( extent.initial, extent.desired, extent.refinement_scaling, w, out_image, models ) @@ -952,7 +957,7 @@ def refine( sampler = w.sampler_custom_advanced( model, positive, negative, latent, models.arch, **_sampler_params(sampling) ) - out_image = w.vae_decode(vae, sampler) + out_image = vae_decode(w, vae, sampler, checkpoint.tiled_vae) out_image = w.nsfw_filter(out_image, sensitivity=misc.nsfw_filter) out_image = scale_to_target(extent, w, out_image, models) w.send_image(out_image) @@ -1006,7 +1011,7 @@ def refine_region( inpaint_model, positive, negative, latent, models.arch, **_sampler_params(sampling) ) out_image = scale_refine_and_decode( - extent, w, cond, sampling, out_latent, model_orig, clip, vae, models + extent, w, cond, sampling, out_latent, model_orig, clip, vae, models, checkpoint.tiled_vae ) out_image = w.nsfw_filter(out_image, sensitivity=misc.nsfw_filter) out_image = scale_to_target(extent, w, out_image, models) @@ -1290,6 +1295,7 @@ def prepare( i.conditioning.positive += _collect_lora_triggers(i.models.loras, files) i.models.loras = unique(i.models.loras + extra_loras, key=lambda l: l.name) i.models.dynamic_caching = perf.dynamic_caching + i.models.tiled_vae = perf.tiled_vae arch = i.models.version = resolve_arch(style, models) _check_server_has_models(i.models, i.conditioning.regions, models, files, style.name)