Skip to content

Commit

Permalink
Add support for Master Key Indicator
Browse files Browse the repository at this point in the history
This adds support for Master Key Indicator (MKI). It is used to select
one of pre-configured SRTP/SRTCP encryption keys.

To use it, Context has to be created with MasterKeyIndicator option,
it specifies MKI for master key and salt passed to CreateContext.
Additional master keys/salts with their MKIs can be added using
AddCipherForMKI. To remove MKIs, use RemoveMKI.

All MKIs must have the same length, and use the same length of master
key and salt - they use the same crypto profile.

SRTP/SRTCP packets by default are encrypted using first key/salt/MKI.
To select other key/salt/MKI, use SetSendMKI.

key/salt/MKI used for decryption are chosen automatically, using MKI
sent in encrypted SRTP/SRTCP packet.
  • Loading branch information
daniel@poradnik-webmastera.com authored and daniel@poradnik-webmastera.com committed Jun 29, 2024
1 parent 19b0fa0 commit 0704101
Show file tree
Hide file tree
Showing 13 changed files with 1,018 additions and 55 deletions.
104 changes: 78 additions & 26 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package srtp

import (
"bytes"
"fmt"

"github.com/pion/transport/v3/replaydetector"
Expand Down Expand Up @@ -56,6 +57,10 @@ type Context struct {

newSRTCPReplayDetector func() replaydetector.ReplayDetector
newSRTPReplayDetector func() replaydetector.ReplayDetector

profile ProtectionProfile
sendMKI []byte
mkis map[string]srtpCipher
}

// CreateContext creates a new SRTP Context.
Expand All @@ -66,52 +71,99 @@ type Context struct {
//
// decCtx, err := srtp.CreateContext(key, salt, profile, srtp.SRTPReplayProtection(256))
func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts ...ContextOption) (c *Context, err error) {
keyLen, err := profile.keyLen()
c = &Context{
srtpSSRCStates: map[uint32]*srtpSSRCState{},
srtcpSSRCStates: map[uint32]*srtcpSSRCState{},
profile: profile,
mkis: map[string]srtpCipher{},
}

for _, o := range append(
[]ContextOption{ // Default options
SRTPNoReplayProtection(),
SRTCPNoReplayProtection(),
},
opts..., // User specified options
) {
if errOpt := o(c); errOpt != nil {
return nil, errOpt
}
}

err = c.AddCipherForMKI(c.sendMKI, masterKey, masterSalt)
if err != nil {
return nil, err
}
c.cipher = c.mkis[string(c.sendMKI)]

saltLen, err := profile.saltLen()
return c, nil
}

// AddCipherForMKI adds new MKI with associated masker key and salt. Context must be created with MasterKeyIndicator option to enable MKI support.
func (c *Context) AddCipherForMKI(mki, masterKey, masterSalt []byte) error {
if len(mki) != len(c.sendMKI) {
return errInvalidMKILength
}
if _, ok := c.mkis[string(mki)]; ok {
return errMKIAlreadyInUse
}

keyLen, err := c.profile.keyLen()
if err != nil {
return nil, err
return err
}

if masterKeyLen := len(masterKey); masterKeyLen != keyLen {
return c, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterKey, masterKey, keyLen)
} else if masterSaltLen := len(masterSalt); masterSaltLen != saltLen {
return c, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterSalt, saltLen, masterSaltLen)
saltLen, err := c.profile.saltLen()
if err != nil {
return err
}

c = &Context{
srtpSSRCStates: map[uint32]*srtpSSRCState{},
srtcpSSRCStates: map[uint32]*srtcpSSRCState{},
if masterKeyLen := len(masterKey); masterKeyLen != keyLen {
return fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterKey, masterKey, keyLen)
} else if masterSaltLen := len(masterSalt); masterSaltLen != saltLen {
return fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterSalt, saltLen, masterSaltLen)
}

switch profile {
var cipher srtpCipher
switch c.profile {
case ProtectionProfileAeadAes128Gcm, ProtectionProfileAeadAes256Gcm:
c.cipher, err = newSrtpCipherAeadAesGcm(profile, masterKey, masterSalt)
cipher, err = newSrtpCipherAeadAesGcm(c.profile, masterKey, masterSalt, mki)
case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80:
c.cipher, err = newSrtpCipherAesCmHmacSha1(profile, masterKey, masterSalt)
cipher, err = newSrtpCipherAesCmHmacSha1(c.profile, masterKey, masterSalt, mki)
default:
return nil, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, profile)
return fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, c.profile)
}
if err != nil {
return nil, err
return err
}

for _, o := range append(
[]ContextOption{ // Default options
SRTPNoReplayProtection(),
SRTCPNoReplayProtection(),
},
opts..., // User specified options
) {
if errOpt := o(c); errOpt != nil {
return nil, errOpt
}
c.mkis[string(mki)] = cipher
return nil
}

// RemoveMKI removes one of MKIs. You cannot remove last MKI and one used for encrypting RTP/RTCP packets.
// Operation is not thread-safe, you need to provide synchronization with decrypting packets.
func (c *Context) RemoveMKI(mki []byte) error {
if _, ok := c.mkis[string(mki)]; !ok {
return ErrMKINotFound
}
if bytes.Equal(mki, c.sendMKI) {
return errMKIAlreadyInUse
}
delete(c.mkis, string(mki))
return nil
}

return c, nil
// SetSendMKI switches MKI and cipher used for encrypting RTP/RTCP packets.
// Operation is not thread-safe, you need to provide synchronization with encrypting packets.
func (c *Context) SetSendMKI(mki []byte) error {
cipher, ok := c.mkis[string(mki)]
if !ok {
return ErrMKINotFound
}
c.sendMKI = mki
c.cipher = cipher
return nil
}

// https://tools.ietf.org/html/rfc3550#appendix-A.1
Expand Down
122 changes: 122 additions & 0 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ package srtp

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestContextROC(t *testing.T) {
Expand Down Expand Up @@ -44,3 +46,123 @@ func TestContextIndex(t *testing.T) {
t.Errorf("Index is set to 100, but returned %d", index)
}
}

func TestContextWithoutMKI(t *testing.T) {
c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR)
if err != nil {
t.Fatal(err)
}

err = c.AddCipherForMKI(nil, make([]byte, 16), make([]byte, 14))
assert.Error(t, err)

err = c.AddCipherForMKI(make([]byte, 0), make([]byte, 16), make([]byte, 14))
assert.Error(t, err)

err = c.AddCipherForMKI(make([]byte, 4), make([]byte, 16), make([]byte, 14))
assert.Error(t, err)

err = c.SetSendMKI(nil)
assert.NoError(t, err)

err = c.SetSendMKI(make([]byte, 0))
assert.NoError(t, err)

err = c.RemoveMKI(nil)
assert.Error(t, err)

err = c.RemoveMKI(make([]byte, 0))
assert.Error(t, err)

err = c.RemoveMKI(make([]byte, 2))
assert.Error(t, err)
}

func TestAddMKIToContextWithMKI(t *testing.T) {
mki1 := []byte{1, 2, 3, 4}
mki2 := []byte{2, 3, 4, 5}

c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR, MasterKeyIndicator(mki1))
if err != nil {
t.Fatal(err)
}
err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14))
if err != nil {
t.Fatal(err)
}

err = c.AddCipherForMKI(nil, make([]byte, 16), make([]byte, 14))
assert.Error(t, err)

err = c.AddCipherForMKI(make([]byte, 0), make([]byte, 16), make([]byte, 14))
assert.Error(t, err)

err = c.AddCipherForMKI(make([]byte, 3), make([]byte, 16), make([]byte, 14))
assert.Error(t, err)

err = c.AddCipherForMKI(mki1, make([]byte, 16), make([]byte, 14))
assert.Error(t, err)

err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14))
assert.Error(t, err)
}

func TestContextSetSendMKI(t *testing.T) {
mki1 := []byte{1, 2, 3, 4}
mki2 := []byte{2, 3, 4, 5}

c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR, MasterKeyIndicator(mki1))
if err != nil {
t.Fatal(err)
}
err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14))
if err != nil {
t.Fatal(err)
}

err = c.SetSendMKI(mki1)
assert.NoError(t, err)

err = c.SetSendMKI(mki2)
assert.NoError(t, err)

err = c.SetSendMKI(make([]byte, 4))
assert.Error(t, err)
}

func TestContextRemoveMKI(t *testing.T) {
mki1 := []byte{1, 2, 3, 4}
mki2 := []byte{2, 3, 4, 5}
mki3 := []byte{3, 4, 5, 6}

c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR, MasterKeyIndicator(mki1))
if err != nil {
t.Fatal(err)
}
err = c.AddCipherForMKI(mki2, make([]byte, 16), make([]byte, 14))
if err != nil {
t.Fatal(err)
}
err = c.AddCipherForMKI(mki3, make([]byte, 16), make([]byte, 14))
if err != nil {
t.Fatal(err)
}

err = c.RemoveMKI(make([]byte, 4))
assert.Error(t, err)

err = c.RemoveMKI(mki1)
assert.Error(t, err)

err = c.SetSendMKI(mki3)
assert.NoError(t, err)

err = c.RemoveMKI(mki1)
assert.NoError(t, err)

err = c.RemoveMKI(mki2)
assert.NoError(t, err)

err = c.RemoveMKI(mki3)
assert.Error(t, err)
}
9 changes: 7 additions & 2 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import (
)

var (
ErrFailedToVerifyAuthTag = errors.New("failed to verify auth tag")
ErrMKINotFound = errors.New("MKI not found")

errDuplicated = errors.New("duplicated packet")
errShortSrtpMasterKey = errors.New("SRTP master key is not long enough")
errShortSrtpMasterSalt = errors.New("SRTP master salt is not long enough")
Expand All @@ -17,12 +20,14 @@ var (
errExporterWrongLabel = errors.New("exporter called with wrong label")
errNoConfig = errors.New("no config provided")
errNoConn = errors.New("no conn provided")
errFailedToVerifyAuthTag = errors.New("failed to verify auth tag")
errTooShortRTCP = errors.New("packet is too short to be rtcp packet")
errTooShortRTP = errors.New("packet is too short to be RTP packet")
errTooShortRTCP = errors.New("packet is too short to be RTCP packet")
errPayloadDiffers = errors.New("payload differs")
errStartedChannelUsedIncorrectly = errors.New("started channel used incorrectly, should only be closed")
errBadIVLength = errors.New("bad iv length in xorBytesCTR")
errExceededMaxPackets = errors.New("exceeded the maximum number of packets")
errMKIAlreadyInUse = errors.New("MKI already in use")
errInvalidMKILength = errors.New("invalid MKI length")

errStreamNotInited = errors.New("stream has not been inited, unable to close")
errStreamAlreadyClosed = errors.New("stream is already closed")
Expand Down
11 changes: 11 additions & 0 deletions option.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,14 @@ type nopReplayDetector struct{}
func (s *nopReplayDetector) Check(uint64) (func() bool, bool) {
return func() bool { return true }, true
}

// MasterKeyIndicator sets MKI for RTP and RTCP.
func MasterKeyIndicator(mki []byte) ContextOption {
return func(c *Context) error {
if len(mki) > 0 {
c.sendMKI = make([]byte, len(mki))
copy(c.sendMKI, mki)
}
return nil
}
}
21 changes: 18 additions & 3 deletions srtcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ func (c *Context) decryptRTCP(dst, encrypted []byte) ([]byte, error) {
if err != nil {
return nil, err
}
tailOffset := len(encrypted) - (authTagLen + srtcpIndexSize)
mkiLen := len(c.sendMKI)
tailOffset := len(encrypted) - (authTagLen + mkiLen + srtcpIndexSize)

if tailOffset < aeadAuthTagLen {
if tailOffset < aeadAuthTagLen + 8 {
return nil, fmt.Errorf("%w: %d", errTooShortRTCP, len(encrypted))
} else if isEncrypted := encrypted[tailOffset] >> 7; isEncrypted == 0 {
return out, nil
Expand All @@ -40,7 +41,17 @@ func (c *Context) decryptRTCP(dst, encrypted []byte) ([]byte, error) {
return nil, &duplicatedError{Proto: "srtcp", SSRC: ssrc, Index: index}
}

out, err = c.cipher.decryptRTCP(out, encrypted, index, ssrc)
cipher := c.cipher
if len(c.mkis) > 0 {
// Find cipher for MKI
actualMKI := c.cipher.getMKI(encrypted, false)
cipher, ok = c.mkis[string(actualMKI)]
if !ok {
return nil, ErrMKINotFound
}
}

out, err = cipher.decryptRTCP(out, encrypted, index, ssrc)
if err != nil {
return nil, err
}
Expand All @@ -63,6 +74,10 @@ func (c *Context) DecryptRTCP(dst, encrypted []byte, header *rtcp.Header) ([]byt
}

func (c *Context) encryptRTCP(dst, decrypted []byte) ([]byte, error) {
if len(decrypted) < 8 {
return nil, errTooShortRTCP
}

ssrc := binary.BigEndian.Uint32(decrypted[4:])
s := c.getSRTCPSSRCState(ssrc)

Expand Down
Loading

0 comments on commit 0704101

Please sign in to comment.