Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Apr 30, 2024
1 parent 878921b commit 1ee4535
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 14 deletions.
28 changes: 14 additions & 14 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1407,22 +1407,22 @@ def _group_beam_search(
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]

if return_dict_in_generate and output_scores:
beam_indices[beam_group_idx] = tuple(
beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0]))
)
if return_dict_in_generate and output_scores:
beam_indices[beam_group_idx] = tuple(
beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0]))
)

input_ids[batch_group_indices] = group_input_ids[beam_idx]
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
current_tokens[batch_group_indices] = group_input_ids[:, -1]
input_ids[batch_group_indices] = group_input_ids[beam_idx]
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
current_tokens[batch_group_indices] = group_input_ids[:, -1]

# (beam_idx // group_size) -> batch_idx
# (beam_idx % group_size) -> offset of idx inside the group
reordering_indices[batch_group_indices] = (
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor")
+ group_start_idx
+ (beam_idx % group_size)
)
# (beam_idx // group_size) -> batch_idx
# (beam_idx % group_size) -> offset of idx inside the group
reordering_indices[batch_group_indices] = (
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor")
+ group_start_idx
+ (beam_idx % group_size)
)

# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
Expand Down
81 changes: 81 additions & 0 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,87 @@ def test_default_filling_attention_mask_and_position_ids(self):
del model_with_cache
gc.collect()

def test_beam_search(self):
model_id = MODEL_NAMES["llama"]
ov_model_stateful = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True, stateful=True)
ov_model_stateless = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True, stateful=False)
transformers_model = AutoModelForCausalLM.from_pretrained(model_id)

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokens = tokenizer(["Today is a nice day and I am longer", "This is me"], return_tensors="pt", padding=True)
ov_model_stateful.generation_config.eos_token_id = None
ov_model_stateless.generation_config.eos_token_id = None
transformers_model.generation_config.eos_token_id = None
ov_model_stateful.config.eos_token_id = None
ov_model_stateless.config.eos_token_id = None
transformers_model.config.eos_token_id = None

# beam search
gen_config = GenerationConfig(
max_new_tokens=10,
min_new_tokens=10,
num_beams=4,
do_sample=False,
eos_token_id=None,
)

transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config)
ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config)
self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs))
ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config)
self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs))
# beam sample
gen_config = GenerationConfig(
max_new_tokens=10,
min_new_tokens=10,
num_beams=4,
do_sample=True,
eos_token_id=None,
top_k=1,
)

transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config)
ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config)
self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs))
ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config)
self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs))

# group beam search
gen_config = GenerationConfig(
max_new_tokens=10,
min_new_tokens=10,
num_beams=4,
do_sample=False,
eos_token_id=None,
num_beam_groups=2,
diversity_penalty=0.0000001,
)

transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config)
ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config)
self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs))
ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config)
self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs))

# constrained beam search
force_word = "cat"
force_words_ids = [tokenizer([force_word], add_special_tokens=False).input_ids]
gen_config = GenerationConfig(
max_new_tokens=10,
min_new_tokens=10,
num_beams=4,
do_sample=False,
eos_token_id=None,
force_words_ids=force_words_ids,
)

transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config)
ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config)
self.assertTrue(torch.allclose(ov_stateful_outputs, transformers_outputs))
ov_stateless_outputs = ov_model_stateless.generate(**tokens, generation_config=gen_config)
self.assertTrue(torch.allclose(ov_stateless_outputs, transformers_outputs))


class OVModelForMaskedLMIntegrationTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = (
Expand Down

0 comments on commit 1ee4535

Please sign in to comment.