diff --git a/optimum/intel/openvino/modeling.py b/optimum/intel/openvino/modeling.py index 4d859be5ea..e7f10fb7dd 100644 --- a/optimum/intel/openvino/modeling.py +++ b/optimum/intel/openvino/modeling.py @@ -129,8 +129,8 @@ def to(self, device: str): Use the specified `device` for inference. For example: "cpu" or "gpu". `device` can be in upper or lower case. To speed up first inference, call `.compile()` after `.to()`. """ + self.compiled_model = None self._device = str(device).upper() - self.request = None return self def forward(self, *args, **kwargs): @@ -197,8 +197,14 @@ def forward( inputs["token_type_ids"] = token_type_ids # Run inference - outputs = self.request(inputs) - logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"] + infer_request = self.compiled_model.create_infer_request() + infer_request.start_async(inputs) + infer_request.wait() + logits = ( + torch.from_numpy(infer_request.get_tensor("logits").data).to(self.device) + if not np_inputs + else infer_request.get_tensor("logits").data + ) return SequenceClassifierOutput(logits=logits) @@ -263,12 +269,18 @@ def forward( inputs["token_type_ids"] = token_type_ids # Run inference - outputs = self.request(inputs) + infer_request = self.compiled_model.create_infer_request() + infer_request.start_async(inputs) + infer_request.wait() start_logits = ( - torch.from_numpy(outputs["start_logits"]).to(self.device) if not np_inputs else outputs["start_logits"] + torch.from_numpy(infer_request.get_tensor("start_logits").data).to(self.device) + if not np_inputs + else infer_request.get_tensor("start_logits").data ) end_logits = ( - torch.from_numpy(outputs["end_logits"]).to(self.device) if not np_inputs else outputs["end_logits"] + torch.from_numpy(infer_request.get_tensor("end_logits").data).to(self.device) + if not np_inputs + else infer_request.get_tensor("end_logits").data ) return QuestionAnsweringModelOutput(start_logits=start_logits, end_logits=end_logits) @@ -333,8 +345,14 @@ def forward( inputs["token_type_ids"] = token_type_ids # Run inference - outputs = self.request(inputs) - logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"] + infer_request = self.compiled_model.create_infer_request() + infer_request.start_async(inputs) + infer_request.wait() + logits = ( + torch.from_numpy(infer_request.get_tensor("logits").data).to(self.device) + if not np_inputs + else infer_request.get_tensor("logits").data + ) return TokenClassifierOutput(logits=logits) @@ -398,11 +416,13 @@ def forward( inputs["token_type_ids"] = token_type_ids # Run inference - outputs = self.request(inputs) + infer_request = self.compiled_model.create_infer_request() + infer_request.start_async(inputs) + infer_request.wait() last_hidden_state = ( - torch.from_numpy(outputs["last_hidden_state"]).to(self.device) + torch.from_numpy(infer_request.get_tensor("last_hidden_state").data).to(self.device) if not np_inputs - else outputs["last_hidden_state"] + else infer_request.get_tensor("last_hidden_state").data ) return BaseModelOutput(last_hidden_state=last_hidden_state) @@ -468,8 +488,14 @@ def forward( inputs["token_type_ids"] = token_type_ids # Run inference - outputs = self.request(inputs) - logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"] + infer_request = self.compiled_model.create_infer_request() + infer_request.start_async(inputs) + infer_request.wait() + logits = ( + torch.from_numpy(infer_request.get_tensor("logits").data).to(self.device) + if not np_inputs + else infer_request.get_tensor("logits").data + ) return MaskedLMOutput(logits=logits) @@ -595,8 +621,14 @@ def forward( } # Run inference - outputs = self.request(inputs) - logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"] + infer_request = self.compiled_model.create_infer_request() + infer_request.start_async(inputs) + infer_request.wait() + logits = ( + torch.from_numpy(infer_request.get_tensor("logits").data).to(self.device) + if not np_inputs + else infer_request.get_tensor("logits").data + ) return ImageClassifierOutput(logits=logits) @@ -660,8 +692,14 @@ def forward( inputs["attention_mask"] = attention_mask # Run inference - outputs = self.request(inputs) - logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"] + infer_request = self.compiled_model.create_infer_request() + infer_request.start_async(inputs) + infer_request.wait() + logits = ( + torch.from_numpy(infer_request.get_tensor("logits").data).to(self.device) + if not np_inputs + else infer_request.get_tensor("logits").data + ) return SequenceClassifierOutput(logits=logits) @@ -732,8 +770,14 @@ def forward( inputs["attention_mask"] = attention_mask # Run inference - outputs = self.request(inputs) - logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"] + infer_request = self.compiled_model.create_infer_request() + infer_request.start_async(inputs) + infer_request.wait() + logits = ( + torch.from_numpy(infer_request.get_tensor("logits").data).to(self.device) + if not np_inputs + else infer_request.get_tensor("logits").data + ) return CausalLMOutput(logits=logits) @@ -813,12 +857,19 @@ def forward( inputs["attention_mask"] = attention_mask # Run inference - outputs = self.request(inputs) - logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"] + infer_request = self.compiled_model.create_infer_request() + infer_request.start_async(inputs) + infer_request.wait() + logits = ( + torch.from_numpy(infer_request.get_tensor("logits").data).to(self.device) + if not np_inputs + else infer_request.get_tensor("logits").data + ) embeddings = ( - torch.from_numpy(outputs["embeddings"]).to(self.device) if not np_inputs else outputs["embeddings"] + torch.from_numpy(infer_request.get_tensor("embeddings").data).to(self.device) + if not np_inputs + else infer_request.get_tensor("embeddings").data ) - return XVectorOutput(logits=logits, embeddings=embeddings) @@ -890,7 +941,13 @@ def forward( inputs["attention_mask"] = attention_mask # Run inference - outputs = self.request(inputs) - logits = torch.from_numpy(outputs["logits"]).to(self.device) if not np_inputs else outputs["logits"] + infer_request = self.compiled_model.create_infer_request() + infer_request.start_async(inputs) + infer_request.wait() + logits = ( + torch.from_numpy(infer_request.get_tensor("logits").data).to(self.device) + if not np_inputs + else infer_request.get_tensor("logits").data + ) return TokenClassifierOutput(logits=logits) diff --git a/optimum/intel/openvino/modeling_base.py b/optimum/intel/openvino/modeling_base.py index 05dc3af9b5..deff686e38 100644 --- a/optimum/intel/openvino/modeling_base.py +++ b/optimum/intel/openvino/modeling_base.py @@ -89,6 +89,7 @@ def __init__( self.model = model self.request = None + self.compiled_model = None if enable_compilation: self.compile() @@ -343,7 +344,7 @@ def _to_load( ) def compile(self): - if self.request is None: + if self.compiled_model is None: logger.info(f"Compiling the model to {self._device} ...") ov_config = {**self.ov_config} if "CACHE_DIR" not in self.ov_config.keys() and not str(self.model_save_dir).startswith(gettempdir()): @@ -351,7 +352,7 @@ def compile(self): cache_dir = Path(self.model_save_dir).joinpath("model_cache") ov_config["CACHE_DIR"] = str(cache_dir) logger.info(f"Setting OpenVINO CACHE_DIR to {str(cache_dir)}") - self.request = core.compile_model(self.model, self._device, ov_config) + self.compiled_model = core.compile_model(self.model, self._device, ov_config) # OPENVINO_LOG_LEVEL can be found in https://docs.openvino.ai/2023.2/openvino_docs_OV_UG_supported_plugins_AUTO_debugging.html if "OPENVINO_LOG_LEVEL" in os.environ and int(os.environ["OPENVINO_LOG_LEVEL"]) > 2: logger.info(f"{self._device} SUPPORTED_PROPERTIES:") @@ -403,6 +404,7 @@ def half(self): apply_moc_transformations(self.model, cf=False) compress_model_transformation(self.model) self.request = None + self.compiled_model = None return self def forward(self, *args, **kwargs): diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 8a2167eae4..1bd57042c8 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -111,7 +111,6 @@ def __init__( "`dynamic_shapes` was set to `False` but static shapes are not supported for causal language model. Please set `dynamic_shapes=True`." ) - enable_compilation = kwargs.get("compile", True) kwargs["compile"] = False # avoid extra compilation in the base class super().__init__( @@ -135,21 +134,13 @@ def __init__( self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) self.key_value_input_names = [key for key in self.input_names if "key_values" in key] self.key_value_output_names = [key for key in self.output_names if "present" in key] - self._original_model = self.model.clone() # keep original model for serialization - self._pkv_precision = Type.f32 - self.next_beam_idx = None - self.update_pkv_precision() - if self.is_dynamic: - self.model = self._reshape(self.model, -1, -1) is_stateful_supported = ensure_stateful_is_available(warn=False) - if self.use_cache and not self.stateful: logger.warn( "Provided model does not contain state. It may lead to sub-optimal performance." "Please reexport model with updated OpenVINO version >= 2023.3.0 calling the `from_pretrained` method with original model " "and `export=True` parameter" ) - if self.stateful: if stateful is None: stateful = is_stateful_supported @@ -176,7 +167,13 @@ def raise_error(model_prop, user_prop, name): if use_cache ^ self.use_cache: raise_error(self.use_cache, use_cache, "use_cache") - if enable_compilation: + def init_ov_model(self, compile=True): + self._pkv_precision = Type.f32 + self.update_pkv_precision(force_fp32=False) + if self.is_dynamic: + self.model = self._reshape(self.model, -1, -1) + self._original_model = self.model.clone() # keep original model for serialization + if compile: self.compile() def update_pkv_precision(self, force_fp32=False): @@ -282,9 +279,10 @@ def _from_transformers( config.is_decoder = True config.is_encoder_decoder = False config.save_pretrained(save_dir_path) - return cls._from_pretrained( + model_instance = cls._from_pretrained( model_id=save_dir_path, config=config, use_cache=use_cache, load_in_8bit=False, stateful=None, **kwargs ) + return model_instance def _reshape( self, @@ -322,14 +320,19 @@ def reshape(self, batch_size: int, sequence_length: int): return self def compile(self): - if self.request is None: + if self.compiled_model is None: super().compile() - self.request = self.request.create_infer_request() def _make_stateful(self): patch_stateful(self.config, self.model) self.stateful = True + def create_infer_request(self): + if self.compiled_model is None: + self.compile() + if self.request is None: + self.request = self.compiled_model.create_infer_request() + @add_start_docstrings( """ @@ -359,6 +362,8 @@ def forward( **kwargs, ) -> CausalLMOutputWithPast: self.compile() + self.create_infer_request() + if self.use_cache and past_key_values is not None: input_ids = input_ids[:, -1:] @@ -556,7 +561,17 @@ def _from_pretrained( else: init_cls = cls - return init_cls(model=model, config=config, model_save_dir=model_cache_path.parent, **kwargs) + model_instance = init_cls(model=model, config=config, model_save_dir=model_cache_path.parent, **kwargs) + model_instance.init_ov_model(compile=kwargs.get("compile", True)) + model_instance.request = None + return model_instance + + def clone(self): + model_instance = self.__class__(model=self.model, config=self.config, compile=False) + model_instance.compiled_model = self.compiled_model + model_instance._pkv_precision = self._pkv_precision + model_instance.request = None + return model_instance class OVBloomForCausalLM(OVModelForCausalLM): diff --git a/optimum/intel/openvino/modeling_diffusion.py b/optimum/intel/openvino/modeling_diffusion.py index 4733ca9e9b..e919541e62 100644 --- a/optimum/intel/openvino/modeling_diffusion.py +++ b/optimum/intel/openvino/modeling_diffusion.py @@ -532,7 +532,7 @@ def __init__( for inputs in self.model.inputs } self.ov_config = ov_config or {**self.parent_model.ov_config} - self.request = None + self.compiled_model = None self._model_name = model_name self._model_dir = Path(model_dir or parent_model._model_save_dir) config_path = self._model_dir / model_name / self.CONFIG_NAME @@ -541,13 +541,13 @@ def __init__( self.ov_config["CACHE_DIR"] = os.path.join(self._model_dir, self._model_name, "model_cache") def _compile(self): - if self.request is None: + if self.compiled_model is None: logger.info(f"Compiling the {self._model_name} to {self.device} ...") - self.request = core.compile_model(self.model, self.device, self.ov_config) + self.compiled_model = core.compile_model(self.model, self.device, self.ov_config) # OPENVINO_LOG_LEVEL can be found in https://docs.openvino.ai/2023.2/openvino_docs_OV_UG_supported_plugins_AUTO_debugging.html if "OPENVINO_LOG_LEVEL" in os.environ and int(os.environ["OPENVINO_LOG_LEVEL"]) > 2: logger.info(f"{self.device} SUPPORTED_PROPERTIES:") - _print_compiled_model_properties(self.request) + _print_compiled_model_properties(self.compiled_model) @property def device(self): @@ -570,8 +570,11 @@ def __call__(self, input_ids: np.ndarray): inputs = { "input_ids": input_ids, } - outputs = self.request(inputs, share_inputs=True) - return list(outputs.values()) + infer_request = self.compiled_model.create_infer_request() + infer_request.start_async(inputs, share_inputs=True) + infer_request.wait() + outputs = [infer_request.get_tensor(output).data for output in infer_request.results] + return outputs class OVModelUnet(OVModelPart): @@ -604,8 +607,11 @@ def __call__( if timestep_cond is not None: inputs["timestep_cond"] = timestep_cond - outputs = self.request(inputs, share_inputs=True) - return list(outputs.values()) + infer_request = self.compiled_model.create_infer_request() + infer_request.start_async(inputs, share_inputs=True) + infer_request.wait() + outputs = [infer_request.get_tensor(output).data for output in infer_request.results] + return outputs class OVModelVaeDecoder(OVModelPart): @@ -620,8 +626,11 @@ def __call__(self, latent_sample: np.ndarray): inputs = { "latent_sample": latent_sample, } - outputs = self.request(inputs, share_inputs=True) - return list(outputs.values()) + infer_request = self.compiled_model.create_infer_request() + infer_request.start_async(inputs, share_inputs=True) + infer_request.wait() + outputs = [infer_request.results[output].data for output in infer_request.results] + return outputs def _compile(self): if "GPU" in self.device: @@ -641,8 +650,11 @@ def __call__(self, sample: np.ndarray): inputs = { "sample": sample, } - outputs = self.request(inputs, share_inputs=True) - return list(outputs.values()) + infer_request = self.compiled_model.create_infer_request() + infer_request.start_async(inputs, share_inputs=True) + infer_request.wait() + outputs = [infer_request.get_tensor(output).data for output in infer_request.results] + return outputs def _compile(self): if "GPU" in self.device: diff --git a/optimum/intel/openvino/modeling_seq2seq.py b/optimum/intel/openvino/modeling_seq2seq.py index 1466681729..737d0218a0 100644 --- a/optimum/intel/openvino/modeling_seq2seq.py +++ b/optimum/intel/openvino/modeling_seq2seq.py @@ -442,7 +442,7 @@ def __init__(self, model: openvino.runtime.Model, device: str, ov_config: Dict, self.input_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.inputs)} self.main_input_name = main_input_name self.ov_config = ov_config - self.request = None + self.compiled_model = None @add_start_docstrings_to_model_forward(ENCODER_INPUTS_DOCSTRING) def forward( @@ -461,9 +461,10 @@ def forward( inputs["attention_mask"] = attention_mask # Run inference - last_hidden_state = torch.from_numpy( - self.request(inputs, share_inputs=True, share_outputs=True)["last_hidden_state"] - ).to(self.device) + infer_request = self.compiled_model.create_infer_request() + infer_request.start_async(inputs, share_inputs=True) + infer_request.wait() + last_hidden_state = torch.from_numpy(infer_request.get_tensor("last_hidden_state").data).to(self.device) return BaseModelOutput(last_hidden_state=last_hidden_state) @@ -471,9 +472,9 @@ def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) def _compile(self): - if self.request is None: + if self.compiled_model is None: logger.info(f"Compiling the encoder to {self._device} ...") - self.request = core.compile_model(self.model, self._device, self.ov_config) + self.compiled_model = core.compile_model(self.model, self._device, self.ov_config) # OPENVINO_LOG_LEVEL can be found in https://docs.openvino.ai/2023.2/openvino_docs_OV_UG_supported_plugins_AUTO_debugging.html if "OPENVINO_LOG_LEVEL" in os.environ and int(os.environ["OPENVINO_LOG_LEVEL"]) > 2: logger.info(f"{self._device} SUPPORTED_PROPERTIES:") @@ -509,7 +510,7 @@ def __init__(self, model: openvino.runtime.Model, device: str, ov_config: Dict): self.num_pkv = 4 self.ov_config = ov_config - self.request = None + self.compiled_model = None @add_start_docstrings_to_model_forward(DECODER_INPUTS_DOCSTRING) def forward( @@ -546,13 +547,14 @@ def forward( if "decoder_attention_mask" in self.input_names and decoder_attention_mask is not None: inputs["decoder_attention_mask"] = decoder_attention_mask # Run inference - self.request.start_async(inputs, share_inputs=True) - self.request.wait() - logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device) + infer_request = self.compiled_model.create_infer_request() + infer_request.start_async(inputs, share_inputs=True) + infer_request.wait() + logits = torch.from_numpy(infer_request.get_tensor("logits").data).to(self.device) # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the # self-attention layer and 2 to the cross-attention layer) - out_past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names) + out_past_key_values = tuple(infer_request.get_tensor(key).data for key in self.key_value_output_names) # Tuple of tuple of length `n_layers`, with each tuple of length equal to: # * 4 for the decoder without cache (k/v of self-attention + k/v of cross-attention) @@ -574,14 +576,13 @@ def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) def _compile(self): - if self.request is None: + if self.compiled_model is None: logger.info(f"Compiling the decoder to {self._device} ...") - compiled_model = core.compile_model(self.model, self._device, self.ov_config) - self.request = compiled_model.create_infer_request() + self.compiled_model = core.compile_model(self.model, self._device, self.ov_config) # OPENVINO_LOG_LEVEL can be found in https://docs.openvino.ai/2023.2/openvino_docs_OV_UG_supported_plugins_AUTO_debugging.html if "OPENVINO_LOG_LEVEL" in os.environ and int(os.environ["OPENVINO_LOG_LEVEL"]) > 2: logger.info(f"{self._device} SUPPORTED_PROPERTIES:") - _print_compiled_model_properties(compiled_model) + _print_compiled_model_properties(self.compiled_model) @add_start_docstrings( diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index 9af0b9c9a6..ed6b0a66f7 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -294,6 +294,8 @@ def _quantize_ovcausallm( # Prefeth past_key_values self.model.update_pkv_precision(True) self.model.compile() + self.model.create_infer_request() + subset_size = kwargs.get("subset_size", 300) data_cache = [] diff --git a/tests/openvino/gen_batch.py b/tests/openvino/gen_batch.py new file mode 100644 index 0000000000..38ecd6d6b2 --- /dev/null +++ b/tests/openvino/gen_batch.py @@ -0,0 +1,82 @@ +import threading +from datetime import datetime + +from transformers import AutoConfig, AutoTokenizer, set_seed + +from optimum.intel import OVModelForCausalLM + + +set_seed(10) +model_path = "/model" +tokenizer = AutoTokenizer.from_pretrained(model_path) +tokenizer.pad_token = "[PAD]" +tokenizer.padding_side = "left" +NUM_THREADS = 3 +prompt1 = [" The weather is "] +prompt2 = [" Openvino is a ", " What the the relativity theory "] +prompt3 = [ + " Are cats smarter that dogs ", + " How big is an elephant ", + " the water in the ocean is much hotter than before ", +] +prompts = [prompt1, prompt2, prompt3] + +OV_CONFIG = {"PERFORMANCE_HINT": "LATENCY", "CACHE_DIR": "", "NUM_STREAMS": "1"} +model = OVModelForCausalLM.from_pretrained( + model_path, + config=AutoConfig.from_pretrained(model_path, trust_remote_code=True), + ov_config=OV_CONFIG, + compile=True, +) + +threads = [None] * NUM_THREADS +results = [None] * NUM_THREADS + + +def print_response(t, p, r): + print("THREAD", t) + print("PROMPT:", p) + for answer in r: + print("Answer:") + print(tokenizer.decode(answer, skip_special_tokens=True)) + + +def gen_thread(prompt, results, i): + inputs = tokenizer(prompt, return_tensors="pt", padding=True) + generate_kwargs = { + "input_ids": inputs.input_ids, + "max_new_tokens": 200, + "temperature": 1.0, + "do_sample": True, + "top_p": 1.0, + "top_k": 50, + "num_beams": 5, + "repetition_penalty": 1.1, + } + start = datetime.now() + model_exec = model.clone() + end = datetime.now() + print("cloning model duration", (end - start).total_seconds() * 1000000, "us") + outputs = model_exec.generate(**generate_kwargs) + num_tok = 0 + for x in range(len(prompt)): + num_tok += outputs[x].numel() - inputs.get("input_ids")[x].numel() + results[i] = outputs, num_tok + + +start = datetime.now() +for i in range(len(threads)): + threads[i] = threading.Thread(target=gen_thread, args=(prompts[i], results, i)) + threads[i].start() + +total_tok = 0 +for i in range(len(threads)): + threads[i].join() + total_tok += results[i][1] +end = datetime.now() + +for i in range(len(threads)): + print_response(i, prompts[i], results[i][0]) + +print("Generation time [s]", ((end - start).total_seconds()), "tokens:", total_tok) +print("Throughput:", total_tok * 60 / ((end - start).total_seconds()), "tokens/min") diff --git a/tests/openvino/gen_img.py b/tests/openvino/gen_img.py new file mode 100644 index 0000000000..6e06132ba5 --- /dev/null +++ b/tests/openvino/gen_img.py @@ -0,0 +1,55 @@ +import datetime +import threading + +from diffusers import DDIMScheduler + +from optimum.intel.openvino import OVStableDiffusionPipeline + + +MODEL_PATH = "/model" +OV_CONFIG = {"PERFORMANCE_HINT": "LATENCY", "NUM_STREAMS": "1"} + + +pipe = OVStableDiffusionPipeline.from_pretrained(MODEL_PATH, device="CPU", ov_config=OV_CONFIG) +pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + +prompt1 = [" Zebras in space "] +prompt2 = [" The statue of liberty in New York", " Big Ben in London "] +prompt3 = [" pigs on the grass field", "beach in the storm", "sail yacht on the ocean"] +prompts = [prompt1, prompt2, prompt3] + +NUM_THREADS = 3 + +threads = [None] * NUM_THREADS +results = [None] * NUM_THREADS + + +def save_response(t, p, r): + print("THREAD", t) + print("PROMPT:", p) + for i in range(len(r)): + print("IMG:", i) + r[i].save("img_" + str(t) + "_" + str(i) + ".png", format="PNG") + + +def gen_thread(prompt, results, i): + text = prompt + images = pipe(text).images + results[i] = images + + +start = datetime.datetime.now() +for i in range(len(threads)): + threads[i] = threading.Thread(target=gen_thread, args=(prompts[i], results, i)) + threads[i].start() +nu_img = 0 +for i in range(len(threads)): + threads[i].join() + nu_img += len(results[i]) +end = datetime.datetime.now() + +for i in range(len(threads)): + save_response(i, prompts[i], results[i]) + +print("Generation time [s]", ((end - start).total_seconds()), "images:", nu_img) +print("Throughput:", nu_img * 60 / ((end - start).total_seconds()), "images/min") diff --git a/tests/openvino/gen_seq2seq.py b/tests/openvino/gen_seq2seq.py new file mode 100644 index 0000000000..27d3ed2a45 --- /dev/null +++ b/tests/openvino/gen_seq2seq.py @@ -0,0 +1,51 @@ +import datetime +import threading + +from transformers import AutoTokenizer, pipeline + +from optimum.intel import OVModelForSeq2SeqLM + + +model_id = "echarlaix/t5-small-openvino" +model = OVModelForSeq2SeqLM.from_pretrained(model_id) +tokenizer = AutoTokenizer.from_pretrained(model_id) +pipe = pipeline("translation_en_to_fr", model=model, tokenizer=tokenizer) + +prompt1 = ["I live in Europe"] +prompt2 = ["What is your name?", "The dog is very happy"] +prompt3 = ["It's a beautiful weather today", "Yes", "Good morning"] +prompts = [prompt1, prompt2, prompt3] + +NUM_THREADS = 3 + +threads = [None] * NUM_THREADS +results = [None] * NUM_THREADS + + +def print_response(t, p, r): + print("THREAD", t) + print("PROMPT:", p) + for i in range(len(r)): + print("TRANSLATION", i, ":", r[i]["translation_text"]) + + +def gen_thread(prompt, results, i): + translations = pipe(prompt) + results[i] = translations + + +start = datetime.datetime.now() +for i in range(len(threads)): + threads[i] = threading.Thread(target=gen_thread, args=(prompts[i], results, i)) + threads[i].start() +nu_trans = 0 +for i in range(len(threads)): + threads[i].join() + nu_trans += len(results[i]) +end = datetime.datetime.now() + +for i in range(len(threads)): + print_response(i, prompts[i], results[i]) + +print("Generation time [s]", ((end - start).total_seconds()), "translations:", nu_trans) +print("Throughput:", nu_trans / ((end - start).total_seconds()), "translations/s") diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index e51b50a5b2..c1746f5538 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -129,9 +129,9 @@ def test_load_from_hub_and_save_model(self): self.assertTrue(manual_openvino_cache_dir.is_dir()) self.assertGreaterEqual(len(list(manual_openvino_cache_dir.glob("*.blob"))), 1) if is_openvino_version("<", "2023.3"): - self.assertEqual(loaded_model.request.get_property("PERFORMANCE_HINT").name, "THROUGHPUT") + self.assertEqual(loaded_model.compiled_model.get_property("PERFORMANCE_HINT").name, "THROUGHPUT") else: - self.assertEqual(loaded_model.request.get_property("PERFORMANCE_HINT"), "THROUGHPUT") + self.assertEqual(loaded_model.compiled_model.get_property("PERFORMANCE_HINT"), "THROUGHPUT") with tempfile.TemporaryDirectory() as tmpdirname: loaded_model.save_pretrained(tmpdirname) @@ -748,7 +748,7 @@ def test_pipeline(self, model_arch): @parameterized.expand(TIMM_MODELS) def test_compare_to_timm(self, model_id): ov_model = OVModelForImageClassification.from_pretrained(model_id, export=True, ov_config=F32_CONFIG) - self.assertEqual(ov_model.request.get_property("INFERENCE_PRECISION_HINT").to_string(), "f32") + self.assertEqual(ov_model.compiled_model.get_property("INFERENCE_PRECISION_HINT").to_string(), "f32") self.assertIsInstance(ov_model.config, PretrainedConfig) timm_model = timm.create_model(model_id, pretrained=True) preprocessor = TimmImageProcessor.from_pretrained(model_id) @@ -886,20 +886,25 @@ def test_compare_with_and_without_past_key_values(self): text = "This is a sample input" tokens = tokenizer(text, return_tensors="pt") - model_with_pkv = OVModelForSeq2SeqLM.from_pretrained(model_id, export=True, use_cache=True) + model_with_pkv = OVModelForSeq2SeqLM.from_pretrained( + model_id, export=True, use_cache=True, ov_config=F32_CONFIG + ) _ = model_with_pkv.generate(**tokens) # warmup with Timer() as with_pkv_timer: outputs_model_with_pkv = model_with_pkv.generate( **tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1 ) - model_without_pkv = OVModelForSeq2SeqLM.from_pretrained(model_id, export=True, use_cache=False) + model_without_pkv = OVModelForSeq2SeqLM.from_pretrained( + model_id, export=True, use_cache=False, ov_config=F32_CONFIG + ) _ = model_without_pkv.generate(**tokens) # warmup with Timer() as without_pkv_timer: outputs_model_without_pkv = model_without_pkv.generate( **tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1 ) - + print(outputs_model_with_pkv) + print(outputs_model_without_pkv) self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv)) self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH) self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH) @@ -1201,20 +1206,23 @@ def test_compare_with_and_without_past_key_values(self): question = "Who am I?" inputs = preprocessor(images=self.IMAGE, text=question, return_tensors="pt") - model_with_pkv = OVModelForPix2Struct.from_pretrained(model_id, export=True, use_cache=True) + model_with_pkv = OVModelForPix2Struct.from_pretrained( + model_id, export=True, use_cache=True, ov_config=F32_CONFIG + ) _ = model_with_pkv.generate(**inputs) # warmup with Timer() as with_pkv_timer: outputs_model_with_pkv = model_with_pkv.generate( **inputs, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1 ) - model_without_pkv = OVModelForPix2Struct.from_pretrained(model_id, export=True, use_cache=False) + model_without_pkv = OVModelForPix2Struct.from_pretrained( + model_id, export=True, use_cache=False, ov_config=F32_CONFIG + ) _ = model_without_pkv.generate(**inputs) # warmup with Timer() as without_pkv_timer: outputs_model_without_pkv = model_without_pkv.generate( **inputs, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1 ) - self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv)) self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH) self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH) diff --git a/tests/openvino/test_stable_diffusion.py b/tests/openvino/test_stable_diffusion.py index d8cef2e027..36b1794015 100644 --- a/tests/openvino/test_stable_diffusion.py +++ b/tests/openvino/test_stable_diffusion.py @@ -410,7 +410,7 @@ def test_image_reproducibility(self, model_arch: str): # Verify every subcomponent is compiled by default for component in {"unet", "vae_encoder", "vae_decoder", "text_encoder", "text_encoder_2"}: - self.assertIsInstance(getattr(pipeline, component).request, CompiledModel) + self.assertIsInstance(getattr(pipeline, component).compiled_model, CompiledModel) batch_size, num_images_per_prompt, height, width = 2, 3, 64, 128 inputs = _generate_inputs(batch_size)