Skip to content

Commit

Permalink
feat: use match_bv when possible
Browse files Browse the repository at this point in the history
  • Loading branch information
arthur-adjedj committed Feb 11, 2025
1 parent a5a1072 commit 6fd3752
Show file tree
Hide file tree
Showing 3 changed files with 305 additions and 12 deletions.
264 changes: 264 additions & 0 deletions src/sail_lean_backend/Sail/BitVec.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
/-
Copyright (c) 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author(s): Shilpi Goel, Siddharth Bhat
-/

-- Taken from https://github.com/leanprover/LNSym/blob/main/Arm/BitVec.lean

import Lean.Elab.Term
import Lean.Meta.Reduce
import Std.Tactic.BVDecide

open BitVec

/- Bitvector pattern component syntax category, originally written by
Leonardo de Moura. -/
declare_syntax_cat bvpat_comp
syntax num : bvpat_comp
syntax ident (":" num)? : bvpat_comp
syntax "_" ":" num : bvpat_comp

/--
Bitvector pattern syntax category.
Example: [sf:1,0011010000,Rm:5,000000,Rn:5,Rd:5]
-/
declare_syntax_cat bvpat
syntax "[" bvpat_comp,* "]" : bvpat

open Lean

abbrev BVPatComp := TSyntax `bvpat_comp
abbrev BVPat := TSyntax `bvpat

/-- Return the number of bits in a bit-vector component pattern. -/
def BVPatComp.length (c : BVPatComp) : Nat := Id.run do
match c with
| `(bvpat_comp| $n:num) =>
let some str := n.raw.isLit? `num | pure 0
return str.length
| `(bvpat_comp| $_:ident : $n:num) =>
return n.raw.toNat
| `(bvpat_comp| $_:ident ) =>
return 1
| `(bvpat_comp| _ : $n:num) =>
return n.raw.toNat
| _ =>
return 0

/--
If the pattern component is a bitvector literal, convert it into a bit-vector term
denoting it.
-/
def BVPatComp.toBVLit? (c : BVPatComp) : MacroM (Option Term) := do
match c with
| `(bvpat_comp| $n:num) =>
let len := c.length
let some str := n.raw.isLit? `num | Macro.throwErrorAt c "invalid bit-vector literal"
let bs := str.toList
let mut val := 0
for b in bs do
if b = '1' then
val := 2*val + 1
else if b = '0' then
val := 2*val
else
Macro.throwErrorAt c "invalid bit-vector literal, '0'/'1's expected"
let r ← `(BitVec.ofNat $(quote len) $(quote val))
return some r
| _ => return none

/--
If the pattern component is a pattern variable of the form `<id>:<size>` return
`some id`.
-/
def BVPatComp.toBVVar? (c : BVPatComp) : MacroM (Option (TSyntax `ident)) := do
match c with
| `(bvpat_comp| $x:ident $[: $_:num]?) =>
return some x
| _ => return none

def BVPat.getComponents (p : BVPat) : Array BVPatComp :=
match p with
| `(bvpat| [$comp,*]) => comp.getElems.reverse
| _ => #[]

/--
Return the number of bits in a bit-vector pattern.
-/
def BVPat.length (p : BVPat) : Nat := Id.run do
let mut sz := 0
for c in p.getComponents do
sz := sz + c.length
return sz

/--
Return a term that evaluates to `true` if `var` is an instance of the pattern `pat`.
-/
def genBVPatMatchTest (vars : Array Term) (pats : Array BVPat) : MacroM Term := do
if vars.size != pats.size then
Macro.throwError "incorrect number of patterns"
let mut result ← `(true)

for (pat, var) in pats.zip vars do
let mut shift := 0
for c in pat.getComponents do
let len := c.length
if let some bv ← c.toBVLit? then
let test ← `(extractLsb $(quote (shift + (len - 1))) $(quote shift) $var == $bv)
result ← `($result && $test)
shift := shift + len
return result

/--
Given a variable `var` representing a term that matches the pattern `pat`, and a term `rhs`,
return a term of the form
```
let y₁ := var.extract ..
...
let yₙ := var.extract ..
rhs
```
where `yᵢ`s are the pattern variables in `pat`.
-/
def declBVPatVars (vars : Array Term) (pats : Array BVPat) (rhs : Term) : MacroM Term := do
let mut result := rhs
for (pat, var) in pats.zip vars do
let mut shift := 0
for c in pat.getComponents do
let len := c.length
if let some y ← c.toBVVar? then
let rhs ← `(extractLsb $(quote (shift + (len - 1))) $(quote shift) $var)
result ← `(let $y := $rhs; $result)
shift := shift + len
return result

/--
Define the `match_bv .. with | bvpat => rhs | _ => rhs`.
The last entry is the `else`-case since we currently do not check whether
the patterns are exhaustive or not.
-/
syntax (name := matchBv) "match_bv " term,+ "with" (atomic("| " bvpat,+) " => " term)* ("| " "_ " " => " term)? : term

open Lean
open Elab
open Term

def checkBVPatLengths (lens : Array (Option Nat)) (pss : Array (Array BVPat)) : TermElabM Unit := do
for (len, i) in lens.zipWithIndex do
let mut patLen := none
for ps in pss do
unless ps.size == lens.size do
throwError "Expected {lens.size} patterns, found {ps.size}"
let p := ps[i]!
let pLen := p.length

-- compare the length to that of the type of the discriminant
if let some pLen' := len then
unless pLen == pLen' do
throwErrorAt p "Exprected pattern of length {pLen}, found {pLen'} instead"

-- compare the lengths of the patterns
if let some pLen' := patLen then
unless pLen == pLen' do
throwErrorAt p "patterns have differrent lengths"
else
patLen := some pLen

-- We use this to gather all the conditions expressing that the
-- previous pattern matches failed. This allows in turn to prove
-- exaustivity of the pattern matching.
abbrev dite_gather {α : Sort u} {old : Prop} (c : Prop) [h : Decidable c]
(t : old ∧ c → α) (e : old ∧ ¬ c → α) (ho : old) : α :=
h.casesOn (λ hc => e (And.intro ho hc)) (λ hc => t (And.intro ho hc))

@[term_elab matchBv]
partial
def elabMatchBv : TermElab := fun stx typ? =>
match stx with
| `(match_bv $[$discrs:term],* with
$[ | $[$pss:bvpat],* => $rhss:term ]*
$[| _ => $rhsElse?:term]?) => do
let xs := discrs

-- try to get the length of the BV to error-out
-- if a pattern has the wrong length
-- TODO: is it the best way to do that?
let lens ← discrs.mapM (fun x => do
let x ← elabTerm x none
let typ ← Meta.inferType x
match_expr typ with
| BitVec n =>
let n ← Meta.reduce n
match n with
| .lit (.natVal n) => return some n
| _ => return none
| _ => return none)

checkBVPatLengths lens pss

let mut result :=
if let some rhsElse := rhsElse? then
`(Function.const _ $rhsElse)
else
`(fun _ => by bv_decide)

for ps in pss.reverse, rhs in rhss.reverse do
let test ← liftMacroM <| genBVPatMatchTest xs ps
let rhs ← liftMacroM <| declBVPatVars xs ps rhs
result ← `(dite_gather $test (Function.const _ $rhs) $result)
let res ← liftMacroM <| `($result True.intro)
elabTerm res typ?
| _ => throwError "invalid syntax"

----------- TESTS -----------

def test_1 (x : BitVec 32) : BitVec 16 :=
match_bv x with
| [sf:1,0011010000,Rm:5,000000,Rn:5,Rd:5] => sf ++ Rm ++ Rn ++ Rd
| [sf:1,0000010000,11111000000,Rn:5,Rd:5] => sf ++ Rn ++ Rd ++ Rd
| _ => 0#16

def test_2 (x y : BitVec 32) : BitVec 16 :=
match_bv x, y with
| [sf:1,0011010000,Rm:5,000000,Rn:5,Rd:5], [_sf':1,0000010000,11111000000,_Rn':5,_Rd':5]
=> sf ++ Rm ++ Rn ++ Rd
| [sf:1,0000010000,11111000000,Rn:5,Rd:5], [_sf:1,0000010000,11111000000,_Rn:5,_Rd:5] => sf ++ Rn ++ Rd ++ Rd
| _ => 0#16

/-- error: Exprected pattern of length 32, found 33 instead -/
#guard_msgs in
def test_fail_length_one_pat (x : BitVec 33) : Bool :=
match_bv x with
| [sf:1,0011010000,Rm:5,000000,Rn:5,Rd:5] => true
| [sf:1,0000010000,11111000000,Rn:6,Rd:5] => false
| _ => true

/-- error: Exprected pattern of length 32, found 33 instead -/
#guard_msgs in
def test_fail_length_two_pats (x : BitVec 32) (y : BitVec 33) : BitVec 16 :=
match_bv x, y with
| [sf:1,0011010000,Rm:5,000000,Rn:5,Rd:5], [sf':1,0000010000,11111000000,Rn':6,Rd':5]
=> sf ++ Rm ++ Rn ++ Rd
| [sf:1,0000010000,11111000000,Rn:5,Rd:5], [sf:1,0000010000,11111000000,Rn:5,Rd:5] => sf ++ Rn ++ Rd ++ Rd
| _ => 0#16

-- TODO: it would be nice to check that the pattern length corresponds to the
-- length of the bit-vector being pattern match against...

-- def test_exhaustive_1 (x : BitVec 1) (h : x = x) : Bool :=
-- match_bv x with
-- | [0] => true
-- | [1] => false

-- def test_exhaustive_2 (x : BitVec 2) : Bool :=
-- match_bv x with
-- | [0, _:1] => true
-- | [1, _:1] => false

-- Failing test, because it is not exhaustive!
-- TODO: have a more informative error message
-- def test_fail_exhaustive_3 (x : BitVec 2) : Bool :=
-- match_bv x with
-- | [01] => true
-- | [1, _:1] => false
44 changes: 36 additions & 8 deletions src/sail_lean_backend/pretty_print_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -280,18 +280,24 @@ let lean_escape_string s = Str.global_replace (Str.regexp "\"") "\"\"" s

let doc_lit (L_aux (lit, l)) =
match lit with
| L_unit -> string "()"
| L_zero -> string "0#1"
| L_one -> string "1#1"
| L_unit -> string "()"
| L_zero -> string "0#1"
| L_one -> string "1#1"
| L_false -> string "false"
| L_true -> string "true"
| L_true -> string "true"
| L_num i -> doc_big_int i
| L_hex n -> utf8string ("0x" ^ n)
| L_bin n -> utf8string ("0b" ^ n)
| L_undef -> utf8string "(Fail \"undefined value of unsupported type\")"
| L_string s -> utf8string ("\"" ^ lean_escape_string s ^ "\"")
| L_real s -> utf8string s (* TODO test if this is really working *)

let doc_vec_lit (L_aux (lit, _) as l) =
match lit with
| L_zero -> string "0"
| L_one -> string "1"
| _ -> failwith "Unexpected litteral found in vector: " ^^ doc_lit l

let string_of_exp_con (E_aux (e, _)) =
match e with
| E_block _ -> "E_block"
Expand Down Expand Up @@ -356,17 +362,26 @@ let string_of_pat_con (P_aux (p, _)) =
let fixup_match_id (Id_aux (id, l) as id') =
match id with Id id -> Id_aux (Id (match id with "Some" -> "some" | "None" -> "none" | _ -> id), l) | _ -> id'

let rec doc_pat (P_aux (p, (l, annot)) as pat) =
let rec doc_pat ?(in_vector = false) (P_aux (p, (l, annot)) as pat) =
match p with
| P_wild -> underscore
| P_lit lit when in_vector -> doc_vec_lit lit
| P_lit lit -> doc_lit lit
| P_typ (Typ_aux (Typ_id (Id_aux (Id "bit", _)),_), p) when in_vector ->
doc_pat p ^^ string ":1"
| P_typ (Typ_aux (Typ_app (Id_aux (Id id, _), [A_aux (A_nexp (Nexp_aux (Nexp_constant i,_)), _)]),_), p)
when in_vector && (id = "bits" || id = "bitvector") ->
doc_pat p ^^ (string ":") ^^ doc_big_int i
| P_typ (ptyp, p) -> doc_pat p
| P_id id -> fixup_match_id id |> doc_id_ctor
| P_tuple pats -> separate (string ", ") (List.map doc_pat pats) |> parens
| P_list pats -> separate (string ", ") (List.map doc_pat pats) |> brackets
| P_vector pats -> concat (List.map (doc_pat ~in_vector:true) pats)
| P_vector_concat (pats) -> separate (string ",") (List.map (doc_pat ~in_vector:true) pats) |> brackets
| P_app (Id_aux (Id "None", _), p) -> string "none"
| P_app (cons, pats) -> doc_id_ctor (fixup_match_id cons) ^^ space ^^ separate_map (string ", ") doc_pat pats
| _ -> failwith ("Pattern " ^ string_of_pat_con pat ^ " " ^ string_of_pat pat ^ " not translatable yet.")
| P_as (pat,id) -> doc_pat pat
| _ -> failwith ("Doc Pattern " ^ string_of_pat_con pat ^ " " ^ string_of_pat pat ^ " not translatable yet.")

(* Copied from the Coq PP *)
let rebind_cast_pattern_vars pat typ exp =
Expand Down Expand Up @@ -406,6 +421,19 @@ let get_fn_implicits (Typ_aux (t, _)) : bool list =
in
match t with Typ_fn (args, cod) -> List.map arg_implicit args | _ -> []

let rec is_bitvector_pattern (P_aux (pat,_)) = match pat with
| P_vector _ | P_vector_concat _ -> true
| P_as (pat,_) -> is_bitvector_pattern pat
| _ -> false

let match_or_match_bv brs =
if List.exists (function
| Pat_aux (Pat_exp (pat,_),_) -> is_bitvector_pattern pat
| _ -> false) brs then
"match_bv "
else
"match "

let rec doc_match_clause (as_monadic : bool) ctx (Pat_aux (cl, l)) =
match cl with
| Pat_exp (pat, branch) -> string "| " ^^ doc_pat pat ^^ string " =>" ^^ space ^^ doc_exp as_monadic ctx branch
Expand Down Expand Up @@ -492,8 +520,8 @@ and doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) =
wrap_with_pure as_monadic
(braces (space ^^ doc_exp false ctx exp ^^ string " with " ^^ separate (comma ^^ space) args ^^ space))
| E_match (discr, brs) ->
let cases = separate_map hardline (fun br -> doc_match_clause as_monadic ctx br) brs in
string "match " ^^ doc_exp (effectful (effect_of discr)) ctx discr ^^ string " with" ^^ hardline ^^ cases
let cases = separate_map hardline (doc_match_clause as_monadic ctx) brs in
string (match_or_match_bv brs) ^^ doc_exp (effectful (effect_of discr)) ctx discr ^^ string " with" ^^ hardline ^^ cases
| E_assign ((LE_aux (le_act, tannot) as le), e) -> (
match le_act with
| LE_id id | LE_typ (_, id) -> string "writeReg " ^^ doc_id_ctor id ^^ space ^^ doc_exp false ctx e
Expand Down
9 changes: 5 additions & 4 deletions src/sail_lean_backend/sail_plugin_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ let lean_rewrites =
("move_termination_measures", []);
("instantiate_outcomes", [String_arg "coq"]);
("realize_mappings", []);
("remove_vector_subrange_pats", []);
(* ("remove_vector_subrange_pats", []); *)
("remove_duplicate_valspecs", []);
("toplevel_string_append", []);
("pat_string_append", []);
Expand All @@ -107,8 +107,8 @@ let lean_rewrites =
("tuple_assignments", []);
("vector_concat_assignments", []);
("simple_assignments", []);
("remove_vector_concat", []);
("remove_bitvector_pats", []);
(* ("remove_vector_concat", []); *)
(* ("remove_bitvector_pats", []); *)
(* ("remove_numeral_pats", []); *)
(* ("pattern_literals", [Literal_arg "lem"]); *)
("guarded_pats", []);
Expand All @@ -129,7 +129,7 @@ let lean_rewrites =
(* We need to do the exhaustiveness check before merging, because it may
introduce new wildcard clauses *)
("recheck_defs", []);
("make_cases_exhaustive", []);
(* ("make_cases_exhaustive", []); *)
(* merge funcls before adding the measure argument so that it doesn't
disappear into an internal pattern match *)
("merge_function_clauses", []);
Expand Down Expand Up @@ -187,6 +187,7 @@ let create_lake_project (out_name : string) default_sail_dir =
in
let project_main = open_out (Filename.concat project_dir (out_name_camel ^ ".lean")) in
output_string project_main ("import " ^ out_name_camel ^ ".Sail.Sail\n\n");
output_string project_main ("import " ^ out_name_camel ^ ".Sail.BitVec\n\n");
output_string project_main "open Sail\n\n";
project_main

Expand Down

0 comments on commit 6fd3752

Please sign in to comment.