Skip to content

Commit b06cc0d

Browse files
committed
fix: do not look for split-tactic candidates in proof terms
We also skip candidates in implicit arguments and binders of lambda expressions See new test.
1 parent e28f8fa commit b06cc0d

File tree

4 files changed

+184
-48
lines changed

4 files changed

+184
-48
lines changed

src/Lean/Meta/Basic.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ structure ParamInfo where
129129
hasFwdDeps : Bool := false
130130
/-- `backDeps` contains the backwards dependencies. That is, the (0-indexed) position of previous parameters that this one depends on. -/
131131
backDeps : Array Nat := #[]
132-
/-- `isProp` is true if the parameter is always a proposition. -/
132+
/-- `isProp` is true if the parameter type is always a proposition. -/
133133
isProp : Bool := false
134134
/--
135135
`isDecInst` is true if the parameter's type is of the form `Decidable ...`.

src/Lean/Meta/Tactic/Split.lean

Lines changed: 5 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def getSimpMatchContext : MetaM Simp.Context :=
1717
return {
1818
simpTheorems := {}
1919
congrTheorems := (← getSimpCongrTheorems)
20-
config := { Simp.neutralConfig with dsimp := false, implicitDefEqProofs := true }
20+
config := { Simp.neutralConfig with dsimp := false }
2121
}
2222

2323
def simpMatch (e : Expr) : MetaM Simp.Result := do
@@ -270,7 +270,7 @@ def mkDiscrGenErrorMsg (e : Expr) : MessageData :=
270270
def throwDiscrGenError (e : Expr) : MetaM α :=
271271
throwError (mkDiscrGenErrorMsg e)
272272

273-
def splitMatch (mvarId : MVarId) (e : Expr) : MetaM (List MVarId) := do
273+
def splitMatch (mvarId : MVarId) (e : Expr) : MetaM (List MVarId) := mvarId.withContext do
274274
let some app ← matchMatcherApp? e | throwError "internal error in `split` tactic: match application expected{indentExpr e}\nthis error typically occurs when the `split` tactic internal functions have been used in a new meta-program"
275275
let matchEqns ← Match.getEquationsFor app.matcherName
276276
let mvarIds ← applyMatchSplitter mvarId app.matcherName app.matcherLevels app.params app.discrs
@@ -279,43 +279,14 @@ def splitMatch (mvarId : MVarId) (e : Expr) : MetaM (List MVarId) := do
279279
return (i+1, mvarId::mvarIds)
280280
return mvarIds.reverse
281281

282-
/-- Return an `if-then-else` or `match-expr` to split. -/
283-
partial def findSplit? (env : Environment) (e : Expr) (splitIte := true) (exceptionSet : ExprSet := {}) : Option Expr :=
284-
go e
285-
where
286-
go (e : Expr) : Option Expr :=
287-
if let some target := e.find? isCandidate then
288-
if e.isIte || e.isDIte then
289-
let cond := target.getArg! 1 5
290-
-- Try to find a nested `if` in `cond`
291-
go cond |>.getD target
292-
else
293-
some target
294-
else
295-
none
296-
297-
isCandidate (e : Expr) : Bool := Id.run do
298-
if exceptionSet.contains e then
299-
false
300-
else if splitIte && (e.isIte || e.isDIte) then
301-
!(e.getArg! 1 5).hasLooseBVars
302-
else if let some info := isMatcherAppCore? env e then
303-
let args := e.getAppArgs
304-
for i in [info.getFirstDiscrPos : info.getFirstDiscrPos + info.numDiscrs] do
305-
if args[i]!.hasLooseBVars then
306-
return false
307-
return true
308-
else
309-
false
310-
311282
end Split
312283

313284
open Split
314285

315-
partial def splitTarget? (mvarId : MVarId) (splitIte := true) : MetaM (Option (List MVarId)) := commitWhenSome? do
286+
partial def splitTarget? (mvarId : MVarId) (splitIte := true) : MetaM (Option (List MVarId)) := commitWhenSome? do mvarId.withContext do
316287
let target ← instantiateMVars (← mvarId.getType)
317288
let rec go (badCases : ExprSet) : MetaM (Option (List MVarId)) := do
318-
if let some e := findSplit? (← getEnv) target splitIte badCases then
289+
if let some e findSplit? target (if splitIte then .both else .match) badCases then
319290
if e.isIte || e.isDIte then
320291
return (← splitIfTarget? mvarId).map fun (s₁, s₂) => [s₁.mvarId, s₂.mvarId]
321292
else
@@ -334,7 +305,7 @@ partial def splitTarget? (mvarId : MVarId) (splitIte := true) : MetaM (Option (L
334305

335306
def splitLocalDecl? (mvarId : MVarId) (fvarId : FVarId) : MetaM (Option (List MVarId)) := commitWhenSome? do
336307
mvarId.withContext do
337-
if let some e := findSplit? (← getEnv) (← instantiateMVars (← inferType (mkFVar fvarId))) then
308+
if let some e findSplit? (← instantiateMVars (← inferType (mkFVar fvarId))) then
338309
if e.isIte || e.isDIte then
339310
return (← splitIfLocalDecl? mvarId fvarId).map fun (mvarId₁, mvarId₂) => [mvarId₁, mvarId₂]
340311
else

src/Lean/Meta/Tactic/SplitIf.lean

Lines changed: 107 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,110 @@ import Lean.Meta.Tactic.Cases
99
import Lean.Meta.Tactic.Simp.Main
1010

1111
namespace Lean.Meta
12+
13+
inductive SplitKind where
14+
| ite | match | both
15+
16+
def SplitKind.considerIte : SplitKind → Bool
17+
| .ite | .both => true
18+
| _ => false
19+
20+
def SplitKind.considerMatch : SplitKind → Bool
21+
| .match | .both => true
22+
| _ => false
23+
24+
namespace FindSplitImpl
25+
26+
structure Context where
27+
exceptionSet : ExprSet := {}
28+
kind : SplitKind := .both
29+
30+
unsafe abbrev FindM := ReaderT Context $ StateT (PtrSet Expr) MetaM
31+
32+
private def isCandidate (env : Environment) (ctx : Context) (e : Expr) : Bool := Id.run do
33+
if ctx.exceptionSet.contains e then
34+
return false
35+
if ctx.kind.considerIte && (e.isIte || e.isDIte) then
36+
return !(e.getArg! 1 5).hasLooseBVars
37+
if ctx.kind.considerMatch then
38+
if let some info := isMatcherAppCore? env e then
39+
let args := e.getAppArgs
40+
for i in [info.getFirstDiscrPos : info.getFirstDiscrPos + info.numDiscrs] do
41+
if args[i]!.hasLooseBVars then
42+
return false
43+
return true
44+
return false
45+
46+
@[inline] unsafe def checkVisited (e : Expr) : OptionT FindM Unit := do
47+
if (← get).contains e then
48+
failure
49+
modify fun s => s.insert e
50+
51+
unsafe def visit (e : Expr) : OptionT FindM Expr := do
52+
checkVisited e
53+
if isCandidate (← getEnv) (← read) e then
54+
return e
55+
else
56+
-- We do not look for split candidates in proofs.
57+
unless e.hasLooseBVars do
58+
if (← isProof e) then
59+
failure
60+
match e with
61+
| .lam _ _ b _ | .proj _ _ b -- We do not look for split candidates in the binder of lambdas.
62+
| .mdata _ b => visit b
63+
| .forallE _ d b _ => visit d <|> visit b -- We want to look for candidates at `A → B`
64+
| .letE _ _ v b _ => visit v <|> visit b
65+
| .app .. => visitApp? e
66+
| _ => failure
67+
where
68+
visitApp? (e : Expr) : FindM (Option Expr) :=
69+
e.withApp fun f args => do
70+
let info ← getFunInfo f
71+
for u : i in [0:args.size] do
72+
let arg := args[i]
73+
if h : i < info.paramInfo.size then
74+
let info := info.paramInfo[i]
75+
unless info.isProp do
76+
if info.isExplicit then
77+
let some found ← visit arg | pure ()
78+
return found
79+
else
80+
let some found ← visit arg | pure ()
81+
return found
82+
visit f
83+
84+
end FindSplitImpl
85+
86+
/-- Return an `if-then-else` or `match-expr` to split. -/
87+
partial def findSplit? (e : Expr) (kind : SplitKind := .both) (exceptionSet : ExprSet := {}) : MetaM (Option Expr) := do
88+
go (← instantiateMVars e)
89+
where
90+
go (e : Expr) : MetaM (Option Expr) := do
91+
if let some target ← find? e then
92+
if e.isIte || e.isDIte then
93+
let cond := target.getArg! 1 5
94+
-- Try to find a nested `if` in `cond`
95+
return (← go cond).getD target
96+
else
97+
return some target
98+
else
99+
return none
100+
101+
find? (e : Expr) : MetaM (Option Expr) := do
102+
let some candidate ← unsafe FindSplitImpl.visit e { kind, exceptionSet } |>.run' mkPtrSet
103+
| return none
104+
trace[split.debug] "candidate:{indentExpr candidate}"
105+
return some candidate
106+
107+
/-- Return the condition and decidable instance of an `if` expression to case split. -/
108+
private partial def findIfToSplit? (e : Expr) : MetaM (Option (Expr × Expr)) := do
109+
if let some iteApp ← findSplit? e .ite then
110+
let cond := iteApp.getArg! 1 5
111+
let dec := iteApp.getArg! 2 5
112+
return (cond, dec)
113+
else
114+
return none
115+
12116
namespace SplitIf
13117

14118
builtin_initialize ext : LazyInitExtension MetaM Simp.Context ←
@@ -21,7 +125,7 @@ builtin_initialize ext : LazyInitExtension MetaM Simp.Context ←
21125
return {
22126
simpTheorems := #[s]
23127
congrTheorems := (← getSimpCongrTheorems)
24-
config := { Simp.neutralConfig with dsimp := false, implicitDefEqProofs := true }
128+
config := { Simp.neutralConfig with dsimp := false }
25129
}
26130

27131
/--
@@ -68,19 +172,9 @@ private def discharge? (numIndices : Nat) (useDecide : Bool) : Simp.Discharge :=
68172
def mkDischarge? (useDecide := false) : MetaM Simp.Discharge :=
69173
return discharge? (← getLCtx).numIndices useDecide
70174

71-
/-- Return the condition and decidable instance of an `if` expression to case split. -/
72-
private partial def findIfToSplit? (e : Expr) : Option (Expr × Expr) :=
73-
if let some iteApp := e.find? fun e => (e.isIte || e.isDIte) && !(e.getArg! 1 5).hasLooseBVars then
74-
let cond := iteApp.getArg! 1 5
75-
let dec := iteApp.getArg! 2 5
76-
-- Try to find a nested `if` in `cond`
77-
findIfToSplit? cond |>.getD (cond, dec)
78-
else
79-
none
80-
81-
def splitIfAt? (mvarId : MVarId) (e : Expr) (hName? : Option Name) : MetaM (Option (ByCasesSubgoal × ByCasesSubgoal)) := do
175+
def splitIfAt? (mvarId : MVarId) (e : Expr) (hName? : Option Name) : MetaM (Option (ByCasesSubgoal × ByCasesSubgoal)) := mvarId.withContext do
82176
let e ← instantiateMVars e
83-
if let some (cond, decInst) := findIfToSplit? e then
177+
if let some (cond, decInst) findIfToSplit? e then
84178
let hName ← match hName? with
85179
| none => mkFreshUserName `h
86180
| some hName => pure hName

tests/lean/run/splitIssue2.lean

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
namespace Batteries
2+
3+
/-- Union-find node type -/
4+
structure UFNode where
5+
/-- Parent of node -/
6+
parent : Nat
7+
8+
namespace UnionFind
9+
10+
/-- Parent of a union-find node, defaults to self when the node is a root -/
11+
def parentD (arr : Array UFNode) (i : Nat) : Nat :=
12+
if h : i < arr.size then (arr.get ⟨i, h⟩).parent else i
13+
14+
/-- Rank of a union-find node, defaults to 0 when the node is a root -/
15+
def rankD (arr : Array UFNode) (i : Nat) : Nat := 0
16+
17+
theorem parentD_of_not_lt : ¬i < arr.size → parentD arr i = i := sorry
18+
19+
theorem parentD_set {arr : Array UFNode} {x v i} :
20+
parentD (arr.set x v) i = if x.1 = i then v.parent else parentD arr i := by
21+
rw [parentD]
22+
sorry
23+
24+
end UnionFind
25+
26+
open UnionFind
27+
28+
structure UnionFind where
29+
arr : Array UFNode
30+
31+
namespace UnionFind
32+
33+
/-- Size of union-find structure. -/
34+
@[inline] abbrev size (self : UnionFind) := self.arr.size
35+
36+
/-- Parent of union-find node -/
37+
abbrev parent (self : UnionFind) (i : Nat) : Nat := parentD self.arr i
38+
39+
theorem parent_lt (self : UnionFind) (i : Nat) : self.parent i < self.size ↔ i < self.size :=
40+
sorry
41+
42+
/-- Rank of union-find node -/
43+
abbrev rank (self : UnionFind) (i : Nat) : Nat := rankD self.arr i
44+
45+
/-- Maximum rank of nodes in a union-find structure -/
46+
noncomputable def rankMax (self : UnionFind) := 0
47+
48+
/-- Root of a union-find node. -/
49+
def root (self : UnionFind) (x : Fin self.size) : Fin self.size :=
50+
let y := (self.arr.get x).parent
51+
if h : y = x then
52+
x
53+
else
54+
have : self.rankMax - self.rank (self.arr.get x).parent < self.rankMax - self.rank x :=
55+
sorry
56+
self.root ⟨y, sorry
57+
termination_by self.rankMax - self.rank x
58+
59+
/-- Root of a union-find node. Returns input if index is out of bounds. -/
60+
def rootD (self : UnionFind) (x : Nat) : Nat :=
61+
if h : x < self.size then self.root ⟨x, h⟩ else x
62+
63+
theorem rootD_parent (self : UnionFind) (x : Nat) : self.rootD (self.parent x) = self.rootD x := by
64+
simp only [rootD, Array.data_length, parent_lt]
65+
split
66+
· simp only [parentD, ↓reduceDIte, *]
67+
conv => rhs; rw [root]
68+
split
69+
· rw [root, dif_pos] <;> simp_all
70+
· simp
71+
· simp only [not_false_eq_true, parentD_of_not_lt, *]

0 commit comments

Comments
 (0)