diff --git a/optimum/intel/openvino/modeling_base.py b/optimum/intel/openvino/modeling_base.py index a48cdf5c92..0b88a8dfbf 100644 --- a/optimum/intel/openvino/modeling_base.py +++ b/optimum/intel/openvino/modeling_base.py @@ -87,6 +87,10 @@ def __init__( self.model = model self.request = None + self.compiled_model = None + if enable_compilation: + self.compile() + self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None self._openvino_config = None diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 39a7bee9a2..942c09f68e 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -14,9 +14,10 @@ import logging import os +from dataclasses import dataclass from pathlib import Path from tempfile import TemporaryDirectory -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import numpy as np import openvino @@ -26,7 +27,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward from transformers.generation import GenerationMixin -from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.utils import ModelOutput from optimum.utils.normalized_config import NormalizedConfigManager @@ -44,6 +45,25 @@ core = Core() +@dataclass +class OVCausalLMOutputWithPast(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + infer_request(`openvino.runtime.InferRequest` to be reused in the generation cycles. + beam_idx (`torch.Tensor` beam search algorimth context for the generation using stateful models + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + infer_request: Optional[openvino.runtime.InferRequest] = None + past_length: Optional[int] = None + + TEXT_GENERATION_EXAMPLE = r""" Example of text generation: ```python @@ -118,8 +138,7 @@ def __init__( self.key_value_output_names = [key for key in self.output_names if "present" in key] self._original_model = self.model.clone() # keep original model for serialization self._pkv_precision = Type.f32 - self.next_beam_idx = None - self._past_length = 0 + self.update_pkv_precision() if self.is_dynamic: self.model = self._reshape(self.model, -1, -1) @@ -197,6 +216,7 @@ def update_pkv_precision(self, force_fp32=False): if self.is_dynamic: self.model = self._reshape(self.model, -1, -1) self.request = None + self.compiled_model = None def _save_pretrained(self, save_directory: Union[str, Path]): """ @@ -322,6 +342,7 @@ def normalized_config(self): def compile(self): if self.request is None: super().compile() + self.compiled_model = self.request self.request = self.request.create_infer_request() def _make_stateful(self): @@ -354,12 +375,12 @@ def prepare_inputs( attention_mask: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, position_ids: Optional[torch.LongTensor] = None, + past_length: Optional[int] = 0, **kwargs, ) -> Dict: batch_size = input_ids.shape[0] if self.config.model_type == "bloom": batch_size *= self.config.num_attention_heads - inputs = {} if not self.stateful: if past_key_values is not None: @@ -395,17 +416,6 @@ def prepare_inputs( else: shape[1] = 0 inputs[input_name] = Tensor(model_inputs.get_element_type(), shape.get_shape()) - else: - # past_key_values are not used explicitly, instead they are handled inside the model - if past_key_values is None: - # This is the first iteration in a sequence, reset all states - if self.request is not None: - self.request.reset_state() - # Set initial value for the next beam_idx input that will be used at the current iteration - # and will be optionally updated by _reorder_cache at the next iterations if beam_search is used - self.next_beam_idx = np.arange(batch_size, dtype=int) - self._past_length = 0 - past_len = self._get_past_length(past_key_values) inputs["input_ids"] = np.array(input_ids) # Add the attention_mask inputs when needed @@ -414,7 +424,7 @@ def prepare_inputs( attention_mask = np.array(attention_mask) else: attention_mask = np.ones( - (input_ids.shape[0], input_ids.shape[1] + past_len), dtype=inputs["input_ids"].dtype + (input_ids.shape[0], input_ids.shape[1] + past_length), dtype=inputs["input_ids"].dtype ) if "attention_mask" in self.input_names: @@ -432,9 +442,11 @@ def prepare_inputs( inputs["position_ids"] = position_ids if "beam_idx" in self.input_names: - inputs["beam_idx"] = ( - self.next_beam_idx if self.next_beam_idx is not None else np.arange(batch_size, dtype=int) - ) + if past_key_values is not None: + if len(past_key_values[0]) > 0: + inputs["beam_idx"] = past_key_values[0] + return inputs + inputs["beam_idx"] = np.arange(batch_size, dtype=int) return inputs @@ -444,33 +456,39 @@ def forward( attention_mask: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, position_ids: Optional[torch.LongTensor] = None, + infer_request: Optional[openvino.runtime.InferRequest] = None, + past_length: Optional[int] = 0, **kwargs, - ) -> CausalLMOutputWithPast: + ) -> OVCausalLMOutputWithPast: self.compile() - inputs = self.prepare_inputs( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, position_ids=position_ids, + past_length=past_length, **kwargs, ) # Run inference - self.request.start_async(inputs, share_inputs=True) - self.request.wait() - logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device) + if infer_request is None: + self.compile() + infer_request = self.compiled_model.create_infer_request() + + infer_request.start_async(inputs, share_inputs=True) + infer_request.wait() + logits = torch.from_numpy(infer_request.get_tensor("logits").data).to(self.device) if self.stateful: # Need a marker to differentiate the first generate iteration from the others in # the first condition at the function beginning above. # It should be something that is not None and it should be True when converted to Boolean. past_key_values = ((),) - self._past_length += input_ids.shape[1] + past_length += input_ids.shape[1] if not self.stateful: if self.use_cache: # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer) - past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names) + past_key_values = tuple(infer_request.get_tensor(key).data for key in self.key_value_output_names) if self.config.model_type not in MULTI_QUERY_ATTN_MODELS: # Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention) past_key_values = tuple( @@ -479,7 +497,28 @@ def forward( else: past_key_values = None - return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values) + return OVCausalLMOutputWithPast( + logits=logits, past_key_values=past_key_values, infer_request=infer_request, past_length=past_length + ) + + def _update_model_kwargs_for_generation( + self, + outputs: OVCausalLMOutputWithPast, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + model_kwargs = super()._update_model_kwargs_for_generation( + outputs=outputs, + model_kwargs=model_kwargs, + is_encoder_decoder=is_encoder_decoder, + standardize_cache_format=standardize_cache_format, + ) + if "infer_request" in outputs: + model_kwargs["infer_request"] = outputs["infer_request"] + if "past_length" in outputs: + model_kwargs["past_length"] = outputs["past_length"] + return model_kwargs # Adapted from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): @@ -487,19 +526,23 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg attention_mask = kwargs.get("attention_mask", None) use_cache = kwargs.get("use_cache", None) + infer_request = kwargs.get("infer_request", None) + past_length = kwargs.get("past_length", 0) + if past_key_values is not None: - past_len = self._get_past_length(past_key_values) + past_length = self._get_past_length(past_key_values, past_length=past_length) # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_len) :] + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. - elif past_len < input_ids.shape[1]: - input_ids = input_ids[:, past_len:] + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens + # >>>>>>> origin/main position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None and "position_ids" in self.input_names: # create position_ids on the fly for batch generation @@ -512,15 +555,17 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg "input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache, + "infer_request": infer_request, "position_ids": position_ids, "attention_mask": attention_mask, + "past_length": past_length, } - def _get_past_length(self, past_key_values=None): + def _get_past_length(self, past_key_values=None, past_length=0): if past_key_values is None: return 0 if self.stateful: - return self._past_length + return past_length if self.config.model_type in MULTI_QUERY_ATTN_MODELS: return past_key_values[0].shape[-2] seq_length_dim = -2 @@ -546,8 +591,10 @@ def _reorder_cache( if self.stateful: # TODO: Apply it differently based on model type # TODO: At least for bloom we need to replicate values for each attention head - self.next_beam_idx = np.array(beam_idx) # save beam_idx to be used as an input in the next iteration - return past_key_values + # save beam_idx and infer_request to be used as an input in the next iteration + # here, beam_idx content is passed inside the past_key_values + + return ((beam_idx),) else: return tuple( tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past) for layer_past in past_key_values @@ -675,8 +722,8 @@ def _reorder_cache( batch_size = beam_idx.shape[0] indices = np.array(range(batch_size * self.config.num_attention_heads)) indices = indices.reshape([batch_size, self.config.num_attention_heads]) - self.next_beam_idx = np.take(indices, beam_idx, 0).flatten() - return past_key_values + next_beam_idx = np.take(indices, beam_idx, 0).flatten() + return ((next_beam_idx),) else: standardized_past = self._convert_to_standard_cache(past_key_values, batch_size=len(beam_idx)) reordered_past = tuple( diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index f84cac8161..c18e13a62b 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -55,7 +55,7 @@ ) from transformers.onnx.utils import get_preprocessor from transformers.testing_utils import slow -from utils_tests import MODEL_NAMES +from utils_tests import MODEL_NAMES, run_on_multiple_threads from optimum.intel import ( OVModelForAudioClassification, @@ -618,6 +618,61 @@ def test_compare_to_transformers(self, model_arch): del ov_model gc.collect() + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_compare_to_transformers_multithreading(self, model_arch): + model_id = MODEL_NAMES[model_arch] + not_stateful = ["gpt_bigcode"] + if is_openvino_version("<", "2024.0"): + not_stateful.append("mixtral") + + if is_openvino_version("<", "2024.1"): + not_stateful.extend(["llama", "gemma"]) + + if "gptq" in model_arch: + self.skipTest("GPTQ model loading unsupported with AutoModelForCausalLM") + + set_seed(SEED) + model_kwargs = {} + if model_arch in self.REMOTE_CODE_MODELS: + model_kwargs = {"trust_remote_code": True} + + ov_model = OVModelForCausalLM.from_pretrained(model_id, export=True, ov_config=F32_CONFIG, **model_kwargs) + self.assertIsInstance(ov_model.config, PretrainedConfig) + self.assertTrue(ov_model.use_cache) + self.assertEqual(ov_model.stateful, ov_model.config.model_type not in not_stateful) + + transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) + tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS) + if model_arch == "qwen": + transformers_model.to(torch.float32) + inputs_list = ["This is a cat", "This is a dog", "Yet another test"] + tokens_list = [ + tokenizer(inputs, return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None) + for inputs in inputs_list + ] + + def run_ov_model(tokens, transformers_model, ov_model): + set_seed(SEED) + ov_outputs = ov_model(**tokens) + self.assertTrue("logits" in ov_outputs) + self.assertIsInstance(ov_outputs.logits, torch.Tensor) + self.assertTrue("past_key_values" in ov_outputs) + self.assertIsInstance(ov_outputs.past_key_values, tuple) + is_stateful = ov_model.config.model_type not in not_stateful + self.assertEqual(ov_model.stateful, is_stateful) + if is_stateful: + self.assertTrue(len(ov_outputs.past_key_values) == 1 and len(ov_outputs.past_key_values[0]) == 0) + with torch.no_grad(): + transformers_outputs = transformers_model(**tokens) + # Compare tensor outputs + self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4)) + + run_on_multiple_threads(run_ov_model, tokens_list, (transformers_model, ov_model)) + + del transformers_model + del ov_model + gc.collect() + @parameterized.expand(SUPPORTED_ARCHITECTURES) @pytest.mark.run_slow @slow @@ -650,6 +705,95 @@ def test_pipeline(self, model_arch): del model gc.collect() + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_pipeline_multithreading(self, model_arch): + model_kwargs = {} + model_id = MODEL_NAMES[model_arch] + if model_arch in self.REMOTE_CODE_MODELS: + model_kwargs = { + "config": AutoConfig.from_pretrained(model_id, trust_remote_code=True), + "trust_remote_code": True, + } + + model = OVModelForCausalLM.from_pretrained( + model_id, export=True, use_cache=False, compile=False, **model_kwargs + ) + model.eval() + model.config.encoder_no_repeat_ngram_size = 0 + model.to("cpu") + model.half() + model.compile() + + def run_ov_model(input_text, model): + # Tokenizer is not supposed to be shared by multiple threads + tokenizer = AutoTokenizer.from_pretrained( + model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS + ) + pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) + outputs = pipe(input_text, max_length=30) + self.assertEqual(pipe.device, model.device) + for i in range(len(outputs)): + self.assertTrue(all(input_text[i] in item["generated_text"] for item in outputs[i])) + del pipe + + inputs_list = [["This is a sample"], ["This is a second sample"], ["This is a third sample"]] + run_on_multiple_threads(run_ov_model, inputs_list, [model]) + del model + gc.collect() + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_multiple_inputs(self, model_arch): + model_id = MODEL_NAMES[model_arch] + set_seed(SEED) + if model_arch == "qwen": + self.skipTest("Qwen tokenizer does not support padding") + model_kwargs = {} + if model_arch in self.REMOTE_CODE_MODELS: + model_kwargs = { + "config": AutoConfig.from_pretrained(model_id, trust_remote_code=True), + "trust_remote_code": True, + } + model = OVModelForCausalLM.from_pretrained(model_id, export=True, compile=False, **model_kwargs) + tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS) + tokenizer.pad_token = tokenizer.eos_token + texts = ["this is a simple input", "this is a second simple input", "this is a third simple input"] + tokens = tokenizer(texts, padding=True, return_tensors="pt") + generation_config = GenerationConfig(encoder_no_repeat_ngram_size=0, max_new_tokens=20, num_beams=2) + outputs = model.generate(**tokens, generation_config=generation_config) + self.assertIsInstance(outputs, torch.Tensor) + self.assertEqual(outputs.shape[0], 3) + del model + gc.collect() + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_multiple_inputs_multithreading(self, model_arch): + model_id = MODEL_NAMES[model_arch] + set_seed(SEED) + if model_arch == "qwen": + self.skipTest("Qwen tokenizer does not support padding") + model_kwargs = {} + if model_arch in self.REMOTE_CODE_MODELS: + model_kwargs = { + "config": AutoConfig.from_pretrained(model_id, trust_remote_code=True), + "trust_remote_code": True, + } + model = OVModelForCausalLM.from_pretrained(model_id, export=True, compile=False, **model_kwargs) + tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS) + tokenizer.pad_token = tokenizer.eos_token + texts = ["this is a simple input", "this is a second simple input", "this is a third simple input"] + tokens = tokenizer(texts, padding=True, return_tensors="pt") + generation_config = GenerationConfig(encoder_no_repeat_ngram_size=0, max_new_tokens=20, num_beams=2) + + def run_ov_model(tokens, model): + outputs = model.generate(**tokens, generation_config=generation_config) + self.assertIsInstance(outputs, torch.Tensor) + self.assertEqual(outputs.shape[0], 3) + + tokens_list = [tokens, tokens, tokens, tokens] # running in 4 threads + run_on_multiple_threads(run_ov_model, tokens_list, [model]) + del model + gc.collect() + def test_model_and_decoder_same_device(self): model_id = MODEL_NAMES["gpt2"] model = OVModelForCausalLM.from_pretrained(model_id, export=True) @@ -751,6 +895,7 @@ def test_default_filling_attention_mask_and_position_ids(self): ) outs_without_attn_mask_step2 = model_with_cache(input_ids=input_ids, past_key_values=past_key_values) self.assertTrue(torch.allclose(outs_step2.logits, outs_without_attn_mask_step2.logits)) + del model_with_cache gc.collect() diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index ca56f6d552..3d98575f68 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import threading + import numpy as np import torch @@ -152,3 +154,37 @@ def get_num_quantized_nodes(ov_model): if elem.get_output_element_type(i).get_type_name() in ["i4", "u4"]: num_int4 += 1 return num_fake_quantize, num_int8, num_int4 + + +### Multithreading + + +class OVThread(threading.Thread): + def __init__(self, target, args): + super().__init__() + self.target = target + self.args = args + + def run(self): + self.exception = None + try: + self.target(*self.args) + except Exception as e: + self.exception = e + + def join(self): + super().join() + if self.exception: + raise self.exception + + +# Each set of args is run in a separate thread. +# Amount of such sets define how many threads are spawned. +def run_on_multiple_threads(target, list, extra_args): + threads = [] + for input in list: + threads.append(OVThread(target=target, args=(input, *extra_args))) + for thread in threads: + thread.start() + for thread in threads: + thread.join()