Skip to content

Commit 1e9d96b

Browse files
authored
perf: add lean_instantiate_level_mvars (#4910)
The new code is not active yet because of bootstrapping issues. It requires an `update_stage0`.
1 parent 647a5e9 commit 1e9d96b

File tree

6 files changed

+132
-3
lines changed

6 files changed

+132
-3
lines changed

src/Lean/MetavarContext.lean

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,8 @@ structure MetavarContext where
336336
For more information about delayed abstraction, see the docstring for `DelayedMetavarAssignment`. -/
337337
dAssignment : PersistentHashMap MVarId DelayedMetavarAssignment := {}
338338

339+
instance : Inhabited MetavarContext := ⟨{}⟩
340+
339341
/-- A monad with a stateful metavariable context, defining `getMCtx` and `modifyMCtx`. -/
340342
class MonadMCtx (m : TypeType) where
341343
getMCtx : m MetavarContext
@@ -358,15 +360,27 @@ abbrev setMCtx [MonadMCtx m] (mctx : MetavarContext) : m Unit :=
358360
abbrev getLevelMVarAssignment? [Monad m] [MonadMCtx m] (mvarId : LMVarId) : m (Option Level) :=
359361
return (← getMCtx).lAssignment.find? mvarId
360362

363+
@[export lean_get_lmvar_assignment]
364+
def getLevelMVarAssignmentExp (m : MetavarContext) (mvarId : LMVarId) : Option Level :=
365+
m.lAssignment.find? mvarId
366+
361367
def MetavarContext.getExprAssignmentCore? (m : MetavarContext) (mvarId : MVarId) : Option Expr :=
362368
m.eAssignment.find? mvarId
363369

370+
@[export lean_get_mvar_assignment]
371+
def MetavarContext.getExprAssignmentExp (m : MetavarContext) (mvarId : MVarId) : Option Expr :=
372+
m.eAssignment.find? mvarId
373+
364374
def getExprMVarAssignment? [Monad m] [MonadMCtx m] (mvarId : MVarId) : m (Option Expr) :=
365375
return (← getMCtx).getExprAssignmentCore? mvarId
366376

367377
def MetavarContext.getDelayedMVarAssignmentCore? (mctx : MetavarContext) (mvarId : MVarId) : Option DelayedMetavarAssignment :=
368378
mctx.dAssignment.find? mvarId
369379

380+
@[export lean_get_delayed_mvar_assignment]
381+
def MetavarContext.getDelayedMVarAssignmentExp (mctx : MetavarContext) (mvarId : MVarId) : Option DelayedMetavarAssignment :=
382+
mctx.dAssignment.find? mvarId
383+
370384
def getDelayedMVarAssignment? [Monad m] [MonadMCtx m] (mvarId : MVarId) : m (Option DelayedMetavarAssignment) :=
371385
return (← getMCtx).getDelayedMVarAssignmentCore? mvarId
372386

@@ -478,6 +492,10 @@ def hasAssignableMVar [Monad m] [MonadMCtx m] : Expr → m Bool
478492
def assignLevelMVar [MonadMCtx m] (mvarId : LMVarId) (val : Level) : m Unit :=
479493
modifyMCtx fun m => { m with lAssignment := m.lAssignment.insert mvarId val }
480494

495+
@[export lean_assign_lmvar]
496+
def assignLevelMVarExp (m : MetavarContext) (mvarId : LMVarId) (val : Level) : MetavarContext :=
497+
{ m with lAssignment := m.lAssignment.insert mvarId val }
498+
481499
/--
482500
Add `mvarId := x` to the metavariable assignment.
483501
This method does not check whether `mvarId` is already assigned, nor it checks whether
@@ -487,6 +505,10 @@ This is a low-level API, and it is safer to use `isDefEq (mkMVar mvarId) x`.
487505
def _root_.Lean.MVarId.assign [MonadMCtx m] (mvarId : MVarId) (val : Expr) : m Unit :=
488506
modifyMCtx fun m => { m with eAssignment := m.eAssignment.insert mvarId val }
489507

508+
@[export lean_assign_mvar]
509+
def assignExp (m : MetavarContext) (mvarId : MVarId) (val : Expr) : MetavarContext :=
510+
{ m with eAssignment := m.eAssignment.insert mvarId val }
511+
490512
/--
491513
Add a delayed assignment for the given metavariable. You must make sure that
492514
the metavariable is not already assigned or delayed-assigned.
@@ -516,6 +538,9 @@ To avoid this term eta-expanded term, we apply beta-reduction when instantiating
516538
This operation is performed at `instantiateExprMVars`, `elimMVarDeps`, and `levelMVarToParam`.
517539
-/
518540

541+
@[extern "lean_instantiate_level_mvars"]
542+
opaque instantiateLevelMVarsImp (mctx : MetavarContext) (l : Level) : MetavarContext × Level
543+
519544
partial def instantiateLevelMVars [Monad m] [MonadMCtx m] : Level → m Level
520545
| lvl@(Level.succ lvl₁) => return Level.updateSucc! lvl (← instantiateLevelMVars lvl₁)
521546
| lvl@(Level.max lvl₁ lvl₂) => return Level.updateMax! lvl (← instantiateLevelMVars lvl₁) (← instantiateLevelMVars lvl₂)
@@ -531,6 +556,9 @@ partial def instantiateLevelMVars [Monad m] [MonadMCtx m] : Level → m Level
531556
| none => pure lvl
532557
| lvl => pure lvl
533558

559+
@[extern "lean_instantiate_expr_mvars"]
560+
opaque instantiateExprMVarsImp (mctx : MetavarContext) (e : Expr) : MetavarContext × Expr
561+
534562
/-- instantiateExprMVars main function -/
535563
partial def instantiateExprMVars [Monad m] [MonadMCtx m] [STWorld ω m] [MonadLiftT (ST ω) m] (e : Expr) : MonadCacheT ExprStructEq Expr m Expr :=
536564
if !e.hasMVar then
@@ -792,8 +820,6 @@ def localDeclDependsOnPred [Monad m] [MonadMCtx m] (localDecl : LocalDecl) (pf :
792820

793821
namespace MetavarContext
794822

795-
instance : Inhabited MetavarContext := ⟨{}⟩
796-
797823
@[export lean_mk_metavar_ctx]
798824
def mkMetavarContext : Unit → MetavarContext := fun _ => {}
799825

src/kernel/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ add_library(kernel OBJECT level.cpp expr.cpp expr_eq_fn.cpp
22
for_each_fn.cpp replace_fn.cpp abstract.cpp instantiate.cpp
33
local_ctx.cpp declaration.cpp environment.cpp type_checker.cpp
44
init_module.cpp expr_cache.cpp equiv_manager.cpp quot.cpp
5-
inductive.cpp trace.cpp)
5+
inductive.cpp trace.cpp instantiate_mvars.cpp)

src/kernel/instantiate_mvars.cpp

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
Released under Apache 2.0 license as described in the file LICENSE.
4+
5+
Authors: Leonardo de Moura
6+
*/
7+
#include <unordered_map>
8+
#include "runtime/option_ref.h"
9+
#include "kernel/instantiate.h"
10+
#include "kernel/abstract.h"
11+
12+
/*
13+
This module is not used by the kernel. It just provides an efficient implementation of
14+
`instantiateExprMVars`
15+
*/
16+
17+
namespace lean {
18+
19+
extern "C" object * lean_get_lmvar_assignment(obj_arg mctx, obj_arg mid);
20+
extern "C" object * lean_assign_lmvar(obj_arg mctx, obj_arg mid, obj_arg val);
21+
22+
typedef object_ref metavar_ctx;
23+
void assign_lmvar(metavar_ctx & mctx, name const & mid, level const & l) {
24+
object * r = lean_assign_lmvar(mctx.steal(), mid.to_obj_arg(), l.to_obj_arg());
25+
mctx.set_box(r);
26+
}
27+
28+
option_ref<level> get_lmvar_assignment(metavar_ctx & mctx, name const & mid) {
29+
return option_ref<level>(lean_get_lmvar_assignment(mctx.to_obj_arg(), mid.to_obj_arg()));
30+
}
31+
32+
class instantiate_lmvar_fn {
33+
metavar_ctx & m_mctx;
34+
std::unordered_map<lean_object *, lean_object *> m_cache;
35+
36+
inline level cache(level const & l, level && r, bool shared) {
37+
if (shared) {
38+
m_cache.insert(mk_pair(l.raw(), r.raw()));
39+
}
40+
return r;
41+
}
42+
public:
43+
instantiate_lmvar_fn(metavar_ctx & mctx):m_mctx(mctx) {}
44+
level visit(level const & l) {
45+
if (!has_mvar(l))
46+
return l;
47+
bool shared = false;
48+
if (is_shared(l)) {
49+
auto it = m_cache.find(l.raw());
50+
if (it != m_cache.end()) {
51+
return level(it->second, true);
52+
}
53+
shared = true;
54+
}
55+
switch (l.kind()) {
56+
case level_kind::Succ:
57+
return cache(l, update_succ(l, visit(succ_of(l))), shared);
58+
case level_kind::Max: case level_kind::IMax:
59+
return cache(l, update_max(l, visit(level_lhs(l)), visit(level_rhs(l))), shared);
60+
case level_kind::Zero: case level_kind::Param:
61+
lean_unreachable();
62+
case level_kind::MVar: {
63+
option_ref<level> r = get_lmvar_assignment(m_mctx, mvar_id(l));
64+
if (!r) {
65+
return l;
66+
} else {
67+
level a(r.get_val());
68+
if (!has_mvar(a)) {
69+
return a;
70+
} else {
71+
level a_new = visit(a);
72+
if (!is_eqp(a, a_new)) {
73+
assign_lmvar(m_mctx, mvar_id(l), a_new);
74+
}
75+
return a_new;
76+
}
77+
}
78+
}}
79+
}
80+
level operator()(level const & l) { return visit(l); }
81+
};
82+
83+
extern "C" LEAN_EXPORT object * lean_instantiate_level_mvars(object * m, object * l) {
84+
metavar_ctx mctx(m);
85+
level l_new = instantiate_lmvar_fn(mctx)(level(l));
86+
object * r = alloc_cnstr(0, 2, 0);
87+
cnstr_set(r, 0, mctx.steal());
88+
cnstr_set(r, 1, l_new.steal());
89+
return r;
90+
}
91+
92+
extern "C" LEAN_EXPORT object * lean_instantiate_expr_mvars(object *, object *) {
93+
lean_internal_panic("not implemented yet");
94+
}
95+
}

src/kernel/level.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ inline bool operator!=(level const & l1, level const & l2) { return !operator==(
8282
struct level_hash { unsigned operator()(level const & n) const { return n.hash(); } };
8383
struct level_eq { bool operator()(level const & n1, level const & n2) const { return n1 == n2; } };
8484

85+
inline bool is_shared(level const & l) { return !is_exclusive(l.raw()); }
86+
8587
inline optional<level> none_level() { return optional<level>(); }
8688
inline optional<level> some_level(level const & e) { return optional<level>(e); }
8789
inline optional<level> some_level(level && e) { return optional<level>(std::forward<level>(e)); }

src/runtime/object_ref.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ class object_ref {
3535
s.m_obj = box(0);
3636
return *this;
3737
}
38+
void set_box(object * o) {
39+
lean_assert(is_scalar(m_obj));
40+
m_obj = o;
41+
}
3842
object * raw() const { return m_obj; }
3943
object * steal() { object * r = m_obj; m_obj = box(0); return r; }
4044
object * to_obj_arg() const { inc(m_obj); return m_obj; }

src/runtime/option_ref.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ class option_ref : public object_ref {
2828
explicit operator bool() const { return !is_scalar(raw()); }
2929
optional<T> get() const { return *this ? some(static_cast<T const &>(cnstr_get_ref(*this, 0))) : optional<T>(); }
3030

31+
T get_val() const { lean_assert(*this); return static_cast<T const &>(cnstr_get_ref(*this, 0)); }
32+
3133
/** \brief Structural equality. */
3234
friend bool operator==(option_ref const & o1, option_ref const & o2) {
3335
return o1.get() == o2.get();

0 commit comments

Comments
 (0)