Skip to content

Commit

Permalink
Migration of the count tokens module to the new message hierarchy + r…
Browse files Browse the repository at this point in the history
…efactoring + jtokkit lib version update
  • Loading branch information
peterbanda committed Nov 29, 2023
1 parent 06816e2 commit a59cffe
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 108 deletions.
2 changes: 1 addition & 1 deletion openai-count-tokens/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ or to *pom.xml* (if you use maven)
## Usage

```scala
import io.cequence.openaiscala.service.OpenAICountTokensService
import io.cequence.openaiscala.service.OpenAICountTokensHelper
import io.cequence.openaiscala.domain.{ChatRole, FunMessageSpec, FunctionSpec}

val messages: Seq[FunMessageSpec] = ??? // messages to be sent to OpenAI
Expand Down
9 changes: 1 addition & 8 deletions openai-count-tokens/build.sbt
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
import sbt.Keys.test

name := "openai-scala-count-tokens"

description := "Module of OpenAI Scala client to count tokens before sending a request to ChatGPT"

libraryDependencies ++= {
val jTokkitV = "0.5.0"
Seq(
"com.knuddels" % "jtokkit" % jTokkitV
)
}
libraryDependencies += "com.knuddels" % "jtokkit" % "0.6.1"
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package io.cequence.openaiscala.service

import com.knuddels.jtokkit.Encodings
import com.knuddels.jtokkit.api.Encoding
import io.cequence.openaiscala.JsonFormats
import io.cequence.openaiscala.domain._
import play.api.libs.json.Json

// based on: https://jtokkit.knuddels.de/docs/getting-started/recipes/chatml
trait OpenAICountTokensHelper {

private lazy val registry = Encodings.newLazyEncodingRegistry()

def countMessageTokens(
model: String,
messages: Seq[BaseMessage]
) = {
val encoding = registry.getEncodingForModel(model).orElseThrow
val (tokensPerMessage, tokensPerName) = tokensPerMessageAndName(model)

val sum = messages.map(countMessageTokensAux(tokensPerMessage, tokensPerName, encoding)).sum

sum + 3 // every reply is primed with <|start|>assistant<|message|>
}

def countMessageTokens(
model: String,
message: BaseMessage
) = {
val encoding = registry.getEncodingForModel(model).orElseThrow
val (tokensPerMessage, tokensPerName) = tokensPerMessageAndName(model)

countMessageTokensAux(tokensPerMessage, tokensPerName, encoding)(message)
}

private def countMessageTokensAux(
tokensPerMessage: Int,
tokensPerName: Int,
encoding: Encoding)(
message: BaseMessage
) = {
tokensPerMessage +
countContentAndExtra(encoding, message) +
encoding.countTokens(message.role.toString) +
message.nameOpt.map { name => encoding.countTokens(name) + tokensPerName }.getOrElse(0)
}

private def tokensPerMessageAndName(model: String) =
model match {
case x if x.startsWith("gpt-4") =>
(3, 1)

case x if x.startsWith("gpt-3.5-turbo") =>
// every message follows <|start|>{role/name}\n{content}<|end|>\n
// if there's a name, the role is omitted
(4, -1)

case _ =>
// failover to (3, 1)
(3, 1)
}

private def countContentAndExtra(
encoding: Encoding,
message: BaseMessage
): Int = {
def count(s: String*) = s.map(encoding.countTokens).sum
def countOpt(s: Option[String]) = s.map(count(_)).getOrElse(0)

message match {
case m: SystemMessage =>
count(m.content)

case m: UserMessage =>
count(m.content)

case m: UserSeqMessage =>
val contents = m.content.map(Json.toJson(_)(JsonFormats.contentWrites).toString())
count(contents: _*)

case m: AssistantMessage =>
count(m.content)

case m: AssistantToolMessage =>
val toolCallTokens = m.tool_calls.map { case (id, toolSpec) =>
toolSpec match {
case toolSpec: FunctionCallSpec =>
count(
id,
toolSpec.name,
toolSpec.arguments
) + 3 // plus extra three tokens per function/tool call
}
}

toolCallTokens.sum + countOpt(m.content)

case m: AssistantFunMessage =>
val funCallTokens = m.function_call
.map(c => count(c.name, c.arguments) + 3 // plus extra three tokens per function call
)
.getOrElse(0)

funCallTokens + countOpt(m.content)

case m: ToolMessage =>
count(m.tool_call_id) + countOpt(m.content)

case m: FunMessage =>
count(m.content)

case m: MessageSpec =>
count(m.content)
}
}

def countFunMessageTokens(
model: String,
messages: Seq[BaseMessage],
functions: Seq[FunctionSpec],
responseFunctionName: Option[String]
): Int = {
val encoding = registry.getEncodingForModel(model).orElseThrow
val (tokensPerMessage, tokensPerName) = tokensPerMessageAndName(model)

def countMessageTokens(message: BaseMessage) =
countMessageTokensAux(tokensPerMessage, tokensPerName, encoding)(message)

val messagesTokensCount = messages.foldLeft((false, 0)) { case ((paddedSystem, count), message) =>
val (newPaddedFlag, paddedMessage) = if (message.role == ChatRole.System && !paddedSystem) {
message match {
case m: SystemMessage =>
(true, m.copy(content = m.content + "\n"))
case m: MessageSpec if m.role == ChatRole.System =>
(true, m.copy(content = m.content + "\n"))
case _ =>
throw new IllegalArgumentException(s"Unexpected message: $message")
}
} else {
(paddedSystem, message)
}

(newPaddedFlag, count + countMessageTokens(paddedMessage))
}._2

val functionsTokensCount = functionsTokensEstimate(encoding, functions)
val systemRoleAdjustment = if (messages.exists(m => m.role == ChatRole.System)) -4 else 0
val responseFunctionNameCount = responseFunctionName.map(name => encoding.countTokens(name) + 4).getOrElse(0)

messagesTokensCount + functionsTokensCount + systemRoleAdjustment + responseFunctionNameCount + 3
}

private def functionsTokensEstimate(
encoding: Encoding,
functions: Seq[FunctionSpec]
): Int = {
val promptDefinitions = OpenAIFunctionsImpl.formatFunctionDefinitions(functions)
encoding.countTokens(promptDefinitions) + 9
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ import scala.collection.mutable.ListBuffer
import scala.reflect.ClassTag

//rewritten from https://github.com/hmarr/openai-chat-tokens
// TODO: consider using a json schema; also avoid using mutable data structures
object OpenAIFunctionsImpl {
def formatFunctionDefinitions(functions: List[FunctionSpec]): String = {
def formatFunctionDefinitions(functions: Seq[FunctionSpec]): String = {
val lines = ListBuffer("namespace functions {", "")
for (f: FunctionSpec <- functions) {
if (f.description.isDefined) {
Expand Down Expand Up @@ -39,21 +40,26 @@ object OpenAIFunctionsImpl {
case Some(r) => r.asInstanceOf[Seq[String]]
case None => Seq.empty[String]
}

val lines = scala.collection.mutable.ArrayBuffer[String]()

for ((name, param) <- properties) {
val paramAsInstance = param.asInstanceOf[Map[String, Any]]
paramAsInstance.get("description") match {
case Some(v) if indent < 2 =>
lines += s"// ${v}"
case _ => ()
}

val paramType = formatType(paramAsInstance, indent)

if (required.contains(name)) {
lines += s"$name: $paramType,"
} else {
lines += s"$name?: $paramType,"
}
}

lines.map(line => " " * indent + line).mkString("\n")
}

Expand All @@ -63,6 +69,7 @@ object OpenAIFunctionsImpl {
): String = {
implicit val ctMSA: ClassTag[Map[String, Any]] = ClassTag(classOf[Map[String, Any]])
implicit val ctSS: ClassTag[Seq[String]] = ClassTag(classOf[Seq[String]])

param.get("type") match {
case Some("string") =>
param.get("enum") match {
Expand Down
Loading

0 comments on commit a59cffe

Please sign in to comment.