From ff253fcb7258f604ab28c15e55c7902b0a8d85f9 Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Mon, 14 Oct 2024 21:33:46 -0500 Subject: [PATCH] feat: toInt_abs We implement `toInt_abs`. A subtle wrinkle is to note that `abs (intMin w) = intMin w`, which complicates our proof. --- src/Init/Data/BitVec/Lemmas.lean | 237 +++++++++++++++++++++++++++++-- src/Init/Data/Int/Basic.lean | 7 + src/Init/Data/Int/Lemmas.lean | 24 ++++ 3 files changed, 258 insertions(+), 10 deletions(-) diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index ba6260e0d488..b5fb948dc5d8 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -206,6 +206,7 @@ theorem eq_of_getMsbD_eq {x y : BitVec w} theorem of_length_zero {x : BitVec 0} : x = 0#0 := by ext; simp theorem toNat_zero_length (x : BitVec 0) : x.toNat = 0 := by simp [of_length_zero] +theorem toInt_length_zero (x : BitVec 0) : x.toInt = 0 := by simp [of_length_zero] theorem getLsbD_zero_length (x : BitVec 0) : x.getLsbD i = false := by simp theorem getMsbD_zero_length (x : BitVec 0) : x.getMsbD i = false := by simp theorem msb_zero_length (x : BitVec 0) : x.msb = false := by simp [BitVec.msb, of_length_zero] @@ -2070,16 +2071,6 @@ theorem smod_zero {x : BitVec n} : x.smod 0#n = x := by · simp · by_cases h : x = 0#n <;> simp [h] -/-! ### abs -/ - -@[simp, bv_toNat] -theorem toNat_abs {x : BitVec w} : x.abs.toNat = if x.msb then 2^w - x.toNat else x.toNat := by - simp only [BitVec.abs, neg_eq] - by_cases h : x.msb = true - · simp only [h, ↓reduceIte, toNat_neg] - have : 2 * x.toNat ≥ 2 ^ w := BitVec.msb_eq_true_iff_two_mul_ge.mp h - rw [Nat.mod_eq_of_lt (by omega)] - · simp [h] /-! ### mul -/ @@ -2643,6 +2634,23 @@ theorem toInt_neg_of_ne_intMin {x : BitVec w} (rs : x ≠ intMin w) : have := @Nat.two_pow_pred_mul_two w (by omega) split <;> split <;> omega + +/-- The msb of `intMin w` is `true` for all `w > 0` -/ +@[simp] +theorem msb_intMin : (intMin w).msb = decide (w > 0) := by + rw [intMin] + rw [msb_eq_decide] + simp + rcases w with rfl | w + · rfl + · simp + have : 0 < 2^w := Nat.pow_pos (by decide) + have : 2^w < 2^(w + 1) := by + rw [Nat.pow_succ] + omega + rw [Nat.mod_eq_of_lt (by omega)] + simp + /-! ### intMax -/ /-- The bitvector of width `w` that has the largest value when interpreted as an integer. -/ @@ -2674,6 +2682,215 @@ theorem getLsbD_intMax (w : Nat) : (intMax w).getLsbD i = decide (i + 1 < w) := · rw [Nat.sub_add_cancel (Nat.two_pow_pos (w - 1)), Nat.two_pow_pred_mod_two_pow (by omega)] +/-! ### abs -/ + +theorem abs_def {x : BitVec w} : x.abs = if x.msb then .neg x else x := rfl + +-- @[simp] +-- theorem abs_intMin : (intMin w).abs = intMin w := by +-- rw [abs_def] +-- simp [msb_intMin] + + +-- @[simp] theorem toInt_zero (w : Nat) : (0#w).toInt = 0 := by +-- simp [BitVec.toInt] +-- omega + +/-- the msb is true iff the bitvector , when interpreted as a signed 2s complement number, is less than zero -/ +theorem msb_eq_decide_slt_zero (x : BitVec w) : x.msb = decide (x.slt 0#w) := by + simp only [BitVec.slt, toInt_eq_msb_cond, msb_zero, Bool.false_eq_true, + ↓reduceIte, toNat_ofNat, Nat.zero_mod, Int.Nat.cast_ofNat_Int, decide_eq_true_eq] + rcases h : x.msb <;> simp [h] <;> omega + + +/-- TODO: what should I name this lemmas? -/ +theorem abs_cases (x : BitVec w) : x.abs = + if x = intMin w then (intMin w) + else if x.slt 0 then -x + else x := by + · rw [abs_def] + rw [msb_eq_decide_slt_zero] + by_cases hx : x.slt 0#w <;> by_cases hx' : x = intMin w <;> simp [hx, hx'] + + +@[simp] +theorem toInt_of_length_zero (x : BitVec 0) : x.toInt = 0 := by + simp [BitVec.of_length_zero] + + +/- +Similar to toInt_eq_toNat_cond, but rewrites in terms of power of two manipulations, +instead of ugly `2 * x < 2^n`. +-/ +@[simp] +theorem toInt_eq_toNat_cond' (x : BitVec n) (hn : n > 0) : + x.toInt = + if x.toNat < 2^(n - 1) then + (x.toNat : Int) + else + (x.toNat : Int) - (2^n : Nat) := by + rcases n with _ | n <;> try contradiction + simp only [gt_iff_lt, Nat.zero_lt_succ, Nat.add_one_sub_one] at * + rw [BitVec.toInt_eq_toNat_cond] + simp only [Nat.pow_succ, Int.natCast_mul, Int.Nat.cast_ofNat_Int] + by_cases hx : x.toNat < 2 ^ n + · simp [show 2 * x.toNat < 2 ^ n * 2 by omega, hx] + · simp [show ¬ 2 * x.toNat < 2^ n * 2 by omega, hx] + +/-- info: 'BitVec.toInt_eq_toNat_cond'' depends on axioms: [propext, Quot.sound] -/ +#guard_msgs in #print axioms BitVec.toInt_eq_toNat_cond' + +-- TODO: make this the theorem, because this does not create 2^(w - 1) +-- nonsense. +-- TODO: make `msb` the simp normal form for checking if number is positive +-- or whatever. +@[bv_toNat] theorem msb_eq_decide' (x : BitVec w) : + BitVec.msb x = decide (2 ^ w ≤ 2 * x.toNat) := by + rw [x.msb_eq_getLsbD_last, x.getLsbD_last] + simp + rcases w with rfl | w <;> simp <;> omega + + +/-- info: 'BitVec.msb_eq_decide'' depends on axioms: [propext, Classical.choice, Quot.sound] -/ +#guard_msgs in #print axioms BitVec.msb_eq_decide' + +/- +Next thing we want to know: bounds on the value of `x.toInt` +-/ +theorem toInt_bounds_of_msb_eq_false {x : BitVec n} (hmsb : x.msb = false) : + 0 ≤ x.toInt ∧ 2 * x.toInt < 2^n := by + have := x.msb_eq_decide' + rw [hmsb] at this + simp only [false_eq_decide_iff, Nat.not_le] at this + rw [BitVec.toInt_eq_toNat_cond] + simp [this] + apply And.intro + · omega + · norm_cast + +/-- +info: 'BitVec.toInt_bounds_of_msb_eq_false' depends on axioms: [propext, Classical.choice, Quot.sound] +-/ +#guard_msgs in #print axioms BitVec.toInt_bounds_of_msb_eq_false + +theorem toInt_bounds_of_msb_eq_true {x : BitVec n} (hmsb : x.msb = true) : + -2^n ≤ x.toInt ∧ x.toInt < 0 := by + have := x.msb_eq_decide' + rw [hmsb] at this + simp only [true_eq_decide_iff] at this + rw [BitVec.toInt_eq_toNat_cond] + simp [show ¬ 2 * x.toNat < 2 ^ n by omega] + apply And.intro + · norm_cast + omega + · omega + +/-- +info: 'BitVec.toInt_bounds_of_msb_eq_true' depends on axioms: [propext, Classical.choice, Quot.sound] +-/ +#guard_msgs in #print axioms BitVec.toInt_bounds_of_msb_eq_true + +theorem toInt_intMin_eq_cases (n : Nat) : (BitVec.intMin n).toInt = + if n = 0 then 0 else - 2^(n - 1) := by + simp [BitVec.toInt_intMin] + rcases n with rfl | n + · simp + · simp + norm_cast + have : 2^n > 0 := by exact Nat.two_pow_pos n + have : 2^n < 2^(n + 1) := by + simp [Nat.pow_succ] + omega + rw [Nat.mod_eq_of_lt (by omega)] + +/-- info: 'BitVec.toInt_intMin_eq_cases' depends on axioms: [propext, Quot.sound] -/ +#guard_msgs in #print axioms BitVec.toInt_intMin_eq_cases + +-- ### TOINT OF NEG +/-- +Define the value of (BitVec.neg.toInt) as a case split +on whether `x` is intMin or not, and showing that when this +exception does not occur, the defn obeys what mathematics says it should +-/ +theorem toInt_neg_eq_cases {x : BitVec n} : + (-x).toInt = + if x = intMin n + then x.toInt + else - x.toInt := by + by_cases hx : x = intMin n + · simp [hx] + · simp [hx] + rw [toInt_neg_of_ne_intMin hx] + +/-- info: 'BitVec.toInt_neg_eq_cases' depends on axioms: [propext, Quot.sound] -/ +#guard_msgs in #print axioms BitVec.toInt_neg_eq_cases + +-- @[simp] +-- theorem Int.abs_neg (x : Int) : (-x).abs = x.abs := by +-- have hx : (-x < 0) ∨ (x = 0) ∨ (-x > 0) := by omega +-- rcases hx with hx | hx | hx +-- · rw [Int.abs_eq_neg (x := -x) hx, Int.neg_neg, Int.abs_eq_self (x := x) (by omega)] +-- · simp [hx] +-- · rw [Int.abs_eq_self (x := -x) (by omega), Int.abs_eq_neg (x := x) (by omega)] +-- +-- +-- /-- info: 'Int.abs_neg' depends on axioms: [propext, Quot.sound] -/ +-- #guard_msgs in #print axioms Int.abs_neg + + +theorem abs_cases' (x : BitVec w) : x.abs = + if x.msb = true then + if x = BitVec.intMin w then (BitVec.intMin w) else -x + else x := by + · rw [BitVec.abs_def] + by_cases hx : x.msb = true <;> by_cases hx' : x = BitVec.intMin w <;> simp [hx, hx'] + +/-- info: 'BitVec.abs_cases'' depends on axioms: [propext, Quot.sound] -/ +#guard_msgs in #print axioms BitVec.abs_cases' + + + +theorem toInt_intMin_eq_twoPow (hn : 0 < n) : (intMin n).toInt = -2^(n - 1) := by + -- Delete our toInt_intMin from simp set. + rw [BitVec.toInt_intMin_eq_cases] + simp [show ¬ n = 0 by omega] + +/-- info: 'BitVec.toInt_intMin_eq_twoPow' depends on axioms: [propext, Quot.sound] -/ +#guard_msgs in #print axioms BitVec.toInt_intMin_eq_twoPow + +theorem toInt_abs (x : BitVec w) : + x.abs.toInt = if x = (intMin w) then if w = 0 then 0 else - 2^(w - 1) else x.toInt.abs := by + rcases w with rfl | w + · simp + · simp only [gt_iff_lt, Nat.zero_lt_succ, Nat.add_one_ne_zero, ↓reduceIte] + rw [BitVec.abs_cases'] + by_cases hx : x = intMin (w + 1) + · simp only [hx, reduceIte] + have := BitVec.msb_intMin (w := w + 1) + rw [this] + simp only [gt_iff_lt, Nat.zero_lt_succ, decide_True, ↓reduceIte] + rw [BitVec.toInt_intMin_eq_cases] + simp + · simp only [hx, reduceIte] + rcases hmsb : x.msb + · simp only [Bool.false_eq_true, ↓reduceIte] + have := BitVec.toInt_bounds_of_msb_eq_false hmsb + rw [Int.abs_eq_self] + omega + · simp only [reduceIte] + have hxbounds := BitVec.toInt_bounds_of_msb_eq_true hmsb + rw [BitVec.toInt_neg_eq_cases] + have hxneq : x.toInt ≠ (intMin (w + 1)).toInt := by + rw [BitVec.toInt_ne] + exact hx + rw [BitVec.toInt_intMin_eq_twoPow (by omega)] at hxneq + -- TODO: remove toInt_eq_toNat_cond from simp set. + simp only [hx, reduceIte] + rw [Int.abs_eq_neg (by omega)] + +/-- info: 'BitVec.toInt_abs' depends on axioms: [propext, Classical.choice, Quot.sound] -/ +#guard_msgs in #print axioms BitVec.toInt_abs + /-! ### Non-overflow theorems -/ /-- If `x.toNat * y.toNat < 2^w`, then the multiplication `(x * y)` does not overflow. -/ diff --git a/src/Init/Data/Int/Basic.lean b/src/Init/Data/Int/Basic.lean index dbf661c4be1b..fed9719af55b 100644 --- a/src/Init/Data/Int/Basic.lean +++ b/src/Init/Data/Int/Basic.lean @@ -333,6 +333,13 @@ instance : Min Int := minOfLe instance : Max Int := maxOfLe +/-- +Return the absolute value of an integer. +-/ +def abs : Int → Int + | ofNat n => .ofNat n + | negSucc n => .ofNat n.succ + end Int /-- diff --git a/src/Init/Data/Int/Lemmas.lean b/src/Init/Data/Int/Lemmas.lean index 4b0e560fb00d..4b555d825373 100644 --- a/src/Init/Data/Int/Lemmas.lean +++ b/src/Init/Data/Int/Lemmas.lean @@ -531,4 +531,28 @@ theorem natCast_one : ((1 : Nat) : Int) = (1 : Int) := rfl @[simp] theorem natCast_mul (a b : Nat) : ((a * b : Nat) : Int) = (a : Int) * (b : Int) := by simp +/-! abs lemmas -/ + +@[simp] +theorem abs_eq_self {x : Int} (h : x ≥ 0) : x.abs = x := by + cases x + case ofNat h => + rfl + case negSucc h => + contradiction + +@[simp] +theorem Int.abs_zero : Int.abs 0 = 0 := rfl + +@[simp] +theorem abs_eq_neg {x : Int} (h : x < 0) : x.abs = -x := by + cases x + case ofNat h => + contradiction + case negSucc n => + rfl + +@[simp] +theorem ofNat_abs (x : Nat) : (x : Int).abs = (x : Int) := rfl + end Int