Skip to content

Commit

Permalink
fix switching between legacy and new processing for llava
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Oct 24, 2024
1 parent 86598a6 commit 31ff7f9
Showing 1 changed file with 45 additions and 14 deletions.
59 changes: 45 additions & 14 deletions optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,33 @@ def can_generate(self):


class _OVLlavaForCausalLM(OVModelForVisualCausalLM):
def __init__(
self,
language_model: ov.Model,
text_embeddings: ov.Model,
vision_embeddings: ov.Model,
config: PretrainedConfig = None,
device: str = "CPU",
dynamic_shapes: bool = True,
ov_config: Optional[Dict[str, str]] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
quantization_config: Union[OVWeightQuantizationConfig, Dict] = None,
**kwargs,
):
super().__init__(
language_model=language_model,
text_embeddings=text_embeddings,
vision_embeddings=vision_embeddings,
config=config,
device=device,
dynamic_shapes=dynamic_shapes,
ov_config=ov_config,
model_save_dir=model_save_dir,
quantization_config=quantization_config,
**kwargs,
)
self._legacy_processing = not hasattr(self.config, "image_seq_length")

def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs):
if input_ids is not None and input_ids.shape[1] == 1:
return None
Expand Down Expand Up @@ -696,10 +723,8 @@ def merge_vision_text_embeddings(
image_features = torch.from_numpy(vision_embeds) if isinstance(vision_embeds, np.ndarray) else vision_embeds
inputs_embeds = torch.from_numpy(inputs_embeds) if isinstance(inputs_embeds, np.ndarray) else inputs_embeds
if legacy_processing is None:
legacy_processing = (
not hasattr(self.config, "image_seq_length")
or ((input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length)
or (input_ids.shape[-1] == 1)
legacy_processing = not (hasattr(self.config, "image_seq_length") and (input_ids.shape[-1] == 1)) or (
(input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
)

if legacy_processing:
Expand Down Expand Up @@ -780,11 +805,15 @@ def merge_vision_text_embeddings(
def get_multimodal_embeddings(
self, input_ids, pixel_values=None, attention_mask=None, position_ids=None, past_key_values=None, **kwargs
):
legacy_processing = (
not hasattr(self.config, "image_seq_length")
or ((input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length)
or (input_ids.shape[-1] == 1 and pixel_values is not None)
)
legacy_processing = self._legacy_processing
inputs_embeds = self.get_text_embeddings(input_ids, **kwargs)

if pixel_values is not None and not legacy_processing and past_key_values is None:
legacy_processing = (input_ids == self.config.image_token_index).sum(
1
).max() < self.config.image_seq_length
self._legacy_processing = legacy_processing

inputs_embeds, attention_mask, position_ids = super().get_multimodal_embeddings(
input_ids, pixel_values, attention_mask, position_ids, legacy_processing=legacy_processing, **kwargs
)
Expand Down Expand Up @@ -902,12 +931,14 @@ def get_multimodal_embeddings(
from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches

inputs_embeds = self.get_text_embeddings(input_ids, **kwargs)
legacy_processing = self._legacy_processing

if pixel_values is not None and not legacy_processing and past_key_values is None:
legacy_processing = (input_ids == self.config.image_token_index).sum(
1
).max() < self.config.image_seq_length
self._legacy_processing = legacy_processing

legacy_processing = (
not hasattr(self.config, "image_seq_length")
or ((input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length)
or (input_ids.shape[-1] == 1 and pixel_values is not None)
)
if pixel_values is not None and pixel_values.size(0) > 0:
# ! infer image_num_patches from image_sizes
image_num_patches = [
Expand Down

0 comments on commit 31ff7f9

Please sign in to comment.