diff --git a/.gitignore b/.gitignore index 1fa2656..365ae40 100644 --- a/.gitignore +++ b/.gitignore @@ -105,3 +105,5 @@ venv.bak/ # PyCharm .idea/ +tests/data/image_classifier16.onnx +tests/data/fp16_tensor.data diff --git a/onnxconverter_common/auto_mixed_precision.py b/onnxconverter_common/auto_mixed_precision.py index da422be..e9797b7 100644 --- a/onnxconverter_common/auto_mixed_precision.py +++ b/onnxconverter_common/auto_mixed_precision.py @@ -68,7 +68,7 @@ def validate(res1, res2): def run_attempt(node_block_list, return_model=False): print(node_block_list) model = float16.convert_float_to_float16(copy.deepcopy(model0), node_block_list=node_block_list, - keep_io_types=keep_io_types, disable_shape_infer=True) + keep_io_types=keep_io_types, disable_shape_infer=False) res1 = get_tensor_values_using_ort(model, feed_dict) if return_model: return validate(res0, res1), model @@ -129,7 +129,7 @@ def get_tensor_values_using_ort(model, input_feed, output_names=None, sess_optio # Below code is for debug only, keep it for next time use # sess_options = ort.SessionOptions() # sess_options.optimized_model_filepath = "d:/optimized_model.onnx" - sess = ort.InferenceSession(model.SerializeToString(), sess_options, providers=['CUDAExecutionProvider']) + sess = ort.InferenceSession(model.SerializeToString(), sess_options, providers=['CPUExecutionProvider']) return sess.run(None, input_feed) original_outputs = list(model.graph.output) while len(model.graph.output) > 0: @@ -137,7 +137,8 @@ def get_tensor_values_using_ort(model, input_feed, output_names=None, sess_optio for n in output_names: out = model.graph.output.add() out.name = n - sess = ort.InferenceSession(model.SerializeToString(), sess_options, providers=['CUDAExecutionProvider']) + # if set to 'CUDAExecutionProvider', will be failed, need further investigation + sess = ort.InferenceSession(model.SerializeToString(), sess_options, providers=['CPUExecutionProvider']) try: return sess.run(output_names, input_feed) finally: diff --git a/onnxconverter_common/float16.py b/onnxconverter_common/float16.py index b2291d2..38b3a4e 100644 --- a/onnxconverter_common/float16.py +++ b/onnxconverter_common/float16.py @@ -12,6 +12,10 @@ from onnx import onnx_pb as onnx_proto +FLOAT32 = 1 +FLOAT16 = 10 + + def _npfloat16_to_int(np_list): ''' Convert numpy float16 to python int. @@ -108,33 +112,7 @@ def make_value_info_from_tensor(tensor): 'RoiAlign', 'Resize', 'Range', 'CumSum', 'Min', 'Max', 'Upsample'] -def convert_float_to_float16(model, min_positive_val=1e-7, max_finite_val=1e4, - keep_io_types=False, disable_shape_infer=False, - op_block_list=None, node_block_list=None): - ''' - Convert tensor float type in the ONNX ModelProto input to tensor float16. - - :param model: ONNX ModelProto object - :param disable_shape_infer: Type/shape information is needed for conversion to work. - Set to True only if the model already has type/shape information for all tensors. - :return: converted ONNX ModelProto object - - Examples: - - :: - - Example 1: Convert ONNX ModelProto object: - from onnxmltools.utils.float16_converter import convert_float_to_float16 - new_onnx_model = convert_float_to_float16(onnx_model) - - Example 2: Convert ONNX model binary file: - from onnxmltools.utils.float16_converter import convert_float_to_float16 - from onnxmltools.utils import load_model, save_model - onnx_model = load_model('model.onnx') - new_onnx_model = convert_float_to_float16(onnx_model) - save_model(new_onnx_model, 'new_model.onnx') - - ''' +def initial_checking(model, disable_shape_infer): func_infer_shape = None if not disable_shape_infer and pv.Version(onnx.__version__) >= pv.Version('1.2'): try: @@ -146,6 +124,16 @@ def convert_float_to_float16(model, min_positive_val=1e-7, max_finite_val=1e4, if not isinstance(model, onnx_proto.ModelProto): raise ValueError('Expected model type is an ONNX ModelProto but got %s' % type(model)) + if func_infer_shape is not None: + model = func_infer_shape(model) + + return model, func_infer_shape + +# new implementation by Xiaowu to fix a lot of bug due to ort changed +def convert_float_to_float16(model, min_positive_val=1e-7, max_finite_val=1e4, + keep_io_types=False, disable_shape_infer=False, + op_block_list=None, node_block_list=None): + # create blocklists if op_block_list is None: op_block_list = DEFAULT_OP_BLOCK_LIST @@ -153,158 +141,251 @@ def convert_float_to_float16(model, min_positive_val=1e-7, max_finite_val=1e4, node_block_list = [] op_block_list = set(op_block_list) node_block_list = set(node_block_list) - # create a queue for BFS - queue = [] - value_info_list = [] - node_list = [] - # key = node, value = graph, used to distinguish global with sub-graph - node_dict = {} - # type inference on input model + + global_input_name_dict = {} # key: input name, value: new output name after Cast node + # basic checking, including shape inference + model, func_infer_shape = initial_checking(model, disable_shape_infer) + graph_stack = [model.graph] + + is_top_level = True + while graph_stack: + next_level = [] + for curr_graph in graph_stack: + process_graph_input(curr_graph, is_top_level, keep_io_types, global_input_name_dict) + value_info_block_list = process_tensor_in_node(curr_graph, op_block_list, node_block_list, min_positive_val, max_finite_val) + process_value_info(curr_graph, value_info_block_list) + process_node_in_block_list(curr_graph, global_input_name_dict, op_block_list, node_block_list) + process_initializers(curr_graph, op_block_list, node_block_list, min_positive_val, max_finite_val) + process_graph_output(curr_graph, is_top_level, keep_io_types) + sub_graph_list = get_next_level_graph(curr_graph, op_block_list, node_block_list) + if len(sub_graph_list) > 0: + next_level.extend(sub_graph_list) + + if not is_top_level: + process_node_input_output(curr_graph, global_input_name_dict) + is_top_level = False # Going to process sub-graph + graph_stack = next_level + + # infor_shape again to fill the shape and size for the new node and edge + # so edge info can be shown in Netron if func_infer_shape is not None: + # infer_shape will change the memory address of the components in the model + # so don't do things depending on the memory address or object reference model = func_infer_shape(model) - queue.append(model) - name_mapping = {} - graph_io_to_skip = set() - io_casts = set() - if keep_io_types: - for i, n in enumerate(model.graph.input): - if n.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT: - output_name = 'graph_input_cast_' + str(i) - name_mapping[n.name] = output_name - graph_io_to_skip.add(n.name) - - node_name = 'graph_input_cast' + str(i) - new_value_info = model.graph.value_info.add() - new_value_info.CopyFrom(n) - new_value_info.name = output_name - new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16 - # add Cast node (from tensor(float) to tensor(float16) after graph input - new_node = [helper.make_node('Cast', [n.name], [output_name], to=10, name=node_name)] - model.graph.node.extend(new_node) - value_info_list.append(new_value_info) - io_casts.add(node_name) - - for i, n in enumerate(model.graph.output): - if n.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT: - input_name = 'graph_output_cast_' + str(i) - name_mapping[n.name] = input_name - graph_io_to_skip.add(n.name) - - node_name = 'graph_output_cast' + str(i) - # add Cast node (from tensor(float16) to tensor(float) before graph output - new_value_info = model.graph.value_info.add() - new_value_info.CopyFrom(n) - new_value_info.name = input_name - new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16 - new_node = [helper.make_node('Cast', [input_name], [n.name], to=1, name=node_name)] - model.graph.node.extend(new_node) - value_info_list.append(new_value_info) - io_casts.add(node_name) - - while queue: - next_level = [] - for q in queue: - # if q is model, push q.graph (GraphProto) - if isinstance(q, onnx_proto.ModelProto): - next_level.append(q.graph) - # if q is model.graph, push q.node.attribute (AttributeProto) - if isinstance(q, onnx_proto.GraphProto): - for n in q.node: - # if n is in the block list (doesn't support float16), no conversion for the node, - # and save the node for further processing - if n.name in io_casts: - continue - for i in range(len(n.input)): - if n.input[i] in name_mapping: - n.input[i] = name_mapping[n.input[i]] - for i in range(len(n.output)): - if n.output[i] in name_mapping: - n.output[i] = name_mapping[n.output[i]] - # don't add the attr into next_level for the node in node_keep_data_type_list - # so it will not be converted to float16 - if n.op_type in op_block_list or n.name in node_block_list: - node_list.append(n) - node_dict[n.name] = q - else: - if n.op_type == 'Cast': - for attr in n.attribute: - if attr.name == 'to' and attr.i == 1: - attr.i = 10 - break - for attr in n.attribute: - next_level.append(attr) - # if q is model.graph.node.attribute, push q.g and q.graphs (GraphProto) - # and process node.attribute.t and node.attribute.tensors (TensorProto) - if isinstance(q, onnx_proto.AttributeProto): - next_level.append(q.g) - for n in q.graphs: - next_level.append(n) - q.t.CopyFrom(convert_tensor_float_to_float16(q.t, min_positive_val, max_finite_val)) - for n in q.tensors: - n = convert_tensor_float_to_float16(n, min_positive_val, max_finite_val) - # if q is graph, process graph.initializer(TensorProto), input, output and value_info (ValueInfoProto) - if isinstance(q, onnx_proto.GraphProto): - for n in q.initializer: # TensorProto type - if n.data_type == onnx_proto.TensorProto.FLOAT: - n = convert_tensor_float_to_float16(n, min_positive_val, max_finite_val) - value_info_list.append(make_value_info_from_tensor(n)) - # for all ValueInfoProto with tensor(float) type in input, output and value_info, convert them to - # tensor(float16) except map and seq(map). And save them in value_info_list for further processing - for n in itertools.chain(q.input, q.output, q.value_info): - if n.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT: - if n.name not in graph_io_to_skip: - n.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16 - value_info_list.append(n) - queue = next_level - - # process the nodes in block list that doesn't support tensor(float16) - for node in node_list: - # if input's name is in the value_info_list meaning input is tensor(float16) type, - # insert a float16 to float Cast node before the node, - # change current node's input name and create new value_info for the new name - for i in range(len(node.input)): - input = node.input[i] - for value_info in value_info_list: - if input == value_info.name: - # create new value_info for current node's new input name - graph = node_dict[node.name] # get the correct graph instead of the global graph - new_value_info = graph.value_info.add() - new_value_info.CopyFrom(value_info) - output_name = node.name + '_input_cast_' + str(i) - new_value_info.name = output_name - new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT - # add Cast node (from tensor(float16) to tensor(float) before current node - node_name = node.name + '_input_cast' + str(i) - new_node = [helper.make_node('Cast', [input], [output_name], to=1, name=node_name)] - graph.node.extend(new_node) - # change current node's input name - node.input[i] = output_name - break - # if output's name is in the value_info_list meaning output is tensor(float16) type, insert a float to - # float16 Cast node after the node, change current node's output name and create new value_info for the new name - for i in range(len(node.output)): - output = node.output[i] - for value_info in value_info_list: - if output == value_info.name: - # create new value_info for current node's new output - graph = node_dict[node.name] # get the correct graph instead of the global graph - new_value_info = graph.value_info.add() - new_value_info.CopyFrom(value_info) - input_name = node.name + '_output_cast_' + str(i) - new_value_info.name = input_name - new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT - # add Cast node (from tensor(float) to tensor(float16) after current node - node_name = node.name + '_output_cast' + str(i) - new_node = [helper.make_node('Cast', [input_name], [output], to=10, name=node_name)] - graph.node.extend(new_node) - # change current node's input name - node.output[i] = input_name - break sort_topology(model.graph) remove_unnecessary_cast_node(model.graph) + return model +# Change the input/output of the node to the new output name after Cast node for sub-graph +# Because there have NO value_info start from +def process_node_input_output(graph: onnx_proto.GraphProto, global_input_name_dict: dict): + for node in graph.node: + for i, input_name in enumerate(node.input): + if input_name in global_input_name_dict: + node.input[i] = global_input_name_dict[input_name] + for i, output_name in enumerate(node.output): + if output_name in global_input_name_dict: + node.output[i] = global_input_name_dict[output_name] + + +def process_graph_input(graph: onnx_proto.GraphProto, is_top_level: bool, is_io_fp32: bool, global_input_name_dict: dict): + # The input dtype is float32, need to cast to fp16 + if is_top_level and is_io_fp32: + for graph_input in graph.input: # n_input is ValueInfoProto + if graph_input.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT: + downstream_nodes = find_donwstream_node_by_input_name(graph, graph_input.name) + for d_node in downstream_nodes: + cast_node_name = graph_input.name + "_cast_to_" + d_node.name + cast_node_output_name = graph_input.name + "_cast_to_" + d_node.name + add_cast_node(graph, [graph_input.name], [cast_node_output_name], cast_node_name, FLOAT16) + add_new_value_info(graph, graph_input, cast_node_output_name, onnx_proto.TensorProto.FLOAT16) + for i, input_name in enumerate(d_node.input): + if input_name == graph_input.name: + d_node.input[i] = cast_node_output_name # Change the input of the second node + global_input_name_dict[graph_input.name] = cast_node_output_name + + # For the sub-graph, don't do cast + else: # Change the input dtype to fp16 without any cast + for graph_input in graph.input: + if graph_input.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT: + graph_input.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16 + + +def process_graph_output(graph: onnx_proto.GraphProto, is_top_level: bool, is_io_fp32: bool): + if is_top_level and is_io_fp32: # the output dtype is float32, need to cast to fp16 + for i, graph_output in enumerate(graph.output): + if graph_output.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT: + upstream_nodes = find_upstream_node_by_output_name(graph, graph_output.name) + for u_node in upstream_nodes: + cast_node_name = u_node.name + "_cast_to_" + graph_output.name + cast_node_output_name = u_node.name + "_cast_to_" + graph_output.name + add_cast_node(graph, [cast_node_output_name], [graph_output.name], cast_node_name, FLOAT32) + add_new_value_info(graph, graph_output, cast_node_output_name, onnx_proto.TensorProto.FLOAT16) + for i, output_name in enumerate(u_node.output): + if output_name == graph_output.name: + u_node.output[i] = cast_node_output_name + else: # change the output dtype to fp16 in tensor + for graph_output in graph.output: + if graph_output.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT: + graph_output.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16 + + +def process_node_in_block_list(graph: onnx_proto.GraphProto, global_input_name_dict: dict, op_block_list: list, node_block_list: list): + for node in graph.node: + if (node.op_type in op_block_list) or (node.name in node_block_list): + insert_cast32_before_node(graph, node, global_input_name_dict) + insert_cast16_after_node(graph, node, global_input_name_dict) + + +# Todo: global_input_name_dict still not fill value +def insert_cast32_before_node(graph: onnx_proto.GraphProto, node: onnx_proto.NodeProto, global_input_name_dict): + for i in range(len(node.input)): + input_name = node.input[i] + for value_info in itertools.chain(graph.value_info, graph.input): + if input_name == value_info.name: + if value_info.type.tensor_type.elem_type != onnx_proto.TensorProto.FLOAT16: + break + cast_output_name = node.name + "_input_cast_" + str(i) + add_new_value_info(graph, value_info, cast_output_name, onnx_proto.TensorProto.FLOAT) + cast_node_name = node.name + "_input_cast" + str(i) + add_cast_node(graph, [input_name], [cast_output_name], cast_node_name, onnx_proto.TensorProto.FLOAT) + node.input[i] = cast_output_name + break + + +# Todo: global_input_name_dict still not fill value +def insert_cast16_after_node(graph: onnx_proto.GraphProto, node: onnx_proto.NodeProto, global_input_name_dict): + for i in range(len(node.output)): + output_name = node.output[i] + for value_info in itertools.chain(graph.value_info, graph.output): + if output_name == value_info.name: + if value_info.type.tensor_type.elem_type != onnx_proto.TensorProto.FLOAT: + break + cast_input_name = node.name + "_output_cast_" + str(i) + add_new_value_info(graph, value_info, cast_input_name, onnx_proto.TensorProto.FLOAT) + value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16 + cast_node_name = node.name + "_output_cast" + str(i) + add_cast_node(graph, [cast_input_name], [output_name], cast_node_name, onnx_proto.TensorProto.FLOAT16) + node.output[i] = cast_input_name + break + + +# Process tensor data in attribute of the node +def process_tensor_in_node(graph: onnx_proto.GraphProto, op_block_list: list, node_block_list: list, min_positive_val, max_finite_val): + value_info_block_list = set() # This is for later use, not in this step + for node in graph.node: + if (node.op_type in op_block_list) or (node.name in node_block_list): + # Only need to block the output value_info changing + for output_name in node.output: + value_info_block_list.add(output_name) + else: + for attr in node.attribute: + # one tensor + if attr.t.data_type == onnx_proto.TensorProto.FLOAT: + attr.t.CopyFrom(convert_tensor_float_to_float16(attr.t, min_positive_val, max_finite_val)) + # list of tensor + for t in attr.tensors: + if t.data_type == onnx_proto.TensorProto.FLOAT: + t.CopyFrom(convert_tensor_float_to_float16(t, min_positive_val, max_finite_val)) + return value_info_block_list + + +# Change all the value info type from float32 to float16 if not in block list +def process_value_info(graph: onnx_proto.GraphProto, value_info_block_list: list): + for value_info in graph.value_info: + if value_info.name in value_info_block_list: + continue + else: + if value_info.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT: + value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16 + + +# Initializer is 'edge' type, so doesn't have value_info +def process_initializers(graph: onnx_proto.GraphProto, op_block_list, node_block_list, min_positive_val, max_finite_val): + # Find the input of the block node, don't need to change this kind of initializer + initializer_block_list = set() + for node in graph.node: + if (node.op_type in op_block_list) or (node.name in node_block_list): + for input_name in node.input: # some is initializer, some is value_info, can't distinguish but doesn't matter + initializer_block_list.add(input_name) + # Process initializers + for initializer in graph.initializer: + if initializer.name not in initializer_block_list: + if initializer.data_type == onnx_proto.TensorProto.FLOAT: + convert_tensor_float_to_float16(initializer, min_positive_val, max_finite_val) + + +def get_next_level_graph(graph: onnx_proto.GraphProto, op_block_list: list, node_block_list: list): + sub_graph_list = [] + for node in graph.node: + if node.op_type in op_block_list or node.name in node_block_list: + continue + for attr in node.attribute: + # Check if sub-graph exist + if len(attr.g.node) > 0: # single sub-graph + sub_graph_list.append(attr.g) + for g in attr.graphs: + if len(g.node) > 0: # multiple sub-graphs + sub_graph_list.append(g) + return sub_graph_list + + +def add_cast_node(graph: onnx_proto.GraphProto, inputs: list, outputs: list, node_name: str, to_type: int): + new_node = [helper.make_node('Cast', inputs, outputs, to=to_type, name=node_name)] + graph.node.extend(new_node) + + +def add_new_value_info(graph: onnx_proto.GraphProto, exist_value_info: onnx_proto.ValueInfoProto, name: str, dtype: int): + new_value_info = graph.value_info.add() + new_value_info.CopyFrom(exist_value_info) + new_value_info.name = name + new_value_info.type.tensor_type.elem_type = dtype + + +# Find the node that has the specified output name +def find_upstream_node_by_output_name(graph: onnx_proto.GraphProto, output_name: str): + nodes = [] + for node in graph.node: + if output_name in node.output: + nodes.append(node) + assert len(nodes) <= 1 # Suppose there is less than one node found + return nodes + + +# Find the node that has the specified input name +def find_donwstream_node_by_input_name(graph: onnx_proto.GraphProto, input_name: str): + nodes = [] + for node in graph.node: + if input_name in node.input: + nodes.append(node) + return nodes + + +# Remove identity node +def remove_identity_node_from_model(model: onnx_proto.ModelProto): + remove_identity_node_from_graph(model.graph) + try: + from onnx.shape_inference import infer_shapes + func_infer_shape = infer_shapes + model = func_infer_shape(model) + return model + finally: + pass + + +# Remove identity node +def remove_identity_node_from_graph(graph: onnx_proto.GraphProto): + for curr_node in graph.node: + if curr_node.op_type == 'Identity': + for input_name in curr_node.input: + upstream_nodes = find_upstream_node_by_output_name(graph, input_name) + for u_node in upstream_nodes: + if u_node is not None: + u_node.output[0] = curr_node.output[0] + graph.node.remove(curr_node) + def convert_float_to_float16_model_path(model_path, min_positive_val=1e-7, max_finite_val=1e4, keep_io_types=False): ''' @@ -391,17 +472,17 @@ def sort_topology(graph_proto): sort_topology(g) # sort sub-graph -def remove_unnecessary_cast_node(graph_proto): +def remove_unnecessary_cast_node(graph_proto: onnx_proto.GraphProto): # 1. find all cast nodes in the graph cast_node_list = [] input_name_to_cast_node_dict = {} output_name_to_cast_node_dict = {} - # using name as key to point to a node. because node cannot be key + # using name as key to point to a node. because node object cannot be key name_to_node_dict = {} for node in graph_proto.node: if node.op_type == 'Cast': - if node.name not in ["graph_input_cast0", "graph_output_cast0"]: - cast_node_list.append(node) + # if node.name not in ["graph_input_cast0", "graph_output_cast0"]: + cast_node_list.append(node) name_to_node_dict[node.name] = node for input_name in node.input: @@ -464,20 +545,31 @@ def remove_unnecessary_cast_node(graph_proto): for cast_node_pair in remove_candidate: first_cast_node = cast_node_pair[0] second_cast_node = cast_node_pair[1] - upstream_node = cast_node_upstream_dict[first_cast_node.name] - downstream_node = cast_node_downstream_dict[second_cast_node.name] - # find the upstream node's output to first_cast_node - out = None - for output_name in upstream_node.output: - if output_name == first_cast_node.input[0]: - out = output_name - break - # find the downstream node's input as second_cast_node's output - for i, input_name in enumerate(downstream_node.input): - for output_name in second_cast_node.output: - if input_name == output_name: - # change the input as the upstream node's output - downstream_node.input[i] = out + upstream_node = cast_node_upstream_dict.get(first_cast_node.name) + downstream_node = cast_node_downstream_dict.get(second_cast_node.name) + if upstream_node is None and downstream_node is not None: + # The upstream_node should be graph input + out = first_cast_node.input[0] + for i, input_name in enumerate(downstream_node.input): + for output_name in second_cast_node.output: + if input_name == output_name: + # change the input as the upstream node's output + downstream_node.input[i] = out + elif upstream_node is not None and downstream_node is None: + raise ValueError("The downstream node of the second cast node should be graph output") + else: + # find the upstream node's output to first_cast_node + out = None + for output_name in upstream_node.output: + if output_name == first_cast_node.input[0]: + out = output_name + break + # find the downstream node's input as second_cast_node's output + for i, input_name in enumerate(downstream_node.input): + for output_name in second_cast_node.output: + if input_name == output_name: + # change the input as the upstream node's output + downstream_node.input[i] = out # 6. remove the cast node pair for cast_node_pair in remove_candidate: diff --git a/tests/test_auto_mixed_precision.py b/tests/test_auto_mixed_precision.py index 67665d4..600ce07 100644 --- a/tests/test_auto_mixed_precision.py +++ b/tests/test_auto_mixed_precision.py @@ -2,6 +2,7 @@ import numpy as np import onnxruntime as _ort import onnx +import os import copy from onnxconverter_common.onnx_fx import Graph, OnnxOperatorBuilderX from onnxconverter_common.onnx_fx import GraphFunctionType as _Ty @@ -47,15 +48,28 @@ def validate_fn(res, fp16res): return np.allclose(res[0], fp16res[0], rtol=0.01) f16model = auto_convert_mixed_precision(copy.deepcopy(model), {'x': m1}, validate_fn, keep_io_types=True) - actual = _ort_inference(f16model, {'x': m1}) self.assertTrue(np.allclose(expected, actual, rtol=0.01)) f16model2 = auto_convert_mixed_precision(copy.deepcopy(model), {'x': m1}, rtol=0.01, keep_io_types=False) - actual = _ort_inference(f16model2, {'x': m1.astype(np.float16)}) self.assertTrue(np.allclose(expected, actual, rtol=0.01)) + def test_auto_mixed_precision_rtol_atol(self): + model32_name = "image_classifier32.onnx" + working_path = os.path.abspath(os.path.dirname(__file__)) + data_path = os.path.join(working_path, 'data') + model32_path = os.path.join(data_path, model32_name) + model32 = onnx.load(model32_path) + np.random.seed(1) + input_x = np.random.rand(32, 3, 32, 32).astype(np.float32) + expected = _ort_inference(model32, {'modelInput': input_x}) + + model16 = auto_convert_mixed_precision(model32, {'modelInput': input_x}, rtol=0.01, keep_io_types=True) + actual = _ort_inference(model16, {'modelInput': input_x.astype(np.float32)}) + self.assertTrue(np.allclose(expected, actual, rtol=1e-2, atol=1e-2)) + + if __name__ == '__main__': suite = unittest.TestLoader().loadTestsFromTestCase(AutoFloat16Test) diff --git a/tests/test_float16.py b/tests/test_float16.py index ca77907..491110e 100644 --- a/tests/test_float16.py +++ b/tests/test_float16.py @@ -10,7 +10,7 @@ from onnxconverter_common.onnx_fx import GraphFunctionType as _Ty from onnxconverter_common.onnx_ex import get_maximum_opset_supported from onnxconverter_common.optimizer import optimize_onnx_model -from onnxconverter_common.float16 import convert_float_to_float16 +from onnxconverter_common.float16 import convert_float_to_float16, remove_identity_node_from_model from onnxconverter_common.float16 import convert_np_to_float16 @@ -36,18 +36,24 @@ def transpose_n_matmul(x): b = ox.constant(value=wm) a = ox.transpose(x, perm=[0, 1, 3, 2]) c = ox.transpose(b, perm=[1, 0]) - return ox.matmul([a, c]) + d = ox.matmul([a, c]) + return ox.min(d) m1 = np.array([[2, 3], [4, 5], [6, 7]]).astype(np.float32).reshape([1, 1, 6, 1]) expected = transpose_n_matmul(m1) model = transpose_n_matmul.to_model() - f16model = convert_float_to_float16(copy.deepcopy(model)) + # This is optional, is a new feature test case + model = remove_identity_node_from_model(model) + actual = _ort_inference(model, {'x': m1}) + self.assertTrue(np.allclose(expected, actual)) + + f16model = convert_float_to_float16(copy.deepcopy(model), keep_io_types=False) actual = _ort_inference(f16model, {'x': m1.astype(np.float16)}) self.assertTrue(np.allclose(expected, actual)) f16model2 = convert_float_to_float16(copy.deepcopy(model), keep_io_types=True) - actual2 = _ort_inference(f16model2, {'x': m1}) - self.assertTrue(np.allclose(expected, actual2)) + actual = _ort_inference(f16model2, {'x': m1}) + self.assertTrue(np.allclose(expected, actual)) def test_float16_with_loop(self): @onnx_function(outputs=['y1', 'y2'], @@ -79,15 +85,15 @@ def range_body(iter_n, cond, total): expected_res = loop_test(m1) model = loop_test.to_model() - f16model = convert_float_to_float16(copy.deepcopy(model)) - actual_res = _ort_inference(f16model, {'data': m1.astype(np.float16)}) - for expected, actual in zip(expected_res, actual_res): + f16model = convert_float_to_float16(copy.deepcopy(model), keep_io_types=False) + actual = _ort_inference(f16model, {'data': m1.astype(np.float16)}) + for expected, actual in zip(expected_res, actual): self.assertTrue(np.allclose(expected, actual)) self.assertTrue(actual.dtype == np.float16) - f16model2 = convert_float_to_float16(copy.deepcopy(model), keep_io_types=True) - actual_res2 = _ort_inference(f16model2, {'data': m1}) - for expected, actual2 in zip(expected_res, actual_res2): + f16model = convert_float_to_float16(copy.deepcopy(model), keep_io_types=True) + actual = _ort_inference(f16model, {'data': m1}) + for expected, actual2 in zip(expected_res, actual): self.assertTrue(np.allclose(expected, actual2)) self.assertTrue(actual2.dtype == np.float32) @@ -100,10 +106,14 @@ def test_convert_to_float16(self): input_x = np.random.rand(1, 3, 32, 32).astype(np.float32) output_32 = _ort_inference(onnx_model32, {'modelInput': input_x}) - onnx_model16 = convert_float_to_float16(onnx_model32) + onnx_model16 = convert_float_to_float16(onnx_model32, keep_io_types=False) output_16 = _ort_inference(onnx_model16, {'modelInput': input_x.astype(np.float16)}) self.assertTrue(np.allclose(output_16, output_32, atol=1e-2)) + onnx_model16 = convert_float_to_float16(onnx_model32, keep_io_types=True) + output_16 = _ort_inference(onnx_model16, {'modelInput': input_x}) + self.assertTrue(np.allclose(output_16, output_32, atol=1e-2)) + def test_convert_to_float16_with_truncated(self): np_array = np.array([1e-10, -2.0, 15, -1e-9, 65536.1, -100000]) convert_np_to_float16(np_array) @@ -120,9 +130,14 @@ def test_convert_to_float16_with_subgraph(self): output_32 = _ort_inference(onnx_model32, {"x":x, "y":y}) onnx_model16 = convert_float_to_float16(onnx_model32, keep_io_types=True) - output_16 = _ort_inference(onnx_model16, {"x":x, "y":y}) - self.assertTrue(np.allclose(output_16, output_32, atol=1e-2)) - + actual = _ort_inference(onnx_model16, {"x":x, "y":y}) + self.assertTrue(np.allclose(actual, output_32, atol=1e-2)) + self.assertTrue(actual[0].dtype == np.float32) + + onnx_model16 = convert_float_to_float16(onnx_model32, keep_io_types=False) + actual = _ort_inference(onnx_model16, {"x": x.astype(np.float16), "y": y.astype(np.float16)}) + self.assertTrue(np.allclose(actual, output_32, atol=1e-2)) + self.assertTrue(actual[0].dtype == np.float16) if __name__ == '__main__':