Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions KLR/Compile/Pass.lean
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,10 @@ def freshName (name : Name := `tmp) : PassM Name :=
let n := s.freshVarNum + 1
(.num name n, { s with freshVarNum := n })

def resetPassState : PassM Unit :=
modify fun st =>
{st with freshVarNum := 0}

-- Emit a warning / linter message
def warn (msg : String) : PassM Unit :=
modify (PassState.warn msg)
Expand Down
2 changes: 2 additions & 0 deletions KLR/Trace/NKI.lean
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import KLR.Trace.ISA
import KLR.Trace.Term
import KLR.Trace.Types
import KLR.Trace.Lang
import KLR.Compile.Pass

/-
# NKI built-ins
Expand Down Expand Up @@ -551,6 +552,7 @@ def traceKernel (k : Kernel) : Trace Core.Kernel := do
let _ <- beginBlock (<- genName `main).toString
addId
globals k
resetPassState
match k.funs.find? fun f => f.name == k.entry with
| none => throw s!"function {k.entry} not found"
| some f => do
Expand Down
6 changes: 3 additions & 3 deletions KLR/Trace/Types.lean
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ namespace KLR.Trace
open KLR.Core
open KLR.Compile.Pass
export Core (Name)
export KLR.Compile.Pass (withPos withFile getPos warn message)
export KLR.Compile.Pass (withPos withFile getPos warn message resetPassState)
export NKI (Pos BinOp)

abbrev SharedConstant := String × TensorLib.Tensor
Expand Down Expand Up @@ -181,8 +181,8 @@ instance : Inhabited State where

abbrev Trace := Pass State

-- generate a fresh name using an existing name as a prefix
def genName (name : Name := `tmp) : Trace Name := freshName name
def genName (name : Name := `tmp) : Trace Name := do
freshName name

-- add a new binding to the global environment
def extend_global (x : Name) (v : Term) : Trace Unit :=
Expand Down
Loading