Skip to content

Commit

Permalink
brain/*: learn entire messages
Browse files Browse the repository at this point in the history
Now brain implementations will have no excuse not to record message
text in addition to tuples.

For #90.
  • Loading branch information
zephyrtronium committed Nov 22, 2024
1 parent f6dfe4e commit 1c26481
Show file tree
Hide file tree
Showing 14 changed files with 101 additions and 63 deletions.
8 changes: 8 additions & 0 deletions brain/brain.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
package brain

import (
"github.com/zephyrtronium/robot/message"
"github.com/zephyrtronium/robot/userhash"
)

// Brain is a combined [Learner] and [Speaker].
type Brain interface {
Learner
Speaker
}

// Message is the message type used by a [Brain].
type Message = message.Received[userhash.Hash]
16 changes: 10 additions & 6 deletions brain/braintest/bench.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"strconv"
"strings"
"testing"
"time"

"github.com/zephyrtronium/robot/brain"
"github.com/zephyrtronium/robot/userhash"
Expand Down Expand Up @@ -45,7 +44,8 @@ func BenchLearn(ctx context.Context, b *testing.B, new func(ctx context.Context,
toks[len(toks)-1] = strconv.FormatInt(t, 10)
id := randid()
u := userhash.Hash(randbytes(make([]byte, len(userhash.Hash{}))))
err := brain.Learn(ctx, l, "bocchi", id, u, time.Unix(t, 0), strings.Join(toks, " "))
msg := brain.Message{ID: id, Sender: u, Timestamp: t * 1e3, Text: strings.Join(toks, " ")}
err := brain.Learn(ctx, l, "bocchi", &msg)
if err != nil {
b.Errorf("error while learning: %v", err)
}
Expand Down Expand Up @@ -83,7 +83,8 @@ func BenchLearn(ctx context.Context, b *testing.B, new func(ctx context.Context,
rand.Shuffle(len(toks), func(i, j int) { toks[i], toks[j] = toks[j], toks[i] })
id := randid()
u := userhash.Hash(randbytes(make([]byte, len(userhash.Hash{}))))
err := brain.Learn(ctx, l, "bocchi", id, u, time.Unix(t, 0), strings.Join(toks[:8], " "))
msg := brain.Message{ID: id, Sender: u, Timestamp: t * 1e3, Text: strings.Join(toks, " ")}
err := brain.Learn(ctx, l, "bocchi", &msg)
if err != nil {
b.Errorf("error while learning: %v", err)
}
Expand Down Expand Up @@ -117,7 +118,8 @@ func BenchSpeak(ctx context.Context, b *testing.B, new func(ctx context.Context,
toks[len(toks)-1] = strconv.FormatInt(t, 10)
id := randid()
u := userhash.Hash(randbytes(make([]byte, len(userhash.Hash{}))))
err := brain.Learn(ctx, br, "bocchi", id, u, time.Unix(t, 0), strings.Join(toks, " "))
msg := brain.Message{ID: id, Sender: u, Timestamp: t * 1e3, Text: strings.Join(toks, " ")}
err := brain.Learn(ctx, br, "bocchi", &msg)
if err != nil {
b.Errorf("error while learning: %v", err)
}
Expand Down Expand Up @@ -162,7 +164,8 @@ func BenchSpeak(ctx context.Context, b *testing.B, new func(ctx context.Context,
rand.Shuffle(len(toks), func(i, j int) { toks[i], toks[j] = toks[j], toks[i] })
id := randid()
u := userhash.Hash(randbytes(make([]byte, len(userhash.Hash{}))))
err := brain.Learn(ctx, br, "bocchi", id, u, time.Unix(t, 0), strings.Join(toks, " "))
msg := brain.Message{ID: id, Sender: u, Timestamp: t * 1e3, Text: strings.Join(toks, " ")}
err := brain.Learn(ctx, br, "bocchi", &msg)
if err != nil {
b.Errorf("error while learning: %v", err)
}
Expand Down Expand Up @@ -207,7 +210,8 @@ func BenchSpeak(ctx context.Context, b *testing.B, new func(ctx context.Context,
rand.Shuffle(len(toks), func(i, j int) { toks[i], toks[j] = toks[j], toks[i] })
id := randid()
u := userhash.Hash(randbytes(make([]byte, len(userhash.Hash{}))))
err := brain.Learn(ctx, br, "bocchi", id, u, time.Unix(t, 0), strings.Join(toks, " "))
msg := brain.Message{ID: id, Sender: u, Timestamp: t * 1e3, Text: strings.Join(toks, " ")}
err := brain.Learn(ctx, br, "bocchi", &msg)
if err != nil {
b.Errorf("error while learning: %v", err)
}
Expand Down
6 changes: 4 additions & 2 deletions brain/braintest/braintest.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ var messages = [...]struct {
func learn(ctx context.Context, t *testing.T, br brain.Learner) {
t.Helper()
for _, m := range messages {
if err := brain.Learn(ctx, br, m.Tag, m.ID, m.User, m.Time, m.Text); err != nil {
msg := brain.Message{ID: m.ID, Sender: m.User, Timestamp: m.Time.UnixMilli(), Text: m.Text}
if err := brain.Learn(ctx, br, m.Tag, &msg); err != nil {
t.Fatalf("couldn't learn message %v: %v", m.ID, err)
}
}
Expand Down Expand Up @@ -242,7 +243,8 @@ func testCombinatoric(ctx context.Context, br brain.Brain) func(t *testing.T) {
toks := toks
for len(toks) > 1 {
id := randid()
err := brain.Learn(ctx, br, "bocchi", id, u, time.Unix(0, 0), strings.Join(toks, " "))
msg := brain.Message{ID: id, Sender: u, Text: strings.Join(toks, " ")}
err := brain.Learn(ctx, br, "bocchi", &msg)
if err != nil {
t.Fatalf("couldn't learn init: %v", err)
}
Expand Down
9 changes: 4 additions & 5 deletions brain/braintest/braintest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"strings"
"sync"
"testing"
"time"

"github.com/zephyrtronium/robot/brain"
"github.com/zephyrtronium/robot/brain/braintest"
Expand All @@ -25,7 +24,7 @@ type membrain struct {

var _ brain.Brain = (*membrain)(nil)

func (m *membrain) Learn(ctx context.Context, tag, id string, user userhash.Hash, t time.Time, tuples []brain.Tuple) error {
func (m *membrain) Learn(ctx context.Context, tag string, msg *brain.Message, tuples []brain.Tuple) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.tups[tag] == nil {
Expand All @@ -37,13 +36,13 @@ func (m *membrain) Learn(ctx context.Context, tag, id string, user userhash.Hash
m.tups[tag] = make(map[string][][2]string)
m.tms[tag] = make(map[int64][]string)
}
m.users[user] = append(m.users[user], [2]string{tag, id})
m.users[msg.Sender] = append(m.users[msg.Sender], [2]string{tag, msg.ID})
tms := m.tms[tag]
tms[t.UnixNano()] = append(tms[t.UnixNano()], id)
tms[msg.Timestamp] = append(tms[msg.Timestamp], msg.ID)
r := m.tups[tag]
for _, tup := range tuples {
p := strings.Join(tup.Prefix, "\xff")
r[p] = append(r[p], [2]string{id, tup.Suffix})
r[p] = append(r[p], [2]string{msg.ID, tup.Suffix})
}
return nil
}
Expand Down
7 changes: 6 additions & 1 deletion brain/kvbrain/forget_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,12 @@ func TestForget(t *testing.T) {
}
br := New(db)
for _, msg := range c.msgs {
err := br.Learn(ctx, msg.tag, msg.id, msg.user, msg.time, msg.tups)
m := brain.Message{
ID: msg.id,
Sender: msg.user,
Timestamp: msg.time.UnixMilli(),
}
err := br.Learn(ctx, msg.tag, &m, msg.tups)
if err != nil {
t.Errorf("failed to learn: %v", err)
}
Expand Down
9 changes: 4 additions & 5 deletions brain/kvbrain/learn.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,16 @@ import (
"context"
"errors"
"fmt"
"time"

"github.com/zephyrtronium/robot/brain"
"github.com/zephyrtronium/robot/userhash"
)

// Learn records a set of tuples. Each tuple prefix has length equal to the
// result of Order. The tuples begin with empty strings in the prefix to
// denote the start of the message and end with one empty suffix to denote
// the end; all other tokens are non-empty. Each tuple's prefix has entropy
// reduction transformations applied.
func (br *Brain) Learn(ctx context.Context, tag, id string, user userhash.Hash, t time.Time, tuples []brain.Tuple) error {
func (br *Brain) Learn(ctx context.Context, tag string, msg *brain.Message, tuples []brain.Tuple) error {
if len(tuples) == 0 {
return errors.New("no tuples to learn")
}
Expand All @@ -31,7 +29,7 @@ func (br *Brain) Learn(ctx context.Context, tag, id string, user userhash.Hash,
b = hashTag(b[:0], tag)
b = append(appendPrefix(b, t.Prefix), '\xff')
// Write message ID.
b = append(b, id[:]...)
b = append(b, msg.ID...)
keys[i] = bytes.Clone(b)
vals[i] = []byte(t.Suffix)
}
Expand All @@ -42,7 +40,8 @@ func (br *Brain) Learn(ctx context.Context, tag, id string, user userhash.Hash,
// overwrite if that happens.
p, _ = br.past.LoadOrStore(tag, new(past))
}
p.record(id, user, t.UnixNano(), keys)
// Scale the timestamp from milliseconds to nanoseconds for historical reasons.
p.record(msg.ID, msg.Sender, msg.Timestamp*1e6, keys)

batch := br.knowledge.NewWriteBatch()
defer batch.Cancel()
Expand Down
7 changes: 6 additions & 1 deletion brain/kvbrain/learn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,12 @@ func TestLearn(t *testing.T) {
t.Fatal(err)
}
br := New(db)
if err := br.Learn(ctx, c.tag, c.id, c.user, c.time, c.tups); err != nil {
msg := brain.Message{
ID: c.id,
Sender: c.user,
Timestamp: c.time.UnixMilli(),
}
if err := br.Learn(ctx, c.tag, &msg, c.tups); err != nil {
t.Errorf("failed to learn: %v", err)
}
dbcheck(t, db, c.want)
Expand Down
12 changes: 5 additions & 7 deletions brain/learn.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@ package brain
import (
"context"
"slices"
"time"

"github.com/zephyrtronium/robot/tpool"
"github.com/zephyrtronium/robot/userhash"
)

// Tuple is a single Markov chain tuple.
Expand All @@ -26,7 +24,7 @@ type Learner interface {
// of the message. The positions of each in the argument are not guaranteed.
// Each tuple's prefix has entropy reduction transformations applied.
// Tuples in the argument may share storage for prefixes.
Learn(ctx context.Context, tag, id string, user userhash.Hash, t time.Time, tuples []Tuple) error
Learn(ctx context.Context, tag string, msg *Message, tuples []Tuple) error
// Forget forgets everything learned from a single given message.
// If nothing has been learned from the message, it should prevent anything
// from being learned from a message with that ID.
Expand All @@ -35,9 +33,9 @@ type Learner interface {

var tuplesPool tpool.Pool[[]Tuple]

// Learn records tokens into a Learner.
func Learn(ctx context.Context, l Learner, tag, id string, user userhash.Hash, t time.Time, text string) error {
toks := Tokens(tokensPool.Get(), text)
// Learn records a message into a Learner.
func Learn(ctx context.Context, l Learner, tag string, msg *Message) error {
toks := Tokens(tokensPool.Get(), msg.Text)
defer func() { tokensPool.Put(toks[:0]) }()
if len(toks) == 0 {
return nil
Expand All @@ -46,7 +44,7 @@ func Learn(ctx context.Context, l Learner, tag, id string, user userhash.Hash, t
defer func() { tuplesPool.Put(tt[:0]) }()
tt = slices.Grow(tt, len(toks)+1)
tt = tupleToks(tt, toks)
return l.Learn(ctx, tag, id, user, t, tt)
return l.Learn(ctx, tag, msg, tt)
}

func tupleToks(tt []Tuple, toks []string) []Tuple {
Expand Down
6 changes: 2 additions & 4 deletions brain/learn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,17 @@ package brain_test
import (
"context"
"testing"
"time"

"github.com/google/go-cmp/cmp"

"github.com/zephyrtronium/robot/brain"
"github.com/zephyrtronium/robot/userhash"
)

type testLearner struct {
learned []brain.Tuple
}

func (t *testLearner) Learn(ctx context.Context, tag, id string, user userhash.Hash, tm time.Time, tuples []brain.Tuple) error {
func (t *testLearner) Learn(ctx context.Context, tag string, msg *brain.Message, tuples []brain.Tuple) error {
t.learned = append(t.learned, tuples...)
return nil
}
Expand Down Expand Up @@ -63,7 +61,7 @@ func TestLearn(t *testing.T) {
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
var l testLearner
err := brain.Learn(context.Background(), &l, "", "", userhash.Hash{}, time.Unix(0, 0), c.msg)
err := brain.Learn(context.Background(), &l, "", &brain.Message{Text: c.msg})
if err != nil {
t.Error(err)
}
Expand Down
32 changes: 18 additions & 14 deletions brain/sqlbrain/forget_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"strings"
"testing"
"time"

"github.com/zephyrtronium/robot/brain"
"github.com/zephyrtronium/robot/brain/sqlbrain"
Expand Down Expand Up @@ -107,19 +106,19 @@ func TestForget(t *testing.T) {
{
tag: "kessoku",
id: "2",
time: 3,
time: 3e6,
user: userhash.Hash{1},
},
{
tag: "kessoku",
id: "5",
time: 6,
time: 6e6,
user: userhash.Hash{4},
},
{
tag: "sickhack",
id: "2",
time: 3,
time: 3e6,
user: userhash.Hash{1},
},
}
Expand Down Expand Up @@ -206,20 +205,20 @@ func TestForget(t *testing.T) {
{
tag: "kessoku",
id: "2",
time: 3,
time: 3e6,
user: userhash.Hash{1},
deleted: ref("CLEARMSG"),
},
{
tag: "kessoku",
id: "5",
time: 6,
time: 6e6,
user: userhash.Hash{4},
},
{
tag: "sickhack",
id: "2",
time: 3,
time: 3e6,
user: userhash.Hash{1},
},
},
Expand Down Expand Up @@ -290,20 +289,20 @@ func TestForget(t *testing.T) {
{
tag: "kessoku",
id: "2",
time: 3,
time: 3e6,
user: userhash.Hash{1},
},
{
tag: "kessoku",
id: "5",
time: 6,
time: 6e6,
user: userhash.Hash{4},
deleted: ref("CLEARMSG"),
},
{
tag: "sickhack",
id: "2",
time: 3,
time: 3e6,
user: userhash.Hash{1},
},
},
Expand Down Expand Up @@ -374,19 +373,19 @@ func TestForget(t *testing.T) {
{
tag: "kessoku",
id: "2",
time: 3,
time: 3e6,
user: userhash.Hash{1},
},
{
tag: "kessoku",
id: "5",
time: 6,
time: 6e6,
user: userhash.Hash{4},
},
{
tag: "sickhack",
id: "2",
time: 3,
time: 3e6,
user: userhash.Hash{1},
deleted: ref("CLEARMSG"),
},
Expand All @@ -403,7 +402,12 @@ func TestForget(t *testing.T) {
t.Fatalf("couldn't open brain: %v", err)
}
for _, m := range learn {
err := br.Learn(ctx, m.tag, m.id, m.user, time.Unix(0, m.t), m.tups)
msg := brain.Message{
ID: m.id,
Sender: m.user,
Timestamp: m.t,
}
err := br.Learn(ctx, m.tag, &msg, m.tups)
if err != nil {
t.Errorf("failed to learn %v/%v: %v", m.tag, m.id, err)
}
Expand Down
Loading

0 comments on commit 1c26481

Please sign in to comment.