Skip to content

Commit 1ffa885

Browse files
Move model sampling code to comfy/model_sampling.py
1 parent ae2acfc commit 1ffa885

File tree

2 files changed

+79
-76
lines changed

2 files changed

+79
-76
lines changed

comfy/model_base.py

Lines changed: 1 addition & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import torch
22
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel
33
from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
4-
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
54
from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
65
import comfy.model_management
76
import comfy.conds
8-
import numpy as np
97
from enum import Enum
108
from . import utils
119

@@ -14,79 +12,7 @@ class ModelType(Enum):
1412
V_PREDICTION = 2
1513

1614

17-
#NOTE: all this sampling stuff will be moved
18-
class EPS:
19-
def calculate_input(self, sigma, noise):
20-
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
21-
return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
22-
23-
def calculate_denoised(self, sigma, model_output, model_input):
24-
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
25-
return model_input - model_output * sigma
26-
27-
28-
class V_PREDICTION(EPS):
29-
def calculate_denoised(self, sigma, model_output, model_input):
30-
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
31-
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
32-
33-
34-
class ModelSamplingDiscrete(torch.nn.Module):
35-
def __init__(self, model_config=None):
36-
super().__init__()
37-
beta_schedule = "linear"
38-
if model_config is not None:
39-
beta_schedule = model_config.beta_schedule
40-
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
41-
self.sigma_data = 1.0
42-
43-
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
44-
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
45-
if given_betas is not None:
46-
betas = given_betas
47-
else:
48-
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
49-
alphas = 1. - betas
50-
alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32)
51-
# alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
52-
53-
timesteps, = betas.shape
54-
self.num_timesteps = int(timesteps)
55-
self.linear_start = linear_start
56-
self.linear_end = linear_end
57-
58-
# self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
59-
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
60-
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
61-
62-
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
63-
64-
self.register_buffer('sigmas', sigmas)
65-
self.register_buffer('log_sigmas', sigmas.log())
66-
67-
@property
68-
def sigma_min(self):
69-
return self.sigmas[0]
70-
71-
@property
72-
def sigma_max(self):
73-
return self.sigmas[-1]
74-
75-
def timestep(self, sigma):
76-
log_sigma = sigma.log()
77-
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
78-
return dists.abs().argmin(dim=0).view(sigma.shape)
79-
80-
def sigma(self, timestep):
81-
t = torch.clamp(timestep.float(), min=0, max=(len(self.sigmas) - 1))
82-
low_idx = t.floor().long()
83-
high_idx = t.ceil().long()
84-
w = t.frac()
85-
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
86-
return log_sigma.exp()
87-
88-
def percent_to_sigma(self, percent):
89-
return self.sigma(torch.tensor(percent * 999.0))
15+
from comfy.model_sampling import EPS, V_PREDICTION, ModelSamplingDiscrete
9016

9117
def model_sampling(model_config, model_type):
9218
if model_type == ModelType.EPS:
@@ -102,7 +28,6 @@ class ModelSampling(s, c):
10228
return ModelSampling(model_config)
10329

10430

105-
10631
class BaseModel(torch.nn.Module):
10732
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
10833
super().__init__()

comfy/model_sampling.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import torch
2+
import numpy as np
3+
from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule
4+
5+
6+
class EPS:
7+
def calculate_input(self, sigma, noise):
8+
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
9+
return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
10+
11+
def calculate_denoised(self, sigma, model_output, model_input):
12+
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
13+
return model_input - model_output * sigma
14+
15+
16+
class V_PREDICTION(EPS):
17+
def calculate_denoised(self, sigma, model_output, model_input):
18+
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
19+
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
20+
21+
22+
class ModelSamplingDiscrete(torch.nn.Module):
23+
def __init__(self, model_config=None):
24+
super().__init__()
25+
beta_schedule = "linear"
26+
if model_config is not None:
27+
beta_schedule = model_config.beta_schedule
28+
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
29+
self.sigma_data = 1.0
30+
31+
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
32+
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
33+
if given_betas is not None:
34+
betas = given_betas
35+
else:
36+
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
37+
alphas = 1. - betas
38+
alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32)
39+
# alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
40+
41+
timesteps, = betas.shape
42+
self.num_timesteps = int(timesteps)
43+
self.linear_start = linear_start
44+
self.linear_end = linear_end
45+
46+
# self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
47+
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
48+
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
49+
50+
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
51+
52+
self.register_buffer('sigmas', sigmas)
53+
self.register_buffer('log_sigmas', sigmas.log())
54+
55+
@property
56+
def sigma_min(self):
57+
return self.sigmas[0]
58+
59+
@property
60+
def sigma_max(self):
61+
return self.sigmas[-1]
62+
63+
def timestep(self, sigma):
64+
log_sigma = sigma.log()
65+
dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
66+
return dists.abs().argmin(dim=0).view(sigma.shape)
67+
68+
def sigma(self, timestep):
69+
t = torch.clamp(timestep.float(), min=0, max=(len(self.sigmas) - 1))
70+
low_idx = t.floor().long()
71+
high_idx = t.ceil().long()
72+
w = t.frac()
73+
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
74+
return log_sigma.exp()
75+
76+
def percent_to_sigma(self, percent):
77+
return self.sigma(torch.tensor(percent * 999.0))
78+

0 commit comments

Comments
 (0)