From 4d6e4aaff9644980c15a3d34f62daca316840356 Mon Sep 17 00:00:00 2001 From: coryMosaicML <83666378+coryMosaicML@users.noreply.github.com> Date: Thu, 3 Oct 2024 14:21:07 -0700 Subject: [PATCH] Add image generator to generate images for use with geneval (#172) --- .../evaluation/generate_geneval_images.py | 180 ++++++++++++++++++ diffusion/generate.py | 46 ++--- .../mosaic-yamls/geneval-flux-1-schnell.yaml | 77 ++++++++ 3 files changed, 281 insertions(+), 22 deletions(-) create mode 100644 diffusion/evaluation/generate_geneval_images.py create mode 100644 yamls/mosaic-yamls/geneval-flux-1-schnell.yaml diff --git a/diffusion/evaluation/generate_geneval_images.py b/diffusion/evaluation/generate_geneval_images.py new file mode 100644 index 00000000..8b64e46a --- /dev/null +++ b/diffusion/evaluation/generate_geneval_images.py @@ -0,0 +1,180 @@ +# Copyright 2022 MosaicML Diffusion authors +# SPDX-License-Identifier: Apache-2.0 + +"""Image generation for runnning evaluation with geneval.""" + +import json +import os +from typing import Dict, Optional, Union +from urllib.parse import urlparse + +import torch +from composer.core import get_precision_context +from composer.utils import dist +from composer.utils.file_helpers import get_file +from composer.utils.object_store import OCIObjectStore +from diffusers import AutoPipelineForText2Image +from torchvision.transforms.functional import to_pil_image +from tqdm.auto import tqdm + + +class GenevalImageGenerator: + """Image generator that generates images from the geneval prompt set and saves them. + + Args: + model (torch.nn.Module): The model to evaluate. + geneval_prompts (str): Path to the prompts to use for geneval (ex: `geneval/prompts/evaluation_metadata.json`). + load_path (str, optional): The path to load the model from. Default: ``None``. + local_checkpoint_path (str, optional): The local path to save the model checkpoint. Default: ``'/tmp/model.pt'``. + load_strict_model_weights (bool): Whether or not to strict load model weights. Default: ``True``. + guidance_scale (float): The guidance scale to use for evaluation. Default: ``7.0``. + height (int): The height of the generated images. Default: ``1024``. + width (int): The width of the generated images. Default: ``1024``. + images_per_prompt (int): The number of images to generate per prompt. Default: ``4``. + load_strict_model_weights (bool): Whether or not to strict load model weights. Default: ``True``. + seed (int): The seed to use for generation. Default: ``17``. + output_bucket (str, Optional): The remote to save images to. Default: ``None``. + output_prefix (str, Optional): The prefix to save images to. Default: ``None``. + local_prefix (str): The local prefix to save images to. Default: ``/tmp``. + additional_generate_kwargs (Dict, optional): Additional keyword arguments to pass to the model.generate method. + hf_model: (bool, Optional): whether the model is HF or not. Default: ``False``. + """ + + def __init__(self, + model: Union[torch.nn.Module, str], + geneval_prompts: str, + load_path: Optional[str] = None, + local_checkpoint_path: str = '/tmp/model.pt', + load_strict_model_weights: bool = True, + guidance_scale: float = 7.0, + height: int = 1024, + width: int = 1024, + images_per_prompt: int = 4, + seed: int = 17, + output_bucket: Optional[str] = None, + output_prefix: Optional[str] = None, + local_prefix: str = '/tmp', + additional_generate_kwargs: Optional[Dict] = None, + hf_model: Optional[bool] = False): + + if isinstance(model, str) and hf_model == False: + raise ValueError('Can only use strings for model with hf models!') + self.hf_model = hf_model + if hf_model or isinstance(model, str): + if dist.get_local_rank() == 0: + self.model = AutoPipelineForText2Image.from_pretrained( + model, torch_dtype=torch.float16).to(f'cuda:{dist.get_local_rank()}') + dist.barrier() + self.model = AutoPipelineForText2Image.from_pretrained( + model, torch_dtype=torch.float16).to(f'cuda:{dist.get_local_rank()}') + dist.barrier() + else: + self.model = model + # Load the geneval prompts + self.geneval_prompts = geneval_prompts + with open(geneval_prompts) as f: + self.prompt_metadata = [json.loads(line) for line in f] + self.load_path = load_path + self.local_checkpoint_path = local_checkpoint_path + self.load_strict_model_weights = load_strict_model_weights + self.guidance_scale = guidance_scale + self.height = height + self.width = width + self.images_per_prompt = images_per_prompt + self.seed = seed + self.generator = torch.Generator(device='cuda').manual_seed(self.seed) + + self.output_bucket = output_bucket + self.output_prefix = output_prefix if output_prefix is not None else '' + self.local_prefix = local_prefix + self.additional_generate_kwargs = additional_generate_kwargs if additional_generate_kwargs is not None else {} + + # Object store for uploading images + if self.output_bucket is not None: + parsed_remote_bucket = urlparse(self.output_bucket) + if parsed_remote_bucket.scheme != 'oci': + raise ValueError(f'Currently only OCI object stores are supported. Got {parsed_remote_bucket.scheme}.') + self.object_store = OCIObjectStore(self.output_bucket.replace('oci://', ''), self.output_prefix) + + # Download the model checkpoint if needed + if self.load_path is not None and not isinstance(self.model, str): + if dist.get_local_rank() == 0: + get_file(path=self.load_path, destination=self.local_checkpoint_path, overwrite=True) + with dist.local_rank_zero_download_and_wait(self.local_checkpoint_path): + # Load the model + state_dict = torch.load(self.local_checkpoint_path, map_location='cpu') + for key in list(state_dict['state']['model'].keys()): + if 'val_metrics.' in key: + del state_dict['state']['model'][key] + self.model.load_state_dict(state_dict['state']['model'], strict=self.load_strict_model_weights) + self.model = self.model.cuda().eval() + + def generate(self): + """Core image generation function. Generates images at a given guidance scale. + + Args: + guidance_scale (float): The guidance scale to use for image generation. + """ + os.makedirs(os.path.join(self.local_prefix, self.output_prefix), exist_ok=True) + # Partition the dataset across the ranks. Note this partitions prompts, not repeats. + dataset_len = len(self.prompt_metadata) + samples_per_rank, remainder = divmod(dataset_len, dist.get_world_size()) + start_idx = dist.get_global_rank() * samples_per_rank + min(remainder, dist.get_global_rank()) + end_idx = start_idx + samples_per_rank + if dist.get_global_rank() < remainder: + end_idx += 1 + print(f'Rank {dist.get_global_rank()} processing samples {start_idx} to {end_idx} of {dataset_len} total.') + # Iterate over the dataset + for sample_id in tqdm(range(start_idx, end_idx)): + metadata = self.prompt_metadata[sample_id] + # Write the metadata jsonl + output_dir = os.path.join(self.local_prefix, f'{sample_id:0>5}') + os.makedirs(output_dir, exist_ok=True) + with open(os.path.join(output_dir, 'metadata.jsonl'), 'w') as f: + json.dump(metadata, f) + caption = metadata['prompt'] + # Create dir for samples to live in + sample_dir = os.path.join(output_dir, 'samples') + os.makedirs(sample_dir, exist_ok=True) + # Generate images from the captions. Take care to use a different seed for each image + for i in range(self.images_per_prompt): + seed = self.seed + i + if self.hf_model: + generated_image = self.model(prompt=caption, + height=self.height, + width=self.width, + guidance_scale=self.guidance_scale, + generator=self.generator, + **self.additional_generate_kwargs).images[0] + img = generated_image + else: + with get_precision_context('amp_fp16'): + generated_image = self.model.generate(prompt=caption, + height=self.height, + width=self.width, + guidance_scale=self.guidance_scale, + seed=seed, + progress_bar=False, + **self.additional_generate_kwargs) # type: ignore + img = to_pil_image(generated_image[0]) + # Save the images and metadata locally + image_name = f'{i:05}.png' + data_name = f'{i:05}.json' + img_local_path = os.path.join(sample_dir, image_name) + data_local_path = os.path.join(sample_dir, data_name) + img.save(img_local_path) + metadata = { + 'image_name': image_name, + 'prompt': caption, + 'guidance_scale': self.guidance_scale, + 'seed': seed + } + json.dump(metadata, open(f'{data_local_path}', 'w')) + # Upload the image and metadata to cloud storage + output_sample_prefix = os.path.join(self.output_prefix, f'{sample_id:0>5}', 'samples') + if self.output_bucket is not None: + self.object_store.upload_object(object_name=os.path.join(output_sample_prefix, image_name), + filename=img_local_path) + # Upload the metadata + self.object_store.upload_object(object_name=os.path.join(output_sample_prefix, data_name), + filename=data_local_path) diff --git a/diffusion/generate.py b/diffusion/generate.py index f03b3190..1f5b18d9 100644 --- a/diffusion/generate.py +++ b/diffusion/generate.py @@ -4,7 +4,7 @@ """Generate images from a model.""" import operator -from typing import List +from typing import Any, List, Optional import hydra from composer import Algorithm, ComposerModel @@ -16,7 +16,20 @@ from omegaconf import DictConfig from torch.utils.data import Dataset -from diffusion.evaluation.generate_images import ImageGenerator + +def _make_dataset(config: DictConfig, tokenizer: Optional[Any] = None) -> Dataset: + if config.hf_dataset: + if dist.get_local_rank() == 0: + dataset = load_dataset(config.dataset.name, split=config.dataset.split) + dist.barrier() + dataset = load_dataset(config.dataset.name, split=config.dataset.split) + dist.barrier() + elif tokenizer: + dataset = hydra.utils.instantiate(config.dataset) + + else: + dataset: Dataset = hydra.utils.instantiate(config.dataset) + return dataset def generate(config: DictConfig) -> None: @@ -37,20 +50,6 @@ def generate(config: DictConfig) -> None: tokenizer = model.tokenizer if hasattr(model, 'tokenizer') else None - # The dataset to use for evaluation - - if config.hf_dataset: - if dist.get_local_rank() == 0: - dataset = load_dataset(config.dataset.name, split=config.dataset.split) - dist.barrier() - dataset = load_dataset(config.dataset.name, split=config.dataset.split) - dist.barrier() - elif tokenizer: - dataset = hydra.utils.instantiate(config.dataset) - - else: - dataset: Dataset = hydra.utils.instantiate(config.dataset) - # Build list of algorithms. algorithms: List[Algorithm] = [] @@ -78,12 +77,15 @@ def generate(config: DictConfig) -> None: precision=Precision(ag_conf['precision']), optimizers=None, ) - - image_generator: ImageGenerator = hydra.utils.instantiate(config.generator, - model=model, - dataset=dataset, - hf_model=config.hf_model, - hf_dataset=config.hf_dataset) + if 'dataset' in config: + dataset = _make_dataset(config, tokenizer) + image_generator = hydra.utils.instantiate(config.generator, + model=model, + dataset=dataset, + hf_model=config.hf_model, + hf_dataset=config.hf_dataset) + else: + image_generator = hydra.utils.instantiate(config.generator, model=model, hf_model=config.hf_model) def generate_from_model(): image_generator.generate() diff --git a/yamls/mosaic-yamls/geneval-flux-1-schnell.yaml b/yamls/mosaic-yamls/geneval-flux-1-schnell.yaml new file mode 100644 index 00000000..6c1c80cd --- /dev/null +++ b/yamls/mosaic-yamls/geneval-flux-1-schnell.yaml @@ -0,0 +1,77 @@ +# Example yaml for running geneval on FLUX.1-schnell model +name: geneval-flux-1-schnell +compute: + cluster: # your cluster name + instance: # your instance name + gpus: # number of gpus +env_variables: + HYDRA_FULL_ERROR: '1' +image: mosaicml/pytorch:2.4.0_cu124-python3.11-ubuntu20.04 +scheduling: + resumable: false + priority: medium + max_retries: 0 +integrations: +- integration_type: git_repo + git_repo: mosaicml/diffusion + git_branch: main + pip_install: .[all] --no-deps # We install with no deps to use only specific deps needed for geneval +- integration_type: pip_packages + packages: + - huggingface-hub[hf_transfer]>=0.23.2 + - numpy==1.26.4 + - pandas + - open_clip_torch + - clip-benchmark + - openmim + - sentencepiece + - mosaicml + - mosaicml-streaming + - hydra-core + - hydra-colorlog + - diffusers[torch]==0.30.3 + - transformers[torch]==4.44.2 + - torchmetrics[image] + - lpips + - clean-fid + - gradio + - datasets + - peft +command: 'cd diffusion + + pip install clip@git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33 + + mim install mmengine mmcv-full==1.7.2 + + apt-get update && apt-get install libgl1-mesa-glx -y + + git clone https://github.com/djghosh13/geneval.git + + git clone https://github.com/open-mmlab/mmdetection.git + + cd mmdetection; git checkout 2.x; pip install -v -e .; cd .. + + composer run_generation.py --config-path /mnt/config --config-name parameters + + cd geneval + + ./evaluation/download_models.sh eval_models + + python evaluation/evaluate_images.py /tmp/geneval-images --outfile outputs.jsonl --model-path eval_models + + python evaluation/summary_scores.py outputs.jsonl + ' +parameters: + seed: 18 + dist_timeout: 300 + hf_model: true # We will use a model from huggingface + model: + name: black-forest-labs/FLUX.1-schnell # Model name from huggingface + generator: + _target_: diffusion.evaluation.generate_geneval_images.GenevalImageGenerator + geneval_prompts: geneval/prompts/evaluation_metadata.jsonl # Path to geneval prompts json + height: 1024 # Generated image height + width: 1024 # Generated image width + local_prefix: /tmp/geneval-images # Local path to save images to. Needed for geneval to read images from. + output_bucket: # Your output oci bucket name (optional) + output_prefix: # Your output prefix (optional)