Skip to content

Commit

Permalink
Add image generator to generate images for use with geneval (#172)
Browse files Browse the repository at this point in the history
  • Loading branch information
coryMosaicML authored Oct 3, 2024
1 parent ab5a2f0 commit 4d6e4aa
Show file tree
Hide file tree
Showing 3 changed files with 281 additions and 22 deletions.
180 changes: 180 additions & 0 deletions diffusion/evaluation/generate_geneval_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Copyright 2022 MosaicML Diffusion authors
# SPDX-License-Identifier: Apache-2.0

"""Image generation for runnning evaluation with geneval."""

import json
import os
from typing import Dict, Optional, Union
from urllib.parse import urlparse

import torch
from composer.core import get_precision_context
from composer.utils import dist
from composer.utils.file_helpers import get_file
from composer.utils.object_store import OCIObjectStore
from diffusers import AutoPipelineForText2Image
from torchvision.transforms.functional import to_pil_image
from tqdm.auto import tqdm


class GenevalImageGenerator:
"""Image generator that generates images from the geneval prompt set and saves them.
Args:
model (torch.nn.Module): The model to evaluate.
geneval_prompts (str): Path to the prompts to use for geneval (ex: `geneval/prompts/evaluation_metadata.json`).
load_path (str, optional): The path to load the model from. Default: ``None``.
local_checkpoint_path (str, optional): The local path to save the model checkpoint. Default: ``'/tmp/model.pt'``.
load_strict_model_weights (bool): Whether or not to strict load model weights. Default: ``True``.
guidance_scale (float): The guidance scale to use for evaluation. Default: ``7.0``.
height (int): The height of the generated images. Default: ``1024``.
width (int): The width of the generated images. Default: ``1024``.
images_per_prompt (int): The number of images to generate per prompt. Default: ``4``.
load_strict_model_weights (bool): Whether or not to strict load model weights. Default: ``True``.
seed (int): The seed to use for generation. Default: ``17``.
output_bucket (str, Optional): The remote to save images to. Default: ``None``.
output_prefix (str, Optional): The prefix to save images to. Default: ``None``.
local_prefix (str): The local prefix to save images to. Default: ``/tmp``.
additional_generate_kwargs (Dict, optional): Additional keyword arguments to pass to the model.generate method.
hf_model: (bool, Optional): whether the model is HF or not. Default: ``False``.
"""

def __init__(self,
model: Union[torch.nn.Module, str],
geneval_prompts: str,
load_path: Optional[str] = None,
local_checkpoint_path: str = '/tmp/model.pt',
load_strict_model_weights: bool = True,
guidance_scale: float = 7.0,
height: int = 1024,
width: int = 1024,
images_per_prompt: int = 4,
seed: int = 17,
output_bucket: Optional[str] = None,
output_prefix: Optional[str] = None,
local_prefix: str = '/tmp',
additional_generate_kwargs: Optional[Dict] = None,
hf_model: Optional[bool] = False):

if isinstance(model, str) and hf_model == False:
raise ValueError('Can only use strings for model with hf models!')
self.hf_model = hf_model
if hf_model or isinstance(model, str):
if dist.get_local_rank() == 0:
self.model = AutoPipelineForText2Image.from_pretrained(
model, torch_dtype=torch.float16).to(f'cuda:{dist.get_local_rank()}')
dist.barrier()
self.model = AutoPipelineForText2Image.from_pretrained(
model, torch_dtype=torch.float16).to(f'cuda:{dist.get_local_rank()}')
dist.barrier()
else:
self.model = model
# Load the geneval prompts
self.geneval_prompts = geneval_prompts
with open(geneval_prompts) as f:
self.prompt_metadata = [json.loads(line) for line in f]
self.load_path = load_path
self.local_checkpoint_path = local_checkpoint_path
self.load_strict_model_weights = load_strict_model_weights
self.guidance_scale = guidance_scale
self.height = height
self.width = width
self.images_per_prompt = images_per_prompt
self.seed = seed
self.generator = torch.Generator(device='cuda').manual_seed(self.seed)

self.output_bucket = output_bucket
self.output_prefix = output_prefix if output_prefix is not None else ''
self.local_prefix = local_prefix
self.additional_generate_kwargs = additional_generate_kwargs if additional_generate_kwargs is not None else {}

# Object store for uploading images
if self.output_bucket is not None:
parsed_remote_bucket = urlparse(self.output_bucket)
if parsed_remote_bucket.scheme != 'oci':
raise ValueError(f'Currently only OCI object stores are supported. Got {parsed_remote_bucket.scheme}.')
self.object_store = OCIObjectStore(self.output_bucket.replace('oci://', ''), self.output_prefix)

# Download the model checkpoint if needed
if self.load_path is not None and not isinstance(self.model, str):
if dist.get_local_rank() == 0:
get_file(path=self.load_path, destination=self.local_checkpoint_path, overwrite=True)
with dist.local_rank_zero_download_and_wait(self.local_checkpoint_path):
# Load the model
state_dict = torch.load(self.local_checkpoint_path, map_location='cpu')
for key in list(state_dict['state']['model'].keys()):
if 'val_metrics.' in key:
del state_dict['state']['model'][key]
self.model.load_state_dict(state_dict['state']['model'], strict=self.load_strict_model_weights)
self.model = self.model.cuda().eval()

def generate(self):
"""Core image generation function. Generates images at a given guidance scale.
Args:
guidance_scale (float): The guidance scale to use for image generation.
"""
os.makedirs(os.path.join(self.local_prefix, self.output_prefix), exist_ok=True)
# Partition the dataset across the ranks. Note this partitions prompts, not repeats.
dataset_len = len(self.prompt_metadata)
samples_per_rank, remainder = divmod(dataset_len, dist.get_world_size())
start_idx = dist.get_global_rank() * samples_per_rank + min(remainder, dist.get_global_rank())
end_idx = start_idx + samples_per_rank
if dist.get_global_rank() < remainder:
end_idx += 1
print(f'Rank {dist.get_global_rank()} processing samples {start_idx} to {end_idx} of {dataset_len} total.')
# Iterate over the dataset
for sample_id in tqdm(range(start_idx, end_idx)):
metadata = self.prompt_metadata[sample_id]
# Write the metadata jsonl
output_dir = os.path.join(self.local_prefix, f'{sample_id:0>5}')
os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, 'metadata.jsonl'), 'w') as f:
json.dump(metadata, f)
caption = metadata['prompt']
# Create dir for samples to live in
sample_dir = os.path.join(output_dir, 'samples')
os.makedirs(sample_dir, exist_ok=True)
# Generate images from the captions. Take care to use a different seed for each image
for i in range(self.images_per_prompt):
seed = self.seed + i
if self.hf_model:
generated_image = self.model(prompt=caption,
height=self.height,
width=self.width,
guidance_scale=self.guidance_scale,
generator=self.generator,
**self.additional_generate_kwargs).images[0]
img = generated_image
else:
with get_precision_context('amp_fp16'):
generated_image = self.model.generate(prompt=caption,
height=self.height,
width=self.width,
guidance_scale=self.guidance_scale,
seed=seed,
progress_bar=False,
**self.additional_generate_kwargs) # type: ignore
img = to_pil_image(generated_image[0])
# Save the images and metadata locally
image_name = f'{i:05}.png'
data_name = f'{i:05}.json'
img_local_path = os.path.join(sample_dir, image_name)
data_local_path = os.path.join(sample_dir, data_name)
img.save(img_local_path)
metadata = {
'image_name': image_name,
'prompt': caption,
'guidance_scale': self.guidance_scale,
'seed': seed
}
json.dump(metadata, open(f'{data_local_path}', 'w'))
# Upload the image and metadata to cloud storage
output_sample_prefix = os.path.join(self.output_prefix, f'{sample_id:0>5}', 'samples')
if self.output_bucket is not None:
self.object_store.upload_object(object_name=os.path.join(output_sample_prefix, image_name),
filename=img_local_path)
# Upload the metadata
self.object_store.upload_object(object_name=os.path.join(output_sample_prefix, data_name),
filename=data_local_path)
46 changes: 24 additions & 22 deletions diffusion/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""Generate images from a model."""

import operator
from typing import List
from typing import Any, List, Optional

import hydra
from composer import Algorithm, ComposerModel
Expand All @@ -16,7 +16,20 @@
from omegaconf import DictConfig
from torch.utils.data import Dataset

from diffusion.evaluation.generate_images import ImageGenerator

def _make_dataset(config: DictConfig, tokenizer: Optional[Any] = None) -> Dataset:
if config.hf_dataset:
if dist.get_local_rank() == 0:
dataset = load_dataset(config.dataset.name, split=config.dataset.split)
dist.barrier()
dataset = load_dataset(config.dataset.name, split=config.dataset.split)
dist.barrier()
elif tokenizer:
dataset = hydra.utils.instantiate(config.dataset)

else:
dataset: Dataset = hydra.utils.instantiate(config.dataset)
return dataset


def generate(config: DictConfig) -> None:
Expand All @@ -37,20 +50,6 @@ def generate(config: DictConfig) -> None:

tokenizer = model.tokenizer if hasattr(model, 'tokenizer') else None

# The dataset to use for evaluation

if config.hf_dataset:
if dist.get_local_rank() == 0:
dataset = load_dataset(config.dataset.name, split=config.dataset.split)
dist.barrier()
dataset = load_dataset(config.dataset.name, split=config.dataset.split)
dist.barrier()
elif tokenizer:
dataset = hydra.utils.instantiate(config.dataset)

else:
dataset: Dataset = hydra.utils.instantiate(config.dataset)

# Build list of algorithms.
algorithms: List[Algorithm] = []

Expand Down Expand Up @@ -78,12 +77,15 @@ def generate(config: DictConfig) -> None:
precision=Precision(ag_conf['precision']),
optimizers=None,
)

image_generator: ImageGenerator = hydra.utils.instantiate(config.generator,
model=model,
dataset=dataset,
hf_model=config.hf_model,
hf_dataset=config.hf_dataset)
if 'dataset' in config:
dataset = _make_dataset(config, tokenizer)
image_generator = hydra.utils.instantiate(config.generator,
model=model,
dataset=dataset,
hf_model=config.hf_model,
hf_dataset=config.hf_dataset)
else:
image_generator = hydra.utils.instantiate(config.generator, model=model, hf_model=config.hf_model)

def generate_from_model():
image_generator.generate()
Expand Down
77 changes: 77 additions & 0 deletions yamls/mosaic-yamls/geneval-flux-1-schnell.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Example yaml for running geneval on FLUX.1-schnell model
name: geneval-flux-1-schnell
compute:
cluster: # your cluster name
instance: # your instance name
gpus: # number of gpus
env_variables:
HYDRA_FULL_ERROR: '1'
image: mosaicml/pytorch:2.4.0_cu124-python3.11-ubuntu20.04
scheduling:
resumable: false
priority: medium
max_retries: 0
integrations:
- integration_type: git_repo
git_repo: mosaicml/diffusion
git_branch: main
pip_install: .[all] --no-deps # We install with no deps to use only specific deps needed for geneval
- integration_type: pip_packages
packages:
- huggingface-hub[hf_transfer]>=0.23.2
- numpy==1.26.4
- pandas
- open_clip_torch
- clip-benchmark
- openmim
- sentencepiece
- mosaicml
- mosaicml-streaming
- hydra-core
- hydra-colorlog
- diffusers[torch]==0.30.3
- transformers[torch]==4.44.2
- torchmetrics[image]
- lpips
- clean-fid
- gradio
- datasets
- peft
command: 'cd diffusion
pip install clip@git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33
mim install mmengine mmcv-full==1.7.2
apt-get update && apt-get install libgl1-mesa-glx -y
git clone https://github.com/djghosh13/geneval.git
git clone https://github.com/open-mmlab/mmdetection.git
cd mmdetection; git checkout 2.x; pip install -v -e .; cd ..
composer run_generation.py --config-path /mnt/config --config-name parameters
cd geneval
./evaluation/download_models.sh eval_models
python evaluation/evaluate_images.py /tmp/geneval-images --outfile outputs.jsonl --model-path eval_models
python evaluation/summary_scores.py outputs.jsonl
'
parameters:
seed: 18
dist_timeout: 300
hf_model: true # We will use a model from huggingface
model:
name: black-forest-labs/FLUX.1-schnell # Model name from huggingface
generator:
_target_: diffusion.evaluation.generate_geneval_images.GenevalImageGenerator
geneval_prompts: geneval/prompts/evaluation_metadata.jsonl # Path to geneval prompts json
height: 1024 # Generated image height
width: 1024 # Generated image width
local_prefix: /tmp/geneval-images # Local path to save images to. Needed for geneval to read images from.
output_bucket: # Your output oci bucket name (optional)
output_prefix: # Your output prefix (optional)

0 comments on commit 4d6e4aa

Please sign in to comment.