Skip to content

Commit

Permalink
Change unet_runner timestep input to int64
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Feb 27, 2024
1 parent 9ffaf82 commit d2c7ebb
Showing 1 changed file with 1 addition and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def forward(
sample = torch.rand(
args.batch_size, 4, args.height // 8, args.width // 8, dtype=dtype
)
timestep = torch.zeros(1, dtype=dtype)
timestep = torch.zeros(1, dtype=torch.int64)
prompt_embeds = torch.rand(2 * args.batch_size, args.max_length, 2048, dtype=dtype)
text_embeds = torch.rand(2 * args.batch_size, 1280, dtype=dtype)
time_ids = torch.zeros(2 * args.batch_size, 6, dtype=dtype)
Expand Down

0 comments on commit d2c7ebb

Please sign in to comment.