From 0ac14c86b7f44ca8c3d474daa24e3c9e411118a3 Mon Sep 17 00:00:00 2001 From: Alex Keizer Date: Thu, 10 Oct 2024 08:55:51 -0500 Subject: [PATCH] refactor: extract out `MemoryEffects` structure (#222) ### Description: Extracted from #179, stacked on #220. We extract out memory-effects related code from AxEffects into a new MemoryEffects structure. This PR is purely a non-functional change, but will serve as the starting point of integrating simp_mem with sym_n. The current simplification is effectively a no-op, since the proof state is not massaged to the way `simp_mem` wants it to be. Subsequent PRs will focus on massaging the goal state to be as `simp_mem` likes, and then trying to symbolically simplify the memory expression we see. ### Testing: What tests have been run? Did `make all` succeed for your changes? Was conformance testing successful on an Aarch64 machine? yes ### License: By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license. Co-authored-by @bollu --------- Co-authored-by: Shilpi Goel --- Tactics/Common.lean | 36 ++++++++++ Tactics/Sym/AxEffects.lean | 96 ++++++++++--------------- Tactics/Sym/Common.lean | 7 ++ Tactics/Sym/MemoryEffects.lean | 127 +++++++++++++++++++++++++++++++++ 4 files changed, 206 insertions(+), 60 deletions(-) create mode 100644 Tactics/Sym/MemoryEffects.lean diff --git a/Tactics/Common.lean b/Tactics/Common.lean index 50212ce9..9763af33 100644 --- a/Tactics/Common.lean +++ b/Tactics/Common.lean @@ -277,6 +277,42 @@ def Lean.Expr.eqReadField? (e : Expr) : Option (Expr × Expr × Expr) := do | none some (field, state, value) +/-- Return the expression for `Memory` -/ +def mkMemory : Expr := mkConst ``Memory + +/-! ## Expr Helpers -/ + +/-- Throw an error if `e` is not of type `expectedType` -/ +def assertHasType (e expectedType : Expr) : MetaM Unit := do + let eType ← inferType e + if !(←isDefEq eType expectedType) then + throwError "{e} {← mkHasTypeButIsExpectedMsg eType expectedType}" + +/-- Throw an error if `e` is not def-eq to `expected` -/ +def assertIsDefEq (e expected : Expr) : MetaM Unit := do + if !(←isDefEq e expected) then + throwError "expected:\n {expected}\nbut found:\n {e}" + +/-- +Rewrites `e` via some `eq`, producing a proof `e = e'` for some `e'`. +Rewrites with a fresh metavariable as the ambient goal. +Fails if the rewrite produces any subgoals. +-/ +-- source: https://github.com/leanprover-community/mathlib4/blob/b35703fe5a80f1fa74b82a2adc22f3631316a5b3/Mathlib/Lean/Expr/Basic.lean#L476-L477 +def rewrite (e eq : Expr) : MetaM Expr := do + let ⟨_, eq', []⟩ ← (← mkFreshExprMVar none).mvarId!.rewrite e eq + | throwError "Expr.rewrite may not produce subgoals." + return eq' + +/-- +Rewrites the type of `e` via some `eq`, then moves `e` into the new type via `Eq.mp`. +Rewrites with a fresh metavariable as the ambient goal. +Fails if the rewrite produces any subgoals. +-/ +-- source: https://github.com/leanprover-community/mathlib4/blob/b35703fe5a80f1fa74b82a2adc22f3631316a5b3/Mathlib/Lean/Expr/Basic.lean#L476-L477 +def rewriteType (e eq : Expr) : MetaM Expr := do + mkEqMP (← rewrite (← inferType e) eq) e + /-! ## Tracing helpers -/ def traceHeartbeats (cls : Name) (header : Option String := none) : diff --git a/Tactics/Sym/AxEffects.lean b/Tactics/Sym/AxEffects.lean index 115d5e19..073da491 100644 --- a/Tactics/Sym/AxEffects.lean +++ b/Tactics/Sym/AxEffects.lean @@ -9,6 +9,7 @@ import Tactics.Common import Tactics.Attr import Tactics.Simp import Tactics.Sym.Common +import Tactics.Sym.MemoryEffects import Std.Data.HashMap @@ -59,17 +60,8 @@ structure AxEffects where where `f₁, ⋯, fₙ` are the keys of `fields` -/ nonEffectProof : Expr - /-- An expression of a (potentially empty) sequence of `write_mem`s - to the initial state, which describes the effects on memory. - See `memoryEffectProof` for more detail -/ - memoryEffect : Expr - /-- An expression that contains the proof of: - ```lean - ∀ n addr, - read_mem_bytes n addr - = read_mem_bytes n addr - ``` -/ - memoryEffectProof : Expr + /-- The memory effects -/ + memoryEffects : MemoryEffects /-- A proof that `.program = .program` -/ programProof : Expr /-- An optional proof of `CheckSPAlignment `. @@ -100,8 +92,8 @@ variable {m} [Monad m] [MonadReaderOf AxEffects m] def getCurrentState : m Expr := do return (← read).currentState def getInitialState : m Expr := do return (← read).initialState def getNonEffectProof : m Expr := do return (← read).nonEffectProof -def getMemoryEffect : m Expr := do return (← read).memoryEffect -def getMemoryEffectProof : m Expr := do return (← read).memoryEffectProof +def getMemoryEffect : m Expr := do return (← read).memoryEffects.effects +def getMemoryEffectProof : m Expr := do return (← read).memoryEffects.proof def getProgramProof : m Expr := do return (← read).programProof def getStackAlignmentProof? : m (Option Expr) := do @@ -136,15 +128,7 @@ def initial (state : Expr) : AxEffects where -- `fun f => rfl` mkLambda `f .default (mkConst ``StateField) <| mkEqReflArmState <| mkApp2 (mkConst ``r) (.bvar 0) state - memoryEffect := state - memoryEffectProof := - -- `fun n addr => rfl` - mkLambda `n .default (mkConst ``Nat) <| - let bv64 := mkApp (mkConst ``BitVec) (toExpr 64) - mkLambda `addr .default bv64 <| - mkApp2 (.const ``Eq.refl [1]) - (mkApp (mkConst ``BitVec) <| mkNatMul (.bvar 1) (toExpr 8)) - (mkApp3 (mkConst ``read_mem_bytes) (.bvar 1) (.bvar 0) state) + memoryEffects := .initial state programProof := -- `rfl` mkAppN (.const ``Eq.refl [1]) #[ @@ -170,8 +154,7 @@ instance : ToMessageData AxEffects where currentState := {eff.currentState}, fields := {eff.fields}, nonEffectProof := {eff.nonEffectProof}, - memoryEffect := {eff.memoryEffect}, - memoryEffectProof := {eff.memoryEffectProof}, + memoryEffects := {eff.memoryEffects}, programProof := {eff.programProof} }" @@ -280,7 +263,7 @@ Note that no effort is made to preserve `currentStateEq`; it is set to `none`! -/ private def update_write_mem (eff : AxEffects) (n addr val : Expr) : MetaM AxEffects := - withTraceNode m!"processing: write_mem {n} {addr} {val} …" (tag := "updateWriteMem") <| do + Sym.withTraceNode m!"processing: write_mem {n} {addr} {val} …" (tag := "updateWriteMem") <| do -- Update each field let fields ← eff.fields.toList.mapM fun ⟨fld, {value, proof}⟩ => do @@ -298,11 +281,10 @@ private def update_write_mem (eff : AxEffects) (n addr val : Expr) : mkLambdaFVars args proof -- ^^ `fun f ... => Eq.trans (@r_of_write_mem_bytes f ...) ` - -- Update the memory effects proof - let memoryEffectProof := - -- `read_mem_bytes_write_mem_bytes_of_read_mem_eq ...` - mkAppN (mkConst ``read_mem_bytes_write_mem_bytes_of_read_mem_eq) - #[eff.currentState, eff.memoryEffect, eff.memoryEffectProof, n, addr, val] + -- Update the memory effects + let memoryEffects ← + eff.memoryEffects.updateWriteMem eff.currentState n addr val + -- Update the program proof let programProof ← @@ -318,15 +300,13 @@ private def update_write_mem (eff : AxEffects) (n addr val : Expr) : #[eff.currentState, n, addr, val, proof] -- Assemble the result - let addWrite (e : Expr) := - -- `@write_mem_bytes ` - mkApp4 (mkConst ``write_mem_bytes) n addr val e + let currentState := -- `@write_mem_bytes ` + mkApp4 (mkConst ``write_mem_bytes) n addr val eff.currentState let eff := { eff with - currentState := addWrite eff.currentState + currentState fields := .ofList fields nonEffectProof - memoryEffect := addWrite eff.memoryEffect - memoryEffectProof + memoryEffects programProof stackAlignmentProof? } @@ -341,7 +321,7 @@ Note that no effort is made to preserve `currentStateEq`; it is set to `none`! -/ private def update_w (eff : AxEffects) (fld val : Expr) : MetaM AxEffects := do - withTraceNode m!"processing: w {fld} {val} …" (tag := "updateWrite") <| do + Sym.withTraceNode m!"processing: w {fld} {val} …" (tag := "updateWrite") <| do let rField ← reflectStateField fld -- Update all other fields @@ -398,11 +378,8 @@ private def update_w (eff : AxEffects) (fld val : Expr) : withLocalDeclD name h_neq_type fun h_neq => k (args.push h_neq) h_neq - -- Update the memory effect proof - let memoryEffectProof := - -- `read_mem_bytes_w_of_read_mem_eq ...` - mkAppN (mkConst ``read_mem_bytes_w_of_read_mem_eq) - #[eff.currentState, eff.memoryEffect, eff.memoryEffectProof, fld, val] + -- Update the memory effects + let memoryEffects ← eff.memoryEffects.updateWrite eff.currentState fld val -- Update the program proof let programProof ← @@ -434,8 +411,7 @@ private def update_w (eff : AxEffects) (fld val : Expr) : currentState := mkApp3 (mkConst ``w) fld val eff.currentState fields := Std.HashMap.ofList fields nonEffectProof - -- memory effects are unchanged - memoryEffectProof + memoryEffects programProof stackAlignmentProof? sideConditions @@ -498,7 +474,7 @@ def fromExpr (e : Expr) : MetaM AxEffects := do set `s` to be the new `currentState`, and update all proofs accordingly -/ def adjustCurrentStateWithEq (eff : AxEffects) (s eq : Expr) : MetaM AxEffects := do - withTraceNode m!"adjustCurrentStateWithEq" (tag := "adjustCurrentStateWithEq") do + Sym.withTraceNode m!"adjustCurrentStateWithEq" (tag := "adjustCurrentStateWithEq") do trace[Tactic.sym] "rewriting along {eq}" eff.traceCurrentState @@ -515,17 +491,15 @@ def adjustCurrentStateWithEq (eff : AxEffects) (s eq : Expr) : pure (field, {fieldEff with proof}) let fields := .ofList fields - withTraceNode m!"rewriting other proofs" (tag := "rewriteMisc") <| do + Sym.withTraceNode m!"rewriting other proofs" (tag := "rewriteMisc") <| do let nonEffectProof ← rewriteType eff.nonEffectProof eq - let memoryEffectProof ← rewriteType eff.memoryEffectProof eq - -- ^^ TODO: what happens if `memoryEffect` is the same as `currentState`? - -- Presumably, we would *not* want to encapsulate `memoryEffect` here + let memoryEffects ← eff.memoryEffects.adjustCurrentStateWithEq eq let programProof ← rewriteType eff.programProof eq let stackAlignmentProof? ← eff.stackAlignmentProof?.mapM (rewriteType · eq) return { eff with - currentState, fields, nonEffectProof, memoryEffectProof, programProof, + currentState, fields, nonEffectProof, memoryEffects, programProof, stackAlignmentProof? } @@ -642,7 +616,7 @@ NOTE: does not necessarily validate *which* type an expression has, validation will still pass if types are different to those we claim in the docstrings -/ def validate (eff : AxEffects) : MetaM Unit := do - withTraceNode "validating that the axiomatic effects are well-formed" + Sym.withTraceNode "validating that the axiomatic effects are well-formed" (tag := "validate") <| do eff.traceCurrentState @@ -653,13 +627,13 @@ def validate (eff : AxEffects) : MetaM Unit := do check fieldEff.value check fieldEff.proof + eff.memoryEffects.validate check eff.nonEffectProof - check eff.memoryEffect - check eff.memoryEffectProof check eff.programProof if let some h := eff.stackAlignmentProof? then check h + /-! ## Tactic Environment -/ section Tactic open Elab.Tactic @@ -678,7 +652,7 @@ that was just added to the local context -/ def addHypothesesToLContext (eff : AxEffects) (hypPrefix : String := "h_") (mvar : Option MVarId := none) : TacticM AxEffects := - withTraceNode m!"adding hypotheses to local context" + Sym.withTraceNode m!"adding hypotheses to local context" (tag := "addHypothesesToLContext") do eff.traceCurrentState let mut goal ← mvar.getDM getMainGoal @@ -704,12 +678,14 @@ def addHypothesesToLContext (eff : AxEffects) (hypPrefix : String := "h_") let nonEffectProof := Expr.fvar nonEffectProof goal := goal' - trace[Tactic.sym] "adding memory effects with {eff.memoryEffectProof}" + trace[Tactic.sym] "adding memory effects with {eff.memoryEffects.proof}" let ⟨memoryEffectProof, goal'⟩ ← goal.withContext do let name := .mkSimple s!"{hypPrefix}memory_effects" - let proof := eff.memoryEffectProof + let proof := eff.memoryEffects.proof replaceOrNote goal name proof - let memoryEffectProof := Expr.fvar memoryEffectProof + let memoryEffects := { eff.memoryEffects with + proof := Expr.fvar memoryEffectProof + } goal := goal' trace[Tactic.sym] "adding program hypothesis with {eff.programProof}" @@ -735,7 +711,7 @@ def addHypothesesToLContext (eff : AxEffects) (hypPrefix : String := "h_") replaceMainGoal [goal] return {eff with - fields, nonEffectProof, memoryEffectProof, programProof, + fields, nonEffectProof, memoryEffects, programProof, stackAlignmentProof? } where @@ -755,7 +731,7 @@ where /-- Return an array of `SimpTheorem`s of the proofs contained in the given `AxEffects` -/ def toSimpTheorems (eff : AxEffects) : MetaM (Array SimpTheorem) := do - withTraceNode m!"computing SimpTheorems for (non-)effect hypotheses" + Sym.withTraceNode m!"computing SimpTheorems for (non-)effect hypotheses" (tag := "toSimpTheorems") <| do let lctx ← getLCtx let baseName? := @@ -789,7 +765,7 @@ def toSimpTheorems (eff : AxEffects) : MetaM (Array SimpTheorem) := do thms ← add thms proof s!"field_{field}" (prio := 1500) thms ← add thms eff.nonEffectProof "nonEffectProof" - thms ← add thms eff.memoryEffectProof "memoryEffectProof" + thms ← add thms eff.memoryEffects.proof "memoryEffectProof" thms ← add thms eff.programProof "programProof" if let some stackAlignmentProof := eff.stackAlignmentProof? then thms ← add thms stackAlignmentProof "stackAlignmentProof" diff --git a/Tactics/Sym/Common.lean b/Tactics/Sym/Common.lean index 21dc8807..3e07236b 100644 --- a/Tactics/Sym/Common.lean +++ b/Tactics/Sym/Common.lean @@ -27,6 +27,13 @@ def withVerboseTraceNode (msg : MessageData) (k : m α) : m α := do Lean.withTraceNode `Tactic.sym.verbose (fun _ => pure msg) k collapsed tag +/-- Create a trace note that folds `header` with `(NOTE: can be large)`, +and prints `msg` under such a trace node. +-/ +def traceLargeMsg (header : MessageData) (msg : MessageData) : MetaM Unit := + withTraceNode m!"{header} (NOTE: can be large)" do + trace[Tactic.sym] msg + end Tracing end Sym diff --git a/Tactics/Sym/MemoryEffects.lean b/Tactics/Sym/MemoryEffects.lean new file mode 100644 index 00000000..45cb7eec --- /dev/null +++ b/Tactics/Sym/MemoryEffects.lean @@ -0,0 +1,127 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Author(s): Alex Keizer, Siddharth Bhat +-/ + +import Arm.State +import Tactics.Common +import Tactics.Attr +import Tactics.Simp +import Tactics.Sym.Common + +import Std.Data.HashMap + +open Lean Meta + +structure MemoryEffects where + /-- An expression of a (potentially empty) sequence of `write_mem`s + to the initial state, which describes the effects on memory. + See `memoryEffectProof` for more detail -/ + effects : Expr + /-- An expression that contains the proof of: + ```lean + ∀ n addr, + read_mem_bytes n addr + = read_mem_bytes n addr + ``` -/ + proof : Expr +deriving Repr + +instance : ToMessageData MemoryEffects where + toMessageData eff := + m!"\ + \{ effects := {eff.effects}, + proof := {eff.proof + }" + +namespace MemoryEffects + +/-! ## Initial Reflected State -/ + +/-- An initial `MemoryEffects`, representing no memory changes to the +initial `state` -/ +def initial (state : Expr) : MemoryEffects where + effects := state + proof := + -- `fun n addr => rfl` + mkLambda `n .default (mkConst ``Nat) <| + let bv64 := mkApp (mkConst ``BitVec) (toExpr 64) + mkLambda `addr .default bv64 <| + mkApp2 (.const ``Eq.refl [1]) + (mkApp (mkConst ``BitVec) <| mkNatMul (.bvar 1) (toExpr 8)) + (mkApp3 (mkConst ``read_mem_bytes) (.bvar 1) (.bvar 0) state) + +/-- Update the memory effects with a memory write -/ +def updateWriteMem (eff : MemoryEffects) (currentState : Expr) + (n addr val : Expr) : + MetaM MemoryEffects := do + let effects := mkApp4 (mkConst ``write_mem_bytes) n addr val eff.effects + let proof := + -- `read_mem_bytes_write_mem_bytes_of_read_mem_eq ...` + mkAppN (mkConst ``read_mem_bytes_write_mem_bytes_of_read_mem_eq) + #[currentState, eff.effects, eff.proof, n, addr, val] + return { effects, proof } + +/-- Update the memory effects with a register write. + +This doesn't change the actual effect, but since the `currentState` has changed, +we need to update proofs -/ +def updateWrite (eff : MemoryEffects) (currentState : Expr) + (fld val : Expr) : + MetaM MemoryEffects := do + let proof := -- `read_mem_bytes_w_of_read_mem_eq ...` + mkAppN (mkConst ``read_mem_bytes_w_of_read_mem_eq) + #[currentState, eff.effects, eff.proof, fld, val] + return { eff with proof } + +/-- Transport all proofs along an equality `eq : = s`, +so that `s` is the new `currentState` -/ +def adjustCurrentStateWithEq (eff : MemoryEffects) (eq : Expr) : + MetaM MemoryEffects := do + let proof ← rewriteType eff.proof eq + /- ^^ This looks scary, since it can rewrite the left-hand-side of the proof + if `memoryEffect` is the same as `currentState` (which would be bad!). + However, this cannot ever happen in LNSym: every instruction has to modify + either the PC or the error field, neither of which is incorporated into + the `memoryEffect` and thus, `memoryEffect` never coincides with + `currentState` (assuming we're dealing with instruction semantics, as we + currently do!). -/ + return { eff with proof } + +/-- Convert a `MemoryEffects` into a `MessageData` for logging. -/ +def toMessageData (eff : MemoryEffects) : MetaM MessageData := do + let out := m!"effects: {eff.effects}" + return out + +/-- Trace the current state of `MemoryEffects`. -/ +def traceCurrentState (eff : MemoryEffects) : MetaM Unit := do + Sym.traceLargeMsg "memoryEffects" (← eff.toMessageData) + + + +/-- type check all expressions stored in `eff`, +throwing an error if one is not type-correct. + +In principle, the various `MemoryEffects` definitions should return only +well-formed expressions, making `validate` a no-op. +In practice, however, running `validate` is helpful for catching bugs in those +definitions. Do note that typechecking might be a bit expensive, so we generally +only call `validate` while debugging, not during normal execution. +See also the `Tactic.sym.debug` option, which controls whether `validate` is +called for each step of the `sym_n` tactic. + +NOTE: does not necessarily validate *which* type an expression has, +validation will still pass if types are different to those we claim in the +docstrings +-/ +def validate (eff : MemoryEffects) : MetaM Unit := do + let msg := "validating that the axiomatic effects are well-formed" + Sym.withTraceNode msg do + eff.traceCurrentState + check eff.effects + assertHasType eff.effects mkArmState + + check eff.proof + +end MemoryEffects