From afe0b5a01326a99fb18049d29d7f62a5ffc83c15 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 30 Jul 2024 20:35:45 +0200 Subject: [PATCH] perf: precise cache for `foldConsts` (#4871) It addresses a performance issue at https://github.com/leanprover/LNSym/blob/proof_size_expt/Proofs/SHA512/Experiments/Sym20.lean --- src/Lean/Util/FindExpr.lean | 1 - src/Lean/Util/FoldConsts.lean | 63 +++++++++++++---------------------- 2 files changed, 23 insertions(+), 41 deletions(-) diff --git a/src/Lean/Util/FindExpr.lean b/src/Lean/Util/FindExpr.lean index 32795ec30ca5..82bde909eaf0 100644 --- a/src/Lean/Util/FindExpr.lean +++ b/src/Lean/Util/FindExpr.lean @@ -5,7 +5,6 @@ Authors: Leonardo de Moura -/ prelude import Lean.Expr -import Lean.Util.PtrSet namespace Lean namespace Expr diff --git a/src/Lean/Util/FoldConsts.lean b/src/Lean/Util/FoldConsts.lean index e01c0d9a8bc0..cd1608bf652e 100644 --- a/src/Lean/Util/FoldConsts.lean +++ b/src/Lean/Util/FoldConsts.lean @@ -11,52 +11,35 @@ namespace Lean namespace Expr namespace FoldConstsImpl -abbrev cacheSize : USize := 8192 - 1 +unsafe structure State where + visited : PtrSet Expr := mkPtrSet + visitedConsts : NameHashSet := {} -structure State where - visitedTerms : Array Expr -- Remark: cache based on pointer address. Our "unsafe" implementation relies on the fact that `()` is not a valid Expr - visitedConsts : NameHashSet -- cache based on structural equality +unsafe abbrev FoldM := StateM State -abbrev FoldM := StateM State - -unsafe def visited (e : Expr) (size : USize) : FoldM Bool := do - let s ← get - let h := ptrAddrUnsafe e - let i := h % size - let k := s.visitedTerms.uget i lcProof - if ptrAddrUnsafe k == h then pure true - else do - modify fun s => { s with visitedTerms := s.visitedTerms.uset i e lcProof } - pure false - -unsafe def fold {α : Type} (f : Name → α → α) (size : USize) (e : Expr) (acc : α) : FoldM α := +unsafe def fold {α : Type} (f : Name → α → α) (e : Expr) (acc : α) : FoldM α := let rec visit (e : Expr) (acc : α) : FoldM α := do - if (← visited e size) then - pure acc - else - match e with - | Expr.forallE _ d b _ => visit b (← visit d acc) - | Expr.lam _ d b _ => visit b (← visit d acc) - | Expr.mdata _ b => visit b acc - | Expr.letE _ t v b _ => visit b (← visit v (← visit t acc)) - | Expr.app f a => visit a (← visit f acc) - | Expr.proj _ _ b => visit b acc - | Expr.const c _ => - let s ← get - if s.visitedConsts.contains c then - pure acc - else do - modify fun s => { s with visitedConsts := s.visitedConsts.insert c }; - pure $ f c acc - | _ => pure acc + if (← get).visited.contains e then + return acc + modify fun s => { s with visited := s.visited.insert e } + match e with + | .forallE _ d b _ => visit b (← visit d acc) + | .lam _ d b _ => visit b (← visit d acc) + | .mdata _ b => visit b acc + | .letE _ t v b _ => visit b (← visit v (← visit t acc)) + | .app f a => visit a (← visit f acc) + | .proj _ _ b => visit b acc + | .const c _ => + if (← get).visitedConsts.contains c then + return acc + else + modify fun s => { s with visitedConsts := s.visitedConsts.insert c }; + return f c acc + | _ => return acc visit e acc -unsafe def initCache : State := - { visitedTerms := mkArray cacheSize.toNat (cast lcProof ()), - visitedConsts := {} } - @[inline] unsafe def foldUnsafe {α : Type} (e : Expr) (init : α) (f : Name → α → α) : α := - (fold f cacheSize e init).run' initCache + (fold f e init).run' {} end FoldConstsImpl