diff --git a/KLR/Core/LowerAP.lean b/KLR/Core/LowerAP.lean index c08d81f5..52365ee2 100644 --- a/KLR/Core/LowerAP.lean +++ b/KLR/Core/LowerAP.lean @@ -34,10 +34,17 @@ abbrev LowerAP := StateT LowerAPState KLR.Err /-- Function to convert an Access to an AccessPattern. Note: This lowering does not work in all cases, for example, if the Access in an AccessBasic whose Par dimension takes steps that are not equal to 1. Returns a None in this case. -/ -def Access.lowerAccessPattern (a : Access) : LowerAP BirAccessPattern := do +partial def Access.lowerAccessPattern (a : Access) : LowerAP BirAccessPattern := do -- Don't violate invariants of proved code if let .birPattern b := a then - return b + -- Lower vectorOffset if it exists and isn't already lowered + let vectorOffset <- b.vectorOffset.mapM fun vo => + match vo with + | .birPattern _ => pure vo + | _ => do + let lowered <- vo.lowerAccessPattern + pure (.birPattern lowered) + return { b with vectorOffset } -- The layout of a tensor in memory -- Note that because accesses are values, we have are forced to assume that all tensors are