Skip to content

Commit

Permalink
feat: syncmap for channels (#71)
Browse files Browse the repository at this point in the history
* feat: syncmap for channels

* fix syncmap.Map#All

* fix tests
  • Loading branch information
karitham authored Aug 29, 2024
1 parent 2d2d681 commit 070afdc
Show file tree
Hide file tree
Showing 8 changed files with 214 additions and 18 deletions.
3 changes: 2 additions & 1 deletion command/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@ import (
"github.com/zephyrtronium/robot/message"
"github.com/zephyrtronium/robot/privacy"
"github.com/zephyrtronium/robot/spoken"
"github.com/zephyrtronium/robot/syncmap"
"github.com/zephyrtronium/robot/userhash"
)

// Robot is the bot state as is visible to commands.
type Robot struct {
Log *slog.Logger
Channels map[string]*channel.Channel // TODO(zeph): syncmap[string]channel.Channel
Channels *syncmap.Map[string, *channel.Channel]
Brain brain.Brain
Privacy *privacy.List
Spoken *spoken.History
Expand Down
2 changes: 1 addition & 1 deletion command/echo.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
// - msg: Message to send.
func EchoIn(ctx context.Context, robo *Robot, call *Invocation) {
t := call.Args["in"]
ch := robo.Channels[t]
ch, _ := robo.Channels.Load(t)
if ch == nil {
robo.Log.WarnContext(ctx, "echo into unknown channel", slog.String("target", t))
return
Expand Down
2 changes: 1 addition & 1 deletion config.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ func (robo *Robot) SetTwitchChannels(ctx context.Context, global Global, channel
msg := message.Format(reply, v.Name, "%s", text)
robo.sendTMI(ctx, robo.tmi.send, msg)
}
robo.channels[p] = v
robo.channels.Store(p, v)
}
}
return nil
Expand Down
2 changes: 1 addition & 1 deletion privmsg.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (

// tmiMessage processes a PRIVMSG from TMI.
func (robo *Robot) tmiMessage(ctx context.Context, group *errgroup.Group, send chan<- *tmi.Message, msg *tmi.Message) {
ch := robo.channels[msg.To()]
ch, _ := robo.channels.Load(msg.To())
if ch == nil {
// TMI gives a WHISPER for a direct message, so this is a message to a
// channel that isn't configured. Ignore it.
Expand Down
21 changes: 11 additions & 10 deletions robot.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/zephyrtronium/robot/channel"
"github.com/zephyrtronium/robot/privacy"
"github.com/zephyrtronium/robot/spoken"
"github.com/zephyrtronium/robot/syncmap"
"github.com/zephyrtronium/robot/twitch"
)

Expand All @@ -30,7 +31,7 @@ type Robot struct {
// spoken is the history of generated messages.
spoken *spoken.History
// channels are the channels.
channels map[string]*channel.Channel // TODO(zeph): syncmap[string]channel.Channel
channels *syncmap.Map[string, *channel.Channel]
// works is the worker queue.
works chan chan func(context.Context)
// secrets are the bot's keys.
Expand Down Expand Up @@ -71,7 +72,7 @@ type client[Send, Receive any] struct {
// to initialize the robot.
func New(poolSize int) *Robot {
return &Robot{
channels: make(map[string]*channel.Channel),
channels: syncmap.New[string, *channel.Channel](),
works: make(chan chan func(context.Context), poolSize),
}
}
Expand Down Expand Up @@ -154,18 +155,18 @@ func (robo *Robot) twitchValidateLoop(ctx context.Context) error {
}
}

func (robo *Robot) streamsLoop(ctx context.Context, channels map[string]*channel.Channel) error {
func (robo *Robot) streamsLoop(ctx context.Context, channels *syncmap.Map[string, *channel.Channel]) error {
// TODO(zeph): one day we should switch to eventsub
// TODO(zeph): remove anything learned since the last check when offline
tok, err := robo.tmi.tokens.Token(ctx)
if err != nil {
return err
}
streams := make([]twitch.Stream, 0, len(channels))
m := make(map[string]bool, len(channels))
streams := make([]twitch.Stream, 0, channels.Len())
m := make(map[string]bool, channels.Len())
// Run once at the start so we start learning in online streams immediately.
streams = streams[:0]
for _, ch := range channels {
for _, ch := range channels.All() {
n := strings.ToLower(strings.TrimPrefix(ch.Name, "#"))
streams = append(streams, twitch.Stream{UserLogin: n})
}
Expand All @@ -188,7 +189,7 @@ func (robo *Robot) streamsLoop(ctx context.Context, channels map[string]*channel
m[n] = true
}
// Now loop all streams.
for _, ch := range channels {
for _, ch := range channels.All() {
n := strings.ToLower(strings.TrimPrefix(ch.Name, "#"))
ch.Enabled.Store(m[n])
}
Expand Down Expand Up @@ -218,7 +219,7 @@ func (robo *Robot) streamsLoop(ctx context.Context, channels map[string]*channel
case <-ctx.Done():
return ctx.Err()
case <-tick.C:
for _, ch := range channels {
for _, ch := range channels.All() {
n := strings.TrimPrefix(ch.Name, "#")
streams = append(streams, twitch.Stream{UserLogin: n})
}
Expand All @@ -241,7 +242,7 @@ func (robo *Robot) streamsLoop(ctx context.Context, channels map[string]*channel
m[n] = true
}
// Now loop all streams.
for _, ch := range channels {
for _, ch := range channels.All() {
n := strings.ToLower(strings.TrimPrefix(ch.Name, "#"))
ch.Enabled.Store(m[n])
}
Expand All @@ -255,7 +256,7 @@ func (robo *Robot) streamsLoop(ctx context.Context, channels map[string]*channel
default:
slog.ErrorContext(ctx, "failed to query online broadcasters", slog.Any("streams", streams), slog.Any("err", err))
// Set all streams as offline.
for _, ch := range channels {
for _, ch := range channels.All() {
ch.Enabled.Store(false)
}
}
Expand Down
66 changes: 66 additions & 0 deletions syncmap/syncmap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package syncmap

import (
"iter"
"sync"
)

// Map is a regular map but synchronized with a mutex.
type Map[K comparable, V any] struct {
mu sync.Mutex
m map[K]V
}

// New returns a new syncmap.
func New[K comparable, V any]() *Map[K, V] {
return &Map[K, V]{
m: make(map[K]V),
}
}

// Load returns the value for a key.
func (m *Map[K, V]) Load(key K) (V, bool) {
m.mu.Lock()
defer m.mu.Unlock()
v, ok := m.m[key]
return v, ok
}

// Store sets the value for a key.
func (m *Map[K, V]) Store(key K, value V) {
m.mu.Lock()
defer m.mu.Unlock()
m.m[key] = value
}

// Delete deletes a key.
func (m *Map[K, V]) Delete(key K) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.m, key)
}

// Len returns the number of elements in the map.
func (m *Map[K, V]) Len() int {
m.mu.Lock()
defer m.mu.Unlock()
return len(m.m)
}

// All iterates over all elements in the map
func (m *Map[K, V]) All() iter.Seq2[K, V] {
return func(f func(K, V) bool) {
m.mu.Lock()
for k, v := range m.m {
m.mu.Unlock()
if !f(k, v) {
m.mu.Lock()
break
}

m.mu.Lock()
}

m.mu.Unlock()
}
}
128 changes: 128 additions & 0 deletions syncmap/syncmap_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package syncmap

import (
"fmt"
"math/rand/v2"
"sync"
"testing"
"testing/quick"
"time"
)

// TestMap_All requires the race detector to be useful
func TestMap_All(t *testing.T) {
m := New[string, int]()

// Test empty map
count := 0
for range m.All() {
count++
}
if count != 0 {
t.Errorf("Expected 0 elements in empty map, got %d", count)
}

// Test with elements
testData := map[string]int{
"one": 1,
"two": 2,
"three": 3,
}

for k, v := range testData {
m.Store(k, v)
}

foundItems := make(map[string]int)
for k, v := range m.All() {
foundItems[k] = v
}

if len(foundItems) != len(testData) {
t.Errorf("Expected %d elements, got %d", len(testData), len(foundItems))
}

for k, v := range testData {
if foundV, ok := foundItems[k]; !ok || foundV != v {
t.Errorf("Missing or incorrect value for key %s: expected %d, got %d", k, v, foundV)
}
}

// Test early termination
count = 0
m.All()(func(k string, v int) bool {
count++
return count < 2
})
if count != 2 {
t.Errorf("Expected early termination after 2 elements, got %d", count)
}

// Test concurrent modification
var editwg sync.WaitGroup
var done = make(chan struct{})

editwg.Add(1)
go func() {
defer editwg.Done()
start := time.Now()
for time.Since(start) < time.Second {
switch rand.IntN(3) {
case 0:
m.Store(fmt.Sprintf("key%d", rand.IntN(100)), rand.IntN(1000))
case 1:
m.Delete(fmt.Sprintf("key%d", rand.IntN(100)))
case 2:
m.Store(fmt.Sprintf("key%d", rand.IntN(100)), rand.IntN(1000))
m.Delete(fmt.Sprintf("key%d", rand.IntN(100)))
}
}
}()

go func() {
for {
for k, v := range m.All() {
foundItems[k] = v // sink
}

select {
case <-done:
return
default:
}
}
}()

editwg.Wait()
close(done)
}

func TestMap_All_Quick(t *testing.T) {
f := func(entries map[string]int) bool {
m := New[string, int]()
for k, v := range entries {
m.Store(k, v)
}

seen := make(map[string]int)
for k, v := range m.All() {
seen[k] = v
}

if len(seen) != len(entries) {
return false
}

for k, v := range entries {
if seen[k] != v {
return false
}
}

return true
}

if err := quick.Check(f, nil); err != nil {
t.Error(err)
}
}
8 changes: 4 additions & 4 deletions tmi.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ func (robo *Robot) tmiLoop(ctx context.Context, group *errgroup.Group, send chan
}

func (robo *Robot) joinTwitch(ctx context.Context, send chan<- *tmi.Message) {
ls := make([]string, 0, len(robo.channels))
for _, ch := range robo.channels {
ls := make([]string, 0, robo.channels.Len())
for _, ch := range robo.channels.All() {
ls = append(ls, ch.Name)
}
burst := 20
Expand Down Expand Up @@ -79,7 +79,7 @@ func (robo *Robot) clearchat(ctx context.Context, group *errgroup.Group, msg *tm
if len(msg.Params) == 0 {
return
}
ch := robo.channels[msg.To()]
ch, _ := robo.channels.Load(msg.To())
if ch == nil {
return
}
Expand Down Expand Up @@ -144,7 +144,7 @@ func (robo *Robot) clearmsg(ctx context.Context, group *errgroup.Group, msg *tmi
if len(msg.Params) == 0 {
return
}
ch := robo.channels[msg.To()]
ch, _ := robo.channels.Load(msg.To())
if ch == nil {
return
}
Expand Down

0 comments on commit 070afdc

Please sign in to comment.