Skip to content

Commit

Permalink
brain: remove separate Learner and Speaker interfaces
Browse files Browse the repository at this point in the history
For #94.
  • Loading branch information
zephyrtronium committed Dec 14, 2024
1 parent 0d04fc7 commit aab95b8
Show file tree
Hide file tree
Showing 19 changed files with 94 additions and 74 deletions.
46 changes: 41 additions & 5 deletions brain/brain.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,51 @@
package brain

import (
"context"

"github.com/zephyrtronium/robot/message"
"github.com/zephyrtronium/robot/userhash"
)

// Brain is a combined [Learner] and [Speaker].
type Brain interface {
Learner
Speaker
// Interface is a store of learned messages which can reproduce them by parts.
//
// It must be safe to call all methods of a brain concurrently with each other.
type Interface interface {
// Learn records a set of tuples.
//
// One tuple has an empty prefix to denote the start of the message, and
// a different tuple has the empty string as its suffix to denote the end
// of the message. The positions of each in the argument are not guaranteed.
//
// Each tuple's prefix has entropy reduction transformations applied.
//
// Tuples in the argument may share storage for prefixes.
Learn(ctx context.Context, tag string, msg *Message, tuples []Tuple) error

// Speak generates a full message and appends it to w.
//
// The prompt is in reverse order and has entropy reduction applied.
Speak(ctx context.Context, tag string, prompt []string, w *Builder) error

// Forget forgets everything learned from a single given message.
// If nothing has been learned from the message, it must prevent anything
// from being learned from a message with that ID.
Forget(ctx context.Context, tag, id string) error

// Recall reads out messages the brain knows.
// At minimum, the message ID and text of each message must be retrieved;
// other fields may be filled if they are available.
//
// Repeated calls using the pagination token returned from the previous
// must yield every message that the brain had recorded at the time of the
// first call exactly once. Messages learned after the first call of an
// enumeration are read at most once.
//
// The first call of an enumeration uses an empty pagination token as input.
// If the returned pagination token is empty, it is interpreted as the end
// of the enumeration.
Recall(ctx context.Context, tag, page string, out []Message) (n int, next string, err error)
}

// Message is the message type used by a [Brain].
// Message is the message type used by a [Interface].
type Message = message.Received[userhash.Hash]
4 changes: 2 additions & 2 deletions brain/braintest/bench.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func randid() string {

// BenchLearn runs benchmarks on the brain's speed with recording new tuples.
// The learner returned by new must be safe for concurrent use.
func BenchLearn(ctx context.Context, b *testing.B, new func(ctx context.Context, b *testing.B) brain.Learner, cleanup func(brain.Learner)) {
func BenchLearn(ctx context.Context, b *testing.B, new func(ctx context.Context, b *testing.B) brain.Interface, cleanup func(brain.Interface)) {
b.Run("similar", func(b *testing.B) {
l := new(ctx, b)
if cleanup != nil {
Expand Down Expand Up @@ -95,7 +95,7 @@ func BenchLearn(ctx context.Context, b *testing.B, new func(ctx context.Context,

// BenchSpeak runs benchmarks on a brain's speed with generating messages
// from tuples. The brain returned by new must be safe for concurrent use.
func BenchSpeak(ctx context.Context, b *testing.B, new func(ctx context.Context, b *testing.B) brain.Brain, cleanup func(brain.Brain)) {
func BenchSpeak(ctx context.Context, b *testing.B, new func(ctx context.Context, b *testing.B) brain.Interface, cleanup func(brain.Interface)) {
sizes := []int64{1e3, 1e4, 1e5}
for _, size := range sizes {
b.Run(fmt.Sprintf("similar-new-%d", size), func(b *testing.B) {
Expand Down
12 changes: 6 additions & 6 deletions brain/braintest/braintest.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
// Test runs the integration test suite against brains produced by new.
//
// If a brain cannot be created without error, new should call t.Fatal.
func Test(ctx context.Context, t *testing.T, new func(context.Context) brain.Brain) {
func Test(ctx context.Context, t *testing.T, new func(context.Context) brain.Interface) {
t.Run("speak", testSpeak(ctx, new(ctx)))
t.Run("forgetMessage", testForget(ctx, new(ctx)))
t.Run("combinatoric", testCombinatoric(ctx, new(ctx)))
Expand Down Expand Up @@ -94,7 +94,7 @@ var messages = [...]struct {
},
}

func learn(ctx context.Context, t *testing.T, br brain.Learner) {
func learn(ctx context.Context, t *testing.T, br brain.Interface) {
t.Helper()
for _, m := range messages {
msg := brain.Message{ID: m.ID, Sender: m.User, Timestamp: m.Time.UnixMilli(), Text: m.Text}
Expand All @@ -104,7 +104,7 @@ func learn(ctx context.Context, t *testing.T, br brain.Learner) {
}
}

func speak(ctx context.Context, t *testing.T, br brain.Speaker, tag, prompt string, iters int) map[string]struct{} {
func speak(ctx context.Context, t *testing.T, br brain.Interface, tag, prompt string, iters int) map[string]struct{} {
t.Helper()
got := make(map[string]struct{}, 20)
for range iters {
Expand All @@ -118,7 +118,7 @@ func speak(ctx context.Context, t *testing.T, br brain.Speaker, tag, prompt stri
}

// testSpeak tests that a brain can speak what it has learned.
func testSpeak(ctx context.Context, br brain.Brain) func(t *testing.T) {
func testSpeak(ctx context.Context, br brain.Interface) func(t *testing.T) {
return func(t *testing.T) {
learn(ctx, t, br)
got := speak(ctx, t, br, "kessoku", "", 2048)
Expand Down Expand Up @@ -177,7 +177,7 @@ func testSpeak(ctx context.Context, br brain.Brain) func(t *testing.T) {
}

// testForget tests that a brain can forget messages by ID.
func testForget(ctx context.Context, br brain.Brain) func(t *testing.T) {
func testForget(ctx context.Context, br brain.Interface) func(t *testing.T) {
return func(t *testing.T) {
learn(ctx, t, br)
if err := br.Forget(ctx, "kessoku", messages[0].ID); err != nil {
Expand Down Expand Up @@ -229,7 +229,7 @@ func testForget(ctx context.Context, br brain.Brain) func(t *testing.T) {

// testCombinatoric tests that chains can generate even with substantial
// overlap in learned material.
func testCombinatoric(ctx context.Context, br brain.Brain) func(t *testing.T) {
func testCombinatoric(ctx context.Context, br brain.Interface) func(t *testing.T) {
return func(t *testing.T) {
u := userhash.Hash{2}
band := []string{"bocchi", "ryou", "nijika", "kita"}
Expand Down
4 changes: 2 additions & 2 deletions brain/braintest/braintest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type membrain struct {
tms map[string]map[int64][]string // map of tags to map of timestamps to ids
}

var _ brain.Brain = (*membrain)(nil)
var _ brain.Interface = (*membrain)(nil)

func (m *membrain) Learn(ctx context.Context, tag string, msg *brain.Message, tuples []brain.Tuple) error {
m.mu.Lock()
Expand Down Expand Up @@ -107,5 +107,5 @@ func (m *membrain) Speak(ctx context.Context, tag string, prompt []string, w *br
}

func TestTests(t *testing.T) {
braintest.Test(context.Background(), t, func(ctx context.Context) brain.Brain { return new(membrain) })
braintest.Test(context.Background(), t, func(ctx context.Context) brain.Interface { return new(membrain) })
}
2 changes: 1 addition & 1 deletion brain/kvbrain/kvbrain.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ type Brain struct {
past sync2.Map[string, *past]
}

var _ brain.Learner = (*Brain)(nil)
var _ brain.Interface = (*Brain)(nil)

func New(knowledge *badger.DB) *Brain {
return &Brain{
Expand Down
2 changes: 1 addition & 1 deletion brain/kvbrain/kvbrain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
)

func TestBrain(t *testing.T) {
braintest.Test(context.Background(), t, func(ctx context.Context) brain.Brain {
braintest.Test(context.Background(), t, func(ctx context.Context) brain.Interface {
db, err := badger.Open(badger.DefaultOptions("").WithInMemory(true).WithLogger(nil))
if err != nil {
t.Fatal(err)
Expand Down
4 changes: 2 additions & 2 deletions brain/kvbrain/learn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,14 @@ func TestLearn(t *testing.T) {
}

func BenchmarkLearn(b *testing.B) {
new := func(ctx context.Context, b *testing.B) brain.Learner {
new := func(ctx context.Context, b *testing.B) brain.Interface {
db, err := badger.Open(badger.DefaultOptions(b.TempDir()).WithLogger(nil))
if err != nil {
b.Fatal(err)
}
return New(db)
}
cleanup := func(l brain.Learner) {
cleanup := func(l brain.Interface) {
br := l.(*Brain)
if err := br.knowledge.DropAll(); err != nil {
b.Fatal(err)
Expand Down
4 changes: 2 additions & 2 deletions brain/kvbrain/speak_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,14 @@ func TestSpeak(t *testing.T) {
}

func BenchmarkSpeak(b *testing.B) {
new := func(ctx context.Context, b *testing.B) brain.Brain {
new := func(ctx context.Context, b *testing.B) brain.Interface {
db, err := badger.Open(badger.DefaultOptions(b.TempDir()).WithLogger(nil).WithCompression(options.None).WithBloomFalsePositive(1.0 / 32).WithNumMemtables(16).WithLevelSizeMultiplier(4))
if err != nil {
b.Fatal(err)
}
return New(db)
}
cleanup := func(l brain.Brain) {
cleanup := func(l brain.Interface) {
br := l.(*Brain)
if err := br.knowledge.DropAll(); err != nil {
b.Fatal(err)
Expand Down
34 changes: 3 additions & 31 deletions brain/learn.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,38 +17,10 @@ type Tuple struct {
Suffix string
}

// Learner records Markov chain tuples.
type Learner interface {
// Learn records a set of tuples.
// One tuple has an empty prefix to denote the start of the message, and
// a different tuple has the empty string as its suffix to denote the end
// of the message. The positions of each in the argument are not guaranteed.
// Each tuple's prefix has entropy reduction transformations applied.
// Tuples in the argument may share storage for prefixes.
Learn(ctx context.Context, tag string, msg *Message, tuples []Tuple) error
// Forget forgets everything learned from a single given message.
// If nothing has been learned from the message, it should prevent anything
// from being learned from a message with that ID.
Forget(ctx context.Context, tag, id string) error
// Recall reads out messages the brain knows.
// At minimum, the message ID and text of each message must be retrieved;
// other fields may be filled if they are available.
// It must be safe to call Recall concurrently with other methods of the
// implementation.
// Repeated calls using the pagination token returned from the previous
// must yield every message that the brain had recorded at the time of the
// first call exactly once. Messages learned after the first call of an
// enumeration are read at most once.
// The first call of an enumeration uses an empty pagination token.
// If the returned pagination token is empty, it is interpreted as the end
// of the enumeration.
Recall(ctx context.Context, tag, page string, out []Message) (n int, next string, err error)
}

var tuplesPool tpool.Pool[[]Tuple]

// Learn records a message into a Learner.
func Learn(ctx context.Context, l Learner, tag string, msg *Message) error {
// Learn records a message into a brain.
func Learn(ctx context.Context, l Interface, tag string, msg *Message) error {
toks := tokens(tokensPool.Get(), msg.Text)
defer func() { tokensPool.Put(toks[:0]) }()
if len(toks) == 0 {
Expand Down Expand Up @@ -77,7 +49,7 @@ func tupleToks(tt []Tuple, toks []string) []Tuple {
}

// Recall iterates over all messages a brain knows with a given tag.
func Recall(ctx context.Context, br Learner, tag string) iter.Seq2[Message, error] {
func Recall(ctx context.Context, br Interface, tag string) iter.Seq2[Message, error] {
return func(yield func(Message, error) bool) {
var (
page string
Expand Down
5 changes: 5 additions & 0 deletions brain/learn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ func (t *testLearner) Forget(ctx context.Context, tag, id string) error {
return nil
}

// Speak implements brain.Brain.
func (t *testLearner) Speak(ctx context.Context, tag string, prompt []string, w *brain.Builder) error {
panic("unimplemented")
}

func (t *testLearner) Recall(ctx context.Context, tag string, page string, out []brain.Message) (n int, next string, err error) {
var k int
if page != "" {
Expand Down
11 changes: 2 additions & 9 deletions brain/speak.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,16 @@ import (
"github.com/zephyrtronium/robot/tpool"
)

// Speaker produces random messages.
type Speaker interface {
// Speak generates a full message and appends it to w.
// The prompt is in reverse order and has entropy reduction applied.
Speak(ctx context.Context, tag string, prompt []string, w *Builder) error
}

var (
tokensPool tpool.Pool[[]string]
builderPool = tpool.Pool[*Builder]{New: func() any { return new(Builder) }}
)

// Speak produces a new message and the trace of messages used to form it
// from the given prompt.
// If the speaker does not produce any terms, the result is the empty string
// If the brain does not produce any terms, the result is the empty string
// regardless of the prompt, with no error.
func Speak(ctx context.Context, s Speaker, tag, prompt string) (string, []string, error) {
func Speak(ctx context.Context, s Interface, tag, prompt string) (string, []string, error) {
w := builderPool.Get()
toks := tokens(tokensPool.Get(), prompt)
defer func() {
Expand Down
17 changes: 17 additions & 0 deletions brain/speak_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"github.com/google/go-cmp/cmp/cmpopts"

"github.com/zephyrtronium/robot/brain"
"github.com/zephyrtronium/robot/message"
"github.com/zephyrtronium/robot/userhash"
)

type testSpeaker struct {
Expand All @@ -22,6 +24,21 @@ func (t *testSpeaker) Speak(ctx context.Context, tag string, prompt []string, w
return nil
}

// Forget implements brain.Brain.
func (t *testSpeaker) Forget(ctx context.Context, tag string, id string) error {
panic("unimplemented")
}

// Learn implements brain.Brain.
func (t *testSpeaker) Learn(ctx context.Context, tag string, msg *message.Received[userhash.Hash], tuples []brain.Tuple) error {
panic("unimplemented")
}

// Recall implements brain.Brain.
func (t *testSpeaker) Recall(ctx context.Context, tag string, page string, out []message.Received[userhash.Hash]) (n int, next string, err error) {
panic("unimplemented")
}

func TestSpeak(t *testing.T) {
cases := []struct {
name string
Expand Down
5 changes: 1 addition & 4 deletions brain/sqlbrain/brain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,10 @@ func testDB(ctx context.Context) *sqlitex.Pool {
return pool
}

var _ brain.Learner = (*sqlbrain.Brain)(nil)
var _ brain.Speaker = (*sqlbrain.Brain)(nil)

func TestIntegrated(t *testing.T) {
t.Parallel()
ctx := context.Background()
new := func(ctx context.Context) brain.Brain {
new := func(ctx context.Context) brain.Interface {
db := testDB(ctx)
br, err := sqlbrain.Open(ctx, db)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions brain/sqlbrain/learn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ func TestRecall(t *testing.T) {

func BenchmarkLearn(b *testing.B) {
dir := filepath.ToSlash(b.TempDir())
new := func(ctx context.Context, b *testing.B) brain.Learner {
new := func(ctx context.Context, b *testing.B) brain.Interface {
dsn := fmt.Sprintf("file:%s/benchmark_learn.db?_journal=WAL", dir)
db, err := sqlitex.NewPool(dsn, sqlitex.PoolOptions{PrepareConn: sqlbrain.RecommendedPrep})
if err != nil {
Expand All @@ -589,5 +589,5 @@ func BenchmarkLearn(b *testing.B) {
}
return br
}
braintest.BenchLearn(context.Background(), b, new, func(l brain.Learner) { l.(*sqlbrain.Brain).Close() })
braintest.BenchLearn(context.Background(), b, new, func(l brain.Interface) { l.(*sqlbrain.Brain).Close() })
}
4 changes: 2 additions & 2 deletions brain/sqlbrain/speak_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ func insert(t *testing.T, conn *sqlite.Conn, know []know, msgs []msg) {

func BenchmarkSpeak(b *testing.B) {
var dbs atomic.Uint64
new := func(ctx context.Context, b *testing.B) brain.Brain {
new := func(ctx context.Context, b *testing.B) brain.Interface {
k := dbs.Add(1)
db, err := sqlitex.NewPool(fmt.Sprintf("file:%s/bench-%d.sql", b.TempDir(), k), sqlitex.PoolOptions{PrepareConn: sqlbrain.RecommendedPrep})
if err != nil {
Expand All @@ -530,7 +530,7 @@ func BenchmarkSpeak(b *testing.B) {
}
return br
}
cleanup := func(l brain.Brain) {
cleanup := func(l brain.Interface) {
br := l.(*sqlbrain.Brain)
br.Close()
}
Expand Down
2 changes: 1 addition & 1 deletion command/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
type Robot struct {
Log *slog.Logger
Channels *syncmap.Map[string, *channel.Channel]
Brain brain.Brain
Brain brain.Interface
Privacy *privacy.List
Spoken *spoken.History
Owner string
Expand Down
4 changes: 2 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ func cliSpeak(ctx context.Context, cmd *cli.Command) error {
if err != nil {
return err
}
var br brain.Brain
var br brain.Interface
if sql == nil {
if kv == nil {
panic("robot: no brain")
Expand Down Expand Up @@ -202,7 +202,7 @@ func cliAncient(ctx context.Context, cmd *cli.Command) error {
if err != nil {
return err
}
var br brain.Brain
var br brain.Interface
if sql == nil {
if kv == nil {
panic("robot: no brain")
Expand Down
Loading

0 comments on commit aab95b8

Please sign in to comment.