Skip to content

Commit

Permalink
减少GTP协议握手内存与流量
Browse files Browse the repository at this point in the history
  • Loading branch information
pangdogs committed Sep 20, 2024
1 parent 27a3e2a commit 0d25d1d
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 30 deletions.
34 changes: 19 additions & 15 deletions plugins/gate/acceptor_handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"bytes"
"context"
"crypto/rand"
"crypto/sha256"
"errors"
"fmt"
"git.golaxy.org/core/utils/uid"
Expand Down Expand Up @@ -59,7 +60,7 @@ func (acc *_Acceptor) handshake(ctx context.Context, conn net.Conn) (*_Session,
var cs gtp.CipherSuite
var cm gtp.Compression
var cliRandom, servRandom []byte
var cliHelloBytes, servHelloBytes []byte
var cliHelloHash, servHelloHash [sha256.Size]byte
var continueFlow, encryptionFlow, authFlow bool
var session *_Session

Expand All @@ -70,12 +71,6 @@ func (acc *_Acceptor) handshake(ctx context.Context, conn net.Conn) (*_Session,
if servRandom != nil {
binaryutil.BytesPool.Put(servRandom)
}
if cliHelloBytes != nil {
binaryutil.BytesPool.Put(cliHelloBytes)
}
if servHelloBytes != nil {
binaryutil.BytesPool.Put(servHelloBytes)
}
}()

// 与客户端互相hello
Expand Down Expand Up @@ -176,21 +171,30 @@ func (acc *_Acceptor) handshake(ctx context.Context, conn net.Conn) (*_Session,

// 开启加密时,记录双方hello数据,用于ecdh后加密验证
if encryptionFlow {
cliHelloBytes = binaryutil.BytesPool.Get(cliHello.Msg.Size())
if _, err := cliHello.Msg.Read(cliHelloBytes); err != nil {
h := sha256.New()

hashBuff := binaryutil.BytesPool.Get(8 * 1024)
defer binaryutil.BytesPool.Put(hashBuff)

h.Reset()
_, err := io.CopyBuffer(h, cliHello.Msg, hashBuff)
if err != nil {
return transport.Event[gtp.MsgHello]{}, &transport.RstError{
Code: gtp.Code_EncryptFailed,
Message: err.Error(),
}
}
copy(cliHelloHash[:], h.Sum(nil))

servHelloBytes = binaryutil.BytesPool.Get(servHello.Msg.Size())
if _, err := servHello.Msg.Read(servHelloBytes); err != nil {
h.Reset()
_, err = io.CopyBuffer(h, servHello.Msg, hashBuff)
if err != nil {
return transport.Event[gtp.MsgHello]{}, &transport.RstError{
Code: gtp.Code_EncryptFailed,
Message: err.Error(),
}
}
copy(servHelloHash[:], h.Sum(nil))
}

return servHello, nil
Expand All @@ -201,7 +205,7 @@ func (acc *_Acceptor) handshake(ctx context.Context, conn net.Conn) (*_Session,

// 开启加密时,与客户端交换秘钥
if encryptionFlow {
err = acc.secretKeyExchange(ctx, handshake, cs, cm, cliRandom, servRandom, cliHelloBytes, servHelloBytes, session.GetId())
err = acc.secretKeyExchange(ctx, handshake, cs, cm, cliRandom, servRandom, cliHelloHash, servHelloHash, session.GetId())
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -325,7 +329,7 @@ func (acc *_Acceptor) handshake(ctx context.Context, conn net.Conn) (*_Session,

// secretKeyExchange 秘钥交换过程
func (acc *_Acceptor) secretKeyExchange(ctx context.Context, handshake *transport.HandshakeProtocol, cs gtp.CipherSuite, cm gtp.Compression,
cliRandom, servRandom, cliHelloBytes, servHelloBytes []byte, sessionId uid.Id) (err error) {
cliRandom, servRandom []byte, cliHelloHash, servHelloHash [sha256.Size]byte, sessionId uid.Id) (err error) {
// 控制协议
ctrl := transport.CtrlProtocol{
Transceiver: handshake.Transceiver,
Expand Down Expand Up @@ -486,7 +490,7 @@ func (acc *_Acceptor) secretKeyExchange(ctx context.Context, handshake *transpor
encryptionModule[0] = codec.NewEncryptionModule(cipher[0], padding[0], fetchNonce[0])

// 加密hello消息
encryptedHello, err = encryptionModule[0].Transforming(nil, servHelloBytes)
encryptedHello, err = encryptionModule[0].Transforming(nil, servHelloHash[:])
if err != nil {
return transport.Event[gtp.MsgChangeCipherSpec]{}, &transport.RstError{
Code: gtp.Code_EncryptFailed,
Expand Down Expand Up @@ -519,7 +523,7 @@ func (acc *_Acceptor) secretKeyExchange(ctx context.Context, handshake *transpor
}
defer decryptedHello.Release()

if bytes.Compare(decryptedHello.Data(), cliHelloBytes) != 0 {
if bytes.Compare(decryptedHello.Data(), cliHelloHash[:]) != 0 {
return &transport.RstError{
Code: gtp.Code_EncryptFailed,
Message: "verify hello failed",
Expand Down
34 changes: 19 additions & 15 deletions plugins/gate/cli/connector_handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"bytes"
"context"
"crypto/rand"
"crypto/sha256"
"errors"
"fmt"
"git.golaxy.org/core/utils/uid"
Expand Down Expand Up @@ -60,7 +61,7 @@ func (ctor *_Connector) handshake(ctx context.Context, conn net.Conn, client *Cl
cs := ctor.options.EncCipherSuite
cm := ctor.options.Compression
var cliRandom, servRandom []byte
var cliHelloBytes, servHelloBytes []byte
var cliHelloHash, servHelloHash [sha256.Size]byte
var continueFlow, encryptionFlow, authFlow bool

defer func() {
Expand All @@ -70,12 +71,6 @@ func (ctor *_Connector) handshake(ctx context.Context, conn net.Conn, client *Cl
if servRandom != nil {
binaryutil.BytesPool.Put(servRandom)
}
if cliHelloBytes != nil {
binaryutil.BytesPool.Put(cliHelloBytes)
}
if servHelloBytes != nil {
binaryutil.BytesPool.Put(servHelloBytes)
}
}()

// 生成客户端随机数
Expand Down Expand Up @@ -127,15 +122,24 @@ func (ctor *_Connector) handshake(ctx context.Context, conn net.Conn, client *Cl
copy(servRandom, servHello.Msg.Random)

// 记录双方hello数据,用于ecdh后加密验证
cliHelloBytes = binaryutil.BytesPool.Get(cliHello.Msg.Size())
if _, err := cliHello.Msg.Read(cliHelloBytes); err != nil {
h := sha256.New()

hashBuff := binaryutil.BytesPool.Get(8 * 1024)
defer binaryutil.BytesPool.Put(hashBuff)

h.Reset()
_, err := io.CopyBuffer(h, cliHello.Msg, hashBuff)
if err != nil {
return err
}
copy(cliHelloHash[:], h.Sum(nil))

servHelloBytes = binaryutil.BytesPool.Get(servHello.Msg.Size())
if _, err := servHello.Msg.Read(servHelloBytes); err != nil {
h.Reset()
_, err = io.CopyBuffer(h, servHello.Msg, hashBuff)
if err != nil {
return err
}
copy(servHelloHash[:], h.Sum(nil))
}

return nil
Expand All @@ -146,7 +150,7 @@ func (ctor *_Connector) handshake(ctx context.Context, conn net.Conn, client *Cl

// 开启加密时,与服务端交换秘钥
if encryptionFlow {
err = ctor.secretKeyExchange(ctx, handshake, cs, cm, cliRandom, servRandom, cliHelloBytes, servHelloBytes, sessionId)
err = ctor.secretKeyExchange(ctx, handshake, cs, cm, cliRandom, servRandom, cliHelloHash, servHelloHash, sessionId)
if err != nil {
return err
}
Expand Down Expand Up @@ -234,7 +238,7 @@ func (ctor *_Connector) handshake(ctx context.Context, conn net.Conn, client *Cl

// secretKeyExchange 秘钥交换过程
func (ctor *_Connector) secretKeyExchange(ctx context.Context, handshake *transport.HandshakeProtocol, cs gtp.CipherSuite, cm gtp.Compression,
cliRandom, servRandom, cliHelloBytes, servHelloBytes []byte, sessionId uid.Id) error {
cliRandom, servRandom []byte, cliHelloHash, servHelloHash [sha256.Size]byte, sessionId uid.Id) error {
// 选择秘钥交换函数,并与客户端交换秘钥
switch cs.SecretKeyExchange {
case gtp.SecretKeyExchange_ECDHE:
Expand Down Expand Up @@ -356,7 +360,7 @@ func (ctor *_Connector) secretKeyExchange(ctx context.Context, handshake *transp
}
defer decryptedHello.Release()

if bytes.Compare(decryptedHello.Data(), servHelloBytes) != 0 {
if bytes.Compare(decryptedHello.Data(), servHelloHash[:]) != 0 {
return transport.Event[gtp.MsgChangeCipherSpec]{}, errors.New("verify hello failed")
}
}
Expand All @@ -368,7 +372,7 @@ func (ctor *_Connector) secretKeyExchange(ctx context.Context, handshake *transp
// 加密hello消息
if verifyEncryption {
var err error
encryptedHello, err = encryptionModule[0].Transforming(nil, cliHelloBytes)
encryptedHello, err = encryptionModule[0].Transforming(nil, cliHelloHash[:])
if err != nil {
return transport.Event[gtp.MsgChangeCipherSpec]{}, fmt.Errorf("encrypt hello failed, %s", err)
}
Expand Down

0 comments on commit 0d25d1d

Please sign in to comment.