Skip to content

Commit

Permalink
Prepare shape with symbols before conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
jane-intel committed Jul 16, 2024
1 parent cb2f2ec commit 097031e
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 52 deletions.
35 changes: 10 additions & 25 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

from openvino.runtime import Model, PartialShape, save_model
from openvino.runtime.exceptions import OVTypeError
from openvino.runtime.utils.types import get_element_type
from openvino.tools.ovc import convert_model
from optimum.exporters import TasksManager
from optimum.exporters.onnx.base import OnnxConfig
Expand All @@ -51,8 +50,7 @@
_MAX_UNCOMPRESSED_SIZE,
OV_XML_FILE_NAME,
clear_class_registry,
flattenize_inputs,
get_input_shapes,
get_input_info,
remove_none_from_dummy_inputs,
)

Expand Down Expand Up @@ -358,13 +356,10 @@ def ts_patched_forward(*args, **kwargs):

with patcher:
check_dummy_inputs_are_allowed(model, dummy_inputs)
sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call)
inputs = config.ordered_inputs(model)
input_names = list(inputs.keys())
output_names = list(config.outputs.keys())
input_info = get_input_shapes(dummy_inputs, inputs)

ov_model = convert_model(model, example_input=dummy_inputs, input=input_info)
input_info = get_input_info(model, config, dummy_inputs)
ov_model = convert_model(model,
example_input=dummy_inputs,
input=list(map(lambda item: (item.shape, item.type), input_info)))

except Exception as ex:
logger.warning(f"Export model to OpenVINO directly failed with: \n{ex}.\nModel will be exported to ONNX")
Expand All @@ -388,27 +383,17 @@ def ts_patched_forward(*args, **kwargs):
ov_config=ov_config,
)

ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs}
if not ordered_dummy_inputs:
ordered_dummy_inputs = dummy_inputs
ordered_input_names = list(inputs)
flatten_inputs = flattenize_inputs(ordered_dummy_inputs.values())
ov_model.validate_nodes_and_infer_types()
ov_model.validate_nodes_and_infer_types() # TODO: remove as unnecessary validation?

output_names = list(config.outputs.keys())
for idx, out_tensor in enumerate(ov_model.outputs):
if idx < len(output_names):
out_tensor.get_tensor().set_names({output_names[idx]})

input_names = list(map(lambda item: item.name, input_info))
for idx, inp_tensor in enumerate(ov_model.inputs):
input_name = ordered_input_names[idx]
input_name = input_names[idx]
inp_tensor.get_tensor().set_names({input_name})
inp_data = flatten_inputs[idx]
static_shape = PartialShape(inp_data.shape)
dims = inputs.get(input_name, [])
for dim in dims:
static_shape[dim] = -1
inp_tensor.get_node().set_partial_shape(static_shape)
inp_tensor.get_node().set_element_type(get_element_type(inp_data.cpu().numpy().dtype))
ov_model.validate_nodes_and_infer_types()

if stateful:
patch_stateful(model.config, ov_model)
Expand Down
67 changes: 40 additions & 27 deletions optimum/exporters/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,16 @@
# limitations under the License.

from typing import Any, Dict, List, Tuple, Union

import inspect
from transformers.utils import is_torch_available

from openvino.runtime import PartialShape
from openvino.runtime import Dimension, PartialShape, Symbol
from openvino.runtime.utils.types import get_element_type
from optimum.utils import is_diffusers_available
from optimum.exporters.onnx.base import OnnxConfig
from collections import namedtuple

InputInfo = namedtuple('InputInfo', ['name', 'shape', 'type', 'example'])


if is_torch_available():
Expand Down Expand Up @@ -69,6 +74,39 @@ def flattenize_inputs(inputs: List[Any]):
return flatten_inputs


def get_input_info(model: Union["PreTrainedModel", "ModelMixin"], config: OnnxConfig, dummy_inputs: Dict[str, Any]) -> List[InputInfo]:
sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call)
inputs = config.ordered_inputs(model)
ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs}
if not ordered_dummy_inputs:
ordered_dummy_inputs = dummy_inputs
ordered_input_names = list(inputs)
flatten_inputs = flattenize_inputs(ordered_dummy_inputs.values())
input_info = list()

name_to_symbol = dict()

for i in range(len(ordered_input_names)):
name = ordered_input_names[i]
example = flatten_inputs[i]
type = get_element_type(example.cpu().numpy().dtype)
shape = PartialShape(example.shape)
if name in inputs:
named_dims = inputs[name]
for idx, dim_name in named_dims.items():
if dim_name in name_to_symbol:
symbol = name_to_symbol[dim_name]
else:
symbol = Symbol()
name_to_symbol[name] = symbol
dim = Dimension(-1)
dim.set_symbol(symbol)
shape[idx] = dim
info = InputInfo(name=name, shape=shape, type=type, example=example)
input_info.append(info)
return input_info


def remove_none_from_dummy_inputs(dummy_inputs: Dict[str, Any]):
"""
Removes None values from the dictionary.
Expand Down Expand Up @@ -109,31 +147,6 @@ def remove_none_from_list_tuple(item: Union[List[Any], Tuple[Any]]):
return upd_dummy, dict_dummy


def get_input_shapes(dummy_inputs: Dict[str, Any], inputs: Dict[str, Any]):
"""
Resolves input shapes based on dynamic axes from input config and dummy input shapes
Args:
dummy_inputs (Dict[str, Any]): A dictionary of dummy inputs.
inputs (Dict[str, Any]): A dictionary of input tensors.
Returns:
input_info: List of input info for conversion
"""
input_info = []
for input_name, data in dummy_inputs.items():
if isinstance(data, (tuple, list, dict)):
return None
static_shape = PartialShape(data.shape)
if input_name in inputs:
dynamic_dims = inputs[input_name]
for dim in dynamic_dims:
static_shape[dim] = -1
input_info.append((input_name, static_shape))
return input_info


def clear_class_registry():
"""
Removes Torchscript cached modules
Expand Down

0 comments on commit 097031e

Please sign in to comment.