Skip to content

Commit 0a45925

Browse files
committed
refactor a bunch
1 parent c28c803 commit 0a45925

File tree

2 files changed

+133
-149
lines changed

2 files changed

+133
-149
lines changed

Std/Tactic/Pattern/Utils.lean

Lines changed: 95 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -1,163 +1,120 @@
11
/-
22
Copyright (c) 2023 Anand Rao Tadipatri. All rights reserved.
33
Released under Apache 2.0 license as described in the file LICENSE.
4-
Authors: Anand Rao Tadipatri
4+
Authors: Anand Rao Tadipatri, Jovan Gerbscheid
55
-/
6-
import Lean.Elab.Term
7-
import Lean.Elab.Tactic
8-
import Lean.SubExpr
9-
import Lean.Meta.ExprLens
10-
import Lean.Meta.KAbstract
11-
import Lean.HeadIndex
6+
import Lean.Meta
7+
8+
namespace Std.Tactic.Pattern
129

1310
open Lean Meta Elab Tactic
1411

1512
/-!
1613
17-
Basic utilities for tactics that target goal locations through patterns and their occurrences.
18-
19-
The code here include:
20-
- Functions for expanding syntax for patterns and occurrences into their corresponding expressions
21-
- Code for generating and finding the occurrences of patterns in expressions
22-
23-
The idea of referring to sub-expressions via patterns and occurrences is due to Yaël Dillies.
14+
We define the `patternLocation` syntax, which specifies one or more subexpressions
15+
in the goal, using a pattern and an optional `occs` argument.
2416
2517
-/
2618
open Parser.Tactic Conv
2719

28-
/-- Refer to a set of subexpression by specifying a pattern.
29-
30-
For example, if hypothesis `h` says that `1 + (2 + 3) = 1 + (2 + 3)`, then
31-
`(occs := 2 3) _ + _ at h` refers to the two expression `2 + 3`,
32-
because it first skips `1 + (2 + 3)`, and matches with `2 + 3`,
33-
which instantiates the pattern to be `2 + 3`, so the next match is the
34-
second instance of `2 + 3`. -/
35-
syntax patternLocation := optional(occs) term optional(location)
36-
37-
/--
38-
Elaborate a pattern as an `AbstractMVarsResult`.
39-
This follows code from `Lean/Elab/Tactic/Conv/Pattern.lean`. -/
20+
/-- Refer to a set of subexpression by specifying a pattern and occurrences.
21+
22+
For example, if hypothesis `h` says that `a + (b + c) = a + (b + c)`, then
23+
`(occs := 2 3) _ + _ at h` refers to the two occurrences of `b + c`,
24+
because it first skips `a + (b + c)`, and then matches with `b + c`,
25+
which instantiates the pattern to be `b + c`, so the next match is the
26+
second occurrence of `b + c`. -/
27+
syntax patternLocation := (occs)? term (location)?
28+
29+
30+
/-- A structure containing the information provided by the `patternLocation` syntax. -/
31+
structure PatternLocation where
32+
/-- The occurences of the pattern in the target. -/
33+
occs : Option (Array Nat)
34+
/-- The pattern itself. -/
35+
pattern : AbstractMVarsResult
36+
/-- The location in the goal. -/
37+
loc : Location
38+
39+
/-- Get the pattern occurrences as a `Occurrences`. -/
40+
def PatternLocation.occurrences (p : PatternLocation) : Occurrences :=
41+
match p.occs with
42+
| none => .all
43+
| some arr => .pos arr.toList
44+
45+
/-- Elaborate a pattern expression.
46+
See elaboration of `Lean.Parser.Tactic.Conv.pattern`. -/
4047
def expandPattern (p : Syntax) : TermElabM AbstractMVarsResult :=
4148
withReader (fun ctx => { ctx with ignoreTCFailures := true, errToSorry := false }) <|
4249
Term.withoutModifyingElabMetaStateWithInfo <| withRef p do
4350
abstractMVars (← Term.elabTerm p none)
4451

45-
/-- Elaborate `occs` syntax as `Occurrences`. -/
46-
def expandOptOccs (stx : Syntax) : TermElabM Occurrences := do
47-
if stx.isNone then
48-
return .all
49-
match stx[0] with
50-
| `(occs| (occs := *)) => return .all
52+
/-- Elaborate `occs` syntax. -/
53+
def expandOptOccs (stx : Option (TSyntax ``occs)) : TermElabM (Option (Array Nat)) := do
54+
let some stx := stx | return none
55+
match stx with
56+
| `(occs| (occs := *)) => return none
5157
| `(occs| (occs := $ids*)) =>
52-
return .pos <| Array.toList <| ← ids.mapM fun id =>
58+
some <$> ids.mapM fun id =>
5359
let n := id.toNat
5460
if n == 0 then
5561
throwErrorAt id "positive integer expected"
5662
else return n
57-
| _ => throwError m! "{stx}"
58-
59-
/-- Elaborate the occurrences, pattern and location in a `patternLocation`. -/
60-
def expandPatternLocation (stx : Syntax) : TermElabM (Occurrences × AbstractMVarsResult × Location) := do
61-
let occs ← expandOptOccs stx[0]
62-
let pattern ← expandPattern stx[1]
63-
let loc := expandOptLocation stx[2]
64-
return (occs, pattern, loc)
65-
section Expand
66-
67-
68-
69-
end Expand
70-
71-
section PatternsAndOccurrences
72-
73-
/-- The pattern at a given position in an expression.
74-
Variables under binders are turned into meta-variables in the pattern. -/
75-
def SubExpr.patternAt (p : SubExpr.Pos) (root : Expr) : MetaM Expr := do
76-
let e ← Core.viewSubexpr p root
77-
let binders ← Core.viewBinders p root
78-
let mvars ← binders.mapM fun (name, type) =>
79-
mkFreshExprMVar type (userName := name)
80-
return e.instantiateRev mvars
81-
82-
/-- Finds the occurrence number of the pattern in the expression
83-
that matches the sub-expression at the specified position.
84-
This follows the code of `kabstract` from Lean core. -/
85-
def findMatchingOccurrence (position : SubExpr.Pos) (root : Expr) (pattern : Expr) : MetaM Nat := do
86-
let root ← instantiateMVars root
87-
unless ← isDefEq pattern (← SubExpr.patternAt position root) do
88-
throwError s!"The specified pattern does not match the pattern at position {position}."
89-
let pattern ← instantiateMVars pattern
90-
let pHeadIdx := pattern.toHeadIndex
91-
let pNumArgs := pattern.headNumArgs
92-
let rec
93-
/-- The recursive step in the expression traversal to search for matching occurrences. -/
94-
visit (e : Expr) (p : SubExpr.Pos) (offset : Nat) := do
95-
let visitChildren : Unit → StateRefT Nat (OptionT MetaM) Unit := fun _ => do
63+
| _ => throwUnsupportedSyntax
64+
65+
/-- Elaborate `patternLocation` syntax. -/
66+
def expandPatternLocation (stx : Syntax) : TacticM PatternLocation :=
67+
withMainContext do
68+
match stx with
69+
| `(patternLocation| $[$a]? $pat $[$loc]?) =>
70+
let occs ← expandOptOccs a
71+
let pattern ← expandPattern pat
72+
let loc := match loc with
73+
| some loc => expandLocation loc
74+
| none => Location.targets #[] true
75+
return { occs, pattern, loc }
76+
| _ => throwUnsupportedSyntax
77+
78+
79+
/-- return the subexpression positions that `kabstract` can abstract -/
80+
def kabstractPositions (p e : Expr) : MetaM (Array SubExpr.Pos) := do
81+
let pHeadIdx := p.toHeadIndex
82+
let pNumArgs := p.headNumArgs
83+
let rec visit (e : Expr) (pos : SubExpr.Pos) (positions : Array SubExpr.Pos) :
84+
MetaM (Array SubExpr.Pos) := do
85+
let visitChildren : Array SubExpr.Pos → MetaM (Array SubExpr.Pos) :=
9686
match e with
97-
| .app f a => do
98-
visit f p.pushAppFn offset <|>
99-
visit a p.pushAppArg offset
100-
| .mdata _ b => visit b p offset
101-
| .proj _ _ b => visit b p.pushProj offset
102-
| .letE _ t v b _ => do
103-
visit t p.pushLetVarType offset <|>
104-
visit v p.pushLetValue offset <|>
105-
visit b p.pushLetBody (offset+1)
106-
| .lam _ d b _ => do
107-
visit d p.pushBindingDomain offset <|>
108-
visit b p.pushBindingBody (offset+1)
109-
| .forallE _ d b _ => do
110-
visit d p.pushBindingDomain offset <|>
111-
visit b p.pushBindingBody (offset+1)
112-
| _ => failure
113-
if e.hasLooseBVars then
114-
visitChildren ()
115-
else if e.toHeadIndex != pHeadIdx || e.headNumArgs != pNumArgs then
116-
visitChildren ()
117-
else if (← isDefEq e pattern) then
118-
let i ← get
119-
set (i+1)
120-
if p = position then
121-
return ()
122-
else
123-
visitChildren ()
87+
| .app f a => visit f pos.pushAppFn
88+
>=> visit a pos.pushAppArg
89+
| .mdata _ b => visit b pos
90+
| .proj _ _ b => visit b pos.pushProj
91+
| .letE _ t v b _ => visit t pos.pushLetVarType
92+
>=> visit v pos.pushLetValue
93+
>=> visit b pos.pushLetBody
94+
| .lam _ d b _ => visit d pos.pushBindingDomain
95+
>=> visit b pos.pushBindingBody
96+
| .forallE _ d b _ => visit d pos.pushBindingDomain
97+
>=> visit b pos.pushBindingBody
98+
| _ => pure
99+
if e.hasLooseBVars || e.toHeadIndex != pHeadIdx || e.headNumArgs != pNumArgs then
100+
visitChildren positions
124101
else
125-
visitChildren ()
126-
let .some (_, occ) ← visit root .root 0 |>.run 0 |
127-
throwError s!"Could not find pattern at specified position {position}."
128-
return occ
129-
130-
/-- Finds the occurrence number of the pattern at
131-
the specified position in the whole expression. -/
132-
def findOccurrence (position : SubExpr.Pos) (root : Expr) : MetaM Nat := do
133-
let pattern ← SubExpr.patternAt position root
134-
findMatchingOccurrence position root pattern
135-
136-
end PatternsAndOccurrences
137-
138-
/-- Substitute occurrences of a pattern in an expression with the result of `replacement`. -/
139-
def substitute (e : Expr) (pattern : AbstractMVarsResult) (occs : Occurrences)
140-
(replacement : Expr → MetaM Expr) (withoutErr : Bool := true) : MetaM Expr := do
141-
let (_, _, p) ← openAbstractMVarsResult pattern
142-
let eAbst ← kabstract e p occs
143-
unless eAbst.hasLooseBVars || withoutErr do
144-
throwError m!"Failed to find instance of pattern {indentExpr p} in {indentExpr e}."
145-
instantiateMVars <| Expr.instantiate1 eAbst (← replacement p)
146-
147-
/-- Replace a pattern at the specified locations with the value of `replacement`,
148-
which is assumed to be definitionally equal to the original pattern. -/
149-
def replaceOccurrencesDefEq (tacticName : Name) (location : Location) (occurrences : Occurrences)
150-
(pattern : AbstractMVarsResult) (replacement : Expr → MetaM Expr) : TacticM Unit := do
151-
let goal ← getMainGoal
152-
goal.withContext do
153-
withLocation location
154-
(atLocal := fun fvarId => do
155-
let hypType ← fvarId.getType
156-
let newGoal ← goal.replaceLocalDeclDefEq fvarId <| ←
157-
substitute hypType pattern occurrences replacement
158-
replaceMainGoal [newGoal])
159-
(atTarget := do
160-
let newGoal ← goal.replaceTargetDefEq <| ←
161-
substitute (← goal.getType) pattern occurrences replacement
162-
replaceMainGoal [newGoal])
163-
(failed := (throwTacticEx tacticName · m!"Failed to run tactic {tacticName}."))
102+
let mctx ← getMCtx
103+
if (← isDefEq e p) then
104+
setMCtx mctx
105+
visitChildren (positions.push pos)
106+
else
107+
visitChildren positions
108+
visit e .root #[]
109+
110+
/-- return the pattern and occurrences specifying position `pos` in target `e`. -/
111+
def patternAndIndex (pos : SubExpr.Pos) (e : Expr) : MetaM (Expr × Option Nat) := do
112+
let e ← instantiateMVars e
113+
let pattern ← Core.viewSubexpr pos e
114+
if pattern.hasLooseBVars then
115+
throwError "the subexpression contains loose bound variables"
116+
let positions ← kabstractPositions pattern e
117+
if positions.size == 1 then
118+
return (pattern, none)
119+
let some index := positions.findIdx? (· == pos) | unreachable!
120+
return (pattern, some (index + 1))

Std/Tactic/PatternTactics/Unfold.lean

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,32 @@
11
/-
2-
Copyright (c) 2023 J. W. Gerbscheid, Anand Rao Tadipatri. All rights reserved.
2+
Copyright (c) 2023 Jovan Gerbscheid, Anand Rao Tadipatri. All rights reserved.
33
Released under Apache 2.0 license as described in the file LICENSE.
4-
Authors: J. W. Gerbscheid, Anand Rao Tadipatri
4+
Authors: Jovan Gerbscheid, Anand Rao Tadipatri
55
-/
66
import Std.Tactic.Pattern.Utils
77
import Lean.Parser.Tactic
88

9-
open Lean Meta
9+
namespace Std.Tactic.Unfold
10+
open Lean Meta Elab.Tactic Pattern
1011

1112
/-!
12-
1313
# Targeted unfolding
1414
1515
A tactic for definitionally unfolding expressions.
1616
The targeted sub-expression is selected using a pattern.
1717
18-
example use case:
18+
example use cases:
1919
```
2020
@[irreducible] def f (n : Nat) := n + 1
21-
example : ∀ n : Nat, n + 1 = f n := by
22-
unfold' f _ at ⊢ --∀ (n : Nat), n + 1 = n + 1
23-
intro n
21+
example (n : Nat) : n + 1 = f n := by
22+
unfold' f n
23+
rfl
24+
25+
example (n m : Nat) : f n + f m = f n + (m+1) := by
26+
unfold' (occs := 2) f _
2427
rfl
2528
```
29+
2630
-/
2731

2832
/-- If the head of the expression is a projection, reduce the projection. -/
@@ -58,7 +62,30 @@ def replaceByDef (e : Expr) : MetaM Expr :=
5862

5963
throwError m! "Could not find a definition for {e}."
6064

61-
open Elab Tactic Parser Tactic Conv
65+
/-- Replace a pattern at the specified locations with the value of `replacement`,
66+
which is assumed to be definitionally equal to the original pattern. -/
67+
def replaceOccurrencesDefEq (tacticName : Name) (pattern : Pattern.PatternLocation)
68+
(replacement : Expr → MetaM Expr) : TacticM Unit := do
69+
let goal ← getMainGoal
70+
goal.withContext do
71+
withLocation pattern.loc
72+
(atLocal := fun fvarId => do
73+
let hypType ← fvarId.getType
74+
let newGoal ← goal.replaceLocalDeclDefEq fvarId (← substitute hypType)
75+
replaceMainGoal [newGoal])
76+
(atTarget := do
77+
let newGoal ← goal.replaceTargetDefEq (← substitute (← goal.getType))
78+
replaceMainGoal [newGoal])
79+
(failed := (throwTacticEx tacticName · ""))
80+
where
81+
/-- Substitute occurrences of a pattern in an expression with the result of `replacement`. -/
82+
substitute (e : Expr) : MetaM Expr := do
83+
let (_, _, p) ← openAbstractMVarsResult pattern.pattern
84+
let eAbst ← kabstract e p pattern.occurrences
85+
unless eAbst.hasLooseBVars do
86+
throwError m! "did not find instance of pattern {p} in target {indentExpr e}"
87+
return eAbst.instantiate1 (← replacement (← instantiateMVars p))
88+
6289

6390
/-- Unfold the selected expression in one of the following ways:
6491
@@ -72,5 +99,5 @@ Note that we always reduce a projection after unfolding a constant,
7299
so that `@Add.add ℕ instAddNat a b` gives `Nat.add a b` instead of `instAddNat.1 a b`.
73100
-/
74101
elab "unfold'" loc:patternLocation : tactic => withMainContext do
75-
let (occs, pattern, loc) ← expandPatternLocation loc
76-
replaceOccurrencesDefEq `unfold' loc occs pattern replaceByDef
102+
let pattern ← expandPatternLocation loc
103+
replaceOccurrencesDefEq `unfold' pattern replaceByDef

0 commit comments

Comments
 (0)