Skip to content

Commit

Permalink
Fix some missing keys
Browse files Browse the repository at this point in the history
  • Loading branch information
corystephenson-db committed Sep 28, 2024
1 parent 479fe54 commit e7fcb59
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions diffusion/datasets/image_caption_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class StreamingImageCaptionLatentsDataset(StreamingDataset):
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
caption_keys (Tuple[str, ...]): Key(s) associated with captions in the streaming dataset. Default: ``('caption',)``.
caption_selection_probs (Tuple[float, ...]): The probability of selecting each caption key. Default: ``(1.0,)``.
aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. Default: ``None``.
text_latent_keys (Tuple[str, ...]): Key(s) associated with text latents in the streaming dataset.
Default: ``('T5_LATENTS', 'CLIP_LATENTS')``.
text_latent_shapes (Tuple[Tuple[int, int], ...]): The shape(s) of the text latents in the streaming dataset.
Expand All @@ -55,6 +56,7 @@ def __init__(
image_key: str = 'image',
caption_keys: Tuple[str, ...] = ('caption',),
caption_selection_probs: Tuple[float, ...] = (1.0,),
aspect_ratio_bucket_key: Optional[str] = None,
text_latent_keys: Tuple[str, ...] = ('T5_LATENTS', 'CLIP_LATENTS'),
text_latent_shapes: Tuple[Tuple[int, int], ...] = ((512, 4096), (77, 768)),
attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'),
Expand All @@ -75,6 +77,9 @@ def __init__(
self.image_key = image_key
self.caption_keys = caption_keys
self.caption_selection_probs = caption_selection_probs
self.aspect_ratio_bucket_key = aspect_ratio_bucket_key
if isinstance(self.crop, RandomCropBucketedAspectRatioTransform):
assert self.aspect_ratio_bucket_key is not None, 'aspect_ratio_bucket_key must be provided when using RandomCropBucketedAspectRatioTransform'
self.text_latent_keys = text_latent_keys
self.text_latent_shapes = text_latent_shapes
self.attention_mask_keys = attention_mask_keys
Expand All @@ -94,15 +99,16 @@ def __getitem__(self, index):
out['cond_original_size'] = torch.tensor(img.size)

# Image transforms
if self.crop is not None:
if isinstance(self.crop, RandomCropBucketedAspectRatioTransform):
img, crop_top, crop_left = self.crop(img, sample[self.aspect_ratio_bucket_key])
elif self.crop is not None:
img, crop_top, crop_left = self.crop(img)
else:
crop_top, crop_left = 0, 0
out['cond_crops_coords_top_left'] = torch.tensor([crop_top, crop_left])

if self.transform is not None:
img = self.transform(img)
out['image'] = img
out['cond_crops_coords_top_left'] = torch.tensor([crop_top, crop_left])

# Get the new height and width
if isinstance(img, torch.Tensor):
Expand Down Expand Up @@ -264,6 +270,7 @@ def build_streaming_image_caption_latents_dataloader(
image_key=image_key,
caption_keys=caption_keys,
caption_selection_probs=caption_selection_probs,
aspect_ratio_bucket_key=aspect_ratio_bucket_key,
text_latent_keys=text_latent_keys,
text_latent_shapes=text_latent_shapes,
attention_mask_keys=attention_mask_keys,
Expand Down

0 comments on commit e7fcb59

Please sign in to comment.