Skip to content

Commit

Permalink
Only remove cast pairs with fp32 input types.
Browse files Browse the repository at this point in the history
  • Loading branch information
LeoZDong committed Dec 23, 2024
1 parent efcfd73 commit e12afaf
Showing 1 changed file with 26 additions and 4 deletions.
30 changes: 26 additions & 4 deletions onnxconverter_common/float16.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import itertools
import uuid
import warnings
from typing import Optional

import numpy as np
import onnx
import packaging.version as pv
import warnings
from onnx import helper, numpy_helper
from onnx import onnx_pb as onnx_proto


FLOAT32 = 1
FLOAT16 = 10

Expand Down Expand Up @@ -522,10 +522,31 @@ def remove_unnecessary_cast_node(graph_proto: onnx_proto.GraphProto):
if upstream_node.op_type == 'Constant':
cast_node_list.remove(cast_node)

# 4. find the cast(to16) node which downstream is Cast(to32)
# 4. find (cast_to_fp16, cast_to_fp32) pairs where --fp32--> cast_to_fp16 --fp16--> cast_to_fp32.
remove_candidate = []

name_to_value_info = {
value_info.name: value_info for value_info in itertools.chain(graph_proto.value_info, graph_proto.input)
}

def get_type(name: str) -> Optional[int]:
if name in name_to_value_info:
return name_to_value_info[name].type
else:
# `name` has no value info.
return None

for cast_node_name, downstream_node in cast_node_downstream_dict.items():
cast_node = name_to_node_dict[cast_node_name]
if len(cast_node.input) != 1:
raise RuntimeError(
f"Cast node {cast_node_name} should have only one input, but has {len(cast_node.input)}."
)

input_type = get_type(cast_node.input[0])
if input_type != onnx_proto.TensorProto.FLOAT:
continue

if isinstance(downstream_node, list):
for dn in downstream_node:
if dn.op_type == 'Cast' and \
Expand All @@ -542,7 +563,8 @@ def remove_unnecessary_cast_node(graph_proto: onnx_proto.GraphProto):
cast_node in cast_node_list:
remove_candidate.append((cast_node, downstream_node))

# 5. change the connection of "upstream->cast16->cast32->downstream" to "upstream->downstream"
# 5. change "upstream --fp32--> cast_to_fp16 --fp16--> cast_to_fp32 --fp32--> downstream" to
# "upstream --fp32--> downstream".
for cast_node_pair in remove_candidate:
first_cast_node = cast_node_pair[0]
second_cast_node = cast_node_pair[1]
Expand Down

0 comments on commit e12afaf

Please sign in to comment.