@@ -52,49 +52,117 @@ def run_clip(
52
52
):
53
53
runner = vmfbRunner (device , vmfb_path , external_weight_path )
54
54
55
- tokenizer = CLIPTokenizer .from_pretrained (
56
- hf_model_name ,
57
- subfolder = "tokenizer" ,
58
- token = hf_auth_token ,
59
- )
60
- text_input = tokenizer (
61
- prompt ,
62
- padding = "max_length" ,
63
- max_length = tokenizer .model_max_length ,
64
- truncation = True ,
65
- return_tensors = "pt" ,
66
- )
55
+ if "google/t5" in hf_model_name :
56
+ from transformers import T5Tokenizer , T5Model
57
+
58
+ tokenizer = T5Tokenizer .from_pretrained (hf_model_name )
59
+ text_input = tokenizer (
60
+ prompt ,
61
+ padding = "max_length" ,
62
+ max_length = tokenizer .model_max_length ,
63
+ truncation = True ,
64
+ return_tensors = "pt" ,
65
+ )
66
+ # TODO: Integrate with HFTransformerBuilder
67
+ else :
68
+ if "openai" in hf_model_name :
69
+ from transformers import CLIPProcessor
70
+ import requests
71
+
72
+ tokenizer = CLIPProcessor .from_pretrained ("openai/clip-vit-large-patch14" )
73
+ text_input = tokenizer (
74
+ text = prompt ,
75
+ truncation = True ,
76
+ padding = True ,
77
+ return_tensors = "pt" ,
78
+ )
79
+ else :
80
+ hf_subfolder = "tokenizer"
81
+
82
+ tokenizer = CLIPTokenizer .from_pretrained (
83
+ hf_model_name ,
84
+ subfolder = hf_subfolder ,
85
+ token = hf_auth_token ,
86
+ )
87
+
88
+ text_input = tokenizer (
89
+ prompt ,
90
+ padding = "max_length" ,
91
+ max_length = tokenizer .model_max_length ,
92
+ truncation = True ,
93
+ return_tensors = "pt" ,
94
+ )
67
95
example_input = text_input .input_ids
68
96
inp = [ireert .asdevicearray (runner .config .device , example_input )]
69
97
98
+ if "google/t5" in hf_model_name :
99
+ inp += [ireert .asdevicearray (runner .config .device , example_input )]
70
100
results = runner .ctx .modules .compiled_clip ["main" ](* inp )
71
101
return results
72
102
73
103
74
104
def run_torch_clip (hf_model_name , hf_auth_token , prompt ):
105
+ if "google/t5" in hf_model_name :
106
+ from transformers import T5Tokenizer , T5Model
107
+
108
+ tokenizer = T5Tokenizer .from_pretrained (hf_model_name )
109
+ model = T5Model .from_pretrained (hf_model_name )
110
+ text_input = tokenizer (
111
+ prompt ,
112
+ padding = "max_length" ,
113
+ max_length = tokenizer .model_max_length ,
114
+ truncation = True ,
115
+ return_tensors = "pt" ,
116
+ )
75
117
# TODO: Integrate with HFTransformerBuilder
76
- from transformers import CLIPTextModel
118
+ else :
119
+ if hf_model_name == "openai/clip-vit-large-patch14" :
120
+ from transformers import CLIPProcessor
77
121
78
- model = CLIPTextModel .from_pretrained (
79
- hf_model_name ,
80
- subfolder = "text_encoder" ,
81
- token = hf_auth_token ,
82
- )
83
- tokenizer = CLIPTokenizer .from_pretrained (
84
- hf_model_name ,
85
- subfolder = "tokenizer" ,
86
- token = hf_auth_token ,
87
- )
88
- text_input = tokenizer (
89
- prompt ,
90
- padding = "max_length" ,
91
- max_length = tokenizer .model_max_length ,
92
- truncation = True ,
93
- return_tensors = "pt" ,
94
- )
122
+ tokenizer = CLIPProcessor .from_pretrained ("openai/clip-vit-large-patch14" )
123
+ hf_subfolder = "" # CLIPProcessor does not have a subfolder
124
+ from transformers import CLIPTextModel
125
+
126
+ model = CLIPTextModel .from_pretrained (
127
+ hf_model_name ,
128
+ subfolder = hf_subfolder ,
129
+ token = hf_auth_token ,
130
+ )
131
+ text_input = tokenizer (
132
+ text = prompt ,
133
+ truncation = True ,
134
+ padding = True ,
135
+ return_tensors = "pt" ,
136
+ )
137
+ else :
138
+ hf_subfolder = "text_encoder"
139
+
140
+ tokenizer = CLIPTokenizer .from_pretrained (
141
+ hf_model_name ,
142
+ subfolder = "tokenizer" ,
143
+ token = hf_auth_token ,
144
+ )
145
+
146
+ from transformers import CLIPTextModel
147
+
148
+ model = CLIPTextModel .from_pretrained (
149
+ hf_model_name ,
150
+ subfolder = hf_subfolder ,
151
+ token = hf_auth_token ,
152
+ )
153
+ text_input = tokenizer (
154
+ prompt ,
155
+ padding = "max_length" ,
156
+ max_length = tokenizer .model_max_length ,
157
+ truncation = True ,
158
+ return_tensors = "pt" ,
159
+ )
95
160
example_input = text_input .input_ids
96
161
97
- results = model .forward (example_input )[0 ]
162
+ if "google/t5" in hf_model_name :
163
+ results = model .forward (example_input , decoder_input_ids = example_input )[0 ]
164
+ else :
165
+ results = model .forward (example_input )[0 ]
98
166
np_torch_output = results .detach ().cpu ().numpy ()
99
167
return np_torch_output
100
168
0 commit comments