Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

getLsb_sshiftRight #3

Closed
wants to merge 15 commits into from
344 changes: 344 additions & 0 deletions src/Init/Data/BitVec/Lemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ import Init.Data.Bool
import Init.Data.BitVec.Basic
import Init.Data.Fin.Lemmas
import Init.Data.Nat.Lemmas
import Init.Data.Int.Bitwise.Lemmas
import Init.Data.BitVec.Basic
import Init.Data.Nat.Div
import Init.Data.Int.DivModLemmas
import Init.Data.Int.DivMod

namespace BitVec

Expand Down Expand Up @@ -140,12 +145,14 @@ theorem ofBool_eq_iff_eq : ∀(b b' : Bool), BitVec.ofBool b = BitVec.ofBool b'
@[simp, bv_toNat] theorem toNat_ofNat (x w : Nat) : (x#w).toNat = x % 2^w := by
simp [BitVec.toNat, BitVec.ofNat, Fin.ofNat']


-- Remark: we don't use `[simp]` here because simproc` subsumes it for literals.
-- If `x` and `n` are not literals, applying this theorem eagerly may not be a good idea.
theorem getLsb_ofNat (n : Nat) (x : Nat) (i : Nat) :
getLsb (x#n) i = (i < n && x.testBit i) := by
simp [getLsb, BitVec.ofNat, Fin.val_ofNat']


@[simp, deprecated toNat_ofNat] theorem toNat_zero (n : Nat) : (0#n).toNat = 0 := by trivial

@[simp] theorem getLsb_zero : (0#w).getLsb i = false := by simp [getLsb]
Expand Down Expand Up @@ -266,6 +273,68 @@ theorem toInt_ofNat {n : Nat} (x : Nat) :
have p : 0 ≤ i % (2^n : Nat) := by omega
simp [toInt_eq_toNat_bmod, Int.toNat_of_nonneg p]

/-- 'BitVec.ofInt' only cares about values upto `emod`. -/
theorem ofInt_eq_ofInt_emod {n : Nat} (i : Int) :
(BitVec.ofInt n i) = BitVec.ofInt n (i % (2^n)) := by
apply BitVec.eq_of_toNat_eq
simp only [toNat_ofInt]
congr 1
have h2n : (2 : Int)^n = (((2 : Nat)^(n : Nat) : Nat) : Int) := by
rw [Int.natCast_pow]
rfl
rw [h2n, Int.emod_emod]

-- /- A variant of emod that knows that the output is a natural number. -/
-- abbrev Int.emod' (i : Int) (n : Nat) : Nat :=
-- (Int.emod i n).toNat

/-- Write ofInt in terms of `ofNat`
of the canonical natural number between 0 and 2^n.-/
theorem ofInt_eq_ofNat_emod {n : Nat} (i : Int) :
(BitVec.ofInt n i) = BitVec.ofNat n (i % (2^n)).toNat := by
have h2n : (2 : Int)^n = (((2 : Nat)^(n : Nat) : Nat) : Int) := by
rw [Int.natCast_pow]
rfl
apply BitVec.eq_of_toNat_eq
simp only [ofInt_eq_ofInt_emod, Int.emod_emod, toNat_ofInt, toNat_ofNat]
conv => lhs; simp [(· % ·), Mod.mod]
rw [Nat.mod_eq_of_lt]
rw [h2n]
congr 1
· apply Int.emod_emod
· have hlt : (i % 2^n) < 2^n := by
apply Int.emod_lt_of_pos
rw [h2n]
norm_cast
apply Nat.pow_pos
decide
rw [Int.toNat_lt]
· rw [h2n]
apply Int.emod_lt_of_pos
rw [Int.natCast_pow]
/- (0 : Int) < ↑2 ^ n -/
clear hlt h2n
apply Lean.Omega.Int.pos_pow_of_pos
decide
· apply Int.emod_nonneg
/- (2 : Int) ^ n ≠ (0 : Int) -/
apply Int.ne_of_gt
apply Lean.Omega.Int.pos_pow_of_pos
decide

@[simp]
theorem Int.testBit_natCast (n : Nat) : (n : Int).testBit i = n.testBit i := rfl

theorem Int.natCast_mod_natCast (n m : Nat) : (n : Int) % (m : Int) = ((n % m : Nat) : Int) := rfl

theorem Int.negSucc_mod_natCast (n m : Nat) :
(Int.negSucc m) % (n : Int) = n - ((m % n) + 1) := rfl

@[simp] theorem Int.toNat_natCast (n : Nat) : (n : Int).toNat = n := rfl

@[simp] theorem ofInt_natCast (w n : Nat) :
BitVec.ofInt w n = BitVec.ofNat w n := rfl

/-! ### zeroExtend and truncate -/

@[simp, bv_toNat] theorem toNat_zeroExtend' {m n : Nat} (p : m ≤ n) (x : BitVec m) :
Expand Down Expand Up @@ -628,6 +697,110 @@ theorem BitVec.shiftLeft_shiftLeft {w : Nat} (x : BitVec w) (n m : Nat) :
getLsb (x >>> i) j = getLsb x (i+j) := by
unfold getLsb ; simp

/-! ### sshiftRight-/

theorem sshiftRight_eq {x : BitVec n} {i : Nat} :
x.sshiftRight i = BitVec.ofInt n (x.toInt >>> i) := by
apply BitVec.eq_of_toInt_eq
simp [BitVec.sshiftRight]


theorem BitVec.toInt_eq_toNat_of_toInt_pos {x : BitVec n} (hx: x.toInt ≥ 0) :
x.toInt = ↑ (x.toNat) := by
rw [toInt_eq_toNat_cond]
simp
intros hx'
simp [BitVec.toInt] at hx
split at hx <;> omega


-- theorem sshiftRight_eq_ushiftRight_of_pos {x : BitVec n} (hx: x.toInt ≥ 0) :
-- x.sshiftRight i = x.ushiftRight i := by
-- rw [sshiftRight_eq, BitVec.ofInt_eq_ofNat_emod]
-- rw [BitVec.toInt_eq_toNat_of_toInt_pos hx]
-- rw [ushiftRight_eq]
-- rw [Int.emod_eq_of_lt]
-- · norm_cast
-- sorry
-- · /- need norm_num-/
-- sorry
-- · have ⟨x', hx'lt⟩ := x
-- simp
-- /- ↑x' >>> i < 2 ^ n -/
-- /- need norm_num -/
-- sorry

/-- The MSB of a bitvector is `true` iff its integer interpretetation is greater than or equal to zero. -/
theorem msb_eq_true_iff_toInt_lt_zero {x : BitVec w}
: x.msb = true ↔ x.toInt < 0 := by
rcases w with rfl | w <;> try simp <;> try omega
· rw [Subsingleton.elim x (0#0)]
simp
rw [BitVec.toInt_eq_toNat_cond]
constructor
· intro h
split
case inl hpos =>
rw [BitVec.msb_eq_decide] at h
simp at h
omega
case inr hneg =>
simp_all
omega
· intro h
rw [BitVec.msb_eq_decide]
simp
omega

/-- The MSB of a bitvector is `true` iff its numerical value is larger than half the bitwidth. -/
theorem msb_eq_true_iff_large {x : BitVec w}
: x.msb = true ↔ 2 * x.toNat ≥ 2^w := by
rcases w with rfl | w <;> try simp <;> try omega
constructor
· intro h
rw [BitVec.msb_eq_decide] at h
simp at h
omega
· intro h
rw [BitVec.msb_eq_decide]
simp
omega
/-- The MSB of a bitvector is `false` iff its numerical value is smaller than half the bitwidth. -/
theorem msb_eq_false_iff_small {x : BitVec w}
: x.msb = false ↔ 2 * x.toNat < 2^w := by
rcases w with rfl | w <;> try simp <;> try omega
constructor
· intro h
rw [BitVec.msb_eq_decide] at h
simp at h
omega
· intro h
rw [BitVec.msb_eq_decide]
simp
omega

/-- The MSB of a bitvector is `false` iff its integer interpretetation is greater than or equal to zero. -/
theorem msb_eq_false_iff_toInt_geq_zero {x : BitVec w}
: x.msb = false ↔ x.toInt ≥ 0 := by
constructor
· intro h
rw [toInt_eq_toNat_cond]
have hsize := msb_eq_false_iff_small.mp h
split <;> omega
· intro h
have hx : x.toNat < 2^w := by exact isLt x
rw [BitVec.msb_eq_decide]
simp
rw [BitVec.toInt_eq_toNat_cond] at h
split at h <;> try omega
case inl hsz =>
norm_cast
simp_all
rcases w with rfl | w
· simp
· simp [Nat.lt_succ_iff]
omega

/-! ### append -/

theorem append_def (x : BitVec v) (y : BitVec w) :
Expand Down Expand Up @@ -948,6 +1121,11 @@ theorem neg_eq_not_add (x : BitVec w) : -x = ~~~x + 1 := by
have hx : x.toNat < 2^w := x.isLt
rw [Nat.sub_sub, Nat.add_comm 1 x.toNat, ← Nat.sub_sub, Nat.sub_add_cancel (by omega)]

-- @[simp, bv_toNat] theorem toInt_sub (x y : BitVec w) :
-- (x - y).toInt = ((x.toNat + (((2 : Nat) ^ w : Nat) - y.toNat)) : Int).bmod (2 ^ w) := by
-- simp [toInt_eq_toNat_bmod]
-- norm_cast
-- simp
/-! ### mul -/

theorem mul_def {n} {x y : BitVec n} : x * y = (ofFin <| x.toFin * y.toFin) := by rfl
Expand Down Expand Up @@ -1042,4 +1220,170 @@ theorem toNat_intMax_eq : (intMax w).toNat = 2^w - 1 := by
(ofBoolListLE bs).getMsb i = (decide (i < bs.length) && bs.getD (bs.length - 1 - i) false) := by
simp [getMsb_eq_getLsb]

theorem Int.negSucc_div_ofNat (a b : Nat) (hb : b ≠ 0) :
(Int.negSucc a) / (Int.ofNat b) = Int.negSucc (((a / b) : Nat)) := by
rcases b with rfl | b
· contradiction
· norm_cast

@[simp] theorem toInt_sshiftRight (x : BitVec n) (i : Nat) :
(x.sshiftRight i).toInt = (x.toInt >>> i).bmod (2^n) := by
rw [sshiftRight_eq, BitVec.toInt_ofInt]

@[simp] private theorem Int.ofNat_sub_ofNat_eq_ofNat_implies_eq_ofNat_sub {x y : Nat}
(h : (x : Int) - (y : Int) = Int.ofNat z) : (x : Int) - (y : Int) = ((x - y : Nat) : Int) := by
simp at h
omega
/-- (n₁ : Int) - (n₂ : Int) = (n₃ : Int) for n₁, n₂, n₃ naturals iff n₁ ≥ n₂ -/
private theorem Int.sub_sub_eq_ofNat_iff_geq {x y : Nat} :
(∃ (z : Nat), (x : Int) - (y : Int) = Int.ofNat z) ↔ x ≥ y := by
constructor
· intros h
simp only [Int.ofNat_eq_coe] at h
omega
· intro h
exists (x - y)
simp only [Int.ofNat_eq_coe]
omega

/-- (n₁ : Int) - (n₂ : Int) = (negSucc n₃) for n₁, n₂, n₃ naturals iff n₁ < n₂ -/
private theorem Int.sub_sub_eq_negSucc_implies_le {x y : Nat}
(h : ∃ (z : Nat), (x : Int) - (y : Int) = Int.negSucc z) : x < y := by
have hxlty:(x < y) ∨ (x >= y) := by omega
rcases hxlty with hxlty | hxlty
· omega
· obtain ⟨z, hz⟩ := Int.sub_sub_eq_ofNat_iff_geq.mpr hxlty
rw [hz] at h
obtain ⟨w, hw⟩ := h
contradiction

theorem Int.toNat_sub_toNat_eq_negSucc_ofLt {n m : Nat} (hlt : n < m) :
(n : Int) - (m : Int) = (Int.negSucc (m - 1 - n)) := by
rw [Int.negSucc_eq] -- TODO: consider adding this to omega cleanup set.
omega

/-- Testing the ith bit of `x.toInt` is the same as testing `x.toNat`-/
theorem testBit_toInt_eq_testBit_toNat {x : BitVec w} (hi : i < w) :
x.toInt.testBit i = x.toNat.testBit i := by
rw [BitVec.toInt]
rcases w with rfl | w
· omega
· by_cases hx : x.toNat < 2^w
· have hx' : 2 * x.toNat < 2 ^ (w + 1) := by omega
simp [hx']
· have hx' : ¬ (2 * x.toNat < 2 ^ (w + 1)) := by omega
simp [hx']
rw [Int.toNat_sub_toNat_eq_negSucc_ofLt (by omega)]
simp
suffices 2 ^ (w + 1) - 1 - x.toNat = (~~~ x).toNat by
rw [this]
rw [← getLsb]
simp [hi]
rfl
simp

/-- the value of testBit of the integer value,
when the index being tested is larger or equal to the bitwidth is the msb of the bitvector. -/
theorem testBit_toInt_eq_msb {x : BitVec w} (hi : i ≥ w) :
x.toInt.testBit i = x.msb := by
rw [BitVec.toInt]
split
case inl h =>
rw [msb_eq_false_iff_small.mpr h]
simp only [Int.testBit_natCast]
rw [testBit_toNat]
exact getLsb_ge x i hi
case inr h =>
rcases w with rfl | w
· simp at h
· have hx' : x.toNat < 2^(w + 1) := by omega
have hz := Int.toNat_sub_toNat_eq_negSucc_ofLt hx'
rw [hz]
simp
rw [Nat.testBit]
rw [Nat.shiftRight_eq_div_pow]
rw [Nat.div_eq_of_lt]
· simp only [Nat.and_one_is_mod, Nat.zero_mod, bne_self_eq_false, Bool.not_false, Bool.true_eq]
rw [msb_eq_true_iff_large]
omega
· rw [Nat.pow_succ]
have hpow : 2^w < 2^i := by
refine (Nat.pow_lt_pow_iff_right ?h).mpr hi
omega
omega
/--
info: 'BitVec.testBit_toInt_eq_msb' depends on axioms: [propext, Classical.choice, Quot.sound]
-/
#guard_msgs in #print axioms testBit_toInt_eq_msb


private theorem Int.mod_ofNat_eq (m : Nat) (n : Int) :
(Int.ofNat m) % n = Int.ofNat (m % Int.natAbs n) := rfl

private theorem Int.mod_negSucc_eq (m : Nat) (n : Int) :
(Int.negSucc m) % n = Int.subNatNat (Int.natAbs n) (Nat.succ (m % Int.natAbs n)) := rfl

/-- if the msb is false, the arithmetic shift right equals logical shift right -/
theorem sshiftRight_eq_of_msb_false {x : BitVec w} {s : Nat} (h : x.msb = false) :
(x.sshiftRight s) = x.ushiftRight s := by
rcases w with rfl | w
· rw [Subsingleton.elim x (0#0)]
simp [ushiftRight, sshiftRight]
rfl
· apply BitVec.eq_of_toNat_eq
rw [BitVec.sshiftRight_eq]
rw [BitVec.toInt_eq_toNat_cond]
have hxbound : 2 * x.toNat < 2 ^ (w + 1) := BitVec.msb_eq_false_iff_small.mp h
simp [hxbound]
rw [Nat.mod_eq_of_lt]
rw [Nat.shiftRight_eq_div_pow]
suffices x.toNat / 2^s ≤ x.toNat by
apply Nat.lt_of_le_of_lt this x.isLt
apply Nat.div_le_self

/-- if the msb is true, the arithmetic shift right equals negating, the logical shifting right, then negating again -/
theorem sshiftRight_eq_of_msb_true {x : BitVec w} {s : Nat} (h : x.msb = true) :
(x.sshiftRight s) = ~~~((~~~x).ushiftRight s) := by
rcases w with rfl | w
· apply BitVec.eq_of_toNat_eq
simp
· apply BitVec.eq_of_toNat_eq
rw [BitVec.sshiftRight_eq]
rw [BitVec.toInt_eq_toNat_cond]
have hxbound : (2 * x.toNat ≥ 2 ^ (w + 1)) := BitVec.msb_eq_true_iff_large.mp h
have hxbound' : ¬ (2 * x.toNat < 2 ^ (w + 1)) := by omega
simp [hxbound']
rw [Int.toNat_sub_toNat_eq_negSucc_ofLt (by omega)]
rw [Int.shiftRight_negSucc]
rw [Int.mod_negSucc_eq]
simp only [Int.natAbs_ofNat, Nat.succ_eq_add_one]
rw [Int.subNatNat_of_le]
simp
rw [Nat.mod_eq_of_lt]
omega
rw [Nat.shiftRight_eq_div_pow]
suffices (2 ^ (w + 1) - 1 - x.toNat) / 2 ^ s < 2 ^ (w + 1) - 1 by
omega
apply Nat.lt_of_le_of_lt (Nat.div_le_self _ _) (by omega)
omega

theorem getLsb_sshiftRight (x : BitVec w) (s i : Nat) :
getLsb (x.sshiftRight s) i =
if i ≥ w then false
else if (s + i) < w then getLsb x (s + i)
else x.msb := by
rcases hmsb:x.msb with rfl | rfl
· simp [sshiftRight_eq_of_msb_false hmsb]
by_cases hi : i ≥ w <;> simp [hi]
· apply getLsb_ge
omega
· intros hlsb
apply BitVec.lt_of_getLsb _ _ hlsb
· by_cases hi:(i ≥ w)
· simp [hi]
· simp [hi, sshiftRight_eq_of_msb_true hmsb] <;> omega

/-- info: 'BitVec.getLsb_sshiftRight' depends on axioms: [propext, Quot.sound, Classical.choice] -/
#guard_msgs in #print axioms getLsb_sshiftRight

end BitVec
Loading
Loading