Skip to content

Commit

Permalink
feat: check pattern coverage in grind_pattern
Browse files Browse the repository at this point in the history
This PR adds pattern validation to the `grind_pattern` command.
The new `checkCoverage` function will also be used to implement the attributes
`@[grind_eq]`, `@[grind_fwd]`, and `@[grind_bwd]`.
  • Loading branch information
leodemoura committed Dec 30, 2024
1 parent 4e2b0f5 commit 1338a86
Show file tree
Hide file tree
Showing 2 changed files with 217 additions and 3 deletions.
132 changes: 129 additions & 3 deletions src/Lean/Meta/Tactic/Grind/TheoremPatterns.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Authors: Leonardo de Moura
prelude
import Lean.HeadIndex
import Lean.Util.FoldConsts
import Lean.Util.CollectFVars
import Lean.Meta.Basic
import Lean.Meta.InferType

Expand Down Expand Up @@ -153,19 +154,144 @@ private partial def go (pattern : Expr) (root := false) : M Expr := do
args := args.set! i arg
return mkAppN f args

def main (patterns : List Expr) : MetaM (List Expr × List HeadIndex) := do
def main (patterns : List Expr) : MetaM (List Expr × List HeadIndex × Std.HashSet Nat) := do
let (patterns, s) ← patterns.mapM go |>.run {}
return (patterns, s.symbols.toList)
return (patterns, s.symbols.toList, s.bvarsFound)

end NormalizePattern

/--
Returns `true` if free variables in `type` are not in `thmVars` or are in `fvarsFound`.
We use this function to check whether `type` is fully instantiated.
-/
private def checkTypeFVars (thmVars : FVarIdSet) (fvarsFound : FVarIdSet) (type : Expr) : Bool :=
let typeFVars := (collectFVars {} type).fvarIds
typeFVars.all fun fvarId => !thmVars.contains fvarId || fvarsFound.contains fvarId

/--
Given an type class instance type `instType`, returns true if free variables in input parameters
1- are not in `thmVars`, or
2- are in `fvarsFound`.
Remark: `fvarsFound` is a subset of `thmVars`
-/
private def canBeSynthesized (thmVars : FVarIdSet) (fvarsFound : FVarIdSet) (instType : Expr) : MetaM Bool := do
forallTelescopeReducing instType fun xs type => type.withApp fun classFn classArgs => do
for x in xs do
unless checkTypeFVars thmVars fvarsFound (← inferType x) do return false
forallBoundedTelescope (← inferType classFn) type.getAppNumArgs fun params _ => do
for param in params, classArg in classArgs do
let paramType ← inferType param
if !paramType.isAppOf ``semiOutParam && !paramType.isAppOf ``outParam then
unless checkTypeFVars thmVars fvarsFound classArg do
return false
return true

/--
Auxiliary type for the `checkCoverage` function.
-/
inductive CheckCoverageResult where
| /-- `checkCoverage` succeeded -/
ok
| /--
`checkCoverage` failed because some of the theorem parameters are missing,
`pos` contains their positions
-/
missing (pos : List Nat)

/--
After we process a set of patterns, we obtain the set of de Bruijn indices in these patterns.
We say they are pattern variables. This function checks whether the set of pattern variables is sufficient for
instantiating the theorem with proof `thmProof`. The theorem has `numParams` parameters.
The missing parameters:
1- we may be able to infer them using type inference or type class synthesis, or
2- they are propositions, and may become hypotheses of the instantiated theorem.
For type class instance parameters, we must check whether the free variables in class input parameters are available.
-/
private def checkCoverage (thmProof : Expr) (numParams : Nat) (bvarsFound : Std.HashSet Nat) : MetaM CheckCoverageResult := do
if bvarsFound.size == numParams then return .ok
forallBoundedTelescope (← inferType thmProof) numParams fun xs _ => do
assert! numParams == xs.size
let patternVars := bvarsFound.toList.map fun bidx => xs[numParams - bidx - 1]!.fvarId!
-- `xs` as a `FVarIdSet`.
let thmVars : FVarIdSet := RBTree.ofList <| xs.toList.map (·.fvarId!)
-- Collect free variables occurring in `e`, and insert the ones that are in `thmVars` into `fvarsFound`
let update (fvarsFound : FVarIdSet) (e : Expr) : FVarIdSet :=
(collectFVars {} e).fvarIds.foldl (init := fvarsFound) fun s fvarId =>
if thmVars.contains fvarId then s.insert fvarId else s
-- Theorem variables found so far. We initialize with the variables occurring in patterns
-- Remark: fvarsFound is a subset of thmVars
let mut fvarsFound : FVarIdSet := RBTree.ofList patternVars
for patternVar in patternVars do
let type ← patternVar.getType
fvarsFound := update fvarsFound type
if fvarsFound.size == numParams then return .ok
-- Now, we keep traversing remaining variables and collecting
-- `processed` contains the variables we have already processed.
let mut processed : FVarIdSet := RBTree.ofList patternVars
let mut modified := false
repeat
modified := false
for x in xs do
let fvarId := x.fvarId!
unless processed.contains fvarId do
let xType ← inferType x
if fvarsFound.contains fvarId then
-- Collect free vars in `x`s type and mark as processed
fvarsFound := update fvarsFound xType
processed := processed.insert fvarId
modified := true
else if (← isProp xType) then
-- If `x` is a proposition, and all theorem variables in `x`s type have already been found
-- add it to `fvarsFound` and mark it as processed.
if checkTypeFVars thmVars fvarsFound xType then
fvarsFound := fvarsFound.insert fvarId
processed := processed.insert fvarId
modified := true
else if (← fvarId.getDecl).binderInfo matches .instImplicit then
-- If `x` is instance implicit, check whether
-- we have found all free variables needed to synthesize instance
if (← canBeSynthesized thmVars fvarsFound xType) then
fvarsFound := fvarsFound.insert fvarId
fvarsFound := update fvarsFound xType
processed := processed.insert fvarId
modified := true
if fvarsFound.size == numParams then
return .ok
if !modified then
break
let mut pos := #[]
for h : i in [:xs.size] do
let fvarId := xs[i].fvarId!
unless fvarsFound.contains fvarId do
pos := pos.push i
return .missing pos.toList

/--
Given a theorem with proof `proof` and `numParams` parameters, returns a message
containing the parameters at positions `paramPos`.
-/
private def ppParamsAt (proof : Expr) (numParms : Nat) (paramPos : List Nat) : MetaM MessageData := do
forallBoundedTelescope (← inferType proof) numParms fun xs _ => do
let mut msg := m!""
let mut first := true
for h : i in [:xs.size] do
if paramPos.contains i then
let x := xs[i]
if first then first := false else msg := msg ++ "\n"
msg := msg ++ m!"{x} : {← inferType x}"
addMessageContextFull msg

def addTheoremPattern (declName : Name) (numParams : Nat) (patterns : List Expr) : MetaM Unit := do
let .thmInfo info ← getConstInfo declName
| throwError "`{declName}` is not a theorem, you cannot assign patterns to non-theorems for the `grind` tactic"
let us := info.levelParams.map mkLevelParam
let proof := mkConst declName us
let (patterns, symbols) ← NormalizePattern.main patterns
let (patterns, symbols, bvarFound) ← NormalizePattern.main patterns
trace[grind.pattern] "{declName}: {patterns.map ppPattern}"
if let .missing pos ← checkCoverage proof numParams bvarFound then
let pats : MessageData := m!"{patterns.map ppPattern}"
throwError "invalid pattern(s) for `{declName}`{indentD pats}\nthe following theorem parameters cannot be instantiated:{indentD (← ppParamsAt proof numParams pos)}"
theoremPatternsExt.add {
proof, patterns, numParams, symbols
origin := .decl declName
Expand Down
88 changes: 88 additions & 0 deletions tests/lean/run/grind_pattern1.lean
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,91 @@ error: `foo` is not a theorem, you cannot assign patterns to non-theorems for th
-/
#guard_msgs in
grind_pattern foo => x + x

/--
error: invalid pattern(s) for `Array.getElem_push_lt`
[@Array.push #4 #3 #2]
the following theorem parameters cannot be instantiated:
i : Nat
h : i < a.size
---
info: [grind.pattern] Array.getElem_push_lt: [@Array.push #4 #3 #2]
-/
#guard_msgs in
grind_pattern Array.getElem_push_lt => (a.push x)

class Foo (α : Type) (β : outParam Type) where
a : Unit

class Boo (α : Type) (β : Type) where
b : β

def f [Foo α β] [Boo α β] (a : α) : (α × β) :=
(a, Boo.b α)

instance [Foo α β] : Foo (List α) (Array β) where
a := ()

instance [Boo α β] : Boo (List α) (Array β) where
b := #[Boo.b α]

theorem fEq [Foo α β] [Boo α β] (a : List α) : (f a).1 = a := rfl

/-- info: [grind.pattern] fEq: [@f ? ? ? ? #0] -/
#guard_msgs in
grind_pattern fEq => f a

theorem fEq2 [Foo α β] [Boo α β] (a : List α) (_h : a.length > 5) : (f a).1 = a := rfl

/-- info: [grind.pattern] fEq2: [@f ? ? ? ? #1] -/
#guard_msgs in
grind_pattern fEq2 => f a

def g [Boo α β] (a : α) : (α × β) :=
(a, Boo.b α)

theorem gEq [Boo α β] (a : List α) : (g (β := Array β) a).1 = a := rfl

/--
error: invalid pattern(s) for `gEq`
[@g ? ? ? #0]
the following theorem parameters cannot be instantiated:
β : Type
inst✝ : Boo α β
---
info: [grind.pattern] gEq: [@g ? ? ? #0]
-/
#guard_msgs in
grind_pattern gEq => g a

def plus (a : Nat) (b : Nat) := a + b

theorem hThm1 (h : b > 10) : plus a b + plus a c > 10 := by
unfold plus; omega

/--
error: invalid pattern(s) for `hThm1`
[plus #2 #3]
the following theorem parameters cannot be instantiated:
c : Nat
---
info: [grind.pattern] hThm1: [plus #2 #3]
-/
#guard_msgs in
grind_pattern hThm1 => plus a b

/--
error: invalid pattern(s) for `hThm1`
[plus #2 #1]
the following theorem parameters cannot be instantiated:
b : Nat
h : b > 10
---
info: [grind.pattern] hThm1: [plus #2 #1]
-/
#guard_msgs in
grind_pattern hThm1 => plus a c

/-- info: [grind.pattern] hThm1: [plus #2 #1, plus #2 #3] -/
#guard_msgs in
grind_pattern hThm1 => plus a c, plus a b

0 comments on commit 1338a86

Please sign in to comment.