Skip to content

Commit

Permalink
add latent logger + small fix to Image.Lanczos
Browse files Browse the repository at this point in the history
  • Loading branch information
rishab-partha committed Jun 18, 2024
1 parent 93a5469 commit 5e2c7a7
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 13 deletions.
110 changes: 98 additions & 12 deletions diffusion/callbacks/log_diffusion_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@

"""Logger for generated images."""

import gc
from math import ceil
from typing import List, Optional, Tuple, Union

import torch
from composer import Callback, Logger, State
from composer.core import TimeUnit, get_precision_context
from torch.nn.parallel import DistributedDataParallel
from transformers import AutoModel, AutoTokenizer, CLIPTextModel


class LogDiffusionImages(Callback):
Expand Down Expand Up @@ -45,14 +47,18 @@ def __init__(self,
guidance_scale: float = 0.0,
rescaled_guidance: Optional[float] = None,
seed: Optional[int] = 1138,
use_table: bool = False):
use_table: bool = False,
text_encoder: Optional[str] = None,
clip_encoder: Optional[str] = None,
cache_dir: Optional[str] = '/tmp/hf_files'):
self.prompts = prompts
self.size = (size, size) if isinstance(size, int) else size
self.num_inference_steps = num_inference_steps
self.guidance_scale = guidance_scale
self.rescaled_guidance = rescaled_guidance
self.seed = seed
self.use_table = use_table
self.cache_dir = cache_dir

# Batch prompts
batch_size = len(prompts) if batch_size is None else batch_size
Expand All @@ -62,6 +68,66 @@ def __init__(self,
start, end = i * batch_size, (i + 1) * batch_size
self.batched_prompts.append(prompts[start:end])

if text_encoder is not None and clip_encoder is None or text_encoder is None and clip_encoder is not None:
raise ValueError('Cannot specify only one of text encoder and CLIP encoder.')

self.precomputed_latents = False
self.batched_latents = []
if text_encoder:
self.precomputed_latents = True
t5_tokenizer = AutoTokenizer.from_pretrained(text_encoder, cache_dir=self.cache_dir, local_files_only=True)
clip_tokenizer = AutoTokenizer.from_pretrained(clip_encoder,
subfolder='tokenizer',
cache_dir=self.cache_dir,
local_files_only=True)

t5_model = AutoModel.from_pretrained(text_encoder,
torch_dtype=torch.float16,
cache_dir=self.cache_dir,
local_files_only=True).encoder.cuda().eval()
clip_model = CLIPTextModel.from_pretrained(clip_encoder,
subfolder='text_encoder',
torch_dtype=torch.float16,
cache_dir=self.cache_dir,
local_files_only=True).cuda().eval()

for batch in self.batched_prompts:
latent_batch = {}
tokenized_t5 = t5_tokenizer(batch,
padding='max_length',
max_length=t5_tokenizer.model.max_length,
truncation=True,
return_tensors='pt')
t5_attention_mask = tokenized_t5['attention_mask'].to(torch.bool).cuda()
t5_ids = tokenized_t5['input_ids'].cuda()
t5_latents = t5_model(input_ids=t5_ids, attention_mask=t5_attention_mask)[0].cpu()
t5_attention_mask = t5_attention_mask.cpu().to(torch.long)

tokenized_clip = clip_tokenizer(batch,
padding='max_length',
max_length=t5_tokenizer.model.max_length,
truncation=True,
return_tensors='pt')
clip_attention_mask = tokenized_clip['attention_mask'].cuda()
clip_ids = tokenized_clip['input_ids'].cuda()
clip_outputs = clip_model(input_ids=clip_ids,
attention_mask=clip_attention_mask,
output_hidden_states=True)
clip_latents = clip_outputs.hidden_states[-2].cpu()
clip_pooled = clip_outputs[-1].cpu()
clip_attention_mask = clip_attention_mask.cpu().to(torch.long)

latent_batch['T5_LATENTS'] = t5_latents
latent_batch['CLIP_LATENTS'] = clip_latents
latent_batch['ATTENTION_MASK'] = torch.cat([t5_attention_mask, clip_attention_mask], dim=1)
latent_batch['CLIP_POOLED'] = clip_pooled
self.batched_latents.append(latent_batch)

del t5_model
del clip_model
gc.collect()
torch.cuda.empty_cache()

def eval_start(self, state: State, logger: Logger):
# Get the model object if it has been wrapped by DDP to access the image generation function.
if isinstance(state.model, DistributedDataParallel):
Expand All @@ -72,17 +138,37 @@ def eval_start(self, state: State, logger: Logger):
# Generate images
with get_precision_context(state.precision):
all_gen_images = []
for batch in self.batched_prompts:
gen_images = model.generate(
prompt=batch, # type: ignore
height=self.size[0],
width=self.size[1],
guidance_scale=self.guidance_scale,
rescaled_guidance=self.rescaled_guidance,
progress_bar=False,
num_inference_steps=self.num_inference_steps,
seed=self.seed)
all_gen_images.append(gen_images)
if self.precomputed_latents:
for batch in self.batched_latents:
pooled_prompt = batch['CLIP_POOLED'].cuda()
prompt_mask = batch['ATTENTION_MASK'].cuda()
t5_embeds = model.t5_proj(batch['T5_LATENTS'].cuda())
clip_embeds = model.clip_proj(batch['CLIP_LATENTS'].cuda())
prompt_embeds = torch.cat([t5_embeds, clip_embeds], dim=1)

gen_images = model.generate(prompt_embeds=prompt_embeds,
pooled_prompt=pooled_prompt,
prompt_mask=prompt_mask,
height=self.size[0],
width=self.size[1],
guidance_scale=self.guidance_scale,
rescaled_guidance=self.rescaled_guidance,
progress_bar=False,
num_inference_steps=self.num_inference_steps,
seed=self.seed)
all_gen_images.append(gen_images)
else:
for batch in self.batched_prompts:
gen_images = model.generate(
prompt=batch, # type: ignore
height=self.size[0],
width=self.size[1],
guidance_scale=self.guidance_scale,
rescaled_guidance=self.rescaled_guidance,
progress_bar=False,
num_inference_steps=self.num_inference_steps,
seed=self.seed)
all_gen_images.append(gen_images)
gen_images = torch.cat(all_gen_images)

# Log images to wandb
Expand Down
2 changes: 1 addition & 1 deletion scripts/batched-llava-caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def resize_and_pad(self, image: Image.Image) -> Image.Image:
resize_width = round(resize_height * aspect_ratio)
else:
raise ValueError('Invalid image dimensions')
resized_image = image.resize((resize_width, resize_height), Image.Resampling.LANCZOS)
resized_image = image.resize((resize_width, resize_height), Image.LANCZOS)

# Calculate padding
pad_width_left = (self.width - resize_width) // 2
Expand Down

0 comments on commit 5e2c7a7

Please sign in to comment.