Skip to content

Commit a08671b

Browse files
dansuh17tensorflower-gardener
authored andcommitted
Add pywrap_quantization_lib.h/cc to provide a middle layer for exposing symbols to shared libraries.
PiperOrigin-RevId: 610248507
1 parent 859b7cb commit a08671b

File tree

12 files changed

+161
-59
lines changed

12 files changed

+161
-59
lines changed

tensorflow/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1377,7 +1377,7 @@ tf_cc_shared_library(
13771377
"//tensorflow/compiler/mlir/lite/quantization/lite:quantize_model",
13781378
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
13791379
"//tensorflow/compiler/mlir/lite/sparsity:sparsify_model",
1380-
"//tensorflow/compiler/mlir/quantization/stablehlo/cc:static_range_ptq_impl",
1380+
"//tensorflow/compiler/mlir/quantization/stablehlo/python:pywrap_quantization_lib_impl",
13811381
"//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:custom_aggregator_op",
13821382
"//tensorflow/compiler/mlir/quantization/tensorflow/python:quantize_model_cc_impl",
13831383
"//tensorflow/compiler/mlir/quantization/tensorflow:passes",

tensorflow/compiler/mlir/lite/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1383,7 +1383,6 @@ cc_library(
13831383
"//tensorflow/compiler/mlir/lite/stablehlo:transforms",
13841384
"//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc",
13851385
"//tensorflow/compiler/mlir/quantization/stablehlo:quantize_passes",
1386-
"//tensorflow/compiler/mlir/quantization/stablehlo/cc:static_range_ptq_impl", # buildcleaner: keep; prevents undefined reference
13871386
"//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc",
13881387
"//tensorflow/compiler/mlir/quantization/tensorflow:quantize_passes",
13891388
"//tensorflow/compiler/mlir/quantization/tensorflow:quantize_preprocess",

tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ tf_cc_test(
3737
"//tensorflow/cc/saved_model:loader",
3838
"//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc",
3939
"//tensorflow/compiler/mlir/quantization/stablehlo/cc:io",
40-
"//tensorflow/compiler/mlir/quantization/stablehlo/cc:static_range_ptq_impl", # buildcleaner: keep; prevents undefined reference
4140
"//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibrator_singleton_impl", # buildcleaner: keep; prevents undefined reference
4241
"@com_google_absl//absl/status",
4342
"@com_google_absl//absl/status:statusor",

tensorflow/compiler/mlir/lite/stablehlo/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,6 @@ tf_cc_binary(
647647
"//tensorflow/compiler/mlir/lite:flatbuffer_export",
648648
"//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer",
649649
"//tensorflow/compiler/mlir/lite/stablehlo/serializer:flatbuffer_export",
650-
"//tensorflow/compiler/mlir/quantization/stablehlo/cc:static_range_ptq_impl",
651650
"//tensorflow/compiler/mlir/quantization/tensorflow:quantize_preprocess",
652651
"//tensorflow/compiler/mlir/quantization/tensorflow:tf_quant_ops",
653652
"//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibrator_singleton_impl",

tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,8 @@ cc_library(
3333
],
3434
)
3535

36-
# OSS: This is a header-only target. Do NOT directly depend on `config_impl` unless it is necessary
37-
# (e.g. undefined symbol error), to avoid ODR violation.
3836
cc_library(
3937
name = "config",
40-
hdrs = ["config.h"],
41-
compatible_with = get_compatible_with_portable(),
42-
deps = [
43-
"//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc",
44-
],
45-
)
46-
47-
# OSS: This is a impl target corresponding to `config`. Do NOT directly depend on `config_impl`
48-
# unless it is necessary (e.g. undefined symbol error), to avoid ODR violation.
49-
cc_library(
50-
name = "config_impl",
5138
srcs = ["config.cc"],
5239
hdrs = ["config.h"],
5340
compatible_with = get_compatible_with_portable(),
@@ -303,44 +290,11 @@ cc_library(
303290
],
304291
)
305292

306-
# OSS: This is a header-only target. Do NOT directly depend on `static_range_ptq_impl` unless it is
307-
# necessary (e.g. undefined symbol error), to avoid ODR violation.
308293
cc_library(
309294
name = "static_range_ptq",
310-
hdrs = ["static_range_ptq.h"],
311-
compatible_with = get_compatible_with_portable(),
312-
# Must be header-only or unique (i.e. the same target is not a dependency of
313-
# tensorflow_framework, tensorflow_cc, or pywrap_tensorflow_internal) dependencies.
314-
deps = [
315-
":component",
316-
":types",
317-
"//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc",
318-
"//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc",
319-
"//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib",
320-
"//tensorflow/core:protos_all_cc",
321-
"@com_google_absl//absl/base:nullability",
322-
"@com_google_absl//absl/container:flat_hash_map",
323-
"@com_google_absl//absl/status",
324-
"@com_google_absl//absl/status:statusor",
325-
"@com_google_absl//absl/strings:string_view",
326-
"@llvm-project//mlir:IR",
327-
],
328-
)
329-
330-
cc_library(
331-
name = "static_range_ptq_impl",
332295
srcs = ["static_range_ptq.cc"],
333296
hdrs = ["static_range_ptq.h"],
334297
compatible_with = get_compatible_with_portable(),
335-
visibility = [
336-
"//tensorflow:__pkg__",
337-
"//tensorflow/compiler/mlir/lite:__pkg__", # For tf_tfl_translate binary.
338-
# For odml_to_stablehlo binary.
339-
"//tensorflow/compiler/mlir/lite/stablehlo:__pkg__",
340-
# For StableHLO Quantizer adapter functionalities within TFLite. Testonly.
341-
"//tensorflow/compiler/mlir/lite/quantization/stablehlo:__pkg__",
342-
"//tensorflow/python:__pkg__",
343-
],
344298
deps = [
345299
":component",
346300
":context",

tensorflow/compiler/mlir/quantization/stablehlo/python/BUILD

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
load("//tensorflow:pytype.default.bzl", "pytype_strict_library")
22
load(
33
"//tensorflow:tensorflow.default.bzl",
4+
"get_compatible_with_portable",
45
"tf_py_strict_test",
56
"tf_python_pybind_extension",
67
)
@@ -91,13 +92,56 @@ tf_py_strict_test(
9192
],
9293
)
9394

95+
# This is a header-only target. The purpose of `pywrap_quantization_lib_*` targets is to expose only
96+
# the symbols that are required by `pywrap_quantization` that translates them to python functions.
97+
# The only intended use case of this library is by `pywrap_quantization`. Not letting
98+
# `pywrap_quantization` directly depend on sub-libraries like `static_range_srq` and instead haiving
99+
# a consolidated impl library `pywrap_quantization_lib_impl` allows the maintainers to avoid
100+
# declaring multiple impl libraries to `libtensorflow_cc` and `lib_pywrap_tensorflow_internal`,
101+
# which is required to avoid ODR violations.
102+
cc_library(
103+
name = "pywrap_quantization_lib_header_only",
104+
srcs = [],
105+
hdrs = ["pywrap_quantization_lib.h"],
106+
compatible_with = get_compatible_with_portable(),
107+
visibility = ["//visibility:private"], # ONLY for `pywrap_quantization`.
108+
deps = [
109+
"//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc",
110+
"//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib",
111+
"@com_google_absl//absl/status",
112+
"@com_google_absl//absl/strings:string_view",
113+
],
114+
)
115+
116+
# See the comments for `pywrap_quantization_lib_header_only`.
117+
cc_library(
118+
name = "pywrap_quantization_lib_impl",
119+
srcs = ["pywrap_quantization_lib.cc"],
120+
hdrs = ["pywrap_quantization_lib.h"],
121+
compatible_with = get_compatible_with_portable(),
122+
visibility = [
123+
"//tensorflow:__pkg__", # For libtensorflow_cc.so.
124+
"//tensorflow/python:__pkg__", # For lib_pywrap_tensorflow_internal.so.
125+
],
126+
deps = [
127+
"//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc",
128+
"//tensorflow/compiler/mlir/quantization/stablehlo/cc:config",
129+
"//tensorflow/compiler/mlir/quantization/stablehlo/cc:static_range_ptq",
130+
"//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib",
131+
"//tensorflow/core/protobuf:for_core_protos_cc",
132+
"@com_google_absl//absl/container:flat_hash_map",
133+
"@com_google_absl//absl/status",
134+
"@com_google_absl//absl/strings:string_view",
135+
],
136+
)
137+
94138
tf_python_pybind_extension(
95139
name = "pywrap_quantization",
96140
srcs = ["pywrap_quantization.cc"],
97141
pytype_srcs = ["pywrap_quantization.pyi"],
142+
# Each dependency MUST be either header-only or exclusive.
98143
deps = [
99-
"//tensorflow/compiler/mlir/quantization/stablehlo/cc:config_impl",
100-
"//tensorflow/compiler/mlir/quantization/stablehlo/cc:static_range_ptq",
144+
":pywrap_quantization_lib_header_only",
101145
"//tensorflow/compiler/mlir/quantization/tensorflow/python:type_casters",
102146
"@pybind11",
103147
"@pybind11_abseil//pybind11_abseil:absl_casters",

tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization.cc

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,15 @@ limitations under the License.
2020
#include "pybind11_abseil/absl_casters.h" // from @pybind11_abseil // IWYU pragma: keep
2121
#include "pybind11_abseil/import_status_module.h" // from @pybind11_abseil
2222
#include "pybind11_abseil/status_casters.h" // from @pybind11_abseil // IWYU pragma: keep
23-
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h"
24-
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h"
23+
#include "tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization_lib.h"
2524
#include "tensorflow/compiler/mlir/quantization/tensorflow/python/type_casters.h" // IWYU pragma: keep
2625

2726
namespace py = pybind11;
2827

2928
namespace {
3029

31-
using ::mlir::quant::stablehlo::QuantizeStaticRangePtq;
32-
using ::stablehlo::quantization::PopulateDefaults;
30+
using ::stablehlo::quantization::pywrap::PywrapPopulateDefaults;
31+
using ::stablehlo::quantization::pywrap::PywrapQuantizeStaticRangePtq;
3332

3433
} // namespace
3534

@@ -42,7 +41,7 @@ PYBIND11_MODULE(pywrap_quantization, m) {
4241
// If the function signature changes, likely its corresponding .pyi type
4342
// hinting should also change.
4443
// LINT.IfChange(static_range_ptq)
45-
m.def("static_range_ptq", &QuantizeStaticRangePtq,
44+
m.def("static_range_ptq", &PywrapQuantizeStaticRangePtq,
4645
R"pbdoc(
4746
Runs static-range post-training quantization (PTQ) on a SavedModel at
4847
`src_saved_model_path` and saves the resulting model to
@@ -68,7 +67,7 @@ PYBIND11_MODULE(pywrap_quantization, m) {
6867
// If the function signature changes, likely its corresponding .pyi type
6968
// hinting should also change.
7069
// LINT.IfChange(populate_default_configs)
71-
m.def("populate_default_configs", &PopulateDefaults,
70+
m.def("populate_default_configs", &PywrapPopulateDefaults,
7271
R"pbdoc(
7372
Populates `QuantizationConfig` with default values.
7473
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#include "tensorflow/compiler/mlir/quantization/stablehlo/python/pywrap_quantization_lib.h"
16+
17+
#include <string>
18+
#include <vector>
19+
20+
#include "absl/container/flat_hash_map.h"
21+
#include "absl/status/status.h"
22+
#include "absl/strings/string_view.h"
23+
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h"
24+
#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h"
25+
#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h"
26+
#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h"
27+
28+
namespace stablehlo::quantization::pywrap {
29+
30+
using ::mlir::quant::stablehlo::QuantizeStaticRangePtq;
31+
using ::tensorflow::SignatureDef;
32+
using ::tensorflow::quantization::PyFunctionLibrary;
33+
34+
// Note for maintainers: the definitions should ONLY mirror existing functions
35+
// defined in different targets. Do not include any extra business logic that
36+
// causes divergence from the semantics of mirrored functions.
37+
38+
absl::Status PywrapQuantizeStaticRangePtq(
39+
absl::string_view src_saved_model_path,
40+
absl::string_view dst_saved_model_path, const QuantizationConfig& config,
41+
const std::vector<std::string>& signature_keys,
42+
const absl::flat_hash_map<std::string, SignatureDef>& signature_def_map,
43+
const absl::flat_hash_map<std::string, std::string>& function_aliases,
44+
const PyFunctionLibrary& py_function_library) {
45+
return QuantizeStaticRangePtq(src_saved_model_path, dst_saved_model_path,
46+
config, signature_keys, signature_def_map,
47+
function_aliases, py_function_library);
48+
}
49+
50+
QuantizationConfig PywrapPopulateDefaults(
51+
const QuantizationConfig& user_provided_config) {
52+
return PopulateDefaults(user_provided_config);
53+
}
54+
55+
} // namespace stablehlo::quantization::pywrap
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PYTHON_PYWRAP_QUANTIZATION_LIB_H_
16+
#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PYTHON_PYWRAP_QUANTIZATION_LIB_H_
17+
18+
// Contains mirror functions from StableHLO Quantizer to be exposed to python
19+
// via `pywrap_quantization`.
20+
21+
#include <string>
22+
#include <vector>
23+
24+
#include "absl/container/flat_hash_map.h"
25+
#include "absl/status/status.h"
26+
#include "absl/strings/string_view.h"
27+
#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h"
28+
#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h"
29+
#include "tensorflow/core/protobuf/meta_graph.pb.h"
30+
31+
namespace stablehlo::quantization::pywrap {
32+
33+
// Function used by the pywrap_quantization module to mirror
34+
// `::mlir::quant::stablehlo::QuantizeStaticRangePtq`.
35+
absl::Status PywrapQuantizeStaticRangePtq(
36+
absl::string_view src_saved_model_path,
37+
absl::string_view dst_saved_model_path, const QuantizationConfig& config,
38+
const std::vector<std::string>& signature_keys,
39+
const absl::flat_hash_map<std::string, tensorflow::SignatureDef>&
40+
signature_def_map,
41+
const absl::flat_hash_map<std::string, std::string>& function_aliases,
42+
const tensorflow::quantization::PyFunctionLibrary& py_function_library);
43+
44+
// Function used by the pywrap_quantization module to mirror
45+
// `::stablehlo::quantization::PopulateDefaults`.
46+
QuantizationConfig PywrapPopulateDefaults(
47+
const QuantizationConfig& user_provided_config);
48+
49+
} // namespace stablehlo::quantization::pywrap
50+
51+
#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PYTHON_PYWRAP_QUANTIZATION_LIB_H_

tensorflow/python/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,7 @@ pywrap_tensorflow_macro(
765765
"//tensorflow/cc/saved_model:fingerprinting_impl",
766766
"//tensorflow/cc/saved_model:loader_lite_impl",
767767
"//tensorflow/cc/saved_model:metrics_impl",
768-
"//tensorflow/compiler/mlir/quantization/stablehlo/cc:static_range_ptq_impl",
768+
"//tensorflow/compiler/mlir/quantization/stablehlo/python:pywrap_quantization_lib_impl",
769769
"//tensorflow/compiler/mlir/quantization/tensorflow/python:quantize_model_cc_impl",
770770
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
771771
"//tensorflow/compiler/tf2tensorrt:op_converter_registry_impl",

tensorflow/tf_exported_symbols.lds

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@
1515
*tsl*
1616
*lite*
1717
*TFL*
18+
*quantization*

tensorflow/tf_version_script.lds

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ tensorflow {
1717
*tsl*;
1818
*lite*;
1919
*TFL*;
20+
*quantization*;
2021
local:
2122
*;
2223
};

0 commit comments

Comments
 (0)