diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index 2b6eab9b..f0cf63d4 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -84,7 +84,7 @@ def command_line_args(request): @pytest.mark.usefixtures("command_line_args") class StableDiffusionXLTest(unittest.TestCase): - def test00_compile_pipe(self): + def test00_sdxl_pipe(self): from turbine_models.custom_models.sd_inference.sd_pipeline import ( SharkSDPipeline, ) @@ -136,12 +136,43 @@ def test00_compile_pipe(self): ) assert output is not None - # Switch to punet. - self.pipe.unload_submodel("unet") - self.pipe.use_punet = True - self.pipe.use_i8_punet = True - self.pipe.setup_punet() - self.pipe.map["unet"]["export_args"]["attn_spec"] = None + def test01_sdxl_pipe_i8(self): + from turbine_models.custom_models.sd_inference.sd_pipeline import ( + SharkSDPipeline, + ) + + self.safe_model_name = create_safe_name(arguments["hf_model_name"], "") + decomp_attn = { + "text_encoder": True, + "unet": False, + "vae": ( + False + if any(x in arguments["device"] for x in ["hip", "rocm"]) + else True + ), + } + self.pipe = SharkSDPipeline( + arguments["hf_model_name"], + arguments["height"], + arguments["width"], + arguments["batch_size"], + arguments["max_length"], + arguments["precision"], + arguments["device"], + arguments["iree_target_triple"], + ireec_flags=None, + attn_spec=arguments["attn_spec"], + decomp_attn=decomp_attn, + pipeline_dir="test_vmfbs", + external_weights_dir="test_weights", + external_weights=arguments["external_weights"], + num_inference_steps=arguments["num_inference_steps"], + cpu_scheduling=True, + scheduler_id=arguments["scheduler_id"], + shift=None, + use_i8_punet=True, + vae_harness=False, + ) self.pipe.prepare_all() self.pipe.load_map() output = self.pipe.generate_images( @@ -157,7 +188,7 @@ def test00_compile_pipe(self): ) assert output is not None - def test01_PromptEncoder(self): + def test02_PromptEncoder(self): if arguments["device"] in ["vulkan", "cuda"]: self.skipTest("Compilation error on vulkan; To be tested on cuda.") clip_filename = ( @@ -237,7 +268,7 @@ def test01_PromptEncoder(self): np.testing.assert_allclose(torch_output1, turbine_output1, rtol, atol) np.testing.assert_allclose(torch_output2, turbine_output2, rtol, atol) - def test02_unet(self): + def test03_unet(self): if arguments["device"] in ["vulkan", "cuda"]: self.skipTest("Unknown error on vulkan; To be tested on cuda.") unet_filename = ( @@ -327,7 +358,7 @@ def test02_unet(self): atol = 4e-1 np.testing.assert_allclose(torch_output, turbine, rtol, atol) - def test03_ExportVaeModelDecode(self): + def test04_ExportVaeModelDecode(self): if arguments["device"] in ["vulkan", "cuda"]: self.skipTest("Compilation error on vulkan; To be tested on cuda.") @@ -393,93 +424,6 @@ def test03_ExportVaeModelDecode(self): atol = 4e-1 np.testing.assert_allclose(torch_output, turbine, rtol, atol) - # def test04_punet(self): - # if arguments["device"] in ["vulkan", "cuda"]: - # self.skipTest("Unknown error on vulkan; To be tested on cuda.") - # unet_filename = "_".join( - # create_safe_name(arguments["hf_model_name"], ""), - # "bs" + str(arguments["batch_size"]), - # str(arguments["max_length"]), - # str(arguments["height"]) + "x" + str(arguments["width"]), - # arguments["precision"], - # "unet", - # arguments["device"], - # arguments["iree_target_triple"], - # ) + ".vmfb" - # arguments["vmfb_path"] = os.path.join( - # "test_vmfbs", - # unet_filename - # ) - # unet_w_filename = "_".join( - # create_safe_name(arguments["hf_model_name"], ""), - # "unet", - # arguments["precision"], - # ) + ".safetensors" - # arguments["external_weight_path"] = os.path.join( - # "test_weights", - # unet_w_filename, - # ) - # dtype = torch.float16 if arguments["precision"] == "fp16" else torch.float32 - # sample = torch.rand( - # ( - # arguments["batch_size"], - # arguments["in_channels"], - # arguments["height"] // 8, - # arguments["width"] // 8, - # ), - # dtype=dtype, - # ) - # timestep = torch.zeros(1, dtype=dtype) - # prompt_embeds = torch.rand( - # (2 * arguments["batch_size"], arguments["max_length"], 2048), - # dtype=dtype, - # ) - # text_embeds = torch.rand(2 * arguments["batch_size"], 1280, dtype=dtype) - # time_ids = torch.zeros(2 * arguments["batch_size"], 6, dtype=dtype) - # guidance_scale = torch.Tensor([arguments["guidance_scale"]]).to(dtype) - - # turbine = unet_runner.run_punet( - # arguments["rt_device"], - # sample, - # timestep, - # prompt_embeds, - # text_embeds, - # time_ids, - # guidance_scale, - # arguments["vmfb_path"], - # arguments["hf_model_name"], - # arguments["hf_auth_token"], - # arguments["external_weight_path"], - # ) - # torch_output = unet_runner.run_torch_unet( - # arguments["hf_model_name"], - # arguments["hf_auth_token"], - # sample.float(), - # timestep, - # prompt_embeds.float(), - # text_embeds.float(), - # time_ids.float(), - # guidance_scale.float(), - # precision=arguments["precision"], - # ) - # if arguments["benchmark"] or arguments["tracy_profile"]: - # run_benchmark( - # "unet", - # arguments["vmfb_path"], - # arguments["external_weight_path"], - # arguments["rt_device"], - # max_length=arguments["max_length"], - # height=arguments["height"], - # width=arguments["width"], - # batch_size=arguments["batch_size"], - # in_channels=arguments["in_channels"], - # precision=arguments["precision"], - # tracy_profile=arguments["tracy_profile"], - # ) - # rtol = 4e-2 - # atol = 4e-1 - # np.testing.assert_allclose(torch_output, turbine, rtol, atol) - def tearDown(self): gc.collect()