diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index f0cf63d4..a8bc702a 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -136,7 +136,7 @@ def test00_sdxl_pipe(self): ) assert output is not None - def test01_sdxl_pipe_i8(self): + def test01_sdxl_pipe_i8_punet(self): from turbine_models.custom_models.sd_inference.sd_pipeline import ( SharkSDPipeline, ) @@ -192,7 +192,7 @@ def test02_PromptEncoder(self): if arguments["device"] in ["vulkan", "cuda"]: self.skipTest("Compilation error on vulkan; To be tested on cuda.") clip_filename = ( - "_".join( + "_".join([ create_safe_name(arguments["hf_model_name"], ""), "bs" + str(arguments["batch_size"]), str(arguments["max_length"]), @@ -200,16 +200,16 @@ def test02_PromptEncoder(self): "text_encoder", arguments["device"], arguments["iree_target_triple"], - ) + ]) + ".vmfb" ) arguments["vmfb_path"] = os.path.join("test_vmfbs", clip_filename) clip_w_filename = ( - "_".join( + "_".join([ create_safe_name(arguments["hf_model_name"], ""), "text_encoder", arguments["precision"], - ) + ]) + ".safetensors" ) arguments["external_weight_path"] = os.path.join( @@ -241,7 +241,7 @@ def test02_PromptEncoder(self): turbine_output2, ) = sdxl_prompt_encoder_runner.run_prompt_encoder( arguments["vmfb_path"], - arguments["rt_driver"], + arguments["rt_device"], arguments["external_weight_path"], text_input_ids_list, uncond_input_ids_list, @@ -259,7 +259,7 @@ def test02_PromptEncoder(self): "prompt_encoder", arguments["vmfb_path"], arguments["external_weight_path"], - arguments["rt_driver"], + arguments["rt_device"], max_length=arguments["max_length"], tracy_profile=arguments["tracy_profile"], ) @@ -272,7 +272,7 @@ def test03_unet(self): if arguments["device"] in ["vulkan", "cuda"]: self.skipTest("Unknown error on vulkan; To be tested on cuda.") unet_filename = ( - "_".join( + "_".join([ create_safe_name(arguments["hf_model_name"], ""), "bs" + str(arguments["batch_size"]), str(arguments["max_length"]), @@ -281,16 +281,16 @@ def test03_unet(self): "unet", arguments["device"], arguments["iree_target_triple"], - ) + ]) + ".vmfb" ) arguments["vmfb_path"] = os.path.join("test_vmfbs", unet_filename) unet_w_filename = ( - "_".join( + "_".join([ create_safe_name(arguments["hf_model_name"], ""), "unet", arguments["precision"], - ) + ]) + ".safetensors" ) arguments["external_weight_path"] = os.path.join( @@ -363,7 +363,7 @@ def test04_ExportVaeModelDecode(self): self.skipTest("Compilation error on vulkan; To be tested on cuda.") vae_filename = ( - "_".join( + "_".join([ create_safe_name(arguments["hf_model_name"], ""), "bs" + str(arguments["batch_size"]), str(arguments["height"]) + "x" + str(arguments["width"]), @@ -371,16 +371,16 @@ def test04_ExportVaeModelDecode(self): "vae", arguments["device"], arguments["iree_target_triple"], - ) + ]) + ".vmfb" ) arguments["vmfb_path"] = os.path.join("test_vmfbs", vae_filename) vae_w_filename = ( - "_".join( + "_".join([ create_safe_name(arguments["hf_model_name"], ""), "vae", arguments["precision"], - ) + ]) + ".safetensors" ) arguments["external_weight_path"] = os.path.join(