Skip to content

Commit

Permalink
Merge branch 'merge_punet_sdxl' of https://github.com/nod-ai/SHARK-Tu…
Browse files Browse the repository at this point in the history
…rbine into merge_punet_sdxl
  • Loading branch information
eagarvey-amd committed Aug 17, 2024
2 parents 18bffdb + 2d7a92e commit df85dca
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 25 deletions.
40 changes: 40 additions & 0 deletions models/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,46 @@

For private/gated models, make sure you have run `huggingface-cli login`.

For MI Instinct:
```bash
#!/bin/bash
sudo apt install -y git

# Clone and build IREE at the shared/sdxl_quantized branch
git clone https://github.com/iree-org/iree && cd iree
git checkout shared/sdxl_quantized
git submodule update --init
python -m venv iree.venv
pip install pybind11 numpy nanobind
cmake -S . -B build-release \
-G Ninja -DCMAKE_BUILD_TYPE=Release \
-DCMAKE_C_COMPILER=`which clang` -DCMAKE_CXX_COMPILER=`which clang++` \
-DIREE_HAL_DRIVER_CUDA=OFF \
-DIREE_BUILD_PYTHON_BINDINGS=ON \
-DPython3_EXECUTABLE="$(which python3)" && \
cmake --build build-release/

export PYTHONPATH=/path/to/iree/build-release/compiler/bindings/python:/path/to/iree/build-release/runtime/bindings/python

# Clone and setup turbine-models
cd ..
git clone https://github.com/nod-ai/SHARK-Turbine.git && cd SHARK-Turbine
git checkout merge_punet_sdxl
pip install torch==2.5.0.dev20240801 torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
pip install -r models/requirements.txt
pip uninstall -y iree-compiler iree-runtime

pip install -e models

# Run sdxl tests.
python models/turbine_models/tests/sdxl_test.py pytest --device=rocm --rt_device=hip --iree_target_triple=gfx942 --external_weights=safetensors --precision=fp16 --clip_spec=mfma --unet_spec=mfma --vae_spec=mfma

# Generate an image.
# To reuse test artifacts/weights, add: --pipeline_dir=./test_vmfbs --external_weights_dir=./test_weights
python models/turbine_models/custom_models/sd_inference/sd_pipeline.py --hf_model_name=stabilityai/stable-diffusion-xl-base-1.0 --device=hip://0 --precision=fp16 --external_weights=safetensors --iree_target_triple=gfx942 --vae_decomp_attn --clip_decomp_attn --use_i8_punet --width=1024 --height=1024 --num_inference_steps=20 --benchmark=all --verbose

```
For GFX11 (RDNA3 Discrete GPUs/Ryzen laptops) the following setup is validated:
```bash
#!/bin/bash

Expand Down
1 change: 1 addition & 0 deletions models/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ transformers==4.43.3
torchsde
accelerate
peft
safetensors==0.4.0
diffusers @ git+https://github.com/nod-ai/diffusers@0.29.0.dev0-shark
brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b
# turbine tank downloading/uploading
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,11 @@ def __init__(
self.map["unet"]["function_name"] = "run_forward"

def setup_punet(self):
self.map["unet"]["mlir"] = None
self.map["unet"]["vmfb"] = None
self.map["unet"]["weights"] = None
self.map["unet"]["keywords"] = [i for i in self.map["unet"]["keywords"] if i != "!punet"]
self.map["unet"]["keywords"] += "punet"
if self.use_i8_punet:
if self.add_tk_kernels:
self.map["unet"]["export_args"]["add_tk_kernels"] = self.add_tk_kernels
Expand Down
28 changes: 14 additions & 14 deletions models/turbine_models/custom_models/sd_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@
"--iree-execution-model=async-external",
],
"masked_attention": [
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, iree-preprocessing-pad-to-intrinsics, util.func(iree-linalg-ext-pad-attention{pad-to-multiple-of=0,64,0,32,0}))",
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-global-opt-raise-special-ops, util.func(iree-preprocessing-pad-to-intrinsics, iree-linalg-ext-pad-attention{pad-to-multiple-of=0,64,0,32,0}))",
],
"punet": [
"--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))"
"--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics, iree-preprocessing-generalize-linalg-matmul-experimental))"
],
"vae_preprocess": [
"--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics, util.func(iree-preprocessing-generalize-linalg-matmul-experimental))"
"--iree-preprocessing-pass-pipeline=builtin.module(util.func(iree-global-opt-raise-special-ops, iree-flow-canonicalize), iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics, iree-preprocessing-generalize-linalg-matmul-experimental))"
],
"preprocess_default": [
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, iree-preprocessing-pad-to-intrinsics)",
"--iree-preprocessing-pass-pipeline=builtin.module(iree-preprocessing-transpose-convolution-pipeline, util.func(iree-preprocessing-pad-to-intrinsics))",
],
"unet": [
"--iree-flow-enable-aggressive-fusion",
Expand All @@ -52,7 +52,7 @@
],
"vae": [
"--iree-flow-enable-aggressive-fusion",
"--iree-global-opt-enable-fuse-horizontal-contractions",
"--iree-flow-enable-fuse-horizontal-contractions",
"--iree-opt-aggressively-propagate-transposes=true",
"--iree-codegen-llvmgpu-use-vector-distribution=true",
"--iree-opt-data-tiling=false",
Expand Down Expand Up @@ -350,15 +350,15 @@ def compile_to_vmfb(
# the TD spec is implemented in C++.

if attn_spec in ["default", "mfma", "punet"]:
if any(x in safe_name for x in ["clip", "prompt_encoder"]) == False:
use_punet = True if attn_spec in ["punet", "i8"] else False
attn_spec = get_mfma_spec_path(
target_triple,
os.path.dirname(safe_name),
use_punet=use_punet,
masked_attention=masked_attention,
)
flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec])
# if any(x in safe_name for x in ["clip", "prompt_encoder"]) == False:
use_punet = True if attn_spec in ["punet", "i8"] else False
attn_spec = get_mfma_spec_path(
target_triple,
os.path.dirname(safe_name),
use_punet=use_punet,
masked_attention=masked_attention,
)
flags.extend(["--iree-codegen-transform-dialect-library=" + attn_spec])

elif attn_spec in ["wmma"] or ("gfx11" in target_triple and not attn_spec):
attn_spec = get_wmma_spec_path(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def export_prompt_encoder(

safe_name = utils.create_safe_name(
hf_model_name,
f"_bs{batch_size}_{str(max_length)}-{precision}-prompt-encoder-{device}",
f"_bs{batch_size}_{str(max_length)}-{precision}-text-encoder-{device}",
)
if pipeline_dir not in [None, ""]:
safe_name = os.path.join(pipeline_dir, safe_name)
Expand Down Expand Up @@ -275,7 +275,6 @@ def encode_prompts_turbo(
}
module = AddMetadataPass(module, model_metadata_encode, "encode_prompts").run()
module_str = str(module)

if compile_to != "vmfb":
return module_str
else:
Expand Down
3 changes: 3 additions & 0 deletions models/turbine_models/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def pytest_addoption(parser):
parser.addoption("--decomp_attn", action="store", default=False)
parser.addoption("--vae_decomp_attn", action="store", default=False)
parser.addoption("--attn_spec", action="store", default="")
parser.addoption("--clip_spec", action="store", default="")
parser.addoption("--unet_spec", action="store", default="")
parser.addoption("--vae_spec", action="store", default="")
# Compiler Options
parser.addoption("--device", action="store", default="cpu")
parser.addoption("--rt_device", action="store", default="local-task")
Expand Down
26 changes: 17 additions & 9 deletions models/turbine_models/tests/sdxl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import os
import numpy as np
import time
import gc


torch.random.manual_seed(0)
Expand Down Expand Up @@ -61,7 +62,11 @@ def command_line_args(request):
arguments["compile_to"] = request.config.getoption("--compile_to")
arguments["external_weights"] = request.config.getoption("--external_weights")
arguments["decomp_attn"] = request.config.getoption("--decomp_attn")
arguments["attn_spec"] = request.config.getoption("--attn_spec")
arguments["attn_spec"] = request.config.getoption("--attn_spec") if request.config.getoption("attn_spec") else {
"text_encoder": request.config.getoption("clip_spec"),
"unet": request.config.getoption("unet_spec"),
"vae": request.config.getoption("vae_spec"),
}
arguments["device"] = request.config.getoption("--device")
arguments["rt_device"] = request.config.getoption("--rt_device")
arguments["iree_target_triple"] = request.config.getoption("--iree_target_triple")
Expand Down Expand Up @@ -111,9 +116,9 @@ def setUp(self):
self.pipe.prepare_all()

def test01_PromptEncoder(self):
if arguments["device"] in ["vulkan", "cuda", "rocm"]:
if arguments["device"] in ["vulkan", "cuda"]:
self.skipTest(
"Compilation error on vulkan; recent numerics regression (nans) on hip driver, To be tested on cuda."
"Compilation error on vulkan; To be tested on cuda."
)
arguments["vmfb_path"] = self.pipe.map["text_encoder"]["vmfb"]
arguments["external_weight_path"] = self.pipe.map["text_encoder"]["weights"]
Expand Down Expand Up @@ -235,7 +240,6 @@ def test02_ExportUnetModel(self):
)
rtol = 4e-2
atol = 4e-1

np.testing.assert_allclose(torch_output, turbine, rtol, atol)

def test03_ExportVaeModelDecode(self):
Expand Down Expand Up @@ -279,7 +283,6 @@ def test03_ExportVaeModelDecode(self):
)
rtol = 4e-2
atol = 4e-1

np.testing.assert_allclose(torch_output, turbine, rtol, atol)

@pytest.mark.xfail(reason="NaN output on rocm, needs triage and file")
Expand Down Expand Up @@ -345,13 +348,13 @@ def test05_t2i_generate_images(self):
)
assert output is not None

@pytest.mark.xfail(reason="compilation issue on gfx90a")
def test06_t2i_generate_images_punet(self):
if arguments["device"] in ["vulkan", "cuda", "rocm"]:
if arguments["device"] in ["vulkan", "cuda"]:
self.skipTest(
"Have issues with submodels on vulkan, cuda; ROCM hangs on mi250 despite submodels working."
"Have issues with submodels on vulkan, cuda"
)
self.pipe.unload_submodel("unet")
if getattr(self.pipe, "unet"):
self.pipe.unload_submodel("unet")
self.pipe.use_punet = True
self.pipe.use_i8_punet = True
self.pipe.setup_punet()
Expand All @@ -369,6 +372,11 @@ def test06_t2i_generate_images_punet(self):
True, # return_img
)
assert output is not None

def tearDown(self):
del self.pipe
gc.collect()



if __name__ == "__main__":
Expand Down

0 comments on commit df85dca

Please sign in to comment.