Skip to content

Commit

Permalink
fix export
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix committed Feb 14, 2024
1 parent fa97edd commit 7fa0f6e
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import functools
import gc
import inspect
import logging
import os
from pathlib import Path
Expand Down Expand Up @@ -370,9 +371,22 @@ def export_pytorch(
)

dummy_inputs = config.rename_ambiguous_inputs(dummy_inputs)

try:
with config.patch_model_for_export(model, model_kwargs=model_kwargs):
# TorchScript used behind OpenVINO conversion. Optimum supports only return_dict=True models for patching,
# while TorchScript do not support dictionary with values of mixed types (e.g. Tensor and None) in model input/output
# To handle it, additional wrapper on patcher forward applied.
# model.config.torchscript = True can not be used for patching, because it overrides return_dict to False
patcher = config.patch_model_for_export(model, model_kwargs=model_kwargs)
patched_forward = patcher.patched_forward

@functools.wraps(patched_forward)
def ts_patched_forward(*args, **kwargs):
outputs = patched_forward(*args, **kwargs)
return tuple(outputs.values())

patcher.patched_forward = ts_patched_forward

with patcher:
check_dummy_inputs_are_allowed(model, dummy_inputs)
inputs = config.ordered_inputs(model)
input_names = list(inputs.keys())
Expand Down Expand Up @@ -404,7 +418,8 @@ def export_pytorch(
compression_ratio=compression_ratio,
)

ordered_dummy_inputs = {param: dummy_inputs[param] for param in inputs if param in dummy_inputs}
sig = inspect.signature(model.forward) if hasattr(model, "forward") else inspect.signature(model.call)
ordered_dummy_inputs = {param: dummy_inputs[param] for param in sig.parameters if param in dummy_inputs}
ordered_input_names = list(inputs)
flatten_inputs = flattenize_inputs(ordered_dummy_inputs.values())
ov_model.validate_nodes_and_infer_types()
Expand All @@ -418,7 +433,6 @@ def export_pytorch(
inp_data = flatten_inputs[idx]
static_shape = PartialShape(inp_data.shape)
dims = inputs[input_name]

for dim in dims:
static_shape[dim] = -1
inp_tensor.get_node().set_partial_shape(static_shape)
Expand Down

0 comments on commit 7fa0f6e

Please sign in to comment.