From a527d0c795ba5572708095fcf0f9366e2076ba7e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 7 Nov 2023 19:33:40 -0500 Subject: [PATCH] Code refactor. --- .../modules/diffusionmodules/openaimodel.py | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 7dfdfc0a29c..6c2113e3e4f 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -251,6 +251,12 @@ def __init__(self, dim): def forward(self, t): return timestep_embedding(t, self.dim) +def apply_control(h, control, name): + if control is not None and name in control and len(control[name]) > 0: + ctrl = control[name].pop() + if ctrl is not None: + h += ctrl + return h class UNetModel(nn.Module): """ @@ -617,25 +623,17 @@ def forward(self, x, timesteps=None, context=None, y=None, control=None, transfo for id, module in enumerate(self.input_blocks): transformer_options["block"] = ("input", id) h = forward_timestep_embed(module, h, emb, context, transformer_options) - if control is not None and 'input' in control and len(control['input']) > 0: - ctrl = control['input'].pop() - if ctrl is not None: - h += ctrl + h = apply_control(h, control, 'input') hs.append(h) + transformer_options["block"] = ("middle", 0) h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options) - if control is not None and 'middle' in control and len(control['middle']) > 0: - ctrl = control['middle'].pop() - if ctrl is not None: - h += ctrl + h = apply_control(h, control, 'middle') for id, module in enumerate(self.output_blocks): transformer_options["block"] = ("output", id) hsp = hs.pop() - if control is not None and 'output' in control and len(control['output']) > 0: - ctrl = control['output'].pop() - if ctrl is not None: - hsp += ctrl + h = apply_control(h, control, 'output') if "output_block_patch" in transformer_patches: patch = transformer_patches["output_block_patch"]