1
1
import torch
2
2
from comfy .ldm .modules .diffusionmodules .openaimodel import UNetModel
3
3
from comfy .ldm .modules .encoders .noise_aug_modules import CLIPEmbeddingNoiseAugmentation
4
- from comfy .ldm .modules .diffusionmodules .util import make_beta_schedule
5
4
from comfy .ldm .modules .diffusionmodules .openaimodel import Timestep
6
5
import comfy .model_management
7
6
import comfy .conds
8
- import numpy as np
9
7
from enum import Enum
10
8
from . import utils
11
9
@@ -14,79 +12,7 @@ class ModelType(Enum):
14
12
V_PREDICTION = 2
15
13
16
14
17
- #NOTE: all this sampling stuff will be moved
18
- class EPS :
19
- def calculate_input (self , sigma , noise ):
20
- sigma = sigma .view (sigma .shape [:1 ] + (1 ,) * (noise .ndim - 1 ))
21
- return noise / (sigma ** 2 + self .sigma_data ** 2 ) ** 0.5
22
-
23
- def calculate_denoised (self , sigma , model_output , model_input ):
24
- sigma = sigma .view (sigma .shape [:1 ] + (1 ,) * (model_output .ndim - 1 ))
25
- return model_input - model_output * sigma
26
-
27
-
28
- class V_PREDICTION (EPS ):
29
- def calculate_denoised (self , sigma , model_output , model_input ):
30
- sigma = sigma .view (sigma .shape [:1 ] + (1 ,) * (model_output .ndim - 1 ))
31
- return model_input * self .sigma_data ** 2 / (sigma ** 2 + self .sigma_data ** 2 ) - model_output * sigma * self .sigma_data / (sigma ** 2 + self .sigma_data ** 2 ) ** 0.5
32
-
33
-
34
- class ModelSamplingDiscrete (torch .nn .Module ):
35
- def __init__ (self , model_config = None ):
36
- super ().__init__ ()
37
- beta_schedule = "linear"
38
- if model_config is not None :
39
- beta_schedule = model_config .beta_schedule
40
- self ._register_schedule (given_betas = None , beta_schedule = beta_schedule , timesteps = 1000 , linear_start = 0.00085 , linear_end = 0.012 , cosine_s = 8e-3 )
41
- self .sigma_data = 1.0
42
-
43
- def _register_schedule (self , given_betas = None , beta_schedule = "linear" , timesteps = 1000 ,
44
- linear_start = 1e-4 , linear_end = 2e-2 , cosine_s = 8e-3 ):
45
- if given_betas is not None :
46
- betas = given_betas
47
- else :
48
- betas = make_beta_schedule (beta_schedule , timesteps , linear_start = linear_start , linear_end = linear_end , cosine_s = cosine_s )
49
- alphas = 1. - betas
50
- alphas_cumprod = torch .tensor (np .cumprod (alphas , axis = 0 ), dtype = torch .float32 )
51
- # alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
52
-
53
- timesteps , = betas .shape
54
- self .num_timesteps = int (timesteps )
55
- self .linear_start = linear_start
56
- self .linear_end = linear_end
57
-
58
- # self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
59
- # self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
60
- # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
61
-
62
- sigmas = ((1 - alphas_cumprod ) / alphas_cumprod ) ** 0.5
63
-
64
- self .register_buffer ('sigmas' , sigmas )
65
- self .register_buffer ('log_sigmas' , sigmas .log ())
66
-
67
- @property
68
- def sigma_min (self ):
69
- return self .sigmas [0 ]
70
-
71
- @property
72
- def sigma_max (self ):
73
- return self .sigmas [- 1 ]
74
-
75
- def timestep (self , sigma ):
76
- log_sigma = sigma .log ()
77
- dists = log_sigma .to (self .log_sigmas .device ) - self .log_sigmas [:, None ]
78
- return dists .abs ().argmin (dim = 0 ).view (sigma .shape )
79
-
80
- def sigma (self , timestep ):
81
- t = torch .clamp (timestep .float (), min = 0 , max = (len (self .sigmas ) - 1 ))
82
- low_idx = t .floor ().long ()
83
- high_idx = t .ceil ().long ()
84
- w = t .frac ()
85
- log_sigma = (1 - w ) * self .log_sigmas [low_idx ] + w * self .log_sigmas [high_idx ]
86
- return log_sigma .exp ()
87
-
88
- def percent_to_sigma (self , percent ):
89
- return self .sigma (torch .tensor (percent * 999.0 ))
15
+ from comfy .model_sampling import EPS , V_PREDICTION , ModelSamplingDiscrete
90
16
91
17
def model_sampling (model_config , model_type ):
92
18
if model_type == ModelType .EPS :
@@ -102,7 +28,6 @@ class ModelSampling(s, c):
102
28
return ModelSampling (model_config )
103
29
104
30
105
-
106
31
class BaseModel (torch .nn .Module ):
107
32
def __init__ (self , model_config , model_type = ModelType .EPS , device = None ):
108
33
super ().__init__ ()
0 commit comments