Skip to content

Commit 9f40059

Browse files
committed
[Model Optimizer] Add rule to remove Resize nodes with unity Resize scale (s=1's)
1 parent 1ffc27c commit 9f40059

File tree

1 file changed

+40
-0
lines changed
  • scripts/osrt_model_tools/onnx_tools/tidl-onnx-model-optimizer/tidl_onnx_model_optimizer/src

1 file changed

+40
-0
lines changed

scripts/osrt_model_tools/onnx_tools/tidl-onnx-model-optimizer/tidl_onnx_model_optimizer/src/resize.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
import onnx_graphsurgeon as gs
6464
import onnx
6565
import numpy as np
66+
from tidl_onnx_model_optimizer.src.common import find_out_layers, remove_node
6667

6768

6869
def tidl_convert_resize_params_size_to_scale(graph: gs.Graph,
@@ -102,3 +103,42 @@ def tidl_convert_resize_params_size_to_scale(graph: gs.Graph,
102103
# endif
103104
# endif
104105
# endfor
106+
107+
108+
def tidl_remove_unity_resize(graph: gs.Graph,
109+
onnx_graph: onnx.GraphProto):
110+
'''
111+
Some models have an effectively null resize node that scales by a factor of 1 in all dimensions
112+
Such a node is often an export artifact -- a layer added by a model format converter
113+
This is node effectively unity, but it will be processed nonetheless. It should therefore be removed
114+
'''
115+
116+
tensors = graph.tensors()
117+
nodes_to_remove = []
118+
for node in graph.nodes:
119+
120+
if node.op == "Resize":
121+
inputs = node.inputs
122+
if len(inputs) >= 3:
123+
X, roi, scales = inputs[0:3]
124+
else:
125+
continue
126+
Y = node.outputs[0]
127+
attrs = node.attrs
128+
129+
if X.shape == Y.shape and all(map(lambda x: x==1, scales.values)):
130+
#ensure it's not using ROI, which is only with crop-and-resize mode
131+
if node.attrs['coordinate_transformation_mode'] == 'tf_crop_and_resize':
132+
logging.warning("Detected Resize node as using ROI... skipping")
133+
continue
134+
135+
logging.debug("Removing unity Resize node %s" % node.name)
136+
137+
out_nodes = find_out_layers(node)
138+
139+
for o_node in out_nodes:
140+
for i, net in enumerate(o_node.inputs):
141+
if net == Y:
142+
o_node.inputs[i] = X
143+
144+
#node will be removed by cleanup since it has only unused outputs

0 commit comments

Comments
 (0)