Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: precise cache for foldConsts #4871

Merged
merged 2 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/Lean/Util/FindExpr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ Authors: Leonardo de Moura
-/
prelude
import Lean.Expr
import Lean.Util.PtrSet

namespace Lean
namespace Expr
Expand Down
63 changes: 23 additions & 40 deletions src/Lean/Util/FoldConsts.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,52 +11,35 @@ namespace Lean
namespace Expr
namespace FoldConstsImpl

abbrev cacheSize : USize := 8192 - 1
unsafe structure State where
visited : PtrSet Expr := mkPtrSet
visitedConsts : NameHashSet := {}

structure State where
visitedTerms : Array Expr -- Remark: cache based on pointer address. Our "unsafe" implementation relies on the fact that `()` is not a valid Expr
visitedConsts : NameHashSet -- cache based on structural equality
unsafe abbrev FoldM := StateM State

abbrev FoldM := StateM State

unsafe def visited (e : Expr) (size : USize) : FoldM Bool := do
let s ← get
let h := ptrAddrUnsafe e
let i := h % size
let k := s.visitedTerms.uget i lcProof
if ptrAddrUnsafe k == h then pure true
else do
modify fun s => { s with visitedTerms := s.visitedTerms.uset i e lcProof }
pure false

unsafe def fold {α : Type} (f : Name → α → α) (size : USize) (e : Expr) (acc : α) : FoldM α :=
unsafe def fold {α : Type} (f : Name → α → α) (e : Expr) (acc : α) : FoldM α :=
let rec visit (e : Expr) (acc : α) : FoldM α := do
if (← visited e size) then
pure acc
else
match e with
| Expr.forallE _ d b _ => visit b (← visit d acc)
| Expr.lam _ d b _ => visit b (← visit d acc)
| Expr.mdata _ b => visit b acc
| Expr.letE _ t v b _ => visit b (← visit v (← visit t acc))
| Expr.app f a => visit a (← visit f acc)
| Expr.proj _ _ b => visit b acc
| Expr.const c _ =>
let s ← get
if s.visitedConsts.contains c then
pure acc
else do
modify fun s => { s with visitedConsts := s.visitedConsts.insert c };
pure $ f c acc
| _ => pure acc
if (← get).visited.contains e then
return acc
modify fun s => { s with visited := s.visited.insert e }
match e with
| .forallE _ d b _ => visit b (← visit d acc)
| .lam _ d b _ => visit b (← visit d acc)
| .mdata _ b => visit b acc
| .letE _ t v b _ => visit b (← visit v (← visit t acc))
| .app f a => visit a (← visit f acc)
| .proj _ _ b => visit b acc
| .const c _ =>
if (← get).visitedConsts.contains c then
return acc
else
modify fun s => { s with visitedConsts := s.visitedConsts.insert c };
return f c acc
| _ => return acc
visit e acc

unsafe def initCache : State :=
{ visitedTerms := mkArray cacheSize.toNat (cast lcProof ()),
visitedConsts := {} }

@[inline] unsafe def foldUnsafe {α : Type} (e : Expr) (init : α) (f : Name → α → α) : α :=
(fold f cacheSize e init).run' initCache
(fold f e init).run' {}

end FoldConstsImpl

Expand Down
Loading