diff --git a/internal/bigmod/nat.go b/internal/bigmod/nat.go index 30176da..bbb6148 100644 --- a/internal/bigmod/nat.go +++ b/internal/bigmod/nat.go @@ -228,6 +228,19 @@ func (x *Nat) setBytes(b []byte) error { return nil } +// SetUint assigns x = y, and returns an error if y >= m. +// +// The output will be resized to the size of m and overwritten. +func (x *Nat) SetUint(y uint, m *Modulus) (*Nat, error) { + x.resetFor(m) + // Modulus is never zero, so always at least one limb. + x.limbs[0] = y + if x.CmpGeq(m.nat) == yes { + return nil, errors.New("input overflows the modulus") + } + return x, nil +} + // Equal returns 1 if x == y, and 0 otherwise. // // Both operands must have the same announced length. @@ -323,19 +336,20 @@ func (x *Nat) sub(y *Nat) (c uint) { // Modulus is used for modular arithmetic, precomputing relevant constants. // -// Moduli are assumed to be odd numbers. Moduli can also leak the exact -// number of bits needed to store their value, and are stored without padding. -// -// Their actual value is still kept secret. +// A Modulus can leak the exact number of bits needed to store its value +// and is stored without padding. Its actual value is still kept secret. type Modulus struct { // The underlying natural number for this modulus. // // This will be stored without any padding, and shouldn't alias with any // other natural number being used. nat *Nat - leading int // number of leading zeros in the modulus - m0inv uint // -nat.limbs[0]⁻¹ mod _W - rr *Nat // R*R for montgomeryRepresentation + leading int // number of leading zeros in the modulus + + // If m is even, the following fields are not set. + odd bool + m0inv uint // -nat.limbs[0]⁻¹ mod _W + rr *Nat // R*R for montgomeryRepresentation } // rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs). @@ -406,17 +420,20 @@ func minusInverseModW(x uint) uint { // NewModulus creates a new Modulus from a slice of big-endian bytes. // -// The value must be odd. The number of significant bits (and nothing else) is -// leaked through timing side-channels. +// The number of significant bits and whether the modulus is even is leaked +// through timing side-channels. func NewModulus(b []byte) (*Modulus, error) { - if len(b) == 0 || b[len(b)-1]&1 != 1 { - return nil, errors.New("modulus must be > 0 and odd") - } m := &Modulus{} m.nat = NewNat().resetToBytes(b) + if len(m.nat.limbs) == 0 { + return nil, errors.New("modulus must be > 0") + } m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1]) - m.m0inv = minusInverseModW(m.nat.limbs[0]) - m.rr = rr(m) + if m.nat.limbs[0]&1 == 1 { + m.odd = true + m.m0inv = minusInverseModW(m.nat.limbs[0]) + m.rr = rr(m) + } return m, nil } @@ -775,17 +792,73 @@ func addMulVVW(z, x []uint, y uint) (carry uint) { // The length of both operands must be the same as the modulus. Both operands // must already be reduced modulo m. func (x *Nat) Mul(y *Nat, m *Modulus) *Nat { - // A Montgomery multiplication by a value out of the Montgomery domain - // takes the result out of Montgomery representation. - xR := NewNat().Set(x).montgomeryRepresentation(m) // xR = x * R mod m - return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m + if m.odd { + // A Montgomery multiplication by a value out of the Montgomery domain + // takes the result out of Montgomery representation. + xR := NewNat().Set(x).montgomeryRepresentation(m) // xR = x * R mod m + return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m + } + n := len(m.nat.limbs) + xLimbs := x.limbs[:n] + yLimbs := y.limbs[:n] + switch n { + default: + // Attempt to use a stack-allocated backing array. + T := make([]uint, 0, preallocLimbs*2) + if cap(T) < n*2 { + T = make([]uint, 0, n*2) + } + T = T[:n*2] + // T = x * y + for i := 0; i < n; i++ { + T[n+i] = addMulVVW(T[i:n+i], xLimbs, yLimbs[i]) + } + // x = T mod m + return x.Mod(&Nat{limbs: T}, m) + // The following specialized cases follow the exact same algorithm, but + // optimized for the sizes most used in RSA. See montgomeryMul for details. + case 256 / _W: // optimization for 256 bits nat + const n = 256 / _W // compiler hint + T := make([]uint, n*2) + for i := 0; i < n; i++ { + T[n+i] = addMulVVW256(&T[i], &xLimbs[0], yLimbs[i]) + } + return x.Mod(&Nat{limbs: T}, m) + case 1024 / _W: + const n = 1024 / _W // compiler hint + T := make([]uint, n*2) + for i := 0; i < n; i++ { + T[n+i] = addMulVVW1024(&T[i], &xLimbs[0], yLimbs[i]) + } + return x.Mod(&Nat{limbs: T}, m) + case 1536 / _W: + const n = 1536 / _W // compiler hint + T := make([]uint, n*2) + for i := 0; i < n; i++ { + T[n+i] = addMulVVW1536(&T[i], &xLimbs[0], yLimbs[i]) + } + return x.Mod(&Nat{limbs: T}, m) + case 2048 / _W: + const n = 2048 / _W // compiler hint + T := make([]uint, n*2) + for i := 0; i < n; i++ { + T[n+i] = addMulVVW2048(&T[i], &xLimbs[0], yLimbs[i]) + } + return x.Mod(&Nat{limbs: T}, m) + } } // Exp calculates out = x^e mod m. // // The exponent e is represented in big-endian order. The output will be resized // to the size of m and overwritten. x must already be reduced modulo m. +// +// m must be odd, or Exp will panic. func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat { + if !m.odd { + panic("bigmod: modulus for Exp must be odd") + } + // We use a 4 bit window. For our RSA workload, 4 bit windows are faster // than 2 bit windows, but use an extra 12 nats worth of scratch space. // Using bit sizes that don't divide 8 are more complex to implement, but @@ -834,7 +907,12 @@ func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat { // // The output will be resized to the size of m and overwritten. x must already // be reduced modulo m. This leaks the exponent through timing side-channels. +// +// m must be odd, or ExpShortVarTime will panic. func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat { + if !m.odd { + panic("bigmod: modulus for ExpShortVarTime must be odd") + } // For short exponents, precomputing a table and using a window like in Exp // doesn't pay off. Instead, we do a simple conditional square-and-multiply // chain, skipping the initial run of zeroes. diff --git a/internal/bigmod/nat_test.go b/internal/bigmod/nat_test.go index 18516a6..d6af772 100644 --- a/internal/bigmod/nat_test.go +++ b/internal/bigmod/nat_test.go @@ -6,6 +6,7 @@ package bigmod import ( "bytes" + cryptorand "crypto/rand" "encoding/hex" "fmt" "math/big" @@ -17,6 +18,19 @@ import ( "testing/quick" ) +// setBig assigns x = n, optionally resizing n to the appropriate size. +// +// The announced length of x is set based on the actual bit size of the input, +// ignoring leading zeroes. +func (x *Nat) setBig(n *big.Int) *Nat { + limbs := n.Bits() + x.reset(len(limbs)) + for i := range limbs { + x.limbs[i] = uint(limbs[i]) + } + return x +} + func (n *Nat) String() string { var limbs []string for i := range n.limbs { @@ -312,19 +326,6 @@ func TestExpShort(t *testing.T) { } } -// setBig assigns x = n, optionally resizing n to the appropriate size. -// -// The announced length of x is set based on the actual bit size of the input, -// ignoring leading zeroes. -func (x *Nat) setBig(n *big.Int) *Nat { - limbs := n.Bits() - x.reset(len(limbs)) - for i := range limbs { - x.limbs[i] = uint(limbs[i]) - } - return x -} - // TestMulReductions tests that Mul reduces results equal or slightly greater // than the modulus. Some Montgomery algorithms don't and need extra care to // return correct results. See https://go.dev/issue/13907. @@ -353,6 +354,52 @@ func TestMulReductions(t *testing.T) { } } +func TestMul(t *testing.T) { + t.Run("760", func(t *testing.T) { testMul(t, 760/8) }) + t.Run("256", func(t *testing.T) { testMul(t, 256/8) }) + t.Run("1024", func(t *testing.T) { testMul(t, 1024/8) }) + t.Run("1536", func(t *testing.T) { testMul(t, 1536/8) }) + t.Run("2048", func(t *testing.T) { testMul(t, 2048/8) }) +} + +func testMul(t *testing.T, n int) { + a, b, m := make([]byte, n), make([]byte, n), make([]byte, n) + cryptorand.Read(a) + cryptorand.Read(b) + cryptorand.Read(m) + // Pick the highest as the modulus. + if bytes.Compare(a, m) > 0 { + a, m = m, a + } + if bytes.Compare(b, m) > 0 { + b, m = m, b + } + M, err := NewModulus(m) + if err != nil { + t.Fatal(err) + } + A, err := NewNat().SetBytes(a, M) + if err != nil { + t.Fatal(err) + } + B, err := NewNat().SetBytes(b, M) + if err != nil { + t.Fatal(err) + } + A.Mul(B, M) + ABytes := A.Bytes(M) + mBig := new(big.Int).SetBytes(m) + aBig := new(big.Int).SetBytes(a) + bBig := new(big.Int).SetBytes(b) + nBig := new(big.Int).Mul(aBig, bBig) + nBig.Mod(nBig, mBig) + nBigBytes := make([]byte, len(ABytes)) + nBig.FillBytes(nBigBytes) + if !bytes.Equal(ABytes, nBigBytes) { + t.Errorf("got %x, want %x", ABytes, nBigBytes) + } +} + func natBytes(n *Nat) []byte { return n.Bytes(maxModulus(uint(len(n.limbs)))) } diff --git a/sm2/sm2.go b/sm2/sm2.go index df6a682..3c971ee 100644 --- a/sm2/sm2.go +++ b/sm2/sm2.go @@ -309,6 +309,8 @@ func encodingCiphertextASN1(C1 *_sm2ec.SM2P256Point, c2, c3 []byte) ([]byte, err // Most applications should use [crypto/rand.Reader] as rand. Note that the // returned key does not depend deterministically on the bytes read from rand, // and may change between calls and/or between versions. +// +// According GB/T 32918.1-2016, the private key must be in [1, n-2]. func GenerateKey(rand io.Reader) (*PrivateKey, error) { randutil.MaybeReadByte(rand) @@ -331,6 +333,8 @@ func GenerateKey(rand io.Reader) (*PrivateKey, error) { // NewPrivateKey checks that key is valid and returns a SM2 PrivateKey. // // key - the private key byte slice, the length must be 32 for SM2. +// +// According GB/T 32918.1-2016, the private key must be in [1, n-2]. func NewPrivateKey(key []byte) (*PrivateKey, error) { c := p256() if len(key) != c.N.Size() { @@ -364,6 +368,8 @@ func NewPrivateKeyFromInt(key *big.Int) (*PrivateKey, error) { } // NewPublicKey checks that key is valid and returns a PublicKey. +// +// According GB/T 32918.1-2016, the private key must be in [1, n-2]. func NewPublicKey(key []byte) (*ecdsa.PublicKey, error) { c := p256() // Reject the point at infinity and compressed encodings. @@ -598,7 +604,7 @@ func (priv *PrivateKey) inverseOfPrivateKeyPlus1(c *sm2Curve) (*bigmod.Nat, erro dp1Bytes []byte ) priv.inverseOfKeyPlus1Once.Do(func() { - oneNat, _ = bigmod.NewNat().SetBytes(one.Bytes(), c.N) + oneNat, _ = bigmod.NewNat().SetUint(1, c.N) dp1Inv, err = bigmod.NewNat().SetBytes(priv.D.Bytes(), c.N) if err == nil { dp1Inv.Add(oneNat, c.N)