From ee7565edc47f6ab18b4c198d90436647bcdf6d4d Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Sat, 10 Jan 2026 21:08:24 +0000 Subject: [PATCH] Modify createOp to take as input the regions Previously, we passed the number of regions to create. It makes more sense to pass existing regions instead. --- Veir/Benchmarks.lean | 34 ++++----- Veir/PatternRewriter/Basic.lean | 5 +- Veir/Rewriter/Basic.lean | 75 ++++++++----------- Veir/Rewriter/GetSetInBounds.lean | 35 +++++---- .../Rewriter/WellFormed/Builder/OpRegion.lean | 4 +- .../WellFormed/Builder/Operation.lean | 3 +- 6 files changed, 73 insertions(+), 83 deletions(-) diff --git a/Veir/Benchmarks.lean b/Veir/Benchmarks.lean index 9302f63..5760b99 100644 --- a/Veir/Benchmarks.lean +++ b/Veir/Benchmarks.lean @@ -43,7 +43,7 @@ def addIConstantFolding (rewriter: PatternRewriter) (op: OperationPtr) : Option -- Sum both constant values let newVal := lhsOpStruct.properties + rhsOpStruct.properties - let (rewriter, newOp) ← rewriter.createOp OpCode.constant 1 #[] 0 newVal (some $ .before op) sorry sorry + let (rewriter, newOp) ← rewriter.createOp OpCode.constant 1 #[] #[] newVal (some $ .before op) sorry sorry sorry let mut rewriter ← rewriter.replaceOp op newOp sorry sorry sorry if (lhsValuePtr.getFirstUse rewriter.ctx (by sorry)).isNone then @@ -96,7 +96,7 @@ def mulITwoReduce (rewriter: PatternRewriter) (op: OperationPtr) : Option Patter -- Get the lhs value let lhsValuePtr := op.getOperand rewriter.ctx 0 (by sorry) (by sorry) - let (rewriter, newOp) ← rewriter.createOp OpCode.addi 1 #[lhsValuePtr, lhsValuePtr] 0 0 (some $ .before op) sorry sorry + let (rewriter, newOp) ← rewriter.createOp OpCode.addi 1 #[lhsValuePtr, lhsValuePtr] #[] 0 (some $ .before op) sorry sorry sorry let mut rewriter ← rewriter.replaceOp op newOp sorry sorry sorry if (rhsValuePtr.getFirstUse rewriter.ctx (by sorry)).isNone then @@ -136,7 +136,7 @@ def addIConstantFolding (ctx: IRContext) (op: OperationPtr) : Option IRContext : -- Sum both constant values let newVal := lhsOpStruct.properties + rhsOpStruct.properties - let (ctx, newOp) ← Rewriter.createOp ctx OpCode.constant 1 #[] 0 newVal (some $ .before op) sorry sorry sorry + let (ctx, newOp) ← Rewriter.createOp ctx OpCode.constant 1 #[] #[] newVal (some $ .before op) sorry sorry sorry sorry let mut ctx ← Rewriter.replaceOp? ctx op newOp sorry sorry sorry sorry if (lhsValuePtr.getFirstUse ctx (by sorry)).isNone then @@ -189,7 +189,7 @@ def mulITwoReduce (ctx: IRContext) (op: OperationPtr) : Option IRContext := do -- Get the lhs value let lhsValuePtr := op.getOperand ctx 0 (by sorry) (by sorry) - let (ctx, newOp) ← Rewriter.createOp ctx OpCode.addi 1 #[lhsValuePtr, lhsValuePtr] 0 0 (some $ .before op) sorry sorry sorry + let (ctx, newOp) ← Rewriter.createOp ctx OpCode.addi 1 #[lhsValuePtr, lhsValuePtr] #[] 0 (some $ .before op) sorry sorry sorry sorry let mut ctx ← Rewriter.replaceOp? ctx op newOp sorry sorry sorry sorry if (rhsValuePtr.getFirstUse ctx (by sorry)).isNone then @@ -252,19 +252,19 @@ def empty : Option (IRContext × InsertPoint) := do -- ... def constFoldTree (opcode: Nat) (size pc: Nat) (root inc: UInt64) : Option IRContext := do let (gctx, insertPoint) ← empty - let mut (gctx, gacc) ← Rewriter.createOp gctx OpCode.constant 1 #[] 0 root insertPoint sorry sorry sorry + let mut (gctx, gacc) ← Rewriter.createOp gctx OpCode.constant 1 #[] #[] root insertPoint sorry sorry sorry sorry for i in [0:size] do let thisOp := if (i % 100 < pc) then opcode else OpCode.andi let (ctx, acc) := (gctx, gacc) - let (ctx, rhsOp) ← Rewriter.createOp ctx OpCode.constant 1 #[] 0 inc insertPoint sorry sorry sorry + let (ctx, rhsOp) ← Rewriter.createOp ctx OpCode.constant 1 #[] #[] inc insertPoint sorry sorry sorry sorry let lhsVal := acc.getResult 0 let rhsVal := rhsOp.getResult 0 - let (ctx, acc) ← Rewriter.createOp ctx thisOp 1 #[lhsVal, rhsVal] 0 0 insertPoint sorry sorry sorry + let (ctx, acc) ← Rewriter.createOp ctx thisOp 1 #[lhsVal, rhsVal] #[] 0 insertPoint sorry sorry sorry sorry (gctx, gacc) := (ctx, acc) let accRes := gacc.getResult 0 - let (ctx, op) ← Rewriter.createOp gctx OpCode.test 0 #[accRes] 0 0 insertPoint sorry sorry sorry + let (ctx, op) ← Rewriter.createOp gctx OpCode.test 0 #[accRes] #[] 0 insertPoint sorry sorry sorry sorry ctx def addZeroTree (size pc: Nat) : Option IRContext := @@ -286,8 +286,8 @@ def mulTwoTree (size pc: Nat) : Option IRContext := -- ... def constReuseTree (opcode: Nat) (size pc: Nat) (root inc: UInt64) : Option IRContext := do let (ctx, insertPoint) ← empty - let (ctx, acc) ← Rewriter.createOp ctx OpCode.constant 1 #[] 0 root insertPoint sorry sorry sorry - let (ctx, reuse) ← Rewriter.createOp ctx OpCode.constant 1 #[] 0 inc insertPoint sorry sorry sorry + let (ctx, acc) ← Rewriter.createOp ctx OpCode.constant 1 #[] #[] root insertPoint sorry sorry sorry sorry + let (ctx, reuse) ← Rewriter.createOp ctx OpCode.constant 1 #[] #[] inc insertPoint sorry sorry sorry sorry let mut (gctx, gacc) := (ctx, acc) for i in [0:size] do @@ -296,12 +296,12 @@ def constReuseTree (opcode: Nat) (size pc: Nat) (root inc: UInt64) : Option IRCo let (ctx, acc) := (gctx, gacc) let lhsVal := acc.getResult 0 let rhsVal := reuse.getResult 0 - let (ctx, acc) ← Rewriter.createOp ctx thisOp 1 #[lhsVal, rhsVal] 0 0 insertPoint sorry sorry sorry + let (ctx, acc) ← Rewriter.createOp ctx thisOp 1 #[lhsVal, rhsVal] #[] 0 insertPoint sorry sorry sorry sorry (gctx, gacc) := (ctx, acc) let (ctx, acc) := (gctx, gacc) let accRes := acc.getResult 0 - let (ctx, op) ← Rewriter.createOp ctx OpCode.test 0 #[accRes] 0 0 insertPoint sorry sorry sorry + let (ctx, op) ← Rewriter.createOp ctx OpCode.test 0 #[accRes] #[] 0 insertPoint sorry sorry sorry sorry ctx def addZeroReuseTree (size pc: Nat) : Option IRContext := @@ -318,11 +318,11 @@ def addZeroReuseTree (size pc: Nat) : Option IRContext := -- ... def constLotsOfReuseTree (opcode: Nat) (size pc: Nat) (lhs rhs: UInt64) : Option IRContext := do let (ctx, insertPoint) ← empty - let (ctx, lhsOp) ← Rewriter.createOp ctx OpCode.constant 1 #[] 0 lhs insertPoint sorry sorry sorry - let (ctx, rhsOp) ← Rewriter.createOp ctx OpCode.constant 1 #[] 0 rhs insertPoint sorry sorry sorry + let (ctx, lhsOp) ← Rewriter.createOp ctx OpCode.constant 1 #[] #[] lhs insertPoint sorry sorry sorry sorry + let (ctx, rhsOp) ← Rewriter.createOp ctx OpCode.constant 1 #[] #[] rhs insertPoint sorry sorry sorry sorry let lhsVal := lhsOp.getResult 0 let rhsVal := rhsOp.getResult 0 - let (ctx, reuse) ← Rewriter.createOp ctx opcode 1 #[lhsVal, rhsVal] 0 0 insertPoint sorry sorry sorry + let (ctx, reuse) ← Rewriter.createOp ctx opcode 1 #[lhsVal, rhsVal] #[] 0 insertPoint sorry sorry sorry sorry let mut (gctx, gacc) := (ctx, reuse) for i in [0:size] do @@ -331,12 +331,12 @@ def constLotsOfReuseTree (opcode: Nat) (size pc: Nat) (lhs rhs: UInt64) : Option let (ctx, acc) := (gctx, gacc) let lhsVal := acc.getResult 0 let rhsVal := reuse.getResult 0 - let (ctx, acc) ← Rewriter.createOp ctx thisOp 1 #[lhsVal, rhsVal] 0 0 insertPoint sorry sorry sorry + let (ctx, acc) ← Rewriter.createOp ctx thisOp 1 #[lhsVal, rhsVal] #[] 0 insertPoint sorry sorry sorry sorry (gctx, gacc) := (ctx, acc) let (ctx, acc) := (gctx, gacc) let accRes := acc.getResult 0 - let (ctx, op) ← Rewriter.createOp ctx OpCode.test 0 #[accRes] 0 0 insertPoint sorry sorry sorry + let (ctx, op) ← Rewriter.createOp ctx OpCode.test 0 #[accRes] #[] 0 insertPoint sorry sorry sorry sorry ctx def addZeroLotsOfReuseTree (size pc: Nat) : Option IRContext := diff --git a/Veir/PatternRewriter/Basic.lean b/Veir/PatternRewriter/Basic.lean index 5c7f4de..fca57bf 100644 --- a/Veir/PatternRewriter/Basic.lean +++ b/Veir/PatternRewriter/Basic.lean @@ -142,11 +142,12 @@ theorem addUsersInWorklist_same_ctx : simp [addUsersInWorklist] def createOp (rewriter: PatternRewriter) (opType: Nat) - (numResults: Nat) (operands: Array ValuePtr) (numRegions: Nat) (properties: UInt64) + (numResults: Nat) (operands: Array ValuePtr) (regions: Array RegionPtr) (properties: UInt64) (insertionPoint: Option InsertPoint) (hoper : ∀ oper, oper ∈ operands → oper.InBounds rewriter.ctx) + (hregions : ∀ region, region ∈ regions → region.InBounds rewriter.ctx) (hins : insertionPoint.maybe InsertPoint.InBounds rewriter.ctx) : Option (PatternRewriter × OperationPtr) := do - rlet (newCtx, op) ← Rewriter.createOp rewriter.ctx opType numResults operands numRegions properties insertionPoint hoper hins (by grind) + rlet (newCtx, op) ← Rewriter.createOp rewriter.ctx opType numResults operands regions properties insertionPoint hoper hregions hins (by grind) if h : insertionPoint.isNone then ({ rewriter with ctx := newCtx, ctx_fib := by grind }, op) else diff --git a/Veir/Rewriter/Basic.lean b/Veir/Rewriter/Basic.lean index cd0fb63..1b622a8 100644 --- a/Veir/Rewriter/Basic.lean +++ b/Veir/Rewriter/Basic.lean @@ -333,30 +333,29 @@ def Rewriter.createRegion (ctx: IRContext) : Option (IRContext × RegionPtr) := RegionPtr.allocEmpty ctx @[irreducible] -def Rewriter.initOpRegions (ctx: IRContext) (opPtr: OperationPtr) (numRegions: Nat) - (hop : opPtr.InBounds ctx) : Option IRContext := - match numRegions with - | 0 => some ctx - | Nat.succ n => do - rlet (ctx, regionPtr) ← Rewriter.createRegion ctx - let oldRegion := regionPtr.get ctx (by grind) - let ctx := opPtr.setRegions ctx ((opPtr.get ctx).regions.push regionPtr) - Rewriter.initOpRegions ctx opPtr n (by grind) +def Rewriter.initOpRegions (ctx: IRContext) (opPtr: OperationPtr) (regions : Array RegionPtr) (n : Nat := regions.size) + (opPtrInBounds : opPtr.InBounds ctx := by grind) + (hregionInBounds : ∀ region, region ∈ regions → region.InBounds ctx := by grind) + (hctx : ctx.FieldsInBounds := by grind) (hn : 0 ≤ n ∧ n ≤ regions.size := by grind) : IRContext := + match h : n with + | 0 => opPtr.setRegions ctx regions (by grind) + | Nat.succ n' => + let index := regions.size - n + let regionPtr := regions[index]'(by grind) + let ctx := regionPtr.setParent ctx opPtr (by grind) + Rewriter.initOpRegions ctx opPtr regions n' -set_option warn.sorry false in @[grind .] -theorem Rewriter.initOpRegions_fieldsInBounds - (hx : ctx.FieldsInBounds) - (heq : initOpRegions ctx opPtr numRegions h₁ = some ctx') : - ctx'.FieldsInBounds := by - induction numRegions generalizing ctx <;> sorry --grind [initOpRegions] +theorem Rewriter.initOpRegions_fieldsInBounds : + ctx.FieldsInBounds → + (initOpRegions ctx opPtr regions n opPtrInBounds hregions hctx hn).FieldsInBounds := by + intros hctx + induction n generalizing ctx <;> simp only [initOpRegions] <;> grind @[grind .] -theorem Rewriter.initOpRegions_inBounds_mono (ptr : GenericPtr) - (heq : initOpRegions ctx opPtr numRegions h₁ = some ctx') : - ptr.InBounds ctx → ptr.InBounds ctx' := by - induction numRegions generalizing ctx <;> - grind [initOpRegions, Option.unattach_eq_some_iff] +theorem Rewriter.initOpRegions_inBounds_mono (ptr : GenericPtr) : + ptr.InBounds ctx → ptr.InBounds (initOpRegions ctx opPtr regions n opPtrInBounds hregions hctx hn) := by + induction n generalizing ctx <;> simp only [initOpRegions] <;> grind @[irreducible] def Rewriter.initOpResults (ctx: IRContext) (opPtr: OperationPtr) (numResults: Nat) (index: Nat := 0) (hop : opPtr.InBounds ctx) @@ -470,9 +469,10 @@ theorem Rewriter.createEmptyOp_fieldsInBounds (h : createEmptyOp ctx opType = so @[irreducible] def Rewriter.createOp (ctx: IRContext) (opType: Nat) - (numResults: Nat) (operands: Array ValuePtr) (numRegions: Nat) (properties: UInt64) + (numResults: Nat) (operands: Array ValuePtr) (regions: Array RegionPtr) (properties: UInt64) (insertionPoint: Option InsertPoint) (hoper : ∀ oper, oper ∈ operands → oper.InBounds ctx) + (hregions : ∀ reg, reg ∈ regions → reg.InBounds ctx) (hins : insertionPoint.maybe InsertPoint.InBounds ctx) (hx : ctx.FieldsInBounds) : Option (IRContext × OperationPtr) := rlet (ctx, newOpPtr) ← Rewriter.createEmptyOp ctx opType @@ -483,11 +483,11 @@ def Rewriter.createOp (ctx: IRContext) (opType: Nat) let ctx := Rewriter.initOpResults ctx newOpPtr numResults 0 hib newOpPtrZeroRes let ctx := newOpPtr.setProperties ctx properties (by grind) have newOpPtrInBounds : newOpPtr.InBounds ctx := by grind - rlet ctx ← Rewriter.initOpRegions ctx newOpPtr numRegions newOpPtrInBounds + let ctx := Rewriter.initOpRegions ctx newOpPtr regions let ctx := Rewriter.initOpOperands ctx newOpPtr (by grind) operands (by grind) (by grind) match _ : insertionPoint with | some insertionPoint => - rlet ctx ← Rewriter.insertOp? ctx newOpPtr insertionPoint (by grind) (by cases insertionPoint <;> grind (ematch := 6) [Option.maybe]) (by grind) in + rlet ctx ← Rewriter.insertOp? ctx newOpPtr insertionPoint (by grind) (by cases insertionPoint <;> grind (ematch := 10) [Option.maybe]) (by grind) in some (ctx, newOpPtr) | none => (ctx, newOpPtr) @@ -524,28 +524,17 @@ def IRContext.create : Option (IRContext × OperationPtr) := grind [Rewriter.createEmptyOp, OperationPtr.allocEmpty, Operation.empty, OperationPtr.set, RegionPtr.InBounds] have : operation.get ctx (by simp_all) = Operation.empty ModuleTypeID := by grind [Rewriter.createEmptyOp, OperationPtr.allocEmpty, Operation.empty, OperationPtr.set, RegionPtr.InBounds] - rlet ctx ← Rewriter.initOpRegions ctx operation 1 (by grind) + rlet (ctx, region) ← Rewriter.createRegion ctx + have : ctx.FieldsInBounds := by sorry + let ctx := Rewriter.initOpRegions ctx operation #[region] have : operation = ⟨0⟩ := by grind [Rewriter.createEmptyOp, OperationPtr.allocEmpty] - have : ctx.topLevelOp = ⟨0⟩ := by - simp_all [Rewriter.initOpRegions, OperationPtr.setRegions, OperationPtr.set, Rewriter.createRegion] - grind [RegionPtr.set, RegionPtr.allocEmpty] - have hop₀ : ∀ (op : OperationPtr), op.InBounds ctx ↔ op = ⟨0⟩ := by - simp_all [Rewriter.initOpRegions, OperationPtr.setRegions, OperationPtr.set, Rewriter.createRegion] - simp [OperationPtr.InBounds] at hops - sorry --grind [Region.empty, RegionPtr.set, OperationPtr.InBounds] + have : ctx.topLevelOp = ⟨0⟩ := by sorry + have hop₀ : ∀ (op : OperationPtr), op.InBounds ctx ↔ op = ⟨0⟩ := by sorry --grind [Region.empty, RegionPtr.set, OperationPtr.InBounds] have : operation.get ctx (by simp_all) = - { Operation.empty ModuleTypeID with regions := #[⟨1⟩] } := by - simp_all [Rewriter.initOpRegions, OperationPtr.setRegions, OperationPtr.set, Rewriter.createRegion] - grind [Operation.empty, RegionPtr.set, RegionPtr.InBounds, RegionPtr.get, OperationPtr.get, RegionPtr.allocEmpty] - have : ∀ (bl : BlockPtr), bl.InBounds ctx ↔ False := by - simp_all [Rewriter.initOpRegions, OperationPtr.setRegions, OperationPtr.set, Rewriter.createRegion] - sorry --grind [Region.empty, RegionPtr.set, BlockPtr.InBounds] - have : ∀ (r : RegionPtr), r.InBounds ctx ↔ r = ⟨1⟩ := by - simp_all [Rewriter.initOpRegions, OperationPtr.setRegions, OperationPtr.set, Rewriter.createRegion] - sorry --grind [Region.empty, RegionPtr.set, RegionPtr.InBounds] - have : (⟨1⟩ : RegionPtr).get ctx (by simp_all) = Region.empty := by - simp_all [Rewriter.initOpRegions, OperationPtr.setRegions, OperationPtr.set, Rewriter.createRegion] - grind [Region.empty, RegionPtr.set, RegionPtr.InBounds, RegionPtr.get, RegionPtr.allocEmpty] + { Operation.empty ModuleTypeID with regions := #[⟨1⟩] } := by sorry + have : ∀ (bl : BlockPtr), bl.InBounds ctx ↔ False := by sorry + have : ∀ (r : RegionPtr), r.InBounds ctx ↔ r = ⟨1⟩ := by sorry + have : (⟨1⟩ : RegionPtr).get ctx (by simp_all) = Region.empty := by sorry have : ctx.FieldsInBounds := by constructor · grind [Operation.empty] diff --git a/Veir/Rewriter/GetSetInBounds.lean b/Veir/Rewriter/GetSetInBounds.lean index d7dec0c..3416018 100644 --- a/Veir/Rewriter/GetSetInBounds.lean +++ b/Veir/Rewriter/GetSetInBounds.lean @@ -1952,31 +1952,30 @@ theorem OperationPtr.getNumOperands_iff_replaceValue? @[grind .] theorem Rewriter.createOp_inBounds_mono (ptr : GenericPtr) - (heq : createOp ctx opType numResults operands numRegions props ip h₁ h₂ h₃ = some (newCtx, newOp)) : + (heq : createOp ctx opType numResults operands regions props ip h₁ h₂ h₃ h₄ = some (newCtx, newOp)) : ptr.InBounds ctx → ptr.InBounds newCtx := by - simp [createOp] at heq + simp only [createOp] at heq + split at heq; grind split at heq + · split at heq; grind + intros hptr + rename_i h + simp at heq + have ⟨_, _⟩ := heq + subst newOp + subst newCtx + rw [←Rewriter.insertOp?_inBounds_mono ptr h] + grind · grind - · split at heq - · grind - · split at heq - · split at heq - case h_1 => grind - case h_2 ctx hctx => - have := Rewriter.insertOp?_inBounds_mono ptr hctx - grind - · grind @[grind .] theorem Rewriter.createOp_fieldsInBounds - (heq : createOp ctx opType numResults operands numRegions props ip h₁ h₂ h₃ = some (newCtx, newOp)) : + (heq : createOp ctx opType numResults operands numRegions props ip h₁ h₂ h₃ h₄ = some (newCtx, newOp)) : ctx.FieldsInBounds → newCtx.FieldsInBounds := by - intros hx - simp [createOp] at heq + simp only [createOp] at heq + split at heq; grind split at heq - · grind · split at heq · grind - · split at heq - · split at heq <;> grind - · grind + · grind + · grind diff --git a/Veir/Rewriter/WellFormed/Builder/OpRegion.lean b/Veir/Rewriter/WellFormed/Builder/OpRegion.lean index b25eab5..7516c55 100644 --- a/Veir/Rewriter/WellFormed/Builder/OpRegion.lean +++ b/Veir/Rewriter/WellFormed/Builder/OpRegion.lean @@ -5,8 +5,8 @@ import Veir.Rewriter.Basic namespace Veir set_option warn.sorry false in -theorem Rewriter.initOpRegions_WellFormed (ctx: IRContext) (opPtr: OperationPtr) (numRegions: Nat) +theorem Rewriter.initOpRegions_WellFormed (opPtr: OperationPtr) (hop : opPtr.InBounds ctx) (hctx : IRContext.WellFormed ctx) (newCtx : IRContext): - Rewriter.initOpRegions ctx opPtr numRegions hop = some newCtx → + Rewriter.initOpRegions ctx opPtr regions n hop regionInBounds (by grind) hn = some newCtx → newCtx.WellFormed := by sorry diff --git a/Veir/Rewriter/WellFormed/Builder/Operation.lean b/Veir/Rewriter/WellFormed/Builder/Operation.lean index 12c6d17..039ee2c 100644 --- a/Veir/Rewriter/WellFormed/Builder/Operation.lean +++ b/Veir/Rewriter/WellFormed/Builder/Operation.lean @@ -15,11 +15,12 @@ theorem Rewriter.createOp_WellFormed (ctx: IRContext) (opType: Nat) (numResults: Nat) (operands: Array ValuePtr) (numRegions: Nat) (properties: UInt64) (insertionPoint: Option InsertPoint) (hoper : ∀ oper, oper ∈ operands → oper.InBounds ctx) + hregions (hins : insertionPoint.maybe InsertPoint.InBounds ctx) (hx : ctx.FieldsInBounds) (hctx : IRContext.WellFormed ctx) (newCtx : IRContext) (newOp : OperationPtr) : - Rewriter.createOp ctx opType numResults operands numRegions properties insertionPoint hoper hins hx = some (newCtx, newOp) → + Rewriter.createOp ctx opType numResults operands regions properties insertionPoint hoper hregions hins hx = some (newCtx, newOp) → newCtx.WellFormed := by intros heq constructor