diff --git a/brain/brain.go b/brain/brain.go index 01ac2b1..634e28b 100644 --- a/brain/brain.go +++ b/brain/brain.go @@ -1,15 +1,51 @@ package brain import ( + "context" + "github.com/zephyrtronium/robot/message" "github.com/zephyrtronium/robot/userhash" ) -// Brain is a combined [Learner] and [Speaker]. -type Brain interface { - Learner - Speaker +// Interface is a store of learned messages which can reproduce them by parts. +// +// It must be safe to call all methods of a brain concurrently with each other. +type Interface interface { + // Learn records a set of tuples. + // + // One tuple has an empty prefix to denote the start of the message, and + // a different tuple has the empty string as its suffix to denote the end + // of the message. The positions of each in the argument are not guaranteed. + // + // Each tuple's prefix has entropy reduction transformations applied. + // + // Tuples in the argument may share storage for prefixes. + Learn(ctx context.Context, tag string, msg *Message, tuples []Tuple) error + + // Speak generates a full message and appends it to w. + // + // The prompt is in reverse order and has entropy reduction applied. + Speak(ctx context.Context, tag string, prompt []string, w *Builder) error + + // Forget forgets everything learned from a single given message. + // If nothing has been learned from the message, it must prevent anything + // from being learned from a message with that ID. + Forget(ctx context.Context, tag, id string) error + + // Recall reads out messages the brain knows. + // At minimum, the message ID and text of each message must be retrieved; + // other fields may be filled if they are available. + // + // Repeated calls using the pagination token returned from the previous + // must yield every message that the brain had recorded at the time of the + // first call exactly once. Messages learned after the first call of an + // enumeration are read at most once. + // + // The first call of an enumeration uses an empty pagination token as input. + // If the returned pagination token is empty, it is interpreted as the end + // of the enumeration. + Recall(ctx context.Context, tag, page string, out []Message) (n int, next string, err error) } -// Message is the message type used by a [Brain]. +// Message is the message type used by a [Interface]. type Message = message.Received[userhash.Hash] diff --git a/brain/braintest/bench.go b/brain/braintest/bench.go index ee0ba1f..ef704ab 100644 --- a/brain/braintest/bench.go +++ b/brain/braintest/bench.go @@ -20,7 +20,7 @@ func randid() string { // BenchLearn runs benchmarks on the brain's speed with recording new tuples. // The learner returned by new must be safe for concurrent use. -func BenchLearn(ctx context.Context, b *testing.B, new func(ctx context.Context, b *testing.B) brain.Learner, cleanup func(brain.Learner)) { +func BenchLearn(ctx context.Context, b *testing.B, new func(ctx context.Context, b *testing.B) brain.Interface, cleanup func(brain.Interface)) { b.Run("similar", func(b *testing.B) { l := new(ctx, b) if cleanup != nil { @@ -95,7 +95,7 @@ func BenchLearn(ctx context.Context, b *testing.B, new func(ctx context.Context, // BenchSpeak runs benchmarks on a brain's speed with generating messages // from tuples. The brain returned by new must be safe for concurrent use. -func BenchSpeak(ctx context.Context, b *testing.B, new func(ctx context.Context, b *testing.B) brain.Brain, cleanup func(brain.Brain)) { +func BenchSpeak(ctx context.Context, b *testing.B, new func(ctx context.Context, b *testing.B) brain.Interface, cleanup func(brain.Interface)) { sizes := []int64{1e3, 1e4, 1e5} for _, size := range sizes { b.Run(fmt.Sprintf("similar-new-%d", size), func(b *testing.B) { diff --git a/brain/braintest/braintest.go b/brain/braintest/braintest.go index dce841e..65bb38c 100644 --- a/brain/braintest/braintest.go +++ b/brain/braintest/braintest.go @@ -16,7 +16,7 @@ import ( // Test runs the integration test suite against brains produced by new. // // If a brain cannot be created without error, new should call t.Fatal. -func Test(ctx context.Context, t *testing.T, new func(context.Context) brain.Brain) { +func Test(ctx context.Context, t *testing.T, new func(context.Context) brain.Interface) { t.Run("speak", testSpeak(ctx, new(ctx))) t.Run("forgetMessage", testForget(ctx, new(ctx))) t.Run("combinatoric", testCombinatoric(ctx, new(ctx))) @@ -94,7 +94,7 @@ var messages = [...]struct { }, } -func learn(ctx context.Context, t *testing.T, br brain.Learner) { +func learn(ctx context.Context, t *testing.T, br brain.Interface) { t.Helper() for _, m := range messages { msg := brain.Message{ID: m.ID, Sender: m.User, Timestamp: m.Time.UnixMilli(), Text: m.Text} @@ -104,7 +104,7 @@ func learn(ctx context.Context, t *testing.T, br brain.Learner) { } } -func speak(ctx context.Context, t *testing.T, br brain.Speaker, tag, prompt string, iters int) map[string]struct{} { +func speak(ctx context.Context, t *testing.T, br brain.Interface, tag, prompt string, iters int) map[string]struct{} { t.Helper() got := make(map[string]struct{}, 20) for range iters { @@ -118,7 +118,7 @@ func speak(ctx context.Context, t *testing.T, br brain.Speaker, tag, prompt stri } // testSpeak tests that a brain can speak what it has learned. -func testSpeak(ctx context.Context, br brain.Brain) func(t *testing.T) { +func testSpeak(ctx context.Context, br brain.Interface) func(t *testing.T) { return func(t *testing.T) { learn(ctx, t, br) got := speak(ctx, t, br, "kessoku", "", 2048) @@ -177,7 +177,7 @@ func testSpeak(ctx context.Context, br brain.Brain) func(t *testing.T) { } // testForget tests that a brain can forget messages by ID. -func testForget(ctx context.Context, br brain.Brain) func(t *testing.T) { +func testForget(ctx context.Context, br brain.Interface) func(t *testing.T) { return func(t *testing.T) { learn(ctx, t, br) if err := br.Forget(ctx, "kessoku", messages[0].ID); err != nil { @@ -229,7 +229,7 @@ func testForget(ctx context.Context, br brain.Brain) func(t *testing.T) { // testCombinatoric tests that chains can generate even with substantial // overlap in learned material. -func testCombinatoric(ctx context.Context, br brain.Brain) func(t *testing.T) { +func testCombinatoric(ctx context.Context, br brain.Interface) func(t *testing.T) { return func(t *testing.T) { u := userhash.Hash{2} band := []string{"bocchi", "ryou", "nijika", "kita"} diff --git a/brain/braintest/braintest_test.go b/brain/braintest/braintest_test.go index 2c24bfa..156d0ff 100644 --- a/brain/braintest/braintest_test.go +++ b/brain/braintest/braintest_test.go @@ -22,7 +22,7 @@ type membrain struct { tms map[string]map[int64][]string // map of tags to map of timestamps to ids } -var _ brain.Brain = (*membrain)(nil) +var _ brain.Interface = (*membrain)(nil) func (m *membrain) Learn(ctx context.Context, tag string, msg *brain.Message, tuples []brain.Tuple) error { m.mu.Lock() @@ -107,5 +107,5 @@ func (m *membrain) Speak(ctx context.Context, tag string, prompt []string, w *br } func TestTests(t *testing.T) { - braintest.Test(context.Background(), t, func(ctx context.Context) brain.Brain { return new(membrain) }) + braintest.Test(context.Background(), t, func(ctx context.Context) brain.Interface { return new(membrain) }) } diff --git a/brain/kvbrain/kvbrain.go b/brain/kvbrain/kvbrain.go index 1ce2d43..a451a60 100644 --- a/brain/kvbrain/kvbrain.go +++ b/brain/kvbrain/kvbrain.go @@ -40,7 +40,7 @@ type Brain struct { past sync2.Map[string, *past] } -var _ brain.Learner = (*Brain)(nil) +var _ brain.Interface = (*Brain)(nil) func New(knowledge *badger.DB) *Brain { return &Brain{ diff --git a/brain/kvbrain/kvbrain_test.go b/brain/kvbrain/kvbrain_test.go index b4e7ffd..fefdbe7 100644 --- a/brain/kvbrain/kvbrain_test.go +++ b/brain/kvbrain/kvbrain_test.go @@ -12,7 +12,7 @@ import ( ) func TestBrain(t *testing.T) { - braintest.Test(context.Background(), t, func(ctx context.Context) brain.Brain { + braintest.Test(context.Background(), t, func(ctx context.Context) brain.Interface { db, err := badger.Open(badger.DefaultOptions("").WithInMemory(true).WithLogger(nil)) if err != nil { t.Fatal(err) diff --git a/brain/kvbrain/learn_test.go b/brain/kvbrain/learn_test.go index 9b71d21..6612754 100644 --- a/brain/kvbrain/learn_test.go +++ b/brain/kvbrain/learn_test.go @@ -142,14 +142,14 @@ func TestLearn(t *testing.T) { } func BenchmarkLearn(b *testing.B) { - new := func(ctx context.Context, b *testing.B) brain.Learner { + new := func(ctx context.Context, b *testing.B) brain.Interface { db, err := badger.Open(badger.DefaultOptions(b.TempDir()).WithLogger(nil)) if err != nil { b.Fatal(err) } return New(db) } - cleanup := func(l brain.Learner) { + cleanup := func(l brain.Interface) { br := l.(*Brain) if err := br.knowledge.DropAll(); err != nil { b.Fatal(err) diff --git a/brain/kvbrain/speak_test.go b/brain/kvbrain/speak_test.go index 052800b..548c757 100644 --- a/brain/kvbrain/speak_test.go +++ b/brain/kvbrain/speak_test.go @@ -157,14 +157,14 @@ func TestSpeak(t *testing.T) { } func BenchmarkSpeak(b *testing.B) { - new := func(ctx context.Context, b *testing.B) brain.Brain { + new := func(ctx context.Context, b *testing.B) brain.Interface { db, err := badger.Open(badger.DefaultOptions(b.TempDir()).WithLogger(nil).WithCompression(options.None).WithBloomFalsePositive(1.0 / 32).WithNumMemtables(16).WithLevelSizeMultiplier(4)) if err != nil { b.Fatal(err) } return New(db) } - cleanup := func(l brain.Brain) { + cleanup := func(l brain.Interface) { br := l.(*Brain) if err := br.knowledge.DropAll(); err != nil { b.Fatal(err) diff --git a/brain/learn.go b/brain/learn.go index c0d4c5c..01f3ab2 100644 --- a/brain/learn.go +++ b/brain/learn.go @@ -17,38 +17,10 @@ type Tuple struct { Suffix string } -// Learner records Markov chain tuples. -type Learner interface { - // Learn records a set of tuples. - // One tuple has an empty prefix to denote the start of the message, and - // a different tuple has the empty string as its suffix to denote the end - // of the message. The positions of each in the argument are not guaranteed. - // Each tuple's prefix has entropy reduction transformations applied. - // Tuples in the argument may share storage for prefixes. - Learn(ctx context.Context, tag string, msg *Message, tuples []Tuple) error - // Forget forgets everything learned from a single given message. - // If nothing has been learned from the message, it should prevent anything - // from being learned from a message with that ID. - Forget(ctx context.Context, tag, id string) error - // Recall reads out messages the brain knows. - // At minimum, the message ID and text of each message must be retrieved; - // other fields may be filled if they are available. - // It must be safe to call Recall concurrently with other methods of the - // implementation. - // Repeated calls using the pagination token returned from the previous - // must yield every message that the brain had recorded at the time of the - // first call exactly once. Messages learned after the first call of an - // enumeration are read at most once. - // The first call of an enumeration uses an empty pagination token. - // If the returned pagination token is empty, it is interpreted as the end - // of the enumeration. - Recall(ctx context.Context, tag, page string, out []Message) (n int, next string, err error) -} - var tuplesPool tpool.Pool[[]Tuple] -// Learn records a message into a Learner. -func Learn(ctx context.Context, l Learner, tag string, msg *Message) error { +// Learn records a message into a brain. +func Learn(ctx context.Context, l Interface, tag string, msg *Message) error { toks := tokens(tokensPool.Get(), msg.Text) defer func() { tokensPool.Put(toks[:0]) }() if len(toks) == 0 { @@ -77,7 +49,7 @@ func tupleToks(tt []Tuple, toks []string) []Tuple { } // Recall iterates over all messages a brain knows with a given tag. -func Recall(ctx context.Context, br Learner, tag string) iter.Seq2[Message, error] { +func Recall(ctx context.Context, br Interface, tag string) iter.Seq2[Message, error] { return func(yield func(Message, error) bool) { var ( page string diff --git a/brain/learn_test.go b/brain/learn_test.go index 20f4158..9a30a7d 100644 --- a/brain/learn_test.go +++ b/brain/learn_test.go @@ -27,6 +27,11 @@ func (t *testLearner) Forget(ctx context.Context, tag, id string) error { return nil } +// Speak implements brain.Brain. +func (t *testLearner) Speak(ctx context.Context, tag string, prompt []string, w *brain.Builder) error { + panic("unimplemented") +} + func (t *testLearner) Recall(ctx context.Context, tag string, page string, out []brain.Message) (n int, next string, err error) { var k int if page != "" { diff --git a/brain/speak.go b/brain/speak.go index 97486e9..d973fe1 100644 --- a/brain/speak.go +++ b/brain/speak.go @@ -9,13 +9,6 @@ import ( "github.com/zephyrtronium/robot/tpool" ) -// Speaker produces random messages. -type Speaker interface { - // Speak generates a full message and appends it to w. - // The prompt is in reverse order and has entropy reduction applied. - Speak(ctx context.Context, tag string, prompt []string, w *Builder) error -} - var ( tokensPool tpool.Pool[[]string] builderPool = tpool.Pool[*Builder]{New: func() any { return new(Builder) }} @@ -23,9 +16,9 @@ var ( // Speak produces a new message and the trace of messages used to form it // from the given prompt. -// If the speaker does not produce any terms, the result is the empty string +// If the brain does not produce any terms, the result is the empty string // regardless of the prompt, with no error. -func Speak(ctx context.Context, s Speaker, tag, prompt string) (string, []string, error) { +func Speak(ctx context.Context, s Interface, tag, prompt string) (string, []string, error) { w := builderPool.Get() toks := tokens(tokensPool.Get(), prompt) defer func() { diff --git a/brain/speak_test.go b/brain/speak_test.go index dba28f5..25b0588 100644 --- a/brain/speak_test.go +++ b/brain/speak_test.go @@ -8,6 +8,8 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "github.com/zephyrtronium/robot/brain" + "github.com/zephyrtronium/robot/message" + "github.com/zephyrtronium/robot/userhash" ) type testSpeaker struct { @@ -22,6 +24,21 @@ func (t *testSpeaker) Speak(ctx context.Context, tag string, prompt []string, w return nil } +// Forget implements brain.Brain. +func (t *testSpeaker) Forget(ctx context.Context, tag string, id string) error { + panic("unimplemented") +} + +// Learn implements brain.Brain. +func (t *testSpeaker) Learn(ctx context.Context, tag string, msg *message.Received[userhash.Hash], tuples []brain.Tuple) error { + panic("unimplemented") +} + +// Recall implements brain.Brain. +func (t *testSpeaker) Recall(ctx context.Context, tag string, page string, out []message.Received[userhash.Hash]) (n int, next string, err error) { + panic("unimplemented") +} + func TestSpeak(t *testing.T) { cases := []struct { name string diff --git a/brain/sqlbrain/brain_test.go b/brain/sqlbrain/brain_test.go index af4852d..dcdf3e0 100644 --- a/brain/sqlbrain/brain_test.go +++ b/brain/sqlbrain/brain_test.go @@ -30,13 +30,10 @@ func testDB(ctx context.Context) *sqlitex.Pool { return pool } -var _ brain.Learner = (*sqlbrain.Brain)(nil) -var _ brain.Speaker = (*sqlbrain.Brain)(nil) - func TestIntegrated(t *testing.T) { t.Parallel() ctx := context.Background() - new := func(ctx context.Context) brain.Brain { + new := func(ctx context.Context) brain.Interface { db := testDB(ctx) br, err := sqlbrain.Open(ctx, db) if err != nil { diff --git a/brain/sqlbrain/learn_test.go b/brain/sqlbrain/learn_test.go index 86f613e..a0d207b 100644 --- a/brain/sqlbrain/learn_test.go +++ b/brain/sqlbrain/learn_test.go @@ -570,7 +570,7 @@ func TestRecall(t *testing.T) { func BenchmarkLearn(b *testing.B) { dir := filepath.ToSlash(b.TempDir()) - new := func(ctx context.Context, b *testing.B) brain.Learner { + new := func(ctx context.Context, b *testing.B) brain.Interface { dsn := fmt.Sprintf("file:%s/benchmark_learn.db?_journal=WAL", dir) db, err := sqlitex.NewPool(dsn, sqlitex.PoolOptions{PrepareConn: sqlbrain.RecommendedPrep}) if err != nil { @@ -589,5 +589,5 @@ func BenchmarkLearn(b *testing.B) { } return br } - braintest.BenchLearn(context.Background(), b, new, func(l brain.Learner) { l.(*sqlbrain.Brain).Close() }) + braintest.BenchLearn(context.Background(), b, new, func(l brain.Interface) { l.(*sqlbrain.Brain).Close() }) } diff --git a/brain/sqlbrain/speak_test.go b/brain/sqlbrain/speak_test.go index 0c7dcdb..b168a20 100644 --- a/brain/sqlbrain/speak_test.go +++ b/brain/sqlbrain/speak_test.go @@ -518,7 +518,7 @@ func insert(t *testing.T, conn *sqlite.Conn, know []know, msgs []msg) { func BenchmarkSpeak(b *testing.B) { var dbs atomic.Uint64 - new := func(ctx context.Context, b *testing.B) brain.Brain { + new := func(ctx context.Context, b *testing.B) brain.Interface { k := dbs.Add(1) db, err := sqlitex.NewPool(fmt.Sprintf("file:%s/bench-%d.sql", b.TempDir(), k), sqlitex.PoolOptions{PrepareConn: sqlbrain.RecommendedPrep}) if err != nil { @@ -530,7 +530,7 @@ func BenchmarkSpeak(b *testing.B) { } return br } - cleanup := func(l brain.Brain) { + cleanup := func(l brain.Interface) { br := l.(*sqlbrain.Brain) br.Close() } diff --git a/command/command.go b/command/command.go index 38bd0d3..8709848 100644 --- a/command/command.go +++ b/command/command.go @@ -17,7 +17,7 @@ import ( type Robot struct { Log *slog.Logger Channels *syncmap.Map[string, *channel.Channel] - Brain brain.Brain + Brain brain.Interface Privacy *privacy.List Spoken *spoken.History Owner string diff --git a/main.go b/main.go index a878742..647e83a 100644 --- a/main.go +++ b/main.go @@ -151,7 +151,7 @@ func cliSpeak(ctx context.Context, cmd *cli.Command) error { if err != nil { return err } - var br brain.Brain + var br brain.Interface if sql == nil { if kv == nil { panic("robot: no brain") @@ -202,7 +202,7 @@ func cliAncient(ctx context.Context, cmd *cli.Command) error { if err != nil { return err } - var br brain.Brain + var br brain.Interface if sql == nil { if kv == nil { panic("robot: no brain") diff --git a/robot.go b/robot.go index 00785c8..6b1708d 100644 --- a/robot.go +++ b/robot.go @@ -28,7 +28,7 @@ import ( // Robot is the overall configuration for the bot. type Robot struct { // brain is the brain. - brain brain.Brain + brain brain.Interface // privacy is the privacy. privacy *privacy.List // spoken is the history of generated messages. diff --git a/tmi.go b/tmi.go index 567cee4..2832443 100644 --- a/tmi.go +++ b/tmi.go @@ -195,7 +195,7 @@ func (robo *Robot) clearmsg(ctx context.Context, msg *tmi.Message) { forget(ctx, log, robo.Metrics.ForgotCount, robo.brain, ch.Send, trace...) } -func forget(ctx context.Context, log *slog.Logger, forgetCount metrics.Observer, brain brain.Brain, tag string, trace ...string) { +func forget(ctx context.Context, log *slog.Logger, forgetCount metrics.Observer, brain brain.Interface, tag string, trace ...string) { forgetCount.Observe(1) for _, id := range trace { err := brain.Forget(ctx, tag, id)