@@ -216,8 +216,18 @@ def load_pipeline(args, vmfbs: dict, weights: dict):
216
216
if args .compiled_pipeline :
217
217
runners ["pipe" ] = vmfbRunner (
218
218
args .rt_device ,
219
- [vmfbs ["scheduled_unet" ], vmfbs ["prompt_encoder" ], vmfbs ["vae_decode" ], vmfbs ["full_pipeline" ]],
220
- [weights ["scheduled_unet" ], weights ["prompt_encoder" ], weights ["vae_decode" ], None ],
219
+ [
220
+ vmfbs ["scheduled_unet" ],
221
+ vmfbs ["prompt_encoder" ],
222
+ vmfbs ["vae_decode" ],
223
+ vmfbs ["full_pipeline" ],
224
+ ],
225
+ [
226
+ weights ["scheduled_unet" ],
227
+ weights ["prompt_encoder" ],
228
+ weights ["vae_decode" ],
229
+ None ,
230
+ ],
221
231
)
222
232
else :
223
233
runners ["pipe" ] = vmfbRunner (
@@ -263,7 +273,9 @@ def generate_images(args, runners: dict):
263
273
numpy_images = []
264
274
265
275
if args .compiled_pipeline and (args .batch_count > 1 ):
266
- print ("Compiled one-shot pipeline only supports 1 image at a time for now. Setting batch count to 1." )
276
+ print (
277
+ "Compiled one-shot pipeline only supports 1 image at a time for now. Setting batch count to 1."
278
+ )
267
279
args .batch_count = 1
268
280
269
281
for i in range (args .batch_count ):
@@ -319,27 +331,31 @@ def generate_images(args, runners: dict):
319
331
[ireert .asdevicearray (runners ["pipe" ].config .device , text_input_ids )]
320
332
)
321
333
uncond_input_ids_list .extend (
322
- [
323
- ireert .asdevicearray (
324
- runners ["pipe" ].config .device , uncond_input_ids
325
- )
326
- ]
334
+ [ireert .asdevicearray (runners ["pipe" ].config .device , uncond_input_ids )]
327
335
)
328
336
if args .compiled_pipeline :
329
337
inf_start = time .time ()
330
- image = runners ["pipe" ].ctx .modules .sdxl_compiled_pipeline ["tokens_to_image" ](samples [0 ], guidance_scale , * text_input_ids_list , * uncond_input_ids_list )
338
+ image = runners ["pipe" ].ctx .modules .sdxl_compiled_pipeline ["tokens_to_image" ](
339
+ samples [0 ], guidance_scale , * text_input_ids_list , * uncond_input_ids_list
340
+ )
331
341
inf_end = time .time ()
332
- print ("Total inference time (Tokens to Image): " + str (inf_end - inf_start ) + "sec" )
342
+ print (
343
+ "Total inference time (Tokens to Image): "
344
+ + str (inf_end - inf_start )
345
+ + "sec"
346
+ )
333
347
numpy_images .append (image .to_host ())
334
348
else :
335
349
encode_prompts_start = time .time ()
336
350
337
- prompt_embeds , add_text_embeds = runners ["prompt_encoder" ].ctx .modules .compiled_clip [
338
- "encode_prompts"
339
- ](* text_input_ids_list , * uncond_input_ids_list )
351
+ prompt_embeds , add_text_embeds = runners [
352
+ "prompt_encoder"
353
+ ].ctx .modules .compiled_clip ["encode_prompts" ](
354
+ * text_input_ids_list , * uncond_input_ids_list
355
+ )
340
356
341
357
encode_prompts_end = time .time ()
342
-
358
+
343
359
for i in range (args .batch_count ):
344
360
unet_start = time .time ()
345
361
@@ -375,12 +391,8 @@ def generate_images(args, runners: dict):
375
391
"sec\n " ,
376
392
)
377
393
end = time .time ()
378
- print (
379
- "Total CLIP time:" , encode_prompts_end - encode_prompts_start , "sec"
380
- )
381
- print (
382
- "Total tokenize time:" , encode_prompts_start - tokenize_start , "sec"
383
- )
394
+ print ("Total CLIP time:" , encode_prompts_end - encode_prompts_start , "sec" )
395
+ print ("Total tokenize time:" , encode_prompts_start - tokenize_start , "sec" )
384
396
print ("Loading time: " , encode_prompts_start - pipe_start , "sec" )
385
397
if args .batch_count > 1 :
386
398
print (
@@ -390,13 +402,7 @@ def generate_images(args, runners: dict):
390
402
)
391
403
timestamp = dt .now ().strftime ("%Y-%m-%d_%H-%M-%S" )
392
404
for idx , image in enumerate (numpy_images ):
393
- image = (
394
- torch .from_numpy (image )
395
- .cpu ()
396
- .permute (0 , 2 , 3 , 1 )
397
- .float ()
398
- .numpy ()
399
- )
405
+ image = torch .from_numpy (image ).cpu ().permute (0 , 2 , 3 , 1 ).float ().numpy ()
400
406
image = numpy_to_pil_image (image )
401
407
img_path = "sdxl_output_" + timestamp + "_" + str (idx ) + ".png"
402
408
image [0 ].save (img_path )
0 commit comments