@@ -49,7 +49,10 @@ def __init__(self, device: str = "cuda", dtype: torch.dtype = torch.float16) ->
4949
5050 self ._stream = None # set in install
5151 # Per-frame prepared tensor cache to avoid per-step device/dtype alignment and batch repeats
52- self ._prepared_cache : Optional [Dict [str , Any ]] = None
52+ self ._prepared_tensors : List [Optional [torch .Tensor ]] = []
53+ self ._prepared_device : Optional [torch .device ] = None
54+ self ._prepared_dtype : Optional [torch .dtype ] = None
55+ self ._prepared_batch : Optional [int ] = None
5356 self ._images_version : int = 0
5457
5558 # ---------- Public API (used by wrapper in a later step) ----------
@@ -66,8 +69,11 @@ def install(self, stream) -> None:
6669 setattr (stream , 'controlnets' , self .controlnets )
6770 setattr (stream , 'controlnet_scales' , self .controlnet_scales )
6871 setattr (stream , 'preprocessors' , self .preprocessors )
69- # Reset caches on install
70- self ._prepared_cache = None
72+ # Reset prepared tensors on install
73+ self ._prepared_tensors = []
74+ self ._prepared_device = None
75+ self ._prepared_dtype = None
76+ self ._prepared_batch = None
7177
7278 def add_controlnet (self , cfg : ControlNetConfig , control_image : Optional [Union [str , Any , torch .Tensor ]] = None ) -> None :
7379 model = self ._load_pytorch_controlnet_model (cfg .model_id )
@@ -120,8 +126,8 @@ def add_controlnet(self, cfg: ControlNetConfig, control_image: Optional[Union[st
120126 self .controlnet_scales .append (float (cfg .conditioning_scale ))
121127 self .preprocessors .append (preproc )
122128 self .enabled_list .append (bool (cfg .enabled ))
123- # Invalidate prepared cache and bump version when graph changes
124- self ._prepared_cache = None
129+ # Invalidate prepared tensors and bump version when graph changes
130+ self ._prepared_tensors = []
125131 self ._images_version += 1
126132
127133 def update_control_image_efficient (self , control_image : Union [str , Any , torch .Tensor ], index : Optional [int ] = None ) -> None :
@@ -154,9 +160,13 @@ def update_control_image_efficient(self, control_image: Union[str, Any, torch.Te
154160 with self ._collections_lock :
155161 if processed is not None and index < len (self .controlnet_images ):
156162 self .controlnet_images [index ] = processed
157- # Invalidate prepared cache and bump version for per-frame reuse
158- self ._prepared_cache = None
163+ # Invalidate prepared tensors and bump version for per-frame reuse
164+ self ._prepared_tensors = []
159165 self ._images_version += 1
166+ # Pre-prepare tensors if we know the target specs
167+ if self ._stream and hasattr (self ._stream , 'device' ) and hasattr (self ._stream , 'dtype' ):
168+ # Use default batch size of 1 for now, will be adjusted on first use
169+ self .prepare_frame_tensors (self ._stream .device , self ._stream .dtype , 1 )
160170 return
161171
162172 # Use intelligent pipelining (automatically detects feedback preprocessors and switches to sync)
@@ -178,8 +188,12 @@ def update_control_image_efficient(self, control_image: Union[str, Any, torch.Te
178188 if img is not None and i < len (self .controlnet_images ):
179189 self .controlnet_images [i ] = img
180190 # Invalidate prepared cache and bump version after bulk update
181- self ._prepared_cache = None
191+ self ._prepared_tensors = []
182192 self ._images_version += 1
193+ # Pre-prepare tensors if we know the target specs
194+ if self ._stream and hasattr (self ._stream , 'device' ) and hasattr (self ._stream , 'dtype' ):
195+ # Use default batch size of 1 for now, will be adjusted on first use
196+ self .prepare_frame_tensors (self ._stream .device , self ._stream .dtype , 1 )
183197
184198 def update_controlnet_scale (self , index : int , scale : float ) -> None :
185199 with self ._collections_lock :
@@ -203,8 +217,8 @@ def remove_controlnet(self, index: int) -> None:
203217 del self .preprocessors [index ]
204218 if index < len (self .enabled_list ):
205219 del self .enabled_list [index ]
206- # Invalidate prepared cache and bump version
207- self ._prepared_cache = None
220+ # Invalidate prepared tensors and bump version
221+ self ._prepared_tensors = []
208222 self ._images_version += 1
209223
210224 def reorder_controlnets_by_model_ids (self , desired_model_ids : List [str ]) -> None :
@@ -260,6 +274,54 @@ def get_current_config(self) -> List[Dict[str, Any]]:
260274 })
261275 return cfg
262276
277+ def prepare_frame_tensors (self , device : torch .device , dtype : torch .dtype , batch_size : int ) -> None :
278+ """Prepare control image tensors for the current frame.
279+
280+ This method is called once per frame to prepare all control images with the correct
281+ device, dtype, and batch size. This avoids redundant operations during each denoising step.
282+
283+ Args:
284+ device: Target device for tensors
285+ dtype: Target dtype for tensors
286+ batch_size: Target batch size
287+ """
288+ with self ._collections_lock :
289+ # Check if we need to re-prepare tensors
290+ cache_valid = (
291+ self ._prepared_device == device and
292+ self ._prepared_dtype == dtype and
293+ self ._prepared_batch == batch_size and
294+ len (self ._prepared_tensors ) == len (self .controlnet_images )
295+ )
296+
297+ if cache_valid :
298+ return
299+
300+ # Prepare tensors for current frame
301+ self ._prepared_tensors = []
302+ for img in self .controlnet_images :
303+ if img is None :
304+ self ._prepared_tensors .append (None )
305+ continue
306+
307+ # Prepare tensor with correct batch size
308+ prepared = img
309+ if prepared .dim () == 4 and prepared .shape [0 ] != batch_size :
310+ if prepared .shape [0 ] == 1 :
311+ prepared = prepared .repeat (batch_size , 1 , 1 , 1 )
312+ else :
313+ repeat_factor = max (1 , batch_size // prepared .shape [0 ])
314+ prepared = prepared .repeat (repeat_factor , 1 , 1 , 1 )[:batch_size ]
315+
316+ # Move to correct device and dtype
317+ prepared = prepared .to (device = device , dtype = dtype )
318+ self ._prepared_tensors .append (prepared )
319+
320+ # Update cache state
321+ self ._prepared_device = device
322+ self ._prepared_dtype = dtype
323+ self ._prepared_batch = batch_size
324+
263325 # ---------- Internal helpers ----------
264326 def build_unet_hook (self ) -> UnetHook :
265327 def _unet_hook (ctx : StepCtx ) -> UnetKwargsDelta :
@@ -324,40 +386,15 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
324386 down_samples_list : List [List [torch .Tensor ]] = []
325387 mid_samples_list : List [torch .Tensor ] = []
326388
327- # Prepare control images once per frame for current device/dtype/batch
328- try :
329- main_batch = x_t .shape [0 ]
330- cache_ok = (
331- isinstance (self ._prepared_cache , dict )
332- and self ._prepared_cache .get ('device' ) == x_t .device
333- and self ._prepared_cache .get ('dtype' ) == x_t .dtype
334- and self ._prepared_cache .get ('batch' ) == main_batch
335- and self ._prepared_cache .get ('version' ) == self ._images_version
336- )
337- if not cache_ok :
338- prepared : List [Optional [torch .Tensor ]] = [None ] * len (self .controlnet_images )
339- for i , base_img in enumerate (self .controlnet_images ):
340- if base_img is None :
341- continue
342- cur = base_img
343- if cur .dim () == 4 and cur .shape [0 ] != main_batch :
344- if cur .shape [0 ] == 1 :
345- cur = cur .repeat (main_batch , 1 , 1 , 1 )
346- else :
347- repeat_factor = max (1 , main_batch // cur .shape [0 ])
348- cur = cur .repeat (repeat_factor , 1 , 1 , 1 )
349- cur = cur .to (device = x_t .device , dtype = x_t .dtype )
350- prepared [i ] = cur
351- self ._prepared_cache = {
352- 'device' : x_t .device ,
353- 'dtype' : x_t .dtype ,
354- 'batch' : main_batch ,
355- 'version' : self ._images_version ,
356- 'prepared' : prepared ,
357- }
358- prepared_images : List [Optional [torch .Tensor ]] = self ._prepared_cache ['prepared' ] if self ._prepared_cache else [None ] * len (self .controlnet_images )
359- except Exception :
360- prepared_images = active_images # Fallback to per-step path if cache prep fails
389+ # Ensure tensors are prepared for this frame
390+ # This should have been called earlier, but we call it here as a safety net
391+ if (self ._prepared_device != x_t .device or
392+ self ._prepared_dtype != x_t .dtype or
393+ self ._prepared_batch != x_t .shape [0 ]):
394+ self .prepare_frame_tensors (x_t .device , x_t .dtype , x_t .shape [0 ])
395+
396+ # Use pre-prepared tensors
397+ prepared_images = self ._prepared_tensors
361398
362399 for cn , img , scale , idx_i in zip (active_controlnets , active_images , active_scales , active_indices ):
363400 # Swap to TRT engine if compiled and available for this model_id
@@ -368,8 +405,8 @@ def _unet_hook(ctx: StepCtx) -> UnetKwargsDelta:
368405 # Swapped to TRT engine
369406 except Exception :
370407 pass
371- # Pull from prepared cache if available
372- current_img = prepared_images [idx_i ] if 'prepared_images' in locals () and prepared_images and idx_i < len (prepared_images ) and prepared_images [ idx_i ] is not None else img
408+ # Use pre- prepared tensor
409+ current_img = prepared_images [idx_i ] if idx_i < len (prepared_images ) else img
373410 if current_img is None :
374411 continue
375412 kwargs = base_kwargs .copy ()
0 commit comments