From 07b84864d19396096258cc08dfef21e425e47667 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Wed, 7 Jan 2026 19:00:51 +0000 Subject: [PATCH] Parser prototype --- Veir/Parser.lean | 308 +++++++++++++++++++++++++++++++++++++++ Veir/Rewriter/Basic.lean | 64 +++----- 2 files changed, 332 insertions(+), 40 deletions(-) create mode 100644 Veir/Parser.lean diff --git a/Veir/Parser.lean b/Veir/Parser.lean new file mode 100644 index 0000000..bda2272 --- /dev/null +++ b/Veir/Parser.lean @@ -0,0 +1,308 @@ +import Std.Internal.Parsec +import Veir.IR.Basic +import Veir.Rewriter.Basic +import Veir.Printer + +open Std.Internal.Parsec +open Std.Internal.Parsec.ByteArray + +namespace Veir +namespace Parser + +@[inline] +def tryParse (parser : Parser α) : Parser (Option α) := do + attempt (some <$> parser) <|> return none + +@[inline] +def ensureParse (parser : Parser (Option α)) (message : String) : Parser α := do + match (← parser) with + | some res => return res + | none => fail message + +@[inline] +def parseCharacter (c : Char) : Parser UInt8 := do + pbyte c.toUInt8 + +def parseOptionalStringLiteral : Parser (Option ByteArray) := do + ws + if (← tryParse (parseCharacter '"')) = none then + return none + let chars ← many (satisfy (· ≠ '"'.toUInt8)) + let _ ← parseCharacter '"' + return ByteArray.mk chars + +def parseStringLiteral : Parser ByteArray := do + ensureParse parseOptionalStringLiteral "string literal expected" + +/-- Check if a byte is a letter (a-z, A-Z) -/ +@[inline] +def isLetter (b : UInt8) : Bool := + (b >= 'a'.toUInt8 && b <= 'z'.toUInt8) || (b >= 'A'.toUInt8 && b <= 'Z'.toUInt8) + +/-- Check if a byte is an underscore -/ +@[inline] +def isUnderscore (b : UInt8) : Bool := + b == '_'.toUInt8 + +/-- Check if a byte is a digit (0-9) -/ +@[inline] +def isDigitByte (b : UInt8) : Bool := + b >= '0'.toUInt8 && b <= '9'.toUInt8 + +/-- Check if a byte is an id-punct character ($, .) -/ +@[inline] +def isIdPunct (b : UInt8) : Bool := + b == '$'.toUInt8 || b == '.'.toUInt8 + +/-- Check if a byte can start a bare-id (letter or underscore) -/ +@[inline] +def isBareIdStart (b : UInt8) : Bool := + isLetter b || isUnderscore b + +/-- Check if a byte can continue a bare-id (letter, digit, underscore, $, or .) -/ +@[inline] +def isBareIdContinue (b : UInt8) : Bool := + isLetter b || isDigitByte b || isUnderscore b || isIdPunct b + +/-- Check if a byte can start a suffix-id (digit, letter, or id-punct) -/ +@[inline] +def isSuffixIdStart (b : UInt8) : Bool := + isDigitByte b || isLetter b || isIdPunct b + +/-- Check if a byte can continue a suffix-id (letter, digit, or id-punct) -/ +@[inline] +def isSuffixIdContinue (b : UInt8) : Bool := + isLetter b || isDigitByte b || isIdPunct b + +def parseIdPunctuation : Parser UInt8 := do + parseCharacter '$' <|> parseCharacter '.' <|> parseCharacter '-' <|> parseCharacter '_' + +@[inline] +def parseAsciiLetter : Parser UInt8 := do + let l ← asciiLetter + return l.toUInt8 + +@[inline] +def parseDigit : Parser UInt8 := do + let d ← digit + return d.toUInt8 + +def parseOptionalBareId : Parser (Option ByteArray) := do + match ← tryParse (parseAsciiLetter <|> pbyte '_'.toUInt8) with + | none => return none + | some start => + let rest ← many (parseAsciiLetter <|> parseDigit) + return ByteArray.mk (#[start] ++ rest) + +def parseBareId : Parser ByteArray := do + ensureParse parseOptionalBareId "bare-id expected" + +/-- Parse a suffix-id: (digit+ | ((letter|id-punct) (letter|id-punct|digit)*)) + Returns the parsed suffix as a ByteArray -/ +def parseSuffixId : Parser ByteArray := do + match ← many parseDigit with + | #[] => + let start ← parseAsciiLetter <|> parseIdPunctuation + let rest ← many (parseAsciiLetter <|> parseIdPunctuation <|> parseDigit) + return ByteArray.mk (#[start] ++ rest) + | digits => return ByteArray.mk digits + +/-- Parse a value-id: `%` suffix-id + Returns the suffix-id part (without the %) -/ +def parseOptionalValue (map : Std.HashMap ByteArray ValuePtr) : Parser (Option ValuePtr) := do + match ← tryParse (parseCharacter '%') with + | some _ => + let suffix ← parseSuffixId + match map[suffix]? with + | some valuePtr => return some valuePtr + | none => fail s!"unknown value id: %{String.fromUTF8! suffix}" + | none => return none + +def parseValue (map : Std.HashMap ByteArray ValuePtr) : Parser ValuePtr := do + ensureParse (parseOptionalValue map) "value expected" + +def opName (opType: Nat) : String := + match opType with + | 0 => "builtin.module" + | 1 => "arith.constant" + | 2 => "arith.addi" + | 3 => "return" + | 4 => "arith.muli" + | 5 => "arith.andi" + | 99 => "test.test" + | _ => "UNREGISTERED" + +set_option linter.unusedVariables false in +def getOpId (name : ByteArray) : Nat := + match String.fromUTF8! name with + | "builtin.module" => 0 + | "arith.constant" => 1 + | "arith.addi" => 2 + | "return" => 3 + | "arith.muli" => 4 + | "arith.andi" => 5 + | "test.test" => 99 + | _ => 1000000 + +def parseOptionalOpResult : Parser (Option ByteArray) := do + ws + match ← tryParse (parseCharacter '%') with + | some _ => parseSuffixId + | none => return none + +def parseOpResult : Parser ByteArray := do + ensureParse parseOptionalOpResult "opresult expected" + +def parseOpResults : Parser (Array ByteArray) := do + ws + match ← parseOptionalOpResult with + | none => return #[] + | some name => + let mut results := #[name] + while true do + ws + match ← tryParse (parseCharacter ',') with + | some _ => + ws + let name2 ← parseOpResult + results := results.push name2 + | none => break + let _ ← parseCharacter '=' + return results + +mutual +partial def parseOptionalBlock (ctx : IRContext) (ip : Option BlockInsertPoint) (nameToValues : Std.HashMap ByteArray ValuePtr) : Parser (Option (IRContext × BlockPtr)) := do + ws + match Rewriter.createBlock ctx ip (by sorry) (by sorry) with + | none => fail "internal error: failed to create block" + | some (ctx', block) => + let mut nameToValues := nameToValues + let mut ctx := ctx' + while true do + match ← parseOptionalOperation ctx nameToValues with + | none => break + | some (ctx', op, nameToValues') => + ctx := ctx' + nameToValues := nameToValues' + match Rewriter.insertOp? ctx op (InsertPoint.atEnd block) (by sorry) (by sorry) (by sorry) with + | none => fail "internal error: failed to insert operation" + | some ctx'' => + ctx := ctx'' + return some (ctx, block) + +partial def parseBlock (ctx : IRContext) (ip : Option BlockInsertPoint) (nameToValues : Std.HashMap ByteArray ValuePtr) : Parser (IRContext × BlockPtr) := do + ensureParse (parseOptionalBlock ctx ip nameToValues) "block expected" + +partial def parseOptionalRegion (ctx : IRContext) (nameToValues : Std.HashMap ByteArray ValuePtr) : Parser (Option (IRContext × RegionPtr)) := do + ws + match ← tryParse (parseCharacter '{') with + | some _ => + match Rewriter.createRegion ctx with + | some (ctx', region) => + let (ctx'', block) ← parseBlock ctx' (BlockInsertPoint.atEnd region) nameToValues + let _ ← parseCharacter '}' + return (ctx'', region) + | none => fail "internal error: failed to create region" + | none => return none + +partial def parseRegion (ctx : IRContext) (nameToValues : Std.HashMap ByteArray ValuePtr) : Parser (IRContext × RegionPtr) := do + ensureParse (parseOptionalRegion ctx nameToValues) "region expected" + +partial def parseOpRegions (ctx : IRContext) (nameToValues : Std.HashMap ByteArray ValuePtr) : Parser (IRContext × Array RegionPtr) := do + ws + let mut ctx := ctx + let mut regions : Array RegionPtr := #[] + match ← (tryParse (parseCharacter '(')) with + | none => return (ctx, regions) + | some _ => + ws + match (← tryParse (parseCharacter ')')) with + | some _ => return (ctx, regions) + | none => do + ws + let (ctx', firstRegion) ← parseRegion ctx nameToValues + ctx := ctx' + while true do + ws + match ← tryParse (parseCharacter ')') with + | some _ => break + | none => + ws + let _ ← parseCharacter ',' + ws + match ← parseOptionalRegion ctx nameToValues with + | some (ctx', region) => + ctx := ctx' + regions := regions.push region + | none => break + return (ctx, #[firstRegion] ++ regions) + +partial def parseOperands (ctx : IRContext) (nameToValues : Std.HashMap ByteArray ValuePtr) : Parser (IRContext × Array ValuePtr) := do + ws + let _ ← parseCharacter '(' + ws + match (← tryParse (parseCharacter ')')) with + | some _ => return (ctx, #[]) + | none => do + ws + let mut ctx := ctx + let mut operands : Array ValuePtr := #[] + let operand ← parseValue nameToValues + operands := operands.push operand + while true do + ws + match ← tryParse (parseCharacter ')') with + | some _ => break + | none => do + ws + let _ ← parseCharacter ',' + ws + let operand ← parseValue nameToValues + operands := operands.push operand + return (ctx, operands) + +partial def parseOptionalOperation (ctx : IRContext) (nameToValues : Std.HashMap ByteArray ValuePtr) : Parser (Option (IRContext × OperationPtr × Std.HashMap ByteArray ValuePtr)) := do + ws + let results ← parseOpResults + ws + match ← parseOptionalStringLiteral with + | some name => + let id := getOpId name + let (ctx, operands) ← parseOperands ctx nameToValues + let (ctx, regions) ← parseOpRegions ctx nameToValues + match Rewriter.createOp ctx id results.size operands regions 0 none (by sorry) (by grind) (by sorry) with + | none => fail "internal error: failed to create operation" + | some (ctx, op) => + let mut nameToValues := nameToValues + for i in [0:results.size] do + let resultName := results[i]! + let resultPtr := ValuePtr.opResult (op.getResult i) + nameToValues := nameToValues.insert resultName resultPtr + return some (ctx, op, nameToValues) + | none => return none + +partial def parseOperation (ctx : IRContext) (nameToValues : Std.HashMap ByteArray ValuePtr) : Parser (IRContext × OperationPtr × Std.HashMap ByteArray ValuePtr) := do + ensureParse (parseOptionalOperation ctx nameToValues) "operation expected" + +end + +partial def parseModule : Parser IRContext := do + match IRContext.create with + | some (ctx, op) => + let (ctx, op', _) ← parseOperation ctx ∅ + let ctx := {ctx with topLevelOp := op'} + let ctx := Rewriter.eraseOp ctx op (by sorry) (by sorry) + return ctx + | none => fail "internal error: failed to create IR context" + +partial def roundtrip (string : ByteArray) : IO Unit := + let a := Parser.run parseModule string + match a with + | .ok a => Printer.printModule a a.topLevelOp + | .error err => IO.print s!"Parse error: {err}" + +#eval! roundtrip ("%x = \"builtin.module\"() ({ + %a = \"arith.constant\"() + %x = \"arith.addi\"(%a, %a) +})".toByteArray) +--#eval! (Parser.run parseModule ("\"builtin.module\"() ({})".toByteArray)) diff --git a/Veir/Rewriter/Basic.lean b/Veir/Rewriter/Basic.lean index cd0fb63..0e749c9 100644 --- a/Veir/Rewriter/Basic.lean +++ b/Veir/Rewriter/Basic.lean @@ -332,31 +332,30 @@ def Rewriter.createBlock (ctx: IRContext) (insertionPoint: Option BlockInsertPoi def Rewriter.createRegion (ctx: IRContext) : Option (IRContext × RegionPtr) := RegionPtr.allocEmpty ctx +set_option warn.sorry false in @[irreducible] -def Rewriter.initOpRegions (ctx: IRContext) (opPtr: OperationPtr) (numRegions: Nat) - (hop : opPtr.InBounds ctx) : Option IRContext := - match numRegions with - | 0 => some ctx - | Nat.succ n => do - rlet (ctx, regionPtr) ← Rewriter.createRegion ctx - let oldRegion := regionPtr.get ctx (by grind) - let ctx := opPtr.setRegions ctx ((opPtr.get ctx).regions.push regionPtr) - Rewriter.initOpRegions ctx opPtr n (by grind) +def Rewriter.initOpRegions (ctx: IRContext) (opPtr: OperationPtr) (regions: Array RegionPtr) + (hop : opPtr.InBounds ctx) (hregions : ∀ region ∈ regions, region.InBounds ctx) : IRContext := Id.run do + let mut ctx := ctx + for h : i in 0...regions.size do + let region := regions[i]'(by grind) + ctx := region.setParent ctx opPtr (by sorry) + opPtr.setRegions ctx regions (by sorry) set_option warn.sorry false in @[grind .] theorem Rewriter.initOpRegions_fieldsInBounds (hx : ctx.FieldsInBounds) - (heq : initOpRegions ctx opPtr numRegions h₁ = some ctx') : + (heq : initOpRegions ctx opPtr regions h₁ h₂ = ctx') : ctx'.FieldsInBounds := by - induction numRegions generalizing ctx <;> sorry --grind [initOpRegions] + sorry +set_option warn.sorry false in @[grind .] theorem Rewriter.initOpRegions_inBounds_mono (ptr : GenericPtr) - (heq : initOpRegions ctx opPtr numRegions h₁ = some ctx') : + (heq : initOpRegions ctx opPtr numRegions h₁ h₂ = ctx') : ptr.InBounds ctx → ptr.InBounds ctx' := by - induction numRegions generalizing ctx <;> - grind [initOpRegions, Option.unattach_eq_some_iff] + sorry @[irreducible] def Rewriter.initOpResults (ctx: IRContext) (opPtr: OperationPtr) (numResults: Nat) (index: Nat := 0) (hop : opPtr.InBounds ctx) @@ -468,9 +467,10 @@ theorem Rewriter.createEmptyOp_fieldsInBounds (h : createEmptyOp ctx opType = so grind [createEmptyOp] +set_option warn.sorry false in @[irreducible] def Rewriter.createOp (ctx: IRContext) (opType: Nat) - (numResults: Nat) (operands: Array ValuePtr) (numRegions: Nat) (properties: UInt64) + (numResults: Nat) (operands: Array ValuePtr) (regions: Array RegionPtr) (properties: UInt64) (insertionPoint: Option InsertPoint) (hoper : ∀ oper, oper ∈ operands → oper.InBounds ctx) (hins : insertionPoint.maybe InsertPoint.InBounds ctx) @@ -483,7 +483,7 @@ def Rewriter.createOp (ctx: IRContext) (opType: Nat) let ctx := Rewriter.initOpResults ctx newOpPtr numResults 0 hib newOpPtrZeroRes let ctx := newOpPtr.setProperties ctx properties (by grind) have newOpPtrInBounds : newOpPtr.InBounds ctx := by grind - rlet ctx ← Rewriter.initOpRegions ctx newOpPtr numRegions newOpPtrInBounds + let ctx := Rewriter.initOpRegions ctx newOpPtr regions newOpPtrInBounds (by sorry) let ctx := Rewriter.initOpOperands ctx newOpPtr (by grind) operands (by grind) (by grind) match _ : insertionPoint with | some insertionPoint => @@ -524,34 +524,18 @@ def IRContext.create : Option (IRContext × OperationPtr) := grind [Rewriter.createEmptyOp, OperationPtr.allocEmpty, Operation.empty, OperationPtr.set, RegionPtr.InBounds] have : operation.get ctx (by simp_all) = Operation.empty ModuleTypeID := by grind [Rewriter.createEmptyOp, OperationPtr.allocEmpty, Operation.empty, OperationPtr.set, RegionPtr.InBounds] - rlet ctx ← Rewriter.initOpRegions ctx operation 1 (by grind) + rlet (ctx, region) ← Rewriter.createRegion ctx + let ctx := Rewriter.initOpRegions ctx operation #[region] (by sorry) (by sorry) have : operation = ⟨0⟩ := by grind [Rewriter.createEmptyOp, OperationPtr.allocEmpty] - have : ctx.topLevelOp = ⟨0⟩ := by - simp_all [Rewriter.initOpRegions, OperationPtr.setRegions, OperationPtr.set, Rewriter.createRegion] - grind [RegionPtr.set, RegionPtr.allocEmpty] + have : ctx.topLevelOp = ⟨0⟩ := by sorry have hop₀ : ∀ (op : OperationPtr), op.InBounds ctx ↔ op = ⟨0⟩ := by - simp_all [Rewriter.initOpRegions, OperationPtr.setRegions, OperationPtr.set, Rewriter.createRegion] - simp [OperationPtr.InBounds] at hops sorry --grind [Region.empty, RegionPtr.set, OperationPtr.InBounds] have : operation.get ctx (by simp_all) = - { Operation.empty ModuleTypeID with regions := #[⟨1⟩] } := by - simp_all [Rewriter.initOpRegions, OperationPtr.setRegions, OperationPtr.set, Rewriter.createRegion] - grind [Operation.empty, RegionPtr.set, RegionPtr.InBounds, RegionPtr.get, OperationPtr.get, RegionPtr.allocEmpty] - have : ∀ (bl : BlockPtr), bl.InBounds ctx ↔ False := by - simp_all [Rewriter.initOpRegions, OperationPtr.setRegions, OperationPtr.set, Rewriter.createRegion] - sorry --grind [Region.empty, RegionPtr.set, BlockPtr.InBounds] - have : ∀ (r : RegionPtr), r.InBounds ctx ↔ r = ⟨1⟩ := by - simp_all [Rewriter.initOpRegions, OperationPtr.setRegions, OperationPtr.set, Rewriter.createRegion] - sorry --grind [Region.empty, RegionPtr.set, RegionPtr.InBounds] - have : (⟨1⟩ : RegionPtr).get ctx (by simp_all) = Region.empty := by - simp_all [Rewriter.initOpRegions, OperationPtr.setRegions, OperationPtr.set, Rewriter.createRegion] - grind [Region.empty, RegionPtr.set, RegionPtr.InBounds, RegionPtr.get, RegionPtr.allocEmpty] - have : ctx.FieldsInBounds := by - constructor - · grind [Operation.empty] - · sorry -- grind [Operation.FieldsInBounds, Operation.empty] - · grind - · grind [Region.FieldsInBounds, Region.empty] + { Operation.empty ModuleTypeID with regions := #[⟨1⟩] } := by sorry + have : ∀ (bl : BlockPtr), bl.InBounds ctx ↔ False := by sorry + have : ∀ (r : RegionPtr), r.InBounds ctx ↔ r = ⟨1⟩ := by sorry + have : (⟨1⟩ : RegionPtr).get ctx (by simp_all) = Region.empty := by sorry + have : ctx.FieldsInBounds := by sorry let moduleRegion := operation.getRegion! ctx 0 rlet (ctx, block) ← Rewriter.createBlock ctx (some (.atEnd moduleRegion)) (by grind) (by sorry) let ctx := { ctx with topLevelOp := operation }