Skip to content

Commit 4a135e3

Browse files
ybaturinatf-text-github-robot
authored andcommitted
Create tf_text wheel build rule.
PiperOrigin-RevId: 816530468
1 parent aa839b1 commit 4a135e3

File tree

8 files changed

+363
-22
lines changed

8 files changed

+363
-22
lines changed

WORKSPACE

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,28 @@ workspace(name = "org_tensorflow_text")
22

33
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
44

5+
# Toolchains for ML projects hermetic builds.
6+
# Details: https://github.com/google-ml-infra/rules_ml_toolchain
7+
http_archive(
8+
name = "rules_ml_toolchain",
9+
sha256 = "de3b14418657eeacd8afc2aa89608be6ec8d66cd6a5de81c4f693e77bc41bee1",
10+
strip_prefix = "rules_ml_toolchain-5653e5a0ca87c1272069b4b24864e55ce7f129a1",
11+
urls = [
12+
"https://github.com/google-ml-infra/rules_ml_toolchain/archive/5653e5a0ca87c1272069b4b24864e55ce7f129a1.tar.gz",
13+
],
14+
)
15+
16+
load(
17+
"@rules_ml_toolchain//cc_toolchain/deps:cc_toolchain_deps.bzl",
18+
"cc_toolchain_deps",
19+
)
20+
21+
cc_toolchain_deps()
22+
23+
register_toolchains("@rules_ml_toolchain//cc_toolchain:lx64_lx64")
24+
25+
register_toolchains("@rules_ml_toolchain//cc_toolchain:lx64_lx64_cuda")
26+
527
http_archive(
628
name = "icu",
729
strip_prefix = "icu-release-64-2",
@@ -56,10 +78,10 @@ http_archive(
5678

5779
http_archive(
5880
name = "org_tensorflow",
59-
strip_prefix = "tensorflow-40998f44c0c500ce0f6e3b1658dfbc54f838a82a",
60-
sha256 = "5a5bc4599964c71277dcac0d687435291e5810d2ac2f6283cc96736febf73aaf",
81+
sha256 = "1a25308b15036bf8006ada5c9955ddc9a217792e6fc24deee04626ec07013f2c",
82+
strip_prefix = "tensorflow-72fbba3d20f4616d7312b5e2b7f79daf6e82f2fa",
6183
urls = [
62-
"https://github.com/tensorflow/tensorflow/archive/40998f44c0c500ce0f6e3b1658dfbc54f838a82a.zip"
84+
"https://github.com/tensorflow/tensorflow/archive/72fbba3d20f4616d7312b5e2b7f79daf6e82f2fa.zip",
6385
],
6486
)
6587

@@ -134,6 +156,14 @@ load("@pypi//:requirements.bzl", "install_deps")
134156

135157
install_deps()
136158

159+
load("//oss_scripts/pip_package:tensorflow_text_python_wheel.bzl", "tensorflow_text_python_wheel_repository")
160+
161+
tensorflow_text_python_wheel_repository(
162+
name = "tensorflow_text_wheel",
163+
version_key = "__version__",
164+
version_source = "//tensorflow_text:__init__.py",
165+
)
166+
137167
# Initialize TensorFlow dependencies.
138168
load("@org_tensorflow//tensorflow:workspace3.bzl", "tf_workspace3")
139169
tf_workspace3()
@@ -151,14 +181,16 @@ load("@local_config_android//:android.bzl", "android_workspace")
151181
android_workspace()
152182

153183
load(
154-
"@local_xla//third_party/py:python_wheel.bzl",
184+
"@org_tensorflow//third_party/xla/third_party/py:python_wheel.bzl",
155185
"python_wheel_version_suffix_repository",
156186
)
157187

158-
python_wheel_version_suffix_repository(name = "tf_wheel_version_suffix")
188+
python_wheel_version_suffix_repository(
189+
name = "tf_wheel_version_suffix",
190+
)
159191

160192
load(
161-
"@local_xla//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
193+
"@rules_ml_toolchain//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl",
162194
"cuda_json_init_repository",
163195
)
164196

@@ -170,7 +202,7 @@ load(
170202
"CUDNN_REDISTRIBUTIONS",
171203
)
172204
load(
173-
"@local_xla//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl",
205+
"@rules_ml_toolchain//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl",
174206
"cuda_redist_init_repositories",
175207
"cudnn_redist_init_repository",
176208
)
@@ -184,21 +216,21 @@ cudnn_redist_init_repository(
184216
)
185217

186218
load(
187-
"@local_xla//third_party/gpus/cuda/hermetic:cuda_configure.bzl",
219+
"@rules_ml_toolchain//third_party/gpus/cuda/hermetic:cuda_configure.bzl",
188220
"cuda_configure",
189221
)
190222

191223
cuda_configure(name = "local_config_cuda")
192224

193225
load(
194-
"@local_xla//third_party/nccl/hermetic:nccl_redist_init_repository.bzl",
226+
"@rules_ml_toolchain//third_party/nccl/hermetic:nccl_redist_init_repository.bzl",
195227
"nccl_redist_init_repository",
196228
)
197229

198230
nccl_redist_init_repository()
199231

200232
load(
201-
"@local_xla//third_party/nccl/hermetic:nccl_configure.bzl",
233+
"@rules_ml_toolchain//third_party/nccl/hermetic:nccl_configure.bzl",
202234
"nccl_configure",
203235
)
204236

oss_scripts/configure.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ else
4141
if [[ "$IS_NIGHTLY" == "nightly" ]]; then
4242
pip install tf-nightly
4343
else
44-
pip install tensorflow==2.18.0
44+
pip install tensorflow==2.20.0
4545
fi
4646
fi
4747

oss_scripts/pip_package/BUILD

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
load("@bazel_skylib//rules:common_settings.bzl", "string_flag")
2+
load("@org_tensorflow//third_party/xla/third_party/py:python_wheel.bzl", "collect_data_files", "transitive_py_deps")
3+
14
# Tools for building the TF.Text pip package.
25
load("@python//:defs.bzl", "compile_pip_requirements")
36
load("@python_version_repo//:py_version.bzl", "REQUIREMENTS")
7+
load("//oss_scripts/pip_package:wheel.bzl", "tensorflow_text_wheel")
48

59
package(default_visibility = ["//visibility:private"])
610

@@ -27,14 +31,58 @@ py_binary(
2731
],
2832
)
2933

30-
sh_binary(
31-
name = "build_pip_package",
32-
srcs = ["build_pip_package.sh"],
33-
data = [
34+
string_flag(
35+
name = "output_path",
36+
build_setting_default = "dist",
37+
)
38+
39+
py_binary(
40+
name = "build_wheel_py",
41+
srcs = ["build_wheel.py"],
42+
main = "build_wheel.py",
43+
deps = [
44+
#":build_utils",
45+
#"@bazel_tools//tools/python/runfiles",
46+
#"@pypi//build",
47+
#"@pypi//setuptools",
48+
#"@pypi//wheel",
49+
],
50+
)
51+
52+
filegroup(
53+
name = "wheel_sources",
54+
srcs = [
3455
"LICENSE",
3556
"MANIFEST.in",
3657
"setup.nightly.py",
37-
"setup.py",
38-
"//tensorflow_text",
58+
":transitive_data_deps",
59+
":transitive_py_deps",
3960
],
4061
)
62+
63+
transitive_py_deps(
64+
name = "transitive_py_deps",
65+
deps = ["//tensorflow_text"],
66+
)
67+
68+
collect_data_files(
69+
name = "transitive_data_deps",
70+
deps = ["//tensorflow_text"],
71+
)
72+
73+
tensorflow_text_wheel(
74+
name = "tensorflow_text_wheel",
75+
srcs = [":wheel_sources"],
76+
)
77+
78+
#sh_binary(
79+
# name = "build_pip_package",
80+
# srcs = ["build_pip_package.sh"],
81+
# data = [
82+
# "LICENSE",
83+
# "MANIFEST.in",
84+
# "setup.nightly.py",
85+
# "setup.py",
86+
# "//tensorflow_text",
87+
# ],
88+
#)
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# coding=utf-8
2+
# Copyright 2025 TF.Text Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
17+
#
18+
# Licensed under the Apache License, Version 2.0 (the "License");
19+
# you may not use this file except in compliance with the License.
20+
# You may obtain a copy of the License at
21+
#
22+
# http://www.apache.org/licenses/LICENSE-2.0
23+
#
24+
# Unless required by applicable law or agreed to in writing, software
25+
# distributed under the License is distributed on an "AS IS" BASIS,
26+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27+
# See the License for the specific language governing permissions and
28+
# limitations under the License.
29+
# ==============================================================================
30+
"""Script that builds a tf text wheel, intended to be run via bazel."""
31+
32+
import argparse
33+
import os
34+
import pathlib
35+
import shutil
36+
import subprocess
37+
import sys
38+
import tempfile
39+
40+
parser = argparse.ArgumentParser(fromfile_prefix_chars="@")
41+
parser.add_argument(
42+
"--output_path",
43+
default=None,
44+
required=True,
45+
help="Path to which the output wheel should be written. Required.",
46+
)
47+
parser.add_argument(
48+
"--srcs", help="source files for the wheel", action="append"
49+
)
50+
parser.add_argument(
51+
"--platform",
52+
default="",
53+
required=False,
54+
help="Platform name to be passed to setup.py",
55+
)
56+
args = parser.parse_args()
57+
58+
59+
def copy_file(
60+
src_file: str,
61+
dst_dir: str,
62+
) -> None:
63+
"""Copy a file to the destination directory.
64+
65+
Args:
66+
src_file: file to be copied
67+
dst_dir: destination directory
68+
"""
69+
70+
dest_dir_path = os.path.join(dst_dir, os.path.dirname(src_file))
71+
os.makedirs(dest_dir_path, exist_ok=True)
72+
shutil.copy(src_file, dest_dir_path)
73+
os.chmod(os.path.join(dst_dir, src_file), 0o644)
74+
75+
76+
def prepare_srcs(deps: list[str], srcs_dir: str) -> None:
77+
"""Filter the sources and copy them to the destination directory.
78+
79+
Args:
80+
deps: a list of paths to files.
81+
srcs_dir: target directory where files are copied to.
82+
"""
83+
84+
for file in deps:
85+
print(file)
86+
if not (file.startswith("bazel-out") or file.startswith("external")):
87+
copy_file(file, srcs_dir)
88+
89+
90+
def build_wheel(
91+
dir_path: str,
92+
cwd: str,
93+
platform: str,
94+
) -> None:
95+
"""Build the wheel in the target directory.
96+
97+
Args:
98+
dir_path: directory where the wheel will be stored
99+
cwd: path to directory with wheel source files
100+
platform: platform name to pass to setup.py.
101+
"""
102+
103+
subprocess.run(
104+
[
105+
sys.executable,
106+
"setup.nightly.py",
107+
"bdist_wheel",
108+
f"--dist-dir={dir_path}",
109+
f"--plat-name={platform}",
110+
],
111+
check=True,
112+
cwd=cwd,
113+
)
114+
115+
116+
tmpdir = tempfile.TemporaryDirectory(prefix="tensorflow_text")
117+
sources_path = tmpdir.name
118+
119+
try:
120+
os.makedirs(args.output_path, exist_ok=True)
121+
prepare_srcs(args.srcs, pathlib.Path(sources_path))
122+
build_wheel(
123+
os.path.join(os.getcwd(), args.output_path),
124+
tmpdir.path,
125+
args.platform,
126+
)
127+
finally:
128+
if tmpdir:
129+
tmpdir.cleanup()
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
setuptools==70.0.0
22
dm-tree==0.1.8 # Limit for macos support.
33
numpy
4-
protobuf==4.25.3 # b/397977335 - Fix crash on python 3.9, 3.10.
5-
tensorflow
4+
#protobuf==4.25.3 # b/397977335 - Fix crash on python 3.9, 3.10.
5+
tensorflow==2.20.0
66
tf-keras
7-
tensorflow-datasets
8-
tensorflow-metadata
7+
#tensorflow-datasets
8+
#tensorflow-metadata
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
# Repository rule to generate a file with TF text wheel version.
16+
def _tensorflow_text_python_wheel_repository_impl(repository_ctx):
17+
version_source = repository_ctx.attr.version_source
18+
version_key = repository_ctx.attr.version_key
19+
version_file_content = repository_ctx.read(
20+
repository_ctx.path(version_source),
21+
)
22+
version_start_index = version_file_content.find(version_key)
23+
version_end_index = version_start_index + version_file_content[version_start_index:].find("\n")
24+
wheel_version = version_file_content[version_start_index:version_end_index].replace(
25+
version_key,
26+
"WHEEL_VERSION",
27+
)
28+
repository_ctx.file(
29+
"wheel.bzl",
30+
wheel_version,
31+
)
32+
repository_ctx.file("BUILD", "")
33+
34+
tensorflow_text_python_wheel_repository = repository_rule(
35+
implementation = _tensorflow_text_python_wheel_repository_impl,
36+
attrs = {
37+
"version_source": attr.label(mandatory = True, allow_single_file = True),
38+
"version_key": attr.string(mandatory = True),
39+
},
40+
)

0 commit comments

Comments
 (0)