forked from bqFirst/convert_onnx_float16_to_float
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconvert_onnx_float16_to_float.py
116 lines (103 loc) · 3.73 KB
/
convert_onnx_float16_to_float.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import onnx
from onnx import helper as h
from onnx import checker as ch
from onnx import TensorProto, GraphProto
from onnx import numpy_helper as nph
import numpy as np
from collections import OrderedDict
from logger import log
import typer
def make_param_dictionary(initializer):
params = OrderedDict()
for data in initializer:
params[data.name] = data
return params
def convert_params_to_fp32(params_dict):
converted_params = []
for param in params_dict:
data = params_dict[param]
if data.data_type == TensorProto.FLOAT16:
data_cvt = nph.to_array(data).astype(np.float32)
data = nph.from_array(data_cvt, data.name)
converted_params += [data]
return converted_params
def convert_constant_nodes_to_fp32(nodes):
"""
convert_constant_nodes_to_fp32 Convert Constant nodes to FLOAT32. If a constant node has data type FLOAT16, a new version of the
node is created with FLOAT32 data type and stored.
Args:
nodes (list): list of nodes
Returns:
list: list of new nodes all with FLOAT32 constants.
"""
new_nodes = []
for node in nodes:
if (
node.op_type == "Constant"
and node.attribute[0].t.data_type == TensorProto.FLOAT16
):
data = nph.to_array(node.attribute[0].t).astype(np.float32)
new_t = nph.from_array(data)
new_node = h.make_node(
"Constant",
inputs=[],
outputs=node.output,
name=node.name,
value=new_t,
)
new_nodes += [new_node]
else:
new_nodes += [node]
return new_nodes
def convert_model_to_fp32(model_path: str, out_path: str):
"""
convert_model_to_fp32 Converts ONNX model with FLOAT16 params to FLOAT32 params.\n
Args:\n
model_path (str): path to original ONNX model.\n
out_path (str): path to save converted model.
"""
log.info("ONNX FLOAT16 --> FLOAT32 Converter")
log.info(f"Loading Model: {model_path}")
# * load model.
model = onnx.load_model(model_path)
# * get model opset version.
opset_version = model.opset_import[0].version
graph = model.graph
# * convert all FLOAT16 input/output to FLOAT32.
for input in model.graph.input:
input.type.tensor_type.elem_type = 1
for output in model.graph.output:
output.type.tensor_type.elem_type = 1
for n in model.graph.node:
if n.op_type == 'Cast':
for attr in n.attribute:
if attr.name == 'to' and attr.i == 10:
attr.i = 1
# * The initializer holds all non-constant weights.
init = graph.initializer
# * collect model params in a dictionary.
params_dict = make_param_dictionary(init)
log.info("Converting FLOAT16 model params to FLOAT32...")
# * convert all FLOAT16 aprams to FLOAT32.
converted_params = convert_params_to_fp32(params_dict)
log.info("Converting constant FLOAT16 nodes to FLOAT32...")
new_nodes = convert_constant_nodes_to_fp32(graph.node)
graph_name = f"{graph.name}-fp32"
log.info("Creating new graph...")
# * create a new graph with converted params and new nodes.
graph_fp32 = h.make_graph(
new_nodes,
graph_name,
graph.input,
graph.output,
initializer=converted_params,
)
log.info("Creating new float32 model...")
model_fp32 = h.make_model(graph_fp32, producer_name="onnx-typecast")
model_fp32.opset_import[0].version = opset_version
log.info(f"Saving converted model as: {out_path}")
onnx.save_model(model_fp32, out_path)
log.info(f"Done Done London. 🎉")
return
if __name__ == "__main__":
typer.run(convert_model_to_fp32)