diff --git a/src/Lean/Meta/ExprDefEq.lean b/src/Lean/Meta/ExprDefEq.lean index 9a6e1b807782..1888e7608cdb 100644 --- a/src/Lean/Meta/ExprDefEq.lean +++ b/src/Lean/Meta/ExprDefEq.lean @@ -417,7 +417,7 @@ Recall that the method `checkAssignment` ensures `v` does not contain offending This method assumes that for any `xs[i]` and `xs[j]` where `i < j`, we have that `index of xs[i]` < `index of xs[j]`. where the index is the position in the local context. -/ -private partial def mkLambdaFVarsWithLetDeps (xs : Array Expr) (v : Expr) : MetaM (Option Expr) := do +private partial def mkLambdaFVarsWithLetDeps (xs : Array Expr) (v : Expr) (mvarLCtx : LocalContext) : MetaM (Option Expr) := do if !(← hasLetDeclsInBetween) then mkLambdaFVars xs v (etaReduce := true) else @@ -428,25 +428,23 @@ where /-- Return true if there are let-declarions between `xs[0]` and `xs[xs.size-1]`. We use it a quick-check to avoid the more expensive collection procedure. -/ hasLetDeclsInBetween : MetaM Bool := do - let check (lctx : LocalContext) : Bool := Id.run do - let start := lctx.getFVar! xs[0]! |>.index - let stop := lctx.getFVar! xs.back! |>.index - for i in [start+1:stop] do - match lctx.getAt? i with - | some localDecl => - if localDecl.isLet then - return true - | _ => pure () - return false - if xs.size <= 1 then - return false - else - return check (← getLCtx) + let rec check (i : Nat) (lctx : LocalContext) : Bool := Id.run do + if let some localDecl := lctx.getAt? i then + if mvarLCtx.contains localDecl.fvarId then + return false + if localDecl.isLet then + return true + match i with + | 0 => return false + | i+1 => check i lctx + let lctx ← getLCtx + let stop := lctx.getFVar! xs.back! |>.index + return check stop lctx /-- Traverse `e` and stores in the state `NameHashSet` any let-declaration with index greater than `(← read)`. The context `Nat` is the position of `xs[0]` in the local context. -/ - collectLetDeclsFrom (e : Expr) : ReaderT Nat (StateRefT FVarIdHashSet MetaM) Unit := do - let rec visit (e : Expr) : MonadCacheT Expr Unit (ReaderT Nat (StateRefT FVarIdHashSet MetaM)) Unit := + collectLetDeclsFrom (e : Expr) : (StateRefT FVarIdHashSet MetaM) Unit := do + let rec visit (e : Expr) : MonadCacheT Expr Unit ((StateRefT FVarIdHashSet MetaM)) Unit := checkCache e fun _ => do match e with | .forallE _ d b _ => visit d; visit b @@ -457,7 +455,7 @@ where | .proj _ _ b => visit b | .fvar fvarId => let localDecl ← fvarId.getDecl - if localDecl.isLet && localDecl.index > (← read) then + if localDecl.isLet && !mvarLCtx.contains localDecl.fvarId then modify fun s => s.insert localDecl.fvarId | _ => pure () visit (← instantiateMVars e) |>.run @@ -468,29 +466,27 @@ where or equal to the position of `xs.back` in the local context. The `Nat` context `(← read)` is the position of `xs[0]` in the local context. -/ - collectLetDepsAux : Nat → ReaderT Nat (StateRefT FVarIdHashSet MetaM) Unit - | 0 => return () + collectLetDepsAux : Nat → StateRefT FVarIdHashSet MetaM Unit + | 0 => return | i+1 => do - if i+1 == (← read) then - return () - else - match (← getLCtx).getAt? (i+1) with - | none => collectLetDepsAux i - | some localDecl => - if (← get).contains localDecl.fvarId then - collectLetDeclsFrom localDecl.type - match localDecl.value? with - | some val => collectLetDeclsFrom val - | _ => pure () - collectLetDepsAux i + match (← getLCtx).getAt? (i+1) with + | none => collectLetDepsAux i + | some localDecl => + if mvarLCtx.contains localDecl.fvarId then + return + else if (← get).contains localDecl.fvarId then + collectLetDeclsFrom localDecl.type + match localDecl.value? with + | some val => collectLetDeclsFrom val + | _ => pure () + collectLetDepsAux i /-- Computes the set `ys`. It is a set of `FVarId`s, -/ collectLetDeps : MetaM FVarIdHashSet := do let lctx ← getLCtx - let start := lctx.getFVar! xs[0]! |>.index let stop := lctx.getFVar! xs.back! |>.index let s := xs.foldl (init := {}) fun s x => s.insert x.fvarId! - let (_, s) ← collectLetDepsAux stop |>.run start |>.run s + let (_, s) ← collectLetDepsAux stop |>.run s return s /-- Computes the array `ys` containing let-decls between `xs[0]` and `xs.back` that @@ -499,10 +495,9 @@ where let lctx ← getLCtx let s ← collectLetDeps /- Convert `s` into the array `ys` -/ - let start := lctx.getFVar! xs[0]! |>.index let stop := lctx.getFVar! xs.back! |>.index let mut ys := #[] - for i in [start:stop+1] do + for i in [:stop+1] do match lctx.getAt? i with | none => pure () | some localDecl => @@ -801,10 +796,10 @@ mutual If `ctxApprox` is true, then we solve this case by creating a fresh metavariable ?n with the correct scope, an assigning `?m := fun _ ... _ => ?n` -/ partial def assignToConstFun (mvar : Expr) (numArgs : Nat) (newMVar : Expr) : MetaM Bool := do - let mvarType ← inferType mvar - forallBoundedTelescope mvarType numArgs fun xs _ => do + let mvarDecl ← mvar.mvarId!.getDecl + forallBoundedTelescope mvarDecl.type numArgs fun xs _ => do if xs.size != numArgs then return false - let some v ← mkLambdaFVarsWithLetDeps xs newMVar | return false + let some v ← mkLambdaFVarsWithLetDeps xs newMVar mvarDecl.lctx | return false let some v ← checkAssignmentAux mvar.mvarId! #[] false v | return false checkTypesAndAssign mvar v @@ -1130,7 +1125,7 @@ private def assignConst (mvar : Expr) (numArgs : Nat) (v : Expr) : MetaM Bool := if xs.size != numArgs then pure false else - let some v ← mkLambdaFVarsWithLetDeps xs v | pure false + let some v ← mkLambdaFVarsWithLetDeps xs v mvarDecl.lctx | pure false match (← checkAssignment mvar.mvarId! #[] v) with | none => pure false | some v => @@ -1171,19 +1166,19 @@ private partial def processConstApprox (mvar : Expr) (args : Array Expr) (patter if xs.size != suffixSize then defaultCase else - let some v ← mkLambdaFVarsWithLetDeps xs v | defaultCase + let some v ← mkLambdaFVarsWithLetDeps xs v mvarDecl.lctx | defaultCase let rec go (argsPrefix : Array Expr) (v : Expr) : MetaM Bool := do trace[Meta.isDefEq] "processConstApprox.go {mvar} {argsPrefix} := {v}" let rec cont : MetaM Bool := do if argsPrefix.isEmpty then defaultCase else - let some v ← mkLambdaFVarsWithLetDeps #[argsPrefix.back!] v | defaultCase + let some v ← mkLambdaFVarsWithLetDeps #[argsPrefix.back!] v mvarDecl.lctx | defaultCase go argsPrefix.pop v match (← checkAssignment mvarId argsPrefix v) with | none => cont | some vNew => - let some vNew ← mkLambdaFVarsWithLetDeps argsPrefix vNew | cont + let some vNew ← mkLambdaFVarsWithLetDeps argsPrefix vNew mvarDecl.lctx | cont if argsPrefix.any (fun arg => mvarDecl.lctx.containsFVar arg) then /- We need to type check `vNew` because abstraction using `mkLambdaFVars` may have produced a type incorrect term. See discussion at A2 -/ @@ -1227,7 +1222,7 @@ private partial def processAssignment (mvarApp : Expr) (v : Expr) : MetaM Bool : | none => useFOApprox args | some v => do trace[Meta.isDefEq.assign.beforeMkLambda] "{mvar} {args} := {v}" - let some v ← mkLambdaFVarsWithLetDeps args v | return false + let some v ← mkLambdaFVarsWithLetDeps args v mvarDecl.lctx | return false if args.any (fun arg => mvarDecl.lctx.containsFVar arg) then /- We need to type check `v` because abstraction using `mkLambdaFVars` may have produced a type incorrect term. See discussion at A2 -/