Skip to content

Commit

Permalink
Merge pull request #8 from pkoukk/fix_0418
Browse files Browse the repository at this point in the history
fix allowed and disallowed token usage
  • Loading branch information
pkoukk authored Apr 18, 2023
2 parents 78cbd36 + c87e1b1 commit 0a14607
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 19 deletions.
40 changes: 22 additions & 18 deletions tiktoken.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@ func GetEncoding(encodingName string) (*Tiktoken, error) {
if err != nil {
return nil, err
}
specialTokensSet := map[string]bool{}
specialTokensSet := map[string]any{}
for k := range enc.SpecialTokens {
specialTokensSet[k] = true
}
return &Tiktoken{
bpe: pbe,
pbeEncoding: enc,
bpe: pbe,
pbeEncoding: enc,
specialTokensSet: specialTokensSet,
}, nil
}

Expand All @@ -50,7 +51,7 @@ func (t *Tiktoken) Encode(text string, allowedSpecial []string, disallowedSpecia
var allowedSpecialSet map[string]any
if len(allowedSpecial) == 0 {
allowedSpecialSet = map[string]any{}
} else if len(disallowedSpecial) == 1 && disallowedSpecial[0] == "all" {
} else if len(allowedSpecial) == 1 && allowedSpecial[0] == "all" {
allowedSpecialSet = t.specialTokensSet
} else {
allowedSpecialSet = map[string]any{}
Expand All @@ -59,19 +60,12 @@ func (t *Tiktoken) Encode(text string, allowedSpecial []string, disallowedSpecia
}
}

var disallowedSpecialSet map[string]any
if len(disallowedSpecial) == 0 || (len(disallowedSpecial) == 1 && disallowedSpecial[0] == "all") {
disallowedSpecialSet = map[string]any{}
for k1 := range t.specialTokensSet {
if _, ok := allowedSpecialSet[k1]; !ok {
disallowedSpecialSet[k1] = nil
}
}
} else {
disallowedSpecialSet = map[string]any{}
for _, v := range disallowedSpecial {
disallowedSpecialSet[v] = nil
}
disallowedSpecialSet := map[string]any{}
for _, v := range disallowedSpecial {
disallowedSpecialSet[v] = nil
}
if len(disallowedSpecial) == 1 && disallowedSpecial[0] == "all" {
disallowedSpecialSet = difference(t.specialTokensSet, allowedSpecialSet)
}

if len(disallowedSpecialSet) > 0 {
Expand All @@ -95,7 +89,7 @@ func (t *Tiktoken) SpecialTokenRegex(disallowedSpecialSet map[string]any) *regex
for k := range disallowedSpecialSet {
specialRegexStrs = append(specialRegexStrs, regexp.QuoteMeta(k))
}
specialRegex, _ := regexp2.Compile(strings.Join(specialRegexStrs, "|"), regexp2.None)
specialRegex := regexp2.MustCompile(strings.Join(specialRegexStrs, "|"), regexp2.None)
return specialRegex
}

Expand All @@ -107,3 +101,13 @@ func findRegex2StringMatch(text string, reg *regexp2.Regexp) string {

return m.String()
}

func difference(setA, setB map[string]any) map[string]any {
result := make(map[string]any)
for k := range setA {
if _, ok := setB[k]; !ok {
result[k] = true
}
}
return result
}
17 changes: 16 additions & 1 deletion tiktoken_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,25 @@ func TestEncoding(t *testing.T) {
ass := assert.New(t)
enc, err := GetEncoding("cl100k_base")
ass.Nil(err, "Encoding init should not be nil")
tokens := enc.Encode("hello world!你好,世界!", nil, nil)
tokens := enc.Encode("hello world!你好,世界!", []string{"all"}, []string{"all"})
// these tokens are converted from the original python code
sourceTokens := []int{15339, 1917, 0, 57668, 53901, 3922, 3574, 244, 98220, 6447}
ass.ElementsMatch(sourceTokens, tokens, "Encoding should be equal")

tokens = enc.Encode("hello <|endoftext|>", []string{"<|endoftext|>"}, nil)
sourceTokens = []int{15339, 220, 100257}
ass.ElementsMatch(sourceTokens, tokens, "Encoding should be equal")

tokens = enc.Encode("hello <|endoftext|>", []string{"<|endoftext|>"}, []string{"all"})
sourceTokens = []int{15339, 220, 100257}
ass.ElementsMatch(sourceTokens, tokens, "Encoding should be equal")

ass.Panics(func() {
tokens = enc.Encode("hello <|endoftext|><|endofprompt|>", []string{"<|endoftext|>"}, []string{"all"})
})
ass.Panics(func() {
tokens = enc.Encode("hello <|endoftext|>", []string{"<|endoftext|>"}, []string{"<|endoftext|>"})
})
}

func TestDecoding(t *testing.T) {
Expand Down

0 comments on commit 0a14607

Please sign in to comment.