@@ -25,20 +25,30 @@ limitations under the License.
25
25
26
26
namespace tensorflow {
27
27
28
- class MklQuantizeV2OpTest : public OpsTestBase {};
28
+ class MklQuantizeV2OpTest : public OpsTestBase ,
29
+ public ::testing::WithParamInterface<DataType> {};
29
30
30
- TEST_F (MklQuantizeV2OpTest, small_uint8) {
31
+ TEST_P (MklQuantizeV2OpTest, small_uint8) {
32
+ const auto dtype = GetParam ();
31
33
TF_ASSERT_OK (NodeDefBuilder (" quantize_op" , " _MklQuantizeV2" )
32
- .Input (FakeInput (DT_FLOAT ))
34
+ .Input (FakeInput (dtype ))
33
35
.Input (FakeInput (DT_FLOAT))
34
36
.Input (FakeInput (DT_FLOAT))
35
37
.Attr (" T" , DataTypeToEnum<quint8>::v ())
36
38
.Attr (" mode" , " SCALED" )
37
39
.Attr (" _kernel" , " QuantizedMklOp" )
38
40
.Finalize (node_def ()));
39
41
TF_ASSERT_OK (InitOp ());
40
- AddInputFromArray<float >(TensorShape ({8 }),
41
- {0.0 , 1.0 , 1.25 , 1.75 , 127.0 , 255.0 , 500.0 , 2.0 });
42
+ switch (dtype) {
43
+ case DT_BFLOAT16:
44
+ AddInputFromList<bfloat16>(
45
+ TensorShape ({8 }), {0.0 , 1.0 , 1.25 , 1.75 , 127.0 , 255.0 , 500.0 , 2.0 });
46
+ break ;
47
+
48
+ default :
49
+ AddInputFromArray<float >(
50
+ TensorShape ({8 }), {0.0 , 1.0 , 1.25 , 1.75 , 127.0 , 255.0 , 500.0 , 2.0 });
51
+ }
42
52
// min_range = 0
43
53
AddInputFromArray<float >(TensorShape ({}), {0 });
44
54
// max_range = 255
@@ -56,20 +66,30 @@ TEST_F(MklQuantizeV2OpTest, small_uint8) {
56
66
test::ExpectTensorEqual<float >(expected_min, *GetOutput (1 ));
57
67
test::ExpectTensorEqual<float >(expected_max, *GetOutput (2 ));
58
68
}
59
- TEST_F (MklQuantizeV2OpTest, small_int8) {
69
+
70
+ TEST_P (MklQuantizeV2OpTest, small_int8) {
71
+ const auto dtype = GetParam ();
60
72
TF_ASSERT_OK (NodeDefBuilder (" quantize_op" , " _MklQuantizeV2" )
61
- .Input (FakeInput (DT_FLOAT ))
73
+ .Input (FakeInput (dtype ))
62
74
.Input (FakeInput (DT_FLOAT))
63
75
.Input (FakeInput (DT_FLOAT))
64
76
.Attr (" T" , DataTypeToEnum<qint8>::v ())
65
77
.Attr (" mode" , " SCALED" )
66
78
.Attr (" _kernel" , " QuantizedMklOp" )
67
79
.Finalize (node_def ()));
68
80
TF_ASSERT_OK (InitOp ());
69
- AddInputFromArray<float >(TensorShape ({8 }), {0.0 , -1.0 , 1.25 , -1.75 , -24.5 ,
70
- -255.0 , -80.315 , 256.0 });
71
- AddInputFromArray<float >(TensorShape ({}), {-50.0 });
72
- AddInputFromArray<float >(TensorShape ({}), {127.0 });
81
+ switch (dtype) {
82
+ case DT_BFLOAT16:
83
+ AddInputFromList<bfloat16>(
84
+ TensorShape ({8 }),
85
+ {0.0 , -1.0 , 1.25 , -1.75 , -24.5 , -255.0 , -80.315 , 256.0 });
86
+ break ;
87
+ default :
88
+ AddInputFromArray<float >(TensorShape ({8 }), {0.0 , -1.0 , 1.25 , -1.75 , -24.5 ,
89
+ -255.0 , -80.315 , 256.0 });
90
+ }
91
+ AddInputFromArray<float >(TensorShape ({1 }), {-50.0 });
92
+ AddInputFromArray<float >(TensorShape ({1 }), {127.0 });
73
93
TF_ASSERT_OK (RunOpKernel ());
74
94
Tensor expected (allocator (), DT_QINT8, TensorShape ({8 }));
75
95
Tensor expected_min (allocator (), DT_FLOAT, TensorShape ({}));
@@ -82,20 +102,28 @@ TEST_F(MklQuantizeV2OpTest, small_int8) {
82
102
test::ExpectTensorEqual<float >(expected_max, *GetOutput (2 ));
83
103
}
84
104
85
- TEST_F (MklQuantizeV2OpTest, small_minfirst) {
105
+ TEST_P (MklQuantizeV2OpTest, small_minfirst) {
106
+ const auto dtype = GetParam ();
86
107
TF_ASSERT_OK (NodeDefBuilder (" quantize_op" , " _MklQuantizeV2" )
87
- .Input (FakeInput (DT_FLOAT ))
108
+ .Input (FakeInput (dtype ))
88
109
.Input (FakeInput (DT_FLOAT))
89
110
.Input (FakeInput (DT_FLOAT))
90
111
.Attr (" T" , DataTypeToEnum<quint8>::v ())
91
112
.Attr (" mode" , " MIN_FIRST" )
92
113
.Attr (" _kernel" , " QuantizedMklOp" )
93
114
.Finalize (node_def ()));
94
115
TF_ASSERT_OK (InitOp ());
95
- AddInputFromArray<float >(TensorShape ({8 }),
96
- {1.0 , 1.25 , 1.75 , 2 , 3.15 , 127.0 , 255.0 , 500.0 });
97
- AddInputFromArray<float >(TensorShape ({}), {0 });
98
- AddInputFromArray<float >(TensorShape ({}), {255 .0f });
116
+ switch (dtype) {
117
+ case DT_BFLOAT16:
118
+ AddInputFromList<bfloat16>(
119
+ TensorShape ({8 }), {1.0 , 1.25 , 1.75 , 2.0 , 3.15 , 127.0 , 255.0 , 500.0 });
120
+ break ;
121
+ default :
122
+ AddInputFromArray<float >(
123
+ TensorShape ({8 }), {1.0 , 1.25 , 1.75 , 2.0 , 3.15 , 127.0 , 255.0 , 500.0 });
124
+ }
125
+ AddInputFromArray<float >(TensorShape ({1 }), {0 });
126
+ AddInputFromArray<float >(TensorShape ({1 }), {255 .0f });
99
127
TF_ASSERT_OK (RunOpKernel ());
100
128
Tensor expected (allocator (), DT_QUINT8, TensorShape ({8 }));
101
129
test::FillValues<quint8>(&expected, {1 , 1 , 2 , 2 , 3 , 127 , 255 , 255 });
@@ -106,20 +134,28 @@ TEST_F(MklQuantizeV2OpTest, small_minfirst) {
106
134
EXPECT_NEAR (255 .0f , output_max, 1e-5f );
107
135
}
108
136
109
- TEST_F (MklQuantizeV2OpTest, small_minfirst_uint) {
137
+ TEST_P (MklQuantizeV2OpTest, small_minfirst_uint) {
138
+ const auto dtype = GetParam ();
110
139
TF_ASSERT_OK (NodeDefBuilder (" quantize_op" , " _MklQuantizeV2" )
111
- .Input (FakeInput (DT_FLOAT ))
140
+ .Input (FakeInput (dtype ))
112
141
.Input (FakeInput (DT_FLOAT))
113
142
.Input (FakeInput (DT_FLOAT))
114
143
.Attr (" T" , DataTypeToEnum<quint8>::v ())
115
144
.Attr (" mode" , " MIN_FIRST" )
116
145
.Attr (" _kernel" , " QuantizedMklOp" )
117
146
.Finalize (node_def ()));
118
147
TF_ASSERT_OK (InitOp ());
119
- AddInputFromArray<float >(TensorShape ({8 }),
120
- {0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 });
121
- AddInputFromArray<float >(TensorShape ({}), {0.1 });
122
- AddInputFromArray<float >(TensorShape ({}), {0.8 });
148
+ switch (dtype) {
149
+ case DT_BFLOAT16:
150
+ AddInputFromList<bfloat16>(TensorShape ({8 }),
151
+ {0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.599 , 0.7 , 0.8 });
152
+ break ;
153
+ default :
154
+ AddInputFromArray<float >(TensorShape ({8 }),
155
+ {0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 });
156
+ }
157
+ AddInputFromArray<float >(TensorShape ({1 }), {0.1 });
158
+ AddInputFromArray<float >(TensorShape ({1 }), {0.8 });
123
159
TF_ASSERT_OK (RunOpKernel ());
124
160
Tensor expected (allocator (), DT_QUINT8, TensorShape ({8 }));
125
161
test::FillValues<quint8>(&expected, {32 , 64 , 96 , 128 , 159 , 191 , 223 , 255 });
@@ -130,20 +166,29 @@ TEST_F(MklQuantizeV2OpTest, small_minfirst_uint) {
130
166
EXPECT_NEAR (0 .8f , output_max, 1e-5f );
131
167
}
132
168
133
- TEST_F (MklQuantizeV2OpTest, small_minfirst_int) {
169
+ TEST_P (MklQuantizeV2OpTest, small_minfirst_int) {
170
+ const auto dtype = GetParam ();
134
171
TF_ASSERT_OK (NodeDefBuilder (" quantize_op" , " _MklQuantizeV2" )
135
- .Input (FakeInput (DT_FLOAT ))
172
+ .Input (FakeInput (dtype ))
136
173
.Input (FakeInput (DT_FLOAT))
137
174
.Input (FakeInput (DT_FLOAT))
138
175
.Attr (" T" , DataTypeToEnum<quint8>::v ())
139
176
.Attr (" mode" , " MIN_FIRST" )
140
177
.Attr (" _kernel" , " QuantizedMklOp" )
141
178
.Finalize (node_def ()));
142
179
TF_ASSERT_OK (InitOp ());
143
- AddInputFromArray<float >(TensorShape ({8 }),
144
- {-0.1 , -0.2 , -0.3 , -0.4 , -0.5 , -0.6 , -0.7 , -0.8 });
145
- AddInputFromArray<float >(TensorShape ({}), {-0.8 });
146
- AddInputFromArray<float >(TensorShape ({}), {-0.1 });
180
+ switch (dtype) {
181
+ case DT_BFLOAT16:
182
+ AddInputFromList<bfloat16>(
183
+ TensorShape ({8 }), {-0.1 , -0.2 , -0.3 , -0.4 , -0.5 , -0.6 , -0.7 , -0.8 });
184
+
185
+ break ;
186
+ default :
187
+ AddInputFromArray<float >(
188
+ TensorShape ({8 }), {-0.1 , -0.2 , -0.3 , -0.4 , -0.5 , -0.6 , -0.7 , -0.8 });
189
+ }
190
+ AddInputFromArray<float >(TensorShape ({1 }), {-0.8 });
191
+ AddInputFromArray<float >(TensorShape ({1 }), {-0.1 });
147
192
TF_ASSERT_OK (RunOpKernel ());
148
193
Tensor expected (allocator (), DT_QUINT8, TensorShape ({8 }));
149
194
test::FillValues<quint8>(&expected, {223 , 191 , 159 , 128 , 96 , 64 , 32 , 0 });
@@ -154,5 +199,8 @@ TEST_F(MklQuantizeV2OpTest, small_minfirst_int) {
154
199
EXPECT_NEAR (0 .0f , output_max, 1e-5f );
155
200
}
156
201
202
+ INSTANTIATE_TEST_SUITE_P (All, MklQuantizeV2OpTest,
203
+ ::testing::Values (DT_FLOAT, DT_BFLOAT16));
204
+
157
205
} // end namespace tensorflow
158
206
#endif // INTEL_MKL
0 commit comments