@@ -251,6 +251,12 @@ def __init__(self, dim):
251
251
def forward (self , t ):
252
252
return timestep_embedding (t , self .dim )
253
253
254
+ def apply_control (h , control , name ):
255
+ if control is not None and name in control and len (control [name ]) > 0 :
256
+ ctrl = control [name ].pop ()
257
+ if ctrl is not None :
258
+ h += ctrl
259
+ return h
254
260
255
261
class UNetModel (nn .Module ):
256
262
"""
@@ -617,25 +623,17 @@ def forward(self, x, timesteps=None, context=None, y=None, control=None, transfo
617
623
for id , module in enumerate (self .input_blocks ):
618
624
transformer_options ["block" ] = ("input" , id )
619
625
h = forward_timestep_embed (module , h , emb , context , transformer_options )
620
- if control is not None and 'input' in control and len (control ['input' ]) > 0 :
621
- ctrl = control ['input' ].pop ()
622
- if ctrl is not None :
623
- h += ctrl
626
+ h = apply_control (h , control , 'input' )
624
627
hs .append (h )
628
+
625
629
transformer_options ["block" ] = ("middle" , 0 )
626
630
h = forward_timestep_embed (self .middle_block , h , emb , context , transformer_options )
627
- if control is not None and 'middle' in control and len (control ['middle' ]) > 0 :
628
- ctrl = control ['middle' ].pop ()
629
- if ctrl is not None :
630
- h += ctrl
631
+ h = apply_control (h , control , 'middle' )
631
632
632
633
for id , module in enumerate (self .output_blocks ):
633
634
transformer_options ["block" ] = ("output" , id )
634
635
hsp = hs .pop ()
635
- if control is not None and 'output' in control and len (control ['output' ]) > 0 :
636
- ctrl = control ['output' ].pop ()
637
- if ctrl is not None :
638
- hsp += ctrl
636
+ h = apply_control (h , control , 'output' )
639
637
640
638
if "output_block_patch" in transformer_patches :
641
639
patch = transformer_patches ["output_block_patch" ]
0 commit comments