diff --git a/cmd_pubsub.go b/cmd_pubsub.go index 41c94413..2fbeb7af 100644 --- a/cmd_pubsub.go +++ b/cmd_pubsub.go @@ -4,7 +4,6 @@ package miniredis import ( "fmt" - "regexp" "strings" "github.com/alicebob/miniredis/v2/server" @@ -35,11 +34,11 @@ func (m *Miniredis) cmdSubscribe(c *server.Peer, cmd string, args []string) { sub := m.subscribedState(c) for _, channel := range args { n := sub.Subscribe(channel) - c.Block(func(c *server.Writer) { - c.WriteLen(3) - c.WriteBulk("subscribe") - c.WriteBulk(channel) - c.WriteInt(n) + c.Block(func(w *server.Writer) { + w.WriteLen(3) + w.WriteBulk("subscribe") + w.WriteBulk(channel) + w.WriteInt(n) }) } }) @@ -63,11 +62,11 @@ func (m *Miniredis) cmdUnsubscribe(c *server.Peer, cmd string, args []string) { // there is no de-duplication for _, channel := range channels { n := sub.Unsubscribe(channel) - c.Block(func(c *server.Writer) { - c.WriteLen(3) - c.WriteBulk("unsubscribe") - c.WriteBulk(channel) - c.WriteInt(n) + c.Block(func(w *server.Writer) { + w.WriteLen(3) + w.WriteBulk("unsubscribe") + w.WriteBulk(channel) + w.WriteInt(n) }) } @@ -92,116 +91,16 @@ func (m *Miniredis) cmdPsubscribe(c *server.Peer, cmd string, args []string) { sub := m.subscribedState(c) for _, pat := range args { n := sub.Psubscribe(pat) - c.Block(func(c *server.Writer) { - c.WriteLen(3) - c.WriteBulk("psubscribe") - c.WriteBulk(pat) - c.WriteInt(n) + c.Block(func(w *server.Writer) { + w.WriteLen(3) + w.WriteBulk("psubscribe") + w.WriteBulk(pat) + w.WriteInt(n) }) } }) } -func compileChannelPattern(pattern string) *regexp.Regexp { - const readingLiteral uint8 = 0 - const afterEscape uint8 = 1 - const inClass uint8 = 2 - - rgx := []rune{'\\', 'A'} - state := readingLiteral - literals := []rune{} - klass := map[rune]struct{}{} - - for _, c := range pattern { - switch state { - case readingLiteral: - switch c { - case '\\': - state = afterEscape - case '?': - rgx = append(rgx, append([]rune(regexp.QuoteMeta(string(literals))), '.')...) - literals = []rune{} - case '*': - rgx = append(rgx, append([]rune(regexp.QuoteMeta(string(literals))), '.', '*')...) - literals = []rune{} - case '[': - rgx = append(rgx, []rune(regexp.QuoteMeta(string(literals)))...) - literals = []rune{} - state = inClass - default: - literals = append(literals, c) - } - case afterEscape: - literals = append(literals, c) - state = readingLiteral - case inClass: - if c == ']' { - expr := []rune{'['} - - if _, hasDash := klass['-']; hasDash { - delete(klass, '-') - expr = append(expr, '-') - } - - flatClass := make([]rune, len(klass)) - i := 0 - - for c := range klass { - flatClass[i] = c - i++ - } - - klass = map[rune]struct{}{} - expr = append(append(expr, []rune(regexp.QuoteMeta(string(flatClass)))...), ']') - - if len(expr) < 3 { - rgx = append(rgx, 'x', '\\', 'b', 'y') - } else { - rgx = append(rgx, expr...) - } - - state = readingLiteral - } else { - klass[c] = struct{}{} - } - } - } - - switch state { - case afterEscape: - rgx = append(rgx, '\\', '\\') - case inClass: - if len(klass) < 0 { - rgx = append(rgx, '\\', '[') - } else { - expr := []rune{'['} - - if _, hasDash := klass['-']; hasDash { - delete(klass, '-') - expr = append(expr, '-') - } - - flatClass := make([]rune, len(klass)) - i := 0 - - for c := range klass { - flatClass[i] = c - i++ - } - - expr = append(append(expr, []rune(regexp.QuoteMeta(string(flatClass)))...), ']') - - if len(expr) < 3 { - rgx = append(rgx, 'x', '\\', 'b', 'y') - } else { - rgx = append(rgx, expr...) - } - } - } - - return regexp.MustCompile(string(append(rgx, '\\', 'z'))) -} - // PUNSUBSCRIBE func (m *Miniredis) cmdPunsubscribe(c *server.Peer, cmd string, args []string) { if !m.handleAuth(c) { @@ -220,11 +119,11 @@ func (m *Miniredis) cmdPunsubscribe(c *server.Peer, cmd string, args []string) { // there is no de-duplication for _, pat := range patterns { n := sub.Punsubscribe(pat) - c.Block(func(c *server.Writer) { - c.WriteLen(3) - c.WriteBulk("punsubscribe") - c.WriteBulk(pat) - c.WriteInt(n) + c.Block(func(w *server.Writer) { + w.WriteLen(3) + w.WriteBulk("punsubscribe") + w.WriteBulk(pat) + w.WriteInt(n) }) } diff --git a/cmd_pubsub_test.go b/cmd_pubsub_test.go index b01026d9..71b06d5d 100644 --- a/cmd_pubsub_test.go +++ b/cmd_pubsub_test.go @@ -147,7 +147,7 @@ func TestPsubscribe(t *testing.T) { s, err := redis.Strings(c.Receive()) ok(t, err) - equals(t, []string{"message", "event4b", "hello 4b!"}, s) + equals(t, []string{"pmessage", "event4[abc]", "event4b", "hello 4b!"}, s) } } @@ -381,440 +381,3 @@ func TestPubSubBadArgs(t *testing.T) { done() } } - -func TestPubSubInteraction(t *testing.T) { - s, err := Run() - ok(t, err) - defer s.Close() - - ch := make(chan struct{}, 8) - tasks := [5]func(){} - directTasks := [4]func(){} - - for i, tester := range [5]func(t *testing.T, s *Miniredis, c redis.Conn, chCtl chan struct{}){ - testPubSubInteractionSub1, - testPubSubInteractionSub2, - testPubSubInteractionPsub1, - testPubSubInteractionPsub2, - testPubSubInteractionPub, - } { - tasks[i] = runActualRedisClientForPubSub(t, s, ch, tester) - } - - for i, tester := range [4]func(t *testing.T, s *Miniredis, chCtl chan struct{}){ - testPubSubInteractionDirectSub1, - testPubSubInteractionDirectSub2, - testPubSubInteractionDirectPsub1, - testPubSubInteractionDirectPsub2, - } { - directTasks[i] = runDirectRedisClientForPubSub(t, s, ch, tester) - } - - for _, task := range tasks { - task() - } - - for _, task := range directTasks { - task() - } -} - -func testPubSubInteractionSub1(t *testing.T, _ *Miniredis, c redis.Conn, ch chan struct{}) { - assertCorrectSubscriptionsCounts( - t, - []int64{1, 2, 3, 4}, - runCmdDuringPubSub(t, c, 3, "SUBSCRIBE", "event1", "event2", "event3", "event4"), - ) - - ch <- struct{}{} - receiveMessagesDuringPubSub(t, c, '1', '2', '3', '4') - - assertCorrectSubscriptionsCounts( - t, - []int64{3, 2}, - runCmdDuringPubSub(t, c, 1, "UNSUBSCRIBE", "event2", "event3"), - ) - - ch <- struct{}{} - receiveMessagesDuringPubSub(t, c, '1', '4') -} - -func testPubSubInteractionSub2(t *testing.T, _ *Miniredis, c redis.Conn, ch chan struct{}) { - assertCorrectSubscriptionsCounts( - t, - []int64{1, 2, 3, 4}, - runCmdDuringPubSub(t, c, 3, "SUBSCRIBE", "event3", "event4", "event5", "event6"), - ) - - ch <- struct{}{} - receiveMessagesDuringPubSub(t, c, '3', '4', '5', '6') - - assertCorrectSubscriptionsCounts( - t, - []int64{3, 2}, - runCmdDuringPubSub(t, c, 1, "UNSUBSCRIBE", "event4", "event5"), - ) - - ch <- struct{}{} - receiveMessagesDuringPubSub(t, c, '3', '6') -} - -func testPubSubInteractionDirectSub1(t *testing.T, s *Miniredis, ch chan struct{}) { - sub := s.NewSubscriber() - defer sub.Close() - - sub.Subscribe("event1") - sub.Subscribe("event3") - sub.Subscribe("event4") - sub.Subscribe("event6") - - ch <- struct{}{} - receiveMessagesDirectlyDuringPubSub(t, sub, '1', '3', '4', '6') - - sub.Unsubscribe("event1") - sub.Unsubscribe("event4") - - ch <- struct{}{} - receiveMessagesDirectlyDuringPubSub(t, sub, '3', '6') -} - -func testPubSubInteractionDirectSub2(t *testing.T, s *Miniredis, ch chan struct{}) { - sub := s.NewSubscriber() - defer sub.Close() - - sub.Subscribe("event2") - sub.Subscribe("event3") - sub.Subscribe("event4") - sub.Subscribe("event5") - - ch <- struct{}{} - receiveMessagesDirectlyDuringPubSub(t, sub, '2', '3', '4', '5') - - sub.Unsubscribe("event3") - sub.Unsubscribe("event5") - - ch <- struct{}{} - receiveMessagesDirectlyDuringPubSub(t, sub, '2', '4') -} - -func testPubSubInteractionPsub1(t *testing.T, _ *Miniredis, c redis.Conn, ch chan struct{}) { - assertCorrectSubscriptionsCounts( - t, - []int64{1, 2, 3, 4}, - runCmdDuringPubSub(t, c, 3, "PSUBSCRIBE", "event[ab1]", "event[cd]", "event[ef3]", "event[gh]"), - ) - - ch <- struct{}{} - receiveMessagesDuringPubSub(t, c, '1', '3', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h') - - assertCorrectSubscriptionsCounts( - t, - []int64{3, 2}, - runCmdDuringPubSub(t, c, 1, "PUNSUBSCRIBE", "event[cd]", "event[ef3]"), - ) - - ch <- struct{}{} - receiveMessagesDuringPubSub(t, c, '1', 'a', 'b', 'g', 'h') -} - -func testPubSubInteractionPsub2(t *testing.T, _ *Miniredis, c redis.Conn, ch chan struct{}) { - assertCorrectSubscriptionsCounts( - t, - []int64{1, 2, 3, 4}, - runCmdDuringPubSub(t, c, 3, "PSUBSCRIBE", "event[ef]", "event[gh4]", "event[ij]", "event[kl6]"), - ) - - ch <- struct{}{} - receiveMessagesDuringPubSub(t, c, '4', '6', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l') - - assertCorrectSubscriptionsCounts( - t, - []int64{3, 2}, - runCmdDuringPubSub(t, c, 1, "PUNSUBSCRIBE", "event[gh4]", "event[ij]"), - ) - - ch <- struct{}{} - receiveMessagesDuringPubSub(t, c, '6', 'e', 'f', 'k', 'l') -} - -func testPubSubInteractionDirectPsub1(t *testing.T, s *Miniredis, ch chan struct{}) { - sub := s.NewSubscriber() - defer sub.Close() - - sub.Psubscribe(`event[ab1]`) - sub.Psubscribe(`event[ef3]`) - sub.Psubscribe(`event[gh]`) - sub.Psubscribe(`event[kl6]`) - - ch <- struct{}{} - receiveMessagesDirectlyDuringPubSub(t, sub, '1', '3', '6', 'a', 'b', 'e', 'f', 'g', 'h', 'k', 'l') - - sub.Punsubscribe(`event[ab1]`) - sub.Punsubscribe(`event[gh]`) - - ch <- struct{}{} - receiveMessagesDirectlyDuringPubSub(t, sub, '3', '6', 'e', 'f', 'k', 'l') -} - -func testPubSubInteractionDirectPsub2(t *testing.T, s *Miniredis, ch chan struct{}) { - sub := s.NewSubscriber() - defer sub.Close() - - sub.Psubscribe(`event[cd]`) - sub.Psubscribe(`event[ef]`) - sub.Psubscribe(`event[gh4]`) - sub.Psubscribe(`event[ij]`) - - ch <- struct{}{} - receiveMessagesDirectlyDuringPubSub(t, sub, '4', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j') - - sub.Punsubscribe(`event[ef]`) - sub.Punsubscribe(`event[ij]`) - - ch <- struct{}{} - receiveMessagesDirectlyDuringPubSub(t, sub, '4', 'c', 'd', 'g', 'h') -} - -func testPubSubInteractionPub(t *testing.T, s *Miniredis, c redis.Conn, ch chan struct{}) { - testPubSubInteractionPubStage1(t, s, c, ch) - testPubSubInteractionPubStage2(t, s, c, ch) -} - -func testPubSubInteractionPubStage1(t *testing.T, s *Miniredis, c redis.Conn, ch chan struct{}) { - for i := uint8(0); i < 8; i++ { - <-ch - } - - for _, pattern := range []string{ - "", - "event?", - } { - assertActiveChannelsDuringPubSub(t, s, c, pattern, []string{ - "event1", "event2", "event3", "event4", "event5", "event6", - }) - } - - assertActiveChannelsDuringPubSub(t, s, c, "*[123]", []string{ - "event1", "event2", "event3", - }) - - assertNumSubDuringPubSub(t, s, c, map[string]int{ - "event1": 2, "event2": 2, "event3": 4, "event4": 4, "event5": 2, "event6": 2, - "event[ab1]": 0, "event[cd]": 0, "event[ef3]": 0, "event[gh]": 0, "event[ij]": 0, "event[kl6]": 0, - }) - - assertNumPatDuringPubSub(t, s, c, 16) - - for _, message := range [18]struct { - channelSuffix rune - subscribers uint8 - }{ - {'1', 4}, {'2', 2}, {'3', 6}, {'4', 6}, {'5', 2}, {'6', 4}, - {'a', 2}, {'b', 2}, {'c', 2}, {'d', 2}, {'e', 4}, {'f', 4}, - {'g', 4}, {'h', 4}, {'i', 2}, {'j', 2}, {'k', 2}, {'l', 2}, - } { - suffix := string([]rune{message.channelSuffix}) - replies := runCmdDuringPubSub(t, c, 0, "PUBLISH", "event"+suffix, "message"+suffix) - equals(t, []interface{}{int64(message.subscribers)}, replies) - } -} - -func testPubSubInteractionPubStage2(t *testing.T, s *Miniredis, c redis.Conn, ch chan struct{}) { - for i := uint8(0); i < 8; i++ { - <-ch - } - - for _, pattern := range []string{ - "", - "event?", - } { - assertActiveChannelsDuringPubSub(t, s, c, pattern, []string{ - "event1", "event2", "event3", "event4", "event6", - }) - } - - assertActiveChannelsDuringPubSub(t, s, c, "*[123]", []string{"event1", "event2", "event3"}) - - assertNumSubDuringPubSub(t, s, c, map[string]int{ - "event1": 1, "event2": 1, "event3": 2, "event4": 2, "event5": 0, "event6": 2, - "event[ab1]": 0, "event[cd]": 0, "event[ef3]": 0, "event[gh]": 0, "event[ij]": 0, "event[kl6]": 0, - }) - - assertNumPatDuringPubSub(t, s, c, 8) - - for _, message := range [18]struct { - channelSuffix rune - subscribers uint8 - }{ - {'1', 2}, {'2', 1}, {'3', 3}, {'4', 3}, {'5', 0}, {'6', 4}, - {'a', 1}, {'b', 1}, {'c', 1}, {'d', 1}, {'e', 2}, {'f', 2}, - {'g', 2}, {'h', 2}, {'i', 0}, {'j', 0}, {'k', 2}, {'l', 2}, - } { - suffix := string([]rune{message.channelSuffix}) - equals(t, int(message.subscribers), s.Publish("event"+suffix, "message"+suffix)) - } -} - -func runActualRedisClientForPubSub(t *testing.T, s *Miniredis, chCtl chan struct{}, tester func(t *testing.T, s *Miniredis, c redis.Conn, chCtl chan struct{})) (wait func()) { - t.Helper() - - c, err := redis.Dial("tcp", s.Addr()) - ok(t, err) - - ch := make(chan struct{}) - - go func() { - t.Helper() - - tester(t, s, c, chCtl) - c.Close() - close(ch) - }() - - return func() { <-ch } -} - -func runDirectRedisClientForPubSub(t *testing.T, s *Miniredis, chCtl chan struct{}, tester func(t *testing.T, s *Miniredis, chCtl chan struct{})) (wait func()) { - t.Helper() - - ch := make(chan struct{}) - - go func() { - t.Helper() - - tester(t, s, chCtl) - close(ch) - }() - - return func() { <-ch } -} - -func runCmdDuringPubSub(t *testing.T, c redis.Conn, followUpMessages uint8, command string, args ...interface{}) (replies []interface{}) { - t.Helper() - - replies = make([]interface{}, followUpMessages+1) - - reply, err := c.Do(command, args...) - ok(t, err) - - replies[0] = reply - i := 1 - - for ; followUpMessages > 0; followUpMessages-- { - reply, err := c.Receive() - ok(t, err) - - replies[i] = reply - i++ - } - - return -} - -func assertCorrectSubscriptionsCounts(t *testing.T, subscriptionsCounts []int64, replies []interface{}) { - t.Helper() - - for i, subscriptionsCount := range subscriptionsCounts { - if arrayReply, isArrayReply := replies[i].([]interface{}); isArrayReply && len(arrayReply) > 2 { - equals(t, subscriptionsCount, arrayReply[2]) - } - } -} - -func receiveMessagesDuringPubSub(t *testing.T, c redis.Conn, suffixes ...rune) { - t.Helper() - - for _, suffix := range suffixes { - msg, err := c.Receive() - ok(t, err) - - suff := string([]rune{suffix}) - equals(t, []interface{}{[]byte("message"), []byte("event" + suff), []byte("message" + suff)}, msg) - } -} - -func receiveMessagesDirectlyDuringPubSub(t *testing.T, sub *Subscriber, suffixes ...rune) { - t.Helper() - - for _, suffix := range suffixes { - suff := string([]rune{suffix}) - equals(t, PubsubMessage{"event" + suff, "message" + suff}, <-sub.Messages()) - } -} - -func assertActiveChannelsDuringPubSub(t *testing.T, s *Miniredis, c redis.Conn, pattern string, channels []string) { - var args []interface{} - if pattern == "" { - args = []interface{}{"CHANNELS"} - } else { - args = []interface{}{"CHANNELS", pattern} - } - - actual, err := redis.Strings(c.Do("PUBSUB", args...)) - ok(t, err) - - equals(t, channels, actual) - - equals(t, channels, s.PubSubChannels(pattern)) -} - -func assertNumSubDuringPubSub(t *testing.T, s *Miniredis, c redis.Conn, channels map[string]int) { - t.Helper() - - args := make([]interface{}, 1+len(channels)) - args[0] = "NUMSUB" - i := 1 - - flatChannels := make([]string, len(channels)) - j := 0 - - for channel := range channels { - args[i] = channel - i++ - - flatChannels[j] = channel - j++ - } - - a, err := redis.Values(c.Do("PUBSUB", args...)) - ok(t, err) - equals(t, len(channels)*2, len(a)) - - actualChannels := make(map[string]int, len(a)) - - var currentChannel string - currentState := uint8(0) - - for _, item := range a { - if currentState&uint8(1) == 0 { - if channelString, channelIsString := item.([]byte); channelIsString { - currentChannel = string(channelString) - currentState |= 2 - } else { - currentState &= ^uint8(2) - } - - currentState |= 1 - } else { - if subsInt, subsIsInt := item.(int64); subsIsInt && currentState&uint8(2) != 0 { - actualChannels[currentChannel] = int(subsInt) - } - - currentState &= ^uint8(1) - } - } - - equals(t, channels, actualChannels) - - equals(t, channels, s.PubSubNumSub(flatChannels...)) -} - -func assertNumPatDuringPubSub(t *testing.T, s *Miniredis, c redis.Conn, numPat int) { - t.Helper() - - a, err := redis.Int(c.Do("PUBSUB", "NUMPAT")) - ok(t, err) - equals(t, numPat, a) - - equals(t, numPat, s.PubSubNumPat()) -} diff --git a/integration/pubsub_test.go b/integration/pubsub_test.go index 909cd973..30fcd0fa 100644 --- a/integration/pubsub_test.go +++ b/integration/pubsub_test.go @@ -27,43 +27,77 @@ func TestSubscribe(t *testing.T) { ) } -func TestPSubscribe(t *testing.T) { - testCommands(t, - fail("PSUBSCRIBE"), +func TestPsubscribe(t *testing.T) { + testClients2(t, func(c1, c2 chan<- command) { + c1 <- fail("PSUBSCRIBE") - succ("PSUBSCRIBE", "foo"), - succ("PUNSUBSCRIBE"), + c1 <- succ("PSUBSCRIBE", "foo") + c2 <- succ("PUBLISH", "foo", "hi") + c1 <- receive() + c1 <- succ("PUNSUBSCRIBE") - succ("PSUBSCRIBE", "foo"), - succ("PUNSUBSCRIBE", "foo"), + c1 <- succ("PSUBSCRIBE", "foo") + c2 <- succ("PUBLISH", "foo", "hi") + c1 <- receive() + c1 <- succ("PUNSUBSCRIBE", "foo") - succ("PSUBSCRIBE", "foo", "bar"), - succ("PUNSUBSCRIBE", "foo", "bar"), + c1 <- succ("PSUBSCRIBE", "foo", "bar") + c2 <- succ("PUBLISH", "foo", "hi") + c1 <- receive() + c1 <- succ("PUNSUBSCRIBE", "foo", "bar") - succ("PSUBSCRIBE", "f?o"), - succ("PUNSUBSCRIBE", "f?o"), + c1 <- succ("PSUBSCRIBE", "f?o") + c2 <- succ("PUBLISH", "foo", "hi") + c1 <- receive() + c1 <- succ("PUNSUBSCRIBE", "f?o") - succ("PSUBSCRIBE", "f*o"), - succ("PUNSUBSCRIBE", "f*o"), + c1 <- succ("PSUBSCRIBE", "f*o") + c2 <- succ("PUBLISH", "foo", "hi") + c1 <- receive() + c1 <- succ("PUNSUBSCRIBE", "f*o") - succ("PSUBSCRIBE", "f[oO]o"), - succ("PUNSUBSCRIBE", "f[oO]o"), + c1 <- succ("PSUBSCRIBE", "f[oO]o") + c2 <- succ("PUBLISH", "foo", "hi") + c1 <- receive() + c1 <- succ("PUNSUBSCRIBE", "f[oO]o") - succ("PSUBSCRIBE", "f\\?o"), - succ("PUNSUBSCRIBE", "f\\?o"), + c1 <- succ("PSUBSCRIBE", "f\\?o") + c2 <- succ("PUBLISH", "f?o", "hi") + c1 <- receive() + c1 <- succ("PUNSUBSCRIBE", "f\\?o") - succ("PSUBSCRIBE", "f\\*o"), - succ("PUNSUBSCRIBE", "f\\*o"), + c1 <- succ("PSUBSCRIBE", "f\\*o") + c2 <- succ("PUBLISH", "f*o", "hi") + c1 <- receive() + c1 <- succ("PUNSUBSCRIBE", "f\\*o") - succ("PSUBSCRIBE", "f\\[oO]o"), - succ("PUNSUBSCRIBE", "f\\[oO]o"), + c1 <- succ("PSUBSCRIBE", "f\\[oO]o") + c2 <- succ("PUBLISH", "f[oO]o", "hi") + c1 <- receive() + c1 <- succ("PUNSUBSCRIBE", "f\\[oO]o") - succ("PSUBSCRIBE", "f\\\\oo"), - succ("PUNSUBSCRIBE", "f\\\\oo"), + c1 <- succ("PSUBSCRIBE", "f\\\\oo") + c2 <- succ("PUBLISH", "f\\\\oo", "hi") + c1 <- receive() + c1 <- succ("PUNSUBSCRIBE", "f\\\\oo") - succ("PSUBSCRIBE", -1), - succ("PUNSUBSCRIBE", -1), - ) + c1 <- succ("PSUBSCRIBE", -1) + c2 <- succ("PUBLISH", "foo", "hi") + c1 <- receive() + c1 <- succ("PUNSUBSCRIBE", -1) + }) + + testClients2(t, func(c1, c2 chan<- command) { + c1 <- succ("PSUBSCRIBE", "news*") + c2 <- succ("PUBLISH", "news", "fire!") + c1 <- receive() + }) + + testClients2(t, func(c1, c2 chan<- command) { + c1 <- succ("PSUBSCRIBE", "news") // no pattern + c2 <- succ("PUBLISH", "news", "fire!") + c1 <- receive() + }) } func TestPublish(t *testing.T) { @@ -104,27 +138,19 @@ func TestPubSub(t *testing.T) { } func TestPubsubFull(t *testing.T) { - var wg1 sync.WaitGroup - wg1.Add(1) - testMultiCommands(t, - func(r chan<- command, _ *miniredis.Miniredis) { - r <- succ("SUBSCRIBE", "news", "sport") - r <- receive() - wg1.Done() - r <- receive() - r <- receive() - r <- receive() - r <- succ("UNSUBSCRIBE", "news", "sport") - r <- receive() - }, - func(r chan<- command, _ *miniredis.Miniredis) { - wg1.Wait() - r <- succ("PUBLISH", "news", "revolution!") - r <- succ("PUBLISH", "news", "alien invasion!") - r <- succ("PUBLISH", "sport", "lady biked too fast") - r <- succ("PUBLISH", "gossip", "man bites dog") - }, - ) + testClients2(t, func(c1, c2 chan<- command) { + c1 <- succ("SUBSCRIBE", "news", "sport") + c1 <- receive() + c2 <- succ("PUBLISH", "news", "revolution!") + c2 <- succ("PUBLISH", "news", "alien invasion!") + c2 <- succ("PUBLISH", "sport", "lady biked too fast") + c2 <- succ("PUBLISH", "gossip", "man bites dog") + c1 <- receive() + c1 <- receive() + c1 <- receive() + c1 <- succ("UNSUBSCRIBE", "news", "sport") + c1 <- receive() + }) } func TestPubsubMulti(t *testing.T) { diff --git a/keys.go b/keys.go index b7cd98fb..8d9a4199 100644 --- a/keys.go +++ b/keys.go @@ -1,13 +1,13 @@ package miniredis -// Translate the 'KEYS' argument ('foo*', 'f??', &c.) into a regexp. +// Translate the 'KEYS' or 'PSUBSCRIBE' argument ('foo*', 'f??', &c.) into a regexp. import ( "bytes" "regexp" ) -// patternRE compiles a KEYS argument to a regexp. Returns nil if the given +// patternRE compiles a glob to a regexp. Returns nil if the given // pattern will never match anything. // The general strategy is to sandwich all non-meta characters between \Q...\E. func patternRE(k string) *regexp.Regexp { diff --git a/miniredis.go b/miniredis.go index 2926808e..7ea0ffc2 100644 --- a/miniredis.go +++ b/miniredis.go @@ -461,6 +461,7 @@ func (m *Miniredis) subscribedState(c *server.Peer) *Subscriber { ctx.subscriber = sub go monitorPublish(c, sub.publish) + go monitorPpublish(c, sub.ppublish) return sub } diff --git a/pubsub.go b/pubsub.go index 0a9b400d..fb4d9d76 100644 --- a/pubsub.go +++ b/pubsub.go @@ -14,9 +14,16 @@ type PubsubMessage struct { Message string } +type PubsubPmessage struct { + Pattern string + Channel string + Message string +} + // Subscriber has the (p)subscriptions. type Subscriber struct { publish chan PubsubMessage + ppublish chan PubsubPmessage channels map[string]struct{} patterns map[string]*regexp.Regexp mu sync.Mutex @@ -27,6 +34,7 @@ type Subscriber struct { func newSubscriber() *Subscriber { return &Subscriber{ publish: make(chan PubsubMessage), + ppublish: make(chan PubsubPmessage), channels: map[string]struct{}{}, patterns: map[string]*regexp.Regexp{}, } @@ -35,6 +43,7 @@ func newSubscriber() *Subscriber { // Close the listening channel func (s *Subscriber) Close() { close(s.publish) + close(s.ppublish) } // Count the total number of channels and patterns @@ -74,7 +83,7 @@ func (s *Subscriber) Psubscribe(pat string) int { s.mu.Lock() defer s.mu.Unlock() - s.patterns[pat] = compileChannelPattern(pat) + s.patterns[pat] = patternRE(pat) return s.count() } @@ -132,9 +141,9 @@ subs: } pats: - for _, pat := range s.patterns { - if pat.MatchString(c) { - s.publish <- PubsubMessage{c, msg} + for orig, pat := range s.patterns { + if pat != nil && pat.MatchString(c) { + s.ppublish <- PubsubPmessage{orig, c, msg} found++ break pats } @@ -143,11 +152,18 @@ pats: return found } -// The channel to read messages for this subscriber +// The channel to read messages for this subscriber. Only for messages matching +// a SUBSCRIBE. func (s *Subscriber) Messages() <-chan PubsubMessage { return s.publish } +// The channel to read messages for this subscriber. Only for messages matching +// a PSUBSCRIBE. +func (s *Subscriber) Pmessages() <-chan PubsubPmessage { + return s.ppublish +} + // List all pubsub channels. If `pat` isn't empty channels names must match the // pattern. Channels are returned alphabetically. func activeChannels(subs []*Subscriber, pat string) []string { @@ -160,7 +176,7 @@ func activeChannels(subs []*Subscriber, pat string) []string { var cpat *regexp.Regexp if pat != "" { - cpat = compileChannelPattern(pat) + cpat = patternRE(pat) } var cs []string @@ -209,3 +225,16 @@ func monitorPublish(conn *server.Peer, msgs <-chan PubsubMessage) { }) } } + +func monitorPpublish(conn *server.Peer, msgs <-chan PubsubPmessage) { + for msg := range msgs { + conn.Block(func(c *server.Writer) { + c.WriteLen(4) + c.WriteBulk("pmessage") + c.WriteBulk(msg.Pattern) + c.WriteBulk(msg.Channel) + c.WriteBulk(msg.Message) + c.Flush() + }) + } +}