Skip to content

Commit

Permalink
Enable running mmdit in onnx for the sd3 pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
ZchiPitt committed Sep 13, 2024
1 parent 32f5265 commit 0b70f6a
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 5 deletions.
97 changes: 96 additions & 1 deletion models/turbine_models/custom_models/pipeline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
)
from turbine_models.utils.sdxl_benchmark import run_benchmark
from turbine_models.model_runner import vmfbRunner
import onnxruntime
import pdb

from PIL import Image
import gc
Expand Down Expand Up @@ -74,6 +76,59 @@ def merge_export_arg(model_map, arg, arg_name):
# item = ast.literal_eval(item)
# return out

class OnnxPipelineComponent:
def __init__(
self,
printer,
dest_type="numpy",
dest_dtype="fp16",
):
self.ort_session = None
self.onnx_file_path = None
self.ep = None
self.dest_type = dest_type
self.dest_dtype = dest_dtype
self.printer = printer
self.supported_dtypes = ["fp32"]
self.default_dtype = "fp32"
self.used_dtype = dest_dtype if dest_dtype in self.supported_dtypes else self.default_dtype
def load(
self,
onnx_file_path: str,
ep="CPUExecutionProvider"
):
self.onnx_file_path = onnx_file_path
self.ep = ep

self.ort_session = onnxruntime.InferenceSession(onnx_file_path, providers=[ep])
self.printer.print(
f"Loading {onnx_file_path} into onnxruntime with {ep}."
)
def unload(self):
self.ort_session = None
gc.collect()

# input type only support numpy
def _convert_inputs(self, inputs):
for iname in inputs.keys():
inp = inputs[iname]
if isinstance(inp, ireert.DeviceArray):
inputs[iname] = inp.to_host()
inputs[iname] = inputs[iname].astype(np_dtypes[self.used_dtype])
return inputs
def _convert_output(self, output):
return output.astype(np_dtypes[self.dest_dtype])

def __call__(self, inputs: dict):
converted_inputs = self._convert_inputs(inputs)
# pdb.set_trace()
out = self.ort_session.run(
None,
converted_inputs,
)[0]
return self._convert_output(out)



class PipelineComponent:
"""
Expand Down Expand Up @@ -268,6 +323,16 @@ def __call__(self, function_name, inputs: list):

# def _run_and_validate(self, iree_fn, torch_fn, inputs: list)

class Bcolors:
HEADER = '\033[95m'
OKBLUE = '\033[94m'
OKCYAN = '\033[96m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'

class Printer:
def __init__(self, verbose, start_time, print_time):
Expand All @@ -284,24 +349,31 @@ def __init__(self, verbose, start_time, print_time):

def reset(self):
if self.print_time:
print(Bcolors.BOLD + Bcolors.WARNING)
if self.verbose:
self.print("Will now reset clock for printer to 0.0 [s].")
self.last_print = time.time()
self.start_time = time.time()
if self.verbose:
self.print("Clock for printer reset to t = 0.0 [s].")
print(Bcolors.ENDC, end='')

def print(self, message):
if self.verbose:
# Print something like "[t=0.123 dt=0.004] 'message'"
print(Bcolors.BOLD + Bcolors.OKCYAN)
if self.print_time:
time_now = time.time()
print(
f"[t={time_now - self.start_time:.3f} dt={time_now - self.last_print:.3f}] {message}"
f"[ts={time_now - self.start_time:.3f}s] {message}"
)
# print(
# f"[t={time_now - self.start_time:.3f} dt={time_now - self.last_print:.3f}] {message}"
# )
self.last_print = time_now
else:
print(f"{message}")
print(Bcolors.ENDC, end='')


class TurbinePipelineBase:
Expand Down Expand Up @@ -359,6 +431,8 @@ def __init__(
ireec_flags: str | dict[str] = None,
precision: str | dict[str] = "fp16",
attn_spec: str | dict[str] = None,
onnx_model_path: str | dict[str] = None,
run_onnx_mmdit: bool = False,
decomp_attn: bool | dict[bool] = False,
external_weights: str | dict[str] = None,
pipeline_dir: str = "./shark_vmfbs",
Expand All @@ -372,6 +446,7 @@ def __init__(
self.map = model_map
self.verbose = verbose
self.printer = Printer(self.verbose, time.time(), True)
self.run_onnx_mmdit=run_onnx_mmdit
if isinstance(device, dict):
assert isinstance(
target, dict
Expand All @@ -396,6 +471,7 @@ def __init__(
map_arguments = {
"ireec_flags": ireec_flags,
"precision": precision,
"onnx_model_path": onnx_model_path,
"attn_spec": attn_spec,
"decomp_attn": decomp_attn,
"external_weights": external_weights,
Expand All @@ -412,6 +488,7 @@ def __init__(
self.map = merge_arg_into_map(
self.map, torch_dtypes[self.map[submodel]["precision"]], "torch_dtype"
)
# pdb.set_trace()
for arg in common_export_args.keys():
for submodel in self.map.keys():
self.map[submodel].get("export_args", {})[arg] = self.map[submodel].get(
Expand Down Expand Up @@ -761,6 +838,8 @@ def load_map(self):
self.load_submodel(submodel)

def load_submodel(self, submodel):


if not self.map[submodel].get("vmfb"):
raise ValueError(f"VMFB not found for {submodel}.")
if not self.map[submodel].get("weights") and self.map[submodel].get(
Expand All @@ -783,6 +862,22 @@ def load_submodel(self, submodel):
)
setattr(self, submodel, self.map[submodel]["runner"])

# add an onnx runners
if self.run_onnx_mmdit and submodel == "mmdit":
dest_type = "numpy"
dest_dtype = self.map[submodel]["precision"]
onnx_runner = OnnxPipelineComponent(
printer=self.printer,
dest_type=dest_type,
dest_dtype=dest_dtype
)
ep = "CPUExecutionProvider"
onnx_runner.load(
onnx_file_path=self.map[submodel]["onnx_model_path"],
ep=ep
)
setattr(self, submodel+"_onnx", onnx_runner)

def unload_submodel(self, submodel):
self.map[submodel]["runner"].unload()
self.map[submodel]["vmfb"] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def export_mmdit_model(
attn_spec=None,
input_mlir=None,
weights_only=False,
onnx_model_path=None,
):
dtype = torch.float16 if precision == "fp16" else torch.float32
mmdit_model = MMDiTModel(
Expand Down
16 changes: 16 additions & 0 deletions models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,4 +445,20 @@ def is_valid_file(arg):
)


##############################################################################
# ONNX Options
##############################################################################
p.add_argument(
"--mmdit_onnx_model_path",
type=str,
default="C:/Users/chiz/work/sd3/mmdit/fp32/mmdit_optimized.onnx",
help="Path to mmdit onnx model",
)

p.add_argument(
"--run_onnx_mmdit",
action="store_true",
help="Run MMDiT in onnx",
)

args, unknown = p.parse_known_args()
40 changes: 36 additions & 4 deletions models/turbine_models/custom_models/sd_inference/sd_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import numpy as np
import time
from datetime import datetime as dt
import pdb

# These are arguments common among submodel exports.
# They are expected to be populated in two steps:
Expand Down Expand Up @@ -227,6 +228,8 @@ def __init__(
target: str | dict[str],
ireec_flags: str | dict[str] = None,
attn_spec: str | dict[str] = None,
onnx_model_path: str | dict[str] = None,
run_onnx_mmdit: bool = False,
decomp_attn: bool | dict[bool] = False,
pipeline_dir: str = "./shark_vmfbs",
external_weights_dir: str = "./shark_weights",
Expand Down Expand Up @@ -287,6 +290,8 @@ def __init__(
ireec_flags,
precision,
attn_spec,
onnx_model_path,
run_onnx_mmdit,
decomp_attn,
external_weights,
pipeline_dir,
Expand Down Expand Up @@ -419,6 +424,7 @@ def load_scheduler(
scheduler_id: str = None,
steps: int = 30,
):
# pdb.set_trace()
if not self.cpu_scheduling:
if self.is_sd3:
export_fn = sd3_schedulers.export_scheduler_model
Expand Down Expand Up @@ -460,6 +466,7 @@ def load_scheduler(
self.pipeline_dir,
utils.create_safe_name(self.base_model_name, scheduler_uid) + ".vmfb",
)
# pdb.set_trace()
if not os.path.exists(scheduler_path):
self.export_submodel("scheduler")
else:
Expand Down Expand Up @@ -720,10 +727,26 @@ def _produce_latents_sd3(
pooled_prompt_embeds,
t,
]
noise_pred = self.mmdit(
"run_forward",
mmdit_inputs,
)
# pdb.set_trace()
if hasattr(self, 'mmdit_onnx'):
# pdb.set_trace()
latent_model_input = latent_model_input.to_host()
batch = latent_model_input.shape[0]
batched_t = np.repeat(t.to_host(), batch)
noise_pred = self.mmdit_onnx(
{
"hidden_states": latent_model_input,
"encoder_hidden_states": prompt_embeds,
"pooled_projections" : pooled_prompt_embeds,
"timestep": batched_t,

}
)
else:
noise_pred = self.mmdit(
"run_forward",
mmdit_inputs,
)
latents = self.scheduler(
"run_step", [noise_pred, t, latents, guidance_scale, steps_list_gpu[i]]
)
Expand Down Expand Up @@ -754,6 +777,7 @@ def generate_images(
prompt = ""

self.cpu_scheduling = cpu_scheduling
# pdb.set_trace()
if steps and needs_new_scheduler:
self.num_inference_steps = steps
self.load_scheduler(scheduler_id, steps)
Expand Down Expand Up @@ -884,6 +908,9 @@ def numpy_to_pil_image(images):
"mmdit": args.mmdit_spec if args.mmdit_spec else args.attn_spec,
"vae": args.vae_spec if args.vae_spec else args.attn_spec,
}
onnx_model_paths = {
"mmdit": args.mmdit_onnx_model_path
}
if not args.pipeline_dir:
args.pipeline_dir = utils.create_safe_name(args.hf_model_name, "")
benchmark = {}
Expand Down Expand Up @@ -913,6 +940,7 @@ def numpy_to_pil_image(images):
),
"vae": args.vae_decomp_attn if args.vae_decomp_attn else args.decomp_attn,
}
# pdb.set_trace()
sd_pipe = SharkSDPipeline(
args.hf_model_name,
args.height,
Expand All @@ -924,6 +952,8 @@ def numpy_to_pil_image(images):
targets,
ireec_flags,
specs,
onnx_model_paths,
args.run_onnx_mmdit,
args.decomp_attn,
args.pipeline_dir,
args.external_weights_dir,
Expand All @@ -937,8 +967,10 @@ def numpy_to_pil_image(images):
args.verbose,
save_outputs=save_outputs,
)
# pdb.set_trace()
sd_pipe.prepare_all()
sd_pipe.load_map()
# pdb.set_trace()
sd_pipe.generate_images(
args.prompt,
args.negative_prompt,
Expand Down
54 changes: 54 additions & 0 deletions run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import os

class bcolors:
HEADER = '\033[95m'
OKBLUE = '\033[94m'
OKCYAN = '\033[96m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'

def print_cmd (cmd, pipeline, flags):
print(bcolors.BOLD + bcolors.OKGREEN)
print (cmd, pipeline)
for f in flags:
print("\t", f)
print(bcolors.ENDC)

cmd = "python"
pipeline = "models/turbine_models/custom_models/sd_inference/sd_pipeline.py"
prompt = "Photo of a ultra realistic sailing ship, dramatic light, pale sunrise, cinematic lighting, battered, low angle, trending on artstation, 4k, hyper realistic, focused, extreme details, unreal engine 5, cinematic, masterpiece, art by studio ghibli, intricate artwork by john william turner"
height = 512
width=512
mmdit_onnx_model_path = "C:/Users/chiz/work/sd3/mmdit/fp32/mmdit_optimized.onnx"
flags = [
"--hf_model_name=stabilityai/stable-diffusion-3-medium-diffusers",
f"--height={height}",
f"--width={width}",
"--clip_device=local-task",
"--clip_precision=fp16",
"--clip_target=znver4",
"--clip_decomp_attn",
"--mmdit_precision=fp16",
"--mmdit_device=rocm-legacy://0",
"--mmdit_target=gfx1150",
'''--mmdit_flags="masked_attention" ''',
"--run_onnx_mmdit",
f'''--mmdit_onnx_model_path="{mmdit_onnx_model_path}" ''',
"--vae_device=rocm-legacy://0",
"--vae_precision=fp16",
"--vae_target=gfx1150",
'''--vae_flags="masked_attention" ''',
"--external_weights=safetensors",
"--num_inference_steps=28",
"--verbose",
f'''--prompt="{prompt}" '''
]

print_cmd(cmd, pipeline, flags)

final_cmd = ' '.join([cmd, pipeline]+flags)
os.system(final_cmd)

0 comments on commit 0b70f6a

Please sign in to comment.