|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include <assert.h> |
| 4 | +#include <memory> |
| 5 | +#include <utility> |
| 6 | + |
| 7 | +/** |
| 8 | + * @brief 符号なし整数の多重集合を管理する |
| 9 | + * @tparam `d` 扱う整数値のビット幅。64以下であることを要請 |
| 10 | + */ |
| 11 | +template <unsigned int d> class BinaryTrie { |
| 12 | + static_assert(d <= 64, "d must be 64 or less"); |
| 13 | + struct BinaryTrieNode { |
| 14 | + std::shared_ptr<BinaryTrieNode> children[2] = {nullptr, nullptr}; |
| 15 | + unsigned int level, subcnt = 0; |
| 16 | + unsigned long long xval = 0; |
| 17 | + |
| 18 | + BinaryTrieNode(int lvl) : level(lvl) {} |
| 19 | + bool get_bit(unsigned long long v) const { return (v >> (level - 1)) & 1; } |
| 20 | + bool is_leaf() const { return level == 0; } |
| 21 | + // 子の状態がvalidかどうか |
| 22 | + // 0: xx, 1: xo, 2: ox, 3: oo |
| 23 | + int state_children() const { |
| 24 | + return ((bool)children[1] << 1) | (bool)children[0]; |
| 25 | + } |
| 26 | + // xorの値を子ノードに伝播する |
| 27 | + void affect_xor() { |
| 28 | + if (get_bit(xval)) { |
| 29 | + std::swap(children[0], children[1]); |
| 30 | + } |
| 31 | + if (children[0]) |
| 32 | + children[0]->xval ^= xval; |
| 33 | + if (children[1]) |
| 34 | + children[1]->xval ^= xval; |
| 35 | + xval = 0; |
| 36 | + } |
| 37 | + }; |
| 38 | + using NodePtr = std::shared_ptr<BinaryTrieNode>; |
| 39 | + NodePtr root_ptr = std::make_shared<BinaryTrieNode>(d); |
| 40 | + |
| 41 | + public: |
| 42 | + /** |
| 43 | + * @brief 集合にnを追加 (O(d)) |
| 44 | + */ |
| 45 | + void insert(unsigned long long n) { |
| 46 | + NodePtr cur_ptr = root_ptr; |
| 47 | + while (!cur_ptr->is_leaf()) { |
| 48 | + cur_ptr->affect_xor(); |
| 49 | + cur_ptr->subcnt += 1; |
| 50 | + NodePtr &nxt_ptr = cur_ptr->children[cur_ptr->get_bit(n)]; |
| 51 | + if (!nxt_ptr) { |
| 52 | + nxt_ptr = std::make_shared<BinaryTrieNode>(cur_ptr->level - 1); |
| 53 | + } |
| 54 | + cur_ptr = nxt_ptr; |
| 55 | + } |
| 56 | + assert(cur_ptr->is_leaf()); |
| 57 | + cur_ptr->subcnt += 1; |
| 58 | + } |
| 59 | + |
| 60 | + int size() const { return root_ptr->subcnt; } |
| 61 | + |
| 62 | + /** |
| 63 | + * @brief 集合からnを検索し、見つかった数を求める (O(d)) |
| 64 | + */ |
| 65 | + int count(unsigned long long n) const { |
| 66 | + NodePtr cur_ptr = root_ptr; |
| 67 | + while (!cur_ptr->is_leaf()) { |
| 68 | + cur_ptr->affect_xor(); |
| 69 | + NodePtr &nxt_ptr = cur_ptr->children[cur_ptr->get_bit(n)]; |
| 70 | + if (!nxt_ptr) { |
| 71 | + return nullptr; |
| 72 | + } |
| 73 | + cur_ptr = nxt_ptr; |
| 74 | + } |
| 75 | + return (cur_ptr ? cur_ptr->subcnt : 0); |
| 76 | + } |
| 77 | + |
| 78 | + /** |
| 79 | + * @brief 集合からnを削除 (O(d)) |
| 80 | + * @note 存在しない要素を指定したとき、何も起こらない |
| 81 | + */ |
| 82 | + void erase(unsigned long long n) const { |
| 83 | + int cnt = count(n); |
| 84 | + if (cnt == 0) |
| 85 | + return; |
| 86 | + NodePtr cur_ptr = root_ptr; |
| 87 | + while (true) { |
| 88 | + cur_ptr->affect_xor(); |
| 89 | + cur_ptr->subcnt -= cnt; |
| 90 | + NodePtr &nxt_ptr = cur_ptr->children[cur_ptr->get_bit(n)]; |
| 91 | + if (nxt_ptr->subcnt == cnt) { |
| 92 | + nxt_ptr = nullptr; |
| 93 | + return; |
| 94 | + } |
| 95 | + cur_ptr = nxt_ptr; |
| 96 | + } |
| 97 | + } |
| 98 | + |
| 99 | + /** |
| 100 | + * @brief 集合からnを一つだけ削除 (O(d)) |
| 101 | + * @note 存在しない要素を指定したとき、何も起こらない |
| 102 | + */ |
| 103 | + void erase_one_element(unsigned long long n) const { |
| 104 | + if (count(n) == 0) |
| 105 | + return; |
| 106 | + NodePtr cur_ptr = root_ptr; |
| 107 | + while (!cur_ptr->is_leaf()) { |
| 108 | + cur_ptr->affect_xor(); |
| 109 | + cur_ptr->subcnt -= 1; |
| 110 | + NodePtr &nxt_ptr = cur_ptr->children[cur_ptr->get_bit(n)]; |
| 111 | + if (nxt_ptr->subcnt == 1) { |
| 112 | + nxt_ptr = nullptr; |
| 113 | + return; |
| 114 | + } |
| 115 | + cur_ptr = nxt_ptr; |
| 116 | + } |
| 117 | + cur_ptr->subcnt -= 1; |
| 118 | + } |
| 119 | + |
| 120 | + /** |
| 121 | + * @brief 昇順でn番目の要素を探索 (O(d)) |
| 122 | + * @note nがtrie木のサイズ以上な場合、assert |
| 123 | + */ |
| 124 | + unsigned long long nth_element(int n) const { |
| 125 | + assert(n < size()); |
| 126 | + unsigned long long ret = 0; |
| 127 | + NodePtr cur_ptr = root_ptr; |
| 128 | + while (!cur_ptr->is_leaf()) { |
| 129 | + cur_ptr->affect_xor(); |
| 130 | + ret <<= 1; |
| 131 | + int state = cur_ptr->state_children(); |
| 132 | + NodePtr &z_ptr = cur_ptr->children[0]; |
| 133 | + NodePtr &o_ptr = cur_ptr->children[1]; |
| 134 | + assert(state > 0); |
| 135 | + if (state == 1 || (state == 3 && n < z_ptr->subcnt)) { |
| 136 | + cur_ptr = z_ptr; |
| 137 | + } else { |
| 138 | + n -= (state & 1 ? z_ptr->subcnt : 0); |
| 139 | + ret |= 1; |
| 140 | + cur_ptr = o_ptr; |
| 141 | + } |
| 142 | + } |
| 143 | + return ret; |
| 144 | + } |
| 145 | + |
| 146 | + /** |
| 147 | + * @brief n以上の要素を探索 (O(d)) |
| 148 | + * @return 探索した値が昇順で何番目か (0-indexed)。該当する要素がなければtrie木のサイズが返る |
| 149 | + */ |
| 150 | + int lower_bound(unsigned long long n) const { |
| 151 | + int ret = 0; |
| 152 | + NodePtr cur_ptr = root_ptr; |
| 153 | + while (!cur_ptr->is_leaf()) { |
| 154 | + cur_ptr->affect_xor(); |
| 155 | + bool b = cur_ptr->get_bit(n); |
| 156 | + NodePtr &nxt_ptr = cur_ptr->children[b]; |
| 157 | + NodePtr &z_ptr = cur_ptr->children[0]; |
| 158 | + if (b && z_ptr) { |
| 159 | + ret += z_ptr->subcnt; |
| 160 | + } |
| 161 | + if (!nxt_ptr) { |
| 162 | + break; |
| 163 | + } |
| 164 | + cur_ptr = nxt_ptr; |
| 165 | + } |
| 166 | + return ret; |
| 167 | + } |
| 168 | + |
| 169 | + /** |
| 170 | + * @brief nより大きな要素を探索 (O(d)) |
| 171 | + * @return 探索した値が昇順で何番目か (0-indexed)。該当する要素がなければtrie木のサイズが返る |
| 172 | + */ |
| 173 | + int upper_bound(unsigned long long n) const { |
| 174 | + return (n < UINT64_MAX ? lower_bound(n + 1) : size()); |
| 175 | + } |
| 176 | + |
| 177 | + /** |
| 178 | + * @brief 集合のすべての要素にxorを作用 |
| 179 | + */ |
| 180 | + void apply_xor(unsigned long long n) { root_ptr->xval ^= n; } |
| 181 | + |
| 182 | + /** |
| 183 | + * @brief 要素をすべて削除する。確保したメモリ領域も削除される |
| 184 | + */ |
| 185 | + void clear() { |
| 186 | + root_ptr = std::make_shared<BinaryTrieNode>(d); |
| 187 | + } |
| 188 | +}; |
0 commit comments