diff --git a/.gitignore b/.gitignore index 80f25a0f..d647d149 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,10 @@ __pycache__/ build/ dist/ klr.bin +KLR/gotest +KLR/test.py +.gitignore +KLR/Compile/test.lean +KLR/test.ast +KLR/.DS_Store +.DS_Store diff --git a/KLR/Compile/Dataflow.lean b/KLR/Compile/Dataflow.lean index e9f85b5b..7d48ada8 100644 --- a/KLR/Compile/Dataflow.lean +++ b/KLR/Compile/Dataflow.lean @@ -198,6 +198,7 @@ section Basics infix:100 "⊔" => Max.max + def NodeMap.instBEq {α β : Type} [NodeMap α] [BEq β] : BEq ⟦α, β⟧ := { beq μ₀ μ₁ := μ₀ fold⟪true, (fun a prev => prev ∧ (μ₀◃a == μ₁◃a))⟫ } @@ -244,6 +245,14 @@ section Basics } } + theorem NodeMap.of_const_map {α β γ: Type} [BEq β] [BEq γ] [LawfulBEq γ] [NodeMap α] + (b : β) (f : β → γ) : (NodeMap.const (α:=α) b) map⟪f⟫ == ⟪↦(f b)⟫ := by { + rw [beq_ext] + intro a + rw [of_map_get, of_const_get, of_const_get] + simp + } + instance {α β : Type} [NodeMap α] [ToString α] [ToString β] : ToString ⟦α, β⟧ where toString μ := μ fold⟪"", (fun nd repr => repr ++ "\n{" ++ toString nd.data ++ ": " ++ toString (μ◃nd) ++ "}")⟫ @@ -300,7 +309,6 @@ section Basics end Basics - /- The section `DataflowProblemSolver ` is parameterized on an instance of `DataflowProblem α β`. It builds on the definitions of maps `⟦α, β⟧` from `NodeMap α`, and on the transition functions @@ -313,20 +321,29 @@ end Basics the dataflow problem, and a `I' ν : Prop` - which captures that `ν` satisfies the dataflow problem. -/ section DataflowProblemSolver + variable {α β : Type} [BEq α] {DP: DataflowProblem α β} open DataflowProblem + def ν₀ : ⟦α, (β × Bool)⟧ := ⟪↦(⊥, true)⟫ + def ε (a₀ a₁ : Node α) : Bool := List.elem a₁ (σ◃a₀) + def strip_bools (ν : ⟦α, (β × Bool)⟧) := ν map⟪fun (β, _)=>β⟫ + def E (P : (Node α) → (Node α) → Prop) := ∀ (a₀ a₁) (_:ε a₀ a₁), P a₀ a₁ + def R (ν₀ : ⟦α, (β × Bool)⟧) (ν₁ : ⟦α, β⟧) [LE β]: Prop := E (fun a₀ a₁ => (ν₀◃a₀).2 ∨ (τ◃a₀) ((ν₀◃a₀).1) ≤ (ν₁◃a₁)) + def I (ν : ⟦α, (β × Bool)⟧) : Prop := R ν (strip_bools ν) + def R' (ν₀ ν₁ : ⟦α, β⟧) : Prop := E (fun a₀ a₁ => (τ◃a₀) (ν₀◃a₀) ≤ ν₁◃a₁) + def I' (ν : ⟦α, β⟧) : Prop := R' ν ν theorem base_case : @I α β _ DP ν₀ := by { @@ -337,17 +354,21 @@ section DataflowProblemSolver rw [NodeMap.of_const_get] } + def δ (ν : ⟦α, (β × Bool)⟧) (a : Node α) : ⟦α, β⟧ := -- step of_func⟪(fun a' => if ε a a' then ((τ◃a) (ν◃a).1) else ⊥)⟫ + def Δ₀ (ν : ⟦α, (β × Bool)⟧) : ⟦α, β⟧ := ν fold⟪ν map⟪(·.1)⟫, (fun a ν₀ => if (ν◃a).2 then ν₀ ⊔ (δ ν a) else ν₀)⟫ + def Δ (ν : ⟦α, (β × Bool)⟧) : ⟦α, (β × Bool)⟧ := let ν' := Δ₀ ν of_func⟪fun a => let (β, β') := ((ν◃a).1, (ν'◃a)); (β', β != β')⟫ + def is_fix (ν : ⟦α, (β × Bool)⟧) : Bool := ν map⟪fun x↦x.2⟫ == ⟪↦false⟫ @@ -529,6 +550,7 @@ section DataflowProblemSolver } } + -- don't want to unroll this automatically def DataflowProblem.solve_to_depth {α β : Type} (depth : ℕ) (DP : DataflowProblem α β) @@ -546,6 +568,7 @@ section DataflowProblemSolver else solve_to_depth depth' DP ν' h' + def DataflowProblem.solve {α β : Type} [BEq α] (DP : DataflowProblem α β) : Option ((ν : ⟦α, β⟧) ×' I' ν) @@ -585,27 +608,35 @@ section FiniteDataflowProblemSolver le_supl (β₀ β₁ : β) : β₀ ≤ Max.max β₀ β₁ le_supr (β₀ β₁ : β) : β₁ ≤ Max.max β₀ β₁ + def LtProp : NodeProp ℕ where node_prop n' := n' < n + def NodeT := @Node ℕ (LtProp n) + def node_to_fin (nd : NodeT n) : (Fin n) := {val := @nd.data, isLt := @nd.sound} + def fin_to_node (fin : Fin n) : (NodeT n) := @Node.mk ℕ (LtProp n) fin.val fin.isLt + def nodes : Vector (NodeT n) n := Vector.ofFn (fin_to_node n) + def vector_fn {β : Type} (f : NodeT n → β) : Vector β n := Vector.ofFn (f ∘ (fin_to_node n)) + def FiniteNodeProp : NodeProp ℕ := { node_prop n' := n' < n } + def FiniteNodeMap : NodeMap ℕ := { FiniteNodeProp n with μ β := Vector β n @@ -697,6 +728,7 @@ section FiniteDataflowProblemSolver This is the end of the section because the returned instance provides the `DataflowProblem.solve` function. -/ + def FiniteDataflowProblem {β : Type} [BEq β] [P:Preorder β] @@ -810,23 +842,27 @@ section InnerMapImpl vals (n k : ℕ) : (n < num_nodes) → (k < num_keys) → ρ props (n m k : ℕ) : (hn : n < num_nodes) → (hm : m < num_nodes) → (hk : k < num_keys) → (edges n m) → transitions n k (vals n k hn hk) ≤ (vals m k hm hk) + key_labels : ℕ → Option String --for debugging printing def SolutionT.toString [ToString ρ] (𝕊 : SolutionT ρ num_nodes num_keys edges transitions) : String := let 𝕍 := 𝕊.vals let nd_to_string n (hn :n < num_nodes) : String := - let entries := (List.range num_keys).filterMap - (fun k => if hk: k < num_keys then some (ToString.toString (𝕍 n k hn hk)) else none) + let entries := (List.range num_keys).filterMap -- all entries will map to some _ but this isn't a dependent map + (fun k => if hk: k < num_keys then + let pre := match 𝕊.key_labels k with | some s => s!"{s}:" | none => ""; + some (s!"{pre}{(𝕍 n k hn hk)}") + else none) String.intercalate " " entries - let lines := (List.range num_nodes).filterMap + let lines := (List.range num_nodes).filterMap -- all entries will map to some _ but this isn't a dependent map (fun n => if hn: n < num_nodes then ( let s := nd_to_string n hn; some (s!"Node {n}: {s}") ) else none) String.intercalate "\n" ([""] ++ lines ++ [""]) - instance [ToString ρ] : ToString (SolutionT ρ num_nodes num_keys edges transitions) where - toString := (SolutionT.toString ρ num_nodes num_keys edges transitions) + instance [ToString ρ] : ToString (SolutionT ρ num_nodes num_keys edges transitions) where + toString := SolutionT.toString ρ num_nodes num_keys edges transitions end SolutionImpl @@ -986,9 +1022,13 @@ section InnerMapImpl some { vals := vals props := props + key_labels _ := none } end InnerMapImpl -/- + + +/- EXAMPLE + The section `ConcreteMapImpl` serves to illustrate an end-to-end usage of the dataflow solver defined above. In particular: @@ -1424,4 +1464,5 @@ namespace UseDefImpl end UseDefImpl + -- thanks for reading! - Julia 💕 diff --git a/KLR/Compile/DataflowTestKernels.lean b/KLR/Compile/DataflowTestKernels.lean new file mode 100644 index 00000000..a42ee4b7 --- /dev/null +++ b/KLR/Compile/DataflowTestKernels.lean @@ -0,0 +1,524 @@ +/- + This file defines a `class HasKernel` that can provide an instance of a NKI kernel function + if present. it also defines a few instances of this class corresponding to different kernels. + `kernel_str : String` is an additional attribute of `HasKernel` that records the source-level + code for output sake. +-/ + +import KLR.NKI.Basic +open KLR.NKI + +section TestKernels + +def start_highlight := "⦃!" --"◯" --"⇉" +def end_highlight := "⦄" --"◯" --"⇇" + +class HasKernel where + kernel : Fun + kernel_str : String + +instance : ToString Pos where toString p := s!"{p.line}, {p.lineEnd}, {p.column}, {p.columnEnd}" + +def highlight_pos_set [HasKernel] (actions : List Pos) (s : String) : String := + let newlines : List Nat := (List.range s.length).filter (fun n ↦ s.toList[n]! = '\n') + + let findStart (pos : Pos) : List Nat := [newlines[pos.line - 1]! + pos.column + 1] + let findEnd (pos : Pos) : List Nat := match pos.lineEnd, pos.columnEnd with + | some l, some c => [newlines[l - 1]! + c + 1] + | _, _ => [] + let starts := actions.flatMap findStart + let ends := actions.flatMap findEnd + let out_str_at n : List Char := + let st := if n ∈ starts then start_highlight.toList else [] + let ed := if n ∈ ends then end_highlight.toList else [] + st ++ ed ++ [s.toList[n]!] + ⟨((List.range s.length).flatMap out_str_at)⟩ + +def kernel_highlighted_repr [HasKernel] (actions : List Pos) : String := + highlight_pos_set actions HasKernel.kernel_str + + +def safe_kernel_1 : HasKernel where + kernel_str :=" + def test(): + x = 0 + c = 0 + p = 0 + if c: + p(x) + else: + y = 0 + p(y) + p(x)" + kernel := { name := "test.test", + file := "unknown", + line := 1, + body := [{ stmt := KLR.NKI.Stmt'.assign + { expr := KLR.NKI.Expr'.var "x", + pos := { line := 2, column := 1, lineEnd := some 2, columnEnd := some 2 } } + none + (some { expr := KLR.NKI.Expr'.value (KLR.NKI.Value.int 0), + pos := { line := 2, column := 5, lineEnd := some 2, columnEnd := some 6 } }), + pos := { line := 2, column := 1, lineEnd := some 2, columnEnd := some 6 } }, + { stmt := KLR.NKI.Stmt'.assign + { expr := KLR.NKI.Expr'.var "c", + pos := { line := 3, column := 1, lineEnd := some 3, columnEnd := some 2 } } + none + (some { expr := KLR.NKI.Expr'.value (KLR.NKI.Value.int 0), + pos := { line := 3, column := 5, lineEnd := some 3, columnEnd := some 6 } }), + pos := { line := 3, column := 1, lineEnd := some 3, columnEnd := some 6 } }, + { stmt := KLR.NKI.Stmt'.assign + { expr := KLR.NKI.Expr'.var "p", + pos := { line := 4, column := 1, lineEnd := some 4, columnEnd := some 2 } } + none + (some { expr := KLR.NKI.Expr'.value (KLR.NKI.Value.int 0), + pos := { line := 4, column := 5, lineEnd := some 4, columnEnd := some 6 } }), + pos := { line := 4, column := 1, lineEnd := some 4, columnEnd := some 6 } }, + { stmt := KLR.NKI.Stmt'.ifStm + { expr := KLR.NKI.Expr'.var "c", + pos := { line := 5, column := 4, lineEnd := some 5, columnEnd := some 5 } } + [{ stmt := KLR.NKI.Stmt'.expr + { expr := KLR.NKI.Expr'.call + { expr := KLR.NKI.Expr'.var "p", + pos := { line := 6, + column := 2, + lineEnd := some 6, + columnEnd := some 3 } } + [{ expr := KLR.NKI.Expr'.var "x", + pos := { line := 6, + column := 4, + lineEnd := some 6, + columnEnd := some 5 } }] + [], + pos := { line := 6, + column := 2, + lineEnd := some 6, + columnEnd := some 6 } }, + pos := { line := 6, column := 2, lineEnd := some 6, columnEnd := some 6 } }] + [{ stmt := KLR.NKI.Stmt'.assign + { expr := KLR.NKI.Expr'.var "y", + pos := { line := 8, + column := 2, + lineEnd := some 8, + columnEnd := some 3 } } + none + (some { expr := KLR.NKI.Expr'.value (KLR.NKI.Value.int 0), + pos := { line := 8, + column := 6, + lineEnd := some 8, + columnEnd := some 7 } }), + pos := { line := 8, column := 2, lineEnd := some 8, columnEnd := some 7 } }, + { stmt := KLR.NKI.Stmt'.expr + { expr := KLR.NKI.Expr'.call + { expr := KLR.NKI.Expr'.var "p", + pos := { line := 9, + column := 2, + lineEnd := some 9, + columnEnd := some 3 } } + [{ expr := KLR.NKI.Expr'.var "y", + pos := { line := 9, + column := 4, + lineEnd := some 9, + columnEnd := some 5 } }] + [], + pos := { line := 9, + column := 2, + lineEnd := some 9, + columnEnd := some 6 } }, + pos := { line := 9, column := 2, lineEnd := some 9, columnEnd := some 6 } }], + pos := { line := 5, column := 1, lineEnd := some 9, columnEnd := some 6 } }, + { stmt := KLR.NKI.Stmt'.expr + { expr := KLR.NKI.Expr'.call + { expr := KLR.NKI.Expr'.var "p", + pos := { line := 10, + column := 1, + lineEnd := some 10, + columnEnd := some 2 } } + [{ expr := KLR.NKI.Expr'.var "x", + pos := { line := 10, + column := 3, + lineEnd := some 10, + columnEnd := some 4 } }] + [], + pos := { line := 10, column := 1, lineEnd := some 10, columnEnd := some 5 } }, + pos := { line := 10, column := 1, lineEnd := some 10, columnEnd := some 5 } }], + args := [] } + +def unsafe_kernel_2 : HasKernel where + kernel_str := " +def test(): + x = 0 + c = 0 + p = 0 + if c: + p(x) + else: + y = 0 + p(y) + p(y) + " + kernel := { name := "test.test", + file := "unknown", + line := 1, + body := [{ stmt := KLR.NKI.Stmt'.assign + { expr := KLR.NKI.Expr'.var "x", + pos := { line := 2, column := 1, lineEnd := some 2, columnEnd := some 2 } } + none + (some { expr := KLR.NKI.Expr'.value (KLR.NKI.Value.int 0), + pos := { line := 2, column := 5, lineEnd := some 2, columnEnd := some 6 } }), + pos := { line := 2, column := 1, lineEnd := some 2, columnEnd := some 6 } }, + { stmt := KLR.NKI.Stmt'.assign + { expr := KLR.NKI.Expr'.var "c", + pos := { line := 3, column := 1, lineEnd := some 3, columnEnd := some 2 } } + none + (some { expr := KLR.NKI.Expr'.value (KLR.NKI.Value.int 0), + pos := { line := 3, column := 5, lineEnd := some 3, columnEnd := some 6 } }), + pos := { line := 3, column := 1, lineEnd := some 3, columnEnd := some 6 } }, + { stmt := KLR.NKI.Stmt'.assign + { expr := KLR.NKI.Expr'.var "p", + pos := { line := 4, column := 1, lineEnd := some 4, columnEnd := some 2 } } + none + (some { expr := KLR.NKI.Expr'.value (KLR.NKI.Value.int 0), + pos := { line := 4, column := 5, lineEnd := some 4, columnEnd := some 6 } }), + pos := { line := 4, column := 1, lineEnd := some 4, columnEnd := some 6 } }, + { stmt := KLR.NKI.Stmt'.ifStm + { expr := KLR.NKI.Expr'.var "c", + pos := { line := 5, column := 4, lineEnd := some 5, columnEnd := some 5 } } + [{ stmt := KLR.NKI.Stmt'.expr + { expr := KLR.NKI.Expr'.call + { expr := KLR.NKI.Expr'.var "p", + pos := { line := 6, + column := 2, + lineEnd := some 6, + columnEnd := some 3 } } + [{ expr := KLR.NKI.Expr'.var "x", + pos := { line := 6, + column := 4, + lineEnd := some 6, + columnEnd := some 5 } }] + [], + pos := { line := 6, + column := 2, + lineEnd := some 6, + columnEnd := some 6 } }, + pos := { line := 6, column := 2, lineEnd := some 6, columnEnd := some 6 } }] + [{ stmt := KLR.NKI.Stmt'.assign + { expr := KLR.NKI.Expr'.var "y", + pos := { line := 8, + column := 2, + lineEnd := some 8, + columnEnd := some 3 } } + none + (some { expr := KLR.NKI.Expr'.value (KLR.NKI.Value.int 0), + pos := { line := 8, + column := 6, + lineEnd := some 8, + columnEnd := some 7 } }), + pos := { line := 8, column := 2, lineEnd := some 8, columnEnd := some 7 } }, + { stmt := KLR.NKI.Stmt'.expr + { expr := KLR.NKI.Expr'.call + { expr := KLR.NKI.Expr'.var "p", + pos := { line := 9, + column := 2, + lineEnd := some 9, + columnEnd := some 3 } } + [{ expr := KLR.NKI.Expr'.var "y", + pos := { line := 9, + column := 4, + lineEnd := some 9, + columnEnd := some 5 } }] + [], + pos := { line := 9, + column := 2, + lineEnd := some 9, + columnEnd := some 6 } }, + pos := { line := 9, column := 2, lineEnd := some 9, columnEnd := some 6 } }], + pos := { line := 5, column := 1, lineEnd := some 9, columnEnd := some 6 } }, + { stmt := KLR.NKI.Stmt'.expr + { expr := KLR.NKI.Expr'.call + { expr := KLR.NKI.Expr'.var "p", + pos := { line := 10, + column := 1, + lineEnd := some 10, + columnEnd := some 2 } } + [{ expr := KLR.NKI.Expr'.var "y", + pos := { line := 10, + column := 3, + lineEnd := some 10, + columnEnd := some 4 } }] + [], + pos := { line := 10, column := 1, lineEnd := some 10, columnEnd := some 5 } }, + pos := { line := 10, column := 1, lineEnd := some 10, columnEnd := some 5 } }], + args := [] } + + +def unsafe_kernel_3 : HasKernel where + kernel_str := " +def test(): + x = 0 + c = 0 + p = 0 + if c: + p(x) + for z in x: + w = 0 + p(w) + p(z) + p(w) + p(z) + else: + y = 0 + p(x) + p(y) + p(x) + p(y) + p(z)" + kernel := {name := "test.test", + file := "unknown", + line := 1, + body := [{ stmt := KLR.NKI.Stmt'.assign + { expr := KLR.NKI.Expr'.var "x", + pos := { line := 2, column := 1, lineEnd := some 2, columnEnd := some 2 } } + none + (some { expr := KLR.NKI.Expr'.value (KLR.NKI.Value.int 0), + pos := { line := 2, column := 5, lineEnd := some 2, columnEnd := some 6 } }), + pos := { line := 2, column := 1, lineEnd := some 2, columnEnd := some 6 } }, + { stmt := KLR.NKI.Stmt'.assign + { expr := KLR.NKI.Expr'.var "c", + pos := { line := 3, column := 1, lineEnd := some 3, columnEnd := some 2 } } + none + (some { expr := KLR.NKI.Expr'.value (KLR.NKI.Value.int 0), + pos := { line := 3, column := 5, lineEnd := some 3, columnEnd := some 6 } }), + pos := { line := 3, column := 1, lineEnd := some 3, columnEnd := some 6 } }, + { stmt := KLR.NKI.Stmt'.assign + { expr := KLR.NKI.Expr'.var "p", + pos := { line := 4, column := 1, lineEnd := some 4, columnEnd := some 2 } } + none + (some { expr := KLR.NKI.Expr'.value (KLR.NKI.Value.int 0), + pos := { line := 4, column := 5, lineEnd := some 4, columnEnd := some 6 } }), + pos := { line := 4, column := 1, lineEnd := some 4, columnEnd := some 6 } }, + { stmt := KLR.NKI.Stmt'.ifStm + { expr := KLR.NKI.Expr'.var "c", + pos := { line := 5, column := 4, lineEnd := some 5, columnEnd := some 5 } } + [{ stmt := KLR.NKI.Stmt'.expr + { expr := KLR.NKI.Expr'.call + { expr := KLR.NKI.Expr'.var "p", + pos := { line := 6, + column := 2, + lineEnd := some 6, + columnEnd := some 3 } } + [{ expr := KLR.NKI.Expr'.var "x", + pos := { line := 6, + column := 4, + lineEnd := some 6, + columnEnd := some 5 } }] + [], + pos := { line := 6, + column := 2, + lineEnd := some 6, + columnEnd := some 6 } }, + pos := { line := 6, column := 2, lineEnd := some 6, columnEnd := some 6 } }, + { stmt := KLR.NKI.Stmt'.forLoop + { expr := KLR.NKI.Expr'.var "z", + pos := { line := 7, + column := 6, + lineEnd := some 7, + columnEnd := some 7 } } + { expr := KLR.NKI.Expr'.var "x", + pos := { line := 7, + column := 11, + lineEnd := some 7, + columnEnd := some 12 } } + [{ stmt := KLR.NKI.Stmt'.assign + { expr := KLR.NKI.Expr'.var "w", + pos := { line := 8, + column := 3, + lineEnd := some 8, + columnEnd := some 4 } } + none + (some { expr := KLR.NKI.Expr'.value (KLR.NKI.Value.int 0), + pos := { line := 8, + column := 7, + lineEnd := some 8, + columnEnd := some 8 } }), + pos := { line := 8, + column := 3, + lineEnd := some 8, + columnEnd := some 8 } }, + { stmt := KLR.NKI.Stmt'.expr + { expr := KLR.NKI.Expr'.call + { expr := KLR.NKI.Expr'.var "p", + pos := { line := 9, + column := 3, + lineEnd := some 9, + columnEnd := some 4 } } + [{ expr := KLR.NKI.Expr'.var "w", + pos := { line := 9, + column := 5, + lineEnd := some 9, + columnEnd := some 6 } }] + [], + pos := { line := 9, + column := 3, + lineEnd := some 9, + columnEnd := some 7 } }, + pos := { line := 9, + column := 3, + lineEnd := some 9, + columnEnd := some 7 } }, + { stmt := KLR.NKI.Stmt'.expr + { expr := KLR.NKI.Expr'.call + { expr := KLR.NKI.Expr'.var "p", + pos := { line := 10, + column := 3, + lineEnd := some 10, + columnEnd := some 4 } } + [{ expr := KLR.NKI.Expr'.var "z", + pos := { line := 10, + column := 5, + lineEnd := some 10, + columnEnd := some 6 } }] + [], + pos := { line := 10, + column := 3, + lineEnd := some 10, + columnEnd := some 7 } }, + pos := { line := 10, + column := 3, + lineEnd := some 10, + columnEnd := some 7 } }], + pos := { line := 7, column := 2, lineEnd := some 10, columnEnd := some 7 } }, + { stmt := KLR.NKI.Stmt'.expr + { expr := KLR.NKI.Expr'.call + { expr := KLR.NKI.Expr'.var "p", + pos := { line := 11, + column := 2, + lineEnd := some 11, + columnEnd := some 3 } } + [{ expr := KLR.NKI.Expr'.var "w", + pos := { line := 11, + column := 4, + lineEnd := some 11, + columnEnd := some 5 } }] + [], + pos := { line := 11, + column := 2, + lineEnd := some 11, + columnEnd := some 6 } }, + pos := { line := 11, column := 2, lineEnd := some 11, columnEnd := some 6 } }, + { stmt := KLR.NKI.Stmt'.expr + { expr := KLR.NKI.Expr'.call + { expr := KLR.NKI.Expr'.var "p", + pos := { line := 12, + column := 2, + lineEnd := some 12, + columnEnd := some 3 } } + [{ expr := KLR.NKI.Expr'.var "z", + pos := { line := 12, + column := 4, + lineEnd := some 12, + columnEnd := some 5 } }] + [], + pos := { line := 12, + column := 2, + lineEnd := some 12, + columnEnd := some 6 } }, + pos := { line := 12, column := 2, lineEnd := some 12, columnEnd := some 6 } }] + [{ stmt := KLR.NKI.Stmt'.assign + { expr := KLR.NKI.Expr'.var "y", + pos := { line := 14, + column := 2, + lineEnd := some 14, + columnEnd := some 3 } } + none + (some { expr := KLR.NKI.Expr'.value (KLR.NKI.Value.int 0), + pos := { line := 14, + column := 6, + lineEnd := some 14, + columnEnd := some 7 } }), + pos := { line := 14, column := 2, lineEnd := some 14, columnEnd := some 7 } }, + { stmt := KLR.NKI.Stmt'.expr + { expr := KLR.NKI.Expr'.call + { expr := KLR.NKI.Expr'.var "p", + pos := { line := 15, + column := 2, + lineEnd := some 15, + columnEnd := some 3 } } + [{ expr := KLR.NKI.Expr'.var "x", + pos := { line := 15, + column := 4, + lineEnd := some 15, + columnEnd := some 5 } }] + [], + pos := { line := 15, + column := 2, + lineEnd := some 15, + columnEnd := some 6 } }, + pos := { line := 15, column := 2, lineEnd := some 15, columnEnd := some 6 } }, + { stmt := KLR.NKI.Stmt'.expr + { expr := KLR.NKI.Expr'.call + { expr := KLR.NKI.Expr'.var "p", + pos := { line := 16, + column := 2, + lineEnd := some 16, + columnEnd := some 3 } } + [{ expr := KLR.NKI.Expr'.var "y", + pos := { line := 16, + column := 4, + lineEnd := some 16, + columnEnd := some 5 } }] + [], + pos := { line := 16, + column := 2, + lineEnd := some 16, + columnEnd := some 6 } }, + pos := { line := 16, column := 2, lineEnd := some 16, columnEnd := some 6 } }], + pos := { line := 5, column := 1, lineEnd := some 16, columnEnd := some 6 } }, + { stmt := KLR.NKI.Stmt'.expr + { expr := KLR.NKI.Expr'.call + { expr := KLR.NKI.Expr'.var "p", + pos := { line := 17, + column := 1, + lineEnd := some 17, + columnEnd := some 2 } } + [{ expr := KLR.NKI.Expr'.var "x", + pos := { line := 17, + column := 3, + lineEnd := some 17, + columnEnd := some 4 } }] + [], + pos := { line := 17, column := 1, lineEnd := some 17, columnEnd := some 5 } }, + pos := { line := 17, column := 1, lineEnd := some 17, columnEnd := some 5 } }, + { stmt := KLR.NKI.Stmt'.expr + { expr := KLR.NKI.Expr'.call + { expr := KLR.NKI.Expr'.var "p", + pos := { line := 18, + column := 1, + lineEnd := some 18, + columnEnd := some 2 } } + [{ expr := KLR.NKI.Expr'.var "y", + pos := { line := 18, + column := 3, + lineEnd := some 18, + columnEnd := some 4 } }] + [], + pos := { line := 18, column := 1, lineEnd := some 18, columnEnd := some 5 } }, + pos := { line := 18, column := 1, lineEnd := some 18, columnEnd := some 5 } }, + { stmt := KLR.NKI.Stmt'.expr + { expr := KLR.NKI.Expr'.call + { expr := KLR.NKI.Expr'.var "p", + pos := { line := 19, + column := 1, + lineEnd := some 19, + columnEnd := some 2 } } + [{ expr := KLR.NKI.Expr'.var "z", + pos := { line := 19, + column := 3, + lineEnd := some 19, + columnEnd := some 4 } }] + [], + pos := { line := 19, column := 1, lineEnd := some 19, columnEnd := some 5 } }, + pos := { line := 19, column := 1, lineEnd := some 19, columnEnd := some 5 } }], + args := [] } + +end TestKernels diff --git a/KLR/Compile/NKIDataflow.lean b/KLR/Compile/NKIDataflow.lean new file mode 100644 index 00000000..75c63d16 --- /dev/null +++ b/KLR/Compile/NKIDataflow.lean @@ -0,0 +1,643 @@ +/- +# NKI Dataflow + +This file uses the Dataflow solver (`InnerMapImpl.Solution`) from `Dataflow.lean` +to analyize NKI functions (`HasKernel.kernel`) from `DataflowTestKernels.lean`. +the final output is the def `decide_safety` - it is built on a kernel, a succesful +dataflow solution, and a safety analysis of the kernel. This is all arranges as follows: + +`section DefVarAction` defines the `VarAction` inductive which describes the + semantically significant actions of NKI statements for the sake of our + analysis: namely, named reads and writes to variables. + +`section DefNKIWalker`defines the `NKIWalker` structure that can walk a + NKI AST to construct a walker instance that contains the entire CFG + structure of the kernel: + `def NKIWalker.processFun (f : Fun) : NKIWalker := ...` + +the remainder of the file is organized in a module/functor-like structure +of parameterization of sections on typeclasses. each of the below +sections takes an instance of a class that bundles an important computational +step, and performs computation on it leading to some dependent output. + +`class HasKernel where kernel : Fun` - defined in `DataflowTestKernels.lean`, wraps a `NKI.Fun` kernel function + +`section WithKernel [HasKernel]` - uses a kernel to construct an instance + of a dataflow problem, whose (option-wrapped) solution `𝕏opt` is the final + output + + `class HasSuccess where success : 𝕏opt.isSome` - makes available the result + that the dataflow analysis was succesful, i.e `𝕏opt ≠ none` + + `section WithSuccess [HasSuccess]` - uses a success result to finish defining + our desired semantic properties of paths, and checks source functions + to ensure that reads occur only in places the (succesful) dataflow + analysis deemed safe + + `class HasSafety where safety : is_safe` - makes availabe + the result that the syntactic safety chcking was succesful + + `section WithSafety [HasSafety]` - defines + `def no_read_without_a_write [HasKernel] [HasSuccess] [HasSafety] : walker.sound := ...` + which provides an instance `is_safe`: + `abbrev is_safe : Prop := ∀ (n : 𝕟) (v : 𝕍), walker.reads n v → var_def n v` + of soundness for this NKI program, conditional on + a nki program being avaiable (`HasKernel`), dataflow analysis succeeding (`HasSuccess`) + and syntactic safety checks succeeding (`HasSafety`). + + `def decide_sound [HasKernel]: Maybe (walker.sound) := ...` - + exists inside `WithKernel` (so instantiating depends on a kernel), + but outside `WithSuccess` and `WithSafety` because success and safety + are decided based on the kernel not releid upon as parameters. + +this provides the final workflow: + - provide a NKI kernel to analyze as an `instance : HasKernel` + - read `decide_sound [HasKernel]`, which will evalute all + success and safety checks and provide a result of success, + with propositional proof of desired path semantics, or a + failure with a message constructed from the NKI source +-/ + +import KLR.NKI.Basic +import KLR.Compile.Dataflow +import KLR.Compile.DataflowTestKernels + +open KLR.NKI + +section DefVarAction + + inductive VarAction where + | Read (name : String) (pos : Pos) + | Write (name : String) (ty : Option Expr) (pos : Pos) + | None + + instance VarAction.toString : ToString VarAction where + toString := fun + | Read name pos => s!"Read({name} @ {pos.line}, {pos.column})" + | Write name _ pos => s!"Write({name} @ {pos.line}, {pos.column})" + | _ => "None" + + def VarAction.var := fun + | Read name _ => some name + | Write name _ _ => some name + | _ => none + +end DefVarAction + +section DefNKIWalker + + structure NKIWalker where + num_nodes : ℕ + num_nodes_nonzero : num_nodes > 0 + last_node : ℕ + actions : ℕ → VarAction + edges : ℕ → ℕ → Bool + breaks : List ℕ + conts : List ℕ + rets : List ℕ + vars : List String --list of varnames seen + + instance NKIWalker.toString : ToString NKIWalker where + toString walker := + let row n := + let tgts := (List.range walker.num_nodes).filter (walker.edges n) + let num := if n = walker.last_node then s!"[{n} (exit)]" else s!"[{n}]" + s!"Node {num} : {walker.actions n} ↦ Nodes {tgts}\n" + String.intercalate "\n" ((List.range walker.num_nodes).map row ++ ["vars: ", walker.vars.toString]) + + def NKIWalker.init : NKIWalker := { + num_nodes := 1 -- zero is always the first node + num_nodes_nonzero := by trivial + last_node := 0 -- zero is always the first node + actions _ := VarAction.None + edges _ _ := false + breaks := [] + conts := [] + rets := [] + vars := [] + } + + def NKIWalker.Node (walker : NKIWalker) : Type := Fin (walker.num_nodes) + def NKIWalker.Var (walker : NKIWalker) : Type := Fin (walker.vars.length) + + def NKIWalker.reads (walker : NKIWalker) (n : walker.Node) (v : walker.Var) : Bool := + match walker.actions n.val with + | VarAction.Read name _ => name = walker.vars.get v + | _ => false + + def NKIWalker.writes (walker : NKIWalker) (n : walker.Node) (v : walker.Var) : Bool := + match walker.actions n.val with + | VarAction.Write name _ _ => name = walker.vars.get v + | _ => false + + def NKIWalker.is_path (walker : NKIWalker) : List walker.Node → Bool := fun + | [] => True + | [n] => walker.edges 0 n.val + | n₁ :: n₀ :: tl => walker.is_path (n₀ :: tl) ∧ (walker.edges n₀.val n₁.val) + + def NKIWalker.is_path_lowers (walker : NKIWalker) : + ∀ n ℓ, walker.is_path (n::ℓ) → walker.is_path ℓ := by { + intro n₁ ℓ₁ h + cases ℓ₁ with | nil => simp [is_path] | cons n₀ ℓ₀ + simp_all [is_path] + } + + structure NKIWalker.Path (walker : NKIWalker) where + nodes : List walker.Node + nodes_sound : walker.is_path nodes + + + -- a path can always be unrolled into a shorter valid one, with proof of an edge across the unrolling + def NKIWalker.Path.unroll (walker : NKIWalker) (𝕡 : walker.Path) + : 𝕡.nodes.length ≥ 2 → + ∃ (n₁ n₀ : walker.Node) (tl : List walker.Node), + (walker.edges n₀.val n₁.val) ∧ (n₁ :: n₀ :: tl = 𝕡.nodes) ∧ (walker.is_path (n₀ :: tl)) := by { + intro not_tiny + rcases 𝕡_def : 𝕡.nodes + simp [𝕡_def] at not_tiny + rename_i n₁ tl₁ + rcases tl₁_def : tl₁ + simp [𝕡_def, tl₁_def] at not_tiny + rename_i n₀ tl₀ + exists n₁, n₀, tl₀ + apply And.intro + { + let sound := 𝕡.nodes_sound + simp [𝕡_def, tl₁_def, is_path] at sound + exact sound.right + } + { + simp [←tl₁_def] + apply walker.is_path_lowers n₁ tl₁ + rw [←𝕡_def] + apply 𝕡.nodes_sound + } + } + + def NKIWalker.Path.writes_somewhere (walker : NKIWalker) (𝕡 : walker.Path) (v : walker.Var) : Bool := + 𝕡.nodes.tail.any (walker.writes . v) + + -- easier to rewrite this than find it in the library lol + abbrev mem_lifts {α} (a : α) (ℓ : List α) : a ∈ ℓ.tail → a ∈ ℓ := by { + intro h + cases ℓ + contradiction + simp_all + } + + def NKIWalker.Path.writes_somewhere_lifts (walker : NKIWalker) (𝕡₀ 𝕡₁ : walker.Path) (v : walker.Var) + : 𝕡₁.nodes.tail = 𝕡₀.nodes → 𝕡₀.writes_somewhere walker v → 𝕡₁.writes_somewhere walker v := by { + simp [writes_somewhere] + intro unroll n₀ n₀_in n₀_writes + exists n₀ + apply And.intro + simp [unroll] + apply mem_lifts + assumption + assumption + } + + def NKIWalker.Path.true_at_terminus (walker : NKIWalker) (𝕡 : walker.Path) (motive : walker.Node → Bool) : Bool := + match 𝕡.nodes with + | n :: _ => motive n + | _ => false + + def NKIWalker.Path.reads_at_terminus (walker : NKIWalker) (𝕡 : walker.Path) (v : walker.Var) : Bool := + 𝕡.true_at_terminus walker (walker.reads . v) + + -- proving (or failing to prove) this is the goal!! + def NKIWalker.sound (walker : NKIWalker) : Prop := + ∀ (𝕡 : walker.Path) v, (𝕡.reads_at_terminus walker v) → (𝕡.writes_somewhere walker v) + + def NKIWalker.processAction (walker : NKIWalker) (action : VarAction) : NKIWalker := + let N := walker.num_nodes + {walker with + num_nodes := N + 1 + num_nodes_nonzero := by simp + last_node := N + actions n := if n = N then action else walker.actions n + edges A B := (A, B) = (walker.last_node, N) + ∨ (walker.edges A B) + vars := match action.var with + | some var => if var ∈ walker.vars then walker.vars else walker.vars.concat var + | none => walker.vars + } + + + def NKIWalker.setLast (walker : NKIWalker) (last_node : ℕ) : NKIWalker := {walker with + last_node := last_node + } + + + def NKIWalker.addEdge (walker : NKIWalker) (a b : ℕ) : NKIWalker := {walker with + edges A B := (A, B) = (a, b) ∨ walker.edges A B + } + + + def NKIWalker.addBreak (walker : NKIWalker) : NKIWalker := {walker with + breaks := walker.breaks ++ [walker.last_node] + } + + + def NKIWalker.clearBreaks (walker : NKIWalker) : NKIWalker := {walker with + breaks := [] + } + + def NKIWalker.addContinue (walker : NKIWalker): NKIWalker := {walker with + conts := walker.conts ++ [walker.last_node] + } + + + def NKIWalker.clearConts (walker : NKIWalker) : NKIWalker := {walker with + conts := [] + } + + + def NKIWalker.addReturn (walker : NKIWalker) : NKIWalker := {walker with + rets := walker.rets ++ [walker.last_node] + } + mutual def NKIWalker.processExpr (walker : NKIWalker) (expr : Expr) : NKIWalker := + let ⟨expr, pos⟩ := expr + match _ : expr with + | Expr'.value _ => walker + | Expr'.var (name : String) => walker.processAction (VarAction.Read name pos) + | Expr'.proj (expr : Expr) _ => walker.processExpr expr + | Expr'.tuple (elements : List Expr) => walker.processExprList elements + | Expr'.access (expr : Expr) _ => walker.processExpr expr + | Expr'.binOp _ left right => (walker.processExpr left).processExpr right + | Expr'.ifExp test body orelse => + let body_walker := ((walker.processExpr test).processExpr body) + let orelse_walker := ((body_walker.setLast walker.last_node).processExpr orelse) + let complete_walker := (orelse_walker.processAction VarAction.None) + complete_walker.addEdge body_walker.last_node complete_walker.last_node + | Expr'.call (f: Expr) (args: List Expr) (_ : List Keyword) => + (walker.processExpr f).processExprList args + termination_by sizeOf expr + decreasing_by + all_goals { + try {rename_i expr' _<;> rcases h' : (expr, expr') with ⟨⟨⟨⟩, ⟨⟩⟩, ⟨⟨⟩, ⟨⟩⟩⟩ <;> simp_all <;> omega} + try {rcases h' : expr with ⟨⟨⟩, ⟨⟩⟩ <;> simp_all <;> omega} + } + def NKIWalker.processExprList (walker : NKIWalker) (exprs : List Expr) : NKIWalker := + exprs.foldl NKIWalker.processExpr walker + termination_by sizeOf exprs + end + + mutual def NKIWalker.processStmt (walker : NKIWalker) (stmt : Stmt) : NKIWalker := + let ⟨stmt, pos⟩ := stmt + match _ : stmt with + | Stmt'.expr (e : Expr) => walker.processExpr e + | Stmt'.assert (e : Expr) => walker.processExpr e + | Stmt'.ret (e : Expr) => (walker.processExpr e).addReturn + | Stmt'.assign ⟨Expr'.var name, _⟩ (ty : Option Expr) (e : Option Expr) => + let withty := (match ty with | some ty => walker.processExpr ty | none => walker) + let withe := (match e with | some e => withty.processExpr e | none => withty) + withe.processAction (VarAction.Write name ty pos) + | Stmt'.assign _ (ty : Option Expr) (e : Option Expr) => + let withty := (match ty with | some ty => walker.processExpr ty | none => walker) + let withe := (match e with | some e => withty.processExpr e | none => withty) + withe.processAction (VarAction.Write "" ty pos) + | Stmt'.ifStm (e : Expr) (thn : List Stmt) (els : List Stmt) => + let cond_walker := walker.processExpr e + let then_walker := cond_walker.processStmtList thn + let else_walker := (then_walker.setLast cond_walker.last_node).processStmtList els + let complete := else_walker.processAction VarAction.None + complete.addEdge then_walker.last_node complete.last_node + | Stmt'.forLoop (x : Expr) (iter: Expr) (body: List Stmt) => + let intro_walker := walker.processExpr iter + let outer_breaks := intro_walker.breaks + let outer_conts := intro_walker.conts + let inner_walker := ((intro_walker.clearBreaks).clearConts).processAction VarAction.None + let enter_node := inner_walker.last_node + let inner_pre_walk := match x with + | ⟨Expr'.var name, pos⟩ => inner_walker.processAction (VarAction.Write name none pos) + | _ => inner_walker + let inner_walked := inner_pre_walk.processStmtList body + let nearly_complete := (inner_walked.addEdge inner_walked.last_node enter_node).setLast enter_node + let complete := nearly_complete.processAction VarAction.None + let exit_node := complete.last_node + let with_conts := complete.conts.foldl (fun walker cont ↦ walker.addEdge cont enter_node) complete + let with_breaks := complete.breaks.foldl (fun walker brk ↦ walker.addEdge brk exit_node) with_conts + {with_breaks with + conts := outer_conts + breaks := outer_breaks + } + | Stmt'.breakLoop => (walker.processAction VarAction.None).addBreak + | Stmt'.continueLoop => (walker.processAction VarAction.None).addContinue + termination_by sizeOf stmt + decreasing_by + try rcases h : (thn, stmt) with ⟨⟨⟨⟩, ⟨⟩⟩, ⟨⟨⟩, ⟨⟩⟩⟩ <;> simp_all <;> omega + try rcases h : (els, stmt) with ⟨⟨⟨⟩, ⟨⟩⟩, ⟨⟨⟩, ⟨⟩⟩⟩ <;> simp_all <;> omega + try rcases h : (body, stmt) with ⟨⟨⟨⟩, ⟨⟩⟩, ⟨⟨⟩, ⟨⟩⟩⟩ <;> simp_all <;> omega + def NKIWalker.processStmtList (walker : NKIWalker) (stmts : List Stmt) : NKIWalker := + stmts.foldl NKIWalker.processStmt walker + termination_by sizeOf stmts + end + + + def NKIWalker.processFun (f : Fun) : NKIWalker := + let body_walker := (NKIWalker.init.processStmtList f.body).processAction VarAction.None + body_walker.rets.foldl (fun walker ret ↦ walker.addEdge ret body_walker.last_node) body_walker + + + -- WIP + def NKIWalker.processCoFun (f : Fun) : NKIWalker := + let walker := processFun f + let invert n := walker.num_nodes - n - 1 + let invert_action := fun + | VarAction.Read name pos => VarAction.Write name none pos + | VarAction.Write name _ pos => VarAction.Read name pos + | VarAction.None => VarAction.None + {walker with + edges n₀ n₁ := walker.edges (invert n₁) (invert n₀) + actions n := invert_action (walker.actions (invert n))} + + def NKIWalker.isClosed (walker : NKIWalker) := walker.breaks.isEmpty ∧ walker.conts.isEmpty + +end DefNKIWalker + +section WithKernel + variable [HK : HasKernel] + + abbrev 𝕂 := HasKernel.kernel + + /- + Perform the walk of the AST, converting it into a CFG + -/ + def walker [HasKernel] : NKIWalker := NKIWalker.processFun 𝕂 + + /- + extract the transitions from the walker + -/ + def transitions (n k : ℕ) (pre : Bool) : Bool := + (n = 0) ∨ + if _ : k < walker.vars.length then + match walker.actions n with + | VarAction.Write name _ _ => ¬ (name = walker.vars[k]) ∧ pre + | _ => pre + else + pre + + instance : Preorder Bool where + le_refl := by trivial + le_trans := by trivial + + instance : HasBot Bool where + bot := false + + instance : ToString Bool where + toString := fun + | true => "❌" + | false => "✅" + + /- + perform dataflow analysis + -/ + def 𝕏opt := (Solution + (ρ:=Bool) + (le_supl:=by trivial) + (le_supr:=by trivial) + (num_nodes:=walker.num_nodes) + (num_keys:=walker.vars.length) + (edges:=walker.edges) + (transitions:=transitions)).map (fun a ↦ {a with + key_labels k := walker.vars[k]? + }) + + class HasSuccess where + success : 𝕏opt.isSome + section WithSuccess + variable [HS : HasSuccess] + + + abbrev 𝕏 := 𝕏opt.get HasSuccess.success + abbrev ℙ := walker.Path + abbrev 𝕟 := walker.Node + abbrev 𝕍 := walker.Var + abbrev 𝔼 (n₀ n₁ : walker.Node) := walker.edges n₀.val n₁.val + + abbrev ν (n : 𝕟) (v : 𝕍) := 𝕏.vals n.val v.val n.isLt v.isLt + + abbrev σ (n₀ n₁ : 𝕟) (v : 𝕍) (𝔼n:𝔼 n₀ n₁): transitions n₀.val v.val (ν n₀ v) ≤ ν n₁ v := by { + apply 𝕏.props n₀.val n₁.val v.val n₀.isLt n₁.isLt v.isLt 𝔼n + } + + --#check 𝕏 + --#check ν + --#check σ + --#check ℙ + + abbrev var_def (n : 𝕟) (v : 𝕍) : Bool := ν n v = false + def NKIWalker.Path.var_def_at_terminus (𝕡 : ℙ) (v : 𝕍) : Bool := 𝕡.true_at_terminus walker (var_def . v) + + def NKIWalker.Path.not_def_at_entry (𝕡 : ℙ) (v : 𝕍) : 𝕡.nodes.length = 1 → ¬ 𝕡.var_def_at_terminus v := + match h : 𝕡.nodes with + | [n] => by { + intro + cases v + rename_i k hk + simp [NKIWalker.Path.var_def_at_terminus, NKIWalker.Path.true_at_terminus] + rw [h] + simp + have h_edge: walker.edges 0 n.val := by { + have h𝕡 := 𝕡.nodes_sound + unfold NKIWalker.is_path at h𝕡 + rw [h] at h𝕡 + simp at h𝕡 + assumption + } + apply σ ⟨0, walker.num_nodes_nonzero⟩ n ⟨k, hk⟩ h_edge + simp [transitions, LE.le, instLEOfPreorder, Preorder.toLE, instPreorderBool_compile, Bool.instLE] + } + | [] | _ :: _ :: _ => by simp + + @[simp] + abbrev NKIWalker.Path.motive (𝕡 : ℙ) (v : 𝕍) : Prop + := 𝕡.var_def_at_terminus v → 𝕡.writes_somewhere walker v + + @[simp] + abbrev length_motive n := ∀ (𝕡 : ℙ) v, 𝕡.nodes.length = n → (𝕡.motive v) + + abbrev sound_at_zero : length_motive 0 := by { + simp [NKIWalker.Path.var_def_at_terminus, NKIWalker.Path.true_at_terminus, NKIWalker.Path.writes_somewhere] + intro _ _ is_zero + simp [is_zero] + } + + abbrev sound_at_one : length_motive 1 := by { + simp + intro 𝕡 v _ _ + exfalso + apply (𝕡.not_def_at_entry v) + assumption + assumption + } + + abbrev sound_ind : ∀ len, len ≥ 1 → length_motive len → length_motive (len + 1) := by { + unfold length_motive + intro len len_nonzero IndHyp 𝕡₁ v 𝕡₁_len ν₁ + cases 𝕡₁_def : 𝕡₁ + rename_i nodes₁ is_path₁ + let ⟨n₁, n₀, tl₀, ε, unroll, is_path₀⟩ := 𝕡₁.unroll walker (by omega) + simp [NKIWalker.Path.var_def_at_terminus, NKIWalker.Path.true_at_terminus, ←unroll] at ν₁ + let 𝕡₀ : ℙ := ⟨n₀ :: tl₀, is_path₀⟩ + cases ν₀ : ν n₀ v + { + -- v is defined at n₀ - the terminus of 𝕡₀, so writes somewhere by ind hypo, then lift + rw [←𝕡₁_def] + apply (NKIWalker.Path.writes_somewhere_lifts walker 𝕡₀ 𝕡₁ v); simp [←unroll, 𝕡₀] + apply IndHyp + simp [←unroll] at 𝕡₁_len + simp [𝕡₀] + assumption + simp [NKIWalker.Path.var_def_at_terminus, NKIWalker.Path.true_at_terminus, 𝕡₀] + assumption + } + { + -- is not defined at n₀ -- the terminus of 𝕡₀, but is at n₁, the terminus of 𝕡₁ + -- since we have ε : edge from n₀ to n₁, σ n₀ n₀ + let σ' := σ n₀ n₁ v ε + simp [transitions, LE.le, instLEOfPreorder, Preorder.toLE, instPreorderBool_compile, Bool.instLE, ν₀, ν₁] at σ' + let ⟨_, σ''⟩ := σ' + cases action_def : walker.actions n₀.val <;> rw [action_def] at σ'' <;> try simp at σ'' + rename_i _ name _ + simp [NKIWalker.Path.writes_somewhere] + simp [𝕡₁_def] at unroll + simp [←unroll, action_def, NKIWalker.writes] + apply Or.inl + assumption + } + } + + abbrev sound_everywhere : ∀ n, length_motive n := fun + | 0 => sound_at_zero + | 1 => sound_at_one + | n + 2 => sound_ind (n + 1) (by omega) (sound_everywhere (n + 1)) + + def no_def_without_a_write : ∀ (𝕡 : ℙ) v, (𝕡.var_def_at_terminus v) → (𝕡.writes_somewhere walker v) := by { + intro 𝕡 v + apply sound_everywhere + rfl + } + + abbrev is_safe_at (n : 𝕟) (v : 𝕍) : Prop := walker.reads n v → var_def n v + + abbrev is_safe : Prop := ∀ (n : 𝕟) (v : 𝕍), is_safe_at n v + + abbrev local_safety_decidable : ∀ n v, Decidable (is_safe_at n v) := by { + intro n v + unfold is_safe_at + cases reads? : walker.reads n v <;> + cases defs? : var_def n v <;> + simp [is_safe_at] <;> try {apply isTrue; trivial} + apply isFalse; trivial + } + + inductive Maybe (P : Prop) -- option type plus message option + | Yes : P → Maybe P + | No : Maybe P + | NoBC : String → Maybe P --no because of message + + instance Maybe.toString : ToString (Maybe P) where + toString := fun + | Yes _ => s!"YES [SAFETY PROVEN]" + | No => "NO [SAFETY NOT PROVEN]" + | NoBC s => s!"NO [SAFETY NOT PROVEN] BECAUSE: {s}" + + def Maybe.well? (s : Maybe P) := match s with + | No => false + | _ => true + + def decide_success : Maybe (𝕏opt.isSome) := by { + cases h : 𝕏opt with | none => apply Maybe.No | some => { + apply Maybe.Yes; simp + } + } + + abbrev forall_fin {n} (f : Fin n → Bool) : Bool := (Vector.ofFn f).all (.) + + abbrev forall_fin_sound (f : Fin n → Bool) : forall_fin f → (m : Fin n) → (f m) := by { + simp [forall_fin] + intro h m + apply h + } + + abbrev 𝕀 (α) (a : α) := a + + def get_unsafe_reads : List VarAction := + (List.ofFn (fun n : 𝕟 ↦ (n, List.ofFn (𝕀 𝕍)))).flatMap (fun (n, vs) ↦ + if vs.any (fun v ↦ ¬ decide (is_safe_at n v)) then [walker.actions n.val] else []) + + def get_unsafe_pos : List Pos := + get_unsafe_reads.flatMap (fun | VarAction.Read _ pos => [pos] | _ => []) + + --def print_unsafe_reads : String := + --(get_unsafe_reads h𝕏).foldl + + def decide_safety : Maybe is_safe := by { + let safe := forall_fin (fun n ↦ forall_fin (fun v ↦ decide (is_safe_at n v))) + by_cases safety : safe + swap; + -- if any reads occur where a var isnt def this will hit and fail + apply Maybe.NoBC; apply kernel_highlighted_repr; apply get_unsafe_pos + apply Maybe.Yes + unfold is_safe + intro n v + have safety_at_n := forall_fin_sound _ safety n + have safety := (forall_fin_sound _ safety_at_n v) + apply of_decide_eq_true + assumption + } + + class IsSafe where + safety : is_safe + section WithSafety + variable [IS : IsSafe] + + def no_read_without_a_def : ∀ (𝕡 : ℙ) v, (𝕡.reads_at_terminus walker v) → (𝕡.var_def_at_terminus v) + := by { + simp [NKIWalker.Path.var_def_at_terminus, NKIWalker.Path.reads_at_terminus, NKIWalker.Path.true_at_terminus] + intro 𝕡 v h + cases nodes_def : 𝕡.nodes with | nil | cons n ℓ <;> simp_all + let safety := IS.safety + simp [is_safe, is_safe_at] at safety + apply safety + assumption + } + + def no_read_without_a_write : walker.sound := by { + unfold NKIWalker.sound + intro 𝕡 name reads + apply no_def_without_a_write + apply no_read_without_a_def + assumption + } + end WithSafety + end WithSuccess + + def decide_sound : Maybe (walker.sound) := by { + cases decide_success with + | No | NoBC _ => apply Maybe.No + | Yes success + have HS : HasSuccess := ⟨success⟩ + cases decide_safety with + | No => apply Maybe.No + | NoBC s => apply Maybe.NoBC s + | Yes safety + have IS : IsSafe := ⟨safety⟩ + apply Maybe.Yes + apply no_read_without_a_write + } +end WithKernel + +instance : HasKernel := safe_kernel_1 + +#eval decide_sound + +instance : HasKernel := unsafe_kernel_2 + +#eval decide_sound + +instance : HasKernel := unsafe_kernel_3 + +#eval decide_sound