Skip to content

Commit 441077b

Browse files
committed
fix e2e tests
1 parent d24fee2 commit 441077b

File tree

3 files changed

+221
-145
lines changed

3 files changed

+221
-145
lines changed

models/turbine_models/custom_models/sd_inference/clip.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,20 @@ def export_clip_model(
5656
max_alloc=None,
5757
upload_ir=False,
5858
):
59+
input_len = 77
5960
if "google/t5" in hf_model_name:
6061
from transformers import T5Tokenizer, T5Model
6162

6263
tokenizer = T5Tokenizer.from_pretrained(hf_model_name)
6364
text_encoder_model = T5Model.from_pretrained(hf_model_name)
65+
input_len = 512
6466

6567
else:
6668
# TODO: Add better filtering mechanism for things that require CLIPProcessor
67-
if hf_model_name == "openai/clip-vit-large-patch14":
69+
if "openai" in hf_model_name:
6870
tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
6971
hf_subfolder = "" # CLIPProcessor does not have a subfolder
72+
input_len = 10
7073
else:
7174
# Load the tokenizer and text encoder to tokenize and encode the text.
7275
tokenizer = CLIPTokenizer.from_pretrained(
@@ -102,8 +105,8 @@ class CompiledClip(CompiledModule):
102105

103106
def main(
104107
self,
105-
inp=AbstractTensor(1, 77, dtype=torch.int64),
106-
decoder_input_ids=AbstractTensor(1, 77, dtype=torch.int64),
108+
inp=AbstractTensor(1, input_len, dtype=torch.int64),
109+
decoder_input_ids=AbstractTensor(1, input_len, dtype=torch.int64),
107110
):
108111
return jittable(text_encoder_model.forward)(
109112
input_ids=inp, decoder_input_ids=decoder_input_ids
@@ -122,7 +125,7 @@ class CompiledClip(CompiledModule):
122125
else:
123126
params = export_parameters(text_encoder_model)
124127

125-
def main(self, inp=AbstractTensor(1, 77, dtype=torch.int64)):
128+
def main(self, inp=AbstractTensor(1, input_len, dtype=torch.int64)):
126129
return jittable(text_encoder_model.forward)(input_ids=inp)
127130

128131
import_to = "INPUT" if compile_to == "linalg" else "IMPORT"

models/turbine_models/custom_models/sd_inference/clip_runner.py

Lines changed: 84 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from transformers import CLIPTokenizer
44
from iree import runtime as ireert
55
import torch
6+
from PIL import Image
67

78
parser = argparse.ArgumentParser()
89

@@ -52,21 +53,54 @@ def run_clip(
5253
):
5354
runner = vmfbRunner(device, vmfb_path, external_weight_path)
5455

55-
tokenizer = CLIPTokenizer.from_pretrained(
56-
hf_model_name,
57-
subfolder="tokenizer",
58-
token=hf_auth_token,
59-
)
60-
text_input = tokenizer(
61-
prompt,
62-
padding="max_length",
63-
max_length=tokenizer.model_max_length,
64-
truncation=True,
65-
return_tensors="pt",
66-
)
56+
if "google/t5" in hf_model_name:
57+
from transformers import T5Tokenizer, T5Model
58+
59+
tokenizer = T5Tokenizer.from_pretrained(hf_model_name)
60+
text_input = tokenizer(
61+
prompt,
62+
padding="max_length",
63+
max_length=tokenizer.model_max_length,
64+
truncation=True,
65+
return_tensors="pt",
66+
)
67+
# TODO: Integrate with HFTransformerBuilder
68+
else:
69+
if "openai" in hf_model_name:
70+
from transformers import CLIPProcessor
71+
import requests
72+
73+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
74+
image = Image.open(requests.get(url, stream=True).raw)
75+
tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
76+
text_input = tokenizer(
77+
text=prompt,
78+
images=image,
79+
truncation=True,
80+
padding=True,
81+
return_tensors="pt",
82+
)
83+
else:
84+
hf_subfolder = "tokenizer"
85+
86+
tokenizer = CLIPTokenizer.from_pretrained(
87+
hf_model_name,
88+
subfolder=hf_subfolder,
89+
token=hf_auth_token,
90+
)
91+
92+
text_input = tokenizer(
93+
prompt,
94+
padding="max_length",
95+
max_length=tokenizer.model_max_length,
96+
truncation=True,
97+
return_tensors="pt",
98+
)
6799
example_input = text_input.input_ids
68100
inp = [ireert.asdevicearray(runner.config.device, example_input)]
69101

102+
if "google/t5" in hf_model_name:
103+
inp += [ireert.asdevicearray(runner.config.device, example_input)]
70104
results = runner.ctx.modules.compiled_clip["main"](*inp)
71105
return results
72106

@@ -77,13 +111,38 @@ def run_torch_clip(hf_model_name, hf_auth_token, prompt):
77111

78112
tokenizer = T5Tokenizer.from_pretrained(hf_model_name)
79113
model = T5Model.from_pretrained(hf_model_name)
114+
text_input = tokenizer(
115+
prompt,
116+
padding="max_length",
117+
max_length=tokenizer.model_max_length,
118+
truncation=True,
119+
return_tensors="pt",
120+
)
80121
# TODO: Integrate with HFTransformerBuilder
81122
else:
82123
if hf_model_name == "openai/clip-vit-large-patch14":
83124
from transformers import CLIPProcessor
125+
import requests
126+
127+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
128+
image = Image.open(requests.get(url, stream=True).raw)
84129

85130
tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
86131
hf_subfolder = "" # CLIPProcessor does not have a subfolder
132+
from transformers import CLIPTextModel
133+
134+
model = CLIPTextModel.from_pretrained(
135+
hf_model_name,
136+
subfolder=hf_subfolder,
137+
token=hf_auth_token,
138+
)
139+
text_input = tokenizer(
140+
text=prompt,
141+
images=image,
142+
truncation=True,
143+
padding=True,
144+
return_tensors="pt",
145+
)
87146
else:
88147
hf_subfolder = "text_encoder"
89148

@@ -93,20 +152,20 @@ def run_torch_clip(hf_model_name, hf_auth_token, prompt):
93152
token=hf_auth_token,
94153
)
95154

96-
from transformers import CLIPTextModel
155+
from transformers import CLIPTextModel
97156

98-
model = CLIPTextModel.from_pretrained(
99-
hf_model_name,
100-
subfolder=hf_subfolder,
101-
token=hf_auth_token,
102-
)
103-
text_input = tokenizer(
104-
prompt,
105-
padding="max_length",
106-
max_length=tokenizer.model_max_length,
107-
truncation=True,
108-
return_tensors="pt",
109-
)
157+
model = CLIPTextModel.from_pretrained(
158+
hf_model_name,
159+
subfolder=hf_subfolder,
160+
token=hf_auth_token,
161+
)
162+
text_input = tokenizer(
163+
prompt,
164+
padding="max_length",
165+
max_length=tokenizer.model_max_length,
166+
truncation=True,
167+
return_tensors="pt",
168+
)
110169
example_input = text_input.input_ids
111170

112171
if "google/t5" in hf_model_name:

0 commit comments

Comments
 (0)