diff --git a/KLR/Trace/NKI.lean b/KLR/Trace/NKI.lean index f08d62ac..8c1a4607 100644 --- a/KLR/Trace/NKI.lean +++ b/KLR/Trace/NKI.lean @@ -273,6 +273,10 @@ partial def bindArgs : Trace (List (String × Term)) := do if args.length + kwargs.length > f.args.length then throw "too many arguments given (varargs not supported)" + let validArgNames := f.args.map (·.name) + kwargs.forM fun (name, _) => do + if !validArgNames.contains name then + throw s!"unexpected keyword argument '{name}'" f.args.zipIdx.mapM fun ({name := x, dflt := d}, i) => do if h:args.length > i then pure ⟨x, args.get (Fin.mk i h)⟩ @@ -543,17 +547,15 @@ private def globals (k : Kernel) : Trace Unit := do catch _ => pure () -private def processArgs (args : List Arg) : List Value × List Keyword := Id.run do - let mut inputs : List Value := [] - let mut kws : List Keyword := [] - for ⟨ name, e ⟩ in args do +private def processArgs (args : List Arg) : Trace (List Value × List Keyword) := do + let (inputs, kws) <- args.foldlM (init := ([], [])) fun (inputs, kws) ⟨name, e⟩ => do + modify fun s => {s with tensorNames := s.tensorNames.insert name} match e with | ⟨ .value (.tensor s d _), pos ⟩ => let t := .tensor s d name - inputs := t :: inputs let e' := ⟨ .value t, pos ⟩ - kws := .mk name e' :: kws - | _ => kws := .mk name e :: kws + return (t :: inputs, .mk name e' :: kws) + | _ => return (inputs, .mk name e :: kws) return (inputs.reverse, kws.reverse) @@ -573,7 +575,7 @@ def traceKernel (k : Kernel) : Trace Core.Kernel := do match k.funs.find? fun f => f.name == k.entry with | none => throw s!"function {k.entry} not found" | some f => do - let (inputs, args) := processArgs k.args + let (inputs, args) <- processArgs k.args let res <- fnCall (.source f) [] (<- args.mapM keyword) let inputs <- inputs.mapM value let inputs := Core.tensors inputs