Skip to content

Commit 865159d

Browse files
authored
internal/bigmod: add more //go:norace annotations and refactoring
1 parent 0d56114 commit 865159d

File tree

3 files changed

+88
-66
lines changed

3 files changed

+88
-66
lines changed

internal/bigmod/nat.go

Lines changed: 52 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,15 @@ const (
1818
_S = _W / 8
1919
)
2020

21+
// Note: These functions make many loops over all the words in a Nat.
22+
// These loops used to be in assembly, invisible to -race, -asan, and -msan,
23+
// but now they are in Go and incur significant overhead in those modes.
24+
// To bring the old performance back, we mark all functions that loop
25+
// over Nat words with //go:norace. Because //go:norace does not
26+
// propagate across inlining, we must also mark functions that inline
27+
// //go:norace functions - specifically, those that inline add, addMulVVW,
28+
// assign, cmpGeq, rshift1, and sub.
29+
2130
// choice represents a constant-time boolean. The value of choice is always
2231
// either 1 or 0. We use an int instead of bool in order to make decisions in
2332
// constant time by turning it into a mask.
@@ -40,14 +49,6 @@ func ctEq(x, y uint) choice {
4049
return not(choice(c1 | c2))
4150
}
4251

43-
// ctGeq returns 1 if x >= y, and 0 otherwise. The execution time of this
44-
// function does not depend on its inputs.
45-
func ctGeq(x, y uint) choice {
46-
// If x < y, then x - y generates a carry.
47-
_, carry := bits.Sub(x, y, 0)
48-
return not(choice(carry))
49-
}
50-
5152
// Nat represents an arbitrary natural number
5253
//
5354
// Each Nat has an announced length, which is the number of limbs it has stored.
@@ -84,6 +85,7 @@ func (x *Nat) expand(n int) *Nat {
8485
return x
8586
}
8687
extraLimbs := x.limbs[len(x.limbs):n]
88+
// clear(extraLimbs)
8789
for i := range extraLimbs {
8890
extraLimbs[i] = 0
8991
}
@@ -97,6 +99,7 @@ func (x *Nat) reset(n int) *Nat {
9799
x.limbs = make([]uint, n)
98100
return x
99101
}
102+
// clear(x.limbs)
100103
for i := range x.limbs {
101104
x.limbs[i] = 0
102105
}
@@ -131,7 +134,7 @@ func (x *Nat) trim() *Nat {
131134
}
132135

133136
// set assigns x = y, optionally resizing x to the appropriate size.
134-
func (x *Nat) Set(y *Nat) *Nat {
137+
func (x *Nat) set(y *Nat) *Nat {
135138
x.reset(len(y.limbs))
136139
copy(x.limbs, y.limbs)
137140
return x
@@ -164,12 +167,14 @@ func (x *Nat) Bytes(m *Modulus) []byte {
164167
// SetBytes returns an error if b >= m.
165168
//
166169
// The output will be resized to the size of m and overwritten.
170+
//
171+
//go:norace
167172
func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) {
168173
x.resetFor(m)
169174
if err := x.setBytes(b); err != nil {
170175
return nil, err
171176
}
172-
if x.CmpGeq(m.nat) == yes {
177+
if x.cmpGeq(m.nat) == yes {
173178
return nil, errors.New("input overflows the modulus")
174179
}
175180
return x, nil
@@ -195,20 +200,6 @@ func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) {
195200
return x, nil
196201
}
197202

198-
// SetOverflowedBytes assigns x = (b mode (m-1)) + 1, where b is a slice of big-endian bytes.
199-
//
200-
// The output will be resized to the size of m and overwritten.
201-
func (x *Nat) SetOverflowedBytes(b []byte, m *Modulus) *Nat {
202-
mMinusOne := NewNat().Set(m.nat)
203-
mMinusOne.limbs[0]-- // due to m is odd, so we can safely subtract 1
204-
one := NewNat().resetFor(m)
205-
one.limbs[0] = 1
206-
x.resetToBytes(b)
207-
x = NewNat().modNat(x, mMinusOne) // x = x mod (m-1)
208-
x.add(one) // we can safely add 1, no need to check overflow
209-
return x
210-
}
211-
212203
// bigEndianUint returns the contents of buf interpreted as a
213204
// big-endian encoded uint value.
214205
func bigEndianUint(buf []byte) uint {
@@ -309,8 +300,6 @@ func (x *Nat) IsMinusOne(m *Modulus) choice {
309300
}
310301

311302
// IsOdd returns 1 if x is odd, and 0 otherwise.
312-
//
313-
//go:norace
314303
func (x *Nat) IsOdd() choice {
315304
if len(x.limbs) == 0 {
316305
return no
@@ -333,12 +322,12 @@ func (x *Nat) TrailingZeroBitsVarTime() uint {
333322
return t
334323
}
335324

336-
// CmpGeq returns 1 if x >= y, and 0 otherwise.
325+
// cmpGeq returns 1 if x >= y, and 0 otherwise.
337326
//
338327
// Both operands must have the same announced length.
339328
//
340329
//go:norace
341-
func (x *Nat) CmpGeq(y *Nat) choice {
330+
func (x *Nat) cmpGeq(y *Nat) choice {
342331
// Eliminate bounds checks in the loop.
343332
size := len(x.limbs)
344333
xLimbs := x.limbs[:size]
@@ -564,6 +553,8 @@ func NewModulus(b []byte) (*Modulus, error) {
564553

565554
// NewModulusProduct creates a new Modulus from the product of two numbers
566555
// represented as big-endian byte slices. The result must be greater than one.
556+
//
557+
//go:norace
567558
func NewModulusProduct(a, b []byte) (*Modulus, error) {
568559
x := NewNat().resetToBytes(a)
569560
y := NewNat().resetToBytes(b)
@@ -602,30 +593,23 @@ func (m *Modulus) Nat() *Nat {
602593
// Make a copy so that the caller can't modify m.nat or alias it with
603594
// another Nat in a modulus operation.
604595
n := NewNat()
605-
n.Set(m.nat)
596+
n.set(m.nat)
606597
return n
607598
}
608599

609-
// shiftIn calculates x = x << _W + y mod m.
610-
//
611-
// This assumes that x is already reduced mod m, and that y < 2^_W.
612-
func (x *Nat) shiftIn(y uint, m *Modulus) *Nat {
613-
return x.shiftInNat(y, m.nat)
614-
}
615-
616600
// shiftIn calculates x = x << _W + y mod m.
617601
//
618602
// This assumes that x is already reduced mod m, and that y < 2^_W.
619603
//
620604
//go:norace
621-
func (x *Nat) shiftInNat(y uint, m *Nat) *Nat {
622-
d := NewNat().reset(len(m.limbs))
605+
func (x *Nat) shiftIn(y uint, m *Modulus) *Nat {
606+
d := NewNat().resetFor(m)
623607

624608
// Eliminate bounds checks in the loop.
625-
size := len(m.limbs)
609+
size := len(m.nat.limbs)
626610
xLimbs := x.limbs[:size]
627611
dLimbs := d.limbs[:size]
628-
mLimbs := m.limbs[:size]
612+
mLimbs := m.nat.limbs[:size]
629613

630614
// Each iteration of this loop computes x = 2x + b mod m, where b is a bit
631615
// from y. Effectively, it left-shifts x and adds y one bit at a time,
@@ -657,17 +641,10 @@ func (x *Nat) shiftInNat(y uint, m *Nat) *Nat {
657641
// This works regardless how large the value of x is.
658642
//
659643
// The output will be resized to the size of m and overwritten.
660-
func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
661-
return out.modNat(x, m.nat)
662-
}
663-
664-
// Mod calculates out = x mod m.
665644
//
666-
// This works regardless how large the value of x is.
667-
//
668-
// The output will be resized to the size of m and overwritten.
669-
func (out *Nat) modNat(x *Nat, m *Nat) *Nat {
670-
out.reset(len(m.limbs))
645+
//go:norace
646+
func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
647+
out.resetFor(m)
671648
// Working our way from the most significant to the least significant limb,
672649
// we can insert each limb at the least significant position, shifting all
673650
// previous limbs left by _W. This way each limb will get shifted by the
@@ -676,7 +653,7 @@ func (out *Nat) modNat(x *Nat, m *Nat) *Nat {
676653
i := len(x.limbs) - 1
677654
// For the first N - 1 limbs we can skip the actual shifting and position
678655
// them at the shifted position, which starts at min(N - 2, i).
679-
start := len(m.limbs) - 2
656+
start := len(m.nat.limbs) - 2
680657
if i < start {
681658
start = i
682659
}
@@ -686,7 +663,7 @@ func (out *Nat) modNat(x *Nat, m *Nat) *Nat {
686663
}
687664
// We shift in the remaining limbs, reducing modulo m each time.
688665
for i >= 0 {
689-
out.shiftInNat(x.limbs[i], m)
666+
out.shiftIn(x.limbs[i], m)
690667
i--
691668
}
692669
return out
@@ -715,8 +692,10 @@ func (out *Nat) resetFor(m *Modulus) *Nat {
715692
// overflowed its size, meaning abstractly x > 2^_W*n > m even if x < m.
716693
//
717694
// x and m operands must have the same announced length.
695+
//
696+
//go:norace
718697
func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) {
719-
t := NewNat().Set(x)
698+
t := NewNat().set(x)
720699
underflow := t.sub(m.nat)
721700
// We keep the result if x - m didn't underflow (meaning x >= m)
722701
// or if always was set.
@@ -728,10 +707,12 @@ func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) {
728707
//
729708
// The length of both operands must be the same as the modulus. Both operands
730709
// must already be reduced modulo m.
710+
//
711+
//go:norace
731712
func (x *Nat) Sub(y *Nat, m *Modulus) *Nat {
732713
underflow := x.sub(y)
733714
// If the subtraction underflowed, add m.
734-
t := NewNat().Set(x)
715+
t := NewNat().set(x)
735716
t.add(m.nat)
736717
x.assign(choice(underflow), t)
737718
return x
@@ -752,6 +733,8 @@ func (x *Nat) SubOne(m *Modulus) *Nat {
752733
//
753734
// The length of both operands must be the same as the modulus. Both operands
754735
// must already be reduced modulo m.
736+
//
737+
//go:norace
755738
func (x *Nat) Add(y *Nat, m *Modulus) *Nat {
756739
overflow := x.add(y)
757740
x.maybeSubtractModulus(choice(overflow), m)
@@ -789,6 +772,8 @@ func (x *Nat) montgomeryReduction(m *Modulus) *Nat {
789772
//
790773
// All inputs should be the same length and already reduced modulo m.
791774
// x will be resized to the size of m and overwritten.
775+
//
776+
//go:norace
792777
func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat {
793778
n := len(m.nat.limbs)
794779
mLimbs := m.nat.limbs[:n]
@@ -946,11 +931,13 @@ func addMulVVW(z, x []uint, y uint) (carry uint) {
946931
//
947932
// The length of both operands must be the same as the modulus. Both operands
948933
// must already be reduced modulo m.
934+
//
935+
//go:norace
949936
func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
950937
if m.odd {
951938
// A Montgomery multiplication by a value out of the Montgomery domain
952939
// takes the result out of Montgomery representation.
953-
xR := NewNat().Set(x).montgomeryRepresentation(m) // xR = x * R mod m
940+
xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m
954941
return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m
955942
}
956943
n := len(m.nat.limbs)
@@ -1009,6 +996,8 @@ func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
1009996
// to the size of m and overwritten. x must already be reduced modulo m.
1010997
//
1011998
// m must be odd, or Exp will panic.
999+
//
1000+
//go:norace
10121001
func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
10131002
if !m.odd {
10141003
panic("bigmod: modulus for Exp must be odd")
@@ -1025,7 +1014,7 @@ func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
10251014
NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
10261015
NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
10271016
}
1028-
table[0].Set(x).montgomeryRepresentation(m)
1017+
table[0].set(x).montgomeryRepresentation(m)
10291018
for i := 1; i < len(table); i++ {
10301019
table[i].montgomeryMul(table[i-1], table[0], m)
10311020
}
@@ -1071,8 +1060,8 @@ func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat {
10711060
// For short exponents, precomputing a table and using a window like in Exp
10721061
// doesn't pay off. Instead, we do a simple conditional square-and-multiply
10731062
// chain, skipping the initial run of zeroes.
1074-
xR := NewNat().Set(x).montgomeryRepresentation(m)
1075-
out.Set(xR)
1063+
xR := NewNat().set(x).montgomeryRepresentation(m)
1064+
out.set(xR)
10761065
for i := bits.UintSize - bits.Len(e) + 1; i < bits.UintSize; i++ {
10771066
out.montgomeryMul(out, out, m)
10781067
if k := (e >> (bits.UintSize - i - 1)) & 1; k != 0 {
@@ -1088,6 +1077,8 @@ func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat {
10881077
//
10891078
// a must be reduced modulo m, but doesn't need to have the same size. The
10901079
// output will be resized to the size of m and overwritten.
1080+
//
1081+
//go:norace
10911082
func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) {
10921083
// This is the extended binary GCD algorithm described in the Handbook of
10931084
// Applied Cryptography, Algorithm 14.61, adapted by BoringSSL to bound
@@ -1121,7 +1112,7 @@ func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) {
11211112
return x, false
11221113
}
11231114

1124-
u := NewNat().Set(a).ExpandFor(m)
1115+
u := NewNat().set(a).ExpandFor(m)
11251116
v := m.Nat()
11261117

11271118
A := NewNat().reset(len(m.nat.limbs))
@@ -1148,7 +1139,7 @@ func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) {
11481139
// If both u and v are odd, subtract the smaller from the larger.
11491140
// If u = v, we need to subtract from v to hit the modified exit condition.
11501141
if u.IsOdd() == yes && v.IsOdd() == yes {
1151-
if v.CmpGeq(u) == no {
1142+
if v.cmpGeq(u) == no {
11521143
u.sub(v)
11531144
A.Add(C, m)
11541145
B.Add(D, &Modulus{nat: a})
@@ -1189,7 +1180,7 @@ func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) {
11891180
if u.IsOne() == no {
11901181
return x, false
11911182
}
1192-
return x.Set(A), true
1183+
return x.set(A), true
11931184
}
11941185
}
11951186
}

internal/bigmod/nat_extension.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package bigmod
2+
3+
func (x *Nat) Set(y *Nat) *Nat {
4+
return x.set(y)
5+
}
6+
7+
// SetOverflowedBytes assigns x = (b mode (m-1)) + 1, where b is a slice of big-endian bytes.
8+
//
9+
// The output will be resized to the size of m and overwritten.
10+
//
11+
//go:norace
12+
func (x *Nat) SetOverflowedBytes(b []byte, m *Modulus) *Nat {
13+
mMinusOne := NewNat().set(m.nat)
14+
mMinusOne.limbs[0]-- // due to m is odd, so we can safely subtract 1
15+
mMinusOneM, _ := NewModulus(mMinusOne.Bytes(m))
16+
one := NewNat().resetFor(m)
17+
one.limbs[0] = 1
18+
x.resetToBytes(b)
19+
x = NewNat().Mod(x, mMinusOneM) // x = x mod (m-1)
20+
x.add(one) // we can safely add 1, no need to check overflow
21+
return x
22+
}
23+
24+
// CmpGeq returns 1 if x >= y, and 0 otherwise.
25+
//
26+
// Both operands must have the same announced length.
27+
//
28+
//go:norace
29+
func (x *Nat) CmpGeq(y *Nat) choice {
30+
return x.cmpGeq(y)
31+
}

internal/bigmod/nat_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,9 @@ func (*Nat) Generate(r *rand.Rand, size int) reflect.Value {
6161

6262
func testModAddCommutative(a *Nat, b *Nat) bool {
6363
m := maxModulus(uint(len(a.limbs)))
64-
aPlusB := new(Nat).Set(a)
64+
aPlusB := new(Nat).set(a)
6565
aPlusB.Add(b, m)
66-
bPlusA := new(Nat).Set(b)
66+
bPlusA := new(Nat).set(b)
6767
bPlusA.Add(a, m)
6868
return aPlusB.Equal(bPlusA) == 1
6969
}
@@ -77,7 +77,7 @@ func TestModAddCommutative(t *testing.T) {
7777

7878
func testModSubThenAddIdentity(a *Nat, b *Nat) bool {
7979
m := maxModulus(uint(len(a.limbs)))
80-
original := new(Nat).Set(a)
80+
original := new(Nat).set(a)
8181
a.Sub(b, m)
8282
a.Add(b, m)
8383
return a.Equal(original) == 1
@@ -97,9 +97,9 @@ func TestMontgomeryRoundtrip(t *testing.T) {
9797
aPlusOne := new(big.Int).SetBytes(natBytes(a))
9898
aPlusOne.Add(aPlusOne, big.NewInt(1))
9999
m, _ := NewModulus(aPlusOne.Bytes())
100-
monty := new(Nat).Set(a)
100+
monty := new(Nat).set(a)
101101
monty.montgomeryRepresentation(m)
102-
aAgain := new(Nat).Set(monty)
102+
aAgain := new(Nat).set(monty)
103103
aAgain.montgomeryMul(monty, one, m)
104104
if a.Equal(aAgain) != 1 {
105105
t.Errorf("%v != %v", a, aAgain)

0 commit comments

Comments
 (0)