Skip to content

Commit

Permalink
perf: for_each with precise cache (#4794)
Browse files Browse the repository at this point in the history
This commit also adds support for `find?` and `findExt?` using kernel
`for_each`.
We need to perform `update-stage0`.
  • Loading branch information
leodemoura authored Jul 20, 2024
1 parent d907771 commit 6c33b9c
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 98 deletions.
9 changes: 9 additions & 0 deletions src/Lean/Util/FindExpr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ unsafe def findUnsafe? (p : Expr → Bool) (e : Expr) : Option Expr :=

end FindImpl

-- TODO: replace `find?` with this one after update-stage0
@[extern "lean_find_expr"]
opaque findImpl? (p : @& (Expr → Bool)) (e : @& Expr) : Option Expr

@[implemented_by FindImpl.findUnsafe?]
def find? (p : Expr → Bool) (e : Expr) : Option Expr :=
/- This is a reference implementation for the unsafe one above -/
Expand All @@ -52,6 +56,7 @@ def find? (p : Expr → Bool) (e : Expr) : Option Expr :=
| .proj _ _ b => find? p b
| _ => none


/-- Return true if `e` occurs in `t` -/
def occurs (e : Expr) (t : Expr) : Bool :=
(t.find? fun s => s == e).isSome
Expand All @@ -64,6 +69,10 @@ inductive FindStep where
/-- Search subterms -/ | visit
/-- Do not search subterms -/ | done

-- TODO: replace `findExt?` with this one after update-stage0
@[extern "lean_find_ext_expr"]
opaque findExtImpl? (p : @& (Expr → FindStep)) (e : @& Expr) : Option Expr

namespace FindExtImpl

unsafe def findM? (p : Expr → FindStep) (e : Expr) : OptionT FindImpl.FindM Expr :=
Expand Down
4 changes: 2 additions & 2 deletions src/kernel/declaration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ bool declaration::is_unsafe() const {

bool use_unsafe(environment const & env, expr const & e) {
bool found = false;
for_each(e, [&](expr const & e, unsigned) {
for_each(e, [&](expr const & e) {
if (found) return false;
if (is_constant(e)) {
if (auto info = env.find(const_name(e))) {
Expand All @@ -181,7 +181,7 @@ declaration::declaration():declaration(*g_dummy) {}

static unsigned get_max_height(environment const & env, expr const & v) {
unsigned h = 0;
for_each(v, [&](expr const & e, unsigned) {
for_each(v, [&](expr const & e) {
if (is_constant(e)) {
auto d = env.find(const_name(e));
if (d && d->get_hints().get_height() > h)
Expand Down
2 changes: 1 addition & 1 deletion src/kernel/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ optional<expr> has_expr_metavar_strict(expr const & e) {
if (!has_expr_metavar(e))
return none_expr();
optional<expr> r;
for_each(e, [&](expr const & e, unsigned) {
for_each(e, [&](expr const & e) {
if (r || !has_expr_metavar(e)) return false;
if (is_metavar_app(e)) { r = e; return false; }
return true;
Expand Down
278 changes: 190 additions & 88 deletions src/kernel/for_each_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,119 +5,221 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include <vector>
#include <unordered_map>
#include <utility>
#include "runtime/memory.h"
#include "runtime/interrupt.h"
#include "runtime/flet.h"
#include "kernel/for_each_fn.h"
#include "kernel/cache_stack.h"

#ifndef LEAN_DEFAULT_FOR_EACH_CACHE_CAPACITY
#define LEAN_DEFAULT_FOR_EACH_CACHE_CAPACITY 1024*8
#endif

namespace lean {
struct for_each_cache {
struct entry {
object const * m_cell;
unsigned m_offset;
entry():m_cell(nullptr) {}
};
unsigned m_capacity;
std::vector<entry> m_cache;
std::vector<unsigned> m_used;
for_each_cache(unsigned c):m_capacity(c), m_cache(c) {}

bool visited(expr const & e, unsigned offset) {
unsigned i = hash(hash(e), offset) % m_capacity;
if (m_cache[i].m_cell == e.raw() && m_cache[i].m_offset == offset) {
return true;
/*
If `partial_apps = true`, then given a term `g a b`, we also apply the function `m_f` to `g a`,
and not only to `g`, `a`, and `b`.
*/
template<bool partial_apps> class for_each_fn {
std::unordered_set<lean_object *> m_cache;
std::function<bool(expr const &)> m_f; // NOLINT

bool visited(expr const & e) {
if (!is_shared(e)) return false;
if (m_cache.find(e.raw()) != m_cache.end()) return true;
m_cache.insert(e.raw());
return false;
}

void apply_fn(expr const & e) {
if (is_app(e)) {
apply_fn(app_fn(e));
apply(app_arg(e));
} else {
if (m_cache[i].m_cell == nullptr)
m_used.push_back(i);
m_cache[i].m_cell = e.raw();
m_cache[i].m_offset = offset;
return false;
apply(e);
}
}

void clear() {
for (unsigned i : m_used)
m_cache[i].m_cell = nullptr;
m_used.clear();
void apply(expr const & e) {
switch (e.kind()) {
case expr_kind::Const: case expr_kind::BVar: case expr_kind::Sort:
m_f(e);
return;
default:
break;
}

if (visited(e))
return;

if (!m_f(e))
return;

switch (e.kind()) {
case expr_kind::Const: case expr_kind::BVar:
case expr_kind::Sort: case expr_kind::Lit:
case expr_kind::MVar: case expr_kind::FVar:
return;
case expr_kind::MData:
apply(mdata_expr(e));
return;
case expr_kind::Proj:
apply(proj_expr(e));
return;
case expr_kind::App:
if (partial_apps)
apply(app_fn(e));
else
apply_fn(e);
apply(app_arg(e));
return;
case expr_kind::Lambda: case expr_kind::Pi:
apply(binding_domain(e));
apply(binding_body(e));
return;
case expr_kind::Let:
apply(let_type(e));
apply(let_value(e));
apply(let_body(e));
return;
}
}
};

/* CACHE_RESET: NO */
MK_CACHE_STACK(for_each_cache, LEAN_DEFAULT_FOR_EACH_CACHE_CAPACITY)
public:
for_each_fn(std::function<bool(expr const &)> && f):m_f(f) {} // NOLINT
for_each_fn(std::function<bool(expr const &)> const & f):m_f(f) {} // NOLINT
void operator()(expr const & e) { apply(e); }
};

class for_each_fn {
for_each_cache_ref m_cache;
class for_each_offset_fn {
struct key_hasher {
std::size_t operator()(std::pair<lean_object *, unsigned> const & p) const {
return hash((size_t)p.first, p.second);
}
};
std::unordered_set<std::pair<lean_object *, unsigned>, key_hasher> m_cache;
std::function<bool(expr const &, unsigned)> m_f; // NOLINT

bool visited(expr const & e, unsigned offset) {
if (!is_shared(e)) return false;
if (m_cache.find(std::make_pair(e.raw(), offset)) != m_cache.end()) return true;
m_cache.insert(std::make_pair(e.raw(), offset));
return false;
}

void apply(expr const & e, unsigned offset) {
buffer<pair<expr const &, unsigned>> todo;
todo.emplace_back(e, offset);
while (true) {
begin_loop:
if (todo.empty())
break;
check_memory("expression traversal");
auto p = todo.back();
todo.pop_back();
expr const & e = p.first;
unsigned offset = p.second;

switch (e.kind()) {
case expr_kind::Const: case expr_kind::BVar:
case expr_kind::Sort:
m_f(e, offset);
goto begin_loop;
default:
break;
}

if (is_shared(e) && m_cache->visited(e, offset))
goto begin_loop;

if (!m_f(e, offset))
goto begin_loop;

switch (e.kind()) {
case expr_kind::Const: case expr_kind::BVar:
case expr_kind::Sort: case expr_kind::Lit:
case expr_kind::MVar: case expr_kind::FVar:
goto begin_loop;
case expr_kind::MData:
todo.emplace_back(mdata_expr(e), offset);
goto begin_loop;
case expr_kind::Proj:
todo.emplace_back(proj_expr(e), offset);
goto begin_loop;
case expr_kind::App:
todo.emplace_back(app_arg(e), offset);
todo.emplace_back(app_fn(e), offset);
goto begin_loop;
case expr_kind::Lambda: case expr_kind::Pi:
todo.emplace_back(binding_body(e), offset + 1);
todo.emplace_back(binding_domain(e), offset);
goto begin_loop;
case expr_kind::Let:
todo.emplace_back(let_body(e), offset + 1);
todo.emplace_back(let_value(e), offset);
todo.emplace_back(let_type(e), offset);
goto begin_loop;
}
switch (e.kind()) {
case expr_kind::Const: case expr_kind::BVar: case expr_kind::Sort:
m_f(e, offset);
return;
default:
break;
}

if (visited(e, offset))
return;

if (!m_f(e, offset))
return;

switch (e.kind()) {
case expr_kind::Const: case expr_kind::BVar:
case expr_kind::Sort: case expr_kind::Lit:
case expr_kind::MVar: case expr_kind::FVar:
return;
case expr_kind::MData:
apply(mdata_expr(e), offset);
return;
case expr_kind::Proj:
apply(proj_expr(e), offset);
return;
case expr_kind::App:
apply(app_fn(e), offset);
apply(app_arg(e), offset);
return;
case expr_kind::Lambda: case expr_kind::Pi:
apply(binding_domain(e), offset);
apply(binding_body(e), offset+1);
return;
case expr_kind::Let:
apply(let_type(e), offset);
apply(let_value(e), offset);
apply(let_body(e), offset+1);
return;
}
}

public:
for_each_fn(std::function<bool(expr const &, unsigned)> && f):m_f(f) {} // NOLINT
for_each_fn(std::function<bool(expr const &, unsigned)> const & f):m_f(f) {} // NOLINT
for_each_offset_fn(std::function<bool(expr const &, unsigned)> && f):m_f(f) {} // NOLINT
for_each_offset_fn(std::function<bool(expr const &, unsigned)> const & f):m_f(f) {} // NOLINT
void operator()(expr const & e) { apply(e, 0); }
};

void for_each(expr const & e, std::function<bool(expr const &)> && f) { // NOLINT
return for_each_fn<true>(f)(e);
}

void for_each(expr const & e, std::function<bool(expr const &, unsigned)> && f) { // NOLINT
return for_each_fn(f)(e);
return for_each_offset_fn(f)(e);
}

extern "C" LEAN_EXPORT obj_res lean_find_expr(b_obj_arg p, b_obj_arg e_) {
lean_object * found = nullptr;
expr const & e = TO_REF(expr, e_);
for_each_fn<true>([&](expr const & e) {
if (found != nullptr) return false;
lean_inc(p);
lean_inc(e.raw());
if (lean_unbox(lean_apply_1(p, e.raw()))) {
found = e.raw();
return false;
}
return true;
})(e);
if (found) {
lean_inc(found);
lean_object * r = lean_alloc_ctor(1, 1, 0);
lean_ctor_set(r, 0, found);
return r;
} else {
return lean_box(0);
}
}

/*
Similar to `lean_find_expr`, but `p` returns
```
inductive FindStep where
/-- Found desired subterm -/ | found
/-- Search subterms -/ | visit
/-- Do not search subterms -/ | done
```
*/
extern "C" LEAN_EXPORT obj_res lean_find_ext_expr(b_obj_arg p, b_obj_arg e_) {
lean_object * found = nullptr;
expr const & e = TO_REF(expr, e_);
// Recall that `findExt?` skips partial applications.
for_each_fn<false>([&](expr const & e) {
if (found != nullptr) return false;
lean_inc(p);
lean_inc(e.raw());
switch(lean_unbox(lean_apply_1(p, e.raw()))) {
case 0: // found
found = e.raw();
return false;
case 1: // visit
return true;
case 2: // done
return false;
default:
lean_unreachable();
}
})(e);
if (found) {
lean_inc(found);
lean_object * r = lean_alloc_ctor(1, 1, 0);
lean_ctor_set(r, 0, found);
return r;
} else {
return lean_box(0);
}
}
}
15 changes: 9 additions & 6 deletions src/kernel/for_each_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,18 @@ Author: Leonardo de Moura
#include "kernel/expr_sets.h"

namespace lean {
/** \brief Expression visitor.
/**
\brief Expression visitor.
The argument \c f must be a lambda (function object) containing the method
The argument \c f must be a lambda (function object) containing the method
<code>
bool operator()(expr const & e, unsigned offset)
</code>
<code>
bool operator()(expr const & e, unsigned offset)
</code>
The \c offset is the number of binders under which \c e occurs.
The \c offset is the number of binders under which \c e occurs.
*/
void for_each(expr const & e, std::function<bool(expr const &, unsigned)> && f); // NOLINT

void for_each(expr const & e, std::function<bool(expr const &)> && f); // NOLINT
}
Loading

0 comments on commit 6c33b9c

Please sign in to comment.