-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathntt.h
94 lines (82 loc) · 2.41 KB
/
ntt.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#pragma once
#include "primitive_root.h"
#include <array>
#include <cassert>
#include <stdexcept>
#include <string>
#include <vector>
#include <stdexcept>
template <typename Mod, int NUMBER_OF_BUFFER = 5> struct NttT {
static void assert_power_of_two(int n) {
if (n & (n - 1)) {
throw std::invalid_argument(std::to_string(n) + " is not a power of two");
}
}
static constexpr Mod get_primitive_root() { return G; }
void reserve(int n) {
if (max_n < n) {
assert_power_of_two(n);
if ((Mod::mod() - 1) % n != 0) {
throw std::invalid_argument(std::to_string(n) +
" is not a divisor of (Mod::mod() - 1)");
}
max_n = n;
auto log_max_n = __builtin_ctz(max_n);
power_of_two_invs.resize(log_max_n + 1);
power_of_two_invs[0] = Mod{1};
for (int i = 1; i <= log_max_n; ++i) {
power_of_two_invs[i] = power_of_two_invs[i - 1] * Mod{2}.inv();
}
twiddles.resize(max_n + 1);
twiddles[0] = Mod{1};
auto omega = binpow(G, (Mod::mod() - 1) / n);
for (int i = 1; i <= max_n; ++i) {
twiddles[i] = twiddles[i - 1] * omega;
}
for (int i = 0; i < NUMBER_OF_BUFFER; ++i) {
buffers[i].reserve(max_n);
}
}
}
template <int i> Mod *raw_buffer() { return buffers[i].data(); }
Mod power_of_two_inv(int n) const {
return power_of_two_invs[__builtin_ctz(n)];
}
void dit(int n, Mod *a) {
assert_power_of_two(n);
for (int m = 1; m < n; m <<= 1) {
auto step = max_n / (m << 1);
for (int i = 0; i < n; i += m << 1) {
int tid = 0;
for (int r = i; r < i + m; r++) {
auto tmp = twiddles[tid] * a[r + m];
a[r + m] = a[r];
a[r + m] -= tmp;
a[r] += tmp;
tid += step;
}
}
}
}
void dif(int n, Mod *a) {
assert_power_of_two(n);
for (int m = n; m >>= 1;) {
auto step = max_n / (m << 1);
for (int i = 0; i < n; i += m << 1) {
int tid = max_n;
for (int r = i; r < i + m; r++) {
auto tmp = a[r];
tmp -= a[r + m];
a[r] += a[r + m];
a[r + m] = twiddles[tid] * tmp;
tid -= step;
}
}
}
}
private:
static constexpr Mod G = FiniteField<Mod>::primitive_root();
int max_n = 0;
std::vector<Mod> power_of_two_invs, twiddles;
std::array<std::vector<Mod>, NUMBER_OF_BUFFER> buffers;
};