Skip to content

Commit

Permalink
Special handling for O1 models (chat completion)
Browse files Browse the repository at this point in the history
  • Loading branch information
peterbanda committed Sep 17, 2024
1 parent 0298726 commit 1bdb8d3
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 40 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package io.cequence.openaiscala.service

import akka.NotUsed
import akka.stream.scaladsl.Source
import io.cequence.openaiscala.domain.BaseMessage
import io.cequence.openaiscala.domain.response.ChatCompletionChunkResponse
import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettings

class OpenAIChatCompletionStreamedConversionAdapter {
def apply(
service: OpenAIChatCompletionStreamedServiceExtra,
messagesConversion: Seq[BaseMessage] => Seq[BaseMessage],
settingsConversion: CreateChatCompletionSettings => CreateChatCompletionSettings
): OpenAIChatCompletionStreamedServiceExtra =
new OpenAIChatCompletionStreamedConversionAdapterImpl(
service,
messagesConversion,
settingsConversion
)

final private class OpenAIChatCompletionStreamedConversionAdapterImpl(
underlying: OpenAIChatCompletionStreamedServiceExtra,
messagesConversion: Seq[BaseMessage] => Seq[BaseMessage],
settingsConversion: CreateChatCompletionSettings => CreateChatCompletionSettings
) extends OpenAIChatCompletionStreamedServiceExtra {

override def createChatCompletionStreamed(
messages: Seq[BaseMessage],
settings: CreateChatCompletionSettings
): Source[ChatCompletionChunkResponse, NotUsed] =
underlying.createChatCompletionStreamed(
messagesConversion(messages),
settingsConversion(settings)
)

override def close(): Unit =
underlying.close()
}
}
5 changes: 0 additions & 5 deletions openai-client/build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,3 @@ libraryDependencies ++= Seq(
"org.scalatest" %% "scalatest" % "3.2.18" % Test,
"org.scalamock" %% "scalamock" % scalaMock % Test
)

//libraryDependencies ++= Seq(
// "com.typesafe.scala-logging" %% "scala-logging" % "3.9.5",
// "ch.qos.logback" % "logback-classic" % "1.4.7"
//)
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
package io.cequence.openaiscala.service.impl

import io.cequence.openaiscala.JsonFormats._
import io.cequence.openaiscala.domain.BaseMessage
import io.cequence.openaiscala.domain.{BaseMessage, ModelId}
import io.cequence.openaiscala.domain.response._
import io.cequence.openaiscala.domain.settings._
import io.cequence.openaiscala.service.adapter.{
ChatCompletionSettingsConversions,
MessageConversions
}
import io.cequence.openaiscala.service.{OpenAIChatCompletionService, OpenAIServiceConsts}
import io.cequence.wsclient.JsonUtil
import io.cequence.wsclient.ResponseImplicits._
import io.cequence.wsclient.service.WSClient
import io.cequence.wsclient.service.WSClientWithEngineTypes.WSClientWithEngine
import play.api.libs.json.{JsValue, Json}
import play.api.libs.json.{JsObject, JsValue, Json}

import scala.concurrent.Future

Expand Down Expand Up @@ -45,76 +50,114 @@ trait ChatCompletionBodyMaker {

this: WSClient =>

private val o1Models = Set(
ModelId.o1_preview,
ModelId.o1_preview_2024_09_12,
ModelId.o1_mini,
ModelId.o1_mini_2024_09_12
)

protected def createBodyParamsForChatCompletion(
messages: Seq[BaseMessage],
messagesAux: Seq[BaseMessage],
settings: CreateChatCompletionSettings,
stream: Boolean
): Seq[(Param, Option[JsValue])] = {
assert(messages.nonEmpty, "At least one message expected.")
assert(messagesAux.nonEmpty, "At least one message expected.")

val messageJsons = messages.map(Json.toJson(_)(messageWrites))
// O1 models needs some special treatment... revisit this later
val messagesFinal =
if (o1Models.contains(settings.model))
MessageConversions.systemToUserMessages(messagesAux)
else
messagesAux

val messageJsons = messagesFinal.map(Json.toJson(_)(messageWrites))

// O1 models needs some special treatment... revisit this later
val settingsFinal =
if (o1Models.contains(settings.model))
ChatCompletionSettingsConversions.o1Specific(settings)
else
settings

jsonBodyParams(
Param.messages -> Some(messageJsons),
Param.model -> Some(settings.model),
Param.temperature -> settings.temperature,
Param.top_p -> settings.top_p,
Param.n -> settings.n,
Param.model -> Some(settingsFinal.model),
Param.temperature -> settingsFinal.temperature,
Param.top_p -> settingsFinal.top_p,
Param.n -> settingsFinal.n,
Param.stream -> Some(stream),
Param.stop -> {
settings.stop.size match {
settingsFinal.stop.size match {
case 0 => None
case 1 => Some(settings.stop.head)
case _ => Some(settings.stop)
case 1 => Some(settingsFinal.stop.head)
case _ => Some(settingsFinal.stop)
}
},
Param.max_tokens -> settings.max_tokens,
Param.presence_penalty -> settings.presence_penalty,
Param.frequency_penalty -> settings.frequency_penalty,
Param.max_tokens -> settingsFinal.max_tokens,
Param.presence_penalty -> settingsFinal.presence_penalty,
Param.frequency_penalty -> settingsFinal.frequency_penalty,
Param.logit_bias -> {
if (settings.logit_bias.isEmpty) None else Some(settings.logit_bias)
if (settingsFinal.logit_bias.isEmpty) None else Some(settingsFinal.logit_bias)
},
Param.user -> settings.user,
Param.logprobs -> settings.logprobs,
Param.top_logprobs -> settings.top_logprobs,
Param.seed -> settings.seed,
Param.user -> settingsFinal.user,
Param.logprobs -> settingsFinal.logprobs,
Param.top_logprobs -> settingsFinal.top_logprobs,
Param.seed -> settingsFinal.seed,
Param.response_format -> {
settings.response_format_type.map { (formatType: ChatCompletionResponseFormatType) =>
if (formatType != ChatCompletionResponseFormatType.json_schema)
Map("type" -> formatType.toString)
else
handleJsonSchema(settings)
settingsFinal.response_format_type.map {
(formatType: ChatCompletionResponseFormatType) =>
if (formatType != ChatCompletionResponseFormatType.json_schema)
Map("type" -> formatType.toString)
else
handleJsonSchema(settingsFinal)
}
},
Param.extra_params -> {
if (settings.extra_params.nonEmpty) Some(settings.extra_params) else None
if (settingsFinal.extra_params.nonEmpty) Some(settingsFinal.extra_params) else None
}
)
}

private def handleJsonSchema(
settings: CreateChatCompletionSettings
): Map[String, Any] =
settings.jsonSchema.map { case JsonSchema(name, strict, structure) =>
val adjustedSchema = if (strict) {
settings.jsonSchema.map { case JsonSchemaDef(name, strict, structure) =>
val schemaMap: Map[String, Any] = structure match {
case Left(schema) =>
val json = Json.toJson(schema).as[JsObject]
JsonUtil.toValueMap(json)

case Right(schema) => schema
}

val adjustedSchema: Map[String, Any] = if (strict) {
// set "additionalProperties" -> false on "object" types if strict
def addFlagAux(map: Map[String, Any]): Map[String, Any] = {
val newMap = map.map { case (key, value) =>
val newValue = value match {
case obj: Map[String, Any] => addFlagAux(obj)
case other => other
val unwrappedValue = value match {
case Some(value) => value
case other => other
}

val newValue = unwrappedValue match {
case obj: Map[String, Any] =>
addFlagAux(obj)

case other =>
other
}
key -> newValue
}

if (map.get("type").contains("object"))
if (Seq("object", Some("object")).contains(map.getOrElse("type", ""))) {
newMap + ("additionalProperties" -> false)
else
} else
newMap
}

addFlagAux(structure)
} else structure
addFlagAux(schemaMap)
} else schemaMap

Map(
"type" -> "json_schema",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package io.cequence.openaiscala.service.adapter

import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettings
import org.slf4j.LoggerFactory

object ChatCompletionSettingsConversions {

private val logger = LoggerFactory.getLogger(getClass)

type SettingsConversion = CreateChatCompletionSettings => CreateChatCompletionSettings

case class FieldConversionDef(
doConversion: CreateChatCompletionSettings => Boolean,
convert: CreateChatCompletionSettings => CreateChatCompletionSettings,
loggingMessage: Option[String],
warning: Boolean = false
)

def generic(
fieldConversions: Seq[FieldConversionDef]
): SettingsConversion = (settings: CreateChatCompletionSettings) =>
fieldConversions.foldLeft(settings) {
case (acc, FieldConversionDef(isDefined, convert, loggingMessage, warning)) =>
if (isDefined(acc)) {
loggingMessage.foreach(message =>
if (warning) logger.warn(message) else logger.debug(message)
)
convert(acc)
} else acc
}

private val o1Conversions = Seq(
// max tokens
FieldConversionDef(
_.max_tokens.isDefined,
settings =>
settings.copy(
max_tokens = None,
extra_params =
settings.extra_params + ("max_completion_tokens" -> settings.max_tokens.get)
),
Some("O1 models don't support max_tokens, converting to max_completion_tokens")
),
// temperature
FieldConversionDef(
settings => settings.temperature.isDefined && settings.temperature.get != 1,
_.copy(temperature = Some(1d)),
Some("O1 models don't support temperature values other than the default of 1, converting to 1."),
warning = true
),
// top_p
FieldConversionDef(
settings => settings.top_p.isDefined && settings.top_p.get != 1,
_.copy(top_p = Some(1d)),
Some("O1 models don't support top p values other than the default of 1, converting to 1."),
warning = true
),
// presence_penalty
FieldConversionDef(
settings => settings.presence_penalty.isDefined && settings.presence_penalty.get != 0,
_.copy(presence_penalty = Some(0d)),
Some("O1 models don't support presence penalty values other than the default of 0, converting to 0."),
warning = true
),
// frequency_penalty
FieldConversionDef(
settings => settings.frequency_penalty.isDefined && settings.frequency_penalty.get != 0,
_.copy(frequency_penalty = Some(0d)),
Some("O1 models don't support frequency penalty values other than the default of 0, converting to 0."),
warning = true
)
)

val o1Specific: SettingsConversion = generic(o1Conversions)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package io.cequence.openaiscala.service.adapter

import io.cequence.openaiscala.domain.{BaseMessage, SystemMessage, UserMessage}
import org.slf4j.LoggerFactory

object MessageConversions {

private val logger = LoggerFactory.getLogger(getClass)

type MessageConversion = Seq[BaseMessage] => Seq[BaseMessage]

val systemToUserMessages: MessageConversion =
(messages: Seq[BaseMessage]) => {
val nonSystemMessages = messages.map {
case SystemMessage(content, _) =>
logger.warn(s"System message found but not supported by an underlying model. Converting to a user message instead: '${content}'")
UserMessage(s"System: ${content}")

case x: BaseMessage => x
}

// there cannot be two consecutive user messages, so we need to merge them
nonSystemMessages.foldLeft(Seq.empty[BaseMessage]) {
case (acc, UserMessage(content, _)) if acc.nonEmpty =>
acc.last match {
case UserMessage(lastContent, _) =>
acc.init :+ UserMessage(lastContent + "\n" + content)
case _ =>
acc :+ UserMessage(content)
}

case (acc, message) => acc :+ message
}
}
}

0 comments on commit 1bdb8d3

Please sign in to comment.