|
| 1 | +/- |
| 2 | +Copyright (c) 2024 Siddharth Bhat. All rights reserved. |
| 3 | +Released under Apache 2.0 license as described in the file LICENSE. |
| 4 | +Authors: Siddharth Bhat |
| 5 | +
|
| 6 | +This file implements lazy ackermannization [1, 2] |
| 7 | +
|
| 8 | +[1] https://lara.epfl.ch/w/_media/model-based.pdf |
| 9 | +[2] https://leodemoura.github.io/files/oregon08.pdf |
| 10 | +-/ |
| 11 | +prelude |
| 12 | +import Std.Tactic.BVDecide.Bitblast |
| 13 | +import Std.Tactic.BVAckermannize.Syntax |
| 14 | + |
| 15 | +structure Result where |
| 16 | + |
| 17 | +namespace Ack |
| 18 | + |
| 19 | +structure Config where |
| 20 | + |
| 21 | +structure Context where |
| 22 | + config : Config |
| 23 | + |
| 24 | +structure Argument where |
| 25 | + /-- The expression corresponding to the argument -/ |
| 26 | + x : Expr |
| 27 | + /-- The cached type of the expression x -/ |
| 28 | + xTy : Expr |
| 29 | +deriving Hashable, BEq, Inhabited |
| 30 | + |
| 31 | +/-- |
| 32 | +A lazily unfolded applied call to a function. |
| 33 | +-/ |
| 34 | +structure Call where |
| 35 | + -- | TODO: replace with `Array Argument`. |
| 36 | + -- -- The name of the call (?) |
| 37 | + -- name : Name |
| 38 | + /-- the expression for the function argument -/ |
| 39 | + x : Expr |
| 40 | + /-- the free variable for for `(f x)`. -/ |
| 41 | + fx : FVarId |
| 42 | + /- Cached type of domain of f, which is also the type of the argument `x` -/ |
| 43 | + xTy : Expr |
| 44 | + /-- Cached type of codomain of f, which is the also the type of the result `fx`. -/ |
| 45 | + fxTy : Expr |
| 46 | + /-- heqProof : The proof that the fvar `fx` eauals the function application `f x` -/ |
| 47 | + heqProof : Expr |
| 48 | +deriving Hashable, BEq, Inhabited |
| 49 | + |
| 50 | +instance : ToMessageData Call where |
| 51 | + toMessageData c := m!"{Expr.fvar c.fx} : {c.fxTy} = f ({c.x} : {c.xTy}) with proof {c.heqProof}" |
| 52 | + |
| 53 | +structure State where |
| 54 | + /-- |
| 55 | + A maping from a function `f` to all calls of the function `{fx₁, fx₂, ...}`. |
| 56 | + This is used to generate equations of the form `x₁ = x₂ → fx₁ = fx₂` on-demand. |
| 57 | + -/ |
| 58 | + fn2apps : HashMap Expr (Std.HashSet Call) := {} |
| 59 | + /-- A counter for generating fresh names. -/ |
| 60 | + gensymCounter : Nat := 0 |
| 61 | + |
| 62 | + |
| 63 | +def State.init (_cfg : Config) : State where |
| 64 | + |
| 65 | +abbrev AckM := StateRefT State (ReaderT Context TacticM) |
| 66 | + |
| 67 | +def run (m : AckM α) (ctx : Context) : TacticM α := |
| 68 | + m.run' (State.init ctx.config) |>.run ctx |
| 69 | + |
| 70 | +/-- Generate a fresh name. -/ |
| 71 | +def gensym : AckM Name := do |
| 72 | + modify fun s => { s with gensymCounter := s.gensymCounter + 1 } |
| 73 | + return Name.mkNum `ack (← get).gensymCounter |
| 74 | + |
| 75 | +def withMainContext (ma : AckM α) : AckM α := (← getMainGoal).withContext ma |
| 76 | + |
| 77 | +def withContext (g : MVarId) (ma : AckM α) : AckM α := g.withContext ma |
| 78 | + |
| 79 | +/-- Get the calls to a function `f`. -/ |
| 80 | +def getCalls (f : Expr) : AckM (Std.HashSet Call) := do |
| 81 | + return (← get).fn2apps.findD fn {} |
| 82 | + |
| 83 | +/-- Track a call to the function `f` -/ |
| 84 | +-- TODO: do we need the `fn` argument? Isn't this already in `Call`? |
| 85 | +def addCall (fn : Expr) (call : Call) : AckM Unit := do |
| 86 | + let calls ← getCalls fn |
| 87 | + modify fun s => { s with fn2apps := s.fn2apps.insert fn (calls.insert call) } |
| 88 | + |
| 89 | +/-- create a trace node in trace class (i.e. `set_option traceClass true`), |
| 90 | +with header `header`, whose default collapsed state is `collapsed`. -/ |
| 91 | +def withTraceNode (header : MessageData) (k : AckM α) |
| 92 | + (collapsed : Bool := true) |
| 93 | + (traceClass : Name := `ack) : AckM α := |
| 94 | + Lean.withTraceNode traceClass (fun _ => return header) k (collapsed := collapsed) |
| 95 | + |
| 96 | +/-- An emoji used to report intemediate states where the tactic is processing hypotheses. -/ |
| 97 | +def processingEmoji : String := "⚙️" |
| 98 | + |
| 99 | +/-- |
| 100 | +Create a trace note that folds `header` with `(NOTE: can be large)`, |
| 101 | +and prints `msg` under such a trace node. |
| 102 | +Used to print goal states, which can be quite noisy in the trace. |
| 103 | +-/ |
| 104 | +def traceLargeMsg (header : MessageData) (msg : MessageData) : AckM Unit := |
| 105 | + withTraceNode m!"{header} (NOTE: can be large)" do |
| 106 | + trace[ack] msg |
| 107 | + |
| 108 | + |
| 109 | +/-- The proof of correctness of the Ackermannization transform. -/ |
| 110 | +theorem ackermannize_proof (A : Type _) (B : Type _) |
| 111 | + (f : A → B) |
| 112 | + (x y : A) |
| 113 | + (fx fy : B) |
| 114 | + (hfx : f x = fx) -- In the same order that `generalize h : f x = fx` would produce. |
| 115 | + (hfy : f y = fy) : |
| 116 | + x = y → fx = fy := by |
| 117 | + intros h |
| 118 | + subst h |
| 119 | + simp [← hfx, ← hfy] |
| 120 | + |
| 121 | +/-- Returns `True` if the type is a function type that is understood by the bitblaster. -/ |
| 122 | +def isBitblastTy (e : Expr) : Bool := |
| 123 | + match_expr e with |
| 124 | + | BitVec _ => true |
| 125 | + | Bool => true |
| 126 | + | _ => false |
| 127 | + |
| 128 | +/- |
| 129 | +Introduce a new definition into the local context, |
| 130 | +and return the FVarId of the new definition in the goal. |
| 131 | +-/ |
| 132 | +def introDef (name : Name) (hdefVal : Expr) : AckM FVarId := do |
| 133 | + withMainContext do |
| 134 | + let goal ← getMainGoal |
| 135 | + let hdefTy ← inferType hdefVal |
| 136 | + |
| 137 | + let goal ← goal.assert name hdefTy hdefVal |
| 138 | + let (fvar, goal) ← goal.intro1P |
| 139 | + replaceMainGoal [goal] |
| 140 | + return fvar |
| 141 | + |
| 142 | +def doAck (eorig : Expr) : AckM Unit := do |
| 143 | + withMainContext do |
| 144 | + traceLargeMsg m!"🔝 TOPLEVEL '{eorig}'" m!"{toString eorig}" |
| 145 | + match eorig with |
| 146 | + | .mdata _ e => doAck e |
| 147 | + | .bvar .. | .fvar .. | .mvar .. | .sort .. | .const .. | .proj .. | .lit .. => return () |
| 148 | + | .app f args => do |
| 149 | + withTraceNode m!"processing '{eorig}'..." do |
| 150 | + doAck f |
| 151 | + doAck args |
| 152 | + withMainContext do |
| 153 | + let e := Expr.app f args |
| 154 | + let args := e.getAppArgs |
| 155 | + |
| 156 | + let ety ← inferType e |
| 157 | + if ! isBitblastTy ety then |
| 158 | + trace[ack] "{crossEmoji} '{eorig}' : '{ety}' not bitblastable.." |
| 159 | + return () |
| 160 | + trace[ack] "{checkEmoji} found bitblastable call ('{f}' '{args}') : '{ety}'." |
| 161 | + |
| 162 | + let newName : Name ← gensym |
| 163 | + -- TODO: build the larger application... |
| 164 | + if h : args.size ≠ 1 then |
| 165 | + trace[ack] "{crossEmoji} Expected fn app ('{f}' '{args}'). to have exactly one argument. Skipping..." |
| 166 | + return () |
| 167 | + else |
| 168 | + let arg := args[0] |
| 169 | + -- let fxName : Name := name.appendAfter s!"App" |
| 170 | + -- let fx ← introDef fxName e -- this changes the main context. |
| 171 | + -- Implementation modeled after `Lean.MVarId.generalizeHyp`. |
| 172 | + let transparency := TransparencyMode.reducible |
| 173 | + let hyps := (← getLCtx).getFVarIds |
| 174 | + let hyps ← hyps.filterM fun h => do |
| 175 | + let type ← instantiateMVars (← h.getType) |
| 176 | + return (← withTransparency transparency <| kabstract type e).hasLooseBVars |
| 177 | + |
| 178 | + let goal ← getMainGoal |
| 179 | + let (reverted, goal) ← goal.revert hyps true |
| 180 | + let garg : GeneralizeArg := { |
| 181 | + expr := e, |
| 182 | + xName? := .some newName, |
| 183 | + hName? := newName.appendAfter "h" |
| 184 | + } |
| 185 | + let (fxs, goal) ← goal.generalize #[garg] |
| 186 | + let (reintros, goal) ← goal.introNP reverted.size |
| 187 | + replaceMainGoal [goal] |
| 188 | + |
| 189 | + withMainContext do |
| 190 | + let mut i := 0 |
| 191 | + for r in reintros do |
| 192 | + trace[ack] "REINTROS[{i}]: {← r.getUserName} : {← r.getType}" |
| 193 | + i := i + 1 |
| 194 | + |
| 195 | + withMainContext do |
| 196 | + let mut i := 0 |
| 197 | + for r in fxs do |
| 198 | + trace[ack] "FXS[{i}]: {← r.getUserName} : {← r.getType}" |
| 199 | + i := i + 1 |
| 200 | + |
| 201 | + let .some fx := fxs[0]? |
| 202 | + | throwTacticEx `ack goal m!"expected generalized variable" |
| 203 | + let .some f_x_eq_fx := fxs[1]? |
| 204 | + | throwTacticEx `ack goal m!"expected proof of generalized variable" |
| 205 | + |
| 206 | + withMainContext do |
| 207 | + trace[ack] "{processingEmoji} introduced new defn {Expr.fvar fx} := {e}." |
| 208 | + |
| 209 | + let calls ← getCalls f |
| 210 | + let call : Call := { |
| 211 | + name := newName |
| 212 | + x := arg, |
| 213 | + xTy := ← inferType arg, |
| 214 | + fx := fx, |
| 215 | + fxTy := ← fx.getType, |
| 216 | + heqProof := Expr.fvar f_x_eq_fx |
| 217 | + } |
| 218 | + |
| 219 | + trace[ack] "built ackermannization: {call} • {Expr.fvar fx}" |
| 220 | + |
| 221 | + for otherCall in calls do |
| 222 | + trace[ack] "building interference: {call.x} = {otherCall.x} => {call.fx.name} = {otherCall.fx.name}" |
| 223 | + let eqName := (otherCall.name).appendAfter s!"Sim" |>.append call.name |
| 224 | + let ackEq ← mkAppM ``Ack.ackermannize_proof |
| 225 | + #[call.xTy, ety, -- A B |
| 226 | + f, -- f |
| 227 | + call.x, otherCall.x, -- x y |
| 228 | + .fvar call.fx, .fvar otherCall.fx, -- fx fy |
| 229 | + call.heqProof, otherCall.heqProof -- hfx hfy |
| 230 | + ] |
| 231 | + let _ ← introDef eqName ackEq |
| 232 | + -- make a call of ackermannize_proof. |
| 233 | + addCall f call |
| 234 | + -- the application is now this fvar. |
| 235 | + | .lam .. | .letE .. => return () |
| 236 | + | .forallE .. => return () |
| 237 | + |
| 238 | + |
| 239 | +/- |
| 240 | +For every bitvector (x : BitVec w), for every function `(f : BitVec w → BitVec w')`, |
| 241 | +walk every function application (f x), and add a new fvar (fx : BitVec w'). |
| 242 | +- Keep an equality that says `fx = f x`. |
| 243 | +- For function application of f, for each pair of bitvectors x, y, |
| 244 | + add a hypothesis that says `x = y => fx = fy, with proof. |
| 245 | +-/ |
| 246 | +def ack (g : MVarId) : AckM Unit := do |
| 247 | + withContext g do |
| 248 | + for hyp in (← getLocalHyps) do |
| 249 | + doAck (← inferType hyp) |
| 250 | + doAck (← getMainTarget) |
| 251 | + |
| 252 | +/-- Entry point for programmatic usage of `bv_ackermannize` -/ |
| 253 | +def ackTac (g : MVarId) (ctx : Context) : TacticM Unit := do |
| 254 | + run ack ctx |
| 255 | + |
| 256 | + |
| 257 | +end Ack |
| 258 | + |
| 259 | +@[builtin_tactic Lean.Parser.Tactic.bvAckermannize] |
| 260 | +def evalBvAckermannize : Tactic := fun |
| 261 | + | `(tactic| bv_ackermannize) => do |
| 262 | + liftMetaFinishingTactic fun g => do |
| 263 | + discard <| ackTac g cfg |
| 264 | + | _ => throwUnsupportedSyntax |
0 commit comments