Skip to content

Commit ff46131

Browse files
committed
Address more review comments.
1 parent 9fbe607 commit ff46131

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

tensorflow/core/kernels/mkl/mkl_kernel_util.h

+32
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,38 @@ class MklTestingUtil {
4949
*tensor_min = min();
5050
*tensor_max = max();
5151
}
52+
53+
// This utility function mimics Quantization of float/bfloat16 tensor with
54+
// oneDNN backend QuantizeV2 operation. Since the op signature requires min
55+
// and max values to be in float type, min_tensor and max_tensor should have
56+
// their dtype set to DT_FLOAT.
57+
template <typename T>
58+
static void GetQuantizationTensors(const Tensor& input, Tensor* output,
59+
DataType out_type, const string mode,
60+
Tensor* min_tensor, Tensor* max_tensor) {
61+
if (min_tensor->dtype() != DT_FLOAT || max_tensor->dtype() != DT_FLOAT) {
62+
return absl::UnimplementedError("Tensor must be float32.");
63+
}
64+
T min;
65+
T max;
66+
ComputeMinMax<T>(input, &min, &max);
67+
68+
float adjusted_min = static_cast<float>(min);
69+
float adjusted_max = static_cast<float>(max);
70+
if (mode == "SCALED") {
71+
if (output->dtype() != DT_QINT8) {
72+
return absl::UnimplementedError("Tensor must be QInt8 in SCALED mode.");
73+
}
74+
float range = std::max(std::abs(adjusted_min), std::abs(adjusted_max));
75+
adjusted_min = -range;
76+
adjusted_max = range;
77+
}
78+
RunMklQuantizeOp(input, adjusted_min, adjusted_max, out_type, mode, output);
79+
min_tensor->flat<float>()(0) = adjusted_min;
80+
max_tensor->flat<float>()(0) = adjusted_max;
81+
82+
return OkStatus();
83+
}
5284
};
5385

5486
#ifdef ENABLE_ONEDNN_V3

tensorflow/core/kernels/mkl/onednn_fused_matmul_ops_test.cc

+7-1
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,12 @@ class FusedMatMulOpsTest : public OpsTestBase {
103103
string output_quant_mode, bool is_bias_quantized, bool is_perchannel,
104104
bool requantize, float output_min, float output_max)>;
105105

106+
bool HasQuantizationSupport() {
107+
return TestCPUFeature(tensorflow::port::CPUFeature::AVX_VNNI_INT8) ||
108+
TestCPUFeature(tensorflow::port::CPUFeature::AVX512_VNNI) ||
109+
TestCPUFeature(port::CPUFeature::AMX_INT8);
110+
}
111+
106112
// Runs a Tensorflow graph defined by the root scope, and fetches the result
107113
// of 'fetch' node into the outputs. Optional `add_nodes` parameter
108114
// allows to define nodes directly using NodeDefBuilder.
@@ -617,7 +623,7 @@ class FusedMatMulOpsTest : public OpsTestBase {
617623
// true: requantized
618624
// (5) weight matrix is transposed : {false, true}
619625
void VerifyQuantizedMatMul(std::vector<string> fused_ops) {
620-
if (!IsMKLEnabled()) {
626+
if (!HasQuantizationSupport()) {
621627
GTEST_SKIP() << "oneDNN based Quantized ops are not enabled on this CPU.";
622628
}
623629
const GraphRunner run_default =

0 commit comments

Comments
 (0)