From 9e4a2469b17c594a3c8544a602d313ab94e18319 Mon Sep 17 00:00:00 2001 From: aviator19941 Date: Mon, 4 Dec 2023 19:29:01 +0000 Subject: [PATCH] Move reusable functions to utils --- .../custom_models/sd_inference/clip_test.py | 24 +------------------ .../custom_models/sd_inference/unet_test.py | 23 +----------------- .../custom_models/sd_inference/utils.py | 23 ++++++++++++++++++ .../custom_models/sd_inference/vae_test.py | 23 +----------------- 4 files changed, 26 insertions(+), 67 deletions(-) create mode 100644 python/turbine_models/custom_models/sd_inference/utils.py diff --git a/python/turbine_models/custom_models/sd_inference/clip_test.py b/python/turbine_models/custom_models/sd_inference/clip_test.py index a89610b6a..de0a2d6e9 100644 --- a/python/turbine_models/custom_models/sd_inference/clip_test.py +++ b/python/turbine_models/custom_models/sd_inference/clip_test.py @@ -13,11 +13,11 @@ from iree.compiler.ir import Context import numpy as np from shark_turbine.aot import * +from utils import * import torch import torch._dynamo as dynamo from transformers import CLIPTextModel, CLIPTokenizer -import safetensors import argparse parser = argparse.ArgumentParser() @@ -44,22 +44,6 @@ prompt = ["a photograph of an astronaut riding a horse"] -def save_external_weights( - mapper, - model, - external_weights=None, - external_weight_file=None, -): - if external_weights is not None: - if external_weights == "safetensors": - mod_params = dict(model.named_parameters()) - for name in mod_params: - mapper["params." + name] = name - if external_weight_file: - safetensors.torch.save_file(mod_params, external_weight_file) - print("Saved params to", external_weight_file) - - def export_clip_model(args): # Load the tokenizer and text encoder to tokenize and encode the text. tokenizer = CLIPTokenizer.from_pretrained( @@ -131,12 +115,6 @@ def main(self, inp=AbstractTensor(1, 77, dtype=torch.int64)): exit() -def largest_error(array1, array2): - absolute_diff = np.abs(array1 - array2) - max_error = np.max(absolute_diff) - return max_error - - def run_clip_vmfb_comparison(args): config = ireert.Config("local-task") diff --git a/python/turbine_models/custom_models/sd_inference/unet_test.py b/python/turbine_models/custom_models/sd_inference/unet_test.py index 0c8ee7358..d349db04b 100644 --- a/python/turbine_models/custom_models/sd_inference/unet_test.py +++ b/python/turbine_models/custom_models/sd_inference/unet_test.py @@ -13,6 +13,7 @@ from iree.compiler.ir import Context import numpy as np from shark_turbine.aot import * +from utils import * import torch import torch._dynamo as dynamo from diffusers import UNet2DConditionModel @@ -64,22 +65,6 @@ def forward(self, sample, timestep, encoder_hidden_states): return noise_pred -def save_external_weights( - mapper, - model, - external_weights=None, - external_weight_file=None, -): - if external_weights is not None: - if external_weights == "safetensors": - mod_params = dict(model.named_parameters()) - for name in mod_params: - mapper["params." + name] = name - if external_weight_file: - safetensors.torch.save_file(mod_params, external_weight_file) - print("Saved params to", external_weight_file) - - def export_unet_model(args, unet_model): mapper = {} save_external_weights( @@ -227,12 +212,6 @@ def run_unet_vmfb_comparison(args): assert err < 9e-5 -def largest_error(array1, array2): - absolute_diff = np.abs(array1 - array2) - max_error = np.max(absolute_diff) - return max_error - - if __name__ == "__main__": args = parser.parse_args() unet_model = UnetModel(args) diff --git a/python/turbine_models/custom_models/sd_inference/utils.py b/python/turbine_models/custom_models/sd_inference/utils.py new file mode 100644 index 000000000..610331584 --- /dev/null +++ b/python/turbine_models/custom_models/sd_inference/utils.py @@ -0,0 +1,23 @@ +import numpy as np +import safetensors + +def save_external_weights( + mapper, + model, + external_weights=None, + external_weight_file=None, +): + if external_weights is not None: + if external_weights == "safetensors": + mod_params = dict(model.named_parameters()) + for name in mod_params: + mapper["params." + name] = name + if external_weight_file: + safetensors.torch.save_file(mod_params, external_weight_file) + print("Saved params to", external_weight_file) + + +def largest_error(array1, array2): + absolute_diff = np.abs(array1 - array2) + max_error = np.max(absolute_diff) + return max_error \ No newline at end of file diff --git a/python/turbine_models/custom_models/sd_inference/vae_test.py b/python/turbine_models/custom_models/sd_inference/vae_test.py index 09b6d739f..fc5f8eb95 100644 --- a/python/turbine_models/custom_models/sd_inference/vae_test.py +++ b/python/turbine_models/custom_models/sd_inference/vae_test.py @@ -13,6 +13,7 @@ from iree.compiler.ir import Context import numpy as np from shark_turbine.aot import * +from utils import * import torch import torch._dynamo as dynamo from diffusers import AutoencoderKL @@ -57,22 +58,6 @@ def forward(self, inp): return x -def save_external_weights( - mapper, - model, - external_weights=None, - external_weight_file=None, -): - if external_weights is not None: - if external_weights == "safetensors": - mod_params = dict(model.named_parameters()) - for name in mod_params: - mapper["params." + name] = name - if external_weight_file: - safetensors.torch.save_file(mod_params, external_weight_file) - print("Saved params to", external_weight_file) - - def export_vae_model(args, vae_model): mapper = {} save_external_weights( @@ -124,12 +109,6 @@ def main(self, inp=AbstractTensor(1, 4, 64, 64, dtype=torch.float32)): exit() -def largest_error(array1, array2): - absolute_diff = np.abs(array1 - array2) - max_error = np.max(absolute_diff) - return max_error - - def run_vae_vmfb_comparison(args, vae_model): config = ireert.Config("local-task")