Skip to content

Commit

Permalink
Small integer arithmetics for values smaller that 64 bits.
Browse files Browse the repository at this point in the history
  • Loading branch information
markkurossi committed Dec 9, 2023
1 parent 9fe9317 commit cf5b992
Show file tree
Hide file tree
Showing 3 changed files with 308 additions and 40 deletions.
71 changes: 46 additions & 25 deletions compiler/mpa/mpint.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,11 @@ func (z *Int) BitLen() int {
// or greater than x.
func (z *Int) Cmp(x *Int) int {
if z.small && x.small {
if z.i64 < x.i64 {
zi64 := z.Int64()
xi64 := x.Int64()
if zi64 < xi64 {
return -1
} else if z.i64 > x.i64 {
} else if zi64 > xi64 {
return 1
} else {
return 0
Expand Down Expand Up @@ -122,7 +124,12 @@ func (z *Int) signed(signBit int32) *big.Int {
// represented as int64, the result is undefined.
func (z *Int) Int64() int64 {
if z.small {
return z.i64
signBit := int64(0x1) << (z.bits - 1)
if z.bits == 64 || z.i64&signBit == 0 {
return z.i64
}
signBit <<= 1
return -(signBit - z.i64)
}
return z.values.Int64()
}
Expand All @@ -137,31 +144,33 @@ func (z *Int) String() string {
// Add sets z to x+y and returns z.
func (z *Int) Add(x, y *Int) *Int {
if x.small && y.small {
z.setShort(x.i64 + y.i64)
z.bits = max(x.bits, y.bits)
z.setSmall(x.i64 + y.i64)
return z
}
return z.bin(circuits.NewAdder, x, y)
}

// And sets z to x&y and returns z.
func (z *Int) And(x, y *Int) *Int {
z.bits = max(x.bits, y.bits)
if x.small && y.small {
z.setShort(x.i64 & y.i64)
z.setSmall(x.i64 & y.i64)
} else {
z.values = big.NewInt(0).And(x.big(), y.big())
z.small = false
}
z.bits = max(x.bits, y.bits)
return z
}

// Div sets z to x/y and returns z.
func (z *Int) Div(x, y *Int) *Int {
if x.small && y.small {
z.bits = max(x.bits, y.bits)
if y.i64 == 0 {
z.setShort(-1)
z.setSmall(-1)
} else {
z.setShort(x.i64 / y.i64)
z.setSmall(x.i64 / y.i64)
}
return z
}
Expand Down Expand Up @@ -220,7 +229,8 @@ func (z *Int) Div(x, y *Int) *Int {
// Lsh sets z to x<<n and returns z.
func (z *Int) Lsh(x *Int, n uint) *Int {
if x.small {
z.setShort(x.i64 << n)
z.bits = x.bits
z.setSmall(x.i64 << n)
return z
}
if z != x {
Expand All @@ -238,10 +248,11 @@ func (z *Int) Lsh(x *Int, n uint) *Int {
// Mod sets z to x%y and returns z.
func (z *Int) Mod(x, y *Int) *Int {
if x.small && y.small {
z.bits = max(x.bits, y.bits)
if y.i64 == 0 {
z.setShort(x.i64)
z.setSmall(x.i64)
} else {
z.setShort(x.i64 % y.i64)
z.setSmall(x.i64 % y.i64)
}
return z
}
Expand Down Expand Up @@ -300,7 +311,8 @@ func (z *Int) Mod(x, y *Int) *Int {
// Mul sets z to x*y and returns z.
func (z *Int) Mul(x, y *Int) *Int {
if x.small && y.small {
z.setShort(x.i64 * y.i64)
z.bits = max(x.bits, y.bits)
z.setSmall(x.i64 * y.i64)
return z
}
return z.bin(func(cc *circuits.Compiler, x, y, z []*circuits.Wire) error {
Expand All @@ -310,20 +322,21 @@ func (z *Int) Mul(x, y *Int) *Int {

// Or sets z to x|y and returns z.
func (z *Int) Or(x, y *Int) *Int {
z.bits = max(x.bits, y.bits)
if x.small && y.small {
z.setShort(x.i64 | y.i64)
z.setSmall(x.i64 | y.i64)
} else {
z.values = big.NewInt(0).Or(x.big(), y.big())
z.small = false
}
z.bits = max(x.bits, y.bits)
return z
}

// Rsh sets z to x>>n and returns z.
func (z *Int) Rsh(x *Int, n uint) *Int {
if x.small {
z.setShort(x.i64 >> n)
z.bits = x.bits
z.setSmall(x.i64 >> n)
return z
}
if z != x {
Expand All @@ -335,10 +348,10 @@ func (z *Int) Rsh(x *Int, n uint) *Int {
return z
}

// SetBig sets z to x and returns z.
func (z *Int) SetBig(x *big.Int) *Int {
func (z *Int) setBig(x *big.Int) *Int {
if x.IsInt64() {
z.setShort(x.Int64())
z.bits = 64
z.setSmall(x.Int64())
return z
}
z.bits = int32(x.BitLen())
Expand All @@ -350,10 +363,17 @@ func (z *Int) SetBig(x *big.Int) *Int {
return z
}

func (z *Int) setShort(x int64) {
z.i64 = x
z.values = nil
func (z *Int) setSmall(x int64) {
if z.bits > 64 {
panic(fmt.Sprintf("Int.setSmall: bits=%v > 64", z.bits))
}

mask := uint64(0xffffffffffffffff)
mask >>= 64 - z.bits
z.i64 = int64(uint64(x) & mask)

z.small = true
z.values = nil
}

// SetString sets z to s according to its ascii value. The argument
Expand All @@ -363,7 +383,7 @@ func (z *Int) SetString(s string, base int) (*Int, bool) {
if !ok {
return nil, false
}
z.SetBig(i)
z.setBig(i)
return z, true
}

Expand All @@ -384,21 +404,22 @@ func (z *Int) Sign() int {
// Sub sets z to x-y and returns z.
func (z *Int) Sub(x, y *Int) *Int {
if x.small && y.small {
z.setShort(x.i64 - y.i64)
z.bits = max(x.bits, y.bits)
z.setSmall(x.i64 - y.i64)
return z
}
return z.bin(circuits.NewSubtractor, x, y)
}

// Xor sets z to x^y and returns z.
func (z *Int) Xor(x, y *Int) *Int {
z.bits = max(x.bits, y.bits)
if x.small && y.small {
z.setShort(x.i64 ^ y.i64)
z.setSmall(x.i64 ^ y.i64)
} else {
z.values = big.NewInt(0).Xor(x.big(), y.big())
z.small = false
}
z.bits = max(x.bits, y.bits)
return z
}

Expand Down
Loading

0 comments on commit cf5b992

Please sign in to comment.