From 2816dee617e13face97422f295a33cc4e5b93ea3 Mon Sep 17 00:00:00 2001 From: Harmen Date: Wed, 6 Mar 2019 16:14:54 +0100 Subject: [PATCH 01/13] go mod init --- go.mod | 7 +++++++ go.sum | 10 ++++++++++ 2 files changed, 17 insertions(+) create mode 100644 go.mod create mode 100644 go.sum diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..f15be7b4 --- /dev/null +++ b/go.mod @@ -0,0 +1,7 @@ +module github.com/alicebob/miniredis + +require ( + github.com/alicebob/gopher-json v0.0.0-20180125190556-5a6b3ba71ee6 + github.com/gomodule/redigo v2.0.0+incompatible + github.com/yuin/gopher-lua v0.0.0-20190206043414-8bfc7677f583 +) diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..3bc2524c --- /dev/null +++ b/go.sum @@ -0,0 +1,10 @@ +github.com/alicebob/gopher-json v0.0.0-20180125190556-5a6b3ba71ee6 h1:45bxf7AZMwWcqkLzDAQugVEwedisr5nRJ1r+7LYnv0U= +github.com/alicebob/gopher-json v0.0.0-20180125190556-5a6b3ba71ee6/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= +github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= +github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= +github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= +github.com/gomodule/redigo v2.0.0+incompatible h1:K/R+8tc58AaqLkqG2Ol3Qk+DR/TlNuhuh457pBFPtt0= +github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= +github.com/yuin/gopher-lua v0.0.0-20190206043414-8bfc7677f583 h1:SZPG5w7Qxq7bMcMVl6e3Ht2X7f+AAGQdzjkbyOnNNZ8= +github.com/yuin/gopher-lua v0.0.0-20190206043414-8bfc7677f583/go.mod h1:gqRgreBUhTSL0GeU64rtZ3Uq3wtjOa/TB2YfrtkCbVQ= +golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= From 5f79d7899372d406c5fc23f80ba83e6fe75df466 Mon Sep 17 00:00:00 2001 From: Harmen Date: Wed, 6 Mar 2019 16:20:16 +0100 Subject: [PATCH 02/13] fix inttest to search for redis-server in the PATH --- integration/ephemeral.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration/ephemeral.go b/integration/ephemeral.go index 1b5477b4..3dcb7e98 100644 --- a/integration/ephemeral.go +++ b/integration/ephemeral.go @@ -37,7 +37,7 @@ func runRedis(extraConfig string) (*ephemeral, string) { port := arbitraryPort() // we prefer the executable from ./redis_src, if any. See ./get_redis.sh - os.Setenv("PATH", fmt.Sprintf("%s:PATH", localSrc)) + os.Setenv("PATH", fmt.Sprintf("%s:%s", localSrc, os.Getenv("PATH"))) c := exec.Command(executable, "-") stdin, err := c.StdinPipe() From 95d4352583a8de91cde78f1bb0c4c16128a66c92 Mon Sep 17 00:00:00 2001 From: "Alexander A. Klimov" Date: Mon, 3 Sep 2018 13:34:59 +0200 Subject: [PATCH 03/13] PUBSUB --- cmd_pubsub.go | 501 ++++++++++++++++++++++ cmd_pubsub_test.go | 1024 ++++++++++++++++++++++++++++++++++++++++++++ direct.go | 446 +++++++++++++++++++ miniredis.go | 125 ++++-- redis.go | 4 + server/server.go | 138 +++++- test_test.go | 21 + 7 files changed, 2205 insertions(+), 54 deletions(-) create mode 100644 cmd_pubsub.go create mode 100644 cmd_pubsub_test.go diff --git a/cmd_pubsub.go b/cmd_pubsub.go new file mode 100644 index 00000000..53a1b2d6 --- /dev/null +++ b/cmd_pubsub.go @@ -0,0 +1,501 @@ +// Commands from https://redis.io/commands#pubsub + +package miniredis + +import ( + "github.com/alicebob/miniredis/server" + "regexp" + "strings" +) + +// commandsPubsub handles all PUB/SUB operations. +func commandsPubsub(m *Miniredis) { + m.srv.Register("SUBSCRIBE", m.cmdSubscribe) + m.srv.Register("UNSUBSCRIBE", m.cmdUnsubscribe) + m.srv.Register("PSUBSCRIBE", m.cmdPSubscribe) + m.srv.Register("PUNSUBSCRIBE", m.cmdPUnsubscribe) + m.srv.Register("PUBLISH", m.cmdPublish) + m.srv.Register("PUBSUB", m.cmdPubSub) +} + +// SUBSCRIBE +func (m *Miniredis) cmdSubscribe(c *server.Peer, cmd string, args []string) { + if len(args) < 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + subscriptionsAmounts := make([]int, len(args)) + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + var cache peerCache + var hasCache bool + + if cache, hasCache = m.peers[c]; !hasCache { + cache = peerCache{subscriptions: map[int]peerSubscriptions{}} + m.peers[c] = cache + } + + var dbSubs peerSubscriptions + var hasDbSubs bool + + if dbSubs, hasDbSubs = cache.subscriptions[ctx.selectedDB]; !hasDbSubs { + dbSubs = peerSubscriptions{channels: map[string]struct{}{}, patterns: map[string]struct{}{}} + cache.subscriptions[ctx.selectedDB] = dbSubs + } + + subscribedChannels := m.db(ctx.selectedDB).subscribedChannels + + for i, channel := range args { + var peers map[*server.Peer]struct{} + var hasPeers bool + + if peers, hasPeers = subscribedChannels[channel]; !hasPeers { + peers = map[*server.Peer]struct{}{} + subscribedChannels[channel] = peers + } + + peers[c] = struct{}{} + + dbSubs.channels[channel] = struct{}{} + + subscriptionsAmounts[i] = m.getSubscriptionsAmount(c, ctx) + } + + for i, channel := range args { + c.WriteLen(3) + c.WriteBulk("subscribe") + c.WriteBulk(channel) + c.WriteInt(subscriptionsAmounts[i]) + } + }) +} + +func (m *Miniredis) getSubscriptionsAmount(c *server.Peer, ctx *connCtx) (total int) { + if cache, hasCache := m.peers[c]; hasCache { + if dbSubs, hasDbSubs := cache.subscriptions[ctx.selectedDB]; hasDbSubs { + total = len(dbSubs.channels) + len(dbSubs.patterns) + } + } + + return +} + +// UNSUBSCRIBE +func (m *Miniredis) cmdUnsubscribe(c *server.Peer, cmd string, args []string) { + if !m.handleAuth(c) { + return + } + + var channels []string = nil + var subscriptionsAmounts []int = nil + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + if cache, hasCache := m.peers[c]; hasCache { + if dbSubs, hasDbSubs := cache.subscriptions[ctx.selectedDB]; hasDbSubs { + subscribedChannels := m.db(ctx.selectedDB).subscribedChannels + + if len(args) > 0 { + channels = args + } else { + channels = make([]string, len(dbSubs.channels)) + i := 0 + + for channel := range dbSubs.channels { + channels[i] = channel + i++ + } + } + + subscriptionsAmounts = make([]int, len(channels)) + + for i, channel := range channels { + if peers, hasPeers := subscribedChannels[channel]; hasPeers { + delete(peers, c) + delete(dbSubs.channels, channel) + + if len(peers) < 1 { + delete(subscribedChannels, channel) + } + + if len(dbSubs.channels) < 1 && len(dbSubs.patterns) < 1 { + delete(cache.subscriptions, ctx.selectedDB) + + if len(cache.subscriptions) < 1 { + delete(m.peers, c) + } + } + } + + subscriptionsAmounts[i] = m.getSubscriptionsAmount(c, ctx) + } + } + } + + var subscriptionsAmount int + + if channels == nil { + subscriptionsAmount = m.getSubscriptionsAmount(c, ctx) + } + + if channels == nil { + for _, channel := range args { + c.WriteLen(3) + c.WriteBulk("unsubscribe") + c.WriteBulk(channel) + c.WriteInt(subscriptionsAmount) + } + } else { + for i, channel := range channels { + c.WriteLen(3) + c.WriteBulk("unsubscribe") + c.WriteBulk(channel) + c.WriteInt(subscriptionsAmounts[i]) + } + } + }) +} + +// PSUBSCRIBE +func (m *Miniredis) cmdPSubscribe(c *server.Peer, cmd string, args []string) { + if len(args) < 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + subscriptionsAmounts := make([]int, len(args)) + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + var cache peerCache + var hasCache bool + + if cache, hasCache = m.peers[c]; !hasCache { + cache = peerCache{subscriptions: map[int]peerSubscriptions{}} + m.peers[c] = cache + } + + var dbSubs peerSubscriptions + var hasDbSubs bool + + if dbSubs, hasDbSubs = cache.subscriptions[ctx.selectedDB]; !hasDbSubs { + dbSubs = peerSubscriptions{channels: map[string]struct{}{}, patterns: map[string]struct{}{}} + cache.subscriptions[ctx.selectedDB] = dbSubs + } + + subscribedPatterns := m.db(ctx.selectedDB).subscribedPatterns + + for i, pattern := range args { + var peers map[*server.Peer]struct{} + var hasPeers bool + + if peers, hasPeers = subscribedPatterns[pattern]; !hasPeers { + peers = map[*server.Peer]struct{}{} + subscribedPatterns[pattern] = peers + } + + peers[c] = struct{}{} + + dbSubs.patterns[pattern] = struct{}{} + + if _, hasRgx := m.channelPatterns[pattern]; !hasRgx { + m.channelPatterns[pattern] = compileChannelPattern(pattern) + } + + subscriptionsAmounts[i] = m.getSubscriptionsAmount(c, ctx) + } + + for i, pattern := range args { + c.WriteLen(3) + c.WriteBulk("psubscribe") + c.WriteBulk(pattern) + c.WriteInt(subscriptionsAmounts[i]) + } + }) +} + +func compileChannelPattern(pattern string) *regexp.Regexp { + const readingLiteral uint8 = 0 + const afterEscape uint8 = 1 + const inClass uint8 = 2 + + rgx := []rune{'\\', 'A'} + state := readingLiteral + literals := []rune{} + klass := map[rune]struct{}{} + + for _, c := range pattern { + switch state { + case readingLiteral: + switch c { + case '\\': + state = afterEscape + case '?': + rgx = append(rgx, append([]rune(regexp.QuoteMeta(string(literals))), '.')...) + literals = []rune{} + case '*': + rgx = append(rgx, append([]rune(regexp.QuoteMeta(string(literals))), '.', '*')...) + literals = []rune{} + case '[': + rgx = append(rgx, []rune(regexp.QuoteMeta(string(literals)))...) + literals = []rune{} + state = inClass + default: + literals = append(literals, c) + } + case afterEscape: + literals = append(literals, c) + state = readingLiteral + case inClass: + if c == ']' { + expr := []rune{'['} + + if _, hasDash := klass['-']; hasDash { + delete(klass, '-') + expr = append(expr, '-') + } + + flatClass := make([]rune, len(klass)) + i := 0 + + for c := range klass { + flatClass[i] = c + i++ + } + + klass = map[rune]struct{}{} + expr = append(append(expr, []rune(regexp.QuoteMeta(string(flatClass)))...), ']') + + if len(expr) < 3 { + rgx = append(rgx, 'x', '\\', 'b', 'y') + } else { + rgx = append(rgx, expr...) + } + + state = readingLiteral + } else { + klass[c] = struct{}{} + } + } + } + + switch state { + case afterEscape: + rgx = append(rgx, '\\', '\\') + case inClass: + if len(klass) < 0 { + rgx = append(rgx, '\\', '[') + } else { + expr := []rune{'['} + + if _, hasDash := klass['-']; hasDash { + delete(klass, '-') + expr = append(expr, '-') + } + + flatClass := make([]rune, len(klass)) + i := 0 + + for c := range klass { + flatClass[i] = c + i++ + } + + expr = append(append(expr, []rune(regexp.QuoteMeta(string(flatClass)))...), ']') + + if len(expr) < 3 { + rgx = append(rgx, 'x', '\\', 'b', 'y') + } else { + rgx = append(rgx, expr...) + } + } + } + + return regexp.MustCompile(string(append(rgx, '\\', 'z'))) +} + +// PUNSUBSCRIBE +func (m *Miniredis) cmdPUnsubscribe(c *server.Peer, cmd string, args []string) { + if !m.handleAuth(c) { + return + } + + var patterns []string = nil + var subscriptionsAmounts []int = nil + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + if cache, hasCache := m.peers[c]; hasCache { + if dbSubs, hasDbSubs := cache.subscriptions[ctx.selectedDB]; hasDbSubs { + subscribedPatterns := m.db(ctx.selectedDB).subscribedPatterns + + if len(args) > 0 { + patterns = args + } else { + patterns = make([]string, len(dbSubs.patterns)) + i := 0 + + for pattern := range dbSubs.patterns { + patterns[i] = pattern + i++ + } + } + + subscriptionsAmounts = make([]int, len(patterns)) + + for i, pattern := range patterns { + if peers, hasPeers := subscribedPatterns[pattern]; hasPeers { + delete(peers, c) + delete(dbSubs.patterns, pattern) + + if len(peers) < 1 { + delete(subscribedPatterns, pattern) + } + + if len(dbSubs.patterns) < 1 && len(dbSubs.channels) < 1 { + delete(cache.subscriptions, ctx.selectedDB) + + if len(cache.subscriptions) < 1 { + delete(m.peers, c) + } + } + } + + subscriptionsAmounts[i] = m.getSubscriptionsAmount(c, ctx) + } + } + } + + var subscriptionsAmount int + + if patterns == nil { + subscriptionsAmount = m.getSubscriptionsAmount(c, ctx) + } + + if patterns == nil { + for _, pattern := range args { + c.WriteLen(3) + c.WriteBulk("punsubscribe") + c.WriteBulk(pattern) + c.WriteInt(subscriptionsAmount) + } + } else { + for i, pattern := range patterns { + c.WriteLen(3) + c.WriteBulk("punsubscribe") + c.WriteBulk(pattern) + c.WriteInt(subscriptionsAmounts[i]) + } + } + }) +} + +type queuedPubSubMessage struct { + channel, message string +} + +func (m *queuedPubSubMessage) Write(c *server.Peer) { + c.WriteLen(3) + c.WriteBulk("message") + c.WriteBulk(m.channel) + c.WriteBulk(m.message) +} + +// PUBLISH +func (m *Miniredis) cmdPublish(c *server.Peer, cmd string, args []string) { + if len(args) != 2 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + if !m.handleAuth(c) { + return + } + + channel := args[0] + message := args[1] + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + c.WriteInt(m.db(ctx.selectedDB).publishMessage(channel, message)) + }) +} + +// PUBSUB +func (m *Miniredis) cmdPubSub(c *server.Peer, cmd string, args []string) { + if len(args) < 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + + subcommand := strings.ToUpper(args[0]) + subargs := args[1:] + var argsOk bool + + switch subcommand { + case "CHANNELS": + argsOk = len(subargs) < 2 + case "NUMSUB": + argsOk = true + case "NUMPAT": + argsOk = len(subargs) == 0 + default: + argsOk = false + } + + if !argsOk { + setDirty(c) + c.WriteError(errInvalidPubsubArgs(subcommand)) + return + } + + if !m.handleAuth(c) { + return + } + + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + switch subcommand { + case "CHANNELS": + var channels map[string]struct{} + + if len(subargs) == 1 { + pattern := subargs[0] + + var rgx *regexp.Regexp + var hasRgx bool + + if rgx, hasRgx = m.channelPatterns[pattern]; !hasRgx { + rgx = compileChannelPattern(pattern) + m.channelPatterns[pattern] = rgx + } + + channels = m.db(ctx.selectedDB).pubSubChannelsNoLock(rgx) + } else { + channels = m.db(ctx.selectedDB).pubSubChannelsNoLock(nil) + } + + c.WriteLen(len(channels)) + + for channel := range channels { + c.WriteBulk(channel) + } + case "NUMSUB": + numSub := m.db(ctx.selectedDB).pubSubNumSubNoLock(subargs...) + + c.WriteLen(len(numSub) * 2) + + for channel, subs := range numSub { + c.WriteBulk(channel) + c.WriteInt(subs) + } + case "NUMPAT": + c.WriteInt(m.db(ctx.selectedDB).pubSubNumPatNoLock()) + } + }) +} diff --git a/cmd_pubsub_test.go b/cmd_pubsub_test.go new file mode 100644 index 00000000..5141d48a --- /dev/null +++ b/cmd_pubsub_test.go @@ -0,0 +1,1024 @@ +package miniredis + +import ( + "github.com/gomodule/redigo/redis" + "regexp" + "testing" +) + +func TestSubscribe(t *testing.T) { + s, c, done := setup(t) + defer done() + + { + a, err := redis.Values(c.Do("SUBSCRIBE", "event1")) + ok(t, err) + equals(t, []interface{}{[]byte("subscribe"), []byte("event1"), int64(1)}, a) + + equals(t, 1, len(s.dbs[s.selectedDB].subscribedChannels)) + equals(t, 1, len(s.peers)) + } + + { + a, err := redis.Values(c.Do("SUBSCRIBE", "event2")) + ok(t, err) + equals(t, []interface{}{[]byte("subscribe"), []byte("event2"), int64(2)}, a) + + equals(t, 2, len(s.dbs[s.selectedDB].subscribedChannels)) + equals(t, 1, len(s.peers)) + } + + { + a, err := redis.Values(c.Do("SUBSCRIBE", "event3", "event4")) + ok(t, err) + equals(t, []interface{}{[]byte("subscribe"), []byte("event3"), int64(3)}, a) + + equals(t, 4, len(s.dbs[s.selectedDB].subscribedChannels)) + equals(t, 1, len(s.peers)) + } + + { + a, err := redis.Values(c.Receive()) + ok(t, err) + equals(t, []interface{}{[]byte("subscribe"), []byte("event4"), int64(4)}, a) + + equals(t, 4, len(s.dbs[s.selectedDB].subscribedChannels)) + equals(t, 1, len(s.peers)) + } + + { + sub := s.NewSubscriber() + defer sub.Close() + + equals(t, map[string]struct{}{}, sub.channels) + equals(t, map[string]map[*Subscriber]struct{}{}, sub.db.directlySubscribedChannels) + + sub.Subscribe() + equals(t, map[string]struct{}{}, sub.channels) + equals(t, map[string]map[*Subscriber]struct{}{}, sub.db.directlySubscribedChannels) + + sub.Subscribe("event1") + equals(t, map[string]struct{}{"event1": {}}, sub.channels) + equals(t, map[string]map[*Subscriber]struct{}{"event1": {sub: {}}}, sub.db.directlySubscribedChannels) + + sub.Subscribe("event2") + equals(t, map[string]struct{}{"event1": {}, "event2": {}}, sub.channels) + equals(t, map[string]map[*Subscriber]struct{}{"event1": {sub: {}}, "event2": {sub: {}}}, sub.db.directlySubscribedChannels) + + sub.Subscribe("event3", "event4") + equals(t, map[string]struct{}{"event1": {}, "event2": {}, "event3": {}, "event4": {}}, sub.channels) + equals(t, map[string]map[*Subscriber]struct{}{"event1": {sub: {}}, "event2": {sub: {}}, "event3": {sub: {}}, "event4": {sub: {}}}, sub.db.directlySubscribedChannels) + } +} + +func TestUnsubscribe(t *testing.T) { + s, c, done := setup(t) + defer done() + + c.Do("SUBSCRIBE", "event1", "event2", "event3") + c.Receive() + c.Receive() + + { + a, err := redis.Values(c.Do("UNSUBSCRIBE", "event1", "event2")) + ok(t, err) + equals(t, []interface{}{[]byte("unsubscribe"), []byte("event1"), int64(2)}, a) + + equals(t, 1, len(s.dbs[s.selectedDB].subscribedChannels)) + equals(t, 1, len(s.peers)) + } + + { + a, err := redis.Values(c.Receive()) + ok(t, err) + equals(t, []interface{}{[]byte("unsubscribe"), []byte("event2"), int64(1)}, a) + + equals(t, 1, len(s.dbs[s.selectedDB].subscribedChannels)) + equals(t, 1, len(s.peers)) + } + + { + a, err := redis.Values(c.Do("UNSUBSCRIBE", "event3")) + ok(t, err) + equals(t, []interface{}{[]byte("unsubscribe"), []byte("event3"), int64(0)}, a) + + equals(t, 0, len(s.dbs[s.selectedDB].subscribedChannels)) + equals(t, 0, len(s.peers)) + } + + { + a, err := redis.Values(c.Do("UNSUBSCRIBE", "event4")) + ok(t, err) + equals(t, []interface{}{[]byte("unsubscribe"), []byte("event4"), int64(0)}, a) + + equals(t, 0, len(s.dbs[s.selectedDB].subscribedChannels)) + equals(t, 0, len(s.peers)) + } + + { + sub := s.NewSubscriber() + defer sub.Close() + + sub.Subscribe("event1", "event2", "event3") + + sub.Unsubscribe() + equals(t, map[string]struct{}{"event1": {}, "event2": {}, "event3": {}}, sub.channels) + equals(t, map[string]map[*Subscriber]struct{}{"event1": {sub: {}}, "event2": {sub: {}}, "event3": {sub: {}}}, sub.db.directlySubscribedChannels) + + sub.Unsubscribe("event1", "event2") + equals(t, map[string]struct{}{"event3": {}}, sub.channels) + equals(t, map[string]map[*Subscriber]struct{}{"event3": {sub: {}}}, sub.db.directlySubscribedChannels) + + sub.Unsubscribe("event3") + equals(t, map[string]struct{}{}, sub.channels) + equals(t, map[string]map[*Subscriber]struct{}{}, sub.db.directlySubscribedChannels) + + sub.Unsubscribe("event4") + equals(t, map[string]struct{}{}, sub.channels) + equals(t, map[string]map[*Subscriber]struct{}{}, sub.db.directlySubscribedChannels) + } +} + +func TestUnsubscribeAll(t *testing.T) { + s, c, done := setup(t) + defer done() + + c.Do("SUBSCRIBE", "event1", "event2", "event3") + c.Receive() + c.Receive() + + channels := map[string]struct{}{"event1": {}, "event2": {}, "event3": {}} + + { + a, err := redis.Values(c.Do("UNSUBSCRIBE")) + ok(t, err) + + if oneOf(t, mkSubReplySet([]byte("unsubscribe"), channels, 2), a) { + delete(channels, string(a[1].([]byte))) + } + + equals(t, 0, len(s.dbs[s.selectedDB].subscribedChannels)) + equals(t, 0, len(s.peers)) + } + + { + a, err := redis.Values(c.Receive()) + ok(t, err) + + if oneOf(t, mkSubReplySet([]byte("unsubscribe"), channels, 1), a) { + delete(channels, string(a[1].([]byte))) + } + + equals(t, 0, len(s.dbs[s.selectedDB].subscribedChannels)) + equals(t, 0, len(s.peers)) + } + + { + a, err := redis.Values(c.Receive()) + ok(t, err) + + if oneOf(t, mkSubReplySet([]byte("unsubscribe"), channels, 0), a) { + delete(channels, string(a[1].([]byte))) + } + + equals(t, 0, len(s.dbs[s.selectedDB].subscribedChannels)) + equals(t, 0, len(s.peers)) + } + + { + sub := s.NewSubscriber() + defer sub.Close() + + sub.Subscribe("event1", "event2", "event3") + + sub.UnsubscribeAll() + equals(t, map[string]struct{}{}, sub.channels) + equals(t, map[string]map[*Subscriber]struct{}{}, sub.db.directlySubscribedChannels) + } +} + +func TestPSubscribe(t *testing.T) { + s, c, done := setup(t) + defer done() + + { + a, err := redis.Values(c.Do("PSUBSCRIBE", "event1")) + ok(t, err) + equals(t, []interface{}{[]byte("psubscribe"), []byte("event1"), int64(1)}, a) + + equals(t, 1, len(s.dbs[s.selectedDB].subscribedPatterns)) + equals(t, 1, len(s.peers)) + } + + { + a, err := redis.Values(c.Do("PSUBSCRIBE", "event2?")) + ok(t, err) + equals(t, []interface{}{[]byte("psubscribe"), []byte("event2?"), int64(2)}, a) + + equals(t, 2, len(s.dbs[s.selectedDB].subscribedPatterns)) + equals(t, 1, len(s.peers)) + } + + { + a, err := redis.Values(c.Do("PSUBSCRIBE", "event3*", "event4[abc]")) + ok(t, err) + equals(t, []interface{}{[]byte("psubscribe"), []byte("event3*"), int64(3)}, a) + + equals(t, 4, len(s.dbs[s.selectedDB].subscribedPatterns)) + equals(t, 1, len(s.peers)) + } + + { + a, err := redis.Values(c.Receive()) + ok(t, err) + equals(t, []interface{}{[]byte("psubscribe"), []byte("event4[abc]"), int64(4)}, a) + + equals(t, 4, len(s.dbs[s.selectedDB].subscribedPatterns)) + equals(t, 1, len(s.peers)) + } + + { + a, err := redis.Values(c.Do("PSUBSCRIBE", "event5[]")) + ok(t, err) + equals(t, []interface{}{[]byte("psubscribe"), []byte("event5[]"), int64(5)}, a) + + equals(t, 5, len(s.dbs[s.selectedDB].subscribedPatterns)) + equals(t, 1, len(s.peers)) + } + + { + sub := s.NewSubscriber() + defer sub.Close() + + rgxs := [5]*regexp.Regexp{ + regexp.MustCompile(`\Aevent1\z`), + regexp.MustCompile(`\Aevent2.\z`), + regexp.MustCompile(`\Aevent3`), + regexp.MustCompile(`\Aevent4[abc]\z`), + regexp.MustCompile(`\Aevent5X\bY\z`), + } + + equals(t, map[*regexp.Regexp]struct{}{}, sub.patterns) + equals(t, map[*regexp.Regexp]map[*Subscriber]struct{}{}, sub.db.directlySubscribedPatterns) + + sub.PSubscribe() + equals(t, map[*regexp.Regexp]struct{}{}, sub.patterns) + equals(t, map[*regexp.Regexp]map[*Subscriber]struct{}{}, sub.db.directlySubscribedPatterns) + + sub.PSubscribe(rgxs[0]) + equals(t, map[*regexp.Regexp]struct{}{rgxs[0]: {}}, sub.patterns) + equals(t, map[*regexp.Regexp]map[*Subscriber]struct{}{rgxs[0]: {sub: {}}}, sub.db.directlySubscribedPatterns) + + sub.PSubscribe(rgxs[1]) + equals(t, map[*regexp.Regexp]struct{}{rgxs[0]: {}, rgxs[1]: {}}, sub.patterns) + equals(t, map[*regexp.Regexp]map[*Subscriber]struct{}{rgxs[0]: {sub: {}}, rgxs[1]: {sub: {}}}, sub.db.directlySubscribedPatterns) + + sub.PSubscribe(rgxs[2], rgxs[3]) + equals(t, map[*regexp.Regexp]struct{}{rgxs[0]: {}, rgxs[1]: {}, rgxs[2]: {}, rgxs[3]: {}}, sub.patterns) + equals(t, map[*regexp.Regexp]map[*Subscriber]struct{}{rgxs[0]: {sub: {}}, rgxs[1]: {sub: {}}, rgxs[2]: {sub: {}}, rgxs[3]: {sub: {}}}, sub.db.directlySubscribedPatterns) + + sub.PSubscribe(rgxs[4]) + equals(t, map[*regexp.Regexp]struct{}{rgxs[0]: {}, rgxs[1]: {}, rgxs[2]: {}, rgxs[3]: {}, rgxs[4]: {}}, sub.patterns) + equals(t, map[*regexp.Regexp]map[*Subscriber]struct{}{rgxs[0]: {sub: {}}, rgxs[1]: {sub: {}}, rgxs[2]: {sub: {}}, rgxs[3]: {sub: {}}, rgxs[4]: {sub: {}}}, sub.db.directlySubscribedPatterns) + } +} + +func TestPUnsubscribe(t *testing.T) { + s, c, done := setup(t) + defer done() + + c.Do("PSUBSCRIBE", "event1", "event2?", "event3*", "event4[abc]", "event5[]") + c.Receive() + c.Receive() + c.Receive() + c.Receive() + + { + a, err := redis.Values(c.Do("PUNSUBSCRIBE", "event1", "event2?")) + ok(t, err) + equals(t, []interface{}{[]byte("punsubscribe"), []byte("event1"), int64(4)}, a) + + equals(t, 3, len(s.dbs[s.selectedDB].subscribedPatterns)) + equals(t, 1, len(s.peers)) + } + + { + a, err := redis.Values(c.Receive()) + ok(t, err) + equals(t, []interface{}{[]byte("punsubscribe"), []byte("event2?"), int64(3)}, a) + + equals(t, 3, len(s.dbs[s.selectedDB].subscribedPatterns)) + equals(t, 1, len(s.peers)) + } + + { + a, err := redis.Values(c.Do("PUNSUBSCRIBE", "event3*")) + ok(t, err) + equals(t, []interface{}{[]byte("punsubscribe"), []byte("event3*"), int64(2)}, a) + + equals(t, 2, len(s.dbs[s.selectedDB].subscribedPatterns)) + equals(t, 1, len(s.peers)) + } + + { + a, err := redis.Values(c.Do("PUNSUBSCRIBE", "event4[abc]")) + ok(t, err) + equals(t, []interface{}{[]byte("punsubscribe"), []byte("event4[abc]"), int64(1)}, a) + + equals(t, 1, len(s.dbs[s.selectedDB].subscribedPatterns)) + equals(t, 1, len(s.peers)) + } + + { + a, err := redis.Values(c.Do("PUNSUBSCRIBE", "event5[]")) + ok(t, err) + equals(t, []interface{}{[]byte("punsubscribe"), []byte("event5[]"), int64(0)}, a) + + equals(t, 0, len(s.dbs[s.selectedDB].subscribedPatterns)) + equals(t, 0, len(s.peers)) + } + + { + a, err := redis.Values(c.Do("PUNSUBSCRIBE", "event6")) + ok(t, err) + equals(t, []interface{}{[]byte("punsubscribe"), []byte("event6"), int64(0)}, a) + + equals(t, 0, len(s.dbs[s.selectedDB].subscribedPatterns)) + equals(t, 0, len(s.peers)) + } + + { + sub := s.NewSubscriber() + defer sub.Close() + + rgxs := [5]*regexp.Regexp{ + regexp.MustCompile(`\Aevent1\z`), + regexp.MustCompile(`\Aevent2.\z`), + regexp.MustCompile(`\Aevent3`), + regexp.MustCompile(`\Aevent4[abc]\z`), + regexp.MustCompile(`\Aevent5X\bY\z`), + } + + sub.PSubscribe(rgxs[:]...) + + sub.PUnsubscribe() + equals(t, map[*regexp.Regexp]struct{}{rgxs[0]: {}, rgxs[1]: {}, rgxs[2]: {}, rgxs[3]: {}, rgxs[4]: {}}, sub.patterns) + equals(t, map[*regexp.Regexp]map[*Subscriber]struct{}{rgxs[0]: {sub: {}}, rgxs[1]: {sub: {}}, rgxs[2]: {sub: {}}, rgxs[3]: {sub: {}}, rgxs[4]: {sub: {}}}, sub.db.directlySubscribedPatterns) + + sub.PUnsubscribe(rgxs[0], rgxs[1]) + equals(t, map[*regexp.Regexp]struct{}{rgxs[2]: {}, rgxs[3]: {}, rgxs[4]: {}}, sub.patterns) + equals(t, map[*regexp.Regexp]map[*Subscriber]struct{}{rgxs[2]: {sub: {}}, rgxs[3]: {sub: {}}, rgxs[4]: {sub: {}}}, sub.db.directlySubscribedPatterns) + + sub.PUnsubscribe(rgxs[2]) + equals(t, map[*regexp.Regexp]struct{}{rgxs[3]: {}, rgxs[4]: {}}, sub.patterns) + equals(t, map[*regexp.Regexp]map[*Subscriber]struct{}{rgxs[3]: {sub: {}}, rgxs[4]: {sub: {}}}, sub.db.directlySubscribedPatterns) + + sub.PUnsubscribe(rgxs[3]) + equals(t, map[*regexp.Regexp]struct{}{rgxs[4]: {}}, sub.patterns) + equals(t, map[*regexp.Regexp]map[*Subscriber]struct{}{rgxs[4]: {sub: {}}}, sub.db.directlySubscribedPatterns) + + sub.PUnsubscribe(rgxs[4]) + equals(t, map[*regexp.Regexp]struct{}{}, sub.patterns) + equals(t, map[*regexp.Regexp]map[*Subscriber]struct{}{}, sub.db.directlySubscribedPatterns) + + sub.PUnsubscribe(regexp.MustCompile(`\Aevent6\z`)) + equals(t, map[*regexp.Regexp]struct{}{}, sub.patterns) + equals(t, map[*regexp.Regexp]map[*Subscriber]struct{}{}, sub.db.directlySubscribedPatterns) + } +} + +func TestPUnsubscribeAll(t *testing.T) { + s, c, done := setup(t) + defer done() + + c.Do("PSUBSCRIBE", "event1", "event2?", "event3*", "event4[abc]", "event5[]") + c.Receive() + c.Receive() + c.Receive() + c.Receive() + + patterns := map[string]struct{}{"event1": {}, "event2?": {}, "event3*": {}, "event4[abc]": {}, "event5[]": {}} + + { + a, err := redis.Values(c.Do("PUNSUBSCRIBE")) + ok(t, err) + + if oneOf(t, mkSubReplySet([]byte("punsubscribe"), patterns, 4), a) { + delete(patterns, string(a[1].([]byte))) + } + + equals(t, 0, len(s.dbs[s.selectedDB].subscribedPatterns)) + equals(t, 0, len(s.peers)) + } + + { + a, err := redis.Values(c.Receive()) + ok(t, err) + + if oneOf(t, mkSubReplySet([]byte("punsubscribe"), patterns, 3), a) { + delete(patterns, string(a[1].([]byte))) + } + + equals(t, 0, len(s.dbs[s.selectedDB].subscribedPatterns)) + equals(t, 0, len(s.peers)) + } + + { + a, err := redis.Values(c.Receive()) + ok(t, err) + + if oneOf(t, mkSubReplySet([]byte("punsubscribe"), patterns, 2), a) { + delete(patterns, string(a[1].([]byte))) + } + + equals(t, 0, len(s.dbs[s.selectedDB].subscribedPatterns)) + equals(t, 0, len(s.peers)) + } + + { + a, err := redis.Values(c.Receive()) + ok(t, err) + + if oneOf(t, mkSubReplySet([]byte("punsubscribe"), patterns, 1), a) { + delete(patterns, string(a[1].([]byte))) + } + + equals(t, 0, len(s.dbs[s.selectedDB].subscribedPatterns)) + equals(t, 0, len(s.peers)) + } + + { + a, err := redis.Values(c.Receive()) + ok(t, err) + + if oneOf(t, mkSubReplySet([]byte("punsubscribe"), patterns, 0), a) { + delete(patterns, string(a[1].([]byte))) + } + + equals(t, 0, len(s.dbs[s.selectedDB].subscribedPatterns)) + equals(t, 0, len(s.peers)) + } + + { + sub := s.NewSubscriber() + defer sub.Close() + + sub.PSubscribe( + regexp.MustCompile(`\Aevent1\z`), + regexp.MustCompile(`\Aevent2.\z`), + regexp.MustCompile(`\Aevent3`), + regexp.MustCompile(`\Aevent4[abc]\z`), + regexp.MustCompile(`\Aevent5X\bY\z`), + ) + + sub.PUnsubscribeAll() + equals(t, map[*regexp.Regexp]struct{}{}, sub.patterns) + equals(t, map[*regexp.Regexp]map[*Subscriber]struct{}{}, sub.db.directlySubscribedPatterns) + } +} + +func mkSubReplySet(subject []byte, channels map[string]struct{}, subs int64) []interface{} { + result := make([]interface{}, len(channels)) + i := 0 + + for channel := range channels { + result[i] = []interface{}{subject, []byte(channel), subs} + i++ + } + + return result +} + +func TestPublish(t *testing.T) { + s, c, done := setup(t) + defer done() + + { + a, err := redis.Int(c.Do("PUBLISH", "event1", "message2")) + ok(t, err) + equals(t, 0, a) + } + + equals(t, 0, s.Publish("event1", "message2")) +} + +func TestPubSubChannels(t *testing.T) { + s, c, done := setup(t) + defer done() + + { + a, err := redis.Values(c.Do("PUBSUB", "CHANNELS")) + ok(t, err) + equals(t, []interface{}{}, a) + } + + { + a, err := redis.Values(c.Do("PUBSUB", "CHANNELS", "event1?*[abc]")) + ok(t, err) + equals(t, []interface{}{}, a) + } + + equals(t, map[string]struct{}{}, s.PubSubChannels(nil)) + equals(t, map[string]struct{}{}, s.PubSubChannels(regexp.MustCompile(`\Aevent1..*[abc]\z`))) +} + +func TestPubSubNumSub(t *testing.T) { + s, c, done := setup(t) + defer done() + + { + a, err := redis.Values(c.Do("PUBSUB", "NUMSUB")) + ok(t, err) + equals(t, []interface{}{}, a) + } + + { + a, err := redis.Values(c.Do("PUBSUB", "NUMSUB", "event1")) + ok(t, err) + equals(t, []interface{}{[]byte("event1"), int64(0)}, a) + } + + { + a, err := redis.Values(c.Do("PUBSUB", "NUMSUB", "event1", "event2")) + ok(t, err) + oneOf(t, []interface{}{ + []interface{}{[]byte("event1"), int64(0), []byte("event2"), int64(0)}, + []interface{}{[]byte("event2"), int64(0), []byte("event1"), int64(0)}, + }, a) + } + + equals(t, map[string]int{"event1": 0}, s.PubSubNumSub("event1")) +} + +func TestPubSubNumPat(t *testing.T) { + s, c, done := setup(t) + defer done() + + { + a, err := redis.Int(c.Do("PUBSUB", "NUMPAT")) + ok(t, err) + equals(t, 0, a) + } + + equals(t, 0, s.PubSubNumPat()) +} + +func TestPubSubBadArgs(t *testing.T) { + for _, command := range [9]struct { + command string + args []interface{} + err string + }{ + {"SUBSCRIBE", []interface{}{}, "ERR wrong number of arguments for 'subscribe' command"}, + {"PSUBSCRIBE", []interface{}{}, "ERR wrong number of arguments for 'psubscribe' command"}, + {"PUBLISH", []interface{}{}, "ERR wrong number of arguments for 'publish' command"}, + {"PUBLISH", []interface{}{"event1"}, "ERR wrong number of arguments for 'publish' command"}, + {"PUBLISH", []interface{}{"event1", "message2", "message3"}, "ERR wrong number of arguments for 'publish' command"}, + {"PUBSUB", []interface{}{}, "ERR wrong number of arguments for 'pubsub' command"}, + {"PUBSUB", []interface{}{"FOOBAR"}, "ERR Unknown PUBSUB subcommand or wrong number of arguments for 'FOOBAR'"}, + {"PUBSUB", []interface{}{"NUMPAT", "FOOBAR"}, "ERR Unknown PUBSUB subcommand or wrong number of arguments for 'NUMPAT'"}, + {"PUBSUB", []interface{}{"CHANNELS", "FOOBAR1", "FOOBAR2"}, "ERR Unknown PUBSUB subcommand or wrong number of arguments for 'CHANNELS'"}, + } { + _, c, done := setup(t) + + _, err := c.Do(command.command, command.args...) + mustFail(t, err, command.err) + + done() + } +} + +func TestPubSubInteraction(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + + ch := make(chan struct{}, 8) + tasks := [5]func(){} + directTasks := [4]func(){} + + for i, tester := range [5]func(t *testing.T, s *Miniredis, c redis.Conn, chCtl chan struct{}){ + testPubSubInteractionSub1, + testPubSubInteractionSub2, + testPubSubInteractionPsub1, + testPubSubInteractionPsub2, + testPubSubInteractionPub, + } { + tasks[i] = runActualRedisClientForPubSub(t, s, ch, tester) + } + + for i, tester := range [4]func(t *testing.T, s *Miniredis, chCtl chan struct{}){ + testPubSubInteractionDirectSub1, + testPubSubInteractionDirectSub2, + testPubSubInteractionDirectPsub1, + testPubSubInteractionDirectPsub2, + } { + directTasks[i] = runDirectRedisClientForPubSub(t, s, ch, tester) + } + + for _, task := range tasks { + task() + } + + for _, task := range directTasks { + task() + } +} + +func testPubSubInteractionSub1(t *testing.T, _ *Miniredis, c redis.Conn, ch chan struct{}) { + assertCorrectSubscriptionsCounts( + t, + []int64{1, 2, 3, 4}, + runCmdDuringPubSub(t, c, 3, "SUBSCRIBE", "event1", "event2", "event3", "event4"), + ) + + ch <- struct{}{} + receiveMessagesDuringPubSub(t, c, '1', '2', '3', '4') + + assertCorrectSubscriptionsCounts( + t, + []int64{3, 2}, + runCmdDuringPubSub(t, c, 1, "UNSUBSCRIBE", "event2", "event3"), + ) + + ch <- struct{}{} + receiveMessagesDuringPubSub(t, c, '1', '4') +} + +func testPubSubInteractionSub2(t *testing.T, _ *Miniredis, c redis.Conn, ch chan struct{}) { + assertCorrectSubscriptionsCounts( + t, + []int64{1, 2, 3, 4}, + runCmdDuringPubSub(t, c, 3, "SUBSCRIBE", "event3", "event4", "event5", "event6"), + ) + + ch <- struct{}{} + receiveMessagesDuringPubSub(t, c, '3', '4', '5', '6') + + assertCorrectSubscriptionsCounts( + t, + []int64{3, 2}, + runCmdDuringPubSub(t, c, 1, "UNSUBSCRIBE", "event4", "event5"), + ) + + ch <- struct{}{} + receiveMessagesDuringPubSub(t, c, '3', '6') +} + +func testPubSubInteractionDirectSub1(t *testing.T, s *Miniredis, ch chan struct{}) { + sub := s.NewSubscriber() + defer sub.Close() + + sub.Subscribe("event1", "event3", "event4", "event6") + + ch <- struct{}{} + receiveMessagesDirectlyDuringPubSub(t, sub, '1', '3', '4', '6') + + sub.Unsubscribe("event1", "event4") + + ch <- struct{}{} + receiveMessagesDirectlyDuringPubSub(t, sub, '3', '6') +} + +func testPubSubInteractionDirectSub2(t *testing.T, s *Miniredis, ch chan struct{}) { + sub := s.NewSubscriber() + defer sub.Close() + + sub.Subscribe("event2", "event3", "event4", "event5") + + ch <- struct{}{} + receiveMessagesDirectlyDuringPubSub(t, sub, '2', '3', '4', '5') + + sub.Unsubscribe("event3", "event5") + + ch <- struct{}{} + receiveMessagesDirectlyDuringPubSub(t, sub, '2', '4') +} + +func testPubSubInteractionPsub1(t *testing.T, _ *Miniredis, c redis.Conn, ch chan struct{}) { + assertCorrectSubscriptionsCounts( + t, + []int64{1, 2, 3, 4}, + runCmdDuringPubSub(t, c, 3, "PSUBSCRIBE", "event[ab1]", "event[cd]", "event[ef3]", "event[gh]"), + ) + + ch <- struct{}{} + receiveMessagesDuringPubSub(t, c, '1', '3', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h') + + assertCorrectSubscriptionsCounts( + t, + []int64{3, 2}, + runCmdDuringPubSub(t, c, 1, "PUNSUBSCRIBE", "event[cd]", "event[ef3]"), + ) + + ch <- struct{}{} + receiveMessagesDuringPubSub(t, c, '1', 'a', 'b', 'g', 'h') +} + +func testPubSubInteractionPsub2(t *testing.T, _ *Miniredis, c redis.Conn, ch chan struct{}) { + assertCorrectSubscriptionsCounts( + t, + []int64{1, 2, 3, 4}, + runCmdDuringPubSub(t, c, 3, "PSUBSCRIBE", "event[ef]", "event[gh4]", "event[ij]", "event[kl6]"), + ) + + ch <- struct{}{} + receiveMessagesDuringPubSub(t, c, '4', '6', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l') + + assertCorrectSubscriptionsCounts( + t, + []int64{3, 2}, + runCmdDuringPubSub(t, c, 1, "PUNSUBSCRIBE", "event[gh4]", "event[ij]"), + ) + + ch <- struct{}{} + receiveMessagesDuringPubSub(t, c, '6', 'e', 'f', 'k', 'l') +} + +func testPubSubInteractionDirectPsub1(t *testing.T, s *Miniredis, ch chan struct{}) { + rgx := regexp.MustCompile + sub := s.NewSubscriber() + defer sub.Close() + + sub.PSubscribe(rgx(`\Aevent[ab1]\z`), rgx(`\Aevent[ef3]\z`), rgx(`\Aevent[gh]\z`), rgx(`\Aevent[kl6]\z`)) + + ch <- struct{}{} + receiveMessagesDirectlyDuringPubSub(t, sub, '1', '3', '6', 'a', 'b', 'e', 'f', 'g', 'h', 'k', 'l') + + sub.PUnsubscribe(rgx(`\Aevent[ab1]\z`), rgx(`\Aevent[gh]\z`)) + + ch <- struct{}{} + receiveMessagesDirectlyDuringPubSub(t, sub, '3', '6', 'e', 'f', 'k', 'l') +} + +func testPubSubInteractionDirectPsub2(t *testing.T, s *Miniredis, ch chan struct{}) { + rgx := regexp.MustCompile + sub := s.NewSubscriber() + defer sub.Close() + + sub.PSubscribe(rgx(`\Aevent[cd]\z`), rgx(`\Aevent[ef]\z`), rgx(`\Aevent[gh4]\z`), rgx(`\Aevent[ij]\z`)) + + ch <- struct{}{} + receiveMessagesDirectlyDuringPubSub(t, sub, '4', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j') + + sub.PUnsubscribe(rgx(`\Aevent[ef]\z`), rgx(`\Aevent[ij]\z`)) + + ch <- struct{}{} + receiveMessagesDirectlyDuringPubSub(t, sub, '4', 'c', 'd', 'g', 'h') +} + +func testPubSubInteractionPub(t *testing.T, s *Miniredis, c redis.Conn, ch chan struct{}) { + testPubSubInteractionPubStage1(t, s, c, ch) + testPubSubInteractionPubStage2(t, s, c, ch) +} + +func testPubSubInteractionPubStage1(t *testing.T, s *Miniredis, c redis.Conn, ch chan struct{}) { + for i := uint8(0); i < 8; i++ { + <-ch + } + + for _, pattern := range [2]struct { + pattern string + rgx *regexp.Regexp + }{{"", nil}, {"event?", regexp.MustCompile(`\Aevent.\z`)}} { + assertActiveChannelsDuringPubSub(t, s, c, pattern.pattern, pattern.rgx, map[string]struct{}{ + "event1": {}, "event2": {}, "event3": {}, "event4": {}, "event5": {}, "event6": {}, + }) + } + + assertActiveChannelsDuringPubSub(t, s, c, "*[123]", regexp.MustCompile(`[123]\z`), map[string]struct{}{ + "event1": {}, "event2": {}, "event3": {}, + }) + + assertNumSubDuringPubSub(t, s, c, map[string]int{ + "event1": 2, "event2": 2, "event3": 4, "event4": 4, "event5": 2, "event6": 2, + "event[ab1]": 0, "event[cd]": 0, "event[ef3]": 0, "event[gh]": 0, "event[ij]": 0, "event[kl6]": 0, + }) + + assertNumPatDuringPubSub(t, s, c, 16) + + for _, message := range [18]struct { + channelSuffix rune + subscribers uint8 + }{ + {'1', 4}, {'2', 2}, {'3', 6}, {'4', 6}, {'5', 2}, {'6', 4}, + {'a', 2}, {'b', 2}, {'c', 2}, {'d', 2}, {'e', 4}, {'f', 4}, + {'g', 4}, {'h', 4}, {'i', 2}, {'j', 2}, {'k', 2}, {'l', 2}, + } { + suffix := string([]rune{message.channelSuffix}) + replies := runCmdDuringPubSub(t, c, 0, "PUBLISH", "event"+suffix, "message"+suffix) + equals(t, []interface{}{int64(message.subscribers)}, replies) + } +} + +func testPubSubInteractionPubStage2(t *testing.T, s *Miniredis, c redis.Conn, ch chan struct{}) { + for i := uint8(0); i < 8; i++ { + <-ch + } + + for _, pattern := range [2]struct { + pattern string + rgx *regexp.Regexp + }{{"", nil}, {"event?", regexp.MustCompile(`\Aevent.\z`)}} { + assertActiveChannelsDuringPubSub(t, s, c, pattern.pattern, pattern.rgx, map[string]struct{}{ + "event1": {}, "event2": {}, "event3": {}, "event4": {}, "event6": {}, + }) + } + + assertActiveChannelsDuringPubSub(t, s, c, "*[123]", regexp.MustCompile(`[123]\z`), map[string]struct{}{ + "event1": {}, "event2": {}, "event3": {}, + }) + + assertNumSubDuringPubSub(t, s, c, map[string]int{ + "event1": 1, "event2": 1, "event3": 2, "event4": 2, "event5": 0, "event6": 2, + "event[ab1]": 0, "event[cd]": 0, "event[ef3]": 0, "event[gh]": 0, "event[ij]": 0, "event[kl6]": 0, + }) + + assertNumPatDuringPubSub(t, s, c, 8) + + for _, message := range [18]struct { + channelSuffix rune + subscribers uint8 + }{ + {'1', 2}, {'2', 1}, {'3', 3}, {'4', 3}, {'5', 0}, {'6', 4}, + {'a', 1}, {'b', 1}, {'c', 1}, {'d', 1}, {'e', 2}, {'f', 2}, + {'g', 2}, {'h', 2}, {'i', 0}, {'j', 0}, {'k', 2}, {'l', 2}, + } { + suffix := string([]rune{message.channelSuffix}) + equals(t, int(message.subscribers), s.Publish("event"+suffix, "message"+suffix)) + } +} + +func runActualRedisClientForPubSub(t *testing.T, s *Miniredis, chCtl chan struct{}, tester func(t *testing.T, s *Miniredis, c redis.Conn, chCtl chan struct{})) (wait func()) { + t.Helper() + + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + ch := make(chan struct{}) + + go func() { + t.Helper() + + tester(t, s, c, chCtl) + c.Close() + close(ch) + }() + + return func() { <-ch } +} + +func runDirectRedisClientForPubSub(t *testing.T, s *Miniredis, chCtl chan struct{}, tester func(t *testing.T, s *Miniredis, chCtl chan struct{})) (wait func()) { + t.Helper() + + ch := make(chan struct{}) + + go func() { + t.Helper() + + tester(t, s, chCtl) + close(ch) + }() + + return func() { <-ch } +} + +func runCmdDuringPubSub(t *testing.T, c redis.Conn, followUpMessages uint8, command string, args ...interface{}) (replies []interface{}) { + t.Helper() + + replies = make([]interface{}, followUpMessages+1) + + reply, err := c.Do(command, args...) + ok(t, err) + + replies[0] = reply + i := 1 + + for ; followUpMessages > 0; followUpMessages-- { + reply, err := c.Receive() + ok(t, err) + + replies[i] = reply + i++ + } + + return +} + +func assertCorrectSubscriptionsCounts(t *testing.T, subscriptionsCounts []int64, replies []interface{}) { + t.Helper() + + for i, subscriptionsCount := range subscriptionsCounts { + if arrayReply, isArrayReply := replies[i].([]interface{}); isArrayReply && len(arrayReply) > 2 { + equals(t, subscriptionsCount, arrayReply[2]) + } + } +} + +func receiveMessagesDuringPubSub(t *testing.T, c redis.Conn, suffixes ...rune) { + t.Helper() + + for _, suffix := range suffixes { + msg, err := c.Receive() + ok(t, err) + + suff := string([]rune{suffix}) + equals(t, []interface{}{[]byte("message"), []byte("event" + suff), []byte("message" + suff)}, msg) + } +} + +func receiveMessagesDirectlyDuringPubSub(t *testing.T, sub *Subscriber, suffixes ...rune) { + t.Helper() + + for _, suffix := range suffixes { + suff := string([]rune{suffix}) + equals(t, Message{"event" + suff, "message" + suff}, <-sub.Messages) + } +} + +func assertActiveChannelsDuringPubSub(t *testing.T, s *Miniredis, c redis.Conn, pattern string, rgx *regexp.Regexp, channels map[string]struct{}) { + t.Helper() + + var args []interface{} + if pattern == "" { + args = []interface{}{"CHANNELS"} + } else { + args = []interface{}{"CHANNELS", pattern} + } + + a, err := redis.Values(c.Do("PUBSUB", args...)) + ok(t, err) + + actualChannels := make(map[string]struct{}, len(a)) + + for _, channel := range a { + if channelString, channelIsString := channel.([]byte); channelIsString { + actualChannels[string(channelString)] = struct{}{} + } + } + + equals(t, channels, actualChannels) + + equals(t, channels, s.PubSubChannels(rgx)) +} + +func assertNumSubDuringPubSub(t *testing.T, s *Miniredis, c redis.Conn, channels map[string]int) { + t.Helper() + + args := make([]interface{}, 1+len(channels)) + args[0] = "NUMSUB" + i := 1 + + flatChannels := make([]string, len(channels)) + j := 0 + + for channel := range channels { + args[i] = channel + i++ + + flatChannels[j] = channel + j++ + } + + a, err := redis.Values(c.Do("PUBSUB", args...)) + ok(t, err) + equals(t, len(channels)*2, len(a)) + + actualChannels := make(map[string]int, len(a)) + + var currentChannel string + currentState := uint8(0) + + for _, item := range a { + if currentState&uint8(1) == 0 { + if channelString, channelIsString := item.([]byte); channelIsString { + currentChannel = string(channelString) + currentState |= 2 + } else { + currentState &= ^uint8(2) + } + + currentState |= 1 + } else { + if subsInt, subsIsInt := item.(int64); subsIsInt && currentState&uint8(2) != 0 { + actualChannels[currentChannel] = int(subsInt) + } + + currentState &= ^uint8(1) + } + } + + equals(t, channels, actualChannels) + + equals(t, channels, s.PubSubNumSub(flatChannels...)) +} + +func assertNumPatDuringPubSub(t *testing.T, s *Miniredis, c redis.Conn, numPat int) { + t.Helper() + + a, err := redis.Int(c.Do("PUBSUB", "NUMPAT")) + ok(t, err) + equals(t, numPat, a) + + equals(t, numPat, s.PubSubNumPat()) +} diff --git a/direct.go b/direct.go index ca41449f..0670da9d 100644 --- a/direct.go +++ b/direct.go @@ -4,6 +4,10 @@ package miniredis import ( "errors" + "github.com/alicebob/miniredis/server" + "regexp" + "sync" + "sync/atomic" "time" ) @@ -547,3 +551,445 @@ func (db *RedisDB) ZScore(k, member string) (float64, error) { } return db.ssetScore(k, member), nil } + +type Message struct { + Channel, Message string +} + +type messageQueue struct { + sync.Mutex + messages []Message + hasNewMessages chan struct{} +} + +func (q *messageQueue) Enqueue(message Message) { + q.Lock() + defer q.Unlock() + + q.messages = append(q.messages, message) + + select { + case q.hasNewMessages <- struct{}{}: + break + default: + break + } +} + +type Subscriber struct { + Messages chan Message + close chan struct{} + db *RedisDB + channels map[string]struct{} + patterns map[*regexp.Regexp]struct{} + queue messageQueue +} + +func (s *Subscriber) Close() error { + close(s.close) + + s.db.master.Lock() + defer s.db.master.Unlock() + + for channel := range s.channels { + subscribers := s.db.directlySubscribedChannels[channel] + delete(subscribers, s) + + if len(subscribers) < 1 { + delete(s.db.directlySubscribedChannels, channel) + } + } + + for pattern := range s.patterns { + subscribers := s.db.directlySubscribedPatterns[pattern] + delete(subscribers, s) + + if len(subscribers) < 1 { + delete(s.db.directlySubscribedPatterns, pattern) + } + } + + return nil +} + +func (s *Subscriber) Subscribe(channels ...string) { + s.db.master.Lock() + defer s.db.master.Unlock() + + for _, channel := range channels { + s.channels[channel] = struct{}{} + + var peers map[*Subscriber]struct{} + var hasPeers bool + + if peers, hasPeers = s.db.directlySubscribedChannels[channel]; !hasPeers { + peers = map[*Subscriber]struct{}{} + s.db.directlySubscribedChannels[channel] = peers + } + + peers[s] = struct{}{} + } +} + +func (s *Subscriber) Unsubscribe(channels ...string) { + s.db.master.Lock() + defer s.db.master.Unlock() + + for _, channel := range channels { + if _, hasChannel := s.channels[channel]; hasChannel { + delete(s.channels, channel) + + peers := s.db.directlySubscribedChannels[channel] + delete(peers, s) + + if len(peers) < 1 { + delete(s.db.directlySubscribedChannels, channel) + } + } + } +} + +func (s *Subscriber) UnsubscribeAll() { + s.db.master.Lock() + defer s.db.master.Unlock() + + for channel := range s.channels { + subscribers := s.db.directlySubscribedChannels[channel] + delete(subscribers, s) + + if len(subscribers) < 1 { + delete(s.db.directlySubscribedChannels, channel) + } + } + + s.channels = map[string]struct{}{} +} + +func (s *Subscriber) PSubscribe(patterns ...*regexp.Regexp) { + s.db.master.Lock() + defer s.db.master.Unlock() + + decompiledDSPs := s.db.master.decompiledDirectlySubscribedPatterns + + for _, pattern := range patterns { + decompiled := pattern.String() + + if decompiledDSP, hasDDSP := decompiledDSPs[decompiled]; hasDDSP { + pattern = decompiledDSP + } else { + decompiledDSPs[decompiled] = pattern + } + + s.patterns[pattern] = struct{}{} + + var peers map[*Subscriber]struct{} + var hasPeers bool + + if peers, hasPeers = s.db.directlySubscribedPatterns[pattern]; !hasPeers { + peers = map[*Subscriber]struct{}{} + s.db.directlySubscribedPatterns[pattern] = peers + } + + peers[s] = struct{}{} + } +} + +func (s *Subscriber) PUnsubscribe(patterns ...*regexp.Regexp) { + s.db.master.Lock() + defer s.db.master.Unlock() + + decompiledDSPs := s.db.master.decompiledDirectlySubscribedPatterns + + for _, pattern := range patterns { + if decompiledDSP, hasDDSP := decompiledDSPs[pattern.String()]; hasDDSP { + pattern = decompiledDSP + } + + if _, hasChannel := s.patterns[pattern]; hasChannel { + delete(s.patterns, pattern) + + peers := s.db.directlySubscribedPatterns[pattern] + delete(peers, s) + + if len(peers) < 1 { + delete(s.db.directlySubscribedPatterns, pattern) + } + } + } +} + +func (s *Subscriber) PUnsubscribeAll() { + s.db.master.Lock() + defer s.db.master.Unlock() + + for pattern := range s.patterns { + subscribers := s.db.directlySubscribedPatterns[pattern] + delete(subscribers, s) + + if len(subscribers) < 1 { + delete(s.db.directlySubscribedPatterns, pattern) + } + } + + s.patterns = map[*regexp.Regexp]struct{}{} +} + +func (s *Subscriber) streamMessages() { + defer close(s.Messages) + + for { + select { + case <-s.queue.hasNewMessages: + s.queue.Lock() + + select { + case <-s.queue.hasNewMessages: + break + default: + break + } + + messages := s.queue.messages + s.queue.messages = []Message{} + + s.queue.Unlock() + + for _, message := range messages { + select { + case s.Messages <- message: + break + case <-s.close: + return + } + } + case <-s.close: + return + } + } +} + +func (m *Miniredis) NewSubscriber() *Subscriber { + return m.DB(m.selectedDB).NewSubscriber() +} + +func (db *RedisDB) NewSubscriber() *Subscriber { + s := &Subscriber{ + Messages: make(chan Message), + close: make(chan struct{}), + db: db, + channels: map[string]struct{}{}, + patterns: map[*regexp.Regexp]struct{}{}, + queue: messageQueue{ + messages: []Message{}, + hasNewMessages: make(chan struct{}, 1), + }, + } + + go s.streamMessages() + + return s +} + +func (m *Miniredis) Publish(channel, message string) int { + return m.DB(m.selectedDB).Publish(channel, message) +} + +func (db *RedisDB) Publish(channel, message string) int { + db.master.Lock() + defer db.master.Unlock() + + return db.publishMessage(channel, message) +} + +func (m *Miniredis) PubSubChannels(pattern *regexp.Regexp) map[string]struct{} { + return m.DB(m.selectedDB).PubSubChannels(pattern) +} + +func (db *RedisDB) PubSubChannels(pattern *regexp.Regexp) map[string]struct{} { + db.master.Lock() + defer db.master.Unlock() + + return db.pubSubChannelsNoLock(pattern) +} + +func (m *Miniredis) PubSubNumSub(channels ...string) map[string]int { + return m.DB(m.selectedDB).PubSubNumSub(channels...) +} + +func (db *RedisDB) PubSubNumSub(channels ...string) map[string]int { + db.master.Lock() + defer db.master.Unlock() + + return db.pubSubNumSubNoLock(channels...) +} + +func (m *Miniredis) PubSubNumPat() int { + return m.DB(m.selectedDB).PubSubNumPat() +} + +func (db *RedisDB) PubSubNumPat() int { + db.master.Lock() + defer db.master.Unlock() + + return db.pubSubNumPatNoLock() +} + +func (db *RedisDB) pubSubChannelsNoLock(pattern *regexp.Regexp) map[string]struct{} { + channels := map[string]struct{}{} + + if pattern == nil { + for channel := range db.subscribedChannels { + channels[channel] = struct{}{} + } + + for channel := range db.directlySubscribedChannels { + channels[channel] = struct{}{} + } + } else { + for channel := range db.subscribedChannels { + if pattern.MatchString(channel) { + channels[channel] = struct{}{} + } + } + + for channel := range db.directlySubscribedChannels { + if pattern.MatchString(channel) { + channels[channel] = struct{}{} + } + } + } + + return channels +} + +func (db *RedisDB) pubSubNumSubNoLock(channels ...string) map[string]int { + numSub := map[string]int{} + + for _, channel := range channels { + numSub[channel] = len(db.subscribedChannels[channel]) + len(db.directlySubscribedChannels[channel]) + } + + return numSub +} + +func (db *RedisDB) pubSubNumPatNoLock() (numPat int) { + for _, peers := range db.subscribedPatterns { + numPat += len(peers) + } + + for _, subscribers := range db.directlySubscribedPatterns { + numPat += len(subscribers) + } + + return +} + +func (db *RedisDB) publishMessage(channel, message string) int { + count := 0 + + var allPeers map[*server.Peer]struct{} = nil + + if peers, hasPeers := db.subscribedChannels[channel]; hasPeers { + allPeers = make(map[*server.Peer]struct{}, len(peers)) + + for peer := range peers { + allPeers[peer] = struct{}{} + } + } + + for pattern, peers := range db.subscribedPatterns { + if db.master.channelPatterns[pattern].MatchString(channel) { + if allPeers == nil { + allPeers = make(map[*server.Peer]struct{}, len(peers)) + } + + for peer := range peers { + allPeers[peer] = struct{}{} + } + } + } + + if allPeers != nil { + count += len(allPeers) + + wait := publishMessagesAsync(allPeers, channel, message) + defer wait() + } + + var allSubscribers map[*Subscriber]struct{} = nil + + if subscribers, hasSubscribers := db.directlySubscribedChannels[channel]; hasSubscribers { + allSubscribers = make(map[*Subscriber]struct{}, len(subscribers)) + + for subscriber := range subscribers { + allSubscribers[subscriber] = struct{}{} + } + } + + for pattern, subscribers := range db.directlySubscribedPatterns { + if pattern.MatchString(channel) { + if allSubscribers == nil { + allSubscribers = make(map[*Subscriber]struct{}, len(subscribers)) + } + + for subscriber := range subscribers { + allSubscribers[subscriber] = struct{}{} + } + } + } + + if allSubscribers != nil { + count += len(allSubscribers) + + wait := publishMessagesToOurselvesAsync(allSubscribers, channel, message) + defer wait() + } + + return count +} + +func publishMessagesAsync(peers map[*server.Peer]struct{}, channel, message string) (wait func()) { + chCtl := make(chan struct{}) + go publishMessages(peers, channel, message, chCtl) + + return func() { <-chCtl } +} + +func publishMessages(peers map[*server.Peer]struct{}, channel, message string, chCtl chan struct{}) { + pending := uint64(len(peers)) + + for peer := range peers { + go publishMessage(peer, channel, message, &pending, chCtl) + } +} + +func publishMessage(peer *server.Peer, channel, message string, pending *uint64, chCtl chan struct{}) { + peer.MsgQueue.Enqueue(&queuedPubSubMessage{channel, message}) + + if atomic.AddUint64(pending, ^uint64(0)) == 0 { + close(chCtl) + } +} + +func publishMessagesToOurselvesAsync(subscribers map[*Subscriber]struct{}, channel, message string) (wait func()) { + chCtl := make(chan struct{}) + go publishMessagesToOurselves(subscribers, channel, message, chCtl) + + return func() { <-chCtl } +} + +func publishMessagesToOurselves(subscribers map[*Subscriber]struct{}, channel, message string, chCtl chan struct{}) { + pending := uint64(len(subscribers)) + + for subscriber := range subscribers { + go publishMessageToOurselves(subscriber, channel, message, &pending, chCtl) + } +} + +func publishMessageToOurselves(subscriber *Subscriber, channel, message string, pending *uint64, chCtl chan struct{}) { + subscriber.queue.Enqueue(Message{channel, message}) + + if atomic.AddUint64(pending, ^uint64(0)) == 0 { + close(chCtl) + } +} diff --git a/miniredis.go b/miniredis.go index 0688bdfe..ecfc0434 100644 --- a/miniredis.go +++ b/miniredis.go @@ -17,6 +17,7 @@ package miniredis import ( "fmt" "net" + "regexp" "strconv" "sync" "time" @@ -32,29 +33,44 @@ type setKey map[string]struct{} // RedisDB holds a single (numbered) Redis database. type RedisDB struct { - master *sync.Mutex // pointer to the lock in Miniredis - id int // db id - keys map[string]string // Master map of keys with their type - stringKeys map[string]string // GET/SET &c. keys - hashKeys map[string]hashKey // MGET/MSET &c. keys - listKeys map[string]listKey // LPUSH &c. keys - setKeys map[string]setKey // SADD &c. keys - sortedsetKeys map[string]sortedSet // ZADD &c. keys - ttl map[string]time.Duration // effective TTL values - keyVersion map[string]uint // used to watch values + master *Miniredis // pointer to the lock in Miniredis + id int // db id + keys map[string]string // Master map of keys with their type + stringKeys map[string]string // GET/SET &c. keys + hashKeys map[string]hashKey // MGET/MSET &c. keys + listKeys map[string]listKey // LPUSH &c. keys + setKeys map[string]setKey // SADD &c. keys + sortedsetKeys map[string]sortedSet // ZADD &c. keys + ttl map[string]time.Duration // effective TTL values + keyVersion map[string]uint // used to watch values + subscribedChannels map[string]map[*server.Peer]struct{} + subscribedPatterns map[string]map[*server.Peer]struct{} + directlySubscribedChannels map[string]map[*Subscriber]struct{} + directlySubscribedPatterns map[*regexp.Regexp]map[*Subscriber]struct{} +} + +type peerSubscriptions struct { + channels, patterns map[string]struct{} +} + +type peerCache struct { + subscriptions map[int]peerSubscriptions } // Miniredis is a Redis server implementation. type Miniredis struct { sync.Mutex - srv *server.Server - port int - password string - dbs map[int]*RedisDB - selectedDB int // DB id used in the direct Get(), Set() &c. - scripts map[string]string // sha1 -> lua src - signal *sync.Cond - now time.Time // used to make a duration from EXPIREAT. time.Now() if not set. + srv *server.Server + port int + password string + dbs map[int]*RedisDB + selectedDB int // DB id used in the direct Get(), Set() &c. + scripts map[string]string // sha1 -> lua src + signal *sync.Cond + now time.Time // used to make a duration from EXPIREAT. time.Now() if not set. + peers map[*server.Peer]peerCache + channelPatterns map[string]*regexp.Regexp + decompiledDirectlySubscribedPatterns map[string]*regexp.Regexp } type txCmd func(*server.Peer, *connCtx) @@ -77,25 +93,32 @@ type connCtx struct { // NewMiniRedis makes a new, non-started, Miniredis object. func NewMiniRedis() *Miniredis { m := Miniredis{ - dbs: map[int]*RedisDB{}, - scripts: map[string]string{}, + dbs: map[int]*RedisDB{}, + scripts: map[string]string{}, + peers: map[*server.Peer]peerCache{}, + channelPatterns: map[string]*regexp.Regexp{}, + decompiledDirectlySubscribedPatterns: map[string]*regexp.Regexp{}, } m.signal = sync.NewCond(&m) return &m } -func newRedisDB(id int, l *sync.Mutex) RedisDB { +func newRedisDB(id int, l *Miniredis) RedisDB { return RedisDB{ - id: id, - master: l, - keys: map[string]string{}, - stringKeys: map[string]string{}, - hashKeys: map[string]hashKey{}, - listKeys: map[string]listKey{}, - setKeys: map[string]setKey{}, - sortedsetKeys: map[string]sortedSet{}, - ttl: map[string]time.Duration{}, - keyVersion: map[string]uint{}, + id: id, + master: l, + keys: map[string]string{}, + stringKeys: map[string]string{}, + hashKeys: map[string]hashKey{}, + listKeys: map[string]listKey{}, + setKeys: map[string]setKey{}, + sortedsetKeys: map[string]sortedSet{}, + ttl: map[string]time.Duration{}, + keyVersion: map[string]uint{}, + subscribedChannels: map[string]map[*server.Peer]struct{}{}, + subscribedPatterns: map[string]map[*server.Peer]struct{}{}, + directlySubscribedChannels: map[string]map[*Subscriber]struct{}{}, + directlySubscribedPatterns: map[*regexp.Regexp]map[*Subscriber]struct{}{}, } } @@ -137,11 +160,14 @@ func (m *Miniredis) start(s *server.Server) error { commandsString(m) commandsHash(m) commandsList(m) + commandsPubsub(m) commandsSet(m) commandsSortedSet(m) commandsTransaction(m) commandsScripting(m) + s.OnDisconnect(m.onDisconnect) + return nil } @@ -182,7 +208,7 @@ func (m *Miniredis) db(i int) *RedisDB { if db, ok := m.dbs[i]; ok { return db } - db := newRedisDB(i, &m.Mutex) // the DB has our lock. + db := newRedisDB(i, m) // the DB has our lock. m.dbs[i] = &db return &db } @@ -323,6 +349,41 @@ func (m *Miniredis) handleAuth(c *server.Peer) bool { return true } +func (m *Miniredis) onDisconnect(c *server.Peer) { + go m.unsubscribeAll(c) +} + +func (m *Miniredis) unsubscribeAll(c *server.Peer) { + m.Lock() + defer m.Unlock() + + if cache, hasCache := m.peers[c]; hasCache { + for dbIdx, subscriptions := range cache.subscriptions { + db := m.dbs[dbIdx] + + for channel := range subscriptions.channels { + peers := db.subscribedChannels[channel] + delete(peers, c) + + if len(peers) < 1 { + delete(db.subscribedChannels, channel) + } + } + + for pattern := range subscriptions.patterns { + peers := db.subscribedPatterns[pattern] + delete(peers, c) + + if len(peers) < 1 { + delete(db.subscribedPatterns, pattern) + } + } + } + + delete(m.peers, c) + } +} + func getCtx(c *server.Peer) *connCtx { if c.Ctx == nil { c.Ctx = &connCtx{} diff --git a/redis.go b/redis.go index 49ff7bc3..e4eff543 100644 --- a/redis.go +++ b/redis.go @@ -41,6 +41,10 @@ func errLuaParseError(err error) string { return fmt.Sprintf("ERR Error compiling script (new function): %s", err.Error()) } +func errInvalidPubsubArgs(subcommand string) string { + return fmt.Sprintf("ERR Unknown PUBSUB subcommand or wrong number of arguments for '%s'", subcommand) +} + // withTx wraps the non-argument-checking part of command handling code in // transaction logic. func withTx( diff --git a/server/server.go b/server/server.go index 1796453d..b29ecdac 100644 --- a/server/server.go +++ b/server/server.go @@ -23,22 +23,26 @@ func errUnknownCommand(cmd string, args []string) string { // Cmd is what Register expects type Cmd func(c *Peer, cmd string, args []string) +type DisconnectHandler func(c *Peer) + // Server is a simple redis server type Server struct { - l net.Listener - cmds map[string]Cmd - peers map[net.Conn]struct{} - mu sync.Mutex - wg sync.WaitGroup - infoConns int - infoCmds int + l net.Listener + cmds map[string]Cmd + peers map[net.Conn]struct{} + mu sync.Mutex + wg sync.WaitGroup + infoConns int + infoCmds int + onDisconnect DisconnectHandler } // NewServer makes a server listening on addr. Close with .Close(). func NewServer(addr string) (*Server, error) { s := Server{ - cmds: map[string]Cmd{}, - peers: map[net.Conn]struct{}{}, + cmds: map[string]Cmd{}, + peers: map[net.Conn]struct{}{}, + onDisconnect: func(c *Peer) {}, } l, err := net.Listen("tcp", addr) @@ -76,7 +80,7 @@ func (s *Server) ServeConn(conn net.Conn) { s.infoConns++ s.mu.Unlock() - s.servePeer(conn) + s.onDisconnect(s.servePeer(conn)) s.mu.Lock() delete(s.peers, conn) @@ -122,22 +126,87 @@ func (s *Server) Register(cmd string, f Cmd) error { return nil } -func (s *Server) servePeer(c net.Conn) { +func (s *Server) OnDisconnect(handler DisconnectHandler) { + s.onDisconnect = handler +} + +func (s *Server) servePeer(c net.Conn) (cl *Peer) { r := bufio.NewReader(c) - cl := &Peer{ + cl = &Peer{ w: bufio.NewWriter(c), + MsgQueue: MessageQueue{ + messages: []QueuedMessage{}, + hasNewMessages: make(chan struct{}, 1), + }, } + + chReceivedArray, chReadNext := readArrayAsync(r) + defer close(chReadNext) + + chReadNext <- struct{}{} + for { - args, err := readArray(r) - if err != nil { - return + select { + case message := <-chReceivedArray: + if message.err != nil { + return + } + + s.dispatch(cl, message.array) + cl.w.Flush() + + if cl.closed { + c.Close() + return + } + + chReadNext <- struct{}{} + case <-cl.MsgQueue.hasNewMessages: + cl.MsgQueue.Lock() + + select { + case <-cl.MsgQueue.hasNewMessages: + break + default: + break + } + + messages := cl.MsgQueue.messages + cl.MsgQueue.messages = []QueuedMessage{} + + cl.MsgQueue.Unlock() + + for _, message := range messages { + message.Write(cl) + } + + cl.Flush() } - s.dispatch(cl, args) - cl.w.Flush() - if cl.closed { - c.Close() + } +} + +type receivedArray struct { + array []string + err error +} + +func readArrayAsync(r *bufio.Reader) (chReceivedArray chan receivedArray, chReadNext chan struct{}) { + chReceivedArray = make(chan receivedArray) + chReadNext = make(chan struct{}) + + go readArraySync(r, chReceivedArray, chReadNext) + return +} + +func readArraySync(r *bufio.Reader, chReceivedArray chan receivedArray, chReadNext chan struct{}) { + for { + if _, isOpen := <-chReadNext; !isOpen { + close(chReceivedArray) return } + + args, err := readArray(r) + chReceivedArray <- receivedArray{args, err} } } @@ -180,11 +249,36 @@ func (s *Server) TotalConnections() int { return s.infoConns } +type QueuedMessage interface { + Write(c *Peer) +} + +type MessageQueue struct { + sync.Mutex + messages []QueuedMessage + hasNewMessages chan struct{} +} + +func (q *MessageQueue) Enqueue(message QueuedMessage) { + q.Lock() + defer q.Unlock() + + q.messages = append(q.messages, message) + + select { + case q.hasNewMessages <- struct{}{}: + break + default: + break + } +} + // Peer is a client connected to the server type Peer struct { - w *bufio.Writer - closed bool - Ctx interface{} // anything goes, server won't touch this + w *bufio.Writer + closed bool + Ctx interface{} // anything goes, server won't touch this + MsgQueue MessageQueue } // Flush the write buffer. Called automatically after every redis command diff --git a/test_test.go b/test_test.go index a20c941b..7868351c 100644 --- a/test_test.go +++ b/test_test.go @@ -1,7 +1,9 @@ package miniredis import ( + "fmt" "reflect" + "strings" "testing" ) @@ -41,3 +43,22 @@ func mustFail(tb testing.TB, err error, want string) { tb.Errorf("have %q, want %q", have, want) } } + +func oneOf(tb testing.TB, exps []interface{}, act interface{}) bool { + tb.Helper() + + for _, exp := range exps { + if reflect.DeepEqual(exp, act) { + return true + } + } + + expPP := make([]string, len(exps)) + for i, exp := range exps { + expPP[i] = fmt.Sprintf("%#v", exp) + } + + tb.Errorf("expected one of: %s got: %#v", strings.Join(expPP, ", "), act) + + return false +} From ca45031e80aacffa1b729c28d09167392ab5b65a Mon Sep 17 00:00:00 2001 From: "Alexander A. Klimov" Date: Fri, 16 Nov 2018 10:31:58 +0100 Subject: [PATCH 04/13] Add integration tests refs #1 --- integration/pubsub_test.go | 101 +++++++++++++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) create mode 100644 integration/pubsub_test.go diff --git a/integration/pubsub_test.go b/integration/pubsub_test.go new file mode 100644 index 00000000..0952a81d --- /dev/null +++ b/integration/pubsub_test.go @@ -0,0 +1,101 @@ +// +build int + +package main + +import ( + "testing" +) + +func TestSubscribe(t *testing.T) { + testCommands(t, + fail("SUBSCRIBE"), + + succ("SUBSCRIBE", "foo"), + succ("UNSUBSCRIBE"), + + succ("SUBSCRIBE", "foo"), + succ("UNSUBSCRIBE", "foo"), + + succ("SUBSCRIBE", "foo", "bar"), + succ("UNSUBSCRIBE", "foo", "bar"), + + succ("SUBSCRIBE", -1), + succ("UNSUBSCRIBE", -1), + ) +} + +func TestPSubscribe(t *testing.T) { + testCommands(t, + fail("PSUBSCRIBE"), + + succ("PSUBSCRIBE", "foo"), + succ("PUNSUBSCRIBE"), + + succ("PSUBSCRIBE", "foo"), + succ("PUNSUBSCRIBE", "foo"), + + succ("PSUBSCRIBE", "foo", "bar"), + succ("PUNSUBSCRIBE", "foo", "bar"), + + succ("PSUBSCRIBE", "f?o"), + succ("PUNSUBSCRIBE", "f?o"), + + succ("PSUBSCRIBE", "f*o"), + succ("PUNSUBSCRIBE", "f*o"), + + succ("PSUBSCRIBE", "f[oO]o"), + succ("PUNSUBSCRIBE", "f[oO]o"), + + succ("PSUBSCRIBE", "f\\?o"), + succ("PUNSUBSCRIBE", "f\\?o"), + + succ("PSUBSCRIBE", "f\\*o"), + succ("PUNSUBSCRIBE", "f\\*o"), + + succ("PSUBSCRIBE", "f\\[oO]o"), + succ("PUNSUBSCRIBE", "f\\[oO]o"), + + succ("PSUBSCRIBE", "f\\\\oo"), + succ("PUNSUBSCRIBE", "f\\\\oo"), + + succ("PSUBSCRIBE", -1), + succ("PUNSUBSCRIBE", -1), + ) +} + +func TestPublish(t *testing.T) { + testCommands(t, + fail("PUBLISH"), + fail("PUBLISH", "foo"), + succ("PUBLISH", "foo", "bar"), + fail("PUBLISH", "foo", "bar", "deadbeef"), + succ("PUBLISH", -1, -2), + ) +} + +func TestPubSub(t *testing.T) { + testCommands(t, + fail("PUBSUB"), + fail("PUBSUB", "FOO"), + + succ("PUBSUB", "CHANNELS"), + succ("PUBSUB", "CHANNELS", "foo"), + fail("PUBSUB", "CHANNELS", "foo", "bar"), + succ("PUBSUB", "CHANNELS", "f?o"), + succ("PUBSUB", "CHANNELS", "f*o"), + succ("PUBSUB", "CHANNELS", "f[oO]o"), + succ("PUBSUB", "CHANNELS", "f\\?o"), + succ("PUBSUB", "CHANNELS", "f\\*o"), + succ("PUBSUB", "CHANNELS", "f\\[oO]o"), + succ("PUBSUB", "CHANNELS", "f\\\\oo"), + succ("PUBSUB", "CHANNELS", -1), + + succ("PUBSUB", "NUMSUB"), + succ("PUBSUB", "NUMSUB", "foo"), + succ("PUBSUB", "NUMSUB", "foo", "bar"), + succ("PUBSUB", "NUMSUB", -1), + + succ("PUBSUB", "NUMPAT"), + fail("PUBSUB", "NUMPAT", "foo"), + ) +} From 8262e0afa2524ca47f471ed6be0dce2bc0525196 Mon Sep 17 00:00:00 2001 From: Harmen Date: Wed, 6 Mar 2019 17:00:02 +0100 Subject: [PATCH 05/13] update error message for redis 5.X --- cmd_pubsub.go | 6 ++++-- cmd_pubsub_test.go | 6 +++--- redis.go | 5 +---- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/cmd_pubsub.go b/cmd_pubsub.go index 53a1b2d6..508754c7 100644 --- a/cmd_pubsub.go +++ b/cmd_pubsub.go @@ -3,9 +3,11 @@ package miniredis import ( - "github.com/alicebob/miniredis/server" + "fmt" "regexp" "strings" + + "github.com/alicebob/miniredis/server" ) // commandsPubsub handles all PUB/SUB operations. @@ -451,7 +453,7 @@ func (m *Miniredis) cmdPubSub(c *server.Peer, cmd string, args []string) { if !argsOk { setDirty(c) - c.WriteError(errInvalidPubsubArgs(subcommand)) + c.WriteError(fmt.Sprintf(msgFPubsubUsage, subcommand)) return } diff --git a/cmd_pubsub_test.go b/cmd_pubsub_test.go index 5141d48a..0175a779 100644 --- a/cmd_pubsub_test.go +++ b/cmd_pubsub_test.go @@ -575,9 +575,9 @@ func TestPubSubBadArgs(t *testing.T) { {"PUBLISH", []interface{}{"event1"}, "ERR wrong number of arguments for 'publish' command"}, {"PUBLISH", []interface{}{"event1", "message2", "message3"}, "ERR wrong number of arguments for 'publish' command"}, {"PUBSUB", []interface{}{}, "ERR wrong number of arguments for 'pubsub' command"}, - {"PUBSUB", []interface{}{"FOOBAR"}, "ERR Unknown PUBSUB subcommand or wrong number of arguments for 'FOOBAR'"}, - {"PUBSUB", []interface{}{"NUMPAT", "FOOBAR"}, "ERR Unknown PUBSUB subcommand or wrong number of arguments for 'NUMPAT'"}, - {"PUBSUB", []interface{}{"CHANNELS", "FOOBAR1", "FOOBAR2"}, "ERR Unknown PUBSUB subcommand or wrong number of arguments for 'CHANNELS'"}, + {"PUBSUB", []interface{}{"FOOBAR"}, "ERR Unknown subcommand or wrong number of arguments for 'FOOBAR'. Try PUBSUB HELP."}, + {"PUBSUB", []interface{}{"NUMPAT", "FOOBAR"}, "ERR Unknown subcommand or wrong number of arguments for 'NUMPAT'. Try PUBSUB HELP."}, + {"PUBSUB", []interface{}{"CHANNELS", "FOOBAR1", "FOOBAR2"}, "ERR Unknown subcommand or wrong number of arguments for 'CHANNELS'. Try PUBSUB HELP."}, } { _, c, done := setup(t) diff --git a/redis.go b/redis.go index e4eff543..ab353dbc 100644 --- a/redis.go +++ b/redis.go @@ -29,6 +29,7 @@ const ( msgInvalidKeysNumber = "ERR Number of keys can't be greater than number of args" msgNegativeKeysNumber = "ERR Number of keys can't be negative" msgFScriptUsage = "ERR Unknown subcommand or wrong number of arguments for '%s'. Try SCRIPT HELP." + msgFPubsubUsage = "ERR Unknown subcommand or wrong number of arguments for '%s'. Try PUBSUB HELP." msgSingleElementPair = "ERR INCR option supports a single increment-element pair" msgNoScriptFound = "NOSCRIPT No matching script. Please use EVAL." ) @@ -41,10 +42,6 @@ func errLuaParseError(err error) string { return fmt.Sprintf("ERR Error compiling script (new function): %s", err.Error()) } -func errInvalidPubsubArgs(subcommand string) string { - return fmt.Sprintf("ERR Unknown PUBSUB subcommand or wrong number of arguments for '%s'", subcommand) -} - // withTx wraps the non-argument-checking part of command handling code in // transaction logic. func withTx( From cdf0ee01ed2fb37f0e22a0c3dfdd839053cab436 Mon Sep 17 00:00:00 2001 From: Harmen Date: Sat, 16 Mar 2019 10:32:16 +0100 Subject: [PATCH 06/13] some failing tests --- integration/pubsub_test.go | 93 ++++++++++++++++++++++++++++++++++++++ integration/test.go | 68 ++++++++++++++++++++++++---- 2 files changed, 152 insertions(+), 9 deletions(-) diff --git a/integration/pubsub_test.go b/integration/pubsub_test.go index 0952a81d..a8a7e566 100644 --- a/integration/pubsub_test.go +++ b/integration/pubsub_test.go @@ -3,7 +3,10 @@ package main import ( + "sync" "testing" + + "github.com/alicebob/miniredis" ) func TestSubscribe(t *testing.T) { @@ -99,3 +102,93 @@ func TestPubSub(t *testing.T) { fail("PUBSUB", "NUMPAT", "foo"), ) } + +func TestPubsubFull(t *testing.T) { + t.Skip() // exit 1, no idea why + var wg1 sync.WaitGroup + wg1.Add(1) + testMultiCommands(t, + func(r chan<- command, _ *miniredis.Miniredis) { + r <- succ("SUBSCRIBE", "news", "sport") + r <- receive() + /* + wg1.Done() + r <- receive() + r <- receive() + r <- receive() + r <- succ("UNSUBSCRIBE", "news", "sport") + r <- receive() + */ + }, + /* + func(r chan<- command, _ *miniredis.Miniredis) { + wg1.Wait() + r <- succ("PUBLISH", "news", "revolution!") + r <- succ("PUBLISH", "news", "alien invasion!") + r <- succ("PUBLISH", "sport", "lady biked too fast") + r <- succ("PUBLISH", "gossip", "man bites dog") + }, + */ + ) +} + +func TestPubsubMulti(t *testing.T) { + t.Skip() // hangs. No idea why. + var wg1 sync.WaitGroup + wg1.Add(2) + testMultiCommands(t, + func(r chan<- command, _ *miniredis.Miniredis) { + r <- succ("SUBSCRIBE", "news", "sport") + r <- receive() + wg1.Done() + r <- receive() + r <- receive() + r <- receive() + r <- succ("UNSUBSCRIBE", "news", "sport") + r <- receive() + }, + func(r chan<- command, _ *miniredis.Miniredis) { + r <- succ("SUBSCRIBE", "sport") + r <- receive() + wg1.Done() + r <- receive() + r <- succ("UNSUBSCRIBE", "sport") + r <- receive() + }, + func(r chan<- command, _ *miniredis.Miniredis) { + wg1.Wait() + r <- succ("PUBLISH", "news", "revolution!") + r <- succ("PUBLISH", "news", "alien invasion!") + r <- succ("PUBLISH", "sport", "lady biked too fast") + }, + ) +} + +func TestPubsubSelect(t *testing.T) { + t.Skip() // known broken + var wg1 sync.WaitGroup + wg1.Add(1) + testMultiCommands(t, + func(r chan<- command, _ *miniredis.Miniredis) { + r <- succ("SUBSCRIBE", "news", "sport") + r <- receive() + wg1.Done() + r <- receive() + }, + func(r chan<- command, _ *miniredis.Miniredis) { + wg1.Wait() + r <- succ("SELECT", 3) + r <- succ("PUBLISH", "news", "revolution!") + }, + ) +} + +func TestPubsubMode(t *testing.T) { + t.Skip() // known broken + testCommands(t, + succ("SUBSCRIBE", "news", "sport"), + receive(), + fail("ECHO", "foo"), + fail("HGET", "foo", "bar"), + ) +} diff --git a/integration/test.go b/integration/test.go index 5db7a36f..ec1ebef8 100644 --- a/integration/test.go +++ b/integration/test.go @@ -16,12 +16,13 @@ import ( ) type command struct { - cmd string // 'GET', 'SET', &c. - args []interface{} - error bool // Whether the command should return an error or not. - sort bool // Sort real redis's result. Used for 'keys'. - loosely bool // Don't compare values, only structure. (for random things) - errorSub string // Both errors need this substring + cmd string // 'GET', 'SET', &c. + args []interface{} + error bool // Whether the command should return an error or not. + sort bool // Sort real redis's result. Used for 'keys'. + loosely bool // Don't compare values, only structure. (for random things) + errorSub string // Both errors need this substring + receiveOnly bool // no command, only receive. For pubsub messages. } func succ(cmd string, args ...interface{}) command { @@ -78,6 +79,13 @@ func failLoosely(cmd string, args ...interface{}) command { } } +// don't send a message, only read one. For pubsub messages. +func receive() command { + return command{ + receiveOnly: true, + } +} + // ok fails the test if an err is not nil. func ok(tb testing.TB, err error) { tb.Helper() @@ -109,7 +117,7 @@ func testMultiCommands(t *testing.T, cs ...func(chan<- command, *miniredis.Minir var wg sync.WaitGroup for _, c := range cs { - // one connections per cs + // one connection per cs cMini, err := redis.Dial("tcp", sMini.Addr()) ok(t, err) @@ -160,8 +168,35 @@ func runCommands(t *testing.T, realAddr, miniAddr string, commands []command) { func runCommand(t *testing.T, cMini, cReal redis.Conn, p command) { t.Helper() - vReal, errReal := cReal.Do(p.cmd, p.args...) - vMini, errMini := cMini.Do(p.cmd, p.args...) + var ( + vReal, vMini interface{} + errReal, errMini error + ) + if p.receiveOnly { + vReal, errReal = cReal.Receive() + dump(vReal, "-real-") + vMini, errMini = cMini.Receive() + dump(vMini, "-mini-") + for _, k := range vReal.([]interface{}) { + switch k := k.(type) { + case []byte: + t.Errorf(" -real- %s", string(k)) + default: + t.Errorf(" -real- %#v", k) + } + } + for _, k := range vMini.([]interface{}) { + switch k := k.(type) { + case []byte: + t.Errorf(" -mini- %s", string(k)) + default: + t.Errorf(" -mini- %#v", k) + } + } + } else { + vReal, errReal = cReal.Do(p.cmd, p.args...) + vMini, errMini = cMini.Do(p.cmd, p.args...) + } if p.error { if errReal == nil { t.Errorf("got no error from realredis. case: %#v", p) @@ -211,6 +246,8 @@ func runCommand(t *testing.T, cMini, cReal redis.Conn, p command) { } else { if !reflect.DeepEqual(vReal, vMini) { t.Errorf("value error. expected: %#v got: %#v case: %#v", vReal, vMini, p) + dump(vReal, " --real-") + dump(vMini, " --mini-") return } } @@ -259,3 +296,16 @@ func looselyEqual(a, b interface{}) bool { panic(fmt.Sprintf("unhandled case, got a %#v", a)) } } + +func dump(r interface{}, prefix string) { + if ls, ok := r.([]interface{}); ok { + for _, k := range ls { + switch k := k.(type) { + case []byte: + fmt.Printf(" %s %s\n", prefix, string(k)) + default: + fmt.Printf(" %s %#v\n", prefix, k) + } + } + } +} From f4b3ed6109f198ee3338f7fc8f553b772ed1750c Mon Sep 17 00:00:00 2001 From: Harmen Date: Sat, 16 Mar 2019 22:35:56 +0100 Subject: [PATCH 07/13] simplify pubsub --- cmd_connection.go | 31 +- cmd_connection_test.go | 20 ++ cmd_hash.go | 6 + cmd_pubsub.go | 340 +++++-------------- cmd_pubsub_test.go | 650 +++++++++++++----------------------- cmd_string.go | 4 + direct.go | 453 ++----------------------- integration/generic_test.go | 8 + integration/pubsub_test.go | 81 ++--- integration/test.go | 69 +++- miniredis.go | 209 +++++++----- pubsub.go | 204 +++++++++++ server/server.go | 184 ++++------ 13 files changed, 875 insertions(+), 1384 deletions(-) create mode 100644 pubsub.go diff --git a/cmd_connection.go b/cmd_connection.go index ca648f4b..39689949 100644 --- a/cmd_connection.go +++ b/cmd_connection.go @@ -21,7 +21,33 @@ func (m *Miniredis) cmdPing(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } - c.WriteInline("PONG") + + if len(args) > 1 { + setDirty(c) + c.WriteError(errWrongNumber(cmd)) + return + } + + payload := "" + if len(args) > 0 { + payload = args[0] + } + + // PING is allowed in subscribed state + if sub := getCtx(c).subscriber; sub != nil { + c.Block(func(c *server.Peer) { + c.WriteLen(2) + c.WriteBulk("pong") + c.WriteBulk(payload) + }) + return + } + + if payload == "" { + c.WriteInline("PONG") + return + } + c.WriteBulk(payload) } // AUTH @@ -58,6 +84,9 @@ func (m *Miniredis) cmdEcho(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } msg := args[0] c.WriteBulk(msg) diff --git a/cmd_connection_test.go b/cmd_connection_test.go index 62b74212..eb28569a 100644 --- a/cmd_connection_test.go +++ b/cmd_connection_test.go @@ -30,6 +30,26 @@ func TestAuth(t *testing.T) { ok(t, err) } +func TestPing(t *testing.T) { + s, err := Run() + ok(t, err) + defer s.Close() + c, err := redis.Dial("tcp", s.Addr()) + ok(t, err) + + r, err := redis.String(c.Do("PING")) + ok(t, err) + equals(t, "PONG", r) + + r, err = redis.String(c.Do("PING", "hi")) + ok(t, err) + equals(t, "hi", r) + + _, err = c.Do("PING", "foo", "bar") + mustFail(t, err, errWrongNumber("ping")) + +} + func TestEcho(t *testing.T) { s, err := Run() ok(t, err) diff --git a/cmd_hash.go b/cmd_hash.go index 1c65ebec..c1248099 100644 --- a/cmd_hash.go +++ b/cmd_hash.go @@ -37,6 +37,9 @@ func (m *Miniredis) cmdHset(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, field, value := args[0], args[1], args[2] @@ -138,6 +141,9 @@ func (m *Miniredis) cmdHget(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, field := args[0], args[1] diff --git a/cmd_pubsub.go b/cmd_pubsub.go index 508754c7..1fc5398b 100644 --- a/cmd_pubsub.go +++ b/cmd_pubsub.go @@ -14,14 +14,15 @@ import ( func commandsPubsub(m *Miniredis) { m.srv.Register("SUBSCRIBE", m.cmdSubscribe) m.srv.Register("UNSUBSCRIBE", m.cmdUnsubscribe) - m.srv.Register("PSUBSCRIBE", m.cmdPSubscribe) - m.srv.Register("PUNSUBSCRIBE", m.cmdPUnsubscribe) + m.srv.Register("PSUBSCRIBE", m.cmdPsubscribe) + m.srv.Register("PUNSUBSCRIBE", m.cmdPunsubscribe) m.srv.Register("PUBLISH", m.cmdPublish) m.srv.Register("PUBSUB", m.cmdPubSub) } // SUBSCRIBE func (m *Miniredis) cmdSubscribe(c *server.Peer, cmd string, args []string) { + // TODO: figure out transactions. if len(args) < 1 { setDirty(c) c.WriteError(errWrongNumber(cmd)) @@ -31,60 +32,16 @@ func (m *Miniredis) cmdSubscribe(c *server.Peer, cmd string, args []string) { return } - subscriptionsAmounts := make([]int, len(args)) - - withTx(m, c, func(c *server.Peer, ctx *connCtx) { - var cache peerCache - var hasCache bool - - if cache, hasCache = m.peers[c]; !hasCache { - cache = peerCache{subscriptions: map[int]peerSubscriptions{}} - m.peers[c] = cache - } - - var dbSubs peerSubscriptions - var hasDbSubs bool - - if dbSubs, hasDbSubs = cache.subscriptions[ctx.selectedDB]; !hasDbSubs { - dbSubs = peerSubscriptions{channels: map[string]struct{}{}, patterns: map[string]struct{}{}} - cache.subscriptions[ctx.selectedDB] = dbSubs - } - - subscribedChannels := m.db(ctx.selectedDB).subscribedChannels - - for i, channel := range args { - var peers map[*server.Peer]struct{} - var hasPeers bool - - if peers, hasPeers = subscribedChannels[channel]; !hasPeers { - peers = map[*server.Peer]struct{}{} - subscribedChannels[channel] = peers - } - - peers[c] = struct{}{} - - dbSubs.channels[channel] = struct{}{} - - subscriptionsAmounts[i] = m.getSubscriptionsAmount(c, ctx) - } - - for i, channel := range args { + sub := m.subscribedState(c) + for _, channel := range args { + n := sub.Subscribe(channel) + c.Block(func(c *server.Peer) { c.WriteLen(3) c.WriteBulk("subscribe") c.WriteBulk(channel) - c.WriteInt(subscriptionsAmounts[i]) - } - }) -} - -func (m *Miniredis) getSubscriptionsAmount(c *server.Peer, ctx *connCtx) (total int) { - if cache, hasCache := m.peers[c]; hasCache { - if dbSubs, hasDbSubs := cache.subscriptions[ctx.selectedDB]; hasDbSubs { - total = len(dbSubs.channels) + len(dbSubs.patterns) - } + c.WriteInt(n) + }) } - - return } // UNSUBSCRIBE @@ -93,77 +50,33 @@ func (m *Miniredis) cmdUnsubscribe(c *server.Peer, cmd string, args []string) { return } - var channels []string = nil - var subscriptionsAmounts []int = nil - - withTx(m, c, func(c *server.Peer, ctx *connCtx) { - if cache, hasCache := m.peers[c]; hasCache { - if dbSubs, hasDbSubs := cache.subscriptions[ctx.selectedDB]; hasDbSubs { - subscribedChannels := m.db(ctx.selectedDB).subscribedChannels - - if len(args) > 0 { - channels = args - } else { - channels = make([]string, len(dbSubs.channels)) - i := 0 - - for channel := range dbSubs.channels { - channels[i] = channel - i++ - } - } - - subscriptionsAmounts = make([]int, len(channels)) + sub := m.subscribedState(c) - for i, channel := range channels { - if peers, hasPeers := subscribedChannels[channel]; hasPeers { - delete(peers, c) - delete(dbSubs.channels, channel) + // TODO: tx, which also locks. - if len(peers) < 1 { - delete(subscribedChannels, channel) - } - - if len(dbSubs.channels) < 1 && len(dbSubs.patterns) < 1 { - delete(cache.subscriptions, ctx.selectedDB) - - if len(cache.subscriptions) < 1 { - delete(m.peers, c) - } - } - } - - subscriptionsAmounts[i] = m.getSubscriptionsAmount(c, ctx) - } - } - } - - var subscriptionsAmount int + channels := args + if len(channels) == 0 { + channels = sub.Channels() + } - if channels == nil { - subscriptionsAmount = m.getSubscriptionsAmount(c, ctx) - } + // there is no de-duplication + for _, channel := range channels { + n := sub.Unsubscribe(channel) + c.Block(func(c *server.Peer) { + c.WriteLen(3) + c.WriteBulk("unsubscribe") + c.WriteBulk(channel) + c.WriteInt(n) + }) + } - if channels == nil { - for _, channel := range args { - c.WriteLen(3) - c.WriteBulk("unsubscribe") - c.WriteBulk(channel) - c.WriteInt(subscriptionsAmount) - } - } else { - for i, channel := range channels { - c.WriteLen(3) - c.WriteBulk("unsubscribe") - c.WriteBulk(channel) - c.WriteInt(subscriptionsAmounts[i]) - } - } - }) + if sub.Count() == 0 { + endSubscriber(m, c) + } } // PSUBSCRIBE -func (m *Miniredis) cmdPSubscribe(c *server.Peer, cmd string, args []string) { +func (m *Miniredis) cmdPsubscribe(c *server.Peer, cmd string, args []string) { if len(args) < 1 { setDirty(c) c.WriteError(errWrongNumber(cmd)) @@ -173,54 +86,16 @@ func (m *Miniredis) cmdPSubscribe(c *server.Peer, cmd string, args []string) { return } - subscriptionsAmounts := make([]int, len(args)) - - withTx(m, c, func(c *server.Peer, ctx *connCtx) { - var cache peerCache - var hasCache bool - - if cache, hasCache = m.peers[c]; !hasCache { - cache = peerCache{subscriptions: map[int]peerSubscriptions{}} - m.peers[c] = cache - } - - var dbSubs peerSubscriptions - var hasDbSubs bool - - if dbSubs, hasDbSubs = cache.subscriptions[ctx.selectedDB]; !hasDbSubs { - dbSubs = peerSubscriptions{channels: map[string]struct{}{}, patterns: map[string]struct{}{}} - cache.subscriptions[ctx.selectedDB] = dbSubs - } - - subscribedPatterns := m.db(ctx.selectedDB).subscribedPatterns - - for i, pattern := range args { - var peers map[*server.Peer]struct{} - var hasPeers bool - - if peers, hasPeers = subscribedPatterns[pattern]; !hasPeers { - peers = map[*server.Peer]struct{}{} - subscribedPatterns[pattern] = peers - } - - peers[c] = struct{}{} - - dbSubs.patterns[pattern] = struct{}{} - - if _, hasRgx := m.channelPatterns[pattern]; !hasRgx { - m.channelPatterns[pattern] = compileChannelPattern(pattern) - } - - subscriptionsAmounts[i] = m.getSubscriptionsAmount(c, ctx) - } - - for i, pattern := range args { + sub := m.subscribedState(c) + for _, pat := range args { + n := sub.Psubscribe(pat) + c.Block(func(c *server.Peer) { c.WriteLen(3) c.WriteBulk("psubscribe") - c.WriteBulk(pattern) - c.WriteInt(subscriptionsAmounts[i]) - } - }) + c.WriteBulk(pat) + c.WriteInt(n) + }) + } } func compileChannelPattern(pattern string) *regexp.Regexp { @@ -324,89 +199,33 @@ func compileChannelPattern(pattern string) *regexp.Regexp { } // PUNSUBSCRIBE -func (m *Miniredis) cmdPUnsubscribe(c *server.Peer, cmd string, args []string) { +func (m *Miniredis) cmdPunsubscribe(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } - var patterns []string = nil - var subscriptionsAmounts []int = nil - - withTx(m, c, func(c *server.Peer, ctx *connCtx) { - if cache, hasCache := m.peers[c]; hasCache { - if dbSubs, hasDbSubs := cache.subscriptions[ctx.selectedDB]; hasDbSubs { - subscribedPatterns := m.db(ctx.selectedDB).subscribedPatterns - - if len(args) > 0 { - patterns = args - } else { - patterns = make([]string, len(dbSubs.patterns)) - i := 0 - - for pattern := range dbSubs.patterns { - patterns[i] = pattern - i++ - } - } - - subscriptionsAmounts = make([]int, len(patterns)) - - for i, pattern := range patterns { - if peers, hasPeers := subscribedPatterns[pattern]; hasPeers { - delete(peers, c) - delete(dbSubs.patterns, pattern) - - if len(peers) < 1 { - delete(subscribedPatterns, pattern) - } - - if len(dbSubs.patterns) < 1 && len(dbSubs.channels) < 1 { - delete(cache.subscriptions, ctx.selectedDB) - - if len(cache.subscriptions) < 1 { - delete(m.peers, c) - } - } - } - - subscriptionsAmounts[i] = m.getSubscriptionsAmount(c, ctx) - } - } - } - - var subscriptionsAmount int + sub := m.subscribedState(c) - if patterns == nil { - subscriptionsAmount = m.getSubscriptionsAmount(c, ctx) - } + patterns := args + if len(patterns) == 0 { + patterns = sub.Patterns() + } - if patterns == nil { - for _, pattern := range args { - c.WriteLen(3) - c.WriteBulk("punsubscribe") - c.WriteBulk(pattern) - c.WriteInt(subscriptionsAmount) - } - } else { - for i, pattern := range patterns { - c.WriteLen(3) - c.WriteBulk("punsubscribe") - c.WriteBulk(pattern) - c.WriteInt(subscriptionsAmounts[i]) - } - } - }) -} + // there is no de-duplication + for _, pat := range patterns { + n := sub.Punsubscribe(pat) + c.Block(func(c *server.Peer) { + c.WriteLen(3) + c.WriteBulk("punsubscribe") + c.WriteBulk(pat) + c.WriteInt(n) + }) + } -type queuedPubSubMessage struct { - channel, message string -} + if sub.Count() == 0 { + endSubscriber(m, c) + } -func (m *queuedPubSubMessage) Write(c *server.Peer) { - c.WriteLen(3) - c.WriteBulk("message") - c.WriteBulk(m.channel) - c.WriteBulk(m.message) } // PUBLISH @@ -419,12 +238,14 @@ func (m *Miniredis) cmdPublish(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } - channel := args[0] - message := args[1] + channel, mesg := args[0], args[1] withTx(m, c, func(c *server.Peer, ctx *connCtx) { - c.WriteInt(m.db(ctx.selectedDB).publishMessage(channel, message)) + c.WriteInt(m.publish(channel, mesg)) }) } @@ -436,6 +257,10 @@ func (m *Miniredis) cmdPubSub(c *server.Peer, cmd string, args []string) { return } + if m.checkPubsub(c) { + return + } + subcommand := strings.ToUpper(args[0]) subargs := args[1:] var argsOk bool @@ -464,40 +289,27 @@ func (m *Miniredis) cmdPubSub(c *server.Peer, cmd string, args []string) { withTx(m, c, func(c *server.Peer, ctx *connCtx) { switch subcommand { case "CHANNELS": - var channels map[string]struct{} - + pat := "" if len(subargs) == 1 { - pattern := subargs[0] - - var rgx *regexp.Regexp - var hasRgx bool - - if rgx, hasRgx = m.channelPatterns[pattern]; !hasRgx { - rgx = compileChannelPattern(pattern) - m.channelPatterns[pattern] = rgx - } - - channels = m.db(ctx.selectedDB).pubSubChannelsNoLock(rgx) - } else { - channels = m.db(ctx.selectedDB).pubSubChannelsNoLock(nil) + pat = subargs[0] } - c.WriteLen(len(channels)) + channels := activeChannels(m.allSubscribers(), pat) - for channel := range channels { + c.WriteLen(len(channels)) + for _, channel := range channels { c.WriteBulk(channel) } - case "NUMSUB": - numSub := m.db(ctx.selectedDB).pubSubNumSubNoLock(subargs...) - - c.WriteLen(len(numSub) * 2) - for channel, subs := range numSub { + case "NUMSUB": + subs := m.allSubscribers() + c.WriteLen(len(subargs) * 2) + for _, channel := range subargs { c.WriteBulk(channel) - c.WriteInt(subs) + c.WriteInt(countSubs(subs, channel)) } case "NUMPAT": - c.WriteInt(m.db(ctx.selectedDB).pubSubNumPatNoLock()) + c.WriteInt(countPsubs(m.allSubscribers())) } }) } diff --git a/cmd_pubsub_test.go b/cmd_pubsub_test.go index 0175a779..b01026d9 100644 --- a/cmd_pubsub_test.go +++ b/cmd_pubsub_test.go @@ -1,31 +1,26 @@ package miniredis import ( - "github.com/gomodule/redigo/redis" - "regexp" "testing" + + "github.com/gomodule/redigo/redis" ) func TestSubscribe(t *testing.T) { s, c, done := setup(t) defer done() + defer c.Close() { a, err := redis.Values(c.Do("SUBSCRIBE", "event1")) ok(t, err) equals(t, []interface{}{[]byte("subscribe"), []byte("event1"), int64(1)}, a) - - equals(t, 1, len(s.dbs[s.selectedDB].subscribedChannels)) - equals(t, 1, len(s.peers)) } { a, err := redis.Values(c.Do("SUBSCRIBE", "event2")) ok(t, err) equals(t, []interface{}{[]byte("subscribe"), []byte("event2"), int64(2)}, a) - - equals(t, 2, len(s.dbs[s.selectedDB].subscribedChannels)) - equals(t, 1, len(s.peers)) } { @@ -33,171 +28,84 @@ func TestSubscribe(t *testing.T) { ok(t, err) equals(t, []interface{}{[]byte("subscribe"), []byte("event3"), int64(3)}, a) - equals(t, 4, len(s.dbs[s.selectedDB].subscribedChannels)) - equals(t, 1, len(s.peers)) - } - - { - a, err := redis.Values(c.Receive()) + a, err = redis.Values(c.Receive()) ok(t, err) equals(t, []interface{}{[]byte("subscribe"), []byte("event4"), int64(4)}, a) - - equals(t, 4, len(s.dbs[s.selectedDB].subscribedChannels)) - equals(t, 1, len(s.peers)) } { - sub := s.NewSubscriber() - defer sub.Close() - - equals(t, map[string]struct{}{}, sub.channels) - equals(t, map[string]map[*Subscriber]struct{}{}, sub.db.directlySubscribedChannels) - - sub.Subscribe() - equals(t, map[string]struct{}{}, sub.channels) - equals(t, map[string]map[*Subscriber]struct{}{}, sub.db.directlySubscribedChannels) - - sub.Subscribe("event1") - equals(t, map[string]struct{}{"event1": {}}, sub.channels) - equals(t, map[string]map[*Subscriber]struct{}{"event1": {sub: {}}}, sub.db.directlySubscribedChannels) + // publish something! + a, err := redis.Values(c.Do("SUBSCRIBE", "colors")) + ok(t, err) + equals(t, []interface{}{[]byte("subscribe"), []byte("colors"), int64(5)}, a) - sub.Subscribe("event2") - equals(t, map[string]struct{}{"event1": {}, "event2": {}}, sub.channels) - equals(t, map[string]map[*Subscriber]struct{}{"event1": {sub: {}}, "event2": {sub: {}}}, sub.db.directlySubscribedChannels) + n := s.Publish("colors", "green") + equals(t, 1, n) - sub.Subscribe("event3", "event4") - equals(t, map[string]struct{}{"event1": {}, "event2": {}, "event3": {}, "event4": {}}, sub.channels) - equals(t, map[string]map[*Subscriber]struct{}{"event1": {sub: {}}, "event2": {sub: {}}, "event3": {sub: {}}, "event4": {sub: {}}}, sub.db.directlySubscribedChannels) + s, err := redis.Strings(c.Receive()) + ok(t, err) + equals(t, []string{"message", "colors", "green"}, s) } } func TestUnsubscribe(t *testing.T) { - s, c, done := setup(t) + _, c, done := setup(t) defer done() - c.Do("SUBSCRIBE", "event1", "event2", "event3") + ok(t, c.Send("SUBSCRIBE", "event1", "event2", "event3", "event4", "event5")) + c.Flush() + c.Receive() + c.Receive() + c.Receive() c.Receive() c.Receive() { a, err := redis.Values(c.Do("UNSUBSCRIBE", "event1", "event2")) ok(t, err) - equals(t, []interface{}{[]byte("unsubscribe"), []byte("event1"), int64(2)}, a) - - equals(t, 1, len(s.dbs[s.selectedDB].subscribedChannels)) - equals(t, 1, len(s.peers)) - } + equals(t, []interface{}{[]byte("unsubscribe"), []byte("event1"), int64(4)}, a) - { - a, err := redis.Values(c.Receive()) + a, err = redis.Values(c.Receive()) ok(t, err) - equals(t, []interface{}{[]byte("unsubscribe"), []byte("event2"), int64(1)}, a) - - equals(t, 1, len(s.dbs[s.selectedDB].subscribedChannels)) - equals(t, 1, len(s.peers)) + equals(t, []interface{}{[]byte("unsubscribe"), []byte("event2"), int64(3)}, a) } { a, err := redis.Values(c.Do("UNSUBSCRIBE", "event3")) ok(t, err) - equals(t, []interface{}{[]byte("unsubscribe"), []byte("event3"), int64(0)}, a) - - equals(t, 0, len(s.dbs[s.selectedDB].subscribedChannels)) - equals(t, 0, len(s.peers)) - } - - { - a, err := redis.Values(c.Do("UNSUBSCRIBE", "event4")) - ok(t, err) - equals(t, []interface{}{[]byte("unsubscribe"), []byte("event4"), int64(0)}, a) - - equals(t, 0, len(s.dbs[s.selectedDB].subscribedChannels)) - equals(t, 0, len(s.peers)) - } - - { - sub := s.NewSubscriber() - defer sub.Close() - - sub.Subscribe("event1", "event2", "event3") - - sub.Unsubscribe() - equals(t, map[string]struct{}{"event1": {}, "event2": {}, "event3": {}}, sub.channels) - equals(t, map[string]map[*Subscriber]struct{}{"event1": {sub: {}}, "event2": {sub: {}}, "event3": {sub: {}}}, sub.db.directlySubscribedChannels) - - sub.Unsubscribe("event1", "event2") - equals(t, map[string]struct{}{"event3": {}}, sub.channels) - equals(t, map[string]map[*Subscriber]struct{}{"event3": {sub: {}}}, sub.db.directlySubscribedChannels) - - sub.Unsubscribe("event3") - equals(t, map[string]struct{}{}, sub.channels) - equals(t, map[string]map[*Subscriber]struct{}{}, sub.db.directlySubscribedChannels) - - sub.Unsubscribe("event4") - equals(t, map[string]struct{}{}, sub.channels) - equals(t, map[string]map[*Subscriber]struct{}{}, sub.db.directlySubscribedChannels) - } -} - -func TestUnsubscribeAll(t *testing.T) { - s, c, done := setup(t) - defer done() - - c.Do("SUBSCRIBE", "event1", "event2", "event3") - c.Receive() - c.Receive() - - channels := map[string]struct{}{"event1": {}, "event2": {}, "event3": {}} - - { - a, err := redis.Values(c.Do("UNSUBSCRIBE")) - ok(t, err) - - if oneOf(t, mkSubReplySet([]byte("unsubscribe"), channels, 2), a) { - delete(channels, string(a[1].([]byte))) - } - - equals(t, 0, len(s.dbs[s.selectedDB].subscribedChannels)) - equals(t, 0, len(s.peers)) + equals(t, []interface{}{[]byte("unsubscribe"), []byte("event3"), int64(2)}, a) } { - a, err := redis.Values(c.Receive()) + a, err := redis.Values(c.Do("UNSUBSCRIBE", "event999")) ok(t, err) - - if oneOf(t, mkSubReplySet([]byte("unsubscribe"), channels, 1), a) { - delete(channels, string(a[1].([]byte))) - } - - equals(t, 0, len(s.dbs[s.selectedDB].subscribedChannels)) - equals(t, 0, len(s.peers)) + equals(t, []interface{}{[]byte("unsubscribe"), []byte("event999"), int64(2)}, a) } { - a, err := redis.Values(c.Receive()) - ok(t, err) - - if oneOf(t, mkSubReplySet([]byte("unsubscribe"), channels, 0), a) { - delete(channels, string(a[1].([]byte))) + // unsub the rest + ok(t, c.Send("UNSUBSCRIBE")) + c.Flush() + seen := map[string]bool{} + for i := 0; i < 2; i++ { + vs, err := redis.Values(c.Receive()) + ok(t, err) + equals(t, 3, len(vs)) + equals(t, "unsubscribe", string(vs[0].([]byte))) + seen[string(vs[1].([]byte))] = true + equals(t, 1-i, int(vs[2].(int64))) } - - equals(t, 0, len(s.dbs[s.selectedDB].subscribedChannels)) - equals(t, 0, len(s.peers)) - } - - { - sub := s.NewSubscriber() - defer sub.Close() - - sub.Subscribe("event1", "event2", "event3") - - sub.UnsubscribeAll() - equals(t, map[string]struct{}{}, sub.channels) - equals(t, map[string]map[*Subscriber]struct{}{}, sub.db.directlySubscribedChannels) + equals(t, + map[string]bool{ + "event4": true, + "event5": true, + }, + seen, + ) } } -func TestPSubscribe(t *testing.T) { +func TestPsubscribe(t *testing.T) { s, c, done := setup(t) defer done() @@ -205,18 +113,12 @@ func TestPSubscribe(t *testing.T) { a, err := redis.Values(c.Do("PSUBSCRIBE", "event1")) ok(t, err) equals(t, []interface{}{[]byte("psubscribe"), []byte("event1"), int64(1)}, a) - - equals(t, 1, len(s.dbs[s.selectedDB].subscribedPatterns)) - equals(t, 1, len(s.peers)) } { a, err := redis.Values(c.Do("PSUBSCRIBE", "event2?")) ok(t, err) equals(t, []interface{}{[]byte("psubscribe"), []byte("event2?"), int64(2)}, a) - - equals(t, 2, len(s.dbs[s.selectedDB].subscribedPatterns)) - equals(t, 1, len(s.peers)) } { @@ -224,308 +126,199 @@ func TestPSubscribe(t *testing.T) { ok(t, err) equals(t, []interface{}{[]byte("psubscribe"), []byte("event3*"), int64(3)}, a) - equals(t, 4, len(s.dbs[s.selectedDB].subscribedPatterns)) - equals(t, 1, len(s.peers)) - } - - { - a, err := redis.Values(c.Receive()) + a, err = redis.Values(c.Receive()) ok(t, err) equals(t, []interface{}{[]byte("psubscribe"), []byte("event4[abc]"), int64(4)}, a) - - equals(t, 4, len(s.dbs[s.selectedDB].subscribedPatterns)) - equals(t, 1, len(s.peers)) } { a, err := redis.Values(c.Do("PSUBSCRIBE", "event5[]")) ok(t, err) equals(t, []interface{}{[]byte("psubscribe"), []byte("event5[]"), int64(5)}, a) - - equals(t, 5, len(s.dbs[s.selectedDB].subscribedPatterns)) - equals(t, 1, len(s.peers)) } { - sub := s.NewSubscriber() - defer sub.Close() - - rgxs := [5]*regexp.Regexp{ - regexp.MustCompile(`\Aevent1\z`), - regexp.MustCompile(`\Aevent2.\z`), - regexp.MustCompile(`\Aevent3`), - regexp.MustCompile(`\Aevent4[abc]\z`), - regexp.MustCompile(`\Aevent5X\bY\z`), - } - - equals(t, map[*regexp.Regexp]struct{}{}, sub.patterns) - equals(t, map[*regexp.Regexp]map[*Subscriber]struct{}{}, sub.db.directlySubscribedPatterns) - - sub.PSubscribe() - equals(t, map[*regexp.Regexp]struct{}{}, sub.patterns) - equals(t, map[*regexp.Regexp]map[*Subscriber]struct{}{}, sub.db.directlySubscribedPatterns) - - sub.PSubscribe(rgxs[0]) - equals(t, map[*regexp.Regexp]struct{}{rgxs[0]: {}}, sub.patterns) - equals(t, map[*regexp.Regexp]map[*Subscriber]struct{}{rgxs[0]: {sub: {}}}, sub.db.directlySubscribedPatterns) - - sub.PSubscribe(rgxs[1]) - equals(t, map[*regexp.Regexp]struct{}{rgxs[0]: {}, rgxs[1]: {}}, sub.patterns) - equals(t, map[*regexp.Regexp]map[*Subscriber]struct{}{rgxs[0]: {sub: {}}, rgxs[1]: {sub: {}}}, sub.db.directlySubscribedPatterns) + // publish some things! + n := s.Publish("event4b", "hello 4b!") + equals(t, 1, n) - sub.PSubscribe(rgxs[2], rgxs[3]) - equals(t, map[*regexp.Regexp]struct{}{rgxs[0]: {}, rgxs[1]: {}, rgxs[2]: {}, rgxs[3]: {}}, sub.patterns) - equals(t, map[*regexp.Regexp]map[*Subscriber]struct{}{rgxs[0]: {sub: {}}, rgxs[1]: {sub: {}}, rgxs[2]: {sub: {}}, rgxs[3]: {sub: {}}}, sub.db.directlySubscribedPatterns) + n = s.Publish("event4d", "hello 4d?") + equals(t, 0, n) - sub.PSubscribe(rgxs[4]) - equals(t, map[*regexp.Regexp]struct{}{rgxs[0]: {}, rgxs[1]: {}, rgxs[2]: {}, rgxs[3]: {}, rgxs[4]: {}}, sub.patterns) - equals(t, map[*regexp.Regexp]map[*Subscriber]struct{}{rgxs[0]: {sub: {}}, rgxs[1]: {sub: {}}, rgxs[2]: {sub: {}}, rgxs[3]: {sub: {}}, rgxs[4]: {sub: {}}}, sub.db.directlySubscribedPatterns) + s, err := redis.Strings(c.Receive()) + ok(t, err) + equals(t, []string{"message", "event4b", "hello 4b!"}, s) } } -func TestPUnsubscribe(t *testing.T) { - s, c, done := setup(t) +func TestPunsubscribe(t *testing.T) { + _, c, done := setup(t) defer done() - c.Do("PSUBSCRIBE", "event1", "event2?", "event3*", "event4[abc]", "event5[]") + c.Send("PSUBSCRIBE", "event1", "event2?", "event3*", "event4[abc]", "event5[]") + c.Flush() + c.Receive() c.Receive() c.Receive() c.Receive() c.Receive() { - a, err := redis.Values(c.Do("PUNSUBSCRIBE", "event1", "event2?")) - ok(t, err) - equals(t, []interface{}{[]byte("punsubscribe"), []byte("event1"), int64(4)}, a) - - equals(t, 3, len(s.dbs[s.selectedDB].subscribedPatterns)) - equals(t, 1, len(s.peers)) - } - - { - a, err := redis.Values(c.Receive()) - ok(t, err) - equals(t, []interface{}{[]byte("punsubscribe"), []byte("event2?"), int64(3)}, a) - - equals(t, 3, len(s.dbs[s.selectedDB].subscribedPatterns)) - equals(t, 1, len(s.peers)) - } - - { - a, err := redis.Values(c.Do("PUNSUBSCRIBE", "event3*")) - ok(t, err) - equals(t, []interface{}{[]byte("punsubscribe"), []byte("event3*"), int64(2)}, a) - - equals(t, 2, len(s.dbs[s.selectedDB].subscribedPatterns)) - equals(t, 1, len(s.peers)) - } - - { - a, err := redis.Values(c.Do("PUNSUBSCRIBE", "event4[abc]")) - ok(t, err) - equals(t, []interface{}{[]byte("punsubscribe"), []byte("event4[abc]"), int64(1)}, a) - - equals(t, 1, len(s.dbs[s.selectedDB].subscribedPatterns)) - equals(t, 1, len(s.peers)) - } - - { - a, err := redis.Values(c.Do("PUNSUBSCRIBE", "event5[]")) - ok(t, err) - equals(t, []interface{}{[]byte("punsubscribe"), []byte("event5[]"), int64(0)}, a) - - equals(t, 0, len(s.dbs[s.selectedDB].subscribedPatterns)) - equals(t, 0, len(s.peers)) - } - - { - a, err := redis.Values(c.Do("PUNSUBSCRIBE", "event6")) - ok(t, err) - equals(t, []interface{}{[]byte("punsubscribe"), []byte("event6"), int64(0)}, a) - - equals(t, 0, len(s.dbs[s.selectedDB].subscribedPatterns)) - equals(t, 0, len(s.peers)) + ok(t, c.Send("PUNSUBSCRIBE", "event1", "event2?")) + c.Flush() + seen := map[string]bool{} + for i := 0; i < 2; i++ { + vs, err := redis.Values(c.Receive()) + ok(t, err) + equals(t, 3, len(vs)) + equals(t, "punsubscribe", string(vs[0].([]byte))) + seen[string(vs[1].([]byte))] = true + equals(t, 4-i, int(vs[2].(int64))) + } + equals(t, + map[string]bool{ + "event1": true, + "event2?": true, + }, + seen, + ) } + // punsub the rest { - sub := s.NewSubscriber() - defer sub.Close() - - rgxs := [5]*regexp.Regexp{ - regexp.MustCompile(`\Aevent1\z`), - regexp.MustCompile(`\Aevent2.\z`), - regexp.MustCompile(`\Aevent3`), - regexp.MustCompile(`\Aevent4[abc]\z`), - regexp.MustCompile(`\Aevent5X\bY\z`), + ok(t, c.Send("PUNSUBSCRIBE")) + c.Flush() + seen := map[string]bool{} + for i := 0; i < 3; i++ { + vs, err := redis.Values(c.Receive()) + ok(t, err) + equals(t, 3, len(vs)) + equals(t, "punsubscribe", string(vs[0].([]byte))) + seen[string(vs[1].([]byte))] = true + equals(t, 2-i, int(vs[2].(int64))) } - - sub.PSubscribe(rgxs[:]...) - - sub.PUnsubscribe() - equals(t, map[*regexp.Regexp]struct{}{rgxs[0]: {}, rgxs[1]: {}, rgxs[2]: {}, rgxs[3]: {}, rgxs[4]: {}}, sub.patterns) - equals(t, map[*regexp.Regexp]map[*Subscriber]struct{}{rgxs[0]: {sub: {}}, rgxs[1]: {sub: {}}, rgxs[2]: {sub: {}}, rgxs[3]: {sub: {}}, rgxs[4]: {sub: {}}}, sub.db.directlySubscribedPatterns) - - sub.PUnsubscribe(rgxs[0], rgxs[1]) - equals(t, map[*regexp.Regexp]struct{}{rgxs[2]: {}, rgxs[3]: {}, rgxs[4]: {}}, sub.patterns) - equals(t, map[*regexp.Regexp]map[*Subscriber]struct{}{rgxs[2]: {sub: {}}, rgxs[3]: {sub: {}}, rgxs[4]: {sub: {}}}, sub.db.directlySubscribedPatterns) - - sub.PUnsubscribe(rgxs[2]) - equals(t, map[*regexp.Regexp]struct{}{rgxs[3]: {}, rgxs[4]: {}}, sub.patterns) - equals(t, map[*regexp.Regexp]map[*Subscriber]struct{}{rgxs[3]: {sub: {}}, rgxs[4]: {sub: {}}}, sub.db.directlySubscribedPatterns) - - sub.PUnsubscribe(rgxs[3]) - equals(t, map[*regexp.Regexp]struct{}{rgxs[4]: {}}, sub.patterns) - equals(t, map[*regexp.Regexp]map[*Subscriber]struct{}{rgxs[4]: {sub: {}}}, sub.db.directlySubscribedPatterns) - - sub.PUnsubscribe(rgxs[4]) - equals(t, map[*regexp.Regexp]struct{}{}, sub.patterns) - equals(t, map[*regexp.Regexp]map[*Subscriber]struct{}{}, sub.db.directlySubscribedPatterns) - - sub.PUnsubscribe(regexp.MustCompile(`\Aevent6\z`)) - equals(t, map[*regexp.Regexp]struct{}{}, sub.patterns) - equals(t, map[*regexp.Regexp]map[*Subscriber]struct{}{}, sub.db.directlySubscribedPatterns) + equals(t, + map[string]bool{ + "event3*": true, + "event4[abc]": true, + "event5[]": true, + }, + seen, + ) } } -func TestPUnsubscribeAll(t *testing.T) { - s, c, done := setup(t) +func TestPublishMode(t *testing.T) { + // only pubsub related commands should be accepted while there are + // subscriptions. + _, c, done := setup(t) defer done() - c.Do("PSUBSCRIBE", "event1", "event2?", "event3*", "event4[abc]", "event5[]") - c.Receive() - c.Receive() - c.Receive() - c.Receive() - - patterns := map[string]struct{}{"event1": {}, "event2?": {}, "event3*": {}, "event4[abc]": {}, "event5[]": {}} - - { - a, err := redis.Values(c.Do("PUNSUBSCRIBE")) - ok(t, err) + _, err := c.Do("SUBSCRIBE", "birds") + ok(t, err) - if oneOf(t, mkSubReplySet([]byte("punsubscribe"), patterns, 4), a) { - delete(patterns, string(a[1].([]byte))) - } + _, err = c.Do("SET", "foo", "bar") + mustFail(t, err, "ERR only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT allowed in this context") - equals(t, 0, len(s.dbs[s.selectedDB].subscribedPatterns)) - equals(t, 0, len(s.peers)) - } + _, err = c.Do("UNSUBSCRIBE", "birds") + ok(t, err) - { - a, err := redis.Values(c.Receive()) - ok(t, err) + // no subs left. All should be fine now. + _, err = c.Do("SET", "foo", "bar") + ok(t, err) +} - if oneOf(t, mkSubReplySet([]byte("punsubscribe"), patterns, 3), a) { - delete(patterns, string(a[1].([]byte))) - } +func TestPublish(t *testing.T) { + s, c, c2, done := setup2(t) + defer done() - equals(t, 0, len(s.dbs[s.selectedDB].subscribedPatterns)) - equals(t, 0, len(s.peers)) - } + a, err := redis.Values(c2.Do("SUBSCRIBE", "event1")) + ok(t, err) + equals(t, []interface{}{[]byte("subscribe"), []byte("event1"), int64(1)}, a) { - a, err := redis.Values(c.Receive()) + n, err := redis.Int(c.Do("PUBLISH", "event1", "message2")) ok(t, err) + equals(t, 1, n) - if oneOf(t, mkSubReplySet([]byte("punsubscribe"), patterns, 2), a) { - delete(patterns, string(a[1].([]byte))) - } - - equals(t, 0, len(s.dbs[s.selectedDB].subscribedPatterns)) - equals(t, 0, len(s.peers)) - } - - { - a, err := redis.Values(c.Receive()) + s, err := redis.Strings(c2.Receive()) ok(t, err) - - if oneOf(t, mkSubReplySet([]byte("punsubscribe"), patterns, 1), a) { - delete(patterns, string(a[1].([]byte))) - } - - equals(t, 0, len(s.dbs[s.selectedDB].subscribedPatterns)) - equals(t, 0, len(s.peers)) + equals(t, []string{"message", "event1", "message2"}, s) } + // direct access { - a, err := redis.Values(c.Receive()) - ok(t, err) - - if oneOf(t, mkSubReplySet([]byte("punsubscribe"), patterns, 0), a) { - delete(patterns, string(a[1].([]byte))) - } + equals(t, 1, s.Publish("event1", "message3")) - equals(t, 0, len(s.dbs[s.selectedDB].subscribedPatterns)) - equals(t, 0, len(s.peers)) + s, err := redis.Strings(c2.Receive()) + ok(t, err) + equals(t, []string{"message", "event1", "message3"}, s) } + // Wrong usage { - sub := s.NewSubscriber() - defer sub.Close() - - sub.PSubscribe( - regexp.MustCompile(`\Aevent1\z`), - regexp.MustCompile(`\Aevent2.\z`), - regexp.MustCompile(`\Aevent3`), - regexp.MustCompile(`\Aevent4[abc]\z`), - regexp.MustCompile(`\Aevent5X\bY\z`), - ) - - sub.PUnsubscribeAll() - equals(t, map[*regexp.Regexp]struct{}{}, sub.patterns) - equals(t, map[*regexp.Regexp]map[*Subscriber]struct{}{}, sub.db.directlySubscribedPatterns) + _, err := c2.Do("PUBLISH", "foo", "bar") + mustFail(t, err, "ERR only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT allowed in this context") } } -func mkSubReplySet(subject []byte, channels map[string]struct{}, subs int64) []interface{} { - result := make([]interface{}, len(channels)) - i := 0 +func TestPublishMix(t *testing.T) { + // SUBSCRIBE and PSUBSCRIBE + _, c, done := setup(t) + defer done() - for channel := range channels { - result[i] = []interface{}{subject, []byte(channel), subs} - i++ - } + a, err := redis.Values(c.Do("SUBSCRIBE", "c1")) + ok(t, err) + equals(t, 1, int(a[2].(int64))) - return result -} + a, err = redis.Values(c.Do("PSUBSCRIBE", "c1")) + ok(t, err) + equals(t, 2, int(a[2].(int64))) -func TestPublish(t *testing.T) { - s, c, done := setup(t) - defer done() + a, err = redis.Values(c.Do("SUBSCRIBE", "c2")) + ok(t, err) + equals(t, 3, int(a[2].(int64))) - { - a, err := redis.Int(c.Do("PUBLISH", "event1", "message2")) - ok(t, err) - equals(t, 0, a) - } + a, err = redis.Values(c.Do("PUNSUBSCRIBE", "c1")) + ok(t, err) + equals(t, 2, int(a[2].(int64))) - equals(t, 0, s.Publish("event1", "message2")) + a, err = redis.Values(c.Do("UNSUBSCRIBE", "c1")) + ok(t, err) + equals(t, 1, int(a[2].(int64))) } -func TestPubSubChannels(t *testing.T) { - s, c, done := setup(t) +func TestPubsubChannels(t *testing.T) { + _, c1, c2, done := setup2(t) defer done() - { - a, err := redis.Values(c.Do("PUBSUB", "CHANNELS")) - ok(t, err) - equals(t, []interface{}{}, a) - } + a, err := redis.Strings(c1.Do("PUBSUB", "CHANNELS")) + ok(t, err) + equals(t, []string{}, a) - { - a, err := redis.Values(c.Do("PUBSUB", "CHANNELS", "event1?*[abc]")) - ok(t, err) - equals(t, []interface{}{}, a) - } + a, err = redis.Strings(c1.Do("PUBSUB", "CHANNELS", "event1[abc]")) + ok(t, err) + equals(t, []string{}, a) - equals(t, map[string]struct{}{}, s.PubSubChannels(nil)) - equals(t, map[string]struct{}{}, s.PubSubChannels(regexp.MustCompile(`\Aevent1..*[abc]\z`))) + _, err = c2.Do("SUBSCRIBE", "event1", "event1b", "event1c") + ok(t, err) + + a, err = redis.Strings(c1.Do("PUBSUB", "CHANNELS")) + ok(t, err) + equals(t, []string{"event1", "event1b", "event1c"}, a) + + a, err = redis.Strings(c1.Do("PUBSUB", "CHANNELS", "event1[abc]")) + ok(t, err) + equals(t, []string{"event1b", "event1c"}, a) } -func TestPubSubNumSub(t *testing.T) { - s, c, done := setup(t) +func TestPubsubNumsub(t *testing.T) { + _, c, c2, done := setup2(t) defer done() + _, err := c2.Do("SUBSCRIBE", "event1", "event2", "event3") + ok(t, err) + { a, err := redis.Values(c.Do("PUBSUB", "NUMSUB")) ok(t, err) @@ -535,22 +328,23 @@ func TestPubSubNumSub(t *testing.T) { { a, err := redis.Values(c.Do("PUBSUB", "NUMSUB", "event1")) ok(t, err) - equals(t, []interface{}{[]byte("event1"), int64(0)}, a) + equals(t, []interface{}{[]byte("event1"), int64(1)}, a) } { - a, err := redis.Values(c.Do("PUBSUB", "NUMSUB", "event1", "event2")) + a, err := redis.Values(c.Do("PUBSUB", "NUMSUB", "event12", "event3")) ok(t, err) - oneOf(t, []interface{}{ - []interface{}{[]byte("event1"), int64(0), []byte("event2"), int64(0)}, - []interface{}{[]byte("event2"), int64(0), []byte("event1"), int64(0)}, - }, a) + equals(t, + []interface{}{ + []byte("event12"), int64(0), + []byte("event3"), int64(1), + }, + a, + ) } - - equals(t, map[string]int{"event1": 0}, s.PubSubNumSub("event1")) } -func TestPubSubNumPat(t *testing.T) { +func TestPubsubNumpat(t *testing.T) { s, c, done := setup(t) defer done() @@ -669,12 +463,16 @@ func testPubSubInteractionDirectSub1(t *testing.T, s *Miniredis, ch chan struct{ sub := s.NewSubscriber() defer sub.Close() - sub.Subscribe("event1", "event3", "event4", "event6") + sub.Subscribe("event1") + sub.Subscribe("event3") + sub.Subscribe("event4") + sub.Subscribe("event6") ch <- struct{}{} receiveMessagesDirectlyDuringPubSub(t, sub, '1', '3', '4', '6') - sub.Unsubscribe("event1", "event4") + sub.Unsubscribe("event1") + sub.Unsubscribe("event4") ch <- struct{}{} receiveMessagesDirectlyDuringPubSub(t, sub, '3', '6') @@ -684,12 +482,16 @@ func testPubSubInteractionDirectSub2(t *testing.T, s *Miniredis, ch chan struct{ sub := s.NewSubscriber() defer sub.Close() - sub.Subscribe("event2", "event3", "event4", "event5") + sub.Subscribe("event2") + sub.Subscribe("event3") + sub.Subscribe("event4") + sub.Subscribe("event5") ch <- struct{}{} receiveMessagesDirectlyDuringPubSub(t, sub, '2', '3', '4', '5') - sub.Unsubscribe("event3", "event5") + sub.Unsubscribe("event3") + sub.Unsubscribe("event5") ch <- struct{}{} receiveMessagesDirectlyDuringPubSub(t, sub, '2', '4') @@ -736,32 +538,38 @@ func testPubSubInteractionPsub2(t *testing.T, _ *Miniredis, c redis.Conn, ch cha } func testPubSubInteractionDirectPsub1(t *testing.T, s *Miniredis, ch chan struct{}) { - rgx := regexp.MustCompile sub := s.NewSubscriber() defer sub.Close() - sub.PSubscribe(rgx(`\Aevent[ab1]\z`), rgx(`\Aevent[ef3]\z`), rgx(`\Aevent[gh]\z`), rgx(`\Aevent[kl6]\z`)) + sub.Psubscribe(`event[ab1]`) + sub.Psubscribe(`event[ef3]`) + sub.Psubscribe(`event[gh]`) + sub.Psubscribe(`event[kl6]`) ch <- struct{}{} receiveMessagesDirectlyDuringPubSub(t, sub, '1', '3', '6', 'a', 'b', 'e', 'f', 'g', 'h', 'k', 'l') - sub.PUnsubscribe(rgx(`\Aevent[ab1]\z`), rgx(`\Aevent[gh]\z`)) + sub.Punsubscribe(`event[ab1]`) + sub.Punsubscribe(`event[gh]`) ch <- struct{}{} receiveMessagesDirectlyDuringPubSub(t, sub, '3', '6', 'e', 'f', 'k', 'l') } func testPubSubInteractionDirectPsub2(t *testing.T, s *Miniredis, ch chan struct{}) { - rgx := regexp.MustCompile sub := s.NewSubscriber() defer sub.Close() - sub.PSubscribe(rgx(`\Aevent[cd]\z`), rgx(`\Aevent[ef]\z`), rgx(`\Aevent[gh4]\z`), rgx(`\Aevent[ij]\z`)) + sub.Psubscribe(`event[cd]`) + sub.Psubscribe(`event[ef]`) + sub.Psubscribe(`event[gh4]`) + sub.Psubscribe(`event[ij]`) ch <- struct{}{} receiveMessagesDirectlyDuringPubSub(t, sub, '4', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j') - sub.PUnsubscribe(rgx(`\Aevent[ef]\z`), rgx(`\Aevent[ij]\z`)) + sub.Punsubscribe(`event[ef]`) + sub.Punsubscribe(`event[ij]`) ch <- struct{}{} receiveMessagesDirectlyDuringPubSub(t, sub, '4', 'c', 'd', 'g', 'h') @@ -777,17 +585,17 @@ func testPubSubInteractionPubStage1(t *testing.T, s *Miniredis, c redis.Conn, ch <-ch } - for _, pattern := range [2]struct { - pattern string - rgx *regexp.Regexp - }{{"", nil}, {"event?", regexp.MustCompile(`\Aevent.\z`)}} { - assertActiveChannelsDuringPubSub(t, s, c, pattern.pattern, pattern.rgx, map[string]struct{}{ - "event1": {}, "event2": {}, "event3": {}, "event4": {}, "event5": {}, "event6": {}, + for _, pattern := range []string{ + "", + "event?", + } { + assertActiveChannelsDuringPubSub(t, s, c, pattern, []string{ + "event1", "event2", "event3", "event4", "event5", "event6", }) } - assertActiveChannelsDuringPubSub(t, s, c, "*[123]", regexp.MustCompile(`[123]\z`), map[string]struct{}{ - "event1": {}, "event2": {}, "event3": {}, + assertActiveChannelsDuringPubSub(t, s, c, "*[123]", []string{ + "event1", "event2", "event3", }) assertNumSubDuringPubSub(t, s, c, map[string]int{ @@ -816,18 +624,16 @@ func testPubSubInteractionPubStage2(t *testing.T, s *Miniredis, c redis.Conn, ch <-ch } - for _, pattern := range [2]struct { - pattern string - rgx *regexp.Regexp - }{{"", nil}, {"event?", regexp.MustCompile(`\Aevent.\z`)}} { - assertActiveChannelsDuringPubSub(t, s, c, pattern.pattern, pattern.rgx, map[string]struct{}{ - "event1": {}, "event2": {}, "event3": {}, "event4": {}, "event6": {}, + for _, pattern := range []string{ + "", + "event?", + } { + assertActiveChannelsDuringPubSub(t, s, c, pattern, []string{ + "event1", "event2", "event3", "event4", "event6", }) } - assertActiveChannelsDuringPubSub(t, s, c, "*[123]", regexp.MustCompile(`[123]\z`), map[string]struct{}{ - "event1": {}, "event2": {}, "event3": {}, - }) + assertActiveChannelsDuringPubSub(t, s, c, "*[123]", []string{"event1", "event2", "event3"}) assertNumSubDuringPubSub(t, s, c, map[string]int{ "event1": 1, "event2": 1, "event3": 2, "event4": 2, "event5": 0, "event6": 2, @@ -932,13 +738,11 @@ func receiveMessagesDirectlyDuringPubSub(t *testing.T, sub *Subscriber, suffixes for _, suffix := range suffixes { suff := string([]rune{suffix}) - equals(t, Message{"event" + suff, "message" + suff}, <-sub.Messages) + equals(t, PubsubMessage{"event" + suff, "message" + suff}, <-sub.Messages()) } } -func assertActiveChannelsDuringPubSub(t *testing.T, s *Miniredis, c redis.Conn, pattern string, rgx *regexp.Regexp, channels map[string]struct{}) { - t.Helper() - +func assertActiveChannelsDuringPubSub(t *testing.T, s *Miniredis, c redis.Conn, pattern string, channels []string) { var args []interface{} if pattern == "" { args = []interface{}{"CHANNELS"} @@ -946,20 +750,12 @@ func assertActiveChannelsDuringPubSub(t *testing.T, s *Miniredis, c redis.Conn, args = []interface{}{"CHANNELS", pattern} } - a, err := redis.Values(c.Do("PUBSUB", args...)) + actual, err := redis.Strings(c.Do("PUBSUB", args...)) ok(t, err) - actualChannels := make(map[string]struct{}, len(a)) - - for _, channel := range a { - if channelString, channelIsString := channel.([]byte); channelIsString { - actualChannels[string(channelString)] = struct{}{} - } - } - - equals(t, channels, actualChannels) + equals(t, channels, actual) - equals(t, channels, s.PubSubChannels(rgx)) + equals(t, channels, s.PubSubChannels(pattern)) } func assertNumSubDuringPubSub(t *testing.T, s *Miniredis, c redis.Conn, channels map[string]int) { diff --git a/cmd_string.go b/cmd_string.go index 930da992..658d77d0 100644 --- a/cmd_string.go +++ b/cmd_string.go @@ -48,6 +48,10 @@ func (m *Miniredis) cmdSet(c *server.Peer, cmd string, args []string) { return } + if m.checkPubsub(c) { + return + } + var ( nx = false // set iff not exists xx = false // set iff exists diff --git a/direct.go b/direct.go index 0670da9d..533224d0 100644 --- a/direct.go +++ b/direct.go @@ -4,10 +4,6 @@ package miniredis import ( "errors" - "github.com/alicebob/miniredis/server" - "regexp" - "sync" - "sync/atomic" "time" ) @@ -552,444 +548,39 @@ func (db *RedisDB) ZScore(k, member string) (float64, error) { return db.ssetScore(k, member), nil } -type Message struct { - Channel, Message string -} - -type messageQueue struct { - sync.Mutex - messages []Message - hasNewMessages chan struct{} -} - -func (q *messageQueue) Enqueue(message Message) { - q.Lock() - defer q.Unlock() - - q.messages = append(q.messages, message) - - select { - case q.hasNewMessages <- struct{}{}: - break - default: - break - } -} - -type Subscriber struct { - Messages chan Message - close chan struct{} - db *RedisDB - channels map[string]struct{} - patterns map[*regexp.Regexp]struct{} - queue messageQueue -} - -func (s *Subscriber) Close() error { - close(s.close) - - s.db.master.Lock() - defer s.db.master.Unlock() - - for channel := range s.channels { - subscribers := s.db.directlySubscribedChannels[channel] - delete(subscribers, s) - - if len(subscribers) < 1 { - delete(s.db.directlySubscribedChannels, channel) - } - } - - for pattern := range s.patterns { - subscribers := s.db.directlySubscribedPatterns[pattern] - delete(subscribers, s) - - if len(subscribers) < 1 { - delete(s.db.directlySubscribedPatterns, pattern) - } - } - - return nil -} - -func (s *Subscriber) Subscribe(channels ...string) { - s.db.master.Lock() - defer s.db.master.Unlock() - - for _, channel := range channels { - s.channels[channel] = struct{}{} - - var peers map[*Subscriber]struct{} - var hasPeers bool - - if peers, hasPeers = s.db.directlySubscribedChannels[channel]; !hasPeers { - peers = map[*Subscriber]struct{}{} - s.db.directlySubscribedChannels[channel] = peers - } - - peers[s] = struct{}{} - } -} - -func (s *Subscriber) Unsubscribe(channels ...string) { - s.db.master.Lock() - defer s.db.master.Unlock() - - for _, channel := range channels { - if _, hasChannel := s.channels[channel]; hasChannel { - delete(s.channels, channel) - - peers := s.db.directlySubscribedChannels[channel] - delete(peers, s) - - if len(peers) < 1 { - delete(s.db.directlySubscribedChannels, channel) - } - } - } -} - -func (s *Subscriber) UnsubscribeAll() { - s.db.master.Lock() - defer s.db.master.Unlock() - - for channel := range s.channels { - subscribers := s.db.directlySubscribedChannels[channel] - delete(subscribers, s) - - if len(subscribers) < 1 { - delete(s.db.directlySubscribedChannels, channel) - } - } - - s.channels = map[string]struct{}{} -} - -func (s *Subscriber) PSubscribe(patterns ...*regexp.Regexp) { - s.db.master.Lock() - defer s.db.master.Unlock() - - decompiledDSPs := s.db.master.decompiledDirectlySubscribedPatterns - - for _, pattern := range patterns { - decompiled := pattern.String() - - if decompiledDSP, hasDDSP := decompiledDSPs[decompiled]; hasDDSP { - pattern = decompiledDSP - } else { - decompiledDSPs[decompiled] = pattern - } - - s.patterns[pattern] = struct{}{} - - var peers map[*Subscriber]struct{} - var hasPeers bool - - if peers, hasPeers = s.db.directlySubscribedPatterns[pattern]; !hasPeers { - peers = map[*Subscriber]struct{}{} - s.db.directlySubscribedPatterns[pattern] = peers - } - - peers[s] = struct{}{} - } -} - -func (s *Subscriber) PUnsubscribe(patterns ...*regexp.Regexp) { - s.db.master.Lock() - defer s.db.master.Unlock() - - decompiledDSPs := s.db.master.decompiledDirectlySubscribedPatterns - - for _, pattern := range patterns { - if decompiledDSP, hasDDSP := decompiledDSPs[pattern.String()]; hasDDSP { - pattern = decompiledDSP - } - - if _, hasChannel := s.patterns[pattern]; hasChannel { - delete(s.patterns, pattern) - - peers := s.db.directlySubscribedPatterns[pattern] - delete(peers, s) - - if len(peers) < 1 { - delete(s.db.directlySubscribedPatterns, pattern) - } - } - } -} - -func (s *Subscriber) PUnsubscribeAll() { - s.db.master.Lock() - defer s.db.master.Unlock() - - for pattern := range s.patterns { - subscribers := s.db.directlySubscribedPatterns[pattern] - delete(subscribers, s) - - if len(subscribers) < 1 { - delete(s.db.directlySubscribedPatterns, pattern) - } - } - - s.patterns = map[*regexp.Regexp]struct{}{} -} - -func (s *Subscriber) streamMessages() { - defer close(s.Messages) - - for { - select { - case <-s.queue.hasNewMessages: - s.queue.Lock() - - select { - case <-s.queue.hasNewMessages: - break - default: - break - } - - messages := s.queue.messages - s.queue.messages = []Message{} - - s.queue.Unlock() - - for _, message := range messages { - select { - case s.Messages <- message: - break - case <-s.close: - return - } - } - case <-s.close: - return - } - } -} - -func (m *Miniredis) NewSubscriber() *Subscriber { - return m.DB(m.selectedDB).NewSubscriber() -} - -func (db *RedisDB) NewSubscriber() *Subscriber { - s := &Subscriber{ - Messages: make(chan Message), - close: make(chan struct{}), - db: db, - channels: map[string]struct{}{}, - patterns: map[*regexp.Regexp]struct{}{}, - queue: messageQueue{ - messages: []Message{}, - hasNewMessages: make(chan struct{}, 1), - }, - } - - go s.streamMessages() - - return s -} - func (m *Miniredis) Publish(channel, message string) int { - return m.DB(m.selectedDB).Publish(channel, message) -} - -func (db *RedisDB) Publish(channel, message string) int { - db.master.Lock() - defer db.master.Unlock() - - return db.publishMessage(channel, message) -} - -func (m *Miniredis) PubSubChannels(pattern *regexp.Regexp) map[string]struct{} { - return m.DB(m.selectedDB).PubSubChannels(pattern) + m.Lock() + defer m.Unlock() + return m.publish(channel, message) } -func (db *RedisDB) PubSubChannels(pattern *regexp.Regexp) map[string]struct{} { - db.master.Lock() - defer db.master.Unlock() +// PubSubChannels is "PUBSUB CHANNELS ". An empty pattern is fine. +// Returned channels will be ordered alphabetically. +func (m *Miniredis) PubSubChannels(pattern string) []string { + m.Lock() + defer m.Unlock() - return db.pubSubChannelsNoLock(pattern) + return activeChannels(m.allSubscribers(), pattern) } +// PubSubNumSub is "PUBSUB NUMSUB [channels]". It returns all channels with their +// subscriber count. func (m *Miniredis) PubSubNumSub(channels ...string) map[string]int { - return m.DB(m.selectedDB).PubSubNumSub(channels...) -} - -func (db *RedisDB) PubSubNumSub(channels ...string) map[string]int { - db.master.Lock() - defer db.master.Unlock() - - return db.pubSubNumSubNoLock(channels...) -} - -func (m *Miniredis) PubSubNumPat() int { - return m.DB(m.selectedDB).PubSubNumPat() -} - -func (db *RedisDB) PubSubNumPat() int { - db.master.Lock() - defer db.master.Unlock() - - return db.pubSubNumPatNoLock() -} - -func (db *RedisDB) pubSubChannelsNoLock(pattern *regexp.Regexp) map[string]struct{} { - channels := map[string]struct{}{} - - if pattern == nil { - for channel := range db.subscribedChannels { - channels[channel] = struct{}{} - } - - for channel := range db.directlySubscribedChannels { - channels[channel] = struct{}{} - } - } else { - for channel := range db.subscribedChannels { - if pattern.MatchString(channel) { - channels[channel] = struct{}{} - } - } - - for channel := range db.directlySubscribedChannels { - if pattern.MatchString(channel) { - channels[channel] = struct{}{} - } - } - } - - return channels -} - -func (db *RedisDB) pubSubNumSubNoLock(channels ...string) map[string]int { - numSub := map[string]int{} + m.Lock() + defer m.Unlock() + subs := m.allSubscribers() + res := map[string]int{} for _, channel := range channels { - numSub[channel] = len(db.subscribedChannels[channel]) + len(db.directlySubscribedChannels[channel]) + res[channel] = countSubs(subs, channel) } - - return numSub -} - -func (db *RedisDB) pubSubNumPatNoLock() (numPat int) { - for _, peers := range db.subscribedPatterns { - numPat += len(peers) - } - - for _, subscribers := range db.directlySubscribedPatterns { - numPat += len(subscribers) - } - - return + return res } -func (db *RedisDB) publishMessage(channel, message string) int { - count := 0 - - var allPeers map[*server.Peer]struct{} = nil - - if peers, hasPeers := db.subscribedChannels[channel]; hasPeers { - allPeers = make(map[*server.Peer]struct{}, len(peers)) - - for peer := range peers { - allPeers[peer] = struct{}{} - } - } - - for pattern, peers := range db.subscribedPatterns { - if db.master.channelPatterns[pattern].MatchString(channel) { - if allPeers == nil { - allPeers = make(map[*server.Peer]struct{}, len(peers)) - } - - for peer := range peers { - allPeers[peer] = struct{}{} - } - } - } - - if allPeers != nil { - count += len(allPeers) - - wait := publishMessagesAsync(allPeers, channel, message) - defer wait() - } - - var allSubscribers map[*Subscriber]struct{} = nil - - if subscribers, hasSubscribers := db.directlySubscribedChannels[channel]; hasSubscribers { - allSubscribers = make(map[*Subscriber]struct{}, len(subscribers)) - - for subscriber := range subscribers { - allSubscribers[subscriber] = struct{}{} - } - } - - for pattern, subscribers := range db.directlySubscribedPatterns { - if pattern.MatchString(channel) { - if allSubscribers == nil { - allSubscribers = make(map[*Subscriber]struct{}, len(subscribers)) - } - - for subscriber := range subscribers { - allSubscribers[subscriber] = struct{}{} - } - } - } - - if allSubscribers != nil { - count += len(allSubscribers) - - wait := publishMessagesToOurselvesAsync(allSubscribers, channel, message) - defer wait() - } - - return count -} - -func publishMessagesAsync(peers map[*server.Peer]struct{}, channel, message string) (wait func()) { - chCtl := make(chan struct{}) - go publishMessages(peers, channel, message, chCtl) - - return func() { <-chCtl } -} - -func publishMessages(peers map[*server.Peer]struct{}, channel, message string, chCtl chan struct{}) { - pending := uint64(len(peers)) - - for peer := range peers { - go publishMessage(peer, channel, message, &pending, chCtl) - } -} - -func publishMessage(peer *server.Peer, channel, message string, pending *uint64, chCtl chan struct{}) { - peer.MsgQueue.Enqueue(&queuedPubSubMessage{channel, message}) - - if atomic.AddUint64(pending, ^uint64(0)) == 0 { - close(chCtl) - } -} - -func publishMessagesToOurselvesAsync(subscribers map[*Subscriber]struct{}, channel, message string) (wait func()) { - chCtl := make(chan struct{}) - go publishMessagesToOurselves(subscribers, channel, message, chCtl) - - return func() { <-chCtl } -} - -func publishMessagesToOurselves(subscribers map[*Subscriber]struct{}, channel, message string, chCtl chan struct{}) { - pending := uint64(len(subscribers)) - - for subscriber := range subscribers { - go publishMessageToOurselves(subscriber, channel, message, &pending, chCtl) - } -} - -func publishMessageToOurselves(subscriber *Subscriber, channel, message string, pending *uint64, chCtl chan struct{}) { - subscriber.queue.Enqueue(Message{channel, message}) +// PubSubNumPat is "PUBSUB NUMPAT" +func (m *Miniredis) PubSubNumPat() int { + m.Lock() + defer m.Unlock() - if atomic.AddUint64(pending, ^uint64(0)) == 0 { - close(chCtl) - } + return countPsubs(m.allSubscribers()) } diff --git a/integration/generic_test.go b/integration/generic_test.go index 45d4e09a..f1487a26 100644 --- a/integration/generic_test.go +++ b/integration/generic_test.go @@ -18,6 +18,14 @@ func TestEcho(t *testing.T) { ) } +func TestPing(t *testing.T) { + testCommands(t, + succ("PING"), + succ("PING", "hello world"), + fail("PING", "hello", "world"), + ) +} + func TestKeys(t *testing.T) { testCommands(t, succ("SET", "one", "1"), diff --git a/integration/pubsub_test.go b/integration/pubsub_test.go index a8a7e566..7b747195 100644 --- a/integration/pubsub_test.go +++ b/integration/pubsub_test.go @@ -104,36 +104,30 @@ func TestPubSub(t *testing.T) { } func TestPubsubFull(t *testing.T) { - t.Skip() // exit 1, no idea why var wg1 sync.WaitGroup wg1.Add(1) testMultiCommands(t, func(r chan<- command, _ *miniredis.Miniredis) { r <- succ("SUBSCRIBE", "news", "sport") r <- receive() - /* - wg1.Done() - r <- receive() - r <- receive() - r <- receive() - r <- succ("UNSUBSCRIBE", "news", "sport") - r <- receive() - */ + wg1.Done() + r <- receive() + r <- receive() + r <- receive() + r <- succ("UNSUBSCRIBE", "news", "sport") + r <- receive() + }, + func(r chan<- command, _ *miniredis.Miniredis) { + wg1.Wait() + r <- succ("PUBLISH", "news", "revolution!") + r <- succ("PUBLISH", "news", "alien invasion!") + r <- succ("PUBLISH", "sport", "lady biked too fast") + r <- succ("PUBLISH", "gossip", "man bites dog") }, - /* - func(r chan<- command, _ *miniredis.Miniredis) { - wg1.Wait() - r <- succ("PUBLISH", "news", "revolution!") - r <- succ("PUBLISH", "news", "alien invasion!") - r <- succ("PUBLISH", "sport", "lady biked too fast") - r <- succ("PUBLISH", "gossip", "man bites dog") - }, - */ ) } func TestPubsubMulti(t *testing.T) { - t.Skip() // hangs. No idea why. var wg1 sync.WaitGroup wg1.Add(2) testMultiCommands(t, @@ -149,11 +143,9 @@ func TestPubsubMulti(t *testing.T) { }, func(r chan<- command, _ *miniredis.Miniredis) { r <- succ("SUBSCRIBE", "sport") - r <- receive() wg1.Done() r <- receive() r <- succ("UNSUBSCRIBE", "sport") - r <- receive() }, func(r chan<- command, _ *miniredis.Miniredis) { wg1.Wait() @@ -165,30 +157,43 @@ func TestPubsubMulti(t *testing.T) { } func TestPubsubSelect(t *testing.T) { - t.Skip() // known broken - var wg1 sync.WaitGroup - wg1.Add(1) - testMultiCommands(t, - func(r chan<- command, _ *miniredis.Miniredis) { - r <- succ("SUBSCRIBE", "news", "sport") - r <- receive() - wg1.Done() - r <- receive() - }, - func(r chan<- command, _ *miniredis.Miniredis) { - wg1.Wait() - r <- succ("SELECT", 3) - r <- succ("PUBLISH", "news", "revolution!") - }, - ) + testClients2(t, func(r1, r2 chan<- command) { + r1 <- succ("SUBSCRIBE", "news", "sport") + r1 <- receive() + r2 <- succ("SELECT", 3) + r2 <- succ("PUBLISH", "news", "revolution!") + r1 <- receive() + }) } func TestPubsubMode(t *testing.T) { - t.Skip() // known broken testCommands(t, succ("SUBSCRIBE", "news", "sport"), receive(), + succ("PING"), + succ("PING", "foo"), fail("ECHO", "foo"), fail("HGET", "foo", "bar"), + fail("SET", "foo", "bar"), + succ("QUIT"), ) } + +func TestSubscriptions(t *testing.T) { + testClients2(t, func(r1, r2 chan<- command) { + r1 <- succ("SUBSCRIBE", "foo", "bar", "foo") + r2 <- succ("PUBSUB", "NUMSUB") + r1 <- succ("UNSUBSCRIBE", "bar", "bar", "bar") + r2 <- succ("PUBSUB", "NUMSUB") + }) +} + +func TestPubsubUnsub(t *testing.T) { + testClients2(t, func(c1, c2 chan<- command) { + c1 <- succ("SUBSCRIBE", "news", "sport") + c1 <- receive() + c2 <- succSorted("PUBSUB", "CHANNELS") + c1 <- succ("QUIT") + c2 <- succSorted("PUBSUB", "CHANNELS") + }) +} diff --git a/integration/test.go b/integration/test.go index ec1ebef8..9c398b56 100644 --- a/integration/test.go +++ b/integration/test.go @@ -4,6 +4,7 @@ package main import ( "bytes" + "context" "fmt" "reflect" "sort" @@ -142,6 +143,58 @@ func testMultiCommands(t *testing.T, cs ...func(chan<- command, *miniredis.Minir wg.Wait() } +// like testCommands, but multiple connections +func testClients2(t *testing.T, f func(c1, c2 chan<- command)) { + t.Helper() + sMini, err := miniredis.Run() + ok(t, err) + defer sMini.Close() + + sReal, realAddr := Redis() + defer sReal.Close() + + type aChan struct { + c chan command + cMini, cReal redis.Conn + } + chans := [2]aChan{} + for i := range chans { + gen := make(chan command) + cMini, err := redis.Dial("tcp", sMini.Addr()) + ok(t, err) + + cReal, err := redis.Dial("tcp", realAddr) + ok(t, err) + chans[i] = aChan{ + c: gen, + cMini: cMini, + cReal: cReal, + } + } + + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + f(chans[0].c, chans[1].c) + cancel() + for _, c := range chans { + close(c.c) + } + }() + +loop: + for { + select { + case <-ctx.Done(): + break loop + case cm := <-chans[0].c: + runCommand(t, chans[0].cMini, chans[0].cReal, cm) + case cm := <-chans[1].c: + runCommand(t, chans[1].cMini, chans[1].cReal, cm) + } + } +} + func testAuthCommands(t *testing.T, passwd string, commands ...command) { sMini, err := miniredis.Run() ok(t, err) @@ -177,22 +230,6 @@ func runCommand(t *testing.T, cMini, cReal redis.Conn, p command) { dump(vReal, "-real-") vMini, errMini = cMini.Receive() dump(vMini, "-mini-") - for _, k := range vReal.([]interface{}) { - switch k := k.(type) { - case []byte: - t.Errorf(" -real- %s", string(k)) - default: - t.Errorf(" -real- %#v", k) - } - } - for _, k := range vMini.([]interface{}) { - switch k := k.(type) { - case []byte: - t.Errorf(" -mini- %s", string(k)) - default: - t.Errorf(" -mini- %#v", k) - } - } } else { vReal, errReal = cReal.Do(p.cmd, p.args...) vMini, errMini = cMini.Do(p.cmd, p.args...) diff --git a/miniredis.go b/miniredis.go index ecfc0434..6b8f3aba 100644 --- a/miniredis.go +++ b/miniredis.go @@ -17,7 +17,6 @@ package miniredis import ( "fmt" "net" - "regexp" "strconv" "sync" "time" @@ -33,44 +32,30 @@ type setKey map[string]struct{} // RedisDB holds a single (numbered) Redis database. type RedisDB struct { - master *Miniredis // pointer to the lock in Miniredis - id int // db id - keys map[string]string // Master map of keys with their type - stringKeys map[string]string // GET/SET &c. keys - hashKeys map[string]hashKey // MGET/MSET &c. keys - listKeys map[string]listKey // LPUSH &c. keys - setKeys map[string]setKey // SADD &c. keys - sortedsetKeys map[string]sortedSet // ZADD &c. keys - ttl map[string]time.Duration // effective TTL values - keyVersion map[string]uint // used to watch values - subscribedChannels map[string]map[*server.Peer]struct{} - subscribedPatterns map[string]map[*server.Peer]struct{} - directlySubscribedChannels map[string]map[*Subscriber]struct{} - directlySubscribedPatterns map[*regexp.Regexp]map[*Subscriber]struct{} -} - -type peerSubscriptions struct { - channels, patterns map[string]struct{} -} - -type peerCache struct { - subscriptions map[int]peerSubscriptions + master *Miniredis // pointer to the lock in Miniredis + id int // db id + keys map[string]string // Master map of keys with their type + stringKeys map[string]string // GET/SET &c. keys + hashKeys map[string]hashKey // MGET/MSET &c. keys + listKeys map[string]listKey // LPUSH &c. keys + setKeys map[string]setKey // SADD &c. keys + sortedsetKeys map[string]sortedSet // ZADD &c. keys + ttl map[string]time.Duration // effective TTL values + keyVersion map[string]uint // used to watch values } // Miniredis is a Redis server implementation. type Miniredis struct { sync.Mutex - srv *server.Server - port int - password string - dbs map[int]*RedisDB - selectedDB int // DB id used in the direct Get(), Set() &c. - scripts map[string]string // sha1 -> lua src - signal *sync.Cond - now time.Time // used to make a duration from EXPIREAT. time.Now() if not set. - peers map[*server.Peer]peerCache - channelPatterns map[string]*regexp.Regexp - decompiledDirectlySubscribedPatterns map[string]*regexp.Regexp + srv *server.Server + port int + password string + dbs map[int]*RedisDB + selectedDB int // DB id used in the direct Get(), Set() &c. + scripts map[string]string // sha1 -> lua src + signal *sync.Cond + now time.Time // used to make a duration from EXPIREAT. time.Now() if not set. + subscribers map[*Subscriber]struct{} } type txCmd func(*server.Peer, *connCtx) @@ -86,18 +71,17 @@ type connCtx struct { selectedDB int // selected DB authenticated bool // auth enabled and a valid AUTH seen transaction []txCmd // transaction callbacks. Or nil. - dirtyTransaction bool // any error during QUEUEing. - watch map[dbKey]uint // WATCHed keys. + dirtyTransaction bool // any error during QUEUEing + watch map[dbKey]uint // WATCHed keys + subscriber *Subscriber // client is in PUBSUB mode if not nil } // NewMiniRedis makes a new, non-started, Miniredis object. func NewMiniRedis() *Miniredis { m := Miniredis{ - dbs: map[int]*RedisDB{}, - scripts: map[string]string{}, - peers: map[*server.Peer]peerCache{}, - channelPatterns: map[string]*regexp.Regexp{}, - decompiledDirectlySubscribedPatterns: map[string]*regexp.Regexp{}, + dbs: map[int]*RedisDB{}, + scripts: map[string]string{}, + subscribers: map[*Subscriber]struct{}{}, } m.signal = sync.NewCond(&m) return &m @@ -105,20 +89,16 @@ func NewMiniRedis() *Miniredis { func newRedisDB(id int, l *Miniredis) RedisDB { return RedisDB{ - id: id, - master: l, - keys: map[string]string{}, - stringKeys: map[string]string{}, - hashKeys: map[string]hashKey{}, - listKeys: map[string]listKey{}, - setKeys: map[string]setKey{}, - sortedsetKeys: map[string]sortedSet{}, - ttl: map[string]time.Duration{}, - keyVersion: map[string]uint{}, - subscribedChannels: map[string]map[*server.Peer]struct{}{}, - subscribedPatterns: map[string]map[*server.Peer]struct{}{}, - directlySubscribedChannels: map[string]map[*Subscriber]struct{}{}, - directlySubscribedPatterns: map[*regexp.Regexp]map[*Subscriber]struct{}{}, + id: id, + master: l, + keys: map[string]string{}, + stringKeys: map[string]string{}, + hashKeys: map[string]hashKey{}, + listKeys: map[string]listKey{}, + setKeys: map[string]setKey{}, + sortedsetKeys: map[string]sortedSet{}, + ttl: map[string]time.Duration{}, + keyVersion: map[string]uint{}, } } @@ -166,8 +146,6 @@ func (m *Miniredis) start(s *server.Server) error { commandsTransaction(m) commandsScripting(m) - s.OnDisconnect(m.onDisconnect) - return nil } @@ -180,12 +158,16 @@ func (m *Miniredis) Restart() error { // Close shuts down a Miniredis. func (m *Miniredis) Close() { m.Lock() - defer m.Unlock() if m.srv == nil { return } + m.Unlock() + m.srv.Close() + + m.Lock() m.srv = nil + m.Unlock() } // RequireAuth makes every connection need to AUTH first. Disable again by @@ -349,39 +331,19 @@ func (m *Miniredis) handleAuth(c *server.Peer) bool { return true } -func (m *Miniredis) onDisconnect(c *server.Peer) { - go m.unsubscribeAll(c) -} - -func (m *Miniredis) unsubscribeAll(c *server.Peer) { +// handlePubsub sends an error to the user if the connection is in PUBSUB mode. +// It'll return true if it did. +func (m *Miniredis) checkPubsub(c *server.Peer) bool { m.Lock() defer m.Unlock() - if cache, hasCache := m.peers[c]; hasCache { - for dbIdx, subscriptions := range cache.subscriptions { - db := m.dbs[dbIdx] - - for channel := range subscriptions.channels { - peers := db.subscribedChannels[channel] - delete(peers, c) - - if len(peers) < 1 { - delete(db.subscribedChannels, channel) - } - } - - for pattern := range subscriptions.patterns { - peers := db.subscribedPatterns[pattern] - delete(peers, c) - - if len(peers) < 1 { - delete(db.subscribedPatterns, pattern) - } - } - } - - delete(m.peers, c) + ctx := getCtx(c) + if ctx.subscriber == nil { + return false } + + c.WriteError("ERR only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT allowed in this context") + return true } func getCtx(c *server.Peer) *connCtx { @@ -432,3 +394,76 @@ func setDirty(c *server.Peer) { func setAuthenticated(c *server.Peer) { getCtx(c).authenticated = true } + +func (m *Miniredis) addSubscriber(s *Subscriber) { + m.Lock() + m.subscribers[s] = struct{}{} + m.Unlock() +} + +// closes and remove the subscriber +func (m *Miniredis) removeSubscriber(s *Subscriber) { + m.Lock() + _, ok := m.subscribers[s] + delete(m.subscribers, s) + m.Unlock() + if ok { + s.Close() + } +} + +func (m *Miniredis) publish(c, msg string) int { + n := 0 + for s := range m.subscribers { + n += s.Publish(c, msg) + } + return n +} + +// enter 'subscribed state', or return the existing one. +func (m *Miniredis) subscribedState(c *server.Peer) *Subscriber { + ctx := getCtx(c) + sub := ctx.subscriber + if sub != nil { + return sub + } + + sub = m.NewSubscriber() + c.OnDisconnect(func() { + m.removeSubscriber(sub) + }) + + ctx.subscriber = sub + + go monitorPublish(c, sub.publish) + + return sub +} + +// whenever the p?sub count drops to 0 subscribed state should be stopped, and +// all redis commands are allowed again. +func endSubscriber(m *Miniredis, c *server.Peer) { + ctx := getCtx(c) + if sub := ctx.subscriber; sub != nil { + m.removeSubscriber(sub) // will Close() the sub + } + ctx.subscriber = nil +} + +// Start a new pubsub subscriber. It can (un) subscribe to channels and +// patterns, and has a channel to get published messages. Close it with +// Close(). +// Does not close itself when there are no subscriptions left. +func (m *Miniredis) NewSubscriber() *Subscriber { + sub := newSubscriber() + m.addSubscriber(sub) + return sub +} + +func (m *Miniredis) allSubscribers() []*Subscriber { + var subs []*Subscriber + for s := range m.subscribers { + subs = append(subs, s) + } + return subs +} diff --git a/pubsub.go b/pubsub.go new file mode 100644 index 00000000..8ccf88be --- /dev/null +++ b/pubsub.go @@ -0,0 +1,204 @@ +package miniredis + +import ( + "regexp" + "sort" + "sync" + + "github.com/alicebob/miniredis/server" +) + +// PubsubMessage is what gets broadcasted over pubsub channels. +type PubsubMessage struct { + Channel string + Message string +} + +// Subscriber has the (p)subscriptions. +type Subscriber struct { + publish chan PubsubMessage + channels map[string]struct{} + patterns map[string]*regexp.Regexp + mu sync.Mutex +} + +func newSubscriber() *Subscriber { + return &Subscriber{ + publish: make(chan PubsubMessage), + channels: map[string]struct{}{}, + patterns: map[string]*regexp.Regexp{}, + } +} + +func (s *Subscriber) Close() { + close(s.publish) +} + +// total number of channels and patterns +func (s *Subscriber) Count() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.count() +} + +func (s *Subscriber) count() int { + return len(s.channels) + len(s.patterns) +} + +func (s *Subscriber) Subscribe(c string) int { + s.mu.Lock() + defer s.mu.Unlock() + + s.channels[c] = struct{}{} + return s.count() +} + +func (s *Subscriber) Unsubscribe(c string) int { + s.mu.Lock() + defer s.mu.Unlock() + + delete(s.channels, c) + return s.count() +} + +func (s *Subscriber) Psubscribe(pat string) int { + s.mu.Lock() + defer s.mu.Unlock() + + s.patterns[pat] = compileChannelPattern(pat) + return s.count() +} + +func (s *Subscriber) Punsubscribe(pat string) int { + s.mu.Lock() + defer s.mu.Unlock() + + delete(s.patterns, pat) + return s.count() +} + +// List all subscribed channels +func (s *Subscriber) Channels() []string { + s.mu.Lock() + defer s.mu.Unlock() + + var cs []string + for c := range s.channels { + cs = append(cs, c) + } + sort.Strings(cs) + return cs +} + +// List all subscribed patters +func (s *Subscriber) Patterns() []string { + s.mu.Lock() + defer s.mu.Unlock() + + var ps []string + for p := range s.patterns { + ps = append(ps, p) + } + sort.Strings(ps) + return ps +} + +// Publish a message. Will return return how often we sent the message (can be +// a match for a subscription and for a psubscription. +func (s *Subscriber) Publish(c, msg string) int { + s.mu.Lock() + defer s.mu.Unlock() + + found := 0 + +subs: + for sub := range s.channels { + if sub == c { + s.publish <- PubsubMessage{c, msg} + found++ + break subs + } + } + +pats: + for _, pat := range s.patterns { + if pat.MatchString(c) { + s.publish <- PubsubMessage{c, msg} + found++ + break pats + } + } + + return found +} + +// reads the subscriptions +func (s *Subscriber) Messages() <-chan PubsubMessage { + return s.publish +} + +// List all pubsub channels. If `pat` isn't empty channels names must match the +// pattern. Channels are returned alphabetically. +func activeChannels(subs []*Subscriber, pat string) []string { + channels := map[string]struct{}{} + for _, s := range subs { + for c := range s.channels { + channels[c] = struct{}{} + } + } + + var cpat *regexp.Regexp + if pat != "" { + cpat = compileChannelPattern(pat) + } + + var cs []string + for k := range channels { + if cpat != nil && !cpat.MatchString(k) { + continue + } + cs = append(cs, k) + } + sort.Strings(cs) + return cs +} + +// Count all subscribed (not psubscribed) clients for the given channel +// pattern. Channels are returned alphabetically. +func countSubs(subs []*Subscriber, channel string) int { + n := 0 + for _, p := range subs { + for c := range p.channels { + if c == channel { + n++ + break + } + } + } + return n +} + +// Count the total of all client psubscriptions. +func countPsubs(subs []*Subscriber) int { + n := 0 + for _, p := range subs { + n += len(p.patterns) + } + return n +} + +func monitorPublish(conn *server.Peer, c <-chan PubsubMessage) { + for msg := range c { + conn.Block(func(c *server.Peer) { + c.WriteLen(3) + c.WriteBulk("message") + c.WriteBulk(msg.Channel) + c.WriteBulk(msg.Message) + c.Flush() + if c.Error != nil { + // what now? :( + // fmt.Printf("publish err: %s\n", c.Error) + } + }) + } +} diff --git a/server/server.go b/server/server.go index b29ecdac..6d06ae25 100644 --- a/server/server.go +++ b/server/server.go @@ -27,22 +27,21 @@ type DisconnectHandler func(c *Peer) // Server is a simple redis server type Server struct { - l net.Listener - cmds map[string]Cmd - peers map[net.Conn]struct{} - mu sync.Mutex - wg sync.WaitGroup - infoConns int - infoCmds int - onDisconnect DisconnectHandler + l net.Listener + cmds map[string]Cmd + peers map[net.Conn]struct{} + mu sync.Mutex + wg sync.WaitGroup + infoConns int + infoCmds int } // NewServer makes a server listening on addr. Close with .Close(). func NewServer(addr string) (*Server, error) { s := Server{ - cmds: map[string]Cmd{}, - peers: map[net.Conn]struct{}{}, - onDisconnect: func(c *Peer) {}, + cmds: map[string]Cmd{}, + peers: map[net.Conn]struct{}{}, + // onDisconnect: func(c *Peer) {}, } l, err := net.Listen("tcp", addr) @@ -80,7 +79,7 @@ func (s *Server) ServeConn(conn net.Conn) { s.infoConns++ s.mu.Unlock() - s.onDisconnect(s.servePeer(conn)) + s.servePeer(conn) s.mu.Lock() delete(s.peers, conn) @@ -126,87 +125,27 @@ func (s *Server) Register(cmd string, f Cmd) error { return nil } -func (s *Server) OnDisconnect(handler DisconnectHandler) { - s.onDisconnect = handler -} - -func (s *Server) servePeer(c net.Conn) (cl *Peer) { +func (s *Server) servePeer(c net.Conn) { r := bufio.NewReader(c) - cl = &Peer{ + peer := &Peer{ w: bufio.NewWriter(c), - MsgQueue: MessageQueue{ - messages: []QueuedMessage{}, - hasNewMessages: make(chan struct{}, 1), - }, } - - chReceivedArray, chReadNext := readArrayAsync(r) - defer close(chReadNext) - - chReadNext <- struct{}{} - - for { - select { - case message := <-chReceivedArray: - if message.err != nil { - return - } - - s.dispatch(cl, message.array) - cl.w.Flush() - - if cl.closed { - c.Close() - return - } - - chReadNext <- struct{}{} - case <-cl.MsgQueue.hasNewMessages: - cl.MsgQueue.Lock() - - select { - case <-cl.MsgQueue.hasNewMessages: - break - default: - break - } - - messages := cl.MsgQueue.messages - cl.MsgQueue.messages = []QueuedMessage{} - - cl.MsgQueue.Unlock() - - for _, message := range messages { - message.Write(cl) - } - - cl.Flush() + defer func() { + for _, f := range peer.onDisconnect { + f() } - } -} - -type receivedArray struct { - array []string - err error -} - -func readArrayAsync(r *bufio.Reader) (chReceivedArray chan receivedArray, chReadNext chan struct{}) { - chReceivedArray = make(chan receivedArray) - chReadNext = make(chan struct{}) - - go readArraySync(r, chReceivedArray, chReadNext) - return -} + }() -func readArraySync(r *bufio.Reader, chReceivedArray chan receivedArray, chReadNext chan struct{}) { for { - if _, isOpen := <-chReadNext; !isOpen { - close(chReceivedArray) + args, err := readArray(r) + if err != nil { return } - - args, err := readArray(r) - chReceivedArray <- receivedArray{args, err} + s.dispatch(peer, args) + peer.w.Flush() + if peer.closed { + c.Close() + } } } @@ -249,41 +188,21 @@ func (s *Server) TotalConnections() int { return s.infoConns } -type QueuedMessage interface { - Write(c *Peer) -} - -type MessageQueue struct { - sync.Mutex - messages []QueuedMessage - hasNewMessages chan struct{} -} - -func (q *MessageQueue) Enqueue(message QueuedMessage) { - q.Lock() - defer q.Unlock() - - q.messages = append(q.messages, message) - - select { - case q.hasNewMessages <- struct{}{}: - break - default: - break - } -} - // Peer is a client connected to the server type Peer struct { - w *bufio.Writer - closed bool - Ctx interface{} // anything goes, server won't touch this - MsgQueue MessageQueue + w *bufio.Writer + closed bool + Ctx interface{} // anything goes, server won't touch this + onDisconnect []func() // list of callbacks + Error error // set if any Write* call had an error + mu sync.Mutex // for Block() } // Flush the write buffer. Called automatically after every redis command func (c *Peer) Flush() { - c.w.Flush() + if err := c.w.Flush(); err != nil { + c.Error = err + } } // Close the client connection after the current command is done. @@ -291,14 +210,31 @@ func (c *Peer) Close() { c.closed = true } +// Register a function to execute on disconnect. There can be multiple +// functions registered. +func (c *Peer) OnDisconnect(f func()) { + c.onDisconnect = append(c.onDisconnect, f) +} + +// issue multiple calls, guarded with a mutex +func (c *Peer) Block(f func(*Peer)) { + c.mu.Lock() + defer c.mu.Unlock() + f(c) +} + // WriteError writes a redis 'Error' func (c *Peer) WriteError(e string) { - fmt.Fprintf(c.w, "-%s\r\n", toInline(e)) + if _, err := fmt.Fprintf(c.w, "-%s\r\n", toInline(e)); err != nil { + c.Error = err + } } // WriteInline writes a redis inline string func (c *Peer) WriteInline(s string) { - fmt.Fprintf(c.w, "+%s\r\n", toInline(s)) + if _, err := fmt.Fprintf(c.w, "+%s\r\n", toInline(s)); err != nil { + c.Error = err + } } // WriteOK write the inline string `OK` @@ -308,22 +244,30 @@ func (c *Peer) WriteOK() { // WriteBulk writes a bulk string func (c *Peer) WriteBulk(s string) { - fmt.Fprintf(c.w, "$%d\r\n%s\r\n", len(s), s) + if _, err := fmt.Fprintf(c.w, "$%d\r\n%s\r\n", len(s), s); err != nil { + c.Error = err + } } // WriteNull writes a redis Null element func (c *Peer) WriteNull() { - fmt.Fprintf(c.w, "$-1\r\n") + if _, err := fmt.Fprintf(c.w, "$-1\r\n"); err != nil { + c.Error = err + } } // WriteLen starts an array with the given length func (c *Peer) WriteLen(n int) { - fmt.Fprintf(c.w, "*%d\r\n", n) + if _, err := fmt.Fprintf(c.w, "*%d\r\n", n); err != nil { + c.Error = err + } } // WriteInt writes an integer func (c *Peer) WriteInt(i int) { - fmt.Fprintf(c.w, ":%d\r\n", i) + if _, err := fmt.Fprintf(c.w, ":%d\r\n", i); err != nil { + c.Error = err + } } func toInline(s string) string { From 7238be090eba65dc98bfeeb9b6c29c1072520cd7 Mon Sep 17 00:00:00 2001 From: Harmen Date: Sat, 23 Mar 2019 13:02:48 +0100 Subject: [PATCH 08/13] support PUBSUB commands in transactions --- cmd_pubsub.go | 116 +++++++++++++++++++------------------ cmd_transactions.go | 5 ++ integration/pubsub_test.go | 60 +++++++++++++++++++ integration/test.go | 2 - integration/tx_test.go | 24 ++++++++ miniredis.go | 16 +++-- pubsub.go | 4 +- 7 files changed, 162 insertions(+), 65 deletions(-) diff --git a/cmd_pubsub.go b/cmd_pubsub.go index 1fc5398b..ec407024 100644 --- a/cmd_pubsub.go +++ b/cmd_pubsub.go @@ -22,7 +22,6 @@ func commandsPubsub(m *Miniredis) { // SUBSCRIBE func (m *Miniredis) cmdSubscribe(c *server.Peer, cmd string, args []string) { - // TODO: figure out transactions. if len(args) < 1 { setDirty(c) c.WriteError(errWrongNumber(cmd)) @@ -32,16 +31,18 @@ func (m *Miniredis) cmdSubscribe(c *server.Peer, cmd string, args []string) { return } - sub := m.subscribedState(c) - for _, channel := range args { - n := sub.Subscribe(channel) - c.Block(func(c *server.Peer) { - c.WriteLen(3) - c.WriteBulk("subscribe") - c.WriteBulk(channel) - c.WriteInt(n) - }) - } + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + sub := m.subscribedState(c) + for _, channel := range args { + n := sub.Subscribe(channel) + c.Block(func(c *server.Peer) { + c.WriteLen(3) + c.WriteBulk("subscribe") + c.WriteBulk(channel) + c.WriteInt(n) + }) + } + }) } // UNSUBSCRIBE @@ -52,27 +53,28 @@ func (m *Miniredis) cmdUnsubscribe(c *server.Peer, cmd string, args []string) { sub := m.subscribedState(c) - // TODO: tx, which also locks. - channels := args - if len(channels) == 0 { - channels = sub.Channels() - } - // there is no de-duplication - for _, channel := range channels { - n := sub.Unsubscribe(channel) - c.Block(func(c *server.Peer) { - c.WriteLen(3) - c.WriteBulk("unsubscribe") - c.WriteBulk(channel) - c.WriteInt(n) - }) - } + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + if len(channels) == 0 { + channels = sub.Channels() + } - if sub.Count() == 0 { - endSubscriber(m, c) - } + // there is no de-duplication + for _, channel := range channels { + n := sub.Unsubscribe(channel) + c.Block(func(c *server.Peer) { + c.WriteLen(3) + c.WriteBulk("unsubscribe") + c.WriteBulk(channel) + c.WriteInt(n) + }) + } + + if sub.Count() == 0 { + endSubscriber(m, c) + } + }) } // PSUBSCRIBE @@ -86,16 +88,18 @@ func (m *Miniredis) cmdPsubscribe(c *server.Peer, cmd string, args []string) { return } - sub := m.subscribedState(c) - for _, pat := range args { - n := sub.Psubscribe(pat) - c.Block(func(c *server.Peer) { - c.WriteLen(3) - c.WriteBulk("psubscribe") - c.WriteBulk(pat) - c.WriteInt(n) - }) - } + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + sub := m.subscribedState(c) + for _, pat := range args { + n := sub.Psubscribe(pat) + c.Block(func(c *server.Peer) { + c.WriteLen(3) + c.WriteBulk("psubscribe") + c.WriteBulk(pat) + c.WriteInt(n) + }) + } + }) } func compileChannelPattern(pattern string) *regexp.Regexp { @@ -207,25 +211,27 @@ func (m *Miniredis) cmdPunsubscribe(c *server.Peer, cmd string, args []string) { sub := m.subscribedState(c) patterns := args - if len(patterns) == 0 { - patterns = sub.Patterns() - } - // there is no de-duplication - for _, pat := range patterns { - n := sub.Punsubscribe(pat) - c.Block(func(c *server.Peer) { - c.WriteLen(3) - c.WriteBulk("punsubscribe") - c.WriteBulk(pat) - c.WriteInt(n) - }) - } + withTx(m, c, func(c *server.Peer, ctx *connCtx) { + if len(patterns) == 0 { + patterns = sub.Patterns() + } - if sub.Count() == 0 { - endSubscriber(m, c) - } + // there is no de-duplication + for _, pat := range patterns { + n := sub.Punsubscribe(pat) + c.Block(func(c *server.Peer) { + c.WriteLen(3) + c.WriteBulk("punsubscribe") + c.WriteBulk(pat) + c.WriteInt(n) + }) + } + if sub.Count() == 0 { + endSubscriber(m, c) + } + }) } // PUBLISH diff --git a/cmd_transactions.go b/cmd_transactions.go index 64912cf5..ad1b269a 100644 --- a/cmd_transactions.go +++ b/cmd_transactions.go @@ -24,6 +24,9 @@ func (m *Miniredis) cmdMulti(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } ctx := getCtx(c) @@ -57,6 +60,8 @@ func (m *Miniredis) cmdExec(c *server.Peer, cmd string, args []string) { if ctx.dirtyTransaction { c.WriteError("EXECABORT Transaction discarded because of previous errors.") + // a failed EXEC finishes the tx + stopTx(ctx) return } diff --git a/integration/pubsub_test.go b/integration/pubsub_test.go index 7b747195..9e4930f2 100644 --- a/integration/pubsub_test.go +++ b/integration/pubsub_test.go @@ -197,3 +197,63 @@ func TestPubsubUnsub(t *testing.T) { c2 <- succSorted("PUBSUB", "CHANNELS") }) } + +func TestPubsubTx(t *testing.T) { + // publish is in a tx + testClients2(t, func(c1, c2 chan<- command) { + c1 <- succ("SUBSCRIBE", "foo") + c2 <- succ("MULTI") + c2 <- succ("PUBSUB", "CHANNELS") + c2 <- succ("PUBLISH", "foo", "hello one") + c2 <- fail("GET") + c2 <- succ("PUBLISH", "foo", "hello two") + c2 <- fail("EXEC") + + c2 <- succ("PUBLISH", "foo", "post tx") + c1 <- receive() + }) + + // SUBSCRIBE is in a tx + testClients2(t, func(c1, c2 chan<- command) { + c1 <- succ("MULTI") + c1 <- succ("SUBSCRIBE", "foo") + c2 <- succ("PUBSUB", "CHANNELS") + c1 <- succ("EXEC") + c2 <- succ("PUBSUB", "CHANNELS") + + c1 <- fail("MULTI") // we're in SUBSCRIBE mode + }) + + // UNSUBSCRIBE is in a tx + testClients2(t, func(c1, c2 chan<- command) { + c1 <- succ("MULTI") + c1 <- succ("SUBSCRIBE", "foo") + c1 <- succ("UNSUBSCRIBE", "foo") + c2 <- succ("PUBSUB", "CHANNELS") + c1 <- succ("EXEC") + c2 <- succ("PUBSUB", "CHANNELS") + c1 <- succ("PUBSUB", "CHANNELS") + }) + + // PSUBSCRIBE is in a tx + testClients2(t, func(c1, c2 chan<- command) { + c1 <- succ("MULTI") + c1 <- succ("PSUBSCRIBE", "foo") + c2 <- succ("PUBSUB", "NUMPAT") + c1 <- succ("EXEC") + c2 <- succ("PUBSUB", "NUMPAT") + + c1 <- fail("MULTI") // we're in SUBSCRIBE mode + }) + + // PUNSUBSCRIBE is in a tx + testClients2(t, func(c1, c2 chan<- command) { + c1 <- succ("MULTI") + c1 <- succ("PSUBSCRIBE", "foo") + c1 <- succ("PUNSUBSCRIBE", "foo") + c2 <- succ("PUBSUB", "NUMPAT") + c1 <- succ("EXEC") + c2 <- succ("PUBSUB", "NUMPAT") + c1 <- succ("PUBSUB", "NUMPAT") + }) +} diff --git a/integration/test.go b/integration/test.go index 9c398b56..f55f8715 100644 --- a/integration/test.go +++ b/integration/test.go @@ -227,9 +227,7 @@ func runCommand(t *testing.T, cMini, cReal redis.Conn, p command) { ) if p.receiveOnly { vReal, errReal = cReal.Receive() - dump(vReal, "-real-") vMini, errMini = cMini.Receive() - dump(vMini, "-mini-") } else { vReal, errReal = cReal.Do(p.cmd, p.args...) vMini, errMini = cMini.Do(p.cmd, p.args...) diff --git a/integration/tx_test.go b/integration/tx_test.go index e3f958d4..b34cb599 100644 --- a/integration/tx_test.go +++ b/integration/tx_test.go @@ -142,4 +142,28 @@ func TestTx(t *testing.T) { succ("BITOP", "BROKEN", "str", ""), succ("EXEC"), ) + + // fail on invalid command + testCommands(t, + succ("MULTI"), + fail("GET"), + fail("EXEC"), + ) + + /* FIXME + // fail on unknown command + testCommands(t, + succ("MULTI"), + fail("NOSUCH"), + fail("EXEC"), + ) + */ + + // failed EXEC cleaned up the tx + testCommands(t, + succ("MULTI"), + fail("GET"), + fail("EXEC"), + succ("MULTI"), + ) } diff --git a/miniredis.go b/miniredis.go index 6b8f3aba..cadb46e8 100644 --- a/miniredis.go +++ b/miniredis.go @@ -396,17 +396,13 @@ func setAuthenticated(c *server.Peer) { } func (m *Miniredis) addSubscriber(s *Subscriber) { - m.Lock() m.subscribers[s] = struct{}{} - m.Unlock() } -// closes and remove the subscriber +// closes and remove the subscriber. func (m *Miniredis) removeSubscriber(s *Subscriber) { - m.Lock() _, ok := m.subscribers[s] delete(m.subscribers, s) - m.Unlock() if ok { s.Close() } @@ -428,9 +424,13 @@ func (m *Miniredis) subscribedState(c *server.Peer) *Subscriber { return sub } - sub = m.NewSubscriber() + sub = newSubscriber() + m.addSubscriber(sub) + c.OnDisconnect(func() { + m.Lock() m.removeSubscriber(sub) + m.Unlock() }) ctx.subscriber = sub @@ -456,7 +456,11 @@ func endSubscriber(m *Miniredis, c *server.Peer) { // Does not close itself when there are no subscriptions left. func (m *Miniredis) NewSubscriber() *Subscriber { sub := newSubscriber() + + m.Lock() m.addSubscriber(sub) + m.Unlock() + return sub } diff --git a/pubsub.go b/pubsub.go index 8ccf88be..60b9dfe9 100644 --- a/pubsub.go +++ b/pubsub.go @@ -187,8 +187,8 @@ func countPsubs(subs []*Subscriber) int { return n } -func monitorPublish(conn *server.Peer, c <-chan PubsubMessage) { - for msg := range c { +func monitorPublish(conn *server.Peer, msgs <-chan PubsubMessage) { + for msg := range msgs { conn.Block(func(c *server.Peer) { c.WriteLen(3) c.WriteBulk("message") From 64c0e93d4c36499688b6e7b0db179014ca96eba1 Mon Sep 17 00:00:00 2001 From: Harmen Date: Sat, 23 Mar 2019 14:24:55 +0100 Subject: [PATCH 09/13] test all commands for publish mode --- cmd_connection.go | 7 ++ cmd_generic.go | 42 +++++++++++ cmd_hash.go | 36 +++++++++ cmd_list.go | 40 ++++++++++ cmd_pubsub.go | 8 +- cmd_scripting.go | 10 +++ cmd_server.go | 16 +++- cmd_set.go | 45 ++++++++++++ cmd_sorted_set.go | 51 +++++++++++++ cmd_string.go | 67 ++++++++++++++++- cmd_transactions.go | 12 +++ integration/Makefile | 4 + integration/pubsub_test.go | 145 +++++++++++++++++++++++++++++++++++++ 13 files changed, 475 insertions(+), 8 deletions(-) diff --git a/cmd_connection.go b/cmd_connection.go index 39689949..58b22bd5 100644 --- a/cmd_connection.go +++ b/cmd_connection.go @@ -57,6 +57,10 @@ func (m *Miniredis) cmdAuth(c *server.Peer, cmd string, args []string) { c.WriteError(errWrongNumber(cmd)) return } + if m.checkPubsub(c) { + return + } + pw := args[0] m.Lock() @@ -102,6 +106,9 @@ func (m *Miniredis) cmdSelect(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } id, err := strconv.Atoi(args[0]) if err != nil { diff --git a/cmd_generic.go b/cmd_generic.go index fa394790..129df634 100644 --- a/cmd_generic.go +++ b/cmd_generic.go @@ -49,6 +49,9 @@ func makeCmdExpire(m *Miniredis, unix bool, d time.Duration) func(*server.Peer, if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] value := args[1] @@ -102,6 +105,10 @@ func (m *Miniredis) cmdTTL(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } + key := args[0] withTx(m, c, func(c *server.Peer, ctx *connCtx) { @@ -133,6 +140,10 @@ func (m *Miniredis) cmdPTTL(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } + key := args[0] withTx(m, c, func(c *server.Peer, ctx *connCtx) { @@ -164,6 +175,10 @@ func (m *Miniredis) cmdPersist(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } + key := args[0] withTx(m, c, func(c *server.Peer, ctx *connCtx) { @@ -191,6 +206,9 @@ func (m *Miniredis) cmdDel(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } withTx(m, c, func(c *server.Peer, ctx *connCtx) { db := m.db(ctx.selectedDB) @@ -216,6 +234,9 @@ func (m *Miniredis) cmdType(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] @@ -242,6 +263,9 @@ func (m *Miniredis) cmdExists(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } withTx(m, c, func(c *server.Peer, ctx *connCtx) { db := m.db(ctx.selectedDB) @@ -266,6 +290,9 @@ func (m *Miniredis) cmdMove(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] targetDB, err := strconv.Atoi(args[1]) @@ -299,6 +326,9 @@ func (m *Miniredis) cmdKeys(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] @@ -323,6 +353,9 @@ func (m *Miniredis) cmdRandomkey(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } withTx(m, c, func(c *server.Peer, ctx *connCtx) { db := m.db(ctx.selectedDB) @@ -352,6 +385,9 @@ func (m *Miniredis) cmdRename(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } from, to := args[0], args[1] @@ -378,6 +414,9 @@ func (m *Miniredis) cmdRenamenx(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } from, to := args[0], args[1] @@ -409,6 +448,9 @@ func (m *Miniredis) cmdScan(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } cursor, err := strconv.Atoi(args[0]) if err != nil { diff --git a/cmd_hash.go b/cmd_hash.go index c1248099..78fffb2c 100644 --- a/cmd_hash.go +++ b/cmd_hash.go @@ -69,6 +69,9 @@ func (m *Miniredis) cmdHsetnx(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, field, value := args[0], args[1], args[2] @@ -105,6 +108,9 @@ func (m *Miniredis) cmdHmset(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, args := args[0], args[1:] if len(args)%2 != 0 { @@ -178,6 +184,9 @@ func (m *Miniredis) cmdHdel(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, fields := args[0], args[1:] @@ -223,6 +232,9 @@ func (m *Miniredis) cmdHexists(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, field := args[0], args[1] @@ -257,6 +269,9 @@ func (m *Miniredis) cmdHgetall(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] @@ -291,6 +306,9 @@ func (m *Miniredis) cmdHkeys(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] @@ -324,6 +342,9 @@ func (m *Miniredis) cmdHvals(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] @@ -357,6 +378,9 @@ func (m *Miniredis) cmdHlen(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] @@ -387,6 +411,9 @@ func (m *Miniredis) cmdHmget(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] @@ -425,6 +452,9 @@ func (m *Miniredis) cmdHincrby(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, field, deltas := args[0], args[1], args[2] @@ -462,6 +492,9 @@ func (m *Miniredis) cmdHincrbyfloat(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, field, deltas := args[0], args[1], args[2] @@ -499,6 +532,9 @@ func (m *Miniredis) cmdHscan(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] cursor, err := strconv.Atoi(args[1]) diff --git a/cmd_list.go b/cmd_list.go index ae543dc6..23aa62f4 100644 --- a/cmd_list.go +++ b/cmd_list.go @@ -57,6 +57,10 @@ func (m *Miniredis) cmdBXpop(c *server.Peer, cmd string, args []string, lr leftr if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } + timeoutS := args[len(args)-1] keys := args[:len(args)-1] @@ -121,6 +125,9 @@ func (m *Miniredis) cmdLindex(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, offsets := args[0], args[1] @@ -167,6 +174,9 @@ func (m *Miniredis) cmdLinsert(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] where := 0 @@ -231,6 +241,9 @@ func (m *Miniredis) cmdLlen(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] @@ -271,6 +284,9 @@ func (m *Miniredis) cmdXpop(c *server.Peer, cmd string, args []string, lr leftri if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] @@ -317,6 +333,9 @@ func (m *Miniredis) cmdXpush(c *server.Peer, cmd string, args []string, lr leftr if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, args := args[0], args[1:] @@ -360,6 +379,9 @@ func (m *Miniredis) cmdXpushx(c *server.Peer, cmd string, args []string, lr left if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, args := args[0], args[1:] @@ -398,6 +420,9 @@ func (m *Miniredis) cmdLrange(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] start, err := strconv.Atoi(args[1]) @@ -445,6 +470,9 @@ func (m *Miniredis) cmdLrem(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] count, err := strconv.Atoi(args[1]) @@ -514,6 +542,9 @@ func (m *Miniredis) cmdLset(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] index, err := strconv.Atoi(args[1]) @@ -561,6 +592,9 @@ func (m *Miniredis) cmdLtrim(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] start, err := strconv.Atoi(args[1]) @@ -612,6 +646,9 @@ func (m *Miniredis) cmdRpoplpush(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } src, dst := args[0], args[1] @@ -642,6 +679,9 @@ func (m *Miniredis) cmdBrpoplpush(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } src := args[0] dst := args[1] diff --git a/cmd_pubsub.go b/cmd_pubsub.go index ec407024..c864e64b 100644 --- a/cmd_pubsub.go +++ b/cmd_pubsub.go @@ -51,11 +51,11 @@ func (m *Miniredis) cmdUnsubscribe(c *server.Peer, cmd string, args []string) { return } - sub := m.subscribedState(c) - channels := args withTx(m, c, func(c *server.Peer, ctx *connCtx) { + sub := m.subscribedState(c) + if len(channels) == 0 { channels = sub.Channels() } @@ -208,11 +208,11 @@ func (m *Miniredis) cmdPunsubscribe(c *server.Peer, cmd string, args []string) { return } - sub := m.subscribedState(c) - patterns := args withTx(m, c, func(c *server.Peer, ctx *connCtx) { + sub := m.subscribedState(c) + if len(patterns) == 0 { patterns = sub.Patterns() } diff --git a/cmd_scripting.go b/cmd_scripting.go index 296e61b9..13b3deca 100644 --- a/cmd_scripting.go +++ b/cmd_scripting.go @@ -113,6 +113,10 @@ func (m *Miniredis) cmdEval(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } + script, args := args[0], args[1:] withTx(m, c, func(c *server.Peer, ctx *connCtx) { @@ -129,6 +133,9 @@ func (m *Miniredis) cmdEvalsha(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } sha, args := args[0], args[1:] @@ -152,6 +159,9 @@ func (m *Miniredis) cmdScript(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } subcmd, args := args[0], args[1:] diff --git a/cmd_server.go b/cmd_server.go index c021644c..1ed9ad2f 100644 --- a/cmd_server.go +++ b/cmd_server.go @@ -27,6 +27,9 @@ func (m *Miniredis) cmdDbsize(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } withTx(m, c, func(c *server.Peer, ctx *connCtx) { db := m.db(ctx.selectedDB) @@ -45,10 +48,12 @@ func (m *Miniredis) cmdFlushall(c *server.Peer, cmd string, args []string) { c.WriteError(msgSyntaxError) return } - if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } withTx(m, c, func(c *server.Peer, ctx *connCtx) { m.flushAll() @@ -66,10 +71,12 @@ func (m *Miniredis) cmdFlushdb(c *server.Peer, cmd string, args []string) { c.WriteError(msgSyntaxError) return } - if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } withTx(m, c, func(c *server.Peer, ctx *connCtx) { m.db(ctx.selectedDB).flush() @@ -77,7 +84,7 @@ func (m *Miniredis) cmdFlushdb(c *server.Peer, cmd string, args []string) { }) } -// TIME: time values are returned in string format instead of int +// TIME func (m *Miniredis) cmdTime(c *server.Peer, cmd string, args []string) { if len(args) > 0 { setDirty(c) @@ -87,6 +94,9 @@ func (m *Miniredis) cmdTime(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } withTx(m, c, func(c *server.Peer, ctx *connCtx) { now := m.now diff --git a/cmd_set.go b/cmd_set.go index 2220cf55..4cb6ee1b 100644 --- a/cmd_set.go +++ b/cmd_set.go @@ -39,6 +39,9 @@ func (m *Miniredis) cmdSadd(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, elems := args[0], args[1:] @@ -65,6 +68,9 @@ func (m *Miniredis) cmdScard(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] @@ -96,6 +102,9 @@ func (m *Miniredis) cmdSdiff(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } keys := args @@ -125,6 +134,9 @@ func (m *Miniredis) cmdSdiffstore(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } dest, keys := args[0], args[1:] @@ -153,6 +165,9 @@ func (m *Miniredis) cmdSinter(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } keys := args @@ -182,6 +197,9 @@ func (m *Miniredis) cmdSinterstore(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } dest, keys := args[0], args[1:] @@ -210,6 +228,9 @@ func (m *Miniredis) cmdSismember(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, value := args[0], args[1] @@ -244,6 +265,9 @@ func (m *Miniredis) cmdSmembers(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] @@ -279,6 +303,9 @@ func (m *Miniredis) cmdSmove(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } src, dst, member := args[0], args[1], args[2] @@ -320,6 +347,9 @@ func (m *Miniredis) cmdSpop(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, args := args[0], args[1:] @@ -401,6 +431,9 @@ func (m *Miniredis) cmdSrandmember(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] count := 0 @@ -467,6 +500,9 @@ func (m *Miniredis) cmdSrem(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, fields := args[0], args[1:] @@ -497,6 +533,9 @@ func (m *Miniredis) cmdSunion(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } keys := args @@ -526,6 +565,9 @@ func (m *Miniredis) cmdSunionstore(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } dest, keys := args[0], args[1:] @@ -554,6 +596,9 @@ func (m *Miniredis) cmdSscan(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] cursor, err := strconv.Atoi(args[1]) diff --git a/cmd_sorted_set.go b/cmd_sorted_set.go index 617282d7..6eeeeab8 100644 --- a/cmd_sorted_set.go +++ b/cmd_sorted_set.go @@ -50,6 +50,9 @@ func (m *Miniredis) cmdZadd(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, args := args[0], args[1:] var ( @@ -168,6 +171,9 @@ func (m *Miniredis) cmdZcard(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] @@ -198,6 +204,9 @@ func (m *Miniredis) cmdZcount(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] min, minIncl, err := parseFloatRange(args[1]) @@ -242,6 +251,9 @@ func (m *Miniredis) cmdZincrby(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] delta, err := strconv.ParseFloat(args[1], 64) @@ -274,6 +286,9 @@ func (m *Miniredis) cmdZinterstore(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } destination := args[0] numKeys, err := strconv.Atoi(args[1]) @@ -403,6 +418,9 @@ func (m *Miniredis) cmdZlexcount(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] min, minIncl, err := parseLexrange(args[1]) @@ -451,6 +469,9 @@ func (m *Miniredis) makeCmdZrange(reverse bool) server.Cmd { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] start, err := strconv.Atoi(args[1]) @@ -524,6 +545,9 @@ func (m *Miniredis) makeCmdZrangebylex(reverse bool) server.Cmd { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] min, minIncl, err := parseLexrange(args[1]) @@ -635,6 +659,9 @@ func (m *Miniredis) makeCmdZrangebyscore(reverse bool) server.Cmd { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] min, minIncl, err := parseFloatRange(args[1]) @@ -756,6 +783,9 @@ func (m *Miniredis) makeCmdZrank(reverse bool) server.Cmd { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, member := args[0], args[1] @@ -796,6 +826,9 @@ func (m *Miniredis) cmdZrem(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, members := args[0], args[1:] @@ -832,6 +865,9 @@ func (m *Miniredis) cmdZremrangebylex(c *server.Peer, cmd string, args []string) if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] min, minIncl, err := parseLexrange(args[1]) @@ -882,6 +918,9 @@ func (m *Miniredis) cmdZremrangebyrank(c *server.Peer, cmd string, args []string if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] start, err := strconv.Atoi(args[1]) @@ -929,6 +968,9 @@ func (m *Miniredis) cmdZremrangebyscore(c *server.Peer, cmd string, args []strin if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] min, minIncl, err := parseFloatRange(args[1]) @@ -977,6 +1019,9 @@ func (m *Miniredis) cmdZscore(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, member := args[0], args[1] @@ -1126,6 +1171,9 @@ func (m *Miniredis) cmdZunionstore(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } destination := args[0] numKeys, err := strconv.Atoi(args[1]) @@ -1254,6 +1302,9 @@ func (m *Miniredis) cmdZscan(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] cursor, err := strconv.Atoi(args[1]) diff --git a/cmd_string.go b/cmd_string.go index 658d77d0..b99a34bd 100644 --- a/cmd_string.go +++ b/cmd_string.go @@ -47,7 +47,6 @@ func (m *Miniredis) cmdSet(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } - if m.checkPubsub(c) { return } @@ -137,6 +136,9 @@ func (m *Miniredis) cmdSetex(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] ttl, err := strconv.Atoi(args[1]) @@ -172,6 +174,9 @@ func (m *Miniredis) cmdPsetex(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] ttl, err := strconv.Atoi(args[1]) @@ -207,6 +212,9 @@ func (m *Miniredis) cmdSetnx(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, value := args[0], args[1] @@ -233,6 +241,9 @@ func (m *Miniredis) cmdMset(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } if len(args)%2 != 0 { setDirty(c) @@ -265,6 +276,9 @@ func (m *Miniredis) cmdMsetnx(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } if len(args)%2 != 0 { setDirty(c) @@ -310,6 +324,9 @@ func (m *Miniredis) cmdGet(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] @@ -339,6 +356,9 @@ func (m *Miniredis) cmdGetset(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, value := args[0], args[1] @@ -373,6 +393,9 @@ func (m *Miniredis) cmdMget(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } withTx(m, c, func(c *server.Peer, ctx *connCtx) { db := m.db(ctx.selectedDB) @@ -404,6 +427,9 @@ func (m *Miniredis) cmdIncr(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } withTx(m, c, func(c *server.Peer, ctx *connCtx) { db := m.db(ctx.selectedDB) @@ -433,6 +459,9 @@ func (m *Miniredis) cmdIncrby(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] delta, err := strconv.Atoi(args[1]) @@ -470,6 +499,9 @@ func (m *Miniredis) cmdIncrbyfloat(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] delta, err := strconv.ParseFloat(args[1], 64) @@ -507,6 +539,9 @@ func (m *Miniredis) cmdDecr(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } withTx(m, c, func(c *server.Peer, ctx *connCtx) { db := m.db(ctx.selectedDB) @@ -536,6 +571,9 @@ func (m *Miniredis) cmdDecrby(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] delta, err := strconv.Atoi(args[1]) @@ -573,6 +611,9 @@ func (m *Miniredis) cmdStrlen(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] @@ -598,6 +639,9 @@ func (m *Miniredis) cmdAppend(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key, value := args[0], args[1] @@ -626,6 +670,9 @@ func (m *Miniredis) cmdGetrange(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] start, err := strconv.Atoi(args[1]) @@ -664,6 +711,9 @@ func (m *Miniredis) cmdSetrange(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] pos, err := strconv.Atoi(args[1]) @@ -709,6 +759,9 @@ func (m *Miniredis) cmdBitcount(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } var ( useRange = false @@ -771,6 +824,9 @@ func (m *Miniredis) cmdBitop(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } var ( op = strings.ToUpper(args[0]) @@ -848,6 +904,9 @@ func (m *Miniredis) cmdBitpos(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] bit, err := strconv.Atoi(args[1]) @@ -930,6 +989,9 @@ func (m *Miniredis) cmdGetbit(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] bit, err := strconv.Atoi(args[1]) @@ -973,6 +1035,9 @@ func (m *Miniredis) cmdSetbit(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } key := args[0] bit, err := strconv.Atoi(args[1]) diff --git a/cmd_transactions.go b/cmd_transactions.go index ad1b269a..d90ff73d 100644 --- a/cmd_transactions.go +++ b/cmd_transactions.go @@ -50,6 +50,9 @@ func (m *Miniredis) cmdExec(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } ctx := getCtx(c) @@ -98,6 +101,9 @@ func (m *Miniredis) cmdDiscard(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } ctx := getCtx(c) if !inTx(ctx) { @@ -119,6 +125,9 @@ func (m *Miniredis) cmdWatch(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } ctx := getCtx(c) if inTx(ctx) { @@ -146,6 +155,9 @@ func (m *Miniredis) cmdUnwatch(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { return } + if m.checkPubsub(c) { + return + } // Doesn't matter if UNWATCH is in a TX or not. Looks like a Redis bug to me. unwatch(getCtx(c)) diff --git a/integration/Makefile b/integration/Makefile index d95d141b..79218443 100644 --- a/integration/Makefile +++ b/integration/Makefile @@ -4,3 +4,7 @@ all: test test: go test -tags int + +commands.txt: ../*.go + grep Register ../*.go|perl -ne '/"(.*)"/ && print "$$1\n"' | sort > commands.txt + diff --git a/integration/pubsub_test.go b/integration/pubsub_test.go index 9e4930f2..2263fe88 100644 --- a/integration/pubsub_test.go +++ b/integration/pubsub_test.go @@ -167,6 +167,7 @@ func TestPubsubSelect(t *testing.T) { } func TestPubsubMode(t *testing.T) { + // most commands aren't allowed in publish mode testCommands(t, succ("SUBSCRIBE", "news", "sport"), receive(), @@ -177,6 +178,142 @@ func TestPubsubMode(t *testing.T) { fail("SET", "foo", "bar"), succ("QUIT"), ) + + e := "ERR only (P)SUBSCRIBE / (P)UNSUBSCRIBE / PING / QUIT allowed in this context" + cbs := []command{ + succ("SUBSCRIBE", "news"), + // failWith(e, "PING"), + // failWith(e, "PSUBSCRIBE"), + // failWith(e, "PUNSUBSCRIBE"), + // failWith(e, "QUIT"), + // failWith(e, "SUBSCRIBE"), + // failWith(e, "UNSUBSCRIBE"), + + failWith(e, "APPEND", "foo", "foo"), + failWith(e, "AUTH", "foo"), + failWith(e, "BITCOUNT", "foo"), + failWith(e, "BITOP", "OR", "foo", "bar"), + failWith(e, "BITPOS", "foo", 0), + failWith(e, "BLPOP", "key", 1), + failWith(e, "BRPOP", "key", 1), + failWith(e, "BRPOPLPUSH", "foo", "bar", 1), + failWith(e, "DBSIZE"), + failWith(e, "DECR", "foo"), + failWith(e, "DECRBY", "foo", 3), + failWith(e, "DEL", "foo"), + failWith(e, "DISCARD"), + failWith(e, "ECHO", "foo"), + failWith(e, "EVAL", "foo", "{}"), + failWith(e, "EVALSHA", "foo", "{}"), + failWith(e, "EXEC"), + failWith(e, "EXISTS", "foo"), + failWith(e, "EXPIRE", "foo", 12), + failWith(e, "EXPIREAT", "foo", 12), + failWith(e, "FLUSHALL"), + failWith(e, "FLUSHDB"), + failWith(e, "GET", "foo"), + failWith(e, "GETBIT", "foo", 12), + failWith(e, "GETRANGE", "foo", 12, 12), + failWith(e, "GETSET", "foo", "bar"), + failWith(e, "HDEL", "foo", "bar"), + failWith(e, "HEXISTS", "foo", "bar"), + failWith(e, "HGET", "foo", "bar"), + failWith(e, "HGETALL", "foo"), + failWith(e, "HINCRBY", "foo", "bar", 12), + failWith(e, "HINCRBYFLOAT", "foo", "bar", 12.34), + failWith(e, "HKEYS", "foo"), + failWith(e, "HLEN", "foo"), + failWith(e, "HMGET", "foo", "bar"), + failWith(e, "HMSET", "foo", "bar", "baz"), + failWith(e, "HSCAN", "foo", 0), + failWith(e, "HSET", "foo", "bar", "baz"), + failWith(e, "HSETNX", "foo", "bar", "baz"), + failWith(e, "HVALS", "foo"), + failWith(e, "INCR", "foo"), + failWith(e, "INCRBY", "foo", 12), + failWith(e, "INCRBYFLOAT", "foo", 12.34), + failWith(e, "KEYS", "*"), + failWith(e, "LINDEX", "foo", 0), + failWith(e, "LINSERT", "foo", "after", "bar", 0), + failWith(e, "LLEN", "foo"), + failWith(e, "LPOP", "foo"), + failWith(e, "LPUSH", "foo", "bar"), + failWith(e, "LPUSHX", "foo", "bar"), + failWith(e, "LRANGE", "foo", 1, 1), + failWith(e, "LREM", "foo", 0, "bar"), + failWith(e, "LSET", "foo", 0, "bar"), + failWith(e, "LTRIM", "foo", 0, 0), + failWith(e, "MGET", "foo", "bar"), + failWith(e, "MOVE", "foo", "bar"), + failWith(e, "MSET", "foo", "bar"), + failWith(e, "MSETNX", "foo", "bar"), + failWith(e, "MULTI"), + failWith(e, "PERSIST", "foo"), + failWith(e, "PEXPIRE", "foo", 12), + failWith(e, "PEXPIREAT", "foo", 12), + failWith(e, "PSETEX", "foo", 12, "bar"), + failWith(e, "PTTL", "foo"), + failWith(e, "PUBLISH", "foo", "bar"), + failWith(e, "PUBSUB", "CHANNELS"), + failWith(e, "RANDOMKEY"), + failWith(e, "RENAME", "foo", "bar"), + failWith(e, "RENAMENX", "foo", "bar"), + failWith(e, "RPOP", "foo"), + failWith(e, "RPOPLPUSH", "foo", "bar"), + failWith(e, "RPUSH", "foo", "bar"), + failWith(e, "RPUSHX", "foo", "bar"), + failWith(e, "SADD", "foo", "bar"), + failWith(e, "SCAN", 0), + failWith(e, "SCARD", "foo"), + failWith(e, "SCRIPT", "FLUSH"), + failWith(e, "SDIFF", "foo"), + failWith(e, "SDIFFSTORE", "foo", "bar"), + failWith(e, "SELECT", 12), + failWith(e, "SET", "foo", "bar"), + failWith(e, "SETBIT", "foo", 0, 1), + failWith(e, "SETEX", "foo", 12, "bar"), + failWith(e, "SETNX", "foo", "bar"), + failWith(e, "SETRANGE", "foo", 0, "bar"), + failWith(e, "SINTER", "foo", "bar"), + failWith(e, "SINTERSTORE", "foo", "bar", "baz"), + failWith(e, "SISMEMBER", "foo", "bar"), + failWith(e, "SMEMBERS", "foo"), + failWith(e, "SMOVE", "foo", "bar", "baz"), + failWith(e, "SPOP", "foo"), + failWith(e, "SRANDMEMBER", "foo"), + failWith(e, "SREM", "foo", "bar", "baz"), + failWith(e, "SSCAN", "foo", 0), + failWith(e, "STRLEN", "foo"), + failWith(e, "SUNION", "foo", "bar"), + failWith(e, "SUNIONSTORE", "foo", "bar", "baz"), + failWith(e, "TIME"), + failWith(e, "TTL", "foo"), + failWith(e, "TYPE", "foo"), + failWith(e, "UNWATCH"), + failWith(e, "WATCH", "foo"), + failWith(e, "ZADD", "foo", "INCR", 1, "bar"), + failWith(e, "ZCARD", "foo"), + failWith(e, "ZCOUNT", "foo", 0, 1), + failWith(e, "ZINCRBY", "foo", "bar", 12), + failWith(e, "ZINTERSTORE", "foo", 1, "bar"), + failWith(e, "ZLEXCOUNT", "foo", "-", "+"), + failWith(e, "ZRANGE", "foo", 0, -1), + failWith(e, "ZRANGEBYLEX", "foo", "-", "+"), + failWith(e, "ZRANGEBYSCORE", "foo", 0, 1), + failWith(e, "ZRANK", "foo", "bar"), + failWith(e, "ZREM", "foo", "bar"), + failWith(e, "ZREMRANGEBYLEX", "foo", "-", "+"), + failWith(e, "ZREMRANGEBYRANK", "foo", 0, 1), + failWith(e, "ZREMRANGEBYSCORE", "foo", 0, 1), + failWith(e, "ZREVRANGE", "foo", 0, -1), + failWith(e, "ZREVRANGEBYLEX", "foo", "+", "-"), + failWith(e, "ZREVRANGEBYSCORE", "foo", 0, 1), + failWith(e, "ZREVRANK", "foo", "bar"), + failWith(e, "ZSCAN", "foo", 0), + failWith(e, "ZSCORE", "foo", "bar"), + failWith(e, "ZUNIONSTORE", "foo", 1, "bar"), + } + testCommands(t, cbs...) } func TestSubscriptions(t *testing.T) { @@ -224,6 +361,14 @@ func TestPubsubTx(t *testing.T) { c1 <- fail("MULTI") // we're in SUBSCRIBE mode }) + // DISCARDing a tx prevents from entering publish mode + testCommands(t, + succ("MULTI"), + succ("SUBSCRIBE", "foo"), + succ("DISCARD"), + succ("PUBSUB", "CHANNELS"), + ) + // UNSUBSCRIBE is in a tx testClients2(t, func(c1, c2 chan<- command) { c1 <- succ("MULTI") From 69ea60b89861752778bba48acf70e286aa7003fb Mon Sep 17 00:00:00 2001 From: Harmen Date: Sat, 23 Mar 2019 15:07:47 +0100 Subject: [PATCH 10/13] clean up peer writer blocking --- cmd_connection.go | 2 +- cmd_pubsub.go | 8 ++--- pubsub.go | 6 +--- server/server.go | 87 +++++++++++++++++++++++++++++++++-------------- 4 files changed, 68 insertions(+), 35 deletions(-) diff --git a/cmd_connection.go b/cmd_connection.go index 58b22bd5..1f35b98f 100644 --- a/cmd_connection.go +++ b/cmd_connection.go @@ -35,7 +35,7 @@ func (m *Miniredis) cmdPing(c *server.Peer, cmd string, args []string) { // PING is allowed in subscribed state if sub := getCtx(c).subscriber; sub != nil { - c.Block(func(c *server.Peer) { + c.Block(func(c *server.Writer) { c.WriteLen(2) c.WriteBulk("pong") c.WriteBulk(payload) diff --git a/cmd_pubsub.go b/cmd_pubsub.go index c864e64b..4763f9bf 100644 --- a/cmd_pubsub.go +++ b/cmd_pubsub.go @@ -35,7 +35,7 @@ func (m *Miniredis) cmdSubscribe(c *server.Peer, cmd string, args []string) { sub := m.subscribedState(c) for _, channel := range args { n := sub.Subscribe(channel) - c.Block(func(c *server.Peer) { + c.Block(func(c *server.Writer) { c.WriteLen(3) c.WriteBulk("subscribe") c.WriteBulk(channel) @@ -63,7 +63,7 @@ func (m *Miniredis) cmdUnsubscribe(c *server.Peer, cmd string, args []string) { // there is no de-duplication for _, channel := range channels { n := sub.Unsubscribe(channel) - c.Block(func(c *server.Peer) { + c.Block(func(c *server.Writer) { c.WriteLen(3) c.WriteBulk("unsubscribe") c.WriteBulk(channel) @@ -92,7 +92,7 @@ func (m *Miniredis) cmdPsubscribe(c *server.Peer, cmd string, args []string) { sub := m.subscribedState(c) for _, pat := range args { n := sub.Psubscribe(pat) - c.Block(func(c *server.Peer) { + c.Block(func(c *server.Writer) { c.WriteLen(3) c.WriteBulk("psubscribe") c.WriteBulk(pat) @@ -220,7 +220,7 @@ func (m *Miniredis) cmdPunsubscribe(c *server.Peer, cmd string, args []string) { // there is no de-duplication for _, pat := range patterns { n := sub.Punsubscribe(pat) - c.Block(func(c *server.Peer) { + c.Block(func(c *server.Writer) { c.WriteLen(3) c.WriteBulk("punsubscribe") c.WriteBulk(pat) diff --git a/pubsub.go b/pubsub.go index 60b9dfe9..dc3bcdb0 100644 --- a/pubsub.go +++ b/pubsub.go @@ -189,16 +189,12 @@ func countPsubs(subs []*Subscriber) int { func monitorPublish(conn *server.Peer, msgs <-chan PubsubMessage) { for msg := range msgs { - conn.Block(func(c *server.Peer) { + conn.Block(func(c *server.Writer) { c.WriteLen(3) c.WriteBulk("message") c.WriteBulk(msg.Channel) c.WriteBulk(msg.Message) c.Flush() - if c.Error != nil { - // what now? :( - // fmt.Printf("publish err: %s\n", c.Error) - } }) } } diff --git a/server/server.go b/server/server.go index 6d06ae25..aad48644 100644 --- a/server/server.go +++ b/server/server.go @@ -142,7 +142,7 @@ func (s *Server) servePeer(c net.Conn) { return } s.dispatch(peer, args) - peer.w.Flush() + peer.Flush() if peer.closed { c.Close() } @@ -194,15 +194,14 @@ type Peer struct { closed bool Ctx interface{} // anything goes, server won't touch this onDisconnect []func() // list of callbacks - Error error // set if any Write* call had an error mu sync.Mutex // for Block() } // Flush the write buffer. Called automatically after every redis command func (c *Peer) Flush() { - if err := c.w.Flush(); err != nil { - c.Error = err - } + c.mu.Lock() + defer c.mu.Unlock() + c.w.Flush() } // Close the client connection after the current command is done. @@ -217,24 +216,24 @@ func (c *Peer) OnDisconnect(f func()) { } // issue multiple calls, guarded with a mutex -func (c *Peer) Block(f func(*Peer)) { +func (c *Peer) Block(f func(*Writer)) { c.mu.Lock() defer c.mu.Unlock() - f(c) + f(&Writer{c.w}) } // WriteError writes a redis 'Error' func (c *Peer) WriteError(e string) { - if _, err := fmt.Fprintf(c.w, "-%s\r\n", toInline(e)); err != nil { - c.Error = err - } + c.Block(func(w *Writer) { + w.WriteError(e) + }) } // WriteInline writes a redis inline string func (c *Peer) WriteInline(s string) { - if _, err := fmt.Fprintf(c.w, "+%s\r\n", toInline(s)); err != nil { - c.Error = err - } + c.Block(func(w *Writer) { + w.WriteInline(s) + }) } // WriteOK write the inline string `OK` @@ -244,30 +243,30 @@ func (c *Peer) WriteOK() { // WriteBulk writes a bulk string func (c *Peer) WriteBulk(s string) { - if _, err := fmt.Fprintf(c.w, "$%d\r\n%s\r\n", len(s), s); err != nil { - c.Error = err - } + c.Block(func(w *Writer) { + w.WriteBulk(s) + }) } // WriteNull writes a redis Null element func (c *Peer) WriteNull() { - if _, err := fmt.Fprintf(c.w, "$-1\r\n"); err != nil { - c.Error = err - } + c.Block(func(w *Writer) { + w.WriteNull() + }) } // WriteLen starts an array with the given length func (c *Peer) WriteLen(n int) { - if _, err := fmt.Fprintf(c.w, "*%d\r\n", n); err != nil { - c.Error = err - } + c.Block(func(w *Writer) { + w.WriteLen(n) + }) } // WriteInt writes an integer func (c *Peer) WriteInt(i int) { - if _, err := fmt.Fprintf(c.w, ":%d\r\n", i); err != nil { - c.Error = err - } + c.Block(func(w *Writer) { + w.WriteInt(i) + }) } func toInline(s string) string { @@ -278,3 +277,41 @@ func toInline(s string) string { return r }, s) } + +// A Writer is given to the callback in Block() +type Writer struct { + w *bufio.Writer +} + +// WriteError writes a redis 'Error' +func (w *Writer) WriteError(e string) { + fmt.Fprintf(w.w, "-%s\r\n", toInline(e)) +} + +func (w *Writer) WriteLen(n int) { + fmt.Fprintf(w.w, "*%d\r\n", n) +} + +// WriteBulk writes a bulk string +func (w *Writer) WriteBulk(s string) { + fmt.Fprintf(w.w, "$%d\r\n%s\r\n", len(s), s) +} + +// WriteInt writes an integer +func (w *Writer) WriteInt(i int) { + fmt.Fprintf(w.w, ":%d\r\n", i) +} + +// WriteNull writes a redis Null element +func (w *Writer) WriteNull() { + fmt.Fprintf(w.w, "$-1\r\n") +} + +// WriteInline writes a redis inline string +func (w *Writer) WriteInline(s string) { + fmt.Fprintf(w.w, "+%s\r\n", toInline(s)) +} + +func (w *Writer) Flush() { + w.w.Flush() +} From d66fb3bf7cf5198e1697197bbc12cfc175adee5f Mon Sep 17 00:00:00 2001 From: Harmen Date: Sat, 23 Mar 2019 15:30:12 +0100 Subject: [PATCH 11/13] fix a datarace --- miniredis.go | 10 ++++++---- server/server.go | 8 ++++++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/miniredis.go b/miniredis.go index cadb46e8..2160afbb 100644 --- a/miniredis.go +++ b/miniredis.go @@ -158,16 +158,18 @@ func (m *Miniredis) Restart() error { // Close shuts down a Miniredis. func (m *Miniredis) Close() { m.Lock() + if m.srv == nil { + m.Unlock() return } + srv := m.srv + m.srv = nil m.Unlock() - m.srv.Close() + // the OnDisconnect callbacks can lock m, so run Close() outside the lock. + srv.Close() - m.Lock() - m.srv = nil - m.Unlock() } // RequireAuth makes every connection need to AUTH first. Disable again by diff --git a/server/server.go b/server/server.go index aad48644..c924bb3c 100644 --- a/server/server.go +++ b/server/server.go @@ -41,7 +41,6 @@ func NewServer(addr string) (*Server, error) { s := Server{ cmds: map[string]Cmd{}, peers: map[net.Conn]struct{}{}, - // onDisconnect: func(c *Peer) {}, } l, err := net.Listen("tcp", addr) @@ -143,7 +142,10 @@ func (s *Server) servePeer(c net.Conn) { } s.dispatch(peer, args) peer.Flush() - if peer.closed { + s.mu.Lock() + closed := peer.closed + s.mu.Unlock() + if closed { c.Close() } } @@ -206,6 +208,8 @@ func (c *Peer) Flush() { // Close the client connection after the current command is done. func (c *Peer) Close() { + c.mu.Lock() + defer c.mu.Unlock() c.closed = true } From 3454bf4487d542204ce399853820d40a2a00bfdc Mon Sep 17 00:00:00 2001 From: Harmen Date: Sat, 23 Mar 2019 15:33:09 +0100 Subject: [PATCH 12/13] upgrade Go version to 1.12 --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index d9122d17..c66fc906 100644 --- a/.travis.yml +++ b/.travis.yml @@ -10,4 +10,4 @@ script: make test testrace int sudo: false go: - - 1.11 + - 1.12 From 8e2ff947abf2815d61aacc11b3cbd98dadb44eaf Mon Sep 17 00:00:00 2001 From: Harmen Date: Fri, 29 Mar 2019 16:25:40 +0100 Subject: [PATCH 13/13] update some comments --- direct.go | 4 +++- miniredis.go | 6 +++--- pubsub.go | 19 +++++++++++++++---- test_test.go | 21 --------------------- 4 files changed, 21 insertions(+), 29 deletions(-) diff --git a/direct.go b/direct.go index 533224d0..6d71f8b2 100644 --- a/direct.go +++ b/direct.go @@ -548,13 +548,15 @@ func (db *RedisDB) ZScore(k, member string) (float64, error) { return db.ssetScore(k, member), nil } +// Publish a message to subscribers. Returns the number of receivers. func (m *Miniredis) Publish(channel, message string) int { m.Lock() defer m.Unlock() return m.publish(channel, message) } -// PubSubChannels is "PUBSUB CHANNELS ". An empty pattern is fine. +// PubSubChannels is "PUBSUB CHANNELS ". An empty pattern is fine +// (meaning all channels). // Returned channels will be ordered alphabetically. func (m *Miniredis) PubSubChannels(pattern string) []string { m.Lock() diff --git a/miniredis.go b/miniredis.go index 2160afbb..a861237d 100644 --- a/miniredis.go +++ b/miniredis.go @@ -32,7 +32,7 @@ type setKey map[string]struct{} // RedisDB holds a single (numbered) Redis database. type RedisDB struct { - master *Miniredis // pointer to the lock in Miniredis + master *sync.Mutex // pointer to the lock in Miniredis id int // db id keys map[string]string // Master map of keys with their type stringKeys map[string]string // GET/SET &c. keys @@ -87,7 +87,7 @@ func NewMiniRedis() *Miniredis { return &m } -func newRedisDB(id int, l *Miniredis) RedisDB { +func newRedisDB(id int, l *sync.Mutex) RedisDB { return RedisDB{ id: id, master: l, @@ -192,7 +192,7 @@ func (m *Miniredis) db(i int) *RedisDB { if db, ok := m.dbs[i]; ok { return db } - db := newRedisDB(i, m) // the DB has our lock. + db := newRedisDB(i, &m.Mutex) // the DB has our lock. m.dbs[i] = &db return &db } diff --git a/pubsub.go b/pubsub.go index dc3bcdb0..2d0f04ec 100644 --- a/pubsub.go +++ b/pubsub.go @@ -22,6 +22,8 @@ type Subscriber struct { mu sync.Mutex } +// Make a new subscriber. The channel is not buffered, so you will need to keep +// reading using Messages(). Use Close() when done, or unsubscribe. func newSubscriber() *Subscriber { return &Subscriber{ publish: make(chan PubsubMessage), @@ -30,11 +32,12 @@ func newSubscriber() *Subscriber { } } +// Close the listening channel func (s *Subscriber) Close() { close(s.publish) } -// total number of channels and patterns +// Count the total number of channels and patterns func (s *Subscriber) Count() int { s.mu.Lock() defer s.mu.Unlock() @@ -45,6 +48,8 @@ func (s *Subscriber) count() int { return len(s.channels) + len(s.patterns) } +// Subscribe to a channel. Returns the total number of (p)subscriptions after +// subscribing. func (s *Subscriber) Subscribe(c string) int { s.mu.Lock() defer s.mu.Unlock() @@ -53,6 +58,8 @@ func (s *Subscriber) Subscribe(c string) int { return s.count() } +// Unsubscribe a channel. Returns the total number of (p)subscriptions after +// unsubscribing. func (s *Subscriber) Unsubscribe(c string) int { s.mu.Lock() defer s.mu.Unlock() @@ -61,6 +68,8 @@ func (s *Subscriber) Unsubscribe(c string) int { return s.count() } +// Subscribe to a pattern. Returns the total number of (p)subscriptions after +// subscribing. func (s *Subscriber) Psubscribe(pat string) int { s.mu.Lock() defer s.mu.Unlock() @@ -69,6 +78,8 @@ func (s *Subscriber) Psubscribe(pat string) int { return s.count() } +// Unsubscribe a pattern. Returns the total number of (p)subscriptions after +// unsubscribing. func (s *Subscriber) Punsubscribe(pat string) int { s.mu.Lock() defer s.mu.Unlock() @@ -77,7 +88,7 @@ func (s *Subscriber) Punsubscribe(pat string) int { return s.count() } -// List all subscribed channels +// List all subscribed channels, in alphabetical order func (s *Subscriber) Channels() []string { s.mu.Lock() defer s.mu.Unlock() @@ -90,7 +101,7 @@ func (s *Subscriber) Channels() []string { return cs } -// List all subscribed patters +// List all subscribed patterns, in alphabetical order func (s *Subscriber) Patterns() []string { s.mu.Lock() defer s.mu.Unlock() @@ -132,7 +143,7 @@ pats: return found } -// reads the subscriptions +// The channel to read messages for this subscriber func (s *Subscriber) Messages() <-chan PubsubMessage { return s.publish } diff --git a/test_test.go b/test_test.go index 7868351c..a20c941b 100644 --- a/test_test.go +++ b/test_test.go @@ -1,9 +1,7 @@ package miniredis import ( - "fmt" "reflect" - "strings" "testing" ) @@ -43,22 +41,3 @@ func mustFail(tb testing.TB, err error, want string) { tb.Errorf("have %q, want %q", have, want) } } - -func oneOf(tb testing.TB, exps []interface{}, act interface{}) bool { - tb.Helper() - - for _, exp := range exps { - if reflect.DeepEqual(exp, act) { - return true - } - } - - expPP := make([]string, len(exps)) - for i, exp := range exps { - expPP[i] = fmt.Sprintf("%#v", exp) - } - - tb.Errorf("expected one of: %s got: %#v", strings.Join(expPP, ", "), act) - - return false -}