From dd252540c50f3bb477b5705b1c718abb9353c6c9 Mon Sep 17 00:00:00 2001 From: Pavel Potapov Date: Mon, 6 Oct 2025 14:32:04 -0400 Subject: [PATCH] fix: naming discrepancy fix: naming discrepancy Added reset function for pass monad --- KLR/Compile/Pass.lean | 4 ++++ KLR/Trace/NKI.lean | 2 ++ KLR/Trace/Types.lean | 6 +++--- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/KLR/Compile/Pass.lean b/KLR/Compile/Pass.lean index cf287bac..4b6870a8 100644 --- a/KLR/Compile/Pass.lean +++ b/KLR/Compile/Pass.lean @@ -213,6 +213,10 @@ def freshName (name : Name := `tmp) : PassM Name := let n := s.freshVarNum + 1 (.num name n, { s with freshVarNum := n }) +def resetPassState : PassM Unit := + modify fun st => + {st with freshVarNum := 0} + -- Emit a warning / linter message def warn (msg : String) : PassM Unit := modify (PassState.warn msg) diff --git a/KLR/Trace/NKI.lean b/KLR/Trace/NKI.lean index 889446fa..a8f93ba1 100644 --- a/KLR/Trace/NKI.lean +++ b/KLR/Trace/NKI.lean @@ -22,6 +22,7 @@ import KLR.Trace.ISA import KLR.Trace.Term import KLR.Trace.Types import KLR.Trace.Lang +import KLR.Compile.Pass /- # NKI built-ins @@ -551,6 +552,7 @@ def traceKernel (k : Kernel) : Trace Core.Kernel := do let _ <- beginBlock (<- genName `main).toString addId globals k + resetPassState match k.funs.find? fun f => f.name == k.entry with | none => throw s!"function {k.entry} not found" | some f => do diff --git a/KLR/Trace/Types.lean b/KLR/Trace/Types.lean index cd24d4ac..f30f7b95 100644 --- a/KLR/Trace/Types.lean +++ b/KLR/Trace/Types.lean @@ -68,7 +68,7 @@ namespace KLR.Trace open KLR.Core open KLR.Compile.Pass export Core (Name) -export KLR.Compile.Pass (withPos withFile getPos warn message) +export KLR.Compile.Pass (withPos withFile getPos warn message resetPassState) export NKI (Pos BinOp) abbrev SharedConstant := String × TensorLib.Tensor @@ -181,8 +181,8 @@ instance : Inhabited State where abbrev Trace := Pass State --- generate a fresh name using an existing name as a prefix -def genName (name : Name := `tmp) : Trace Name := freshName name +def genName (name : Name := `tmp) : Trace Name := do + freshName name -- add a new binding to the global environment def extend_global (x : Name) (v : Term) : Trace Unit :=