diff --git a/src/Lean/LocalContext.lean b/src/Lean/LocalContext.lean index 3db35ad9da50..4cf7889d894a 100644 --- a/src/Lean/LocalContext.lean +++ b/src/Lean/LocalContext.lean @@ -384,7 +384,7 @@ def size (lctx : LocalContext) : Nat := @[inline] def findDeclRev? (lctx : LocalContext) (f : LocalDecl → Option β) : Option β := Id.run <| lctx.findDeclRevM? f -partial def isSubPrefixOfAux (a₁ a₂ : PArray (Option LocalDecl)) (exceptFVars : Array Expr) (i j : Nat) : Bool := +partial def isSubPrefixOfAux (a₁ a₂ : PArray (Option LocalDecl)) (exceptFVars : Subarray Expr) (i j : Nat) : Bool := if h : i < a₁.size then match a₁[i] with | none => isSubPrefixOfAux a₁ a₂ exceptFVars (i+1) j @@ -401,7 +401,7 @@ partial def isSubPrefixOfAux (a₁ a₂ : PArray (Option LocalDecl)) (exceptFVar /-- Given `lctx₁ - exceptFVars` of the form `(x_1 : A_1) ... (x_n : A_n)`, then return true iff there is a local context `B_1* (x_1 : A_1) ... B_n* (x_n : A_n)` which is a prefix of `lctx₂` where `B_i`'s are (possibly empty) sequences of local declarations. -/ -def isSubPrefixOf (lctx₁ lctx₂ : LocalContext) (exceptFVars : Array Expr := #[]) : Bool := +def isSubPrefixOf (lctx₁ lctx₂ : LocalContext) (exceptFVars : Subarray Expr := {}) : Bool := isSubPrefixOfAux lctx₁.decls lctx₂.decls exceptFVars 0 0 @[inline] def mkBinding (isLambda : Bool) (lctx : LocalContext) (xs : Array Expr) (b : Expr) : Expr := diff --git a/src/Lean/Meta/ExprDefEq.lean b/src/Lean/Meta/ExprDefEq.lean index 9a6e1b807782..28db16c52560 100644 --- a/src/Lean/Meta/ExprDefEq.lean +++ b/src/Lean/Meta/ExprDefEq.lean @@ -387,129 +387,6 @@ private def checkTypesAndAssign (mvar : Expr) (v : Expr) : MetaM Bool := else pure false -/-- -Auxiliary method for solving constraints of the form `?m xs := v`. -It creates a lambda using `mkLambdaFVars ys v`, where `ys` is a superset of `xs`. -`ys` is often equal to `xs`. It is a bigger when there are let-declaration dependencies in `xs`. -For example, suppose we have `xs` of the form `#[a, c]` where -``` -a : Nat -b : Nat := f a -c : b = a -``` -In this scenario, the type of `?m` is `(x1 : Nat) -> (x2 : f x1 = x1) -> C[x1, x2]`, -and type of `v` is `C[a, c]`. Note that, `?m a c` is type correct since `f a = a` is definitionally equal -to the type of `c : b = a`, and the type of `?m a c` is equal to the type of `v`. -Note that `fun xs => v` is the term `fun (x1 : Nat) (x2 : b = x1) => v` which has type -`(x1 : Nat) -> (x2 : b = x1) -> C[x1, x2]` which is not definitionally equal to the type of `?m`, -and may not even be type correct. -The issue here is that we are not capturing the `let`-declarations. - -This method collects let-declarations `y` occurring between `xs[0]` and `xs.back` s.t. -some `x` in `xs` depends on `y`. -`ys` is the `xs` with these extra let-declarations included. - -In the example above, `ys` is `#[a, b, c]`, and `mkLambdaFVars ys v` produces -`fun a => let b := f a; fun (c : b = a) => v` which has a type definitionally equal to the type of `?m`. - -Recall that the method `checkAssignment` ensures `v` does not contain offending `let`-declarations. - -This method assumes that for any `xs[i]` and `xs[j]` where `i < j`, we have that `index of xs[i]` < `index of xs[j]`. -where the index is the position in the local context. --/ -private partial def mkLambdaFVarsWithLetDeps (xs : Array Expr) (v : Expr) : MetaM (Option Expr) := do - if !(← hasLetDeclsInBetween) then - mkLambdaFVars xs v (etaReduce := true) - else - let ys ← addLetDeps - mkLambdaFVars ys v (etaReduce := true) - -where - /-- Return true if there are let-declarions between `xs[0]` and `xs[xs.size-1]`. - We use it a quick-check to avoid the more expensive collection procedure. -/ - hasLetDeclsInBetween : MetaM Bool := do - let check (lctx : LocalContext) : Bool := Id.run do - let start := lctx.getFVar! xs[0]! |>.index - let stop := lctx.getFVar! xs.back! |>.index - for i in [start+1:stop] do - match lctx.getAt? i with - | some localDecl => - if localDecl.isLet then - return true - | _ => pure () - return false - if xs.size <= 1 then - return false - else - return check (← getLCtx) - - /-- Traverse `e` and stores in the state `NameHashSet` any let-declaration with index greater than `(← read)`. - The context `Nat` is the position of `xs[0]` in the local context. -/ - collectLetDeclsFrom (e : Expr) : ReaderT Nat (StateRefT FVarIdHashSet MetaM) Unit := do - let rec visit (e : Expr) : MonadCacheT Expr Unit (ReaderT Nat (StateRefT FVarIdHashSet MetaM)) Unit := - checkCache e fun _ => do - match e with - | .forallE _ d b _ => visit d; visit b - | .lam _ d b _ => visit d; visit b - | .letE _ t v b _ => visit t; visit v; visit b - | .app f a => visit f; visit a - | .mdata _ b => visit b - | .proj _ _ b => visit b - | .fvar fvarId => - let localDecl ← fvarId.getDecl - if localDecl.isLet && localDecl.index > (← read) then - modify fun s => s.insert localDecl.fvarId - | _ => pure () - visit (← instantiateMVars e) |>.run - - /-- - Auxiliary definition for traversing all declarations between `xs[0]` ... `xs.back` backwards. - The `Nat` argument is the current position in the local context being visited, and it is less than - or equal to the position of `xs.back` in the local context. - The `Nat` context `(← read)` is the position of `xs[0]` in the local context. - -/ - collectLetDepsAux : Nat → ReaderT Nat (StateRefT FVarIdHashSet MetaM) Unit - | 0 => return () - | i+1 => do - if i+1 == (← read) then - return () - else - match (← getLCtx).getAt? (i+1) with - | none => collectLetDepsAux i - | some localDecl => - if (← get).contains localDecl.fvarId then - collectLetDeclsFrom localDecl.type - match localDecl.value? with - | some val => collectLetDeclsFrom val - | _ => pure () - collectLetDepsAux i - - /-- Computes the set `ys`. It is a set of `FVarId`s, -/ - collectLetDeps : MetaM FVarIdHashSet := do - let lctx ← getLCtx - let start := lctx.getFVar! xs[0]! |>.index - let stop := lctx.getFVar! xs.back! |>.index - let s := xs.foldl (init := {}) fun s x => s.insert x.fvarId! - let (_, s) ← collectLetDepsAux stop |>.run start |>.run s - return s - - /-- Computes the array `ys` containing let-decls between `xs[0]` and `xs.back` that - some `x` in `xs` depends on. -/ - addLetDeps : MetaM (Array Expr) := do - let lctx ← getLCtx - let s ← collectLetDeps - /- Convert `s` into the array `ys` -/ - let start := lctx.getFVar! xs[0]! |>.index - let stop := lctx.getFVar! xs.back! |>.index - let mut ys := #[] - for i in [start:stop+1] do - match lctx.getAt? i with - | none => pure () - | some localDecl => - if s.contains localDecl.fvarId then - ys := ys.push localDecl.toExpr - return ys - /-! Each metavariable is declared in a particular local context. We use the notation `C |- ?m : t` to denote a metavariable `?m` that @@ -682,7 +559,7 @@ structure State where structure Context where mvarId : MVarId mvarDecl : MetavarDecl - fvars : Array Expr + fvars : Subarray Expr hasCtxLocals : Bool rhs : Expr @@ -708,7 +585,7 @@ private def addAssignmentInfo (msg : MessageData) : CheckAssignmentM MessageData let ctx ← read return m!"{msg} @ {mkMVar ctx.mvarId} {ctx.fvars} := {ctx.rhs}" -@[inline] def run (x : CheckAssignmentM Expr) (mvarId : MVarId) (fvars : Array Expr) (hasCtxLocals : Bool) (v : Expr) : MetaM (Option Expr) := do +@[inline] def run (x : CheckAssignmentM Expr) (mvarId : MVarId) (fvars : Subarray Expr) (hasCtxLocals : Bool) (v : Expr) : MetaM (Option Expr) := do let mvarDecl ← mvarId.getDecl let ctx := { mvarId := mvarId, mvarDecl := mvarDecl, fvars := fvars, hasCtxLocals := hasCtxLocals, rhs := v : Context } let x : CheckAssignmentM (Option Expr) := @@ -729,7 +606,7 @@ mutual match lctx.findFVar? fvar with | some (.ldecl (value := v) ..) => check v | _ => - if ctx.fvars.contains fvar then pure fvar + if ctx.fvars.any (· == fvar) then pure fvar else traceM `Meta.isDefEq.assign.outOfScopeFVar do addAssignmentInfo fvar throwOutOfScopeFVar @@ -767,30 +644,28 @@ mutual a metavariable that we also need to reduce the context. We remove from `ctx.mvarDecl.lctx` any variable that is not in `mvarDecl.lctx` - or in `ctx.fvars`. We don't need to remove the ones in `ctx.fvars` because - `elimMVarDeps` will take care of them. + or in `ctx.fvars[:fvarsInScope]`. We don't need to remove the ones in `ctx.fvars[:fvarsInScope]` + because `elimMVarDeps` will take care of them. - First, we collect `toErase` the variables that need to be erased. - Notat that if a variable is `ctx.fvars`, but it depends on variable at `toErase`, - we must also erase it. + Note that if a variable is in `ctx.fvars[:fvarsInScope]`, then its type is already checked, + so by replacing the type with the checked type, we avoid creating an illegal local context. + + We collect `toErase`, the variables that are erased, in order to filter the local instances. -/ - let toErase ← mvarDecl.lctx.foldlM (init := #[]) fun toErase localDecl => do - if ctx.mvarDecl.lctx.contains localDecl.fvarId then - return toErase - else if ctx.fvars.any fun fvar => fvar.fvarId! == localDecl.fvarId then - if (← findLocalDeclDependsOn localDecl fun fvarId => toErase.contains fvarId) then - -- localDecl depends on a variable that will be erased. So, we must add it to `toErase` too - return toErase.push localDecl.fvarId + let mut toErase := #[] + let mut newLCtx := mvarDecl.lctx + for localDecl in mvarDecl.lctx do + let fvarId := localDecl.fvarId + unless ctx.mvarDecl.lctx.contains fvarId do + if ctx.fvars.any (·.fvarId! == fvarId) then + let fvarType ← fvarId.getType -- get the type from the current local context + newLCtx := newLCtx.modifyLocalDecl fvarId (·.setType fvarType) else - return toErase - else - return toErase.push localDecl.fvarId - let lctx := toErase.foldl (init := mvarDecl.lctx) fun lctx toEraseFVar => - lctx.erase toEraseFVar - /- Compute new set of local instances. -/ + newLCtx := newLCtx.erase fvarId + toErase := toErase.push fvarId let localInsts := mvarDecl.localInstances.filter fun localInst => !toErase.contains localInst.fvar.fvarId! let mvarType ← check mvarDecl.type - let newMVar ← mkAuxMVar lctx localInsts mvarType mvarDecl.numScopeArgs + let newMVar ← mkAuxMVar newLCtx localInsts mvarType mvarDecl.numScopeArgs mvarId.assign newMVar return newMVar @@ -804,12 +679,12 @@ mutual let mvarType ← inferType mvar forallBoundedTelescope mvarType numArgs fun xs _ => do if xs.size != numArgs then return false - let some v ← mkLambdaFVarsWithLetDeps xs newMVar | return false - let some v ← checkAssignmentAux mvar.mvarId! #[] false v | return false + let v ← mkLambdaFVars xs newMVar + let some v ← checkAssignmentAux mvar.mvarId! {} false v | return false checkTypesAndAssign mvar v -- See checkAssignment - partial def checkAssignmentAux (mvarId : MVarId) (fvars : Array Expr) (hasCtxLocals : Bool) (v : Expr) : MetaM (Option Expr) := do + partial def checkAssignmentAux (mvarId : MVarId) (fvars : Subarray Expr) (hasCtxLocals : Bool) (v : Expr) : MetaM (Option Expr) := do run (check v) mvarId fvars hasCtxLocals v partial def checkApp (e : Expr) : CheckAssignmentM Expr := @@ -897,7 +772,7 @@ namespace CheckAssignmentQuick unsafe def checkImpl (hasCtxLocals : Bool) - (mctx : MetavarContext) (lctx : LocalContext) (mvarDecl : MetavarDecl) (mvarId : MVarId) (fvars : Array Expr) (e : Expr) : Bool := + (mctx : MetavarContext) (lctx : LocalContext) (mvarDecl : MetavarDecl) (mvarId : MVarId) (fvars : Subarray Expr) (e : Expr) : Bool := let rec visit (e : Expr) : StateM (PtrSet Expr) Bool := do if !e.hasExprMVar && !e.hasFVar then return true @@ -934,7 +809,7 @@ unsafe def checkImpl else visit e |>.run' mkPtrSet -def check (hasCtxLocals : Bool) (mctx : MetavarContext) (lctx : LocalContext) (mvarDecl : MetavarDecl) (mvarId : MVarId) (fvars : Array Expr) (e : Expr) : Bool := +def check (hasCtxLocals : Bool) (mctx : MetavarContext) (lctx : LocalContext) (mvarDecl : MetavarDecl) (mvarId : MVarId) (fvars : Subarray Expr) (e : Expr) : Bool := unsafe checkImpl hasCtxLocals mctx lctx mvarDecl mvarId fvars e end CheckAssignmentQuick @@ -1025,41 +900,69 @@ private def typeOccursCheck (mctx : MetavarContext) (mvarId : MVarId) (v : Expr) unsafe typeOccursCheckImp mctx mvarId v /-- - Auxiliary function for handling constraints of the form `?m a₁ ... aₙ =?= v`. - It will check whether we can perform the assignment - ``` - ?m := fun fvars => v - ``` - The result is `none` if the assignment can't be performed. - The result is `some newV` where `newV` is a possibly updated `v`. This method may need - to unfold let-declarations. -/ -def checkAssignment (mvarId : MVarId) (fvars : Array Expr) (v : Expr) : MetaM (Option Expr) := do - /- Check whether `mvarId` occurs in the type of `fvars` or not. If it does, return `none` - to prevent us from creating the cyclic assignment `?m := fun fvars => v` -/ - for fvar in fvars do - unless (← occursCheck mvarId (← inferType fvar)) do - return none - if !v.hasExprMVar && !v.hasFVar then - pure (some v) - else - let mvarDecl ← mvarId.getDecl - let hasCtxLocals := fvars.any fun fvar => mvarDecl.lctx.containsFVar fvar - let ctx ← read - let mctx ← getMCtx - let v ← if CheckAssignmentQuick.check hasCtxLocals mctx ctx.lctx mvarDecl mvarId fvars v then - pure v - else if let some v ← CheckAssignment.checkAssignmentAux mvarId fvars hasCtxLocals (← instantiateMVars v) then - pure v +Auxiliary function for handling constraints of the form `?m a₁ ... aₙ =?= v`. +It will check whether we can perform the assignment +``` +?m := fun fvars => v +``` +The result is `none` if the assignment can't be performed. +The result is `some (newV, newLCtx)` where `newV` is a possibly updated `v` and `newLCtx` a possibly updated +local context. This method may need to unfold let-declarations. + +The following things are checked: + +1) occurs check, that is, check that `?m` doesn't appear in the assignment +2) all free variables in the assignment must be present in the local context of `?m`. + If a let-varaible isn't in the local context, then this can be resolved by instantiating it with its let-value. +3) all metavariables in the assignment must have local contexts that are at most as large as the local context of `?m`. + If a metavariable has a local context with more free variables, this can be resolved by + assigning it to a new fresh metavariable with a restricted local context. + +checks 1-3 are done in both `v` and in the types of `a₁ ... aₙ`. +Additionally, we do an occurs check at the type of `v` (see `typeOccursCheck`) + +For check 3, it is important to do the checks in the order `a₁, ... aₙ, v`. +For example, if `?m` has an empty local context, and `a₁ : N`, for some let-variable `N := Nat`. +We should first replace the type of `a₁` with `Nat`, because if we find +a metavariable in `a₂, ... aₙ, v` which has `a₁` in its context, there would be a problem, +as the metavariable will be restricted to not have `N` in its context, and therefore it can't depend on `a₁`. +By replacing the type of `a₁` by `Nat`, and also doing this replacement in this new metavariable local context, +we can now allow metavariables in `a₂, ... aₙ, v` to depend on `a₁`. +-/ +def checkAssignment (mvarId : MVarId) (fvars : Array Expr) (v : Expr) : MetaM (Option (Expr × LocalContext)) := do + let mvarDecl ← mvarId.getDecl + let hasCtxLocals := fvars.any fun fvar => mvarDecl.lctx.containsFVar fvar + let rec checkFVars (i : Nat) : MetaM (Option (Expr × LocalContext)) := do + if h : i < fvars.size then + let fvar := fvars[i] + let fvarType ← inferType fvar + if CheckAssignmentQuick.check hasCtxLocals (← getMCtx) (← getLCtx) mvarDecl mvarId fvars[:i] fvarType then + checkFVars (i+1) + else if let some fvarType ← CheckAssignment.checkAssignmentAux mvarId fvars[:i] hasCtxLocals fvarType then + withReader (fun ctx => { ctx with lctx := ctx.lctx.modifyLocalDecl fvar.fvarId! (·.setType fvarType) }) do + checkFVars (i+1) + else + return none else - return none - unless typeOccursCheck (← getMCtx) mvarId v do - return none - return some v + let lctx ← getLCtx + if !v.hasExprMVar && !v.hasFVar then + return some (v, lctx) + else + let v ← if CheckAssignmentQuick.check hasCtxLocals (← getMCtx) lctx mvarDecl mvarId fvars[:fvars.size] v then + pure v + else if let some v ← CheckAssignment.checkAssignmentAux mvarId fvars[:fvars.size] hasCtxLocals (← instantiateMVars v) then + pure v + else + return none + unless typeOccursCheck (← getMCtx) mvarId v do + return none + return some (v, lctx) + checkFVars 0 -- Implementation for `_root_.Lean.MVarId.checkedAssign` @[export lean_checked_assign] def checkedAssignImpl (mvarId : MVarId) (val : Expr) : MetaM Bool := do - if let some val ← checkAssignment mvarId #[] val then + if let some (val, _) ← checkAssignment mvarId #[] val then mvarId.assign val return true else @@ -1130,10 +1033,10 @@ private def assignConst (mvar : Expr) (numArgs : Nat) (v : Expr) : MetaM Bool := if xs.size != numArgs then pure false else - let some v ← mkLambdaFVarsWithLetDeps xs v | pure false + let v ← mkLambdaFVars xs v match (← checkAssignment mvar.mvarId! #[] v) with | none => pure false - | some v => + | some (v, _) => trace[Meta.isDefEq.constApprox] "{mvar} := {v}" checkTypesAndAssign mvar v @@ -1171,19 +1074,19 @@ private partial def processConstApprox (mvar : Expr) (args : Array Expr) (patter if xs.size != suffixSize then defaultCase else - let some v ← mkLambdaFVarsWithLetDeps xs v | defaultCase + let v ← mkLambdaFVars xs v let rec go (argsPrefix : Array Expr) (v : Expr) : MetaM Bool := do trace[Meta.isDefEq] "processConstApprox.go {mvar} {argsPrefix} := {v}" let rec cont : MetaM Bool := do if argsPrefix.isEmpty then defaultCase else - let some v ← mkLambdaFVarsWithLetDeps #[argsPrefix.back!] v | defaultCase + let v ← mkLambdaFVars #[argsPrefix.back!] v (etaReduce := true) go argsPrefix.pop v match (← checkAssignment mvarId argsPrefix v) with | none => cont - | some vNew => - let some vNew ← mkLambdaFVarsWithLetDeps argsPrefix vNew | cont + | some (vNew, lctx) => + let vNew ← withReader ({ · with lctx }) do mkLambdaFVars argsPrefix vNew (etaReduce := true) if argsPrefix.any (fun arg => mvarDecl.lctx.containsFVar arg) then /- We need to type check `vNew` because abstraction using `mkLambdaFVars` may have produced a type incorrect term. See discussion at A2 -/ @@ -1225,9 +1128,9 @@ private partial def processAssignment (mvarApp : Expr) (v : Expr) : MetaM Bool : let mvarId := mvar.mvarId! match (← checkAssignment mvarId args v) with | none => useFOApprox args - | some v => do + | some (v, lctx) => do trace[Meta.isDefEq.assign.beforeMkLambda] "{mvar} {args} := {v}" - let some v ← mkLambdaFVarsWithLetDeps args v | return false + let v ← withReader ({· with lctx}) do mkLambdaFVars args v (etaReduce := true) if args.any (fun arg => mvarDecl.lctx.containsFVar arg) then /- We need to type check `v` because abstraction using `mkLambdaFVars` may have produced a type incorrect term. See discussion at A2 -/ diff --git a/tests/lean/run/isDefEqIssue.lean b/tests/lean/run/isDefEqIssue.lean index 932a85953944..ea00fc9af05d 100644 --- a/tests/lean/run/isDefEqIssue.lean +++ b/tests/lean/run/isDefEqIssue.lean @@ -7,3 +7,67 @@ private def resolveLValAux (s : String) (i : Nat) : Nat := i - 1 else i + + +/-- +This used to give +(kernel) declaration has free variables '_example' +-/ +example : Unit := + let x : Nat → Unit := _ + let N := Nat; + (fun a : N => + have : x a = () := rfl + ()) Nat.zero + + +/-- +This used to give +(kernel) declaration has free variables '_example' +-/ +example : IO Unit := do + pure () + match some () with + | some u => do + let pair := match () with | _ => ((),()) + let i := () + if h : i = pair.1 then + let k := 0 + | _ => return + + +/- +This used to give +(kernel) declaration has free variables '_example' +-/ +/-- +error: type mismatch + rfl +has type + x b a = x b a : Prop +but is expected to have type + x b a = Nat.zero : Prop +-/ +#guard_msgs in +example : Unit := + let x : Nat → Nat → Nat := _ + (fun (a : Nat) (b : let _ := a; Nat) => + have : x b a = Nat.zero := rfl + () + ) Nat.zero Nat.zero + +class Foo (a b : Nat) (h : a = b) (β : Nat → Type) where + val : β a + +@[default_instance] +instance (a b : Nat) (h : a = b) : Foo a b h Fin where + val := sorry + +/-- +This used to give +typeclass instance problem is stuck, it is often due to metavariables + Foo a Nat.zero h (?m.734 h) +-/ +example := + let a : Nat := Nat.zero + fun (h : a = Nat.zero) => Foo.val h