@@ -17,41 +17,19 @@ def calculate_denoised(self, sigma, model_output, model_input):
17
17
18
18
return c_out * x0 + c_skip * model_input
19
19
20
- class ModelSamplingDiscreteDistilled (torch . nn . Module ):
20
+ class ModelSamplingDiscreteDistilled (comfy . model_sampling . ModelSamplingDiscrete ):
21
21
original_timesteps = 50
22
22
23
- def __init__ (self ):
24
- super ().__init__ ()
25
- self .sigma_data = 1.0
26
- timesteps = 1000
27
- beta_start = 0.00085
28
- beta_end = 0.012
23
+ def __init__ (self , model_config = None ):
24
+ super ().__init__ (model_config )
29
25
30
- betas = torch .linspace (beta_start ** 0.5 , beta_end ** 0.5 , timesteps , dtype = torch .float32 ) ** 2
31
- alphas = 1.0 - betas
32
- alphas_cumprod = torch .cumprod (alphas , dim = 0 )
26
+ self .skip_steps = self .num_timesteps // self .original_timesteps
33
27
34
- self .skip_steps = timesteps // self .original_timesteps
35
-
36
-
37
- alphas_cumprod_valid = torch .zeros ((self .original_timesteps ), dtype = torch .float32 )
28
+ sigmas_valid = torch .zeros ((self .original_timesteps ), dtype = torch .float32 )
38
29
for x in range (self .original_timesteps ):
39
- alphas_cumprod_valid [self .original_timesteps - 1 - x ] = alphas_cumprod [timesteps - 1 - x * self .skip_steps ]
40
-
41
- sigmas = ((1 - alphas_cumprod_valid ) / alphas_cumprod_valid ) ** 0.5
42
- self .set_sigmas (sigmas )
43
-
44
- def set_sigmas (self , sigmas ):
45
- self .register_buffer ('sigmas' , sigmas )
46
- self .register_buffer ('log_sigmas' , sigmas .log ())
30
+ sigmas_valid [self .original_timesteps - 1 - x ] = self .sigmas [self .num_timesteps - 1 - x * self .skip_steps ]
47
31
48
- @property
49
- def sigma_min (self ):
50
- return self .sigmas [0 ]
51
-
52
- @property
53
- def sigma_max (self ):
54
- return self .sigmas [- 1 ]
32
+ self .set_sigmas (sigmas_valid )
55
33
56
34
def timestep (self , sigma ):
57
35
log_sigma = sigma .log ()
@@ -66,14 +44,6 @@ def sigma(self, timestep):
66
44
log_sigma = (1 - w ) * self .log_sigmas [low_idx ] + w * self .log_sigmas [high_idx ]
67
45
return log_sigma .exp ().to (timestep .device )
68
46
69
- def percent_to_sigma (self , percent ):
70
- if percent <= 0.0 :
71
- return 999999999.9
72
- if percent >= 1.0 :
73
- return 0.0
74
- percent = 1.0 - percent
75
- return self .sigma (torch .tensor (percent * 999.0 )).item ()
76
-
77
47
78
48
def rescale_zero_terminal_snr_sigmas (sigmas ):
79
49
alphas_cumprod = 1 / ((sigmas * sigmas ) + 1 )
@@ -154,7 +124,7 @@ def patch(self, model, sampling, sigma_max, sigma_min):
154
124
class ModelSamplingAdvanced (comfy .model_sampling .ModelSamplingContinuousEDM , sampling_type ):
155
125
pass
156
126
157
- model_sampling = ModelSamplingAdvanced ()
127
+ model_sampling = ModelSamplingAdvanced (model . model . model_config )
158
128
model_sampling .set_sigma_range (sigma_min , sigma_max )
159
129
m .add_object_patch ("model_sampling" , model_sampling )
160
130
return (m , )
0 commit comments