Skip to content

Commit

Permalink
Fix lazy loader tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
riga committed Nov 15, 2023
1 parent e1aa4e0 commit cc3cdf8
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
13 changes: 11 additions & 2 deletions tests/test_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@

import cmsml
from cmsml.util import tmp_dir, tmp_file
import cmsml.tensorflow.tools as cmsml_tools
from cmsml.tensorflow.aot import get_graph_ops, OpsData
# from cmsml.tensorflow.aot import get_graph_ops, OpsData

from . import CMSMLTestCase

Expand Down Expand Up @@ -57,6 +56,8 @@ def tf_version(self):
return self._tf_version

def create_graph_def(self, create="saved_model", **kwargs):
import cmsml.tensorflow.tools as cmsml_tools

# helper function to create GraphDef from SavedModel or Graph
tf = self.tf

Expand Down Expand Up @@ -91,6 +92,8 @@ def create_graph_def(self, create="saved_model", **kwargs):

@skip_if_no_tf2xla_supported_ops
def test_get_graph_ops_saved_model(self):
from cmsml.tensorflow.aot import get_graph_ops

tf_graph_def, keras_graph_def = self.create_graph_def(create="saved_model")

graph_ops = set(get_graph_ops(tf_graph_def, node_def_number=0))
Expand All @@ -105,6 +108,8 @@ def test_get_graph_ops_saved_model(self):

@skip_if_no_tf2xla_supported_ops
def test_get_graph_ops_graph(self):
from cmsml.tensorflow.aot import get_graph_ops

concrete_function_graph_def = self.create_graph_def(create="graph")
graph_ops = set(get_graph_ops(concrete_function_graph_def, node_def_number=0))

Expand Down Expand Up @@ -147,6 +152,8 @@ def tf_version(self):

@skip_if_no_tf2xla_supported_ops
def test_parse_ops_table(self):
from cmsml.tensorflow.aot import OpsData

ops_dict = OpsData.parse_ops_table(device="cpu")
expected_ops = ("Abs", "Acosh", "Add", "Atan", "BatchMatMul", "Conv2D")

Expand All @@ -158,6 +165,8 @@ def test_parse_ops_table(self):

@skip_if_no_tf2xla_supported_ops
def test_determine_ops(self):
from cmsml.tensorflow.aot import OpsData

# function to merge multiple tables
devices = ("cpu", "gpu")

Expand Down
5 changes: 4 additions & 1 deletion tests/test_compile_tf_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import cmsml
from cmsml.util import tmp_dir
from cmsml.scripts.compile_tf_graph import compile_tf_graph

from . import CMSMLTestCase

Expand Down Expand Up @@ -49,6 +48,8 @@ def create_test_model(self, tf):
return model

def test_compile_tf_graph_static_preparation(self):
from cmsml.scripts.compile_tf_graph import compile_tf_graph

# check only preparation process for aot, but do not aot compile
tf = self.tf

Expand Down Expand Up @@ -105,6 +106,8 @@ def test_compile_tf_graph_static_preparation(self):
compile_class=None)

def test_compile_tf_graph_static_aot_compilation(self):
from cmsml.scripts.compile_tf_graph import compile_tf_graph

# check aot compilation
tf = self.tf
model = self.create_test_model(tf)
Expand Down

0 comments on commit cc3cdf8

Please sign in to comment.