1
+ import copy
1
2
import logging
2
3
import os
3
4
import warnings
5
+ from abc import abstractmethod
4
6
from pathlib import Path
5
7
from typing import Dict , Optional , Tuple , Union
6
8
10
12
from huggingface_hub import hf_hub_download
11
13
from huggingface_hub .constants import HUGGINGFACE_HUB_CACHE
12
14
from openvino ._offline_transformations import apply_moc_transformations , compress_model_transformation
13
- from transformers import AutoConfig , GenerationConfig , GenerationMixin , PretrainedConfig
15
+ from PIL .Image import Image
16
+ from transformers import (
17
+ AutoConfig ,
18
+ GenerationConfig ,
19
+ GenerationMixin ,
20
+ PretrainedConfig ,
21
+ PreTrainedTokenizer ,
22
+ )
14
23
from transformers .modeling_outputs import BaseModelOutputWithPooling
15
24
16
25
from ...exporters .openvino import main_export
17
26
from ...exporters .openvino .stateful import ensure_stateful_is_available , model_has_input_output_name
27
+ from .. import OVQuantizer
18
28
from .configuration import OVConfig , OVWeightQuantizationConfig
19
29
from .modeling_base import OVBaseModel , OVModelPart
20
30
from .modeling_decoder import CausalLMOutputWithPast , OVModelForCausalLM
@@ -181,6 +191,7 @@ def __init__(self, model: ov.Model, parent_model: OVBaseModel) -> None:
181
191
self ._main_input = "images" if model_has_input_output_name (self .model , "images" ) else "pixel_values"
182
192
183
193
def forward (self , pixel_values , ** kwargs ):
194
+ self ._compile ()
184
195
inputs = {self ._main_input : pixel_values }
185
196
if len (self .input_names ) > 1 :
186
197
for name in self .input_names :
@@ -210,6 +221,7 @@ def __init__(self, model: ov.Model, parent_model: OVBaseModel) -> None:
210
221
self .output_names = {key .get_any_name (): idx for idx , key in enumerate (self .model .outputs )}
211
222
212
223
def forward (self , image_feature , pos_embed , key_padding_mask ):
224
+ self ._compile ()
213
225
result = self .request (
214
226
{"image_feature" : image_feature , "pos_embed" : pos_embed , "key_padding_mask" : key_padding_mask }
215
227
)[0 ]
@@ -244,7 +256,7 @@ def __init__(
244
256
self .ov_config = {} if ov_config is None else {** ov_config }
245
257
self .preprocessors = kwargs .get ("preprocessors" , [])
246
258
self .lm_model = language_model
247
- self .text_embdings_model = text_embeddings
259
+ self .text_embeddings_model = text_embeddings
248
260
self .vision_embeddings_model = vision_embeddings
249
261
self ._supports_cache_class = False
250
262
self .main_input_name = "input_ids"
@@ -261,13 +273,13 @@ def __init__(
261
273
self ._set_ov_config_parameters ()
262
274
self .language_model = OVModelWithEmbedForCausalLM (
263
275
self .lm_model ,
264
- self .text_embdings_model ,
276
+ self .text_embeddings_model ,
265
277
config = config ,
266
278
deivce = device ,
267
279
ov_config = ov_config ,
268
280
model_save_dir = model_save_dir ,
269
281
quantization_config = quantization_config ,
270
- compile = not self ._compile_only ,
282
+ compile = not self ._compile_only and enable_compilation ,
271
283
compile_only = self ._compile_only ,
272
284
)
273
285
self .vision_embeddings = OVVisionEmbedding (self .vision_embeddings_model , self )
@@ -287,6 +299,18 @@ def __init__(
287
299
except AttributeError :
288
300
pass
289
301
302
+ def clear_requests (self ):
303
+ if self ._compile_only :
304
+ raise ValueError (
305
+ "`clear_requests()` is not supported with `compile_only` mode, please intialize model without this option"
306
+ )
307
+
308
+ self .language_model .clear_requests ()
309
+ components = [self .vision_embeddings ] + [getattr (self , part ) for part in self .additional_parts ]
310
+ for component in components :
311
+ if component is not None :
312
+ component .request = None
313
+
290
314
def compile (self ):
291
315
self .language_model .compile ()
292
316
self .vision_embeddings ._compile ()
@@ -304,11 +328,11 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
304
328
save_directory (`str` or `Path`):
305
329
The directory where to save the model files.
306
330
"""
307
- src_files = [self .lm_model , self .text_embdings_model , self .vision_embeddings_model ]
331
+ src_files = [self .lm_model , self .text_embeddings_model , self .vision_embeddings_model ]
308
332
dst_file_names = [
309
333
"openvino_language_model.xml" ,
310
334
"openvino_text_embeddings_model.xml" ,
311
- "openvino_vision_embeddings .xml" ,
335
+ "openvino_vision_embeddings_model .xml" ,
312
336
]
313
337
for part in self .additional_parts :
314
338
model = getattr (self , f"{ part } _model" , None )
@@ -387,26 +411,18 @@ def _from_pretrained(
387
411
raise ValueError ("You cannot use both `use_auth_token` and `token` arguments at the same time." )
388
412
token = use_auth_token
389
413
390
- model_cls = MODEL_TYPE_TO_CLS_MAPPING [config .model_type ]
391
-
392
- quantization_config = model_cls ._prepare_weight_quantization_config (quantization_config , load_in_8bit )
393
- compile_only = kwargs .get ("compile_only" , False )
394
-
395
- # Load model from a local directory
396
- if os .path .isdir (model_id ):
397
- model_save_dir = Path (model_id )
398
414
model_file_names = {
399
415
"language_model" : "openvino_language_model.xml" ,
400
416
"text_embeddings" : "openvino_text_embeddings_model.xml" ,
401
417
"vision_embeddings" : "openvino_vision_embeddings_model.xml" ,
402
418
}
403
419
420
+ model_cls = MODEL_TYPE_TO_CLS_MAPPING [config .model_type ]
404
421
for part in model_cls .additional_parts :
405
422
model_file_names [part ] = f"openvino_{ part } _model.xml"
406
- model_cls = MODEL_TYPE_TO_CLS_MAPPING [config .model_type ]
407
- quantization_config = model_cls ._prepare_weight_quantization_config (quantization_config , load_in_8bit )
408
423
compile_only = kwargs .get ("compile_only" , False )
409
424
if os .path .isdir (model_id ):
425
+ # Load model from a local directory
410
426
model_save_dir = Path (model_id )
411
427
file_names = {k : os .path .join (model_id , model_file_names [k ]) for k in model_file_names }
412
428
else :
@@ -424,11 +440,11 @@ def _from_pretrained(
424
440
file_names [name ] = model_cache_path
425
441
model_save_dir = Path (model_cache_path ).parent
426
442
if not compile_only :
427
- language_model = model_cls .load_model (file_names ["language_model" ], quantization_config )
428
- text_embeddings = model_cls .load_model (file_names ["text_embeddings" ], quantization_config )
429
- vision_embeddings = model_cls .load_model (file_names ["vision_embeddings" ], quantization_config )
443
+ language_model = model_cls .load_model (file_names ["language_model" ])
444
+ text_embeddings = model_cls .load_model (file_names ["text_embeddings" ])
445
+ vision_embeddings = model_cls .load_model (file_names ["vision_embeddings" ])
430
446
for part in model_cls .additional_parts :
431
- kwargs [part ] = model_cls .load_model (file_names [part ], quantization_config )
447
+ kwargs [part ] = model_cls .load_model (file_names [part ])
432
448
else :
433
449
language_model = model_cls ._compile_model (
434
450
file_names ["language_model" ],
@@ -468,7 +484,12 @@ def _from_pretrained(
468
484
except Exception :
469
485
pass
470
486
471
- return model_cls (
487
+ quantization_config = model_cls ._prepare_weight_quantization_config (quantization_config , load_in_8bit )
488
+ to_quantize = not compile_only and quantization_config is not None
489
+ if to_quantize :
490
+ kwargs ["compile" ] = False
491
+
492
+ model = model_cls (
472
493
language_model = language_model ,
473
494
text_embeddings = text_embeddings ,
474
495
vision_embeddings = vision_embeddings ,
@@ -478,6 +499,15 @@ def _from_pretrained(
478
499
** kwargs ,
479
500
)
480
501
502
+ if to_quantize :
503
+ quantization_config_copy = copy .deepcopy (quantization_config )
504
+ quantization_config_copy .tokenizer = quantization_config .tokenizer or model_id
505
+ potential_processor_id = config .mm_vision_tower if isinstance (model , _OVNanoLlavaForCausalLM ) else model_id
506
+ quantization_config_copy .processor = quantization_config .processor or potential_processor_id
507
+ OVQuantizer (model ).quantize (ov_config = OVConfig (quantization_config = quantization_config_copy ))
508
+
509
+ return model
510
+
481
511
@classmethod
482
512
def _from_transformers (
483
513
cls ,
@@ -556,8 +586,8 @@ def half(self):
556
586
"""
557
587
apply_moc_transformations (self .lm_model , cf = False )
558
588
compress_model_transformation (self .lm_model )
559
- apply_moc_transformations (self .text_embdings_model , cf = False )
560
- compress_model_transformation (self .text_embdings_model )
589
+ apply_moc_transformations (self .text_embeddings_model , cf = False )
590
+ compress_model_transformation (self .text_embeddings_model )
561
591
apply_moc_transformations (self .vision_embeddings_model , cf = False )
562
592
compress_model_transformation (self .vision_embeddings_model )
563
593
for part in self .additional_parts :
@@ -695,6 +725,18 @@ def can_generate(self):
695
725
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
696
726
return True
697
727
728
+ @staticmethod
729
+ @abstractmethod
730
+ def preprocess_inputs (
731
+ processor ,
732
+ text : str ,
733
+ image : Optional [Image ] = None ,
734
+ tokenizer : Optional [PreTrainedTokenizer ] = None ,
735
+ ):
736
+ """
737
+ Preprocess input instruction and an image.
738
+ """
739
+
698
740
699
741
class _OVLlavaForCausalLM (OVModelForVisualCausalLM ):
700
742
def __init__ (
@@ -858,6 +900,20 @@ def _filter_unattended_tokens(self, input_ids, attention_mask, past_key_values):
858
900
position_ids [attention_mask == 0 ] = 1
859
901
return attention_mask , position_ids
860
902
903
+ @staticmethod
904
+ def preprocess_inputs (
905
+ processor ,
906
+ text : str ,
907
+ image : Optional [Image ] = None ,
908
+ tokenizer : Optional [PreTrainedTokenizer ] = None ,
909
+ ):
910
+ if image is None :
911
+ raise ValueError ("Image is required." )
912
+ chat_template = [{"role" : "user" , "content" : [{"type" : "text" , "text" : text }, {"type" : "image" }]}]
913
+ prompt = processor .apply_chat_template (chat_template , add_generation_prompt = True )
914
+ inputs = processor (images = image , text = prompt , return_tensors = "pt" )
915
+ return inputs
916
+
861
917
862
918
class _OVLlavaNextForCausalLM (_OVLlavaForCausalLM ):
863
919
# Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L655
@@ -1372,6 +1428,19 @@ def merge_vision_text_embeddings(
1372
1428
)
1373
1429
return vllm_embedding , attention_mask , position_ids
1374
1430
1431
+ @staticmethod
1432
+ def preprocess_inputs (
1433
+ processor ,
1434
+ text : str ,
1435
+ image : Optional [Image ] = None ,
1436
+ tokenizer : Optional [PreTrainedTokenizer ] = None ,
1437
+ ):
1438
+ if image is None :
1439
+ raise ValueError ("Image is required." )
1440
+ prompt = f"<|im_start|>user\n (<image>./</image>)\n { text } <|im_end|>\n <|im_start|>assistant\n "
1441
+ inputs = processor ([prompt ], [image ], return_tensors = "pt" )
1442
+ return inputs
1443
+
1375
1444
1376
1445
class _OVNanoLlavaForCausalLM (OVModelForVisualCausalLM ):
1377
1446
def get_vision_embeddings (self , pixel_values , input_ids = None , ** kwargs ):
@@ -1544,6 +1613,25 @@ def get_multimodal_embeddings(
1544
1613
1545
1614
return new_input_embeds , attention_mask , position_ids
1546
1615
1616
+ @staticmethod
1617
+ def preprocess_inputs (
1618
+ processor ,
1619
+ text : str ,
1620
+ image : Optional [Image ] = None ,
1621
+ tokenizer : Optional [PreTrainedTokenizer ] = None ,
1622
+ ):
1623
+ if tokenizer is None :
1624
+ raise ValueError ("Tokenizer is required." )
1625
+ messages = [{"role" : "user" , "content" : f"<image>\n { text } " }]
1626
+ text = tokenizer .apply_chat_template (messages , tokenize = False , add_generation_prompt = True )
1627
+ text_chunks = [tokenizer (chunk ).input_ids for chunk in text .split ("<image>" )]
1628
+ input_ids = torch .tensor (text_chunks [0 ] + [- 200 ] + text_chunks [1 ], dtype = torch .long ).unsqueeze (0 )
1629
+ attention_mask = torch .ones_like (input_ids , dtype = torch .int64 )
1630
+ result = {"input_ids" : input_ids , "attention_mask" : attention_mask }
1631
+ if image is not None :
1632
+ result ["images" ] = torch .unsqueeze (processor (images = image , return_tensors = "pt" )["pixel_values" ][0 ], 0 )
1633
+ return result
1634
+
1547
1635
1548
1636
MODEL_TYPE_TO_CLS_MAPPING = {
1549
1637
"llava" : _OVLlavaForCausalLM ,
0 commit comments