Skip to content

Commit 119c372

Browse files
committed
More offload checks
1 parent 7edc864 commit 119c372

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

modules/model_stablecascade.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import copy
33
import torch
44
import diffusers
5-
from modules import shared, devices
5+
from modules import shared, devices, sd_models
66

77
def get_timestep_ratio_conditioning(t, alphas_cumprod):
88
s = torch.tensor([0.008]) # diffusers uses 0.003 while the original is 0.008
@@ -191,7 +191,8 @@ def __call__(
191191
callback_on_step_end=None,
192192
callback_on_step_end_tensor_inputs=["latents"],
193193
):
194-
194+
if shared.opts.diffusers_offload_mode == "balanced":
195+
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
195196
# 0. Define commonly used variables
196197
self.guidance_scale = guidance_scale
197198
self.do_classifier_free_guidance = self.guidance_scale > 1
@@ -334,6 +335,8 @@ def __call__(
334335

335336
# Offload all models
336337
self.maybe_free_model_hooks()
338+
if shared.opts.diffusers_offload_mode == "balanced":
339+
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
337340

338341
if not return_dict:
339342
return images

modules/processing_vae.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ def full_vae_decode(latents, model):
6969
model.vae.apply(sd_models.convert_to_faketensors)
7070
devices.torch_gc(force=True)
7171

72-
if shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False) and base_device is not None:
72+
if shared.opts.diffusers_offload_mode == "balanced":
73+
shared.sd_model = sd_models.apply_balanced_offload(shared.sd_model)
74+
elif shared.opts.diffusers_move_unet and not getattr(model, 'has_accelerate', False) and base_device is not None:
7375
sd_models.move_base(model, base_device)
7476
t1 = time.time()
7577
debug(f'VAE decode: name={sd_vae.loaded_vae_file if sd_vae.loaded_vae_file is not None else "baked"} dtype={model.vae.dtype} upcast={upcast} images={latents.shape[0]} latents={latents.shape} time={round(t1-t0, 3)}')

0 commit comments

Comments
 (0)