Skip to content

Commit a8e69a3

Browse files
Quantization support for CausalVisualLMs (#951)
* Quantization support for CausalVisualLMs * Tweaks * Add tests * Fix test * Added a data-aware compression test for llava-next * Add assemble_inputs() method to OVModelForVisualCausalLM * Add support for minicpmv * Add support for nanollava * Add group size * Fix test * Added support for cli compression * Tweak refs * Fix test * Rename assemble_input; fix tests * Addressed suggested changes
1 parent c887610 commit a8e69a3

File tree

8 files changed

+391
-76
lines changed

8 files changed

+391
-76
lines changed

optimum/commands/export/openvino.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -329,11 +329,18 @@ def run(self):
329329
model.save_pretrained(self.args.output)
330330
if not self.args.disable_convert_tokenizer:
331331
maybe_convert_tokenizers(library_name, self.args.output, model, task=task)
332-
elif task.startswith("text-generation") and quantize_with_dataset:
333-
from optimum.intel import OVModelForCausalLM
332+
elif (task.startswith("text-generation") or task == "image-text-to-text") and quantize_with_dataset:
333+
if task.startswith("text-generation"):
334+
from optimum.intel import OVModelForCausalLM
334335

335-
# To quantize a text-generation model with a dataset, an instantiated OVModelForCausalLM is required
336-
model = OVModelForCausalLM.from_pretrained(
336+
model_cls = OVModelForCausalLM
337+
else:
338+
from optimum.intel import OVModelForVisualCausalLM
339+
340+
model_cls = OVModelForVisualCausalLM
341+
342+
# To quantize a model with a dataset, an instance of a model class is required
343+
model = model_cls.from_pretrained(
337344
self.args.model,
338345
export=True,
339346
quantization_config=quantization_config,

optimum/intel/openvino/configuration.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from optimum.configuration_utils import BaseConfig
2727

2828
from ..utils.import_utils import is_nncf_available
29+
from .utils import PREDEFINED_SD_DATASETS, PREDEFINED_VISUAL_LM_DATASETS
2930

3031

3132
if is_nncf_available():
@@ -350,6 +351,11 @@ class OVWeightQuantizationConfig(OVQuantizationConfigBase):
350351
qptq (`bool`, *optional*):
351352
Whether to apply GPTQ algorithm. GPTQ optimizes compressed weights in a layer-wise fashion to minimize the
352353
difference between activations of a compressed and original layer. Dataset is required to run GPTQ.
354+
processor (`str`, *optional*):
355+
A transformers processor used to process inputs for multi-modal models. You can pass either:
356+
- A string, the *model id* of a predefined processor hosted inside a model repo on huggingface.co.
357+
- A path to a *directory* containing files required by the processor, for instance saved
358+
using the [`~AutoProcessor.save_pretrained`] method, e.g., `./my_model_directory/`.
353359
"""
354360

355361
def __init__(
@@ -369,6 +375,7 @@ def __init__(
369375
scale_estimation: bool = None,
370376
weight_format: Optional[str] = None,
371377
gptq: bool = None,
378+
processor: Optional[str] = None,
372379
**kwargs,
373380
):
374381
super().__init__(bits=bits, sym=sym, ignored_scope=ignored_scope, num_samples=num_samples)
@@ -383,6 +390,7 @@ def __init__(
383390
self.scale_estimation = scale_estimation
384391
self.weight_format = weight_format
385392
self.gptq = gptq
393+
self.processor = processor
386394
self.post_init()
387395

388396
def post_init(self):
@@ -400,16 +408,14 @@ def post_init(self):
400408
f"If you wish to provide a custom dataset, please use the `OVQuantizer` instead."
401409
)
402410
if self.dataset is not None and isinstance(self.dataset, str):
403-
llm_datasets = ["wikitext2", "c4", "c4-new"]
404-
stable_diffusion_datasets = [
405-
"conceptual_captions",
406-
"laion/220k-GPT4Vision-captions-from-LIVIS",
407-
"laion/filtered-wit",
408-
]
409-
if self.dataset not in llm_datasets + stable_diffusion_datasets:
411+
lm_datasets = ["wikitext2", "c4", "c4-new"]
412+
visual_lm_datasets = list(PREDEFINED_VISUAL_LM_DATASETS.keys())
413+
stable_diffusion_datasets = list(PREDEFINED_SD_DATASETS.keys())
414+
if self.dataset not in lm_datasets + visual_lm_datasets + stable_diffusion_datasets:
410415
raise ValueError(
411416
f"""You have entered a string value for dataset. You can only choose between
412-
{llm_datasets} for LLLMs or {stable_diffusion_datasets} for diffusion models, but we found {self.dataset}"""
417+
{lm_datasets} for LLMs, {visual_lm_datasets} for visual LLMs
418+
or {stable_diffusion_datasets} for diffusion models, but we found {self.dataset}"""
413419
)
414420

415421
if self.bits not in [4, 8]:
@@ -444,6 +450,9 @@ def post_init(self):
444450
if self.tokenizer is not None and not isinstance(self.tokenizer, str):
445451
raise ValueError(f"Tokenizer is expected to be a string, but found {self.tokenizer}")
446452

453+
if self.processor is not None and not isinstance(self.processor, str):
454+
raise ValueError(f"Processor is expected to be a string, but found {self.processor}")
455+
447456
if self.weight_format is None:
448457
self.weight_format = "int4" if self.bits == 4 else "int8"
449458
if self.weight_format not in ["int4", "int8", "mxfp4"]:

optimum/intel/openvino/modeling_visual_language.py

Lines changed: 111 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import copy
12
import logging
23
import os
34
import warnings
5+
from abc import abstractmethod
46
from pathlib import Path
57
from typing import Dict, Optional, Tuple, Union
68

@@ -10,11 +12,19 @@
1012
from huggingface_hub import hf_hub_download
1113
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
1214
from openvino._offline_transformations import apply_moc_transformations, compress_model_transformation
13-
from transformers import AutoConfig, GenerationConfig, GenerationMixin, PretrainedConfig
15+
from PIL.Image import Image
16+
from transformers import (
17+
AutoConfig,
18+
GenerationConfig,
19+
GenerationMixin,
20+
PretrainedConfig,
21+
PreTrainedTokenizer,
22+
)
1423
from transformers.modeling_outputs import BaseModelOutputWithPooling
1524

1625
from ...exporters.openvino import main_export
1726
from ...exporters.openvino.stateful import ensure_stateful_is_available, model_has_input_output_name
27+
from .. import OVQuantizer
1828
from .configuration import OVConfig, OVWeightQuantizationConfig
1929
from .modeling_base import OVBaseModel, OVModelPart
2030
from .modeling_decoder import CausalLMOutputWithPast, OVModelForCausalLM
@@ -181,6 +191,7 @@ def __init__(self, model: ov.Model, parent_model: OVBaseModel) -> None:
181191
self._main_input = "images" if model_has_input_output_name(self.model, "images") else "pixel_values"
182192

183193
def forward(self, pixel_values, **kwargs):
194+
self._compile()
184195
inputs = {self._main_input: pixel_values}
185196
if len(self.input_names) > 1:
186197
for name in self.input_names:
@@ -210,6 +221,7 @@ def __init__(self, model: ov.Model, parent_model: OVBaseModel) -> None:
210221
self.output_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)}
211222

212223
def forward(self, image_feature, pos_embed, key_padding_mask):
224+
self._compile()
213225
result = self.request(
214226
{"image_feature": image_feature, "pos_embed": pos_embed, "key_padding_mask": key_padding_mask}
215227
)[0]
@@ -244,7 +256,7 @@ def __init__(
244256
self.ov_config = {} if ov_config is None else {**ov_config}
245257
self.preprocessors = kwargs.get("preprocessors", [])
246258
self.lm_model = language_model
247-
self.text_embdings_model = text_embeddings
259+
self.text_embeddings_model = text_embeddings
248260
self.vision_embeddings_model = vision_embeddings
249261
self._supports_cache_class = False
250262
self.main_input_name = "input_ids"
@@ -261,13 +273,13 @@ def __init__(
261273
self._set_ov_config_parameters()
262274
self.language_model = OVModelWithEmbedForCausalLM(
263275
self.lm_model,
264-
self.text_embdings_model,
276+
self.text_embeddings_model,
265277
config=config,
266278
deivce=device,
267279
ov_config=ov_config,
268280
model_save_dir=model_save_dir,
269281
quantization_config=quantization_config,
270-
compile=not self._compile_only,
282+
compile=not self._compile_only and enable_compilation,
271283
compile_only=self._compile_only,
272284
)
273285
self.vision_embeddings = OVVisionEmbedding(self.vision_embeddings_model, self)
@@ -287,6 +299,18 @@ def __init__(
287299
except AttributeError:
288300
pass
289301

302+
def clear_requests(self):
303+
if self._compile_only:
304+
raise ValueError(
305+
"`clear_requests()` is not supported with `compile_only` mode, please intialize model without this option"
306+
)
307+
308+
self.language_model.clear_requests()
309+
components = [self.vision_embeddings] + [getattr(self, part) for part in self.additional_parts]
310+
for component in components:
311+
if component is not None:
312+
component.request = None
313+
290314
def compile(self):
291315
self.language_model.compile()
292316
self.vision_embeddings._compile()
@@ -304,11 +328,11 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
304328
save_directory (`str` or `Path`):
305329
The directory where to save the model files.
306330
"""
307-
src_files = [self.lm_model, self.text_embdings_model, self.vision_embeddings_model]
331+
src_files = [self.lm_model, self.text_embeddings_model, self.vision_embeddings_model]
308332
dst_file_names = [
309333
"openvino_language_model.xml",
310334
"openvino_text_embeddings_model.xml",
311-
"openvino_vision_embeddings.xml",
335+
"openvino_vision_embeddings_model.xml",
312336
]
313337
for part in self.additional_parts:
314338
model = getattr(self, f"{part}_model", None)
@@ -387,26 +411,18 @@ def _from_pretrained(
387411
raise ValueError("You cannot use both `use_auth_token` and `token` arguments at the same time.")
388412
token = use_auth_token
389413

390-
model_cls = MODEL_TYPE_TO_CLS_MAPPING[config.model_type]
391-
392-
quantization_config = model_cls._prepare_weight_quantization_config(quantization_config, load_in_8bit)
393-
compile_only = kwargs.get("compile_only", False)
394-
395-
# Load model from a local directory
396-
if os.path.isdir(model_id):
397-
model_save_dir = Path(model_id)
398414
model_file_names = {
399415
"language_model": "openvino_language_model.xml",
400416
"text_embeddings": "openvino_text_embeddings_model.xml",
401417
"vision_embeddings": "openvino_vision_embeddings_model.xml",
402418
}
403419

420+
model_cls = MODEL_TYPE_TO_CLS_MAPPING[config.model_type]
404421
for part in model_cls.additional_parts:
405422
model_file_names[part] = f"openvino_{part}_model.xml"
406-
model_cls = MODEL_TYPE_TO_CLS_MAPPING[config.model_type]
407-
quantization_config = model_cls._prepare_weight_quantization_config(quantization_config, load_in_8bit)
408423
compile_only = kwargs.get("compile_only", False)
409424
if os.path.isdir(model_id):
425+
# Load model from a local directory
410426
model_save_dir = Path(model_id)
411427
file_names = {k: os.path.join(model_id, model_file_names[k]) for k in model_file_names}
412428
else:
@@ -424,11 +440,11 @@ def _from_pretrained(
424440
file_names[name] = model_cache_path
425441
model_save_dir = Path(model_cache_path).parent
426442
if not compile_only:
427-
language_model = model_cls.load_model(file_names["language_model"], quantization_config)
428-
text_embeddings = model_cls.load_model(file_names["text_embeddings"], quantization_config)
429-
vision_embeddings = model_cls.load_model(file_names["vision_embeddings"], quantization_config)
443+
language_model = model_cls.load_model(file_names["language_model"])
444+
text_embeddings = model_cls.load_model(file_names["text_embeddings"])
445+
vision_embeddings = model_cls.load_model(file_names["vision_embeddings"])
430446
for part in model_cls.additional_parts:
431-
kwargs[part] = model_cls.load_model(file_names[part], quantization_config)
447+
kwargs[part] = model_cls.load_model(file_names[part])
432448
else:
433449
language_model = model_cls._compile_model(
434450
file_names["language_model"],
@@ -468,7 +484,12 @@ def _from_pretrained(
468484
except Exception:
469485
pass
470486

471-
return model_cls(
487+
quantization_config = model_cls._prepare_weight_quantization_config(quantization_config, load_in_8bit)
488+
to_quantize = not compile_only and quantization_config is not None
489+
if to_quantize:
490+
kwargs["compile"] = False
491+
492+
model = model_cls(
472493
language_model=language_model,
473494
text_embeddings=text_embeddings,
474495
vision_embeddings=vision_embeddings,
@@ -478,6 +499,15 @@ def _from_pretrained(
478499
**kwargs,
479500
)
480501

502+
if to_quantize:
503+
quantization_config_copy = copy.deepcopy(quantization_config)
504+
quantization_config_copy.tokenizer = quantization_config.tokenizer or model_id
505+
potential_processor_id = config.mm_vision_tower if isinstance(model, _OVNanoLlavaForCausalLM) else model_id
506+
quantization_config_copy.processor = quantization_config.processor or potential_processor_id
507+
OVQuantizer(model).quantize(ov_config=OVConfig(quantization_config=quantization_config_copy))
508+
509+
return model
510+
481511
@classmethod
482512
def _from_transformers(
483513
cls,
@@ -556,8 +586,8 @@ def half(self):
556586
"""
557587
apply_moc_transformations(self.lm_model, cf=False)
558588
compress_model_transformation(self.lm_model)
559-
apply_moc_transformations(self.text_embdings_model, cf=False)
560-
compress_model_transformation(self.text_embdings_model)
589+
apply_moc_transformations(self.text_embeddings_model, cf=False)
590+
compress_model_transformation(self.text_embeddings_model)
561591
apply_moc_transformations(self.vision_embeddings_model, cf=False)
562592
compress_model_transformation(self.vision_embeddings_model)
563593
for part in self.additional_parts:
@@ -695,6 +725,18 @@ def can_generate(self):
695725
"""Returns True to validate the check that the model using `GenerationMixin.generate()` can indeed generate."""
696726
return True
697727

728+
@staticmethod
729+
@abstractmethod
730+
def preprocess_inputs(
731+
processor,
732+
text: str,
733+
image: Optional[Image] = None,
734+
tokenizer: Optional[PreTrainedTokenizer] = None,
735+
):
736+
"""
737+
Preprocess input instruction and an image.
738+
"""
739+
698740

699741
class _OVLlavaForCausalLM(OVModelForVisualCausalLM):
700742
def __init__(
@@ -858,6 +900,20 @@ def _filter_unattended_tokens(self, input_ids, attention_mask, past_key_values):
858900
position_ids[attention_mask == 0] = 1
859901
return attention_mask, position_ids
860902

903+
@staticmethod
904+
def preprocess_inputs(
905+
processor,
906+
text: str,
907+
image: Optional[Image] = None,
908+
tokenizer: Optional[PreTrainedTokenizer] = None,
909+
):
910+
if image is None:
911+
raise ValueError("Image is required.")
912+
chat_template = [{"role": "user", "content": [{"type": "text", "text": text}, {"type": "image"}]}]
913+
prompt = processor.apply_chat_template(chat_template, add_generation_prompt=True)
914+
inputs = processor(images=image, text=prompt, return_tensors="pt")
915+
return inputs
916+
861917

862918
class _OVLlavaNextForCausalLM(_OVLlavaForCausalLM):
863919
# Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L655
@@ -1372,6 +1428,19 @@ def merge_vision_text_embeddings(
13721428
)
13731429
return vllm_embedding, attention_mask, position_ids
13741430

1431+
@staticmethod
1432+
def preprocess_inputs(
1433+
processor,
1434+
text: str,
1435+
image: Optional[Image] = None,
1436+
tokenizer: Optional[PreTrainedTokenizer] = None,
1437+
):
1438+
if image is None:
1439+
raise ValueError("Image is required.")
1440+
prompt = f"<|im_start|>user\n(<image>./</image>)\n{text}<|im_end|>\n<|im_start|>assistant\n"
1441+
inputs = processor([prompt], [image], return_tensors="pt")
1442+
return inputs
1443+
13751444

13761445
class _OVNanoLlavaForCausalLM(OVModelForVisualCausalLM):
13771446
def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs):
@@ -1544,6 +1613,25 @@ def get_multimodal_embeddings(
15441613

15451614
return new_input_embeds, attention_mask, position_ids
15461615

1616+
@staticmethod
1617+
def preprocess_inputs(
1618+
processor,
1619+
text: str,
1620+
image: Optional[Image] = None,
1621+
tokenizer: Optional[PreTrainedTokenizer] = None,
1622+
):
1623+
if tokenizer is None:
1624+
raise ValueError("Tokenizer is required.")
1625+
messages = [{"role": "user", "content": f"<image>\n{text}"}]
1626+
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
1627+
text_chunks = [tokenizer(chunk).input_ids for chunk in text.split("<image>")]
1628+
input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)
1629+
attention_mask = torch.ones_like(input_ids, dtype=torch.int64)
1630+
result = {"input_ids": input_ids, "attention_mask": attention_mask}
1631+
if image is not None:
1632+
result["images"] = torch.unsqueeze(processor(images=image, return_tensors="pt")["pixel_values"][0], 0)
1633+
return result
1634+
15471635

15481636
MODEL_TYPE_TO_CLS_MAPPING = {
15491637
"llava": _OVLlavaForCausalLM,

0 commit comments

Comments
 (0)