Skip to content

Commit

Permalink
Filename fixes, explicit input dtypes for i8 punet
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed Oct 4, 2024
1 parent d3c8e80 commit 40808db
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,8 @@ def setup_punet(self):
]
self.map["unet"]["keywords"] += "punet"
if self.use_i8_punet:
self.map["unet"]["np_dtype"] = "int8"
self.map["unet"]["torch_dtype"] = torch.int8
if self.add_tk_kernels:
self.map["unet"]["export_args"]["add_tk_kernels"] = self.add_tk_kernels
self.map["unet"]["export_args"]["tk_kernels_dir"] = self.tk_kernels_dir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def export_prompt_encoder(

safe_name = utils.create_safe_name(
hf_model_name,
f"_bs{batch_size}_{str(max_length)}-{precision}-text-encoder-{device}",
f"_bs{batch_size}_{str(max_length)}-{precision}-text-encoder",
)
if pipeline_dir not in [None, ""]:
safe_name = os.path.join(pipeline_dir, safe_name)
Expand Down
95 changes: 52 additions & 43 deletions models/turbine_models/tests/sdxl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,24 +192,27 @@ def test02_PromptEncoder(self):
if arguments["device"] in ["vulkan", "cuda"]:
self.skipTest("Compilation error on vulkan; To be tested on cuda.")
clip_filename = (
"_".join([
create_safe_name(arguments["hf_model_name"], ""),
"bs" + str(arguments["batch_size"]),
str(arguments["max_length"]),
arguments["precision"],
"text_encoder",
arguments["device"],
arguments["iree_target_triple"],
])
"_".join(
[
create_safe_name(arguments["hf_model_name"], ""),
"bs" + str(arguments["batch_size"]),
str(arguments["max_length"]),
arguments["precision"],
"text_encoder",
arguments["iree_target_triple"],
]
)
+ ".vmfb"
)
arguments["vmfb_path"] = os.path.join("test_vmfbs", clip_filename)
clip_w_filename = (
"_".join([
create_safe_name(arguments["hf_model_name"], ""),
"text_encoder",
arguments["precision"],
])
"_".join(
[
create_safe_name(arguments["hf_model_name"], ""),
"text_encoder",
arguments["precision"],
]
)
+ ".safetensors"
)
arguments["external_weight_path"] = os.path.join(
Expand Down Expand Up @@ -272,25 +275,28 @@ def test03_unet(self):
if arguments["device"] in ["vulkan", "cuda"]:
self.skipTest("Unknown error on vulkan; To be tested on cuda.")
unet_filename = (
"_".join([
create_safe_name(arguments["hf_model_name"], ""),
"bs" + str(arguments["batch_size"]),
str(arguments["max_length"]),
str(arguments["height"]) + "x" + str(arguments["width"]),
arguments["precision"],
"unet",
arguments["device"],
arguments["iree_target_triple"],
])
"_".join(
[
create_safe_name(arguments["hf_model_name"], ""),
"bs" + str(arguments["batch_size"]),
str(arguments["max_length"]),
str(arguments["height"]) + "x" + str(arguments["width"]),
arguments["precision"],
"unet",
arguments["iree_target_triple"],
]
)
+ ".vmfb"
)
arguments["vmfb_path"] = os.path.join("test_vmfbs", unet_filename)
unet_w_filename = (
"_".join([
create_safe_name(arguments["hf_model_name"], ""),
"unet",
arguments["precision"],
])
"_".join(
[
create_safe_name(arguments["hf_model_name"], ""),
"unet",
arguments["precision"],
]
)
+ ".safetensors"
)
arguments["external_weight_path"] = os.path.join(
Expand Down Expand Up @@ -363,24 +369,27 @@ def test04_ExportVaeModelDecode(self):
self.skipTest("Compilation error on vulkan; To be tested on cuda.")

vae_filename = (
"_".join([
create_safe_name(arguments["hf_model_name"], ""),
"bs" + str(arguments["batch_size"]),
str(arguments["height"]) + "x" + str(arguments["width"]),
arguments["precision"],
"vae",
arguments["device"],
arguments["iree_target_triple"],
])
"_".join(
[
create_safe_name(arguments["hf_model_name"], ""),
"bs" + str(arguments["batch_size"]),
str(arguments["height"]) + "x" + str(arguments["width"]),
arguments["precision"],
"vae",
arguments["iree_target_triple"],
]
)
+ ".vmfb"
)
arguments["vmfb_path"] = os.path.join("test_vmfbs", vae_filename)
vae_w_filename = (
"_".join([
create_safe_name(arguments["hf_model_name"], ""),
"vae",
arguments["precision"],
])
"_".join(
[
create_safe_name(arguments["hf_model_name"], ""),
"vae",
arguments["precision"],
]
)
+ ".safetensors"
)
arguments["external_weight_path"] = os.path.join(
Expand Down

0 comments on commit 40808db

Please sign in to comment.