Skip to content

Commit

Permalink
fix export
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Feb 14, 2024
1 parent fa97edd commit fb1910e
Showing 1 changed file with 25 additions and 3 deletions.
28 changes: 25 additions & 3 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import functools
import gc
import inspect
import logging
import os
from pathlib import Path
Expand Down Expand Up @@ -43,6 +44,7 @@
clear_class_registry,
flattenize_inputs,
get_input_shapes,
remove_none_from_dummy_inputs,
)


Expand Down Expand Up @@ -370,9 +372,29 @@ def export_pytorch(
)

dummy_inputs = config.rename_ambiguous_inputs(dummy_inputs)
dummy_inputs, dict_inputs = remove_none_from_dummy_inputs(dummy_inputs)

try:
with config.patch_model_for_export(model, model_kwargs=model_kwargs):
# TorchScript used behind OpenVINO conversion. Optimum supports only return_dict=True models for patching,
# while TorchScript do not support dictionary with values of mixed types (e.g. Tensor and None) in model input/output
# To handle it, additional wrapper on patcher forward applied.
# model.config.torchscript = True can not be used for patching, because it overrides return_dict to False
patcher = config.patch_model_for_export(model, model_kwargs=model_kwargs)
patched_forward = patcher.patched_forward

@functools.wraps(patched_forward)
def ts_patched_forward(*args, **kwargs):
for i in range(len(dict_inputs)):
input_name, keys = dict_inputs[i]
tuple_input = kwargs[input_name]
input_dict = dict(zip(keys, tuple_input))
kwargs[input_name] = input_dict
outputs = patched_forward(*args, **kwargs)
return tuple(outputs.values())

patcher.patched_forward = ts_patched_forward

with patcher:
check_dummy_inputs_are_allowed(model, dummy_inputs)
inputs = config.ordered_inputs(model)
input_names = list(inputs.keys())
Expand Down Expand Up @@ -404,7 +426,8 @@ def export_pytorch(
compression_ratio=compression_ratio,
)

ordered_dummy_inputs = {param: dummy_inputs[param] for param in inputs if param in dummy_inputs}
sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call)
ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs}
ordered_input_names = list(inputs)
flatten_inputs = flattenize_inputs(ordered_dummy_inputs.values())
ov_model.validate_nodes_and_infer_types()
Expand All @@ -418,7 +441,6 @@ def export_pytorch(
inp_data = flatten_inputs[idx]
static_shape = PartialShape(inp_data.shape)
dims = inputs[input_name]

for dim in dims:
static_shape[dim] = -1
inp_tensor.get_node().set_partial_shape(static_shape)
Expand Down

0 comments on commit fb1910e

Please sign in to comment.