diff --git a/models/README.md b/models/README.md index c917d03e..96f3dadd 100644 --- a/models/README.md +++ b/models/README.md @@ -2,6 +2,46 @@ For private/gated models, make sure you have run `huggingface-cli login`. +For MI Instinct: +```bash +#!/bin/bash +sudo apt install -y git + +# Clone and build IREE at the shared/sdxl_quantized branch +git clone https://github.com/iree-org/iree && cd iree +git checkout shared/sdxl_quantized +git submodule update --init +python -m venv iree.venv +pip install pybind11 numpy nanobind +cmake -S . -B build-release \ + -G Ninja -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_C_COMPILER=`which clang` -DCMAKE_CXX_COMPILER=`which clang++` \ + -DIREE_HAL_DRIVER_CUDA=OFF \ + -DIREE_BUILD_PYTHON_BINDINGS=ON \ + -DPython3_EXECUTABLE="$(which python3)" && \ + cmake --build build-release/ + +export PYTHONPATH=/path/to/iree/build-release/compiler/bindings/python:/path/to/iree/build-release/runtime/bindings/python + +# Clone and setup turbine-models +cd .. +git clone https://github.com/nod-ai/SHARK-Turbine.git && cd SHARK-Turbine +git checkout merge_punet_sdxl +pip install torch==2.5.0.dev20240801 torchvision --index-url https://download.pytorch.org/whl/nightly/cpu +pip install -r models/requirements.txt +pip uninstall -y iree-compiler iree-runtime + +pip install -e models + +# Run sdxl tests. +python models/turbine_models/tests/sdxl_test.py pytest --device=rocm --rt_device=hip --iree_target_triple=gfx942 --external_weights=safetensors --precision=fp16 --clip_spec=mfma --unet_spec=mfma --vae_spec=mfma + +# Generate an image. +# To reuse test artifacts/weights, add: --pipeline_dir=./test_vmfbs --external_weights_dir=./test_weights +python models/turbine_models/custom_models/sd_inference/sd_pipeline.py --hf_model_name=stabilityai/stable-diffusion-xl-base-1.0 --device=hip://0 --precision=fp16 --external_weights=safetensors --iree_target_triple=gfx942 --vae_decomp_attn --clip_decomp_attn --use_i8_punet --width=1024 --height=1024 --num_inference_steps=20 --benchmark=all --verbose + +``` +For GFX11 (RDNA3 Discrete GPUs/Ryzen laptops) the following setup is validated: ```bash #!/bin/bash diff --git a/models/requirements.txt b/models/requirements.txt index 06283efd..ec153a02 100644 --- a/models/requirements.txt +++ b/models/requirements.txt @@ -4,6 +4,7 @@ transformers==4.43.3 torchsde accelerate peft +safetensors==0.4.0 diffusers @ git+https://github.com/nod-ai/diffusers@0.29.0.dev0-shark brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b # turbine tank downloading/uploading diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index 03c4f64f..e37c095c 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -388,6 +388,11 @@ def __init__( self.map["unet"]["function_name"] = "run_forward" def setup_punet(self): + self.map["unet"]["mlir"] = None + self.map["unet"]["vmfb"] = None + self.map["unet"]["weights"] = None + self.map["unet"]["keywords"] = [i for i in self.map["unet"]["keywords"] if i != "!punet"] + self.map["unet"]["keywords"] += "punet" if self.use_i8_punet: if self.add_tk_kernels: self.map["unet"]["export_args"]["add_tk_kernels"] = self.add_tk_kernels diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 8db8806f..07e60998 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -22,16 +22,16 @@ "--iree-execution-model=async-external", ], "masked_attention": [ - "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, iree-preprocessing-pad-to-intrinsics, util.func(iree-linalg-ext-pad-attention{pad-to-multiple-of=0,64,0,32,0}))", + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics, iree-linalg-ext-pad-attention{pad-to-multiple-of=0,64,0,32,0}))", ], "punet": [ - "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" + "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics, iree-preprocessing-generalize-linalg-matmul-experimental))" ], "vae_preprocess": [ - "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))" + "--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics, iree-preprocessing-generalize-linalg-matmul-experimental))" ], "preprocess_default": [ - "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics)", + "--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))", ], "unet": [ "--iree-flow-enable-aggressive-fusion", @@ -52,7 +52,7 @@ ], "vae": [ "--iree-flow-enable-aggressive-fusion", - "--iree-global-opt-enable-fuse-horizontal-contractions", + "--iree-flow-enable-fuse-horizontal-contractions", "--iree-opt-aggressively-propagate-transposes=true", "--iree-codegen-llvmgpu-use-vector-distribution=true", "--iree-opt-data-tiling=false", @@ -350,15 +350,15 @@ def compile_to_vmfb( # the TD spec is implemented in C++. if attn_spec in ["default", "mfma", "punet"]: - if any(x in safe_name for x in ["clip", "prompt_encoder"]) == False: - use_punet = True if attn_spec in ["punet", "i8"] else False - attn_spec = get_mfma_spec_path( - target_triple, - os.path.dirname(safe_name), - use_punet=use_punet, - masked_attention=masked_attention, - ) - flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) +# if any(x in safe_name for x in ["clip", "prompt_encoder"]) == False: + use_punet = True if attn_spec in ["punet", "i8"] else False + attn_spec = get_mfma_spec_path( + target_triple, + os.path.dirname(safe_name), + use_punet=use_punet, + masked_attention=masked_attention, + ) + flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec]) elif attn_spec in ["wmma"] or ("gfx11" in target_triple and not attn_spec): attn_spec = get_wmma_spec_path( diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py index 40ce6c2e..d547cadf 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder.py @@ -177,7 +177,7 @@ def export_prompt_encoder( safe_name = utils.create_safe_name( hf_model_name, - f"_bs{batch_size}_{str(max_length)}-{precision}-prompt-encoder-{device}", + f"_bs{batch_size}_{str(max_length)}-{precision}-text-encoder-{device}", ) if pipeline_dir not in [None, ""]: safe_name = os.path.join(pipeline_dir, safe_name) @@ -275,7 +275,6 @@ def encode_prompts_turbo( } module = AddMetadataPass(module, model_metadata_encode, "encode_prompts").run() module_str = str(module) - if compile_to != "vmfb": return module_str else: diff --git a/models/turbine_models/tests/conftest.py b/models/turbine_models/tests/conftest.py index 4292c739..bcd8ba91 100644 --- a/models/turbine_models/tests/conftest.py +++ b/models/turbine_models/tests/conftest.py @@ -39,6 +39,9 @@ def pytest_addoption(parser): parser.addoption("--decomp_attn", action="store", default=False) parser.addoption("--vae_decomp_attn", action="store", default=False) parser.addoption("--attn_spec", action="store", default="") + parser.addoption("--clip_spec", action="store", default="") + parser.addoption("--unet_spec", action="store", default="") + parser.addoption("--vae_spec", action="store", default="") # Compiler Options parser.addoption("--device", action="store", default="cpu") parser.addoption("--rt_device", action="store", default="local-task") diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index 96b90b55..7cec4a66 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -24,6 +24,7 @@ import os import numpy as np import time +import gc torch.random.manual_seed(0) @@ -61,7 +62,11 @@ def command_line_args(request): arguments["compile_to"] = request.config.getoption("--compile_to") arguments["external_weights"] = request.config.getoption("--external_weights") arguments["decomp_attn"] = request.config.getoption("--decomp_attn") - arguments["attn_spec"] = request.config.getoption("--attn_spec") + arguments["attn_spec"] = request.config.getoption("--attn_spec") if request.config.getoption("attn_spec") else { + "text_encoder": request.config.getoption("clip_spec"), + "unet": request.config.getoption("unet_spec"), + "vae": request.config.getoption("vae_spec"), + } arguments["device"] = request.config.getoption("--device") arguments["rt_device"] = request.config.getoption("--rt_device") arguments["iree_target_triple"] = request.config.getoption("--iree_target_triple") @@ -111,9 +116,9 @@ def setUp(self): self.pipe.prepare_all() def test01_PromptEncoder(self): - if arguments["device"] in ["vulkan", "cuda", "rocm"]: + if arguments["device"] in ["vulkan", "cuda"]: self.skipTest( - "Compilation error on vulkan; recent numerics regression (nans) on hip driver, To be tested on cuda." + "Compilation error on vulkan; To be tested on cuda." ) arguments["vmfb_path"] = self.pipe.map["text_encoder"]["vmfb"] arguments["external_weight_path"] = self.pipe.map["text_encoder"]["weights"] @@ -235,7 +240,6 @@ def test02_ExportUnetModel(self): ) rtol = 4e-2 atol = 4e-1 - np.testing.assert_allclose(torch_output, turbine, rtol, atol) def test03_ExportVaeModelDecode(self): @@ -279,7 +283,6 @@ def test03_ExportVaeModelDecode(self): ) rtol = 4e-2 atol = 4e-1 - np.testing.assert_allclose(torch_output, turbine, rtol, atol) @pytest.mark.xfail(reason="NaN output on rocm, needs triage and file") @@ -345,13 +348,13 @@ def test05_t2i_generate_images(self): ) assert output is not None - @pytest.mark.xfail(reason="compilation issue on gfx90a") def test06_t2i_generate_images_punet(self): - if arguments["device"] in ["vulkan", "cuda", "rocm"]: + if arguments["device"] in ["vulkan", "cuda"]: self.skipTest( - "Have issues with submodels on vulkan, cuda; ROCM hangs on mi250 despite submodels working." + "Have issues with submodels on vulkan, cuda" ) - self.pipe.unload_submodel("unet") + if getattr(self.pipe, "unet"): + self.pipe.unload_submodel("unet") self.pipe.use_punet = True self.pipe.use_i8_punet = True self.pipe.setup_punet() @@ -369,6 +372,11 @@ def test06_t2i_generate_images_punet(self): True, # return_img ) assert output is not None + + def tearDown(self): + del self.pipe + gc.collect() + if __name__ == "__main__":