Skip to content

Commit

Permalink
Add kl_optimal scheduler (#6206)
Browse files Browse the repository at this point in the history
* Add kl_optimal scheduler

* Rename kl_optimal_schedule to kl_optimal_scheduler to be more consistent
  • Loading branch information
blepping authored Dec 30, 2024
1 parent d9b7cfa commit a90aafa
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion comfy/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,13 @@ def linear_quadratic_schedule(model_sampling, steps, threshold_noise=0.025, line
sigma_schedule = [1.0 - x for x in sigma_schedule]
return torch.FloatTensor(sigma_schedule) * model_sampling.sigma_max.cpu()

# Referenced from https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15608
def kl_optimal_scheduler(n: int, sigma_min: float, sigma_max: float) -> torch.Tensor:
adj_idxs = torch.arange(n, dtype=torch.float).div_(n - 1)
sigmas = adj_idxs.new_zeros(n + 1)
sigmas[:-1] = (adj_idxs * math.atan(sigma_min) + (1 - adj_idxs) * math.atan(sigma_max)).tan_()
return sigmas

def get_mask_aabb(masks):
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device, dtype=torch.int)
Expand Down Expand Up @@ -913,7 +920,7 @@ def sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)


SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta", "linear_quadratic"]
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "beta", "linear_quadratic", "kl_optimal"]
SAMPLER_NAMES = KSAMPLER_NAMES + ["ddim", "uni_pc", "uni_pc_bh2"]

def calculate_sigmas(model_sampling, scheduler_name, steps):
Expand All @@ -933,6 +940,8 @@ def calculate_sigmas(model_sampling, scheduler_name, steps):
sigmas = beta_scheduler(model_sampling, steps)
elif scheduler_name == "linear_quadratic":
sigmas = linear_quadratic_schedule(model_sampling, steps)
elif scheduler_name == "kl_optimal":
sigmas = kl_optimal_scheduler(n=steps, sigma_min=float(model_sampling.sigma_min), sigma_max=float(model_sampling.sigma_max))
else:
logging.error("error invalid scheduler {}".format(scheduler_name))
return sigmas
Expand Down

0 comments on commit a90aafa

Please sign in to comment.