Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions KLR/Trace/Term.lean
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,16 @@ def toIndex (shape : List Nat) (ts : List Term) : Err (List Index) := do
| _ => ts
toIndex' shape ts

-- Container type information for preserving original types during access
inductive ContainerType where
| tuple : ContainerType
| list : ContainerType
deriving Repr, BEq

-- Note, a list index can be negative, which means index from end of list.
-- Python also allows l[True] and l[False]
-- TODO: add case for slice
def listAccess (l : List Term) : List Term -> Err Term
def listAccess (l : List Term) (containerType : Option ContainerType := none) : List Term -> Err Term
| [.bool false] => do
if h:l.length > 0 then return l.get (Fin.mk 0 h)
else throw "index out of bounds"
Expand All @@ -306,7 +312,9 @@ def listAccess (l : List Term) : List Term -> Err Term
throw "slice index out of bounds"
let sliced := List.range ((e - start + step - 1) / step).toNat |>.map fun i =>
l[start.toNat + i * step.toNat]!
return .list sliced.toArray
match containerType with
| some .tuple => return .tuple sliced
| some .list | none => return .list sliced.toArray
| e => throw s!"index must be an integer or slice, got {repr e}"

def dictAccess (arr : AA) : List Term -> Err Term
Expand Down Expand Up @@ -408,8 +416,8 @@ partial def access (e : Term) (indexes : List Term) : Trace Term := do
match e with
| .ref name _ => access (<- lookup name) indexes
| .string _ => throw "string subscript not implemented"
| .tuple l => listAccess l indexes
| .list l => listAccess l.toList indexes
| .tuple l => listAccess l (some .tuple) indexes
| .list l => listAccess l.toList (some .list) indexes
| .dict arr => dictAccess arr indexes
| .pointer addr => pointerAccess addr indexes
| .access (.simple tensor) => do
Expand Down
Loading