From ac0dba947e3f581e64b8e63ca4da705975b60d60 Mon Sep 17 00:00:00 2001 From: Pavel Potapov Date: Wed, 15 Oct 2025 10:46:19 -0400 Subject: [PATCH 1/2] NKIFE-206 fix: check name uniqueness against args --- KLR/Trace/FromNKI.lean | 83 +++++++++++++++++++++++++++--------------- KLR/Trace/NKI.lean | 14 +++---- 2 files changed, 60 insertions(+), 37 deletions(-) diff --git a/KLR/Trace/FromNKI.lean b/KLR/Trace/FromNKI.lean index f5896e1b..dca23622 100644 --- a/KLR/Trace/FromNKI.lean +++ b/KLR/Trace/FromNKI.lean @@ -225,35 +225,60 @@ instance : FromNKI TensorName where | t => throw s!"expecting 'tensor', got '{Term.kindStr t}'" instance : FromNKI AluOp where - fromNKI? t := - match fromEnum t with - -- bitwise operations - | some "invert" => return .bitwise_not - | some "bitwise_and" => return .bitwise_and - | some "bitwise_or" => return .bitwise_or - | some "bitwise_xor" => return .bitwise_xor - | some "left_shift" => return .logical_shift_left - | some "right_shift" => return .logical_shift_right - -- arithemetic operations - | some "add" => return .add - | some "subtract" => return .subtract - | some "multiply" => return .mult - | some "maximum" => return .max - | some "minimum" => return .min - | some "equal" => return .is_equal - | some "not_equal" => return .not_equal - | some "greater_equal" => return .is_ge - | some "greater" => return .is_gt - | some "less_equal" => return .is_le - | some "less" => return .is_lt - | some "logical_not" => throw "'logical_not' operator not supported" - | some "logical_and" => return .logical_and - | some "logical_or" => return .logical_or - | some "logical_xor" => return .logical_xor - | some "rsqrt" => return .rsqrt - | some "abs" => return .abs - | some "power" => return .pow - | _ => throw s!"expecting operator, got '{Term.kindStr t}'" + fromNKI? + | .none => return .bypass + | .source {name, ..} + | .var name => + match name with + -- bitwise operations + | `nki.language.invert => return .bitwise_not + | `nki.language.bitwise_and => return .bitwise_and + | `nki.language.bitwise_or => return .bitwise_or + | `nki.language.bitwise_xor => return .bitwise_xor + | `nki.language.left_shift => return .logical_shift_left + | `nki.language.right_shift => return .logical_shift_right + -- numpy variants + | `numpy.bitwise_not => return .bitwise_not + | `numpy.bitwise_invert => return .bitwise_not + | `numpy.bitwise_and => return .bitwise_and + | `numpy.bitwise_or => return .bitwise_or + | `numpy.bitwise_xor => return .bitwise_xor + | `numpy.bitwise_left_shift => return .logical_shift_left + | `numpy.bitwise_right_shift => return .logical_shift_right + -- arithemetic operations + | `nki.language.add => return .add + | `nki.language.subtract => return .subtract + | `nki.language.multiply => return .mult + | `nki.language.maximum => return .max + | `nki.language.minimum => return .min + | `nki.language.equal => return .is_equal + | `nki.language.not_equal => return .not_equal + | `nki.language.greater_equal => return .is_ge + | `nki.language.greater => return .is_gt + | `nki.language.less_equal => return .is_le + | `nki.language.less => return .is_lt + | `nki.language.logical_not => throw "'logical_not' operator not supported" + | `nki.language.logical_and => return .logical_and + | `nki.language.logical_or => return .logical_or + | `nki.language.logical_xor => return .logical_xor + -- numpy variants + | `numpy.add => return .add + | `numpy.subtract => return .subtract + | `numpy.multiply => return .mult + | `numpy.maximum => return .max + | `numpy.minimum => return .min + | `numpy.equal => return .is_equal + | `numpy.not_equal => return .not_equal + | `numpy.greater_equal => return .is_ge + | `numpy.greater => return .is_gt + | `numpy.less_equal => return .is_le + | `numpy.less => return .is_lt + | `numpy.logical_not => throw "'logical_not' operator not supported" + | `numpy.logical_and => return .logical_and + | `numpy.logical_or => return .logical_or + | `numpy.logical_xor => return .logical_xor + | _ => throw s!"unsupported operator {name}" + | t => throw s!"expecting operator, got '{Term.kindStr t}'" instance : FromNKI ActivationFunc where fromNKI? t := diff --git a/KLR/Trace/NKI.lean b/KLR/Trace/NKI.lean index f08d62ac..27a7e124 100644 --- a/KLR/Trace/NKI.lean +++ b/KLR/Trace/NKI.lean @@ -543,17 +543,15 @@ private def globals (k : Kernel) : Trace Unit := do catch _ => pure () -private def processArgs (args : List Arg) : List Value × List Keyword := Id.run do - let mut inputs : List Value := [] - let mut kws : List Keyword := [] - for ⟨ name, e ⟩ in args do +private def processArgs (args : List Arg) : Trace (List Value × List Keyword) := do + let (inputs, kws) <- args.foldlM (init := ([], [])) fun (inputs, kws) ⟨name, e⟩ => do + modify fun s => {s with tensorNames := s.tensorNames.insert name} match e with | ⟨ .value (.tensor s d _), pos ⟩ => let t := .tensor s d name - inputs := t :: inputs let e' := ⟨ .value t, pos ⟩ - kws := .mk name e' :: kws - | _ => kws := .mk name e :: kws + return (t :: inputs, .mk name e' :: kws) + | _ => return (inputs, .mk name e :: kws) return (inputs.reverse, kws.reverse) @@ -573,7 +571,7 @@ def traceKernel (k : Kernel) : Trace Core.Kernel := do match k.funs.find? fun f => f.name == k.entry with | none => throw s!"function {k.entry} not found" | some f => do - let (inputs, args) := processArgs k.args + let (inputs, args) <- processArgs k.args let res <- fnCall (.source f) [] (<- args.mapM keyword) let inputs <- inputs.mapM value let inputs := Core.tensors inputs From 1c2151bb169ee3952047932d345f3e5b1f13f1a2 Mon Sep 17 00:00:00 2001 From: Pavel Potapov Date: Mon, 12 Jan 2026 14:27:46 -0500 Subject: [PATCH 2/2] Fix kwargs existing check --- KLR/Trace/FromNKI.lean | 83 +++++++++++++++--------------------------- KLR/Trace/NKI.lean | 4 ++ 2 files changed, 33 insertions(+), 54 deletions(-) diff --git a/KLR/Trace/FromNKI.lean b/KLR/Trace/FromNKI.lean index dca23622..f5896e1b 100644 --- a/KLR/Trace/FromNKI.lean +++ b/KLR/Trace/FromNKI.lean @@ -225,60 +225,35 @@ instance : FromNKI TensorName where | t => throw s!"expecting 'tensor', got '{Term.kindStr t}'" instance : FromNKI AluOp where - fromNKI? - | .none => return .bypass - | .source {name, ..} - | .var name => - match name with - -- bitwise operations - | `nki.language.invert => return .bitwise_not - | `nki.language.bitwise_and => return .bitwise_and - | `nki.language.bitwise_or => return .bitwise_or - | `nki.language.bitwise_xor => return .bitwise_xor - | `nki.language.left_shift => return .logical_shift_left - | `nki.language.right_shift => return .logical_shift_right - -- numpy variants - | `numpy.bitwise_not => return .bitwise_not - | `numpy.bitwise_invert => return .bitwise_not - | `numpy.bitwise_and => return .bitwise_and - | `numpy.bitwise_or => return .bitwise_or - | `numpy.bitwise_xor => return .bitwise_xor - | `numpy.bitwise_left_shift => return .logical_shift_left - | `numpy.bitwise_right_shift => return .logical_shift_right - -- arithemetic operations - | `nki.language.add => return .add - | `nki.language.subtract => return .subtract - | `nki.language.multiply => return .mult - | `nki.language.maximum => return .max - | `nki.language.minimum => return .min - | `nki.language.equal => return .is_equal - | `nki.language.not_equal => return .not_equal - | `nki.language.greater_equal => return .is_ge - | `nki.language.greater => return .is_gt - | `nki.language.less_equal => return .is_le - | `nki.language.less => return .is_lt - | `nki.language.logical_not => throw "'logical_not' operator not supported" - | `nki.language.logical_and => return .logical_and - | `nki.language.logical_or => return .logical_or - | `nki.language.logical_xor => return .logical_xor - -- numpy variants - | `numpy.add => return .add - | `numpy.subtract => return .subtract - | `numpy.multiply => return .mult - | `numpy.maximum => return .max - | `numpy.minimum => return .min - | `numpy.equal => return .is_equal - | `numpy.not_equal => return .not_equal - | `numpy.greater_equal => return .is_ge - | `numpy.greater => return .is_gt - | `numpy.less_equal => return .is_le - | `numpy.less => return .is_lt - | `numpy.logical_not => throw "'logical_not' operator not supported" - | `numpy.logical_and => return .logical_and - | `numpy.logical_or => return .logical_or - | `numpy.logical_xor => return .logical_xor - | _ => throw s!"unsupported operator {name}" - | t => throw s!"expecting operator, got '{Term.kindStr t}'" + fromNKI? t := + match fromEnum t with + -- bitwise operations + | some "invert" => return .bitwise_not + | some "bitwise_and" => return .bitwise_and + | some "bitwise_or" => return .bitwise_or + | some "bitwise_xor" => return .bitwise_xor + | some "left_shift" => return .logical_shift_left + | some "right_shift" => return .logical_shift_right + -- arithemetic operations + | some "add" => return .add + | some "subtract" => return .subtract + | some "multiply" => return .mult + | some "maximum" => return .max + | some "minimum" => return .min + | some "equal" => return .is_equal + | some "not_equal" => return .not_equal + | some "greater_equal" => return .is_ge + | some "greater" => return .is_gt + | some "less_equal" => return .is_le + | some "less" => return .is_lt + | some "logical_not" => throw "'logical_not' operator not supported" + | some "logical_and" => return .logical_and + | some "logical_or" => return .logical_or + | some "logical_xor" => return .logical_xor + | some "rsqrt" => return .rsqrt + | some "abs" => return .abs + | some "power" => return .pow + | _ => throw s!"expecting operator, got '{Term.kindStr t}'" instance : FromNKI ActivationFunc where fromNKI? t := diff --git a/KLR/Trace/NKI.lean b/KLR/Trace/NKI.lean index 27a7e124..8c1a4607 100644 --- a/KLR/Trace/NKI.lean +++ b/KLR/Trace/NKI.lean @@ -273,6 +273,10 @@ partial def bindArgs : Trace (List (String × Term)) := do if args.length + kwargs.length > f.args.length then throw "too many arguments given (varargs not supported)" + let validArgNames := f.args.map (·.name) + kwargs.forM fun (name, _) => do + if !validArgNames.contains name then + throw s!"unexpected keyword argument '{name}'" f.args.zipIdx.mapM fun ({name := x, dflt := d}, i) => do if h:args.length > i then pure ⟨x, args.get (Fin.mk i h)⟩