diff --git a/models/turbine_models/custom_models/sd_inference/clip.py b/models/turbine_models/custom_models/sd_inference/clip.py index 4cc5f91dd..17999ab45 100644 --- a/models/turbine_models/custom_models/sd_inference/clip.py +++ b/models/turbine_models/custom_models/sd_inference/clip.py @@ -5,17 +5,13 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import os -import sys +import re -from iree import runtime as ireert -import iree.compiler as ireec from iree.compiler.ir import Context -import numpy as np from shark_turbine.aot import * from turbine_models.custom_models.sd_inference import utils import torch -import torch._dynamo as dynamo -from transformers import CLIPTextModel, CLIPTokenizer +from transformers import CLIPTextModel, CLIPTokenizer, CLIPProcessor from turbine_models.turbine_tank import turbine_tank import argparse @@ -60,37 +56,77 @@ def export_clip_model( max_alloc=None, upload_ir=False, ): - # Load the tokenizer and text encoder to tokenize and encode the text. - tokenizer = CLIPTokenizer.from_pretrained( - hf_model_name, - subfolder="tokenizer", - token=hf_auth_token, - ) + input_len = 77 + if "google/t5" in hf_model_name: + from transformers import T5Tokenizer, T5Model - text_encoder_model = CLIPTextModel.from_pretrained( - hf_model_name, - subfolder="text_encoder", - token=hf_auth_token, - ) + tokenizer = T5Tokenizer.from_pretrained(hf_model_name) + text_encoder_model = T5Model.from_pretrained(hf_model_name) + input_len = 512 + + else: + # TODO: Add better filtering mechanism for things that require CLIPProcessor + if "openai" in hf_model_name: + tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") + hf_subfolder = "" # CLIPProcessor does not have a subfolder + input_len = 10 + else: + # Load the tokenizer and text encoder to tokenize and encode the text. + tokenizer = CLIPTokenizer.from_pretrained( + hf_model_name, + subfolder="tokenizer", + token=hf_auth_token, + ) + hf_subfolder = "text_encoder" + + text_encoder_model = CLIPTextModel.from_pretrained( + hf_model_name, + subfolder=hf_subfolder, + token=hf_auth_token, + ) mapper = {} utils.save_external_weights( mapper, text_encoder_model, external_weights, external_weight_path ) - class CompiledClip(CompiledModule): - if external_weights: - params = export_parameters( - text_encoder_model, - external=True, - external_scope="", - name_mapper=mapper.get, - ) - else: - params = export_parameters(text_encoder_model) + if "google/t5" in hf_model_name: + + class CompiledClip(CompiledModule): + if external_weights: + params = export_parameters( + text_encoder_model, + external=True, + external_scope="", + name_mapper=mapper.get, + ) + else: + params = export_parameters(text_encoder_model) + + def main( + self, + inp=AbstractTensor(1, input_len, dtype=torch.int64), + decoder_input_ids=AbstractTensor(1, input_len, dtype=torch.int64), + ): + return jittable(text_encoder_model.forward)( + input_ids=inp, decoder_input_ids=decoder_input_ids + ) + + else: + + class CompiledClip(CompiledModule): + if external_weights: + params = export_parameters( + text_encoder_model, + external=True, + external_scope="", + name_mapper=mapper.get, + ) + else: + params = export_parameters(text_encoder_model) - def main(self, inp=AbstractTensor(1, 77, dtype=torch.int64)): - return jittable(text_encoder_model.forward)(inp) + def main(self, inp=AbstractTensor(1, input_len, dtype=torch.int64)): + return jittable(text_encoder_model.forward)(input_ids=inp) import_to = "INPUT" if compile_to == "linalg" else "IMPORT" inst = CompiledClip(context=Context(), import_to=import_to) diff --git a/models/turbine_models/custom_models/sd_inference/clip_runner.py b/models/turbine_models/custom_models/sd_inference/clip_runner.py index b7f046e2e..c72b5e221 100644 --- a/models/turbine_models/custom_models/sd_inference/clip_runner.py +++ b/models/turbine_models/custom_models/sd_inference/clip_runner.py @@ -3,6 +3,7 @@ from transformers import CLIPTokenizer from iree import runtime as ireert import torch +from PIL import Image parser = argparse.ArgumentParser() @@ -52,49 +53,125 @@ def run_clip( ): runner = vmfbRunner(device, vmfb_path, external_weight_path) - tokenizer = CLIPTokenizer.from_pretrained( - hf_model_name, - subfolder="tokenizer", - token=hf_auth_token, - ) - text_input = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) + if "google/t5" in hf_model_name: + from transformers import T5Tokenizer, T5Model + + tokenizer = T5Tokenizer.from_pretrained(hf_model_name) + text_input = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + # TODO: Integrate with HFTransformerBuilder + else: + if "openai" in hf_model_name: + from transformers import CLIPProcessor + import requests + + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") + text_input = tokenizer( + text=prompt, + images=image, + truncation=True, + padding=True, + return_tensors="pt", + ) + else: + hf_subfolder = "tokenizer" + + tokenizer = CLIPTokenizer.from_pretrained( + hf_model_name, + subfolder=hf_subfolder, + token=hf_auth_token, + ) + + text_input = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) example_input = text_input.input_ids inp = [ireert.asdevicearray(runner.config.device, example_input)] + if "google/t5" in hf_model_name: + inp += [ireert.asdevicearray(runner.config.device, example_input)] results = runner.ctx.modules.compiled_clip["main"](*inp) return results def run_torch_clip(hf_model_name, hf_auth_token, prompt): + if "google/t5" in hf_model_name: + from transformers import T5Tokenizer, T5Model + + tokenizer = T5Tokenizer.from_pretrained(hf_model_name) + model = T5Model.from_pretrained(hf_model_name) + text_input = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) # TODO: Integrate with HFTransformerBuilder - from transformers import CLIPTextModel + else: + if hf_model_name == "openai/clip-vit-large-patch14": + from transformers import CLIPProcessor + import requests - model = CLIPTextModel.from_pretrained( - hf_model_name, - subfolder="text_encoder", - token=hf_auth_token, - ) - tokenizer = CLIPTokenizer.from_pretrained( - hf_model_name, - subfolder="tokenizer", - token=hf_auth_token, - ) - text_input = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + + tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") + hf_subfolder = "" # CLIPProcessor does not have a subfolder + from transformers import CLIPTextModel + + model = CLIPTextModel.from_pretrained( + hf_model_name, + subfolder=hf_subfolder, + token=hf_auth_token, + ) + text_input = tokenizer( + text=prompt, + images=image, + truncation=True, + padding=True, + return_tensors="pt", + ) + else: + hf_subfolder = "text_encoder" + + tokenizer = CLIPTokenizer.from_pretrained( + hf_model_name, + subfolder="tokenizer", + token=hf_auth_token, + ) + + from transformers import CLIPTextModel + + model = CLIPTextModel.from_pretrained( + hf_model_name, + subfolder=hf_subfolder, + token=hf_auth_token, + ) + text_input = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) example_input = text_input.input_ids - results = model.forward(example_input)[0] + if "google/t5" in hf_model_name: + results = model.forward(example_input, decoder_input_ids=example_input)[0] + else: + results = model.forward(example_input)[0] np_torch_output = results.detach().cpu().numpy() return np_torch_output diff --git a/models/turbine_models/tests/sd_test.py b/models/turbine_models/tests/sd_test.py index 9d00fb9e5..bce03a455 100644 --- a/models/turbine_models/tests/sd_test.py +++ b/models/turbine_models/tests/sd_test.py @@ -21,9 +21,11 @@ import torch import unittest import os +import copy +import platform -arguments = { +default_arguments = { "hf_auth_token": None, "hf_model_name": "CompVis/stable-diffusion-v1-4", "scheduler_id": "PNDM", @@ -42,6 +44,7 @@ "prompt": "a photograph of an astronaut riding a horse", "in_channels": 4, } +UPLOAD_IR = os.environ.get("TURBINE_TANK_ACTION", "not_upload") == "upload" unet_model = unet.UnetModel( @@ -60,15 +63,92 @@ # This is a public model, so no auth required "CompVis/stable-diffusion-v1-4", ) -scheduler = schedulers_dict[arguments["scheduler_id"]] +scheduler = schedulers_dict[default_arguments["scheduler_id"]] scheduler_module = schedulers.Scheduler( - "CompVis/stable-diffusion-v1-4", arguments["num_inference_steps"], scheduler + "CompVis/stable-diffusion-v1-4", default_arguments["num_inference_steps"], scheduler ) +# TODO: this is a mess, don't share args across tests, create a copy for each test class StableDiffusionTest(unittest.TestCase): + def testExportT5Model(self): + current_args = copy.deepcopy(default_arguments) + current_args["hf_model_name"] = "google/t5-v1_1-small" + safe_prefix = "t5_v1_1_small" + with self.assertRaises(SystemExit) as cm: + clip.export_clip_model( + hf_model_name=current_args["hf_model_name"], + hf_auth_token=None, + compile_to="vmfb", + external_weights=None, + external_weight_path=None, + device="cpu", + target_triple=None, + max_alloc=None, + upload_ir=UPLOAD_IR, + ) + self.assertEqual(cm.exception.code, None) + current_args["vmfb_path"] = safe_prefix + "_clip.vmfb" + turbine = clip_runner.run_clip( + current_args["device"], + current_args["prompt"], + current_args["vmfb_path"], + current_args["hf_model_name"], + current_args["hf_auth_token"], + None, + ) + torch_output = clip_runner.run_torch_clip( + current_args["hf_model_name"], + current_args["hf_auth_token"], + current_args["prompt"], + ) + err = utils.largest_error(torch_output, turbine[0]) + assert err < 9e-4 + if platform.system() != "Windows": + os.remove(current_args["vmfb_path"]) + del current_args + + def testExportClipVitLarge14(self): + current_args = copy.deepcopy(default_arguments) + current_args["hf_model_name"] = "openai/clip-vit-large-patch14" + safe_prefix = "clip_vit_large_patch14" + with self.assertRaises(SystemExit) as cm: + clip.export_clip_model( + hf_model_name=current_args["hf_model_name"], + hf_auth_token=None, + compile_to="vmfb", + external_weights="safetensors", + external_weight_path=safe_prefix + ".safetensors", + device="cpu", + target_triple=None, + max_alloc=None, + upload_ir=UPLOAD_IR, + ) + self.assertEqual(cm.exception.code, None) + current_args["external_weight_path"] = safe_prefix + ".safetensors" + current_args["vmfb_path"] = safe_prefix + "_clip.vmfb" + turbine = clip_runner.run_clip( + current_args["device"], + current_args["prompt"], + current_args["vmfb_path"], + current_args["hf_model_name"], + current_args["hf_auth_token"], + current_args["external_weight_path"], + ) + torch_output = clip_runner.run_torch_clip( + current_args["hf_model_name"], + current_args["hf_auth_token"], + current_args["prompt"], + ) + err = utils.largest_error(torch_output, turbine[0]) + assert err < 9e-5 + if platform.system() != "Windows": + os.remove(current_args["external_weight_path"]) + os.remove(current_args["vmfb_path"]) + def testExportClipModel(self): - upload_ir_var = os.environ.get("TURBINE_TANK_ACTION", "not_upload") + current_args = copy.deepcopy(default_arguments) + current_args["hf_model_name"] = "CompVis/stable-diffusion-v1-4" with self.assertRaises(SystemExit) as cm: clip.export_clip_model( # This is a public model, so no auth required @@ -78,21 +158,23 @@ def testExportClipModel(self): "safetensors", "stable_diffusion_v1_4_clip.safetensors", "cpu", - upload_ir=upload_ir_var == "upload", + upload_ir=UPLOAD_IR, ) self.assertEqual(cm.exception.code, None) - arguments["external_weight_path"] = "stable_diffusion_v1_4_clip.safetensors" - arguments["vmfb_path"] = "stable_diffusion_v1_4_clip.vmfb" + current_args["external_weight_path"] = "stable_diffusion_v1_4_clip.safetensors" + current_args["vmfb_path"] = "stable_diffusion_v1_4_clip.vmfb" turbine = clip_runner.run_clip( - arguments["device"], - arguments["prompt"], - arguments["vmfb_path"], - arguments["hf_model_name"], - arguments["hf_auth_token"], - arguments["external_weight_path"], + current_args["device"], + current_args["prompt"], + current_args["vmfb_path"], + current_args["hf_model_name"], + current_args["hf_auth_token"], + current_args["external_weight_path"], ) torch_output = clip_runner.run_torch_clip( - arguments["hf_model_name"], arguments["hf_auth_token"], arguments["prompt"] + current_args["hf_model_name"], + current_args["hf_auth_token"], + current_args["prompt"], ) err = utils.largest_error(torch_output, turbine[0]) assert err < 9e-5 @@ -100,48 +182,48 @@ def testExportClipModel(self): os.remove("stable_diffusion_v1_4_clip.vmfb") def testExportUnetModel(self): - upload_ir_var = os.environ.get("TURBINE_TANK_ACTION", "not_upload") + current_args = copy.deepcopy(default_arguments) with self.assertRaises(SystemExit) as cm: unet.export_unet_model( unet_model, # This is a public model, so no auth required "CompVis/stable-diffusion-v1-4", - arguments["batch_size"], - arguments["height"], - arguments["width"], + current_args["batch_size"], + current_args["height"], + current_args["width"], None, "vmfb", "safetensors", "stable_diffusion_v1_4_unet.safetensors", "cpu", - upload_ir=upload_ir_var == "upload", + upload_ir=UPLOAD_IR, ) self.assertEqual(cm.exception.code, None) - arguments["external_weight_path"] = "stable_diffusion_v1_4_unet.safetensors" - arguments["vmfb_path"] = "stable_diffusion_v1_4_unet.vmfb" + current_args["external_weight_path"] = "stable_diffusion_v1_4_unet.safetensors" + current_args["vmfb_path"] = "stable_diffusion_v1_4_unet.vmfb" sample = torch.rand( - arguments["batch_size"], - arguments["in_channels"], - arguments["height"] // 8, - arguments["width"] // 8, + current_args["batch_size"], + current_args["in_channels"], + current_args["height"] // 8, + current_args["width"] // 8, dtype=torch.float32, ) timestep = torch.zeros(1, dtype=torch.float32) encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32) turbine = unet_runner.run_unet( - arguments["device"], + current_args["device"], sample, timestep, encoder_hidden_states, - arguments["vmfb_path"], - arguments["hf_model_name"], - arguments["hf_auth_token"], - arguments["external_weight_path"], + current_args["vmfb_path"], + current_args["hf_model_name"], + current_args["hf_auth_token"], + current_args["external_weight_path"], ) torch_output = unet_runner.run_torch_unet( - arguments["hf_model_name"], - arguments["hf_auth_token"], + current_args["hf_model_name"], + current_args["hf_auth_token"], sample, timestep, encoder_hidden_states, @@ -152,44 +234,44 @@ def testExportUnetModel(self): os.remove("stable_diffusion_v1_4_unet.vmfb") def testExportVaeModelDecode(self): - upload_ir_var = os.environ.get("TURBINE_TANK_ACTION", "not_upload") + current_args = copy.deepcopy(default_arguments) with self.assertRaises(SystemExit) as cm: vae.export_vae_model( vae_model, # This is a public model, so no auth required "CompVis/stable-diffusion-v1-4", - arguments["batch_size"], - arguments["height"], - arguments["width"], + current_args["batch_size"], + current_args["height"], + current_args["width"], None, "vmfb", "safetensors", "stable_diffusion_v1_4_vae.safetensors", "cpu", variant="decode", - upload_ir=upload_ir_var == "upload", + upload_ir=UPLOAD_IR, ) self.assertEqual(cm.exception.code, None) - arguments["external_weight_path"] = "stable_diffusion_v1_4_vae.safetensors" - arguments["vmfb_path"] = "stable_diffusion_v1_4_vae.vmfb" + current_args["external_weight_path"] = "stable_diffusion_v1_4_vae.safetensors" + current_args["vmfb_path"] = "stable_diffusion_v1_4_vae.vmfb" example_input = torch.rand( - arguments["batch_size"], + current_args["batch_size"], 4, - arguments["height"] // 8, - arguments["width"] // 8, + current_args["height"] // 8, + current_args["width"] // 8, dtype=torch.float32, ) turbine = vae_runner.run_vae( - arguments["device"], + current_args["device"], example_input, - arguments["vmfb_path"], - arguments["hf_model_name"], - arguments["hf_auth_token"], - arguments["external_weight_path"], + current_args["vmfb_path"], + current_args["hf_model_name"], + current_args["hf_auth_token"], + current_args["external_weight_path"], ) torch_output = vae_runner.run_torch_vae( - arguments["hf_model_name"], - arguments["hf_auth_token"], + current_args["hf_model_name"], + current_args["hf_auth_token"], "decode", example_input, ) @@ -198,45 +280,47 @@ def testExportVaeModelDecode(self): os.remove("stable_diffusion_v1_4_vae.safetensors") os.remove("stable_diffusion_v1_4_vae.vmfb") + # https://github.com/nod-ai/SHARK-Turbine/issues/536 + @unittest.expectedFailure def testExportVaeModelEncode(self): - upload_ir_var = os.environ.get("TURBINE_TANK_ACTION", "not_upload") + current_args = copy.deepcopy(default_arguments) with self.assertRaises(SystemExit) as cm: vae.export_vae_model( vae_model, # This is a public model, so no auth required "CompVis/stable-diffusion-v1-4", - arguments["batch_size"], - arguments["height"], - arguments["width"], + current_args["batch_size"], + current_args["height"], + current_args["width"], None, "vmfb", "safetensors", "stable_diffusion_v1_4_vae.safetensors", "cpu", variant="encode", - upload_ir=upload_ir_var == "upload", + upload_ir=UPLOAD_IR, ) self.assertEqual(cm.exception.code, None) - arguments["external_weight_path"] = "stable_diffusion_v1_4_vae.safetensors" - arguments["vmfb_path"] = "stable_diffusion_v1_4_vae.vmfb" + current_args["external_weight_path"] = "stable_diffusion_v1_4_vae.safetensors" + current_args["vmfb_path"] = "stable_diffusion_v1_4_vae.vmfb" example_input = torch.rand( - arguments["batch_size"], + current_args["batch_size"], 3, - arguments["height"], - arguments["width"], + current_args["height"], + current_args["width"], dtype=torch.float32, ) turbine = vae_runner.run_vae( - arguments["device"], + current_args["device"], example_input, - arguments["vmfb_path"], - arguments["hf_model_name"], - arguments["hf_auth_token"], - arguments["external_weight_path"], + current_args["vmfb_path"], + current_args["hf_model_name"], + current_args["hf_auth_token"], + current_args["external_weight_path"], ) torch_output = vae_runner.run_torch_vae( - arguments["hf_model_name"], - arguments["hf_auth_token"], + current_args["hf_model_name"], + current_args["hf_auth_token"], "encode", example_input, ) @@ -247,48 +331,47 @@ def testExportVaeModelEncode(self): @unittest.expectedFailure def testExportPNDMScheduler(self): - upload_ir_var = os.environ.get("TURBINE_TANK_ACTION", "not_upload") + current_args = copy.deepcopy(default_arguments) + safe_name = "stable_diffusion_v1_4_scheduler" with self.assertRaises(SystemExit) as cm: schedulers.export_scheduler( scheduler_module, # This is a public model, so no auth required "CompVis/stable-diffusion-v1-4", - arguments["batch_size"], - arguments["height"], - arguments["width"], + current_args["batch_size"], + current_args["height"], + current_args["width"], None, "vmfb", "safetensors", "stable_diffusion_v1_4_scheduler.safetensors", "cpu", - upload_ir=upload_ir_var == "upload", + upload_ir=UPLOAD_IR, ) self.assertEqual(cm.exception.code, None) - arguments[ - "external_weight_path" - ] = "stable_diffusion_v1_4_scheduler.safetensors" - arguments["vmfb_path"] = "stable_diffusion_v1_4_scheduler.vmfb" + current_args["external_weight_path"] = safe_name + ".safetensors" + current_args["vmfb_path"] = safe_name + ".vmfb" sample = torch.rand( - arguments["batch_size"], + current_args["batch_size"], 4, - arguments["height"] // 8, - arguments["width"] // 8, + current_args["height"] // 8, + current_args["width"] // 8, dtype=torch.float32, ) encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32) turbine = schedulers_runner.run_scheduler( - arguments["device"], + current_args["device"], sample, encoder_hidden_states, - arguments["vmfb_path"], - arguments["hf_model_name"], - arguments["hf_auth_token"], - arguments["external_weight_path"], + current_args["vmfb_path"], + current_args["hf_model_name"], + current_args["hf_auth_token"], + current_args["external_weight_path"], ) torch_output = schedulers_runner.run_torch_scheduler( - arguments["hf_model_name"], + current_args["hf_model_name"], scheduler, - arguments["num_inference_steps"], + current_args["num_inference_steps"], sample, encoder_hidden_states, )