Skip to content

Commit

Permalink
Add option for pre-bucketed aspect ratio buckets
Browse files Browse the repository at this point in the history
  • Loading branch information
corystephenson-db committed Sep 27, 2024
1 parent 40ecb59 commit 479fe54
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions diffusion/datasets/image_caption_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from torch.utils.data import DataLoader
from torchvision import transforms

from diffusion.datasets.laion.transforms import LargestCenterSquare, RandomCropAspectRatioTransform, RandomCropSquare
from diffusion.datasets.laion.transforms import (LargestCenterSquare, RandomCropAspectRatioTransform,
RandomCropBucketedAspectRatioTransform, RandomCropSquare)
from diffusion.datasets.utils import make_streams

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -172,6 +173,7 @@ def build_streaming_image_caption_latents_dataloader(
text_latent_shapes: Tuple[Tuple, ...] = ((512, 4096), (77, 768)),
attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'),
latent_dtype: str = 'torch.bfloat16',
aspect_ratio_bucket_key: Optional[str] = None,
streaming_kwargs: Optional[Dict] = None,
dataloader_kwargs: Optional[Dict] = None,
):
Expand All @@ -190,11 +192,12 @@ def build_streaming_image_caption_latents_dataloader(
``None``, the bucket with the smallest distance to the current sample's aspect ratio is selected.
Default: ``None``.
transform (Callable, optional): The transforms to apply to the image. Default: ``None``.
crop_type (str, optional): Type of crop to perform, either ['square', 'random', 'aspect_ratio'].
crop_type (str, optional): Type of crop to perform, either ['square', 'random', 'aspect_ratio', 'bucketed_aspect_ratio'].
Default: ``'square'``.
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
caption_keys (Tuple[str, ...]): Key(s) associated with captions in the streaming dataset. Default: ``('caption',)``.
caption_selection_probs (Tuple[float, ...]): The probability of selecting each caption key. Default: ``(1.0,)``.
aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. Default: ``None``.
text_latent_keys (Tuple[str, ...]): Key(s) associated with text latents in the streaming dataset.
Default: ``('T5_LATENTS', 'CLIP_LATENTS')``.
text_latent_shapes (Tuple[Tuple[int, int], ...]): The shape(s) of the text latents in the streaming dataset.
Expand All @@ -204,18 +207,22 @@ def build_streaming_image_caption_latents_dataloader(
Default: ``('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK')``.
latent_dtype (str): The torch dtype to cast the text latents to. One of 'torch.float16', 'torch.float32',
or 'torch.bfloat16'. Default: ``'torch.bfloat16'``.
aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset.
Needed if using ``crop_type='bucketed_aspect_ratio'``. Default: ``None``.
streaming_kwargs (dict, optional): Additional arguments to pass to the ``StreamingDataset``. Default: ``None``.
dataloader_kwargs (dict, optional): Additional arguments to pass to the ``DataLoader``. Default: ``None``.
"""
# Check crop type
if crop_type is not None:
crop_type = crop_type.lower()
if crop_type not in ['square', 'random', 'aspect_ratio']:
raise ValueError(f'Invalid crop_type: {crop_type}. Must be ["square", "random", "aspect_ratio", None]')
if crop_type == 'aspect_ratio' and (isinstance(resize_size, int) or isinstance(resize_size[0], int)):
if crop_type not in ['square', 'random', 'aspect_ratio', 'bucketed_aspect_ratio']:
raise ValueError(
'If using crop_type="aspect_ratio", specify aspect ratio buckets in resize_size as a tuple of tuples.')

f'Invalid crop_type: {crop_type}. Must be ["square", "random", "aspect_ratio", "bucketed_aspect_ratio", None]'
)
if crop_type in ['aspect_ratio', 'bucketed_aspect_ratio'] and (isinstance(resize_size, int) or
isinstance(resize_size[0], int)):
raise ValueError(
'If using aspect ratio bucketing, specify aspect ratio buckets in resize_size as a tuple of tuples.')
# Check latent dtype
dtypes = {'torch.float16': torch.float16, 'torch.float32': torch.float32, 'torch.bfloat16': torch.bfloat16}
assert latent_dtype in dtypes, f'Invalid latent_dtype: {latent_dtype}. Must be one of {list(dtypes.keys())}'
Expand All @@ -237,6 +244,9 @@ def build_streaming_image_caption_latents_dataloader(
crop = RandomCropSquare(resize_size)
elif crop_type == 'aspect_ratio':
crop = RandomCropAspectRatioTransform(resize_size, ar_bucket_boundaries) # type: ignore
elif crop_type == 'bucketed_aspect_ratio':
assert aspect_ratio_bucket_key is not None, 'aspect_ratio_bucket_key must be provided when using bucketed_aspect_ratio crop type'
crop = RandomCropBucketedAspectRatioTransform(resize_size) # type: ignore
else:
crop = None

Expand Down

0 comments on commit 479fe54

Please sign in to comment.