@@ -49,6 +49,38 @@ class MklTestingUtil {
49
49
*tensor_min = min ();
50
50
*tensor_max = max ();
51
51
}
52
+
53
+ // This utility function mimics Quantization of float/bfloat16 tensor with
54
+ // oneDNN backend QuantizeV2 operation. Since the op signature requires min
55
+ // and max values to be in float type, min_tensor and max_tensor should have
56
+ // their dtype set to DT_FLOAT.
57
+ template <typename T>
58
+ static void GetQuantizationTensors (const Tensor& input, Tensor* output,
59
+ DataType out_type, const string mode,
60
+ Tensor* min_tensor, Tensor* max_tensor) {
61
+ if (min_tensor->dtype () != DT_FLOAT || max_tensor->dtype () != DT_FLOAT) {
62
+ return absl::UnimplementedError (" Tensor must be float32." );
63
+ }
64
+ T min;
65
+ T max;
66
+ ComputeMinMax<T>(input, &min, &max);
67
+
68
+ float adjusted_min = static_cast <float >(min);
69
+ float adjusted_max = static_cast <float >(max);
70
+ if (mode == " SCALED" ) {
71
+ if (output->dtype () != DT_QINT8) {
72
+ return absl::UnimplementedError (" Tensor must be QInt8 in SCALED mode." );
73
+ }
74
+ float range = std::max (std::abs (adjusted_min), std::abs (adjusted_max));
75
+ adjusted_min = -range;
76
+ adjusted_max = range;
77
+ }
78
+ RunMklQuantizeOp (input, adjusted_min, adjusted_max, out_type, mode, output);
79
+ min_tensor->flat <float >()(0 ) = adjusted_min;
80
+ max_tensor->flat <float >()(0 ) = adjusted_max;
81
+
82
+ return OkStatus ();
83
+ }
52
84
};
53
85
54
86
#ifdef ENABLE_ONEDNN_V3
0 commit comments