Skip to content

Commit 9a4f082

Browse files
Split several shared functions into utility modules.
PiperOrigin-RevId: 718030153
1 parent 4178fc1 commit 9a4f082

File tree

10 files changed

+326
-161
lines changed

10 files changed

+326
-161
lines changed

jax_tpu_embedding/sparsecore/lib/core/BUILD

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
load("//third_party/bazel/python:pybind11.bzl", "pybind_extension")
14+
load("//third_party/bazel/python:pybind11.bzl", "pybind_extension", "pybind_library")
1515
load("//third_party/bazel/python:pypi.bzl", "pypi_requirement")
1616
load("//third_party/bazel/python:pytype.bzl", "pytype_strict_contrib_test", "pytype_strict_library")
1717

@@ -72,17 +72,41 @@ cc_test(
7272
],
7373
)
7474

75+
pybind_library(
76+
name = "input_preprocessing_py_util",
77+
srcs = [
78+
"input_preprocessing_py_util.cc",
79+
],
80+
hdrs = [
81+
"input_preprocessing_py_util.h",
82+
],
83+
deps = [
84+
":input_preprocessing_util",
85+
"@com_google_absl//absl/base:core_headers",
86+
"@com_google_absl//absl/container:flat_hash_map",
87+
"@com_google_absl//absl/log:check",
88+
"@tsl//tsl/profiler/lib:traceme",
89+
],
90+
)
91+
7592
pybind_extension(
7693
name = "input_preprocessing_cc",
77-
srcs = ["input_preprocessing.cc"],
94+
srcs = [
95+
"input_preprocessing.cc",
96+
],
7897
deps = [
98+
":input_preprocessing_py_util",
7999
":input_preprocessing_threads",
80100
":input_preprocessing_util",
101+
"@com_google_absl//absl/base:core_headers",
81102
"@com_google_absl//absl/container:flat_hash_map",
103+
"@com_google_absl//absl/log",
82104
"@com_google_absl//absl/log:check",
83105
"@com_google_absl//absl/strings",
84106
"@com_google_absl//absl/synchronization",
85107
"@com_google_absl//absl/types:span",
108+
"@highway//:hwy",
109+
"@highway//hwy/contrib/sort:vqsort",
86110
"@tsl//tsl/profiler/lib:connected_traceme",
87111
"@tsl//tsl/profiler/lib:traceme",
88112
],

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414
#include <algorithm>
15-
#include <cmath>
1615
#include <optional>
1716
#include <string>
1817
#include <utility>
@@ -24,6 +23,7 @@
2423
#include "absl/strings/string_view.h" // from @com_google_absl
2524
#include "absl/synchronization/blocking_counter.h" // from @com_google_absl
2625
#include "absl/types/span.h" // from @com_google_absl
26+
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_py_util.h"
2727
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_threads.h"
2828
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h"
2929
#include "pybind11/cast.h" // from @pybind11
@@ -148,48 +148,6 @@ int ExtractCooTensors(const py::array& features,
148148
global_device_count, coo_tensors);
149149
}
150150

151-
absl::flat_hash_map<std::string, std::vector<StackedTableMetadata>>
152-
GetStackedTableMetadata(py::list feature_specs, py::list features) {
153-
tsl::profiler::TraceMe t([] { return "GetStackedTableMetadata"; });
154-
absl::flat_hash_map<std::string, std::vector<StackedTableMetadata>>
155-
stacked_table_metadata;
156-
for (int i = 0; i < feature_specs.size(); ++i) {
157-
const py::object& feature_spec = feature_specs[i];
158-
const py::array& feature = features[i].cast<py::array>();
159-
const py::object& feature_transformation =
160-
feature_spec.attr("_id_transformation");
161-
const py::object& table_spec = feature_spec.attr("table_spec");
162-
const py::object& stacked_table_spec =
163-
table_spec.attr("stacked_table_spec");
164-
const std::string stacked_table_name = py::cast<std::string>(
165-
table_spec.attr("_setting_in_stack").attr("stack_name"));
166-
int col_shift = 0;
167-
int col_offset = 0;
168-
int row_offset = 0;
169-
const int max_ids_per_partition =
170-
py::cast<int>(stacked_table_spec.attr("max_ids_per_partition"));
171-
const int max_unique_ids_per_partition =
172-
py::cast<int>(stacked_table_spec.attr("max_unique_ids_per_partition"));
173-
if (!feature_transformation.is_none()) {
174-
row_offset = py::cast<int>(feature_transformation.attr("row_offset"));
175-
col_shift = py::cast<int>(feature_transformation.attr("col_shift"));
176-
col_offset = py::cast<int>(feature_transformation.attr("col_offset"));
177-
}
178-
stacked_table_metadata[stacked_table_name].emplace_back(
179-
i, max_ids_per_partition, max_unique_ids_per_partition, row_offset,
180-
col_offset, col_shift,
181-
/*batch_size=*/feature.shape(0));
182-
}
183-
// Sort the stacked tables by row_offset.
184-
for (auto& [_, t] : stacked_table_metadata) {
185-
std::sort(t.begin(), t.end(),
186-
[](const StackedTableMetadata& a, const StackedTableMetadata& b) {
187-
return a.row_offset < b.row_offset;
188-
});
189-
}
190-
return stacked_table_metadata;
191-
}
192-
193151
// Preprocess inputs for a single table. Stacked table here refers to a
194152
// a table that has no parent in the table stacking hierarchy. So in the case
195153
// of table stacking, the stacked table is the top level table and in the case
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// Copyright 2024 The JAX SC Authors.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_py_util.h"
15+
16+
#include <algorithm>
17+
#include <cmath>
18+
#include <string>
19+
#include <utility>
20+
#include <vector>
21+
22+
#include "absl/container/flat_hash_map.h" // from @com_google_absl
23+
#include "absl/log/check.h" // from @com_google_absl
24+
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h"
25+
#include "pybind11/cast.h" // from @pybind11
26+
#include "pybind11/gil.h" // from @pybind11
27+
#include "pybind11/numpy.h" // from @pybind11
28+
#include "pybind11/pybind11.h" // from @pybind11
29+
#include "pybind11/pytypes.h" // from @pybind11
30+
#include "tsl/profiler/lib/traceme.h" // from @tsl
31+
32+
namespace jax_sc_embedding {
33+
34+
namespace py = ::pybind11;
35+
36+
absl::flat_hash_map<std::string, std::vector<StackedTableMetadata>>
37+
GetStackedTableMetadata(const py::list& feature_specs, const int batch_size) {
38+
tsl::profiler::TraceMe t([] { return "GetStackedTableMetadata"; });
39+
absl::flat_hash_map<std::string, std::vector<StackedTableMetadata>>
40+
stacked_table_metadata;
41+
for (int i = 0; i < feature_specs.size(); ++i) {
42+
const py::object& feature_spec = feature_specs[i];
43+
44+
const py::object& feature_transformation =
45+
feature_spec.attr("_id_transformation");
46+
const py::object& table_spec = feature_spec.attr("table_spec");
47+
const py::object& stacked_table_spec =
48+
table_spec.attr("stacked_table_spec");
49+
const std::string stacked_table_name = py::cast<std::string>(
50+
table_spec.attr("_setting_in_stack").attr("stack_name"));
51+
int col_shift = 0;
52+
int col_offset = 0;
53+
int row_offset = 0;
54+
const int max_ids_per_partition =
55+
py::cast<int>(stacked_table_spec.attr("max_ids_per_partition"));
56+
const int max_unique_ids_per_partition =
57+
py::cast<int>(stacked_table_spec.attr("max_unique_ids_per_partition"));
58+
const int vocab_size =
59+
py::cast<int>(stacked_table_spec.attr("stack_vocab_size"));
60+
if (!feature_transformation.is_none()) {
61+
row_offset = py::cast<int>(feature_transformation.attr("row_offset"));
62+
col_shift = py::cast<int>(feature_transformation.attr("col_shift"));
63+
col_offset = py::cast<int>(feature_transformation.attr("col_offset"));
64+
}
65+
stacked_table_metadata[stacked_table_name].emplace_back(
66+
i, max_ids_per_partition, max_unique_ids_per_partition, row_offset,
67+
col_offset, col_shift,
68+
/*batch_size=*/batch_size, vocab_size);
69+
}
70+
// Sort the stacked tables by row_offset.
71+
for (auto& [_, t] : stacked_table_metadata) {
72+
std::sort(t.begin(), t.end(),
73+
[](const StackedTableMetadata& a, const StackedTableMetadata& b) {
74+
return a.row_offset < b.row_offset;
75+
});
76+
}
77+
return stacked_table_metadata;
78+
}
79+
80+
absl::flat_hash_map<std::string, std::vector<StackedTableMetadata>>
81+
GetStackedTableMetadata(const py::list& feature_specs,
82+
const py::list& features) {
83+
tsl::profiler::TraceMe t([] { return "GetStackedTableMetadata"; });
84+
int batch_size = features[0].cast<py::array>().shape(0);
85+
return GetStackedTableMetadata(feature_specs, batch_size);
86+
}
87+
88+
} // namespace jax_sc_embedding
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Copyright 2024 The JAX SC Authors.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#ifndef JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_INPUT_PREPROCESSING_H_
15+
#define JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_INPUT_PREPROCESSING_H_
16+
#include <string>
17+
#include <vector>
18+
19+
#include "absl/container/flat_hash_map.h" // from @com_google_absl
20+
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h"
21+
#include "pybind11/numpy.h" // from @pybind11
22+
#include "pybind11/pytypes.h" // from @pybind11
23+
24+
namespace jax_sc_embedding {
25+
26+
namespace py = ::pybind11;
27+
28+
// Copy information from feature_specs to StackedTableMetadata.
29+
// The features argument is only used to get the batch size.
30+
absl::flat_hash_map<std::string, std::vector<StackedTableMetadata>>
31+
GetStackedTableMetadata(const py::list& feature_specs,
32+
const py::list& features);
33+
34+
// Copy information from feature_specs to StackedTableMetadata.
35+
absl::flat_hash_map<std::string, std::vector<StackedTableMetadata>>
36+
GetStackedTableMetadata(const py::list& feature_specs, int batch_size);
37+
38+
} // namespace jax_sc_embedding
39+
40+
#endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_INPUT_PREPROCESSING_H_

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ struct CooFormat {
3535

3636
// Get adjusted col_id based on shift and offset.
3737
int GetColId(int col_id, int col_shift, int col_offset, int num_scs_mod,
38-
int num_scs_mod_inv);
38+
int num_scs_mod_inv);
3939

4040
inline unsigned int CeilOfRatio(unsigned int numerator,
4141
unsigned int denominator) {
@@ -50,14 +50,16 @@ struct StackedTableMetadata {
5050
StackedTableMetadata() = delete;
5151
StackedTableMetadata(int feature_index, int max_ids_per_partition,
5252
int max_unique_ids_per_partition, int row_offset,
53-
int col_offset, int col_shift, int batch_size)
53+
int col_offset, int col_shift, int batch_size,
54+
int stacked_table_vocab_size = 0)
5455
: feature_index(feature_index),
5556
max_ids_per_partition(max_ids_per_partition),
5657
max_unique_ids_per_partition(max_unique_ids_per_partition),
5758
row_offset(row_offset),
5859
col_offset(col_offset),
5960
col_shift(col_shift),
60-
batch_size(batch_size) {}
61+
batch_size(batch_size),
62+
stacked_table_vocab_size(stacked_table_vocab_size) {}
6163
// The batch is given as a list of features (numpy arrays). `feature_index`
6264
// represents the index of the feature in the list.
6365
int feature_index;
@@ -70,6 +72,8 @@ struct StackedTableMetadata {
7072

7173
// Process local batch size of the feature.
7274
int batch_size;
75+
76+
int stacked_table_vocab_size;
7377
};
7478

7579
void SortAndGroupCooTensors(

jax_tpu_embedding/sparsecore/lib/core/primitives/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ pytype_strict_library(
4242
"//jax_tpu_embedding/sparsecore/lib/core:constants",
4343
pypi_requirement("jax"),
4444
pypi_requirement("jax/_src/lib"),
45+
pypi_requirement("jax/extend"),
4546
pypi_requirement("numpy"),
4647
],
4748
)

jax_tpu_embedding/sparsecore/lib/core/primitives/sparse_dense_matmul_csr_with_mini_batching.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@
2020
from jax._src import dispatch
2121
from jax._src.lib.mlir import ir
2222
from jax._src.lib.mlir.dialects import hlo
23+
import jax.extend as jex
2324
from jax.interpreters import mlir
2425
import jax.numpy as jnp
2526
from jax_tpu_embedding.sparsecore.lib.core import constants
2627
import numpy as np
2728

2829
# Define the sparse dense matmul primitive.
29-
tpu_sparse_dense_matmul_csr_with_mini_batching_primitive = core.Primitive(
30+
tpu_sparse_dense_matmul_csr_with_mini_batching_primitive = jex.core.Primitive(
3031
"sparse_dense_matmul_csr_with_mini_batching"
3132
)
3233

jax_tpu_embedding/sparsecore/lib/nn/BUILD

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,27 @@ pytype_strict_library(
3232
],
3333
)
3434

35+
pytype_strict_library(
36+
name = "embedding_utils",
37+
srcs = ["embedding_utils.py"],
38+
visibility = ["//jax_tpu_embedding/sparsecore/lib/nn:__subpackages__"],
39+
deps = [
40+
":embedding_spec",
41+
":table_stacking",
42+
pypi_requirement("absl/logging"),
43+
pypi_requirement("jax"),
44+
pypi_requirement("jax:experimental"),
45+
pypi_requirement("numpy"),
46+
pypi_requirement("tree"),
47+
],
48+
)
49+
3550
pytype_strict_library(
3651
name = "embedding",
3752
srcs = ["embedding.py"],
3853
deps = [
3954
":embedding_spec",
55+
":embedding_utils",
4056
":table_stacking",
4157
"//jax_tpu_embedding/sparsecore/lib/core:input_preprocessing_cc",
4258
"//jax_tpu_embedding/sparsecore/lib/core/primitives:sparse_dense_matmul_csr",

0 commit comments

Comments
 (0)