Skip to content

Commit c1f30ed

Browse files
authored
Merge branch 'main' into ean-sd-fp16
2 parents df7ac8c + 36a7091 commit c1f30ed

File tree

3 files changed

+352
-179
lines changed

3 files changed

+352
-179
lines changed

models/turbine_models/custom_models/sd_inference/clip.py

Lines changed: 65 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,13 @@
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
import os
8-
import sys
8+
import re
99

10-
from iree import runtime as ireert
11-
import iree.compiler as ireec
1210
from iree.compiler.ir import Context
13-
import numpy as np
1411
from shark_turbine.aot import *
1512
from turbine_models.custom_models.sd_inference import utils
1613
import torch
17-
import torch._dynamo as dynamo
18-
from transformers import CLIPTextModel, CLIPTokenizer
14+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPProcessor
1915
from turbine_models.turbine_tank import turbine_tank
2016

2117
import argparse
@@ -60,37 +56,77 @@ def export_clip_model(
6056
max_alloc=None,
6157
upload_ir=False,
6258
):
63-
# Load the tokenizer and text encoder to tokenize and encode the text.
64-
tokenizer = CLIPTokenizer.from_pretrained(
65-
hf_model_name,
66-
subfolder="tokenizer",
67-
token=hf_auth_token,
68-
)
59+
input_len = 77
60+
if "google/t5" in hf_model_name:
61+
from transformers import T5Tokenizer, T5Model
6962

70-
text_encoder_model = CLIPTextModel.from_pretrained(
71-
hf_model_name,
72-
subfolder="text_encoder",
73-
token=hf_auth_token,
74-
)
63+
tokenizer = T5Tokenizer.from_pretrained(hf_model_name)
64+
text_encoder_model = T5Model.from_pretrained(hf_model_name)
65+
input_len = 512
66+
67+
else:
68+
# TODO: Add better filtering mechanism for things that require CLIPProcessor
69+
if "openai" in hf_model_name:
70+
tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
71+
hf_subfolder = "" # CLIPProcessor does not have a subfolder
72+
input_len = 10
73+
else:
74+
# Load the tokenizer and text encoder to tokenize and encode the text.
75+
tokenizer = CLIPTokenizer.from_pretrained(
76+
hf_model_name,
77+
subfolder="tokenizer",
78+
token=hf_auth_token,
79+
)
80+
hf_subfolder = "text_encoder"
81+
82+
text_encoder_model = CLIPTextModel.from_pretrained(
83+
hf_model_name,
84+
subfolder=hf_subfolder,
85+
token=hf_auth_token,
86+
)
7587

7688
mapper = {}
7789
utils.save_external_weights(
7890
mapper, text_encoder_model, external_weights, external_weight_path
7991
)
8092

81-
class CompiledClip(CompiledModule):
82-
if external_weights:
83-
params = export_parameters(
84-
text_encoder_model,
85-
external=True,
86-
external_scope="",
87-
name_mapper=mapper.get,
88-
)
89-
else:
90-
params = export_parameters(text_encoder_model)
93+
if "google/t5" in hf_model_name:
94+
95+
class CompiledClip(CompiledModule):
96+
if external_weights:
97+
params = export_parameters(
98+
text_encoder_model,
99+
external=True,
100+
external_scope="",
101+
name_mapper=mapper.get,
102+
)
103+
else:
104+
params = export_parameters(text_encoder_model)
105+
106+
def main(
107+
self,
108+
inp=AbstractTensor(1, input_len, dtype=torch.int64),
109+
decoder_input_ids=AbstractTensor(1, input_len, dtype=torch.int64),
110+
):
111+
return jittable(text_encoder_model.forward)(
112+
input_ids=inp, decoder_input_ids=decoder_input_ids
113+
)
114+
115+
else:
116+
117+
class CompiledClip(CompiledModule):
118+
if external_weights:
119+
params = export_parameters(
120+
text_encoder_model,
121+
external=True,
122+
external_scope="",
123+
name_mapper=mapper.get,
124+
)
125+
else:
126+
params = export_parameters(text_encoder_model)
91127

92-
def main(self, inp=AbstractTensor(1, 77, dtype=torch.int64)):
93-
return jittable(text_encoder_model.forward)(inp)
128+
def main(self, inp=AbstractTensor(1, input_len, dtype=torch.int64)):
129+
return jittable(text_encoder_model.forward)(input_ids=inp)
94130

95131
import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
96132
inst = CompiledClip(context=Context(), import_to=import_to)

models/turbine_models/custom_models/sd_inference/clip_runner.py

Lines changed: 99 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -52,49 +52,117 @@ def run_clip(
5252
):
5353
runner = vmfbRunner(device, vmfb_path, external_weight_path)
5454

55-
tokenizer = CLIPTokenizer.from_pretrained(
56-
hf_model_name,
57-
subfolder="tokenizer",
58-
token=hf_auth_token,
59-
)
60-
text_input = tokenizer(
61-
prompt,
62-
padding="max_length",
63-
max_length=tokenizer.model_max_length,
64-
truncation=True,
65-
return_tensors="pt",
66-
)
55+
if "google/t5" in hf_model_name:
56+
from transformers import T5Tokenizer, T5Model
57+
58+
tokenizer = T5Tokenizer.from_pretrained(hf_model_name)
59+
text_input = tokenizer(
60+
prompt,
61+
padding="max_length",
62+
max_length=tokenizer.model_max_length,
63+
truncation=True,
64+
return_tensors="pt",
65+
)
66+
# TODO: Integrate with HFTransformerBuilder
67+
else:
68+
if "openai" in hf_model_name:
69+
from transformers import CLIPProcessor
70+
import requests
71+
72+
tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
73+
text_input = tokenizer(
74+
text=prompt,
75+
truncation=True,
76+
padding=True,
77+
return_tensors="pt",
78+
)
79+
else:
80+
hf_subfolder = "tokenizer"
81+
82+
tokenizer = CLIPTokenizer.from_pretrained(
83+
hf_model_name,
84+
subfolder=hf_subfolder,
85+
token=hf_auth_token,
86+
)
87+
88+
text_input = tokenizer(
89+
prompt,
90+
padding="max_length",
91+
max_length=tokenizer.model_max_length,
92+
truncation=True,
93+
return_tensors="pt",
94+
)
6795
example_input = text_input.input_ids
6896
inp = [ireert.asdevicearray(runner.config.device, example_input)]
6997

98+
if "google/t5" in hf_model_name:
99+
inp += [ireert.asdevicearray(runner.config.device, example_input)]
70100
results = runner.ctx.modules.compiled_clip["main"](*inp)
71101
return results
72102

73103

74104
def run_torch_clip(hf_model_name, hf_auth_token, prompt):
105+
if "google/t5" in hf_model_name:
106+
from transformers import T5Tokenizer, T5Model
107+
108+
tokenizer = T5Tokenizer.from_pretrained(hf_model_name)
109+
model = T5Model.from_pretrained(hf_model_name)
110+
text_input = tokenizer(
111+
prompt,
112+
padding="max_length",
113+
max_length=tokenizer.model_max_length,
114+
truncation=True,
115+
return_tensors="pt",
116+
)
75117
# TODO: Integrate with HFTransformerBuilder
76-
from transformers import CLIPTextModel
118+
else:
119+
if hf_model_name == "openai/clip-vit-large-patch14":
120+
from transformers import CLIPProcessor
77121

78-
model = CLIPTextModel.from_pretrained(
79-
hf_model_name,
80-
subfolder="text_encoder",
81-
token=hf_auth_token,
82-
)
83-
tokenizer = CLIPTokenizer.from_pretrained(
84-
hf_model_name,
85-
subfolder="tokenizer",
86-
token=hf_auth_token,
87-
)
88-
text_input = tokenizer(
89-
prompt,
90-
padding="max_length",
91-
max_length=tokenizer.model_max_length,
92-
truncation=True,
93-
return_tensors="pt",
94-
)
122+
tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
123+
hf_subfolder = "" # CLIPProcessor does not have a subfolder
124+
from transformers import CLIPTextModel
125+
126+
model = CLIPTextModel.from_pretrained(
127+
hf_model_name,
128+
subfolder=hf_subfolder,
129+
token=hf_auth_token,
130+
)
131+
text_input = tokenizer(
132+
text=prompt,
133+
truncation=True,
134+
padding=True,
135+
return_tensors="pt",
136+
)
137+
else:
138+
hf_subfolder = "text_encoder"
139+
140+
tokenizer = CLIPTokenizer.from_pretrained(
141+
hf_model_name,
142+
subfolder="tokenizer",
143+
token=hf_auth_token,
144+
)
145+
146+
from transformers import CLIPTextModel
147+
148+
model = CLIPTextModel.from_pretrained(
149+
hf_model_name,
150+
subfolder=hf_subfolder,
151+
token=hf_auth_token,
152+
)
153+
text_input = tokenizer(
154+
prompt,
155+
padding="max_length",
156+
max_length=tokenizer.model_max_length,
157+
truncation=True,
158+
return_tensors="pt",
159+
)
95160
example_input = text_input.input_ids
96161

97-
results = model.forward(example_input)[0]
162+
if "google/t5" in hf_model_name:
163+
results = model.forward(example_input, decoder_input_ids=example_input)[0]
164+
else:
165+
results = model.forward(example_input)[0]
98166
np_torch_output = results.detach().cpu().numpy()
99167
return np_torch_output
100168

0 commit comments

Comments
 (0)