From 32f526516be87a3b549949d770bf562910f92d5f Mon Sep 17 00:00:00 2001 From: Chi Zhang Date: Thu, 12 Sep 2024 18:40:05 -0700 Subject: [PATCH] Add scripts to export mmdit and vae into onnx format --- .../sd3_inference/sd3_mmdit_onnx.py | 130 ++++++++++++++++++ .../sd3_inference/sd3_vae_onnx.py | 130 ++++++++++++++++++ 2 files changed, 260 insertions(+) create mode 100644 models/turbine_models/custom_models/sd3_inference/sd3_mmdit_onnx.py create mode 100644 models/turbine_models/custom_models/sd3_inference/sd3_vae_onnx.py diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_onnx.py b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_onnx.py new file mode 100644 index 000000000..fc3a53c06 --- /dev/null +++ b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_onnx.py @@ -0,0 +1,130 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import copy +import os +import sys +import math + +import numpy as np +from shark_turbine.aot import * + +from turbine_models.custom_models.sd_inference import utils +import torch +import torch._dynamo as dynamo +from diffusers import SD3Transformer2DModel + + +class MMDiTModel(torch.nn.Module): + def __init__( + self, + hf_model_name="stabilityai/stable-diffusion-3-medium-diffusers", + dtype=torch.float16, + ): + super().__init__() + self.mmdit = SD3Transformer2DModel.from_pretrained( + hf_model_name, + subfolder="transformer", + torch_dtype=dtype, + low_cpu_mem_usage=False, + ) + + def forward( + self, + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep, + ): + # timestep.expand(hidden_states.shape[0]) + noise_pred = self.mmdit( + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep, + return_dict=False, + )[0] + return noise_pred + +@torch.no_grad() +def export_mmdit_model( + hf_model_name="stabilityai/stable-diffusion-3-medium-diffusers", + batch_size=1, + height=512, + width=512, + precision="fp16", + max_length=77 +): + dtype = torch.float16 if precision == "fp16" else torch.float32 + mmdit_model = MMDiTModel( + dtype=dtype, + ) + file_prefix = "C:/Users/chiz/work/sd3/mmdit/exported/" + safe_name = file_prefix + utils.create_safe_name( + hf_model_name, + f"_bs{batch_size}_{max_length}_{height}x{width}_{precision}_mmdit", + ) + ".onnx" + print(safe_name) + + do_classifier_free_guidance = True + init_batch_dim = 2 if do_classifier_free_guidance else 1 + batch_size = batch_size * init_batch_dim + hidden_states_shape = ( + batch_size, + 16, + height // 8, + width // 8, + ) + encoder_hidden_states_shape = (batch_size, 154, 4096) + pooled_projections_shape = (batch_size, 2048) + hidden_states = torch.empty(hidden_states_shape, dtype=dtype) + encoder_hidden_states = torch.empty(encoder_hidden_states_shape, dtype=dtype) + pooled_projections = torch.empty(pooled_projections_shape, dtype=dtype) + timestep = torch.empty(batch_size, dtype=dtype) + # mmdit_model(hidden_states, encoder_hidden_states, pooled_projections, timestep) + + torch.onnx.export( + mmdit_model, # model being run + ( + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep + ), # model input (or a tuple for multiple inputs) + safe_name, # where to save the model (can be a file or file-like object) + export_params=True, # store the trained parameter weights inside the model file + opset_version=17, # the ONNX version to export the model to + do_constant_folding=True, # whether to execute constant folding for optimization + input_names=[ + "hidden_states", + "encoder_hidden_states", + "pooled_projections", + "timestep" + ], # the model's input names + output_names=[ + "sample_out", + ], # the model's output names + ) + return safe_name + + + +if __name__ == "__main__": + import logging + + logging.basicConfig(level=logging.DEBUG) + from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args + + onnx_model_name = export_mmdit_model( + args.hf_model_name, + 1, # args.batch_size, + 512, # args.height, + 512, # args.width, + "fp16", # args.precision, + 77, # args.max_length, + ) + + print("Saved to", onnx_model_name) diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_vae_onnx.py b/models/turbine_models/custom_models/sd3_inference/sd3_vae_onnx.py new file mode 100644 index 000000000..5d97c623c --- /dev/null +++ b/models/turbine_models/custom_models/sd3_inference/sd3_vae_onnx.py @@ -0,0 +1,130 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import copy +import os +import sys + +from iree import runtime as ireert +from iree.compiler.ir import Context +import numpy as np +from shark_turbine.aot import * +from shark_turbine.dynamo.passes import ( + DEFAULT_DECOMPOSITIONS, +) +from turbine_models.custom_models.sd_inference import utils +import torch +import torch._dynamo as dynamo +from diffusers import AutoencoderKL + + +class VaeModel(torch.nn.Module): + def __init__( + self, + hf_model_name, + ): + super().__init__() + self.vae = AutoencoderKL.from_pretrained( + hf_model_name, + subfolder="vae", + ) + + def forward(self, inp): + inp = (inp / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(inp, return_dict=False)[0] + image = image.float() + image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] + return image + + # def decode(self, inp): + # inp = (inp / self.vae.config.scaling_factor) + self.vae.config.shift_factor + # image = self.vae.decode(inp, return_dict=False)[0] + # image = image.float() + # image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] + # return image + + # def encode(self, inp): + # image_np = inp / 255.0 + # image_np = np.moveaxis(image_np, 2, 0) + # batch_images = np.expand_dims(image_np, axis=0).repeat(1, axis=0) + # image_torch = torch.from_numpy(batch_images) + # image_torch = 2.0 * image_torch - 1.0 + # image_torch = image_torch + # latent = self.vae.encode(image_torch) + # return latent + + +def export_vae_model( + vae_model, + hf_model_name="stabilityai/stable-diffusion-3-medium-diffusers", + batch_size=1, + height=512, + width=512, + precision="fp32" +): + dtype = torch.float16 if precision == "fp16" else torch.float32 + file_prefix = "C:/Users/chiz/work/sd3/vae_decoder/exported/" + safe_name = file_prefix + utils.create_safe_name( + hf_model_name, + f"_bs{batch_size}_{height}x{width}_{precision}_vae", + ) + ".onnx" + print(safe_name) + + if dtype == torch.float16: + vae_model = vae_model.half() + + + # input_image_shape = (height, width, 3) + input_latents_shape = (batch_size, 16, height // 8, width // 8) + input_latents = torch.empty(input_latents_shape, dtype=dtype) + # encode_args = [ + # torch.empty( + # input_image_shape, + # dtype=torch.float32, + # ) + # ] + # decode_args = [ + # torch.empty( + # input_latents_shape, + # dtype=dtype, + # ) + # ] + + torch.onnx.export( + vae_model, # model being run + ( + input_latents + ), # model input (or a tuple for multiple inputs) + safe_name, # where to save the model (can be a file or file-like object) + export_params=True, # store the trained parameter weights inside the model file + opset_version=17, # the ONNX version to export the model to + do_constant_folding=True, # whether to execute constant folding for optimization + input_names=[ + "input_latents", + ], # the model's input names + output_names=[ + "sample_out", + ], # the model's output names + ) + return safe_name + + + +if __name__ == "__main__": + from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args + + vae_model = VaeModel( + args.hf_model_name, + ) + onnx_model_name = export_vae_model( + vae_model, + args.hf_model_name, + 1, # args.batch_size, + 512, # height=args.height, + 512, # width=args.width, + "fp32" # precision=args.precision + ) + print("Saved to", onnx_model_name)