Skip to content

Commit 522d6a5

Browse files
committed
binary-trie
1 parent 4b8388d commit 522d6a5

File tree

2 files changed

+220
-0
lines changed

2 files changed

+220
-0
lines changed

cpp/binary-trie.hpp

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
#pragma once
2+
3+
#include <assert.h>
4+
#include <memory>
5+
#include <utility>
6+
7+
/**
8+
* @brief 符号なし整数の多重集合を管理する
9+
* @tparam `d` 扱う整数値のビット幅。64以下であることを要請
10+
*/
11+
template <unsigned int d> class BinaryTrie {
12+
static_assert(d <= 64, "d must be 64 or less");
13+
struct BinaryTrieNode {
14+
std::shared_ptr<BinaryTrieNode> children[2] = {nullptr, nullptr};
15+
unsigned int level, subcnt = 0;
16+
unsigned long long xval = 0;
17+
18+
BinaryTrieNode(int lvl) : level(lvl) {}
19+
bool get_bit(unsigned long long v) const { return (v >> (level - 1)) & 1; }
20+
bool is_leaf() const { return level == 0; }
21+
// 子の状態がvalidかどうか
22+
// 0: xx, 1: xo, 2: ox, 3: oo
23+
int state_children() const {
24+
return ((bool)children[1] << 1) | (bool)children[0];
25+
}
26+
// xorの値を子ノードに伝播する
27+
void affect_xor() {
28+
if (get_bit(xval)) {
29+
std::swap(children[0], children[1]);
30+
}
31+
if (children[0])
32+
children[0]->xval ^= xval;
33+
if (children[1])
34+
children[1]->xval ^= xval;
35+
xval = 0;
36+
}
37+
};
38+
using NodePtr = std::shared_ptr<BinaryTrieNode>;
39+
NodePtr root_ptr = std::make_shared<BinaryTrieNode>(d);
40+
41+
public:
42+
/**
43+
* @brief 集合にnを追加 (O(d))
44+
*/
45+
void insert(unsigned long long n) {
46+
NodePtr cur_ptr = root_ptr;
47+
while (!cur_ptr->is_leaf()) {
48+
cur_ptr->affect_xor();
49+
cur_ptr->subcnt += 1;
50+
NodePtr &nxt_ptr = cur_ptr->children[cur_ptr->get_bit(n)];
51+
if (!nxt_ptr) {
52+
nxt_ptr = std::make_shared<BinaryTrieNode>(cur_ptr->level - 1);
53+
}
54+
cur_ptr = nxt_ptr;
55+
}
56+
assert(cur_ptr->is_leaf());
57+
cur_ptr->subcnt += 1;
58+
}
59+
60+
int size() const { return root_ptr->subcnt; }
61+
62+
/**
63+
* @brief 集合からnを検索し、見つかった数を求める (O(d))
64+
*/
65+
int count(unsigned long long n) const {
66+
NodePtr cur_ptr = root_ptr;
67+
while (!cur_ptr->is_leaf()) {
68+
cur_ptr->affect_xor();
69+
NodePtr &nxt_ptr = cur_ptr->children[cur_ptr->get_bit(n)];
70+
if (!nxt_ptr) {
71+
return nullptr;
72+
}
73+
cur_ptr = nxt_ptr;
74+
}
75+
return (cur_ptr ? cur_ptr->subcnt : 0);
76+
}
77+
78+
/**
79+
* @brief 集合からnを削除 (O(d))
80+
* @note 存在しない要素を指定したとき、何も起こらない
81+
*/
82+
void erase(unsigned long long n) const {
83+
int cnt = count(n);
84+
if (cnt == 0)
85+
return;
86+
NodePtr cur_ptr = root_ptr;
87+
while (true) {
88+
cur_ptr->affect_xor();
89+
cur_ptr->subcnt -= cnt;
90+
NodePtr &nxt_ptr = cur_ptr->children[cur_ptr->get_bit(n)];
91+
if (nxt_ptr->subcnt == cnt) {
92+
nxt_ptr = nullptr;
93+
return;
94+
}
95+
cur_ptr = nxt_ptr;
96+
}
97+
}
98+
99+
/**
100+
* @brief 集合からnを一つだけ削除 (O(d))
101+
* @note 存在しない要素を指定したとき、何も起こらない
102+
*/
103+
void erase_one_element(unsigned long long n) const {
104+
if (count(n) == 0)
105+
return;
106+
NodePtr cur_ptr = root_ptr;
107+
while (!cur_ptr->is_leaf()) {
108+
cur_ptr->affect_xor();
109+
cur_ptr->subcnt -= 1;
110+
NodePtr &nxt_ptr = cur_ptr->children[cur_ptr->get_bit(n)];
111+
if (nxt_ptr->subcnt == 1) {
112+
nxt_ptr = nullptr;
113+
return;
114+
}
115+
cur_ptr = nxt_ptr;
116+
}
117+
cur_ptr->subcnt -= 1;
118+
}
119+
120+
/**
121+
* @brief 昇順でn番目の要素を探索 (O(d))
122+
* @note nがtrie木のサイズ以上な場合、assert
123+
*/
124+
unsigned long long nth_element(int n) const {
125+
assert(n < size());
126+
unsigned long long ret = 0;
127+
NodePtr cur_ptr = root_ptr;
128+
while (!cur_ptr->is_leaf()) {
129+
cur_ptr->affect_xor();
130+
ret <<= 1;
131+
int state = cur_ptr->state_children();
132+
NodePtr &z_ptr = cur_ptr->children[0];
133+
NodePtr &o_ptr = cur_ptr->children[1];
134+
assert(state > 0);
135+
if (state == 1 || (state == 3 && n < z_ptr->subcnt)) {
136+
cur_ptr = z_ptr;
137+
} else {
138+
n -= (state & 1 ? z_ptr->subcnt : 0);
139+
ret |= 1;
140+
cur_ptr = o_ptr;
141+
}
142+
}
143+
return ret;
144+
}
145+
146+
/**
147+
* @brief n以上の要素を探索 (O(d))
148+
* @return 探索した値が昇順で何番目か (0-indexed)。該当する要素がなければtrie木のサイズが返る
149+
*/
150+
int lower_bound(unsigned long long n) const {
151+
int ret = 0;
152+
NodePtr cur_ptr = root_ptr;
153+
while (!cur_ptr->is_leaf()) {
154+
cur_ptr->affect_xor();
155+
bool b = cur_ptr->get_bit(n);
156+
NodePtr &nxt_ptr = cur_ptr->children[b];
157+
NodePtr &z_ptr = cur_ptr->children[0];
158+
if (b && z_ptr) {
159+
ret += z_ptr->subcnt;
160+
}
161+
if (!nxt_ptr) {
162+
break;
163+
}
164+
cur_ptr = nxt_ptr;
165+
}
166+
return ret;
167+
}
168+
169+
/**
170+
* @brief nより大きな要素を探索 (O(d))
171+
* @return 探索した値が昇順で何番目か (0-indexed)。該当する要素がなければtrie木のサイズが返る
172+
*/
173+
int upper_bound(unsigned long long n) const {
174+
return (n < UINT64_MAX ? lower_bound(n + 1) : size());
175+
}
176+
177+
/**
178+
* @brief 集合のすべての要素にxorを作用
179+
*/
180+
void apply_xor(unsigned long long n) { root_ptr->xval ^= n; }
181+
182+
/**
183+
* @brief 要素をすべて削除する。確保したメモリ領域も削除される
184+
*/
185+
void clear() {
186+
root_ptr = std::make_shared<BinaryTrieNode>(d);
187+
}
188+
};

test/yosupo-set-xor-min.test.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#define PROBLEM "https://judge.yosupo.jp/problem/set_xor_min"
2+
3+
#include <iostream>
4+
5+
#include "../cpp/binary-trie.hpp"
6+
7+
int main(void) {
8+
9+
int Q;
10+
std::cin >> Q;
11+
BinaryTrie<30> A;
12+
while (Q--) {
13+
int q, x;
14+
std::cin >> q >> x;
15+
switch (q) {
16+
case 0:
17+
if (A.count(x) == 0) {
18+
A.insert(x);
19+
}
20+
break;
21+
case 1:
22+
A.erase(x);
23+
break;
24+
case 2:
25+
A.apply_xor(x);
26+
std::cout << A.nth_element(0) << std::endl;
27+
A.apply_xor(x);
28+
break;
29+
}
30+
}
31+
32+
}

0 commit comments

Comments
 (0)