Skip to content

Commit

Permalink
Move reusable functions to utils
Browse files Browse the repository at this point in the history
  • Loading branch information
aviator19941 committed Dec 4, 2023
1 parent f4e046a commit 9e4a246
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 67 deletions.
24 changes: 1 addition & 23 deletions python/turbine_models/custom_models/sd_inference/clip_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from iree.compiler.ir import Context
import numpy as np
from shark_turbine.aot import *
from utils import *
import torch
import torch._dynamo as dynamo
from transformers import CLIPTextModel, CLIPTokenizer

import safetensors
import argparse

parser = argparse.ArgumentParser()
Expand All @@ -44,22 +44,6 @@
prompt = ["a photograph of an astronaut riding a horse"]


def save_external_weights(
mapper,
model,
external_weights=None,
external_weight_file=None,
):
if external_weights is not None:
if external_weights == "safetensors":
mod_params = dict(model.named_parameters())
for name in mod_params:
mapper["params." + name] = name
if external_weight_file:
safetensors.torch.save_file(mod_params, external_weight_file)
print("Saved params to", external_weight_file)


def export_clip_model(args):
# Load the tokenizer and text encoder to tokenize and encode the text.
tokenizer = CLIPTokenizer.from_pretrained(
Expand Down Expand Up @@ -131,12 +115,6 @@ def main(self, inp=AbstractTensor(1, 77, dtype=torch.int64)):
exit()


def largest_error(array1, array2):
absolute_diff = np.abs(array1 - array2)
max_error = np.max(absolute_diff)
return max_error


def run_clip_vmfb_comparison(args):
config = ireert.Config("local-task")

Expand Down
23 changes: 1 addition & 22 deletions python/turbine_models/custom_models/sd_inference/unet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from iree.compiler.ir import Context
import numpy as np
from shark_turbine.aot import *
from utils import *
import torch
import torch._dynamo as dynamo
from diffusers import UNet2DConditionModel
Expand Down Expand Up @@ -64,22 +65,6 @@ def forward(self, sample, timestep, encoder_hidden_states):
return noise_pred


def save_external_weights(
mapper,
model,
external_weights=None,
external_weight_file=None,
):
if external_weights is not None:
if external_weights == "safetensors":
mod_params = dict(model.named_parameters())
for name in mod_params:
mapper["params." + name] = name
if external_weight_file:
safetensors.torch.save_file(mod_params, external_weight_file)
print("Saved params to", external_weight_file)


def export_unet_model(args, unet_model):
mapper = {}
save_external_weights(
Expand Down Expand Up @@ -227,12 +212,6 @@ def run_unet_vmfb_comparison(args):
assert err < 9e-5


def largest_error(array1, array2):
absolute_diff = np.abs(array1 - array2)
max_error = np.max(absolute_diff)
return max_error


if __name__ == "__main__":
args = parser.parse_args()
unet_model = UnetModel(args)
Expand Down
23 changes: 23 additions & 0 deletions python/turbine_models/custom_models/sd_inference/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import numpy as np
import safetensors

def save_external_weights(
mapper,
model,
external_weights=None,
external_weight_file=None,
):
if external_weights is not None:
if external_weights == "safetensors":
mod_params = dict(model.named_parameters())
for name in mod_params:
mapper["params." + name] = name
if external_weight_file:
safetensors.torch.save_file(mod_params, external_weight_file)
print("Saved params to", external_weight_file)


def largest_error(array1, array2):
absolute_diff = np.abs(array1 - array2)
max_error = np.max(absolute_diff)
return max_error
23 changes: 1 addition & 22 deletions python/turbine_models/custom_models/sd_inference/vae_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from iree.compiler.ir import Context
import numpy as np
from shark_turbine.aot import *
from utils import *
import torch
import torch._dynamo as dynamo
from diffusers import AutoencoderKL
Expand Down Expand Up @@ -57,22 +58,6 @@ def forward(self, inp):
return x


def save_external_weights(
mapper,
model,
external_weights=None,
external_weight_file=None,
):
if external_weights is not None:
if external_weights == "safetensors":
mod_params = dict(model.named_parameters())
for name in mod_params:
mapper["params." + name] = name
if external_weight_file:
safetensors.torch.save_file(mod_params, external_weight_file)
print("Saved params to", external_weight_file)


def export_vae_model(args, vae_model):
mapper = {}
save_external_weights(
Expand Down Expand Up @@ -124,12 +109,6 @@ def main(self, inp=AbstractTensor(1, 4, 64, 64, dtype=torch.float32)):
exit()


def largest_error(array1, array2):
absolute_diff = np.abs(array1 - array2)
max_error = np.max(absolute_diff)
return max_error


def run_vae_vmfb_comparison(args, vae_model):
config = ireert.Config("local-task")

Expand Down

0 comments on commit 9e4a246

Please sign in to comment.