Skip to content

Commit

Permalink
Remove download links for tk kernels and instead specify kernel direc…
Browse files Browse the repository at this point in the history
…tory as an arg
  • Loading branch information
nithinsubbiah committed Jul 25, 2024
1 parent 0c02652 commit 0313c96
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def __init__(
vae_weight_path: str = None,
vae_harness: bool = True,
add_tk_kernels: bool = False,
tk_kernels_dir: str | dict[str] = None,
save_outputs: bool | dict[bool] = False,
):
common_export_args = {
Expand Down Expand Up @@ -320,6 +321,7 @@ def __init__(

self.split_scheduler = True
self.add_tk_kernels = add_tk_kernels
self.tk_kernels_dir = tk_kernels_dir

self.base_model_name = (
hf_model_name
Expand Down Expand Up @@ -373,6 +375,7 @@ def setup_punet(self):
if self.use_i8_punet:
if self.add_tk_kernels:
self.map["unet"]["export_args"]["add_tk_kernels"] = self.add_tk_kernels
self.map["unet"]["export_args"]["tk_kernels_dir"] = self.tk_kernels_dir
self.map["unet"]["export_args"]["precision"] = "i8"
self.map["unet"]["export_args"]["external_weight_path"] = (
utils.create_safe_name(self.base_model_name) + "_punet_dataset_i8.irpa"
Expand Down
51 changes: 21 additions & 30 deletions models/turbine_models/custom_models/sd_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import safetensors
import safetensors.numpy as safe_numpy
import re
import glob
from diffusers import (
PNDMScheduler,
EulerDiscreteScheduler,
Expand Down Expand Up @@ -155,16 +156,8 @@ def iree_backend_map(device):
return iree_device


def replace_with_tk_kernels(flow_dialect_ir, batch_size):
if batch_size == 8:
kernels = [
"https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/bs8/tk_gemm_fused_16x1024x10240x1280.mlir"
]
if batch_size == 1:
kernels = [
"https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/bs1/tk_gemm_fused_2x1024x10240x1280.mlir",
"https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/bs1/tk_gemm_fused_2x1024x1280x5120.mlir",
]
def replace_with_tk_kernels(tk_kernels_dir, flow_dialect_ir, batch_size):
kernels = glob.glob(tk_kernels_dir + "/bs" + str(batch_size) + "/*")

# Replace all calls to old kernel with new kernel
print("Inserting kernels and updating calls to kernels...")
Expand All @@ -178,25 +171,21 @@ def replace_with_tk_kernels(flow_dialect_ir, batch_size):
new_base = []
for line in base:
for kernel in kernels:
suffix = kernel.split("/")[-1].split(".")[0].split("_")[-1]
# Uncomment/rework when a kernel with bias comes in
# bias_explicit = False
# if "bias" in suffix:
# bias_explicit = True
# kernel_args = 3 + int(suffix[4:])
# suffix = kernel.split(".")[0].split("_")[-2]
suffix = kernel.split(".")[0].split("_")[-1]
if "bias" in suffix:
suffix = kernel.split(".")[0].split("_")[-2]
B, M, N, K = suffix.split("x")
old_kernel = f"matmul_like_{B}x{M}x{N}x{K}"
if not old_kernel in line:
continue
if old_kernel in line and "func.func" in line:
data = urlopen(kernel).read().decode("utf-8")
data = data.split("\n")
num_args = line.count("arg")
with open(kernel, "r") as f:
data = f.readlines()
idx_with_kernel_args = [
idx for idx, s in enumerate(data) if "func.func" in s
][0]
kernel_args = data[idx_with_kernel_args].count("arg")
num_args = line.count("arg")
if num_args != kernel_args:
continue
kernel_map[kernel] = line.strip().split(" ")[1][1:-7]
Expand All @@ -218,11 +207,12 @@ def replace_with_tk_kernels(flow_dialect_ir, batch_size):
and "flow.executable" in line
and "private" in line
):
data = urlopen(kernel).read().decode("utf-8")
data = data.split("\n")
with open(kernel, "r") as f:
data = f.readlines()
translation_info = data[0].split("#translation = ")[1].strip()
data[10] = data[10].replace("#translation", translation_info)
final_ir.append("\n".join(data[2:-3]))
extract = "".join(data[2:-2])
extract = extract.replace("#translation", translation_info)
final_ir += extract
final_ir.append(line)

print("tk kernels added")
Expand All @@ -245,6 +235,7 @@ def compile_to_vmfb(
flagset_keywords=[],
debug=False,
add_tk_kernels=False,
tk_kernels_dir=None,
batch_size=1,
):
if batch_size != 1 and batch_size != 8:
Expand Down Expand Up @@ -406,7 +397,7 @@ def compile_to_vmfb(

flow_ir = flatbuffer_blob.decode("utf-8")

flow_ir_tk = replace_with_tk_kernels(flow_ir, batch_size)
flow_ir_tk = replace_with_tk_kernels(tk_kernels_dir, flow_ir, batch_size)
module_str = "\n".join(flow_ir_tk)
flags.pop()
flags.extend(["--compile-from=flow"])
Expand Down Expand Up @@ -554,11 +545,11 @@ def get_schedulers(model_id):
model_id,
subfolder="scheduler",
)
schedulers[
"EulerAncestralDiscrete"
] = EulerAncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
schedulers["EulerAncestralDiscrete"] = (
EulerAncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
)
# schedulers["DPMSolverSDE"] = DPMSolverSDEScheduler.from_pretrained(
# model_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,4 +376,11 @@ def is_valid_file(arg):
help="Flag to add compiled tk kernels.",
)

p.add_argument(
"--tk_kernels_dir",
default=False,
action="store_true",
help="Path to directory containing tk kernels sorted by batch size.",
)

args, unknown = p.parse_known_args()

0 comments on commit 0313c96

Please sign in to comment.