@@ -18,6 +18,15 @@ const (
18
18
_S = _W / 8
19
19
)
20
20
21
+ // Note: These functions make many loops over all the words in a Nat.
22
+ // These loops used to be in assembly, invisible to -race, -asan, and -msan,
23
+ // but now they are in Go and incur significant overhead in those modes.
24
+ // To bring the old performance back, we mark all functions that loop
25
+ // over Nat words with //go:norace. Because //go:norace does not
26
+ // propagate across inlining, we must also mark functions that inline
27
+ // //go:norace functions - specifically, those that inline add, addMulVVW,
28
+ // assign, cmpGeq, rshift1, and sub.
29
+
21
30
// choice represents a constant-time boolean. The value of choice is always
22
31
// either 1 or 0. We use an int instead of bool in order to make decisions in
23
32
// constant time by turning it into a mask.
@@ -40,14 +49,6 @@ func ctEq(x, y uint) choice {
40
49
return not (choice (c1 | c2 ))
41
50
}
42
51
43
- // ctGeq returns 1 if x >= y, and 0 otherwise. The execution time of this
44
- // function does not depend on its inputs.
45
- func ctGeq (x , y uint ) choice {
46
- // If x < y, then x - y generates a carry.
47
- _ , carry := bits .Sub (x , y , 0 )
48
- return not (choice (carry ))
49
- }
50
-
51
52
// Nat represents an arbitrary natural number
52
53
//
53
54
// Each Nat has an announced length, which is the number of limbs it has stored.
@@ -84,6 +85,7 @@ func (x *Nat) expand(n int) *Nat {
84
85
return x
85
86
}
86
87
extraLimbs := x .limbs [len (x .limbs ):n ]
88
+ // clear(extraLimbs)
87
89
for i := range extraLimbs {
88
90
extraLimbs [i ] = 0
89
91
}
@@ -97,6 +99,7 @@ func (x *Nat) reset(n int) *Nat {
97
99
x .limbs = make ([]uint , n )
98
100
return x
99
101
}
102
+ // clear(x.limbs)
100
103
for i := range x .limbs {
101
104
x .limbs [i ] = 0
102
105
}
@@ -131,7 +134,7 @@ func (x *Nat) trim() *Nat {
131
134
}
132
135
133
136
// set assigns x = y, optionally resizing x to the appropriate size.
134
- func (x * Nat ) Set (y * Nat ) * Nat {
137
+ func (x * Nat ) set (y * Nat ) * Nat {
135
138
x .reset (len (y .limbs ))
136
139
copy (x .limbs , y .limbs )
137
140
return x
@@ -164,12 +167,14 @@ func (x *Nat) Bytes(m *Modulus) []byte {
164
167
// SetBytes returns an error if b >= m.
165
168
//
166
169
// The output will be resized to the size of m and overwritten.
170
+ //
171
+ //go:norace
167
172
func (x * Nat ) SetBytes (b []byte , m * Modulus ) (* Nat , error ) {
168
173
x .resetFor (m )
169
174
if err := x .setBytes (b ); err != nil {
170
175
return nil , err
171
176
}
172
- if x .CmpGeq (m .nat ) == yes {
177
+ if x .cmpGeq (m .nat ) == yes {
173
178
return nil , errors .New ("input overflows the modulus" )
174
179
}
175
180
return x , nil
@@ -195,20 +200,6 @@ func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) {
195
200
return x , nil
196
201
}
197
202
198
- // SetOverflowedBytes assigns x = (b mode (m-1)) + 1, where b is a slice of big-endian bytes.
199
- //
200
- // The output will be resized to the size of m and overwritten.
201
- func (x * Nat ) SetOverflowedBytes (b []byte , m * Modulus ) * Nat {
202
- mMinusOne := NewNat ().Set (m .nat )
203
- mMinusOne .limbs [0 ]-- // due to m is odd, so we can safely subtract 1
204
- one := NewNat ().resetFor (m )
205
- one .limbs [0 ] = 1
206
- x .resetToBytes (b )
207
- x = NewNat ().modNat (x , mMinusOne ) // x = x mod (m-1)
208
- x .add (one ) // we can safely add 1, no need to check overflow
209
- return x
210
- }
211
-
212
203
// bigEndianUint returns the contents of buf interpreted as a
213
204
// big-endian encoded uint value.
214
205
func bigEndianUint (buf []byte ) uint {
@@ -309,8 +300,6 @@ func (x *Nat) IsMinusOne(m *Modulus) choice {
309
300
}
310
301
311
302
// IsOdd returns 1 if x is odd, and 0 otherwise.
312
- //
313
- //go:norace
314
303
func (x * Nat ) IsOdd () choice {
315
304
if len (x .limbs ) == 0 {
316
305
return no
@@ -333,12 +322,12 @@ func (x *Nat) TrailingZeroBitsVarTime() uint {
333
322
return t
334
323
}
335
324
336
- // CmpGeq returns 1 if x >= y, and 0 otherwise.
325
+ // cmpGeq returns 1 if x >= y, and 0 otherwise.
337
326
//
338
327
// Both operands must have the same announced length.
339
328
//
340
329
//go:norace
341
- func (x * Nat ) CmpGeq (y * Nat ) choice {
330
+ func (x * Nat ) cmpGeq (y * Nat ) choice {
342
331
// Eliminate bounds checks in the loop.
343
332
size := len (x .limbs )
344
333
xLimbs := x .limbs [:size ]
@@ -564,6 +553,8 @@ func NewModulus(b []byte) (*Modulus, error) {
564
553
565
554
// NewModulusProduct creates a new Modulus from the product of two numbers
566
555
// represented as big-endian byte slices. The result must be greater than one.
556
+ //
557
+ //go:norace
567
558
func NewModulusProduct (a , b []byte ) (* Modulus , error ) {
568
559
x := NewNat ().resetToBytes (a )
569
560
y := NewNat ().resetToBytes (b )
@@ -602,30 +593,23 @@ func (m *Modulus) Nat() *Nat {
602
593
// Make a copy so that the caller can't modify m.nat or alias it with
603
594
// another Nat in a modulus operation.
604
595
n := NewNat ()
605
- n .Set (m .nat )
596
+ n .set (m .nat )
606
597
return n
607
598
}
608
599
609
- // shiftIn calculates x = x << _W + y mod m.
610
- //
611
- // This assumes that x is already reduced mod m, and that y < 2^_W.
612
- func (x * Nat ) shiftIn (y uint , m * Modulus ) * Nat {
613
- return x .shiftInNat (y , m .nat )
614
- }
615
-
616
600
// shiftIn calculates x = x << _W + y mod m.
617
601
//
618
602
// This assumes that x is already reduced mod m, and that y < 2^_W.
619
603
//
620
604
//go:norace
621
- func (x * Nat ) shiftInNat (y uint , m * Nat ) * Nat {
622
- d := NewNat ().reset ( len ( m . limbs ) )
605
+ func (x * Nat ) shiftIn (y uint , m * Modulus ) * Nat {
606
+ d := NewNat ().resetFor ( m )
623
607
624
608
// Eliminate bounds checks in the loop.
625
- size := len (m .limbs )
609
+ size := len (m .nat . limbs )
626
610
xLimbs := x .limbs [:size ]
627
611
dLimbs := d .limbs [:size ]
628
- mLimbs := m .limbs [:size ]
612
+ mLimbs := m .nat . limbs [:size ]
629
613
630
614
// Each iteration of this loop computes x = 2x + b mod m, where b is a bit
631
615
// from y. Effectively, it left-shifts x and adds y one bit at a time,
@@ -657,17 +641,10 @@ func (x *Nat) shiftInNat(y uint, m *Nat) *Nat {
657
641
// This works regardless how large the value of x is.
658
642
//
659
643
// The output will be resized to the size of m and overwritten.
660
- func (out * Nat ) Mod (x * Nat , m * Modulus ) * Nat {
661
- return out .modNat (x , m .nat )
662
- }
663
-
664
- // Mod calculates out = x mod m.
665
644
//
666
- // This works regardless how large the value of x is.
667
- //
668
- // The output will be resized to the size of m and overwritten.
669
- func (out * Nat ) modNat (x * Nat , m * Nat ) * Nat {
670
- out .reset (len (m .limbs ))
645
+ //go:norace
646
+ func (out * Nat ) Mod (x * Nat , m * Modulus ) * Nat {
647
+ out .resetFor (m )
671
648
// Working our way from the most significant to the least significant limb,
672
649
// we can insert each limb at the least significant position, shifting all
673
650
// previous limbs left by _W. This way each limb will get shifted by the
@@ -676,7 +653,7 @@ func (out *Nat) modNat(x *Nat, m *Nat) *Nat {
676
653
i := len (x .limbs ) - 1
677
654
// For the first N - 1 limbs we can skip the actual shifting and position
678
655
// them at the shifted position, which starts at min(N - 2, i).
679
- start := len (m .limbs ) - 2
656
+ start := len (m .nat . limbs ) - 2
680
657
if i < start {
681
658
start = i
682
659
}
@@ -686,7 +663,7 @@ func (out *Nat) modNat(x *Nat, m *Nat) *Nat {
686
663
}
687
664
// We shift in the remaining limbs, reducing modulo m each time.
688
665
for i >= 0 {
689
- out .shiftInNat (x .limbs [i ], m )
666
+ out .shiftIn (x .limbs [i ], m )
690
667
i --
691
668
}
692
669
return out
@@ -715,8 +692,10 @@ func (out *Nat) resetFor(m *Modulus) *Nat {
715
692
// overflowed its size, meaning abstractly x > 2^_W*n > m even if x < m.
716
693
//
717
694
// x and m operands must have the same announced length.
695
+ //
696
+ //go:norace
718
697
func (x * Nat ) maybeSubtractModulus (always choice , m * Modulus ) {
719
- t := NewNat ().Set (x )
698
+ t := NewNat ().set (x )
720
699
underflow := t .sub (m .nat )
721
700
// We keep the result if x - m didn't underflow (meaning x >= m)
722
701
// or if always was set.
@@ -728,10 +707,12 @@ func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) {
728
707
//
729
708
// The length of both operands must be the same as the modulus. Both operands
730
709
// must already be reduced modulo m.
710
+ //
711
+ //go:norace
731
712
func (x * Nat ) Sub (y * Nat , m * Modulus ) * Nat {
732
713
underflow := x .sub (y )
733
714
// If the subtraction underflowed, add m.
734
- t := NewNat ().Set (x )
715
+ t := NewNat ().set (x )
735
716
t .add (m .nat )
736
717
x .assign (choice (underflow ), t )
737
718
return x
@@ -752,6 +733,8 @@ func (x *Nat) SubOne(m *Modulus) *Nat {
752
733
//
753
734
// The length of both operands must be the same as the modulus. Both operands
754
735
// must already be reduced modulo m.
736
+ //
737
+ //go:norace
755
738
func (x * Nat ) Add (y * Nat , m * Modulus ) * Nat {
756
739
overflow := x .add (y )
757
740
x .maybeSubtractModulus (choice (overflow ), m )
@@ -789,6 +772,8 @@ func (x *Nat) montgomeryReduction(m *Modulus) *Nat {
789
772
//
790
773
// All inputs should be the same length and already reduced modulo m.
791
774
// x will be resized to the size of m and overwritten.
775
+ //
776
+ //go:norace
792
777
func (x * Nat ) montgomeryMul (a * Nat , b * Nat , m * Modulus ) * Nat {
793
778
n := len (m .nat .limbs )
794
779
mLimbs := m .nat .limbs [:n ]
@@ -946,11 +931,13 @@ func addMulVVW(z, x []uint, y uint) (carry uint) {
946
931
//
947
932
// The length of both operands must be the same as the modulus. Both operands
948
933
// must already be reduced modulo m.
934
+ //
935
+ //go:norace
949
936
func (x * Nat ) Mul (y * Nat , m * Modulus ) * Nat {
950
937
if m .odd {
951
938
// A Montgomery multiplication by a value out of the Montgomery domain
952
939
// takes the result out of Montgomery representation.
953
- xR := NewNat ().Set (x ).montgomeryRepresentation (m ) // xR = x * R mod m
940
+ xR := NewNat ().set (x ).montgomeryRepresentation (m ) // xR = x * R mod m
954
941
return x .montgomeryMul (xR , y , m ) // x = xR * y / R mod m
955
942
}
956
943
n := len (m .nat .limbs )
@@ -1009,6 +996,8 @@ func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
1009
996
// to the size of m and overwritten. x must already be reduced modulo m.
1010
997
//
1011
998
// m must be odd, or Exp will panic.
999
+ //
1000
+ //go:norace
1012
1001
func (out * Nat ) Exp (x * Nat , e []byte , m * Modulus ) * Nat {
1013
1002
if ! m .odd {
1014
1003
panic ("bigmod: modulus for Exp must be odd" )
@@ -1025,7 +1014,7 @@ func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
1025
1014
NewNat (), NewNat (), NewNat (), NewNat (), NewNat (),
1026
1015
NewNat (), NewNat (), NewNat (), NewNat (), NewNat (),
1027
1016
}
1028
- table [0 ].Set (x ).montgomeryRepresentation (m )
1017
+ table [0 ].set (x ).montgomeryRepresentation (m )
1029
1018
for i := 1 ; i < len (table ); i ++ {
1030
1019
table [i ].montgomeryMul (table [i - 1 ], table [0 ], m )
1031
1020
}
@@ -1071,8 +1060,8 @@ func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat {
1071
1060
// For short exponents, precomputing a table and using a window like in Exp
1072
1061
// doesn't pay off. Instead, we do a simple conditional square-and-multiply
1073
1062
// chain, skipping the initial run of zeroes.
1074
- xR := NewNat ().Set (x ).montgomeryRepresentation (m )
1075
- out .Set (xR )
1063
+ xR := NewNat ().set (x ).montgomeryRepresentation (m )
1064
+ out .set (xR )
1076
1065
for i := bits .UintSize - bits .Len (e ) + 1 ; i < bits .UintSize ; i ++ {
1077
1066
out .montgomeryMul (out , out , m )
1078
1067
if k := (e >> (bits .UintSize - i - 1 )) & 1 ; k != 0 {
@@ -1088,6 +1077,8 @@ func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat {
1088
1077
//
1089
1078
// a must be reduced modulo m, but doesn't need to have the same size. The
1090
1079
// output will be resized to the size of m and overwritten.
1080
+ //
1081
+ //go:norace
1091
1082
func (x * Nat ) InverseVarTime (a * Nat , m * Modulus ) (* Nat , bool ) {
1092
1083
// This is the extended binary GCD algorithm described in the Handbook of
1093
1084
// Applied Cryptography, Algorithm 14.61, adapted by BoringSSL to bound
@@ -1121,7 +1112,7 @@ func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) {
1121
1112
return x , false
1122
1113
}
1123
1114
1124
- u := NewNat ().Set (a ).ExpandFor (m )
1115
+ u := NewNat ().set (a ).ExpandFor (m )
1125
1116
v := m .Nat ()
1126
1117
1127
1118
A := NewNat ().reset (len (m .nat .limbs ))
@@ -1148,7 +1139,7 @@ func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) {
1148
1139
// If both u and v are odd, subtract the smaller from the larger.
1149
1140
// If u = v, we need to subtract from v to hit the modified exit condition.
1150
1141
if u .IsOdd () == yes && v .IsOdd () == yes {
1151
- if v .CmpGeq (u ) == no {
1142
+ if v .cmpGeq (u ) == no {
1152
1143
u .sub (v )
1153
1144
A .Add (C , m )
1154
1145
B .Add (D , & Modulus {nat : a })
@@ -1189,7 +1180,7 @@ func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) {
1189
1180
if u .IsOne () == no {
1190
1181
return x , false
1191
1182
}
1192
- return x .Set (A ), true
1183
+ return x .set (A ), true
1193
1184
}
1194
1185
}
1195
1186
}
0 commit comments