Skip to content

Commit 2ccdac8

Browse files
committed
Merge branch 'master' into beta
2 parents 5585614 + 97015b6 commit 2ccdac8

File tree

11 files changed

+47
-36
lines changed

11 files changed

+47
-36
lines changed

comfy/clip_vision.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@ def encode_image(self, image):
5454
t = outputs[k]
5555
if t is not None:
5656
if k == 'hidden_states':
57-
outputs["penultimate_hidden_states"] = t[-2].cpu()
57+
outputs["penultimate_hidden_states"] = t[-2].to(comfy.model_management.intermediate_device())
5858
outputs["hidden_states"] = None
5959
else:
60-
outputs[k] = t.cpu()
60+
outputs[k] = t.to(comfy.model_management.intermediate_device())
6161

6262
return outputs
6363

comfy/model_management.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,12 @@ def text_encoder_dtype(device=None):
522522
else:
523523
return torch.float32
524524

525+
def intermediate_device():
526+
if args.gpu_only:
527+
return get_torch_device()
528+
else:
529+
return torch.device("cpu")
530+
525531
def vae_device():
526532
return get_torch_device()
527533

comfy/model_sampling.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,17 @@ def calculate_denoised(self, sigma, model_output, model_input):
2222
class ModelSamplingDiscrete(torch.nn.Module):
2323
def __init__(self, model_config=None):
2424
super().__init__()
25-
beta_schedule = "linear"
25+
2626
if model_config is not None:
27-
beta_schedule = model_config.sampling_settings.get("beta_schedule", beta_schedule)
28-
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
27+
sampling_settings = model_config.sampling_settings
28+
else:
29+
sampling_settings = {}
30+
31+
beta_schedule = sampling_settings.get("beta_schedule", "linear")
32+
linear_start = sampling_settings.get("linear_start", 0.00085)
33+
linear_end = sampling_settings.get("linear_end", 0.012)
34+
35+
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3)
2936
self.sigma_data = 1.0
3037

3138
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,

comfy/sample.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
9898
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
9999

100100
samples = sampler.sample(noise, positive_copy, negative_copy, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
101-
samples = samples.cpu()
101+
samples = samples.to(comfy.model_management.intermediate_device())
102102

103103
cleanup_additional_models(models)
104104
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
@@ -111,7 +111,7 @@ def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent
111111
sigmas = sigmas.to(model.load_device)
112112

113113
samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
114-
samples = samples.cpu()
114+
samples = samples.to(comfy.model_management.intermediate_device())
115115
cleanup_additional_models(models)
116116
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
117117
return samples

comfy/samplers.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -276,10 +276,7 @@ def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, model_option
276276
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
277277
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, model_options=model_options, seed=seed)
278278
if denoise_mask is not None:
279-
out *= denoise_mask
280-
281-
if denoise_mask is not None:
282-
out += self.latent_image * latent_mask
279+
out = out * denoise_mask + self.latent_image * latent_mask
283280
return out
284281

285282
def simple_scheduler(model, steps):

comfy/sd.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def __init__(self, sd=None, device=None, config=None):
190190
offload_device = model_management.vae_offload_device()
191191
self.vae_dtype = model_management.vae_dtype()
192192
self.first_stage_model.to(self.vae_dtype)
193+
self.output_device = model_management.intermediate_device()
193194

194195
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
195196

@@ -201,9 +202,9 @@ def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
201202

202203
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
203204
output = torch.clamp((
204-
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, pbar = pbar) +
205-
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, pbar = pbar) +
206-
comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, pbar = pbar))
205+
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = 8, output_device=self.output_device, pbar = pbar) +
206+
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = 8, output_device=self.output_device, pbar = pbar) +
207+
comfy.utils.tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount = 8, output_device=self.output_device, pbar = pbar))
207208
/ 3.0) / 2.0, min=0.0, max=1.0)
208209
return output
209210

@@ -214,9 +215,9 @@ def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
214215
pbar = comfy.utils.ProgressBar(steps)
215216

216217
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
217-
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
218-
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
219-
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, pbar=pbar)
218+
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/8), out_channels=4, output_device=self.output_device, pbar=pbar)
219+
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/8), out_channels=4, output_device=self.output_device, pbar=pbar)
220+
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/8), out_channels=4, output_device=self.output_device, pbar=pbar)
220221
samples /= 3.0
221222
return samples
222223

@@ -230,10 +231,10 @@ def decode(self, samples_in):
230231
batch_number = int(free_memory / memory_used)
231232
batch_number = max(1, batch_number)
232233

233-
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device="cpu")
234+
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * 8), round(samples_in.shape[3] * 8)), device=self.output_device)
234235
for x in range(0, samples_in.shape[0], batch_number):
235236
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
236-
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).cpu().float() + 1.0) / 2.0, min=0.0, max=1.0)
237+
pixel_samples[x:x+batch_number] = torch.clamp((self.first_stage_model.decode(samples).to(self.output_device).float() + 1.0) / 2.0, min=0.0, max=1.0)
237238

238239
except model_management.OOM_EXCEPTION as e:
239240
pixel_samples = None
@@ -252,7 +253,7 @@ def decode(self, samples_in):
252253
if pixel_samples is None:
253254
raise e
254255

255-
pixel_samples = pixel_samples.cpu().movedim(1,-1)
256+
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
256257
return pixel_samples
257258

258259
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap = 16):
@@ -270,10 +271,10 @@ def encode(self, pixel_samples):
270271
free_memory = model_management.get_free_memory(self.device)
271272
batch_number = int(free_memory / memory_used)
272273
batch_number = max(1, batch_number)
273-
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device="cpu")
274+
samples = torch.empty((pixel_samples.shape[0], 4, round(pixel_samples.shape[2] // 8), round(pixel_samples.shape[3] // 8)), device=self.output_device)
274275
for x in range(0, pixel_samples.shape[0], batch_number):
275276
pixels_in = (2. * pixel_samples[x:x+batch_number] - 1.).to(self.vae_dtype).to(self.device)
276-
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).cpu().float()
277+
samples[x:x+batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
277278

278279
except model_management.OOM_EXCEPTION as e:
279280
samples = None

comfy/sd1_clip.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def encode_token_weights(self, token_weight_pairs):
3939

4040
out, pooled = self.encode(to_encode)
4141
if pooled is not None:
42-
first_pooled = pooled[0:1].cpu()
42+
first_pooled = pooled[0:1].to(model_management.intermediate_device())
4343
else:
4444
first_pooled = pooled
4545

@@ -56,8 +56,8 @@ def encode_token_weights(self, token_weight_pairs):
5656
output.append(z)
5757

5858
if (len(output) == 0):
59-
return out[-1:].cpu(), first_pooled
60-
return torch.cat(output, dim=-2).cpu(), first_pooled
59+
return out[-1:].to(model_management.intermediate_device()), first_pooled
60+
return torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled
6161

6262
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
6363
"""Uses the CLIP transformer encoder for text (from huggingface)"""

comfy/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ def lanczos(samples, width, height):
402402
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
403403
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
404404
result = torch.stack(images)
405-
return result
405+
return result.to(samples.device, samples.dtype)
406406

407407
def common_upscale(samples, width, height, upscale_method, crop):
408408
if crop == "center":
@@ -431,17 +431,17 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
431431
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
432432

433433
@torch.inference_mode()
434-
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, pbar = None):
435-
output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device="cpu")
434+
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
435+
output = torch.empty((samples.shape[0], out_channels, round(samples.shape[2] * upscale_amount), round(samples.shape[3] * upscale_amount)), device=output_device)
436436
for b in range(samples.shape[0]):
437437
s = samples[b:b+1]
438-
out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu")
439-
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu")
438+
out = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
439+
out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device=output_device)
440440
for y in range(0, s.shape[2], tile_y - overlap):
441441
for x in range(0, s.shape[3], tile_x - overlap):
442442
s_in = s[:,:,y:y+tile_y,x:x+tile_x]
443443

444-
ps = function(s_in).cpu()
444+
ps = function(s_in).to(output_device)
445445
mask = torch.ones_like(ps)
446446
feather = round(overlap * upscale_amount)
447447
for t in range(feather):

comfy_extras/nodes_canny.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def INPUT_TYPES(s):
291291

292292
def detect_edge(self, image, low_threshold, high_threshold):
293293
output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold)
294-
img_out = output[1].cpu().repeat(1, 3, 1, 1).movedim(1, -1)
294+
img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1)
295295
return (img_out,)
296296

297297
NODE_CLASS_MAPPINGS = {

comfy_extras/nodes_post_processing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def sharpen(self, image: torch.Tensor, sharpen_radius: int, sigma:float, alpha:
226226
batch_size, height, width, channels = image.shape
227227

228228
kernel_size = sharpen_radius * 2 + 1
229-
kernel = gaussian_kernel(kernel_size, sigma) * -(alpha*10)
229+
kernel = gaussian_kernel(kernel_size, sigma, device=image.device) * -(alpha*10)
230230
center = kernel_size // 2
231231
kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0
232232
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)

nodes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -947,8 +947,8 @@ def append(self, conditioning_to, clip, gligen_textbox_model, text, width, heigh
947947
return (c, )
948948

949949
class EmptyLatentImage:
950-
def __init__(self, device="cpu"):
951-
self.device = device
950+
def __init__(self):
951+
self.device = comfy.model_management.intermediate_device()
952952

953953
@classmethod
954954
def INPUT_TYPES(s):
@@ -961,7 +961,7 @@ def INPUT_TYPES(s):
961961
CATEGORY = "latent"
962962

963963
def generate(self, width, height, batch_size=1):
964-
latent = torch.zeros([batch_size, 4, height // 8, width // 8])
964+
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
965965
return ({"samples":latent}, )
966966

967967

0 commit comments

Comments
 (0)