Skip to content

Commit 014c8bf

Browse files
Refactor LCM to support more model types.
1 parent 9cad2f0 commit 014c8bf

File tree

1 file changed

+8
-38
lines changed

1 file changed

+8
-38
lines changed

comfy_extras/nodes_model_advanced.py

Lines changed: 8 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -17,41 +17,19 @@ def calculate_denoised(self, sigma, model_output, model_input):
1717

1818
return c_out * x0 + c_skip * model_input
1919

20-
class ModelSamplingDiscreteDistilled(torch.nn.Module):
20+
class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete):
2121
original_timesteps = 50
2222

23-
def __init__(self):
24-
super().__init__()
25-
self.sigma_data = 1.0
26-
timesteps = 1000
27-
beta_start = 0.00085
28-
beta_end = 0.012
23+
def __init__(self, model_config=None):
24+
super().__init__(model_config)
2925

30-
betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2
31-
alphas = 1.0 - betas
32-
alphas_cumprod = torch.cumprod(alphas, dim=0)
26+
self.skip_steps = self.num_timesteps // self.original_timesteps
3327

34-
self.skip_steps = timesteps // self.original_timesteps
35-
36-
37-
alphas_cumprod_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
28+
sigmas_valid = torch.zeros((self.original_timesteps), dtype=torch.float32)
3829
for x in range(self.original_timesteps):
39-
alphas_cumprod_valid[self.original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
40-
41-
sigmas = ((1 - alphas_cumprod_valid) / alphas_cumprod_valid) ** 0.5
42-
self.set_sigmas(sigmas)
43-
44-
def set_sigmas(self, sigmas):
45-
self.register_buffer('sigmas', sigmas)
46-
self.register_buffer('log_sigmas', sigmas.log())
30+
sigmas_valid[self.original_timesteps - 1 - x] = self.sigmas[self.num_timesteps - 1 - x * self.skip_steps]
4731

48-
@property
49-
def sigma_min(self):
50-
return self.sigmas[0]
51-
52-
@property
53-
def sigma_max(self):
54-
return self.sigmas[-1]
32+
self.set_sigmas(sigmas_valid)
5533

5634
def timestep(self, sigma):
5735
log_sigma = sigma.log()
@@ -66,14 +44,6 @@ def sigma(self, timestep):
6644
log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx]
6745
return log_sigma.exp().to(timestep.device)
6846

69-
def percent_to_sigma(self, percent):
70-
if percent <= 0.0:
71-
return 999999999.9
72-
if percent >= 1.0:
73-
return 0.0
74-
percent = 1.0 - percent
75-
return self.sigma(torch.tensor(percent * 999.0)).item()
76-
7747

7848
def rescale_zero_terminal_snr_sigmas(sigmas):
7949
alphas_cumprod = 1 / ((sigmas * sigmas) + 1)
@@ -154,7 +124,7 @@ def patch(self, model, sampling, sigma_max, sigma_min):
154124
class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingContinuousEDM, sampling_type):
155125
pass
156126

157-
model_sampling = ModelSamplingAdvanced()
127+
model_sampling = ModelSamplingAdvanced(model.model.model_config)
158128
model_sampling.set_sigma_range(sigma_min, sigma_max)
159129
m.add_object_patch("model_sampling", model_sampling)
160130
return (m, )

0 commit comments

Comments
 (0)