Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: kernel replace with precise cache #4796

Merged
merged 3 commits into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/Lean/Util/ReplaceExpr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ def replaceNoCache (f? : Expr → Option Expr) (e : Expr) : Expr :=
| .proj _ _ b => let b := replaceNoCache f? b; e.updateProj! b
| e => e


@[extern "lean_replace_expr"]
opaque replaceImpl (f? : @& (Expr → Option Expr)) (e : @& Expr) : Expr

@[implemented_by ReplaceImpl.replaceUnsafe]
partial def replace (f? : Expr → Option Expr) (e : Expr) : Expr :=
def replace (f? : Expr → Option Expr) (e : Expr) : Expr :=
e.replaceNoCache f?
41 changes: 0 additions & 41 deletions src/kernel/cache_stack.h

This file was deleted.

131 changes: 80 additions & 51 deletions src/kernel/replace_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,75 +6,35 @@ Author: Leonardo de Moura
*/
#include <vector>
#include <memory>
#include <unordered_map>
#include "kernel/replace_fn.h"
#include "kernel/cache_stack.h"

#ifndef LEAN_DEFAULT_REPLACE_CACHE_CAPACITY
#define LEAN_DEFAULT_REPLACE_CACHE_CAPACITY 1024*8
#endif

namespace lean {
struct replace_cache {
struct entry {
object * m_cell;
unsigned m_offset;
expr m_result;
entry():m_cell(nullptr) {}
};
unsigned m_capacity;
std::vector<entry> m_cache;
std::vector<unsigned> m_used;
replace_cache(unsigned c):m_capacity(c), m_cache(c) {}

expr * find(expr const & e, unsigned offset) {
unsigned i = hash(hash(e), offset) % m_capacity;
if (m_cache[i].m_cell == e.raw() && m_cache[i].m_offset == offset)
return &m_cache[i].m_result;
else
return nullptr;
}

void insert(expr const & e, unsigned offset, expr const & v) {
unsigned i = hash(hash(e), offset) % m_capacity;
if (m_cache[i].m_cell == nullptr)
m_used.push_back(i);
m_cache[i].m_cell = e.raw();
m_cache[i].m_offset = offset;
m_cache[i].m_result = v;
}

void clear() {
for (unsigned i : m_used) {
m_cache[i].m_cell = nullptr;
m_cache[i].m_result = expr();
}
m_used.clear();
}
};

/* CACHE_RESET: NO */
MK_CACHE_STACK(replace_cache, LEAN_DEFAULT_REPLACE_CACHE_CAPACITY)

class replace_rec_fn {
replace_cache_ref m_cache;
struct key_hasher {
std::size_t operator()(std::pair<lean_object *, unsigned> const & p) const {
return hash((size_t)p.first, p.second);
}
};
std::unordered_map<std::pair<lean_object *, unsigned>, expr, key_hasher> m_cache;
std::function<optional<expr>(expr const &, unsigned)> m_f;
bool m_use_cache;

expr save_result(expr const & e, unsigned offset, expr const & r, bool shared) {
if (shared)
m_cache->insert(e, offset, r);
m_cache.insert(mk_pair(mk_pair(e.raw(), offset), r));
return r;
}

expr apply(expr const & e, unsigned offset) {
bool shared = false;
if (m_use_cache && is_shared(e)) {
if (auto r = m_cache->find(e, offset))
return *r;
auto it = m_cache.find(mk_pair(e.raw(), offset));
if (it != m_cache.end())
return it->second;
shared = true;
}
check_system("replace");

if (optional<expr> r = m_f(e, offset)) {
return save_result(e, offset, *r, shared);
} else {
Expand Down Expand Up @@ -121,4 +81,73 @@ class replace_rec_fn {
expr replace(expr const & e, std::function<optional<expr>(expr const &, unsigned)> const & f, bool use_cache) {
return replace_rec_fn(f, use_cache)(e);
}

class replace_fn {
std::unordered_map<lean_object *, expr> m_cache;
lean_object * m_f;

expr save_result(expr const & e, expr const & r, bool shared) {
if (shared)
m_cache.insert(mk_pair(e.raw(), r));
return r;
}

expr apply(expr const & e) {
bool shared = false;
if (is_shared(e)) {
auto it = m_cache.find(e.raw());
if (it != m_cache.end())
return it->second;
shared = true;
}

lean_inc(e.raw());
lean_inc_ref(m_f);
lean_object * r = lean_apply_1(m_f, e.raw());
if (!lean_is_scalar(r)) {
expr e_new(lean_ctor_get(r, 0));
lean_dec_ref(r);
return save_result(e, e_new, shared);
}

switch (e.kind()) {
case expr_kind::Const: case expr_kind::Sort:
case expr_kind::BVar: case expr_kind::Lit:
case expr_kind::MVar: case expr_kind::FVar:
return save_result(e, e, shared);
case expr_kind::MData: {
expr new_e = apply(mdata_expr(e));
return save_result(e, update_mdata(e, new_e), shared);
}
case expr_kind::Proj: {
expr new_e = apply(proj_expr(e));
return save_result(e, update_proj(e, new_e), shared);
}
case expr_kind::App: {
expr new_f = apply(app_fn(e));
expr new_a = apply(app_arg(e));
return save_result(e, update_app(e, new_f, new_a), shared);
}
case expr_kind::Pi: case expr_kind::Lambda: {
expr new_d = apply(binding_domain(e));
expr new_b = apply(binding_body(e));
return save_result(e, update_binding(e, new_d, new_b), shared);
}
case expr_kind::Let: {
expr new_t = apply(let_type(e));
expr new_v = apply(let_value(e));
expr new_b = apply(let_body(e));
return save_result(e, update_let(e, new_t, new_v, new_b), shared);
}}
lean_unreachable();
}
public:
replace_fn(lean_object * f):m_f(f) {}
expr operator()(expr const & e) { return apply(e); }
};

extern "C" LEAN_EXPORT obj_res lean_replace_expr(b_obj_arg f, b_obj_arg e) {
expr r = replace_fn(f)(TO_REF(expr, e));
return r.steal();
}
}
Loading