Skip to content

Commit

Permalink
Add docs and *style*
Browse files Browse the repository at this point in the history
  • Loading branch information
Landanjs committed Jun 7, 2024
1 parent 7d398bc commit c4055f0
Showing 1 changed file with 70 additions and 23 deletions.
93 changes: 70 additions & 23 deletions diffusion/datasets/text_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,26 @@

log = logging.getLogger(__name__)


class StreamingTextLatentsDataset(StreamingDataset):
"""Streaming dataset for image-caption datasets with pre-computed text latents.
Args:
streams (Sequence[Stream]): One or more Streams to stream/cache samples from.
caption_drop_prob (float): The probability of dropping a caption. Default: ``0.0``.
microcond_drop_prob (float): The probability of dropping microconditioning. Only relevant for SDXL. Default: ``0.0``.
crop (Callable, optional): The crop transform to apply to the image before ``transform``. Default: ``None``
transform (Callable, optional): The transforms to apply to the image. Default: ``None``.
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
text_latent_keys (Tuple[str, ...]): Key(s) associated with text latents in the streaming dataset.
Default: ``('T5_LATENTS', 'CLIP_LATENTS')``.
text_latent_shapes (Tuple[Tuple[int, int], ...]): The shape(s) of the text latents in the streaming dataset.
Each shape is a 2-tuple where the first dim is the sequence length and the second dim is the feature size.
Default: ``((512, 4096), (77, 768))``.
attention_mask_keys (Tuple[str, ...]): Key(s) associated with attention masks in the streaming dataset.
Default: ``('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK')``.
**streaming_kwargs: Additional arguments to pass in the construction of the StreamingDataloader
"""

def __init__(
self,
Expand All @@ -30,11 +49,11 @@ def __init__(
transform: Optional[Callable] = None,
image_key: str = 'image',
text_latent_keys: Tuple[str, ...] = ('T5_LATENTS', 'CLIP_LATENTS'),
text_latent_shapes: Tuple[Tuple, ...] = ((512, 4096), (77, 768)),
text_latent_shapes: Tuple[Tuple[int, int], ...] = ((512, 4096), (77, 768)),
attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'),
**streaming_kwargs,
):

# Set defaults for vision-friendly streaming args.
streaming_kwargs.setdefault('shuffle_block_size', 1 << 18)
streaming_kwargs.setdefault('shuffle_algo', 'py1s')
Expand Down Expand Up @@ -91,7 +110,7 @@ def __getitem__(self, index):
out['cond_target_size'] = out['cond_target_size'] * 0

# Load text latents, attention masks, and clip pooled embeddings
for i in range(len(self.text_latent_keys)):
for i in range(len(self.text_latent_keys)):
latent_key = self.text_latent_keys[i]
latent_shape = self.text_latent_shapes[i]
attention_key = self.attention_mask_keys[i]
Expand All @@ -111,26 +130,53 @@ def __getitem__(self, index):
clip_pooled = np.frombuffer(sample['CLIP_POOLED_TEXT'], dtype=np.float16).copy()
out['CLIP_POOLED'] = torch.from_numpy(clip_pooled).reshape(latent_shape[1])
return out


def build_streaming_text_latents_dataloader(
remote: Union[str, List],
batch_size: int,
local: Optional[Union[str, List]] = None,
caption_drop_prob: float = 0.0,
microcond_drop_prob: float = 0.0,
resize_size: Union[int, Tuple[int, int], Tuple[Tuple[int, int], ...]] = 256,
ar_bucket_boundaries: Optional[Tuple[float, ...]] = None,
transform: Optional[List[Callable]] = None,
crop_type: Optional[str] = 'square',
image_key: str = 'image',
text_latent_keys: Tuple[str, ...] = ('T5_LATENTS', 'CLIP_LATENTS'),
text_latent_shapes: Tuple[Tuple, ...] = ((512, 4096), (77, 768)),
attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'),
streaming_kwargs: Optional[Dict] = None,
dataloader_kwargs: Optional[Dict] = None,
remote: Union[str, List],
batch_size: int,
local: Optional[Union[str, List]] = None,
caption_drop_prob: float = 0.0,
microcond_drop_prob: float = 0.0,
resize_size: Union[int, Tuple[int, int], Tuple[Tuple[int, int], ...]] = 256,
ar_bucket_boundaries: Optional[Tuple[float, ...]] = None,
transform: Optional[List[Callable]] = None,
crop_type: Optional[str] = 'square',
image_key: str = 'image',
text_latent_keys: Tuple[str, ...] = ('T5_LATENTS', 'CLIP_LATENTS'),
text_latent_shapes: Tuple[Tuple, ...] = ((512, 4096), (77, 768)),
attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'),
streaming_kwargs: Optional[Dict] = None,
dataloader_kwargs: Optional[Dict] = None,
):

"""Builds a streaming dataloader for image-caption pairs with pre-computed text latents.
Args:
remote (str, Sequence[str]): One or more remote directories (S3 or local filesystem) where dataset is stored.
batch_size (int): The batch size to use for both the ``StreamingDataset`` and ``DataLoader``.
local (str, Sequence[str], optional): One or more local filesystem directories where dataset is cached during operation.
caption_drop_prob (float): The probability of dropping a caption. Default: ``0.0``.
microcond_drop_prob (float): The probability of dropping microconditioning. Default:``0.0``.
resize_size (int, Tuple[int, int], Tuple[Tuple[int, int], ...]): The size to resize the image to. Specify a
tuple of tuples if using 'aspect_ratio' crop_type. Default: ``256``.
ar_bucket_boundaries (Tuple[float, ...], optional): When using ``crop_type='aspect_ratio'``, specifies the
boundary points for bucket assignment. This tuple should be of length len(resize_size) - 1. If set to
``None``, the bucket with the smallest distance to the current sample's aspect ratio is selected.
Default: ``None``.
transform (Callable, optional): The transforms to apply to the image. Default: ``None``.
crop_type (str, optional): Type of crop to perform, either ['square', 'random', 'aspect_ratio'].
Default: ``'square'``.
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
text_latent_keys (Tuple[str, ...]): Key(s) associated with text latents in the streaming dataset.
Default: ``('T5_LATENTS', 'CLIP_LATENTS')``.
text_latent_shapes (Tuple[Tuple[int, int], ...]): The shape(s) of the text latents in the streaming dataset.
Each shape is a 2-tuple where the first dim is the sequence length and the second dim is the feature size.
Default: ``((512, 4096), (77, 768))``.
attention_mask_keys (Tuple[str, ...]): Key(s) associated with attention masks in the streaming dataset.
Default: ``('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK')``.
streaming_kwargs (dict, optional): Additional arguments to pass to the ``StreamingDataset``. Default: ``None``.
dataloader_kwargs (dict, optional): Additional arguments to pass to the ``DataLoader``. Default: ``None``.
"""
# Check crop type
if crop_type is not None:
crop_type = crop_type.lower()
Expand All @@ -139,13 +185,13 @@ def build_streaming_text_latents_dataloader(
if crop_type == 'aspect_ratio' and (isinstance(resize_size, int) or isinstance(resize_size[0], int)):
raise ValueError(
'If using crop_type="aspect_ratio", specify aspect ratio buckets in resize_size as a tuple of tuples.')

# Handle ``None`` kwargs
if streaming_kwargs is None:
streaming_kwargs = {}
if dataloader_kwargs is None:
dataloader_kwargs = {}

# Check types for remote and local

if isinstance(remote, str):
Expand Down Expand Up @@ -203,5 +249,6 @@ def build_streaming_text_latents_dataloader(

return dataloader


def _make_default_local_path(remote_path):
return str(Path(*['/tmp'] + list(Path(remote_path).parts[1:])))
return str(Path(*['/tmp'] + list(Path(remote_path).parts[1:])))

0 comments on commit c4055f0

Please sign in to comment.