Skip to content

Commit

Permalink
fix: IndPred: track function's motive in a let binding, use withoutPr…
Browse files Browse the repository at this point in the history
…oofIrrelevance, no chaining (#4839)

this improves support for structural recursion over inductive
*predicates* when there are reflexive arguments.

Consider
```lean
inductive F: Prop where
  | base
  | step (fn: Nat → F)

-- set_option trace.Meta.IndPredBelow.search true
set_option pp.proofs true

def F.asdf1 : (f : F) → True
  | base => trivial
  | step f => F.asdf1 (f 0)
termination_by structural f => f`
```

Previously the search for the right induction hypothesis would fail with
```
could not solve using backwards chaining x✝¹ : F
x✝ : x✝¹.below
f : Nat → F
a✝¹ : ∀ (a : Nat), (f a).below
a✝ : Nat → True
⊢ True
```

The backchaining process will try to use `a✝ : Nat → True`, but then has
no idea what to use for `Nat`.

There are three steps here to fix this.

1. We let-bind the function's type before the whole process. Now the
   goal is

   ```
   funType : F → Prop := fun x => True
   x✝ : x✝¹.below
   f : Nat → F
   a✝¹ : ∀ (a : Nat), (f a).below
   a✝ : ∀ (a : Nat), funType (f a)
   ⊢ funType (f 0)
   ```
2. Instead of using the general purpose backchaining proof search, which
is more
powerful than we need here (we need on recursive search and no
backtracking),
   we have a custom search that looks for local assumptions that 
   provide evidence of `funType`, and extracts the arguments from that
   “type” application to construct the recursive call.

   Above, it will thus unify `f a =?= f 0`.

3. In order to make progress here, we also turn on use
`withoutProofIrrelevance`,
because else `isDefEq` is happy to say “they are equal” without actually
looking
   at the terms and thus assigning `?a := 0`.

This idea of let-binding the function's motive may also be useful for
the other recursion compilers, as it may simplify the FunInd
construction. This is to be investigated.

fixes #4751
  • Loading branch information
nomeata authored Jul 28, 2024
1 parent 87c92a3 commit 671ce7a
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 73 deletions.
97 changes: 61 additions & 36 deletions src/Lean/Elab/PreDefinition/Structural/IndPred.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,35 @@ import Lean.Elab.PreDefinition.Structural.RecArgInfo
namespace Lean.Elab.Structural
open Meta

private partial def replaceIndPredRecApps (recArgInfo : RecArgInfo) (motive : Expr) (e : Expr) : M Expr := do
let maxDepth := IndPredBelow.maxBackwardChainingDepth.get (← getOptions)
private def replaceIndPredRecApp (numFixed : Nat) (funType : Expr) (e : Expr) : M Expr := do
withoutProofIrrelevance do
withTraceNode `Elab.definition.structural (fun _ => pure m!"eliminating recursive call {e}") do
-- We want to replace `e` with an expression of the same type
let main ← mkFreshExprSyntheticOpaqueMVar (← inferType e)
let args : Array Expr := e.getAppArgs[numFixed:]
let lctx ← getLCtx
let r ← lctx.anyM fun localDecl => do
if localDecl.isAuxDecl then return false
let (mvars, _, t) ← forallMetaTelescope localDecl.type -- NB: do not reduce, we want to see the `funType`
unless t.getAppFn == funType do return false
withTraceNodeBefore `Elab.definition.structural (do pure m!"trying {mkFVar localDecl.fvarId} : {localDecl.type}") do
if args.size < t.getAppNumArgs then
trace[Elab.definition.structural] "too few arguments. Underapplied recursive call?"
return false
if (← (t.getAppArgs.zip args).allM (fun (t,s) => isDefEq t s)) then
main.mvarId!.assign (mkAppN (mkAppN localDecl.toExpr mvars) args[t.getAppNumArgs:])
return ← mvars.allM fun v => do
unless (← v.mvarId!.isAssigned) do
trace[Elab.definition.structural] "Cannot use {mkFVar localDecl.fvarId}: parameter {v} remains unassigned"
return false
return true
trace[Elab.definition.structural] "Arguments do not match"
return false
unless r do
throwError "Could not eliminate recursive call {e}"
instantiateMVars main

private partial def replaceIndPredRecApps (recArgInfo : RecArgInfo) (funType : Expr) (motive : Expr) (e : Expr) : M Expr := do
let rec loop (e : Expr) : M Expr := do
match e with
| Expr.lam n d b c =>
Expand All @@ -35,12 +62,7 @@ private partial def replaceIndPredRecApps (recArgInfo : RecArgInfo) (motive : Ex
let processApp (e : Expr) : M Expr := do
e.withApp fun f args => do
if f.isConstOf recArgInfo.fnName then
let ty ← inferType e
let main ← mkFreshExprSyntheticOpaqueMVar ty
if (← IndPredBelow.backwardsChaining main.mvarId! maxDepth) then
pure main
else
throwError "could not solve using backwards chaining {MessageData.ofGoal main.mvarId!}"
replaceIndPredRecApp recArgInfo.numFixed funType e
else
return mkAppN (← loop f) (← args.mapM loop)
match (← matchMatcherApp? e) with
Expand Down Expand Up @@ -79,33 +101,36 @@ def mkIndPredBRecOn (recArgInfo : RecArgInfo) (value : Expr) : M Expr := do
let type := (← inferType value).headBeta
let (indexMajorArgs, otherArgs) := recArgInfo.pickIndicesMajor ys
trace[Elab.definition.structural] "numFixed: {recArgInfo.numFixed}, indexMajorArgs: {indexMajorArgs}, otherArgs: {otherArgs}"
let motive ← mkForallFVars otherArgs type
let motive ← mkLambdaFVars indexMajorArgs motive
trace[Elab.definition.structural] "brecOn motive: {motive}"
let brecOn := Lean.mkConst (mkBRecOnName recArgInfo.indName!) recArgInfo.indGroupInst.levels
let brecOn := mkAppN brecOn recArgInfo.indGroupInst.params
let brecOn := mkApp brecOn motive
let brecOn := mkAppN brecOn indexMajorArgs
check brecOn
let brecOnType ← inferType brecOn
trace[Elab.definition.structural] "brecOn {brecOn}"
trace[Elab.definition.structural] "brecOnType {brecOnType}"
-- we need to close the telescope here, because the local context is used:
-- The root cause was, that this copied code puts an ih : FType into the
-- local context and later, when we use the local context to build the recursive
-- call, it uses this ih. But that ih doesn't exist in the actual brecOn call.
-- That's why it must go.
let FType ← forallBoundedTelescope brecOnType (some 1) fun F _ => do
let F := F[0]!
let FType ← inferType F
trace[Elab.definition.structural] "FType: {FType}"
instantiateForall FType indexMajorArgs
forallBoundedTelescope FType (some 1) fun below _ => do
let below := below[0]!
let valueNew ← replaceIndPredRecApps recArgInfo motive value
let Farg ← mkLambdaFVars (indexMajorArgs ++ #[below] ++ otherArgs) valueNew
let brecOn := mkApp brecOn Farg
let brecOn := mkAppN brecOn otherArgs
mkLambdaFVars ys brecOn
let funType ← mkLambdaFVars ys type
withLetDecl `funType (← inferType funType) funType fun funType => do
let motive ← mkForallFVars otherArgs (mkAppN funType ys)
let motive ← mkLambdaFVars indexMajorArgs motive
trace[Elab.definition.structural] "brecOn motive: {motive}"
let brecOn := Lean.mkConst (mkBRecOnName recArgInfo.indName!) recArgInfo.indGroupInst.levels
let brecOn := mkAppN brecOn recArgInfo.indGroupInst.params
let brecOn := mkApp brecOn motive
let brecOn := mkAppN brecOn indexMajorArgs
check brecOn
let brecOnType ← inferType brecOn
trace[Elab.definition.structural] "brecOn {brecOn}"
trace[Elab.definition.structural] "brecOnType {brecOnType}"
-- we need to close the telescope here, because the local context is used:
-- The root cause was, that this copied code puts an ih : FType into the
-- local context and later, when we use the local context to build the recursive
-- call, it uses this ih. But that ih doesn't exist in the actual brecOn call.
-- That's why it must go.
let FType ← forallBoundedTelescope brecOnType (some 1) fun F _ => do
let F := F[0]!
let FType ← inferType F
trace[Elab.definition.structural] "FType: {FType}"
instantiateForall FType indexMajorArgs
forallBoundedTelescope FType (some 1) fun below _ => do
let below := below[0]!
let valueNew ← replaceIndPredRecApps recArgInfo funType motive value
let Farg ← mkLambdaFVars (indexMajorArgs ++ #[below] ++ otherArgs) valueNew
let brecOn := mkApp brecOn Farg
let brecOn := mkAppN brecOn otherArgs
let brecOn ← mkLetFVars #[funType] brecOn
mkLambdaFVars ys brecOn

end Lean.Elab.Structural
63 changes: 26 additions & 37 deletions tests/lean/run/4751.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,59 +3,48 @@ inductive F: Prop where
| step (fn: Nat → F)

-- set_option trace.Meta.IndPredBelow.search true
-- set_option trace.Elab.definition.structural true
set_option pp.proofs true

/--
error: failed to infer structural recursion:
Cannot use parameter #1:
could not solve using backwards chaining x✝¹ : F
x✝ : x✝¹.below
f : Nat → F
a✝¹ : ∀ (a : Nat), (f a).below
a✝ : Nat → True
⊢ True
-/
#guard_msgs in
def F.asdf1 : (f : F) → True
| base => trivial
| step f => F.asdf1 (f 0)
| step g => match g 1 with
| base => trivial
| step h => F.asdf1 (h 1)
termination_by structural f => f


def TTrue (_f : F) := True

/--
error: failed to infer structural recursion:
Cannot use parameter #1:
could not solve using backwards chaining x✝¹ : F
x✝ : x✝¹.below
f : Nat → F
a✝¹ : ∀ (a : Nat), (f a).below
a✝ : ∀ (a : Nat), TTrue (f a)
⊢ TTrue (f 0)
-/
#guard_msgs in
def F.asdf2 : (f : F) → TTrue f
| base => trivial
| step f => F.asdf2 (f 0)
termination_by structural f => f



inductive ITrue (f : F) : Prop where | trivial

/--
error: failed to infer structural recursion:
Cannot use parameter #1:
could not solve using backwards chaining x✝¹ : F
x✝ : x✝¹.below
f : Nat → F
a✝¹ : ∀ (a : Nat), (f a).below
a✝ : ∀ (a : Nat), ITrue (f a)
⊢ ITrue (f 0)
-/
#guard_msgs in
def F.asdf3 : (f : F) → ITrue f
| base => .trivial
| step f => F.asdf3 (f 0)
termination_by structural f => f

-- Variants with extra arguments

inductive T : PropProp where
| base : T True
| step (fn: Nat → T (True → p)) : T p

def T.foo {P : Prop} : (f : T P) → P
| base => True.intro
| step f => foo (f 0) True.intro
termination_by structural f => f

-- The same, but as a non-reflexive data type

inductive T' : PropProp where
| base : T' True
| step (t : T' (True → p)) : T' p

def T'.foo {P : Prop} : (f : T' P) → P
| base => True.intro
| step t => foo t True.intro
termination_by structural f => f
5 changes: 5 additions & 0 deletions tests/lean/run/structuralRec1.lean
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
set_option linter.unusedVariables false

inductive PList (α : Type) : Prop
| nil
| cons : α → PList α → PList α
Expand Down Expand Up @@ -85,6 +87,7 @@ else
match ys with
| PList.nil => PList.nil
| y:::ys => (y + x/2 + 1) ::: pbla (x/2) ys
termination_by structural ys

theorem blaEq (y : Nat) (ys : List Nat) : bla 4 (y::ys) = (y+2) :: bla 2 ys :=
rfl
Expand Down Expand Up @@ -181,11 +184,13 @@ match n, m, hn with
| _, _, is_nat_T.S is_nat_T.Z => TF1
| _, m, is_nat_T.S (is_nat_T.S h) => TFS («reordered discriminants, type» _ h m)


theorem «reordered discriminants» : ∀ n, is_nat n → Nat → P n := fun n hn m =>
match n, m, hn with
| _, _, is_nat.Z => F0
| _, _, is_nat.S is_nat.Z => F1
| _, m, is_nat.S (is_nat.S h) => FS («reordered discriminants» _ h m)
termination_by structural _ n => n

/-- known unsupported case for types, just here for reference. -/
-- def «unsupported nesting» (xs : List Nat) : True :=
Expand Down

0 comments on commit 671ce7a

Please sign in to comment.