Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clip_vit_large14 and t5 models #535

Merged
merged 11 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 62 additions & 24 deletions models/turbine_models/custom_models/sd_inference/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import os
import sys
import re
IanNod marked this conversation as resolved.
Show resolved Hide resolved

from iree import runtime as ireert
import iree.compiler as ireec
Expand All @@ -15,7 +16,7 @@
from turbine_models.custom_models.sd_inference import utils
import torch
import torch._dynamo as dynamo
from transformers import CLIPTextModel, CLIPTokenizer
from transformers import CLIPTextModel, CLIPTokenizer, CLIPProcessor
from turbine_models.turbine_tank import turbine_tank

import argparse
Expand Down Expand Up @@ -60,37 +61,74 @@ def export_clip_model(
max_alloc=None,
upload_ir=False,
):
# Load the tokenizer and text encoder to tokenize and encode the text.
tokenizer = CLIPTokenizer.from_pretrained(
hf_model_name,
subfolder="tokenizer",
token=hf_auth_token,
)
if "google/t5" in hf_model_name:
from transformers import T5Tokenizer, T5Model

text_encoder_model = CLIPTextModel.from_pretrained(
hf_model_name,
subfolder="text_encoder",
token=hf_auth_token,
)
tokenizer = T5Tokenizer.from_pretrained(hf_model_name)
text_encoder_model = T5Model.from_pretrained(hf_model_name)

else:
# TODO: Add better filtering mechanism for things that require CLIPProcessor
if hf_model_name == "openai/clip-vit-large-patch14":
tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
hf_subfolder = "" # CLIPProcessor does not have a subfolder
else:
# Load the tokenizer and text encoder to tokenize and encode the text.
tokenizer = CLIPTokenizer.from_pretrained(
hf_model_name,
subfolder="tokenizer",
token=hf_auth_token,
)
hf_subfolder = "text_encoder"

text_encoder_model = CLIPTextModel.from_pretrained(
hf_model_name,
subfolder=hf_subfolder,
token=hf_auth_token,
)

mapper = {}
utils.save_external_weights(
mapper, text_encoder_model, external_weights, external_weight_path
)

class CompiledClip(CompiledModule):
if external_weights:
params = export_parameters(
text_encoder_model,
external=True,
external_scope="",
name_mapper=mapper.get,
)
else:
params = export_parameters(text_encoder_model)
if "google/t5" in hf_model_name:

class CompiledClip(CompiledModule):
if external_weights:
params = export_parameters(
text_encoder_model,
external=True,
external_scope="",
name_mapper=mapper.get,
)
else:
params = export_parameters(text_encoder_model)

def main(
self,
inp=AbstractTensor(1, 77, dtype=torch.int64),
decoder_input_ids=AbstractTensor(1, 77, dtype=torch.int64),
):
return jittable(text_encoder_model.forward)(
input_ids=inp, decoder_input_ids=decoder_input_ids
)

else:

class CompiledClip(CompiledModule):
if external_weights:
params = export_parameters(
text_encoder_model,
external=True,
external_scope="",
name_mapper=mapper.get,
)
else:
params = export_parameters(text_encoder_model)

def main(self, inp=AbstractTensor(1, 77, dtype=torch.int64)):
return jittable(text_encoder_model.forward)(inp)
def main(self, inp=AbstractTensor(1, 77, dtype=torch.int64)):
return jittable(text_encoder_model.forward)(input_ids=inp)

import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
inst = CompiledClip(context=Context(), import_to=import_to)
Expand Down
34 changes: 34 additions & 0 deletions models/turbine_models/tests/sd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,38 @@


class StableDiffusionTest(unittest.TestCase):
dan-garvey marked this conversation as resolved.
Show resolved Hide resolved
def testExportT5Model(self):
upload_ir_var = os.environ.get("TURBINE_TANK_ACTION", "not_upload")
with self.assertRaises(SystemExit) as cm:
clip.export_clip_model(
hf_model_name="google/t5-v1_1-small",
hf_auth_token=None,
compile_to="vmfb",
external_weights=None,
external_weight_path=None,
device="cpu",
target_triple=None,
max_alloc=None,
upload_ir=upload_ir_var == "upload",
)
self.assertEqual(cm.exception.code, None)

def testExportClipVitLarge14(self):
upload_ir_var = os.environ.get("TURBINE_TANK_ACTION", "not_upload")
with self.assertRaises(SystemExit) as cm:
clip.export_clip_model(
hf_model_name="openai/clip-vit-large-patch14",
hf_auth_token=None,
compile_to="vmfb",
external_weights="safetensors",
external_weight_path="openai_clip-vit-large-patch14.safetensors",
device="cpu",
target_triple=None,
max_alloc=None,
upload_ir=upload_ir_var == "upload",
)
self.assertEqual(cm.exception.code, None)

def testExportClipModel(self):
upload_ir_var = os.environ.get("TURBINE_TANK_ACTION", "not_upload")
with self.assertRaises(SystemExit) as cm:
Expand Down Expand Up @@ -198,6 +230,8 @@ def testExportVaeModelDecode(self):
os.remove("stable_diffusion_v1_4_vae.safetensors")
os.remove("stable_diffusion_v1_4_vae.vmfb")

# https://github.com/nod-ai/SHARK-Turbine/issues/536
@unittest.expectedFailure
def testExportVaeModelEncode(self):
upload_ir_var = os.environ.get("TURBINE_TANK_ACTION", "not_upload")
with self.assertRaises(SystemExit) as cm:
Expand Down
Loading