diff --git a/KLR/Fuzz/Main.lean b/KLR/Fuzz/Main.lean new file mode 100644 index 00000000..23587342 --- /dev/null +++ b/KLR/Fuzz/Main.lean @@ -0,0 +1,377 @@ +/- +Copyright KLR Contributors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +-/ + +import Cli +import KLR.Py.Gen +import KLR.Py.Pretty +import KLR.NKI.Simplify +import KLR.Compile.Pass +import KLR.Python + +/-! +# AST Fuzzer for Simplify Pass Validation + +Generates random Python ASTs, runs them through the simplify pass, +and outputs valid Python for seeds that pass simplification. +-/ + +open Cli +open KLR.Py +open KLR.Compile.Pass (runPasses runPass) +open System (FilePath) + +-- Inhabited instance for Python.Expr' +instance : Inhabited KLR.Python.Expr' where + default := .const .none + +instance : Inhabited KLR.Python.Expr where + default := { expr := .const .none, pos := { line := 0 } } + +instance : Inhabited KLR.Python.Stmt' where + default := .pass + +instance : Inhabited KLR.Python.Stmt where + default := { stmt := .pass, pos := { line := 0 } } + +def defaultPos : KLR.Core.Pos := { line := 1 } + +/-- Convert Py.Value to Python.Const -/ +def valueToConst : Value → KLR.Python.Const + | .none => .none + | .bool b => .bool b + | .int i => .int i + | .float f => .float f + | .string s => .string s + | .ellipsis => .ellipsis + +/-- Convert Py.UnaryOp to Python.UnaryOp -/ +def unaryOpConv : UnaryOp → KLR.Python.UnaryOp + | .pos => .uadd + | .neg => .usub + | .bitwiseNot => .invert + | .lnot => .not + +/-- Convert Py.BinOp to Python.BinOp (for non-comparison ops) -/ +def binOpConv : BinOp → KLR.Python.BinOp + | .land | .lor => .and -- handled specially + | .bitwiseOr => .or + | .bitwiseXor => .xor + | .bitwiseAnd => .and + | .lshift => .lshift + | .rshift => .rshift + | .eq | .ne | .lt | .le | .gt | .ge => .add -- comparisons handled separately + | .add => .add + | .sub => .sub + | .mul => .mul + | .div => .div + | .mod => .mod + | .pow => .pow + | .floor => .floor + | .matmul => .matmul + +/-- Check if BinOp is a comparison -/ +def isCompareOp : BinOp → Bool + | .eq | .ne | .lt | .le | .gt | .ge => true + | _ => false + +/-- Convert BinOp to CmpOp -/ +def binOpToCmpOp : BinOp → KLR.Python.CmpOp + | .eq => .eq + | .ne => .ne + | .lt => .lt + | .le => .le + | .gt => .gt + | .ge => .ge + | _ => .eq -- shouldn't happen + +/-- Check if BinOp is a boolean op -/ +def isBoolOp : BinOp → Bool + | .land | .lor => true + | _ => false + +/-- Convert BinOp to BoolOp -/ +def binOpToBoolOp : BinOp → KLR.Python.BoolOp + | .land => .land + | .lor => .lor + | _ => .land -- shouldn't happen + +mutual +/-- Convert Py.Exp to Python.Expr -/ +partial def pyExpToExpr (e : Exp) : KLR.Python.Expr := + { expr := pyExp'ToExpr' e.exp, pos := defaultPos } + +partial def pyExp'ToExpr' : Exp' → KLR.Python.Expr' + | .var name => .name name .load + | .value v => .const (valueToConst v) + | .unaryOp op x => .unaryOp (unaryOpConv op) (pyExpToExpr x) + | .binOp op x y => + if isCompareOp op then + .compare (pyExpToExpr x) [binOpToCmpOp op] [pyExpToExpr y] + else if isBoolOp op then + .boolOp (binOpToBoolOp op) [pyExpToExpr x, pyExpToExpr y] + else + .binOp (binOpConv op) (pyExpToExpr x) (pyExpToExpr y) + | .tuple es => .tuple (es.map pyExpToExpr) .load + | .list es => .list (es.map pyExpToExpr) .load + | .ifExp test body orelse => + .ifExp (pyExpToExpr test) (pyExpToExpr body) (pyExpToExpr orelse) + | .call f _ args => + .call (pyExpToExpr f) + (args.filterMap fun a => if a.keyword.isNone then some (pyExpToExpr a.val) else none) + (args.filterMap fun a => a.keyword.map fun k => ⟨some k, pyExpToExpr a.val, defaultPos⟩) + | .access e indices => .subscript (pyExpToExpr e) (indicesToExpr indices) .load + | .attr e field => .attr (pyExpToExpr e) field .load + +partial def indexToExpr : Index → KLR.Python.Expr + | .coord i => pyExpToExpr i + | .slice l u step => + { expr := .slice (l.map pyExpToExpr) (u.map pyExpToExpr) (step.map pyExpToExpr) + , pos := defaultPos } + | .dynamic t _ _ => pyExpToExpr t + +partial def indicesToExpr (indices : List Index) : KLR.Python.Expr := + match indices with + | [idx] => indexToExpr idx + | _ => + let exprs := indices.map indexToExpr + { expr := .tuple exprs .load, pos := defaultPos } +end + +/-- Convert Py.Pattern to Python.Expr -/ +def pyPatternToExpr : Pattern → KLR.Python.Expr + | .var name => { expr := .name name .store, pos := defaultPos } + | .tuple pats => { expr := .tuple (pats.map pyPatternToExpr) .store, pos := defaultPos } + +mutual +/-- Convert Py.Stmt to Python.Stmt -/ +partial def pyStmtToStmt (s : Stmt) : KLR.Python.Stmt := + { stmt := pyStmt'ToStmt' s.stmt, pos := defaultPos } + +partial def pyStmt'ToStmt' : Stmt' → KLR.Python.Stmt' + | .exp e => .expr (pyExpToExpr e) + | .imprt _ _ => .pass + | .imprtFrom _ _ _ => .pass + | .ret e => .ret (pyExpToExpr e) + | .assign lhs _ rhs => .assign [pyExpToExpr lhs] (pyExpToExpr rhs) + | .assert e => .assert (pyExpToExpr e) none + | .funcDef _ => .pass -- nested functions not supported + | .ifStm cond thn _ els => + let thnStmts := thn.map pyStmtToStmt + let elsStmts := els.map (·.map pyStmtToStmt) |>.getD [] + .ifStm (pyExpToExpr cond) thnStmts elsStmts + | .forLoop pat iter body => + .forLoop (pyPatternToExpr pat) (pyExpToExpr iter) (body.map pyStmtToStmt) [] + | .whileLoop cond body => + .whileLoop (pyExpToExpr cond) (body.map pyStmtToStmt) [] + | .pass => .pass + | .breakLoop => .breakLoop + | .continueLoop => .continueLoop +end + +/-- Convert Py.FuncDef to Python.Fun -/ +def pyFuncDefToFun (f : FuncDef) : KLR.Python.Fun := { + name := f.name + fileName := "generated.py" + line := 1 + source := "" + decorators := [] + args := { + posonlyargs := [] + args := f.params.map (·.name) + defaults := f.params.filterMap (·.dflt) |>.map pyExpToExpr + vararg := none + kwonlyargs := [] + kw_defaults := [] + kwarg := none + } + body := f.body.map pyStmtToStmt +} + +/-- Generate random enum class -/ +def genEnumClass (seed : Nat) (idx : Nat) : KLR.Python.Class := + let rng := (seed * 1103515245 + 12345 + idx * 7) % (2^31) + let nMembers := (rng % 4) + 1 + let fields := (List.range nMembers).map fun i => + let val : Int := ((rng * (i + 1)) % 200 : Nat) + ⟨some s!"MEMBER_{i}", { expr := .const (.int val), pos := defaultPos }, defaultPos⟩ + { + name := s!"MyEnum_{idx}" + bases := [{ expr := .name "Enum" .load, pos := defaultPos }] + decorators := [] + fields := fields + methods := [] + } + +/-- Generate random NKIObject class with methods -/ +def genNKIObjectClass (seed : Nat) (idx : Nat) : KLR.Python.Class := + let rng := (seed * 1103515245 + 12345 + idx * 13) % (2^31) + let nFields := (rng % 3) + 1 + let fields := (List.range nFields).map fun i => + let val : Int := ((rng * (i + 1)) % 100 : Nat) + ⟨some s!"field_{i}", { expr := .const (.int val), pos := defaultPos }, defaultPos⟩ + let nMethods := (rng % 2) + 1 + let methods : List KLR.Python.Fun := (List.range nMethods).map fun i => + let retVal : Int := ((rng + i) % 50 : Nat) + let methodBody : KLR.Python.Stmt := { + stmt := .ret { expr := .const (.int retVal), pos := defaultPos } + pos := defaultPos + } + { + name := s!"method_{i}" + fileName := "generated.py" + line := 1 + source := "" + decorators := [] + args := { posonlyargs := [], args := ["self"], defaults := [], vararg := none, kwonlyargs := [], kw_defaults := [], kwarg := none } + body := [methodBody] + } + { + name := s!"MyClass_{idx}" + bases := [{ expr := .name "NKIObject" .load, pos := defaultPos }] + decorators := [{ expr := .name "dataclass" .load, pos := defaultPos }] + fields := fields + methods := methods + } + +/-- Convert Py.Prog to Python.Kernel for simplify -/ +def pyProgToKernel (p : Prog) (entry : String) (seed : Nat) : KLR.Python.Kernel := + let funcs := p.stmts.filterMap fun s => + match s.stmt with + | .funcDef f => some (pyFuncDefToFun f) + | _ => none + + -- Generate enums and classes based on seed + let rng := seed % (2^31) + let nEnums := (rng % 3) + let nClasses := ((rng / 3) % 3) + let enums := (List.range nEnums).map (genEnumClass seed) + let classes := (List.range nClasses).map (genNKIObjectClass seed) + + { + entry := entry + funcs := funcs + classes := enums ++ classes + args := [] + kwargs := [] + globals := [] + arch := 0 + grid := 1 + scheduleEdges := [] + flags := [] + } + +/-- Try to run simplify on a generated program, return error if failed -/ +def trySimplify (prog : Prog) (seed : Nat) : Option KLR.NKI.Kernel × Option String := + let kernel := pyProgToKernel prog "test_kernel" seed + let result := runPasses (runPass (KLR.NKI.simplify kernel)) + match result.result with + | some k => (some k, none) + | none => (none, result.errors.get? 0) + +/-- Generate Python code for an enum class -/ +def enumClassToString (seed : Nat) (idx : Nat) : String := + let rng := (seed * 1103515245 + 12345 + idx * 7) % (2^31) + let nMembers := (rng % 4) + 1 + let members := (List.range nMembers).map fun i => + let val := (rng * (i + 1)) % 200 + s!" MEMBER_{i} = {val}" + s!"class MyEnum_{idx}(Enum):\n{"\n".intercalate members}\n" + +/-- Generate Python code for an NKIObject class -/ +def nkiClassToString (seed : Nat) (idx : Nat) : String := + let rng := (seed * 1103515245 + 12345 + idx * 13) % (2^31) + let nFields := (rng % 3) + 1 + let fields := (List.range nFields).map fun i => + let val := (rng * (i + 1)) % 100 + s!" field_{i} = {val}" + let nMethods := (rng % 2) + 1 + let methods := (List.range nMethods).map fun i => + let retVal := (rng + i) % 50 + s!" def method_{i}(self):\n return {retVal}" + s!"@dataclass\nclass MyClass_{idx}(NKIObject):\n{"\n".intercalate fields}\n{"\n".intercalate methods}\n" + +/-- Generate class/enum code based on seed -/ +def genClassesCode (seed : Nat) : String := + let rng := seed % (2^31) + let nEnums := rng % 3 + let nClasses := (rng / 3) % 3 + let enumCode := (List.range nEnums).map (enumClassToString seed) |> "\n".intercalate + let classCode := (List.range nClasses).map (nkiClassToString seed) |> "\n".intercalate + let imports := if nEnums > 0 || nClasses > 0 then "from enum import Enum\nfrom dataclasses import dataclass\n\nclass NKIObject: pass\n\n" else "" + imports ++ enumCode ++ (if enumCode.isEmpty then "" else "\n") ++ classCode ++ (if classCode.isEmpty then "" else "\n") + +/-- Generate and test a single seed -/ +def generateAndTest (seed : Nat) (outputDir : String) (verbose : Bool) : IO Bool := do + let prog := Gen.genProg seed + match trySimplify prog seed with + | (some _, _) => + -- Simplify succeeded, write the Python file with classes + let classesCode := genClassesCode seed + let funcCode := Gen.progToString prog + let content := s!"# Generated from seed: {seed}\n\n{classesCode}{funcCode}" + let path := FilePath.mk outputDir / s!"{seed}.py" + IO.FS.writeFile path content + return true + | (none, err) => + if verbose then + IO.println s!"✗ Seed {seed}: {err.getD "unknown error"}" + return false + +/-- Simple LCG for generating random seeds -/ +def nextSeed (s : Nat) : Nat := + (s * 1103515245 + 12345) % (2^31) + +/-- Main fuzzer command -/ +def fuzz (p : Parsed) : IO UInt32 := do + let outputDir := p.positionalArg! "output" |>.as! String + let count := (p.flag? "count").map (·.as! Nat) |>.getD 100 + let initSeed ← match p.flag? "seed" with + | some f => pure (f.as! Nat) + | none => IO.rand 0 (2^31 - 1) + + IO.FS.createDirAll outputDir + + let mut passed := 0 + let mut failed := 0 + let mut rng := initSeed + + for _ in [:count] do + let seed := rng + rng := nextSeed rng + if ← generateAndTest seed outputDir (p.hasFlag "verbose") then + passed := passed + 1 + IO.println s!"✓ Seed {seed} passed" + else + failed := failed + 1 + + IO.println s!"\nResults: {passed} passed, {failed} failed out of {count}" + return 0 + +def fuzzCmd : Cmd := `[Cli| + fuzz VIA fuzz; + "Generate random ASTs and test simplify pass" + + FLAGS: + n, count : Nat; "Number of programs to generate (default: 100)" + s, seed : Nat; "Starting seed (default: 0)" + v, verbose; "Print failed seeds" + + ARGS: + output : String; "Output directory for passing programs" +] + +def main (args : List String) : IO UInt32 := do + fuzzCmd.validate args diff --git a/KLR/Py/Basic.lean b/KLR/Py/Basic.lean index 8b5c5868..b27e165d 100644 --- a/KLR/Py/Basic.lean +++ b/KLR/Py/Basic.lean @@ -33,6 +33,11 @@ def KLR.Core.Dtype.toString : Dtype → String | .uint16 => "uint16" | .uint32 => "uint32" | .uint64 => "uint64" + | .float8_e4m3 => "float8_e4m3" + | .float8_e4m3fn => "float8_e4m3fn" + | .float8_e5m2_x4 => "float8_e5m2_x4" + | .float8_e4m3fn_x4 => "float8_e4m3fn_x4" + | .float4_e2m1fn_x4 => "float4_e2m1fn_x4" namespace KLR.Py diff --git a/KLR/Py/Gen.lean b/KLR/Py/Gen.lean new file mode 100644 index 00000000..275956e2 --- /dev/null +++ b/KLR/Py/Gen.lean @@ -0,0 +1,348 @@ +/- +Copyright KLR Contributors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +-/ + +import KLR.Py.Basic +import KLR.Py.Pretty + +/-! +# Random AST Generator for KLR.Py + +Generates random Python ASTs based on the definitions in Basic.lean. +Used for fuzzing the simplify pass. +-/ + +namespace KLR.Py.Gen + +-- Add Inhabited instances for AST types +instance : Inhabited Index where + default := .coord default + +instance : Inhabited Pattern where + default := .var "x" + +instance : Inhabited Stmt where + default := ⟨{}, .pass⟩ + +structure GenState where + seed : Nat + varCounter : Nat := 0 + vars : List Ident := [] + deriving Inhabited + +abbrev GenM := StateM GenState + +def nextRand : GenM Nat := do + let s ← get + -- LCG parameters (same as glibc) + let next := (s.seed * 1103515245 + 12345) % (2^31) + set { s with seed := next } + return next + +def randNat (max : Nat) : GenM Nat := do + if max == 0 then return 0 + let r ← nextRand + return r % (max + 1) + +def randInt (lo hi : Int) : GenM Int := do + let range := (hi - lo).toNat + 1 + let r ← randNat range + return lo + r + +def randBool : GenM Bool := do + let r ← randNat 1 + return r == 1 + +def randChoice [Inhabited α] (xs : List α) : GenM α := do + match xs with + | [] => return default + | _ => + let idx ← randNat (xs.length - 1) + return xs[idx]! + +def freshVar : GenM Ident := do + let s ← get + let name := s!"v{s.varCounter}" + set { s with varCounter := s.varCounter + 1, vars := name :: s.vars } + return name + +def pickVar : GenM Ident := do + let s ← get + match s.vars with + | [] => freshVar + | vs => randChoice vs + +def span : Span := {} + +def genValue : GenM Value := do + let choice ← randNat 4 + match choice with + | 0 => return .none + | 1 => return .bool (← randBool) + | 2 => return .int (← randInt (-100) 100) + | 3 => return .float (Float.ofInt (← randInt (-100) 100)) + | _ => do + let n ← randNat 99 + return .string s!"s{n}" + +def unaryOps : List UnaryOp := [.pos, .neg, .bitwiseNot, .lnot] +def arithOps : List BinOp := [.add, .sub, .mul, .div, .mod, .pow, .floor] +def cmpOps : List BinOp := [.eq, .ne, .lt, .le, .gt, .ge] +def logicOps : List BinOp := [.land, .lor] +def bitOps : List BinOp := [.bitwiseOr, .bitwiseXor, .bitwiseAnd, .lshift, .rshift] +def allBinOps : List BinOp := arithOps ++ cmpOps ++ logicOps ++ bitOps + +mutual +partial def genLeafExp : GenM Exp := do + let choice ← randNat 3 + match choice with + | 0 => + let s ← get + if !s.vars.isEmpty then + let v ← pickVar + return ⟨span, .var v⟩ + else + let v ← genValue + return ⟨span, .value v⟩ + | 1 => + let v ← genValue + return ⟨span, .value v⟩ + | _ => + -- Call with no args as leaf + let fname ← randChoice (["f", "g", "h"] : List String) + return ⟨span, .call ⟨span, .var fname⟩ [] []⟩ + +partial def genIndex (depth : Nat) : GenM Index := do + let choice ← randNat 4 + match choice with + | 0 | 1 => + let i ← genExp depth + return .coord i + | 2 => + -- Slice with expressions + let hasL ← randBool + let hasU ← randBool + let l ← if hasL then some <$> genExp depth else pure none + let u ← if hasU then some <$> genExp depth else pure none + let step ← if (← randNat 3) == 0 then some <$> genExp depth else pure none + return .slice l u step + | _ => + -- Empty slice + return .slice none none none + +partial def genExp (depth : Nat) : GenM Exp := do + if depth >= 5 then + genLeafExp + else + let choice ← randNat 15 + match choice with + | 0 => genLeafExp + | 1 => + -- Unary op + let op ← randChoice unaryOps + let x ← genExp (depth + 1) + return ⟨span, .unaryOp op x⟩ + | 2 | 3 | 4 => + -- Binary op (common) + let op ← randChoice allBinOps + let x ← genExp (depth + 1) + let y ← genExp (depth + 1) + return ⟨span, .binOp op x y⟩ + | 5 => + -- Nested binary ops + let op1 ← randChoice allBinOps + let op2 ← randChoice allBinOps + let a ← genExp (depth + 1) + let b ← genExp (depth + 1) + let c ← genExp (depth + 1) + let inner : Exp := ⟨span, .binOp op1 a b⟩ + return ⟨span, .binOp op2 inner c⟩ + | 6 => + -- Tuple + let n ← randNat 3 + let es ← (List.range (n + 1)).mapM fun _ => genExp (depth + 1) + return ⟨span, .tuple es⟩ + | 7 => + -- List + let n ← randNat 3 + let es ← (List.range (n + 1)).mapM fun _ => genExp (depth + 1) + return ⟨span, .list es⟩ + | 8 | 9 => + -- Conditional (common) + let test ← genExp (depth + 1) + let body ← genExp (depth + 1) + let orelse ← genExp (depth + 1) + return ⟨span, .ifExp test body orelse⟩ + | 10 | 11 => + -- Function call with args + let fname ← randChoice (["f", "g", "h"] : List String) + let nargs ← randNat 2 + let args ← (List.range (nargs + 1)).mapM fun _ => do + let val ← genExp (depth + 1) + return Arg.mk none val + return ⟨span, .call ⟨span, .var fname⟩ [] args⟩ + | 12 => + -- Chained call: f(g(x)) + let f1 ← randChoice (["f", "g", "h"] : List String) + let f2 ← randChoice (["f", "g", "h"] : List String) + let inner ← genExp (depth + 1) + let innerCall : Exp := ⟨span, .call ⟨span, .var f2⟩ [] [⟨none, inner⟩]⟩ + return ⟨span, .call ⟨span, .var f1⟩ [] [⟨none, innerCall⟩]⟩ + | 13 => + -- Subscript access + let base ← genExp (depth + 1) + let idx ← genIndex (depth + 1) + return ⟨span, .access base [idx]⟩ + | _ => + -- Multi-index access: x[a, b] or x[a][b] + let base ← genExp (depth + 1) + let idx1 ← genIndex (depth + 1) + let idx2 ← genIndex (depth + 1) + let useMulti ← randBool + if useMulti then + return ⟨span, .access base [idx1, idx2]⟩ + else + let first : Exp := ⟨span, .access base [idx1]⟩ + return ⟨span, .access first [idx2]⟩ + +partial def genPattern (depth : Nat) : GenM Pattern := do + -- Simplify pass only supports simple variables in for-loops + let v ← freshVar + return .var v + +partial def genStmt (depth : Nat) (inLoop : Bool) : GenM Stmt := do + if depth >= 3 then + -- At max depth, generate simple but meaningful statements + let choice ← randNat 2 + match choice with + | 0 => + let v ← freshVar + let rhs ← genExp depth + return ⟨span, .assign ⟨span, .var v⟩ none rhs⟩ + | 1 => + let e ← genExp depth + return ⟨span, .ret e⟩ + | _ => + let e ← genExp depth + return ⟨span, .exp e⟩ + let choice ← randNat 10 + match choice with + | 0 | 1 => + -- Assignment (more common) + let v ← freshVar + let rhs ← genExp depth + return ⟨span, .assign ⟨span, .var v⟩ none rhs⟩ + | 2 => + let e ← genExp depth + return ⟨span, .ret e⟩ + | 3 | 4 => + -- If statement (more common) + let cond ← genExp depth + let thn ← genStmts (depth + 1) inLoop + let hasElse ← randBool + let els ← if hasElse then some <$> genStmts (depth + 1) inLoop else pure none + return ⟨span, .ifStm cond thn [] els⟩ + | 5 | 6 => + -- For loop (more common) + let pat ← genPattern 0 + let bound ← randNat 4 + let bound := bound + 1 + let iter : Exp := ⟨span, .call ⟨span, .var "range"⟩ [] [⟨none, ⟨span, .value (.int bound)⟩⟩]⟩ + let body ← genStmts (depth + 1) true + return ⟨span, .forLoop pat iter body⟩ + | 7 => + let cond ← genExp depth + let body ← genStmts (depth + 1) true + return ⟨span, .whileLoop cond body⟩ + | 8 => + if inLoop then return ⟨span, .breakLoop⟩ + else + let v ← freshVar + let rhs ← genExp depth + return ⟨span, .assign ⟨span, .var v⟩ none rhs⟩ + | 9 => + if inLoop then return ⟨span, .continueLoop⟩ + else + let e ← genExp depth + return ⟨span, .exp e⟩ + | _ => + let e ← genExp depth + return ⟨span, .exp e⟩ + +partial def genStmts (depth : Nat) (inLoop : Bool) : GenM (List Stmt) := do + let n ← randNat 2 + let n := n + 2 -- 2-4 statements + (List.range n).mapM fun _ => genStmt depth inLoop +end + +def genFuncDef (name : Ident := "test_kernel") (allowParams : Bool := true) : GenM FuncDef := do + let nParams ← if allowParams then randNat 3 else pure 0 + let params ← (List.range nParams).mapM fun i => do + let pname := s!"p{i}" + modify fun s => { s with vars := pname :: s.vars } + return Param.mk pname none none + let body ← genStmts 0 false + return { + name := name + typParams := [] + params := params + returns := none + body := body + decorators := [] + whereBounds := [] + } + +/-- Generate helper functions with varied signatures -/ +def genHelperFuncs : GenM (List FuncDef) := do + let nHelpers ← randNat 3 + (List.range nHelpers).mapM fun i => do + modify fun s => { s with vars := [] } + let nParams ← randNat 2 + let params ← (List.range nParams).mapM fun j => do + let pname := s!"arg{j}" + modify fun s => { s with vars := pname :: s.vars } + return Param.mk pname none none + let body ← genStmts 0 false + return { + name := s!"func_{i}" + typParams := [] + params := params + returns := none + body := body + decorators := [] + whereBounds := [] + } + +def genProg (seed : Nat) : Prog := + let gen : GenM Prog := do + -- Generate helper functions + let helpers ← genHelperFuncs + + -- Generate main kernel function (no params to avoid "argument not supplied") + modify fun s => { s with vars := [] } + let mainFunc ← genFuncDef "test_kernel" false + + let helperStmts := helpers.map fun f => Stmt.mk span (.funcDef f) + let mainStmt := Stmt.mk span (.funcDef mainFunc) + + return { file := s!"seed_{seed}.py", stmts := helperStmts ++ [mainStmt] } + + let (prog, _) := gen.run { seed := seed } + prog + +def progToString (p : Prog) : String := + s!"# Generated from seed (see filename)\n\n{p.prettyPrint}" + +end KLR.Py.Gen diff --git a/lakefile.lean b/lakefile.lean index cd9b2bad..80127551 100644 --- a/lakefile.lean +++ b/lakefile.lean @@ -32,6 +32,10 @@ lean_exe "klr" where root := `Main supportInterpreter := true +lean_exe "klr-fuzz" where + root := `KLR.Fuzz.Main + supportInterpreter := true + require Cli from git "https://github.com/leanprover/lean4-cli.git" @ "v4.23.0"