Skip to content

Commit

Permalink
Upgrade to github.com/pion/dtls/v3
Browse files Browse the repository at this point in the history
  • Loading branch information
Danielius1922 committed Sep 13, 2024
1 parent 1afdeb7 commit 0e51f3a
Show file tree
Hide file tree
Showing 25 changed files with 175 additions and 348 deletions.
3 changes: 0 additions & 3 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@ linters-settings:
enable:
- nilness
- shadow
gomoddirectives:
replace-allow-list:
- github.com/pion/dtls/v2

linters:
enable:
Expand Down
4 changes: 2 additions & 2 deletions dtls/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import (
"fmt"
"time"

"github.com/pion/dtls/v2"
dtlsnet "github.com/pion/dtls/v2/pkg/net"
"github.com/pion/dtls/v3"
dtlsnet "github.com/pion/dtls/v3/pkg/net"
"github.com/plgd-dev/go-coap/v3/dtls/server"
"github.com/plgd-dev/go-coap/v3/message"
"github.com/plgd-dev/go-coap/v3/message/codes"
Expand Down
121 changes: 41 additions & 80 deletions dtls/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"testing"
"time"

piondtls "github.com/pion/dtls/v2"
piondtls "github.com/pion/dtls/v3"
"github.com/plgd-dev/go-coap/v3/dtls"
"github.com/plgd-dev/go-coap/v3/message"
"github.com/plgd-dev/go-coap/v3/message/codes"
Expand All @@ -20,7 +20,6 @@ import (
coapNet "github.com/plgd-dev/go-coap/v3/net"
"github.com/plgd-dev/go-coap/v3/net/responsewriter"
"github.com/plgd-dev/go-coap/v3/options"
"github.com/plgd-dev/go-coap/v3/options/config"
"github.com/plgd-dev/go-coap/v3/pkg/runner/periodic"
"github.com/plgd-dev/go-coap/v3/udp/client"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -123,7 +122,7 @@ func TestConnGet(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3600)
ctx, cancel := context.WithTimeout(context.Background(), Timeout)
defer cancel()
got, err := cc.Get(ctx, tt.args.path, tt.args.opts...)
if tt.wantErr {
Expand Down Expand Up @@ -216,7 +215,7 @@ func TestConnGetSeparateMessage(t *testing.T) {
require.NoError(t, errC)
}()

ctx, cancel := context.WithTimeout(context.Background(), time.Second*3600)
ctx, cancel := context.WithTimeout(context.Background(), Timeout)
defer cancel()

req, err := cc.NewGetRequest(ctx, "/a")
Expand Down Expand Up @@ -340,7 +339,7 @@ func TestConnPost(t *testing.T) {
require.NoError(t, errC)
}()

ctx, cancel := context.WithTimeout(context.Background(), time.Second*3600)
ctx, cancel := context.WithTimeout(context.Background(), Timeout)
defer cancel()
got, err := cc.Post(ctx, tt.args.path, tt.args.contentFormat, tt.args.payload, tt.args.opts...)
if tt.wantErr {
Expand Down Expand Up @@ -475,7 +474,7 @@ func TestConnPut(t *testing.T) {
require.NoError(t, errC)
}()

ctx, cancel := context.WithTimeout(context.Background(), time.Second*3600)
ctx, cancel := context.WithTimeout(context.Background(), Timeout)
defer cancel()
got, err := cc.Put(ctx, tt.args.path, tt.args.contentFormat, tt.args.payload, tt.args.opts...)
if tt.wantErr {
Expand Down Expand Up @@ -590,7 +589,7 @@ func TestConnDelete(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3600)
ctx, cancel := context.WithTimeout(context.Background(), Timeout)
defer cancel()
got, err := cc.Delete(ctx, tt.args.path, tt.args.opts...)
if tt.wantErr {
Expand Down Expand Up @@ -654,58 +653,12 @@ func TestConnPing(t *testing.T) {
require.NoError(t, err)
}

func TestConnHandeShakeFailure(t *testing.T) {
dtlsCfg := &piondtls.Config{
PSK: func(hint []byte) ([]byte, error) {
fmt.Printf("Hint: %s \n", hint)
return []byte{0xAB, 0xC1, 0x23}, nil
},
PSKIdentityHint: []byte("Pion DTLS Server"),
CipherSuites: []piondtls.CipherSuiteID{piondtls.TLS_PSK_WITH_AES_128_CCM_8},
ConnectContextMaker: func() (context.Context, func()) {
return context.WithTimeout(context.Background(), 1*time.Second)
},
}
l, err := coapNet.NewDTLSListener("udp", "", dtlsCfg)
require.NoError(t, err)
defer func() {
errC := l.Close()
require.NoError(t, errC)
}()
var wg sync.WaitGroup
defer wg.Wait()

s := dtls.NewServer()
defer s.Stop()

wg.Add(1)
go func() {
defer wg.Done()
errS := s.Serve(l)
assert.NoError(t, errS)
}()

dtlsCfgClient := &piondtls.Config{
PSK: func(hint []byte) ([]byte, error) {
fmt.Printf("Hint: %s \n", hint)
return []byte{0xAB, 0xC1, 0x24}, nil
},
PSKIdentityHint: []byte("Pion DTLS Client"),
CipherSuites: []piondtls.CipherSuiteID{piondtls.TLS_PSK_WITH_AES_128_CCM_8},
ConnectContextMaker: func() (context.Context, func()) {
return context.WithTimeout(context.Background(), 1*time.Second)
},
}
_, err = dtls.Dial(l.Addr().String(), dtlsCfgClient)
require.Error(t, err)
}

func TestClientInactiveMonitor(t *testing.T) {
var inactivityDetected atomic.Bool

ctx, cancel := context.WithTimeout(context.Background(), Timeout)
defer cancel()
serverCgf, clientCgf, _, err := createDTLSConfig(ctx)
serverCgf, clientCgf, _, err := createDTLSConfig()
require.NoError(t, err)

ld, err := coapNet.NewDTLSListener("udp4", "", serverCgf)
Expand Down Expand Up @@ -745,7 +698,9 @@ func TestClientInactiveMonitor(t *testing.T) {
serverWg.Wait()
}()

cc, err := dtls.Dial(ld.Addr().String(), clientCgf,
cc, err := dtls.Dial(
ld.Addr().String(),
clientCgf,
options.WithInactivityMonitor(100*time.Millisecond, func(cc *client.Conn) {
require.False(t, inactivityDetected.Load())
inactivityDetected.Store(true)
Expand Down Expand Up @@ -774,65 +729,71 @@ func TestClientInactiveMonitor(t *testing.T) {
func TestClientKeepAliveMonitor(t *testing.T) {
var inactivityDetected atomic.Bool

ctx, cancel := context.WithTimeout(context.Background(), Timeout)
defer cancel()
serverCgf, clientCgf, _, err := createDTLSConfig(ctx)
serverCgf, clientCgf, _, err := createDTLSConfig()
require.NoError(t, err)

ld, err := coapNet.NewDTLSListener("udp4", "", serverCgf)
require.NoError(t, err)
defer func() {
errC := ld.Close()
require.NoError(t, errC)
}()

ctx, cancel := context.WithTimeout(context.Background(), Timeout)
defer cancel()

checkClose := semaphore.NewWeighted(1)
err = checkClose.Acquire(ctx, 1)
checkClose := semaphore.NewWeighted(2)
err = checkClose.Acquire(ctx, 2)
require.NoError(t, err)

sd := dtls.NewServer(
options.WithOnNewConn(func(cc *client.Conn) {
cc.AddOnClose(func() {
checkClose.Release(1)
})
}),
options.WithPeriodicRunner(periodic.New(ctx.Done(), time.Millisecond*10)),
options.WithRequestMonitor(func(_ *client.Conn, _ *pool.Message) (bool, error) {
// lets drop all messages, this will trigger keep alive because of inactivity
return true, nil
}),
)

var serverWg sync.WaitGroup
serverWg.Add(1)
go func() {
defer serverWg.Done()
for {
c, errA := ld.AcceptWithContext(ctx)
if errA != nil {
if errors.Is(errA, coapNet.ErrListenerIsClosed) {
return
}
}
defer c.Close()
assert.NoError(t, errA)
}
errS := sd.Serve(ld)
assert.NoError(t, errS)
}()
defer func() {
errC := ld.Close()
require.NoError(t, errC)
sd.Stop()
serverWg.Wait()
}()

cc, err := dtls.Dial(
ld.Addr().String(),
clientCgf,
options.WithKeepAlive(3, 100*time.Millisecond, func(cc *client.Conn) {
t.Log("client - close for inactivity")
require.False(t, inactivityDetected.Load())
inactivityDetected.Store(true)
errC := cc.Close()
require.NoError(t, errC)
}),
options.WithPeriodicRunner(periodic.New(ctx.Done(), time.Millisecond*10)),
options.WithReceivedMessageQueueSize(32),
options.WithProcessReceivedMessageFunc(func(req *pool.Message, cc *client.Conn, handler config.HandlerFunc[*client.Conn]) {
cc.ProcessReceivedMessageWithHandler(req, handler)
}),
)
require.NoError(t, err)
cc.AddOnClose(func() {
t.Log("connection is closed")
checkClose.Release(1)
})

// send ping to create server side connection
ctxPing, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
err = cc.Ping(ctxPing)
require.Error(t, err)
_ = cc.Ping(ctxPing)

err = checkClose.Acquire(ctx, 1)
err = checkClose.Acquire(ctx, 2)
require.NoError(t, err)
require.True(t, inactivityDetected.Load())
}
2 changes: 1 addition & 1 deletion dtls/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"log"
"time"

piondtls "github.com/pion/dtls/v2"
piondtls "github.com/pion/dtls/v3"
"github.com/plgd-dev/go-coap/v3/dtls"
"github.com/plgd-dev/go-coap/v3/net"
)
Expand Down
49 changes: 26 additions & 23 deletions dtls/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,20 @@ func (s *Server) checkAndSetListener(l Listener) error {
s.listenMutex.Lock()
defer s.listenMutex.Unlock()
if s.listen != nil {
return errors.New("server already serve listener")
return errors.New("server already serves listener")
}
s.listen = l
return nil
}

func (s *Server) popListener() Listener {
s.listenMutex.Lock()
defer s.listenMutex.Unlock()
l := s.listen
s.listen = nil
return l
}

func (s *Server) checkAcceptError(err error) bool {
if err == nil {
return true
Expand All @@ -116,7 +124,14 @@ func (s *Server) checkAcceptError(err error) bool {
}
}

func (s *Server) serveConnection(connections *connections.Connections, cc *udpClient.Conn) {
func (s *Server) serveConnection(connections *connections.Connections, rw net.Conn) {
inactivityMonitor := s.cfg.CreateInactivityMonitor()
requestMonitor := s.cfg.RequestMonitor
dtlsConn := coapNet.NewConn(rw)
cc := s.createConn(dtlsConn, inactivityMonitor, requestMonitor)
if s.cfg.OnNewConn != nil {
s.cfg.OnNewConn(cc)
}
connections.Store(cc)
defer connections.Delete(cc)

Expand All @@ -129,16 +144,14 @@ func (s *Server) Serve(l Listener) error {
if s.cfg.BlockwiseSZX > blockwise.SZX1024 {
return errors.New("invalid blockwiseSZX")
}

err := s.checkAndSetListener(l)
if err != nil {
return err
}
defer func() {
s.listenMutex.Lock()
defer s.listenMutex.Unlock()
s.listen = nil
s.Stop()
}()

var wg sync.WaitGroup
defer wg.Wait()

Expand All @@ -158,32 +171,22 @@ func (s *Server) Serve(l Listener) error {
continue
}
wg.Add(1)
var cc *udpClient.Conn
inactivityMonitor := s.cfg.CreateInactivityMonitor()
requestMonitor := s.cfg.RequestMonitor

cc = s.createConn(coapNet.NewConn(rw), inactivityMonitor, requestMonitor)
if s.cfg.OnNewConn != nil {
s.cfg.OnNewConn(cc)
}
go func() {
defer wg.Done()
s.serveConnection(connections, cc)
s.serveConnection(connections, rw)
}()
}
}

// Stop stops server without wait of ends Serve function.
func (s *Server) Stop() {
s.cancel()
s.listenMutex.Lock()
l := s.listen
s.listen = nil
s.listenMutex.Unlock()
if l != nil {
if err := l.Close(); err != nil {
s.cfg.Errors(fmt.Errorf("cannot close listener: %w", err))
}
l := s.popListener()
if l == nil {
return
}
if err := l.Close(); err != nil {
s.cfg.Errors(fmt.Errorf("cannot close listener: %w", err))
}
}

Expand Down
Loading

0 comments on commit 0e51f3a

Please sign in to comment.