Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to use predefined aspect ratio buckets in the cropping transform #157

Merged
merged 5 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions diffusion/datasets/image_caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from torch.utils.data import DataLoader
from torchvision import transforms

from diffusion.datasets.laion.transforms import LargestCenterSquare, RandomCropAspectRatioTransorm, RandomCropSquare
from diffusion.datasets.laion.transforms import (LargestCenterSquare, RandomCropAspectRatioTransform,
RandomCropBucketedAspectRatioTransform, RandomCropSquare)
from diffusion.datasets.utils import make_streams
from diffusion.models.text_encoder import MultiTokenizer

Expand Down Expand Up @@ -45,6 +46,7 @@ class StreamingImageCaptionDataset(StreamingDataset):
transform (Callable, optional): The transforms to apply to the image. Default: ``None``.
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
caption_key (str): Key associated with the caption in the streaming dataset. Default: ``'caption'``.
aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. Default: ``None``.
sdxl_conditioning (bool): Whether or not to include SDXL microconditioning in a sample. Default: `False`.
zero_dropped_captions (bool): If True, zero out text embeddings for dropped captions. Default: ``False``.
**streaming_kwargs: Additional arguments to pass in the construction of the StreamingDataloader
Expand All @@ -63,6 +65,7 @@ def __init__(
transform: Optional[Callable] = None,
image_key: str = 'image',
caption_key: str = 'caption',
aspect_ratio_bucket_key: Optional[str] = None,
sdxl_conditioning: bool = False,
zero_dropped_captions: bool = False,
**streaming_kwargs,
Expand Down Expand Up @@ -90,6 +93,9 @@ def __init__(
self.caption_selection = caption_selection
self.image_key = image_key
self.caption_key = caption_key
self.aspect_ratio_bucket_key = aspect_ratio_bucket_key
if isinstance(self.crop, RandomCropBucketedAspectRatioTransform):
assert self.aspect_ratio_bucket_key is not None, 'aspect_ratio_bucket_key must be provided when using RandomCropBucketedAspectRatioTransform'
self.zero_dropped_captions = zero_dropped_captions

self.tokenizer = tokenizer
Expand All @@ -107,7 +113,9 @@ def __getitem__(self, index):
orig_w, orig_h = img.size

# Image transforms
if self.crop is not None:
if isinstance(self.crop, RandomCropBucketedAspectRatioTransform):
img, crop_top, crop_left = self.crop(img, sample[self.aspect_ratio_bucket_key])
elif self.crop is not None:
img, crop_top, crop_left = self.crop(img)
else:
crop_top, crop_left = 0, 0
Expand Down Expand Up @@ -179,6 +187,7 @@ def build_streaming_image_caption_dataloader(
transform: Optional[List[Callable]] = None,
image_key: str = 'image',
caption_key: str = 'caption',
aspect_ratio_bucket_key: Optional[str] = None,
crop_type: Optional[str] = 'square',
zero_dropped_captions: bool = True,
sdxl_conditioning: bool = False,
Expand Down Expand Up @@ -212,7 +221,8 @@ def build_streaming_image_caption_dataloader(
transform (Optional[Callable]): The transforms to apply to the image. Default: ``None``.
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
caption_key (str): Key associated with the caption in the streaming dataset. Default: ``'caption'``.
crop_type (str, optional): Type of crop to perform, either ['square', 'random', 'aspect_ratio'].
aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. Default: ``None``.
crop_type (str, optional): Type of crop to perform, either ['square', 'random', 'aspect_ratio', 'bucketed_aspect_ratio'].
Default: ``'square'``.
zero_dropped_captions (bool): If True, zero out text embeddings for dropped captions. Default: ``True``.
sdxl_conditioning (bool): Whether or not to include SDXL microconditioning in a sample. Default: `False`.
Expand All @@ -225,12 +235,14 @@ def build_streaming_image_caption_dataloader(
# Check crop type
if crop_type is not None:
crop_type = crop_type.lower()
if crop_type not in ['square', 'random', 'aspect_ratio']:
raise ValueError(f'Invalid crop_type: {crop_type}. Must be ["square", "random", "aspect_ratio", None]')
if crop_type == 'aspect_ratio' and (isinstance(resize_size, int) or isinstance(resize_size[0], int)):
if crop_type not in ['square', 'random', 'aspect_ratio', 'bucketed_aspect_ratio']:
raise ValueError(
'If using crop_type="aspect_ratio", specify aspect ratio buckets in resize_size as a tuple of tuples.')

f'Invalid crop_type: {crop_type}. Must be ["square", "random", "aspect_ratio", "bucketed_aspect_ratio", None]'
)
if crop_type in ['aspect_ratio', 'bucketed_aspect_ratio'] and (isinstance(resize_size, int) or
isinstance(resize_size[0], int)):
raise ValueError(
'If using aspect ratio bucketing, specify aspect ratio buckets in resize_size as a tuple of tuples.')
# Handle ``None`` kwargs
if streaming_kwargs is None:
streaming_kwargs = {}
Expand All @@ -246,7 +258,10 @@ def build_streaming_image_caption_dataloader(
elif crop_type == 'random':
crop = RandomCropSquare(resize_size)
elif crop_type == 'aspect_ratio':
crop = RandomCropAspectRatioTransorm(resize_size, ar_bucket_boundaries) # type: ignore
crop = RandomCropAspectRatioTransform(resize_size, ar_bucket_boundaries) # type: ignore
elif crop_type == 'bucketed_aspect_ratio':
assert aspect_ratio_bucket_key is not None, 'aspect_ratio_bucket_key must be provided when using bucketed_aspect_ratio crop type'
crop = RandomCropBucketedAspectRatioTransform(resize_size) # type: ignore
else:
crop = None

Expand All @@ -265,6 +280,7 @@ def build_streaming_image_caption_dataloader(
transform=transform,
image_key=image_key,
caption_key=caption_key,
aspect_ratio_bucket_key=aspect_ratio_bucket_key,
batch_size=batch_size,
sdxl_conditioning=sdxl_conditioning,
zero_dropped_captions=zero_dropped_captions,
Expand Down
49 changes: 48 additions & 1 deletion diffusion/datasets/laion/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Transforms for the training and eval dataset."""

import math
from typing import Optional, Tuple

import torch
Expand Down Expand Up @@ -45,7 +46,7 @@ def __call__(self, img):
return img, c_top, c_left


class RandomCropAspectRatioTransorm:
class RandomCropAspectRatioTransform:
"""Assigns an image to a arbitrary set of aspect ratio buckets, then resizes and crops to fit into the bucket.

Args:
Expand Down Expand Up @@ -111,3 +112,49 @@ def __call__(self, img):
c_top, c_left, height, width = transforms.RandomCrop.get_params(img, output_size=(target_height, target_width))
img = crop(img, c_top, c_left, height, width)
return img, c_top, c_left


class RandomCropBucketedAspectRatioTransform:
"""Assigns an image to a arbitrary set of aspect ratio buckets, then resizes and crops to fit into the bucket.

This transform requires the desired aspect ratio bucket to be specified manually in the call to the transform.

Args:
resize_size (Tuple[Tuple[int, int], ...): A tuple of 2-tuple integers representing the aspect ratio buckets.
The format is ((height_bucket1, width_bucket1), (height_bucket2, width_bucket2), ...).
"""

def __init__(
self,
resize_size: Tuple[Tuple[int, int], ...],
):
self.height_buckets = torch.tensor([size[0] for size in resize_size])
self.width_buckets = torch.tensor([size[1] for size in resize_size])
self.aspect_ratio_buckets = self.height_buckets / self.width_buckets
self.log_aspect_ratio_buckets = torch.log(self.aspect_ratio_buckets)

def __call__(self, img, aspect_ratio):
orig_w, orig_h = img.size
orig_aspect_ratio = orig_h / orig_w
# Figure out target H/W given the input aspect ratio
bucket_ind = torch.abs(self.log_aspect_ratio_buckets - math.log(aspect_ratio)).argmin()
target_width, target_height = self.width_buckets[bucket_ind].item(), self.height_buckets[bucket_ind].item()
target_aspect_ratio = target_height / target_width

# Determine resize size
if orig_aspect_ratio > target_aspect_ratio:
# Resize width and crop height
w_scale = target_width / orig_w
resize_size = (round(w_scale * orig_h), target_width)
elif orig_aspect_ratio < target_aspect_ratio:
# Resize height and crop width
h_scale = target_height / orig_h
resize_size = (target_height, round(h_scale * orig_w))
else:
resize_size = (target_height, target_width)
img = transforms.functional.resize(img, resize_size, antialias=True)

# Crop based on aspect ratio
c_top, c_left, height, width = transforms.RandomCrop.get_params(img, output_size=(target_height, target_width))
img = crop(img, c_top, c_left, height, width)
return img, c_top, c_left
Loading