From cc3cdf8932bd068baf2bb7826eae0fa4ca3735dc Mon Sep 17 00:00:00 2001 From: Marcel R Date: Wed, 15 Nov 2023 15:59:05 +0100 Subject: [PATCH] Fix lazy loader tests. --- tests/test_aot.py | 13 +++++++++++-- tests/test_compile_tf_graph.py | 5 ++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/tests/test_aot.py b/tests/test_aot.py index ea50e8b..6982275 100644 --- a/tests/test_aot.py +++ b/tests/test_aot.py @@ -9,8 +9,7 @@ import cmsml from cmsml.util import tmp_dir, tmp_file -import cmsml.tensorflow.tools as cmsml_tools -from cmsml.tensorflow.aot import get_graph_ops, OpsData +# from cmsml.tensorflow.aot import get_graph_ops, OpsData from . import CMSMLTestCase @@ -57,6 +56,8 @@ def tf_version(self): return self._tf_version def create_graph_def(self, create="saved_model", **kwargs): + import cmsml.tensorflow.tools as cmsml_tools + # helper function to create GraphDef from SavedModel or Graph tf = self.tf @@ -91,6 +92,8 @@ def create_graph_def(self, create="saved_model", **kwargs): @skip_if_no_tf2xla_supported_ops def test_get_graph_ops_saved_model(self): + from cmsml.tensorflow.aot import get_graph_ops + tf_graph_def, keras_graph_def = self.create_graph_def(create="saved_model") graph_ops = set(get_graph_ops(tf_graph_def, node_def_number=0)) @@ -105,6 +108,8 @@ def test_get_graph_ops_saved_model(self): @skip_if_no_tf2xla_supported_ops def test_get_graph_ops_graph(self): + from cmsml.tensorflow.aot import get_graph_ops + concrete_function_graph_def = self.create_graph_def(create="graph") graph_ops = set(get_graph_ops(concrete_function_graph_def, node_def_number=0)) @@ -147,6 +152,8 @@ def tf_version(self): @skip_if_no_tf2xla_supported_ops def test_parse_ops_table(self): + from cmsml.tensorflow.aot import OpsData + ops_dict = OpsData.parse_ops_table(device="cpu") expected_ops = ("Abs", "Acosh", "Add", "Atan", "BatchMatMul", "Conv2D") @@ -158,6 +165,8 @@ def test_parse_ops_table(self): @skip_if_no_tf2xla_supported_ops def test_determine_ops(self): + from cmsml.tensorflow.aot import OpsData + # function to merge multiple tables devices = ("cpu", "gpu") diff --git a/tests/test_compile_tf_graph.py b/tests/test_compile_tf_graph.py index e3603ff..9f06673 100644 --- a/tests/test_compile_tf_graph.py +++ b/tests/test_compile_tf_graph.py @@ -5,7 +5,6 @@ import cmsml from cmsml.util import tmp_dir -from cmsml.scripts.compile_tf_graph import compile_tf_graph from . import CMSMLTestCase @@ -49,6 +48,8 @@ def create_test_model(self, tf): return model def test_compile_tf_graph_static_preparation(self): + from cmsml.scripts.compile_tf_graph import compile_tf_graph + # check only preparation process for aot, but do not aot compile tf = self.tf @@ -105,6 +106,8 @@ def test_compile_tf_graph_static_preparation(self): compile_class=None) def test_compile_tf_graph_static_aot_compilation(self): + from cmsml.scripts.compile_tf_graph import compile_tf_graph + # check aot compilation tf = self.tf model = self.create_test_model(tf)