Skip to content

Commit

Permalink
Add option to shift noise schedules when changing resolution (#153)
Browse files Browse the repository at this point in the history
  • Loading branch information
coryMosaicML authored Jul 8, 2024
1 parent 273cd43 commit 0b79104
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 1 deletion.
21 changes: 21 additions & 0 deletions diffusion/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from diffusion.models.stable_diffusion import StableDiffusion
from diffusion.models.text_encoder import MultiTextEncoder, MultiTokenizer
from diffusion.schedulers.schedulers import ContinuousTimeScheduler
from diffusion.schedulers.utils import shift_noise_schedule

try:
import xformers # type: ignore
Expand Down Expand Up @@ -51,6 +52,7 @@ def stable_diffusion_2(
beta_schedule: str = 'scaled_linear',
zero_terminal_snr: bool = False,
offset_noise: Optional[float] = None,
scheduler_shift_resolution: int = 256,
train_metrics: Optional[List] = None,
val_metrics: Optional[List] = None,
quasirandomness: bool = False,
Expand Down Expand Up @@ -97,6 +99,7 @@ def stable_diffusion_2(
precomputed_latents (bool): Whether to use precomputed latents. Defaults to False.
offset_noise (float, optional): The scale of the offset noise. If not specified, offset noise will not
be used. Default `None`.
scheduler_shift_resolution (int): The resolution to shift the noise scheduler to. Default: `256`.
encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True.
mask_pad_tokens (bool): Whether to mask pad tokens in cross attention. Defaults to False.
fsdp (bool): Whether to use FSDP. Defaults to True.
Expand Down Expand Up @@ -178,6 +181,14 @@ def stable_diffusion_2(
set_alpha_to_one=False,
prediction_type=prediction_type)

# Shift noise scheduler to correct for resolution changes
noise_scheduler = shift_noise_schedule(noise_scheduler,
base_dim=32,
shift_dim=scheduler_shift_resolution // downsample_factor)
inference_noise_scheduler = shift_noise_schedule(inference_noise_scheduler,
base_dim=32,
shift_dim=scheduler_shift_resolution // downsample_factor)

# Make the composer model
model = StableDiffusion(
unet=unet,
Expand Down Expand Up @@ -236,6 +247,7 @@ def stable_diffusion_xl(
zero_terminal_snr: bool = False,
use_karras_sigmas: bool = False,
offset_noise: Optional[float] = None,
scheduler_shift_resolution: int = 256,
train_metrics: Optional[List] = None,
val_metrics: Optional[List] = None,
quasirandomness: bool = False,
Expand Down Expand Up @@ -283,6 +295,7 @@ def stable_diffusion_xl(
use_karras_sigmas (bool): Whether to use the Karras sigmas for the diffusion process noise. Default: `False`.
offset_noise (float, optional): The scale of the offset noise. If not specified, offset noise will not
be used. Default `None`.
scheduler_shift_resolution (int): The resolution to shift the noise scheduler to. Default: `256`.
train_metrics (list, optional): List of metrics to compute during training. If None, defaults to
[MeanSquaredError()].
val_metrics (list, optional): List of metrics to compute during validation. If None, defaults to
Expand Down Expand Up @@ -433,6 +446,14 @@ def stable_diffusion_xl(
steps_offset=1,
rescale_betas_zero_snr=zero_terminal_snr)

# Shift noise scheduler to correct for resolution changes
noise_scheduler = shift_noise_schedule(noise_scheduler,
base_dim=32,
shift_dim=scheduler_shift_resolution // downsample_factor)
inference_noise_scheduler = shift_noise_schedule(inference_noise_scheduler,
base_dim=32,
shift_dim=scheduler_shift_resolution // downsample_factor)

# Make the composer model
model = StableDiffusion(
unet=unet,
Expand Down
3 changes: 2 additions & 1 deletion diffusion/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
"""Diffusion schedulers."""

from diffusion.schedulers.schedulers import ContinuousTimeScheduler
from diffusion.schedulers.utils import shift_noise_schedule

__all__ = ['ContinuousTimeScheduler']
__all__ = ['ContinuousTimeScheduler', 'shift_noise_schedule']
39 changes: 39 additions & 0 deletions diffusion/schedulers/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2022 MosaicML Diffusion authors
# SPDX-License-Identifier: Apache-2.0

"""Utils for working with diffusion schedulers."""

import torch


def shift_noise_schedule(noise_scheduler, base_dim: int = 64, shift_dim: int = 64):
"""Shifts the function SNR(t) for a noise scheduler to correct for resolution changes.
Implements the technique from https://arxiv.org/abs/2301.11093
Args:
noise_scheduler (diffusers.SchedulerMixin): The noise scheduler to shift.
base_dim (int): The base side length of the schedule resolution.
shift_dim (int): The new side length of the schedule resolution.
Returns:
diffusers.SchedulerMixin: The shifted noise scheduler.
"""
# First, we need to get the original SNR(t) function
alpha_bar = noise_scheduler.alphas_cumprod
SNR = alpha_bar / (1 - alpha_bar)
# Shift the SNR acorrording to the resolution change
SNR_shifted = (base_dim / shift_dim)**2 * SNR
# Get the new alpha_bars
alpha_bar_shifted = torch.where(SNR_shifted == float('inf'), torch.tensor(1.0), SNR_shifted / (1 + SNR_shifted))
# Get the new alpha values
alpha_shifted = torch.empty_like(alpha_bar_shifted)
alpha_shifted[0] = alpha_bar_shifted[0]
alpha_shifted[1:] = alpha_bar_shifted[1:] / alpha_bar_shifted[:-1]
# Get the new beta values
beta_shifted = 1 - alpha_shifted
# Update the noise scheduler
noise_scheduler.alphas = alpha_shifted
noise_scheduler.betas = beta_shifted
noise_scheduler.alphas_cumprod = alpha_bar_shifted
return noise_scheduler

0 comments on commit 0b79104

Please sign in to comment.