From 329c57199302f6b9ccfebb86c96e937c386da92f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 14 Dec 2023 11:41:49 -0500 Subject: [PATCH] Improve code legibility. --- comfy/samplers.py | 46 +++++++++++++++++++++++----------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 7dc27528aa4..39bc3774a4c 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -2,6 +2,7 @@ from .extra_samplers import uni_pc import torch import enum +import collections from comfy import model_management import math from comfy import model_base @@ -61,9 +62,7 @@ def get_area_and_mult(conds, x_in, timestep_in): for c in model_conds: conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) - control = None - if 'control' in conds: - control = conds['control'] + control = conds.get('control', None) patches = None if 'gligen' in conds: @@ -78,7 +77,8 @@ def get_area_and_mult(conds, x_in, timestep_in): patches['middle_patch'] = [gligen_patch] - return (input_x, mult, conditioning, area, control, patches) + cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches']) + return cond_obj(input_x, mult, conditioning, area, control, patches) def cond_equal_size(c1, c2): if c1 is c2: @@ -91,24 +91,24 @@ def cond_equal_size(c1, c2): return True def can_concat_cond(c1, c2): - if c1[0].shape != c2[0].shape: + if c1.input_x.shape != c2.input_x.shape: return False - #control - if (c1[4] is None) != (c2[4] is None): - return False - if c1[4] is not None: - if c1[4] is not c2[4]: + def objects_concatable(obj1, obj2): + if (obj1 is None) != (obj2 is None): return False + if obj1 is not None: + if obj1 is not obj2: + return False + return True - #patches - if (c1[5] is None) != (c2[5] is None): + if not objects_concatable(c1.control, c2.control): + return False + + if not objects_concatable(c1.patches, c2.patches): return False - if (c1[5] is not None): - if c1[5] is not c2[5]: - return False - return cond_equal_size(c1[2], c2[2]) + return cond_equal_size(c1.conditioning, c2.conditioning) def cond_cat(c_list): c_crossattn = [] @@ -184,13 +184,13 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): for x in to_batch: o = to_run.pop(x) p = o[0] - input_x += [p[0]] - mult += [p[1]] - c += [p[2]] - area += [p[3]] - cond_or_uncond += [o[1]] - control = p[4] - patches = p[5] + input_x.append(p.input_x) + mult.append(p.mult) + c.append(p.conditioning) + area.append(p.area) + cond_or_uncond.append(o[1]) + control = p.control + patches = p.patches batch_chunks = len(cond_or_uncond) input_x = torch.cat(input_x)