|
1 | 1 | /-
|
2 | 2 | Copyright (c) 2023 Anand Rao Tadipatri. All rights reserved.
|
3 | 3 | Released under Apache 2.0 license as described in the file LICENSE.
|
4 |
| -Authors: Anand Rao Tadipatri |
| 4 | +Authors: Anand Rao Tadipatri, Jovan Gerbscheid |
5 | 5 | -/
|
6 |
| -import Lean.Elab.Term |
7 |
| -import Lean.Elab.Tactic |
8 |
| -import Lean.SubExpr |
9 |
| -import Lean.Meta.ExprLens |
10 |
| -import Lean.Meta.KAbstract |
11 |
| -import Lean.HeadIndex |
| 6 | +import Lean.Meta |
| 7 | + |
| 8 | +namespace Std.Tactic.Pattern |
12 | 9 |
|
13 | 10 | open Lean Meta Elab Tactic
|
14 | 11 |
|
15 | 12 | /-!
|
16 | 13 |
|
17 |
| -Basic utilities for tactics that target goal locations through patterns and their occurrences. |
18 |
| -
|
19 |
| -The code here include: |
20 |
| -- Functions for expanding syntax for patterns and occurrences into their corresponding expressions |
21 |
| -- Code for generating and finding the occurrences of patterns in expressions |
22 |
| -
|
23 |
| -The idea of referring to sub-expressions via patterns and occurrences is due to Yaël Dillies. |
| 14 | +We define the `patternLocation` syntax, which specifies one or more subexpressions |
| 15 | +in the goal, using a pattern and an optional `occs` argument. |
24 | 16 |
|
25 | 17 | -/
|
26 | 18 | open Parser.Tactic Conv
|
27 | 19 |
|
28 |
| -/-- Refer to a set of subexpression by specifying a pattern. |
29 |
| -
|
30 |
| -For example, if hypothesis `h` says that `1 + (2 + 3) = 1 + (2 + 3)`, then |
31 |
| -`(occs := 2 3) _ + _ at h` refers to the two expression `2 + 3`, |
32 |
| -because it first skips `1 + (2 + 3)`, and matches with `2 + 3`, |
33 |
| -which instantiates the pattern to be `2 + 3`, so the next match is the |
34 |
| -second instance of `2 + 3`. -/ |
35 |
| -syntax patternLocation := optional(occs) term optional(location) |
36 |
| - |
37 |
| -/-- |
38 |
| -Elaborate a pattern as an `AbstractMVarsResult`. |
39 |
| -This follows code from `Lean/Elab/Tactic/Conv/Pattern.lean`. -/ |
| 20 | +/-- Refer to a set of subexpression by specifying a pattern and occurrences. |
| 21 | +
|
| 22 | +For example, if hypothesis `h` says that `a + (b + c) = a + (b + c)`, then |
| 23 | +`(occs := 2 3) _ + _ at h` refers to the two occurrences of `b + c`, |
| 24 | +because it first skips `a + (b + c)`, and then matches with `b + c`, |
| 25 | +which instantiates the pattern to be `b + c`, so the next match is the |
| 26 | +second occurrence of `b + c`. -/ |
| 27 | +syntax patternLocation := (occs)? term (location)? |
| 28 | + |
| 29 | + |
| 30 | +/-- A structure containing the information provided by the `patternLocation` syntax. -/ |
| 31 | +structure PatternLocation where |
| 32 | + /-- The occurences of the pattern in the target. -/ |
| 33 | + occs : Option (Array Nat) |
| 34 | + /-- The pattern itself. -/ |
| 35 | + pattern : AbstractMVarsResult |
| 36 | + /-- The location in the goal. -/ |
| 37 | + loc : Location |
| 38 | + |
| 39 | +/-- Get the pattern occurrences as a `Occurrences`. -/ |
| 40 | +def PatternLocation.occurrences (p : PatternLocation) : Occurrences := |
| 41 | + match p.occs with |
| 42 | + | none => .all |
| 43 | + | some arr => .pos arr.toList |
| 44 | + |
| 45 | +/-- Elaborate a pattern expression. |
| 46 | +See elaboration of `Lean.Parser.Tactic.Conv.pattern`. -/ |
40 | 47 | def expandPattern (p : Syntax) : TermElabM AbstractMVarsResult :=
|
41 | 48 | withReader (fun ctx => { ctx with ignoreTCFailures := true, errToSorry := false }) <|
|
42 | 49 | Term.withoutModifyingElabMetaStateWithInfo <| withRef p do
|
43 | 50 | abstractMVars (← Term.elabTerm p none)
|
44 | 51 |
|
45 |
| -/-- Elaborate `occs` syntax as `Occurrences`. -/ |
46 |
| -def expandOptOccs (stx : Syntax) : TermElabM Occurrences := do |
47 |
| - if stx.isNone then |
48 |
| - return .all |
49 |
| - match stx[0] with |
50 |
| - | `(occs| (occs := *)) => return .all |
| 52 | +/-- Elaborate `occs` syntax. -/ |
| 53 | +def expandOptOccs (stx : Option (TSyntax ``occs)) : TermElabM (Option (Array Nat)) := do |
| 54 | + let some stx := stx | return none |
| 55 | + match stx with |
| 56 | + | `(occs| (occs := *)) => return none |
51 | 57 | | `(occs| (occs := $ids*)) =>
|
52 |
| - return .pos <| Array.toList <| ← ids.mapM fun id => |
| 58 | + some <$> ids.mapM fun id => |
53 | 59 | let n := id.toNat
|
54 | 60 | if n == 0 then
|
55 | 61 | throwErrorAt id "positive integer expected"
|
56 | 62 | else return n
|
57 |
| - | _ => throwError m! "{stx}" |
58 |
| - |
59 |
| -/-- Elaborate the occurrences, pattern and location in a `patternLocation`. -/ |
60 |
| -def expandPatternLocation (stx : Syntax) : TermElabM (Occurrences × AbstractMVarsResult × Location) := do |
61 |
| - let occs ← expandOptOccs stx[0] |
62 |
| - let pattern ← expandPattern stx[1] |
63 |
| - let loc := expandOptLocation stx[2] |
64 |
| - return (occs, pattern, loc) |
65 |
| -section Expand |
66 |
| - |
67 |
| - |
68 |
| - |
69 |
| -end Expand |
70 |
| - |
71 |
| -section PatternsAndOccurrences |
72 |
| - |
73 |
| -/-- The pattern at a given position in an expression. |
74 |
| - Variables under binders are turned into meta-variables in the pattern. -/ |
75 |
| -def SubExpr.patternAt (p : SubExpr.Pos) (root : Expr) : MetaM Expr := do |
76 |
| - let e ← Core.viewSubexpr p root |
77 |
| - let binders ← Core.viewBinders p root |
78 |
| - let mvars ← binders.mapM fun (name, type) => |
79 |
| - mkFreshExprMVar type (userName := name) |
80 |
| - return e.instantiateRev mvars |
81 |
| - |
82 |
| -/-- Finds the occurrence number of the pattern in the expression |
83 |
| - that matches the sub-expression at the specified position. |
84 |
| - This follows the code of `kabstract` from Lean core. -/ |
85 |
| -def findMatchingOccurrence (position : SubExpr.Pos) (root : Expr) (pattern : Expr) : MetaM Nat := do |
86 |
| - let root ← instantiateMVars root |
87 |
| - unless ← isDefEq pattern (← SubExpr.patternAt position root) do |
88 |
| - throwError s!"The specified pattern does not match the pattern at position {position}." |
89 |
| - let pattern ← instantiateMVars pattern |
90 |
| - let pHeadIdx := pattern.toHeadIndex |
91 |
| - let pNumArgs := pattern.headNumArgs |
92 |
| - let rec |
93 |
| - /-- The recursive step in the expression traversal to search for matching occurrences. -/ |
94 |
| - visit (e : Expr) (p : SubExpr.Pos) (offset : Nat) := do |
95 |
| - let visitChildren : Unit → StateRefT Nat (OptionT MetaM) Unit := fun _ => do |
| 63 | + | _ => throwUnsupportedSyntax |
| 64 | + |
| 65 | +/-- Elaborate `patternLocation` syntax. -/ |
| 66 | +def expandPatternLocation (stx : Syntax) : TacticM PatternLocation := |
| 67 | + withMainContext do |
| 68 | + match stx with |
| 69 | + | `(patternLocation| $[$a]? $pat $[$loc]?) => |
| 70 | + let occs ← expandOptOccs a |
| 71 | + let pattern ← expandPattern pat |
| 72 | + let loc := match loc with |
| 73 | + | some loc => expandLocation loc |
| 74 | + | none => Location.targets #[] true |
| 75 | + return { occs, pattern, loc } |
| 76 | + | _ => throwUnsupportedSyntax |
| 77 | + |
| 78 | + |
| 79 | +/-- return the subexpression positions that `kabstract` can abstract -/ |
| 80 | +def kabstractPositions (p e : Expr) : MetaM (Array SubExpr.Pos) := do |
| 81 | + let pHeadIdx := p.toHeadIndex |
| 82 | + let pNumArgs := p.headNumArgs |
| 83 | + let rec visit (e : Expr) (pos : SubExpr.Pos) (positions : Array SubExpr.Pos) : |
| 84 | + MetaM (Array SubExpr.Pos) := do |
| 85 | + let visitChildren : Array SubExpr.Pos → MetaM (Array SubExpr.Pos) := |
96 | 86 | match e with
|
97 |
| - | .app f a => do |
98 |
| - visit f p.pushAppFn offset <|> |
99 |
| - visit a p.pushAppArg offset |
100 |
| - | .mdata _ b => visit b p offset |
101 |
| - | .proj _ _ b => visit b p.pushProj offset |
102 |
| - | .letE _ t v b _ => do |
103 |
| - visit t p.pushLetVarType offset <|> |
104 |
| - visit v p.pushLetValue offset <|> |
105 |
| - visit b p.pushLetBody (offset+1) |
106 |
| - | .lam _ d b _ => do |
107 |
| - visit d p.pushBindingDomain offset <|> |
108 |
| - visit b p.pushBindingBody (offset+1) |
109 |
| - | .forallE _ d b _ => do |
110 |
| - visit d p.pushBindingDomain offset <|> |
111 |
| - visit b p.pushBindingBody (offset+1) |
112 |
| - | _ => failure |
113 |
| - if e.hasLooseBVars then |
114 |
| - visitChildren () |
115 |
| - else if e.toHeadIndex != pHeadIdx || e.headNumArgs != pNumArgs then |
116 |
| - visitChildren () |
117 |
| - else if (← isDefEq e pattern) then |
118 |
| - let i ← get |
119 |
| - set (i+1) |
120 |
| - if p = position then |
121 |
| - return () |
122 |
| - else |
123 |
| - visitChildren () |
| 87 | + | .app f a => visit f pos.pushAppFn |
| 88 | + >=> visit a pos.pushAppArg |
| 89 | + | .mdata _ b => visit b pos |
| 90 | + | .proj _ _ b => visit b pos.pushProj |
| 91 | + | .letE _ t v b _ => visit t pos.pushLetVarType |
| 92 | + >=> visit v pos.pushLetValue |
| 93 | + >=> visit b pos.pushLetBody |
| 94 | + | .lam _ d b _ => visit d pos.pushBindingDomain |
| 95 | + >=> visit b pos.pushBindingBody |
| 96 | + | .forallE _ d b _ => visit d pos.pushBindingDomain |
| 97 | + >=> visit b pos.pushBindingBody |
| 98 | + | _ => pure |
| 99 | + if e.hasLooseBVars || e.toHeadIndex != pHeadIdx || e.headNumArgs != pNumArgs then |
| 100 | + visitChildren positions |
124 | 101 | else
|
125 |
| - visitChildren () |
126 |
| - let .some (_, occ) ← visit root .root 0 |>.run 0 | |
127 |
| - throwError s!"Could not find pattern at specified position {position}." |
128 |
| - return occ |
129 |
| - |
130 |
| -/-- Finds the occurrence number of the pattern at |
131 |
| - the specified position in the whole expression. -/ |
132 |
| -def findOccurrence (position : SubExpr.Pos) (root : Expr) : MetaM Nat := do |
133 |
| - let pattern ← SubExpr.patternAt position root |
134 |
| - findMatchingOccurrence position root pattern |
135 |
| - |
136 |
| -end PatternsAndOccurrences |
137 |
| - |
138 |
| -/-- Substitute occurrences of a pattern in an expression with the result of `replacement`. -/ |
139 |
| -def substitute (e : Expr) (pattern : AbstractMVarsResult) (occs : Occurrences) |
140 |
| - (replacement : Expr → MetaM Expr) (withoutErr : Bool := true) : MetaM Expr := do |
141 |
| - let (_, _, p) ← openAbstractMVarsResult pattern |
142 |
| - let eAbst ← kabstract e p occs |
143 |
| - unless eAbst.hasLooseBVars || withoutErr do |
144 |
| - throwError m!"Failed to find instance of pattern {indentExpr p} in {indentExpr e}." |
145 |
| - instantiateMVars <| Expr.instantiate1 eAbst (← replacement p) |
146 |
| - |
147 |
| -/-- Replace a pattern at the specified locations with the value of `replacement`, |
148 |
| - which is assumed to be definitionally equal to the original pattern. -/ |
149 |
| -def replaceOccurrencesDefEq (tacticName : Name) (location : Location) (occurrences : Occurrences) |
150 |
| - (pattern : AbstractMVarsResult) (replacement : Expr → MetaM Expr) : TacticM Unit := do |
151 |
| - let goal ← getMainGoal |
152 |
| - goal.withContext do |
153 |
| - withLocation location |
154 |
| - (atLocal := fun fvarId => do |
155 |
| - let hypType ← fvarId.getType |
156 |
| - let newGoal ← goal.replaceLocalDeclDefEq fvarId <| ← |
157 |
| - substitute hypType pattern occurrences replacement |
158 |
| - replaceMainGoal [newGoal]) |
159 |
| - (atTarget := do |
160 |
| - let newGoal ← goal.replaceTargetDefEq <| ← |
161 |
| - substitute (← goal.getType) pattern occurrences replacement |
162 |
| - replaceMainGoal [newGoal]) |
163 |
| - (failed := (throwTacticEx tacticName · m!"Failed to run tactic {tacticName}.")) |
| 102 | + let mctx ← getMCtx |
| 103 | + if (← isDefEq e p) then |
| 104 | + setMCtx mctx |
| 105 | + visitChildren (positions.push pos) |
| 106 | + else |
| 107 | + visitChildren positions |
| 108 | + visit e .root #[] |
| 109 | + |
| 110 | +/-- return the pattern and occurrences specifying position `pos` in target `e`. -/ |
| 111 | +def patternAndIndex (pos : SubExpr.Pos) (e : Expr) : MetaM (Expr × Option Nat) := do |
| 112 | + let e ← instantiateMVars e |
| 113 | + let pattern ← Core.viewSubexpr pos e |
| 114 | + if pattern.hasLooseBVars then |
| 115 | + throwError "the subexpression contains loose bound variables" |
| 116 | + let positions ← kabstractPositions pattern e |
| 117 | + if positions.size == 1 then |
| 118 | + return (pattern, none) |
| 119 | + let some index := positions.findIdx? (· == pos) | unreachable! |
| 120 | + return (pattern, some (index + 1)) |
0 commit comments