Skip to content

Commit

Permalink
Add tests for Serf
Browse files Browse the repository at this point in the history
  • Loading branch information
ohkinozomu committed Aug 8, 2024
1 parent 7c0cd02 commit c14cf7a
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 27 deletions.
23 changes: 7 additions & 16 deletions cluster/agent_test.go
Original file line number Diff line number Diff line change
@@ -1,41 +1,32 @@
package cluster

import (
"net"
"strconv"
"testing"
"time"

"github.com/stretchr/testify/require"
"github.com/wind-c/comqtt/v2/cluster/log"
"github.com/wind-c/comqtt/v2/cluster/utils"
"github.com/wind-c/comqtt/v2/config"
)

func getFreePort() (int, error) {
listener, err := net.Listen("tcp", ":0")
if err != nil {
return 0, err
}
defer listener.Close()
return listener.Addr().(*net.TCPAddr).Port, nil
}

func TestCluster(t *testing.T) {
log.Init(log.DefaultOptions())

bindPort1, err := getFreePort()
bindPort1, err := utils.GetFreePort()
require.NoError(t, err, "Failed to get free port for node1")
raftPort1, err := getFreePort()
raftPort1, err := utils.GetFreePort()
require.NoError(t, err, "Failed to get free port for node1 Raft")

bindPort2, err := getFreePort()
bindPort2, err := utils.GetFreePort()
require.NoError(t, err, "Failed to get free port for node2")
raftPort2, err := getFreePort()
raftPort2, err := utils.GetFreePort()
require.NoError(t, err, "Failed to get free port for node2 Raft")

bindPort3, err := getFreePort()
bindPort3, err := utils.GetFreePort()
require.NoError(t, err, "Failed to get free port for node3")
raftPort3, err := getFreePort()
raftPort3, err := utils.GetFreePort()
require.NoError(t, err, "Failed to get free port for node3 Raft")

members := []string{
Expand Down
4 changes: 2 additions & 2 deletions cluster/discovery/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ package discovery

import (
"encoding/json"
"github.com/wind-c/comqtt/v2/mqtt"
"net"
"os"
"strconv"

"github.com/wind-c/comqtt/v2/mqtt"
)

const (
Expand All @@ -31,7 +32,6 @@ type Node interface {
BindMqttServer(server *mqtt.Server)
LocalAddr() string
LocalName() string
NumMembers() int
Members() []Member
EventChan() <-chan *Event
SendToNode(nodeName string, msg []byte) error
Expand Down
9 changes: 2 additions & 7 deletions cluster/discovery/serf/membership.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"strconv"

"github.com/hashicorp/logutils"
"github.com/hashicorp/memberlist"
"github.com/hashicorp/serf/serf"
mb "github.com/wind-c/comqtt/v2/cluster/discovery"
"github.com/wind-c/comqtt/v2/cluster/log"
Expand Down Expand Up @@ -99,8 +98,8 @@ func (m *Membership) EventChan() <-chan *mb.Event {
return m.eventCh
}

func (m *Membership) NumMembers() int {
return m.serf.NumNodes()
func (m *Membership) numMembers() int {
return len(m.aliveMembers())
}

func (m *Membership) LocalName() string {
Expand Down Expand Up @@ -195,10 +194,6 @@ func (m *Membership) eventLoop() {
}
}

func (m *Membership) send(to memberlist.Address, msg []byte) error {
return m.serf.Memberlist().SendToAddress(to, msg)
}

// SendToOthers send message to all nodes except yourself
func (m *Membership) SendToOthers(msg []byte) {
m.Broadcast(msg)
Expand Down
169 changes: 169 additions & 0 deletions cluster/discovery/serf/membership_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
package serf

import (
"os"
"strconv"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/wind-c/comqtt/v2/cluster/log"
"github.com/wind-c/comqtt/v2/cluster/utils"
"github.com/wind-c/comqtt/v2/config"
)

func TestMain(m *testing.M) {
log.Init(log.DefaultOptions())
code := m.Run()
os.Exit(code)
}

func TestJoin_Leave_NumMembers(t *testing.T) {
bindPort1, err := utils.GetFreePort()
assert.NoError(t, err)
conf1 := &config.Cluster{
BindAddr: "127.0.0.1",
BindPort: bindPort1,
NodeName: "test-node-1",
}
inboundMsgCh1 := make(chan []byte)
membership1 := New(conf1, inboundMsgCh1)
err = membership1.Setup()
assert.NoError(t, err)
defer membership1.Stop()

assert.Equal(t, 1, membership1.numMembers())

bindPort2, err := utils.GetFreePort()
assert.NoError(t, err)
conf2 := &config.Cluster{
BindAddr: "127.0.0.1",
BindPort: bindPort2,
NodeName: "test-node-2",
}
inboundMsgCh2 := make(chan []byte)
membership2 := New(conf2, inboundMsgCh2)
err = membership2.Setup()
assert.NoError(t, err)
defer membership2.Stop()

numJoined, err := membership2.Join([]string{"127.0.0.1:" + strconv.Itoa(bindPort1)})
assert.NoError(t, err)
assert.Equal(t, numJoined, 1)
assert.Equal(t, 2, membership1.numMembers())
assert.Equal(t, 2, membership2.numMembers())

t.Log("Leave node 2")
err = membership2.Leave()
assert.NoError(t, err)

time.Sleep(5 * time.Second)

assert.Equal(t, 1, membership1.numMembers())
}

func TestSendToNode(t *testing.T) {
bindPort1, err := utils.GetFreePort()
assert.NoError(t, err)
bindPort2, err := utils.GetFreePort()
assert.NoError(t, err)

conf1 := &config.Cluster{
BindAddr: "127.0.0.1",
BindPort: bindPort1,
NodeName: "test-node-1",
}
conf2 := &config.Cluster{
BindAddr: "127.0.0.1",
BindPort: bindPort2,
NodeName: "test-node-2",
Members: []string{"127.0.0.1:" + strconv.Itoa(bindPort1)},
}
inboundMsgCh1 := make(chan []byte)
inboundMsgCh2 := make(chan []byte)

membership1 := New(conf1, inboundMsgCh1)
err = membership1.Setup()
defer membership1.Stop()
assert.NoError(t, err)

membership2 := New(conf2, inboundMsgCh2)
err = membership2.Setup()
defer membership2.Stop()
assert.NoError(t, err)

time.Sleep(3 * time.Second)

err = membership1.SendToNode("test-node-2", []byte("test message"))
assert.NoError(t, err)

select {
case msg := <-inboundMsgCh2:
assert.Equal(t, []byte("test message"), msg)
case <-time.After(5 * time.Second):
t.Fatal("Did not receive the message in membership2")
}
}

func TestSendToOthers(t *testing.T) {
bindPort1, err := utils.GetFreePort()
assert.NoError(t, err)
bindPort2, err := utils.GetFreePort()
assert.NoError(t, err)
bindPort3, err := utils.GetFreePort()
assert.NoError(t, err)

conf1 := &config.Cluster{
BindAddr: "127.0.0.1",
BindPort: bindPort1,
NodeName: "test-node-1",
}
conf2 := &config.Cluster{
BindAddr: "127.0.0.1",
BindPort: bindPort2,
NodeName: "test-node-2",
Members: []string{"127.0.0.1:" + strconv.Itoa(bindPort1)},
}
conf3 := &config.Cluster{
BindAddr: "127.0.0.1",
BindPort: bindPort3,
NodeName: "test-node-3",
Members: []string{"127.0.0.1:" + strconv.Itoa(bindPort1)},
}
inboundMsgCh1 := make(chan []byte)
inboundMsgCh2 := make(chan []byte)
inboundMsgCh3 := make(chan []byte)

membership1 := New(conf1, inboundMsgCh1)
err = membership1.Setup()
defer membership1.Stop()
assert.NoError(t, err)

membership2 := New(conf2, inboundMsgCh2)
err = membership2.Setup()
defer membership2.Stop()
assert.NoError(t, err)

membership3 := New(conf3, inboundMsgCh3)
err = membership3.Setup()
defer membership3.Stop()
assert.NoError(t, err)

time.Sleep(3 * time.Second)

membership1.SendToOthers([]byte("test message"))

select {
case msg := <-inboundMsgCh2:
assert.Equal(t, []byte("test message"), msg)
case <-time.After(5 * time.Second):
t.Fatal("Did not receive the message in membership2")
}

select {
case msg := <-inboundMsgCh3:
assert.Equal(t, []byte("test message"), msg)
case <-time.After(5 * time.Second):
t.Fatal("Did not receive the message in membership3")
}
}
14 changes: 12 additions & 2 deletions cluster/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ package utils

import (
"fmt"
"github.com/hashicorp/go-sockaddr"
"net"
"os"
"reflect"
"strings"
"testing"

"github.com/satori/go.uuid"
"github.com/hashicorp/go-sockaddr"

uuid "github.com/satori/go.uuid"
)

func InArray(val interface{}, array interface{}) bool {
Expand Down Expand Up @@ -151,3 +152,12 @@ func PathExists(path string) bool {
}
return true
}

func GetFreePort() (int, error) {
listener, err := net.Listen("tcp", ":0")
if err != nil {
return 0, err
}
defer listener.Close()
return listener.Addr().(*net.TCPAddr).Port, nil
}

0 comments on commit c14cf7a

Please sign in to comment.