diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 2044863d0..5bee5a097 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -163,7 +163,7 @@ def replace_with_tk_kernels(flow_dialect_ir, batch_size): if batch_size == 1: kernels = [ "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/bs1/tk_gemm_fused_2x1024x10240x1280.mlir", - "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/bs1/tk_gemm_fused_2x1024x1280x5120.mlir" + "https://raw.githubusercontent.com/nod-ai/sdxl-scripts/main/int8-model/tk_kernels/bs1/tk_gemm_fused_2x1024x1280x5120.mlir", ] # Replace all calls to old kernel with new kernel @@ -192,8 +192,10 @@ def replace_with_tk_kernels(flow_dialect_ir, batch_size): if old_kernel in line and "func.func" in line: data = urlopen(kernel).read().decode("utf-8") data = data.split("\n") - idx_with_kernel_args = [idx for idx, s in enumerate(data) if 'func.func' in s][0] - kernel_args = data[idx_with_kernel_args].count('arg') + idx_with_kernel_args = [ + idx for idx, s in enumerate(data) if "func.func" in s + ][0] + kernel_args = data[idx_with_kernel_args].count("arg") num_args = line.count("arg") if num_args != kernel_args: continue @@ -552,11 +554,11 @@ def get_schedulers(model_id): model_id, subfolder="scheduler", ) - schedulers["EulerAncestralDiscrete"] = ( - EulerAncestralDiscreteScheduler.from_pretrained( - model_id, - subfolder="scheduler", - ) + schedulers[ + "EulerAncestralDiscrete" + ] = EulerAncestralDiscreteScheduler.from_pretrained( + model_id, + subfolder="scheduler", ) # schedulers["DPMSolverSDE"] = DPMSolverSDEScheduler.from_pretrained( # model_id,