From e12afaf13a899899cf841b9bdcf181325fae51ab Mon Sep 17 00:00:00 2001 From: Leo Dong Date: Mon, 23 Dec 2024 13:53:06 -0800 Subject: [PATCH] Only remove cast pairs with fp32 input types. --- onnxconverter_common/float16.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/onnxconverter_common/float16.py b/onnxconverter_common/float16.py index 6513fb3..3389872 100644 --- a/onnxconverter_common/float16.py +++ b/onnxconverter_common/float16.py @@ -6,14 +6,14 @@ import itertools import uuid import warnings +from typing import Optional + import numpy as np import onnx import packaging.version as pv -import warnings from onnx import helper, numpy_helper from onnx import onnx_pb as onnx_proto - FLOAT32 = 1 FLOAT16 = 10 @@ -522,10 +522,31 @@ def remove_unnecessary_cast_node(graph_proto: onnx_proto.GraphProto): if upstream_node.op_type == 'Constant': cast_node_list.remove(cast_node) - # 4. find the cast(to16) node which downstream is Cast(to32) + # 4. find (cast_to_fp16, cast_to_fp32) pairs where --fp32--> cast_to_fp16 --fp16--> cast_to_fp32. remove_candidate = [] + + name_to_value_info = { + value_info.name: value_info for value_info in itertools.chain(graph_proto.value_info, graph_proto.input) + } + + def get_type(name: str) -> Optional[int]: + if name in name_to_value_info: + return name_to_value_info[name].type + else: + # `name` has no value info. + return None + for cast_node_name, downstream_node in cast_node_downstream_dict.items(): cast_node = name_to_node_dict[cast_node_name] + if len(cast_node.input) != 1: + raise RuntimeError( + f"Cast node {cast_node_name} should have only one input, but has {len(cast_node.input)}." + ) + + input_type = get_type(cast_node.input[0]) + if input_type != onnx_proto.TensorProto.FLOAT: + continue + if isinstance(downstream_node, list): for dn in downstream_node: if dn.op_type == 'Cast' and \ @@ -542,7 +563,8 @@ def remove_unnecessary_cast_node(graph_proto: onnx_proto.GraphProto): cast_node in cast_node_list: remove_candidate.append((cast_node, downstream_node)) - # 5. change the connection of "upstream->cast16->cast32->downstream" to "upstream->downstream" + # 5. change "upstream --fp32--> cast_to_fp16 --fp16--> cast_to_fp32 --fp32--> downstream" to + # "upstream --fp32--> downstream". for cast_node_pair in remove_candidate: first_cast_node = cast_node_pair[0] second_cast_node = cast_node_pair[1]