Skip to content

Commit

Permalink
Openai: use message attachment as GPT context
Browse files Browse the repository at this point in the history
  • Loading branch information
brainexe committed Nov 3, 2023
1 parent f29110e commit 8bbb1a0
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 7 deletions.
32 changes: 32 additions & 0 deletions bot/interaction.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package bot

import (
"bytes"
"strings"
"sync"

"github.com/innogames/slack-bot/v2/bot/msg"
Expand All @@ -18,12 +20,18 @@ func (b *Bot) handleEvent(eventsAPIEvent slackevents.EventsAPIEvent) {
switch eventsAPIEvent.Type {
case slackevents.CallbackEvent:
innerEvent := eventsAPIEvent.InnerEvent

switch ev := innerEvent.Data.(type) {
case *slackevents.MessageEvent:
if ev.SubType == "message_changed" {
// don't listen to edited messages
return
}

if len(ev.Files) > 0 {
ev.Text += b.loadFileContent(ev)
}

message := &slack.MessageEvent{
Msg: slack.Msg{
Text: ev.Text,
Expand Down Expand Up @@ -51,6 +59,30 @@ func (b *Bot) handleEvent(eventsAPIEvent slackevents.EventsAPIEvent) {
}
}

func (b *Bot) loadFileContent(event *slackevents.MessageEvent) string {
response := ""

for _, file := range event.Files {
if !strings.HasPrefix(file.Mimetype, "text/") {
log.Infof("Can't load file %s: mimetype is %s", file.Name, file.Mimetype)
continue
}

var downloadedText bytes.Buffer
log.Infof("Downloading message attachment file %s", file.Name)

err := b.slackClient.Client.GetFile(file.URLPrivate, &downloadedText)
if err != nil {
log.Errorf("Failed to download file %s: %s", file.URLPrivate, err.Error())
continue
}

response += "\n" + downloadedText.String()
}

return response
}

func (b *Bot) handleInteraction(payload slack.InteractionCallback) bool {
if !b.isUserActionAllowed(payload.User.ID) {
log.Warnf("User %s tried to execute a command", payload.User.ID)
Expand Down
7 changes: 4 additions & 3 deletions bot/stats/stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"github.com/innogames/slack-bot/v2/bot/storage"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/constraints"
)

const collection = "stats"
Expand All @@ -28,9 +29,9 @@ func IncreaseOne(key string) {
}

// Increase is increasing the stats counter
func Increase(key string, count uint) {
func Increase[T constraints.Signed](key string, count T) {
storage.Atomic(func() {
var value uint
var value T
_ = storage.Read(collection, key, &value)

value += count
Expand All @@ -42,7 +43,7 @@ func Increase(key string, count uint) {
}

// Set the stats to a specific value
func Set(key string, value uint) {
func Set[T constraints.Signed](key string, value T) {
if err := storage.Write(collection, key, value); err != nil {
log.Warn(errors.Wrap(err, "error while set stats"))
}
Expand Down
6 changes: 6 additions & 0 deletions command/openai/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ import (
"github.com/innogames/slack-bot/v2/bot/config"
"github.com/innogames/slack-bot/v2/bot/matcher"
"github.com/innogames/slack-bot/v2/bot/msg"
"github.com/innogames/slack-bot/v2/bot/stats"
"github.com/innogames/slack-bot/v2/bot/storage"
"github.com/innogames/slack-bot/v2/bot/util"

log "github.com/sirupsen/logrus"
"github.com/slack-go/slack"
)
Expand Down Expand Up @@ -235,6 +237,10 @@ func (c *chatGPTCommand) callAndStore(messages []ChatMessage, storageIdentifier
log.Warnf("Error while storing openai history: %s", err)
}

stats.IncreaseOne("openai_calls")
stats.Increase("openai_input_tokens", inputTokens)
stats.Increase("openai_output_tokens", estimateTokensForMessage(responseText.String()))

log.Infof(
"Openai %s call took %s with %d sub messages (%d tokens). Message: '%s'. Response: '%s'",
c.cfg.Model,
Expand Down
6 changes: 3 additions & 3 deletions command/openai/tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func truncateMessages(model string, inputMessages []ChatMessage) ([]ChatMessage,
truncatedMessages := 0
maxTokens := getMaxTokensForModel(model)
for _, message := range inputMessages {
tokens := estimateTokensForMessage(message)
tokens := estimateTokensForMessage(message.Content)

if currentTokens+tokens >= maxTokens {
truncatedMessages++
Expand All @@ -49,8 +49,8 @@ func getMaxTokensForModel(model string) int {
return 4000
}

func estimateTokensForMessage(message ChatMessage) int {
func estimateTokensForMessage(message string) int {
// to lower the dependency to heavy external libs we use the rule of thumbs which is totally fine here
// https://platform.openai.com/tokenizer
return len(message.Content) / 4
return len(message) / 4
}
2 changes: 1 addition & 1 deletion command/openai/tokens_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func TestTruncate(t *testing.T) {

func TestCountTokens(t *testing.T) {
t.Run("Count", func(t *testing.T) {
actual := estimateTokensForMessage(ChatMessage{Content: "hello you!"})
actual := estimateTokensForMessage("hello you!")
assert.Equal(t, 2, actual)
})
}

0 comments on commit 8bbb1a0

Please sign in to comment.