Skip to content

Commit c91dca1

Browse files
committed
QuantizeV2 enabled for bfloat16 input tensor.
1 parent 9d32123 commit c91dca1

File tree

7 files changed

+218
-60
lines changed

7 files changed

+218
-60
lines changed

tensorflow/core/api_def/base_api/api_def_QuantizeV2.pbtxt

+7-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,13 @@ If the `axis` attribute is specified, this will be a 1-D tensor whose size
4242
matches the `axis` dimension of the input and output tensors.
4343
END
4444
}
45-
summary: "Quantize the \'input\' tensor of type float to \'output\' tensor of type \'T\'."
45+
attr {
46+
name: "dtype"
47+
description: <<END
48+
Type of the input tensor. Currently QuantizeV2 supports float and bfloat16.
49+
END
50+
}
51+
summary: "Quantize the \'input\' tensor of types float and bfloat16 to \'output\' tensor of type \'T\'."
4652
description: <<END
4753
[min_range, max_range] are scalar floats that specify the range for
4854
the 'input' data. The 'mode' attribute controls exactly which calculations are

tensorflow/core/kernels/mkl/mkl_quantize_op.cc

+18
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,20 @@ class MklQuantizeV2Op : public OpKernel {
311311
OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_));
312312
OP_REQUIRES_OK(
313313
ctx, ctx->GetAttr("ensure_minimum_range", &ensure_minimum_range_));
314+
if (ctx->HasAttr("dtype")) {
315+
OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
316+
if (dtype_ == DT_BFLOAT16) {
317+
OP_REQUIRES(
318+
ctx,
319+
ctx->input_type(0) == DT_BFLOAT16 &&
320+
(mode_ == QUANTIZE_MODE_MIN_FIRST ||
321+
mode_ == QUANTIZE_MODE_SCALED),
322+
errors::InvalidArgument("Input type bfloat16 is supported only "
323+
"with MIN_FIRST and SCLAED modes"));
324+
}
325+
} else {
326+
dtype_ = DT_FLOAT;
327+
}
314328
}
315329

316330
void ComputeScalar(OpKernelContext* ctx, float min_range, float max_range) {
@@ -605,18 +619,22 @@ class MklQuantizeV2Op : public OpKernel {
605619
int round_mode_;
606620
int axis_;
607621
bool narrow_range_;
622+
DataType dtype_;
608623
};
609624

610625
#define REGISTER_QUANTIZE(src_type, dst_type) \
611626
REGISTER_KERNEL_BUILDER( \
612627
Name("_MklQuantizeV2") \
613628
.Device(DEVICE_CPU) \
629+
.TypeConstraint<src_type>("dtype") \
614630
.TypeConstraint<dst_type>("T") \
615631
.Label(mkl_op_registry::kMklQuantizedOpLabel), \
616632
MklQuantizeV2Op<CPUDevice, dst_type, src_type, true>)
617633

618634
REGISTER_QUANTIZE(float, qint8);
619635
REGISTER_QUANTIZE(float, quint8);
636+
REGISTER_QUANTIZE(bfloat16, qint8);
637+
REGISTER_QUANTIZE(bfloat16, quint8);
620638

621639
#undef SET_MKL_LAYOUT
622640

tensorflow/core/kernels/mkl/mkl_quantize_op_test.cc

+77-29
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,30 @@ limitations under the License.
2525

2626
namespace tensorflow {
2727

28-
class MklQuantizeV2OpTest : public OpsTestBase {};
28+
class MklQuantizeV2OpTest : public OpsTestBase,
29+
public ::testing::WithParamInterface<DataType> {};
2930

30-
TEST_F(MklQuantizeV2OpTest, small_uint8) {
31+
TEST_P(MklQuantizeV2OpTest, small_uint8) {
32+
const auto dtype = GetParam();
3133
TF_ASSERT_OK(NodeDefBuilder("quantize_op", "_MklQuantizeV2")
32-
.Input(FakeInput(DT_FLOAT))
34+
.Input(FakeInput(dtype))
3335
.Input(FakeInput(DT_FLOAT))
3436
.Input(FakeInput(DT_FLOAT))
3537
.Attr("T", DataTypeToEnum<quint8>::v())
3638
.Attr("mode", "SCALED")
3739
.Attr("_kernel", "QuantizedMklOp")
3840
.Finalize(node_def()));
3941
TF_ASSERT_OK(InitOp());
40-
AddInputFromArray<float>(TensorShape({8}),
41-
{0.0, 1.0, 1.25, 1.75, 127.0, 255.0, 500.0, 2.0});
42+
switch (dtype) {
43+
case DT_BFLOAT16:
44+
AddInputFromList<bfloat16>(
45+
TensorShape({8}), {0.0, 1.0, 1.25, 1.75, 127.0, 255.0, 500.0, 2.0});
46+
break;
47+
48+
default:
49+
AddInputFromArray<float>(
50+
TensorShape({8}), {0.0, 1.0, 1.25, 1.75, 127.0, 255.0, 500.0, 2.0});
51+
}
4252
// min_range = 0
4353
AddInputFromArray<float>(TensorShape({}), {0});
4454
// max_range = 255
@@ -56,20 +66,30 @@ TEST_F(MklQuantizeV2OpTest, small_uint8) {
5666
test::ExpectTensorEqual<float>(expected_min, *GetOutput(1));
5767
test::ExpectTensorEqual<float>(expected_max, *GetOutput(2));
5868
}
59-
TEST_F(MklQuantizeV2OpTest, small_int8) {
69+
70+
TEST_P(MklQuantizeV2OpTest, small_int8) {
71+
const auto dtype = GetParam();
6072
TF_ASSERT_OK(NodeDefBuilder("quantize_op", "_MklQuantizeV2")
61-
.Input(FakeInput(DT_FLOAT))
73+
.Input(FakeInput(dtype))
6274
.Input(FakeInput(DT_FLOAT))
6375
.Input(FakeInput(DT_FLOAT))
6476
.Attr("T", DataTypeToEnum<qint8>::v())
6577
.Attr("mode", "SCALED")
6678
.Attr("_kernel", "QuantizedMklOp")
6779
.Finalize(node_def()));
6880
TF_ASSERT_OK(InitOp());
69-
AddInputFromArray<float>(TensorShape({8}), {0.0, -1.0, 1.25, -1.75, -24.5,
70-
-255.0, -80.315, 256.0});
71-
AddInputFromArray<float>(TensorShape({}), {-50.0});
72-
AddInputFromArray<float>(TensorShape({}), {127.0});
81+
switch (dtype) {
82+
case DT_BFLOAT16:
83+
AddInputFromList<bfloat16>(
84+
TensorShape({8}),
85+
{0.0, -1.0, 1.25, -1.75, -24.5, -255.0, -80.315, 256.0});
86+
break;
87+
default:
88+
AddInputFromArray<float>(TensorShape({8}), {0.0, -1.0, 1.25, -1.75, -24.5,
89+
-255.0, -80.315, 256.0});
90+
}
91+
AddInputFromArray<float>(TensorShape({1}), {-50.0});
92+
AddInputFromArray<float>(TensorShape({1}), {127.0});
7393
TF_ASSERT_OK(RunOpKernel());
7494
Tensor expected(allocator(), DT_QINT8, TensorShape({8}));
7595
Tensor expected_min(allocator(), DT_FLOAT, TensorShape({}));
@@ -82,20 +102,28 @@ TEST_F(MklQuantizeV2OpTest, small_int8) {
82102
test::ExpectTensorEqual<float>(expected_max, *GetOutput(2));
83103
}
84104

85-
TEST_F(MklQuantizeV2OpTest, small_minfirst) {
105+
TEST_P(MklQuantizeV2OpTest, small_minfirst) {
106+
const auto dtype = GetParam();
86107
TF_ASSERT_OK(NodeDefBuilder("quantize_op", "_MklQuantizeV2")
87-
.Input(FakeInput(DT_FLOAT))
108+
.Input(FakeInput(dtype))
88109
.Input(FakeInput(DT_FLOAT))
89110
.Input(FakeInput(DT_FLOAT))
90111
.Attr("T", DataTypeToEnum<quint8>::v())
91112
.Attr("mode", "MIN_FIRST")
92113
.Attr("_kernel", "QuantizedMklOp")
93114
.Finalize(node_def()));
94115
TF_ASSERT_OK(InitOp());
95-
AddInputFromArray<float>(TensorShape({8}),
96-
{1.0, 1.25, 1.75, 2, 3.15, 127.0, 255.0, 500.0});
97-
AddInputFromArray<float>(TensorShape({}), {0});
98-
AddInputFromArray<float>(TensorShape({}), {255.0f});
116+
switch (dtype) {
117+
case DT_BFLOAT16:
118+
AddInputFromList<bfloat16>(
119+
TensorShape({8}), {1.0, 1.25, 1.75, 2.0, 3.15, 127.0, 255.0, 500.0});
120+
break;
121+
default:
122+
AddInputFromArray<float>(
123+
TensorShape({8}), {1.0, 1.25, 1.75, 2.0, 3.15, 127.0, 255.0, 500.0});
124+
}
125+
AddInputFromArray<float>(TensorShape({1}), {0});
126+
AddInputFromArray<float>(TensorShape({1}), {255.0f});
99127
TF_ASSERT_OK(RunOpKernel());
100128
Tensor expected(allocator(), DT_QUINT8, TensorShape({8}));
101129
test::FillValues<quint8>(&expected, {1, 1, 2, 2, 3, 127, 255, 255});
@@ -106,20 +134,28 @@ TEST_F(MklQuantizeV2OpTest, small_minfirst) {
106134
EXPECT_NEAR(255.0f, output_max, 1e-5f);
107135
}
108136

109-
TEST_F(MklQuantizeV2OpTest, small_minfirst_uint) {
137+
TEST_P(MklQuantizeV2OpTest, small_minfirst_uint) {
138+
const auto dtype = GetParam();
110139
TF_ASSERT_OK(NodeDefBuilder("quantize_op", "_MklQuantizeV2")
111-
.Input(FakeInput(DT_FLOAT))
140+
.Input(FakeInput(dtype))
112141
.Input(FakeInput(DT_FLOAT))
113142
.Input(FakeInput(DT_FLOAT))
114143
.Attr("T", DataTypeToEnum<quint8>::v())
115144
.Attr("mode", "MIN_FIRST")
116145
.Attr("_kernel", "QuantizedMklOp")
117146
.Finalize(node_def()));
118147
TF_ASSERT_OK(InitOp());
119-
AddInputFromArray<float>(TensorShape({8}),
120-
{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
121-
AddInputFromArray<float>(TensorShape({}), {0.1});
122-
AddInputFromArray<float>(TensorShape({}), {0.8});
148+
switch (dtype) {
149+
case DT_BFLOAT16:
150+
AddInputFromList<bfloat16>(TensorShape({8}),
151+
{0.1, 0.2, 0.3, 0.4, 0.5, 0.599, 0.7, 0.8});
152+
break;
153+
default:
154+
AddInputFromArray<float>(TensorShape({8}),
155+
{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
156+
}
157+
AddInputFromArray<float>(TensorShape({1}), {0.1});
158+
AddInputFromArray<float>(TensorShape({1}), {0.8});
123159
TF_ASSERT_OK(RunOpKernel());
124160
Tensor expected(allocator(), DT_QUINT8, TensorShape({8}));
125161
test::FillValues<quint8>(&expected, {32, 64, 96, 128, 159, 191, 223, 255});
@@ -130,20 +166,29 @@ TEST_F(MklQuantizeV2OpTest, small_minfirst_uint) {
130166
EXPECT_NEAR(0.8f, output_max, 1e-5f);
131167
}
132168

133-
TEST_F(MklQuantizeV2OpTest, small_minfirst_int) {
169+
TEST_P(MklQuantizeV2OpTest, small_minfirst_int) {
170+
const auto dtype = GetParam();
134171
TF_ASSERT_OK(NodeDefBuilder("quantize_op", "_MklQuantizeV2")
135-
.Input(FakeInput(DT_FLOAT))
172+
.Input(FakeInput(dtype))
136173
.Input(FakeInput(DT_FLOAT))
137174
.Input(FakeInput(DT_FLOAT))
138175
.Attr("T", DataTypeToEnum<quint8>::v())
139176
.Attr("mode", "MIN_FIRST")
140177
.Attr("_kernel", "QuantizedMklOp")
141178
.Finalize(node_def()));
142179
TF_ASSERT_OK(InitOp());
143-
AddInputFromArray<float>(TensorShape({8}),
144-
{-0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7, -0.8});
145-
AddInputFromArray<float>(TensorShape({}), {-0.8});
146-
AddInputFromArray<float>(TensorShape({}), {-0.1});
180+
switch (dtype) {
181+
case DT_BFLOAT16:
182+
AddInputFromList<bfloat16>(
183+
TensorShape({8}), {-0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7, -0.8});
184+
185+
break;
186+
default:
187+
AddInputFromArray<float>(
188+
TensorShape({8}), {-0.1, -0.2, -0.3, -0.4, -0.5, -0.6, -0.7, -0.8});
189+
}
190+
AddInputFromArray<float>(TensorShape({1}), {-0.8});
191+
AddInputFromArray<float>(TensorShape({1}), {-0.1});
147192
TF_ASSERT_OK(RunOpKernel());
148193
Tensor expected(allocator(), DT_QUINT8, TensorShape({8}));
149194
test::FillValues<quint8>(&expected, {223, 191, 159, 128, 96, 64, 32, 0});
@@ -154,5 +199,8 @@ TEST_F(MklQuantizeV2OpTest, small_minfirst_int) {
154199
EXPECT_NEAR(0.0f, output_max, 1e-5f);
155200
}
156201

202+
INSTANTIATE_TEST_SUITE_P(All, MklQuantizeV2OpTest,
203+
::testing::Values(DT_FLOAT, DT_BFLOAT16));
204+
157205
} // end namespace tensorflow
158206
#endif // INTEL_MKL

tensorflow/core/kernels/quantize_op.cc

+52-17
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@ limitations under the License.
1717

1818
#define EIGEN_USE_THREADS
1919

20+
#include <type_traits>
21+
2022
#include "tensorflow/core/framework/op.h"
2123
#include "tensorflow/core/framework/op_kernel.h"
24+
#include "tensorflow/core/framework/register_types.h"
2225
#include "tensorflow/core/framework/type_traits.h"
2326
#include "tensorflow/core/framework/types.h"
2427
#include "tensorflow/core/kernels/cwise_ops.h"
@@ -55,7 +58,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
5558
// max_range.
5659
// TODO(xbing): Add a new QuantizeOp just taking scale,
5760
// rather than min_range and max_range.
58-
template <typename Device, typename T>
61+
template <typename Device, typename S, typename T>
5962
class QuantizeV2Op : public OpKernel {
6063
public:
6164
explicit QuantizeV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) {
@@ -106,8 +109,26 @@ class QuantizeV2Op : public OpKernel {
106109
ctx, ctx->GetAttr("ensure_minimum_range", &ensure_minimum_range_));
107110
}
108111

112+
void MaybeConvertToFloat(OpKernelContext* ctx, const int idx,
113+
Tensor* converted_tensor) {
114+
if (std::is_same<S, float>::value) return;
115+
// Convert input tensor of type S to float tensor.
116+
const Tensor& input_tensor = ctx->input(idx);
117+
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, input_tensor.shape(),
118+
converted_tensor));
119+
auto flat_input = input_tensor.flat<S>();
120+
auto d = ctx->eigen_device<Device>();
121+
auto flat_output = converted_tensor->flat<float>();
122+
flat_output.device(d) = flat_input.template cast<float>();
123+
}
124+
109125
void Compute(OpKernelContext* ctx) override {
110-
const Tensor& input = ctx->input(0);
126+
// To process bfloat16 input tensor the tensor is converted to float.
127+
Tensor converted_tensor;
128+
MaybeConvertToFloat(ctx, 0,
129+
&converted_tensor); // Does nothing for float input.
130+
const Tensor& input =
131+
(std::is_same<S, float>::value) ? ctx->input(0) : converted_tensor;
111132
const Tensor& input_min_range = ctx->input(1);
112133
const Tensor& input_max_range = ctx->input(2);
113134

@@ -345,19 +366,33 @@ class QuantizeV2Op : public OpKernel {
345366
bool narrow_range_;
346367
};
347368

348-
REGISTER_KERNEL_BUILDER(
349-
Name("QuantizeV2").Device(DEVICE_CPU).TypeConstraint<quint8>("T"),
350-
QuantizeV2Op<CPUDevice, quint8>);
351-
REGISTER_KERNEL_BUILDER(
352-
Name("QuantizeV2").Device(DEVICE_CPU).TypeConstraint<qint8>("T"),
353-
QuantizeV2Op<CPUDevice, qint8>);
354-
REGISTER_KERNEL_BUILDER(
355-
Name("QuantizeV2").Device(DEVICE_CPU).TypeConstraint<quint16>("T"),
356-
QuantizeV2Op<CPUDevice, quint16>);
357-
REGISTER_KERNEL_BUILDER(
358-
Name("QuantizeV2").Device(DEVICE_CPU).TypeConstraint<qint16>("T"),
359-
QuantizeV2Op<CPUDevice, qint16>);
360-
REGISTER_KERNEL_BUILDER(
361-
Name("QuantizeV2").Device(DEVICE_CPU).TypeConstraint<qint32>("T"),
362-
QuantizeV2Op<CPUDevice, qint32>);
369+
#define REGISTER_CPU(S) \
370+
REGISTER_KERNEL_BUILDER(Name("QuantizeV2") \
371+
.Device(DEVICE_CPU) \
372+
.TypeConstraint<S>("dtype") \
373+
.TypeConstraint<quint8>("T"), \
374+
QuantizeV2Op<CPUDevice, S, quint8>); \
375+
REGISTER_KERNEL_BUILDER(Name("QuantizeV2") \
376+
.Device(DEVICE_CPU) \
377+
.TypeConstraint<S>("dtype") \
378+
.TypeConstraint<qint8>("T"), \
379+
QuantizeV2Op<CPUDevice, S, qint8>); \
380+
REGISTER_KERNEL_BUILDER(Name("QuantizeV2") \
381+
.Device(DEVICE_CPU) \
382+
.TypeConstraint<S>("dtype") \
383+
.TypeConstraint<quint16>("T"), \
384+
QuantizeV2Op<CPUDevice, S, quint16>); \
385+
REGISTER_KERNEL_BUILDER(Name("QuantizeV2") \
386+
.Device(DEVICE_CPU) \
387+
.TypeConstraint<S>("dtype") \
388+
.TypeConstraint<qint16>("T"), \
389+
QuantizeV2Op<CPUDevice, S, qint16>); \
390+
REGISTER_KERNEL_BUILDER(Name("QuantizeV2") \
391+
.Device(DEVICE_CPU) \
392+
.TypeConstraint<S>("dtype") \
393+
.TypeConstraint<qint32>("T"), \
394+
QuantizeV2Op<CPUDevice, S, qint32>);
395+
396+
TF_CALL_float(REGISTER_CPU);
397+
TF_CALL_bfloat16(REGISTER_CPU);
363398
} // namespace tensorflow

0 commit comments

Comments
 (0)