Skip to content

Commit cdbb789

Browse files
committed
fix: UB in shifts
1 parent 1effaa3 commit cdbb789

File tree

2 files changed

+24
-24
lines changed

2 files changed

+24
-24
lines changed

src/Init/Data/SInt/Basic.lean

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ def Int8.lor (a b : Int8) : Int8 := ⟨⟨a.toBitVec ||| b.toBitVec⟩⟩
9696
@[extern "lean_int8_xor"]
9797
def Int8.xor (a b : Int8) : Int8 := ⟨⟨a.toBitVec ^^^ b.toBitVec⟩⟩
9898
@[extern "lean_int8_shift_left"]
99-
def Int8.shiftLeft (a b : Int8) : Int8 := ⟨⟨a.toBitVec <<< (mod b 8).toBitVec⟩⟩
99+
def Int8.shiftLeft (a b : Int8) : Int8 := ⟨⟨a.toBitVec <<< (b.toBitVec.smod 8)⟩⟩
100100
@[extern "lean_int8_shift_right"]
101-
def Int8.shiftRight (a b : Int8) : Int8 := ⟨⟨BitVec.sshiftRight' a.toBitVec (mod b 8).toBitVec⟩⟩
101+
def Int8.shiftRight (a b : Int8) : Int8 := ⟨⟨BitVec.sshiftRight' a.toBitVec (b.toBitVec.smod 8)⟩⟩
102102
@[extern "lean_int8_complement"]
103103
def Int8.complement (a : Int8) : Int8 := ⟨⟨~~~a.toBitVec⟩⟩
104104

@@ -193,9 +193,9 @@ def Int16.lor (a b : Int16) : Int16 := ⟨⟨a.toBitVec ||| b.toBitVec⟩⟩
193193
@[extern "lean_int16_xor"]
194194
def Int16.xor (a b : Int16) : Int16 := ⟨⟨a.toBitVec ^^^ b.toBitVec⟩⟩
195195
@[extern "lean_int16_shift_left"]
196-
def Int16.shiftLeft (a b : Int16) : Int16 := ⟨⟨a.toBitVec <<< (mod b 16).toBitVec⟩⟩
196+
def Int16.shiftLeft (a b : Int16) : Int16 := ⟨⟨a.toBitVec <<< (b.toBitVec.smod 16)⟩⟩
197197
@[extern "lean_int16_shift_right"]
198-
def Int16.shiftRight (a b : Int16) : Int16 := ⟨⟨BitVec.sshiftRight' a.toBitVec (mod b 16).toBitVec⟩⟩
198+
def Int16.shiftRight (a b : Int16) : Int16 := ⟨⟨BitVec.sshiftRight' a.toBitVec (b.toBitVec.smod 16)⟩⟩
199199
@[extern "lean_int16_complement"]
200200
def Int16.complement (a : Int16) : Int16 := ⟨⟨~~~a.toBitVec⟩⟩
201201

@@ -294,9 +294,9 @@ def Int32.lor (a b : Int32) : Int32 := ⟨⟨a.toBitVec ||| b.toBitVec⟩⟩
294294
@[extern "lean_int32_xor"]
295295
def Int32.xor (a b : Int32) : Int32 := ⟨⟨a.toBitVec ^^^ b.toBitVec⟩⟩
296296
@[extern "lean_int32_shift_left"]
297-
def Int32.shiftLeft (a b : Int32) : Int32 := ⟨⟨a.toBitVec <<< (mod b 32).toBitVec⟩⟩
297+
def Int32.shiftLeft (a b : Int32) : Int32 := ⟨⟨a.toBitVec <<< (b.toBitVec.smod 32)⟩⟩
298298
@[extern "lean_int32_shift_right"]
299-
def Int32.shiftRight (a b : Int32) : Int32 := ⟨⟨BitVec.sshiftRight' a.toBitVec (mod b 32).toBitVec⟩⟩
299+
def Int32.shiftRight (a b : Int32) : Int32 := ⟨⟨BitVec.sshiftRight' a.toBitVec (b.toBitVec.smod 32)⟩⟩
300300
@[extern "lean_int32_complement"]
301301
def Int32.complement (a : Int32) : Int32 := ⟨⟨~~~a.toBitVec⟩⟩
302302

@@ -399,9 +399,9 @@ def Int64.lor (a b : Int64) : Int64 := ⟨⟨a.toBitVec ||| b.toBitVec⟩⟩
399399
@[extern "lean_int64_xor"]
400400
def Int64.xor (a b : Int64) : Int64 := ⟨⟨a.toBitVec ^^^ b.toBitVec⟩⟩
401401
@[extern "lean_int64_shift_left"]
402-
def Int64.shiftLeft (a b : Int64) : Int64 := ⟨⟨a.toBitVec <<< (mod b 64).toBitVec⟩⟩
402+
def Int64.shiftLeft (a b : Int64) : Int64 := ⟨⟨a.toBitVec <<< (b.toBitVec.smod 64)⟩⟩
403403
@[extern "lean_int64_shift_right"]
404-
def Int64.shiftRight (a b : Int64) : Int64 := ⟨⟨BitVec.sshiftRight' a.toBitVec (mod b 64).toBitVec⟩⟩
404+
def Int64.shiftRight (a b : Int64) : Int64 := ⟨⟨BitVec.sshiftRight' a.toBitVec (b.toBitVec.smod 64)⟩⟩
405405
@[extern "lean_int64_complement"]
406406
def Int64.complement (a : Int64) : Int64 := ⟨⟨~~~a.toBitVec⟩⟩
407407

src/include/lean/lean.h

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1954,16 +1954,16 @@ static inline uint8_t lean_int8_xor(uint8_t a1, uint8_t a2) {
19541954

19551955
static inline uint8_t lean_int8_shift_right(uint8_t a1, uint8_t a2) {
19561956
int8_t lhs = (int8_t)a1;
1957-
int8_t rhs = (int8_t)a2;
1957+
int8_t rhs = (((int8_t)a2 % 8) + 8) % 8; // this is smod 8
19581958

1959-
return (uint8_t)(lhs >> (rhs % 8));
1959+
return (uint8_t)(lhs >> rhs);
19601960
}
19611961

19621962
static inline uint8_t lean_int8_shift_left(uint8_t a1, uint8_t a2) {
19631963
int8_t lhs = (int8_t)a1;
1964-
int8_t rhs = (int8_t)a2;
1964+
int8_t rhs = (((int8_t)a2 % 8) + 8) % 8; // this is smod 8
19651965

1966-
return (uint8_t)(lhs << (rhs % 8));
1966+
return (uint8_t)(lhs << rhs);
19671967
}
19681968

19691969
static inline uint8_t lean_int8_complement(uint8_t a) {
@@ -2094,16 +2094,16 @@ static inline uint16_t lean_int16_xor(uint16_t a1, uint16_t a2) {
20942094

20952095
static inline uint16_t lean_int16_shift_right(uint16_t a1, uint16_t a2) {
20962096
int16_t lhs = (int16_t)a1;
2097-
int16_t rhs = (int16_t)a2;
2097+
int16_t rhs = (((int16_t)a2 % 16) + 16) % 16; // this is smod 16
20982098

2099-
return (uint16_t)(lhs >> (rhs % 16));
2099+
return (uint16_t)(lhs >> rhs);
21002100
}
21012101

21022102
static inline uint16_t lean_int16_shift_left(uint16_t a1, uint16_t a2) {
21032103
int16_t lhs = (int16_t)a1;
2104-
int16_t rhs = (int16_t)a2;
2104+
int16_t rhs = (((int16_t)a2 % 16) + 16) % 16; // this is smod 16
21052105

2106-
return (uint16_t)(lhs << (rhs % 16));
2106+
return (uint16_t)(lhs << rhs);
21072107
}
21082108

21092109
static inline uint16_t lean_int16_complement(uint16_t a) {
@@ -2233,16 +2233,16 @@ static inline uint32_t lean_int32_xor(uint32_t a1, uint32_t a2) {
22332233

22342234
static inline uint32_t lean_int32_shift_right(uint32_t a1, uint32_t a2) {
22352235
int32_t lhs = (int32_t)a1;
2236-
int32_t rhs = (int32_t)a2;
2236+
int32_t rhs = (((int32_t)a2 % 32) + 32) % 32; // this is smod 32
22372237

2238-
return (uint32_t)(lhs >> (rhs % 32));
2238+
return (uint32_t)(lhs >> rhs);
22392239
}
22402240

22412241
static inline uint32_t lean_int32_shift_left(uint32_t a1, uint32_t a2) {
22422242
int32_t lhs = (int32_t)a1;
2243-
int32_t rhs = (int32_t)a2;
2243+
int32_t rhs = (((int32_t)a2 % 32) + 32) % 32; // this is smod 32
22442244

2245-
return (uint32_t)(lhs << (rhs % 32));
2245+
return (uint32_t)(lhs << rhs);
22462246
}
22472247

22482248
static inline uint32_t lean_int32_complement(uint32_t a) {
@@ -2372,16 +2372,16 @@ static inline uint64_t lean_int64_xor(uint64_t a1, uint64_t a2) {
23722372

23732373
static inline uint64_t lean_int64_shift_right(uint64_t a1, uint64_t a2) {
23742374
int64_t lhs = (int64_t)a1;
2375-
int64_t rhs = (int64_t)a2;
2375+
int64_t rhs = (((int64_t)a2 % 64) + 64) % 64; // this is smod 64
23762376

2377-
return (uint64_t)(lhs >> (rhs % 64));
2377+
return (uint64_t)(lhs >> rhs);
23782378
}
23792379

23802380
static inline uint64_t lean_int64_shift_left(uint64_t a1, uint64_t a2) {
23812381
int64_t lhs = (int64_t)a1;
2382-
int64_t rhs = (int64_t)a2;
2382+
int64_t rhs = (((int64_t)a2 % 64) + 64) % 64; // this is smod 64
23832383

2384-
return (uint64_t)(lhs << (rhs % 64));
2384+
return (uint64_t)(lhs << rhs);
23852385
}
23862386

23872387
static inline uint64_t lean_int64_complement(uint64_t a) {

0 commit comments

Comments
 (0)