Skip to content

Commit

Permalink
brain: remove ForgetDuring and ForgetUserSince methods
Browse files Browse the repository at this point in the history
We keep the information needed to implement these in terms of only
message IDs elsewhere.

Fixes #52.
For #90.
  • Loading branch information
zephyrtronium committed Nov 21, 2024
1 parent 0ef3deb commit dd538e3
Show file tree
Hide file tree
Showing 10 changed files with 51 additions and 1,393 deletions.
54 changes: 4 additions & 50 deletions brain/braintest/braintest.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ import (
// If a brain cannot be created without error, new should call t.Fatal.
func Test(ctx context.Context, t *testing.T, new func(context.Context) brain.Brain) {
t.Run("speak", testSpeak(ctx, new(ctx)))
t.Run("forgetMessage", testForgetMessage(ctx, new(ctx)))
t.Run("forgetDuring", testForgetDuring(ctx, new(ctx)))
t.Run("forgetMessage", testForget(ctx, new(ctx)))
t.Run("combinatoric", testCombinatoric(ctx, new(ctx)))
}

Expand Down Expand Up @@ -182,11 +181,11 @@ func testSpeak(ctx context.Context, br brain.Brain) func(t *testing.T) {
}
}

// testForgetMessage tests that a brain can forget messages by ID.
func testForgetMessage(ctx context.Context, br brain.Brain) func(t *testing.T) {
// testForget tests that a brain can forget messages by ID.
func testForget(ctx context.Context, br brain.Brain) func(t *testing.T) {
return func(t *testing.T) {
learn(ctx, t, br)
if err := br.ForgetMessage(ctx, "kessoku", messages[0].ID); err != nil {
if err := br.Forget(ctx, "kessoku", messages[0].ID); err != nil {
t.Errorf("failed to forget first message: %v", err)
}
got := speak(ctx, t, br, "kessoku", "", 2048)
Expand Down Expand Up @@ -233,51 +232,6 @@ func testForgetMessage(ctx context.Context, br brain.Brain) func(t *testing.T) {
}
}

// testForgetDuring tests that a brain can forget messages in a time range.
func testForgetDuring(ctx context.Context, br brain.Brain) func(t *testing.T) {
return func(t *testing.T) {
learn(ctx, t, br)
if err := br.ForgetDuring(ctx, "kessoku", time.Unix(1, 0).Add(-time.Millisecond), time.Unix(2, 0).Add(time.Millisecond)); err != nil {
t.Errorf("failed to forget: %v", err)
}
got := speak(ctx, t, br, "kessoku", "", 2048)
want := map[string]struct{}{
"1#member bocchi": {},
"1 4#member bocchi": {},
"1 4#member kita": {},
"4#member kita": {},
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("wrong messages after forgetting (+got/-want):\n%s", diff)
}
got = speak(ctx, t, br, "sickhack", "", 2048)
want = map[string]struct{}{
"5#member bocchi": {},
"5 6#member bocchi": {},
"5 7#member bocchi": {},
"5 8#member bocchi": {},
"5 6#member ryou": {},
"6#member ryou": {},
"6 7#member ryou": {},
"6 8#member ryou": {},
"5 7#member nijika": {},
"6 7#member nijika": {},
"7#member nijika": {},
"7 8#member nijika": {},
"5 8#member kita": {},
"6 8#member kita": {},
"7 8#member kita": {},
"8#member kita": {},
"9#manager seika": {},
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("wrong spoken messages for sickhack (+got/-want):\n%s", diff)
}
}
}

// TODO(zeph): testForgetUser

// testCombinatoric tests that chains can generate even with substantial
// overlap in learned material.
func testCombinatoric(ctx context.Context, br brain.Brain) func(t *testing.T) {
Expand Down
44 changes: 1 addition & 43 deletions brain/braintest/braintest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,55 +66,13 @@ func (m *membrain) forgetIDLocked(tag, id string) {
}
}

func (m *membrain) Forget(ctx context.Context, tag string, tuples []brain.Tuple) error {
m.mu.Lock()
defer m.mu.Unlock()
for _, tup := range tuples {
p := strings.Join(tup.Prefix, "\xff")
u := m.tups[tag][p]
k := slices.IndexFunc(u, func(v [2]string) bool { return v[1] == tup.Suffix })
if k < 0 {
continue
}
u[k], u[len(u)-1] = u[len(u)-1], u[k]
m.tups[tag][p] = u[:len(u)-1]
}
return nil
}

func (m *membrain) ForgetMessage(ctx context.Context, tag, id string) error {
func (m *membrain) Forget(ctx context.Context, tag, id string) error {
m.mu.Lock()
defer m.mu.Unlock()
m.forgetIDLocked(tag, id)
return nil
}

func (m *membrain) ForgetDuring(ctx context.Context, tag string, since, before time.Time) error {
m.mu.Lock()
defer m.mu.Unlock()
s, b := since.UnixNano(), before.UnixNano()
for tm, u := range m.tms[tag] {
if tm < s || tm > b {
continue
}
for _, v := range u {
m.forgetIDLocked(tag, v)
}
delete(m.tms[tag], tm) // yea i modify the map during iteration, yea i'm cool
}
return nil
}

func (m *membrain) ForgetUser(ctx context.Context, user *userhash.Hash) error {
m.mu.Lock()
defer m.mu.Unlock()
for _, v := range m.users[*user] {
m.forgetIDLocked(v[0], v[1])
}
delete(m.users, *user)
return nil
}

func (m *membrain) Speak(ctx context.Context, tag string, prompt []string, w *brain.Builder) error {
m.mu.Lock()
defer m.mu.Unlock()
Expand Down
91 changes: 2 additions & 89 deletions brain/kvbrain/forget.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"slices"
"sync"
"time"

"github.com/zephyrtronium/robot/userhash"
)
Expand Down Expand Up @@ -53,45 +52,9 @@ func (p *past) findID(id string) [][]byte {
return nil
}

// findDuring finds all knowledge keys of messages recorded with timestamps in
// the given time span.
func (p *past) findDuring(since, before int64) [][]byte {
r := make([][]byte, 0, 64)
p.mu.Lock()
defer p.mu.Unlock()
for k, v := range p.time {
if since <= v && v <= before {
keys := p.key[k]
r = slices.Grow(r, len(keys))
for _, v := range keys {
r = append(r, bytes.Clone(v))
}
}
}
return r
}

// findUser finds all knowledge keys of messages recorded from a given user
// since a timestamp.
func (p *past) findUser(user userhash.Hash) [][]byte {
r := make([][]byte, 0, 64)
p.mu.Lock()
defer p.mu.Unlock()
for k, v := range p.user {
if v == user {
keys := p.key[k]
r = slices.Grow(r, len(keys))
for _, v := range keys {
r = append(r, bytes.Clone(v))
}
}
}
return r
}

// ForgetMessage forgets everything learned from a single given message.
// Forget forgets everything learned from a single given message.
// If nothing has been learned from the message, it should be ignored.
func (br *Brain) ForgetMessage(ctx context.Context, tag, id string) error {
func (br *Brain) Forget(ctx context.Context, tag, id string) error {
past, _ := br.past.Load(tag)
if past == nil {
return nil
Expand All @@ -111,53 +74,3 @@ func (br *Brain) ForgetMessage(ctx context.Context, tag, id string) error {
}
return nil
}

// ForgetDuring forgets all messages learned in the given time span.
func (br *Brain) ForgetDuring(ctx context.Context, tag string, since, before time.Time) error {
past, _ := br.past.Load(tag)
if past == nil {
return nil
}
keys := past.findDuring(since.UnixNano(), before.UnixNano())
batch := br.knowledge.NewWriteBatch()
defer batch.Cancel()
for _, key := range keys {
err := batch.Delete(key)
if err != nil {
return err
}
}
err := batch.Flush()
if err != nil {
return fmt.Errorf("couldn't commit deleting between times %v and %v: %w", since, before, err)
}
return nil
}

// ForgetUser forgets all messages associated with a userhash.
func (br *Brain) ForgetUser(ctx context.Context, user *userhash.Hash) error {
var rangeErr error
u := *user
br.past.Range(func(tag string, past *past) bool {
keys := past.findUser(u)
if len(keys) == 0 {
return true
}
batch := br.knowledge.NewWriteBatch()
defer batch.Cancel()
for _, key := range keys {
err := batch.Delete(key)
if err != nil {
rangeErr = err
return false
}
}
err := batch.Flush()
if err != nil {
rangeErr = fmt.Errorf("couldn't commit deleting messages by user: %w", err)
return false
}
return false
})
return rangeErr
}
Loading

0 comments on commit dd538e3

Please sign in to comment.