From 070afdc16d7f94a4827e00f6db2482a2905acd08 Mon Sep 17 00:00:00 2001 From: Kar Date: Thu, 29 Aug 2024 04:26:18 +0200 Subject: [PATCH] feat: syncmap for channels (#71) * feat: syncmap for channels * fix syncmap.Map#All * fix tests --- command/command.go | 3 +- command/echo.go | 2 +- config.go | 2 +- privmsg.go | 2 +- robot.go | 21 +++---- syncmap/syncmap.go | 66 +++++++++++++++++++++ syncmap/syncmap_test.go | 128 ++++++++++++++++++++++++++++++++++++++++ tmi.go | 8 +-- 8 files changed, 214 insertions(+), 18 deletions(-) create mode 100644 syncmap/syncmap.go create mode 100644 syncmap/syncmap_test.go diff --git a/command/command.go b/command/command.go index ec9e745..f551e92 100644 --- a/command/command.go +++ b/command/command.go @@ -9,13 +9,14 @@ import ( "github.com/zephyrtronium/robot/message" "github.com/zephyrtronium/robot/privacy" "github.com/zephyrtronium/robot/spoken" + "github.com/zephyrtronium/robot/syncmap" "github.com/zephyrtronium/robot/userhash" ) // Robot is the bot state as is visible to commands. type Robot struct { Log *slog.Logger - Channels map[string]*channel.Channel // TODO(zeph): syncmap[string]channel.Channel + Channels *syncmap.Map[string, *channel.Channel] Brain brain.Brain Privacy *privacy.List Spoken *spoken.History diff --git a/command/echo.go b/command/echo.go index 268dc0a..d73e6a8 100644 --- a/command/echo.go +++ b/command/echo.go @@ -10,7 +10,7 @@ import ( // - msg: Message to send. func EchoIn(ctx context.Context, robo *Robot, call *Invocation) { t := call.Args["in"] - ch := robo.Channels[t] + ch, _ := robo.Channels.Load(t) if ch == nil { robo.Log.WarnContext(ctx, "echo into unknown channel", slog.String("target", t)) return diff --git a/config.go b/config.go index 79716be..120b908 100644 --- a/config.go +++ b/config.go @@ -307,7 +307,7 @@ func (robo *Robot) SetTwitchChannels(ctx context.Context, global Global, channel msg := message.Format(reply, v.Name, "%s", text) robo.sendTMI(ctx, robo.tmi.send, msg) } - robo.channels[p] = v + robo.channels.Store(p, v) } } return nil diff --git a/privmsg.go b/privmsg.go index d0e8963..4964211 100644 --- a/privmsg.go +++ b/privmsg.go @@ -23,7 +23,7 @@ import ( // tmiMessage processes a PRIVMSG from TMI. func (robo *Robot) tmiMessage(ctx context.Context, group *errgroup.Group, send chan<- *tmi.Message, msg *tmi.Message) { - ch := robo.channels[msg.To()] + ch, _ := robo.channels.Load(msg.To()) if ch == nil { // TMI gives a WHISPER for a direct message, so this is a message to a // channel that isn't configured. Ignore it. diff --git a/robot.go b/robot.go index 894d4dd..c1297d8 100644 --- a/robot.go +++ b/robot.go @@ -18,6 +18,7 @@ import ( "github.com/zephyrtronium/robot/channel" "github.com/zephyrtronium/robot/privacy" "github.com/zephyrtronium/robot/spoken" + "github.com/zephyrtronium/robot/syncmap" "github.com/zephyrtronium/robot/twitch" ) @@ -30,7 +31,7 @@ type Robot struct { // spoken is the history of generated messages. spoken *spoken.History // channels are the channels. - channels map[string]*channel.Channel // TODO(zeph): syncmap[string]channel.Channel + channels *syncmap.Map[string, *channel.Channel] // works is the worker queue. works chan chan func(context.Context) // secrets are the bot's keys. @@ -71,7 +72,7 @@ type client[Send, Receive any] struct { // to initialize the robot. func New(poolSize int) *Robot { return &Robot{ - channels: make(map[string]*channel.Channel), + channels: syncmap.New[string, *channel.Channel](), works: make(chan chan func(context.Context), poolSize), } } @@ -154,18 +155,18 @@ func (robo *Robot) twitchValidateLoop(ctx context.Context) error { } } -func (robo *Robot) streamsLoop(ctx context.Context, channels map[string]*channel.Channel) error { +func (robo *Robot) streamsLoop(ctx context.Context, channels *syncmap.Map[string, *channel.Channel]) error { // TODO(zeph): one day we should switch to eventsub // TODO(zeph): remove anything learned since the last check when offline tok, err := robo.tmi.tokens.Token(ctx) if err != nil { return err } - streams := make([]twitch.Stream, 0, len(channels)) - m := make(map[string]bool, len(channels)) + streams := make([]twitch.Stream, 0, channels.Len()) + m := make(map[string]bool, channels.Len()) // Run once at the start so we start learning in online streams immediately. streams = streams[:0] - for _, ch := range channels { + for _, ch := range channels.All() { n := strings.ToLower(strings.TrimPrefix(ch.Name, "#")) streams = append(streams, twitch.Stream{UserLogin: n}) } @@ -188,7 +189,7 @@ func (robo *Robot) streamsLoop(ctx context.Context, channels map[string]*channel m[n] = true } // Now loop all streams. - for _, ch := range channels { + for _, ch := range channels.All() { n := strings.ToLower(strings.TrimPrefix(ch.Name, "#")) ch.Enabled.Store(m[n]) } @@ -218,7 +219,7 @@ func (robo *Robot) streamsLoop(ctx context.Context, channels map[string]*channel case <-ctx.Done(): return ctx.Err() case <-tick.C: - for _, ch := range channels { + for _, ch := range channels.All() { n := strings.TrimPrefix(ch.Name, "#") streams = append(streams, twitch.Stream{UserLogin: n}) } @@ -241,7 +242,7 @@ func (robo *Robot) streamsLoop(ctx context.Context, channels map[string]*channel m[n] = true } // Now loop all streams. - for _, ch := range channels { + for _, ch := range channels.All() { n := strings.ToLower(strings.TrimPrefix(ch.Name, "#")) ch.Enabled.Store(m[n]) } @@ -255,7 +256,7 @@ func (robo *Robot) streamsLoop(ctx context.Context, channels map[string]*channel default: slog.ErrorContext(ctx, "failed to query online broadcasters", slog.Any("streams", streams), slog.Any("err", err)) // Set all streams as offline. - for _, ch := range channels { + for _, ch := range channels.All() { ch.Enabled.Store(false) } } diff --git a/syncmap/syncmap.go b/syncmap/syncmap.go new file mode 100644 index 0000000..e626adc --- /dev/null +++ b/syncmap/syncmap.go @@ -0,0 +1,66 @@ +package syncmap + +import ( + "iter" + "sync" +) + +// Map is a regular map but synchronized with a mutex. +type Map[K comparable, V any] struct { + mu sync.Mutex + m map[K]V +} + +// New returns a new syncmap. +func New[K comparable, V any]() *Map[K, V] { + return &Map[K, V]{ + m: make(map[K]V), + } +} + +// Load returns the value for a key. +func (m *Map[K, V]) Load(key K) (V, bool) { + m.mu.Lock() + defer m.mu.Unlock() + v, ok := m.m[key] + return v, ok +} + +// Store sets the value for a key. +func (m *Map[K, V]) Store(key K, value V) { + m.mu.Lock() + defer m.mu.Unlock() + m.m[key] = value +} + +// Delete deletes a key. +func (m *Map[K, V]) Delete(key K) { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.m, key) +} + +// Len returns the number of elements in the map. +func (m *Map[K, V]) Len() int { + m.mu.Lock() + defer m.mu.Unlock() + return len(m.m) +} + +// All iterates over all elements in the map +func (m *Map[K, V]) All() iter.Seq2[K, V] { + return func(f func(K, V) bool) { + m.mu.Lock() + for k, v := range m.m { + m.mu.Unlock() + if !f(k, v) { + m.mu.Lock() + break + } + + m.mu.Lock() + } + + m.mu.Unlock() + } +} diff --git a/syncmap/syncmap_test.go b/syncmap/syncmap_test.go new file mode 100644 index 0000000..a13e05b --- /dev/null +++ b/syncmap/syncmap_test.go @@ -0,0 +1,128 @@ +package syncmap + +import ( + "fmt" + "math/rand/v2" + "sync" + "testing" + "testing/quick" + "time" +) + +// TestMap_All requires the race detector to be useful +func TestMap_All(t *testing.T) { + m := New[string, int]() + + // Test empty map + count := 0 + for range m.All() { + count++ + } + if count != 0 { + t.Errorf("Expected 0 elements in empty map, got %d", count) + } + + // Test with elements + testData := map[string]int{ + "one": 1, + "two": 2, + "three": 3, + } + + for k, v := range testData { + m.Store(k, v) + } + + foundItems := make(map[string]int) + for k, v := range m.All() { + foundItems[k] = v + } + + if len(foundItems) != len(testData) { + t.Errorf("Expected %d elements, got %d", len(testData), len(foundItems)) + } + + for k, v := range testData { + if foundV, ok := foundItems[k]; !ok || foundV != v { + t.Errorf("Missing or incorrect value for key %s: expected %d, got %d", k, v, foundV) + } + } + + // Test early termination + count = 0 + m.All()(func(k string, v int) bool { + count++ + return count < 2 + }) + if count != 2 { + t.Errorf("Expected early termination after 2 elements, got %d", count) + } + + // Test concurrent modification + var editwg sync.WaitGroup + var done = make(chan struct{}) + + editwg.Add(1) + go func() { + defer editwg.Done() + start := time.Now() + for time.Since(start) < time.Second { + switch rand.IntN(3) { + case 0: + m.Store(fmt.Sprintf("key%d", rand.IntN(100)), rand.IntN(1000)) + case 1: + m.Delete(fmt.Sprintf("key%d", rand.IntN(100))) + case 2: + m.Store(fmt.Sprintf("key%d", rand.IntN(100)), rand.IntN(1000)) + m.Delete(fmt.Sprintf("key%d", rand.IntN(100))) + } + } + }() + + go func() { + for { + for k, v := range m.All() { + foundItems[k] = v // sink + } + + select { + case <-done: + return + default: + } + } + }() + + editwg.Wait() + close(done) +} + +func TestMap_All_Quick(t *testing.T) { + f := func(entries map[string]int) bool { + m := New[string, int]() + for k, v := range entries { + m.Store(k, v) + } + + seen := make(map[string]int) + for k, v := range m.All() { + seen[k] = v + } + + if len(seen) != len(entries) { + return false + } + + for k, v := range entries { + if seen[k] != v { + return false + } + } + + return true + } + + if err := quick.Check(f, nil); err != nil { + t.Error(err) + } +} diff --git a/tmi.go b/tmi.go index 4006c07..e155d7f 100644 --- a/tmi.go +++ b/tmi.go @@ -48,8 +48,8 @@ func (robo *Robot) tmiLoop(ctx context.Context, group *errgroup.Group, send chan } func (robo *Robot) joinTwitch(ctx context.Context, send chan<- *tmi.Message) { - ls := make([]string, 0, len(robo.channels)) - for _, ch := range robo.channels { + ls := make([]string, 0, robo.channels.Len()) + for _, ch := range robo.channels.All() { ls = append(ls, ch.Name) } burst := 20 @@ -79,7 +79,7 @@ func (robo *Robot) clearchat(ctx context.Context, group *errgroup.Group, msg *tm if len(msg.Params) == 0 { return } - ch := robo.channels[msg.To()] + ch, _ := robo.channels.Load(msg.To()) if ch == nil { return } @@ -144,7 +144,7 @@ func (robo *Robot) clearmsg(ctx context.Context, group *errgroup.Group, msg *tmi if len(msg.Params) == 0 { return } - ch := robo.channels[msg.To()] + ch, _ := robo.channels.Load(msg.To()) if ch == nil { return }