From 1955a49c390557ecabb17bf5586e22b6c9abb5bc Mon Sep 17 00:00:00 2001 From: Landan Seguin Date: Tue, 21 May 2024 16:49:55 -0700 Subject: [PATCH 01/10] add precompute script --- scripts/t5_precompute.py | 80 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 scripts/t5_precompute.py diff --git a/scripts/t5_precompute.py b/scripts/t5_precompute.py new file mode 100644 index 00000000..845f04a6 --- /dev/null +++ b/scripts/t5_precompute.py @@ -0,0 +1,80 @@ +import os +import json +import torch +from streaming import MDSWriter, StreamingDataset +from transformers import AutoTokenizer, AutoModel, CLIPTextModel +from tqdm import trange +from argparse import ArgumentParser + +# TODO: Implement batching? 10% faster (when using t5-only), but a lot more complicated code + +arg_parser = ArgumentParser() +arg_parser.add_argument('--remote_src_base', type=str, required=True, help='Remote base to download MDS-formatted shards.') +arg_parser.add_argument('--remote_dst_base', type=str, required=True, help='Remote base to write MDS-formatted shards.') +arg_parser.add_argument('--subdir_path', type=str, required=True, help='Path to the subdirectory to process.') +args = arg_parser.parse_args() + +remote_src = os.path.join(args.remote_src_base, args.subdir_path) +remote_dst = os.path.join(args.remote_dst_base, args.subdir_path) + +# Dataset +print('Building dataset') +dataset = StreamingDataset(remote=remote_src, local=os.path.join('/tmp', args.subdir_path), download_timeout=300, shuffle=False) + +# Instantiate tokenizers +t5_tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-xxl') +clip_tokenizer = AutoTokenizer.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder='tokenizer') + +print('Building models') +t5_model = AutoModel.from_pretrained('google/t5-v1_1-xxl', torch_dtype=torch.float16).encoder.cuda().eval() +clip_model = CLIPTextModel.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder='text_encoder', torch_dtype=torch.float16).cuda().eval() + +# Get columns +with open(os.path.join('/tmp/', args.subdir_path, 'index.json')) as f: + index_json = json.load(f) +columns = {k: v for k, v in zip(index_json['shards'][0]['column_names'], index_json['shards'][0]['column_encodings'])} +columns['T5_ATTENTION_MASK'] = 'bytes' +columns['T5_LATENTS'] = 'bytes' +columns['CLIP_ATTENTION_MASK'] = 'bytes' +columns['CLIP_LATENTS'] = 'bytes' +columns['CLIP_POOLED_TEXT'] = 'bytes' +print(columns) + +# Make writer +writer = MDSWriter(out=remote_dst, columns=columns, compression='zstd', hashes=[], size_limit='1GB') + +print('Loading batch') +with torch.no_grad(): + for i in trange(len(dataset)): + sample = dataset[i] + captions = sample['DESCRIPTION'] + + # Pre-compute T5 + t5_tokenizer_out = t5_tokenizer(captions, + padding='max_length', + max_length=t5_tokenizer.model_max_length, + truncation=True, + return_tensors='pt') + tokenized_captions = t5_tokenizer_out['input_ids'].cuda() + attention_masks = t5_tokenizer_out['attention_mask'].to(torch.bool).cuda() + sample['T5_ATTENTION_MASK'] = t5_tokenizer_out['attention_mask'].squeeze(0).numpy().tobytes() + t5_out = t5_model(input_ids=tokenized_captions, attention_mask=attention_masks) + sample['T5_LATENTS'] = t5_out[0].squeeze(0).cpu().numpy().tobytes() + + # Pre-compute CLIP + clip_tokenizer_out = clip_tokenizer(captions, + padding='max_length', + max_length=clip_tokenizer.model_max_length, + truncation=True, + return_tensors='pt') + tokenized_captions = clip_tokenizer_out['input_ids'].cuda() + attention_masks = clip_tokenizer_out['attention_mask'].cuda() + sample['CLIP_ATTENTION_MASK'] = clip_tokenizer_out['attention_mask'].squeeze(0).to(torch.bool).numpy().tobytes() + clip_out = clip_model(input_ids=tokenized_captions, attention_mask=attention_masks, output_hidden_states=True) + sample['CLIP_LATENTS'] = clip_out.hidden_states[-2].squeeze(0).cpu().numpy().tobytes() + sample['CLIP_POOLED_TEXT'] = clip_out[1].squeeze(0).cpu().numpy().tobytes() + + writer.write(sample) +writer.finish() + + From 6fe47ecac9b2d2abe46a5aab6ca3c926a2c3d3f2 Mon Sep 17 00:00:00 2001 From: Landan Seguin Date: Tue, 28 May 2024 14:06:22 -0700 Subject: [PATCH 02/10] Delete extra endlines --- scripts/t5_precompute.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/scripts/t5_precompute.py b/scripts/t5_precompute.py index 845f04a6..f7ced660 100644 --- a/scripts/t5_precompute.py +++ b/scripts/t5_precompute.py @@ -76,5 +76,3 @@ writer.write(sample) writer.finish() - - From 33648d9ff2f405996c31a72dc1be937b259da922 Mon Sep 17 00:00:00 2001 From: Landan Seguin Date: Wed, 29 May 2024 13:43:53 -0700 Subject: [PATCH 03/10] Process multiple subdirs per run --- scripts/t5_precompute.py | 102 ++++++++++++++++++++------------------- 1 file changed, 53 insertions(+), 49 deletions(-) diff --git a/scripts/t5_precompute.py b/scripts/t5_precompute.py index f7ced660..d4f31781 100644 --- a/scripts/t5_precompute.py +++ b/scripts/t5_precompute.py @@ -11,15 +11,10 @@ arg_parser = ArgumentParser() arg_parser.add_argument('--remote_src_base', type=str, required=True, help='Remote base to download MDS-formatted shards.') arg_parser.add_argument('--remote_dst_base', type=str, required=True, help='Remote base to write MDS-formatted shards.') -arg_parser.add_argument('--subdir_path', type=str, required=True, help='Path to the subdirectory to process.') +arg_parser.add_argument('--subdir_paths', nargs='+', type=str, required=True, help='Path to the subdirectory to process.') args = arg_parser.parse_args() -remote_src = os.path.join(args.remote_src_base, args.subdir_path) -remote_dst = os.path.join(args.remote_dst_base, args.subdir_path) -# Dataset -print('Building dataset') -dataset = StreamingDataset(remote=remote_src, local=os.path.join('/tmp', args.subdir_path), download_timeout=300, shuffle=False) # Instantiate tokenizers t5_tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-xxl') @@ -29,50 +24,59 @@ t5_model = AutoModel.from_pretrained('google/t5-v1_1-xxl', torch_dtype=torch.float16).encoder.cuda().eval() clip_model = CLIPTextModel.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder='text_encoder', torch_dtype=torch.float16).cuda().eval() -# Get columns -with open(os.path.join('/tmp/', args.subdir_path, 'index.json')) as f: - index_json = json.load(f) -columns = {k: v for k, v in zip(index_json['shards'][0]['column_names'], index_json['shards'][0]['column_encodings'])} -columns['T5_ATTENTION_MASK'] = 'bytes' -columns['T5_LATENTS'] = 'bytes' -columns['CLIP_ATTENTION_MASK'] = 'bytes' -columns['CLIP_LATENTS'] = 'bytes' -columns['CLIP_POOLED_TEXT'] = 'bytes' -print(columns) +columns = None +for subdir_path in args.subdir_paths: + remote_src = os.path.join(args.remote_src_base, subdir_path) + remote_dst = os.path.join(args.remote_dst_base, subdir_path) + # Dataset + print('Building dataset') + dataset = StreamingDataset(remote=remote_src, local=os.path.join('/tmp', subdir_path), download_timeout=300, shuffle=False) -# Make writer -writer = MDSWriter(out=remote_dst, columns=columns, compression='zstd', hashes=[], size_limit='1GB') + # Get columns + if columns is None: + with open(os.path.join('/tmp/', subdir_path, 'index.json')) as f: + index_json = json.load(f) + columns = {k: v for k, v in zip(index_json['shards'][0]['column_names'], index_json['shards'][0]['column_encodings'])} + columns['T5_ATTENTION_MASK'] = 'bytes' + columns['T5_LATENTS'] = 'bytes' + columns['CLIP_ATTENTION_MASK'] = 'bytes' + columns['CLIP_LATENTS'] = 'bytes' + columns['CLIP_POOLED_TEXT'] = 'bytes' + print(columns) -print('Loading batch') -with torch.no_grad(): - for i in trange(len(dataset)): - sample = dataset[i] - captions = sample['DESCRIPTION'] - - # Pre-compute T5 - t5_tokenizer_out = t5_tokenizer(captions, - padding='max_length', - max_length=t5_tokenizer.model_max_length, - truncation=True, - return_tensors='pt') - tokenized_captions = t5_tokenizer_out['input_ids'].cuda() - attention_masks = t5_tokenizer_out['attention_mask'].to(torch.bool).cuda() - sample['T5_ATTENTION_MASK'] = t5_tokenizer_out['attention_mask'].squeeze(0).numpy().tobytes() - t5_out = t5_model(input_ids=tokenized_captions, attention_mask=attention_masks) - sample['T5_LATENTS'] = t5_out[0].squeeze(0).cpu().numpy().tobytes() + # Make writer + writer = MDSWriter(out=remote_dst, columns=columns, compression='zstd', hashes=[], size_limit='1GB') - # Pre-compute CLIP - clip_tokenizer_out = clip_tokenizer(captions, - padding='max_length', - max_length=clip_tokenizer.model_max_length, - truncation=True, - return_tensors='pt') - tokenized_captions = clip_tokenizer_out['input_ids'].cuda() - attention_masks = clip_tokenizer_out['attention_mask'].cuda() - sample['CLIP_ATTENTION_MASK'] = clip_tokenizer_out['attention_mask'].squeeze(0).to(torch.bool).numpy().tobytes() - clip_out = clip_model(input_ids=tokenized_captions, attention_mask=attention_masks, output_hidden_states=True) - sample['CLIP_LATENTS'] = clip_out.hidden_states[-2].squeeze(0).cpu().numpy().tobytes() - sample['CLIP_POOLED_TEXT'] = clip_out[1].squeeze(0).cpu().numpy().tobytes() + print('Loading batch') + with torch.no_grad(): + for i in trange(len(dataset)): + sample = dataset[i] + captions = sample['DESCRIPTION'] - writer.write(sample) -writer.finish() + # Pre-compute T5 + t5_tokenizer_out = t5_tokenizer(captions, + padding='max_length', + max_length=t5_tokenizer.model_max_length, + truncation=True, + return_tensors='pt') + tokenized_captions = t5_tokenizer_out['input_ids'].cuda() + attention_masks = t5_tokenizer_out['attention_mask'].to(torch.bool).cuda() + sample['T5_ATTENTION_MASK'] = t5_tokenizer_out['attention_mask'].squeeze(0).numpy().tobytes() + t5_out = t5_model(input_ids=tokenized_captions, attention_mask=attention_masks) + sample['T5_LATENTS'] = t5_out[0].squeeze(0).cpu().numpy().tobytes() + + # Pre-compute CLIP + clip_tokenizer_out = clip_tokenizer(captions, + padding='max_length', + max_length=clip_tokenizer.model_max_length, + truncation=True, + return_tensors='pt') + tokenized_captions = clip_tokenizer_out['input_ids'].cuda() + attention_masks = clip_tokenizer_out['attention_mask'].cuda() + sample['CLIP_ATTENTION_MASK'] = clip_tokenizer_out['attention_mask'].squeeze(0).to(torch.bool).numpy().tobytes() + clip_out = clip_model(input_ids=tokenized_captions, attention_mask=attention_masks, output_hidden_states=True) + sample['CLIP_LATENTS'] = clip_out.hidden_states[-2].squeeze(0).cpu().numpy().tobytes() + sample['CLIP_POOLED_TEXT'] = clip_out[1].squeeze(0).cpu().numpy().tobytes() + + writer.write(sample) + writer.finish() From 84b40f1382124a31c143d9e5d2e24c16fc754818 Mon Sep 17 00:00:00 2001 From: Landan Seguin Date: Wed, 29 May 2024 14:37:58 -0700 Subject: [PATCH 04/10] Add dist to download on rank 0 only --- scripts/t5_precompute.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/scripts/t5_precompute.py b/scripts/t5_precompute.py index d4f31781..002c4884 100644 --- a/scripts/t5_precompute.py +++ b/scripts/t5_precompute.py @@ -5,6 +5,7 @@ from transformers import AutoTokenizer, AutoModel, CLIPTextModel from tqdm import trange from argparse import ArgumentParser +from composer.utils import dist # TODO: Implement batching? 10% faster (when using t5-only), but a lot more complicated code @@ -17,13 +18,17 @@ # Instantiate tokenizers -t5_tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-xxl') -clip_tokenizer = AutoTokenizer.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder='tokenizer') - print('Building models') -t5_model = AutoModel.from_pretrained('google/t5-v1_1-xxl', torch_dtype=torch.float16).encoder.cuda().eval() -clip_model = CLIPTextModel.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder='text_encoder', torch_dtype=torch.float16).cuda().eval() +with dist.run_local_rank_zero_first(): + t5_tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-xxl') + clip_tokenizer = AutoTokenizer.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder='tokenizer') + + t5_model = AutoModel.from_pretrained('google/t5-v1_1-xxl', torch_dtype=torch.float16) + clip_model = CLIPTextModel.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder='text_encoder', torch_dtype=torch.float16) +print('Moving models to GPUs') +t5_model = t5_model.encoder.cuda().eval() +clip_model = clip_model.cuda().eval() columns = None for subdir_path in args.subdir_paths: remote_src = os.path.join(args.remote_src_base, subdir_path) From 945e5231bdd9d5e5471330d5aa90478f9919b901 Mon Sep 17 00:00:00 2001 From: Landan Seguin Date: Wed, 29 May 2024 14:56:08 -0700 Subject: [PATCH 05/10] abort abort --- scripts/t5_precompute.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/scripts/t5_precompute.py b/scripts/t5_precompute.py index 002c4884..5ec7e93f 100644 --- a/scripts/t5_precompute.py +++ b/scripts/t5_precompute.py @@ -5,7 +5,6 @@ from transformers import AutoTokenizer, AutoModel, CLIPTextModel from tqdm import trange from argparse import ArgumentParser -from composer.utils import dist # TODO: Implement batching? 10% faster (when using t5-only), but a lot more complicated code @@ -16,19 +15,16 @@ args = arg_parser.parse_args() - +cache_dir = '/tmp/hf_files' # Instantiate tokenizers print('Building models') -with dist.run_local_rank_zero_first(): - t5_tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-xxl') - clip_tokenizer = AutoTokenizer.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder='tokenizer') +t5_tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-xxl', cache_dir=cache_dir, local_files_only=True) +clip_tokenizer = AutoTokenizer.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder='tokenizer', cache_dir=cache_dir, local_files_only=True) - t5_model = AutoModel.from_pretrained('google/t5-v1_1-xxl', torch_dtype=torch.float16) - clip_model = CLIPTextModel.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder='text_encoder', torch_dtype=torch.float16) +print('Building models') +t5_model = AutoModel.from_pretrained('google/t5-v1_1-xxl', torch_dtype=torch.float16, cache_dir=cache_dir, local_files_only=True).encoder.cuda().eval() +clip_model = CLIPTextModel.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder='text_encoder', torch_dtype=torch.float16, cache_dir=cache_dir, local_files_only=True).cuda().eval() -print('Moving models to GPUs') -t5_model = t5_model.encoder.cuda().eval() -clip_model = clip_model.cuda().eval() columns = None for subdir_path in args.subdir_paths: remote_src = os.path.join(args.remote_src_base, subdir_path) From 7d1fffcc6907800b60c7c79b572dd371d24a01ca Mon Sep 17 00:00:00 2001 From: Landan Seguin Date: Thu, 6 Jun 2024 15:01:44 -0700 Subject: [PATCH 06/10] Add text latent dataset --- diffusion/datasets/__init__.py | 3 + diffusion/datasets/text_latents.py | 207 +++++++++++++++++++++++++++++ 2 files changed, 210 insertions(+) create mode 100644 diffusion/datasets/text_latents.py diff --git a/diffusion/datasets/__init__.py b/diffusion/datasets/__init__.py index 283b5cd0..4276cc55 100644 --- a/diffusion/datasets/__init__.py +++ b/diffusion/datasets/__init__.py @@ -8,6 +8,7 @@ from diffusion.datasets.laion import StreamingLAIONDataset, build_streaming_laion_dataloader from diffusion.datasets.synthetic_image_caption import (SyntheticImageCaptionDataset, build_synthetic_image_caption_dataloader) +from diffusion.datasets.text_latents import StreamingTextLatentsDataset, build_streaming_text_latents_dataloader __all__ = [ 'build_streaming_laion_dataloader', @@ -18,4 +19,6 @@ 'StreamingImageCaptionDataset', 'build_synthetic_image_caption_dataloader', 'SyntheticImageCaptionDataset', + 'build_streaming_text_latents_dataloader', + 'StreamingTextLatentsDataset', ] diff --git a/diffusion/datasets/text_latents.py b/diffusion/datasets/text_latents.py new file mode 100644 index 00000000..e108a7ab --- /dev/null +++ b/diffusion/datasets/text_latents.py @@ -0,0 +1,207 @@ +# Copyright 2022 MosaicML Diffusion authors +# SPDX-License-Identifier: Apache-2.0 + +"""Streaming Image-Caption Dataset for SDXL with Pre-computed Text Latents.""" + +import logging +from io import BytesIO +from pathlib import Path +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +from PIL import Image +from streaming import Stream, StreamingDataset +from torch.utils.data import DataLoader +from torchvision import transforms + +from diffusion.datasets.laion.transforms import LargestCenterSquare, RandomCropAspectRatioTransorm, RandomCropSquare + +log = logging.getLogger(__name__) + +class StreamingTextLatentsDataset(StreamingDataset): + + def __init__( + self, + streams: Sequence[Stream], + caption_drop_prob: float = 0.0, + microcond_drop_prob: float = 0.0, + crop: Optional[Callable] = None, + transform: Optional[Callable] = None, + image_key: str = 'image', + text_latent_keys: Tuple[str, ...] = ('T5_LATENTS', 'CLIP_LATENTS'), + text_latent_shapes: Tuple[Tuple, ...] = ((512, 4096), (77, 768)), + attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'), + **streaming_kwargs, + ): + + # Set defaults for vision-friendly streaming args. + streaming_kwargs.setdefault('shuffle_block_size', 1 << 18) + streaming_kwargs.setdefault('shuffle_algo', 'py1s') + super().__init__(streams=streams, **streaming_kwargs) + + self.crop = crop + self.transform = transform + self.caption_drop_prob = caption_drop_prob + self.microcond_drop_prob = microcond_drop_prob + self.image_key = image_key + self.text_latent_keys = text_latent_keys + self.text_latent_shapes = text_latent_shapes + self.attention_mask_keys = attention_mask_keys + + def __getitem__(self, index): + sample = super().__getitem__(index) + out = {} + + # Image + img = sample[self.image_key] + if not isinstance(img, Image.Image): + img = Image.open(BytesIO(sample[self.image_key])) + if img.mode != 'RGB': + img = img.convert('RGB') + out['cond_original_size'] = torch.tensor(img.size) + + # Image transforms + if self.crop is not None: + img, crop_top, crop_left = self.crop(img) + else: + crop_top, crop_left = 0, 0 + out['cond_crops_coords_top_left'] = torch.tensor([crop_top, crop_left]) + + if self.transform is not None: + img = self.transform(img) + out['image'] = img + + # Get the new height and width + if isinstance(img, torch.Tensor): + img_h, img_w = img.shape[-2], img.shape[-1] + elif isinstance(img, Image.Image): + img_w, img_h = img.size + else: + raise ValueError('Image after transformations must either be a PIL Image or Torch Tensor') + out['cond_target_size'] = torch.tensor([img_w, img_h]) + + # Microconditioning dropout as in Stability repo + # https://github.com/Stability-AI/generative-models/blob/477d8b9a7730d9b2e92b326a770c0420d00308c9/sgm/modules/encoders/modules.py#L151-L160 + if torch.rand(1) < self.microcond_drop_prob: + out['cond_crops_coords_top_left'] = out['cond_crops_coords_top_left'] * 0 + if torch.rand(1) < self.microcond_drop_prob: + out['cond_original_size'] = out['cond_original_size'] * 0 + if torch.rand(1) < self.microcond_drop_prob: + out['cond_target_size'] = out['cond_target_size'] * 0 + + # Load text latents, attention masks, and clip pooled embeddings + for i in range(len(self.text_latent_keys)): + latent_key = self.text_latent_keys[i] + latent_shape = self.text_latent_shapes[i] + attention_key = self.attention_mask_keys[i] + + if torch.rand(1) < self.caption_drop_prob: + out[latent_key] = torch.zeros(latent_shape, dtype=torch.float16) + out[attention_key] = torch.zeros(latent_shape[0]) + if latent_key == 'CLIP_LATENTS': + out['CLIP_POOLED'] = torch.zeros(latent_shape[0]) + else: + text_latent = np.frombuffer(sample[latent_key], dtype=np.float16).copy() + out[latent_key] = torch.from_numpy(text_latent).reshape(latent_shape) + print(i, len(sample[attention_key])) + attention_mask = np.frombuffer(sample[attention_key], dtype=[np.int64, np.bool_][i]).copy() + out[attention_key] = torch.from_numpy(attention_mask).reshape(latent_shape[0]) + if latent_key == 'CLIP_LATENTS': + clip_pooled = np.frombuffer(sample['CLIP_POOLED_TEXT'], dtype=np.float16).copy() + out['CLIP_POOLED'] = torch.from_numpy(clip_pooled).reshape(latent_shape[1]) + return out + + +def build_streaming_text_latents_dataloader( + remote: Union[str, List], + batch_size: int, + local: Optional[Union[str, List]] = None, + caption_drop_prob: float = 0.0, + microcond_drop_prob: float = 0.0, + resize_size: Union[int, Tuple[int, int], Tuple[Tuple[int, int], ...]] = 256, + ar_bucket_boundaries: Optional[Tuple[float, ...]] = None, + transform: Optional[List[Callable]] = None, + crop_type: Optional[str] = 'square', + image_key: str = 'image', + text_latent_keys: Tuple[str, ...] = ('T5_LATENTS', 'CLIP_LATENTS'), + text_latent_shapes: Tuple[Tuple, ...] = ((512, 4096), (77, 768)), + attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'), + streaming_kwargs: Optional[Dict] = None, + dataloader_kwargs: Optional[Dict] = None, +): + + # Check crop type + if crop_type is not None: + crop_type = crop_type.lower() + if crop_type not in ['square', 'random', 'aspect_ratio']: + raise ValueError(f'Invalid crop_type: {crop_type}. Must be ["square", "random", "aspect_ratio", None]') + if crop_type == 'aspect_ratio' and (isinstance(resize_size, int) or isinstance(resize_size[0], int)): + raise ValueError( + 'If using crop_type="aspect_ratio", specify aspect ratio buckets in resize_size as a tuple of tuples.') + + # Handle ``None`` kwargs + if streaming_kwargs is None: + streaming_kwargs = {} + if dataloader_kwargs is None: + dataloader_kwargs = {} + + # Check types for remote and local + + if isinstance(remote, str): + remote = [remote] + if isinstance(local, str): + local = [local] + if not local: + local = [_make_default_local_path(r) for r in remote] + if isinstance(remote, Sequence) and isinstance(local, Sequence): + if len(remote) != len(local): + ValueError( + f'remote and local Sequences must be the same length, got lengths {len(remote)} and {len(local)}') + else: + ValueError(f'remote and local must be both Strings or Sequences, got types {type(remote)} and {type(local)}.') + + # Create a Stream for each (remote, local) pair + streams = [] + for r, l in zip(remote, local): + streams.append(Stream(remote=r, local=l)) + + # Set the crop to apply + if crop_type == 'square': + crop = LargestCenterSquare(resize_size) + elif crop_type == 'random': + crop = RandomCropSquare(resize_size) + elif crop_type == 'aspect_ratio': + crop = RandomCropAspectRatioTransorm(resize_size, ar_bucket_boundaries) # type: ignore + else: + crop = None + + if transform is None: + transform = [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + transform = transforms.Compose(transform) + assert isinstance(transform, Callable) + + dataset = StreamingTextLatentsDataset( + streams=streams, + caption_drop_prob=caption_drop_prob, + microcond_drop_prob=microcond_drop_prob, + crop=crop, + transform=transform, + image_key=image_key, + text_latent_keys=text_latent_keys, + text_latent_shapes=text_latent_shapes, + attention_mask_keys=attention_mask_keys, + **streaming_kwargs, + ) + + dataloader = DataLoader( + dataset=dataset, + batch_size=batch_size, + sampler=None, + **dataloader_kwargs, + ) + + return dataloader + +def _make_default_local_path(remote_path): + return str(Path(*['/tmp'] + list(Path(remote_path).parts[1:]))) \ No newline at end of file From 2dc2935c009fa72885640f5b8d645857a6d23c84 Mon Sep 17 00:00:00 2001 From: Landan Seguin Date: Thu, 6 Jun 2024 15:03:07 -0700 Subject: [PATCH 07/10] adjust precision --- diffusion/datasets/text_latents.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffusion/datasets/text_latents.py b/diffusion/datasets/text_latents.py index e108a7ab..4980c2d1 100644 --- a/diffusion/datasets/text_latents.py +++ b/diffusion/datasets/text_latents.py @@ -105,7 +105,7 @@ def __getitem__(self, index): text_latent = np.frombuffer(sample[latent_key], dtype=np.float16).copy() out[latent_key] = torch.from_numpy(text_latent).reshape(latent_shape) print(i, len(sample[attention_key])) - attention_mask = np.frombuffer(sample[attention_key], dtype=[np.int64, np.bool_][i]).copy() + attention_mask = np.frombuffer(sample[attention_key], dtype=np.bool_).copy() out[attention_key] = torch.from_numpy(attention_mask).reshape(latent_shape[0]) if latent_key == 'CLIP_LATENTS': clip_pooled = np.frombuffer(sample['CLIP_POOLED_TEXT'], dtype=np.float16).copy() From 3ea9c82ef15db5f80ebe702d8e1e9cb1b01501bf Mon Sep 17 00:00:00 2001 From: Landan Seguin Date: Thu, 6 Jun 2024 15:05:13 -0700 Subject: [PATCH 08/10] oops --- diffusion/datasets/text_latents.py | 207 ----------------------------- 1 file changed, 207 deletions(-) delete mode 100644 diffusion/datasets/text_latents.py diff --git a/diffusion/datasets/text_latents.py b/diffusion/datasets/text_latents.py deleted file mode 100644 index 4980c2d1..00000000 --- a/diffusion/datasets/text_latents.py +++ /dev/null @@ -1,207 +0,0 @@ -# Copyright 2022 MosaicML Diffusion authors -# SPDX-License-Identifier: Apache-2.0 - -"""Streaming Image-Caption Dataset for SDXL with Pre-computed Text Latents.""" - -import logging -from io import BytesIO -from pathlib import Path -from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union - -import numpy as np -import torch -from PIL import Image -from streaming import Stream, StreamingDataset -from torch.utils.data import DataLoader -from torchvision import transforms - -from diffusion.datasets.laion.transforms import LargestCenterSquare, RandomCropAspectRatioTransorm, RandomCropSquare - -log = logging.getLogger(__name__) - -class StreamingTextLatentsDataset(StreamingDataset): - - def __init__( - self, - streams: Sequence[Stream], - caption_drop_prob: float = 0.0, - microcond_drop_prob: float = 0.0, - crop: Optional[Callable] = None, - transform: Optional[Callable] = None, - image_key: str = 'image', - text_latent_keys: Tuple[str, ...] = ('T5_LATENTS', 'CLIP_LATENTS'), - text_latent_shapes: Tuple[Tuple, ...] = ((512, 4096), (77, 768)), - attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'), - **streaming_kwargs, - ): - - # Set defaults for vision-friendly streaming args. - streaming_kwargs.setdefault('shuffle_block_size', 1 << 18) - streaming_kwargs.setdefault('shuffle_algo', 'py1s') - super().__init__(streams=streams, **streaming_kwargs) - - self.crop = crop - self.transform = transform - self.caption_drop_prob = caption_drop_prob - self.microcond_drop_prob = microcond_drop_prob - self.image_key = image_key - self.text_latent_keys = text_latent_keys - self.text_latent_shapes = text_latent_shapes - self.attention_mask_keys = attention_mask_keys - - def __getitem__(self, index): - sample = super().__getitem__(index) - out = {} - - # Image - img = sample[self.image_key] - if not isinstance(img, Image.Image): - img = Image.open(BytesIO(sample[self.image_key])) - if img.mode != 'RGB': - img = img.convert('RGB') - out['cond_original_size'] = torch.tensor(img.size) - - # Image transforms - if self.crop is not None: - img, crop_top, crop_left = self.crop(img) - else: - crop_top, crop_left = 0, 0 - out['cond_crops_coords_top_left'] = torch.tensor([crop_top, crop_left]) - - if self.transform is not None: - img = self.transform(img) - out['image'] = img - - # Get the new height and width - if isinstance(img, torch.Tensor): - img_h, img_w = img.shape[-2], img.shape[-1] - elif isinstance(img, Image.Image): - img_w, img_h = img.size - else: - raise ValueError('Image after transformations must either be a PIL Image or Torch Tensor') - out['cond_target_size'] = torch.tensor([img_w, img_h]) - - # Microconditioning dropout as in Stability repo - # https://github.com/Stability-AI/generative-models/blob/477d8b9a7730d9b2e92b326a770c0420d00308c9/sgm/modules/encoders/modules.py#L151-L160 - if torch.rand(1) < self.microcond_drop_prob: - out['cond_crops_coords_top_left'] = out['cond_crops_coords_top_left'] * 0 - if torch.rand(1) < self.microcond_drop_prob: - out['cond_original_size'] = out['cond_original_size'] * 0 - if torch.rand(1) < self.microcond_drop_prob: - out['cond_target_size'] = out['cond_target_size'] * 0 - - # Load text latents, attention masks, and clip pooled embeddings - for i in range(len(self.text_latent_keys)): - latent_key = self.text_latent_keys[i] - latent_shape = self.text_latent_shapes[i] - attention_key = self.attention_mask_keys[i] - - if torch.rand(1) < self.caption_drop_prob: - out[latent_key] = torch.zeros(latent_shape, dtype=torch.float16) - out[attention_key] = torch.zeros(latent_shape[0]) - if latent_key == 'CLIP_LATENTS': - out['CLIP_POOLED'] = torch.zeros(latent_shape[0]) - else: - text_latent = np.frombuffer(sample[latent_key], dtype=np.float16).copy() - out[latent_key] = torch.from_numpy(text_latent).reshape(latent_shape) - print(i, len(sample[attention_key])) - attention_mask = np.frombuffer(sample[attention_key], dtype=np.bool_).copy() - out[attention_key] = torch.from_numpy(attention_mask).reshape(latent_shape[0]) - if latent_key == 'CLIP_LATENTS': - clip_pooled = np.frombuffer(sample['CLIP_POOLED_TEXT'], dtype=np.float16).copy() - out['CLIP_POOLED'] = torch.from_numpy(clip_pooled).reshape(latent_shape[1]) - return out - - -def build_streaming_text_latents_dataloader( - remote: Union[str, List], - batch_size: int, - local: Optional[Union[str, List]] = None, - caption_drop_prob: float = 0.0, - microcond_drop_prob: float = 0.0, - resize_size: Union[int, Tuple[int, int], Tuple[Tuple[int, int], ...]] = 256, - ar_bucket_boundaries: Optional[Tuple[float, ...]] = None, - transform: Optional[List[Callable]] = None, - crop_type: Optional[str] = 'square', - image_key: str = 'image', - text_latent_keys: Tuple[str, ...] = ('T5_LATENTS', 'CLIP_LATENTS'), - text_latent_shapes: Tuple[Tuple, ...] = ((512, 4096), (77, 768)), - attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'), - streaming_kwargs: Optional[Dict] = None, - dataloader_kwargs: Optional[Dict] = None, -): - - # Check crop type - if crop_type is not None: - crop_type = crop_type.lower() - if crop_type not in ['square', 'random', 'aspect_ratio']: - raise ValueError(f'Invalid crop_type: {crop_type}. Must be ["square", "random", "aspect_ratio", None]') - if crop_type == 'aspect_ratio' and (isinstance(resize_size, int) or isinstance(resize_size[0], int)): - raise ValueError( - 'If using crop_type="aspect_ratio", specify aspect ratio buckets in resize_size as a tuple of tuples.') - - # Handle ``None`` kwargs - if streaming_kwargs is None: - streaming_kwargs = {} - if dataloader_kwargs is None: - dataloader_kwargs = {} - - # Check types for remote and local - - if isinstance(remote, str): - remote = [remote] - if isinstance(local, str): - local = [local] - if not local: - local = [_make_default_local_path(r) for r in remote] - if isinstance(remote, Sequence) and isinstance(local, Sequence): - if len(remote) != len(local): - ValueError( - f'remote and local Sequences must be the same length, got lengths {len(remote)} and {len(local)}') - else: - ValueError(f'remote and local must be both Strings or Sequences, got types {type(remote)} and {type(local)}.') - - # Create a Stream for each (remote, local) pair - streams = [] - for r, l in zip(remote, local): - streams.append(Stream(remote=r, local=l)) - - # Set the crop to apply - if crop_type == 'square': - crop = LargestCenterSquare(resize_size) - elif crop_type == 'random': - crop = RandomCropSquare(resize_size) - elif crop_type == 'aspect_ratio': - crop = RandomCropAspectRatioTransorm(resize_size, ar_bucket_boundaries) # type: ignore - else: - crop = None - - if transform is None: - transform = [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] - transform = transforms.Compose(transform) - assert isinstance(transform, Callable) - - dataset = StreamingTextLatentsDataset( - streams=streams, - caption_drop_prob=caption_drop_prob, - microcond_drop_prob=microcond_drop_prob, - crop=crop, - transform=transform, - image_key=image_key, - text_latent_keys=text_latent_keys, - text_latent_shapes=text_latent_shapes, - attention_mask_keys=attention_mask_keys, - **streaming_kwargs, - ) - - dataloader = DataLoader( - dataset=dataset, - batch_size=batch_size, - sampler=None, - **dataloader_kwargs, - ) - - return dataloader - -def _make_default_local_path(remote_path): - return str(Path(*['/tmp'] + list(Path(remote_path).parts[1:]))) \ No newline at end of file From 0788cbbc55eec6695d72540c666fa2032495281e Mon Sep 17 00:00:00 2001 From: Landan Seguin Date: Fri, 7 Jun 2024 13:54:47 -0700 Subject: [PATCH 09/10] oops again --- diffusion/datasets/__init__.py | 3 --- scripts/t5_precompute.py | 5 +++-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/diffusion/datasets/__init__.py b/diffusion/datasets/__init__.py index 4276cc55..283b5cd0 100644 --- a/diffusion/datasets/__init__.py +++ b/diffusion/datasets/__init__.py @@ -8,7 +8,6 @@ from diffusion.datasets.laion import StreamingLAIONDataset, build_streaming_laion_dataloader from diffusion.datasets.synthetic_image_caption import (SyntheticImageCaptionDataset, build_synthetic_image_caption_dataloader) -from diffusion.datasets.text_latents import StreamingTextLatentsDataset, build_streaming_text_latents_dataloader __all__ = [ 'build_streaming_laion_dataloader', @@ -19,6 +18,4 @@ 'StreamingImageCaptionDataset', 'build_synthetic_image_caption_dataloader', 'SyntheticImageCaptionDataset', - 'build_streaming_text_latents_dataloader', - 'StreamingTextLatentsDataset', ] diff --git a/scripts/t5_precompute.py b/scripts/t5_precompute.py index 5ec7e93f..7c88ddaf 100644 --- a/scripts/t5_precompute.py +++ b/scripts/t5_precompute.py @@ -16,8 +16,9 @@ cache_dir = '/tmp/hf_files' + # Instantiate tokenizers -print('Building models') +print('Building tokenizers') t5_tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-xxl', cache_dir=cache_dir, local_files_only=True) clip_tokenizer = AutoTokenizer.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder='tokenizer', cache_dir=cache_dir, local_files_only=True) @@ -62,7 +63,7 @@ return_tensors='pt') tokenized_captions = t5_tokenizer_out['input_ids'].cuda() attention_masks = t5_tokenizer_out['attention_mask'].to(torch.bool).cuda() - sample['T5_ATTENTION_MASK'] = t5_tokenizer_out['attention_mask'].squeeze(0).numpy().tobytes() + sample['T5_ATTENTION_MASK'] = t5_tokenizer_out['attention_mask'].squeeze(0).to(torch.bool).numpy().tobytes() t5_out = t5_model(input_ids=tokenized_captions, attention_mask=attention_masks) sample['T5_LATENTS'] = t5_out[0].squeeze(0).cpu().numpy().tobytes() From 7f906621d4af4668903cc8dd7f3bcda5da37e447 Mon Sep 17 00:00:00 2001 From: Landan Seguin Date: Fri, 7 Jun 2024 14:18:53 -0700 Subject: [PATCH 10/10] pre-commit style --- scripts/t5_precompute.py | 62 +++++++++++++++++++++++++++++----------- 1 file changed, 45 insertions(+), 17 deletions(-) diff --git a/scripts/t5_precompute.py b/scripts/t5_precompute.py index 7c88ddaf..11b1f8ec 100644 --- a/scripts/t5_precompute.py +++ b/scripts/t5_precompute.py @@ -1,30 +1,52 @@ -import os +# Copyright 2022 MosaicML Diffusion authors +# SPDX-License-Identifier: Apache-2.0 + +"""Script to stream text from a dataset, compute CLIP and T5 latents, and write the latents to streaming dataset.""" + import json +import os +from argparse import ArgumentParser + import torch from streaming import MDSWriter, StreamingDataset -from transformers import AutoTokenizer, AutoModel, CLIPTextModel from tqdm import trange -from argparse import ArgumentParser +from transformers import AutoModel, AutoTokenizer, CLIPTextModel # TODO: Implement batching? 10% faster (when using t5-only), but a lot more complicated code arg_parser = ArgumentParser() -arg_parser.add_argument('--remote_src_base', type=str, required=True, help='Remote base to download MDS-formatted shards.') +arg_parser.add_argument('--remote_src_base', + type=str, + required=True, + help='Remote base to download MDS-formatted shards.') arg_parser.add_argument('--remote_dst_base', type=str, required=True, help='Remote base to write MDS-formatted shards.') -arg_parser.add_argument('--subdir_paths', nargs='+', type=str, required=True, help='Path to the subdirectory to process.') +arg_parser.add_argument('--subdir_paths', + nargs='+', + type=str, + required=True, + help='Path to the subdirectory to process.') args = arg_parser.parse_args() - cache_dir = '/tmp/hf_files' # Instantiate tokenizers print('Building tokenizers') t5_tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-xxl', cache_dir=cache_dir, local_files_only=True) -clip_tokenizer = AutoTokenizer.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder='tokenizer', cache_dir=cache_dir, local_files_only=True) +clip_tokenizer = AutoTokenizer.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', + subfolder='tokenizer', + cache_dir=cache_dir, + local_files_only=True) print('Building models') -t5_model = AutoModel.from_pretrained('google/t5-v1_1-xxl', torch_dtype=torch.float16, cache_dir=cache_dir, local_files_only=True).encoder.cuda().eval() -clip_model = CLIPTextModel.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder='text_encoder', torch_dtype=torch.float16, cache_dir=cache_dir, local_files_only=True).cuda().eval() +t5_model = AutoModel.from_pretrained('google/t5-v1_1-xxl', + torch_dtype=torch.float16, + cache_dir=cache_dir, + local_files_only=True).encoder.cuda().eval() +clip_model = CLIPTextModel.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', + subfolder='text_encoder', + torch_dtype=torch.float16, + cache_dir=cache_dir, + local_files_only=True).cuda().eval() columns = None for subdir_path in args.subdir_paths: @@ -32,13 +54,16 @@ remote_dst = os.path.join(args.remote_dst_base, subdir_path) # Dataset print('Building dataset') - dataset = StreamingDataset(remote=remote_src, local=os.path.join('/tmp', subdir_path), download_timeout=300, shuffle=False) + dataset = StreamingDataset(remote=remote_src, + local=os.path.join('/tmp', subdir_path), + download_timeout=300, + shuffle=False) # Get columns if columns is None: with open(os.path.join('/tmp/', subdir_path, 'index.json')) as f: index_json = json.load(f) - columns = {k: v for k, v in zip(index_json['shards'][0]['column_names'], index_json['shards'][0]['column_encodings'])} + columns = dict(zip(index_json['shards'][0]['column_names'], index_json['shards'][0]['column_encodings'])) columns['T5_ATTENTION_MASK'] = 'bytes' columns['T5_LATENTS'] = 'bytes' columns['CLIP_ATTENTION_MASK'] = 'bytes' @@ -57,10 +82,10 @@ # Pre-compute T5 t5_tokenizer_out = t5_tokenizer(captions, - padding='max_length', - max_length=t5_tokenizer.model_max_length, - truncation=True, - return_tensors='pt') + padding='max_length', + max_length=t5_tokenizer.model_max_length, + truncation=True, + return_tensors='pt') tokenized_captions = t5_tokenizer_out['input_ids'].cuda() attention_masks = t5_tokenizer_out['attention_mask'].to(torch.bool).cuda() sample['T5_ATTENTION_MASK'] = t5_tokenizer_out['attention_mask'].squeeze(0).to(torch.bool).numpy().tobytes() @@ -75,8 +100,11 @@ return_tensors='pt') tokenized_captions = clip_tokenizer_out['input_ids'].cuda() attention_masks = clip_tokenizer_out['attention_mask'].cuda() - sample['CLIP_ATTENTION_MASK'] = clip_tokenizer_out['attention_mask'].squeeze(0).to(torch.bool).numpy().tobytes() - clip_out = clip_model(input_ids=tokenized_captions, attention_mask=attention_masks, output_hidden_states=True) + sample['CLIP_ATTENTION_MASK'] = clip_tokenizer_out['attention_mask'].squeeze(0).to( + torch.bool).numpy().tobytes() + clip_out = clip_model(input_ids=tokenized_captions, + attention_mask=attention_masks, + output_hidden_states=True) sample['CLIP_LATENTS'] = clip_out.hidden_states[-2].squeeze(0).cpu().numpy().tobytes() sample['CLIP_POOLED_TEXT'] = clip_out[1].squeeze(0).cpu().numpy().tobytes()