Skip to content

Commit

Permalink
Separate punet run
Browse files Browse the repository at this point in the history
  • Loading branch information
eagarvey-amd committed Oct 4, 2024
1 parent 9fe20a6 commit f39b2d2
Showing 1 changed file with 41 additions and 97 deletions.
138 changes: 41 additions & 97 deletions models/turbine_models/tests/sdxl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def command_line_args(request):

@pytest.mark.usefixtures("command_line_args")
class StableDiffusionXLTest(unittest.TestCase):
def test00_compile_pipe(self):
def test00_sdxl_pipe(self):
from turbine_models.custom_models.sd_inference.sd_pipeline import (
SharkSDPipeline,
)
Expand Down Expand Up @@ -136,12 +136,43 @@ def test00_compile_pipe(self):
)
assert output is not None

# Switch to punet.
self.pipe.unload_submodel("unet")
self.pipe.use_punet = True
self.pipe.use_i8_punet = True
self.pipe.setup_punet()
self.pipe.map["unet"]["export_args"]["attn_spec"] = None
def test01_sdxl_pipe_i8(self):
from turbine_models.custom_models.sd_inference.sd_pipeline import (
SharkSDPipeline,
)

self.safe_model_name = create_safe_name(arguments["hf_model_name"], "")
decomp_attn = {
"text_encoder": True,
"unet": False,
"vae": (
False
if any(x in arguments["device"] for x in ["hip", "rocm"])
else True
),
}
self.pipe = SharkSDPipeline(
arguments["hf_model_name"],
arguments["height"],
arguments["width"],
arguments["batch_size"],
arguments["max_length"],
arguments["precision"],
arguments["device"],
arguments["iree_target_triple"],
ireec_flags=None,
attn_spec=arguments["attn_spec"],
decomp_attn=decomp_attn,
pipeline_dir="test_vmfbs",
external_weights_dir="test_weights",
external_weights=arguments["external_weights"],
num_inference_steps=arguments["num_inference_steps"],
cpu_scheduling=True,
scheduler_id=arguments["scheduler_id"],
shift=None,
use_i8_punet=True,
vae_harness=False,
)
self.pipe.prepare_all()
self.pipe.load_map()
output = self.pipe.generate_images(
Expand All @@ -157,7 +188,7 @@ def test00_compile_pipe(self):
)
assert output is not None

def test01_PromptEncoder(self):
def test02_PromptEncoder(self):
if arguments["device"] in ["vulkan", "cuda"]:
self.skipTest("Compilation error on vulkan; To be tested on cuda.")
clip_filename = (
Expand Down Expand Up @@ -237,7 +268,7 @@ def test01_PromptEncoder(self):
np.testing.assert_allclose(torch_output1, turbine_output1, rtol, atol)
np.testing.assert_allclose(torch_output2, turbine_output2, rtol, atol)

def test02_unet(self):
def test03_unet(self):
if arguments["device"] in ["vulkan", "cuda"]:
self.skipTest("Unknown error on vulkan; To be tested on cuda.")
unet_filename = (
Expand Down Expand Up @@ -327,7 +358,7 @@ def test02_unet(self):
atol = 4e-1
np.testing.assert_allclose(torch_output, turbine, rtol, atol)

def test03_ExportVaeModelDecode(self):
def test04_ExportVaeModelDecode(self):
if arguments["device"] in ["vulkan", "cuda"]:
self.skipTest("Compilation error on vulkan; To be tested on cuda.")

Expand Down Expand Up @@ -393,93 +424,6 @@ def test03_ExportVaeModelDecode(self):
atol = 4e-1
np.testing.assert_allclose(torch_output, turbine, rtol, atol)

# def test04_punet(self):
# if arguments["device"] in ["vulkan", "cuda"]:
# self.skipTest("Unknown error on vulkan; To be tested on cuda.")
# unet_filename = "_".join(
# create_safe_name(arguments["hf_model_name"], ""),
# "bs" + str(arguments["batch_size"]),
# str(arguments["max_length"]),
# str(arguments["height"]) + "x" + str(arguments["width"]),
# arguments["precision"],
# "unet",
# arguments["device"],
# arguments["iree_target_triple"],
# ) + ".vmfb"
# arguments["vmfb_path"] = os.path.join(
# "test_vmfbs",
# unet_filename
# )
# unet_w_filename = "_".join(
# create_safe_name(arguments["hf_model_name"], ""),
# "unet",
# arguments["precision"],
# ) + ".safetensors"
# arguments["external_weight_path"] = os.path.join(
# "test_weights",
# unet_w_filename,
# )
# dtype = torch.float16 if arguments["precision"] == "fp16" else torch.float32
# sample = torch.rand(
# (
# arguments["batch_size"],
# arguments["in_channels"],
# arguments["height"] // 8,
# arguments["width"] // 8,
# ),
# dtype=dtype,
# )
# timestep = torch.zeros(1, dtype=dtype)
# prompt_embeds = torch.rand(
# (2 * arguments["batch_size"], arguments["max_length"], 2048),
# dtype=dtype,
# )
# text_embeds = torch.rand(2 * arguments["batch_size"], 1280, dtype=dtype)
# time_ids = torch.zeros(2 * arguments["batch_size"], 6, dtype=dtype)
# guidance_scale = torch.Tensor([arguments["guidance_scale"]]).to(dtype)

# turbine = unet_runner.run_punet(
# arguments["rt_device"],
# sample,
# timestep,
# prompt_embeds,
# text_embeds,
# time_ids,
# guidance_scale,
# arguments["vmfb_path"],
# arguments["hf_model_name"],
# arguments["hf_auth_token"],
# arguments["external_weight_path"],
# )
# torch_output = unet_runner.run_torch_unet(
# arguments["hf_model_name"],
# arguments["hf_auth_token"],
# sample.float(),
# timestep,
# prompt_embeds.float(),
# text_embeds.float(),
# time_ids.float(),
# guidance_scale.float(),
# precision=arguments["precision"],
# )
# if arguments["benchmark"] or arguments["tracy_profile"]:
# run_benchmark(
# "unet",
# arguments["vmfb_path"],
# arguments["external_weight_path"],
# arguments["rt_device"],
# max_length=arguments["max_length"],
# height=arguments["height"],
# width=arguments["width"],
# batch_size=arguments["batch_size"],
# in_channels=arguments["in_channels"],
# precision=arguments["precision"],
# tracy_profile=arguments["tracy_profile"],
# )
# rtol = 4e-2
# atol = 4e-1
# np.testing.assert_allclose(torch_output, turbine, rtol, atol)

def tearDown(self):
gc.collect()

Expand Down

0 comments on commit f39b2d2

Please sign in to comment.