Skip to content

Commit

Permalink
Add OpenVINO qwen2vl support (#1042)
Browse files Browse the repository at this point in the history
* qwen2vl support

* fix code style

* add test case

* Added compression tests for qwen2-vl

* Remove trust_remote_code

* Apply suggestions from code review

Co-authored-by: Nikita Savelyev <nikita.savelyev@intel.com>

* revert changes in notebook

* apply review comments

* add comments for patching

* reuse original methods if possile

* Update optimum/intel/openvino/modeling_visual_language.py

* fix typings in patchers

---------

Co-authored-by: Nikita Savelyev <nikita.savelyev@intel.com>
  • Loading branch information
eaidova and nikita-savelyevv authored Dec 19, 2024
1 parent b17d1e0 commit 93777ec
Show file tree
Hide file tree
Showing 8 changed files with 621 additions and 10 deletions.
232 changes: 226 additions & 6 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@
PersimmonModelPatcher,
Phi3ModelPatcher,
Phi3VisionImageEmbeddingsPatcher,
Qwen2VLLanguageModelPatcher,
Qwen2VLVisionEmbMergerPatcher,
QwenModelPatcher,
RotaryEmbPatcher,
UpdateCausalMaskModelPatcher,
Expand All @@ -106,6 +108,10 @@ def init_model_configs():
"transformers",
"LlavaNextForConditionalGeneration",
)
TasksManager._CUSTOM_CLASSES[("pt", "qwen2-vl", "image-text-to-text")] = (
"transformers",
"Qwen2VLForConditionalGeneration",
)
TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS[
"image-text-to-text"
] = TasksManager._TRANSFORMERS_TASKS_TO_MODEL_LOADERS["text-generation"]
Expand Down Expand Up @@ -1288,18 +1294,26 @@ def patch_model_for_export(


class LMInputEmbedsConfigHelper(TextDecoderWithPositionIdsOnnxConfig):
def __init__(self, export_config):
def __init__(self, export_config, patcher_cls=None, dummy_input_generator=None, inputs_update=None):
self.orig_export_config = export_config
if dummy_input_generator is not None:
export_config.DUMMY_INPUT_GENERATOR_CLASSES = (
dummy_input_generator,
) + export_config.DUMMY_INPUT_GENERATOR_CLASSES
self.DUMMY_INPUT_GENERATOR_CLASSES = export_config.DUMMY_INPUT_GENERATOR_CLASSES
self.DEFAULT_ONNX_OPSET = export_config.DEFAULT_ONNX_OPSET
self.DUMMY_PKV_GENERATOR_CLASS = export_config.DUMMY_PKV_GENERATOR_CLASS
self._config = export_config._config
self._normalized_config = export_config._normalized_config
self.use_past = export_config.use_past
self.patcher_cls = patcher_cls
self.input_info_upd = inputs_update

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
if self.patcher_cls is not None:
return self.patcher_cls(self, model, model_kwargs=model_kwargs)
# Refer to DecoderModelPatcher.
return self.orig_export_config.patch_model_for_export(model, model_kwargs=model_kwargs)

Expand All @@ -1312,6 +1326,8 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
orig_inputs = self.orig_export_config.inputs
input_ids_config = orig_inputs.pop("input_ids")
orig_inputs["inputs_embeds"] = input_ids_config
if self.input_info_upd is not None:
orig_inputs.update(self.input_info_upd)
return orig_inputs

def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
Expand Down Expand Up @@ -1383,9 +1399,22 @@ def get_vlm_text_embeddings_config(model_type, model_config, int_dtype, float_dt
return export_config


def get_vlm_text_generation_config(model_type, model_config, int_dtype, float_dtype):
def get_vlm_text_generation_config(
model_type,
model_config,
int_dtype,
float_dtype,
model_patcher=None,
dummy_input_generator=None,
inputs_update=None,
):
internal_export_config = get_vlm_internal_text_generation_config(model_type, model_config, int_dtype, float_dtype)
export_config = LMInputEmbedsConfigHelper(internal_export_config)
export_config = LMInputEmbedsConfigHelper(
internal_export_config,
patcher_cls=model_patcher,
dummy_input_generator=dummy_input_generator,
inputs_update=inputs_update,
)
export_config._normalized_config = internal_export_config._normalized_config
return export_config

Expand Down Expand Up @@ -1821,9 +1850,11 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
img_ids_height = self.height // 2
img_ids_width = self.width // 2
return self.random_int_tensor(
[self.batch_size, img_ids_height * img_ids_width, 3]
if is_diffusers_version("<", "0.31.0")
else [img_ids_height * img_ids_width, 3],
(
[self.batch_size, img_ids_height * img_ids_width, 3]
if is_diffusers_version("<", "0.31.0")
else [img_ids_height * img_ids_width, 3]
),
min_value=0,
max_value=min(img_ids_height, img_ids_width),
framework=framework,
Expand Down Expand Up @@ -2260,3 +2291,192 @@ def patch_model_for_export(
if self._behavior == Phi3VisionConfigBehavior.VISION_EMBEDDINGS:
return Phi3VisionImageEmbeddingsPatcher(self, model, model_kwargs)
return super().patch_model_for_export(model, model_kwargs)


class DummyQwen2VLLMInputGenerator(DummyTextInputGenerator):
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
generated_input = super().generate(input_name, framework, int_dtype, float_dtype)
if input_name == "position_ids":
return generated_input.unsqueeze(0).expand(3, -1, -1)
return generated_input


class DummyQwen2VLVisionEmbedInputGenerator(DummyVisionInputGenerator):
SUPPORTED_INPUT_NAMES = ("hidden_states", "attention_mask", "rotary_pos_emb")

def __init__(
self,
task: str,
normalized_config: NormalizedVisionConfig,
batch_size: int = 1,
num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"],
width: int = 420,
height: int = 420,
**kwargs,
):
self.batch_size = batch_size
self.height = height
self.width = width
self.num_channels = num_channels
self.temporal_patch_size = normalized_config.config.temporal_patch_size
self.patch_size = normalized_config.config.patch_size
if normalized_config.use_embed_dim:
self.embed_dim = normalized_config.config.embed_dim
else:
self.embed_dim = self.num_channels * self.temporal_patch_size * self.patch_size * self.patch_size
self.num_heads = normalized_config.config.num_heads

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
grid_h, grid_w = self.height // self.patch_size, self.width // self.patch_size
grid_t = self.batch_size

if input_name == "hidden_states":
return self.random_float_tensor(
[grid_t * grid_h * grid_w, self.embed_dim], framework=framework, dtype=float_dtype
)

if input_name == "attention_mask":
return self.random_mask_tensor(
[1, grid_t * grid_h * grid_w, grid_t * grid_h * grid_w], framework=framework, dtype=float_dtype
)

if input_name == "rotary_pos_emb":
dim = self.embed_dim // self.num_heads // 2
return self.random_float_tensor([grid_h * grid_t * grid_w, dim], framework=framework, dtype=float_dtype)


class Qwen2VLConfigBehavior(str, enum.Enum):
LANGUAGE = "language"
VISION_EMBEDDINGS = "vision_embeddings"
VISION_EMBEDDINGS_MERGER = "vision_embeddings_merger"
TEXT_EMBEDDINGS = "text_embeddings"


@register_in_tasks_manager("qwen2-vl", *["image-text-to-text"], library_name="transformers")
class Qwen2VLOpenVINOConfig(OnnxConfig):
SUPPORTED_BEHAVIORS = [model_type.value for model_type in Qwen2VLConfigBehavior]
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
DUMMY_INPUT_GENERATOR_CLASSES = (DummyQwen2VLVisionEmbedInputGenerator,)
MIN_TRANSFORMERS_VERSION = version.parse("4.45.0")

def __init__(
self,
config: "PretrainedConfig",
task: str = "feature-extraction",
int_dtype: str = "int64",
float_dtype: str = "fp32",
behavior: Qwen2VLConfigBehavior = Qwen2VLConfigBehavior.VISION_EMBEDDINGS,
preprocessors: Optional[List[Any]] = None,
):
super().__init__(
config=config,
task=task,
int_dtype=int_dtype,
float_dtype=float_dtype,
preprocessors=preprocessors,
)
self._behavior = behavior
self._orig_config = config
if self._behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS and hasattr(config, "vision_config"):
self._config = config.vision_config
self._config
self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config)
self._normalized_config.use_embed_dim = False
if self._behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER and hasattr(config, "vision_config"):
self._config = config.vision_config
self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config)
self._normalized_config.use_embed_dim = True

@staticmethod
def get_model_for_behavior(model, behavior: Union[str, Qwen2VLConfigBehavior]):
if isinstance(behavior, str) and not isinstance(behavior, Qwen2VLConfigBehavior):
behavior = Qwen2VLConfigBehavior(behavior)

if behavior == Qwen2VLConfigBehavior.LANGUAGE:
return model

if behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS:
vision_embeddings = model.visual.patch_embed
vision_embeddings.config = model.config.vision_config
return vision_embeddings

if behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER:
vision_emb_merger = model.visual
vision_emb_merger.config = model.config.vision_config
return vision_emb_merger

if behavior == Qwen2VLConfigBehavior.TEXT_EMBEDDINGS:
text_embedding = model.model.embed_tokens
text_embedding.config = model.config
return text_embedding

def with_behavior(
self,
behavior: Union[str, Qwen2VLConfigBehavior],
):
"""
Creates a config for different behaviour.
Args:
behavior ([`ConfigBehavior`]):
The behavior to use for the new instance.
"""
if isinstance(behavior, str) and not isinstance(behavior, Qwen2VLConfigBehavior):
behavior = Qwen2VLConfigBehavior(behavior)

if behavior == Qwen2VLConfigBehavior.TEXT_EMBEDDINGS:
return get_vlm_text_embeddings_config("qwen2", self._orig_config, self.int_dtype, self.float_dtype)

if behavior == Qwen2VLConfigBehavior.LANGUAGE:
return get_vlm_text_generation_config(
"qwen2",
self._orig_config,
self.int_dtype,
self.float_dtype,
model_patcher=Qwen2VLLanguageModelPatcher,
dummy_input_generator=DummyQwen2VLLMInputGenerator,
inputs_update={"position_ids": {1: "batch_size", 2: "sequence_length"}},
)

if behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS:
return self.__class__(
self._orig_config,
task=self.task,
int_dtype=self.int_dtype,
float_dtype=self.float_dtype,
behavior=behavior,
preprocessors=self._preprocessors,
)
if behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER:
return self.__class__(
self._orig_config,
task=self.task,
int_dtype=self.int_dtype,
float_dtype=self.float_dtype,
behavior=behavior,
preprocessors=self._preprocessors,
)

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
):
model_kwargs = model_kwargs or {}
if self._behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER:
return Qwen2VLVisionEmbMergerPatcher(self, model, model_kwargs)
return super().patch_model_for_export(model, model_kwargs)

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
if self._behavior == Phi3VisionConfigBehavior.VISION_EMBEDDINGS:
return {"hidden_states": {0: "patch_thw_grid", 1: "patch_temporal_channels"}}
if self._behavior == Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER:
return {
"hidden_states": {0: "sequence_length"},
"attention_mask": {1: "sequence_length", 2: "sequence_length"},
"rotary_pos_emb": {0: "sequence_length"},
}

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
if self._behavior in [Qwen2VLConfigBehavior.VISION_EMBEDDINGS, Qwen2VLConfigBehavior.VISION_EMBEDDINGS_MERGER]:
return {"last_hidden_state": {0: "seq_len"}}
return {}
106 changes: 106 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3378,3 +3378,109 @@ def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
for block in self._model.model.layers:
block.self_attn.forward = block.self_attn._orig_forward


class Qwen2VLLanguageModelPatcher(DecoderModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any] = None,
):
model.__orig_forward = model.forward

def forward_wrap(
self,
attention_mask,
position_ids=None,
past_key_values=None,
inputs_embeds=None,
input_ids=None,
):
from transformers.cache_utils import DynamicCache

new_past_key_values = DynamicCache.from_legacy_cache(past_key_values)
result = self.__orig_forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=new_past_key_values,
inputs_embeds=inputs_embeds,
)
if past_key_values is not None:
result["past_key_values"] = result["past_key_values"].to_legacy_cache()
return result

model.forward = types.MethodType(forward_wrap, model)
super().__init__(config, model, model_kwargs)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward


class Qwen2VLVisionEmbMergerPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Dict[str, Any] = None,
):
model.__orig_forward = model.forward

# Modified from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1118
# added attention_mask input instead cu_lens for its internal calculation model (unsupported by tracing due to cycle with dynamic len)
# separated patch_embed and rot_pos_emb calls for performing as part of another model
def image_embed_forward(
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, rotary_pos_emb: torch.Tensor
) -> torch.Tensor:
for blk in self.blocks:
hidden_states = blk(hidden_states, attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb)
return self.merger(hidden_states)

model.forward = types.MethodType(image_embed_forward, model)
super().__init__(config, model, model_kwargs)

def __enter__(self):
# Modified from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L390
# added attention_mask input instead of internal calculation (unsupported by tracing due to cycle with dynamic len)
def sdpa_attn_forward(
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, rotary_pos_emb: torch.Tensor = None
) -> torch.Tensor:
from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_rotary_pos_emb_vision

seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)

q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output

# Modified from https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L430
# added attention_mask input propagation to self.attn
def block_forward(self, hidden_states, attention_mask, rotary_pos_emb) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states), attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb
)
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
return hidden_states

for block in self._model.blocks:
block._orig_forward = block.forward
block.forward = types.MethodType(block_forward, block)
block.attn._orig_forward = block.attn.forward
block.attn.forward = types.MethodType(sdpa_attn_forward, block.attn)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward
for block in self._model.blocks:
block.forward = block._orig_forward
block.attn.forward = block.attn._orig_forward
Loading

0 comments on commit 93777ec

Please sign in to comment.