File tree 2 files changed +8
-3
lines changed
2 files changed +8
-3
lines changed Original file line number Diff line number Diff line change 2
2
import copy
3
3
import torch
4
4
import diffusers
5
- from modules import shared , devices
5
+ from modules import shared , devices , sd_models
6
6
7
7
def get_timestep_ratio_conditioning (t , alphas_cumprod ):
8
8
s = torch .tensor ([0.008 ]) # diffusers uses 0.003 while the original is 0.008
@@ -191,7 +191,8 @@ def __call__(
191
191
callback_on_step_end = None ,
192
192
callback_on_step_end_tensor_inputs = ["latents" ],
193
193
):
194
-
194
+ if shared .opts .diffusers_offload_mode == "balanced" :
195
+ shared .sd_model = sd_models .apply_balanced_offload (shared .sd_model )
195
196
# 0. Define commonly used variables
196
197
self .guidance_scale = guidance_scale
197
198
self .do_classifier_free_guidance = self .guidance_scale > 1
@@ -334,6 +335,8 @@ def __call__(
334
335
335
336
# Offload all models
336
337
self .maybe_free_model_hooks ()
338
+ if shared .opts .diffusers_offload_mode == "balanced" :
339
+ shared .sd_model = sd_models .apply_balanced_offload (shared .sd_model )
337
340
338
341
if not return_dict :
339
342
return images
Original file line number Diff line number Diff line change @@ -69,7 +69,9 @@ def full_vae_decode(latents, model):
69
69
model .vae .apply (sd_models .convert_to_faketensors )
70
70
devices .torch_gc (force = True )
71
71
72
- if shared .opts .diffusers_move_unet and not getattr (model , 'has_accelerate' , False ) and base_device is not None :
72
+ if shared .opts .diffusers_offload_mode == "balanced" :
73
+ shared .sd_model = sd_models .apply_balanced_offload (shared .sd_model )
74
+ elif shared .opts .diffusers_move_unet and not getattr (model , 'has_accelerate' , False ) and base_device is not None :
73
75
sd_models .move_base (model , base_device )
74
76
t1 = time .time ()
75
77
debug (f'VAE decode: name={ sd_vae .loaded_vae_file if sd_vae .loaded_vae_file is not None else "baked" } dtype={ model .vae .dtype } upcast={ upcast } images={ latents .shape [0 ]} latents={ latents .shape } time={ round (t1 - t0 , 3 )} ' )
You can’t perform that action at this time.
0 commit comments