diff --git a/python/tvm/driver/tvmc/autotuner.py b/python/tvm/driver/tvmc/autotuner.py index 82b5cc1598fd..ad4f8ae6169b 100644 --- a/python/tvm/driver/tvmc/autotuner.py +++ b/python/tvm/driver/tvmc/autotuner.py @@ -672,7 +672,7 @@ def autotvm_get_tuning_tasks( """ target, target_host = Target.canon_target_and_host(target, target_host) - mod = apply_graph_transforms(mod, transform_args) + mod = apply_graph_transforms(mod, transform_args, params) tasks = autotvm.task.extract_from_program( mod["main"], @@ -718,7 +718,7 @@ def autoscheduler_get_tuning_tasks( """ target, target_host = Target.canon_target_and_host(target, target_host) - mod = apply_graph_transforms(mod, transform_args) + mod = apply_graph_transforms(mod, transform_args, params) # Extract the tasks tasks, task_weights = auto_scheduler.extract_tasks( diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 09ba8909e3e7..78c0c04eb844 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -402,7 +402,7 @@ def compile_model( instruments=instruments, ): transform_args = parse_graph_transform_args(locals()) - mod = apply_graph_transforms(mod, transform_args) + mod = apply_graph_transforms(mod, transform_args, params) for partition_function, opts in zip(partition_functions, partition_opts): mod = partition_function(mod, params, mod_name=mod_name, **opts) diff --git a/python/tvm/driver/tvmc/transform.py b/python/tvm/driver/tvmc/transform.py index 30d9bfa639b1..253c624e6ed4 100644 --- a/python/tvm/driver/tvmc/transform.py +++ b/python/tvm/driver/tvmc/transform.py @@ -162,7 +162,7 @@ def layout_helper(layout): raise TVMCException("Error converting layouts: {}".format(str(err))) -def apply_graph_transforms(mod, args): +def apply_graph_transforms(mod, args, params=None): """Alter the layout of the input graph. Parameters @@ -171,6 +171,8 @@ def apply_graph_transforms(mod, args): The relay module to convert. args : dict The transform arguments. + params: dict + Module params Returns ------- @@ -188,6 +190,7 @@ def apply_graph_transforms(mod, args): # ToMixedPrecision if args.get("mixed_precision", False): + mod = relay.quantize.prerequisite_optimize(mod, params) mod = convert_to_mixed_precision( mod, args.get("mixed_precision_ops"), diff --git a/tests/python/driver/tvmc/test_transform.py b/tests/python/driver/tvmc/test_transform.py index 06af3cb156c1..ebf067990d0f 100644 --- a/tests/python/driver/tvmc/test_transform.py +++ b/tests/python/driver/tvmc/test_transform.py @@ -226,6 +226,7 @@ def check(self, func): "mixed_precision_calculation_type": "float16", "mixed_precision_acc_type": "float16", }, + params, ) ret = CheckOpMutator("float16", "float16", "nn.conv2d").check(mod["main"]) assert ret @@ -240,6 +241,7 @@ def check(self, func): "mixed_precision_calculation_type": "float16", "mixed_precision_acc_type": "float32", }, + params, ) ret = CheckOpMutator("float16", "float32", "nn.conv2d").check(mod["main"]) assert ret diff --git a/tests/python/relay/opencl_texture/test_network.py b/tests/python/relay/opencl_texture/test_network.py index 2b2f3741cba2..66c88ebbe294 100644 --- a/tests/python/relay/opencl_texture/test_network.py +++ b/tests/python/relay/opencl_texture/test_network.py @@ -47,6 +47,7 @@ def _test_mobilenet_v1(remote, target, calc_dtype, executor_type, acc_dtype): "mixed_precision_calculation_type": calc_dtype, "mixed_precision_acc_type": acc_dtype, }, + params, ) if executor_type == "ge":