Skip to content

Commit

Permalink
add nanollava in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Oct 24, 2024
1 parent 9a83fe5 commit e545b48
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 23 deletions.
3 changes: 1 addition & 2 deletions optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -1433,8 +1433,7 @@ def get_multimodal_embeddings(
vision_embeds = self.get_vision_embeddings(pixel_values, input_ids=input_ids, **kwargs)
if vision_embeds is None:
inputs_embeds = torch.from_numpy(self.get_text_embeddings(input_ids))
if kwargs.get("past_key_values") is not None:
past_len = self.language_model._get_past_length(kwargs.get("past_key_values"))
past_len = self.language_model._get_past_length(kwargs.get("past_key_values"))
if attention_mask is not None and attention_mask.shape[1] < past_len + input_ids.shape[1]:
attention_mask = torch.cat(
[
Expand Down
68 changes: 47 additions & 21 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1870,7 +1870,7 @@ def test_compare_with_and_without_past_key_values(self):
class OVModelForVisualCausalLMIntegrationTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = ["llava"]

REMOTE_CODE_MODELS = ["minicpmv"]
REMOTE_CODE_MODELS = ["minicpmv", "nanollava"]

if is_transformers_version(">=", "4.40.0"):
SUPPORTED_ARCHITECTURES += ["llava_next"]
Expand All @@ -1894,22 +1894,47 @@ def get_transformer_model_class(self, model_arch):
from transformers import LlavaNextForConditionalGeneration

return LlavaNextForConditionalGeneration
if model_arch == "minicpmv":
return AutoModelForCausalLM
return None
return AutoModelForCausalLM

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_to_transformers(self, model_arch):
def gen_inputs(self, model_arch, base_text_prompt, image=None):
model_id = MODEL_NAMES[model_arch]
if "llava" in model_arch:
prompt = "<image>\n What is shown in this image?"
prompt = f"<image>\n {base_text_prompt}"
elif "minicpmv" in model_arch:
prompt = "<|im_start|>user\n(<image>./</image>)\n What is shown in this image?<|im_end|>\n<|im_start|>assistant\n"
prompt = "<|im_start|>user\n(<image>./</image>)\n {base_text_prompt}<|im_end|>\n<|im_start|>assistant\n"
if model_arch != "nanollava":
processor = AutoProcessor.from_pretrained(
model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS
)
inputs = processor(images=[self.IMAGE.resize((600, 600))], text=[prompt], return_tensors="pt")
else:
config = AutoConfig.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS)
processor = AutoProcessor.from_pretrained(
config.mm_vision_tower, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS
)
tokenizer = AutoTokenizer.from_pretrained(
model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS
)
image_input = None
if image is not None:
image_input = processor(images=image, return_tensors="pt")["pixel_values"]
text_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]

input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)
attention_mask = torch.ones_like(input_ids, dtype=torch.int64)
inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "images": image_input}
return inputs

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_to_transformers(self, model_arch):
model_id = MODEL_NAMES[model_arch]
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS)
transformers_model = self.get_transformer_model_class(model_arch).from_pretrained(
model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS
)
inputs = processor(images=[self.IMAGE.resize((600, 600))], text=[prompt], return_tensors="pt")
if "nanollava" in model_arch:
transformers_model.get_vision_tower().load_model()
inputs = self.gen_inputs(model_arch, "What is shown on this image?", self.IMAGE)

ov_model = OVModelForVisualCausalLM.from_pretrained(
model_id, export=True, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS
)
Expand All @@ -1920,6 +1945,7 @@ def test_compare_to_transformers(self, model_arch):
self.assertTrue(hasattr(ov_model, additional_part))
self.assertIsInstance(getattr(ov_model, additional_part), MODEL_PARTS_CLS_MAPPING[additional_part])
self.assertIsInstance(ov_model.config, PretrainedConfig)
# minicpmv is not designed to be used via forward
if "minicpmv" not in model_arch:
set_seed(SEED)
with torch.no_grad():
Expand All @@ -1941,6 +1967,7 @@ def test_compare_to_transformers(self, model_arch):
ov_outputs = ov_model.generate(**inputs, generation_config=gen_config)
set_seed(SEED)
transformers_outputs = transformers_model.generate(**inputs, generation_config=gen_config)
# original minicpmv always skip input tokens in generation results, while transformers based approach provide them
if model_arch == "minicpmv":
ov_outputs = ov_outputs[:, inputs["input_ids"].shape[1] :]
self.assertTrue(
Expand All @@ -1959,23 +1986,22 @@ def test_generate_utils(self, model_arch):
model = OVModelForVisualCausalLM.from_pretrained(
model_id, export=True, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS
)
preprocessor = AutoProcessor.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS)
if "llava" in model_arch:
question = "<image>\nDescribe image"
elif "minicpmv" in model_arch:
question = "(<image>./</image>)\n What is shown in this image?"

inputs = preprocessor(images=[self.IMAGE.resize((600, 600))], text=[question], return_tensors="pt")

tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS)
inputs = self.gen_inputs(model_arch, "What is shown on this image?", self.IMAGE)
# General case
outputs = model.generate(**inputs, max_new_tokens=10)
outputs = preprocessor.batch_decode(outputs, skip_special_tokens=True)
# filter out original prompt becuase it may contains out of tokenizer tokens e.g. in nanollva text separator = -200
outputs = outputs[:, inputs["input_ids"].shape[1] :]
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertIsInstance(outputs[0], str)

# No input image case
question = "Hi, how are you?"
inputs = preprocessor(images=None, text=question, return_tensors="pt")
inputs = self.gen_inputs(model_arch, question, None)
outputs = model.generate(**inputs, max_new_tokens=10)
outputs = preprocessor.batch_decode(outputs, skip_special_tokens=True)
# filter out original prompt becuase it may contains out of tokenizer tokens e.g. in nanollva text separator = -200
outputs = outputs[:, inputs["input_ids"].shape[1] :]
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
self.assertIsInstance(outputs[0], str)
del model

Expand Down
1 change: 1 addition & 0 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
"mpt": "hf-internal-testing/tiny-random-MptForCausalLM",
"mpnet": "hf-internal-testing/tiny-random-MPNetModel",
"mt5": "stas/mt5-tiny-random",
"nanollava": "katuni4ka/tiny-random-nanollava",
"nystromformer": "hf-internal-testing/tiny-random-NystromformerModel",
"olmo": "katuni4ka/tiny-random-olmo-hf",
"orion": "katuni4ka/tiny-random-orion",
Expand Down

0 comments on commit e545b48

Please sign in to comment.