diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 13f4bf8a8a95..3a300a47146c 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -9,6 +9,11 @@ import Init.Data.Bool import Init.Data.BitVec.Basic import Init.Data.Fin.Lemmas import Init.Data.Nat.Lemmas +import Init.Data.Int.Bitwise.Lemmas +import Init.Data.BitVec.Basic +import Init.Data.Nat.Div +import Init.Data.Int.DivModLemmas +import Init.Data.Int.DivMod namespace BitVec @@ -140,12 +145,14 @@ theorem ofBool_eq_iff_eq : ∀(b b' : Bool), BitVec.ofBool b = BitVec.ofBool b' @[simp, bv_toNat] theorem toNat_ofNat (x w : Nat) : (x#w).toNat = x % 2^w := by simp [BitVec.toNat, BitVec.ofNat, Fin.ofNat'] + -- Remark: we don't use `[simp]` here because simproc` subsumes it for literals. -- If `x` and `n` are not literals, applying this theorem eagerly may not be a good idea. theorem getLsb_ofNat (n : Nat) (x : Nat) (i : Nat) : getLsb (x#n) i = (i < n && x.testBit i) := by simp [getLsb, BitVec.ofNat, Fin.val_ofNat'] + @[simp, deprecated toNat_ofNat] theorem toNat_zero (n : Nat) : (0#n).toNat = 0 := by trivial @[simp] theorem getLsb_zero : (0#w).getLsb i = false := by simp [getLsb] @@ -266,6 +273,68 @@ theorem toInt_ofNat {n : Nat} (x : Nat) : have p : 0 ≤ i % (2^n : Nat) := by omega simp [toInt_eq_toNat_bmod, Int.toNat_of_nonneg p] +/-- 'BitVec.ofInt' only cares about values upto `emod`. -/ +theorem ofInt_eq_ofInt_emod {n : Nat} (i : Int) : + (BitVec.ofInt n i) = BitVec.ofInt n (i % (2^n)) := by + apply BitVec.eq_of_toNat_eq + simp only [toNat_ofInt] + congr 1 + have h2n : (2 : Int)^n = (((2 : Nat)^(n : Nat) : Nat) : Int) := by + rw [Int.natCast_pow] + rfl + rw [h2n, Int.emod_emod] + +-- /- A variant of emod that knows that the output is a natural number. -/ +-- abbrev Int.emod' (i : Int) (n : Nat) : Nat := +-- (Int.emod i n).toNat + +/-- Write ofInt in terms of `ofNat` +of the canonical natural number between 0 and 2^n.-/ +theorem ofInt_eq_ofNat_emod {n : Nat} (i : Int) : + (BitVec.ofInt n i) = BitVec.ofNat n (i % (2^n)).toNat := by + have h2n : (2 : Int)^n = (((2 : Nat)^(n : Nat) : Nat) : Int) := by + rw [Int.natCast_pow] + rfl + apply BitVec.eq_of_toNat_eq + simp only [ofInt_eq_ofInt_emod, Int.emod_emod, toNat_ofInt, toNat_ofNat] + conv => lhs; simp [(· % ·), Mod.mod] + rw [Nat.mod_eq_of_lt] + rw [h2n] + congr 1 + · apply Int.emod_emod + · have hlt : (i % 2^n) < 2^n := by + apply Int.emod_lt_of_pos + rw [h2n] + norm_cast + apply Nat.pow_pos + decide + rw [Int.toNat_lt] + · rw [h2n] + apply Int.emod_lt_of_pos + rw [Int.natCast_pow] + /- (0 : Int) < ↑2 ^ n -/ + clear hlt h2n + apply Lean.Omega.Int.pos_pow_of_pos + decide + · apply Int.emod_nonneg + /- (2 : Int) ^ n ≠ (0 : Int) -/ + apply Int.ne_of_gt + apply Lean.Omega.Int.pos_pow_of_pos + decide + +@[simp] +theorem Int.testBit_natCast (n : Nat) : (n : Int).testBit i = n.testBit i := rfl + +theorem Int.natCast_mod_natCast (n m : Nat) : (n : Int) % (m : Int) = ((n % m : Nat) : Int) := rfl + +theorem Int.negSucc_mod_natCast (n m : Nat) : + (Int.negSucc m) % (n : Int) = n - ((m % n) + 1) := rfl + +@[simp] theorem Int.toNat_natCast (n : Nat) : (n : Int).toNat = n := rfl + +@[simp] theorem ofInt_natCast (w n : Nat) : + BitVec.ofInt w n = BitVec.ofNat w n := rfl + /-! ### zeroExtend and truncate -/ @[simp, bv_toNat] theorem toNat_zeroExtend' {m n : Nat} (p : m ≤ n) (x : BitVec m) : @@ -628,6 +697,110 @@ theorem BitVec.shiftLeft_shiftLeft {w : Nat} (x : BitVec w) (n m : Nat) : getLsb (x >>> i) j = getLsb x (i+j) := by unfold getLsb ; simp +/-! ### sshiftRight-/ + +theorem sshiftRight_eq {x : BitVec n} {i : Nat} : + x.sshiftRight i = BitVec.ofInt n (x.toInt >>> i) := by + apply BitVec.eq_of_toInt_eq + simp [BitVec.sshiftRight] + + +theorem BitVec.toInt_eq_toNat_of_toInt_pos {x : BitVec n} (hx: x.toInt ≥ 0) : + x.toInt = ↑ (x.toNat) := by + rw [toInt_eq_toNat_cond] + simp + intros hx' + simp [BitVec.toInt] at hx + split at hx <;> omega + + +-- theorem sshiftRight_eq_ushiftRight_of_pos {x : BitVec n} (hx: x.toInt ≥ 0) : +-- x.sshiftRight i = x.ushiftRight i := by +-- rw [sshiftRight_eq, BitVec.ofInt_eq_ofNat_emod] +-- rw [BitVec.toInt_eq_toNat_of_toInt_pos hx] +-- rw [ushiftRight_eq] +-- rw [Int.emod_eq_of_lt] +-- · norm_cast +-- sorry +-- · /- need norm_num-/ +-- sorry +-- · have ⟨x', hx'lt⟩ := x +-- simp +-- /- ↑x' >>> i < 2 ^ n -/ +-- /- need norm_num -/ +-- sorry + +/-- The MSB of a bitvector is `true` iff its integer interpretetation is greater than or equal to zero. -/ +theorem msb_eq_true_iff_toInt_lt_zero {x : BitVec w} + : x.msb = true ↔ x.toInt < 0 := by + rcases w with rfl | w <;> try simp <;> try omega + · rw [Subsingleton.elim x (0#0)] + simp + rw [BitVec.toInt_eq_toNat_cond] + constructor + · intro h + split + case inl hpos => + rw [BitVec.msb_eq_decide] at h + simp at h + omega + case inr hneg => + simp_all + omega + · intro h + rw [BitVec.msb_eq_decide] + simp + omega + +/-- The MSB of a bitvector is `true` iff its numerical value is larger than half the bitwidth. -/ +theorem msb_eq_true_iff_large {x : BitVec w} + : x.msb = true ↔ 2 * x.toNat ≥ 2^w := by + rcases w with rfl | w <;> try simp <;> try omega + constructor + · intro h + rw [BitVec.msb_eq_decide] at h + simp at h + omega + · intro h + rw [BitVec.msb_eq_decide] + simp + omega +/-- The MSB of a bitvector is `false` iff its numerical value is smaller than half the bitwidth. -/ +theorem msb_eq_false_iff_small {x : BitVec w} + : x.msb = false ↔ 2 * x.toNat < 2^w := by + rcases w with rfl | w <;> try simp <;> try omega + constructor + · intro h + rw [BitVec.msb_eq_decide] at h + simp at h + omega + · intro h + rw [BitVec.msb_eq_decide] + simp + omega + +/-- The MSB of a bitvector is `false` iff its integer interpretetation is greater than or equal to zero. -/ +theorem msb_eq_false_iff_toInt_geq_zero {x : BitVec w} + : x.msb = false ↔ x.toInt ≥ 0 := by + constructor + · intro h + rw [toInt_eq_toNat_cond] + have hsize := msb_eq_false_iff_small.mp h + split <;> omega + · intro h + have hx : x.toNat < 2^w := by exact isLt x + rw [BitVec.msb_eq_decide] + simp + rw [BitVec.toInt_eq_toNat_cond] at h + split at h <;> try omega + case inl hsz => + norm_cast + simp_all + rcases w with rfl | w + · simp + · simp [Nat.lt_succ_iff] + omega + /-! ### append -/ theorem append_def (x : BitVec v) (y : BitVec w) : @@ -948,6 +1121,11 @@ theorem neg_eq_not_add (x : BitVec w) : -x = ~~~x + 1 := by have hx : x.toNat < 2^w := x.isLt rw [Nat.sub_sub, Nat.add_comm 1 x.toNat, ← Nat.sub_sub, Nat.sub_add_cancel (by omega)] +-- @[simp, bv_toNat] theorem toInt_sub (x y : BitVec w) : +-- (x - y).toInt = ((x.toNat + (((2 : Nat) ^ w : Nat) - y.toNat)) : Int).bmod (2 ^ w) := by +-- simp [toInt_eq_toNat_bmod] +-- norm_cast +-- simp /-! ### mul -/ theorem mul_def {n} {x y : BitVec n} : x * y = (ofFin <| x.toFin * y.toFin) := by rfl @@ -1042,4 +1220,170 @@ theorem toNat_intMax_eq : (intMax w).toNat = 2^w - 1 := by (ofBoolListLE bs).getMsb i = (decide (i < bs.length) && bs.getD (bs.length - 1 - i) false) := by simp [getMsb_eq_getLsb] +theorem Int.negSucc_div_ofNat (a b : Nat) (hb : b ≠ 0) : + (Int.negSucc a) / (Int.ofNat b) = Int.negSucc (((a / b) : Nat)) := by + rcases b with rfl | b + · contradiction + · norm_cast + +@[simp] theorem toInt_sshiftRight (x : BitVec n) (i : Nat) : + (x.sshiftRight i).toInt = (x.toInt >>> i).bmod (2^n) := by + rw [sshiftRight_eq, BitVec.toInt_ofInt] + +@[simp] private theorem Int.ofNat_sub_ofNat_eq_ofNat_implies_eq_ofNat_sub {x y : Nat} + (h : (x : Int) - (y : Int) = Int.ofNat z) : (x : Int) - (y : Int) = ((x - y : Nat) : Int) := by + simp at h + omega +/-- (n₁ : Int) - (n₂ : Int) = (n₃ : Int) for n₁, n₂, n₃ naturals iff n₁ ≥ n₂ -/ +private theorem Int.sub_sub_eq_ofNat_iff_geq {x y : Nat} : + (∃ (z : Nat), (x : Int) - (y : Int) = Int.ofNat z) ↔ x ≥ y := by + constructor + · intros h + simp only [Int.ofNat_eq_coe] at h + omega + · intro h + exists (x - y) + simp only [Int.ofNat_eq_coe] + omega + +/-- (n₁ : Int) - (n₂ : Int) = (negSucc n₃) for n₁, n₂, n₃ naturals iff n₁ < n₂ -/ +private theorem Int.sub_sub_eq_negSucc_implies_le {x y : Nat} + (h : ∃ (z : Nat), (x : Int) - (y : Int) = Int.negSucc z) : x < y := by + have hxlty:(x < y) ∨ (x >= y) := by omega + rcases hxlty with hxlty | hxlty + · omega + · obtain ⟨z, hz⟩ := Int.sub_sub_eq_ofNat_iff_geq.mpr hxlty + rw [hz] at h + obtain ⟨w, hw⟩ := h + contradiction + +theorem Int.toNat_sub_toNat_eq_negSucc_ofLt {n m : Nat} (hlt : n < m) : + (n : Int) - (m : Int) = (Int.negSucc (m - 1 - n)) := by + rw [Int.negSucc_eq] -- TODO: consider adding this to omega cleanup set. + omega + +/-- Testing the ith bit of `x.toInt` is the same as testing `x.toNat`-/ +theorem testBit_toInt_eq_testBit_toNat {x : BitVec w} (hi : i < w) : + x.toInt.testBit i = x.toNat.testBit i := by +rw [BitVec.toInt] +rcases w with rfl | w +· omega +· by_cases hx : x.toNat < 2^w + · have hx' : 2 * x.toNat < 2 ^ (w + 1) := by omega + simp [hx'] + · have hx' : ¬ (2 * x.toNat < 2 ^ (w + 1)) := by omega + simp [hx'] + rw [Int.toNat_sub_toNat_eq_negSucc_ofLt (by omega)] + simp + suffices 2 ^ (w + 1) - 1 - x.toNat = (~~~ x).toNat by + rw [this] + rw [← getLsb] + simp [hi] + rfl + simp + +/-- the value of testBit of the integer value, + when the index being tested is larger or equal to the bitwidth is the msb of the bitvector. -/ +theorem testBit_toInt_eq_msb {x : BitVec w} (hi : i ≥ w) : + x.toInt.testBit i = x.msb := by +rw [BitVec.toInt] +split +case inl h => + rw [msb_eq_false_iff_small.mpr h] + simp only [Int.testBit_natCast] + rw [testBit_toNat] + exact getLsb_ge x i hi +case inr h => + rcases w with rfl | w + · simp at h + · have hx' : x.toNat < 2^(w + 1) := by omega + have hz := Int.toNat_sub_toNat_eq_negSucc_ofLt hx' + rw [hz] + simp + rw [Nat.testBit] + rw [Nat.shiftRight_eq_div_pow] + rw [Nat.div_eq_of_lt] + · simp only [Nat.and_one_is_mod, Nat.zero_mod, bne_self_eq_false, Bool.not_false, Bool.true_eq] + rw [msb_eq_true_iff_large] + omega + · rw [Nat.pow_succ] + have hpow : 2^w < 2^i := by + refine (Nat.pow_lt_pow_iff_right ?h).mpr hi + omega + omega +/-- +info: 'BitVec.testBit_toInt_eq_msb' depends on axioms: [propext, Classical.choice, Quot.sound] +-/ +#guard_msgs in #print axioms testBit_toInt_eq_msb + + +private theorem Int.mod_ofNat_eq (m : Nat) (n : Int) : + (Int.ofNat m) % n = Int.ofNat (m % Int.natAbs n) := rfl + +private theorem Int.mod_negSucc_eq (m : Nat) (n : Int) : + (Int.negSucc m) % n = Int.subNatNat (Int.natAbs n) (Nat.succ (m % Int.natAbs n)) := rfl + +/-- if the msb is false, the arithmetic shift right equals logical shift right -/ +theorem sshiftRight_eq_of_msb_false {x : BitVec w} {s : Nat} (h : x.msb = false) : + (x.sshiftRight s) = x.ushiftRight s := by + rcases w with rfl | w + · rw [Subsingleton.elim x (0#0)] + simp [ushiftRight, sshiftRight] + rfl + · apply BitVec.eq_of_toNat_eq + rw [BitVec.sshiftRight_eq] + rw [BitVec.toInt_eq_toNat_cond] + have hxbound : 2 * x.toNat < 2 ^ (w + 1) := BitVec.msb_eq_false_iff_small.mp h + simp [hxbound] + rw [Nat.mod_eq_of_lt] + rw [Nat.shiftRight_eq_div_pow] + suffices x.toNat / 2^s ≤ x.toNat by + apply Nat.lt_of_le_of_lt this x.isLt + apply Nat.div_le_self + +/-- if the msb is true, the arithmetic shift right equals negating, the logical shifting right, then negating again -/ +theorem sshiftRight_eq_of_msb_true {x : BitVec w} {s : Nat} (h : x.msb = true) : + (x.sshiftRight s) = ~~~((~~~x).ushiftRight s) := by + rcases w with rfl | w + · apply BitVec.eq_of_toNat_eq + simp + · apply BitVec.eq_of_toNat_eq + rw [BitVec.sshiftRight_eq] + rw [BitVec.toInt_eq_toNat_cond] + have hxbound : (2 * x.toNat ≥ 2 ^ (w + 1)) := BitVec.msb_eq_true_iff_large.mp h + have hxbound' : ¬ (2 * x.toNat < 2 ^ (w + 1)) := by omega + simp [hxbound'] + rw [Int.toNat_sub_toNat_eq_negSucc_ofLt (by omega)] + rw [Int.shiftRight_negSucc] + rw [Int.mod_negSucc_eq] + simp only [Int.natAbs_ofNat, Nat.succ_eq_add_one] + rw [Int.subNatNat_of_le] + simp + rw [Nat.mod_eq_of_lt] + omega + rw [Nat.shiftRight_eq_div_pow] + suffices (2 ^ (w + 1) - 1 - x.toNat) / 2 ^ s < 2 ^ (w + 1) - 1 by + omega + apply Nat.lt_of_le_of_lt (Nat.div_le_self _ _) (by omega) + omega + +theorem getLsb_sshiftRight (x : BitVec w) (s i : Nat) : + getLsb (x.sshiftRight s) i = + if i ≥ w then false + else if (s + i) < w then getLsb x (s + i) + else x.msb := by + rcases hmsb:x.msb with rfl | rfl + · simp [sshiftRight_eq_of_msb_false hmsb] + by_cases hi : i ≥ w <;> simp [hi] + · apply getLsb_ge + omega + · intros hlsb + apply BitVec.lt_of_getLsb _ _ hlsb + · by_cases hi:(i ≥ w) + · simp [hi] + · simp [hi, sshiftRight_eq_of_msb_true hmsb] <;> omega + +/-- info: 'BitVec.getLsb_sshiftRight' depends on axioms: [propext, Quot.sound, Classical.choice] -/ +#guard_msgs in #print axioms getLsb_sshiftRight + end BitVec diff --git a/src/Init/Data/Int/Bitwise.lean b/src/Init/Data/Int/Bitwise.lean index 2bcce0a8e8a8..d13d86f6c1c0 100644 --- a/src/Init/Data/Int/Bitwise.lean +++ b/src/Init/Data/Int/Bitwise.lean @@ -37,7 +37,7 @@ complement and shifts the value to the right. ```lean ( 0b0111:Int) >>> 1 = 0b0011 ( 0b1000:Int) >>> 1 = 0b0100 -(-0b1000:Int) >>> 1 = -0b0100 +(-0b1000:Int) >>> 1 = -0b010 (-0b0111:Int) >>> 1 = -0b0100 ``` -/ @@ -47,4 +47,17 @@ protected def shiftRight : Int → Nat → Int instance : HShiftRight Int Nat Int := ⟨.shiftRight⟩ +/- +### testBit +We define an operation for testing individual bits in the binary representation +of a number. +-/ + +-- -m = !m + 1 +-- -(m + 1) = -m - 1 = !m +/-- `testBit m n` returns whether the `(n+1)` least significant bit is `1` or `0`-/ +def testBit : Int → Nat → Bool + | .ofNat m, n => Nat.testBit m n + | .negSucc m, n => !(Nat.testBit m n) + end Int diff --git a/src/Init/Data/Int/Bitwise/Lemmas.lean b/src/Init/Data/Int/Bitwise/Lemmas.lean new file mode 100644 index 000000000000..d3edf650c62f --- /dev/null +++ b/src/Init/Data/Int/Bitwise/Lemmas.lean @@ -0,0 +1,98 @@ +/- +Copyright (c) 2023 Siddharth Bhat. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Siddharth Bhat, Jeremy Avigad +-- https://github.com/leanprover-community/mathlib4/blob/12b9e3064e35636d35f0c349bf2be29fb9b75bf2/Mathlib/Data/Int/Bitwise.lean#L490-L490 +-/ +prelude +import Init.Data.Nat.Bitwise.Basic +import Init.Data.Nat.Bitwise.Lemmas +import Init.Data.Int.Basic +import Init.Data.Int.Bitwise +import Init.Data.Int.Pow +import Init.Data.Bool +import Init.Data.Fin.Lemmas +import Init.Data.Nat.Lemmas +import Init.Omega.Int +import Init.Data.Int.DivMod +import Init.Data.Int.Order +import Init.Data.Nat.Dvd + +namespace Int + +theorem shiftRight_eq (n : Int) (s : Nat) : n >>> s = Int.shiftRight n s := rfl +@[simp] +theorem shiftRight_ofNat (n s : Nat) : (n : Int) >>> s = Int.ofNat (n >>> s) := rfl +theorem natCast_shiftRight (n s : Nat) : ((↑n) : Int) >>> s = n >>> s := rfl + +@[simp] +theorem shiftRight_negSucc (m : Nat) (n : Nat) : + -[m+1] >>> n = -[m >>>n +1] := rfl + +theorem Int.shiftRight_shiftRight (i : Int) (m n : Nat) : + i >>> m >>> n = i >>> (m + n) := by + cases i + case ofNat i => + simp [natCast_shiftRight, Nat.shiftRight_add] + case negSucc i => + simp [Int.shiftRight_negSucc, Nat.shiftRight_add] + +theorem shiftRight_eq_div_pow (m : Int) (n : Nat) : m >>> n = m / ((((2 : Nat) ^ n) : Nat) : Int) := by + rcases m + case ofNat m => + simp only [Int.ofNat_eq_coe, shiftRight_eq, Int.shiftRight, Nat.shiftRight_eq_div_pow] + simp [Int.natCast_pow] + case negSucc m => + rw [Int.shiftRight_negSucc] + rw [negSucc_ediv] + rw [Nat.shiftRight_eq_div_pow] + . norm_cast + · norm_cast + apply Nat.pow_pos + omega + +@[simp] +theorem zero_shiftRight (n : Nat) : (0 : Int) >>> n = 0 := by + simp [Int.shiftRight_eq_div_pow] + +@[simp] theorem zero_testBit (i : Nat) : Int.testBit 0 i = false := by + simp only [testBit, zero_shiftRight, Nat.zero_and, bne_self_eq_false, Nat.zero_testBit i] + + +private theorem Nat.mod2_cases (x : Nat) : (x % 2 = 0) ∨ (x % 2 = 1) := by omega +private theorem Int.mod2_cases (x : Int) : (x % 2 = 0) ∨ (x % 2 = 1) := by omega + +@[simp] theorem Int.mod2_ofNat_eq (x : Nat) : (Int.ofNat x % 2) = (x % 2) := by + simp [Int.mod_def'] + +@[simp] theorem Int.mod2_negSucc_eq (x : Nat) : (Int.negSucc x % 2) = (1 - x % 2) := by + simp only [mod_def'] + unfold Int.emod + simp only [subNatNat, Int.reduceAbs, Nat.succ_eq_add_one, Nat.reduceSubDiff, ofNat_eq_coe, + ofNat_emod, Nat.cast_ofNat_Int] + split <;> omega + +@[simp] theorem testBit_ofNat (x : Nat) (i : Nat) : (x : Int).testBit i = x.testBit i := rfl +@[simp] theorem testBit_negSucc (x : Nat) (i : Nat) : (Int.negSucc x).testBit i = !(x.testBit i) := rfl + +@[simp] theorem testBit_zero (x : Int) : Int.testBit x 0 = decide (x % 2 = 1) := by + rcases x with x | x + · simp only [ofNat_eq_coe, testBit_ofNat, Nat.testBit_zero, decide_eq_decide]; omega + · simp only [testBit_negSucc, Nat.testBit_zero, Int.mod2_negSucc_eq] + rcases (Nat.mod2_cases x) with h | h <;> simp_all <;> omega + +@[simp] theorem testBit_succ (x : Int) (i : Nat) : Int.testBit x (Nat.succ i) = testBit (x/2) i := by + unfold testBit + cases x <;> simp <;> rfl + + +-- theorem toNat_testBit (x i : Nat) : +-- (x.testBit i).toNat = x / 2 ^ i % 2 := by +-- rw [Nat.testBit_to_div_mod] +-- rcases Nat.mod_two_eq_zero_or_one (x / 2^i) <;> simp_all + +@[simp] theorem testBit_shiftRight (x : Int) (i j : Nat) : testBit (x >>> i) j = testBit x (i+j) := by + cases x <;> simp [testBit] + + +end Int