diff --git a/context.go b/context.go index dd3dd4b..9f1b246 100644 --- a/context.go +++ b/context.go @@ -4,6 +4,7 @@ package srtp import ( + "bytes" "fmt" "github.com/pion/transport/v3/replaydetector" @@ -56,6 +57,10 @@ type Context struct { newSRTCPReplayDetector func() replaydetector.ReplayDetector newSRTPReplayDetector func() replaydetector.ReplayDetector + + profile ProtectionProfile + sendMKI []byte + mkis map[string]srtpCipher } // CreateContext creates a new SRTP Context. @@ -66,52 +71,99 @@ type Context struct { // // decCtx, err := srtp.CreateContext(key, salt, profile, srtp.SRTPReplayProtection(256)) func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts ...ContextOption) (c *Context, err error) { - keyLen, err := profile.keyLen() + c = &Context{ + srtpSSRCStates: map[uint32]*srtpSSRCState{}, + srtcpSSRCStates: map[uint32]*srtcpSSRCState{}, + profile: profile, + mkis: map[string]srtpCipher{}, + } + + for _, o := range append( + []ContextOption{ // Default options + SRTPNoReplayProtection(), + SRTCPNoReplayProtection(), + }, + opts..., // User specified options + ) { + if errOpt := o(c); errOpt != nil { + return nil, errOpt + } + } + + err = c.AddCipherForMKI(c.sendMKI, masterKey, masterSalt) if err != nil { return nil, err } + c.cipher = c.mkis[string(c.sendMKI)] - saltLen, err := profile.saltLen() + return c, nil +} + +// AddCipherForMKI adds new MKI with associated masker key and salt. Context must be created with MasterKeyIndicator option to enable MKI support. +func (c *Context) AddCipherForMKI(mki, masterKey, masterSalt []byte) error { + if len(mki) != len(c.sendMKI) { + return errInvalidMKILength + } + if _, ok := c.mkis[string(mki)]; ok { + return errMKIAlreadyInUse + } + + keyLen, err := c.profile.keyLen() if err != nil { - return nil, err + return err } - if masterKeyLen := len(masterKey); masterKeyLen != keyLen { - return c, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterKey, masterKey, keyLen) - } else if masterSaltLen := len(masterSalt); masterSaltLen != saltLen { - return c, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterSalt, saltLen, masterSaltLen) + saltLen, err := c.profile.saltLen() + if err != nil { + return err } - c = &Context{ - srtpSSRCStates: map[uint32]*srtpSSRCState{}, - srtcpSSRCStates: map[uint32]*srtcpSSRCState{}, + if masterKeyLen := len(masterKey); masterKeyLen != keyLen { + return fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterKey, masterKey, keyLen) + } else if masterSaltLen := len(masterSalt); masterSaltLen != saltLen { + return fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterSalt, saltLen, masterSaltLen) } - switch profile { + var cipher srtpCipher + switch c.profile { case ProtectionProfileAeadAes128Gcm, ProtectionProfileAeadAes256Gcm: - c.cipher, err = newSrtpCipherAeadAesGcm(profile, masterKey, masterSalt) + cipher, err = newSrtpCipherAeadAesGcm(c.profile, masterKey, masterSalt, mki) case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80: - c.cipher, err = newSrtpCipherAesCmHmacSha1(profile, masterKey, masterSalt) + cipher, err = newSrtpCipherAesCmHmacSha1(c.profile, masterKey, masterSalt, mki) default: - return nil, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, profile) + return fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, c.profile) } if err != nil { - return nil, err + return err } - for _, o := range append( - []ContextOption{ // Default options - SRTPNoReplayProtection(), - SRTCPNoReplayProtection(), - }, - opts..., // User specified options - ) { - if errOpt := o(c); errOpt != nil { - return nil, errOpt - } + c.mkis[string(mki)] = cipher + return nil +} + +// RemoveMKI removes one of MKIs. You cannot remove last MKI and one used for encrypting RTP/RTCP packets. +// Operation is not thread-safe, you need to provide synchronization with decrypting packets. +func (c *Context) RemoveMKI(mki []byte) error { + if _, ok := c.mkis[string(mki)]; !ok { + return ErrMKINotFound + } + if bytes.Equal(mki, c.sendMKI) { + return errMKIAlreadyInUse } + delete(c.mkis, string(mki)) + return nil +} - return c, nil +// SetSendMKI switches MKI and cipher used for encrypting RTP/RTCP packets. +// Operation is not thread-safe, you need to provide synchronization with encrypting packets. +func (c *Context) SetSendMKI(mki []byte) error { + cipher, ok := c.mkis[string(mki)] + if !ok { + return ErrMKINotFound + } + c.sendMKI = mki + c.cipher = cipher + return nil } // https://tools.ietf.org/html/rfc3550#appendix-A.1 diff --git a/context_test.go b/context_test.go index 60b078c..3df0f36 100644 --- a/context_test.go +++ b/context_test.go @@ -5,6 +5,8 @@ package srtp import ( "testing" + + "github.com/stretchr/testify/assert" ) func TestContextROC(t *testing.T) { @@ -44,3 +46,123 @@ func TestContextIndex(t *testing.T) { t.Errorf("Index is set to 100, but returned %d", index) } } + +func TestContextWithoutMKI(t *testing.T) { + c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR) + if err != nil { + t.Fatal(err) + } + + err = c.AddCipherForMKI(nil, make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) + + err = c.AddCipherForMKI(make([]byte, 0), make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) + + err = c.AddCipherForMKI(make([]byte, 4), make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) + + err = c.SetSendMKI(nil) + assert.NoError(t, err) + + err = c.SetSendMKI(make([]byte, 0)) + assert.NoError(t, err) + + err = c.RemoveMKI(nil) + assert.Error(t, err) + + err = c.RemoveMKI(make([]byte, 0)) + assert.Error(t, err) + + err = c.RemoveMKI(make([]byte, 2)) + assert.Error(t, err) +} + +func TestAddMKIToContextWithMKI(t *testing.T) { + mki1 := []byte{1, 2, 3, 4} + mki2 := []byte{2, 3, 4, 5} + + c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14)) + if err != nil { + t.Fatal(err) + } + + err = c.AddCipherForMKI(nil, make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) + + err = c.AddCipherForMKI(make([]byte, 0), make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) + + err = c.AddCipherForMKI(make([]byte, 3), make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) + + err = c.AddCipherForMKI(mki1, make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) + + err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14)) + assert.Error(t, err) +} + +func TestContextSetSendMKI(t *testing.T) { + mki1 := []byte{1, 2, 3, 4} + mki2 := []byte{2, 3, 4, 5} + + c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14)) + if err != nil { + t.Fatal(err) + } + + err = c.SetSendMKI(mki1) + assert.NoError(t, err) + + err = c.SetSendMKI(mki2) + assert.NoError(t, err) + + err = c.SetSendMKI(make([]byte, 4)) + assert.Error(t, err) +} + +func TestContextRemoveMKI(t *testing.T) { + mki1 := []byte{1, 2, 3, 4} + mki2 := []byte{2, 3, 4, 5} + mki3 := []byte{3, 4, 5, 6} + + c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14)) + if err != nil { + t.Fatal(err) + } + err = c.AddCipherForMKI(mki3, make([]byte, 16), make([]byte, 14)) + if err != nil { + t.Fatal(err) + } + + err = c.RemoveMKI(make([]byte, 4)) + assert.Error(t, err) + + err = c.RemoveMKI(mki1) + assert.Error(t, err) + + err = c.SetSendMKI(mki3) + assert.NoError(t, err) + + err = c.RemoveMKI(mki1) + assert.NoError(t, err) + + err = c.RemoveMKI(mki2) + assert.NoError(t, err) + + err = c.RemoveMKI(mki3) + assert.Error(t, err) +} diff --git a/errors.go b/errors.go index 5b1751d..83a99c3 100644 --- a/errors.go +++ b/errors.go @@ -9,6 +9,11 @@ import ( ) var ( + // ErrFailedToVerifyAuthTag is returned when decryption fails due to invalid authentication tag + ErrFailedToVerifyAuthTag = errors.New("failed to verify auth tag") + // ErrFailedToVerifyAuthTag is returned when decryption fails due to unknown MKI value in packet + ErrMKINotFound = errors.New("MKI not found") + errDuplicated = errors.New("duplicated packet") errShortSrtpMasterKey = errors.New("SRTP master key is not long enough") errShortSrtpMasterSalt = errors.New("SRTP master salt is not long enough") @@ -17,12 +22,14 @@ var ( errExporterWrongLabel = errors.New("exporter called with wrong label") errNoConfig = errors.New("no config provided") errNoConn = errors.New("no conn provided") - errFailedToVerifyAuthTag = errors.New("failed to verify auth tag") - errTooShortRTCP = errors.New("packet is too short to be rtcp packet") + errTooShortRTP = errors.New("packet is too short to be RTP packet") + errTooShortRTCP = errors.New("packet is too short to be RTCP packet") errPayloadDiffers = errors.New("payload differs") errStartedChannelUsedIncorrectly = errors.New("started channel used incorrectly, should only be closed") errBadIVLength = errors.New("bad iv length in xorBytesCTR") errExceededMaxPackets = errors.New("exceeded the maximum number of packets") + errMKIAlreadyInUse = errors.New("MKI already in use") + errInvalidMKILength = errors.New("invalid MKI length") errStreamNotInited = errors.New("stream has not been inited, unable to close") errStreamAlreadyClosed = errors.New("stream is already closed") diff --git a/option.go b/option.go index 67bdf2e..f7f011b 100644 --- a/option.go +++ b/option.go @@ -71,3 +71,14 @@ type nopReplayDetector struct{} func (s *nopReplayDetector) Check(uint64) (func() bool, bool) { return func() bool { return true }, true } + +// MasterKeyIndicator sets MKI for RTP and RTCP. +func MasterKeyIndicator(mki []byte) ContextOption { + return func(c *Context) error { + if len(mki) > 0 { + c.sendMKI = make([]byte, len(mki)) + copy(c.sendMKI, mki) + } + return nil + } +} diff --git a/srtcp.go b/srtcp.go index 7fd0746..b81040a 100644 --- a/srtcp.go +++ b/srtcp.go @@ -23,9 +23,10 @@ func (c *Context) decryptRTCP(dst, encrypted []byte) ([]byte, error) { if err != nil { return nil, err } - tailOffset := len(encrypted) - (authTagLen + srtcpIndexSize) + mkiLen := len(c.sendMKI) + tailOffset := len(encrypted) - (authTagLen + mkiLen + srtcpIndexSize) - if tailOffset < aeadAuthTagLen { + if tailOffset < aeadAuthTagLen+8 { return nil, fmt.Errorf("%w: %d", errTooShortRTCP, len(encrypted)) } else if isEncrypted := encrypted[tailOffset] >> 7; isEncrypted == 0 { return out, nil @@ -40,7 +41,17 @@ func (c *Context) decryptRTCP(dst, encrypted []byte) ([]byte, error) { return nil, &duplicatedError{Proto: "srtcp", SSRC: ssrc, Index: index} } - out, err = c.cipher.decryptRTCP(out, encrypted, index, ssrc) + cipher := c.cipher + if len(c.mkis) > 0 { + // Find cipher for MKI + actualMKI := c.cipher.getMKI(encrypted, false) + cipher, ok = c.mkis[string(actualMKI)] + if !ok { + return nil, ErrMKINotFound + } + } + + out, err = cipher.decryptRTCP(out, encrypted, index, ssrc) if err != nil { return nil, err } @@ -63,6 +74,10 @@ func (c *Context) DecryptRTCP(dst, encrypted []byte, header *rtcp.Header) ([]byt } func (c *Context) encryptRTCP(dst, decrypted []byte) ([]byte, error) { + if len(decrypted) < 8 { + return nil, errTooShortRTCP + } + ssrc := binary.BigEndian.Uint32(decrypted[4:]) s := c.getSRTCPSSRCState(ssrc) diff --git a/srtcp_test.go b/srtcp_test.go index 8cf8e3a..273656e 100644 --- a/srtcp_test.go +++ b/srtcp_test.go @@ -594,3 +594,163 @@ func TestRTCPReplayDetectorFactory(t *testing.T) { } assert.Equal(1, cntFactory) } + +func TestRTCPInvalidMKI(t *testing.T) { + mki1 := []byte{0x01, 0x02, 0x03, 0x04} + mki2 := []byte{0x02, 0x03, 0x04, 0x05} + + for caseName, testCase := range rtcpTestCases() { + testCase := testCase + t.Run(caseName, func(t *testing.T) { + encryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki1)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + + decryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki2)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + + for _, pkt := range testCase.packets { + rtcpPacket := append([]byte{}, pkt.decrypted...) + encrypted, err := encryptContext.encryptRTCP(nil, rtcpPacket) + if err != nil { + t.Error(err) + } + + _, err = decryptContext.DecryptRTCP(nil, encrypted, nil) + if err == nil { + t.Errorf("Managed to decrypt with incorrect MKI for packet with SSRC: %d", pkt.ssrc) + } else { + assert.ErrorIs(t, err, ErrMKINotFound) + } + } + }) + } +} + +func TestRTCPHandleMultipleMKI(t *testing.T) { + mki1 := []byte{0x01, 0x02, 0x03, 0x04} + mki2 := []byte{0x02, 0x03, 0x04, 0x05} + + for caseName, testCase := range rtcpTestCases() { + testCase := testCase + t.Run(caseName, func(t *testing.T) { + masterKey2 := make([]byte, len(testCase.masterKey)) + copy(masterKey2, testCase.masterKey) + masterKey2[0] = ^masterKey2[0] + + encryptContext1, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki1)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + encryptContext2, err := CreateContext(masterKey2, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki2)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + + decryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki1)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + err = decryptContext.AddCipherForMKI(mki2, masterKey2, testCase.masterSalt) + if err != nil { + t.Errorf("AddCipherForMKI failed: %v", err) + } + + for _, pkt := range testCase.packets { + rtcpPacket := append([]byte{}, pkt.decrypted...) + encrypted1, err := encryptContext1.encryptRTCP(nil, rtcpPacket) + if err != nil { + t.Error(err) + } + encrypted2, err := encryptContext2.encryptRTCP(nil, rtcpPacket) + if err != nil { + t.Error(err) + } + + decrypted1, err := decryptContext.DecryptRTCP(nil, encrypted1, nil) + if err != nil { + t.Fatal(err) + } + decrypted2, err := decryptContext.DecryptRTCP(nil, encrypted2, nil) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, rtcpPacket, decrypted1) + assert.Equal(t, rtcpPacket, decrypted2) + } + }) + } +} + +func TestRTCPSwitchMKI(t *testing.T) { + mki1 := []byte{0x01, 0x02, 0x03, 0x04} + mki2 := []byte{0x02, 0x03, 0x04, 0x05} + + for caseName, testCase := range rtcpTestCases() { + testCase := testCase + t.Run(caseName, func(t *testing.T) { + masterKey2 := make([]byte, len(testCase.masterKey)) + copy(masterKey2, testCase.masterKey) + masterKey2[0] = ^masterKey2[0] + + encryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki1)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + err = encryptContext.AddCipherForMKI(mki2, masterKey2, testCase.masterSalt) + if err != nil { + t.Errorf("AddCipherForMKI failed: %v", err) + } + + decryptContext1, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki1)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + decryptContext2, err := CreateContext(masterKey2, testCase.masterSalt, testCase.algo, MasterKeyIndicator(mki2)) + if err != nil { + t.Errorf("CreateContext failed: %v", err) + } + + for _, pkt := range testCase.packets { + rtcpPacket := append([]byte{}, pkt.decrypted...) + encrypted1, err := encryptContext.encryptRTCP(nil, rtcpPacket) + if err != nil { + t.Error(err) + } + + err = encryptContext.SetSendMKI(mki2) + if err != nil { + t.Fatal(err) + } + + encrypted2, err := encryptContext.encryptRTCP(nil, rtcpPacket) + if err != nil { + t.Error(err) + } + + assert.NotEqual(t, encrypted1, encrypted2) + + decrypted1, err := decryptContext1.DecryptRTCP(nil, encrypted1, nil) + if err != nil { + t.Fatal(err) + } + decrypted2, err := decryptContext2.DecryptRTCP(nil, encrypted2, nil) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, rtcpPacket, decrypted1) + assert.Equal(t, rtcpPacket, decrypted2) + + err = encryptContext.SetSendMKI(mki1) + if err != nil { + t.Fatal(err) + } + } + }) + } +} diff --git a/srtp.go b/srtp.go index 42c71be..f7f0186 100644 --- a/srtp.go +++ b/srtp.go @@ -9,6 +9,14 @@ import ( ) func (c *Context) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerLen int) ([]byte, error) { + authTagLen, err := c.cipher.rtpAuthTagLen() + if err != nil { + return nil, err + } + if len(ciphertext) < headerLen+len(c.sendMKI)+authTagLen { + return nil, errTooShortRTP + } + s := c.getSRTPSSRCState(header.SSRC) roc, diff, _ := s.nextRolloverCount(header.SequenceNumber) @@ -21,13 +29,19 @@ func (c *Context) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerL } } - authTagLen, err := c.cipher.rtpAuthTagLen() - if err != nil { - return nil, err + cipher := c.cipher + if len(c.mkis) > 0 { + // Find cipher for MKI + actualMKI := c.cipher.getMKI(ciphertext, true) + cipher, ok = c.mkis[string(actualMKI)] + if !ok { + return nil, ErrMKINotFound + } } - dst = growBufferSize(dst, len(ciphertext)-authTagLen) - dst, err = c.cipher.decryptRTP(dst, ciphertext, header, headerLen, roc) + dst = growBufferSize(dst, len(ciphertext)-authTagLen-len(c.sendMKI)) + + dst, err = cipher.decryptRTP(dst, ciphertext, header, headerLen, roc) if err != nil { return nil, err } diff --git a/srtp_cipher.go b/srtp_cipher.go index db50147..67ac561 100644 --- a/srtp_cipher.go +++ b/srtp_cipher.go @@ -16,6 +16,7 @@ type srtpCipher interface { // See the note below. aeadAuthTagLen() (int, error) getRTCPIndex([]byte) uint32 + getMKI([]byte, bool) []byte encryptRTP([]byte, *rtp.Header, []byte, uint32) ([]byte, error) encryptRTCP([]byte, []byte, uint32, uint32) ([]byte, error) diff --git a/srtp_cipher_aead_aes_gcm.go b/srtp_cipher_aead_aes_gcm.go index 90643d9..b980584 100644 --- a/srtp_cipher_aead_aes_gcm.go +++ b/srtp_cipher_aead_aes_gcm.go @@ -21,9 +21,11 @@ type srtpCipherAeadAesGcm struct { srtpCipher, srtcpCipher cipher.AEAD srtpSessionSalt, srtcpSessionSalt []byte + + mki []byte } -func newSrtpCipherAeadAesGcm(profile ProtectionProfile, masterKey, masterSalt []byte) (*srtpCipherAeadAesGcm, error) { +func newSrtpCipherAeadAesGcm(profile ProtectionProfile, masterKey, masterSalt, mki []byte) (*srtpCipherAeadAesGcm, error) { s := &srtpCipherAeadAesGcm{ProtectionProfile: profile} srtpSessionKey, err := aesCmKeyDerivation(labelSRTPEncryption, masterKey, masterSalt, 0, len(masterKey)) @@ -62,6 +64,12 @@ func newSrtpCipherAeadAesGcm(profile ProtectionProfile, masterKey, masterSalt [] return nil, err } + mkiLen := len(mki) + if mkiLen > 0 { + s.mki = make([]byte, mkiLen) + copy(s.mki, mki) + } + return s, nil } @@ -71,7 +79,7 @@ func (s *srtpCipherAeadAesGcm) encryptRTP(dst []byte, header *rtp.Header, payloa if err != nil { return nil, err } - dst = growBufferSize(dst, header.MarshalSize()+len(payload)+authTagLen) + dst = growBufferSize(dst, header.MarshalSize()+len(payload)+authTagLen+len(s.mki)) n, err := header.MarshalTo(dst) if err != nil { @@ -80,6 +88,12 @@ func (s *srtpCipherAeadAesGcm) encryptRTP(dst []byte, header *rtp.Header, payloa iv := s.rtpInitializationVector(header, roc) s.srtpCipher.Seal(dst[n:n], iv[:], payload, dst[:n]) + + // Add MKI after the encrypted payload + if len(s.mki) > 0 { + copy(dst[len(dst)-len(s.mki):], s.mki) + } + return dst, nil } @@ -89,17 +103,18 @@ func (s *srtpCipherAeadAesGcm) decryptRTP(dst, ciphertext []byte, header *rtp.He if err != nil { return nil, err } - nDst := len(ciphertext) - authTagLen - if nDst < 0 { + nDst := len(ciphertext) - authTagLen - len(s.mki) + if nDst < headerLen { // Size of ciphertext is shorter than AEAD auth tag len. - return nil, errFailedToVerifyAuthTag + return nil, ErrFailedToVerifyAuthTag } dst = growBufferSize(dst, nDst) iv := s.rtpInitializationVector(header, roc) + nEnd := len(ciphertext) - len(s.mki) if _, err := s.srtpCipher.Open( - dst[headerLen:headerLen], iv[:], ciphertext[headerLen:], ciphertext[:headerLen], + dst[headerLen:headerLen], iv[:], ciphertext[headerLen:nEnd], ciphertext[:headerLen], ); err != nil { return nil, err } @@ -115,7 +130,7 @@ func (s *srtpCipherAeadAesGcm) encryptRTCP(dst, decrypted []byte, srtcpIndex uin } aadPos := len(decrypted) + authTagLen // Grow the given buffer to fit the output. - dst = growBufferSize(dst, aadPos+srtcpIndexSize) + dst = growBufferSize(dst, aadPos+srtcpIndexSize+len(s.mki)) iv := s.rtcpInitializationVector(srtcpIndex, ssrc) aad := s.rtcpAdditionalAuthenticatedData(decrypted, srtcpIndex) @@ -124,11 +139,12 @@ func (s *srtpCipherAeadAesGcm) encryptRTCP(dst, decrypted []byte, srtcpIndex uin copy(dst[:8], decrypted[:8]) copy(dst[aadPos:aadPos+4], aad[8:12]) + copy(dst[aadPos+4:], s.mki) return dst, nil } func (s *srtpCipherAeadAesGcm) decryptRTCP(dst, encrypted []byte, srtcpIndex, ssrc uint32) ([]byte, error) { - aadPos := len(encrypted) - srtcpIndexSize + aadPos := len(encrypted) - srtcpIndexSize - len(s.mki) // Grow the given buffer to fit the output. authTagLen, err := s.aeadAuthTagLen() if err != nil { @@ -137,7 +153,7 @@ func (s *srtpCipherAeadAesGcm) decryptRTCP(dst, encrypted []byte, srtcpIndex, ss nDst := aadPos - authTagLen if nDst < 0 { // Size of ciphertext is shorter than AEAD auth tag len. - return nil, errFailedToVerifyAuthTag + return nil, ErrFailedToVerifyAuthTag } dst = growBufferSize(dst, nDst) @@ -205,5 +221,15 @@ func (s *srtpCipherAeadAesGcm) rtcpAdditionalAuthenticatedData(rtcpPacket []byte } func (s *srtpCipherAeadAesGcm) getRTCPIndex(in []byte) uint32 { - return binary.BigEndian.Uint32(in[len(in)-4:]) &^ (rtcpEncryptionFlag << 24) + return binary.BigEndian.Uint32(in[len(in)-len(s.mki)-4:]) &^ (rtcpEncryptionFlag << 24) +} + +func (s *srtpCipherAeadAesGcm) getMKI(in []byte, _ bool) []byte { + mkiLen := len(s.mki) + if mkiLen == 0 { + return nil + } + + tailOffset := len(in) - mkiLen + return in[tailOffset:] } diff --git a/srtp_cipher_aead_aes_gcm_test.go b/srtp_cipher_aead_aes_gcm_test.go index 90fd88c..11ba650 100644 --- a/srtp_cipher_aead_aes_gcm_test.go +++ b/srtp_cipher_aead_aes_gcm_test.go @@ -9,6 +9,13 @@ import ( "github.com/stretchr/testify/assert" ) +func addMkiToAeadPacket(packet, mki []byte) []byte { + p := make([]byte, len(packet)+len(mki)) + copy(p, packet) + copy(p[len(packet):], mki) + return p +} + func TestSrtpCipherAedAes128Gcm(t *testing.T) { decryptedRTPPacket := []byte{ 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, @@ -41,6 +48,8 @@ func TestSrtpCipherAedAes128Gcm(t *testing.T) { masterKey := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f} masterSalt := []byte{0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xab} + mki := []byte{0x01, 0x02, 0x03, 0x04} + t.Run("Encrypt RTP", func(t *testing.T) { ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAeadAes128Gcm) assert.NoError(t, err) @@ -63,6 +72,28 @@ func TestSrtpCipherAedAes128Gcm(t *testing.T) { }) }) + t.Run("Encrypt RTP with MKI", func(t *testing.T) { + ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAeadAes128Gcm, MasterKeyIndicator(mki)) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualEncrypted, err := ctx.EncryptRTP(nil, decryptedRTPPacket, nil) + assert.NoError(t, err) + assert.Equal(t, addMkiToAeadPacket(encryptedRTPPacket, mki), actualEncrypted) + }) + }) + + t.Run("Decrypt RTP with MKI", func(t *testing.T) { + ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAeadAes128Gcm, MasterKeyIndicator(mki)) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualDecrypted, err := ctx.DecryptRTP(nil, addMkiToAeadPacket(encryptedRTPPacket, mki), nil) + assert.NoError(t, err) + assert.Equal(t, decryptedRTPPacket, actualDecrypted) + }) + }) + t.Run("Encrypt RTCP", func(t *testing.T) { ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAeadAes128Gcm) assert.NoError(t, err) @@ -84,6 +115,28 @@ func TestSrtpCipherAedAes128Gcm(t *testing.T) { assert.Equal(t, decryptedRtcpPacket, actualDecrypted) }) }) + + t.Run("Encrypt RTCP with MKI", func(t *testing.T) { + ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAeadAes128Gcm, MasterKeyIndicator(mki)) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualEncrypted, err := ctx.EncryptRTCP(nil, decryptedRtcpPacket, nil) + assert.NoError(t, err) + assert.Equal(t, addMkiToAeadPacket(encryptedRtcpPacket, mki), actualEncrypted) + }) + }) + + t.Run("Decrypt RTCP with MKI", func(t *testing.T) { + ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAeadAes128Gcm, MasterKeyIndicator(mki)) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualDecrypted, err := ctx.DecryptRTCP(nil, addMkiToAeadPacket(encryptedRtcpPacket, mki), nil) + assert.NoError(t, err) + assert.Equal(t, decryptedRtcpPacket, actualDecrypted) + }) + }) } func TestSrtpCipherAedAes256Gcm(t *testing.T) { @@ -118,6 +171,8 @@ func TestSrtpCipherAedAes256Gcm(t *testing.T) { masterKey := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f} masterSalt := []byte{0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xab} + mki := []byte{0x01, 0x02, 0x03, 0x04} + t.Run("Encrypt RTP", func(t *testing.T) { ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAeadAes256Gcm) assert.NoError(t, err) @@ -140,6 +195,28 @@ func TestSrtpCipherAedAes256Gcm(t *testing.T) { }) }) + t.Run("Encrypt RTP with MKI", func(t *testing.T) { + ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAeadAes256Gcm, MasterKeyIndicator(mki)) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualEncrypted, err := ctx.EncryptRTP(nil, decryptedRTPPacket, nil) + assert.NoError(t, err) + assert.Equal(t, addMkiToAeadPacket(encryptedRTPPacket, mki), actualEncrypted) + }) + }) + + t.Run("Decrypt RTP with MKI", func(t *testing.T) { + ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAeadAes256Gcm, MasterKeyIndicator(mki)) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualDecrypted, err := ctx.DecryptRTP(nil, addMkiToAeadPacket(encryptedRTPPacket, mki), nil) + assert.NoError(t, err) + assert.Equal(t, decryptedRTPPacket, actualDecrypted) + }) + }) + t.Run("Encrypt RTCP", func(t *testing.T) { ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAeadAes256Gcm) assert.NoError(t, err) @@ -161,4 +238,26 @@ func TestSrtpCipherAedAes256Gcm(t *testing.T) { assert.Equal(t, decryptedRtcpPacket, actualDecrypted) }) }) + + t.Run("Encrypt RTCP with MKI", func(t *testing.T) { + ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAeadAes256Gcm, MasterKeyIndicator(mki)) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualEncrypted, err := ctx.EncryptRTCP(nil, decryptedRtcpPacket, nil) + assert.NoError(t, err) + assert.Equal(t, addMkiToAeadPacket(encryptedRtcpPacket, mki), actualEncrypted) + }) + }) + + t.Run("Decrypt RTCP with MKI", func(t *testing.T) { + ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAeadAes256Gcm, MasterKeyIndicator(mki)) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualDecrypted, err := ctx.DecryptRTCP(nil, addMkiToAeadPacket(encryptedRtcpPacket, mki), nil) + assert.NoError(t, err) + assert.Equal(t, decryptedRtcpPacket, actualDecrypted) + }) + }) } diff --git a/srtp_cipher_aes_cm_hmac_sha1.go b/srtp_cipher_aes_cm_hmac_sha1.go index d56e6af..5701083 100644 --- a/srtp_cipher_aes_cm_hmac_sha1.go +++ b/srtp_cipher_aes_cm_hmac_sha1.go @@ -25,9 +25,11 @@ type srtpCipherAesCmHmacSha1 struct { srtcpSessionSalt []byte srtcpSessionAuth hash.Hash srtcpBlock cipher.Block + + mki []byte } -func newSrtpCipherAesCmHmacSha1(profile ProtectionProfile, masterKey, masterSalt []byte) (*srtpCipherAesCmHmacSha1, error) { +func newSrtpCipherAesCmHmacSha1(profile ProtectionProfile, masterKey, masterSalt, mki []byte) (*srtpCipherAesCmHmacSha1, error) { s := &srtpCipherAesCmHmacSha1{ProtectionProfile: profile} srtpSessionKey, err := aesCmKeyDerivation(labelSRTPEncryption, masterKey, masterSalt, 0, len(masterKey)) if err != nil { @@ -66,6 +68,13 @@ func newSrtpCipherAesCmHmacSha1(profile ProtectionProfile, masterKey, masterSalt s.srtcpSessionAuth = hmac.New(sha1.New, srtcpSessionAuthTag) s.srtpSessionAuth = hmac.New(sha1.New, srtpSessionAuthTag) + + mkiLen := len(mki) + if mkiLen > 0 { + s.mki = make([]byte, mkiLen) + copy(s.mki, mki) + } + return s, nil } @@ -75,7 +84,7 @@ func (s *srtpCipherAesCmHmacSha1) encryptRTP(dst []byte, header *rtp.Header, pay if err != nil { return nil, err } - dst = growBufferSize(dst, header.MarshalSize()+len(payload)+authTagLen) + dst = growBufferSize(dst, header.MarshalSize()+len(payload)+len(s.mki)+authTagLen) // Copy the header unencrypted. n, err := header.MarshalTo(dst) @@ -96,6 +105,12 @@ func (s *srtpCipherAesCmHmacSha1) encryptRTP(dst []byte, header *rtp.Header, pay return nil, err } + // Append the MKI (if used) + if len(s.mki) > 0 { + copy(dst[n:], s.mki) + n += len(s.mki) + } + // Write the auth tag to the dest. copy(dst[n:], authTag) @@ -103,13 +118,14 @@ func (s *srtpCipherAesCmHmacSha1) encryptRTP(dst []byte, header *rtp.Header, pay } func (s *srtpCipherAesCmHmacSha1) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerLen int, roc uint32) ([]byte, error) { - // Split the auth tag and the cipher text into two parts. authTagLen, err := s.rtpAuthTagLen() if err != nil { return nil, err } + + // Split the auth tag and the cipher text into two parts. actualTag := ciphertext[len(ciphertext)-authTagLen:] - ciphertext = ciphertext[:len(ciphertext)-authTagLen] + ciphertext = ciphertext[:len(ciphertext)-len(s.mki)-authTagLen] // Generate the auth tag we expect to see from the ciphertext. expectedTag, err := s.generateSrtpAuthTag(ciphertext, roc) @@ -120,7 +136,7 @@ func (s *srtpCipherAesCmHmacSha1) decryptRTP(dst, ciphertext []byte, header *rtp // See if the auth tag actually matches. // We use a constant time comparison to prevent timing attacks. if subtle.ConstantTimeCompare(actualTag, expectedTag) != 1 { - return nil, errFailedToVerifyAuthTag + return nil, ErrFailedToVerifyAuthTag } // Write the plaintext header to the destination buffer. @@ -148,10 +164,18 @@ func (s *srtpCipherAesCmHmacSha1) encryptRTCP(dst, decrypted []byte, srtcpIndex binary.BigEndian.PutUint32(dst[len(dst)-4:], srtcpIndex) dst[len(dst)-4] |= 0x80 + // Generate the authentication tag authTag, err := s.generateSrtcpAuthTag(dst) if err != nil { return nil, err } + + // Include the MKI if provided + if len(s.mki) > 0 { + dst = append(dst, s.mki...) + } + + // Append the auth tag at the end of the buffer return append(dst, authTag...), nil } @@ -160,17 +184,18 @@ func (s *srtpCipherAesCmHmacSha1) decryptRTCP(out, encrypted []byte, index, ssrc if err != nil { return nil, err } - tailOffset := len(encrypted) - (authTagLen + srtcpIndexSize) + + tailOffset := len(encrypted) - (authTagLen + len(s.mki) + srtcpIndexSize) out = out[0:tailOffset] - expectedTag, err := s.generateSrtcpAuthTag(encrypted[:len(encrypted)-authTagLen]) + expectedTag, err := s.generateSrtcpAuthTag(encrypted[:len(encrypted)-len(s.mki)-authTagLen]) if err != nil { return nil, err } actualTag := encrypted[len(encrypted)-authTagLen:] if subtle.ConstantTimeCompare(actualTag, expectedTag) != 1 { - return nil, errFailedToVerifyAuthTag + return nil, ErrFailedToVerifyAuthTag } counter := generateCounter(uint16(index&0xffff), index>>16, ssrc, s.srtcpSessionSalt) @@ -244,7 +269,23 @@ func (s *srtpCipherAesCmHmacSha1) generateSrtcpAuthTag(buf []byte) ([]byte, erro func (s *srtpCipherAesCmHmacSha1) getRTCPIndex(in []byte) uint32 { authTagLen, _ := s.rtcpAuthTagLen() - tailOffset := len(in) - (authTagLen + srtcpIndexSize) + tailOffset := len(in) - (authTagLen + srtcpIndexSize + len(s.mki)) srtcpIndexBuffer := in[tailOffset : tailOffset+srtcpIndexSize] return binary.BigEndian.Uint32(srtcpIndexBuffer) &^ (1 << 31) } + +func (s *srtpCipherAesCmHmacSha1) getMKI(in []byte, rtp bool) []byte { + mkiLen := len(s.mki) + if mkiLen == 0 { + return nil + } + + var authTagLen int + if rtp { + authTagLen, _ = s.rtpAuthTagLen() + } else { + authTagLen, _ = s.rtcpAuthTagLen() + } + tailOffset := len(in) - (authTagLen + mkiLen) + return in[tailOffset : tailOffset+mkiLen] +} diff --git a/srtp_cipher_aes_cm_hmac_sha1_test.go b/srtp_cipher_aes_cm_hmac_sha1_test.go new file mode 100644 index 0000000..ad3fa84 --- /dev/null +++ b/srtp_cipher_aes_cm_hmac_sha1_test.go @@ -0,0 +1,259 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package srtp + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func addMkiToAesPacket(packet, mki []byte, authTagLen int) []byte { + p := make([]byte, len(packet)+len(mki)) + copy(p, packet[:len(packet)-authTagLen]) + copy(p[len(packet)-authTagLen:], mki) + copy(p[len(packet)-authTagLen+len(mki):], packet[len(packet)-authTagLen:]) + return p +} + +func TestSrtpCipherAes128CmHmacSha1_32(t *testing.T) { + decryptedRTPPacket := []byte{ + 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, + 0xca, 0xfe, 0xba, 0xbe, 0xab, 0xab, 0xab, 0xab, + 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, + 0xab, 0xab, 0xab, 0xab, + } + encryptedRTPPacket := []byte{ + 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, + 0xca, 0xfe, 0xba, 0xbe, 0xe2, 0xd8, 0xdf, 0x8f, + 0x7a, 0x75, 0xd6, 0x88, 0xc3, 0x50, 0x2e, 0xee, + 0xc2, 0xa9, 0x80, 0x66, 0xcd, 0x7c, 0x0d, 0x09, + } + decryptedRtcpPacket := []byte{ + 0x81, 0xc8, 0x00, 0x0b, 0xca, 0xfe, 0xba, 0xbe, + 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, + 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, + } + encryptedRtcpPacket := []byte{ + 0x81, 0xc8, 0x0, 0x00b, 0xca, 0xfe, 0xba, 0xbe, + 0x56, 0x74, 0xbf, 0x01, 0x81, 0x3d, 0xc0, 0x62, + 0xac, 0x1d, 0xf6, 0xf7, 0x5f, 0x77, 0xc6, 0x88, + 0x80, 0x00, 0x00, 0x01, 0x3d, 0xb7, 0xa1, 0x98, + 0x37, 0xff, 0x64, 0xe5, 0xcb, 0xd2, + } + + masterKey := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f} + masterSalt := []byte{0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xab, 0x0ac, 0xad} + + mki := []byte{0x01, 0x02, 0x03, 0x04} + + t.Run("Encrypt RTP", func(t *testing.T) { + ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAes128CmHmacSha1_32) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualEncrypted, err := ctx.EncryptRTP(nil, decryptedRTPPacket, nil) + assert.NoError(t, err) + assert.Equal(t, encryptedRTPPacket, actualEncrypted) + }) + }) + + t.Run("Decrypt RTP", func(t *testing.T) { + ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAes128CmHmacSha1_32) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualDecrypted, err := ctx.DecryptRTP(nil, encryptedRTPPacket, nil) + assert.NoError(t, err) + assert.Equal(t, decryptedRTPPacket, actualDecrypted) + }) + }) + + t.Run("Encrypt RTP with MKI", func(t *testing.T) { + ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAes128CmHmacSha1_32, MasterKeyIndicator(mki)) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualEncrypted, err := ctx.EncryptRTP(nil, decryptedRTPPacket, nil) + assert.NoError(t, err) + assert.Equal(t, addMkiToAesPacket(encryptedRTPPacket, mki, 4), actualEncrypted) + }) + }) + + t.Run("Decrypt RTP with MKI", func(t *testing.T) { + ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAes128CmHmacSha1_32, MasterKeyIndicator(mki)) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualDecrypted, err := ctx.DecryptRTP(nil, addMkiToAesPacket(encryptedRTPPacket, mki, 4), nil) + assert.NoError(t, err) + assert.Equal(t, decryptedRTPPacket, actualDecrypted) + }) + }) + + t.Run("Encrypt RTCP", func(t *testing.T) { + ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAes128CmHmacSha1_32) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualEncrypted, err := ctx.EncryptRTCP(nil, decryptedRtcpPacket, nil) + assert.NoError(t, err) + assert.Equal(t, encryptedRtcpPacket, actualEncrypted) + }) + }) + + t.Run("Decrypt RTCP", func(t *testing.T) { + ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAes128CmHmacSha1_32) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualDecrypted, err := ctx.DecryptRTCP(nil, encryptedRtcpPacket, nil) + assert.NoError(t, err) + assert.Equal(t, decryptedRtcpPacket, actualDecrypted) + }) + }) + + t.Run("Encrypt RTCP with MKI", func(t *testing.T) { + ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAes128CmHmacSha1_32, MasterKeyIndicator(mki)) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualEncrypted, err := ctx.EncryptRTCP(nil, decryptedRtcpPacket, nil) + assert.NoError(t, err) + assert.Equal(t, addMkiToAesPacket(encryptedRtcpPacket, mki, 10), actualEncrypted) + }) + }) + + t.Run("Decrypt RTCP with MKI", func(t *testing.T) { + ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAes128CmHmacSha1_32, MasterKeyIndicator(mki)) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualDecrypted, err := ctx.DecryptRTCP(nil, addMkiToAesPacket(encryptedRtcpPacket, mki, 10), nil) + assert.NoError(t, err) + assert.Equal(t, decryptedRtcpPacket, actualDecrypted) + }) + }) +} + +func TestSrtpCipherAes128CmHmacSha1_80(t *testing.T) { + decryptedRTPPacket := []byte{ + 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, + 0xca, 0xfe, 0xba, 0xbe, 0xab, 0xab, 0xab, 0xab, + 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, + 0xab, 0xab, 0xab, 0xab, + } + encryptedRTPPacket := []byte{ + 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, + 0xca, 0xfe, 0xba, 0xbe, 0xe2, 0xd8, 0xdf, 0x8f, + 0x7a, 0x75, 0xd6, 0x88, 0xc3, 0x50, 0x2e, 0xee, + 0xc2, 0xa9, 0x80, 0x66, 0xcd, 0x7c, 0x0d, 0x09, + 0xca, 0x44, 0x32, 0xa5, 0x6e, 0x3d, + } + decryptedRtcpPacket := []byte{ + 0x81, 0xc8, 0x00, 0x0b, 0xca, 0xfe, 0xba, 0xbe, + 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, + 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, + } + encryptedRtcpPacket := []byte{ + 0x81, 0xc8, 0x0, 0x00b, 0xca, 0xfe, 0xba, 0xbe, + 0x56, 0x74, 0xbf, 0x01, 0x81, 0x3d, 0xc0, 0x62, + 0xac, 0x1d, 0xf6, 0xf7, 0x5f, 0x77, 0xc6, 0x88, + 0x80, 0x00, 0x00, 0x01, 0x3d, 0xb7, 0xa1, 0x98, + 0x37, 0xff, 0x64, 0xe5, 0xcb, 0xd2, + } + + masterKey := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f} + masterSalt := []byte{0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xab, 0x0ac, 0xad} + + mki := []byte{0x01, 0x02, 0x03, 0x04} + + t.Run("Encrypt RTP", func(t *testing.T) { + ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAes128CmHmacSha1_80) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualEncrypted, err := ctx.EncryptRTP(nil, decryptedRTPPacket, nil) + assert.NoError(t, err) + assert.Equal(t, encryptedRTPPacket, actualEncrypted) + }) + }) + + t.Run("Decrypt RTP", func(t *testing.T) { + ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAes128CmHmacSha1_80) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualDecrypted, err := ctx.DecryptRTP(nil, encryptedRTPPacket, nil) + assert.NoError(t, err) + assert.Equal(t, decryptedRTPPacket, actualDecrypted) + }) + }) + + t.Run("Encrypt RTP with MKI", func(t *testing.T) { + ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAes128CmHmacSha1_80, MasterKeyIndicator(mki)) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualEncrypted, err := ctx.EncryptRTP(nil, decryptedRTPPacket, nil) + assert.NoError(t, err) + assert.Equal(t, addMkiToAesPacket(encryptedRTPPacket, mki, 10), actualEncrypted) + }) + }) + + t.Run("Decrypt RTP with MKI", func(t *testing.T) { + ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAes128CmHmacSha1_80, MasterKeyIndicator(mki)) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualDecrypted, err := ctx.DecryptRTP(nil, addMkiToAesPacket(encryptedRTPPacket, mki, 10), nil) + assert.NoError(t, err) + assert.Equal(t, decryptedRTPPacket, actualDecrypted) + }) + }) + + t.Run("Encrypt RTCP", func(t *testing.T) { + ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAes128CmHmacSha1_80) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualEncrypted, err := ctx.EncryptRTCP(nil, decryptedRtcpPacket, nil) + assert.NoError(t, err) + assert.Equal(t, encryptedRtcpPacket, actualEncrypted) + }) + }) + + t.Run("Decrypt RTCP", func(t *testing.T) { + ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAes128CmHmacSha1_80) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualDecrypted, err := ctx.DecryptRTCP(nil, encryptedRtcpPacket, nil) + assert.NoError(t, err) + assert.Equal(t, decryptedRtcpPacket, actualDecrypted) + }) + }) + + t.Run("Encrypt RTCP with MKI", func(t *testing.T) { + ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAes128CmHmacSha1_80, MasterKeyIndicator(mki)) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualEncrypted, err := ctx.EncryptRTCP(nil, decryptedRtcpPacket, nil) + assert.NoError(t, err) + assert.Equal(t, addMkiToAesPacket(encryptedRtcpPacket, mki, 10), actualEncrypted) + }) + }) + + t.Run("Decrypt RTCP with MKI", func(t *testing.T) { + ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAes128CmHmacSha1_80, MasterKeyIndicator(mki)) + assert.NoError(t, err) + + t.Run("New Allocation", func(t *testing.T) { + actualDecrypted, err := ctx.DecryptRTCP(nil, addMkiToAesPacket(encryptedRtcpPacket, mki, 10), nil) + assert.NoError(t, err) + assert.Equal(t, decryptedRtcpPacket, actualDecrypted) + }) + }) +} diff --git a/srtp_test.go b/srtp_test.go index c27f598..05324a0 100644 --- a/srtp_test.go +++ b/srtp_test.go @@ -884,3 +884,161 @@ func TestRTPBurstLossWithSetROC(t *testing.T) { }) } } + +func TestRTPInvalidMKI(t *testing.T) { + mki1 := []byte{0x01, 0x02, 0x03, 0x04} + mki2 := []byte{0x02, 0x03, 0x04, 0x05} + + encryptContext, err := buildTestContext(profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + + decryptContext, err := buildTestContext(profileCTR, MasterKeyIndicator(mki2)) + if err != nil { + t.Fatal(err) + } + + for _, testCase := range rtpTestCases() { + pkt := &rtp.Packet{Payload: rtpTestCaseDecrypted(), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} + pktRaw, err := pkt.Marshal() + if err != nil { + t.Fatal(err) + } + + out, err := encryptContext.EncryptRTP(nil, pktRaw, nil) + if err != nil { + t.Fatal(err) + } + + if _, err := decryptContext.DecryptRTP(nil, out, nil); err == nil { + t.Errorf("Managed to decrypt with incorrect MKI for packet with SeqNum: %d", testCase.sequenceNumber) + } else { + assert.ErrorIs(t, err, ErrMKINotFound) + } + } +} + +func TestRTPHandleMultipleMKI(t *testing.T) { + mki1 := []byte{0x01, 0x02, 0x03, 0x04} + mki2 := []byte{0x02, 0x03, 0x04, 0x05} + + masterKey2 := []byte{0xff, 0xcd, 0x21, 0x3e, 0x4c, 0xbc, 0xf2, 0x8f, 0x01, 0x7f, 0x69, 0x94, 0x40, 0x1e, 0x28, 0x89} + masterSalt2 := []byte{0xff, 0x77, 0x60, 0x38, 0xc0, 0x6d, 0xc9, 0x41, 0x9f, 0x6d, 0xd9, 0x43, 0x3e, 0x7c} + + encryptContext1, err := buildTestContext(profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + + encryptContext2, err := CreateContext(masterKey2, masterSalt2, profileCTR, MasterKeyIndicator(mki2)) + if err != nil { + t.Fatal(err) + } + + decryptContext, err := buildTestContext(profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + err = decryptContext.AddCipherForMKI(mki2, masterKey2, masterSalt2) + if err != nil { + t.Fatal(err) + } + + for _, testCase := range rtpTestCases() { + pkt := &rtp.Packet{Payload: rtpTestCaseDecrypted(), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} + pktRaw, err := pkt.Marshal() + if err != nil { + t.Fatal(err) + } + + encrypted1, err := encryptContext1.EncryptRTP(nil, pktRaw, nil) + if err != nil { + t.Fatal(err) + } + encrypted2, err := encryptContext2.EncryptRTP(nil, pktRaw, nil) + if err != nil { + t.Fatal(err) + } + + decrypted1, err := decryptContext.DecryptRTP(nil, encrypted1, nil) + if err != nil { + t.Fatal(err) + } + decrypted2, err := decryptContext.DecryptRTP(nil, encrypted2, nil) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, pktRaw, decrypted1) + assert.Equal(t, pktRaw, decrypted2) + } +} + +func TestRTPSwitchMKI(t *testing.T) { + mki1 := []byte{0x01, 0x02, 0x03, 0x04} + mki2 := []byte{0x02, 0x03, 0x04, 0x05} + + masterKey2 := []byte{0xff, 0xcd, 0x21, 0x3e, 0x4c, 0xbc, 0xf2, 0x8f, 0x01, 0x7f, 0x69, 0x94, 0x40, 0x1e, 0x28, 0x89} + masterSalt2 := []byte{0xff, 0x77, 0x60, 0x38, 0xc0, 0x6d, 0xc9, 0x41, 0x9f, 0x6d, 0xd9, 0x43, 0x3e, 0x7c} + + encryptContext, err := buildTestContext(profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + err = encryptContext.AddCipherForMKI(mki2, masterKey2, masterSalt2) + if err != nil { + t.Fatal(err) + } + + decryptContext1, err := buildTestContext(profileCTR, MasterKeyIndicator(mki1)) + if err != nil { + t.Fatal(err) + } + decryptContext2, err := CreateContext(masterKey2, masterSalt2, profileCTR, MasterKeyIndicator(mki2)) + if err != nil { + t.Fatal(err) + } + + for _, testCase := range rtpTestCases() { + pkt := &rtp.Packet{Payload: rtpTestCaseDecrypted(), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} + pktRaw, err := pkt.Marshal() + if err != nil { + t.Fatal(err) + } + + encrypted1, err := encryptContext.EncryptRTP(nil, pktRaw, nil) + if err != nil { + t.Fatal(err) + } + + err = encryptContext.SetSendMKI(mki2) + if err != nil { + t.Fatal(err) + } + + encrypted2, err := encryptContext.EncryptRTP(nil, pktRaw, nil) + if err != nil { + t.Fatal(err) + } + + assert.NotEqual(t, encrypted1, encrypted2) + + decrypted1, err := decryptContext1.DecryptRTP(nil, encrypted1, nil) + if err != nil { + t.Fatal(err) + } + decrypted2, err := decryptContext2.DecryptRTP(nil, encrypted2, nil) + if err != nil { + t.Fatal(err) + } + + assert.Equal(t, pktRaw, decrypted1) + assert.Equal(t, pktRaw, decrypted2) + + err = encryptContext.SetSendMKI(mki1) + if err != nil { + t.Fatal(err) + } + } +}