diff --git a/src/Lean/Util/ReplaceExpr.lean b/src/Lean/Util/ReplaceExpr.lean index d7a364e5ca65..deca88cc73ab 100644 --- a/src/Lean/Util/ReplaceExpr.lean +++ b/src/Lean/Util/ReplaceExpr.lean @@ -77,6 +77,10 @@ def replaceNoCache (f? : Expr → Option Expr) (e : Expr) : Expr := | .proj _ _ b => let b := replaceNoCache f? b; e.updateProj! b | e => e + +@[extern "lean_replace_expr"] +opaque replaceImpl (f? : @& (Expr → Option Expr)) (e : @& Expr) : Expr + @[implemented_by ReplaceImpl.replaceUnsafe] -partial def replace (f? : Expr → Option Expr) (e : Expr) : Expr := +def replace (f? : Expr → Option Expr) (e : Expr) : Expr := e.replaceNoCache f? diff --git a/src/kernel/cache_stack.h b/src/kernel/cache_stack.h deleted file mode 100644 index 1c16f29103f9..000000000000 --- a/src/kernel/cache_stack.h +++ /dev/null @@ -1,41 +0,0 @@ -/* -Copyright (c) 2013 Microsoft Corporation. All rights reserved. -Released under Apache 2.0 license as described in the file LICENSE. - -Author: Leonardo de Moura -*/ -#pragma once -#include -#include -#include "runtime/debug.h" - -/** \brief Macro for creating a stack of objects of type Cache in thread local storage. - The argument \c Arg is provided to every new instance of Cache. - The macro provides the helper class Cache_ref that "reuses" cache objects from the stack. -*/ -#define MK_CACHE_STACK(Cache, Arg) \ -struct Cache ## _stack { \ - unsigned m_top; \ - std::vector> m_cache_stack; \ - Cache ## _stack():m_top(0) {} \ -}; \ -MK_THREAD_LOCAL_GET_DEF(Cache ## _stack, get_ ## Cache ## _stack); \ -class Cache ## _ref { \ - Cache * m_cache; \ -public: \ - Cache ## _ref() { \ - Cache ## _stack & s = get_ ## Cache ## _stack(); \ - lean_assert(s.m_top <= s.m_cache_stack.size()); \ - if (s.m_top == s.m_cache_stack.size()) \ - s.m_cache_stack.push_back(std::unique_ptr(new Cache(Arg))); \ - m_cache = s.m_cache_stack[s.m_top].get(); \ - s.m_top++; \ - } \ - ~Cache ## _ref() { \ - Cache ## _stack & s = get_ ## Cache ## _stack(); \ - lean_assert(s.m_top > 0); \ - s.m_top--; \ - m_cache->clear(); \ - } \ - Cache * operator->() const { return m_cache; } \ -}; diff --git a/src/kernel/replace_fn.cpp b/src/kernel/replace_fn.cpp index f0b995d6d1f9..23883f1d8bba 100644 --- a/src/kernel/replace_fn.cpp +++ b/src/kernel/replace_fn.cpp @@ -6,75 +6,35 @@ Author: Leonardo de Moura */ #include #include +#include #include "kernel/replace_fn.h" -#include "kernel/cache_stack.h" - -#ifndef LEAN_DEFAULT_REPLACE_CACHE_CAPACITY -#define LEAN_DEFAULT_REPLACE_CACHE_CAPACITY 1024*8 -#endif namespace lean { -struct replace_cache { - struct entry { - object * m_cell; - unsigned m_offset; - expr m_result; - entry():m_cell(nullptr) {} - }; - unsigned m_capacity; - std::vector m_cache; - std::vector m_used; - replace_cache(unsigned c):m_capacity(c), m_cache(c) {} - - expr * find(expr const & e, unsigned offset) { - unsigned i = hash(hash(e), offset) % m_capacity; - if (m_cache[i].m_cell == e.raw() && m_cache[i].m_offset == offset) - return &m_cache[i].m_result; - else - return nullptr; - } - - void insert(expr const & e, unsigned offset, expr const & v) { - unsigned i = hash(hash(e), offset) % m_capacity; - if (m_cache[i].m_cell == nullptr) - m_used.push_back(i); - m_cache[i].m_cell = e.raw(); - m_cache[i].m_offset = offset; - m_cache[i].m_result = v; - } - - void clear() { - for (unsigned i : m_used) { - m_cache[i].m_cell = nullptr; - m_cache[i].m_result = expr(); - } - m_used.clear(); - } -}; - -/* CACHE_RESET: NO */ -MK_CACHE_STACK(replace_cache, LEAN_DEFAULT_REPLACE_CACHE_CAPACITY) class replace_rec_fn { - replace_cache_ref m_cache; + struct key_hasher { + std::size_t operator()(std::pair const & p) const { + return hash((size_t)p.first, p.second); + } + }; + std::unordered_map, expr, key_hasher> m_cache; std::function(expr const &, unsigned)> m_f; bool m_use_cache; expr save_result(expr const & e, unsigned offset, expr const & r, bool shared) { if (shared) - m_cache->insert(e, offset, r); + m_cache.insert(mk_pair(mk_pair(e.raw(), offset), r)); return r; } expr apply(expr const & e, unsigned offset) { bool shared = false; if (m_use_cache && is_shared(e)) { - if (auto r = m_cache->find(e, offset)) - return *r; + auto it = m_cache.find(mk_pair(e.raw(), offset)); + if (it != m_cache.end()) + return it->second; shared = true; } - check_system("replace"); - if (optional r = m_f(e, offset)) { return save_result(e, offset, *r, shared); } else { @@ -121,4 +81,73 @@ class replace_rec_fn { expr replace(expr const & e, std::function(expr const &, unsigned)> const & f, bool use_cache) { return replace_rec_fn(f, use_cache)(e); } + +class replace_fn { + std::unordered_map m_cache; + lean_object * m_f; + + expr save_result(expr const & e, expr const & r, bool shared) { + if (shared) + m_cache.insert(mk_pair(e.raw(), r)); + return r; + } + + expr apply(expr const & e) { + bool shared = false; + if (is_shared(e)) { + auto it = m_cache.find(e.raw()); + if (it != m_cache.end()) + return it->second; + shared = true; + } + + lean_inc(e.raw()); + lean_inc_ref(m_f); + lean_object * r = lean_apply_1(m_f, e.raw()); + if (!lean_is_scalar(r)) { + expr e_new(lean_ctor_get(r, 0)); + lean_dec_ref(r); + return save_result(e, e_new, shared); + } + + switch (e.kind()) { + case expr_kind::Const: case expr_kind::Sort: + case expr_kind::BVar: case expr_kind::Lit: + case expr_kind::MVar: case expr_kind::FVar: + return save_result(e, e, shared); + case expr_kind::MData: { + expr new_e = apply(mdata_expr(e)); + return save_result(e, update_mdata(e, new_e), shared); + } + case expr_kind::Proj: { + expr new_e = apply(proj_expr(e)); + return save_result(e, update_proj(e, new_e), shared); + } + case expr_kind::App: { + expr new_f = apply(app_fn(e)); + expr new_a = apply(app_arg(e)); + return save_result(e, update_app(e, new_f, new_a), shared); + } + case expr_kind::Pi: case expr_kind::Lambda: { + expr new_d = apply(binding_domain(e)); + expr new_b = apply(binding_body(e)); + return save_result(e, update_binding(e, new_d, new_b), shared); + } + case expr_kind::Let: { + expr new_t = apply(let_type(e)); + expr new_v = apply(let_value(e)); + expr new_b = apply(let_body(e)); + return save_result(e, update_let(e, new_t, new_v, new_b), shared); + }} + lean_unreachable(); + } +public: + replace_fn(lean_object * f):m_f(f) {} + expr operator()(expr const & e) { return apply(e); } +}; + +extern "C" LEAN_EXPORT obj_res lean_replace_expr(b_obj_arg f, b_obj_arg e) { + expr r = replace_fn(f)(TO_REF(expr, e)); + return r.steal(); +} }