10
10
11
11
12
12
class TfCompileTestCase (CMSMLTestCase ):
13
+
13
14
def __init__ (self , * args , ** kwargs ):
14
15
super (TfCompileTestCase , self ).__init__ (* args , ** kwargs )
15
16
@@ -44,6 +45,7 @@ def create_test_model(self, tf):
44
45
x = tf .concat ([x1 , x2 ], axis = 1 )
45
46
a1 = tf .keras .layers .Dense (10 , activation = "elu" )(x )
46
47
y = tf .keras .layers .Dense (5 , activation = "softmax" )(a1 )
48
+
47
49
model = tf .keras .Model (inputs = (x1 , x2 , x3 ), outputs = y )
48
50
return model
49
51
@@ -55,55 +57,60 @@ def test_compile_tf_graph_static_preparation(self):
55
57
56
58
model = self .create_test_model (tf )
57
59
58
- with tmp_dir (create = False ) as model_path :
60
+ with tmp_dir (create = False ) as model_path , tmp_dir ( create = False ) as static_saved_model_path :
59
61
tf .saved_model .save (model , model_path )
60
62
61
- with tmp_dir (create = False ) as static_saved_model_path :
62
- batch_sizes = [1 , 2 ]
63
-
64
- compile_tf_graph (model_path = model_path ,
65
- output_path = static_saved_model_path ,
66
- batch_sizes = batch_sizes ,
67
- input_serving_key = tf .saved_model .DEFAULT_SERVING_SIGNATURE_DEF_KEY ,
68
- output_serving_key = None ,
69
- compile_prefix = None ,
70
- compile_class = None )
71
-
72
- # load model and check input shape
73
- loaded_static_model = cmsml .tensorflow .load_model (static_saved_model_path )
74
- for batch_size in batch_sizes :
75
- # first entry is empty, second contains inputs tuple(tensorspecs)
76
- model_static_inputs = loaded_static_model .signatures [f"serving_default__{ batch_size } " ].structured_input_signature [1 ] # noqa
77
-
78
- expected_model_static_inputs = {
79
- f"first__bs{ batch_size } " : tf .TensorSpec (
80
- shape = (batch_size , 2 ),
81
- dtype = tf .float32 ,
82
- name = f"first__bs{ batch_size } " ,
83
- ),
84
- f"second__bs{ batch_size } " : tf .TensorSpec (
85
- shape = (batch_size , 3 ),
86
- dtype = tf .float32 ,
87
- name = f"second__bs{ batch_size } " ,
88
- ),
89
- f"third__bs{ batch_size } " : tf .TensorSpec (
90
- shape = (batch_size , 10 ),
91
- dtype = tf .float32 ,
92
- name = f"third__bs{ batch_size } " ,
93
- ),
94
- }
95
-
96
- self .assertDictEqual (model_static_inputs , expected_model_static_inputs )
97
-
98
- # throw error if compilation happens with illegal batch size
99
- with self .assertRaises (ValueError ):
100
- compile_tf_graph (model_path = model_path ,
101
- output_path = static_saved_model_path ,
102
- batch_sizes = [- 1 ,],
103
- input_serving_key = tf .saved_model .DEFAULT_SERVING_SIGNATURE_DEF_KEY ,
104
- output_serving_key = None ,
105
- compile_prefix = None ,
106
- compile_class = None )
63
+ # throw error if compilation happens with illegal batch size
64
+ with self .assertRaises (ValueError ):
65
+ compile_tf_graph (
66
+ model_path = model_path ,
67
+ output_path = static_saved_model_path ,
68
+ batch_sizes = [- 1 ],
69
+ input_serving_key = tf .saved_model .DEFAULT_SERVING_SIGNATURE_DEF_KEY ,
70
+ output_serving_key = None ,
71
+ compile_prefix = None ,
72
+ compile_class = None ,
73
+ )
74
+
75
+ batch_sizes = [1 , 2 ]
76
+ compile_tf_graph (
77
+ model_path = model_path ,
78
+ output_path = static_saved_model_path ,
79
+ batch_sizes = batch_sizes ,
80
+ input_serving_key = tf .saved_model .DEFAULT_SERVING_SIGNATURE_DEF_KEY ,
81
+ output_serving_key = None ,
82
+ compile_prefix = None ,
83
+ compile_class = None ,
84
+ )
85
+
86
+ # load model
87
+ loaded_static_model = cmsml .tensorflow .load_model (static_saved_model_path )
88
+
89
+ # check input shape
90
+ for batch_size in batch_sizes :
91
+ # first entry is empty, second contains inputs tuple(tensorspecs)
92
+ key = f"serving_default_bs{ batch_size } "
93
+ model_static_inputs = loaded_static_model .signatures [key ].structured_input_signature [1 ]
94
+
95
+ expected_model_static_inputs = {
96
+ f"first_bs{ batch_size } " : tf .TensorSpec (
97
+ shape = (batch_size , 2 ),
98
+ dtype = tf .float32 ,
99
+ name = f"first_bs{ batch_size } " ,
100
+ ),
101
+ f"second_bs{ batch_size } " : tf .TensorSpec (
102
+ shape = (batch_size , 3 ),
103
+ dtype = tf .float32 ,
104
+ name = f"second_bs{ batch_size } " ,
105
+ ),
106
+ f"third_bs{ batch_size } " : tf .TensorSpec (
107
+ shape = (batch_size , 10 ),
108
+ dtype = tf .float32 ,
109
+ name = f"third_bs{ batch_size } " ,
110
+ ),
111
+ }
112
+
113
+ self .assertDictEqual (model_static_inputs , expected_model_static_inputs )
107
114
108
115
def test_compile_tf_graph_static_aot_compilation (self ):
109
116
from cmsml .scripts .compile_tf_graph import compile_tf_graph
@@ -112,23 +119,24 @@ def test_compile_tf_graph_static_aot_compilation(self):
112
119
tf = self .tf
113
120
model = self .create_test_model (tf )
114
121
115
- with tmp_dir (create = False ) as model_path :
122
+ with tmp_dir (create = False ) as model_path , tmp_dir ( create = False ) as static_saved_model_path :
116
123
tf .saved_model .save (model , model_path )
117
124
118
- with tmp_dir (create = False ) as static_saved_model_path :
119
- batch_sizes = [1 , 2 ]
120
- compile_tf_graph (model_path = model_path ,
121
- output_path = static_saved_model_path ,
122
- batch_sizes = batch_sizes ,
123
- input_serving_key = tf .saved_model .DEFAULT_SERVING_SIGNATURE_DEF_KEY ,
124
- output_serving_key = None ,
125
- compile_prefix = "aot_model_bs_{}" ,
126
- compile_class = "bs_{}" )
127
-
128
- aot_dir = os .path .join (static_saved_model_path , "aot" )
129
- for batch_size in batch_sizes :
130
- aot_model_header = os .path .join (aot_dir , "aot_model_bs_{}.h" .format (batch_size ))
131
- aot_model_object = os .path .join (aot_dir , "aot_model_bs_{}.o" .format (batch_size ))
132
-
133
- self .assertTrue (os .path .exists (aot_model_object ))
134
- self .assertTrue (os .path .exists (aot_model_header ))
125
+ batch_sizes = [1 , 2 ]
126
+ compile_tf_graph (
127
+ model_path = model_path ,
128
+ output_path = static_saved_model_path ,
129
+ batch_sizes = batch_sizes ,
130
+ input_serving_key = tf .saved_model .DEFAULT_SERVING_SIGNATURE_DEF_KEY ,
131
+ output_serving_key = None ,
132
+ compile_prefix = "aot_model_bs{}" ,
133
+ compile_class = "bs_{}" ,
134
+ )
135
+
136
+ aot_dir = os .path .join (static_saved_model_path , "aot" )
137
+ for batch_size in batch_sizes :
138
+ aot_model_header = os .path .join (aot_dir , "aot_model_bs{}.h" .format (batch_size ))
139
+ aot_model_object = os .path .join (aot_dir , "aot_model_bs{}.o" .format (batch_size ))
140
+
141
+ self .assertTrue (os .path .exists (aot_model_object ))
142
+ self .assertTrue (os .path .exists (aot_model_header ))
0 commit comments