Skip to content

Commit

Permalink
Add taesd and taesdxl to VAELoader node.
Browse files Browse the repository at this point in the history
They will show up if both the taesd_encoder and taesd_decoder or taesdxl
model files are present in the models/vae_approx directory.
  • Loading branch information
comfyanonymous committed Nov 21, 2023
1 parent 6ff06fa commit cd4fc77
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 17 deletions.
17 changes: 12 additions & 5 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import comfy.lora
import comfy.t2i_adapter.adapter
import comfy.supported_models_base
import comfy.taesd.taesd

def load_model_weights(model, sd):
m, u = model.load_state_dict(sd, strict=False)
Expand Down Expand Up @@ -154,10 +155,16 @@ def __init__(self, sd=None, device=None, config=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)

self.memory_used_encode = lambda shape: (2078 * shape[2] * shape[3]) * 1.7 #These are for AutoencoderKL and need tweaking
self.memory_used_decode = lambda shape: (2562 * shape[2] * shape[3] * 64) * 1.7

if config is None:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
if "taesd_decoder.1.weight" in sd:
self.first_stage_model = comfy.taesd.taesd.TAESD()
else:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
else:
self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval()
Expand Down Expand Up @@ -206,7 +213,7 @@ def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
def decode(self, samples_in):
self.first_stage_model = self.first_stage_model.to(self.device)
try:
memory_used = (2562 * samples_in.shape[2] * samples_in.shape[3] * 64) * 1.7
memory_used = self.memory_used_decode(samples_in.shape)
model_management.free_memory(memory_used, self.device)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
Expand Down Expand Up @@ -234,7 +241,7 @@ def encode(self, pixel_samples):
self.first_stage_model = self.first_stage_model.to(self.device)
pixel_samples = pixel_samples.movedim(-1,1)
try:
memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.7 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
memory_used = self.memory_used_encode(pixel_samples.shape) #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
model_management.free_memory(memory_used, self.device)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
Expand Down
19 changes: 14 additions & 5 deletions comfy/taesd/taesd.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,16 @@ class TAESD(nn.Module):
latent_magnitude = 3
latent_shift = 0.5

def __init__(self, encoder_path="taesd_encoder.pth", decoder_path="taesd_decoder.pth"):
def __init__(self, encoder_path=None, decoder_path=None):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
self.taesd_encoder = Encoder()
self.taesd_decoder = Decoder()
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
if encoder_path is not None:
self.encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
self.taesd_encoder.load_state_dict(comfy.utils.load_torch_file(encoder_path, safe_load=True))
if decoder_path is not None:
self.decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))

@staticmethod
def scale_latents(x):
Expand All @@ -65,3 +66,11 @@ def scale_latents(x):
def unscale_latents(x):
"""[0, 1] -> raw latents"""
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)

def decode(self, x):
x_sample = self.taesd_decoder(x * self.vae_scale)
x_sample = x_sample.sub(0.5).mul(2)
return x_sample

def encode(self, x):
return self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale
5 changes: 1 addition & 4 deletions latent_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@ def __init__(self, taesd):
self.taesd = taesd

def decode_latent_to_preview(self, x0):
x_sample = self.taesd.decoder(x0[:1])[0].detach()
# x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2]
x_sample = x_sample.sub(0.5).mul(2)

x_sample = self.taesd.decode(x0[:1])[0].detach()
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
Expand Down
55 changes: 52 additions & 3 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,18 +573,67 @@ def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
return (model_lora, clip_lora)

class VAELoader:
@staticmethod
def vae_list():
vaes = folder_paths.get_filename_list("vae")
approx_vaes = folder_paths.get_filename_list("vae_approx")
sdxl_taesd_enc = False
sdxl_taesd_dec = False
sd1_taesd_enc = False
sd1_taesd_dec = False

for v in approx_vaes:
if v.startswith("taesd_decoder."):
sd1_taesd_dec = True
elif v.startswith("taesd_encoder."):
sd1_taesd_enc = True
elif v.startswith("taesdxl_decoder."):
sdxl_taesd_dec = True
elif v.startswith("taesdxl_encoder."):
sdxl_taesd_enc = True
if sd1_taesd_dec and sd1_taesd_enc:
vaes.append("taesd")
if sdxl_taesd_dec and sdxl_taesd_enc:
vaes.append("taesdxl")
return vaes

@staticmethod
def load_taesd(name):
sd = {}
approx_vaes = folder_paths.get_filename_list("vae_approx")

encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))

enc = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder))
for k in enc:
sd["taesd_encoder.{}".format(k)] = enc[k]

dec = comfy.utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder))
for k in dec:
sd["taesd_decoder.{}".format(k)] = dec[k]

if name == "taesd":
sd["vae_scale"] = torch.tensor(0.18215)
elif name == "taesdxl":
sd["vae_scale"] = torch.tensor(0.13025)
return sd

@classmethod
def INPUT_TYPES(s):
return {"required": { "vae_name": (folder_paths.get_filename_list("vae"), )}}
return {"required": { "vae_name": (s.vae_list(), )}}
RETURN_TYPES = ("VAE",)
FUNCTION = "load_vae"

CATEGORY = "loaders"

#TODO: scale factor?
def load_vae(self, vae_name):
vae_path = folder_paths.get_full_path("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path)
if vae_name in ["taesd", "taesdxl"]:
sd = self.load_taesd(vae_name)
else:
vae_path = folder_paths.get_full_path("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)
return (vae,)

Expand Down

0 comments on commit cd4fc77

Please sign in to comment.