Skip to content

Commit

Permalink
attempt number 2
Browse files Browse the repository at this point in the history
  • Loading branch information
JovanGerb committed Nov 6, 2024
1 parent 1218825 commit 73af511
Showing 1 changed file with 39 additions and 44 deletions.
83 changes: 39 additions & 44 deletions src/Lean/Meta/ExprDefEq.lean
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ Recall that the method `checkAssignment` ensures `v` does not contain offending
This method assumes that for any `xs[i]` and `xs[j]` where `i < j`, we have that `index of xs[i]` < `index of xs[j]`.
where the index is the position in the local context.
-/
private partial def mkLambdaFVarsWithLetDeps (xs : Array Expr) (v : Expr) : MetaM (Option Expr) := do
private partial def mkLambdaFVarsWithLetDeps (xs : Array Expr) (v : Expr) (mvarLCtx : LocalContext) : MetaM (Option Expr) := do
if !(← hasLetDeclsInBetween) then
mkLambdaFVars xs v (etaReduce := true)
else
Expand All @@ -428,25 +428,23 @@ where
/-- Return true if there are let-declarions between `xs[0]` and `xs[xs.size-1]`.
We use it a quick-check to avoid the more expensive collection procedure. -/
hasLetDeclsInBetween : MetaM Bool := do
let check (lctx : LocalContext) : Bool := Id.run do
let start := lctx.getFVar! xs[0]! |>.index
let stop := lctx.getFVar! xs.back! |>.index
for i in [start+1:stop] do
match lctx.getAt? i with
| some localDecl =>
if localDecl.isLet then
return true
| _ => pure ()
return false
if xs.size <= 1 then
return false
else
return check (← getLCtx)
let rec check (i : Nat) (lctx : LocalContext) : Bool := Id.run do
if let some localDecl := lctx.getAt? i then
if mvarLCtx.contains localDecl.fvarId then
return false
if localDecl.isLet then
return true
match i with
| 0 => return false
| i+1 => check i lctx
let lctx ← getLCtx
let stop := lctx.getFVar! xs.back! |>.index
return check stop lctx

/-- Traverse `e` and stores in the state `NameHashSet` any let-declaration with index greater than `(← read)`.
The context `Nat` is the position of `xs[0]` in the local context. -/
collectLetDeclsFrom (e : Expr) : ReaderT Nat (StateRefT FVarIdHashSet MetaM) Unit := do
let rec visit (e : Expr) : MonadCacheT Expr Unit (ReaderT Nat (StateRefT FVarIdHashSet MetaM)) Unit :=
collectLetDeclsFrom (e : Expr) : (StateRefT FVarIdHashSet MetaM) Unit := do
let rec visit (e : Expr) : MonadCacheT Expr Unit ((StateRefT FVarIdHashSet MetaM)) Unit :=
checkCache e fun _ => do
match e with
| .forallE _ d b _ => visit d; visit b
Expand All @@ -457,7 +455,7 @@ where
| .proj _ _ b => visit b
| .fvar fvarId =>
let localDecl ← fvarId.getDecl
if localDecl.isLet && localDecl.index > (← read) then
if localDecl.isLet && !mvarLCtx.contains localDecl.fvarId then
modify fun s => s.insert localDecl.fvarId
| _ => pure ()
visit (← instantiateMVars e) |>.run
Expand All @@ -468,29 +466,27 @@ where
or equal to the position of `xs.back` in the local context.
The `Nat` context `(← read)` is the position of `xs[0]` in the local context.
-/
collectLetDepsAux : Nat → ReaderT Nat (StateRefT FVarIdHashSet MetaM) Unit
| 0 => return ()
collectLetDepsAux : Nat → StateRefT FVarIdHashSet MetaM Unit
| 0 => return
| i+1 => do
if i+1 == (← read) then
return ()
else
match (← getLCtx).getAt? (i+1) with
| none => collectLetDepsAux i
| some localDecl =>
if (← get).contains localDecl.fvarId then
collectLetDeclsFrom localDecl.type
match localDecl.value? with
| some val => collectLetDeclsFrom val
| _ => pure ()
collectLetDepsAux i
match (← getLCtx).getAt? (i+1) with
| none => collectLetDepsAux i
| some localDecl =>
if mvarLCtx.contains localDecl.fvarId then
return
else if (← get).contains localDecl.fvarId then
collectLetDeclsFrom localDecl.type
match localDecl.value? with
| some val => collectLetDeclsFrom val
| _ => pure ()
collectLetDepsAux i

/-- Computes the set `ys`. It is a set of `FVarId`s, -/
collectLetDeps : MetaM FVarIdHashSet := do
let lctx ← getLCtx
let start := lctx.getFVar! xs[0]! |>.index
let stop := lctx.getFVar! xs.back! |>.index
let s := xs.foldl (init := {}) fun s x => s.insert x.fvarId!
let (_, s) ← collectLetDepsAux stop |>.run start |>.run s
let (_, s) ← collectLetDepsAux stop |>.run s
return s

/-- Computes the array `ys` containing let-decls between `xs[0]` and `xs.back` that
Expand All @@ -499,10 +495,9 @@ where
let lctx ← getLCtx
let s ← collectLetDeps
/- Convert `s` into the array `ys` -/
let start := lctx.getFVar! xs[0]! |>.index
let stop := lctx.getFVar! xs.back! |>.index
let mut ys := #[]
for i in [start:stop+1] do
for i in [:stop+1] do
match lctx.getAt? i with
| none => pure ()
| some localDecl =>
Expand Down Expand Up @@ -801,10 +796,10 @@ mutual
If `ctxApprox` is true, then we solve this case by creating a fresh metavariable ?n with the correct scope,
an assigning `?m := fun _ ... _ => ?n` -/
partial def assignToConstFun (mvar : Expr) (numArgs : Nat) (newMVar : Expr) : MetaM Bool := do
let mvarTypeinferType mvar
forallBoundedTelescope mvarType numArgs fun xs _ => do
let mvarDecl ← mvar.mvarId!.getDecl
forallBoundedTelescope mvarDecl.type numArgs fun xs _ => do
if xs.size != numArgs then return false
let some v ← mkLambdaFVarsWithLetDeps xs newMVar | return false
let some v ← mkLambdaFVarsWithLetDeps xs newMVar mvarDecl.lctx | return false
let some v ← checkAssignmentAux mvar.mvarId! #[] false v | return false
checkTypesAndAssign mvar v

Expand Down Expand Up @@ -1130,7 +1125,7 @@ private def assignConst (mvar : Expr) (numArgs : Nat) (v : Expr) : MetaM Bool :=
if xs.size != numArgs then
pure false
else
let some v ← mkLambdaFVarsWithLetDeps xs v | pure false
let some v ← mkLambdaFVarsWithLetDeps xs v mvarDecl.lctx | pure false
match (← checkAssignment mvar.mvarId! #[] v) with
| none => pure false
| some v =>
Expand Down Expand Up @@ -1171,19 +1166,19 @@ private partial def processConstApprox (mvar : Expr) (args : Array Expr) (patter
if xs.size != suffixSize then
defaultCase
else
let some v ← mkLambdaFVarsWithLetDeps xs v | defaultCase
let some v ← mkLambdaFVarsWithLetDeps xs v mvarDecl.lctx | defaultCase
let rec go (argsPrefix : Array Expr) (v : Expr) : MetaM Bool := do
trace[Meta.isDefEq] "processConstApprox.go {mvar} {argsPrefix} := {v}"
let rec cont : MetaM Bool := do
if argsPrefix.isEmpty then
defaultCase
else
let some v ← mkLambdaFVarsWithLetDeps #[argsPrefix.back!] v | defaultCase
let some v ← mkLambdaFVarsWithLetDeps #[argsPrefix.back!] v mvarDecl.lctx | defaultCase
go argsPrefix.pop v
match (← checkAssignment mvarId argsPrefix v) with
| none => cont
| some vNew =>
let some vNew ← mkLambdaFVarsWithLetDeps argsPrefix vNew | cont
let some vNew ← mkLambdaFVarsWithLetDeps argsPrefix vNew mvarDecl.lctx | cont
if argsPrefix.any (fun arg => mvarDecl.lctx.containsFVar arg) then
/- We need to type check `vNew` because abstraction using `mkLambdaFVars` may have produced
a type incorrect term. See discussion at A2 -/
Expand Down Expand Up @@ -1227,7 +1222,7 @@ private partial def processAssignment (mvarApp : Expr) (v : Expr) : MetaM Bool :
| none => useFOApprox args
| some v => do
trace[Meta.isDefEq.assign.beforeMkLambda] "{mvar} {args} := {v}"
let some v ← mkLambdaFVarsWithLetDeps args v | return false
let some v ← mkLambdaFVarsWithLetDeps args v mvarDecl.lctx | return false
if args.any (fun arg => mvarDecl.lctx.containsFVar arg) then
/- We need to type check `v` because abstraction using `mkLambdaFVars` may have produced
a type incorrect term. See discussion at A2 -/
Expand Down

0 comments on commit 73af511

Please sign in to comment.