From 0313c96849b466a1e3df1d3e499a790834ce4950 Mon Sep 17 00:00:00 2001 From: nithinsubbiah Date: Thu, 25 Jul 2024 14:40:12 -0500 Subject: [PATCH] Remove download links for tk kernels and instead specify kernel directory as an arg --- .../custom_models/sd_inference/sd_pipeline.py | 3 ++ .../custom_models/sd_inference/utils.py | 51 ++++++++----------- .../sdxl_inference/sdxl_cmd_opts.py | 7 +++ 3 files changed, 31 insertions(+), 30 deletions(-) diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index 97369c53..a322cb08 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -238,6 +238,7 @@ def __init__( vae_weight_path: str = None, vae_harness: bool = True, add_tk_kernels: bool = False, + tk_kernels_dir: str | dict[str] = None, save_outputs: bool | dict[bool] = False, ): common_export_args = { @@ -320,6 +321,7 @@ def __init__( self.split_scheduler = True self.add_tk_kernels = add_tk_kernels + self.tk_kernels_dir = tk_kernels_dir self.base_model_name = ( hf_model_name @@ -373,6 +375,7 @@ def setup_punet(self): if self.use_i8_punet: if self.add_tk_kernels: self.map["unet"]["export_args"]["add_tk_kernels"] = self.add_tk_kernels + self.map["unet"]["export_args"]["tk_kernels_dir"] = self.tk_kernels_dir self.map["unet"]["export_args"]["precision"] = "i8" self.map["unet"]["export_args"]["external_weight_path"] = ( utils.create_safe_name(self.base_model_name) + "_punet_dataset_i8.irpa" diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 5bee5a09..84d9cb3b 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -5,6 +5,7 @@ import safetensors import safetensors.numpy as safe_numpy import re +import glob from diffusers import ( PNDMScheduler, EulerDiscreteScheduler, @@ -155,16 +156,8 @@ def iree_backend_map(device): return iree_device -def replace_with_tk_kernels(flow_dialect_ir, batch_size): - if batch_size == 8: - kernels = [ - "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/bs8/tk_gemm_fused_16x1024x10240x1280.mlir" - ] - if batch_size == 1: - kernels = [ - "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/bs1/tk_gemm_fused_2x1024x10240x1280.mlir", - "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/bs1/tk_gemm_fused_2x1024x1280x5120.mlir", - ] +def replace_with_tk_kernels(tk_kernels_dir, flow_dialect_ir, batch_size): + kernels = glob.glob(tk_kernels_dir + "/bs" + str(batch_size) + "/*") # Replace all calls to old kernel with new kernel print("Inserting kernels and updating calls to kernels...") @@ -178,25 +171,21 @@ def replace_with_tk_kernels(flow_dialect_ir, batch_size): new_base = [] for line in base: for kernel in kernels: - suffix = kernel.split("/")[-1].split(".")[0].split("_")[-1] - # Uncomment/rework when a kernel with bias comes in - # bias_explicit = False - # if "bias" in suffix: - # bias_explicit = True - # kernel_args = 3 + int(suffix[4:]) - # suffix = kernel.split(".")[0].split("_")[-2] + suffix = kernel.split(".")[0].split("_")[-1] + if "bias" in suffix: + suffix = kernel.split(".")[0].split("_")[-2] B, M, N, K = suffix.split("x") old_kernel = f"matmul_like_{B}x{M}x{N}x{K}" if not old_kernel in line: continue if old_kernel in line and "func.func" in line: - data = urlopen(kernel).read().decode("utf-8") - data = data.split("\n") + num_args = line.count("arg") + with open(kernel, "r") as f: + data = f.readlines() idx_with_kernel_args = [ idx for idx, s in enumerate(data) if "func.func" in s ][0] kernel_args = data[idx_with_kernel_args].count("arg") - num_args = line.count("arg") if num_args != kernel_args: continue kernel_map[kernel] = line.strip().split(" ")[1][1:-7] @@ -218,11 +207,12 @@ def replace_with_tk_kernels(flow_dialect_ir, batch_size): and "flow.executable" in line and "private" in line ): - data = urlopen(kernel).read().decode("utf-8") - data = data.split("\n") + with open(kernel, "r") as f: + data = f.readlines() translation_info = data[0].split("#translation = ")[1].strip() - data[10] = data[10].replace("#translation", translation_info) - final_ir.append("\n".join(data[2:-3])) + extract = "".join(data[2:-2]) + extract = extract.replace("#translation", translation_info) + final_ir += extract final_ir.append(line) print("tk kernels added") @@ -245,6 +235,7 @@ def compile_to_vmfb( flagset_keywords=[], debug=False, add_tk_kernels=False, + tk_kernels_dir=None, batch_size=1, ): if batch_size != 1 and batch_size != 8: @@ -406,7 +397,7 @@ def compile_to_vmfb( flow_ir = flatbuffer_blob.decode("utf-8") - flow_ir_tk = replace_with_tk_kernels(flow_ir, batch_size) + flow_ir_tk = replace_with_tk_kernels(tk_kernels_dir, flow_ir, batch_size) module_str = "\n".join(flow_ir_tk) flags.pop() flags.extend(["--compile-from=flow"]) @@ -554,11 +545,11 @@ def get_schedulers(model_id): model_id, subfolder="scheduler", ) - schedulers[ - "EulerAncestralDiscrete" - ] = EulerAncestralDiscreteScheduler.from_pretrained( - model_id, - subfolder="scheduler", + schedulers["EulerAncestralDiscrete"] = ( + EulerAncestralDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", + ) ) # schedulers["DPMSolverSDE"] = DPMSolverSDEScheduler.from_pretrained( # model_id, diff --git a/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py b/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py index 626df59c..017244f6 100644 --- a/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py +++ b/models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py @@ -376,4 +376,11 @@ def is_valid_file(arg): help="Flag to add compiled tk kernels.", ) +p.add_argument( + "--tk_kernels_dir", + default=False, + action="store_true", + help="Path to directory containing tk kernels sorted by batch size.", +) + args, unknown = p.parse_known_args()