Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaowuhu committed Jun 5, 2024
1 parent d11af42 commit b8013d8
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 321 deletions.
8 changes: 5 additions & 3 deletions onnxconverter_common/auto_mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def run_attempt(node_block_list, return_model=False):
print(node_block_list)
# compare new and old model
model = float16.convert_float_to_float16(copy.deepcopy(model0), node_block_list=node_block_list,
is_io_fp32=keep_io_types, disable_shape_infer=False)
keep_io_types=keep_io_types, disable_shape_infer=False)
#onnx.save_model(model, "d:/new_fp16.onnx")
res1 = get_tensor_values_using_ort(model, feed_dict)
if return_model:
Expand Down Expand Up @@ -131,15 +131,17 @@ def get_tensor_values_using_ort(model, input_feed, output_names=None, sess_optio
# Below code is for debug only, keep it for next time use
# sess_options = ort.SessionOptions()
# sess_options.optimized_model_filepath = "d:/optimized_model.onnx"
sess = ort.InferenceSession(model.SerializeToString(), sess_options, providers=['CUDAExecutionProvider'])
#sess = ort.InferenceSession(model.SerializeToString(), sess_options, providers=['CUDAExecutionProvider'])
sess = ort.InferenceSession(model.SerializeToString(), sess_options, providers=['CPUExecutionProvider'])
return sess.run(None, input_feed)
original_outputs = list(model.graph.output)
while len(model.graph.output) > 0:
model.graph.output.pop()
for n in output_names:
out = model.graph.output.add()
out.name = n
sess = ort.InferenceSession(model.SerializeToString(), sess_options, providers=['CUDAExecutionProvider'])
#sess = ort.InferenceSession(model.SerializeToString(), sess_options, providers=['CUDAExecutionProvider'])
sess = ort.InferenceSession(model.SerializeToString(), sess_options, providers=['CPUExecutionProvider'])
try:
return sess.run(output_names, input_feed)
finally:
Expand Down
2 changes: 1 addition & 1 deletion onnxconverter_common/auto_mixed_precision_model_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def _convert_and_check_inference_result(**kwargs):
_print_node_block_list(node_block_list)
model_16 = float16.convert_float_to_float16(
copy.deepcopy(model_32), node_block_list=node_block_list,
is_io_fp32=keep_io_types, disable_shape_infer=True)
keep_io_types=keep_io_types, disable_shape_infer=True)

if is_final_model:
location = kwargs.get("location") # using the speficified external data file name
Expand Down
Loading

0 comments on commit b8013d8

Please sign in to comment.