-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Special handling for O1 models (chat completion)
- Loading branch information
1 parent
0298726
commit 1bdb8d3
Showing
5 changed files
with
227 additions
and
40 deletions.
There are no files selected for viewing
39 changes: 39 additions & 0 deletions
39
...scala/io/cequence/openaiscala/service/OpenAIChatCompletionStreamedConversionAdapter.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
package io.cequence.openaiscala.service | ||
|
||
import akka.NotUsed | ||
import akka.stream.scaladsl.Source | ||
import io.cequence.openaiscala.domain.BaseMessage | ||
import io.cequence.openaiscala.domain.response.ChatCompletionChunkResponse | ||
import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettings | ||
|
||
class OpenAIChatCompletionStreamedConversionAdapter { | ||
def apply( | ||
service: OpenAIChatCompletionStreamedServiceExtra, | ||
messagesConversion: Seq[BaseMessage] => Seq[BaseMessage], | ||
settingsConversion: CreateChatCompletionSettings => CreateChatCompletionSettings | ||
): OpenAIChatCompletionStreamedServiceExtra = | ||
new OpenAIChatCompletionStreamedConversionAdapterImpl( | ||
service, | ||
messagesConversion, | ||
settingsConversion | ||
) | ||
|
||
final private class OpenAIChatCompletionStreamedConversionAdapterImpl( | ||
underlying: OpenAIChatCompletionStreamedServiceExtra, | ||
messagesConversion: Seq[BaseMessage] => Seq[BaseMessage], | ||
settingsConversion: CreateChatCompletionSettings => CreateChatCompletionSettings | ||
) extends OpenAIChatCompletionStreamedServiceExtra { | ||
|
||
override def createChatCompletionStreamed( | ||
messages: Seq[BaseMessage], | ||
settings: CreateChatCompletionSettings | ||
): Source[ChatCompletionChunkResponse, NotUsed] = | ||
underlying.createChatCompletionStreamed( | ||
messagesConversion(messages), | ||
settingsConversion(settings) | ||
) | ||
|
||
override def close(): Unit = | ||
underlying.close() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
75 changes: 75 additions & 0 deletions
75
...ain/scala/io/cequence/openaiscala/service/adapter/ChatCompletionSettingsConversions.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
package io.cequence.openaiscala.service.adapter | ||
|
||
import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettings | ||
import org.slf4j.LoggerFactory | ||
|
||
object ChatCompletionSettingsConversions { | ||
|
||
private val logger = LoggerFactory.getLogger(getClass) | ||
|
||
type SettingsConversion = CreateChatCompletionSettings => CreateChatCompletionSettings | ||
|
||
case class FieldConversionDef( | ||
doConversion: CreateChatCompletionSettings => Boolean, | ||
convert: CreateChatCompletionSettings => CreateChatCompletionSettings, | ||
loggingMessage: Option[String], | ||
warning: Boolean = false | ||
) | ||
|
||
def generic( | ||
fieldConversions: Seq[FieldConversionDef] | ||
): SettingsConversion = (settings: CreateChatCompletionSettings) => | ||
fieldConversions.foldLeft(settings) { | ||
case (acc, FieldConversionDef(isDefined, convert, loggingMessage, warning)) => | ||
if (isDefined(acc)) { | ||
loggingMessage.foreach(message => | ||
if (warning) logger.warn(message) else logger.debug(message) | ||
) | ||
convert(acc) | ||
} else acc | ||
} | ||
|
||
private val o1Conversions = Seq( | ||
// max tokens | ||
FieldConversionDef( | ||
_.max_tokens.isDefined, | ||
settings => | ||
settings.copy( | ||
max_tokens = None, | ||
extra_params = | ||
settings.extra_params + ("max_completion_tokens" -> settings.max_tokens.get) | ||
), | ||
Some("O1 models don't support max_tokens, converting to max_completion_tokens") | ||
), | ||
// temperature | ||
FieldConversionDef( | ||
settings => settings.temperature.isDefined && settings.temperature.get != 1, | ||
_.copy(temperature = Some(1d)), | ||
Some("O1 models don't support temperature values other than the default of 1, converting to 1."), | ||
warning = true | ||
), | ||
// top_p | ||
FieldConversionDef( | ||
settings => settings.top_p.isDefined && settings.top_p.get != 1, | ||
_.copy(top_p = Some(1d)), | ||
Some("O1 models don't support top p values other than the default of 1, converting to 1."), | ||
warning = true | ||
), | ||
// presence_penalty | ||
FieldConversionDef( | ||
settings => settings.presence_penalty.isDefined && settings.presence_penalty.get != 0, | ||
_.copy(presence_penalty = Some(0d)), | ||
Some("O1 models don't support presence penalty values other than the default of 0, converting to 0."), | ||
warning = true | ||
), | ||
// frequency_penalty | ||
FieldConversionDef( | ||
settings => settings.frequency_penalty.isDefined && settings.frequency_penalty.get != 0, | ||
_.copy(frequency_penalty = Some(0d)), | ||
Some("O1 models don't support frequency penalty values other than the default of 0, converting to 0."), | ||
warning = true | ||
) | ||
) | ||
|
||
val o1Specific: SettingsConversion = generic(o1Conversions) | ||
} |
35 changes: 35 additions & 0 deletions
35
openai-core/src/main/scala/io/cequence/openaiscala/service/adapter/MessageConversions.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
package io.cequence.openaiscala.service.adapter | ||
|
||
import io.cequence.openaiscala.domain.{BaseMessage, SystemMessage, UserMessage} | ||
import org.slf4j.LoggerFactory | ||
|
||
object MessageConversions { | ||
|
||
private val logger = LoggerFactory.getLogger(getClass) | ||
|
||
type MessageConversion = Seq[BaseMessage] => Seq[BaseMessage] | ||
|
||
val systemToUserMessages: MessageConversion = | ||
(messages: Seq[BaseMessage]) => { | ||
val nonSystemMessages = messages.map { | ||
case SystemMessage(content, _) => | ||
logger.warn(s"System message found but not supported by an underlying model. Converting to a user message instead: '${content}'") | ||
UserMessage(s"System: ${content}") | ||
|
||
case x: BaseMessage => x | ||
} | ||
|
||
// there cannot be two consecutive user messages, so we need to merge them | ||
nonSystemMessages.foldLeft(Seq.empty[BaseMessage]) { | ||
case (acc, UserMessage(content, _)) if acc.nonEmpty => | ||
acc.last match { | ||
case UserMessage(lastContent, _) => | ||
acc.init :+ UserMessage(lastContent + "\n" + content) | ||
case _ => | ||
acc :+ UserMessage(content) | ||
} | ||
|
||
case (acc, message) => acc :+ message | ||
} | ||
} | ||
} |