Skip to content

Commit

Permalink
Expose option to set per-stream weighting in image and image_caption …
Browse files Browse the repository at this point in the history
…datasets (#156)
  • Loading branch information
coryMosaicML authored Jul 10, 2024
1 parent 0b79104 commit c64e18c
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 39 deletions.
25 changes: 10 additions & 15 deletions diffusion/datasets/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from torch.utils.data import DataLoader
from torchvision import transforms

from diffusion.datasets.utils import make_streams

log = logging.getLogger(__name__)

# Disable PIL max image size limit
Expand Down Expand Up @@ -93,6 +95,9 @@ def build_streaming_image_dataloader(
transform: Optional[List[Callable]] = None,
image_key: str = 'image',
image_output_key: Optional[str] = 'image',
proportion: Optional[list] = None,
repeat: Optional[list] = None,
choose: Optional[list] = None,
streaming_kwargs: Optional[Dict] = None,
dataloader_kwargs: Optional[Dict] = None,
):
Expand All @@ -106,6 +111,9 @@ def build_streaming_image_dataloader(
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
image_output_key (optional, str): Optional output key for the image. If none, the value of `image_key` will
be used. Default: ``image``.
proportion (list, optional): Specifies how to sample this Stream relative to other Streams. Default: ``None``.
repeat (list, optional): Specifies the degree to which a Stream is upsampled or downsampled. Default: ``None``.
choose (list, optional): Specifies the number of samples to choose from a Stream. Default: ``None``.
streaming_kwargs (dict, optional): Additional arguments to pass to the ``StreamingDataset``. Default: ``None``.
dataloader_kwargs (dict, optional): Additional arguments to pass to the ``DataLoader``. Default: ``None``.
"""
Expand All @@ -115,21 +123,8 @@ def build_streaming_image_dataloader(
if dataloader_kwargs is None:
dataloader_kwargs = {}

# Check types for remote and local
if isinstance(remote, str) and isinstance(local, str):
remote, local = [remote], [local]
elif isinstance(remote, Sequence) and isinstance(local, Sequence):
if len(remote) != len(local):
raise ValueError(
f'remote and local Sequences must be the same length, got lengths {len(remote)} and {len(local)}')
else:
raise ValueError(
f'remote and local must be both Strings or Sequences, got types {type(remote)} and {type(local)}.')

# Create a Stream for each (remote, local) pair
streams = []
for r, l in zip(remote, local):
streams.append(Stream(remote=r, local=l))
# Set up streams
streams = make_streams(remote, local=local, proportion=proportion, repeat=repeat, choose=choose)

if transform is None:
transform = [transforms.ToTensor()]
Expand Down
33 changes: 9 additions & 24 deletions diffusion/datasets/image_caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import logging
import random
from io import BytesIO
from pathlib import Path
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union

import torch
Expand All @@ -17,6 +16,7 @@
from torchvision import transforms

from diffusion.datasets.laion.transforms import LargestCenterSquare, RandomCropAspectRatioTransorm, RandomCropSquare
from diffusion.datasets.utils import make_streams
from diffusion.models.text_encoder import MultiTokenizer

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -182,6 +182,9 @@ def build_streaming_image_caption_dataloader(
crop_type: Optional[str] = 'square',
zero_dropped_captions: bool = True,
sdxl_conditioning: bool = False,
proportion: Optional[list] = None,
repeat: Optional[list] = None,
choose: Optional[list] = None,
streaming_kwargs: Optional[Dict] = None,
dataloader_kwargs: Optional[Dict] = None,
):
Expand Down Expand Up @@ -213,6 +216,9 @@ def build_streaming_image_caption_dataloader(
Default: ``'square'``.
zero_dropped_captions (bool): If True, zero out text embeddings for dropped captions. Default: ``True``.
sdxl_conditioning (bool): Whether or not to include SDXL microconditioning in a sample. Default: `False`.
proportion (list, optional): Specifies how to sample this Stream relative to other Streams. Default: ``None``.
repeat (list, optional): Specifies the degree to which a Stream is upsampled or downsampled. Default: ``None``.
choose (list, optional): Specifies the number of samples to choose from a Stream. Default: ``None``.
streaming_kwargs (dict, optional): Additional arguments to pass to the ``StreamingDataset``. Default: ``None``.
dataloader_kwargs (dict, optional): Additional arguments to pass to the ``DataLoader``. Default: ``None``.
"""
Expand All @@ -231,25 +237,8 @@ def build_streaming_image_caption_dataloader(
if dataloader_kwargs is None:
dataloader_kwargs = {}

# Check types for remote and local

if isinstance(remote, str):
remote = [remote]
if isinstance(local, str):
local = [local]
if not local:
local = [_make_default_local_path(r) for r in remote]
if isinstance(remote, Sequence) and isinstance(local, Sequence):
if len(remote) != len(local):
ValueError(
f'remote and local Sequences must be the same length, got lengths {len(remote)} and {len(local)}')
else:
ValueError(f'remote and local must be both Strings or Sequences, got types {type(remote)} and {type(local)}.')

# Create a Stream for each (remote, local) pair
streams = []
for r, l in zip(remote, local):
streams.append(Stream(remote=r, local=l))
# Set up streams
streams = make_streams(remote, local=local, proportion=proportion, repeat=repeat, choose=choose)

# Set the crop to apply
if crop_type == 'square':
Expand Down Expand Up @@ -290,7 +279,3 @@ def build_streaming_image_caption_dataloader(
)

return dataloader


def _make_default_local_path(remote_path):
return str(Path(*['/tmp'] + list(Path(remote_path).parts[1:])))
64 changes: 64 additions & 0 deletions diffusion/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2022 MosaicML Diffusion authors
# SPDX-License-Identifier: Apache-2.0

"""Useful functions for dealing with streaming datasets."""

from pathlib import Path
from typing import Sequence

from streaming import Stream


def make_streams(remote, local=None, proportion=None, repeat=None, choose=None):
"""Helper function to create a list of Stream objects from a set of remotes and stream weights.
Args:
remote (Union[str, Sequence[str]]): The remote path or paths to stream from.
local (Union[str, Sequence[str]], optional): The local path or paths to cache the data. If not provided, the
default local path is used. Default: ``None``.
proportion (list, optional): Specifies how to sample this Stream relative to other Streams. Default: ``None``.
repeat (list, optional): Specifies the degree to which a Stream is upsampled or downsampled. Default: ``None``.
choose (list, optional): Specifies the number of samples to choose from a Stream. Default: ``None``.
Returns:
List[Stream]: A list of Stream objects.
"""
remote, local = _make_remote_and_local_sequences(remote, local)
proportion, repeat, choose = _make_weighting_sequences(remote, proportion, repeat, choose)

streams = []
for i, (r, l) in enumerate(zip(remote, local)):
streams.append(Stream(remote=r, local=l, proportion=proportion[i], repeat=repeat[i], choose=choose[i]))
return streams


def _make_remote_and_local_sequences(remote, local=None):
if isinstance(remote, str):
remote = [remote]
if isinstance(local, str):
local = [local]
if not local:
local = [_make_default_local_path(r) for r in remote]

if isinstance(remote, Sequence) and isinstance(local, Sequence):
if len(remote) != len(local):
ValueError(
f'remote and local Sequences must be the same length, got lengths {len(remote)} and {len(local)}')
else:
ValueError(f'remote and local must be both Strings or Sequences, got types {type(remote)} and {type(local)}.')
return remote, local


def _make_default_local_path(remote_path):
return str(Path(*['/tmp'] + list(Path(remote_path).parts[1:])))


def _make_weighting_sequences(remote, proportion=None, repeat=None, choose=None):
weights = {'proportion': proportion, 'repeat': repeat, 'choose': choose}
for name, weight in weights.items():
if weight is not None and len(remote) != len(weight):
ValueError(f'{name} must be the same length as remote, got lengths {len(remote)} and {len(weight)}')
proportion = weights['proportion'] if weights['proportion'] is not None else [None] * len(remote)
repeat = weights['repeat'] if weights['repeat'] is not None else [None] * len(remote)
choose = weights['choose'] if weights['choose'] is not None else [None] * len(remote)
return proportion, repeat, choose

0 comments on commit c64e18c

Please sign in to comment.