Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: BitVec.sshiftRight' in bv_decide #5995

Merged
merged 1 commit into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Init/Data/BitVec/Bitblast.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1092,8 +1092,8 @@ def sshiftRightRec (x : BitVec w₁) (y : BitVec w₂) (n : Nat) : BitVec w₁ :

@[simp]
theorem sshiftRightRec_zero_eq (x : BitVec w₁) (y : BitVec w₂) :
sshiftRightRec x y 0 = x.sshiftRight' (y &&& 1#w₂) := by
simp only [sshiftRightRec, twoPow_zero]
sshiftRightRec x y 0 = x.sshiftRight' (y &&& twoPow w₂ 0) := by
simp only [sshiftRightRec]

@[simp]
theorem sshiftRightRec_succ_eq (x : BitVec w₁) (y : BitVec w₂) (n : Nat) :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ where
mkApp4 (mkConst ``BVExpr.shiftLeft) (toExpr m) (toExpr n) (go lhs) (go rhs)
| .shiftRight (m := m) (n := n) lhs rhs =>
mkApp4 (mkConst ``BVExpr.shiftRight) (toExpr m) (toExpr n) (go lhs) (go rhs)
| .arithShiftRight (m := m) (n := n) lhs rhs =>
mkApp4 (mkConst ``BVExpr.arithShiftRight) (toExpr m) (toExpr n) (go lhs) (go rhs)

instance : ToExpr BVBinPred where
toExpr x :=
Expand Down
14 changes: 10 additions & 4 deletions src/Lean/Elab/Tactic/BVDecide/Frontend/BVDecide/Reify.lean
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ where
``BVUnOp.shiftLeftConst
``Std.Tactic.BVDecide.Reflect.BitVec.shiftLeftNat_congr
else
let_expr BitVec _ := β | return none
shiftReflection
β
distanceExpr
innerExpr
.shiftLeft
Expand All @@ -78,8 +78,8 @@ where
``BVUnOp.shiftRightConst
``Std.Tactic.BVDecide.Reflect.BitVec.shiftRightNat_congr
else
let_expr BitVec _ := β | return none
shiftReflection
β
distanceExpr
innerExpr
.shiftRight
Expand All @@ -92,6 +92,13 @@ where
innerExpr
.arithShiftRightConst
``BVUnOp.arithShiftRightConst
``Std.Tactic.BVDecide.Reflect.BitVec.arithShiftRightNat_congr
| BitVec.sshiftRight' _ _ innerExpr distanceExpr =>
shiftReflection
distanceExpr
innerExpr
.arithShiftRight
``BVExpr.arithShiftRight
``Std.Tactic.BVDecide.Reflect.BitVec.arithShiftRight_congr
| BitVec.zeroExtend _ newWidthExpr innerExpr =>
let some newWidth ← getNatValue? newWidthExpr | return none
Expand Down Expand Up @@ -258,11 +265,10 @@ where
let some distance ← ReifiedBVExpr.getNatOrBvValue? β distanceExpr | return none
shiftConstLikeReflection distance innerExpr shiftOp shiftOpName congrThm

shiftReflection (β : Expr) (distanceExpr : Expr) (innerExpr : Expr)
shiftReflection (distanceExpr : Expr) (innerExpr : Expr)
(shiftOp : {m n : Nat} → BVExpr m → BVExpr n → BVExpr m) (shiftOpName : Name)
(congrThm : Name) :
LemmaM (Option ReifiedBVExpr) := do
let_expr BitVec _ ← β | return none
let some inner ← goOrAtom innerExpr | return none
let some distance ← goOrAtom distanceExpr | return none
let bvExpr : BVExpr inner.width := shiftOp inner.bvExpr distance.bvExpr
Expand Down
11 changes: 11 additions & 0 deletions src/Std/Tactic/BVDecide/Bitblast/BVExpr/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,10 @@ inductive BVExpr : Nat → Type where
shift right by another BitVec expression. For constant shifts there exists a `BVUnop`.
-/
| shiftRight (lhs : BVExpr m) (rhs : BVExpr n) : BVExpr m
/--
shift right arithmetically by another BitVec expression. For constant shifts there exists a `BVUnop`.
-/
| arithShiftRight (lhs : BVExpr m) (rhs : BVExpr n) : BVExpr m

namespace BVExpr

Expand All @@ -260,6 +264,7 @@ def toString : BVExpr w → String
| .signExtend v expr => s!"(sext {v} {expr.toString})"
| .shiftLeft lhs rhs => s!"({lhs.toString} << {rhs.toString})"
| .shiftRight lhs rhs => s!"({lhs.toString} >> {rhs.toString})"
| .arithShiftRight lhs rhs => s!"({lhs.toString} >>a {rhs.toString})"


instance : ToString (BVExpr w) := ⟨toString⟩
Expand Down Expand Up @@ -299,6 +304,7 @@ def eval (assign : Assignment) : BVExpr w → BitVec w
| .signExtend v expr => BitVec.signExtend v (eval assign expr)
| .shiftLeft lhs rhs => (eval assign lhs) <<< (eval assign rhs)
| .shiftRight lhs rhs => (eval assign lhs) >>> (eval assign rhs)
| .arithShiftRight lhs rhs => BitVec.sshiftRight' (eval assign lhs) (eval assign rhs)

@[simp]
theorem eval_var : eval assign ((.var idx) : BVExpr w) = (assign.getD idx).bv.truncate _ := by
Expand Down Expand Up @@ -343,6 +349,11 @@ theorem eval_shiftLeft : eval assign (.shiftLeft lhs rhs) = (eval assign lhs) <<
theorem eval_shiftRight : eval assign (.shiftRight lhs rhs) = (eval assign lhs) >>> (eval assign rhs) := by
rfl

@[simp]
theorem eval_arithShiftRight :
eval assign (.arithShiftRight lhs rhs) = BitVec.sshiftRight' (eval assign lhs) (eval assign rhs) := by
rfl

end BVExpr

/--
Expand Down
22 changes: 22 additions & 0 deletions src/Std/Tactic/BVDecide/Bitblast/BVExpr/Circuit/Impl/Expr.lean
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,18 @@ where
dsimp only at hlaig hraig
omega
⟨res, this⟩
| .arithShiftRight lhs rhs =>
let ⟨⟨aig, lhs⟩, hlaig⟩ := go aig lhs
let ⟨⟨aig, rhs⟩, hraig⟩ := go aig rhs
let lhs := lhs.cast <| by
dsimp only at hlaig hraig
omega
let res := bitblast.blastArithShiftRight aig ⟨_, lhs, rhs⟩
have := by
apply AIG.LawfulVecOperator.le_size_of_le_aig_size (f := bitblast.blastArithShiftRight)
dsimp only at hlaig hraig
omega
⟨res, this⟩

theorem bitblast.go_decl_eq (aig : AIG BVBit) (expr : BVExpr w) :
∀ (idx : Nat) (h1) (h2), (go aig expr).val.aig.decls[idx]'h2 = aig.decls[idx]'h1 := by
Expand Down Expand Up @@ -300,6 +312,16 @@ theorem bitblast.go_decl_eq (aig : AIG BVBit) (expr : BVExpr w) :
· omega
· apply Nat.lt_of_lt_of_le h1
apply Nat.le_trans <;> assumption
| arithShiftRight lhs rhs lih rih =>
dsimp only [go]
have := (bitblast.go aig lhs).property
have := (bitblast.go aig lhs).property
have := (go (go aig lhs).1.aig rhs).property
rw [AIG.LawfulVecOperator.decl_eq (f := blastArithShiftRight)]
rw [rih, lih]
· omega
· apply Nat.lt_of_lt_of_le h1
apply Nat.le_trans <;> assumption

instance : AIG.LawfulVecOperator BVBit (fun _ w => BVExpr w) bitblast where
le_size := by
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,14 @@ instance : AIG.LawfulVecOperator α AIG.ShiftTarget blastArithShiftRightConst wh
unfold blastArithShiftRightConst
simp

namespace blastShiftRight

structure TwoPowShiftTarget (aig : AIG α) (w : Nat) where
n : Nat
lhs : AIG.RefVec aig w
rhs : AIG.RefVec aig n
pow : Nat

namespace blastShiftRight

def twoPowShift (aig : AIG α) (target : TwoPowShiftTarget aig w) : AIG.RefVecEntry α w :=
let ⟨n, lhs, rhs, pow⟩ := target
if h : pow < n then
Expand Down Expand Up @@ -246,6 +246,120 @@ instance : AIG.LawfulVecOperator α AIG.ArbitraryShiftTarget blastShiftRight whe
apply AIG.LawfulVecOperator.lt_size_of_lt_aig_size (f := blastShiftRight.twoPowShift)
assumption

namespace blastArithShiftRight

def twoPowShift (aig : AIG α) (target : TwoPowShiftTarget aig w) : AIG.RefVecEntry α w :=
let ⟨n, lhs, rhs, pow⟩ := target
if h : pow < n then
let res := blastArithShiftRightConst aig ⟨lhs, 2 ^ pow⟩
let aig := res.aig
let shifted := res.vec

have := AIG.LawfulVecOperator.le_size (f := blastArithShiftRightConst) ..
let rhs := rhs.cast this
let lhs := lhs.cast this
AIG.RefVec.ite aig ⟨rhs.get pow h, shifted, lhs⟩
else
⟨aig, lhs⟩

instance : AIG.LawfulVecOperator α TwoPowShiftTarget twoPowShift where
le_size := by
intros
unfold twoPowShift
dsimp only
split
· apply AIG.LawfulVecOperator.le_size_of_le_aig_size (f := AIG.RefVec.ite)
apply AIG.LawfulVecOperator.le_size (f := blastArithShiftRightConst)
· simp
decl_eq := by
intros
unfold twoPowShift
dsimp only
split
· rw [AIG.LawfulVecOperator.decl_eq (f := AIG.RefVec.ite)]
rw [AIG.LawfulVecOperator.decl_eq (f := blastArithShiftRightConst)]
apply AIG.LawfulVecOperator.lt_size_of_lt_aig_size (f := blastArithShiftRightConst)
assumption
· simp

end blastArithShiftRight

def blastArithShiftRight (aig : AIG α) (target : AIG.ArbitraryShiftTarget aig w) :
AIG.RefVecEntry α w :=
let ⟨n, input, distance⟩ := target
if n = 0 then
⟨aig, input⟩
else
let res := blastArithShiftRight.twoPowShift aig ⟨_, input, distance, 0⟩
let aig := res.aig
let acc := res.vec
have := AIG.LawfulVecOperator.le_size (f := blastArithShiftRight.twoPowShift) ..
let distance := distance.cast this
go aig distance 0 acc
where
go {n : Nat} (aig : AIG α) (distance : AIG.RefVec aig n) (curr : Nat)
(acc : AIG.RefVec aig w) :
AIG.RefVecEntry α w :=
if curr < n - 1 then
let res := blastArithShiftRight.twoPowShift aig ⟨_, acc, distance, curr + 1⟩
let aig := res.aig
let acc := res.vec
have := AIG.LawfulVecOperator.le_size (f := blastArithShiftRight.twoPowShift) ..
let distance := distance.cast this
go aig distance (curr + 1) acc
else
⟨aig, acc⟩
termination_by n - 1 - curr

theorem blastArithShiftRight.go_le_size (aig : AIG α) (distance : AIG.RefVec aig n) (curr : Nat)
(acc : AIG.RefVec aig w) :
aig.decls.size ≤ (go aig distance curr acc).aig.decls.size := by
unfold go
dsimp only
split
· refine Nat.le_trans ?_ (by apply go_le_size)
apply AIG.LawfulVecOperator.le_size (f := blastArithShiftRight.twoPowShift)
· simp
termination_by n - 1 - curr

theorem blastArithShiftRight.go_decl_eq (aig : AIG α) (distance : AIG.RefVec aig n) (curr : Nat)
(acc : AIG.RefVec aig w) :
∀ (idx : Nat) (h1) (h2),
(go aig distance curr acc).aig.decls[idx]'h2 = aig.decls[idx]'h1 := by
generalize hgo : go aig distance curr acc = res
unfold go at hgo
dsimp only at hgo
split at hgo
· rw [← hgo]
intros
rw [blastArithShiftRight.go_decl_eq]
rw [AIG.LawfulVecOperator.decl_eq (f := blastArithShiftRight.twoPowShift)]
apply AIG.LawfulVecOperator.lt_size_of_lt_aig_size (f := blastArithShiftRight.twoPowShift)
assumption
· simp [← hgo]
termination_by n - 1 - curr


instance : AIG.LawfulVecOperator α AIG.ArbitraryShiftTarget blastArithShiftRight where
le_size := by
intros
unfold blastArithShiftRight
dsimp only
split
· simp
· refine Nat.le_trans ?_ (by apply blastArithShiftRight.go_le_size)
apply AIG.LawfulVecOperator.le_size (f := blastArithShiftRight.twoPowShift)
decl_eq := by
intros
unfold blastArithShiftRight
dsimp only
split
· simp
· rw [blastArithShiftRight.go_decl_eq]
rw [AIG.LawfulVecOperator.decl_eq (f := blastArithShiftRight.twoPowShift)]
apply AIG.LawfulVecOperator.lt_size_of_lt_aig_size (f := blastArithShiftRight.twoPowShift)
assumption

end bitblast
end BVExpr

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,18 @@ theorem go_denote_eq (aig : AIG BVBit) (expr : BVExpr w) (assign : Assignment) :
· simp [Ref.hgate]
· intros
rw [← rih]
| arithShiftRight lhs rhs lih rih =>
simp only [go, eval_arithShiftRight]
apply denote_blastArithShiftRight
· intros
dsimp only
rw [go_denote_mem_prefix]
rw [← lih (aig := aig)]
· simp
· assumption
· simp [Ref.hgate]
· intros
rw [← rih]
| bin lhs op rhs lih rih =>
cases op with
| and =>
Expand Down
Loading
Loading