diff --git a/KLR/Trace/Term.lean b/KLR/Trace/Term.lean index a11cfdaa..4b5a6b9d 100644 --- a/KLR/Trace/Term.lean +++ b/KLR/Trace/Term.lean @@ -279,10 +279,16 @@ def toIndex (shape : List Nat) (ts : List Term) : Err (List Index) := do | _ => ts toIndex' shape ts +-- Container type information for preserving original types during access +inductive ContainerType where + | tuple : ContainerType + | list : ContainerType + deriving Repr, BEq + -- Note, a list index can be negative, which means index from end of list. -- Python also allows l[True] and l[False] -- TODO: add case for slice -def listAccess (l : List Term) : List Term -> Err Term +def listAccess (l : List Term) (containerType : Option ContainerType := none) : List Term -> Err Term | [.bool false] => do if h:l.length > 0 then return l.get (Fin.mk 0 h) else throw "index out of bounds" @@ -306,7 +312,9 @@ def listAccess (l : List Term) : List Term -> Err Term throw "slice index out of bounds" let sliced := List.range ((e - start + step - 1) / step).toNat |>.map fun i => l[start.toNat + i * step.toNat]! - return .list sliced.toArray + match containerType with + | some .tuple => return .tuple sliced + | some .list | none => return .list sliced.toArray | e => throw s!"index must be an integer or slice, got {repr e}" def dictAccess (arr : AA) : List Term -> Err Term @@ -408,8 +416,8 @@ partial def access (e : Term) (indexes : List Term) : Trace Term := do match e with | .ref name _ => access (<- lookup name) indexes | .string _ => throw "string subscript not implemented" - | .tuple l => listAccess l indexes - | .list l => listAccess l.toList indexes + | .tuple l => listAccess l (some .tuple) indexes + | .list l => listAccess l.toList (some .list) indexes | .dict arr => dictAccess arr indexes | .pointer addr => pointerAccess addr indexes | .access (.simple tensor) => do