From 1dea19e694dab1d39fe600551c189cf00874ef84 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Thu, 11 Apr 2024 15:13:03 -0500 Subject: [PATCH] Adds SDXL support and CI testing, benchmarks. (#271) - Updates turbine-models requirements to use nod-ai fork of diffusers - Adds SDXL implementations and tests, benchmarks - Updates sd_inference/utils with newest flags and adds scheduler scaffolding to sd1.5/2.1 Co-authored-by: jinchen62 Co-authored-by: jinchen <49575973+jinchen62@users.noreply.github.com> Co-authored-by: PhaneeshB Co-authored-by: Avinash Sharma Co-authored-by: gpetters94 Co-authored-by: George Petterson Co-authored-by: aviator19941 Co-authored-by: saienduri <77521230+saienduri@users.noreply.github.com> --- .github/workflows/test_models.yml | 6 +- core/iree-requirements.txt | 2 +- core/shark_turbine/aot/builtins/jittable.py | 2 +- core/shark_turbine/dynamo/decompositions.py | 2 - models/requirements.txt | 2 +- .../README.md | 0 .../benchmark.mlir | 0 .../benchmark_forward.mlir | 0 .../benchmark_module.py | 0 .../stateless_llama_benchmark.py | 130 +-- .../sd_inference/schedulers_runner.py | 148 ++- .../custom_models/sd_inference/unet.py | 52 +- .../custom_models/sd_inference/unet_runner.py | 18 +- .../custom_models/sd_inference/utils.py | 193 ++- .../custom_models/sd_inference/vae.py | 75 +- .../custom_models/sd_inference/vae_runner.py | 64 +- .../custom_models/sdxl_inference/COMMANDS.md | 126 ++ .../custom_models/sdxl_inference/README.md | 32 + .../custom_models/sdxl_inference/clip.py | 207 ++++ .../sdxl_inference/clip_runner.py | 274 +++++ .../default_mfma_attn_spec.mlir | 1040 +++++++++++++++++ .../sdxl_inference/sdxl_benchmark.py | 73 ++ .../sdxl_inference/sdxl_cmd_opts.py | 289 +++++ .../sdxl_inference/sdxl_compiled_pipeline.py | 693 +++++++++++ .../sdxl_pipeline_bench_f16.mlir | 23 + .../sdxl_pipeline_bench_f32.mlir | 23 + .../sdxl_inference/sdxl_prompt_encoder.py | 247 ++++ .../sdxl_prompt_encoder_runner.py | 164 +++ .../sdxl_sched_unet_bench_f16.mlir | 19 + .../sdxl_sched_unet_bench_f32.mlir | 19 + .../sdxl_inference/sdxl_scheduled_unet.py | 344 ++++++ .../sdxl_scheduled_unet_runner.py | 358 ++++++ .../sdxl_inference/sdxl_schedulers.py | 197 ++++ .../custom_models/sdxl_inference/unet.py | 253 ++++ .../sdxl_inference/unet_runner.py | 166 +++ .../custom_models/sdxl_inference/vae.py | 213 ++++ .../sdxl_inference/vae_runner.py | 124 ++ models/turbine_models/model_runner.py | 61 +- models/turbine_models/tests/conftest.py | 53 + models/turbine_models/tests/sd_test.py | 65 +- models/turbine_models/tests/sdxl_test.py | 601 ++++++++++ .../turbine_tank/turbine_tank.py | 12 +- models/turbine_models/utils/benchmark.py | 137 +++ models/turbine_models/utils/sdxl_benchmark.py | 75 ++ serving/setup.py | 1 + 45 files changed, 6286 insertions(+), 297 deletions(-) rename models/turbine_models/custom_models/{llama-benchmark => llama_benchmark}/README.md (100%) rename models/turbine_models/custom_models/{llama-benchmark => llama_benchmark}/benchmark.mlir (100%) rename models/turbine_models/custom_models/{llama-benchmark => llama_benchmark}/benchmark_forward.mlir (100%) rename models/turbine_models/custom_models/{llama-benchmark => llama_benchmark}/benchmark_module.py (100%) rename models/turbine_models/custom_models/{llama-benchmark => llama_benchmark}/stateless_llama_benchmark.py (62%) create mode 100644 models/turbine_models/custom_models/sdxl_inference/COMMANDS.md create mode 100644 models/turbine_models/custom_models/sdxl_inference/README.md create mode 100644 models/turbine_models/custom_models/sdxl_inference/clip.py create mode 100644 models/turbine_models/custom_models/sdxl_inference/clip_runner.py create mode 100644 models/turbine_models/custom_models/sdxl_inference/default_mfma_attn_spec.mlir create mode 100644 models/turbine_models/custom_models/sdxl_inference/sdxl_benchmark.py create mode 100644 models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py create mode 100644 models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py create mode 100644 models/turbine_models/custom_models/sdxl_inference/sdxl_pipeline_bench_f16.mlir create mode 100644 models/turbine_models/custom_models/sdxl_inference/sdxl_pipeline_bench_f32.mlir create mode 100644 models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py create mode 100644 models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder_runner.py create mode 100644 models/turbine_models/custom_models/sdxl_inference/sdxl_sched_unet_bench_f16.mlir create mode 100644 models/turbine_models/custom_models/sdxl_inference/sdxl_sched_unet_bench_f32.mlir create mode 100644 models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py create mode 100644 models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet_runner.py create mode 100644 models/turbine_models/custom_models/sdxl_inference/sdxl_schedulers.py create mode 100644 models/turbine_models/custom_models/sdxl_inference/unet.py create mode 100644 models/turbine_models/custom_models/sdxl_inference/unet_runner.py create mode 100644 models/turbine_models/custom_models/sdxl_inference/vae.py create mode 100644 models/turbine_models/custom_models/sdxl_inference/vae_runner.py create mode 100644 models/turbine_models/tests/conftest.py create mode 100644 models/turbine_models/tests/sdxl_test.py create mode 100644 models/turbine_models/utils/benchmark.py create mode 100644 models/turbine_models/utils/sdxl_benchmark.py diff --git a/.github/workflows/test_models.yml b/.github/workflows/test_models.yml index 3a478f179..09129fdc2 100644 --- a/.github/workflows/test_models.yml +++ b/.github/workflows/test_models.yml @@ -34,6 +34,7 @@ jobs: uses: actions/checkout@v2 - name: Sync source deps + # build IREE from source with -DIREE_BUILD_TRACY=ON if getting tracy profile run: | python -m venv turbine_venv source turbine_venv/bin/activate @@ -44,7 +45,7 @@ jobs: pip install -r core/pytorch-cpu-requirements.txt pip install --pre --upgrade -r core/requirements.txt pip install --pre -e core[testing] - pip install --pre -e models + pip install --pre --upgrade -e models -r models/requirements.txt - name: Show current free memory run: | @@ -59,3 +60,6 @@ jobs: run: | source turbine_venv/bin/activate pytest models/turbine_models/tests/sd_test.py + pytest models/turbine_models/tests/sdxl_test.py --device cpu --rt_device local-task --iree_target_triple x86_64-linux-gnu + pytest models/turbine_models/tests/sdxl_test.py --device vulkan --rt_device vulkan --iree_target_triple rdna3-unknown-linux + pytest models/turbine_models/tests/sdxl_test.py --device rocm --rt_device hip --iree_target_triple gfx90a --precision fp16 diff --git a/core/iree-requirements.txt b/core/iree-requirements.txt index da59aeae6..eaa171b4e 100644 --- a/core/iree-requirements.txt +++ b/core/iree-requirements.txt @@ -1,2 +1,2 @@ iree-compiler==20240410.859 -iree-runtime==20240410.859 +iree-runtime==20240410.859 \ No newline at end of file diff --git a/core/shark_turbine/aot/builtins/jittable.py b/core/shark_turbine/aot/builtins/jittable.py index fedbcae55..29a90617b 100644 --- a/core/shark_turbine/aot/builtins/jittable.py +++ b/core/shark_turbine/aot/builtins/jittable.py @@ -226,7 +226,7 @@ def flat_wrapped_f(*args): ) logger.debug("Invoking dynamo trace") gm, guards = exported_f(*flat_pytorch_args) - logger.debug("Dyanmo trace complete") + logger.debug("Dynamo trace complete") # TODO: Add debug logging for the exported graph module. # gm.print_readable() diff --git a/core/shark_turbine/dynamo/decompositions.py b/core/shark_turbine/dynamo/decompositions.py index 8e4b1fea5..84f630c23 100644 --- a/core/shark_turbine/dynamo/decompositions.py +++ b/core/shark_turbine/dynamo/decompositions.py @@ -115,8 +115,6 @@ def _get_default_decomposition_ops() -> DecompositionOpsList: aten.lift_fresh_copy.default, aten._unsafe_index.Tensor, aten.unbind.int, - # decompositions added manually in this file - aten._scaled_dot_product_flash_attention.default, ] diff --git a/models/requirements.txt b/models/requirements.txt index d779002c9..ed2a0b0c1 100644 --- a/models/requirements.txt +++ b/models/requirements.txt @@ -3,7 +3,7 @@ sentencepiece shark_turbine transformers==4.37.1 accelerate -diffusers==0.24.0 +diffusers @ git+https://github.com/nod-ai/diffusers@v0.24.0-release brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b # turbine tank downloading/uploading azure-storage-blob diff --git a/models/turbine_models/custom_models/llama-benchmark/README.md b/models/turbine_models/custom_models/llama_benchmark/README.md similarity index 100% rename from models/turbine_models/custom_models/llama-benchmark/README.md rename to models/turbine_models/custom_models/llama_benchmark/README.md diff --git a/models/turbine_models/custom_models/llama-benchmark/benchmark.mlir b/models/turbine_models/custom_models/llama_benchmark/benchmark.mlir similarity index 100% rename from models/turbine_models/custom_models/llama-benchmark/benchmark.mlir rename to models/turbine_models/custom_models/llama_benchmark/benchmark.mlir diff --git a/models/turbine_models/custom_models/llama-benchmark/benchmark_forward.mlir b/models/turbine_models/custom_models/llama_benchmark/benchmark_forward.mlir similarity index 100% rename from models/turbine_models/custom_models/llama-benchmark/benchmark_forward.mlir rename to models/turbine_models/custom_models/llama_benchmark/benchmark_forward.mlir diff --git a/models/turbine_models/custom_models/llama-benchmark/benchmark_module.py b/models/turbine_models/custom_models/llama_benchmark/benchmark_module.py similarity index 100% rename from models/turbine_models/custom_models/llama-benchmark/benchmark_module.py rename to models/turbine_models/custom_models/llama_benchmark/benchmark_module.py diff --git a/models/turbine_models/custom_models/llama-benchmark/stateless_llama_benchmark.py b/models/turbine_models/custom_models/llama_benchmark/stateless_llama_benchmark.py similarity index 62% rename from models/turbine_models/custom_models/llama-benchmark/stateless_llama_benchmark.py rename to models/turbine_models/custom_models/llama_benchmark/stateless_llama_benchmark.py index fdf1657bf..2ce93cb73 100644 --- a/models/turbine_models/custom_models/llama-benchmark/stateless_llama_benchmark.py +++ b/models/turbine_models/custom_models/llama_benchmark/stateless_llama_benchmark.py @@ -4,19 +4,17 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -import sys +import argparse import numpy as np -import re import os +import re +import sys from transformers import AutoTokenizer from iree import runtime as ireert +from turbine_models.utils.benchmark import benchmark_module import turbine_models.custom_models.stateless_llama as llama -import argparse - -import subprocess -from collections import namedtuple parser = argparse.ArgumentParser() parser.add_argument( @@ -71,16 +69,20 @@ def run_benchmark(args): input.append(temp) input.append(np.array(args.steps)) + vmfbs = [] + vmfbs.append(args.llama_vmfb_path) + vmfbs.append(args.benchmark_vmfb_path) + if args.external_weight_file: results = benchmark_module( benchmark_mod, - args, "run", + vmfbs, input, parameters=f"model={args.external_weight_file}", ) else: - results = benchmark_module(benchmark_mod, args, "run", input) + results = benchmark_module(benchmark_mod, "run", vmfbs, input) for benchmark_result in results: print( @@ -146,16 +148,20 @@ def run_forward_benchmark(args): input.append(temp) input.append(np.array(args.steps)) + vmfbs = [] + vmfbs.append(args.llama_vmfb_path) + vmfbs.append(args.benchmark_vmfb_path) + if args.external_weight_file: results = benchmark_module( benchmark_mod, - args, "run", + vmfbs, input, parameters=f"model={args.external_weight_file}", ) else: - results = benchmark_module(benchmark_mod, args, "run", input) + results = benchmark_module(benchmark_mod, "run", vmfbs, input) for benchmark_result in results: print( @@ -178,110 +184,6 @@ def run_forward_benchmark(args): np.dtype(np.bool_): "i1", } -BenchmarkResult = namedtuple( - "BenchmarkResult", "benchmark_name time cpu_time iterations user_counters" -) - - -class BenchmarkToolError(Exception): - """Benchmark exception that preserves the command line and error output.""" - - def __init__(self, message): - self.message = message - super().__init__(self.message) - - -class BenchmarkTimeoutError(Exception): - """Exception raised if the benchmark is cancelled by the user specified timeout.""" - - pass - - -def benchmark_module( - module, bench_args, entry_function=None, inputs=[], timeout=None, **kwargs -): - funcs = [a for a in module.function_names if a != "__init"] - if entry_function is None: - if len(funcs) > 1: - raise ValueError(f"No function specified with multiple options {funcs}") - entry_function = funcs[0] - if entry_function not in funcs: - raise ValueError( - f"Attempted to benchmark unknown function {entry_function} of options {funcs}" - ) - - args = [ireert.benchmark_exe()] - args.append(f"--function={entry_function}") - - for inp in inputs: - if isinstance(inp, str): - args.append(f"--input={inp}") - continue - shape = "x".join([str(d) for d in inp.shape]) - abitype = DTYPE_TO_ABI_TYPE[inp.dtype] - values = inp.flatten() - if np.all(values[0] == values): - values = str(values[0]) - else: - values = ",".join([str(v) for v in values]) - - args.append(f"--input={shape}x{abitype}={values}") - - for k in kwargs: - v = kwargs[k] - args.append(f"--{k}={v}") - - args.append(f"--module={bench_args.llama_vmfb_path}") - args.append(f"--module={bench_args.benchmark_vmfb_path}") - - try: - benchmark_process = subprocess.run( - args=args, - # input=flatbuffer, - timeout=timeout, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - except subprocess.TimeoutExpired: - raise BenchmarkTimeoutError(f"Benchmark timed out after {timeout} seconds") - out = benchmark_process.stdout - err = benchmark_process.stderr - - err = err.decode() - if "INVALID_ARGUMENT;" in err: - raise ValueError("Invalid inputs specified for benchmarking") - - # In the event benchmarking runs but encounteres an internal error, - # return the internal error instead of benchmark results. - if "INTERNAL; CUDA driver error" in str(out): - raise BenchmarkToolError(str(out)) - - # Grab individual results by line (skip header lines) - bench_lines = out.decode().split("\n")[3:] - benchmark_results = [] - for line in bench_lines: - split = line.split() - if len(split) == 0: - continue - benchmark_name = split[0] - time = " ".join(split[1:3]) - cpu_time = " ".join(split[3:5]) - iterations = split[5] - user_counters = None - if len(split) > 5: - user_counters = split[6] - benchmark_results.append( - BenchmarkResult( - benchmark_name=benchmark_name, - time=time, - cpu_time=cpu_time, - iterations=iterations, - user_counters=user_counters, - ) - ) - - return benchmark_results - if __name__ == "__main__": args = parser.parse_args() diff --git a/models/turbine_models/custom_models/sd_inference/schedulers_runner.py b/models/turbine_models/custom_models/sd_inference/schedulers_runner.py index 2490f8ebf..45663c0a6 100644 --- a/models/turbine_models/custom_models/sd_inference/schedulers_runner.py +++ b/models/turbine_models/custom_models/sd_inference/schedulers_runner.py @@ -43,7 +43,7 @@ "--hf_model_name", type=str, help="HF model name", - default="CompVis/stable-diffusion-v1-4", + default="stabilityai/stable-diffusion-xl-base-1.0", ) parser.add_argument( "--hf_auth_token", @@ -60,9 +60,9 @@ "--batch_size", type=int, default=1, help="Batch size for inference" ) parser.add_argument( - "--height", type=int, default=512, help="Height of Stable Diffusion" + "--height", type=int, default=1024, help="Height of Stable Diffusion" ) -parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion") +parser.add_argument("--width", type=int, default=1024, help="Width of Stable Diffusion") def run_scheduler( @@ -84,42 +84,114 @@ def run_scheduler( return results +def run_sdxl_scheduler( + device, + sample, + prompt_embeds, + text_embeds, + time_ids, + vmfb_path, + hf_model_name, + hf_auth_token, + external_weight_path, +): + runner = vmfbRunner(device, vmfb_path, external_weight_path) + + inputs = [ + ireert.asdevicearray(runner.config.device, sample), + ireert.asdevicearray(runner.config.device, prompt_embeds), + ireert.asdevicearray(runner.config.device, text_embeds), + ireert.asdevicearray(runner.config.device, time_ids), + ] + results = runner.ctx.modules.compiled_scheduler["main"](*inputs) + return results + + def run_torch_scheduler( - hf_model_name, scheduler, num_inference_steps, sample, encoder_hidden_states + hf_model_name, + scheduler, + num_inference_steps, + sample, + prompt_embeds, + text_embeds, + time_ids, ): - class Scheduler(torch.nn.Module): - def __init__(self, hf_model_name, num_inference_steps, scheduler): + class SDXLScheduler(torch.nn.Module): + def __init__( + self, + hf_model_name, + num_inference_steps, + scheduler, + hf_auth_token=None, + precision="fp32", + ): super().__init__() self.scheduler = scheduler self.scheduler.set_timesteps(num_inference_steps) - self.unet = UNet2DConditionModel.from_pretrained( - hf_model_name, - subfolder="unet", - ) self.guidance_scale = 7.5 - - def forward(self, latents, encoder_hidden_states) -> torch.FloatTensor: - latents = latents * self.scheduler.init_noise_sigma - for t in self.scheduler.timesteps: - latent_model_input = torch.cat([latents] * 2) - t = t.unsqueeze(0) - latent_model_input = self.scheduler.scale_model_input( - latent_model_input, timestep=t + if precision == "fp16": + try: + self.unet = UNet2DConditionModel.from_pretrained( + hf_model_name, + subfolder="unet", + auth_token=hf_auth_token, + low_cpu_mem_usage=False, + variant="fp16", + ) + except: + self.unet = UNet2DConditionModel.from_pretrained( + hf_model_name, + subfolder="unet", + auth_token=hf_auth_token, + low_cpu_mem_usage=False, + ) + else: + self.unet = UNet2DConditionModel.from_pretrained( + hf_model_name, + subfolder="unet", + auth_token=hf_auth_token, + low_cpu_mem_usage=False, ) - unet_out = self.unet.forward( - latent_model_input, t, encoder_hidden_states, return_dict=False - )[0] - noise_pred_uncond, noise_pred_text = unet_out.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) - latents = self.scheduler.step( - noise_pred, t, latents, return_dict=False - )[0] - return latents - scheduler_module = Scheduler(hf_model_name, num_inference_steps, scheduler) - results = scheduler_module.forward(sample, encoder_hidden_states) + def forward(self, sample, prompt_embeds, text_embeds, time_ids): + sample = sample * self.scheduler.init_noise_sigma + for t in self.scheduler.timesteps: + with torch.no_grad(): + added_cond_kwargs = { + "text_embeds": text_embeds, + "time_ids": time_ids, + } + latent_model_input = torch.cat([sample] * 2) + t = t.unsqueeze(0) + # print('UNSQUEEZE T:', t) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, timestep=t + ) + noise_pred = self.unet.forward( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=None, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + sample = self.scheduler.step( + noise_pred, t, sample, return_dict=False + )[0] + return sample + + scheduler_module = SDXLScheduler( + hf_model_name, + num_inference_steps, + scheduler, + hf_auth_token=None, + precision="fp16", + ) + results = scheduler_module.forward(sample, prompt_embeds, text_embeds, time_ids) np_torch_output = results.detach().cpu().numpy() return np_torch_output @@ -134,10 +206,16 @@ def forward(self, latents, encoder_hidden_states) -> torch.FloatTensor: elif args.hf_model_name == "stabilityai/stable-diffusion-2-1-base": encoder_hidden_states = torch.rand(2, 77, 1024, dtype=torch.float32) - turbine_output = run_scheduler( + sample = torch.rand(args.batch_size, 4, args.height // 8, args.width // 8) + prompt_embeds = torch.rand(2, 77, 2048) + text_embeds = torch.rand(2, 1280) + time_ids = torch.rand(2, 6) + turbine_output = run_sdxl_scheduler( args.device, sample, - encoder_hidden_states, + prompt_embeds, + text_embeds, + time_ids, args.vmfb_path, args.hf_model_name, args.hf_auth_token, @@ -161,7 +239,9 @@ def forward(self, latents, encoder_hidden_states) -> torch.FloatTensor: scheduler, args.num_inference_steps, sample, - encoder_hidden_states, + prompt_embeds, + text_embeds, + time_ids, ) print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) err = utils.largest_error(torch_output, turbine_output) diff --git a/models/turbine_models/custom_models/sd_inference/unet.py b/models/turbine_models/custom_models/sd_inference/unet.py index 7ac419d3b..18657ae86 100644 --- a/models/turbine_models/custom_models/sd_inference/unet.py +++ b/models/turbine_models/custom_models/sd_inference/unet.py @@ -11,6 +11,9 @@ from iree.compiler.ir import Context import numpy as np from shark_turbine.aot import * +from shark_turbine.dynamo.passes import ( + DEFAULT_DECOMPOSITIONS, +) from turbine_models.custom_models.sd_inference import utils import torch import torch._dynamo as dynamo @@ -37,6 +40,12 @@ "--height", type=int, default=512, help="Height of Stable Diffusion" ) parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion") +parser.add_argument( + "--precision", type=str, default="fp16", help="Precision of Stable Diffusion" +) +parser.add_argument( + "--max_length", type=int, default=77, help="Sequence Length of Stable Diffusion" +) parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb") parser.add_argument("--external_weight_path", type=str, default="") parser.add_argument( @@ -57,22 +66,20 @@ class UnetModel(torch.nn.Module): - def __init__(self, hf_model_name, hf_auth_token): + def __init__(self, hf_model_name, hf_auth_token=None): super().__init__() self.unet = UNet2DConditionModel.from_pretrained( hf_model_name, subfolder="unet", - token=hf_auth_token, ) - self.guidance_scale = 7.5 - def forward(self, sample, timestep, encoder_hidden_states): + def forward(self, sample, timestep, encoder_hidden_states, guidance_scale): samples = torch.cat([sample] * 2) unet_out = self.unet.forward( samples, timestep, encoder_hidden_states, return_dict=False )[0] noise_pred_uncond, noise_pred_text = unet_out.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond ) return noise_pred @@ -84,6 +91,8 @@ def export_unet_model( batch_size, height, width, + precision="fp32", + max_length=77, hf_auth_token=None, compile_to="torch", external_weights=None, @@ -92,15 +101,27 @@ def export_unet_model( target_triple=None, max_alloc=None, upload_ir=False, + decomp_attn=True, ): mapper = {} + decomp_list = DEFAULT_DECOMPOSITIONS + if decomp_attn: + decomp_list.extend( + [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + ] + ) + dtype = torch.float16 if precision == "fp16" else torch.float32 + unet_model = unet_model.to(dtype) utils.save_external_weights( mapper, unet_model, external_weights, external_weight_path ) - - encoder_hidden_states_sizes = (2, 77, 768) - if hf_model_name == "stabilityai/stable-diffusion-2-1-base": - encoder_hidden_states_sizes = (2, 77, 1024) + encoder_hidden_states_sizes = ( + unet_model.unet.config.layers_per_block, + max_length, + unet_model.unet.config.cross_attention_dim, + ) sample = (batch_size, unet_model.unet.config.in_channels, height // 8, width // 8) @@ -114,13 +135,16 @@ class CompiledUnet(CompiledModule): def main( self, - sample=AbstractTensor(*sample, dtype=torch.float32), - timestep=AbstractTensor(1, dtype=torch.float32), + sample=AbstractTensor(*sample, dtype=dtype), + timestep=AbstractTensor(1, dtype=dtype), encoder_hidden_states=AbstractTensor( - *encoder_hidden_states_sizes, dtype=torch.float32 + *encoder_hidden_states_sizes, dtype=dtype ), + guidance_scale=AbstractTensor(1, dtype=dtype), ): - return jittable(unet_model.forward)(sample, timestep, encoder_hidden_states) + return jittable(unet_model.forward, decompose_ops=decomp_list)( + sample, timestep, encoder_hidden_states, guidance_scale + ) import_to = "INPUT" if compile_to == "linalg" else "IMPORT" inst = CompiledUnet(context=Context(), import_to=import_to) @@ -156,6 +180,8 @@ def main( args.batch_size, args.height, args.width, + args.precision, + args.max_length, args.hf_auth_token, args.compile_to, args.external_weights, diff --git a/models/turbine_models/custom_models/sd_inference/unet_runner.py b/models/turbine_models/custom_models/sd_inference/unet_runner.py index 2f73493a2..1b8c5d101 100644 --- a/models/turbine_models/custom_models/sd_inference/unet_runner.py +++ b/models/turbine_models/custom_models/sd_inference/unet_runner.py @@ -52,6 +52,7 @@ def run_unet( sample, timestep, encoder_hidden_states, + guidance_scale, vmfb_path, hf_model_name, hf_auth_token, @@ -63,13 +64,19 @@ def run_unet( ireert.asdevicearray(runner.config.device, sample), ireert.asdevicearray(runner.config.device, timestep), ireert.asdevicearray(runner.config.device, encoder_hidden_states), + ireert.asdevicearray(runner.config.device, guidance_scale), ] results = runner.ctx.modules.compiled_unet["main"](*inputs) return results def run_torch_unet( - hf_model_name, hf_auth_token, sample, timestep, encoder_hidden_states + hf_model_name, + hf_auth_token, + sample, + timestep, + encoder_hidden_states, + guidance_scale, ): from diffusers import UNet2DConditionModel @@ -83,7 +90,7 @@ def __init__(self, hf_model_name, hf_auth_token): ) self.guidance_scale = 7.5 - def forward(self, sample, timestep, encoder_hidden_states): + def forward(self, sample, timestep, encoder_hidden_states, guidance_scale): samples = torch.cat([sample] * 2) unet_out = self.unet.forward( samples, timestep, encoder_hidden_states, return_dict=False @@ -98,7 +105,9 @@ def forward(self, sample, timestep, encoder_hidden_states): hf_model_name, hf_auth_token, ) - results = unet_model.forward(sample, timestep, encoder_hidden_states) + results = unet_model.forward( + sample, timestep, encoder_hidden_states, guidance_scale + ) np_torch_output = results.detach().cpu().numpy() return np_torch_output @@ -109,6 +118,7 @@ def forward(self, sample, timestep, encoder_hidden_states): args.batch_size, 4, args.height // 8, args.width // 8, dtype=torch.float32 ) timestep = torch.zeros(1, dtype=torch.float32) + guidance_scale = torch.Tensor([7.5], dtype=torch.float32) if args.hf_model_name == "CompVis/stable-diffusion-v1-4": encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32) elif args.hf_model_name == "stabilityai/stable-diffusion-2-1-base": @@ -119,6 +129,7 @@ def forward(self, sample, timestep, encoder_hidden_states): sample, timestep, encoder_hidden_states, + guidance_scale, args.vmfb_path, args.hf_model_name, args.hf_auth_token, @@ -141,6 +152,7 @@ def forward(self, sample, timestep, encoder_hidden_states): sample, timestep, encoder_hidden_states, + guidance_scale, ) print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) err = utils.largest_error(torch_output, turbine_output) diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 2182ee168..2ce0ef601 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -1,47 +1,79 @@ import iree.compiler as ireec import numpy as np +import os import safetensors import re from diffusers import ( PNDMScheduler, + EulerDiscreteScheduler, ) +# If flags are verified to work on a specific model and improve performance without regressing numerics, add them to this dictionary. If you are working with bleeding edge flags, please add them manually with the --ireec_flags argument. +gfx94X_flags = { + "all": [ + "--iree-global-opt-propagate-transposes=true", + "--iree-opt-outer-dim-concat=true", + "--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode", + "--iree-vm-target-truncate-unsupported-floats", + "--iree-llvmgpu-enable-prefetch=true", + "--verify=false", + "--iree-rocm-waves-per-eu=2", + "--iree-opt-data-tiling=false", + "--iree-codegen-log-swizzle-tile=4", + "--iree-llvmgpu-promote-filter=true", + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics)", + ], + "unet": [ + "--iree-codegen-llvmgpu-use-conv-vector-distribute-pipeline", + "--iree-codegen-llvmgpu-reduce-skinny-matmuls", + "--iree-codegen-gpu-native-math-precision=true", + "--iree-codegen-llvmgpu-use-vector-distribution", + "--iree-codegen-winograd-use-forall", + ], + "clip": [ + "--iree-codegen-llvmgpu-use-vector-distribution", + "--iree-codegen-llvmgpu-reduce-skinny-matmuls", + "--iree-global-opt-only-sink-transposes=true", + ], + "vae": [ + "--iree-codegen-llvmgpu-use-conv-vector-distribute-pipeline", + "--iree-codegen-llvmgpu-use-vector-distribution", + "--iree-global-opt-only-sink-transposes=true", + "--iree-codegen-winograd-use-forall", + ], +} -def save_external_weights( - mapper, - model, - external_weights=None, - external_weight_file=None, -): - if external_weights is not None: - if external_weights == "safetensors": - mod_params = dict(model.named_parameters()) - for name in mod_params: - mapper["params." + name] = name - if external_weight_file: - safetensors.torch.save_file(mod_params, external_weight_file) - print("Saved params to", external_weight_file) - - -def largest_error(array1, array2): - absolute_diff = np.abs(array1 - array2) - max_error = np.max(absolute_diff) - return max_error - -def compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name): - flags = [ - "--iree-input-type=torch", - "--mlir-print-debuginfo", - "--mlir-print-op-on-diagnostic=false", - "--iree-llvmcpu-target-cpu-features=host", - "--iree-llvmcpu-target-triple=x86_64-linux-gnu", - "--iree-stream-resource-index-bits=64", - "--iree-vm-target-index-bits=64", - "--iree-flow-inline-constants-max-byte-length=1", - ] +def compile_to_vmfb( + module_str, + device, + target_triple, + ireec_flags=[""], + safe_name="model", + return_path=False, + const_expr_hoisting=True, + mlir_source="str", + max_alloc="4294967296", + save_mlir=False, + attn_spec=None, +): + flags = [] + if target_triple in ["", None]: + if device == "cpu": + target_triple = "x86_64-linux-gnu" + else: + raise ValueError( + "target_triple must be set. Usually this can be fixed by setting --iree_target_triple in the CLI." + ) if device == "cpu": - flags.append("--iree-llvmcpu-enable-ukernels=all") + flags.extend( + [ + "--iree-llvmcpu-target-triple=" + target_triple, + "--iree-llvmcpu-target-cpu-features=host", + "--iree-llvmcpu-fail-on-out-of-bounds-stack-allocation=false", + "--iree-llvmcpu-distribution-size=32", + ] + ) device = "llvm-cpu" elif device == "vulkan": flags.extend( @@ -49,18 +81,19 @@ def compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name): "--iree-hal-target-backends=vulkan-spirv", "--iree-vulkan-target-triple=" + target_triple, "--iree-stream-resource-max-allocation-size=" + max_alloc, + "--iree-stream-resource-index-bits=64", + "--iree-vm-target-index-bits=64", + "--iree-flow-inline-constants-max-byte-length=1", ] ) + device = "vulkan-spirv" elif device == "rocm": flags.extend( [ "--iree-hal-target-backends=rocm", "--iree-rocm-target-chip=" + target_triple, - "--iree-rocm-link-bc=true", - "--iree-rocm-bc-dir=/opt/rocm/amdgcn/bitcode", - "--iree-vm-bytecode-module-strip-source-map=true", - "--iree-opt-strip-assertions=true", - "--iree-vm-target-truncate-unsupported-floats", + "--verify=false", + "--iree-opt-const-eval=false", ] ) elif device == "cuda": @@ -68,29 +101,97 @@ def compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name): [ "--iree-hal-target-backends=cuda", "--iree-hal-cuda-llvm-target-arch=" + target_triple, - "--iree-vm-bytecode-module-strip-source-map=true", "--iree-vm-target-truncate-unsupported-floats", ] ) else: print("incorrect device: ", device) + if isinstance(ireec_flags, str): + if ireec_flags != "": + ireec_flags = ireec_flags.split(",") + elif ireec_flags == None: + ireec_flags = [] - flatbuffer_blob = ireec.compile_str( - module_str, - target_backends=[device], - extra_args=flags, - ) + for i, flag in enumerate(ireec_flags): + k = flag.strip().split("=")[0] + for idx, default in enumerate(flags): + if k == default.split("=")[0]: + flags[idx] = flag + ireec_flags[i] = "" + if flag not in [None, "", " "]: + flags.append(flag) + + if target_triple in ["gfx940", "gfx941", "gfx942"]: + if "unet" in safe_name: + flags.extend(gfx94X_flags["unet"]) + elif any(x in safe_name for x in ["clip", "prompt_encoder"]): + flags.extend(gfx94X_flags["clip"]) + elif "vae" in safe_name: + flags.extend(gfx94X_flags["vae"]) + flags.extend(gfx94X_flags["all"]) + + if attn_spec not in [None, "", " "]: + flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) + + print("Compiling to", device, "with flags:", flags) + + if mlir_source == "file": + flatbuffer_blob = ireec.compile_file( + module_str, + target_backends=[device], + input_type="torch", + extra_args=flags, + ) + elif mlir_source == "str": + if save_mlir: + with open(f"{safe_name}.mlir", "w+") as f: + f.write(module_str) + print("Saved to", safe_name + ".mlir") + flatbuffer_blob = ireec.compile_str( + module_str, + target_backends=[device], + input_type="torch", + extra_args=flags, + ) + else: + raise ValueError("mlir_source must be either 'file' or 'str'") with open(f"{safe_name}.vmfb", "wb+") as f: f.write(flatbuffer_blob) print("Saved to", safe_name + ".vmfb") + if return_path == True: + return safe_name + ".vmfb" def create_safe_name(hf_model_name, model_name_str): safe_name = hf_model_name.split("/")[-1].strip() + model_name_str safe_name = re.sub("-", "_", safe_name) + safe_name = re.sub("\.", "_", safe_name) return safe_name +def save_external_weights( + mapper, + model, + external_weights=None, + external_weight_file=None, +): + if external_weights is not None: + if external_weights in ["safetensors", "irpa"]: + mod_params = dict(model.named_parameters()) + for name in mod_params: + mapper["params." + name] = name + if external_weight_file and not os.path.isfile(external_weight_file): + safetensors.torch.save_file(mod_params, external_weight_file) + print("Saved params to", external_weight_file) + + +def largest_error(array1, array2): + absolute_diff = np.abs(array1 - array2) + max_error = np.max(absolute_diff) + print("Max error:", max_error) + return max_error + + def get_schedulers(model_id): # TODO: Robust scheduler setup on pipeline creation -- if we don't # set batch_size here, the SHARK schedulers will @@ -106,4 +207,8 @@ def get_schedulers(model_id): model_id, subfolder="scheduler", ) + schedulers["Euler"] = EulerDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) return schedulers diff --git a/models/turbine_models/custom_models/sd_inference/vae.py b/models/turbine_models/custom_models/sd_inference/vae.py index 46b758f15..0916acda0 100644 --- a/models/turbine_models/custom_models/sd_inference/vae.py +++ b/models/turbine_models/custom_models/sd_inference/vae.py @@ -11,19 +11,17 @@ from iree.compiler.ir import Context import numpy as np from shark_turbine.aot import * +from shark_turbine.dynamo.passes import ( + DEFAULT_DECOMPOSITIONS, +) from turbine_models.custom_models.sd_inference import utils import torch import torch._dynamo as dynamo from diffusers import AutoencoderKL - -import safetensors import argparse from turbine_models.turbine_tank import turbine_tank parser = argparse.ArgumentParser() -parser.add_argument( - "--hf_auth_token", type=str, help="The Hugging Face auth token, required" -) parser.add_argument( "--hf_model_name", type=str, @@ -37,6 +35,9 @@ "--height", type=int, default=512, help="Height of Stable Diffusion" ) parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion") +parser.add_argument( + "--precision", type=str, default="fp32", help="Precision of Stable Diffusion" +) parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb") parser.add_argument("--external_weight_path", type=str, default="") parser.add_argument( @@ -58,18 +59,42 @@ class VaeModel(torch.nn.Module): - def __init__(self, hf_model_name, hf_auth_token): + def __init__( + self, + hf_model_name, + custom_vae="", + ): super().__init__() - self.vae = AutoencoderKL.from_pretrained( - hf_model_name, - subfolder="vae", - token=hf_auth_token, - ) + self.vae = None + if custom_vae in ["", None]: + self.vae = AutoencoderKL.from_pretrained( + hf_model_name, + subfolder="vae", + ) + elif not isinstance(custom_vae, dict): + try: + # custom HF repo with no vae subfolder + self.vae = AutoencoderKL.from_pretrained( + custom_vae, + ) + except: + # some larger repo with vae subfolder + self.vae = AutoencoderKL.from_pretrained( + custom_vae, + subfolder="vae", + ) + else: + # custom vae as a HF state dict + self.vae = AutoencoderKL.from_pretrained( + hf_model_name, + subfolder="vae", + ) + self.vae.load_state_dict(custom_vae) def decode_inp(self, inp): - with torch.no_grad(): - x = self.vae.decode(inp, return_dict=False)[0] - return x + inp = 1 / 0.18215 * inp + x = self.vae.decode(inp, return_dict=False)[0] + return (x / 2 + 0.5).clamp(0, 1) def encode_inp(self, inp): latents = self.vae.encode(inp).latent_dist.sample() @@ -82,7 +107,7 @@ def export_vae_model( batch_size, height, width, - hf_auth_token=None, + precision, compile_to="torch", external_weights=None, external_weight_path=None, @@ -91,8 +116,19 @@ def export_vae_model( max_alloc=None, variant="decode", upload_ir=False, + decomp_attn=True, ): mapper = {} + decomp_list = DEFAULT_DECOMPOSITIONS + if decomp_attn: + decomp_list.extend( + [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + ] + ) + dtype = torch.float16 if precision == "fp16" else torch.float32 + vae_model = vae_model.to(dtype) utils.save_external_weights( mapper, vae_model, external_weights, external_weight_path ) @@ -104,11 +140,11 @@ def export_vae_model( class CompiledVae(CompiledModule): params = export_parameters(vae_model) - def main(self, inp=AbstractTensor(*sample, dtype=torch.float32)): + def main(self, inp=AbstractTensor(*sample, dtype=dtype)): if variant == "decode": - return jittable(vae_model.decode_inp)(inp) + return jittable(vae_model.decode_inp, decompose_ops=decomp_list)(inp) elif variant == "encode": - return jittable(vae_model.encode_inp)(inp) + return jittable(vae_model.encode_inp, decompose_ops=decomp_list)(inp) import_to = "INPUT" if compile_to == "linalg" else "IMPORT" inst = CompiledVae(context=Context(), import_to=import_to) @@ -136,7 +172,6 @@ def main(self, inp=AbstractTensor(*sample, dtype=torch.float32)): args = parser.parse_args() vae_model = VaeModel( args.hf_model_name, - args.hf_auth_token, ) mod_str = export_vae_model( vae_model, @@ -144,7 +179,7 @@ def main(self, inp=AbstractTensor(*sample, dtype=torch.float32)): args.batch_size, args.height, args.width, - args.hf_auth_token, + args.precision, args.compile_to, args.external_weights, args.external_weight_path, diff --git a/models/turbine_models/custom_models/sd_inference/vae_runner.py b/models/turbine_models/custom_models/sd_inference/vae_runner.py index fa5e430ac..cce53c118 100644 --- a/models/turbine_models/custom_models/sd_inference/vae_runner.py +++ b/models/turbine_models/custom_models/sd_inference/vae_runner.py @@ -27,11 +27,6 @@ help="HF model name", default="CompVis/stable-diffusion-v1-4", ) -parser.add_argument( - "--hf_auth_token", - type=str, - help="The Hugging face auth token, required for some models", -) parser.add_argument( "--device", type=str, @@ -48,32 +43,59 @@ parser.add_argument("--variant", type=str, default="decode") -def run_vae( - device, example_input, vmfb_path, hf_model_name, hf_auth_token, external_weight_path -): +def run_vae(device, example_input, vmfb_path, hf_model_name, external_weight_path): runner = vmfbRunner(device, vmfb_path, external_weight_path) inputs = [ireert.asdevicearray(runner.config.device, example_input)] - results = runner.ctx.modules.compiled_vae["main"](*inputs) + + results = runner.ctx.modules.compiled_vae["main"](*inputs).to_host() + return results -def run_torch_vae(hf_model_name, hf_auth_token, variant, example_input): +def run_torch_vae(hf_model_name, variant, example_input): from diffusers import AutoencoderKL class VaeModel(torch.nn.Module): - def __init__(self, hf_model_name, hf_auth_token): + def __init__( + self, + hf_model_name, + base_vae=False, + custom_vae="", + low_cpu_mem_usage=False, + hf_auth_token="", + ): super().__init__() - self.vae = AutoencoderKL.from_pretrained( - hf_model_name, - subfolder="vae", - token=hf_auth_token, - ) - - def decode_inp(self, inp): + self.vae = None + if custom_vae == "": + self.vae = AutoencoderKL.from_pretrained( + hf_model_name, + subfolder="vae", + low_cpu_mem_usage=low_cpu_mem_usage, + hf_auth_token=hf_auth_token, + ) + elif not isinstance(custom_vae, dict): + self.vae = AutoencoderKL.from_pretrained( + custom_vae, + subfolder="vae", + low_cpu_mem_usage=low_cpu_mem_usage, + hf_auth_token=hf_auth_token, + ) + else: + self.vae = AutoencoderKL.from_pretrained( + hf_model_name, + subfolder="vae", + low_cpu_mem_usage=low_cpu_mem_usage, + hf_auth_token=hf_auth_token, + ) + self.vae.load_state_dict(custom_vae) + self.base_vae = base_vae + + def decode_inp(self, input): with torch.no_grad(): - x = self.vae.decode(inp, return_dict=False)[0] - return x + input = 1 / 0.18215 * input + x = self.vae.decode(input, return_dict=False)[0] + return (x / 2 + 0.5).clamp(0, 1) def encode_inp(self, inp): latents = self.vae.encode(inp).latent_dist.sample() @@ -81,7 +103,6 @@ def encode_inp(self, inp): vae_model = VaeModel( hf_model_name, - hf_auth_token, ) if variant == "decode": @@ -108,7 +129,6 @@ def encode_inp(self, inp): example_input, args.vmfb_path, args.hf_model_name, - args.hf_auth_token, args.external_weight_path, ) print( diff --git a/models/turbine_models/custom_models/sdxl_inference/COMMANDS.md b/models/turbine_models/custom_models/sdxl_inference/COMMANDS.md new file mode 100644 index 000000000..220916383 --- /dev/null +++ b/models/turbine_models/custom_models/sdxl_inference/COMMANDS.md @@ -0,0 +1,126 @@ + +# SHARK-Turbine SDXL CLI usage (ROCM) + +## Pipeline (txt2img): + +Note: These commands are generally for unix, and use `$WEIGHTS_DIR`, `$PIPELINE_DIR`, and `$TARGET_TRIPLE` in place of actual values. You can set these env variables or replace them in the commands as desired. + +```shell +python /home/eagarvey/sdxl/SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py --precision=fp16 --external_weights=irpa --device=rocm --rt_device=rocm --iree_target_triple=$TARGET_TRIPLE --scheduler_id=PNDM --num_inference_steps=30 --pipeline_dir=$PIPELINE_DIR --external_weights_dir=$WEIGHTS_DIR --attn_spec=default --compiled_pipeline + +iree-benchmark-module \ + --module=$PWD/stable_diffusion_xl_base_1_0_64_fp16_prompt_encoder_rocm.vmfb \ + --parameters=model=$WEIGHTS_DIR/prompt_encoder.irpa \ + --module=$PWD/stable_diffusion_xl_base_1_0_64_1024x1024_fp16_scheduled_unet_rocm.vmfb \ + --parameters=model=$WEIGHTS_DIR/unet.irpa \ + --module=$PWD/stable_diffusion_xl_base_1_0_1024x1024_fp16_vae_decode_rocm.vmfb \ + --parameters=model=$WEIGHTS_DIR/vae_decode.irpa \ + --module=$PWD/sdxl_pipeline_fp16_$TARGET_TRIPLE.vmfb \ + --function=tokens_to_image \ + --input=1x4x128x128xf16 \ + --input=1xf16 \ + --input=1x64xi64 \ + --input=1x64xi64 \ + --input=1x64xi64 \ + --input=1x64xi64 \ + --device_allocator=caching \ + --benchmark_repetitions=1 \ + --device=rocm +``` +Note: you can either manually compile the pipeline vmfb from the .mlir in sdxl_inference, or by running the sdxl_scheduled_unet.py script. +The sdxl_compiled_pipeline script will do this for you, and you can switch between the segmented pipeline and the 'tokens->image' one-shot pipeline using `--compiled_pipeline` (if present, script will run the latter.) + +## Scheduled UNet + +``` +# Import to MLIR: + +python /home/eagarvey/sdxl/SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py --precision=fp16 --external_weights=safetensors --device=rocm --compile_to=mlir --external_weight_path=$WEIGHTS_DIR/unet.safetensors + +# Compile to VMFB (MLIR not necessary here but this is faster if you are compiling more than once): + +python /home/eagarvey/sdxl/SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py --precision=fp16 --external_weights=safetensors --device=rocm --compile_to=vmfb --iree_target_triple=$TARGET_TRIPLE --input_mlir=$PWD/stable_diffusion_xl_base_1_0_PNDM_64_1024x1024_fp16_unet_30.mlir + +# Test numerics (validate against pytorch cpu): + +python /home/eagarvey/sdxl/SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet_runner.py --compare_vs_torch --precision=fp16 --device=rocm --external_weight_path=$WEIGHTS_DIR/scheduled_unet.irpa --max_length=64 --pipeline_vmfb_path=./sdxl_pipeline_fp16_$TARGET_TRIPLE.vmfb --vmfb_path=$PWD/stable_diffusion_xl_base_1_0_64_1024x1024_fp16_scheduled_unet_rocm.vmfb + +# Benchmark with IREE CLI: + +iree-benchmark-module --benchmark_repetitions=5 --device=rocm --module=$PWD/stable_diffusion_xl_base_1_0_64_1024x1024_fp16_scheduled_unet_rocm.vmfb --parameters=model=$WEIGHTS_DIR/scheduled_unet.irpa --function=run_forward --input=1x4x128x128xf16 --input=2x64x2048xf16 --input=2x1280xf16 --input=2x6xf16 --input=1xf16 --input=1xi64 --device_allocator=caching + +iree-benchmark-module --benchmark_repetitions=5 --device=rocm --module=$PWD/stable_diffusion_xl_base_1_0_64_1024x1024_fp16_scheduled_unet_rocm.vmfb --module=$PWD/sdxl_pipeline_fp16_$TARGET_TRIPLE.vmfb --parameters=model=$WEIGHTS_DIR/scheduled_unet.irpa --function=run_forward --input=1x4x128x128xf16 --input=2x64x2048xf16 --input=2x1280xf16 --input=2x6xf16 --input=1xf16 --input=1xi64 --device_allocator=caching +``` + +## UNet + +``` +# Import to MLIR: + +python /home/eagarvey/sdxl/SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/unet.py --precision=fp16 --external_weights=safetensors --device=rocm --compile_to=mlir + +# Compile to VMFB (MLIR not necessary here but this is faster if you are compiling more than once): + +python /home/eagarvey/sdxl/SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/unet.py --precision=fp16 --external_weights=safetensors --device=rocm --compile_to=vmfb --iree_target_triple=$TARGET_TRIPLE --input_mlir=$PWD/stable_diffusion_xl_base_1_0_64_1024x1024_fp16_unet.mlir + +# Convert weights to IREE parameter archive format: + +iree-convert-parameters --parameters=$WEIGHTS_DIR/unet.safetensors --output=$WEIGHTS_DIR/scheduled_unet.irpa + +# Test numerics (validate against pytorch cpu): + +python /home/eagarvey/sdxl/SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/unet_runner.py --compare_vs_torch --precision=fp16 --device=rocm --external_weight_path=$WEIGHTS_DIR/scheduled_unet.irpa --max_length=64 --vmfb_path=$PWD/stable_diffusion_xl_base_1_0_64_1024x1024_fp16_unet_rocm.vmfb + +# Benchmark with IREE CLI: + +iree-benchmark-module --benchmark_repetitions=5 --device=rocm --module=$PWD/stable_diffusion_xl_base_1_0_64_1024x1024_fp16_unet_rocm.vmfb --parameters=model=$WEIGHTS_DIR/scheduled_unet.irpa --function=main --input=1x4x128x128xf16 --input=1xi64 --input=2x64x2048xf16 --input=2x1280xf16 --input=2x6xf16 --input=1xf16 --device_allocator=caching +``` + +## CLIP + +``` +# Import to MLIR: + +python /home/eagarvey/sdxl/SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py --precision=fp16 --external_weights=safetensors --device=rocm --rt_device=rocm --compile_to=mlir --iree_target_triple=$TARGET_TRIPLE --external_weight_path=$WEIGHTS_DIR/prompt_encoder.safetensors + +# Compile to VMFB (MLIR not necessary here but this is faster if you are compiling more than once): + +python /home/eagarvey/sdxl/SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py --precision=fp16 --external_weights=safetensors --device=rocm --rt_device=rocm --compile_to=vmfb --iree_target_triple=$TARGET_TRIPLE --input_mlir=$PWD/stable_diffusion_xl_base_1_0_64_fp16_prompt_encoder.mlir + +# Convert weights to IREE parameter archive format: + +iree-convert-parameters --parameters=$WEIGHTS_DIR/prompt_encoder.safetensors --output=$WEIGHTS_DIR/prompt_encoder.irpa + +# Test numerics (validate against pytorch cpu): + +python /home/eagarvey/sdxl/SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder_runner.py --compare_vs_torch --precision=fp16 --device=rocm --external_weight_path=$WEIGHTS_DIR/prompt_encoder.irpa --max_length=64 --vmfb_path=$PWD/stable_diffusion_xl_base_1_0_64_fp16_prompt_encoder_rocm.vmfb + +# Benchmark with IREE CLI: + +iree-benchmark-module --benchmark_repetitions=5 --device=rocm --module=$PWD/stable_diffusion_xl_base_1_0_64_fp16_prompt_encoder_rocm.vmfb --parameters=model=$WEIGHTS_DIR/prompt_encoder.irpa --function=encode_prompts --input=1x64xi64 --input=1x64xi64 --input=1x64xi64 --input=1x64xi64 --device_allocator=caching +``` + + +## VAE + +``` +# Import to MLIR: + +python /home/eagarvey/sdxl/SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/vae.py --precision=fp16 --external_weights=safetensors --device=rocm --compile_to=mlir --iree_target_triple=$TARGET_TRIPLE --external_weight_path=$WEIGHTS_DIR/vae_decode.safetensors + +# Compile to VMFB (MLIR not necessary here but this is faster if you are compiling more than once): + +python /home/eagarvey/sdxl/SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/vae.py --precision=fp16 --external_weights=safetensors --device=rocm --compile_to=vmfb --iree_target_triple=$TARGET_TRIPLE --input_mlir=$PWD/stable_diffusion_xl_base_1_0_1024x1024_fp16_vae_decode.mlir + +# Convert weights to IREE parameter archive format: + +iree-convert-parameters --parameters=$WEIGHTS_DIR/vae_decode.safetensors --output=$WEIGHTS_DIR/vae_decode.irpa + +# Test numerics (validate against pytorch cpu): + +python /home/eagarvey/sdxl/SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/vae_runner.py --precision=fp16 --external_weights=irpa --device=rocm --iree_target_triple=$TARGET_TRIPLE --vmfb_path=$PWD/stable_diffusion_xl_base_1_0_1024x1024_fp16_vae_decode_rocm.vmfb --external_weight_path=$WEIGHTS_DIR/vae_decode.irpa --compare_vs_torch + +# Benchmark with IREE CLI: + +iree-benchmark-module --benchmark_repetitions=5 --module=$PWD/stable_diffusion_xl_base_1_0_1024x1024_fp16_vae_decode_rocm.vmfb --parameters=model=$WEIGHTS_DIR/vae_decode.irpa --device=rocm --input=1x4x128x128xf16 --device-allocator=caching --function=main +``` diff --git a/models/turbine_models/custom_models/sdxl_inference/README.md b/models/turbine_models/custom_models/sdxl_inference/README.md new file mode 100644 index 000000000..19783c146 --- /dev/null +++ b/models/turbine_models/custom_models/sdxl_inference/README.md @@ -0,0 +1,32 @@ +# Stable Diffusion Commands + +## Run and benchmark the entire SDXL pipeline on MI300 + - note: the command below is specifically for use on the ppac-pla-s22-35 instance. you may need to tweak paths accordingly. + - follow "setup repository" in the next section + - optional: set HF_HOME to save dl time/ disk usage +``` +export HF_HOME=/mnt/dcgpuval/huggingface/ #ppac +export HF_HOME=/data/huggingface-cache #banff +``` + - make sure you have ROCM working with IREE, check `iree-run-module --dump_devices` + - make a file called "mfma_spec.mlir" and drop in the contents of the TD script https://github.com/nod-ai/2024-q1-sdxl-sprint/tree/main/specs. + +### Newest pipeline command, weights (as of [SHARK-Turbine@ean-sd-fp16:6251fbef9233c406093dab056a08cd42cfc54a0b](https://github.com/nod-ai/SHARK-Turbine/commit/6251fbef9233c406093dab056a08cd42cfc54a0b)): + + +gfx940: +``` +python SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py --precision=fp16 --external_weights=safetensors --device=rocm --rt_device=rocm --iree_target_triple=gfx942 --scheduler_id=PNDM --num_inference_steps=30 --pipeline_dir=./sdxl_fp16_1024x1024_gfx940/ --external_weights_dir=./weights_fp16/ --attn_spec=default +``` + +gfx942: +``` +python SHARK-Turbine/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py --precision=fp16 --external_weights=safetensors --device=rocm --rt_device=rocm --iree_target_triple=gfx940 --scheduler_id=PNDM --num_inference_steps=30 --pipeline_dir=./sdxl_fp16_1024x1024_gfx940/ --external_weights_dir=./weights_fp16/ --attn_spec=default +``` + +Note: the following "prompt_encoder_f16.irpa" contains weights for both clip1 and clip2. +The pipeline script will look for these filenames in the specified "external_weights_dir" under "prompt_encoder.irpa", "vae_decode.irpa", "scheduled_unet.irpa". +It's not ideal in current state, but will be smoothed out now that general pipeline structure and file management needs are stable. + - [prompt_encoder_f16.irpa](https://sharkpublic.blob.core.windows.net/sharkpublic/SDXL/SDXL_weights_fp16/prompt_encoder_fp16.irpa) + - [scheduled_unet_f16.irpa](https://sharkpublic.blob.core.windows.net/sharkpublic/SDXL/SDXL_weights_fp16/scheduled_unet_f16.irpa) + - [vae_decode_f16.irpa](https://sharkpublic.blob.core.windows.net/sharkpublic/SDXL/SDXL_weights_fp16/vae_encode_fp16.irpa) diff --git a/models/turbine_models/custom_models/sdxl_inference/clip.py b/models/turbine_models/custom_models/sdxl_inference/clip.py new file mode 100644 index 000000000..20b0aa7ae --- /dev/null +++ b/models/turbine_models/custom_models/sdxl_inference/clip.py @@ -0,0 +1,207 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import sys + +from iree import runtime as ireert +import iree.compiler as ireec +from iree.compiler.ir import Context +import numpy as np +from shark_turbine.aot import * +from turbine_models.custom_models.sd_inference import utils +import torch +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + + +class ClipModel(torch.nn.Module): + def __init__(self, hf_model_name, hf_auth_token=None, index=1): + super().__init__() + if index == 1: + self.text_encoder_model = CLIPTextModel.from_pretrained( + hf_model_name, + subfolder="text_encoder", + token=hf_auth_token, + ) + if index == 2: + self.text_encoder_model = CLIPTextModelWithProjection.from_pretrained( + hf_model_name, + subfolder="text_encoder_2", + token=hf_auth_token, + ) + + def forward(self, input): + with torch.no_grad(): + prompt_embeds = self.text_encoder_model( + input, + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + return prompt_embeds, pooled_prompt_embeds + + +def export_clip_model( + hf_model_name, + hf_auth_token=None, + max_length=77, + precision="fp16", + compile_to="torch", + external_weights=None, + external_weight_path=None, + device=None, + target_triple=None, + ireec_flags=None, + index=1, + exit_on_vmfb=True, + pipeline_dir=None, + input_mlir=None, + attn_spec=None, + weights_only=False, +): + if pipeline_dir not in [None, ""]: + safe_name = os.path.join(pipeline_dir, "clip_" + str(index)) + else: + safe_name = utils.create_safe_name( + hf_model_name, f"-{str(max_length)}-{precision}-clip-{index}-{device}" + ) + if input_mlir: + vmfb_path = utils.compile_to_vmfb( + input_mlir, + device, + target_triple, + ireec_flags, + safe_name, + mlir_source="file", + return_path=not exit_on_vmfb, + const_expr_hoisting=True, + attn_spec=attn_spec, + ) + return vmfb_path + # Load the tokenizer and text encoder to tokenize and encode the text. + if index == 1: + tokenizer = CLIPTokenizer.from_pretrained( + hf_model_name, + subfolder="tokenizer", + token=hf_auth_token, + model_max_length=max_length, + ) + elif index == 2: + tokenizer = CLIPTokenizer.from_pretrained( + hf_model_name, + subfolder="tokenizer_2", + token=hf_auth_token, + model_max_length=max_length, + ) + text_encoder_model = ClipModel(hf_model_name, hf_auth_token, index=index) + if compile_to == "tokenizer_only": + return None, tokenizer + if precision == "fp16": + text_encoder_model = text_encoder_model.half() + mapper = {} + if external_weight_path: + weights_path = ( + external_weight_path.split(f".{external_weights}")[0] + + f"_{index}" + + f".{external_weights}" + ) + else: + weights_path = None + + utils.save_external_weights( + mapper, text_encoder_model, external_weights, weights_path + ) + + if weights_only: + return weights_path + + class CompiledClip(CompiledModule): + if external_weights: + params = export_parameters( + text_encoder_model, + external=True, + external_scope="", + name_mapper=mapper.get, + ) + else: + params = export_parameters(text_encoder_model) + + def main(self, inp=AbstractTensor(1, max_length, dtype=torch.int64)): + return jittable(text_encoder_model.forward)(inp) + + import_to = "INPUT" if compile_to == "linalg" else "IMPORT" + inst = CompiledClip(context=Context(), import_to=import_to) + + module_str = str(CompiledModule.get_mlir_module(inst)) + + if compile_to != "vmfb": + return module_str, tokenizer + else: + vmfb_path = utils.compile_to_vmfb( + module_str, + device, + target_triple, + ireec_flags, + safe_name, + return_path=not exit_on_vmfb, + const_expr_hoisting=True, + attn_spec=attn_spec, + ) + return None, vmfb_path + + +if __name__ == "__main__": + from turbine_models.custom_models.sdxl_inference.sdxl_cmd_opts import args + + mod_1_str, _ = export_clip_model( + args.hf_model_name, + args.hf_auth_token, + args.max_length, + args.precision, + args.compile_to, + args.external_weights, + args.external_weight_path, + args.device, + args.iree_target_triple, + args.ireec_flags + args.clip_flags, + 1, + exit_on_vmfb=False, + pipeline_dir=args.pipeline_dir, + input_mlir=args.input_mlir, + attn_spec=args.attn_spec, + ) + mod_2_str, _ = export_clip_model( + args.hf_model_name, + args.hf_auth_token, + args.max_length, + args.precision, + args.compile_to, + args.external_weights, + args.external_weight_path, + args.device, + args.iree_target_triple, + args.ireec_flags + args.clip_flags, + 2, + exit_on_vmfb=True, + pipeline_dir=args.pipeline_dir, + input_mlir=args.input_mlir, + attn_spec=args.attn_spec, + ) + if args.input_mlir: + exit() + safe_name_1 = safe_name = utils.create_safe_name( + args.hf_model_name, f"_{str(args.max_length)}_{args.precision}_clip_1" + ) + safe_name_2 = safe_name = utils.create_safe_name( + args.hf_model_name, f"_{str(args.max_length)}_{args.precision}_clip_2" + ) + with open(f"{safe_name_1}.mlir", "w+") as f: + f.write(mod_1_str) + print("Saved to", safe_name_1 + ".mlir") + with open(f"{safe_name_2}.mlir", "w+") as f: + f.write(mod_2_str) + print("Saved to", safe_name_2 + ".mlir") diff --git a/models/turbine_models/custom_models/sdxl_inference/clip_runner.py b/models/turbine_models/custom_models/sdxl_inference/clip_runner.py new file mode 100644 index 000000000..d9a905d2f --- /dev/null +++ b/models/turbine_models/custom_models/sdxl_inference/clip_runner.py @@ -0,0 +1,274 @@ +import argparse +from turbine_models.model_runner import vmfbRunner +from transformers import CLIPTokenizer +from iree import runtime as ireert +import torch +import numpy as np + + +def run_encode_prompts( + device, + prompt, + negative_prompt, + vmfb_path_1, + vmfb_path_2, + hf_model_name, + hf_auth_token, + external_weight_path_1, + external_weight_path_2, + max_length, +): + runner_1 = vmfbRunner(device, vmfb_path_1, external_weight_path_1) + runner_2 = vmfbRunner(device, vmfb_path_2, external_weight_path_2) + text_encoders = [runner_1, runner_2] + + tokenizer_1 = CLIPTokenizer.from_pretrained( + hf_model_name, + subfolder="tokenizer", + token=hf_auth_token, + ) + tokenizer_2 = CLIPTokenizer.from_pretrained( + hf_model_name, + subfolder="tokenizer_2", + token=hf_auth_token, + ) + tokenizers = [tokenizer_1, tokenizer_2] + prompt_embeds_list = [] + prompts = [prompt, prompt] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode( + untruncated_ids[:, max_length - 1 : -1] + ) + print( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + text_input_ids = [ + ireert.asdevicearray(text_encoder.config.device, text_input_ids) + ] + text_encoder_output = text_encoder.ctx.modules.compiled_clip["main"]( + *text_input_ids + ) + prompt_embeds = torch.from_numpy(text_encoder_output[0].to_host()) + pooled_prompt_embeds = torch.from_numpy(text_encoder_output[1].to_host()) + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + uncond_tokens = [negative_prompt, negative_prompt] + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip( + uncond_tokens, tokenizers, text_encoders + ): + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + uncond_input_ids = uncond_input.input_ids + uncond_input_ids = [ + ireert.asdevicearray(text_encoder.config.device, uncond_input_ids) + ] + + text_encoder_output = text_encoder.ctx.modules.compiled_clip["main"]( + *uncond_input_ids + ) + negative_prompt_embeds = torch.from_numpy(text_encoder_output[0].to_host()) + negative_pooled_prompt_embeds = torch.from_numpy( + text_encoder_output[1].to_host() + ) + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + do_classifier_free_guidance = True + + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, 1, 1) + prompt_embeds = prompt_embeds.view(bs_embed * 1, seq_len, -1) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, 1).view( + 1, -1 + ) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, 1, 1) + negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * 1, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view(bs_embed * 1, -1) + return ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + +def run_torch_clip(hf_model_name, hf_auth_token, prompt, max_length=64): + # TODO: Integrate with HFTransformerBuilder + from turbine_models.custom_models.sdxl_inference.clip import ClipModel + + model_1 = ClipModel(hf_model_name, hf_auth_token, index=1) + model_2 = ClipModel(hf_model_name, hf_auth_token, index=2) + tokenizer_1 = CLIPTokenizer.from_pretrained( + hf_model_name, + subfolder="tokenizer", + token=hf_auth_token, + ) + tokenizer_2 = CLIPTokenizer.from_pretrained( + hf_model_name, + subfolder="tokenizer_2", + token=hf_auth_token, + ) + text_input_1 = tokenizer_1( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + text_input_2 = tokenizer_2( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + example_input_1 = text_input_1.input_ids + example_input_2 = text_input_2.input_ids + + results_1 = model_1.forward(example_input_1) + results_2 = model_2.forward(example_input_2) + np_torch_output_1 = results_1[0].detach().cpu().numpy().astype(np.float16) + np_torch_output_2 = results_2[0].detach().cpu().numpy().astype(np.float16) + return np_torch_output_1, np_torch_output_2 + + +def run_clip( + device, + prompt, + vmfb_path, + hf_model_name, + hf_auth_token, + external_weight_path, + max_length, + index, +): + runner = vmfbRunner(device, vmfb_path, external_weight_path) + + if index == 1: + tokenizer = CLIPTokenizer.from_pretrained( + hf_model_name, + subfolder="tokenizer", + token=hf_auth_token, + ) + elif index == 2: + tokenizer = CLIPTokenizer.from_pretrained( + hf_model_name, + subfolder="tokenizer_2", + token=hf_auth_token, + ) + else: + print("Incorrect CLIP model index, please use 1 or 2") + exit(1) + + text_input = tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + example_input = text_input.input_ids + inp = [ireert.asdevicearray(runner.config.device, example_input)] + results = runner.ctx.modules.compiled_clip["main"](*inp) + + return results + + +if __name__ == "__main__": + from turbine_models.custom_models.sdxl_inference.sdxl_cmd_opts import args + + vmfb_path_1 = "_clip_1".join(args.vmfb_path.split("_clip")) + vmfb_path_2 = "_clip_2".join(args.vmfb_path.split("_clip")) + external_weight_path_1 = "_clip_1".join(args.external_weight_path.split("_clip")) + external_weight_path_2 = "_clip_2".join(args.external_weight_path.split("_clip")) + turbine_output1 = run_clip( + args.device, + args.prompt, + vmfb_path_1, + args.hf_model_name, + args.hf_auth_token, + external_weight_path_1, + args.max_length, + index=1, + ) + print( + "TURBINE OUTPUT 1:", + turbine_output1[0].to_host(), + turbine_output1[0].to_host().shape, + turbine_output1[0].to_host().dtype, + ) + + turbine_output2 = run_clip( + args.device, + args.prompt, + vmfb_path_2, + args.hf_model_name, + args.hf_auth_token, + external_weight_path_2, + args.max_length, + index=2, + ) + print( + "TURBINE OUTPUT 2:", + turbine_output2[0].to_host(), + turbine_output2[0].to_host().shape, + turbine_output2[0].to_host().dtype, + ) + if args.compare_vs_torch: + print("generating torch output: ") + from turbine_models.custom_models.sd_inference import utils + + torch_output1, torch_output2 = run_torch_clip( + args.hf_model_name, + args.hf_auth_token, + args.prompt, + args.max_length, + ) + print( + "TORCH OUTPUT 1:", torch_output1, torch_output1.shape, torch_output1.dtype + ) + + print( + "TORCH OUTPUT 2:", torch_output2, torch_output2.shape, torch_output2.dtype + ) + rtol = 4e-1 + atol = 4e-2 + np.testing.assert_allclose( + torch_output1, turbine_output1[0], rtol, atol, verbose=True + ) + np.testing.assert_allclose( + torch_output2, turbine_output2[0], rtol, atol, verbose=True + ) + # TODO: Figure out why we occasionally segfault without unlinking output variables + turbine_output1, turbine_output2 = (None, None) diff --git a/models/turbine_models/custom_models/sdxl_inference/default_mfma_attn_spec.mlir b/models/turbine_models/custom_models/sdxl_inference/default_mfma_attn_spec.mlir new file mode 100644 index 000000000..794c83d99 --- /dev/null +++ b/models/turbine_models/custom_models/sdxl_inference/default_mfma_attn_spec.mlir @@ -0,0 +1,1040 @@ +// Transform dialect specification for attention on MI300 with MFMA. +// This script only supports variants of attention with a sequence +// length that is a multiple of 64. There are two near duplicate +// because we need different tile sizes when the head dimension is 512. +// TODO: Figure out how to parameterize the tile sizes without duplicating +// the attention function. + +#layout_16 = #iree_gpu.mfma_layout +#layout = #iree_gpu.mfma_layout + +module attributes { transform.with_named_sequence } { +//===----------------------------------------------------------------------===// +// Attention +//===----------------------------------------------------------------------===// + + // Utility matching for finding all undistributed fills. + transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { + transform.match.operation_name %arg0 ["linalg.fill"] : !transform.any_op + %0 = transform.get_parent_op %arg0 {allow_empty_results, nth_parent = 2 : i64, op_name = "scf.forall"} : (!transform.any_op) -> !transform.any_op + transform.match.operation_empty %0 : !transform.any_op + transform.yield %arg0 : !transform.any_op + } + + transform.named_sequence @get_undistributed_fills(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { + %0 = transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op + transform.yield %0 : !transform.any_op + } + + // Script for FA2 transform pipeline when head_dim % 64 = 0. + transform.named_sequence @__attention_main(%variant_op: !transform.any_op {transform.consumed}) { + // Get attention op + // ========================================== + %attention = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op + + // Tile and distribute to workgroups + // ========================================== + %tiled_attention, %forall_grid = + transform.structured.tile_using_forall %attention tile_sizes [1, 128] + ( mapping = [#gpu.block, #gpu.block] ) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.iree.populate_workgroup_count_region_using_num_threads_slice %forall_grid : (!transform.any_op) -> () + + // Tile batch dimensions of attention + // ========================================== + %attention2 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op + %batch_tiled_attn, %loop = transform.structured.tile_using_for %attention2 [1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %top_level_func { + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %top_level_func : !transform.any_op + + // Promote query and output operands + // ========================================== + //%attention3 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op + //%promoted_attention, %alloc_a0, %alloc_a1 = transform.iree.promote_operands %attention3 [0, 3] + // : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + + // Tile and decompose attention + // ========================================== + %attention4 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op + %acc_fill, %max_fill, %sum_fill, %inner_loop, %final_scaling, %last_truncate, %blocked_attention = transform.iree.tile_attention %attention4 {tile_size = 32} : + (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + %scale_q, %fill_op, %first_matmul, %reduce_max, %partial_softmax, %scale_factor, %update, %reduce_sum, %truncate, %scale_acc, %second_matmul + = transform.iree.decompose_tiled_attention %blocked_attention {tile_size = 32} : + (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + + // Promote key and value operands + // ========================================== + %promoted_first_matmul, %alloc0 = transform.iree.promote_operands %first_matmul [1] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %promoted_second_matmul, %alloc1 = transform.iree.promote_operands %second_matmul [1] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Tile and fuse attention ops + // ========================================== + %tiled_matmul, %forall = transform.structured.tile_using_forall %promoted_second_matmul tile_sizes [32] (mapping = [#gpu.warp]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %tiled_reduce_sum, %forall_reduce = transform.structured.tile_using_forall %reduce_sum tile_sizes [32] (mapping = [#gpu.warp]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + + %f0, %loop0 = transform.structured.fuse_into_containing_op %scale_acc into %forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %f1, %loop1 = transform.structured.fuse_into_containing_op %truncate into %loop0 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + %func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.apply_cse to %func : !transform.any_op + + %loop4 = transform.loop.fuse_sibling %forall_reduce into %loop1 : (!transform.any_op, !transform.any_op) -> !transform.any_op + transform.apply_cse to %func : !transform.any_op + + %f5_1, %loop5_1 = transform.structured.fuse_into_containing_op %update into %loop4 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.apply_cse to %func : !transform.any_op + + %f5, %loop5 = transform.structured.fuse_into_containing_op %scale_factor into %loop5_1 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %f6, %loop6 = transform.structured.fuse_into_containing_op %partial_softmax into %loop5 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.apply_cse to %func : !transform.any_op + + %f7, %loop7 = transform.structured.fuse_into_containing_op %reduce_max into %loop6 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %f8, %loop8 = transform.structured.fuse_into_containing_op %promoted_first_matmul into %loop7 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.apply_patterns to %func { + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %func : !transform.any_op + + %f9, %loop9 = transform.structured.fuse_into_containing_op %fill_op into %loop8 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + transform.apply_patterns to %func { + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %func : !transform.any_op + + %f10, %loop10 = transform.structured.fuse_into_containing_op %scale_q into %loop9 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + transform.apply_patterns to %func { + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %func : !transform.any_op + + // Distribute fills + // ========================================== + + // Get all fills that haven't been distributed to warps. + %fills = transform.include @get_undistributed_fills failures(propagate) (%variant_op) : (!transform.any_op) -> !transform.any_op + %tiled_fill, %fill_grid = transform.structured.tile_using_forall %fills tile_sizes[32] (mapping = [#gpu.warp]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Distribute last_truncate and fuse final_scaling into it + // ========================================== + %tiled_truncate, %loop_truncate = transform.structured.tile_using_forall %last_truncate tile_sizes[32] (mapping = [#gpu.warp]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.structured.fuse_into_containing_op %final_scaling into %loop_truncate : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + transform.apply_patterns to %func { + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %func : !transform.any_op + + // Vectorize function + // ========================================== + transform.apply_patterns to %func { + transform.apply_patterns.iree.fold_reshape_into_tensor_hal_interface + transform.apply_patterns.linalg.fold_unit_extent_dims_via_slices + transform.apply_patterns.vector.cast_away_vector_leading_one_dim + } : !transform.any_op + %func_3 = transform.structured.vectorize_children_and_apply_patterns %func : (!transform.any_op) -> (!transform.any_op) + + // Bufferization + // ========================================== + transform.apply_patterns to %func_3 { + transform.apply_patterns.tensor.reassociative_reshape_folding + transform.apply_patterns.canonicalization + transform.apply_patterns.iree.fold_fill_into_pad + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + } : !transform.any_op + transform.apply_cse to %func_3 : !transform.any_op + transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> () + transform.apply_patterns to %func_3 { transform.apply_patterns.linalg.erase_unnecessary_inputs } : !transform.any_op + %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op : (!transform.any_op) -> (!transform.any_op) + + // Step 5. Pre-process the contract and transfer ops to put it in the right form. + // =========================================================================== + %func_2 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_2 { + transform.apply_patterns.iree.fold_arith_ext_into_contraction + } : !transform.any_op + + // Step 6. Post-bufferization vector distribution + // =========================================================================== + %func_7 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op + transform.iree.forall_to_workgroup %func_7 : (!transform.any_op) -> () + transform.iree.map_nested_forall_to_gpu_threads %func_7 workgroup_dims = [64, 4, 1] subgroup_size = 64 : (!transform.any_op) -> () + + transform.apply_patterns to %func_7 { + transform.apply_patterns.memref.fold_memref_alias_ops + } : !transform.any_op + transform.iree.apply_licm %func_7 : !transform.any_op + transform.apply_patterns to %func_7 { + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %func_7 : !transform.any_op + %func_8 = transform.structured.hoist_redundant_vector_transfers %func_7 + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_8 { + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %func_8 : !transform.any_op + transform.memref.erase_dead_alloc_and_stores %func_8 : (!transform.any_op) -> () + + // Apply chained matmul optimization. + transform.apply_registered_pass "iree-amdgpu-prepare-chained-matmul" to %func_8 : (!transform.any_op) -> (!transform.any_op) + + // Get the vector.contract ops. + %contracts = transform.structured.match ops{["vector.contract"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op + %contract1, %contract2 = transform.split_handle %contracts : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + %layout16x16x16 = transform.param.constant #layout -> !transform.any_param + transform.iree.set_contraction_layout_attributes %contract1, %layout16x16x16 { read_layout_indices = array } : !transform.any_op, !transform.any_param + transform.iree.set_contraction_layout_attributes %contract2, %layout16x16x16 : !transform.any_op, !transform.any_param + + %distribute_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op + transform.iree.amdgpu_distribute_vectors %distribute_func : !transform.any_op + + %distribute_func_2 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op + + transform.apply_patterns to %distribute_func_2 { + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %distribute_func_2 : !transform.any_op + + // Distribute shared memory copies + // ========================================== + %func_10 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op + transform.iree.gpu_distribute_shared_memory_copy %func_10 : (!transform.any_op) -> () + transform.apply_patterns to %func_10 { + transform.apply_patterns.memref.fold_memref_alias_ops + transform.apply_patterns.canonicalization + transform.apply_patterns.linalg.tiling_canonicalization + } : !transform.any_op + transform.apply_cse to %func_10 : !transform.any_op + + %forop = transform.structured.match ops{["scf.for"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op + %prefetched_forop = transform.iree.prefetch_shared_memory_copies %forop : (!transform.any_op) -> (!transform.any_op) + + transform.apply_patterns to %func_10 { + transform.apply_patterns.memref.fold_memref_alias_ops + transform.apply_patterns.canonicalization + transform.apply_patterns.linalg.tiling_canonicalization + } : !transform.any_op + transform.apply_cse to %func_10 : !transform.any_op + + %func_11 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op + transform.amdgpu.optimize_shared_memory_reads_and_writes %func_11 : (!transform.any_op) -> () + + transform.yield + } + + // Script for FA2 transform pipeline for head_dim = 512. + // For head_dim = 512, since the matmul is so big, and just try to do a single wave big load + big mfma. + transform.named_sequence @__attention_main_len_512(%variant_op: !transform.any_op {transform.consumed}) { + // Get attention op + // ========================================== + %attention = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op + + // Tile and distribute to workgroups + // ========================================== + %tiled_attention, %forall_grid = + transform.structured.tile_using_forall %attention tile_sizes [1, 64] + ( mapping = [#gpu.block, #gpu.block] ) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.iree.populate_workgroup_count_region_using_num_threads_slice %forall_grid : (!transform.any_op) -> () + + // Tile batch dimensions of attention + // ========================================== + %attention2 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op + %batch_tiled_attn, %loop = transform.structured.tile_using_for %attention2 [1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %top_level_func { + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %top_level_func : !transform.any_op + + // Promote query and output operands + // ========================================== + //%attention3 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op + //%promoted_attention, %alloc_a0, %alloc_a1 = transform.iree.promote_operands %attention3 [0, 3] + // : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + + // Tile and decompose attention + // ========================================== + %attention4 = transform.structured.match ops{["iree_linalg_ext.attention"]} in %variant_op : (!transform.any_op) -> !transform.any_op + %acc_fill, %max_fill, %sum_fill, %inner_loop, %final_scaling, %last_truncate, %blocked_attention = transform.iree.tile_attention %attention4 {tile_size = 64} : + (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + %scale_q, %fill_op, %first_matmul, %reduce_max, %partial_softmax, %scale_factor, %update, %reduce_sum, %truncate, %scale_acc, %second_matmul + = transform.iree.decompose_tiled_attention %blocked_attention {tile_size = 64} : + (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + + // Promote key and value operands + // ========================================== + // %promoted_first_matmul, %alloc0 = transform.iree.promote_operands %first_matmul [1] + // : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + // %promoted_second_matmul, %alloc1 = transform.iree.promote_operands %second_matmul [1] + // : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Tile and fuse attention ops + // ========================================== + %tiled_matmul, %forall = transform.structured.tile_using_forall %second_matmul tile_sizes [16] (mapping = [#gpu.warp]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %tiled_reduce_sum, %forall_reduce = transform.structured.tile_using_forall %reduce_sum tile_sizes [16] (mapping = [#gpu.warp]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + + %f0, %loop0 = transform.structured.fuse_into_containing_op %scale_acc into %forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %f1, %loop1 = transform.structured.fuse_into_containing_op %truncate into %loop0 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + %func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.apply_cse to %func : !transform.any_op + + %loop4 = transform.loop.fuse_sibling %forall_reduce into %loop1 : (!transform.any_op, !transform.any_op) -> !transform.any_op + transform.apply_cse to %func : !transform.any_op + + %f5_1, %loop5_1 = transform.structured.fuse_into_containing_op %update into %loop4 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.apply_cse to %func : !transform.any_op + + %f5, %loop5 = transform.structured.fuse_into_containing_op %scale_factor into %loop5_1 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %f6, %loop6 = transform.structured.fuse_into_containing_op %partial_softmax into %loop5 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.apply_cse to %func : !transform.any_op + + %f7, %loop7 = transform.structured.fuse_into_containing_op %reduce_max into %loop6 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %f8, %loop8 = transform.structured.fuse_into_containing_op %first_matmul into %loop7 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.apply_patterns to %func { + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %func : !transform.any_op + + %f9, %loop9 = transform.structured.fuse_into_containing_op %fill_op into %loop8 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + transform.apply_patterns to %func { + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %func : !transform.any_op + + %f10, %loop10 = transform.structured.fuse_into_containing_op %scale_q into %loop9 : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + transform.apply_patterns to %func { + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %func : !transform.any_op + + // Distribute fills + // ========================================== + + // Get all fills that haven't been distributed to warps. + %fills = transform.include @get_undistributed_fills failures(propagate) (%variant_op) : (!transform.any_op) -> !transform.any_op + %tiled_fill, %fill_grid = transform.structured.tile_using_forall %fills tile_sizes[16] (mapping = [#gpu.warp]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Distribute last_truncate and fuse final_scaling into it + // ========================================== + %tiled_truncate, %loop_truncate = transform.structured.tile_using_forall %last_truncate tile_sizes[16] (mapping = [#gpu.warp]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.structured.fuse_into_containing_op %final_scaling into %loop_truncate : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + + transform.apply_patterns to %func { + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %func : !transform.any_op + + // Vectorize function + // ========================================== + transform.apply_patterns to %func { + transform.apply_patterns.iree.fold_reshape_into_tensor_hal_interface + transform.apply_patterns.linalg.fold_unit_extent_dims_via_slices + transform.apply_patterns.vector.cast_away_vector_leading_one_dim + } : !transform.any_op + %func_3 = transform.structured.vectorize_children_and_apply_patterns %func : (!transform.any_op) -> (!transform.any_op) + + // Bufferization + // ========================================== + transform.apply_patterns to %func_3 { + transform.apply_patterns.tensor.reassociative_reshape_folding + transform.apply_patterns.canonicalization + transform.apply_patterns.iree.fold_fill_into_pad + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + } : !transform.any_op + transform.apply_cse to %func_3 : !transform.any_op + transform.iree.eliminate_empty_tensors %variant_op : (!transform.any_op) -> () + transform.apply_patterns to %func_3 { transform.apply_patterns.linalg.erase_unnecessary_inputs } : !transform.any_op + %variant_op_3 = transform.iree.bufferize { target_gpu } %variant_op : (!transform.any_op) -> (!transform.any_op) + + // Step 5. Pre-process the contract and transfer ops to put it in the right form. + // =========================================================================== + %func_2 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_2 { + transform.apply_patterns.iree.fold_arith_ext_into_contraction + } : !transform.any_op + + // Step 6. Post-bufferization vector distribution + // =========================================================================== + %func_7 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op + transform.iree.forall_to_workgroup %func_7 : (!transform.any_op) -> () + transform.iree.map_nested_forall_to_gpu_threads %func_7 workgroup_dims = [64, 4, 1] subgroup_size = 64 : (!transform.any_op) -> () + + transform.apply_patterns to %func_7 { + transform.apply_patterns.memref.fold_memref_alias_ops + } : !transform.any_op + transform.iree.apply_licm %func_7 : !transform.any_op + transform.apply_patterns to %func_7 { + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %func_7 : !transform.any_op + %func_8 = transform.structured.hoist_redundant_vector_transfers %func_7 + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_8 { + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %func_8 : !transform.any_op + transform.memref.erase_dead_alloc_and_stores %func_8 : (!transform.any_op) -> () + + // Apply chained matmul optimization. + transform.apply_registered_pass "iree-amdgpu-prepare-chained-matmul" to %func_8 : (!transform.any_op) -> (!transform.any_op) + + // transform.print %variant_op_3 : !transform.any_op + + // Get the vector.contract ops. + %contracts = transform.structured.match ops{["vector.contract"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op + %contract1, %contract2 = transform.split_handle %contracts : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + %layout16x16x16 = transform.param.constant #layout_16 -> !transform.any_param + transform.iree.set_contraction_layout_attributes %contract1, %layout16x16x16 { read_layout_indices = array } : !transform.any_op, !transform.any_param + transform.iree.set_contraction_layout_attributes %contract2, %layout16x16x16 : !transform.any_op, !transform.any_param + + %distribute_func = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op + transform.iree.amdgpu_distribute_vectors %distribute_func : !transform.any_op + + %distribute_func_2 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op + + transform.apply_patterns to %distribute_func_2 { + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %distribute_func_2 : !transform.any_op + + // Distribute shared memory copies + // ========================================== + %func_10 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op + transform.iree.gpu_distribute_shared_memory_copy %func_10 : (!transform.any_op) -> () + transform.apply_patterns to %func_10 { + transform.apply_patterns.memref.fold_memref_alias_ops + transform.apply_patterns.canonicalization + transform.apply_patterns.linalg.tiling_canonicalization + } : !transform.any_op + transform.apply_cse to %func_10 : !transform.any_op + + %forop = transform.structured.match ops{["scf.for"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op + %prefetched_forop = transform.iree.prefetch_shared_memory_copies %forop : (!transform.any_op) -> (!transform.any_op) + + transform.apply_patterns to %func_10 { + transform.apply_patterns.memref.fold_memref_alias_ops + transform.apply_patterns.canonicalization + transform.apply_patterns.linalg.tiling_canonicalization + } : !transform.any_op + transform.apply_cse to %func_10 : !transform.any_op + + %func_11 = transform.structured.match ops{["func.func"]} in %variant_op_3 : (!transform.any_op) -> !transform.any_op + transform.amdgpu.optimize_shared_memory_reads_and_writes %func_11 : (!transform.any_op) -> () + + transform.yield + } + + // Send it down a custom transform dialect pipeline. + transform.named_sequence @custom_attention_len_512(%attention: !transform.any_op {transform.readonly}) { + %variant_op = transform.get_parent_op %attention {op_name = "hal.executable.variant"} : (!transform.any_op) -> !transform.any_op + %exports = transform.structured.match ops{["hal.executable.export"]} in %variant_op : (!transform.any_op) -> !transform.any_op + %attn = transform.param.constant #iree_codegen.translation_info -> !transform.any_param + transform.annotate %exports "translation_info" = %attn : !transform.any_op, !transform.any_param + transform.yield + } + + transform.named_sequence @match_attention_len_512(%attention: !transform.any_op {transform.readonly}) -> (!transform.any_op) { + transform.match.operation_name %attention ["iree_linalg_ext.attention"] : !transform.any_op + %in0 = transform.get_operand %attention[0] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %in0 = tensor : !transform.any_value + transform.yield %attention : !transform.any_op + } + + // Send it down a custom transform dialect pipeline. + transform.named_sequence @custom_attention(%attention: !transform.any_op {transform.readonly}) { + %variant_op = transform.get_parent_op %attention {op_name = "hal.executable.variant"} : (!transform.any_op) -> !transform.any_op + %exports = transform.structured.match ops{["hal.executable.export"]} in %variant_op : (!transform.any_op) -> !transform.any_op + %attn = transform.param.constant #iree_codegen.translation_info -> !transform.any_param + transform.annotate %exports "translation_info" = %attn : !transform.any_op, !transform.any_param + transform.yield + } + + transform.named_sequence @match_attention(%attention: !transform.any_op {transform.readonly}) -> (!transform.any_op) { + transform.match.operation_name %attention ["iree_linalg_ext.attention"] : !transform.any_op + %in0 = transform.get_operand %attention[0] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %in0 = tensor : !transform.any_value + transform.iree.match.dim_is_multiple_of %in0[2], 64 : !transform.any_value + transform.yield %attention : !transform.any_op + } + +//===----------------------------------------------------------------------===// +// Matmul tuning +//===----------------------------------------------------------------------===// + + transform.named_sequence @match_mmt_f16_f16_f32(%root: !transform.any_op {transform.readonly}) -> (!transform.any_op) { + transform.match.operation_name %root ["linalg.generic"] : !transform.any_op + // transform.print %root {name = "Generic"} : !transform.any_op + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %root { + ^bb0(%lhs: tensor, %rhs: tensor, %out: tensor): + %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%lhs, %rhs : tensor, tensor) outs(%out : tensor) { + ^bb0(%in: f16, %in_0: f16, %acc: f32): + %8 = arith.extf %in : f16 to f32 + %9 = arith.extf %in_0 : f16 to f32 + %10 = arith.mulf %8, %9 : f32 + %11 = arith.addf %acc, %10 : f32 + linalg.yield %11 : f32 + } -> tensor + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + transform.yield %root : !transform.any_op + } + + transform.named_sequence @match_mmt_f16_f16_f16(%root: !transform.any_op {transform.readonly}) -> (!transform.any_op) { + transform.match.operation_name %root ["linalg.generic"] : !transform.any_op + // transform.print %root {name = "Generic"} : !transform.any_op + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %root { + ^bb0(%lhs: tensor, %rhs: tensor, %out: tensor): + %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction"]} + ins(%lhs, %rhs : tensor, tensor) outs(%out : tensor) { + ^bb0(%in: f16, %in_0: f16, %acc: f16): + %10 = arith.mulf %in, %in_0 : f16 + %11 = arith.addf %acc, %10 : f16 + linalg.yield %11 : f16 + } -> tensor + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + transform.yield %root : !transform.any_op + } + + transform.named_sequence @apply_op_config(%op: !transform.any_op {transform.readonly}, %config: !transform.any_param {transform.readonly}) { + transform.annotate %op "compilation_info" = %config : !transform.any_op, !transform.any_param + // transform.print %op {name = "Applied"} : !transform.any_op + transform.yield + } + + transform.named_sequence @match_mmt_2048x10240x1280(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<2048x1280xf16> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<10240x1280xf16> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 2, subgroup_n_count = 2, + subgroup_m_tile_count = 4, + subgroup_n_tile_count = 4, + subgroup_k_tile_count = 2>, no_reorder_workgroups}>, + workgroup_size = [128, 2, 1], subgroup_size = 64 + > -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_mmt_2048x1280x1280(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<2048x1280xf16> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<1280x1280xf16> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 2, subgroup_n_count = 1, + subgroup_m_tile_count = 2, + subgroup_n_tile_count = 5, + subgroup_k_tile_count = 4>}>, + workgroup_size = [64, 2, 1], subgroup_size = 64 + > -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_mmt_2048x1280x5120(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<2048x5120xf16> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<1280x5120xf16> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 2, subgroup_n_count = 2, + subgroup_m_tile_count = 2, + subgroup_n_tile_count = 5, + subgroup_k_tile_count = 4>}>, + workgroup_size = [128, 2, 1], subgroup_size = 64 + > -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_mmt_128x1280x2048(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_mmt_f16_f16_f16 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<128x2048xf16> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<1280x2048xf16> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 2, subgroup_n_count = 2, + subgroup_m_tile_count = 1, + subgroup_n_tile_count = 1, + subgroup_k_tile_count = 16>}>, + workgroup_size = [128, 2, 1], subgroup_size = 64 + > -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_mmt_8192x640x2560(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<8192x2560xf16> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<640x2560xf16> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 2, subgroup_n_count = 2, + subgroup_m_tile_count = 2, + subgroup_n_tile_count = 5, + subgroup_k_tile_count = 4>}>, + workgroup_size = [128, 2, 1], subgroup_size = 64 + > -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_mmt_8192x5120x640(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<8192x640xf16> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<5120x640xf16> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 2, subgroup_n_count = 2, + subgroup_m_tile_count = 4, + subgroup_n_tile_count = 4, + subgroup_k_tile_count = 2>}>, + workgroup_size = [128, 2, 1], subgroup_size = 64 + > -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_mmt_128x640x2048(%matmul: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) { + %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op + %lhs = transform.get_operand %matmul[0] : (!transform.any_op) -> !transform.any_value + %rhs = transform.get_operand %matmul[1] : (!transform.any_op) -> !transform.any_value + transform.iree.match.cast_compatible_type %lhs = tensor<128x2048xf16> : !transform.any_value + transform.iree.match.cast_compatible_type %rhs = tensor<640x2048xf16> : !transform.any_value + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 2, subgroup_n_count = 1, + subgroup_m_tile_count = 1, + subgroup_n_tile_count = 1, + subgroup_k_tile_count = 32>}>, + workgroup_size = [64, 2, 1], subgroup_size = 64 + > -> !transform.any_param + transform.yield %matmul, %config : !transform.any_op, !transform.any_param + } + +//===----------------------------------------------------------------------===// +// Convolution tuning +//===----------------------------------------------------------------------===// + + transform.named_sequence @match_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1280(%conv: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv { + ^bb0(%lhs: tensor<2x?x?x1280xf16>, %rhs: tensor<3x3x1280x1280xf16>, %out: tensor<2x32x32x1280xf32>): + %13 = linalg.conv_2d_nhwc_hwcf { dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64> } + ins(%lhs, %rhs : tensor<2x?x?x1280xf16>, tensor<3x3x1280x1280xf16>) + outs(%out : tensor<2x32x32x1280xf32>) -> tensor<2x32x32x1280xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 2, subgroup_n_count = 5, + subgroup_m_tile_count = 1, + subgroup_n_tile_count = 1, + subgroup_k_tile_count = 8>}>, + workgroup_size = [640, 1, 1], subgroup_size = 64> -> !transform.any_param + transform.yield %conv, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1920(%conv: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv { + ^bb0(%lhs: tensor<2x?x?x1920xf16>, %rhs: tensor<3x3x1920x1280xf16>, %out: tensor<2x32x32x1280xf32>): + %13 = linalg.conv_2d_nhwc_hwcf { dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64> } + ins(%lhs, %rhs : tensor<2x?x?x1920xf16>, tensor<3x3x1920x1280xf16>) + outs(%out : tensor<2x32x32x1280xf32>) -> tensor<2x32x32x1280xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 2, subgroup_n_count = 5, + subgroup_m_tile_count = 1, + subgroup_n_tile_count = 1, + subgroup_k_tile_count = 8>}>, + workgroup_size = [640, 1, 1], subgroup_size = 64> -> !transform.any_param + transform.yield %conv, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x2560(%conv: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv { + ^bb0(%lhs: tensor<2x?x?x2560xf16>, %rhs: tensor<3x3x2560x1280xf16>, %out: tensor<2x32x32x1280xf32>): + %13 = linalg.conv_2d_nhwc_hwcf { dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64> } + ins(%lhs, %rhs : tensor<2x?x?x2560xf16>, tensor<3x3x2560x1280xf16>) + outs(%out : tensor<2x32x32x1280xf32>) -> tensor<2x32x32x1280xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 2, subgroup_n_count = 5, + subgroup_m_tile_count = 1, + subgroup_n_tile_count = 1, + subgroup_k_tile_count = 8>}>, + workgroup_size = [640, 1, 1], subgroup_size = 64> -> !transform.any_param + transform.yield %conv, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_conv_2d_nhwc_hwcf_2x64x64x640x3x3x640(%conv: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv { + ^bb0(%lhs: tensor<2x?x?x640xf16>, %rhs: tensor<3x3x640x640xf16>, %out: tensor<2x64x64x640xf32>): + %13 = linalg.conv_2d_nhwc_hwcf { dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64> } + ins(%lhs, %rhs : tensor<2x?x?x640xf16>, tensor<3x3x640x640xf16>) + outs(%out : tensor<2x64x64x640xf32>) -> tensor<2x64x64x640xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 2, subgroup_n_count = 5, + subgroup_m_tile_count = 2, + subgroup_n_tile_count = 1, + subgroup_k_tile_count = 4>}>, + workgroup_size = [640, 1, 1], subgroup_size = 64> -> !transform.any_param + transform.yield %conv, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_conv_2d_nhwc_hwcf_2x64x64x640x3x3x1280(%conv: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv { + ^bb0(%lhs: tensor<2x?x?x1280xf16>, %rhs: tensor<3x3x1280x640xf16>, %out: tensor<2x64x64x640xf32>): + %13 = linalg.conv_2d_nhwc_hwcf { dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64> } + ins(%lhs, %rhs : tensor<2x?x?x1280xf16>, tensor<3x3x1280x640xf16>) + outs(%out : tensor<2x64x64x640xf32>) -> tensor<2x64x64x640xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 1, subgroup_n_count = 10, + subgroup_m_tile_count = 2, + subgroup_n_tile_count = 1, + subgroup_k_tile_count = 10>}>, + workgroup_size = [640, 1, 1], subgroup_size = 64> -> !transform.any_param + transform.yield %conv, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_conv_2d_nhwc_hwcf_2x64x64x640x3x3x1920(%conv: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv { + ^bb0(%lhs: tensor<2x?x?x1920xf16>, %rhs: tensor<3x3x1920x640xf16>, %out: tensor<2x64x64x640xf32>): + %13 = linalg.conv_2d_nhwc_hwcf { dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64> } + ins(%lhs, %rhs : tensor<2x?x?x1920xf16>, tensor<3x3x1920x640xf16>) + outs(%out : tensor<2x64x64x640xf32>) -> tensor<2x64x64x640xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 1, subgroup_n_count = 10, + subgroup_m_tile_count = 2, + subgroup_n_tile_count = 1, + subgroup_k_tile_count = 10>}>, + workgroup_size = [640, 1, 1], subgroup_size = 64> -> !transform.any_param + transform.yield %conv, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_conv_2d_nhwc_hwcf_2x64x64x1280x3x3x1280(%conv: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + transform.match.operation_name %conv ["linalg.conv_2d_nhwc_hwcf"] : !transform.any_op + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv { + ^bb0(%lhs: tensor<2x66x66x1280xf16>, %rhs: tensor<3x3x1280x1280xf16>, %out: tensor<2x64x64x1280xf32>): + %13 = linalg.conv_2d_nhwc_hwcf { dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64> } + ins(%lhs, %rhs : tensor<2x66x66x1280xf16>, tensor<3x3x1280x1280xf16>) + outs(%out : tensor<2x64x64x1280xf32>) -> tensor<2x64x64x1280xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 1, subgroup_n_count = 10, + subgroup_m_tile_count = 2, + subgroup_n_tile_count = 1, + subgroup_k_tile_count = 10>}>, + workgroup_size = [640, 1, 1], subgroup_size = 64> -> !transform.any_param + transform.yield %conv, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_conv_2d_nhwc_hwcf_2x128x128x320x3x3x320(%conv: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv { + ^bb0(%lhs: tensor<2x?x?x320xf16>, %rhs: tensor<3x3x320x320xf16>, %out: tensor<2x128x128x320xf32>): + %13 = linalg.conv_2d_nhwc_hwcf { dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64> } + ins(%lhs, %rhs : tensor<2x?x?x320xf16>, tensor<3x3x320x320xf16>) + outs(%out : tensor<2x128x128x320xf32>) -> tensor<2x128x128x320xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 1, subgroup_n_count = 5, + subgroup_m_tile_count = 4, + subgroup_n_tile_count = 2, + subgroup_k_tile_count = 5>}>, + workgroup_size = [320, 1, 1], subgroup_size = 64> -> !transform.any_param + transform.yield %conv, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_conv_2d_nhwc_hwcf_2x128x128x320x3x3x640(%conv: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv { + ^bb0(%lhs: tensor<2x?x?x640xf16>, %rhs: tensor<3x3x640x320xf16>, %out: tensor<2x128x128x320xf32>): + %13 = linalg.conv_2d_nhwc_hwcf { dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64> } + ins(%lhs, %rhs : tensor<2x?x?x640xf16>, tensor<3x3x640x320xf16>) + outs(%out : tensor<2x128x128x320xf32>) -> tensor<2x128x128x320xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 1, subgroup_n_count = 5, + subgroup_m_tile_count = 4, + subgroup_n_tile_count = 2, + subgroup_k_tile_count = 5>}>, + workgroup_size = [320, 1, 1], subgroup_size = 64> -> !transform.any_param + transform.yield %conv, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_conv_2d_nhwc_hwcf_2x128x128x320x3x3x960(%conv: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv { + ^bb0(%lhs: tensor<2x?x?x960xf16>, %rhs: tensor<3x3x960x320xf16>, %out: tensor<2x128x128x320xf32>): + %13 = linalg.conv_2d_nhwc_hwcf { dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64> } + ins(%lhs, %rhs : tensor<2x?x?x960xf16>, tensor<3x3x960x320xf16>) + outs(%out : tensor<2x128x128x320xf32>) -> tensor<2x128x128x320xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 1, subgroup_n_count = 5, + subgroup_m_tile_count = 4, + subgroup_n_tile_count = 2, + subgroup_k_tile_count = 5>}>, + workgroup_size = [320, 1, 1], subgroup_size = 64> -> !transform.any_param + transform.yield %conv, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_conv_2d_nhwc_hwcf_2x128x128x640x3x3x640(%conv: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %conv { + ^bb0(%lhs: tensor<2x?x?x640xf16>, %rhs: tensor<3x3x640x640xf16>, %out: tensor<2x128x128x640xf32>): + %13 = linalg.conv_2d_nhwc_hwcf { dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64> } + ins(%lhs, %rhs : tensor<2x?x?x640xf16>, tensor<3x3x640x640xf16>) + outs(%out : tensor<2x128x128x640xf32>) -> tensor<2x128x128x640xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 1, subgroup_n_count = 4, + subgroup_m_tile_count = 4, + subgroup_n_tile_count = 1, + subgroup_k_tile_count = 4>}>, + workgroup_size = [256, 1, 1], subgroup_size = 64> -> !transform.any_param + transform.yield %conv, %config : !transform.any_op, !transform.any_param + } + +//===----------------------------------------------------------------------===// +// Contraction tuning +//===----------------------------------------------------------------------===// + + transform.named_sequence @match_contract_2x1024x1280x20x64(%contract: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %contract { + ^bb0(%lhs: tensor<2x20x1024x64xf16>, %rhs: tensor<1280x20x64xf16>, %out: tensor<2x1024x1280xf32>): + %20 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d1, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>, + affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"] + } ins(%lhs, %rhs : tensor<2x20x1024x64xf16>, tensor<1280x20x64xf16>) + outs(%out : tensor<2x1024x1280xf32>) { + ^bb0(%in: f16, %in_0: f16, %acc: f32): + %22 = arith.extf %in : f16 to f32 + %23 = arith.extf %in_0 : f16 to f32 + %24 = arith.mulf %22, %23 : f32 + %25 = arith.addf %acc, %24 : f32 + linalg.yield %25 : f32 + } -> tensor<2x1024x1280xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 2, subgroup_n_count = 2, + subgroup_m_tile_count = 2, + subgroup_n_tile_count = 5, + subgroup_k_tile_count = 4>}>, + workgroup_size = [128, 2, 1], subgroup_size = 64> -> !transform.any_param + // transform.print %contract {name = "Contract"} : !transform.any_op + transform.yield %contract, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_contract_2x2x20x64x64x2048(%contract: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %contract { + ^bb0(%lhs: tensor<2x64x2048xf16>, %rhs: tensor<2x20x64x2048xf16>, %out: tensor<2x2x20x64x64xf32>): + %10 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d3, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction"] + } ins(%lhs, %rhs : tensor<2x64x2048xf16>, tensor<2x20x64x2048xf16>) + outs(%out : tensor<2x2x20x64x64xf32>) { + ^bb0(%in: f16, %in_0: f16, %acc: f32): + %12 = arith.extf %in : f16 to f32 + %13 = arith.extf %in_0 : f16 to f32 + %14 = arith.mulf %12, %13 : f32 + %15 = arith.addf %acc, %14 : f32 + linalg.yield %15 : f32 + } -> tensor<2x2x20x64x64xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 2, subgroup_n_count = 1, + subgroup_m_tile_count = 1, + subgroup_n_tile_count = 1, + subgroup_k_tile_count = 8>}>, + workgroup_size = [64, 2, 1], subgroup_size = 64> -> !transform.any_param + // transform.print %contract {name = "Contract"} : !transform.any_op + transform.yield %contract, %config : !transform.any_op, !transform.any_param + } + + transform.named_sequence @match_contract_3x2x20x64x64x1280(%contract: !transform.any_op {transform.readonly}) + -> (!transform.any_op, !transform.any_param) { + %ins, %outs = transform.iree.match.cast_compatible_dag_from_root %contract { + ^bb0(%lhs: tensor<2x1024x1280xf16>, %rhs: tensor<3x20x64x1280xf16>, %out: tensor<3x2x20x1024x64xf32>): + %14 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d3, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction"] + } ins(%lhs, %rhs : tensor<2x1024x1280xf16>, tensor<3x20x64x1280xf16>) + outs(%out : tensor<3x2x20x1024x64xf32>) { + ^bb0(%in: f16, %in_0: f16, %acc: f32): + %16 = arith.extf %in : f16 to f32 + %17 = arith.extf %in_0 : f16 to f32 + %18 = arith.mulf %16, %17 : f32 + %19 = arith.addf %acc, %18 : f32 + linalg.yield %19 : f32 + } -> tensor<3x2x20x1024x64xf32> + } : (!transform.any_op) -> (!transform.any_value, !transform.any_value) + %config = transform.param.constant #iree_codegen.compilation_info< + lowering_config = #iree_codegen.lowering_config, + translation_info = #iree_codegen.translation_info, + subgroup_m_count = 2, subgroup_n_count = 2, + subgroup_m_tile_count = 2, + subgroup_n_tile_count = 1, + subgroup_k_tile_count = 8>}>, + workgroup_size = [128, 2, 1], subgroup_size = 64> -> !transform.any_param + // transform.print %contract {name = "Contract"} : !transform.any_op + transform.yield %contract, %config : !transform.any_op, !transform.any_param + } + +//===----------------------------------------------------------------------===// +// Entry point +//===----------------------------------------------------------------------===// + + transform.named_sequence @__kernel_config(%variant_op: !transform.any_op {transform.consumed}) { + transform.foreach_match in %variant_op + // Attention. + @match_attention_len_512 -> @custom_attention_len_512, + @match_attention -> @custom_attention, + // Matmul tuning. + @match_mmt_2048x10240x1280 -> @apply_op_config, + @match_mmt_2048x1280x1280 -> @apply_op_config, + @match_mmt_2048x1280x5120 -> @apply_op_config, + @match_mmt_128x1280x2048 -> @apply_op_config, + @match_mmt_128x640x2048 -> @apply_op_config, + @match_mmt_8192x640x2560 -> @apply_op_config, + @match_mmt_8192x5120x640 -> @apply_op_config, + // Convolution tuning. + @match_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1280 -> @apply_op_config, + @match_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x1920 -> @apply_op_config, + @match_conv_2d_nhwc_hwcf_2x32x32x1280x3x3x2560 -> @apply_op_config, + @match_conv_2d_nhwc_hwcf_2x64x64x640x3x3x640 -> @apply_op_config, + @match_conv_2d_nhwc_hwcf_2x64x64x640x3x3x1280 -> @apply_op_config, + @match_conv_2d_nhwc_hwcf_2x64x64x640x3x3x1920 -> @apply_op_config, + @match_conv_2d_nhwc_hwcf_2x64x64x1280x3x3x1280 -> @apply_op_config, + @match_conv_2d_nhwc_hwcf_2x128x128x320x3x3x320 -> @apply_op_config, + @match_conv_2d_nhwc_hwcf_2x128x128x320x3x3x640 -> @apply_op_config, + @match_conv_2d_nhwc_hwcf_2x128x128x320x3x3x960 -> @apply_op_config, + @match_conv_2d_nhwc_hwcf_2x128x128x640x3x3x640 -> @apply_op_config, + // Contract tuning. + @match_contract_2x1024x1280x20x64 -> @apply_op_config, + @match_contract_2x2x20x64x64x2048 -> @apply_op_config, + @match_contract_3x2x20x64x64x1280 -> @apply_op_config + : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} //// module \ No newline at end of file diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_benchmark.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_benchmark.py new file mode 100644 index 000000000..4c78fb764 --- /dev/null +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_benchmark.py @@ -0,0 +1,73 @@ +# Copyright 2024 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + + +import numpy as np +import torch +import sys + +from iree import runtime as ireert +from turbine_models.utils.benchmark import benchmark_module + + +def run_benchmark(args): + config = ireert.Config(args.rt_device) + + if args.external_weight_file: + index = ireert.ParameterIndex() + index.load(args.external_weight_file) + + if not args.benchmark_vmfb_path: + sys.exit("no --benchmark_vmfb_path provided, required for run_benchmark") + benchmark_mod = ireert.VmModule.mmap(config.vm_instance, args.benchmark_vmfb_path) + + if not args.scheduled_unet_vmfb_path: + sys.exit("no --scheduled_unet_vmfb_path provided, required for run_benchmark") + + dtype = np.float16 if args.precision == "fp16" else np.float32 + sample = np.random.randn( + args.batch_size, 4, args.height // 8, args.width // 8 + ).astype(dtype) + prompt_embeds = np.random.randn(2 * args.batch_size, args.max_length, 2048).astype( + dtype + ) + text_embeds = np.random.randn(2 * args.batch_size, 1280).astype(dtype) + guidance_scale = np.array([7.5], dtype=dtype) + num_iters = np.array(args.num_inference_steps) + input = [ + sample, + prompt_embeds, + text_embeds, + guidance_scale, + num_iters, + ] + + vmfbs = [] + vmfbs.append(args.scheduled_unet_vmfb_path) + vmfbs.append(args.benchmark_vmfb_path) + + if args.external_weight_file: + results = benchmark_module( + benchmark_mod, + "produce_image_latents", + vmfbs, + input, + parameters=f"model={args.external_weight_file}", + ) + else: + results = benchmark_module(benchmark_mod, "produce_image_latents", vmfbs, input) + for benchmark_result in results: + print( + f"benchmark_name: {benchmark_result.benchmark_name}, time: {benchmark_result.time}, cpu_time: {benchmark_result.cpu_time}, iterations: {benchmark_result.iterations}, user_counters: {benchmark_result.user_counters}" + ) + + +# Python Benchmarking Support for multiple modules + +if __name__ == "__main__": + from turbine_models.custom_models.sdxl_inference.sdxl_cmd_opts import args + + run_benchmark(args) diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py new file mode 100644 index 000000000..f2faa0323 --- /dev/null +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py @@ -0,0 +1,289 @@ +import argparse +import os +from pathlib import Path + + +def path_expand(s): + return Path(s).expanduser().resolve() + + +def is_valid_file(arg): + if not os.path.exists(arg): + return None + else: + return arg + + +# Note: this is where command-line options for the scripts in this directory +# are defined along with their defaults. Thus, they should not be referenced +# within modelling or inference code, only at the entry point to the script. + +# We should consider separating out the options that are "model configs" from +# the options that control the compiler, runtime, and script behavior, +# when applicable, as the formermost would best be kept in a separate +# config or imported from huggingface. + +p = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter +) + +############################################################################## +# SDXL Huggingface Options +############################################################################## + +p.add_argument( + "--hf_auth_token", + type=str, + help="The Hugging Face auth token, if required", + default=None, +) +p.add_argument( + "--hf_model_name", + type=str, + help="HF model name", + default="stabilityai/stable-diffusion-xl-base-1.0", +) +p.add_argument( + "--scheduler_id", + type=str, + help="Scheduler ID", + default="PNDM", +) + +############################################################################## +# SDXL Inference Options +# These options are used to control runtime parameters for SDXL inference. +############################################################################## + +p.add_argument( + "--prompt", + type=str, + default=" a cat under the snow with blue eyes, covered by snow, cinematic style, medium shot, professional photo, animal", + help="Prompt input to stable diffusion.", +) + +p.add_argument( + "--negative_prompt", + type=str, + default="Watermark, blurry, oversaturated, low resolution, pollution", + help="Negative prompt input to stable diffusion.", +) + +p.add_argument( + "--num_inference_steps", type=int, default=30, help="Number of UNet inference steps" +) + +p.add_argument( + "--batch_count", + type=int, + default=1, + help="Number of batches to run for a single prompt", +) + +p.add_argument( + "--guidance_scale", + type=float, + default=7.5, + help="Scale by which to adjust prompt guidance to the unconditional noise prediction output of UNet after each iteration.", +) + +p.add_argument( + "--seed", type=float, default=0, help="Seed for random number/latents generation." +) + +p.add_argument( + "--external_weight_path", + type=str, + default="", + help="Path to external weights file, for jobs with one weights filepath. When importing, this is used to specify where to save the model weights, and at runtime, this is used to specify where to load the model weights from.", +) + +p.add_argument( + "--external_weights_dir", + type=str, + default="", + help="Directory containing external weights for a job that requires more than one weights file. When importing, this is used to specify where to save the model weights, and at runtime, this is used to specify where to load the model weights from. Files will then be saved according to the parameters that make them unique, i.e. ___.", +) + +p.add_argument( + "--vmfb_path", type=str, default="", help="path to vmfb containing compiled module" +) + +p.add_argument( + "--pipeline_vmfb_path", + type=str, + default="", + help="path to vmfb containing compiled meta-module", +) + +p.add_argument( + "--external_weight_file", + type=str, + default=None, + help="Path to external weights, used in benchmark scripts.", +) + +p.add_argument( + "--pipeline_dir", + type=str, + default=None, + help="Directory to save pipeline artifacts", +) + +p.add_argument( + "--compiled_pipeline", + default=False, + action="store_true", + help="Do one-shot inference from tokens to image in a shrink-wrapped pipeline binary.", +) + +############################################################################## +# SDXL Modelling Options +# These options are used to control model defining parameters for SDXL. +# These are MLIR - changing variables! If you change them, you will need +# to import/download and recompile the model. +############################################################################## + +p.add_argument("--batch_size", type=int, default=1, help="Batch size for inference") +p.add_argument( + "--height", type=int, default=1024, help="Height of Stable Diffusion output image." +) +p.add_argument( + "--width", type=int, default=1024, help="Width of Stable Diffusion output image" +) +p.add_argument( + "--precision", + type=str, + default="fp16", + help="Precision of Stable Diffusion weights and graph.", +) +p.add_argument( + "--max_length", type=int, default=64, help="Sequence Length of Stable Diffusion" +) +p.add_argument("--vae_variant", type=str, default="decode", help="encode, decode") +p.add_argument( + "--return_index", + action="store_true", + help="Make scheduled unet compiled module return the step index.", +) + +p.add_argument( + "--vae_decomp_attn", + type=bool, + default=True, + help="Decompose attention for VAE decode only at fx graph level", +) + +############################################################################## +# SDXL script general options. +############################################################################## + +p.add_argument("--compile_to", type=str, default="mlir", help="torch, linalg, vmfb") + +p.add_argument( + "--external_weights", + type=str, + default=None, + choices=["safetensors", "irpa", "gguf", None], + help="Externalizes model weights from the torch dialect IR and its successors", +) + +# See --external_weight_path and --external_weight_dir to specify where to save the model weights. + +p.add_argument( + "--compare_vs_torch", + action="store_true", + help="Runs both turbine vmfb and a torch model to compare results", +) +p.add_argument( + "--decomp_attn", + default=False, + action="store_true", + help="Decompose attention at fx graph level", +) +p.add_argument( + "--exit_on_vmfb", + default=True, + action="store_false", + help="Exit program on vmfb compilation completion. Most scripts will also save .mlir if this is disabled.", +) +p.add_argument( + "--input_mlir", + type=str, + default=None, + help="Path to input mlir file to compile. Comma-separate paths to provide more than one input to pipelines.", +) +p.add_argument( + "--download_mlir", + default=False, + action="store_true", + help="Download missing mlir files from Azure storage.", +) +p.add_argument( + "--container_name", + type=str, + default=None, + help="Azure storage container name to download mlir files from.", +) + + +############################################################################## +# IREE Compiler Options +############################################################################## + +p.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm") + +p.add_argument( + "--rt_device", + type=str, + default="local-task", + help="local-task, local-sync, vulkan://0, rocm://0, cuda://0, etc.", +) + +# TODO: Bring in detection for target triple +p.add_argument( + "--iree_target_triple", + type=str, + default="", + help="Specify vulkan target triple or rocm/cuda target device.", +) + +p.add_argument("--ireec_flags", type=str, default="", help="extra iree-compile options") + +p.add_argument( + "--attn_flags", + type=str, + default="", + help="extra iree-compile options for models with iree_linalg_ext.attention ops.", +) + +p.add_argument( + "--attn_spec", + type=str, + default=None, + help="extra iree-compile options for models with iree_linalg_ext.attention ops. Set this to 'default' if you are using mfma-capable hardware with ROCM.", +) + +p.add_argument( + "--clip_flags", + type=str, + default="", + help="extra iree-compile options to send for compiling CLIP/prompt_encoder. Only use this for testing bleeding edge flags! Any default options should be added to sd_inference/utils.py", +) + +p.add_argument( + "--vae_flags", + type=str, + default="", + help="extra iree-compile options to send for compiling VAE. Only use this for testing bleeding edge flags! Any default options should be added to sd_inference/utils.py", +) + +p.add_argument( + "--unet_flags", + type=str, + default="", + help="extra iree-compile options to send for compiling unet. Only use this for testing bleeding edge flags! Any default options should be added to sd_inference/utils.py", +) + + +args, unknown = p.parse_known_args() diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py new file mode 100644 index 000000000..f17a17f60 --- /dev/null +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py @@ -0,0 +1,693 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import torch +from turbine_models.custom_models.sdxl_inference import ( + sdxl_prompt_encoder, + sdxl_scheduled_unet, + vae, +) +import iree.runtime as ireert +from turbine_models.custom_models.sd_inference import utils +from turbine_models.utils.sdxl_benchmark import run_benchmark +from turbine_models.model_runner import vmfbRunner +from transformers import CLIPTokenizer + +from PIL import Image +import os +import numpy as np +import time +from datetime import datetime as dt + +device_list = [ + "cpu", + "vulkan", + "cuda", + "rocm", +] + +rt_device_list = [ + "local-task", + "local-sync", + "vulkan", + "cuda", + "rocm", +] + +empty_pipe_dict = { + "vae_decode": None, + "prompt_encoder": None, + "scheduled_unet": None, + "pipeline": None, + "full_pipeline": None, +} + + +class SharkSDXLPipeline: + def __init__( + self, + hf_model_name: str, + scheduler_id: str, + height: int, + width: int, + precision: str, + max_length: int, + batch_size: int, + num_inference_steps: int, + device: str, + iree_target_triple: str, + ireec_flags: dict, + attn_spec: str = None, + decomp_attn: bool = False, + pipeline_dir: str = "./shark_vmfbs", + external_weights_dir: str = "./shark_weights", + external_weights: str = "safetensors", + vae_decomp_attn: bool = True, + ): + self.hf_model_name = hf_model_name + self.scheduler_id = scheduler_id + self.height = height + self.width = width + self.precision = precision + self.max_length = max_length + self.batch_size = batch_size + self.num_inference_steps = num_inference_steps + self.device = device + self.iree_target_triple = iree_target_triple + self.ireec_flags = ireec_flags + self.attn_spec = attn_spec + self.decomp_attn = decomp_attn + self.pipeline_dir = pipeline_dir + self.external_weights_dir = external_weights_dir + self.external_weights = external_weights + self.vae_decomp_attn = vae_decomp_attn + + # FILE MANAGEMENT AND PIPELINE SETUP + + def check_prepared( + self, + mlirs: dict, + vmfbs: dict, + weights: dict, + interactive: bool = True, + ): + ready, vmfbs, weights = self.is_prepared(vmfbs, weights) + if not ready: + if interactive: + do_continue = input( + f"\nIt seems you are missing some necessary files. Would you like to generate them now? (y/n)" + ) + if do_continue.lower() != "y": + exit() + else: + do_continue = "y" + if do_continue.lower() == "y": + for submodel in vmfbs.keys(): + if vmfbs[submodel] == None: + vmfb, weight = self.export_submodel(submodel, input_mlir=mlirs) + vmfbs[submodel] = vmfb + if weights[submodel] is None: + weights[submodel] = weight + elif weights[submodel] is None and "pipeline" not in submodel: + _, weight = self.export_submodel(submodel, weights_only=True) + weights[submodel] = weight + ready, vmfbs, weights = self.is_prepared(vmfbs, weights) + if ready: + print("All necessary files found. Generating images.") + return vmfbs, weights + else: + print("There was an error generating the necessary files.") + exit() + else: + print("All necessary files found. Generating images.") + return vmfbs, weights + + def is_prepared(self, vmfbs, weights): + missing = [] + for key in vmfbs: + if key == "scheduled_unet": + val = f"{self.scheduler_id}_unet_{self.num_inference_steps}" + default_filepath = os.path.join(self.pipeline_dir, val + ".vmfb") + else: + val = vmfbs[key] + default_filepath = os.path.join(self.pipeline_dir, key + ".vmfb") + if vmfbs[key] is not None and os.path.exists(vmfbs[key]): + continue + elif vmfbs[key] == None and os.path.exists(default_filepath): + vmfbs[key] = default_filepath + elif val is None: + missing.append(key + ".vmfb") + else: + missing.append(val + ".vmfb") + for w_key in weights: + if "pipeline" in w_key: + continue + if weights[w_key] is not None and os.path.exists(weights[w_key]): + continue + default_name = os.path.join( + self.external_weights_dir, w_key + "." + self.external_weights + ) + if weights[w_key] is None and os.path.exists(default_name): + weights[w_key] = os.path.join(default_name) + else: + missing.append(w_key + "." + self.external_weights) + if len(missing) > 0: + print(f"Missing files: " + ", ".join(missing)) + return False, vmfbs, weights + else: + return True, vmfbs, weights + + def get_mlir_from_turbine_tank(self, submodel, container_name): + from turbine_models.turbine_tank import downloadModelArtifacts + + safe_name = utils.create_safe_name( + self.hf_model_name, + f"_{self.max_length}_{self.height}x{self.width}_{self.precision}_{submodel}.mlir", + ) + mlir_path = downloadModelArtifacts( + safe_name, + container_name, + ) + return mlir_path + + # IMPORT / COMPILE PHASE + + def get_torch_models(self, submodel): + match submodel: + case "scheduled_unet": + scheduled_unet_torch = sdxl_scheduled_unet.SDXLScheduledUnet( + # This is a public model, so no auth required + self.hf_model_name, + self.scheduler_id, + self.height, + self.width, + self.batch_size, + None, + precision=self.precision, + num_inference_steps=self.num_inference_steps, + ) + return scheduled_unet_torch + case "vae_decode": + vae_torch = vae.VaeModel( + # This is a public model, so no auth required + self.hf_model_name, + custom_vae=( + "madebyollin/sdxl-vae-fp16-fix" + if self.precision == "fp16" + else None + ), + ) + return vae_torch + + def export_submodel( + self, + submodel: str, + input_mlir: str = None, + weights_only: bool = False, + ): + if not os.path.exists(self.pipeline_dir): + os.makedirs(self.pipeline_dir) + if self.external_weights_dir: + if not os.path.exists(self.external_weights_dir): + os.makedirs(external_weights_dir, exist_ok=True) + vae_external_weight_path = os.path.join( + self.external_weights_dir, "vae_decode." + self.external_weights + ) + unet_external_weight_path = os.path.join( + self.external_weights_dir, "scheduled_unet." + self.external_weights + ) + prompt_encoder_external_weight_path = os.path.join( + self.external_weights_dir, "prompt_encoder." + self.external_weights + ) + elif self.external_weights is None: + print( + "No external weights type specified using --external_weights, weights for imported .mlir files will not be externalized." + ) + vae_external_weight_path = None + unet_external_weight_path = None + prompt_encoder_external_weight_path = None + else: + print( + f"No external weights directory specified using --external_weights_dir, we assume you have your own weights in {self.pipeline_dir}." + ) + external_weights_dir = self.pipeline_dir + if not os.path.exists(self.pipeline_dir): + os.makedirs(self.pipeline_dir, exist_ok=True) + vae_external_weight_path = os.path.join( + self.pipeline_dir, "vae_decode." + self.external_weights + ) + unet_external_weight_path = os.path.join( + self.pipeline_dir, "scheduled_unet." + self.external_weights + ) + prompt_encoder_external_weight_path = os.path.join( + self.pipeline_dir, "prompt_encoder." + self.external_weights + ) + if weights_only: + input_mlir = { + "vae_decode": None, + "prompt_encoder": None, + "scheduled_unet": None, + "pipeline": None, + "full_pipeline": None, + } + match submodel: + case "scheduled_unet": + if not input_mlir[submodel]: + scheduled_unet_torch = self.get_torch_models("scheduled_unet") + else: + scheduled_unet_torch = None + unet_vmfb = sdxl_scheduled_unet.export_scheduled_unet_model( + scheduled_unet_torch, + self.scheduler_id, + self.num_inference_steps, + self.hf_model_name, + self.batch_size, + self.height, + self.width, + self.precision, + self.max_length, + None, + "vmfb", + self.external_weights, + unet_external_weight_path, + self.device, + self.iree_target_triple, + self.ireec_flags["unet"], + self.decomp_attn, + exit_on_vmfb=False, + pipeline_dir=self.pipeline_dir, + attn_spec=self.attn_spec, + input_mlir=input_mlir["scheduled_unet"], + weights_only=weights_only, + ) + return unet_vmfb, unet_external_weight_path + case "vae_decode": + if not input_mlir[submodel]: + vae_torch = self.get_torch_models("vae_decode") + else: + vae_torch = None + vae_decode_vmfb = vae.export_vae_model( + vae_torch, + self.hf_model_name, + self.batch_size, + self.height, + self.width, + self.precision, + "vmfb", + self.external_weights, + vae_external_weight_path, + self.device, + self.iree_target_triple, + self.ireec_flags["vae"], + "decode", + self.vae_decomp_attn, + exit_on_vmfb=False, + pipeline_dir=self.pipeline_dir, + attn_spec=self.attn_spec, + input_mlir=input_mlir["vae_decode"], + weights_only=weights_only, + ) + return vae_decode_vmfb, vae_external_weight_path + case "prompt_encoder": + _, prompt_encoder_vmfb = sdxl_prompt_encoder.export_prompt_encoder( + self.hf_model_name, + None, + self.max_length, + self.precision, + "vmfb", + self.external_weights, + prompt_encoder_external_weight_path, + self.device, + self.iree_target_triple, + self.ireec_flags["clip"], + exit_on_vmfb=False, + pipeline_dir=self.pipeline_dir, + input_mlir=input_mlir["prompt_encoder"], + attn_spec=self.attn_spec, + weights_only=weights_only, + ) + return prompt_encoder_vmfb, prompt_encoder_external_weight_path + case "pipeline": + pipeline_file = ( + "sdxl_sched_unet_bench_" + "f32" + if self.precision == "fp32" + else "sdxl_sched_unet_bench_" + "f16" + ) + pipeline_vmfb = utils.compile_to_vmfb( + os.path.join( + os.path.realpath(os.path.dirname(__file__)), + pipeline_file + ".mlir", + ), + self.device, + self.iree_target_triple, + self.ireec_flags["pipeline"], + os.path.join(self.pipeline_dir, "pipeline"), + return_path=True, + mlir_source="file", + ) + return pipeline_vmfb, None + case "full_pipeline": + pipeline_file = ( + "sdxl_pipeline_bench_" + "f32" + if self.precision == "fp32" + else "sdxl_pipeline_bench_" + "f16" + ) + pipeline_vmfb = utils.compile_to_vmfb( + os.path.join( + os.path.realpath(os.path.dirname(__file__)), + pipeline_file + ".mlir", + ), + self.device, + self.iree_target_triple, + self.ireec_flags["pipeline"], + os.path.join(self.pipeline_dir, "full_pipeline"), + return_path=True, + mlir_source="file", + ) + return pipeline_vmfb, None + + # LOAD + + def load_pipeline( + self, + vmfbs: dict, + weights: dict, + rt_device: str = "local-task", + compiled_pipeline: bool = True, + ): + self.runners = {} + runners = {} + if compiled_pipeline: + runners["pipe"] = vmfbRunner( + rt_device, + [ + vmfbs["scheduled_unet"], + vmfbs["prompt_encoder"], + vmfbs["vae_decode"], + vmfbs["full_pipeline"], + ], + [ + weights["scheduled_unet"], + weights["prompt_encoder"], + weights["vae_decode"], + None, + ], + ) + else: + runners["pipe"] = vmfbRunner( + rt_device, + [vmfbs["scheduled_unet"], vmfbs["pipeline"]], + [weights["scheduled_unet"], None], + ) + runners["vae_decode"] = vmfbRunner( + rt_device, vmfbs["vae_decode"], weights["vae_decode"] + ) + runners["prompt_encoder"] = vmfbRunner( + rt_device, vmfbs["prompt_encoder"], weights["prompt_encoder"] + ) + runners["tokenizer_1"] = CLIPTokenizer.from_pretrained( + self.hf_model_name, + subfolder="tokenizer", + ) + runners["tokenizer_2"] = CLIPTokenizer.from_pretrained( + self.hf_model_name, + subfolder="tokenizer_2", + ) + self.runners = runners + self.compiled_pipeline = compiled_pipeline + print("Successfully loaded pipeline.") + + # RUN + + def generate_images( + self, + prompt: str, + negative_prompt: str = "", + batch_count: int = 1, + guidance_scale: float = 7.5, + seed: float = -1, + ): + # TODO: implement case where this is false e.g. in SDXL Turbo + # do_classifier_free_guidance = True + + iree_dtype = "float32" if self.precision == "fp32" else "float16" + torch_dtype = torch.float32 if self.precision == "fp32" else torch.float16 + + pipe_start = time.time() + + tokenizers = [self.runners["tokenizer_1"], self.runners["tokenizer_2"]] + + max_length = self.max_length + + samples = [] + numpy_images = [] + + if self.compiled_pipeline and (batch_count > 1): + print( + "Compiled one-shot pipeline only supports 1 image at a time for now. Setting batch count to 1." + ) + batch_count = 1 + + for i in range(batch_count): + generator = torch.random.manual_seed(seed + i) + rand_sample = torch.randn( + ( + self.batch_size, + 4, + self.height // 8, + self.width // 8, + ), + generator=generator, + dtype=torch_dtype, + ) + samples.append( + ireert.asdevicearray( + self.runners["pipe"].config.device, rand_sample, dtype=iree_dtype + ) + ) + + guidance_scale = ireert.asdevicearray( + self.runners["pipe"].config.device, + np.asarray([guidance_scale]), + dtype=iree_dtype, + ) + + text_input_ids_list = [] + uncond_input_ids_list = [] + + tokenize_start = time.time() + + # Tokenize prompt and negative prompt. + for tokenizer in tokenizers: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + uncond_input_ids = uncond_input.input_ids + + text_input_ids_list.extend( + [ + ireert.asdevicearray( + self.runners["pipe"].config.device, text_input_ids + ) + ] + ) + uncond_input_ids_list.extend( + [ + ireert.asdevicearray( + self.runners["pipe"].config.device, uncond_input_ids + ) + ] + ) + if self.compiled_pipeline: + inf_start = time.time() + image = self.runners["pipe"].ctx.modules.sdxl_compiled_pipeline[ + "tokens_to_image" + ](samples[0], guidance_scale, *text_input_ids_list, *uncond_input_ids_list) + inf_end = time.time() + print( + "Total inference time (Tokens to Image): " + + str(inf_end - inf_start) + + "sec" + ) + numpy_images.append(image.to_host()) + else: + encode_prompts_start = time.time() + + prompt_embeds, add_text_embeds = self.runners[ + "prompt_encoder" + ].ctx.modules.compiled_clip["encode_prompts"]( + *text_input_ids_list, *uncond_input_ids_list + ) + + encode_prompts_end = time.time() + + for i in range(batch_count): + unet_start = time.time() + + latents = self.runners["pipe"].ctx.modules.sdxl_compiled_pipeline[ + "produce_image_latents" + ](samples[i], prompt_embeds, add_text_embeds, guidance_scale) + + vae_start = time.time() + vae_out = self.runners["vae_decode"].ctx.modules.compiled_vae["main"]( + latents + ) + + pipe_end = time.time() + + image = vae_out.to_host() + + numpy_images.append(image) + print("Batch #", i + 1, "\n") + print( + "UNet time(", + self.num_inference_steps, + "): ", + vae_start - unet_start, + "sec,", + ) + print( + "Unet average step latency: ", + (vae_start - unet_start) / self.num_inference_steps, + "sec", + ) + print("VAE time: ", pipe_end - vae_start, "sec") + print( + f"\nTotal time (txt2img, batch #{str(i+1)}): ", + (encode_prompts_end - encode_prompts_start) + + (pipe_end - unet_start), + "sec\n", + ) + end = time.time() + print("Total CLIP time:", encode_prompts_end - encode_prompts_start, "sec") + print("Total tokenize time:", encode_prompts_start - tokenize_start, "sec") + print("Loading time: ", encode_prompts_start - pipe_start, "sec") + if batch_count > 1: + print( + f"Total inference time ({batch_count} batch(es)):", + end - encode_prompts_start, + "sec", + ) + timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") + for idx, image in enumerate(numpy_images): + image = torch.from_numpy(image).cpu().permute(0, 2, 3, 1).float().numpy() + image = numpy_to_pil_image(image) + img_path = "sdxl_output_" + timestamp + "_" + str(idx) + ".png" + image[0].save(img_path) + print(img_path, "saved") + + +def numpy_to_pil_image(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] + else: + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + +if __name__ == "__main__": + from turbine_models.custom_models.sdxl_inference.sdxl_cmd_opts import args + + mlirs = { + "vae_decode": None, + "prompt_encoder": None, + "scheduled_unet": None, + "pipeline": None, + "full_pipeline": None, + } + vmfbs = { + "vae_decode": None, + "prompt_encoder": None, + "scheduled_unet": None, + "pipeline": None, + "full_pipeline": None, + } + weights = { + "vae_decode": None, + "prompt_encoder": None, + "scheduled_unet": None, + "pipeline": None, + "full_pipeline": None, + } + ireec_flags = { + "unet": args.ireec_flags + args.unet_flags, + "vae": args.ireec_flags + args.vae_flags, + "clip": args.ireec_flags + args.clip_flags, + "pipeline": args.ireec_flags, + } + if not args.pipeline_dir: + pipe_id_list = [ + "sdxl_1_0", + str(args.height), + str(args.width), + str(args.max_length), + args.precision, + args.device, + ] + args.pipeline_dir = os.path.join( + ".", + "_".join(pipe_id_list), + ) + if args.input_mlir: + user_mlir_list = args.input_mlir.split(",") + else: + user_mlir_list = [] + for submodel_id, mlir_path in zip(mlirs.keys(), user_mlir_list): + if submodel_id in mlir_path: + mlirs[submodel_id] = mlir_path + if not args.external_weights_dir and args.external_weights: + args.external_weights_dir = args.pipeline_dir + + sdxl_pipe = SharkSDXLPipeline( + args.hf_model_name, + args.scheduler_id, + args.height, + args.width, + args.precision, + args.max_length, + args.batch_size, + args.num_inference_steps, + args.device, + args.iree_target_triple, + ireec_flags, + args.attn_spec, + args.decomp_attn, + args.pipeline_dir, + args.external_weights_dir, + args.external_weights, + args.vae_decomp_attn, + ) + vmfbs, weights = sdxl_pipe.check_prepared(mlirs, vmfbs, weights) + sdxl_pipe.load_pipeline(vmfbs, weights, args.rt_device, args.compiled_pipeline) + sdxl_pipe.generate_images( + args.prompt, + args.negative_prompt, + args.batch_count, + args.guidance_scale, + args.seed, + ) + print("Image generation complete.") diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_pipeline_bench_f16.mlir b/models/turbine_models/custom_models/sdxl_inference/sdxl_pipeline_bench_f16.mlir new file mode 100644 index 000000000..523d09fa6 --- /dev/null +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_pipeline_bench_f16.mlir @@ -0,0 +1,23 @@ +module @sdxl_compiled_pipeline { + func.func private @compiled_scheduled_unet.run_initialize(%arg0: tensor<1x4x128x128xf16>) -> (tensor<1x4x128x128xf16>, tensor<2x6xf16>, tensor) attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}]"} + func.func private @compiled_scheduled_unet.run_forward(%arg0: tensor<1x4x128x128xf16>, %arg1: tensor<2x64x2048xf16>, %arg2: tensor<2x1280xf16>, %arg3: tensor<2x6xf16>, %arg4: tensor<1xf16>, %arg5: tensor<1xi64>) -> tensor<1x4x128x128xf16> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"} + func.func private @compiled_clip.encode_prompts(%arg0: tensor<1x64xi64>, %arg1: tensor<1x64xi64>, %arg2: tensor<1x64xi64>, %arg3: tensor<1x64xi64>) -> (tensor<2x64x2048xf16>, tensor<2x1280xf16>) attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}]"} + func.func private @compiled_vae.main(%arg0: tensor<1x4x128x128xf16>) -> tensor<1x3x1024x1024xf16> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"} + + func.func @tokens_to_image(%sample: tensor<1x4x128x128xf16>, %guidance_scale: tensor<1xf16>, %t_ids_1: tensor<1x64xi64>, %t_ids_2: tensor<1x64xi64>, %u_ids_1: tensor<1x64xi64>, %u_ids_2: tensor<1x64xi64>) -> tensor<1x3x1024x1024xf16> { + %p_embeds, %t_embeds = func.call @compiled_clip.encode_prompts(%t_ids_1, %t_ids_2, %u_ids_1, %u_ids_2) : (tensor<1x64xi64>, tensor<1x64xi64>, tensor<1x64xi64>, tensor<1x64xi64>) -> (tensor<2x64x2048xf16>, tensor<2x1280xf16>) + %noisy_sample, %time_ids, %steps = func.call @compiled_scheduled_unet.run_initialize(%sample) : (tensor<1x4x128x128xf16>) -> (tensor<1x4x128x128xf16>, tensor<2x6xf16>, tensor) + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %steps_int = tensor.extract %steps[] : tensor + %n_steps = arith.index_cast %steps_int: i64 to index + %res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg = %noisy_sample) -> (tensor<1x4x128x128xf16>) { + %step_64 = arith.index_cast %arg0 : index to i64 + %this_step = tensor.from_elements %step_64 : tensor<1xi64> + %inner = func.call @compiled_scheduled_unet.run_forward(%arg, %p_embeds, %t_embeds, %time_ids, %guidance_scale, %this_step) : (tensor<1x4x128x128xf16>, tensor<2x64x2048xf16>, tensor<2x1280xf16>, tensor<2x6xf16>, tensor<1xf16>, tensor<1xi64>) -> tensor<1x4x128x128xf16> + scf.yield %inner : tensor<1x4x128x128xf16> + } + %image = func.call @compiled_vae.main(%res): (tensor<1x4x128x128xf16>) -> tensor<1x3x1024x1024xf16> + return %image : tensor<1x3x1024x1024xf16> + } +} \ No newline at end of file diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_pipeline_bench_f32.mlir b/models/turbine_models/custom_models/sdxl_inference/sdxl_pipeline_bench_f32.mlir new file mode 100644 index 000000000..669df73b2 --- /dev/null +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_pipeline_bench_f32.mlir @@ -0,0 +1,23 @@ +module @sdxl_compiled_pipeline { + func.func private @compiled_scheduled_unet.run_initialize(%arg0: tensor<1x4x128x128xf32>) -> (tensor<1x4x128x128xf32>, tensor<2x6xf32>, tensor) attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}]"} + func.func private @compiled_scheduled_unet.run_forward(%arg0: tensor<1x4x128x128xf32>, %arg1: tensor<2x64x2048xf32>, %arg2: tensor<2x1280xf32>, %arg3: tensor<2x6xf32>, %arg4: tensor<1xf32>, %arg5: tensor<1xi64>) -> tensor<1x4x128x128xf32> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"} + func.func private @compiled_clip.encode_prompts(%arg0: tensor<1x64xi64>, %arg1: tensor<1x64xi64>, %arg2: tensor<1x64xi64>, %arg3: tensor<1x64xi64>) -> (tensor<2x64x2048xf32>, tensor<2x1280xf32>) attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}]"} + func.func private @compiled_vae.main(%arg0: tensor<1x4x128x128xf32>) -> tensor<1x3x1024x1024xf32> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"} + + func.func @tokens_to_image(%sample: tensor<1x4x128x128xf32>, %guidance_scale: tensor<1xf32>, %t_ids_1: tensor<1x64xi64>, %t_ids_2: tensor<1x64xi64>, %u_ids_1: tensor<1x64xi64>, %u_ids_2: tensor<1x64xi64>) -> tensor<1x3x1024x1024xf32> { + %p_embeds, %t_embeds = func.call @compiled_clip.encode_prompts(%t_ids_1, %t_ids_2, %u_ids_1, %u_ids_2) : (tensor<1x64xi64>, tensor<1x64xi64>, tensor<1x64xi64>, tensor<1x64xi64>) -> (tensor<2x64x2048xf32>, tensor<2x1280xf32>) + %noisy_sample, %time_ids, %steps = func.call @compiled_scheduled_unet.run_initialize(%sample) : (tensor<1x4x128x128xf32>) -> (tensor<1x4x128x128xf32>, tensor<2x6xf32>, tensor) + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %steps_int = tensor.extract %steps[] : tensor + %n_steps = arith.index_cast %steps_int: i64 to index + %res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg = %noisy_sample) -> (tensor<1x4x128x128xf32>) { + %step_64 = arith.index_cast %arg0 : index to i64 + %this_step = tensor.from_elements %step_64 : tensor<1xi64> + %inner = func.call @compiled_scheduled_unet.run_forward(%arg, %p_embeds, %t_embeds, %time_ids, %guidance_scale, %this_step) : (tensor<1x4x128x128xf32>, tensor<2x64x2048xf32>, tensor<2x1280xf32>, tensor<2x6xf32>, tensor<1xf32>, tensor<1xi64>) -> tensor<1x4x128x128xf32> + scf.yield %inner : tensor<1x4x128x128xf32> + } + %image = func.call @compiled_vae.main(%res): (tensor<1x4x128x128xf32>) -> tensor<1x3x1024x1024xf32> + return %image : tensor<1x3x1024x1024xf32> + } +} \ No newline at end of file diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py new file mode 100644 index 000000000..1c6b6331c --- /dev/null +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py @@ -0,0 +1,247 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import sys + +from iree import runtime as ireert +import iree.compiler as ireec +from iree.compiler.ir import Context +import numpy as np +from shark_turbine.aot import * +from turbine_models.custom_models.sd_inference import utils +import torch +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + + +class PromptEncoderModule(torch.nn.Module): + def __init__( + self, + hf_model_name, + precision, + hf_auth_token=None, + do_classifier_free_guidance=True, + ): + super().__init__() + self.torch_dtype = torch.float16 if precision == "fp16" else torch.float32 + self.text_encoder_model_1 = CLIPTextModel.from_pretrained( + hf_model_name, + subfolder="text_encoder", + token=hf_auth_token, + ) + self.text_encoder_model_2 = CLIPTextModelWithProjection.from_pretrained( + hf_model_name, + subfolder="text_encoder_2", + token=hf_auth_token, + ) + self.do_classifier_free_guidance = do_classifier_free_guidance + + def forward( + self, text_input_ids_1, text_input_ids_2, uncond_input_ids_1, uncond_input_ids_2 + ): + with torch.no_grad(): + prompt_embeds_1 = self.text_encoder_model_1( + text_input_ids_1, + output_hidden_states=True, + ) + prompt_embeds_2 = self.text_encoder_model_2( + text_input_ids_2, + output_hidden_states=True, + ) + neg_prompt_embeds_1 = self.text_encoder_model_1( + uncond_input_ids_1, + output_hidden_states=True, + ) + neg_prompt_embeds_2 = self.text_encoder_model_2( + uncond_input_ids_2, + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds_2[0] + neg_pooled_prompt_embeds = neg_prompt_embeds_2[0] + + prompt_embeds_list = [ + prompt_embeds_1.hidden_states[-2], + prompt_embeds_2.hidden_states[-2], + ] + neg_prompt_embeds_list = [ + neg_prompt_embeds_1.hidden_states[-2], + neg_prompt_embeds_2.hidden_states[-2], + ] + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + neg_prompt_embeds = torch.concat(neg_prompt_embeds_list, dim=-1) + + bs_embed, seq_len, _ = prompt_embeds.shape + + prompt_embeds = prompt_embeds.repeat(1, 1, 1) + prompt_embeds = prompt_embeds.view(bs_embed * 1, seq_len, -1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view( + bs_embed * 1, -1 + ) + add_text_embeds = pooled_prompt_embeds + if self.do_classifier_free_guidance: + neg_pooled_prompt_embeds = neg_pooled_prompt_embeds.repeat(1, 1).view( + 1, -1 + ) + neg_prompt_embeds = neg_prompt_embeds.repeat(1, 1, 1) + neg_prompt_embeds = neg_prompt_embeds.view(bs_embed * 1, seq_len, -1) + prompt_embeds = torch.cat([neg_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat( + [neg_pooled_prompt_embeds, add_text_embeds], dim=0 + ) + + add_text_embeds = add_text_embeds.to(self.torch_dtype) + prompt_embeds = prompt_embeds.to(self.torch_dtype) + return prompt_embeds, add_text_embeds + + +def export_prompt_encoder( + hf_model_name, + hf_auth_token=None, + max_length=64, + precision="fp16", + compile_to="torch", + external_weights=None, + external_weight_path=None, + device=None, + target_triple=None, + ireec_flags=None, + exit_on_vmfb=True, + pipeline_dir=None, + input_mlir=None, + attn_spec=None, + weights_only=False, +): + if "turbo" in hf_model_name: + do_classifier_free_guidance = False + else: + do_classifier_free_guidance = True + + if (attn_spec in ["default"]) and ("gfx94" in target_triple): + attn_spec = os.path.join( + os.path.realpath(os.path.dirname(__file__)), "default_mfma_attn_spec.mlir" + ) + else: + attn_spec = None + + if pipeline_dir not in [None, ""]: + safe_name = os.path.join(pipeline_dir, "prompt_encoder") + else: + safe_name = utils.create_safe_name( + hf_model_name, f"-{str(max_length)}-{precision}-prompt-encoder-{device}" + ) + if input_mlir: + vmfb_path = utils.compile_to_vmfb( + input_mlir, + device, + target_triple, + ireec_flags, + safe_name, + mlir_source="file", + return_path=not exit_on_vmfb, + const_expr_hoisting=True, + attn_spec=attn_spec, + ) + return vmfb_path + # Load the tokenizer and text encoder to tokenize and encode the text. + tokenizer_1 = CLIPTokenizer.from_pretrained( + hf_model_name, + subfolder="tokenizer", + token=hf_auth_token, + model_max_length=max_length, + ) + tokenizer_2 = CLIPTokenizer.from_pretrained( + hf_model_name, + subfolder="tokenizer_2", + token=hf_auth_token, + model_max_length=max_length, + ) + tokenizers = [tokenizer_1, tokenizer_2] + prompt_encoder_module = PromptEncoderModule( + hf_model_name, precision, hf_auth_token, do_classifier_free_guidance + ) + if precision == "fp16": + prompt_encoder_module = prompt_encoder_module.half() + mapper = {} + + utils.save_external_weights( + mapper, prompt_encoder_module, external_weights, external_weight_path + ) + + if weights_only: + return None, external_weight_path + + class CompiledClip(CompiledModule): + if external_weights: + params = export_parameters( + prompt_encoder_module, + external=True, + external_scope="", + name_mapper=mapper.get, + ) + else: + params = export_parameters(prompt_encoder_module) + + def encode_prompts( + self, + t_ids_1=AbstractTensor(1, max_length, dtype=torch.int64), + t_ids_2=AbstractTensor(1, max_length, dtype=torch.int64), + uc_ids_1=AbstractTensor(1, max_length, dtype=torch.int64), + uc_ids_2=AbstractTensor(1, max_length, dtype=torch.int64), + ): + return jittable(prompt_encoder_module.forward)( + t_ids_1, t_ids_2, uc_ids_1, uc_ids_2 + ) + + import_to = "INPUT" if compile_to == "linalg" else "IMPORT" + inst = CompiledClip(context=Context(), import_to=import_to) + + module_str = str(CompiledModule.get_mlir_module(inst)) + + if compile_to != "vmfb": + return module_str, tokenizers + else: + vmfb_path = utils.compile_to_vmfb( + module_str, + device, + target_triple, + ireec_flags, + safe_name, + return_path=not exit_on_vmfb, + const_expr_hoisting=True, + attn_spec=attn_spec, + ) + return module_str, vmfb_path + + +if __name__ == "__main__": + from turbine_models.custom_models.sdxl_inference.sdxl_cmd_opts import args + + mod_str, _ = export_prompt_encoder( + args.hf_model_name, + args.hf_auth_token, + args.max_length, + args.precision, + args.compile_to, + args.external_weights, + args.external_weight_path, + args.device, + args.iree_target_triple, + args.ireec_flags + args.clip_flags, + exit_on_vmfb=True, + pipeline_dir=args.pipeline_dir, + input_mlir=args.input_mlir, + attn_spec=args.attn_spec, + ) + if args.input_mlir: + exit() + safe_name_1 = safe_name = utils.create_safe_name( + args.hf_model_name, f"_{str(args.max_length)}_{args.precision}_prompt_encoder" + ) + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder_runner.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder_runner.py new file mode 100644 index 000000000..50c01e964 --- /dev/null +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder_runner.py @@ -0,0 +1,164 @@ +from turbine_models.model_runner import vmfbRunner +from transformers import CLIPTokenizer +from iree import runtime as ireert +import torch +import numpy as np + + +def run_torch_clip(hf_model_name, hf_auth_token, prompt, max_length=64): + # TODO: Integrate with HFTransformerBuilder + from turbine_models.custom_models.sdxl_inference.clip import ClipModel + + model_1 = ClipModel(hf_model_name, hf_auth_token, index=1) + model_2 = ClipModel(hf_model_name, hf_auth_token, index=2) + tokenizer_1 = CLIPTokenizer.from_pretrained( + hf_model_name, + subfolder="tokenizer", + token=hf_auth_token, + ) + tokenizer_2 = CLIPTokenizer.from_pretrained( + hf_model_name, + subfolder="tokenizer_2", + token=hf_auth_token, + ) + text_input_1 = tokenizer_1( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + text_input_2 = tokenizer_2( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + example_input_1 = text_input_1.input_ids + example_input_2 = text_input_2.input_ids + + results_1 = model_1.forward(example_input_1) + results_2 = model_2.forward(example_input_2) + np_torch_output_1 = results_1[0].detach().cpu().numpy().astype(np.float16) + np_torch_output_2 = results_2[0].detach().cpu().numpy().astype(np.float16) + return np_torch_output_1, np_torch_output_2 + + +def run_prompt_encoder( + args, + input_ids, + uncond_input_ids, +): + prompt_encoder_runner = vmfbRunner( + args.device, args.vmfb_path, args.external_weight_path + ) + np.save("input0.npy", input_ids[0].numpy()) + np.save("input1.npy", input_ids[1].numpy()) + np.save("input2.npy", uncond_input_ids[0].numpy()) + np.save("input3.npy", uncond_input_ids[1].numpy()) + prompt_encoder_inputs = [ + ireert.asdevicearray(prompt_encoder_runner.config.device, input_ids[0]), + ireert.asdevicearray(prompt_encoder_runner.config.device, input_ids[1]), + ireert.asdevicearray(prompt_encoder_runner.config.device, uncond_input_ids[0]), + ireert.asdevicearray(prompt_encoder_runner.config.device, uncond_input_ids[1]), + ] + encoded_outputs = prompt_encoder_runner.ctx.modules.compiled_clip["encode_prompts"]( + *prompt_encoder_inputs + ) + del prompt_encoder_inputs + return encoded_outputs + + +if __name__ == "__main__": + from turbine_models.custom_models.sdxl_inference.sdxl_cmd_opts import args + + tokenizer_1 = CLIPTokenizer.from_pretrained( + args.hf_model_name, + subfolder="tokenizer", + token=args.hf_auth_token, + ) + tokenizer_2 = CLIPTokenizer.from_pretrained( + args.hf_model_name, + subfolder="tokenizer_2", + token=args.hf_auth_token, + ) + text_input_ids_list = [] + uncond_input_ids_list = [] + + # Tokenize prompt and negative prompt. + tokenizers = [tokenizer_1, tokenizer_2] + for tokenizer in tokenizers: + text_inputs = tokenizer( + args.prompt, + padding="max_length", + max_length=args.max_length, + truncation=True, + return_tensors="pt", + ) + uncond_input = tokenizer( + args.negative_prompt, + padding="max_length", + max_length=args.max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + uncond_input_ids = uncond_input.input_ids + + text_input_ids_list.extend([text_input_ids]) + uncond_input_ids_list.extend([uncond_input_ids]) + + turbine_output1, turbine_output2 = run_prompt_encoder( + args, + text_input_ids_list, + uncond_input_ids_list, + ) + print( + "TURBINE OUTPUT 1:", + turbine_output1.to_host(), + turbine_output1.shape, + turbine_output1.dtype, + ) + + print( + "TURBINE OUTPUT 2:", + turbine_output2.to_host(), + turbine_output2.shape, + turbine_output2.dtype, + ) + + if args.compare_vs_torch: + print("generating torch output: ") + from turbine_models.custom_models.sd_inference import utils + from turbine_models.custom_models.sdxl_inference.sdxl_prompt_encoder import ( + PromptEncoderModule, + ) + + torch_encoder_model = PromptEncoderModule( + args.hf_model_name, args.precision, args.hf_auth_token + ) + torch_output1, torch_output2 = torch_encoder_model.forward( + *text_input_ids_list, *uncond_input_ids_list + ) + np.save("torch_output1.npy", torch_output1) + np.save("torch_output2.npy", torch_output2) + print( + "TORCH OUTPUT 1:", torch_output1, torch_output1.shape, torch_output1.dtype + ) + + print( + "TORCH OUTPUT 2:", torch_output2, torch_output2.shape, torch_output2.dtype + ) + rtol = 4e-2 + atol = 4e-2 + + np.testing.assert_allclose( + torch_output1, turbine_output1.to_host(), rtol, atol, verbose=True + ) + np.testing.assert_allclose( + torch_output2, turbine_output2.to_host(), rtol, atol, verbose=True + ) + print("Passed!") + # TODO: Figure out why we occasionally segfault without unlinking output variables + turbine_output1, turbine_output2 = (None, None) diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_sched_unet_bench_f16.mlir b/models/turbine_models/custom_models/sdxl_inference/sdxl_sched_unet_bench_f16.mlir new file mode 100644 index 000000000..b12fc82b9 --- /dev/null +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_sched_unet_bench_f16.mlir @@ -0,0 +1,19 @@ +module @sdxl_compiled_pipeline { + func.func private @compiled_scheduled_unet.run_initialize(%arg0: tensor<1x4x128x128xf16>) -> (tensor<1x4x128x128xf16>, tensor<2x6xf16>, tensor) attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}]"} + func.func private @compiled_scheduled_unet.run_forward(%arg0: tensor<1x4x128x128xf16>, %arg1: tensor<2x64x2048xf16>, %arg2: tensor<2x1280xf16>, %arg3: tensor<2x6xf16>, %arg4: tensor<1xf16>, %arg5: tensor<1xi64>) -> tensor<1x4x128x128xf16> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"} + + func.func @produce_image_latents(%sample: tensor<1x4x128x128xf16>, %p_embeds: tensor<2x64x2048xf16>, %t_embeds: tensor<2x1280xf16>, %guidance_scale: tensor<1xf16>) -> tensor<1x4x128x128xf16> { + %noisy_sample, %time_ids, %steps = func.call @compiled_scheduled_unet.run_initialize(%sample) : (tensor<1x4x128x128xf16>) -> (tensor<1x4x128x128xf16>, tensor<2x6xf16>, tensor) + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %steps_int = tensor.extract %steps[] : tensor + %n_steps = arith.index_cast %steps_int: i64 to index + %res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg = %noisy_sample) -> (tensor<1x4x128x128xf16>) { + %step_64 = arith.index_cast %arg0 : index to i64 + %this_step = tensor.from_elements %step_64 : tensor<1xi64> + %inner = func.call @compiled_scheduled_unet.run_forward(%arg, %p_embeds, %t_embeds, %time_ids, %guidance_scale, %this_step) : (tensor<1x4x128x128xf16>, tensor<2x64x2048xf16>, tensor<2x1280xf16>, tensor<2x6xf16>, tensor<1xf16>, tensor<1xi64>) -> tensor<1x4x128x128xf16> + scf.yield %inner : tensor<1x4x128x128xf16> + } + return %res : tensor<1x4x128x128xf16> + } +} \ No newline at end of file diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_sched_unet_bench_f32.mlir b/models/turbine_models/custom_models/sdxl_inference/sdxl_sched_unet_bench_f32.mlir new file mode 100644 index 000000000..fbc69f854 --- /dev/null +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_sched_unet_bench_f32.mlir @@ -0,0 +1,19 @@ +module @sdxl_compiled_pipeline { + func.func private @compiled_scheduled_unet.run_initialize(%arg0: tensor<1x4x128x128xf32>) -> (tensor<1x4x128x128xf32>, tensor<2x6xf32>, tensor) attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}]"} + func.func private @compiled_scheduled_unet.run_forward(%arg0: tensor<1x4x128x128xf32>, %arg1: tensor<2x64x2048xf32>, %arg2: tensor<2x1280xf32>, %arg3: tensor<2x6xf32>, %arg4: tensor<1xf32>, %arg5: tensor<1xi64>) -> tensor<1x4x128x128xf32> attributes {torch.args_schema = "[1, {\22type\22: \22builtins.tuple\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: \22builtins.list\22, \22context\22: \22null\22, \22children_spec\22: [{\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]}, {\22type\22: \22builtins.dict\22, \22context\22: \22[]\22, \22children_spec\22: []}]}]", torch.return_schema = "[1, {\22type\22: null, \22context\22: null, \22children_spec\22: []}]"} + + func.func @produce_image_latents(%sample: tensor<1x4x128x128xf32>, %p_embeds: tensor<2x64x2048xf32>, %t_embeds: tensor<2x1280xf32>, %guidance_scale: tensor<1xf32>) -> tensor<1x4x128x128xf32> { + %noisy_sample, %time_ids, %steps = func.call @compiled_scheduled_unet.run_initialize(%sample) : (tensor<1x4x128x128xf32>) -> (tensor<1x4x128x128xf32>, tensor<2x6xf32>, tensor) + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %steps_int = tensor.extract %steps[] : tensor + %n_steps = arith.index_cast %steps_int: i64 to index + %res = scf.for %arg0 = %c0 to %n_steps step %c1 iter_args(%arg_s = %noisy_sample) -> (tensor<1x4x128x128xf32>) { + %step_64 = arith.index_cast %arg0 : index to i64 + %this_step = tensor.from_elements %step_64 : tensor<1xi64> + %inner = func.call @compiled_scheduled_unet.run_forward(%arg_s, %p_embeds, %t_embeds, %time_ids, %guidance_scale, %this_step) : (tensor<1x4x128x128xf32>, tensor<2x64x2048xf32>, tensor<2x1280xf32>, tensor<2x6xf32>, tensor<1xf32>, tensor<1xi64>) -> tensor<1x4x128x128xf32> + scf.yield %inner : tensor<1x4x128x128xf32> + } + return %res : tensor<1x4x128x128xf32> + } +} \ No newline at end of file diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py new file mode 100644 index 000000000..f74c707e7 --- /dev/null +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet.py @@ -0,0 +1,344 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# from @aviator19941's gist : https://gist.github.com/aviator19941/4e7967bd1787c83ee389a22637c6eea7 + +import os +import sys + +from iree import runtime as ireert +from iree.compiler.ir import Context +import numpy as np +from shark_turbine.aot import * +from turbine_models.custom_models.sd_inference import utils +import torch +import torch._dynamo as dynamo +from diffusers import UNet2DConditionModel +from shark_turbine.dynamo.passes import ( + DEFAULT_DECOMPOSITIONS, +) + + +class SDXLScheduledUnet(torch.nn.Module): + def __init__( + self, + hf_model_name, + scheduler_id, + height, + width, + batch_size, + hf_auth_token=None, + precision="fp32", + num_inference_steps=1, + return_index=False, + ): + super().__init__() + self.dtype = torch.float16 if precision == "fp16" else torch.float32 + self.scheduler = utils.get_schedulers(hf_model_name)[scheduler_id] + if scheduler_id == "PNDM": + num_inference_steps = num_inference_steps - 1 + self.scheduler.set_timesteps(num_inference_steps) + self.scheduler.is_scale_input_called = True + self.return_index = return_index + + if precision == "fp16": + try: + self.unet = UNet2DConditionModel.from_pretrained( + hf_model_name, + subfolder="unet", + auth_token=hf_auth_token, + low_cpu_mem_usage=False, + variant="fp16", + ) + except: + self.unet = UNet2DConditionModel.from_pretrained( + hf_model_name, + subfolder="unet", + auth_token=hf_auth_token, + low_cpu_mem_usage=False, + ) + else: + self.unet = UNet2DConditionModel.from_pretrained( + hf_model_name, + subfolder="unet", + auth_token=hf_auth_token, + low_cpu_mem_usage=False, + ) + + def initialize(self, sample): + height = sample.shape[-2] * 8 + width = sample.shape[-1] * 8 + original_size = (height, width) + target_size = (height, width) + crops_coords_top_left = (0, 0) + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) + add_time_ids = torch.cat([add_time_ids] * 2, dim=0) + add_time_ids = add_time_ids.repeat(sample.shape[0], 1).type(self.dtype) + timesteps = self.scheduler.timesteps + step_indexes = torch.tensor(len(timesteps)) + sample = sample * self.scheduler.init_noise_sigma + return sample.type(self.dtype), add_time_ids, step_indexes + + def forward( + self, sample, prompt_embeds, text_embeds, time_ids, guidance_scale, step_index + ): + with torch.no_grad(): + added_cond_kwargs = { + "text_embeds": text_embeds, + "time_ids": time_ids, + } + t = self.scheduler.timesteps[step_index] + latent_model_input = torch.cat([sample] * 2) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + noise_pred = self.unet.forward( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=None, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + sample = self.scheduler.step(noise_pred, t, sample, return_dict=False)[0] + return sample.type(self.dtype) + + +def export_scheduled_unet_model( + scheduled_unet_model, + scheduler_id, + num_inference_steps, + hf_model_name, + batch_size, + height, + width, + precision, + max_length, + hf_auth_token, + compile_to, + external_weights, + external_weight_path, + device, + iree_target_triple, + ireec_flags=None, + decomp_attn=False, + exit_on_vmfb=False, + pipeline_dir=None, + attn_spec=None, + input_mlir=None, + weights_only=False, +): + if "turbo" in hf_model_name: + do_classifier_free_guidance = False + else: + do_classifier_free_guidance = True + if ( + (attn_spec in ["default"]) + and decomp_attn == False + and ("gfx9" in iree_target_triple) + ): + attn_spec = os.path.join( + os.path.realpath(os.path.dirname(__file__)), "default_mfma_attn_spec.mlir" + ) + elif decomp_attn: + attn_spec = None + + if pipeline_dir: + safe_name = os.path.join( + pipeline_dir, f"{scheduler_id}_unet_{str(num_inference_steps)}" + ) + else: + safe_name = utils.create_safe_name( + hf_model_name, + f"_{max_length}_{height}x{width}_{precision}_scheduled_unet_{device}", + ) + + if input_mlir: + vmfb_path = utils.compile_to_vmfb( + input_mlir, + device, + iree_target_triple, + ireec_flags, + safe_name, + mlir_source="file", + return_path=not exit_on_vmfb, + attn_spec=attn_spec, + ) + return vmfb_path + + mapper = {} + + decomp_list = DEFAULT_DECOMPOSITIONS + if decomp_attn == True: + decomp_list.extend( + [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + ] + ) + + dtype = torch.float16 if precision == "fp16" else torch.float32 + + if precision == "fp16": + scheduled_unet_model = scheduled_unet_model.half() + + utils.save_external_weights( + mapper, scheduled_unet_model, external_weights, external_weight_path + ) + + if weights_only: + return external_weight_path + + sample = ( + batch_size, + scheduled_unet_model.unet.config.in_channels, + height // 8, + width // 8, + ) + if do_classifier_free_guidance: + init_batch_dim = 2 + else: + init_batch_dim = 1 + + time_ids_shape = (init_batch_dim * batch_size, 6) + prompt_embeds_shape = (init_batch_dim * batch_size, max_length, 2048) + text_embeds_shape = (init_batch_dim * batch_size, 1280) + + class CompiledScheduledUnet(CompiledModule): + if external_weights: + params = export_parameters( + scheduled_unet_model, + external=True, + external_scope="", + name_mapper=mapper.get, + ) + else: + params = export_parameters(scheduled_unet_model) + + def run_initialize( + self, + sample=AbstractTensor(*sample, dtype=dtype), + ): + return jittable(scheduled_unet_model.initialize)(sample) + + def run_forward( + self, + sample=AbstractTensor(*sample, dtype=dtype), + prompt_embeds=AbstractTensor(*prompt_embeds_shape, dtype=dtype), + text_embeds=AbstractTensor(*text_embeds_shape, dtype=dtype), + time_ids=AbstractTensor(*time_ids_shape, dtype=dtype), + guidance_scale=AbstractTensor(1, dtype=dtype), + step_index=AbstractTensor(1, dtype=torch.int64), + ): + return jittable(scheduled_unet_model.forward, decompose_ops=decomp_list)( + sample, prompt_embeds, text_embeds, time_ids, guidance_scale, step_index + ) + + import_to = "INPUT" if compile_to == "linalg" else "IMPORT" + inst = CompiledScheduledUnet(context=Context(), import_to=import_to) + + module_str = str(CompiledModule.get_mlir_module(inst)) + + if compile_to != "vmfb": + return module_str + elif compile_to == "vmfb": + vmfb = utils.compile_to_vmfb( + module_str, + device, + iree_target_triple, + ireec_flags, + safe_name, + return_path=True, + attn_spec=attn_spec, + ) + if exit_on_vmfb: + exit() + return vmfb + + +def export_pipeline_module(args): + pipeline_file = ( + "sdxl_sched_unet_bench_" + "f32" + if args.precision == "fp32" + else "sdxl_sched_unet_bench_" + "f16" + ) + if "turbo" in args.hf_model_name: + pipe_prefix = "sdxl_turbo_pipeline_bench_" + else: + pipe_prefix = "sdxl_pipeline_bench_" + full_pipeline_file = ( + pipe_prefix + "f32" if args.precision == "fp32" else pipe_prefix + "f16" + ) + full_pipeline_vmfb_path = utils.compile_to_vmfb( + os.path.join( + os.path.realpath(os.path.dirname(__file__)), full_pipeline_file + ".mlir" + ), + args.device, + args.iree_target_triple, + args.ireec_flags, + "sdxl_full_pipeline_" + args.precision + "_" + args.iree_target_triple, + return_path=True, + const_expr_hoisting=False, + mlir_source="file", + ) + return full_pipeline_vmfb_path + + +if __name__ == "__main__": + from turbine_models.custom_models.sdxl_inference.sdxl_cmd_opts import args + + if args.input_mlir: + scheduled_unet_model = None + else: + scheduled_unet_model = SDXLScheduledUnet( + args.hf_model_name, + args.scheduler_id, + args.height, + args.width, + args.batch_size, + args.hf_auth_token, + args.precision, + args.num_inference_steps, + args.return_index, + ) + if args.compile_to == "vmfb": + pipeline_vmfb_path = export_pipeline_module(args) + mod_str = export_scheduled_unet_model( + scheduled_unet_model, + args.scheduler_id, + args.num_inference_steps, + args.hf_model_name, + args.batch_size, + args.height, + args.width, + args.precision, + args.max_length, + args.hf_auth_token, + args.compile_to, + args.external_weights, + args.external_weight_path, + args.device, + args.iree_target_triple, + args.ireec_flags + args.attn_flags + args.unet_flags, + args.decomp_attn, + args.exit_on_vmfb, + args.pipeline_dir, + args.attn_spec, + args.input_mlir, + ) + if args.input_mlir: + exit() + safe_name = utils.create_safe_name( + args.hf_model_name + "_" + args.scheduler_id, + f"_{args.max_length}_{args.height}x{args.width}_{args.precision}_unet_{str(args.num_inference_steps)}", + ) + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet_runner.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet_runner.py new file mode 100644 index 000000000..8945d274a --- /dev/null +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_scheduled_unet_runner.py @@ -0,0 +1,358 @@ +import argparse +from turbine_models.model_runner import vmfbRunner +from turbine_models.custom_models.sd_inference import utils +from iree import runtime as ireert +import torch +import numpy as np +from tqdm.auto import tqdm + +torch.random.manual_seed(0) + + +def run_unet_hybrid( + sample, + prompt_embeds, + text_embeds, + args, +): + runner = vmfbRunner(args.device, args.vmfb_path, args.external_weight_path) + init_inp = [ + ireert.asdevicearray(runner.config.device, sample), + ] + sample, time_ids, steps = runner.ctx.modules.compiled_scheduled_unet[ + "run_initialize" + ]( + *init_inp, + ) + dtype = "float16" if args.precision == "fp16" else "float32" + inputs = [ + sample, + ireert.asdevicearray(runner.config.device, prompt_embeds), + ireert.asdevicearray(runner.config.device, text_embeds), + time_ids, + ireert.asdevicearray( + runner.config.device, np.asarray([args.guidance_scale]), dtype=dtype + ), + None, + ] + for i in range(steps.to_host()): + inputs[0] = sample + inputs[5] = ireert.asdevicearray( + runner.config.device, torch.tensor([i]), dtype="int64" + ) + sample = runner.ctx.modules.compiled_scheduled_unet["run_forward"](*inputs) + return sample + + +def run_torch_scheduled_unet( + sample, + prompt_embeds, + text_embeds, + args, +): + from diffusers import UNet2DConditionModel + + class SDXLScheduledUnet(torch.nn.Module): + def __init__( + self, + hf_model_name, + scheduler_id, + height, + width, + batch_size, + hf_auth_token=None, + precision="fp32", + num_inference_steps=1, + return_index=False, + ): + super().__init__() + self.dtype = torch.float16 if precision == "fp16" else torch.float32 + self.scheduler = utils.get_schedulers(hf_model_name)[scheduler_id] + self.scheduler.set_timesteps(num_inference_steps) + self.scheduler.is_scale_input_called = True + self.return_index = return_index + + if precision == "fp16": + try: + self.unet = UNet2DConditionModel.from_pretrained( + hf_model_name, + subfolder="unet", + auth_token=hf_auth_token, + low_cpu_mem_usage=False, + variant="fp16", + ) + except: + self.unet = UNet2DConditionModel.from_pretrained( + hf_model_name, + subfolder="unet", + auth_token=hf_auth_token, + low_cpu_mem_usage=False, + ) + else: + self.unet = UNet2DConditionModel.from_pretrained( + hf_model_name, + subfolder="unet", + auth_token=hf_auth_token, + low_cpu_mem_usage=False, + ) + + def initialize(self, sample): + height = sample.shape[-2] * 8 + width = sample.shape[-1] * 8 + original_size = (height, width) + target_size = (height, width) + crops_coords_top_left = (0, 0) + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) + add_time_ids = torch.cat([add_time_ids] * 2, dim=0) + add_time_ids = add_time_ids.repeat(sample.shape[0], 1).type(self.dtype) + timesteps = self.scheduler.timesteps + step_indexes = torch.tensor(len(timesteps)) + sample = sample * self.scheduler.init_noise_sigma + return sample.type(self.dtype), add_time_ids, step_indexes + + def forward( + self, + sample, + prompt_embeds, + text_embeds, + time_ids, + guidance_scale, + step_index, + ): + with torch.no_grad(): + added_cond_kwargs = { + "text_embeds": text_embeds, + "time_ids": time_ids, + } + t = self.scheduler.timesteps[step_index] + latent_model_input = torch.cat([sample] * 2) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) + noise_pred = self.unet.forward( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=None, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + sample = self.scheduler.step(noise_pred, t, sample, return_dict=False)[ + 0 + ] + if self.return_index: + return sample.type(self.dtype), step_index + else: + return sample.type(self.dtype) + + unet_model = SDXLScheduledUnet( + args.hf_model_name, + args.scheduler_id, + args.height, + args.width, + args.batch_size, + args.hf_auth_token, + args.precision, + args.num_inference_steps, + ) + sample, add_time_ids, steps = unet_model.initialize(sample) + for i in range(steps): + sample = unet_model.forward( + sample.float(), + prompt_embeds.float(), + text_embeds.float(), + add_time_ids.float(), + args.guidance_scale, + i, + ) + return sample + + +def run_scheduled_unet( + sample, + prompt_embeds, + text_embeds, + args, +): + pipe_runner = vmfbRunner( + args.device, + [args.vmfb_path, args.pipeline_vmfb_path], + [args.external_weight_path, None], + ) + dtype = "float16" if args.precision == "fp16" else "float32" + inputs = [ + ireert.asdevicearray(pipe_runner.config.device, sample), + ireert.asdevicearray(pipe_runner.config.device, prompt_embeds), + ireert.asdevicearray(pipe_runner.config.device, text_embeds), + ireert.asdevicearray( + pipe_runner.config.device, np.asarray([args.guidance_scale]), dtype=dtype + ), + ] + print(inputs) + latents = pipe_runner.ctx.modules.sdxl_compiled_pipeline["produce_image_latents"]( + *inputs, + ) + + return latents + + +def run_torch_diffusers_loop( + sample, + prompt_embeds, + text_embeds, + args, +): + from turbine_models.custom_models.sdxl_inference.unet import UnetModel + + unet_model = UnetModel( + args.hf_model_name, + args.hf_auth_token, + precision="fp32", + ) + scheduler = utils.get_schedulers(args.hf_model_name)[args.scheduler_id] + + scheduler.set_timesteps(args.num_inference_steps) + scheduler.is_scale_input_called = True + sample = sample * scheduler.init_noise_sigma + + height = sample.shape[-2] * 8 + width = sample.shape[-1] * 8 + original_size = (height, width) + target_size = (height, width) + crops_coords_top_left = (0, 0) + + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids, add_time_ids], dtype=torch.float32) + add_time_ids = add_time_ids.repeat(args.batch_size * 1, 1) + sample = sample.to(torch.float32) + prompt_embeds = prompt_embeds.to(torch.float32) + text_embeds = text_embeds.to(torch.float32) + + for i in range(args.num_inference_steps): + timestep = scheduler.timesteps[i] + + latent_model_input = scheduler.scale_model_input(sample, timestep) + noise_pred = unet_model.forward( + latent_model_input, + timestep, + prompt_embeds, + text_embeds, + add_time_ids, + args.guidance_scale, + ) + sample = scheduler.step( + noise_pred, + timestep, + sample, + return_dict=False, + )[0] + return sample.detach().cpu().numpy() + + +if __name__ == "__main__": + from turbine_models.custom_models.sdxl_inference.sdxl_cmd_opts import args + import numpy as np + + if args.precision == "fp16": + dtype = torch.float16 + else: + dtype = torch.float32 + sample = torch.rand( + args.batch_size, 4, args.height // 8, args.width // 8, dtype=dtype + ) + timestep = torch.zeros(1, dtype=torch.int64) + prompt_embeds = torch.rand(2 * args.batch_size, args.max_length, 2048, dtype=dtype) + text_embeds = torch.rand(2 * args.batch_size, 1280, dtype=dtype) + + turbine_output = run_scheduled_unet( + sample, + prompt_embeds, + text_embeds, + args, + ) + print( + "TURBINE OUTPUT:", + turbine_output.to_host(), + turbine_output.to_host().shape, + turbine_output.to_host().dtype, + ) + + if args.compare_vs_torch: + from turbine_models.custom_models.sd_inference import utils + + print("generating output with python/torch scheduling unet: ") + hybrid_output = run_unet_hybrid( + sample, + prompt_embeds, + text_embeds, + args, + ) + print("generating torch output: ") + torch_output = run_torch_scheduled_unet( + sample, + prompt_embeds, + text_embeds, + args, + ) + print("generating torch+diffusers output: ") + diff_output = run_torch_diffusers_loop( + sample, + prompt_embeds, + text_embeds, + args, + ) + print( + "diffusers-like OUTPUT:", diff_output, diff_output.shape, diff_output.dtype + ) + print("torch OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) + + print( + "HYBRID OUTPUT:", + hybrid_output.to_host(), + hybrid_output.to_host().shape, + hybrid_output.to_host().dtype, + ) + print("Comparing... \n(turbine pipelined unet to torch unet): ") + try: + np.testing.assert_allclose( + turbine_output, torch_output, rtol=4e-2, atol=4e-2 + ) + except AssertionError as err: + print(err) + print("\n(turbine pipelined unet to hybrid unet): ") + try: + np.testing.assert_allclose( + hybrid_output, turbine_output, rtol=4e-2, atol=4e-2 + ) + print("passed!") + except AssertionError as err: + print(err) + print("\n(hybrid unet to diff unet): ") + try: + np.testing.assert_allclose(diff_output, hybrid_output, rtol=4e-2, atol=4e-2) + print("passed!") + except AssertionError as err: + print(err) + print("\n(turbine loop to diffusers loop): ") + try: + np.testing.assert_allclose( + turbine_output, diff_output, rtol=4e-2, atol=4e-2 + ) + print("passed!") + except AssertionError as err: + print(err) + print("\n(torch sched unet loop to diffusers loop): ") + try: + np.testing.assert_allclose(torch_output, diff_output, rtol=4e-2, atol=4e-2) + print("passed!") + except AssertionError as err: + print(err) + + # TODO: Figure out why we occasionally segfault without unlinking output variables + turbine_output = None diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_schedulers.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_schedulers.py new file mode 100644 index 000000000..a3ae29595 --- /dev/null +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_schedulers.py @@ -0,0 +1,197 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# from @aviator19941's gist : https://gist.github.com/aviator19941/4e7967bd1787c83ee389a22637c6eea7 + +import os +import sys + +from iree import runtime as ireert +from iree.compiler.ir import Context +import numpy as np +from shark_turbine.aot import * +from turbine_models.custom_models.sd_inference import utils +import torch +import torch._dynamo as dynamo +from diffusers import UNet2DConditionModel +from shark_turbine.dynamo.passes import ( + DEFAULT_DECOMPOSITIONS, +) + +import safetensors + + +class SDXLScheduler(torch.nn.Module): + def __init__( + self, + hf_model_name, + num_inference_steps, + scheduler, + hf_auth_token=None, + precision="fp32", + ): + super().__init__() + self.scheduler = scheduler + self.scheduler.set_timesteps(num_inference_steps) + self.guidance_scale = 7.5 + if precision == "fp16": + try: + self.unet = UNet2DConditionModel.from_pretrained( + hf_model_name, + subfolder="unet", + auth_token=hf_auth_token, + low_cpu_mem_usage=False, + variant="fp16", + ) + except: + self.unet = UNet2DConditionModel.from_pretrained( + hf_model_name, + subfolder="unet", + auth_token=hf_auth_token, + low_cpu_mem_usage=False, + ) + else: + self.unet = UNet2DConditionModel.from_pretrained( + hf_model_name, + subfolder="unet", + auth_token=hf_auth_token, + low_cpu_mem_usage=False, + ) + + def forward(self, sample, prompt_embeds, text_embeds, time_ids): + sample = sample * self.scheduler.init_noise_sigma + for t in self.scheduler.timesteps: + with torch.no_grad(): + added_cond_kwargs = { + "text_embeds": text_embeds, + "time_ids": time_ids, + } + latent_model_input = torch.cat([sample] * 2) + t = t.unsqueeze(0) + # print('UNSQUEEZE T:', t) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, timestep=t + ) + noise_pred = self.unet.forward( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=None, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + sample = self.scheduler.step(noise_pred, t, sample, return_dict=False)[ + 0 + ] + return sample + + +def export_scheduler( + scheduler, + hf_model_name, + batch_size, + height, + width, + hf_auth_token=None, + compile_to="torch", + external_weights=None, + external_weight_path=None, + device=None, + target_triple=None, + ireec_flags=None, +): + mapper = {} + utils.save_external_weights( + mapper, scheduler, external_weights, external_weight_path + ) + + decomp_list = DEFAULT_DECOMPOSITIONS + + decomp_list.extend( + [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + ] + ) + # tensor shapes for tracing + sample = (batch_size, 4, height // 8, width // 8) + prompt_embeds = (2, 77, 2048) + text_embeds = (2, 1280) + time_ids = (2, 6) + + class CompiledScheduler(CompiledModule): + if external_weights: + params = export_parameters( + scheduler, external=True, external_scope="", name_mapper=mapper.get + ) + else: + params = export_parameters(scheduler) + + def main( + self, + sample=AbstractTensor(*sample, dtype=torch.float32), + prompt_embeds=AbstractTensor(*prompt_embeds, dtype=torch.float32), + text_embeds=AbstractTensor(*text_embeds, dtype=torch.float32), + time_ids=AbstractTensor(*time_ids, dtype=torch.float32), + ): + return jittable(scheduler.forward, decompose_ops=decomp_list)( + sample, prompt_embeds, text_embeds, time_ids + ) + + import_to = "INPUT" if compile_to == "linalg" else "IMPORT" + inst = CompiledScheduler(context=Context(), import_to=import_to) + + module_str = str(CompiledModule.get_mlir_module(inst)) + + safe_name = utils.create_safe_name(hf_model_name, "-scheduler") + with open(f"{safe_name}.mlir", "w+") as f: + f.write(module_str) + print("Saved to", safe_name + ".mlir") + + if compile_to != "vmfb": + return module_str + else: + utils.compile_to_vmfb(module_str, device, target_triple, ireec_flags, safe_name) + + +if __name__ == "__main__": + from turbine_models.custom_models.sdxl_inference.sdxl_cmd_opts import args + + hf_model_name = "stabilityai/stable-diffusion-xl-base-1.0" + schedulers = utils.get_schedulers(args.hf_model_name) + scheduler = schedulers[args.scheduler_id] + scheduler_module = SDXLScheduler( + args.hf_model_name, + args.num_inference_steps, + scheduler, + hf_auth_token=None, + precision=args.precision, + ) + + print("export scheduler begin") + mod_str = export_scheduler( + scheduler_module, + args.hf_model_name, + args.batch_size, + args.height, + args.width, + args.hf_auth_token, + args.compile_to, + args.external_weights, + args.external_weight_path, + args.device, + args.iree_target_triple, + args.ireec_flags, + ) + print("export scheduler complete") + safe_name = utils.create_safe_name(args.hf_model_name, "-scheduler") + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") diff --git a/models/turbine_models/custom_models/sdxl_inference/unet.py b/models/turbine_models/custom_models/sdxl_inference/unet.py new file mode 100644 index 000000000..e9839ba06 --- /dev/null +++ b/models/turbine_models/custom_models/sdxl_inference/unet.py @@ -0,0 +1,253 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import sys + +from iree import runtime as ireert +from iree.compiler.ir import Context +import numpy as np +from shark_turbine.aot import * +from shark_turbine.dynamo.passes import ( + DEFAULT_DECOMPOSITIONS, +) +from turbine_models.custom_models.sd_inference import utils +import torch +import torch._dynamo as dynamo +from diffusers import UNet2DConditionModel + + +class UnetModel(torch.nn.Module): + def __init__(self, hf_model_name, hf_auth_token=None, precision="fp32"): + super().__init__() + if precision == "fp16": + try: + self.unet = UNet2DConditionModel.from_pretrained( + hf_model_name, + subfolder="unet", + auth_token=hf_auth_token, + low_cpu_mem_usage=False, + variant="fp16", + ) + except: + self.unet = UNet2DConditionModel.from_pretrained( + hf_model_name, + subfolder="unet", + auth_token=hf_auth_token, + low_cpu_mem_usage=False, + ) + else: + self.unet = UNet2DConditionModel.from_pretrained( + hf_model_name, + subfolder="unet", + auth_token=hf_auth_token, + low_cpu_mem_usage=False, + ) + if "turbo" in hf_model_name: + self.do_classifier_free_guidance = False + else: + self.do_classifier_free_guidance = True + + def forward( + self, sample, timestep, prompt_embeds, text_embeds, time_ids, guidance_scale + ): + with torch.no_grad(): + added_cond_kwargs = { + "text_embeds": text_embeds, + "time_ids": time_ids, + } + if self.do_classifier_free_guidance: + latent_model_input = torch.cat([sample] * 2) + else: + latent_model_input = sample + noise_pred = self.unet.forward( + latent_model_input, + timestep, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=None, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + return noise_pred + + +def export_unet_model( + unet_model, + hf_model_name, + batch_size, + height, + width, + precision="fp32", + max_length=77, + hf_auth_token=None, + compile_to="torch", + external_weights=None, + external_weight_path=None, + device=None, + target_triple=None, + ireec_flags=None, + decomp_attn=False, + exit_on_vmfb=False, + attn_spec=None, + input_mlir=None, + weights_only=False, +): + if "turbo" in hf_model_name: + do_classifier_free_guidance = False + else: + do_classifier_free_guidance = True + + if ( + (attn_spec in ["default"]) + and decomp_attn == False + and ("gfx9" in target_triple) + ): + attn_spec = os.path.join( + os.path.realpath(os.path.dirname(__file__)), "default_mfma_attn_spec.mlir" + ) + elif decomp_attn: + attn_spec = None + + safe_name = utils.create_safe_name( + hf_model_name, f"_{max_length}_{height}x{width}_{precision}_unet_{device}" + ) + + if input_mlir: + vmfb_path = utils.compile_to_vmfb( + input_mlir, + device, + target_triple, + ireec_flags, + safe_name, + mlir_source="file", + return_path=not exit_on_vmfb, + attn_spec=attn_spec, + ) + return vmfb_path + + mapper = {} + decomp_list = DEFAULT_DECOMPOSITIONS + if decomp_attn == True: + decomp_list.extend( + [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + ] + ) + dtype = torch.float16 if precision == "fp16" else torch.float32 + + if precision == "fp16": + unet_model = unet_model.half() + + utils.save_external_weights( + mapper, unet_model, external_weights, external_weight_path + ) + + if weights_only: + return external_weight_path + + sample = ( + batch_size, + unet_model.unet.config.in_channels, + height // 8, + width // 8, + ) + if do_classifier_free_guidance: + init_batch_dim = 2 + else: + init_batch_dim = 1 + + time_ids_shape = (init_batch_dim * batch_size, 6) + prompt_embeds_shape = (init_batch_dim * batch_size, max_length, 2048) + text_embeds_shape = (init_batch_dim * batch_size, 1280) + + class CompiledUnet(CompiledModule): + if external_weights: + params = export_parameters( + unet_model, external=True, external_scope="", name_mapper=mapper.get + ) + else: + params = export_parameters(unet_model) + + def main( + self, + sample=AbstractTensor(*sample, dtype=dtype), + timestep=AbstractTensor(1, dtype=torch.int64), + prompt_embeds=AbstractTensor(*prompt_embeds_shape, dtype=dtype), + text_embeds=AbstractTensor(*text_embeds_shape, dtype=dtype), + time_ids=AbstractTensor(*time_ids_shape, dtype=dtype), + guidance_scale=AbstractTensor(1, dtype=dtype), + ): + return jittable(unet_model.forward, decompose_ops=decomp_list)( + sample, timestep, prompt_embeds, text_embeds, time_ids, guidance_scale + ) + + import_to = "INPUT" if compile_to == "linalg" else "IMPORT" + inst = CompiledUnet(context=Context(), import_to=import_to) + + module_str = str(CompiledModule.get_mlir_module(inst)) + + if compile_to != "vmfb": + return module_str + else: + utils.compile_to_vmfb( + module_str, + device, + target_triple, + ireec_flags, + safe_name, + return_path=False, + attn_spec=attn_spec, + ) + + +if __name__ == "__main__": + import logging + + logging.basicConfig(level=logging.DEBUG) + from turbine_models.custom_models.sdxl_inference.sdxl_cmd_opts import args + + if args.input_mlir: + unet_model = None + else: + unet_model = UnetModel( + args.hf_model_name, + args.hf_auth_token, + args.precision, + ) + mod_str = export_unet_model( + unet_model, + args.hf_model_name, + args.batch_size, + args.height, + args.width, + args.precision, + args.max_length, + args.hf_auth_token, + args.compile_to, + args.external_weights, + args.external_weight_path, + args.device, + args.iree_target_triple, + args.ireec_flags + args.attn_flags + args.unet_flags, + args.decomp_attn, + attn_spec=args.attn_spec, + input_mlir=args.input_mlir, + ) + if args.input_mlir: + exit() + safe_name = utils.create_safe_name( + args.hf_model_name, + f"_{args.max_length}_{args.height}x{args.width}_{args.precision}_unet", + ) + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") diff --git a/models/turbine_models/custom_models/sdxl_inference/unet_runner.py b/models/turbine_models/custom_models/sdxl_inference/unet_runner.py new file mode 100644 index 000000000..197d850a9 --- /dev/null +++ b/models/turbine_models/custom_models/sdxl_inference/unet_runner.py @@ -0,0 +1,166 @@ +import argparse +from turbine_models.model_runner import vmfbRunner +from iree import runtime as ireert +import torch +import numpy as np +from tqdm.auto import tqdm + +torch.random.manual_seed(0) + + +def run_unet( + device, + sample, + timestep, + prompt_embeds, + text_embeds, + time_ids, + guidance_scale, + vmfb_path, + hf_model_name, + hf_auth_token, + external_weight_path, + runner=None, +): + if runner is None: + runner = vmfbRunner(device, vmfb_path, external_weight_path) + + inputs = [ + ireert.asdevicearray(runner.config.device, sample), + ireert.asdevicearray(runner.config.device, timestep), + ireert.asdevicearray(runner.config.device, prompt_embeds), + ireert.asdevicearray(runner.config.device, text_embeds), + ireert.asdevicearray(runner.config.device, time_ids), + ireert.asdevicearray(runner.config.device, guidance_scale), + ] + results = runner.ctx.modules.compiled_unet["main"](*inputs) + + return results + + +def run_unet_steps( + device, + sample, + scheduler, + prompt_embeds, + text_embeds, + time_ids, + guidance_scale, + vmfb_path, + external_weight_path, +): + runner = vmfbRunner(device, vmfb_path, external_weight_path) + timestep = torch.zeros(1, dtype=torch.int64) + inputs = [ + ireert.asdevicearray(runner.config.device, sample), + ireert.asdevicearray(runner.config.device, timestep), + ireert.asdevicearray(runner.config.device, prompt_embeds), + ireert.asdevicearray(runner.config.device, text_embeds), + ireert.asdevicearray(runner.config.device, time_ids), + ireert.asdevicearray(runner.config.device, (guidance_scale,)), + ] + for i, t in tqdm(enumerate(scheduler.timesteps)): + timestep = t + latent_model_input = scheduler.scale_model_input(sample, timestep) + + inputs[0] = latent_model_input = ireert.asdevicearray( + runner.config.device, latent_model_input + ) + inputs[1] = timestep = ireert.asdevicearray( + runner.config.device, (timestep,), dtype="int64" + ) + noise_pred = runner.ctx.modules.compiled_unet["main"](*inputs).to_host() + sample = scheduler.step( + torch.from_numpy(noise_pred).cpu(), + timestep, + sample, + generator=None, + return_dict=False, + )[0] + return sample + + +def run_torch_unet( + hf_model_name, + hf_auth_token, + sample, + timestep, + prompt_embeds, + text_embeds, + time_ids, + guidance_scale, + precision="fp32", +): + from turbine_models.custom_models.sdxl_inference.unet import UnetModel + + unet_model = UnetModel( + hf_model_name, + hf_auth_token, + precision="fp32", + ) + results = unet_model.forward( + sample, timestep, prompt_embeds, text_embeds, time_ids, guidance_scale + ) + np_torch_output = results.detach().cpu().numpy() + return np_torch_output + + +if __name__ == "__main__": + from turbine_models.custom_models.sdxl_inference.sdxl_cmd_opts import args + + if args.precision == "fp16": + dtype = torch.float16 + else: + dtype = torch.float32 + sample = torch.rand( + args.batch_size, 4, args.height // 8, args.width // 8, dtype=dtype + ) + timestep = torch.zeros(1, dtype=torch.int64) + prompt_embeds = torch.rand(2 * args.batch_size, args.max_length, 2048, dtype=dtype) + text_embeds = torch.rand(2 * args.batch_size, 1280, dtype=dtype) + time_ids = torch.zeros(2 * args.batch_size, 6, dtype=dtype) + guidance_scale = torch.tensor([7.5], dtype=dtype) + + turbine_output = run_unet( + args.device, + sample, + timestep, + prompt_embeds, + text_embeds, + time_ids, + guidance_scale, + args.vmfb_path, + args.hf_model_name, + args.hf_auth_token, + args.external_weight_path, + ) + print( + "TURBINE OUTPUT:", + turbine_output.to_host(), + turbine_output.to_host().shape, + turbine_output.to_host().dtype, + ) + + if args.compare_vs_torch: + print("generating torch output: ") + from turbine_models.custom_models.sd_inference import utils + + # comment out .float for fp16... sorry. + torch_output = run_torch_unet( + args.hf_model_name, + args.hf_auth_token, + sample.float(), + timestep, + prompt_embeds.float(), + text_embeds.float(), + time_ids.float(), + guidance_scale.float(), + # precision="fp16", + ) + print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) + err = utils.largest_error(torch_output, turbine_output) + print("Largest Error: ", err) + assert err < 9e-3 + + # TODO: Figure out why we occasionally segfault without unlinking output variables + turbine_output = None diff --git a/models/turbine_models/custom_models/sdxl_inference/vae.py b/models/turbine_models/custom_models/sdxl_inference/vae.py new file mode 100644 index 000000000..7563eed96 --- /dev/null +++ b/models/turbine_models/custom_models/sdxl_inference/vae.py @@ -0,0 +1,213 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import os +import sys + +from iree import runtime as ireert +from iree.compiler.ir import Context +import numpy as np +from shark_turbine.aot import * +from shark_turbine.dynamo.passes import ( + DEFAULT_DECOMPOSITIONS, +) +from turbine_models.custom_models.sd_inference import utils +import torch +import torch._dynamo as dynamo +from diffusers import AutoencoderKL + + +class VaeModel(torch.nn.Module): + def __init__( + self, + hf_model_name, + custom_vae="", + ): + super().__init__() + self.vae = None + if custom_vae in ["", None]: + self.vae = AutoencoderKL.from_pretrained( + hf_model_name, + subfolder="vae", + ) + elif not isinstance(custom_vae, dict): + try: + # custom HF repo with no vae subfolder + self.vae = AutoencoderKL.from_pretrained( + custom_vae, + ) + except: + # some larger repo with vae subfolder + self.vae = AutoencoderKL.from_pretrained( + custom_vae, + subfolder="vae", + ) + else: + # custom vae as a HF state dict + self.vae = AutoencoderKL.from_pretrained( + hf_model_name, + subfolder="vae", + ) + self.vae.load_state_dict(custom_vae) + + def decode_inp(self, inp): + inp = 1 / 0.13025 * inp + x = self.vae.decode(inp, return_dict=False)[0] + return (x / 2 + 0.5).clamp(0, 1) + + def encode_inp(self, inp): + latents = self.vae.encode(inp).latent_dist.sample() + return 0.13025 * latents + + +def export_vae_model( + vae_model, + hf_model_name, + batch_size, + height, + width, + precision, + compile_to="torch", + external_weights=None, + external_weight_path=None, + device=None, + target_triple=None, + ireec_flags=None, + variant="decode", + decomp_attn=False, + exit_on_vmfb=False, + pipeline_dir=None, + attn_spec=None, + input_mlir=None, + weights_only=False, +): + if ( + (attn_spec in ["default"]) + and decomp_attn == False + and ("gfx9" in target_triple) + ): + attn_spec = os.path.join( + os.path.realpath(os.path.dirname(__file__)), "default_mfma_attn_spec.mlir" + ) + elif decomp_attn: + attn_spec = None + + if pipeline_dir: + safe_name = os.path.join(pipeline_dir, "vae_" + variant) + else: + safe_name = utils.create_safe_name( + hf_model_name, f"_{height}x{width}_{precision}_vae_{variant}_{device}" + ) + if input_mlir: + vmfb_path = utils.compile_to_vmfb( + input_mlir, + device, + target_triple, + ireec_flags, + safe_name, + mlir_source="file", + return_path=not exit_on_vmfb, + attn_spec=attn_spec, + ) + return vmfb_path + + mapper = {} + decomp_list = DEFAULT_DECOMPOSITIONS + if decomp_attn == True: + decomp_list.extend( + [ + torch.ops.aten._scaled_dot_product_flash_attention_for_cpu, + torch.ops.aten._scaled_dot_product_flash_attention.default, + ] + ) + dtype = torch.float16 if precision == "fp16" else torch.float32 + if precision == "fp16": + vae_model = vae_model.half() + utils.save_external_weights( + mapper, vae_model, external_weights, external_weight_path + ) + if weights_only: + return external_weight_path + sample = (batch_size, 4, height // 8, width // 8) + if variant == "encode": + sample = (batch_size, 3, height, width) + + class CompiledVae(CompiledModule): + if external_weights: + params = export_parameters( + vae_model, external=True, external_scope="", name_mapper=mapper.get + ) + else: + params = export_parameters(vae_model) + + def main(self, inp=AbstractTensor(*sample, dtype=dtype)): + if variant == "decode": + return jittable(vae_model.decode_inp, decompose_ops=decomp_list)(inp) + elif variant == "encode": + return jittable(vae_model.encode_inp, decompose_ops=decomp_list)(inp) + + import_to = "INPUT" if compile_to == "linalg" else "IMPORT" + inst = CompiledVae(context=Context(), import_to=import_to) + + module_str = str(CompiledModule.get_mlir_module(inst)) + + if compile_to != "vmfb": + return module_str + else: + vmfb_path = utils.compile_to_vmfb( + module_str, + device, + target_triple, + ireec_flags, + safe_name, + return_path=not exit_on_vmfb, + attn_spec=attn_spec, + ) + return vmfb_path + + +if __name__ == "__main__": + from turbine_models.custom_models.sdxl_inference.sdxl_cmd_opts import args + + if args.precision == "fp16": + custom_vae = "madebyollin/sdxl-vae-fp16-fix" + else: + custom_vae = "" + + if args.input_mlir: + vae_model = None + else: + vae_model = VaeModel( + args.hf_model_name, + custom_vae=custom_vae, + ) + mod_str = export_vae_model( + vae_model, + args.hf_model_name, + args.batch_size, + height=args.height, + width=args.width, + precision=args.precision, + compile_to=args.compile_to, + external_weights=args.external_weights, + external_weight_path=args.external_weight_path, + device=args.device, + target_triple=args.iree_target_triple, + ireec_flags=args.ireec_flags + args.attn_flags + args.vae_flags, + variant=args.vae_variant, + decomp_attn=args.decomp_attn, + attn_spec=args.attn_spec, + input_mlir=args.input_mlir, + ) + if args.input_mlir or (args.compile_to == "vmfb"): + exit() + safe_name = utils.create_safe_name( + args.hf_model_name, + f"_{args.height}x{args.width}_{args.precision}_vae_{args.vae_variant}", + ) + with open(f"{safe_name}.mlir", "w+") as f: + f.write(mod_str) + print("Saved to", safe_name + ".mlir") diff --git a/models/turbine_models/custom_models/sdxl_inference/vae_runner.py b/models/turbine_models/custom_models/sdxl_inference/vae_runner.py new file mode 100644 index 000000000..539c99868 --- /dev/null +++ b/models/turbine_models/custom_models/sdxl_inference/vae_runner.py @@ -0,0 +1,124 @@ +import argparse +from turbine_models.model_runner import vmfbRunner +from iree import runtime as ireert +import torch + +torch.random.manual_seed(0) + + +def run_vae( + device, + example_input, + vmfb_path, + hf_model_name, + external_weight_path, +): + runner = vmfbRunner(device, vmfb_path, external_weight_path) + inputs = [ireert.asdevicearray(runner.config.device, example_input)] + results = runner.ctx.modules.compiled_vae["main"](*inputs) + + return results + + +def run_torch_vae(hf_model_name, custom_vae, variant, example_input): + from diffusers import AutoencoderKL + + class VaeModel(torch.nn.Module): + def __init__( + self, + hf_model_name, + custom_vae=custom_vae, + ): + super().__init__() + self.vae = None + if custom_vae in ["", None]: + self.vae = AutoencoderKL.from_pretrained( + hf_model_name, + subfolder="vae", + ) + elif not isinstance(custom_vae, dict): + try: + # custom HF repo with no vae subfolder + self.vae = AutoencoderKL.from_pretrained( + custom_vae, + ) + except: + # some larger repo with vae subfolder + self.vae = AutoencoderKL.from_pretrained( + custom_vae, + subfolder="vae", + ) + else: + # custom vae as a HF state dict + self.vae = AutoencoderKL.from_pretrained( + hf_model_name, + subfolder="vae", + ) + self.vae.load_state_dict(custom_vae) + + def decode_inp(self, inp): + inp = inp / 0.13025 + x = self.vae.decode(inp, return_dict=False)[0] + return (x / 2 + 0.5).clamp(0, 1) + + def encode_inp(self, inp): + latents = self.vae.encode(inp).latent_dist.sample() + return 0.13025 * latents + + vae_model = VaeModel( + hf_model_name, + ) + + if variant == "decode": + results = vae_model.decode_inp(example_input) + elif variant == "encode": + results = vae_model.encode_inp(example_input) + np_torch_output = results.detach().cpu().numpy() + return np_torch_output + + +if __name__ == "__main__": + from turbine_models.custom_models.sdxl_inference.sdxl_cmd_opts import args + + if args.precision == "fp16": + dtype = torch.float16 + custom_vae = "madebyollin/sdxl-vae-fp16-fix" + else: + dtype = torch.float32 + custom_vae = "" + if args.vae_variant == "decode": + example_input = torch.rand( + args.batch_size, 4, args.height // 8, args.width // 8, dtype=dtype + ) + elif args.vae_variant == "encode": + example_input = torch.rand( + args.batch_size, 3, args.height, args.width, dtype=dtype + ) + print("generating turbine output:") + turbine_results = run_vae( + args.device, + example_input, + args.vmfb_path, + args.hf_model_name, + args.external_weight_path, + ) + print( + "TURBINE OUTPUT:", + turbine_results.to_host(), + turbine_results.to_host().shape, + turbine_results.to_host().dtype, + ) + if args.compare_vs_torch: + print("generating torch output: ") + from turbine_models.custom_models.sd_inference import utils + + torch_output = run_torch_vae( + args.hf_model_name, custom_vae, args.vae_variant, example_input.float() + ) + print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype) + err = utils.largest_error(torch_output, turbine_results) + print("Largest Error: ", err) + assert err < 2e-3 + + # TODO: Figure out why we occasionally segfault without unlinking output variables + turbine_results = None diff --git a/models/turbine_models/model_runner.py b/models/turbine_models/model_runner.py index 74dd3dc9a..4afa5eda5 100644 --- a/models/turbine_models/model_runner.py +++ b/models/turbine_models/model_runner.py @@ -5,26 +5,67 @@ class vmfbRunner: def __init__(self, device, vmfb_path, external_weight_path=None): - self.config = ireert.Config(device) + flags = [] + haldriver = ireert.get_driver(device) + if "://" in device: + try: + device_idx = int(device.split("://")[-1]) + device_uri = None + except: + device_idx = None + device_uri = device.split("://")[-1] + else: + device_idx = 0 + device_uri = None + if device_uri: + if not any(x in device for x in ["cpu", "task"]): + allocators = ["caching"] + haldevice = haldriver.create_device_by_uri( + device_uri, allocators=allocators + ) + else: + haldevice = haldriver.create_device_by_uri(device_uri) + else: + hal_device_id = haldriver.query_available_devices()[device_idx]["device_id"] + if not any(x in device for x in ["cpu", "task"]): + allocators = ["caching"] + haldevice = haldriver.create_device( + hal_device_id, allocators=allocators + ) + else: + haldevice = haldriver.create_device(hal_device_id) - # TODO: enable multiple vmfb's - mod = ireert.VmModule.mmap(self.config.vm_instance, vmfb_path) + self.config = ireert.Config(device=haldevice) + mods = [] + if not isinstance(vmfb_path, list): + vmfb_path = [vmfb_path] + for path in vmfb_path: + mods.append(ireert.VmModule.mmap(self.config.vm_instance, path)) vm_modules = [ - mod, + *mods, ireert.create_hal_module(self.config.vm_instance, self.config.device), ] # TODO: Enable multiple weight files if external_weight_path: index = ireert.ParameterIndex() - index.load(external_weight_path) - # TODO: extend scope - param_module = ireert.create_io_parameters_module( - self.config.vm_instance, index.create_provider(scope="model") - ) - vm_modules.insert(0, param_module) + if not isinstance(external_weight_path, list): + external_weight_path = [external_weight_path] + for i, path in enumerate(external_weight_path): + if path in ["", None]: + continue + index.load(path) + # TODO: extend scope + param_module = ireert.create_io_parameters_module( + self.config.vm_instance, index.create_provider(scope="model") + ) + vm_modules.insert(i, param_module) self.ctx = ireert.SystemContext( vm_modules=vm_modules, config=self.config, ) + + def unload(self): + self.ctx = None + self.config = None diff --git a/models/turbine_models/tests/conftest.py b/models/turbine_models/tests/conftest.py new file mode 100644 index 000000000..7a1f55b1a --- /dev/null +++ b/models/turbine_models/tests/conftest.py @@ -0,0 +1,53 @@ +def pytest_addoption(parser): + # Huggingface Options + parser.addoption("--hf_auth_token", action="store", default=None) + parser.addoption( + "--hf_model_name", + action="store", + default="stabilityai/stable-diffusion-xl-base-1.0", + ) + parser.addoption("--scheduler_id", action="store", default="PNDM") + # Inference Options + parser.addoption( + "--prompt", + action="store", + default="a photograph of an astronaut riding a horse", + ) + parser.addoption( + "--negative_prompt", + action="store", + default="blurry, unsaturated, watermark, noisy, grainy, out of focus", + ) + parser.addoption("--num_inference_steps", type=int, action="store", default=5) + parser.addoption("--guidance_scale", type=float, action="store", default=7.5) + parser.addoption("--seed", type=float, action="store", default=0.0) + parser.addoption("--vmfb_path", action="store", default="") + parser.addoption("--external_weight_path", action="store", default="") + parser.addoption("--external_weight_dir", action="store", default="") + parser.addoption("--external_weight_file", action="store", default="") + parser.addoption("--pipeline_dir", action="store", default=".") + # Modelling Options + parser.addoption("--batch_size", type=int, action="store", default=1) + parser.addoption("--height", type=int, action="store", default=1024) + parser.addoption("--width", type=int, action="store", default=1024) + parser.addoption("--precision", action="store", default="fp32") + parser.addoption("--max_length", type=int, action="store", default=64) + parser.addoption("--run_vmfb", action="store", default=True) + # General Options + parser.addoption("--compile_to", action="store", default=None) + parser.addoption("--external_weights", action="store", default="safetensors") + parser.addoption("--decomp_attn", action="store", default=True) + parser.addoption("--attn_spec", action="store", default="") + # Compiler Options + parser.addoption("--device", action="store", default="cpu") + parser.addoption("--rt_device", action="store", default="local-task") + parser.addoption( + "--iree_target_triple", type=str, action="store", default="x86_64-linux-gnu" + ) + parser.addoption("--ireec_flags", action="store", default="") + parser.addoption("--attn_flags", action="store", default="") + # Test Options + parser.addoption("--in_channels", type=int, action="store", default=4) + parser.addoption("--benchmark", action="store_true", default=False) + parser.addoption("--tracy_profile", action="store_true", default=False) + parser.addoption("--compiled_pipeline", type=bool, default=True) diff --git a/models/turbine_models/tests/sd_test.py b/models/turbine_models/tests/sd_test.py index f88b44813..76c11bcba 100644 --- a/models/turbine_models/tests/sd_test.py +++ b/models/turbine_models/tests/sd_test.py @@ -29,18 +29,23 @@ default_arguments = { "hf_auth_token": None, "hf_model_name": "CompVis/stable-diffusion-v1-4", + "safe_model_name": "stable-diffusion_v1_4", "scheduler_id": "PNDM", "num_inference_steps": 5, "batch_size": 1, "height": 512, "width": 512, + "precision": "fp32", + "max_length": 77, + "guidance_scale": 7.5, "run_vmfb": True, "compile_to": None, "external_weight_path": "", "vmfb_path": "", "external_weights": None, - "device": "local-task", - "iree_target_triple": "", + "device": "cpu", + "rt_device": "local-task", + "iree_target_triple": "x86_64-linux-gnu", "vulkan_max_allocation": "4294967296", "prompt": "a photograph of an astronaut riding a horse", "in_channels": 4, @@ -50,14 +55,13 @@ unet_model = unet.UnetModel( # This is a public model, so no auth required - "CompVis/stable-diffusion-v1-4", - None, + default_arguments["hf_model_name"], ) vae_model = vae.VaeModel( # This is a public model, so no auth required - "CompVis/stable-diffusion-v1-4", - None, + default_arguments["hf_model_name"], + custom_vae=None, ) schedulers_dict = utils.get_schedulers( @@ -89,7 +93,7 @@ def testExportT5Model(self): ) current_args["vmfb_path"] = safe_prefix + "_clip.vmfb" turbine = clip_runner.run_clip( - current_args["device"], + current_args["rt_device"], current_args["prompt"], current_args["vmfb_path"], current_args["hf_model_name"], @@ -103,8 +107,6 @@ def testExportT5Model(self): ) err = utils.largest_error(torch_output, turbine[0]) assert err < 9e-4 - if platform.system() != "Windows": - os.remove(current_args["vmfb_path"]) if UPLOAD_IR: new_blob_name = blob_name.split(".") new_blob_name = new_blob_name[0] + "-pass.mlir" @@ -129,7 +131,7 @@ def testExportClipVitLarge14(self): current_args["external_weight_path"] = safe_prefix + ".safetensors" current_args["vmfb_path"] = safe_prefix + "_clip.vmfb" turbine = clip_runner.run_clip( - current_args["device"], + current_args["rt_device"], current_args["prompt"], current_args["vmfb_path"], current_args["hf_model_name"], @@ -167,7 +169,7 @@ def testExportClipModel(self): current_args["external_weight_path"] = "stable_diffusion_v1_4_clip.safetensors" current_args["vmfb_path"] = "stable_diffusion_v1_4_clip.vmfb" turbine = clip_runner.run_clip( - current_args["device"], + current_args["rt_device"], current_args["prompt"], current_args["vmfb_path"], current_args["hf_model_name"], @@ -185,18 +187,20 @@ def testExportClipModel(self): new_blob_name = blob_name.split(".") new_blob_name = new_blob_name[0] + "-pass.mlir" turbine_tank.changeBlobName(blob_name, new_blob_name) - os.remove("stable_diffusion_v1_4_clip.safetensors") - os.remove("stable_diffusion_v1_4_clip.vmfb") + if platform.system() != "Windows": + os.remove(current_args["external_weight_path"]) + os.remove(current_args["vmfb_path"]) def testExportUnetModel(self): current_args = copy.deepcopy(default_arguments) blob_name = unet.export_unet_model( unet_model, - # This is a public model, so no auth required "CompVis/stable-diffusion-v1-4", current_args["batch_size"], current_args["height"], current_args["width"], + current_args["precision"], + current_args["max_length"], None, "vmfb", "safetensors", @@ -213,14 +217,22 @@ def testExportUnetModel(self): current_args["width"] // 8, dtype=torch.float32, ) + timestep = torch.zeros(1, dtype=torch.float32) - encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32) + if current_args["hf_model_name"] == "CompVis/stable-diffusion-v1-4": + encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32) + elif current_args["hf_model_name"] == "stabilityai/stable-diffusion-2-1-base": + encoder_hidden_states = torch.rand(2, 77, 1024, dtype=torch.float32) + guidance_scale = torch.tensor( + [current_args["guidance_scale"]], dtype=torch.float32 + ) turbine = unet_runner.run_unet( - current_args["device"], + current_args["rt_device"], sample, timestep, encoder_hidden_states, + guidance_scale, current_args["vmfb_path"], current_args["hf_model_name"], current_args["hf_auth_token"], @@ -232,6 +244,7 @@ def testExportUnetModel(self): sample, timestep, encoder_hidden_states, + guidance_scale, ) err = utils.largest_error(torch_output, turbine) assert err < 9e-5 @@ -241,6 +254,8 @@ def testExportUnetModel(self): turbine_tank.changeBlobName(blob_name, new_blob_name) os.remove("stable_diffusion_v1_4_unet.safetensors") os.remove("stable_diffusion_v1_4_unet.vmfb") + del torch_output + del turbine def testExportVaeModelDecode(self): current_args = copy.deepcopy(default_arguments) @@ -269,16 +284,14 @@ def testExportVaeModelDecode(self): dtype=torch.float32, ) turbine = vae_runner.run_vae( - current_args["device"], + current_args["rt_device"], example_input, current_args["vmfb_path"], current_args["hf_model_name"], - current_args["hf_auth_token"], current_args["external_weight_path"], ) torch_output = vae_runner.run_torch_vae( current_args["hf_model_name"], - current_args["hf_auth_token"], "decode", example_input, ) @@ -288,11 +301,11 @@ def testExportVaeModelDecode(self): new_blob_name = blob_name.split(".") new_blob_name = new_blob_name[0] + "-pass.mlir" turbine_tank.changeBlobName(blob_name, new_blob_name) + del torch_output + del turbine os.remove("stable_diffusion_v1_4_vae.safetensors") os.remove("stable_diffusion_v1_4_vae.vmfb") - # https://github.com/nod-ai/SHARK-Turbine/issues/536 - @unittest.expectedFailure def testExportVaeModelEncode(self): current_args = copy.deepcopy(default_arguments) blob_name = vae.export_vae_model( @@ -302,7 +315,7 @@ def testExportVaeModelEncode(self): current_args["batch_size"], current_args["height"], current_args["width"], - None, + current_args["precision"], "vmfb", "safetensors", "stable_diffusion_v1_4_vae.safetensors", @@ -320,16 +333,14 @@ def testExportVaeModelEncode(self): dtype=torch.float32, ) turbine = vae_runner.run_vae( - current_args["device"], + current_args["rt_device"], example_input, current_args["vmfb_path"], current_args["hf_model_name"], - current_args["hf_auth_token"], current_args["external_weight_path"], ) torch_output = vae_runner.run_torch_vae( current_args["hf_model_name"], - current_args["hf_auth_token"], "encode", example_input, ) @@ -371,7 +382,7 @@ def testExportPNDMScheduler(self): ) encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32) turbine = schedulers_runner.run_scheduler( - current_args["device"], + current_args["rt_device"], sample, encoder_hidden_states, current_args["vmfb_path"], @@ -394,6 +405,8 @@ def testExportPNDMScheduler(self): turbine_tank.changeBlobName(blob_name, new_blob_name) os.remove("stable_diffusion_v1_4_scheduler.safetensors") os.remove("stable_diffusion_v1_4_scheduler.vmfb") + del torch_output + del turbine if __name__ == "__main__": diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py new file mode 100644 index 000000000..a45fd7ca4 --- /dev/null +++ b/models/turbine_models/tests/sdxl_test.py @@ -0,0 +1,601 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import pytest +import torch +from turbine_models.custom_models.sd_inference.utils import create_safe_name +from turbine_models.custom_models.sdxl_inference import ( + clip, + clip_runner, + unet, + unet_runner, + vae, + vae_runner, + sdxl_compiled_pipeline, +) +from turbine_models.utils.sdxl_benchmark import run_benchmark +import unittest +from tqdm.auto import tqdm +from PIL import Image +import os +import numpy as np +import time + + +torch.random.manual_seed(0) + +arguments = {} + + +@pytest.fixture(scope="session") +def command_line_args(request): + arguments["hf_auth_token"] = request.config.getoption("--hf_auth_token") + arguments["hf_model_name"] = request.config.getoption("--hf_model_name") + arguments["scheduler_id"] = request.config.getoption("--scheduler_id") + arguments["prompt"] = request.config.getoption("--prompt") + arguments["negative_prompt"] = request.config.getoption("--negative_prompt") + arguments["num_inference_steps"] = int( + request.config.getoption("--num_inference_steps") + ) + arguments["guidance_scale"] = float(request.config.getoption("--guidance_scale")) + arguments["seed"] = float(request.config.getoption("--seed")) + arguments["vmfb_path"] = request.config.getoption("--vmfb_path") + arguments["external_weight_path"] = request.config.getoption( + "--external_weight_path" + ) + arguments["external_weight_dir"] = request.config.getoption("--external_weight_dir") + arguments["external_weight_file"] = request.config.getoption( + "--external_weight_file" + ) + arguments["pipeline_dir"] = request.config.getoption("--pipeline_dir") + arguments["batch_size"] = int(request.config.getoption("--batch_size")) + arguments["height"] = int(request.config.getoption("--height")) + arguments["width"] = int(request.config.getoption("--width")) + arguments["precision"] = request.config.getoption("--precision") + arguments["max_length"] = int(request.config.getoption("--max_length")) + arguments["run_vmfb"] = request.config.getoption("--run_vmfb") + arguments["compile_to"] = request.config.getoption("--compile_to") + arguments["external_weights"] = request.config.getoption("--external_weights") + arguments["decomp_attn"] = request.config.getoption("--decomp_attn") + arguments["attn_spec"] = request.config.getoption("--attn_spec") + arguments["device"] = request.config.getoption("--device") + arguments["rt_device"] = request.config.getoption("--rt_device") + arguments["iree_target_triple"] = request.config.getoption("--iree_target_triple") + arguments["ireec_flags"] = request.config.getoption("--ireec_flags") + arguments["attn_flags"] = request.config.getoption("--attn_flags") + arguments["in_channels"] = int(request.config.getoption("--in_channels")) + arguments["benchmark"] = request.config.getoption("--benchmark") + arguments["tracy_profile"] = request.config.getoption("--tracy_profile") + arguments["compiled_pipeline"] = request.config.getoption("--compiled_pipeline") + + +@pytest.mark.usefixtures("command_line_args") +class StableDiffusionXLTest(unittest.TestCase): + def setUp(self): + self.safe_model_name = create_safe_name(arguments["hf_model_name"], "") + self.unet_model = unet.UnetModel( + # This is a public model, so no auth required + arguments["hf_model_name"], + precision=arguments["precision"], + ) + self.vae_model = vae.VaeModel( + # This is a public model, so no auth required + arguments["hf_model_name"], + custom_vae=( + "madebyollin/sdxl-vae-fp16-fix" + if arguments["precision"] == "fp16" + else None + ), + ) + + def test01_ExportClipModels(self): + if arguments["device"] in ["vulkan", "cuda", "rocm"]: + self.skipTest( + "Compilation error on vulkan; Runtime error on rocm; To be tested on cuda." + ) + clip.export_clip_model( + # This is a public model, so no auth required + hf_model_name=arguments["hf_model_name"], + hf_auth_token=None, + max_length=arguments["max_length"], + precision=arguments["precision"], + compile_to="vmfb", + external_weights=arguments["external_weights"], + external_weight_path=self.safe_model_name + + "_" + + arguments["precision"] + + "_clip", + device=arguments["device"], + target_triple=arguments["iree_target_triple"], + ireec_flags=arguments["ireec_flags"], + index=1, + exit_on_vmfb=True, + ) + clip.export_clip_model( + hf_model_name=arguments["hf_model_name"], + hf_auth_token=None, # This is a public model, so no auth required + max_length=arguments["max_length"], + precision=arguments["precision"], + compile_to="vmfb", + external_weights=arguments["external_weights"], + external_weight_path=self.safe_model_name + + "_" + + arguments["precision"] + + "_clip", + device=arguments["device"], + target_triple=arguments["iree_target_triple"], + ireec_flags=arguments["ireec_flags"], + index=2, + exit_on_vmfb=True, + ) + arguments["external_weight_path_1"] = ( + self.safe_model_name + + "_" + + arguments["precision"] + + "_clip_1." + + arguments["external_weights"] + ) + arguments["external_weight_path_2"] = ( + self.safe_model_name + + "_" + + arguments["precision"] + + "_clip_2." + + arguments["external_weights"] + ) + arguments["vmfb_path_1"] = ( + self.safe_model_name + + "_" + + str(arguments["max_length"]) + + "_" + + arguments["precision"] + + "_clip_1_" + + arguments["device"] + + ".vmfb" + ) + arguments["vmfb_path_2"] = ( + self.safe_model_name + + "_" + + str(arguments["max_length"]) + + "_" + + arguments["precision"] + + "_clip_2_" + + arguments["device"] + + ".vmfb" + ) + turbine_1 = clip_runner.run_clip( + arguments["rt_device"], + arguments["prompt"], + arguments["vmfb_path_1"], + arguments["hf_model_name"], + arguments["hf_auth_token"], + arguments["external_weight_path_1"], + arguments["max_length"], + index=1, + ) + turbine_2 = clip_runner.run_clip( + arguments["rt_device"], + arguments["prompt"], + arguments["vmfb_path_2"], + arguments["hf_model_name"], + arguments["hf_auth_token"], + arguments["external_weight_path_2"], + arguments["max_length"], + index=2, + ) + torch_output_1, torch_output_2 = clip_runner.run_torch_clip( + arguments["hf_model_name"], + arguments["hf_auth_token"], + arguments["prompt"], + arguments["max_length"], + ) + if arguments["benchmark"] or arguments["tracy_profile"]: + run_benchmark( + "clip_1", + arguments["vmfb_path_1"], + arguments["external_weight_path_1"], + arguments["rt_device"], + max_length=arguments["max_length"], + tracy_profile=arguments["tracy_profile"], + ) + run_benchmark( + "clip_2", + arguments["vmfb_path_2"], + arguments["external_weight_path_2"], + arguments["rt_device"], + max_length=arguments["max_length"], + tracy_profile=arguments["tracy_profile"], + ) + rtol = 4e-1 + atol = 4e-1 + np.testing.assert_allclose(torch_output_1, turbine_1[0], rtol, atol) + np.testing.assert_allclose(torch_output_2, turbine_2[0], rtol, atol) + + def test02_ExportUnetModel(self): + if arguments["device"] in ["vulkan", "cuda", "rocm"]: + self.skipTest( + "Unknown error on vulkan; Runtime error on rocm; To be tested on cuda." + ) + unet.export_unet_model( + unet_model=self.unet_model, + # This is a public model, so no auth required + hf_model_name=arguments["hf_model_name"], + batch_size=arguments["batch_size"], + height=arguments["height"], + width=arguments["width"], + precision=arguments["precision"], + max_length=arguments["max_length"], + hf_auth_token=None, + compile_to="vmfb", + external_weights=arguments["external_weights"], + external_weight_path=self.safe_model_name + + "_" + + arguments["precision"] + + "_unet." + + arguments["external_weights"], + device=arguments["device"], + target_triple=arguments["iree_target_triple"], + ireec_flags=arguments["ireec_flags"], + decomp_attn=arguments["decomp_attn"], + ) + arguments["external_weight_path"] = ( + self.safe_model_name + + "_" + + arguments["precision"] + + "_unet." + + arguments["external_weights"] + ) + arguments["vmfb_path"] = ( + self.safe_model_name + + "_" + + str(arguments["max_length"]) + + "_" + + str(arguments["height"]) + + "x" + + str(arguments["width"]) + + "_" + + arguments["precision"] + + "_unet_" + + arguments["device"] + + ".vmfb" + ) + dtype = torch.float16 if arguments["precision"] == "fp16" else torch.float32 + sample = torch.rand( + ( + arguments["batch_size"], + arguments["in_channels"], + arguments["height"] // 8, + arguments["width"] // 8, + ), + dtype=dtype, + ) + timestep = torch.zeros(1, dtype=torch.int64) + prompt_embeds = torch.rand( + (2 * arguments["batch_size"], arguments["max_length"], 2048), + dtype=dtype, + ) + text_embeds = torch.rand(2 * arguments["batch_size"], 1280, dtype=dtype) + time_ids = torch.zeros(2 * arguments["batch_size"], 6, dtype=dtype) + guidance_scale = torch.Tensor([arguments["guidance_scale"]]).to(dtype) + + turbine = unet_runner.run_unet( + arguments["rt_device"], + sample, + timestep, + prompt_embeds, + text_embeds, + time_ids, + guidance_scale, + arguments["vmfb_path"], + arguments["hf_model_name"], + arguments["hf_auth_token"], + arguments["external_weight_path"], + ) + torch_output = unet_runner.run_torch_unet( + arguments["hf_model_name"], + arguments["hf_auth_token"], + sample.float(), + timestep, + prompt_embeds.float(), + text_embeds.float(), + time_ids.float(), + guidance_scale.float(), + precision=arguments["precision"], + ) + if arguments["benchmark"] or arguments["tracy_profile"]: + run_benchmark( + "unet", + arguments["vmfb_path"], + arguments["external_weight_path"], + arguments["rt_device"], + max_length=arguments["max_length"], + height=arguments["height"], + width=arguments["width"], + batch_size=arguments["batch_size"], + in_channels=arguments["in_channels"], + precision=arguments["precision"], + tracy_profile=arguments["tracy_profile"], + ) + rtol = 4e-2 + atol = 4e-2 + + np.testing.assert_allclose(torch_output, turbine, rtol, atol) + + def test03_ExportVaeModelDecode(self): + if arguments["device"] in ["vulkan", "cuda", "rocm"]: + self.skipTest( + "Compilation error on vulkan; Runtime error on rocm; To be tested on cuda." + ) + vae.export_vae_model( + vae_model=self.vae_model, + # This is a public model, so no auth required + hf_model_name=arguments["hf_model_name"], + batch_size=arguments["batch_size"], + height=arguments["height"], + width=arguments["width"], + precision=arguments["precision"], + compile_to="vmfb", + external_weights=arguments["external_weights"], + external_weight_path=self.safe_model_name + + "_" + + arguments["precision"] + + "_vae_decode." + + arguments["external_weights"], + device=arguments["device"], + target_triple=arguments["iree_target_triple"], + ireec_flags=arguments["ireec_flags"], + variant="decode", + decomp_attn=arguments["decomp_attn"], + exit_on_vmfb=True, + ) + arguments["external_weight_path"] = ( + self.safe_model_name + + "_" + + arguments["precision"] + + "_vae_decode." + + arguments["external_weights"] + ) + arguments["vmfb_path"] = ( + self.safe_model_name + + "_" + + str(arguments["height"]) + + "x" + + str(arguments["width"]) + + "_" + + arguments["precision"] + + "_vae_decode_" + + arguments["device"] + + ".vmfb" + ) + example_input = torch.ones( + arguments["batch_size"], + 4, + arguments["height"] // 8, + arguments["width"] // 8, + dtype=torch.float32, + ) + example_input_torch = example_input + if arguments["precision"] == "fp16": + example_input = example_input.half() + turbine = vae_runner.run_vae( + arguments["rt_device"], + example_input, + arguments["vmfb_path"], + arguments["hf_model_name"], + arguments["external_weight_path"], + ) + torch_output = vae_runner.run_torch_vae( + arguments["hf_model_name"], + ( + "madebyollin/sdxl-vae-fp16-fix" + if arguments["precision"] == "fp16" + else "" + ), + "decode", + example_input_torch, + ) + if arguments["benchmark"] or arguments["tracy_profile"]: + run_benchmark( + "vae_decode", + arguments["vmfb_path"], + arguments["external_weight_path"], + arguments["rt_device"], + height=arguments["height"], + width=arguments["width"], + precision=arguments["precision"], + tracy_profile=arguments["tracy_profile"], + ) + rtol = 4e-2 + atol = 4e-2 + + np.testing.assert_allclose(torch_output, turbine, rtol, atol) + + def test04_ExportVaeModelEncode(self): + if arguments["device"] in ["cpu", "vulkan", "cuda", "rocm"]: + self.skipTest( + "Compilation error on cpu, vulkan and rocm; To be tested on cuda." + ) + vae.export_vae_model( + vae_model=self.vae_model, + # This is a public model, so no auth required + hf_model_name=arguments["hf_model_name"], + batch_size=arguments["batch_size"], + height=arguments["height"], + width=arguments["width"], + precision=arguments["precision"], + compile_to="vmfb", + external_weights=arguments["external_weights"], + external_weight_path=self.safe_model_name + + "_" + + arguments["precision"] + + "_vae_encode." + + arguments["external_weights"], + device=arguments["device"], + target_triple=arguments["iree_target_triple"], + ireec_flags=arguments["ireec_flags"], + variant="encode", + decomp_attn=arguments["decomp_attn"], + exit_on_vmfb=True, + ) + arguments["external_weight_path"] = ( + self.safe_model_name + + "_" + + arguments["precision"] + + "_vae_encode." + + arguments["external_weights"] + ) + arguments["vmfb_path"] = ( + self.safe_model_name + + "_" + + str(arguments["height"]) + + "x" + + str(arguments["width"]) + + "_" + + arguments["precision"] + + "_vae_encode_" + + arguments["device"] + + ".vmfb" + ) + example_input = torch.ones( + arguments["batch_size"], + 3, + arguments["height"], + arguments["width"], + dtype=torch.float32, + ) + example_input_torch = example_input + if arguments["precision"] == "fp16": + example_input = example_input.half() + turbine = vae_runner.run_vae( + arguments["rt_device"], + example_input, + arguments["vmfb_path"], + arguments["hf_model_name"], + arguments["external_weight_path"], + ) + torch_output = vae_runner.run_torch_vae( + arguments["hf_model_name"], + ( + "madebyollin/sdxl-vae-fp16-fix" + if arguments["precision"] == "fp16" + else "" + ), + "encode", + example_input_torch, + ) + if arguments["benchmark"] or arguments["tracy_profile"]: + run_benchmark( + "vae_encode", + arguments["vmfb_path"], + arguments["external_weight_path"], + arguments["rt_device"], + height=arguments["height"], + width=arguments["width"], + precision=arguments["precision"], + tracy_profile=arguments["tracy_profile"], + ) + rtol = 4e-2 + atol = 4e-2 + np.testing.assert_allclose(torch_output, turbine, rtol, atol) + + def test05_t2i_generate_images(self): + if arguments["device"] in ["vulkan", "cuda", "rocm"]: + self.skipTest("Have issues with submodels on these backends") + mlirs = { + "vae_decode": None, + "prompt_encoder": None, + "scheduled_unet": None, + "pipeline": None, + "full_pipeline": None, + } + vmfbs = { + "vae_decode": None, + "prompt_encoder": None, + "scheduled_unet": None, + "pipeline": None, + "full_pipeline": None, + } + weights = { + "vae_decode": None, + "prompt_encoder": None, + "scheduled_unet": None, + "pipeline": None, + "full_pipeline": None, + } + + if not arguments["pipeline_dir"]: + pipe_id_list = [ + "sdxl_1_0", + str(arguments["height"]), + str(arguments["width"]), + str(arguments["max_length"]), + arguments["precision"], + arguments["device"], + ] + arguments["pipeline_dir"] = os.path.join( + ".", + "_".join(pipe_id_list), + ) + ireec_flags = { + "unet": arguments["ireec_flags"], + "vae": arguments["ireec_flags"], + "clip": arguments["ireec_flags"], + "pipeline": arguments["ireec_flags"], + } + user_mlir_list = [] + for submodel_id, mlir_path in zip(mlirs.keys(), user_mlir_list): + if submodel_id in mlir_path: + mlirs[submodel_id] = mlir_path + external_weights_dir = arguments["pipeline_dir"] + sdxl_pipe = sdxl_compiled_pipeline.SharkSDXLPipeline( + arguments["hf_model_name"], + arguments["scheduler_id"], + arguments["height"], + arguments["width"], + arguments["precision"], + arguments["max_length"], + arguments["batch_size"], + arguments["num_inference_steps"], + arguments["device"], + arguments["iree_target_triple"], + ireec_flags, + arguments["attn_spec"], + arguments["decomp_attn"], + arguments["pipeline_dir"], + external_weights_dir, + arguments["external_weights"], + ) + vmfbs, weights = sdxl_pipe.check_prepared( + mlirs, vmfbs, weights, interactive=False + ) + sdxl_pipe.load_pipeline( + vmfbs, weights, arguments["rt_device"], arguments["compiled_pipeline"] + ) + sdxl_pipe.generate_images( + arguments["prompt"], + arguments["negative_prompt"], + 1, + arguments["guidance_scale"], + arguments["seed"], + ) + print("Image generation complete.") + os.remove(os.path.join(arguments["pipeline_dir"], "prompt_encoder.vmfb")) + os.remove( + os.path.join( + arguments["pipeline_dir"], + arguments["scheduler_id"] + + "_unet_" + + str(arguments["num_inference_steps"]) + + ".vmfb", + ) + ) + os.remove(os.path.join(arguments["pipeline_dir"], "vae_decode.vmfb")) + os.remove(os.path.join(arguments["pipeline_dir"], "full_pipeline.vmfb")) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/models/turbine_models/turbine_tank/turbine_tank.py b/models/turbine_models/turbine_tank/turbine_tank.py index ae25978b4..e57f81764 100644 --- a/models/turbine_models/turbine_tank/turbine_tank.py +++ b/models/turbine_models/turbine_tank/turbine_tank.py @@ -29,7 +29,7 @@ os.makedirs(WORKDIR, exist_ok=True) connection_string = os.environ.get("AZURE_CONNECTION_STRING") -container_name = os.environ.get("AZURE_CONTAINER_NAME") +CONTAINER_NAME = os.environ.get("AZURE_CONTAINER_NAME") def get_short_git_sha() -> str: @@ -72,11 +72,11 @@ def uploadToBlobStorage(file_path, file_name): prefix = today + "_" + commit blob_service_client = BlobServiceClient.from_connection_string(connection_string) blob_client = blob_service_client.get_blob_client( - container=container_name, blob=prefix + "/" + file_name + container=CONTAINER_NAME, blob=prefix + "/" + file_name ) blob = blob_client.from_connection_string( conn_str=connection_string, - container_name=container_name, + CONTAINER_NAME=CONTAINER_NAME, blob_name=blob_client.blob_name, ) # we check to see if we already uploaded the blob (don't want to duplicate) @@ -117,7 +117,9 @@ def checkAndRemoveIfDownloadedOld(model_name: str, model_dir: str, prefix: str): return False -def download_public_folder(model_name: str, prefix: str, model_dir: str): +def download_public_folder( + model_name: str, prefix: str, model_dir: str, container_name=CONTAINER_NAME +): """Downloads a folder of blobs in azure container.""" blob_service_client = BlobServiceClient.from_connection_string(connection_string) container_client = blob_service_client.get_container_client( @@ -163,7 +165,7 @@ def compare(item1, item2): return 0 -def downloadModelArtifacts(model_name: str) -> str: +def downloadModelArtifacts(model_name: str, container_name=CONTAINER_NAME) -> str: model_name = model_name.replace("/", "_") container_client = BlobServiceClient.from_connection_string( connection_string diff --git a/models/turbine_models/utils/benchmark.py b/models/turbine_models/utils/benchmark.py new file mode 100644 index 000000000..28b97b9d3 --- /dev/null +++ b/models/turbine_models/utils/benchmark.py @@ -0,0 +1,137 @@ +import subprocess +from collections import namedtuple +import iree.runtime as ireert +import numpy as np + + +BenchmarkResult = namedtuple( + "BenchmarkResult", "benchmark_name time cpu_time iterations user_counters" +) + + +class BenchmarkToolError(Exception): + """Benchmark exception that preserves the command line and error output.""" + + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +class BenchmarkTimeoutError(Exception): + """Exception raised if the benchmark is cancelled by the user specified timeout.""" + + pass + + +DTYPE_TO_ABI_TYPE = { + np.dtype(np.float32): "f32", + np.dtype(np.float16): "f16", + np.dtype(np.int32): "i32", + np.dtype(np.int64): "i64", + np.dtype(np.float64): "f64", + np.dtype(np.int16): "i16", + np.dtype(np.int8): "i8", + np.dtype(np.bool_): "i1", +} + + +def benchmark_module( + module, + entry_function=None, + vmfbs=[], + inputs=[], + tracy_profile=False, + timeout=None, + **kwargs, +): + funcs = [a for a in module.function_names if a != "__init"] + if entry_function is None: + if len(funcs) > 1: + raise ValueError(f"No function specified with multiple options {funcs}") + entry_function = funcs[0] + if entry_function not in funcs: + raise ValueError( + f"Attempted to benchmark unknown function {entry_function} of options {funcs}" + ) + + args = [] + if tracy_profile: + args.append("TRACY_NO_EXIT=1") + # TODO: run iree-tracy-capture subprocess + args.append(ireert.benchmark_exe()) + args.append(f"--function={entry_function}") + + for inp in inputs: + if isinstance(inp, str): + args.append(f"--input={inp}") + continue + shape = "x".join([str(d) for d in inp.shape]) + abitype = DTYPE_TO_ABI_TYPE[inp.dtype] + values = inp.flatten() + if np.all(values[0] == values): + values = str(values[0]) + else: + values = ",".join([str(v) for v in values]) + input_arg = f"--input={shape}x{abitype}={values}" + if len(input_arg) > 256: + print( + f"Randomizing {input_arg.split('=')[0]} because it is too long for subprocess.run" + ) + input_arg = f"--input={shape}x{abitype}" + args.append(input_arg) + print(args) + + for k in kwargs: + v = kwargs[k] + args.append(f"--{k}={v}") + + for vmfb in vmfbs: + args.append(f"--module={vmfb}") + + try: + benchmark_process = subprocess.run( + args=args, + # input=flatbuffer, + timeout=timeout, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + except subprocess.TimeoutExpired: + raise BenchmarkTimeoutError(f"Benchmark timed out after {timeout} seconds") + out = benchmark_process.stdout + err = benchmark_process.stderr + + err = err.decode() + if "INVALID_ARGUMENT;" in err: + raise ValueError("Invalid inputs specified for benchmarking") + + # In the event benchmarking runs but encounteres an internal error, + # return the internal error instead of benchmark results. + if "INTERNAL; CUDA driver error" in str(out): + raise BenchmarkToolError(str(out)) + + # Grab individual results by line (skip header lines) + bench_lines = out.decode().split("\n")[3:] + benchmark_results = [] + for line in bench_lines: + split = line.split() + if len(split) == 0: + continue + benchmark_name = split[0] + time = " ".join(split[1:3]) + cpu_time = " ".join(split[3:5]) + iterations = split[5] + user_counters = None + if len(split) > 5: + user_counters = split[6] + benchmark_results.append( + BenchmarkResult( + benchmark_name=benchmark_name, + time=time, + cpu_time=cpu_time, + iterations=iterations, + user_counters=user_counters, + ) + ) + + return benchmark_results diff --git a/models/turbine_models/utils/sdxl_benchmark.py b/models/turbine_models/utils/sdxl_benchmark.py new file mode 100644 index 000000000..1c37f93a1 --- /dev/null +++ b/models/turbine_models/utils/sdxl_benchmark.py @@ -0,0 +1,75 @@ +import sys +from iree import runtime as ireert +from turbine_models.utils.benchmark import benchmark_module + + +DTYPE_MAP = { + "fp16": "f16", + "fp32": "f32", +} + + +def run_benchmark( + model, + vmfb_path, + weights_path, + device, + max_length=None, + height=None, + width=None, + batch_size=None, + in_channels=None, + precision=None, + tracy_profile=False, +): + config = ireert.Config(device) + + if not vmfb_path: + sys.exit("no vmfb_path provided, required for run_benchmark") + benchmark_mod = ireert.VmModule.mmap(config.vm_instance, vmfb_path) + + if weights_path: + index = ireert.ParameterIndex() + index.load(weights_path) + + vmfbs = [] + vmfbs.append(vmfb_path) + + inputs = [] + match model: + case "clip_1": + inputs.append(f"1x{max_length}xi64") + case "clip_2": + inputs.append(f"1x{max_length}xi64") + case "unet": + inputs.append( + f"{batch_size}x{in_channels}x{height//8}x{width//8}x{DTYPE_MAP[precision]}" + ) + inputs.append(f"1x{DTYPE_MAP[precision]}") + inputs.append(f"{2*batch_size}x{max_length}x2048x{DTYPE_MAP[precision]}") + inputs.append(f"{2*batch_size}x1280x{DTYPE_MAP[precision]}") + inputs.append(f"{2*batch_size}x6x{DTYPE_MAP[precision]}") + inputs.append(f"1x{DTYPE_MAP[precision]}") + case "vae_decode": + inputs.append(f"1x4x{height//8}x{width//8}x{DTYPE_MAP[precision]}") + case "vae_encode": + inputs.append(f"1x3x{height}x{width}x{DTYPE_MAP[precision]}") + case _: + sys.exit("model name doesn't match for inputs") + + if weights_path: + results = benchmark_module( + benchmark_mod, + "main", + vmfbs, + inputs, + tracy_profile, + parameters=f"model={weights_path}", + ) + else: + results = benchmark_module(benchmark_mod, "main", vmfbs, inputs, tracy_profile) + + for benchmark_result in results: + print( + f"model: {model}, benchmark_name: {benchmark_result.benchmark_name}, time: {benchmark_result.time}, cpu_time: {benchmark_result.cpu_time}, iterations: {benchmark_result.iterations}, user_counters: {benchmark_result.user_counters}" + ) diff --git a/serving/setup.py b/serving/setup.py index 37ad48703..53c9fc4f9 100644 --- a/serving/setup.py +++ b/serving/setup.py @@ -98,6 +98,7 @@ def initialize_options(self): f"iree-compiler{get_version_spec('iree-compiler')}", f"iree-runtime{get_version_spec('iree-runtime')}", f"uvicorn{get_version_spec('uvicorn')}", + f"requests{get_version_spec('requests')}", ], extras_require={ "testing": [