Skip to content

Commit a0f73c3

Browse files
committed
feat: theorem patterns for heuristic instantiation in grind
This PR implements the command `grind_pattern`. The new command allows users to associate patterns with theorems. These patterns are used for performing heuristic instantiation with e-matching. In the future, we will add the attributes `@[grind_eq]`, `@[grind_fwd]`, and `@[grind_bwd]` to compute the patterns automatically for theorems.
1 parent 536c6a8 commit a0f73c3

File tree

4 files changed

+213
-0
lines changed

4 files changed

+213
-0
lines changed

src/Lean/Elab/Tactic/Grind.lean

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,29 @@ Authors: Leonardo de Moura
66
prelude
77
import Init.Grind.Tactics
88
import Lean.Meta.Tactic.Grind
9+
import Lean.Elab.Command
910
import Lean.Elab.Tactic.Basic
1011

12+
1113
namespace Lean.Elab.Tactic
1214
open Meta
1315

16+
open Command Term in
17+
@[builtin_command_elab Lean.Parser.Command.grindPattern]
18+
def elabGrindPattern : CommandElab := fun stx => do
19+
match stx with
20+
| `(grind_pattern $thmName:ident => $terms,*) => do
21+
liftTermElabM do
22+
let declName ← resolveGlobalConstNoOverload thmName
23+
let info ← getConstInfo declName
24+
forallTelescope info.type fun xs _ => do
25+
let patterns ← terms.getElems.mapM fun term => do
26+
let pattern ← instantiateMVars (← elabTerm term none)
27+
let pattern ← Grind.unfoldReducible pattern
28+
return pattern.abstract xs
29+
Grind.addTheoremPattern declName xs.size patterns.toList
30+
| _ => throwUnsupportedSyntax
31+
1432
def grind (mvarId : MVarId) (mainDeclName : Name) : MetaM Unit := do
1533
let mvarIds ← Grind.main mvarId mainDeclName
1634
unless mvarIds.isEmpty do

src/Lean/Meta/Tactic/Grind.lean

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import Lean.Meta.Tactic.Grind.PP
2121
import Lean.Meta.Tactic.Grind.Simp
2222
import Lean.Meta.Tactic.Grind.Ctor
2323
import Lean.Meta.Tactic.Grind.Parser
24+
import Lean.Meta.Tactic.Grind.TheoremPatterns
2425

2526
namespace Lean
2627

@@ -35,5 +36,6 @@ builtin_initialize registerTraceClass `grind.simp
3536
builtin_initialize registerTraceClass `grind.congr
3637
builtin_initialize registerTraceClass `grind.proof
3738
builtin_initialize registerTraceClass `grind.proof.detail
39+
builtin_initialize registerTraceClass `grind.pattern
3840

3941
end Lean
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
/-
2+
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
Released under Apache 2.0 license as described in the file LICENSE.
4+
Authors: Leonardo de Moura
5+
-/
6+
prelude
7+
import Lean.HeadIndex
8+
import Lean.Util.FoldConsts
9+
import Lean.Meta.Basic
10+
import Lean.Meta.InferType
11+
12+
namespace Lean.Meta.Grind
13+
14+
inductive Origin where
15+
/-- A global declaration in the environment. -/
16+
| decl (declName : Name)
17+
/-- A local hypothesis. -/
18+
| fvar (fvarId : FVarId)
19+
/--
20+
A proof term provided directly to a call to `grind` where `ref`
21+
is the provided grind argument. The `id` is a unique identifier for the call.
22+
-/
23+
| stx (id : Name) (ref : Syntax)
24+
| other
25+
deriving Inhabited, Repr
26+
27+
structure TheoremPattern where
28+
proof : Expr
29+
numParams : Nat
30+
patterns : List Expr
31+
/-- Contains all symbols used in `pattterns`. -/
32+
symbols : List HeadIndex
33+
origin : Origin
34+
deriving Inhabited
35+
36+
abbrev TheoremPatterns := SMap Name (List TheoremPattern)
37+
38+
builtin_initialize theoremPatternsExt : SimpleScopedEnvExtension TheoremPattern TheoremPatterns ←
39+
registerSimpleScopedEnvExtension {
40+
addEntry := fun s t => Id.run do
41+
let .const declName :: _ := t.symbols | unreachable!
42+
if let some ts := s.find? declName then
43+
s.insert declName (t::ts)
44+
else
45+
s.insert declName [t]
46+
initial := .empty
47+
}
48+
49+
-- TODO: create attribute?
50+
private def forbiddenDeclNames := #[``Eq, ``HEq, ``Iff, ``And, ``Or, ``Not]
51+
52+
private def isForbidden (declName : Name) := forbiddenDeclNames.contains declName
53+
54+
private def dontCare := mkConst (Name.mkSimple "[grind_dontcare]")
55+
56+
private def mkGroundPattern (e : Expr) : Expr :=
57+
mkAnnotation `grind.ground_pat e
58+
59+
private def groundPattern? (e : Expr) : Option Expr :=
60+
annotation? `grind.ground_pat e
61+
62+
private def isGroundPattern (e : Expr) : Bool :=
63+
groundPattern? e |>.isSome
64+
65+
private def isAtomicPattern (e : Expr) : Bool :=
66+
e.isBVar || e == dontCare || isGroundPattern e
67+
68+
partial def ppPattern (pattern : Expr) : MessageData := Id.run do
69+
if let some e := groundPattern? pattern then
70+
return m!"`[{e}]"
71+
else if pattern == dontCare then
72+
return m!"?"
73+
else match pattern with
74+
| .bvar idx => return m!"#{idx}"
75+
| _ =>
76+
let mut r := m!"{pattern.getAppFn}"
77+
for arg in pattern.getAppArgs do
78+
let mut argFmt ← ppPattern arg
79+
if !isAtomicPattern arg then
80+
argFmt := MessageData.paren argFmt
81+
r := r ++ " " ++ argFmt
82+
return r
83+
84+
namespace NormalizePattern
85+
86+
structure State where
87+
symbols : Array HeadIndex := #[]
88+
symbolSet : Std.HashSet HeadIndex := {}
89+
bvarsFound : Std.HashSet Nat := {}
90+
91+
abbrev M := StateRefT State MetaM
92+
93+
private def saveSymbol (h : HeadIndex) : M Unit := do
94+
unless (← get).symbolSet.contains h do
95+
modify fun s => { s with symbols := s.symbols.push h, symbolSet := s.symbolSet.insert h }
96+
97+
private def foundBVar (idx : Nat) : M Bool :=
98+
return (← get).bvarsFound.contains idx
99+
100+
private def saveBVar (idx : Nat) : M Unit := do
101+
modify fun s => { s with bvarsFound := s.bvarsFound.insert idx }
102+
103+
private def getPatternFn? (pattern : Expr) : Option Expr :=
104+
if !pattern.isApp then
105+
none
106+
else match pattern.getAppFn with
107+
| f@(.const declName _) => if isForbidden declName then none else some f
108+
| f@(.fvar _) => some f
109+
| _ => none
110+
111+
private structure PatternFunInfo where
112+
instImplicitMask : Array Bool
113+
typeMask : Array Bool
114+
115+
private def getPatternFunInfo (f : Expr) (numArgs : Nat) : MetaM PatternFunInfo := do
116+
forallBoundedTelescope (← inferType f) numArgs fun xs _ => do
117+
let typeMask ← xs.mapM fun x => isTypeFormer x
118+
let instImplicitMask ← xs.mapM fun x => return (← x.fvarId!.getDecl).binderInfo matches .instImplicit
119+
return { typeMask, instImplicitMask }
120+
121+
private partial def go (pattern : Expr) (root := false) : M Expr := do
122+
if root && !pattern.hasLooseBVars then
123+
throwError "invalid pattern, it does not have pattern variables"
124+
let some f := getPatternFn? pattern
125+
| throwError "invalid pattern, (non-forbidden) application expected"
126+
assert! f.isConst || f.isFVar
127+
saveSymbol f.toHeadIndex
128+
let mut args := pattern.getAppArgs
129+
let { instImplicitMask, typeMask } ← getPatternFunInfo f args.size
130+
for i in [:args.size] do
131+
let arg := args[i]!
132+
let isType := typeMask[i]?.getD false
133+
let isInstImplicit := instImplicitMask[i]?.getD false
134+
let arg ← if !arg.hasLooseBVars then
135+
if arg.hasMVar then
136+
pure dontCare
137+
else
138+
pure <| mkGroundPattern arg
139+
else match arg with
140+
| .bvar idx =>
141+
if (isType || isInstImplicit) && (← foundBVar idx) then
142+
pure dontCare
143+
else
144+
saveBVar idx
145+
pure arg
146+
| _ =>
147+
if isType || isInstImplicit then
148+
pure dontCare
149+
else if let some _ := getPatternFn? arg then
150+
go arg
151+
else
152+
pure dontCare
153+
args := args.set! i arg
154+
return mkAppN f args
155+
156+
def main (patterns : List Expr) : MetaM (List Expr × List HeadIndex) := do
157+
let (patterns, s) ← patterns.mapM go |>.run {}
158+
return (patterns, s.symbols.toList)
159+
160+
end NormalizePattern
161+
162+
def addTheoremPattern (declName : Name) (numParams : Nat) (patterns : List Expr) : MetaM Unit := do
163+
let cinfo ← getConstInfo declName
164+
let us := cinfo.levelParams.map mkLevelParam
165+
let proof := mkConst declName us
166+
let (patterns, symbols) ← NormalizePattern.main patterns
167+
trace[grind.pattern] "{declName}: {patterns.map ppPattern}"
168+
theoremPatternsExt.add {
169+
proof, patterns, numParams, symbols
170+
origin := .decl declName
171+
}
172+
173+
end Lean.Meta.Grind

tests/lean/run/grind_pattern1.lean

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
set_option trace.grind.pattern true
2+
3+
/--
4+
info: [grind.pattern] Array.getElem_push_lt: [@getElem ? `[Nat] #4 ? ? (@Array.push ? #3 #2) #1 ?]
5+
-/
6+
#guard_msgs in
7+
grind_pattern Array.getElem_push_lt => (a.push x)[i]
8+
9+
10+
/--
11+
info: [grind.pattern] List.getElem_attach: [@getElem ? `[Nat] ? ? ? (@List.attach #3 #2) #1 ?]
12+
-/
13+
#guard_msgs in
14+
grind_pattern List.getElem_attach => xs.attach[i]
15+
16+
/--
17+
info: [grind.pattern] List.mem_concat_self: [@Membership.mem #2 ? ? (@HAppend.hAppend ? ? ? ? #1 (@List.cons ? #0 (@List.nil ?))) #0]
18+
-/
19+
#guard_msgs in
20+
grind_pattern List.mem_concat_self => a ∈ xs ++ [a]

0 commit comments

Comments
 (0)