Skip to content

Commit

Permalink
[tk kernel] Add support to match kernel with number of arguments and …
Browse files Browse the repository at this point in the history
…update kernel links
  • Loading branch information
nithinsubbiah committed Jul 25, 2024
1 parent 6f16731 commit 5876436
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions models/turbine_models/custom_models/sd_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,12 @@ def iree_backend_map(device):
def replace_with_tk_kernels(flow_dialect_ir, batch_size):
if batch_size == 8:
kernels = [
"https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/tk_gemm_fused_16x1024x10240x1280.mlir"
"https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/bs8/tk_gemm_fused_16x1024x10240x1280.mlir"
]
if batch_size == 1:
kernels = [
"https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/tk_gemm_fused_2x1024x10240x1280.mlir"
"https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/bs1/tk_gemm_fused_2x1024x10240x1280.mlir",
"https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/bs1/tk_gemm_fused_2x1024x1280x5120.mlir"
]

# Replace all calls to old kernel with new kernel
Expand All @@ -178,20 +179,24 @@ def replace_with_tk_kernels(flow_dialect_ir, batch_size):
for line in base:
for kernel in kernels:
suffix = kernel.split("/")[-1].split(".")[0].split("_")[-1]
bias_explicit = False
if "bias" in suffix:
bias_explicit = True
kernel_args = 3 + int(suffix[4:])
suffix = kernel.split(".")[0].split("_")[-2]
# Uncomment/rework when a kernel with bias comes in
# bias_explicit = False
# if "bias" in suffix:
# bias_explicit = True
# kernel_args = 3 + int(suffix[4:])
# suffix = kernel.split(".")[0].split("_")[-2]
B, M, N, K = suffix.split("x")
old_kernel = f"matmul_like_{B}x{M}x{N}x{K}"
if not old_kernel in line:
continue
if old_kernel in line and "func.func" in line:
if bias_explicit:
num_args = line.count("arg")
if num_args != kernel_args:
continue
data = urlopen(kernel).read().decode("utf-8")
data = data.split("\n")
idx_with_kernel_args = [idx for idx, s in enumerate(data) if 'func.func' in s][0]
kernel_args = data[idx_with_kernel_args].count('arg')
num_args = line.count("arg")
if num_args != kernel_args:
continue
kernel_map[kernel] = line.strip().split(" ")[1][1:-7]
prefix_map[kernel] = kernel_map[kernel].split(old_kernel)[0][:-1]
if (
Expand Down

0 comments on commit 5876436

Please sign in to comment.