Skip to content

Commit 0a86d2a

Browse files
committed
perf: add replaceImpl
It uses the kernel implementation. We will replace `Expr.replace` with it after update-stage0
1 parent 411cb48 commit 0a86d2a

File tree

2 files changed

+74
-1
lines changed

2 files changed

+74
-1
lines changed

src/Lean/Util/ReplaceExpr.lean

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ def replaceNoCache (f? : Expr → Option Expr) (e : Expr) : Expr :=
7777
| .proj _ _ b => let b := replaceNoCache f? b; e.updateProj! b
7878
| e => e
7979

80+
81+
@[extern "lean_replace_expr"]
82+
opaque replaceImpl (f? : @& (Expr → Option Expr)) (e : @& Expr) : Expr
83+
8084
@[implemented_by ReplaceImpl.replaceUnsafe]
81-
partial def replace (f? : Expr → Option Expr) (e : Expr) : Expr :=
85+
def replace (f? : Expr → Option Expr) (e : Expr) : Expr :=
8286
e.replaceNoCache f?

src/kernel/replace_fn.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,73 @@ class replace_rec_fn {
8181
expr replace(expr const & e, std::function<optional<expr>(expr const &, unsigned)> const & f, bool use_cache) {
8282
return replace_rec_fn(f, use_cache)(e);
8383
}
84+
85+
class replace_fn {
86+
std::unordered_map<lean_object *, expr> m_cache;
87+
lean_object * m_f;
88+
89+
expr save_result(expr const & e, expr const & r, bool shared) {
90+
if (shared)
91+
m_cache.insert(mk_pair(e.raw(), r));
92+
return r;
93+
}
94+
95+
expr apply(expr const & e) {
96+
bool shared = false;
97+
if (is_shared(e)) {
98+
auto it = m_cache.find(e.raw());
99+
if (it != m_cache.end())
100+
return it->second;
101+
shared = true;
102+
}
103+
104+
lean_inc(e.raw());
105+
lean_inc_ref(m_f);
106+
lean_object * r = lean_apply_1(m_f, e.raw());
107+
if (!lean_is_scalar(r)) {
108+
expr e_new(lean_ctor_get(r, 0));
109+
lean_dec_ref(r);
110+
return save_result(e, e_new, shared);
111+
}
112+
113+
switch (e.kind()) {
114+
case expr_kind::Const: case expr_kind::Sort:
115+
case expr_kind::BVar: case expr_kind::Lit:
116+
case expr_kind::MVar: case expr_kind::FVar:
117+
return save_result(e, e, shared);
118+
case expr_kind::MData: {
119+
expr new_e = apply(mdata_expr(e));
120+
return save_result(e, update_mdata(e, new_e), shared);
121+
}
122+
case expr_kind::Proj: {
123+
expr new_e = apply(proj_expr(e));
124+
return save_result(e, update_proj(e, new_e), shared);
125+
}
126+
case expr_kind::App: {
127+
expr new_f = apply(app_fn(e));
128+
expr new_a = apply(app_arg(e));
129+
return save_result(e, update_app(e, new_f, new_a), shared);
130+
}
131+
case expr_kind::Pi: case expr_kind::Lambda: {
132+
expr new_d = apply(binding_domain(e));
133+
expr new_b = apply(binding_body(e));
134+
return save_result(e, update_binding(e, new_d, new_b), shared);
135+
}
136+
case expr_kind::Let: {
137+
expr new_t = apply(let_type(e));
138+
expr new_v = apply(let_value(e));
139+
expr new_b = apply(let_body(e));
140+
return save_result(e, update_let(e, new_t, new_v, new_b), shared);
141+
}}
142+
lean_unreachable();
143+
}
144+
public:
145+
replace_fn(lean_object * f):m_f(f) {}
146+
expr operator()(expr const & e) { return apply(e); }
147+
};
148+
149+
extern "C" LEAN_EXPORT obj_res lean_replace_expr(b_obj_arg f, b_obj_arg e) {
150+
expr r = replace_fn(f)(TO_REF(expr, e));
151+
return r.steal();
152+
}
84153
}

0 commit comments

Comments
 (0)