3
3
from transformers import CLIPTokenizer
4
4
from iree import runtime as ireert
5
5
import torch
6
+ from PIL import Image
6
7
7
8
parser = argparse .ArgumentParser ()
8
9
@@ -52,21 +53,54 @@ def run_clip(
52
53
):
53
54
runner = vmfbRunner (device , vmfb_path , external_weight_path )
54
55
55
- tokenizer = CLIPTokenizer .from_pretrained (
56
- hf_model_name ,
57
- subfolder = "tokenizer" ,
58
- token = hf_auth_token ,
59
- )
60
- text_input = tokenizer (
61
- prompt ,
62
- padding = "max_length" ,
63
- max_length = tokenizer .model_max_length ,
64
- truncation = True ,
65
- return_tensors = "pt" ,
66
- )
56
+ if "google/t5" in hf_model_name :
57
+ from transformers import T5Tokenizer , T5Model
58
+
59
+ tokenizer = T5Tokenizer .from_pretrained (hf_model_name )
60
+ text_input = tokenizer (
61
+ prompt ,
62
+ padding = "max_length" ,
63
+ max_length = tokenizer .model_max_length ,
64
+ truncation = True ,
65
+ return_tensors = "pt" ,
66
+ )
67
+ # TODO: Integrate with HFTransformerBuilder
68
+ else :
69
+ if "openai" in hf_model_name :
70
+ from transformers import CLIPProcessor
71
+ import requests
72
+
73
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
74
+ image = Image .open (requests .get (url , stream = True ).raw )
75
+ tokenizer = CLIPProcessor .from_pretrained ("openai/clip-vit-large-patch14" )
76
+ text_input = tokenizer (
77
+ text = prompt ,
78
+ images = image ,
79
+ truncation = True ,
80
+ padding = True ,
81
+ return_tensors = "pt" ,
82
+ )
83
+ else :
84
+ hf_subfolder = "tokenizer"
85
+
86
+ tokenizer = CLIPTokenizer .from_pretrained (
87
+ hf_model_name ,
88
+ subfolder = hf_subfolder ,
89
+ token = hf_auth_token ,
90
+ )
91
+
92
+ text_input = tokenizer (
93
+ prompt ,
94
+ padding = "max_length" ,
95
+ max_length = tokenizer .model_max_length ,
96
+ truncation = True ,
97
+ return_tensors = "pt" ,
98
+ )
67
99
example_input = text_input .input_ids
68
100
inp = [ireert .asdevicearray (runner .config .device , example_input )]
69
101
102
+ if "google/t5" in hf_model_name :
103
+ inp += [ireert .asdevicearray (runner .config .device , example_input )]
70
104
results = runner .ctx .modules .compiled_clip ["main" ](* inp )
71
105
return results
72
106
@@ -77,13 +111,38 @@ def run_torch_clip(hf_model_name, hf_auth_token, prompt):
77
111
78
112
tokenizer = T5Tokenizer .from_pretrained (hf_model_name )
79
113
model = T5Model .from_pretrained (hf_model_name )
114
+ text_input = tokenizer (
115
+ prompt ,
116
+ padding = "max_length" ,
117
+ max_length = tokenizer .model_max_length ,
118
+ truncation = True ,
119
+ return_tensors = "pt" ,
120
+ )
80
121
# TODO: Integrate with HFTransformerBuilder
81
122
else :
82
123
if hf_model_name == "openai/clip-vit-large-patch14" :
83
124
from transformers import CLIPProcessor
125
+ import requests
126
+
127
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
128
+ image = Image .open (requests .get (url , stream = True ).raw )
84
129
85
130
tokenizer = CLIPProcessor .from_pretrained ("openai/clip-vit-large-patch14" )
86
131
hf_subfolder = "" # CLIPProcessor does not have a subfolder
132
+ from transformers import CLIPTextModel
133
+
134
+ model = CLIPTextModel .from_pretrained (
135
+ hf_model_name ,
136
+ subfolder = hf_subfolder ,
137
+ token = hf_auth_token ,
138
+ )
139
+ text_input = tokenizer (
140
+ text = prompt ,
141
+ images = image ,
142
+ truncation = True ,
143
+ padding = True ,
144
+ return_tensors = "pt" ,
145
+ )
87
146
else :
88
147
hf_subfolder = "text_encoder"
89
148
@@ -93,20 +152,20 @@ def run_torch_clip(hf_model_name, hf_auth_token, prompt):
93
152
token = hf_auth_token ,
94
153
)
95
154
96
- from transformers import CLIPTextModel
155
+ from transformers import CLIPTextModel
97
156
98
- model = CLIPTextModel .from_pretrained (
99
- hf_model_name ,
100
- subfolder = hf_subfolder ,
101
- token = hf_auth_token ,
102
- )
103
- text_input = tokenizer (
104
- prompt ,
105
- padding = "max_length" ,
106
- max_length = tokenizer .model_max_length ,
107
- truncation = True ,
108
- return_tensors = "pt" ,
109
- )
157
+ model = CLIPTextModel .from_pretrained (
158
+ hf_model_name ,
159
+ subfolder = hf_subfolder ,
160
+ token = hf_auth_token ,
161
+ )
162
+ text_input = tokenizer (
163
+ prompt ,
164
+ padding = "max_length" ,
165
+ max_length = tokenizer .model_max_length ,
166
+ truncation = True ,
167
+ return_tensors = "pt" ,
168
+ )
110
169
example_input = text_input .input_ids
111
170
112
171
if "google/t5" in hf_model_name :
0 commit comments