Skip to content

Commit cc3cdf8

Browse files
committed
Fix lazy loader tests.
1 parent e1aa4e0 commit cc3cdf8

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

tests/test_aot.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99

1010
import cmsml
1111
from cmsml.util import tmp_dir, tmp_file
12-
import cmsml.tensorflow.tools as cmsml_tools
13-
from cmsml.tensorflow.aot import get_graph_ops, OpsData
12+
# from cmsml.tensorflow.aot import get_graph_ops, OpsData
1413

1514
from . import CMSMLTestCase
1615

@@ -57,6 +56,8 @@ def tf_version(self):
5756
return self._tf_version
5857

5958
def create_graph_def(self, create="saved_model", **kwargs):
59+
import cmsml.tensorflow.tools as cmsml_tools
60+
6061
# helper function to create GraphDef from SavedModel or Graph
6162
tf = self.tf
6263

@@ -91,6 +92,8 @@ def create_graph_def(self, create="saved_model", **kwargs):
9192

9293
@skip_if_no_tf2xla_supported_ops
9394
def test_get_graph_ops_saved_model(self):
95+
from cmsml.tensorflow.aot import get_graph_ops
96+
9497
tf_graph_def, keras_graph_def = self.create_graph_def(create="saved_model")
9598

9699
graph_ops = set(get_graph_ops(tf_graph_def, node_def_number=0))
@@ -105,6 +108,8 @@ def test_get_graph_ops_saved_model(self):
105108

106109
@skip_if_no_tf2xla_supported_ops
107110
def test_get_graph_ops_graph(self):
111+
from cmsml.tensorflow.aot import get_graph_ops
112+
108113
concrete_function_graph_def = self.create_graph_def(create="graph")
109114
graph_ops = set(get_graph_ops(concrete_function_graph_def, node_def_number=0))
110115

@@ -147,6 +152,8 @@ def tf_version(self):
147152

148153
@skip_if_no_tf2xla_supported_ops
149154
def test_parse_ops_table(self):
155+
from cmsml.tensorflow.aot import OpsData
156+
150157
ops_dict = OpsData.parse_ops_table(device="cpu")
151158
expected_ops = ("Abs", "Acosh", "Add", "Atan", "BatchMatMul", "Conv2D")
152159

@@ -158,6 +165,8 @@ def test_parse_ops_table(self):
158165

159166
@skip_if_no_tf2xla_supported_ops
160167
def test_determine_ops(self):
168+
from cmsml.tensorflow.aot import OpsData
169+
161170
# function to merge multiple tables
162171
devices = ("cpu", "gpu")
163172

tests/test_compile_tf_graph.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import cmsml
77
from cmsml.util import tmp_dir
8-
from cmsml.scripts.compile_tf_graph import compile_tf_graph
98

109
from . import CMSMLTestCase
1110

@@ -49,6 +48,8 @@ def create_test_model(self, tf):
4948
return model
5049

5150
def test_compile_tf_graph_static_preparation(self):
51+
from cmsml.scripts.compile_tf_graph import compile_tf_graph
52+
5253
# check only preparation process for aot, but do not aot compile
5354
tf = self.tf
5455

@@ -105,6 +106,8 @@ def test_compile_tf_graph_static_preparation(self):
105106
compile_class=None)
106107

107108
def test_compile_tf_graph_static_aot_compilation(self):
109+
from cmsml.scripts.compile_tf_graph import compile_tf_graph
110+
108111
# check aot compilation
109112
tf = self.tf
110113
model = self.create_test_model(tf)

0 commit comments

Comments
 (0)