diff --git a/cluster/agent_test.go b/cluster/agent_test.go index bfd4813..bc6f922 100644 --- a/cluster/agent_test.go +++ b/cluster/agent_test.go @@ -1,41 +1,32 @@ package cluster import ( - "net" "strconv" "testing" "time" "github.com/stretchr/testify/require" "github.com/wind-c/comqtt/v2/cluster/log" + "github.com/wind-c/comqtt/v2/cluster/utils" "github.com/wind-c/comqtt/v2/config" ) -func getFreePort() (int, error) { - listener, err := net.Listen("tcp", ":0") - if err != nil { - return 0, err - } - defer listener.Close() - return listener.Addr().(*net.TCPAddr).Port, nil -} - func TestCluster(t *testing.T) { log.Init(log.DefaultOptions()) - bindPort1, err := getFreePort() + bindPort1, err := utils.GetFreePort() require.NoError(t, err, "Failed to get free port for node1") - raftPort1, err := getFreePort() + raftPort1, err := utils.GetFreePort() require.NoError(t, err, "Failed to get free port for node1 Raft") - bindPort2, err := getFreePort() + bindPort2, err := utils.GetFreePort() require.NoError(t, err, "Failed to get free port for node2") - raftPort2, err := getFreePort() + raftPort2, err := utils.GetFreePort() require.NoError(t, err, "Failed to get free port for node2 Raft") - bindPort3, err := getFreePort() + bindPort3, err := utils.GetFreePort() require.NoError(t, err, "Failed to get free port for node3") - raftPort3, err := getFreePort() + raftPort3, err := utils.GetFreePort() require.NoError(t, err, "Failed to get free port for node3 Raft") members := []string{ diff --git a/cluster/discovery/node.go b/cluster/discovery/node.go index d2d390b..029ea05 100644 --- a/cluster/discovery/node.go +++ b/cluster/discovery/node.go @@ -6,10 +6,11 @@ package discovery import ( "encoding/json" - "github.com/wind-c/comqtt/v2/mqtt" "net" "os" "strconv" + + "github.com/wind-c/comqtt/v2/mqtt" ) const ( @@ -31,7 +32,6 @@ type Node interface { BindMqttServer(server *mqtt.Server) LocalAddr() string LocalName() string - NumMembers() int Members() []Member EventChan() <-chan *Event SendToNode(nodeName string, msg []byte) error diff --git a/cluster/discovery/serf/membership.go b/cluster/discovery/serf/membership.go index b5dbab6..d06ecc9 100644 --- a/cluster/discovery/serf/membership.go +++ b/cluster/discovery/serf/membership.go @@ -8,7 +8,6 @@ import ( "strconv" "github.com/hashicorp/logutils" - "github.com/hashicorp/memberlist" "github.com/hashicorp/serf/serf" mb "github.com/wind-c/comqtt/v2/cluster/discovery" "github.com/wind-c/comqtt/v2/cluster/log" @@ -99,8 +98,8 @@ func (m *Membership) EventChan() <-chan *mb.Event { return m.eventCh } -func (m *Membership) NumMembers() int { - return m.serf.NumNodes() +func (m *Membership) numMembers() int { + return len(m.aliveMembers()) } func (m *Membership) LocalName() string { @@ -195,10 +194,6 @@ func (m *Membership) eventLoop() { } } -func (m *Membership) send(to memberlist.Address, msg []byte) error { - return m.serf.Memberlist().SendToAddress(to, msg) -} - // SendToOthers send message to all nodes except yourself func (m *Membership) SendToOthers(msg []byte) { m.Broadcast(msg) diff --git a/cluster/discovery/serf/membership_test.go b/cluster/discovery/serf/membership_test.go new file mode 100644 index 0000000..bd357ce --- /dev/null +++ b/cluster/discovery/serf/membership_test.go @@ -0,0 +1,169 @@ +package serf + +import ( + "os" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/wind-c/comqtt/v2/cluster/log" + "github.com/wind-c/comqtt/v2/cluster/utils" + "github.com/wind-c/comqtt/v2/config" +) + +func TestMain(m *testing.M) { + log.Init(log.DefaultOptions()) + code := m.Run() + os.Exit(code) +} + +func TestJoin_Leave_NumMembers(t *testing.T) { + bindPort1, err := utils.GetFreePort() + assert.NoError(t, err) + conf1 := &config.Cluster{ + BindAddr: "127.0.0.1", + BindPort: bindPort1, + NodeName: "test-node-1", + } + inboundMsgCh1 := make(chan []byte) + membership1 := New(conf1, inboundMsgCh1) + err = membership1.Setup() + assert.NoError(t, err) + defer membership1.Stop() + + assert.Equal(t, 1, membership1.numMembers()) + + bindPort2, err := utils.GetFreePort() + assert.NoError(t, err) + conf2 := &config.Cluster{ + BindAddr: "127.0.0.1", + BindPort: bindPort2, + NodeName: "test-node-2", + } + inboundMsgCh2 := make(chan []byte) + membership2 := New(conf2, inboundMsgCh2) + err = membership2.Setup() + assert.NoError(t, err) + defer membership2.Stop() + + numJoined, err := membership2.Join([]string{"127.0.0.1:" + strconv.Itoa(bindPort1)}) + assert.NoError(t, err) + assert.Equal(t, numJoined, 1) + assert.Equal(t, 2, membership1.numMembers()) + assert.Equal(t, 2, membership2.numMembers()) + + t.Log("Leave node 2") + err = membership2.Leave() + assert.NoError(t, err) + + time.Sleep(5 * time.Second) + + assert.Equal(t, 1, membership1.numMembers()) +} + +func TestSendToNode(t *testing.T) { + bindPort1, err := utils.GetFreePort() + assert.NoError(t, err) + bindPort2, err := utils.GetFreePort() + assert.NoError(t, err) + + conf1 := &config.Cluster{ + BindAddr: "127.0.0.1", + BindPort: bindPort1, + NodeName: "test-node-1", + } + conf2 := &config.Cluster{ + BindAddr: "127.0.0.1", + BindPort: bindPort2, + NodeName: "test-node-2", + Members: []string{"127.0.0.1:" + strconv.Itoa(bindPort1)}, + } + inboundMsgCh1 := make(chan []byte) + inboundMsgCh2 := make(chan []byte) + + membership1 := New(conf1, inboundMsgCh1) + err = membership1.Setup() + defer membership1.Stop() + assert.NoError(t, err) + + membership2 := New(conf2, inboundMsgCh2) + err = membership2.Setup() + defer membership2.Stop() + assert.NoError(t, err) + + time.Sleep(3 * time.Second) + + err = membership1.SendToNode("test-node-2", []byte("test message")) + assert.NoError(t, err) + + select { + case msg := <-inboundMsgCh2: + assert.Equal(t, []byte("test message"), msg) + case <-time.After(5 * time.Second): + t.Fatal("Did not receive the message in membership2") + } +} + +func TestSendToOthers(t *testing.T) { + bindPort1, err := utils.GetFreePort() + assert.NoError(t, err) + bindPort2, err := utils.GetFreePort() + assert.NoError(t, err) + bindPort3, err := utils.GetFreePort() + assert.NoError(t, err) + + conf1 := &config.Cluster{ + BindAddr: "127.0.0.1", + BindPort: bindPort1, + NodeName: "test-node-1", + } + conf2 := &config.Cluster{ + BindAddr: "127.0.0.1", + BindPort: bindPort2, + NodeName: "test-node-2", + Members: []string{"127.0.0.1:" + strconv.Itoa(bindPort1)}, + } + conf3 := &config.Cluster{ + BindAddr: "127.0.0.1", + BindPort: bindPort3, + NodeName: "test-node-3", + Members: []string{"127.0.0.1:" + strconv.Itoa(bindPort1)}, + } + inboundMsgCh1 := make(chan []byte) + inboundMsgCh2 := make(chan []byte) + inboundMsgCh3 := make(chan []byte) + + membership1 := New(conf1, inboundMsgCh1) + err = membership1.Setup() + defer membership1.Stop() + assert.NoError(t, err) + + membership2 := New(conf2, inboundMsgCh2) + err = membership2.Setup() + defer membership2.Stop() + assert.NoError(t, err) + + membership3 := New(conf3, inboundMsgCh3) + err = membership3.Setup() + defer membership3.Stop() + assert.NoError(t, err) + + time.Sleep(3 * time.Second) + + membership1.SendToOthers([]byte("test message")) + + select { + case msg := <-inboundMsgCh2: + assert.Equal(t, []byte("test message"), msg) + case <-time.After(5 * time.Second): + t.Fatal("Did not receive the message in membership2") + } + + select { + case msg := <-inboundMsgCh3: + assert.Equal(t, []byte("test message"), msg) + case <-time.After(5 * time.Second): + t.Fatal("Did not receive the message in membership3") + } +} diff --git a/cluster/utils/utils.go b/cluster/utils/utils.go index e5b8b6d..0c90c4c 100644 --- a/cluster/utils/utils.go +++ b/cluster/utils/utils.go @@ -6,14 +6,15 @@ package utils import ( "fmt" - "github.com/hashicorp/go-sockaddr" "net" "os" "reflect" "strings" "testing" - "github.com/satori/go.uuid" + "github.com/hashicorp/go-sockaddr" + + uuid "github.com/satori/go.uuid" ) func InArray(val interface{}, array interface{}) bool { @@ -151,3 +152,12 @@ func PathExists(path string) bool { } return true } + +func GetFreePort() (int, error) { + listener, err := net.Listen("tcp", ":0") + if err != nil { + return 0, err + } + defer listener.Close() + return listener.Addr().(*net.TCPAddr).Port, nil +}