9
9
import gc
10
10
import re
11
11
import inspect
12
+ from typing import List , Literal , NamedTuple , Optional
12
13
13
14
import torch
14
15
import nodes
@@ -298,8 +299,15 @@ def reset(self):
298
299
self .outputs = {}
299
300
self .object_storage = {}
300
301
self .outputs_ui = {}
302
+ self .status_messages = []
303
+ self .success = True
301
304
self .old_prompt = {}
302
305
306
+ def add_message (self , event , data , broadcast : bool ):
307
+ self .status_messages .append ((event , data ))
308
+ if self .server .client_id is not None or broadcast :
309
+ self .server .send_sync (event , data , self .server .client_id )
310
+
303
311
def handle_execution_error (self , prompt_id , prompt , current_outputs , executed , error , ex ):
304
312
node_id = error ["node_id" ]
305
313
class_type = prompt [node_id ]["class_type" ]
@@ -313,23 +321,22 @@ def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, e
313
321
"node_type" : class_type ,
314
322
"executed" : list (executed ),
315
323
}
316
- self .server . send_sync ("execution_interrupted" , mes , self . server . client_id )
324
+ self .add_message ("execution_interrupted" , mes , broadcast = True )
317
325
else :
318
- if self .server .client_id is not None :
319
- mes = {
320
- "prompt_id" : prompt_id ,
321
- "node_id" : node_id ,
322
- "node_type" : class_type ,
323
- "executed" : list (executed ),
324
-
325
- "exception_message" : error ["exception_message" ],
326
- "exception_type" : error ["exception_type" ],
327
- "traceback" : error ["traceback" ],
328
- "current_inputs" : error ["current_inputs" ],
329
- "current_outputs" : error ["current_outputs" ],
330
- }
331
- self .server .send_sync ("execution_error" , mes , self .server .client_id )
326
+ mes = {
327
+ "prompt_id" : prompt_id ,
328
+ "node_id" : node_id ,
329
+ "node_type" : class_type ,
330
+ "executed" : list (executed ),
332
331
332
+ "exception_message" : error ["exception_message" ],
333
+ "exception_type" : error ["exception_type" ],
334
+ "traceback" : error ["traceback" ],
335
+ "current_inputs" : error ["current_inputs" ],
336
+ "current_outputs" : error ["current_outputs" ],
337
+ }
338
+ self .add_message ("execution_error" , mes , broadcast = False )
339
+
333
340
# Next, remove the subsequent outputs since they will not be executed
334
341
to_delete = []
335
342
for o in self .outputs :
@@ -350,8 +357,8 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
350
357
else :
351
358
self .server .client_id = None
352
359
353
- if self .server . client_id is not None :
354
- self .server . send_sync ("execution_start" , { "prompt_id" : prompt_id }, self . server . client_id )
360
+ self .status_messages = []
361
+ self .add_message ("execution_start" , { "prompt_id" : prompt_id }, broadcast = False )
355
362
356
363
with torch .inference_mode ():
357
364
#delete cached outputs if nodes don't exist for them
@@ -384,8 +391,9 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
384
391
del d
385
392
386
393
comfy .model_management .cleanup_models ()
387
- if self .server .client_id is not None :
388
- self .server .send_sync ("execution_cached" , { "nodes" : list (current_outputs ) , "prompt_id" : prompt_id }, self .server .client_id )
394
+ self .add_message ("execution_cached" ,
395
+ { "nodes" : list (current_outputs ) , "prompt_id" : prompt_id },
396
+ broadcast = False )
389
397
executed = set ()
390
398
output_node_id = None
391
399
to_execute = []
@@ -401,8 +409,8 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
401
409
# This call shouldn't raise anything if there's an error deep in
402
410
# the actual SD code, instead it will report the node where the
403
411
# error was raised
404
- success , error , ex = recursive_execute (self .server , prompt , self .outputs , output_node_id , extra_data , executed , prompt_id , self .outputs_ui , self .object_storage )
405
- if success is not True :
412
+ self . success , error , ex = recursive_execute (self .server , prompt , self .outputs , output_node_id , extra_data , executed , prompt_id , self .outputs_ui , self .object_storage )
413
+ if self . success is not True :
406
414
self .handle_execution_error (prompt_id , prompt , current_outputs , executed , error , ex )
407
415
break
408
416
@@ -754,14 +762,27 @@ def get(self, timeout=None):
754
762
self .server .queue_updated ()
755
763
return (item , i )
756
764
757
- def task_done (self , item_id , outputs ):
765
+ class ExecutionStatus (NamedTuple ):
766
+ status_str : Literal ['success' , 'error' ]
767
+ completed : bool
768
+ messages : List [str ]
769
+
770
+ def task_done (self , item_id , outputs ,
771
+ status : Optional ['PromptQueue.ExecutionStatus' ]):
758
772
with self .mutex :
759
773
prompt = self .currently_running .pop (item_id )
760
774
if len (self .history ) > MAXIMUM_HISTORY_SIZE :
761
775
self .history .pop (next (iter (self .history )))
762
- self .history [prompt [1 ]] = { "prompt" : prompt , "outputs" : {} }
763
- for o in outputs :
764
- self .history [prompt [1 ]]["outputs" ][o ] = outputs [o ]
776
+
777
+ status_dict : dict | None = None
778
+ if status is not None :
779
+ status_dict = copy .deepcopy (status ._asdict ())
780
+
781
+ self .history [prompt [1 ]] = {
782
+ "prompt" : prompt ,
783
+ "outputs" : copy .deepcopy (outputs ),
784
+ 'status' : status_dict ,
785
+ }
765
786
self .server .queue_updated ()
766
787
767
788
def get_current_queue (self ):
0 commit comments