diff --git a/diffusion/datasets/image_caption_latents.py b/diffusion/datasets/image_caption_latents.py index dd2cbcc8..18fb2ffa 100644 --- a/diffusion/datasets/image_caption_latents.py +++ b/diffusion/datasets/image_caption_latents.py @@ -33,6 +33,7 @@ class StreamingImageCaptionLatentsDataset(StreamingDataset): image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``. caption_keys (Tuple[str, ...]): Key(s) associated with captions in the streaming dataset. Default: ``('caption',)``. caption_selection_probs (Tuple[float, ...]): The probability of selecting each caption key. Default: ``(1.0,)``. + aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. Default: ``None``. text_latent_keys (Tuple[str, ...]): Key(s) associated with text latents in the streaming dataset. Default: ``('T5_LATENTS', 'CLIP_LATENTS')``. text_latent_shapes (Tuple[Tuple[int, int], ...]): The shape(s) of the text latents in the streaming dataset. @@ -55,6 +56,7 @@ def __init__( image_key: str = 'image', caption_keys: Tuple[str, ...] = ('caption',), caption_selection_probs: Tuple[float, ...] = (1.0,), + aspect_ratio_bucket_key: Optional[str] = None, text_latent_keys: Tuple[str, ...] = ('T5_LATENTS', 'CLIP_LATENTS'), text_latent_shapes: Tuple[Tuple[int, int], ...] = ((512, 4096), (77, 768)), attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'), @@ -75,6 +77,9 @@ def __init__( self.image_key = image_key self.caption_keys = caption_keys self.caption_selection_probs = caption_selection_probs + self.aspect_ratio_bucket_key = aspect_ratio_bucket_key + if isinstance(self.crop, RandomCropBucketedAspectRatioTransform): + assert self.aspect_ratio_bucket_key is not None, 'aspect_ratio_bucket_key must be provided when using RandomCropBucketedAspectRatioTransform' self.text_latent_keys = text_latent_keys self.text_latent_shapes = text_latent_shapes self.attention_mask_keys = attention_mask_keys @@ -94,15 +99,16 @@ def __getitem__(self, index): out['cond_original_size'] = torch.tensor(img.size) # Image transforms - if self.crop is not None: + if isinstance(self.crop, RandomCropBucketedAspectRatioTransform): + img, crop_top, crop_left = self.crop(img, sample[self.aspect_ratio_bucket_key]) + elif self.crop is not None: img, crop_top, crop_left = self.crop(img) else: crop_top, crop_left = 0, 0 - out['cond_crops_coords_top_left'] = torch.tensor([crop_top, crop_left]) - if self.transform is not None: img = self.transform(img) out['image'] = img + out['cond_crops_coords_top_left'] = torch.tensor([crop_top, crop_left]) # Get the new height and width if isinstance(img, torch.Tensor): @@ -264,6 +270,7 @@ def build_streaming_image_caption_latents_dataloader( image_key=image_key, caption_keys=caption_keys, caption_selection_probs=caption_selection_probs, + aspect_ratio_bucket_key=aspect_ratio_bucket_key, text_latent_keys=text_latent_keys, text_latent_shapes=text_latent_shapes, attention_mask_keys=attention_mask_keys,