-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Migration of the count tokens module to the new message hierarchy + r…
…efactoring + jtokkit lib version update
- Loading branch information
1 parent
06816e2
commit a59cffe
Showing
6 changed files
with
224 additions
and
108 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,5 @@ | ||
import sbt.Keys.test | ||
|
||
name := "openai-scala-count-tokens" | ||
|
||
description := "Module of OpenAI Scala client to count tokens before sending a request to ChatGPT" | ||
|
||
libraryDependencies ++= { | ||
val jTokkitV = "0.5.0" | ||
Seq( | ||
"com.knuddels" % "jtokkit" % jTokkitV | ||
) | ||
} | ||
libraryDependencies += "com.knuddels" % "jtokkit" % "0.6.1" |
160 changes: 160 additions & 0 deletions
160
...count-tokens/src/main/scala/io/cequence/openaiscala/service/OpenAICountTokensHelper.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
package io.cequence.openaiscala.service | ||
|
||
import com.knuddels.jtokkit.Encodings | ||
import com.knuddels.jtokkit.api.Encoding | ||
import io.cequence.openaiscala.JsonFormats | ||
import io.cequence.openaiscala.domain._ | ||
import play.api.libs.json.Json | ||
|
||
// based on: https://jtokkit.knuddels.de/docs/getting-started/recipes/chatml | ||
trait OpenAICountTokensHelper { | ||
|
||
private lazy val registry = Encodings.newLazyEncodingRegistry() | ||
|
||
def countMessageTokens( | ||
model: String, | ||
messages: Seq[BaseMessage] | ||
) = { | ||
val encoding = registry.getEncodingForModel(model).orElseThrow | ||
val (tokensPerMessage, tokensPerName) = tokensPerMessageAndName(model) | ||
|
||
val sum = messages.map(countMessageTokensAux(tokensPerMessage, tokensPerName, encoding)).sum | ||
|
||
sum + 3 // every reply is primed with <|start|>assistant<|message|> | ||
} | ||
|
||
def countMessageTokens( | ||
model: String, | ||
message: BaseMessage | ||
) = { | ||
val encoding = registry.getEncodingForModel(model).orElseThrow | ||
val (tokensPerMessage, tokensPerName) = tokensPerMessageAndName(model) | ||
|
||
countMessageTokensAux(tokensPerMessage, tokensPerName, encoding)(message) | ||
} | ||
|
||
private def countMessageTokensAux( | ||
tokensPerMessage: Int, | ||
tokensPerName: Int, | ||
encoding: Encoding)( | ||
message: BaseMessage | ||
) = { | ||
tokensPerMessage + | ||
countContentAndExtra(encoding, message) + | ||
encoding.countTokens(message.role.toString) + | ||
message.nameOpt.map { name => encoding.countTokens(name) + tokensPerName }.getOrElse(0) | ||
} | ||
|
||
private def tokensPerMessageAndName(model: String) = | ||
model match { | ||
case x if x.startsWith("gpt-4") => | ||
(3, 1) | ||
|
||
case x if x.startsWith("gpt-3.5-turbo") => | ||
// every message follows <|start|>{role/name}\n{content}<|end|>\n | ||
// if there's a name, the role is omitted | ||
(4, -1) | ||
|
||
case _ => | ||
// failover to (3, 1) | ||
(3, 1) | ||
} | ||
|
||
private def countContentAndExtra( | ||
encoding: Encoding, | ||
message: BaseMessage | ||
): Int = { | ||
def count(s: String*) = s.map(encoding.countTokens).sum | ||
def countOpt(s: Option[String]) = s.map(count(_)).getOrElse(0) | ||
|
||
message match { | ||
case m: SystemMessage => | ||
count(m.content) | ||
|
||
case m: UserMessage => | ||
count(m.content) | ||
|
||
case m: UserSeqMessage => | ||
val contents = m.content.map(Json.toJson(_)(JsonFormats.contentWrites).toString()) | ||
count(contents: _*) | ||
|
||
case m: AssistantMessage => | ||
count(m.content) | ||
|
||
case m: AssistantToolMessage => | ||
val toolCallTokens = m.tool_calls.map { case (id, toolSpec) => | ||
toolSpec match { | ||
case toolSpec: FunctionCallSpec => | ||
count( | ||
id, | ||
toolSpec.name, | ||
toolSpec.arguments | ||
) + 3 // plus extra three tokens per function/tool call | ||
} | ||
} | ||
|
||
toolCallTokens.sum + countOpt(m.content) | ||
|
||
case m: AssistantFunMessage => | ||
val funCallTokens = m.function_call | ||
.map(c => count(c.name, c.arguments) + 3 // plus extra three tokens per function call | ||
) | ||
.getOrElse(0) | ||
|
||
funCallTokens + countOpt(m.content) | ||
|
||
case m: ToolMessage => | ||
count(m.tool_call_id) + countOpt(m.content) | ||
|
||
case m: FunMessage => | ||
count(m.content) | ||
|
||
case m: MessageSpec => | ||
count(m.content) | ||
} | ||
} | ||
|
||
def countFunMessageTokens( | ||
model: String, | ||
messages: Seq[BaseMessage], | ||
functions: Seq[FunctionSpec], | ||
responseFunctionName: Option[String] | ||
): Int = { | ||
val encoding = registry.getEncodingForModel(model).orElseThrow | ||
val (tokensPerMessage, tokensPerName) = tokensPerMessageAndName(model) | ||
|
||
def countMessageTokens(message: BaseMessage) = | ||
countMessageTokensAux(tokensPerMessage, tokensPerName, encoding)(message) | ||
|
||
val messagesTokensCount = messages.foldLeft((false, 0)) { case ((paddedSystem, count), message) => | ||
val (newPaddedFlag, paddedMessage) = if (message.role == ChatRole.System && !paddedSystem) { | ||
message match { | ||
case m: SystemMessage => | ||
(true, m.copy(content = m.content + "\n")) | ||
case m: MessageSpec if m.role == ChatRole.System => | ||
(true, m.copy(content = m.content + "\n")) | ||
case _ => | ||
throw new IllegalArgumentException(s"Unexpected message: $message") | ||
} | ||
} else { | ||
(paddedSystem, message) | ||
} | ||
|
||
(newPaddedFlag, count + countMessageTokens(paddedMessage)) | ||
}._2 | ||
|
||
val functionsTokensCount = functionsTokensEstimate(encoding, functions) | ||
val systemRoleAdjustment = if (messages.exists(m => m.role == ChatRole.System)) -4 else 0 | ||
val responseFunctionNameCount = responseFunctionName.map(name => encoding.countTokens(name) + 4).getOrElse(0) | ||
|
||
messagesTokensCount + functionsTokensCount + systemRoleAdjustment + responseFunctionNameCount + 3 | ||
} | ||
|
||
private def functionsTokensEstimate( | ||
encoding: Encoding, | ||
functions: Seq[FunctionSpec] | ||
): Int = { | ||
val promptDefinitions = OpenAIFunctionsImpl.formatFunctionDefinitions(functions) | ||
encoding.countTokens(promptDefinitions) + 9 | ||
} | ||
} |
70 changes: 0 additions & 70 deletions
70
...ount-tokens/src/main/scala/io/cequence/openaiscala/service/OpenAICountTokensService.scala
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.