Skip to content

Commit 5bee2e8

Browse files
committed
Fix formatting
1 parent 9b9aa59 commit 5bee2e8

File tree

6 files changed

+39
-29
lines changed

6 files changed

+39
-29
lines changed

models/turbine_models/custom_models/sd_inference/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics)",
2424
],
2525
"unet": [
26-
#"--iree-flow-split-matmul-reduction=5",
26+
# "--iree-flow-split-matmul-reduction=5",
2727
"--iree-codegen-gpu-native-math-precision=true",
2828
"--iree-codegen-llvmgpu-use-vector-distribution",
2929
],

models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,18 @@ def load_pipeline(args, vmfbs: dict, weights: dict):
216216
if args.compiled_pipeline:
217217
runners["pipe"] = vmfbRunner(
218218
args.rt_device,
219-
[vmfbs["scheduled_unet"], vmfbs["prompt_encoder"], vmfbs["vae_decode"], vmfbs["full_pipeline"]],
220-
[weights["scheduled_unet"], weights["prompt_encoder"], weights["vae_decode"], None],
219+
[
220+
vmfbs["scheduled_unet"],
221+
vmfbs["prompt_encoder"],
222+
vmfbs["vae_decode"],
223+
vmfbs["full_pipeline"],
224+
],
225+
[
226+
weights["scheduled_unet"],
227+
weights["prompt_encoder"],
228+
weights["vae_decode"],
229+
None,
230+
],
221231
)
222232
else:
223233
runners["pipe"] = vmfbRunner(
@@ -263,7 +273,9 @@ def generate_images(args, runners: dict):
263273
numpy_images = []
264274

265275
if args.compiled_pipeline and (args.batch_count > 1):
266-
print("Compiled one-shot pipeline only supports 1 image at a time for now. Setting batch count to 1.")
276+
print(
277+
"Compiled one-shot pipeline only supports 1 image at a time for now. Setting batch count to 1."
278+
)
267279
args.batch_count = 1
268280

269281
for i in range(args.batch_count):
@@ -319,27 +331,31 @@ def generate_images(args, runners: dict):
319331
[ireert.asdevicearray(runners["pipe"].config.device, text_input_ids)]
320332
)
321333
uncond_input_ids_list.extend(
322-
[
323-
ireert.asdevicearray(
324-
runners["pipe"].config.device, uncond_input_ids
325-
)
326-
]
334+
[ireert.asdevicearray(runners["pipe"].config.device, uncond_input_ids)]
327335
)
328336
if args.compiled_pipeline:
329337
inf_start = time.time()
330-
image = runners["pipe"].ctx.modules.sdxl_compiled_pipeline["tokens_to_image"](samples[0], guidance_scale, *text_input_ids_list, *uncond_input_ids_list)
338+
image = runners["pipe"].ctx.modules.sdxl_compiled_pipeline["tokens_to_image"](
339+
samples[0], guidance_scale, *text_input_ids_list, *uncond_input_ids_list
340+
)
331341
inf_end = time.time()
332-
print("Total inference time (Tokens to Image): " + str(inf_end - inf_start) + "sec")
342+
print(
343+
"Total inference time (Tokens to Image): "
344+
+ str(inf_end - inf_start)
345+
+ "sec"
346+
)
333347
numpy_images.append(image.to_host())
334348
else:
335349
encode_prompts_start = time.time()
336350

337-
prompt_embeds, add_text_embeds = runners["prompt_encoder"].ctx.modules.compiled_clip[
338-
"encode_prompts"
339-
](*text_input_ids_list, *uncond_input_ids_list)
351+
prompt_embeds, add_text_embeds = runners[
352+
"prompt_encoder"
353+
].ctx.modules.compiled_clip["encode_prompts"](
354+
*text_input_ids_list, *uncond_input_ids_list
355+
)
340356

341357
encode_prompts_end = time.time()
342-
358+
343359
for i in range(args.batch_count):
344360
unet_start = time.time()
345361

@@ -375,12 +391,8 @@ def generate_images(args, runners: dict):
375391
"sec\n",
376392
)
377393
end = time.time()
378-
print(
379-
"Total CLIP time:", encode_prompts_end - encode_prompts_start, "sec"
380-
)
381-
print(
382-
"Total tokenize time:", encode_prompts_start - tokenize_start, "sec"
383-
)
394+
print("Total CLIP time:", encode_prompts_end - encode_prompts_start, "sec")
395+
print("Total tokenize time:", encode_prompts_start - tokenize_start, "sec")
384396
print("Loading time: ", encode_prompts_start - pipe_start, "sec")
385397
if args.batch_count > 1:
386398
print(
@@ -390,13 +402,7 @@ def generate_images(args, runners: dict):
390402
)
391403
timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S")
392404
for idx, image in enumerate(numpy_images):
393-
image = (
394-
torch.from_numpy(image)
395-
.cpu()
396-
.permute(0, 2, 3, 1)
397-
.float()
398-
.numpy()
399-
)
405+
image = torch.from_numpy(image).cpu().permute(0, 2, 3, 1).float().numpy()
400406
image = numpy_to_pil_image(image)
401407
img_path = "sdxl_output_" + timestamp + "_" + str(idx) + ".png"
402408
image[0].save(img_path)

models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def export_prompt_encoder(
146146
attn_spec=None,
147147
weights_only=False,
148148
):
149-
if (attn_spec in ["default", "", None]):
149+
if attn_spec in ["default", "", None]:
150150
attn_spec = os.path.join(
151151
os.path.realpath(os.path.dirname(__file__)), "default_mfma_attn_spec.mlir"
152152
)

models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ def run_forward(
251251
exit()
252252
return vmfb
253253

254+
254255
def export_pipeline_module(args):
255256
pipeline_file = (
256257
"sdxl_sched_unet_bench_" + "f32"
@@ -288,6 +289,7 @@ def export_pipeline_module(args):
288289
)
289290
return pipeline_vmfb_path
290291

292+
291293
if __name__ == "__main__":
292294
from turbine_models.custom_models.sdxl_inference.sdxl_cmd_opts import args
293295

models/turbine_models/custom_models/sdxl_inference/unet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def main(
192192

193193
logging.basicConfig(level=logging.DEBUG)
194194
from turbine_models.custom_models.sdxl_inference.sdxl_cmd_opts import args
195+
195196
if args.input_mlir:
196197
unet_model = None
197198
else:

models/turbine_models/custom_models/sdxl_inference/vae_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
torch.random.manual_seed(0)
77

8+
89
def run_vae(
910
device,
1011
example_input,

0 commit comments

Comments
 (0)