Skip to content

Commit

Permalink
Add NewGCMTLS13 (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
dagood authored Jan 5, 2024
1 parent ebaf9de commit c8f7184
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 52 deletions.
63 changes: 56 additions & 7 deletions cng/aes.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func (c *aesCipher) NewGCM(nonceSize, tagSize int) (cipher.AEAD, error) {
if tagSize != gcmTagSize {
return cipher.NewGCMWithTagSize(&noGCM{c}, tagSize)
}
return newGCM(c.key, false)
return newGCM(c.key, cipherGCMTLSNone)
}

// NewGCMTLS returns a GCM cipher specific to TLS
Expand All @@ -116,7 +116,17 @@ func NewGCMTLS(c cipher.Block) (cipher.AEAD, error) {
}

func (c *aesCipher) NewGCMTLS() (cipher.AEAD, error) {
return newGCM(c.key, true)
return newGCM(c.key, cipherGCMTLS12)
}

// NewGCMTLS13 returns a GCM cipher specific to TLS 1.3 and should not be used
// for non-TLS purposes.
func NewGCMTLS13(c cipher.Block) (cipher.AEAD, error) {
return c.(*aesCipher).NewGCMTLS13()
}

func (c *aesCipher) NewGCMTLS13() (cipher.AEAD, error) {
return newGCM(c.key, cipherGCMTLS13)
}

type cbcCipher struct {
Expand Down Expand Up @@ -197,17 +207,32 @@ const (
gcmTlsFixedNonceSize = 4
)

type cipherGCMTLS uint8

const (
cipherGCMTLSNone cipherGCMTLS = iota
cipherGCMTLS12
cipherGCMTLS13
)

type aesGCM struct {
kh bcrypt.KEY_HANDLE
tls bool
kh bcrypt.KEY_HANDLE
tls cipherGCMTLS
// minNextNonce is the minimum value that the next nonce can be, enforced by
// all TLS modes.
minNextNonce uint64
// mask is the nonce mask used in TLS 1.3 mode.
mask uint64
// maskInitialized is true if mask has been initialized. This happens during
// the first Seal. The initialized mask may be 0. Used by TLS 1.3 mode.
maskInitialized bool
}

func (g *aesGCM) finalize() {
bcrypt.DestroyKey(g.kh)
}

func newGCM(key []byte, tls bool) (*aesGCM, error) {
func newGCM(key []byte, tls cipherGCMTLS) (*aesGCM, error) {
kh, err := newCipherHandle(bcrypt.AES_ALGORITHM, bcrypt.CHAIN_MODE_GCM, key)
if err != nil {
return nil, err
Expand Down Expand Up @@ -235,15 +260,39 @@ func (g *aesGCM) Seal(dst, nonce, plaintext, additionalData []byte) []byte {
if len(dst)+len(plaintext)+gcmTagSize < len(dst) {
panic("cipher: message too large for buffer")
}
if g.tls {
if g.tls != cipherGCMTLSNone {
if len(additionalData) != gcmTlsAddSize {
panic("cipher: incorrect additional data length given to GCM TLS")
}
counter := bigUint64(nonce[gcmTlsFixedNonceSize:])
if g.tls == cipherGCMTLS13 {
// In TLS 1.3, the counter in the nonce has a mask and requires
// further decoding.
if !g.maskInitialized {
// According to TLS 1.3 nonce construction details at
// https://tools.ietf.org/html/rfc8446#section-5.3:
//
// the first record transmitted under a particular traffic
// key MUST use sequence number 0.
//
// The padded sequence number is XORed with [a mask].
//
// The resulting quantity (of length iv_length) is used as
// the per-record nonce.
//
// We need to convert from the given nonce to sequence numbers
// to keep track of minNextNonce and enforce the counter
// maximum. On the first call, we know counter^mask is 0^mask,
// so we can simply store it as the mask.
g.mask = counter
g.maskInitialized = true
}
counter ^= g.mask
}
// BoringCrypto enforces strictly monotonically increasing explicit nonces
// and to fail after 2^64 - 1 keys as per FIPS 140-2 IG A.5,
// but BCrypt does not perform this check, so it is implemented here.
const maxUint64 = 1<<64 - 1
counter := bigUint64(nonce[gcmTlsFixedNonceSize:])
if counter == maxUint64 {
panic("cipher: nonce counter must be less than 2^64 - 1")
}
Expand Down
115 changes: 70 additions & 45 deletions cng/aes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,51 +71,76 @@ func TestSealAndOpen(t *testing.T) {
}

func TestSealAndOpenTLS(t *testing.T) {
ci, err := NewAESCipher(key)
if err != nil {
t.Fatal(err)
}
gcm, err := NewGCMTLS(ci)
if err != nil {
t.Fatal(err)
}
nonce := [12]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
nonce1 := [12]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
nonce9 := [12]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9}
nonce10 := [12]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10}
nonceMax := [12]byte{0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255}
plainText := []byte{0x01, 0x02, 0x03}
additionalData := make([]byte, 13)
additionalData[11] = byte(len(plainText) >> 8)
additionalData[12] = byte(len(plainText))
sealed := gcm.Seal(nil, nonce[:], plainText, additionalData)
assertPanic(t, func() {
gcm.Seal(nil, nonce[:], plainText, additionalData)
})
sealed1 := gcm.Seal(nil, nonce1[:], plainText, additionalData)
gcm.Seal(nil, nonce10[:], plainText, additionalData)
assertPanic(t, func() {
gcm.Seal(nil, nonce9[:], plainText, additionalData)
})
assertPanic(t, func() {
gcm.Seal(nil, nonceMax[:], plainText, additionalData)
})
if bytes.Equal(sealed, sealed1) {
t.Errorf("different nonces should produce different outputs\ngot: %#v\nexp: %#v", sealed, sealed1)
}
decrypted, err := gcm.Open(nil, nonce[:], sealed, additionalData)
if err != nil {
t.Error(err)
}
decrypted1, err := gcm.Open(nil, nonce1[:], sealed1, additionalData)
if err != nil {
t.Error(err)
}
if !bytes.Equal(decrypted, plainText) {
t.Errorf("unexpected decrypted result\ngot: %#v\nexp: %#v", decrypted, plainText)
}
if !bytes.Equal(decrypted, decrypted1) {
t.Errorf("unexpected decrypted result\ngot: %#v\nexp: %#v", decrypted, decrypted1)
tests := []struct {
name string
new func(c cipher.Block) (cipher.AEAD, error)
mask func(n *[12]byte)
}{
{"1.2", NewGCMTLS, nil},
{"1.3", NewGCMTLS13, nil},
{"1.3_masked", NewGCMTLS13, func(n *[12]byte) {
// Arbitrary mask in the high bits.
n[9] ^= 0x42
// Mask the very first bit. This makes sure that if Seal doesn't
// handle the mask, the counter appears to go backwards and panics
// when it shouldn't.
n[11] ^= 0x1
}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ci, err := NewAESCipher(key)
if err != nil {
t.Fatal(err)
}
gcm, err := tt.new(ci)
if err != nil {
t.Fatal(err)
}
nonce := [12]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
nonce1 := [12]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
nonce9 := [12]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9}
nonce10 := [12]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10}
nonceMax := [12]byte{0, 0, 0, 0, 255, 255, 255, 255, 255, 255, 255, 255}
if tt.mask != nil {
for _, m := range []*[12]byte{&nonce, &nonce1, &nonce9, &nonce10, &nonceMax} {
tt.mask(m)
}
}
plainText := []byte{0x01, 0x02, 0x03}
additionalData := make([]byte, 13)
additionalData[11] = byte(len(plainText) >> 8)
additionalData[12] = byte(len(plainText))
sealed := gcm.Seal(nil, nonce[:], plainText, additionalData)
assertPanic(t, func() {
gcm.Seal(nil, nonce[:], plainText, additionalData)
})
sealed1 := gcm.Seal(nil, nonce1[:], plainText, additionalData)
gcm.Seal(nil, nonce10[:], plainText, additionalData)
assertPanic(t, func() {
gcm.Seal(nil, nonce9[:], plainText, additionalData)
})
assertPanic(t, func() {
gcm.Seal(nil, nonceMax[:], plainText, additionalData)
})
if bytes.Equal(sealed, sealed1) {
t.Errorf("different nonces should produce different outputs\ngot: %#v\nexp: %#v", sealed, sealed1)
}
decrypted, err := gcm.Open(nil, nonce[:], sealed, additionalData)
if err != nil {
t.Error(err)
}
decrypted1, err := gcm.Open(nil, nonce1[:], sealed1, additionalData)
if err != nil {
t.Error(err)
}
if !bytes.Equal(decrypted, plainText) {
t.Errorf("unexpected decrypted result\ngot: %#v\nexp: %#v", decrypted, plainText)
}
if !bytes.Equal(decrypted, decrypted1) {
t.Errorf("unexpected decrypted result\ngot: %#v\nexp: %#v", decrypted, decrypted1)
}
})
}
}

Expand Down

0 comments on commit c8f7184

Please sign in to comment.