From d2c7ebb1e68bcc449a02ed0755c2e2eaa92c7965 Mon Sep 17 00:00:00 2001 From: Ean Garvey Date: Tue, 27 Feb 2024 01:48:53 -0600 Subject: [PATCH] Change unet_runner timestep input to int64 --- .../turbine_models/custom_models/sdxl_inference/unet_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/turbine_models/custom_models/sdxl_inference/unet_runner.py b/models/turbine_models/custom_models/sdxl_inference/unet_runner.py index d79602d94..904d1ba65 100644 --- a/models/turbine_models/custom_models/sdxl_inference/unet_runner.py +++ b/models/turbine_models/custom_models/sdxl_inference/unet_runner.py @@ -152,7 +152,7 @@ def forward( sample = torch.rand( args.batch_size, 4, args.height // 8, args.width // 8, dtype=dtype ) - timestep = torch.zeros(1, dtype=dtype) + timestep = torch.zeros(1, dtype=torch.int64) prompt_embeds = torch.rand(2 * args.batch_size, args.max_length, 2048, dtype=dtype) text_embeds = torch.rand(2 * args.batch_size, 1280, dtype=dtype) time_ids = torch.zeros(2 * args.batch_size, 6, dtype=dtype)