Skip to content

Commit

Permalink
perf: precise cache for foldConsts (#4871)
Browse files Browse the repository at this point in the history
  • Loading branch information
leodemoura authored Jul 30, 2024
1 parent 90dab5e commit afe0b5a
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 41 deletions.
1 change: 0 additions & 1 deletion src/Lean/Util/FindExpr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ Authors: Leonardo de Moura
-/
prelude
import Lean.Expr
import Lean.Util.PtrSet

namespace Lean
namespace Expr
Expand Down
63 changes: 23 additions & 40 deletions src/Lean/Util/FoldConsts.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,52 +11,35 @@ namespace Lean
namespace Expr
namespace FoldConstsImpl

abbrev cacheSize : USize := 8192 - 1
unsafe structure State where
visited : PtrSet Expr := mkPtrSet
visitedConsts : NameHashSet := {}

structure State where
visitedTerms : Array Expr -- Remark: cache based on pointer address. Our "unsafe" implementation relies on the fact that `()` is not a valid Expr
visitedConsts : NameHashSet -- cache based on structural equality
unsafe abbrev FoldM := StateM State

abbrev FoldM := StateM State

unsafe def visited (e : Expr) (size : USize) : FoldM Bool := do
let s ← get
let h := ptrAddrUnsafe e
let i := h % size
let k := s.visitedTerms.uget i lcProof
if ptrAddrUnsafe k == h then pure true
else do
modify fun s => { s with visitedTerms := s.visitedTerms.uset i e lcProof }
pure false

unsafe def fold {α : Type} (f : Name → α → α) (size : USize) (e : Expr) (acc : α) : FoldM α :=
unsafe def fold {α : Type} (f : Name → α → α) (e : Expr) (acc : α) : FoldM α :=
let rec visit (e : Expr) (acc : α) : FoldM α := do
if (← visited e size) then
pure acc
else
match e with
| Expr.forallE _ d b _ => visit b (← visit d acc)
| Expr.lam _ d b _ => visit b (← visit d acc)
| Expr.mdata _ b => visit b acc
| Expr.letE _ t v b _ => visit b (← visit v (← visit t acc))
| Expr.app f a => visit a (← visit f acc)
| Expr.proj _ _ b => visit b acc
| Expr.const c _ =>
let s ← get
if s.visitedConsts.contains c then
pure acc
else do
modify fun s => { s with visitedConsts := s.visitedConsts.insert c };
pure $ f c acc
| _ => pure acc
if (← get).visited.contains e then
return acc
modify fun s => { s with visited := s.visited.insert e }
match e with
| .forallE _ d b _ => visit b (← visit d acc)
| .lam _ d b _ => visit b (← visit d acc)
| .mdata _ b => visit b acc
| .letE _ t v b _ => visit b (← visit v (← visit t acc))
| .app f a => visit a (← visit f acc)
| .proj _ _ b => visit b acc
| .const c _ =>
if (← get).visitedConsts.contains c then
return acc
else
modify fun s => { s with visitedConsts := s.visitedConsts.insert c };
return f c acc
| _ => return acc
visit e acc

unsafe def initCache : State :=
{ visitedTerms := mkArray cacheSize.toNat (cast lcProof ()),
visitedConsts := {} }

@[inline] unsafe def foldUnsafe {α : Type} (e : Expr) (init : α) (f : Name → α → α) : α :=
(fold f cacheSize e init).run' initCache
(fold f e init).run' {}

end FoldConstsImpl

Expand Down

0 comments on commit afe0b5a

Please sign in to comment.