Skip to content

Commit bbbabd0

Browse files
author
Ryan Kim
authored
Merge pull request #441 from kroma-network/feat/impl-radix2ditparallel
feat: impl `Radix2DitParallel`
2 parents 5de7501 + fa4c33f commit bbbabd0

21 files changed

+605
-54
lines changed

tachyon/math/finite_fields/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ tachyon_cc_library(
147147
deps = [
148148
":finite_field",
149149
":legendre_symbol",
150+
":packed_prime_field_traits_forward",
150151
":prime_field_util",
151152
"//tachyon/base:bits",
152153
"//tachyon/base/json",

tachyon/math/finite_fields/baby_bear/BUILD.bazel

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
1+
load("@bazel_skylib//rules:common_settings.bzl", "string_flag")
12
load("//bazel:tachyon.bzl", "if_aarch64", "if_has_avx512", "if_x86_64")
23
load("//bazel:tachyon_cc.bzl", "tachyon_avx512_defines", "tachyon_cc_library")
34
load("//tachyon/math/finite_fields/generator/ext_prime_field_generator:build_defs.bzl", "generate_fp4s")
4-
load("//tachyon/math/finite_fields/generator/prime_field_generator:build_defs.bzl", "generate_prime_fields")
5+
load("//tachyon/math/finite_fields/generator/prime_field_generator:build_defs.bzl", "SUBGROUP_GENERATOR", "generate_fft_prime_fields")
56

67
package(default_visibility = ["//visibility:public"])
78

8-
generate_prime_fields(
9+
string_flag(
10+
name = SUBGROUP_GENERATOR,
11+
build_setting_default = "31",
12+
)
13+
14+
generate_fft_prime_fields(
915
name = "baby_bear",
1016
class_name = "BabyBear",
1117
# 2³¹ - 2²⁷ + 1
1218
# Hex: 0x78000001
1319
modulus = "2013265921",
1420
namespace = "tachyon::math",
21+
subgroup_generator = ":" + SUBGROUP_GENERATOR,
1522
use_montgomery = True,
1623
)
1724

tachyon/math/finite_fields/baby_bear/packed_baby_bear.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ struct FiniteFieldTraits<PackedBabyBear> {
3737
using Config = BabyBear::Config;
3838
};
3939

40+
template <>
41+
struct PackedPrimeFieldTraits<BabyBear> {
42+
using PackedPrimeField = PackedBabyBear;
43+
};
44+
4045
} // namespace tachyon::math
4146

4247
namespace Eigen {

tachyon/math/finite_fields/koala_bear/BUILD.bazel

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
1+
load("@bazel_skylib//rules:common_settings.bzl", "string_flag")
12
load("//bazel:tachyon.bzl", "if_aarch64", "if_has_avx512", "if_x86_64")
23
load("//bazel:tachyon_cc.bzl", "tachyon_avx512_defines", "tachyon_cc_library")
34
load("//tachyon/math/finite_fields/generator/ext_prime_field_generator:build_defs.bzl", "generate_fp2s", "generate_fp4s")
4-
load("//tachyon/math/finite_fields/generator/prime_field_generator:build_defs.bzl", "generate_prime_fields")
5+
load("//tachyon/math/finite_fields/generator/prime_field_generator:build_defs.bzl", "SUBGROUP_GENERATOR", "generate_fft_prime_fields")
56

67
package(default_visibility = ["//visibility:public"])
78

8-
generate_prime_fields(
9+
string_flag(
10+
name = SUBGROUP_GENERATOR,
11+
build_setting_default = "3",
12+
)
13+
14+
generate_fft_prime_fields(
915
name = "koala_bear",
1016
class_name = "KoalaBear",
1117
# 2³¹ - 2²⁴ + 1
1218
# Hex: 0x7f000001
1319
modulus = "2130706433",
1420
namespace = "tachyon::math",
21+
subgroup_generator = ":" + SUBGROUP_GENERATOR,
1522
use_montgomery = True,
1623
)
1724

tachyon/math/finite_fields/koala_bear/packed_koala_bear.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ struct FiniteFieldTraits<PackedKoalaBear> {
3737
using Config = KoalaBear::Config;
3838
};
3939

40+
template <>
41+
struct PackedPrimeFieldTraits<KoalaBear> {
42+
using PackedPrimeField = PackedKoalaBear;
43+
};
44+
4045
} // namespace tachyon::math
4146

4247
namespace Eigen {

tachyon/math/finite_fields/mersenne31/packed_mersenne31.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ struct FiniteFieldTraits<PackedMersenne31> {
3737
using Config = Mersenne31::Config;
3838
};
3939

40+
template <>
41+
struct PackedPrimeFieldTraits<Mersenne31> {
42+
using PackedPrimeField = PackedMersenne31;
43+
};
44+
4045
} // namespace tachyon::math
4146

4247
namespace Eigen {

tachyon/math/finite_fields/packed_prime_field_traits_forward.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
namespace tachyon::math {
55

6-
template <typename T>
6+
template <typename T, typename SFINAE = void>
77
struct PackedPrimeFieldTraits;
88

99
} // namespace tachyon::math

tachyon/math/finite_fields/prime_field_base.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "tachyon/math/base/gmp/gmp_util.h"
2323
#include "tachyon/math/finite_fields/finite_field.h"
2424
#include "tachyon/math/finite_fields/legendre_symbol.h"
25+
#include "tachyon/math/finite_fields/packed_prime_field_traits_forward.h"
2526
#include "tachyon/math/finite_fields/prime_field_util.h"
2627

2728
namespace tachyon {
@@ -160,6 +161,12 @@ H AbslHashValue(H h, const F& prime_field) {
160161
return h;
161162
}
162163

164+
template <typename T>
165+
struct PackedPrimeFieldTraits<
166+
T, std::enable_if_t<std::is_base_of_v<math::PrimeFieldBase<T>, T>>> {
167+
using PackedPrimeField = T;
168+
};
169+
163170
} // namespace math
164171

165172
namespace base {

tachyon/math/matrix/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ tachyon_cc_library(
2626
name = "matrix_utils",
2727
hdrs = ["matrix_utils.h"],
2828
deps = [
29+
"//tachyon/base:bits",
30+
"//tachyon/base:openmp_util",
2931
"//tachyon/base/containers:container_util",
3032
"//tachyon/math/finite_fields:packed_prime_field_traits_forward",
3133
"@eigen_archive//:eigen3",

tachyon/math/matrix/matrix_utils.h

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
#ifndef TACHYON_MATH_MATRIX_MATRIX_UTILS_H_
22
#define TACHYON_MATH_MATRIX_MATRIX_UTILS_H_
33

4+
#include <utility>
45
#include <vector>
56

67
#include "third_party/eigen3/Eigen/Core"
78

9+
#include "tachyon/base/bits.h"
810
#include "tachyon/base/containers/container_util.h"
11+
#include "tachyon/base/openmp_util.h"
912
#include "tachyon/math/finite_fields/packed_prime_field_traits_forward.h"
1013

1114
namespace tachyon::math {
@@ -45,22 +48,25 @@ MakeCirculant(const Eigen::MatrixBase<ArgType>& arg) {
4548
CirculantFunctor<ArgType>(arg.derived()));
4649
}
4750

51+
// NOTE(ashjeong): Important! |matrix| should carry the same amount of rows as
52+
// the parent matrix it is a block from. |PackRowHorizontally| currently only
53+
// supports row-major matrices.
4854
template <typename PackedPrimeField, typename Derived, typename PrimeField>
49-
std::vector<PackedPrimeField> PackRowHorizontally(
50-
const Eigen::MatrixBase<Derived>& matrix, size_t row,
51-
std::vector<PrimeField>& remaining_values) {
55+
std::vector<PackedPrimeField*> PackRowHorizontally(
56+
Eigen::Block<Derived>& matrix, size_t row,
57+
std::vector<PrimeField*>& remaining_values) {
58+
static_assert(Derived::Options & Eigen::RowMajorBit);
5259
size_t num_packed = matrix.cols() / PackedPrimeField::N;
5360
size_t remaining_start_idx = num_packed * PackedPrimeField::N;
54-
remaining_values =
55-
base::CreateVector(matrix.cols() - remaining_start_idx,
56-
[row, remaining_start_idx, &matrix](size_t col) {
57-
return matrix(row, remaining_start_idx + col);
58-
});
59-
61+
remaining_values = base::CreateVector(
62+
matrix.cols() - remaining_start_idx,
63+
[row, remaining_start_idx, &matrix](size_t col) {
64+
return reinterpret_cast<PrimeField*>(
65+
matrix.data() + row * matrix.cols() + remaining_start_idx + col);
66+
});
6067
return base::CreateVector(num_packed, [row, &matrix](size_t col) {
61-
return PackedPrimeField::From([row, col, &matrix](size_t i) {
62-
return matrix(row, PackedPrimeField::N * col + i);
63-
});
68+
return reinterpret_cast<PackedPrimeField*>(
69+
matrix.data() + row * matrix.cols() + PackedPrimeField::N * col);
6470
});
6571
}
6672

@@ -74,6 +80,50 @@ std::vector<PackedPrimeField> PackRowVertically(
7480
});
7581
}
7682

83+
// Expands a |Eigen::MatrixBase|'s rows from |rows| to |rows|^(|added_bits|),
84+
// moving values from row |i| to row |i|^(|added_bits|). All new entries are set
85+
// to |F::Zero()|.
86+
template <typename Derived>
87+
void ExpandInPlaceWithZeroPad(Eigen::MatrixBase<Derived>& mat,
88+
size_t added_bits) {
89+
if (added_bits == 0) {
90+
return;
91+
}
92+
93+
Eigen::Index original_rows = mat.rows();
94+
Eigen::Index new_rows = mat.rows() << added_bits;
95+
Eigen::Index cols = mat.cols();
96+
97+
Derived padded = Derived::Zero(new_rows, cols);
98+
99+
OPENMP_PARALLEL_FOR(Eigen::Index row = 0; row < original_rows; ++row) {
100+
Eigen::Index padded_row_index = row << added_bits;
101+
// TODO(ashjeong): Check if moved properly
102+
padded.row(padded_row_index) = std::move(mat.row(row));
103+
}
104+
mat = std::move(padded);
105+
}
106+
107+
// Swaps rows of a |Eigen::MatrixBase| such that each row is changed to the row
108+
// accessed with the reversed bits of the current index. Crashes if the number
109+
// of rows is not a power of two.
110+
template <typename Derived>
111+
void ReverseMatrixIndexBits(Eigen::MatrixBase<Derived>& mat) {
112+
size_t rows = static_cast<size_t>(mat.rows());
113+
if (rows == 0) {
114+
return;
115+
}
116+
CHECK(base::bits::IsPowerOfTwo(rows));
117+
size_t log_n = base::bits::Log2Ceiling(rows);
118+
119+
OPENMP_PARALLEL_FOR(size_t row = 1; row < rows; ++row) {
120+
size_t ridx = base::bits::BitRev(row) >> (sizeof(size_t) * 8 - log_n);
121+
if (row < ridx) {
122+
mat.row(row).swap(mat.row(ridx));
123+
}
124+
}
125+
}
126+
77127
} // namespace tachyon::math
78128

79129
#endif // TACHYON_MATH_MATRIX_MATRIX_UTILS_H_

tachyon/math/matrix/matrix_utils_unittest.cc

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#include "tachyon/math/matrix/matrix_utils.h"
22

33
#include "tachyon/base/strings/string_util.h"
4-
#include "tachyon/build/build_config.h"
54
#include "tachyon/math/finite_fields/baby_bear/packed_baby_bear.h"
65
#include "tachyon/math/finite_fields/test/finite_field_test.h"
76
#include "tachyon/math/finite_fields/test/gf7.h"
@@ -27,30 +26,39 @@ TEST_F(MatrixPackingTest, PackRowHorizontally) {
2726
constexpr size_t N = PackedBabyBear::N;
2827
constexpr size_t R = 3;
2928

30-
Matrix<BabyBear> matrix = Matrix<BabyBear>::Random(2 * N, 2 * N);
31-
std::vector<BabyBear> remaining_values;
32-
std::vector<PackedBabyBear> packed_values =
33-
PackRowHorizontally<PackedBabyBear>(matrix, R, remaining_values);
34-
ASSERT_TRUE(remaining_values.empty());
35-
ASSERT_EQ(packed_values.size(), 2);
36-
for (size_t i = 0; i < packed_values.size(); ++i) {
37-
for (size_t j = 0; j < N; ++j) {
38-
EXPECT_EQ(packed_values[i][j], matrix(R, i * N + j));
29+
{
30+
RowMajorMatrix<BabyBear> matrix =
31+
RowMajorMatrix<BabyBear>::Random(2 * N, 2 * N);
32+
Eigen::Block<RowMajorMatrix<BabyBear>> mat =
33+
matrix.block(0, 0, matrix.rows(), matrix.cols());
34+
std::vector<BabyBear*> remaining_values;
35+
std::vector<PackedBabyBear*> packed_values =
36+
PackRowHorizontally<PackedBabyBear>(mat, R, remaining_values);
37+
ASSERT_TRUE(remaining_values.empty());
38+
ASSERT_EQ(packed_values.size(), 2);
39+
for (size_t i = 0; i < packed_values.size(); ++i) {
40+
for (size_t j = 0; j < N; ++j) {
41+
EXPECT_EQ((*packed_values[i])[j], matrix(R, i * N + j));
42+
}
3943
}
4044
}
41-
42-
matrix = Matrix<BabyBear>::Random(2 * N - 1, 2 * N - 1);
43-
remaining_values.clear();
44-
packed_values =
45-
PackRowHorizontally<PackedBabyBear>(matrix, R, remaining_values);
46-
ASSERT_EQ(remaining_values.size(), N - 1);
47-
ASSERT_EQ(packed_values.size(), 1);
48-
for (size_t i = 0; i < remaining_values.size(); ++i) {
49-
EXPECT_EQ(remaining_values[i], matrix(R, packed_values.size() * N + i));
50-
}
51-
for (size_t i = 0; i < packed_values.size(); ++i) {
52-
for (size_t j = 0; j < N; ++j) {
53-
EXPECT_EQ(packed_values[i][j], matrix(R, i * N + j));
45+
{
46+
RowMajorMatrix<BabyBear> matrix =
47+
RowMajorMatrix<BabyBear>::Random(2 * N - 1, 2 * N - 1);
48+
Eigen::Block<RowMajorMatrix<BabyBear>> mat =
49+
matrix.block(0, 0, matrix.rows(), matrix.cols());
50+
std::vector<BabyBear*> remaining_values;
51+
std::vector<PackedBabyBear*> packed_values =
52+
PackRowHorizontally<PackedBabyBear>(mat, R, remaining_values);
53+
ASSERT_EQ(remaining_values.size(), N - 1);
54+
ASSERT_EQ(packed_values.size(), 1);
55+
for (size_t i = 0; i < remaining_values.size(); ++i) {
56+
EXPECT_EQ(*remaining_values[i], matrix(R, packed_values.size() * N + i));
57+
}
58+
for (size_t i = 0; i < packed_values.size(); ++i) {
59+
for (size_t j = 0; j < N; ++j) {
60+
EXPECT_EQ((*packed_values[i])[j], matrix(R, i * N + j));
61+
}
5462
}
5563
}
5664
}

tachyon/math/polynomials/univariate/BUILD.bazel

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,41 @@ tachyon_cc_library(
2424
],
2525
)
2626

27+
tachyon_cc_library(
28+
name = "naive_batch_fft",
29+
hdrs = ["naive_batch_fft.h"],
30+
deps = [
31+
":two_adic_subgroup",
32+
"//tachyon/base:bits",
33+
],
34+
)
35+
2736
tachyon_cc_library(
2837
name = "radix2_evaluation_domain",
2938
hdrs = ["radix2_evaluation_domain.h"],
3039
deps = [
40+
":two_adic_subgroup",
3141
":univariate_evaluation_domain",
42+
"//tachyon/base:bits",
43+
"//tachyon/base:openmp_util",
3244
"//tachyon/base:parallelize",
3345
"//tachyon/base/containers:container_util",
46+
"//tachyon/math/finite_fields:packed_prime_field_base",
47+
"//tachyon/math/matrix:matrix_types",
48+
"//tachyon/math/matrix:matrix_utils",
3449
"@com_google_absl//absl/memory",
3550
"@com_google_absl//absl/types:span",
3651
"@com_google_googletest//:gtest_prod",
52+
"@eigen_archive//:eigen3",
53+
],
54+
)
55+
56+
tachyon_cc_library(
57+
name = "two_adic_subgroup",
58+
hdrs = ["two_adic_subgroup.h"],
59+
deps = [
60+
"//tachyon/base:optional",
61+
"//tachyon/math/matrix:matrix_types",
3762
],
3863
)
3964

@@ -118,6 +143,7 @@ tachyon_cc_unittest(
118143
name = "univariate_unittests",
119144
srcs = [
120145
"lagrange_interpolation_unittest.cc",
146+
"radix2_evaluation_domain_unittest.cc",
121147
"univariate_dense_polynomial_unittest.cc",
122148
"univariate_evaluation_domain_unittest.cc",
123149
"univariate_evaluations_unittest.cc",
@@ -126,6 +152,7 @@ tachyon_cc_unittest(
126152
deps = [
127153
":lagrange_interpolation",
128154
":mixed_radix_evaluation_domain",
155+
":naive_batch_fft",
129156
":radix2_evaluation_domain",
130157
":univariate_polynomial",
131158
"//tachyon/base:optional",
@@ -136,6 +163,8 @@ tachyon_cc_unittest(
136163
"//tachyon/math/elliptic_curves/bls12/bls12_381:fr",
137164
"//tachyon/math/elliptic_curves/bn/bn254:fr",
138165
"//tachyon/math/elliptic_curves/bn/bn384_small_two_adicity:fq",
166+
"//tachyon/math/finite_fields/baby_bear:packed_baby_bear",
167+
"//tachyon/math/finite_fields/koala_bear:packed_koala_bear",
139168
"//tachyon/math/finite_fields/test:finite_field_test",
140169
"//tachyon/math/finite_fields/test:gf7",
141170
"@com_google_absl//absl/hash:hash_testing",

tachyon/math/polynomials/univariate/mixed_radix_evaluation_domain.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ class MixedRadixEvaluationDomain
270270
for (size_t k = 0; k < n; k += 2 * m) {
271271
F w = F::One();
272272
for (size_t j = 0; j < m; ++j) {
273-
UnivariateEvaluationDomain<F, MaxDegree>::ButterflyFnOutIn(
273+
UnivariateEvaluationDomain<F, MaxDegree>::template ButterflyFnOutIn(
274274
a.at(k + j), a.at((k + m) + j), w);
275275
w *= w_m;
276276
}

0 commit comments

Comments
 (0)