Skip to content

Commit 726e162

Browse files
authored
perf: kernel replace with precise cache (#4796)
Changes: - We avoid the thread local storage. - We use a hash map to ensure that cached values are not lost. - We remove `check_system`. If this becomes an issue in the future we should precompute the remaining amount of stack space, and use a cheaper check. - We add a `Expr.replaceImpl`, and will use it to implement `Expr.replace` after update-stage0
1 parent de5e07c commit 726e162

File tree

3 files changed

+85
-93
lines changed

3 files changed

+85
-93
lines changed

src/Lean/Util/ReplaceExpr.lean

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ def replaceNoCache (f? : Expr → Option Expr) (e : Expr) : Expr :=
7777
| .proj _ _ b => let b := replaceNoCache f? b; e.updateProj! b
7878
| e => e
7979

80+
81+
@[extern "lean_replace_expr"]
82+
opaque replaceImpl (f? : @& (Expr → Option Expr)) (e : @& Expr) : Expr
83+
8084
@[implemented_by ReplaceImpl.replaceUnsafe]
81-
partial def replace (f? : Expr → Option Expr) (e : Expr) : Expr :=
85+
def replace (f? : Expr → Option Expr) (e : Expr) : Expr :=
8286
e.replaceNoCache f?

src/kernel/cache_stack.h

Lines changed: 0 additions & 41 deletions
This file was deleted.

src/kernel/replace_fn.cpp

Lines changed: 80 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -6,75 +6,35 @@ Author: Leonardo de Moura
66
*/
77
#include <vector>
88
#include <memory>
9+
#include <unordered_map>
910
#include "kernel/replace_fn.h"
10-
#include "kernel/cache_stack.h"
11-
12-
#ifndef LEAN_DEFAULT_REPLACE_CACHE_CAPACITY
13-
#define LEAN_DEFAULT_REPLACE_CACHE_CAPACITY 1024*8
14-
#endif
1511

1612
namespace lean {
17-
struct replace_cache {
18-
struct entry {
19-
object * m_cell;
20-
unsigned m_offset;
21-
expr m_result;
22-
entry():m_cell(nullptr) {}
23-
};
24-
unsigned m_capacity;
25-
std::vector<entry> m_cache;
26-
std::vector<unsigned> m_used;
27-
replace_cache(unsigned c):m_capacity(c), m_cache(c) {}
28-
29-
expr * find(expr const & e, unsigned offset) {
30-
unsigned i = hash(hash(e), offset) % m_capacity;
31-
if (m_cache[i].m_cell == e.raw() && m_cache[i].m_offset == offset)
32-
return &m_cache[i].m_result;
33-
else
34-
return nullptr;
35-
}
36-
37-
void insert(expr const & e, unsigned offset, expr const & v) {
38-
unsigned i = hash(hash(e), offset) % m_capacity;
39-
if (m_cache[i].m_cell == nullptr)
40-
m_used.push_back(i);
41-
m_cache[i].m_cell = e.raw();
42-
m_cache[i].m_offset = offset;
43-
m_cache[i].m_result = v;
44-
}
45-
46-
void clear() {
47-
for (unsigned i : m_used) {
48-
m_cache[i].m_cell = nullptr;
49-
m_cache[i].m_result = expr();
50-
}
51-
m_used.clear();
52-
}
53-
};
54-
55-
/* CACHE_RESET: NO */
56-
MK_CACHE_STACK(replace_cache, LEAN_DEFAULT_REPLACE_CACHE_CAPACITY)
5713

5814
class replace_rec_fn {
59-
replace_cache_ref m_cache;
15+
struct key_hasher {
16+
std::size_t operator()(std::pair<lean_object *, unsigned> const & p) const {
17+
return hash((size_t)p.first, p.second);
18+
}
19+
};
20+
std::unordered_map<std::pair<lean_object *, unsigned>, expr, key_hasher> m_cache;
6021
std::function<optional<expr>(expr const &, unsigned)> m_f;
6122
bool m_use_cache;
6223

6324
expr save_result(expr const & e, unsigned offset, expr const & r, bool shared) {
6425
if (shared)
65-
m_cache->insert(e, offset, r);
26+
m_cache.insert(mk_pair(mk_pair(e.raw(), offset), r));
6627
return r;
6728
}
6829

6930
expr apply(expr const & e, unsigned offset) {
7031
bool shared = false;
7132
if (m_use_cache && is_shared(e)) {
72-
if (auto r = m_cache->find(e, offset))
73-
return *r;
33+
auto it = m_cache.find(mk_pair(e.raw(), offset));
34+
if (it != m_cache.end())
35+
return it->second;
7436
shared = true;
7537
}
76-
check_system("replace");
77-
7838
if (optional<expr> r = m_f(e, offset)) {
7939
return save_result(e, offset, *r, shared);
8040
} else {
@@ -121,4 +81,73 @@ class replace_rec_fn {
12181
expr replace(expr const & e, std::function<optional<expr>(expr const &, unsigned)> const & f, bool use_cache) {
12282
return replace_rec_fn(f, use_cache)(e);
12383
}
84+
85+
class replace_fn {
86+
std::unordered_map<lean_object *, expr> m_cache;
87+
lean_object * m_f;
88+
89+
expr save_result(expr const & e, expr const & r, bool shared) {
90+
if (shared)
91+
m_cache.insert(mk_pair(e.raw(), r));
92+
return r;
93+
}
94+
95+
expr apply(expr const & e) {
96+
bool shared = false;
97+
if (is_shared(e)) {
98+
auto it = m_cache.find(e.raw());
99+
if (it != m_cache.end())
100+
return it->second;
101+
shared = true;
102+
}
103+
104+
lean_inc(e.raw());
105+
lean_inc_ref(m_f);
106+
lean_object * r = lean_apply_1(m_f, e.raw());
107+
if (!lean_is_scalar(r)) {
108+
expr e_new(lean_ctor_get(r, 0));
109+
lean_dec_ref(r);
110+
return save_result(e, e_new, shared);
111+
}
112+
113+
switch (e.kind()) {
114+
case expr_kind::Const: case expr_kind::Sort:
115+
case expr_kind::BVar: case expr_kind::Lit:
116+
case expr_kind::MVar: case expr_kind::FVar:
117+
return save_result(e, e, shared);
118+
case expr_kind::MData: {
119+
expr new_e = apply(mdata_expr(e));
120+
return save_result(e, update_mdata(e, new_e), shared);
121+
}
122+
case expr_kind::Proj: {
123+
expr new_e = apply(proj_expr(e));
124+
return save_result(e, update_proj(e, new_e), shared);
125+
}
126+
case expr_kind::App: {
127+
expr new_f = apply(app_fn(e));
128+
expr new_a = apply(app_arg(e));
129+
return save_result(e, update_app(e, new_f, new_a), shared);
130+
}
131+
case expr_kind::Pi: case expr_kind::Lambda: {
132+
expr new_d = apply(binding_domain(e));
133+
expr new_b = apply(binding_body(e));
134+
return save_result(e, update_binding(e, new_d, new_b), shared);
135+
}
136+
case expr_kind::Let: {
137+
expr new_t = apply(let_type(e));
138+
expr new_v = apply(let_value(e));
139+
expr new_b = apply(let_body(e));
140+
return save_result(e, update_let(e, new_t, new_v, new_b), shared);
141+
}}
142+
lean_unreachable();
143+
}
144+
public:
145+
replace_fn(lean_object * f):m_f(f) {}
146+
expr operator()(expr const & e) { return apply(e); }
147+
};
148+
149+
extern "C" LEAN_EXPORT obj_res lean_replace_expr(b_obj_arg f, b_obj_arg e) {
150+
expr r = replace_fn(f)(TO_REF(expr, e));
151+
return r.steal();
152+
}
124153
}

0 commit comments

Comments
 (0)