-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathrobot.go
314 lines (300 loc) · 9.52 KB
/
robot.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
package main
import (
"context"
"crypto/tls"
"errors"
"fmt"
"log/slog"
"net/http"
"strings"
"time"
"gitlab.com/zephyrtronium/tmi"
"golang.org/x/sync/errgroup"
"golang.org/x/time/rate"
"github.com/zephyrtronium/robot/auth"
"github.com/zephyrtronium/robot/brain"
"github.com/zephyrtronium/robot/channel"
"github.com/zephyrtronium/robot/metrics"
"github.com/zephyrtronium/robot/pet"
"github.com/zephyrtronium/robot/privacy"
"github.com/zephyrtronium/robot/spoken"
"github.com/zephyrtronium/robot/syncmap"
"github.com/zephyrtronium/robot/twitch"
"github.com/zephyrtronium/robot/userhash"
)
// Robot is the overall configuration for the bot.
type Robot struct {
// brain is the brain.
brain brain.Interface
// privacy is the privacy.
privacy *privacy.List
// spoken is the history of generated messages.
spoken *spoken.History
// channels are the channels.
channels *syncmap.Map[string, *channel.Channel]
// works is the worker queue.
works chan chan func(context.Context)
// hashes is a function that obtains userhashers.
hashes func() userhash.Hasher
// owner is the username of the owner.
owner string
// ownerContact describes contact information for the owner.
ownerContact string
// tmi contains the bot's Twitch OAuth2 settings. It may be nil if there is
// no Twitch configuration.
tmi *client[*tmi.Message, *tmi.Message]
// twitch is the Twitch API client.
twitch twitch.Client
// pet is the robot's pet status.
pet pet.Status
// metrics are a collection of custom domain specific metrics.
metrics *metrics.Metrics
}
// client is the settings for OAuth2 and related elements.
type client[Send, Receive any] struct {
// send is the channel on which messages are sent.
send chan Send
// recv is the channel on which received messages are communicated.
recv chan Receive
// clientID is the OAuth2 application client ID.
clientID string
// name is the bot's username. The interpretation of this is domain-specific.
name string
// userID is the bot's user ID. The interpretation of this is domain-specific.
userID string
// owner is the user ID of the owner. The interpretation of this is
// domain-specific.
owner string
// rate is the global rate limiter for messages sent to this client.
rate *rate.Limiter
// tokens is the source of OAuth2 tokens.
tokens auth.TokenSource
}
// New creates a new robot instance.
func New(usersKey []byte, poolSize int) *Robot {
return &Robot{
channels: syncmap.New[string, *channel.Channel](),
works: make(chan chan func(context.Context), poolSize),
hashes: func() userhash.Hasher { return userhash.New(usersKey) },
metrics: newMetrics(),
}
}
func (robo *Robot) Run(ctx context.Context, listen string) error {
group, ctx := errgroup.WithContext(ctx)
// TODO(zeph): stdin?
if robo.tmi != nil {
group.Go(func() error { return robo.runTwitch(ctx, group) })
}
if listen != "" {
group.Go(func() error { return robo.api(ctx, listen, new(http.ServeMux), robo.metrics.Collectors()) })
}
err := group.Wait()
if err == context.Canceled {
// If the first error is context canceled, then we are shutting down
// normally in response to a sigint.
err = nil
}
return err
}
func (robo *Robot) runTwitch(ctx context.Context, group *errgroup.Group) error {
group.Go(func() error {
robo.tmiLoop(ctx, group, robo.tmi.send, robo.tmi.recv)
return nil
})
group.Go(func() error {
return robo.twitchValidateLoop(ctx)
})
group.Go(func() error {
return robo.streamsLoop(ctx, robo.channels)
})
tok, err := robo.tmi.tokens.Token(ctx)
if err != nil {
return err
}
for {
cfg := tmi.ConnectConfig{
Dial: new(tls.Dialer).DialContext,
RetryWait: tmi.RetryList(true, 0, time.Second, time.Minute, 5*time.Minute),
Nick: strings.ToLower(robo.tmi.name),
Pass: "oauth:" + tok.AccessToken,
Capabilities: []string{"twitch.tv/commands", "twitch.tv/tags"},
Timeout: 300 * time.Second,
}
err = tmi.Connect(ctx, cfg, &tmiSlog{slog.Default()}, robo.tmi.send, robo.tmi.recv)
switch {
case err == nil:
// We received a RECONNECT and exited normally. Do nothing.
// It's likely (though not guaranteed) we'll need a refresh,
// but we can worry about that when we're told to do it.
case errors.Is(err, tmi.ErrAuthenticationFailed):
tok, err = robo.tmi.tokens.Refresh(ctx, tok)
if err != nil {
return err
}
default:
return err
}
}
}
func (robo *Robot) twitchValidateLoop(ctx context.Context) error {
tm := time.NewTicker(time.Hour)
defer tm.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-tm.C: // continue below
}
tok, err := robo.tmi.tokens.Token(ctx)
if err != nil {
return fmt.Errorf("validation loop failed to get user access token: %w", err)
}
val, err := twitch.Validate(ctx, robo.twitch.HTTP, tok)
switch {
case err == nil:
slog.InfoContext(ctx, "validation loop",
slog.String("clientid", val.ClientID),
slog.String("userid", val.UserID),
slog.String("login", val.Login),
slog.Int("expires", val.ExpiresIn),
)
case errors.Is(err, twitch.ErrNeedRefresh):
_, err := robo.tmi.tokens.Refresh(ctx, tok)
if err != nil {
return fmt.Errorf("validation loop failed to refresh user access token: %w", err)
}
default:
if val != nil {
slog.ErrorContext(ctx, "validation loop", slog.Int("status", val.Status), slog.String("message", val.Message))
}
return fmt.Errorf("validation loop failed to validate user access token: %w", err)
}
}
}
func (robo *Robot) streamsLoop(ctx context.Context, channels *syncmap.Map[string, *channel.Channel]) error {
// TODO(zeph): one day we should switch to eventsub
// TODO(zeph): remove anything learned since the last check when offline
tok, err := robo.tmi.tokens.Token(ctx)
if err != nil {
return err
}
streams := make([]twitch.Stream, 0, channels.Len())
m := make(map[string]bool, channels.Len())
// Run once at the start so we start learning in online streams immediately.
streams = streams[:0]
for _, ch := range channels.All() {
n := strings.ToLower(strings.TrimPrefix(ch.Name, "#"))
streams = append(streams, twitch.Stream{UserLogin: n})
}
for range 5 {
// TODO(zeph): limit to 100
streams, err = twitch.UserStreams(ctx, robo.twitch, tok, streams)
switch {
case err == nil:
// Mark online streams as enabled.
// First map names to online status.
for _, s := range streams {
slog.DebugContext(ctx, "stream",
slog.String("login", s.UserLogin),
slog.String("display", s.UserName),
slog.String("id", s.UserID),
slog.String("type", s.Type),
)
n := strings.ToLower(s.UserLogin)
m[n] = true
}
// Now loop all streams.
for _, ch := range channels.All() {
n := strings.ToLower(strings.TrimPrefix(ch.Name, "#"))
ch.Enabled.Store(m[n])
}
case errors.Is(err, twitch.ErrNeedRefresh):
tok, err = robo.tmi.tokens.Refresh(ctx, tok)
if err != nil {
slog.ErrorContext(ctx, "failed to refresh token", slog.Any("err", err))
return fmt.Errorf("couldn't get valid access token: %w", err)
}
continue
default:
slog.ErrorContext(ctx, "failed to query online broadcasters", slog.Any("streams", streams), slog.Any("err", err))
// All streams are already offline.
}
break
}
streams = streams[:0]
clear(m)
tick := time.NewTicker(time.Minute)
go func() {
<-ctx.Done()
tick.Stop()
}()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-tick.C:
for _, ch := range channels.All() {
n := strings.TrimPrefix(ch.Name, "#")
streams = append(streams, twitch.Stream{UserLogin: n})
}
for range 5 {
// TODO(zeph): limit to 100
streams, err = twitch.UserStreams(ctx, robo.twitch, tok, streams)
switch {
case err == nil:
// Mark online streams as enabled.
// First map names to online status.
for _, s := range streams {
slog.DebugContext(ctx, "stream",
slog.String("login", s.UserLogin),
slog.String("display", s.UserName),
slog.String("id", s.UserID),
slog.String("type", s.Type),
)
n := strings.ToLower(s.UserLogin)
m[n] = true
}
// Now loop all streams.
for _, ch := range channels.All() {
n := strings.ToLower(strings.TrimPrefix(ch.Name, "#"))
ch.Enabled.Store(m[n])
}
case errors.Is(err, twitch.ErrNeedRefresh):
tok, err = robo.tmi.tokens.Refresh(ctx, tok)
if err != nil {
slog.ErrorContext(ctx, "failed to refresh token", slog.Any("err", err))
return fmt.Errorf("couldn't get valid access token: %w", err)
}
continue
default:
slog.ErrorContext(ctx, "failed to query online broadcasters", slog.Any("streams", streams), slog.Any("err", err))
// Set all streams as offline.
for _, ch := range channels.All() {
ch.Enabled.Store(false)
}
}
break
}
streams = streams[:0]
clear(m)
}
}
}
func deviceCodePrompt(userCode, verURI, verURIComplete string) {
fmt.Println("\n---- OAuth2 Device Code Flow ----")
if verURIComplete != "" {
fmt.Print(verURIComplete, "\n\nOR\n\n")
}
fmt.Println("Enter code at", verURI)
fmt.Printf("\n\t%s\n\n", userCode)
}
type tmiSlog struct {
l *slog.Logger
}
func (l *tmiSlog) Error(err error) { l.l.Error("TMI error", slog.String("err", err.Error())) }
func (l *tmiSlog) Status(s string) { l.l.Info("TMI status", slog.String("message", s)) }
func (l *tmiSlog) Send(s string) { l.l.Debug("TMI send", slog.String("message", s)) }
func (l *tmiSlog) Recv(s string) { l.l.Debug("TMI recv", slog.String("message", s)) }
func (l *tmiSlog) Ping(s string) {
l.l.Log(context.Background(), slog.LevelDebug-1, "TMI ping", slog.String("message", s))
}