Skip to content

Commit

Permalink
Chat completion provider refactored
Browse files Browse the repository at this point in the history
  • Loading branch information
peterbanda committed Sep 18, 2024
1 parent 51adcbf commit f6a9d5f
Show file tree
Hide file tree
Showing 18 changed files with 63 additions and 102 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ This is a no-nonsense async Scala client for OpenAI API supporting all the avail

* **Models**: [listModels](https://platform.openai.com/docs/api-reference/models/list), and [retrieveModel](https://platform.openai.com/docs/api-reference/models/retrieve)
* **Completions**: [createCompletion](https://platform.openai.com/docs/api-reference/completions/create)
* **Chat Completions**: [createChatCompletion](https://platform.openai.com/docs/api-reference/chat/create) (also with GPT vision support!), [createChatFunCompletion](https://platform.openai.com/docs/api-reference/chat/create) (deprecated), and [createChatToolCompletion](https://platform.openai.com/docs/api-reference/chat/create)
* **Chat Completions**: [createChatCompletion](https://platform.openai.com/docs/api-reference/chat/create) (also with JSON schema support 🔥), [createChatFunCompletion](https://platform.openai.com/docs/api-reference/chat/create) (deprecated), and [createChatToolCompletion](https://platform.openai.com/docs/api-reference/chat/create)
* **Edits**: [createEdit](https://platform.openai.com/docs/api-reference/edits/create) (deprecated)
* **Images**: [createImage](https://platform.openai.com/docs/api-reference/images/create), [createImageEdit](https://platform.openai.com/docs/api-reference/images/create-edit), and [createImageVariation](https://platform.openai.com/docs/api-reference/images/create-variation)
* **Embeddings**: [createEmbeddings](https://platform.openai.com/docs/api-reference/embeddings/create)
Expand Down Expand Up @@ -33,10 +33,11 @@ In addition to the OpenAI API, this library also supports API-compatible provide
- [Azure AI](https://azure.microsoft.com/en-us/products/ai-studio) - cloud-based, offers a vast selection of open-source models
- [Anthropic](https://www.anthropic.com/api) - cloud-based, a major competitor to OpenAI, features proprietary/closed-source models such as Claude3 - Haiku, Sonnet, and Opus
- [Google Vertex AI](https://cloud.google.com/vertex-ai) (🔥 **New**) - cloud-based, features proprietary/closed-source models such as Gemini 1.5 Pro and flash
- [Groq](https://wow.groq.com/) - cloud-based provider, known for its super-fast inference with LPUs
- [Groq](https://wow.groq.com/) - cloud-based provider, known for its superfast inference with LPUs
- [Fireworks AI](https://fireworks.ai/) - cloud-based provider
- [OctoAI](https://octo.ai/) - cloud-based provider
- [TogetherAI](https://www.together.ai/) (🔥 **New**) - cloud-based provider
- [Cerebras](https://cerebras.ai/) (🔥 **New**) - cloud-based provider, superfast (akin to Groq)
- [Mistral](https://mistral.ai/) (🔥 **New**) - cloud-based, leading open-source LLM company
- [Ollama](https://ollama.com/) - runs locally, serves as an umbrella for open-source LLMs including LLaMA3, dbrx, and Command-R
- [FastChat](https://github.com/lm-sys/FastChat) - runs locally, serves as an umbrella for open-source LLMs such as Vicuna, Alpaca, and FastChat-T5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ final case class FunMessage(

/**
* Deprecation warning: Use typed Message(s), such as SystemMessage, UserMessage, instead.
* Will be dropped in the next major version.
*/
@Deprecated
final case class MessageSpec(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,128 +2,102 @@ package io.cequence.openaiscala.examples

import akka.stream.Materializer
import io.cequence.openaiscala.anthropic.service.AnthropicServiceFactory
import io.cequence.openaiscala.service.{
OpenAIChatCompletionService,
OpenAIChatCompletionServiceFactory,
OpenAIChatCompletionStreamedServiceExtra,
OpenAIChatCompletionStreamedServiceFactory
}
import io.cequence.openaiscala.service.OpenAIChatCompletionServiceFactory
import io.cequence.openaiscala.service.OpenAIStreamedServiceImplicits._
import io.cequence.openaiscala.vertexai.service.VertexAIServiceFactory
import io.cequence.wsclient.domain.WsRequestContext

import scala.concurrent.ExecutionContext
import io.cequence.openaiscala.service.StreamedServiceTypes
import io.cequence.openaiscala.service.StreamedServiceTypes.OpenAIChatCompletionStreamedService

object ChatCompletionProvider {
case class ProviderSettings(
private case class ProviderSettings(
coreUrl: String,
apiKeyEnvVariable: String
)

val Cerebras = ProviderSettings("https://api.cerebras.ai/v1/", "CEREBRAS_API_KEY")
val Groq = ProviderSettings("https://api.groq.com/openai/v1/", "GROQ_API_KEY")
val Fireworks =
private val Cerebras = ProviderSettings("https://api.cerebras.ai/v1/", "CEREBRAS_API_KEY")
private val Groq = ProviderSettings("https://api.groq.com/openai/v1/", "GROQ_API_KEY")
private val Fireworks =
ProviderSettings("https://api.fireworks.ai/inference/v1/", "FIREWORKS_API_KEY")
val Mistral = ProviderSettings("https://api.mistral.ai/v1/", "MISTRAL_API_KEY")
val OctoML = ProviderSettings("https://text.octoai.run/v1/", "OCTOAI_TOKEN")
val TogetherAI = ProviderSettings("https://api.together.xyz/v1/", "TOGETHERAI_API_KEY")
private val Mistral = ProviderSettings("https://api.mistral.ai/v1/", "MISTRAL_API_KEY")
private val OctoML = ProviderSettings("https://text.octoai.run/v1/", "OCTOAI_TOKEN")
private val TogetherAI = ProviderSettings("https://api.together.xyz/v1/", "TOGETHERAI_API_KEY")

/**
* Requires `CEREBRAS_API_KEY`
*/
def cerebras(
implicit ec: ExecutionContext,
m: Materializer
): OpenAIChatCompletionService = provide(Cerebras)
): OpenAIChatCompletionStreamedService = provide(Cerebras)

/**
* Requires `GROQ_API_KEY`
*/
def groq(
implicit ec: ExecutionContext,
m: Materializer
): OpenAIChatCompletionService = provide(Groq)
): OpenAIChatCompletionStreamedService = provide(Groq)

/**
* Requires `FIREWORKS_API_KEY`
*/
def fireworks(
implicit ec: ExecutionContext,
m: Materializer
): OpenAIChatCompletionService = provide(Fireworks)
): OpenAIChatCompletionStreamedService = provide(Fireworks)

/**
* Requires `MISTRAL_API_KEY`
*/
def mistral(
implicit ec: ExecutionContext,
m: Materializer
): OpenAIChatCompletionService = provide(Mistral)
): OpenAIChatCompletionStreamedService = provide(Mistral)

/**
* Requires `OCTOAI_TOKEN`
*/
def octoML(
implicit ec: ExecutionContext,
m: Materializer
): OpenAIChatCompletionService = provide(OctoML)
): OpenAIChatCompletionStreamedService = provide(OctoML)

/**
* Requires `TOGETHERAI_API_KEY`
*/
def togetherAI(
implicit ec: ExecutionContext,
m: Materializer
): OpenAIChatCompletionService = provide(TogetherAI)
): OpenAIChatCompletionStreamedService = provide(TogetherAI)

/**
* Requires `VERTEXAI_API_KEY` and "VERTEXAI_LOCATION"
*/
def vertexAI(
implicit ec: ExecutionContext,
m: Materializer
): StreamedServiceTypes.OpenAIChatCompletionStreamedService =
implicit ec: ExecutionContext
): OpenAIChatCompletionStreamedService =
VertexAIServiceFactory.asOpenAI()

/**
* Requires `ANTHROPIC_API_KEY`
*/
def anthropic(
implicit ec: ExecutionContext,
m: Materializer
): StreamedServiceTypes.OpenAIChatCompletionStreamedService =
): OpenAIChatCompletionStreamedService =
AnthropicServiceFactory.asOpenAI()

object streamed {
def cerebras(
implicit ec: ExecutionContext,
m: Materializer
): OpenAIChatCompletionStreamedServiceExtra = provideStreamed(Cerebras)

def groq(
implicit ec: ExecutionContext,
m: Materializer
): OpenAIChatCompletionStreamedServiceExtra = provideStreamed(Groq)

def fireworks(
implicit ec: ExecutionContext,
m: Materializer
): OpenAIChatCompletionStreamedServiceExtra = provideStreamed(Fireworks)

def mistral(
implicit ec: ExecutionContext,
m: Materializer
): OpenAIChatCompletionStreamedServiceExtra = provideStreamed(Mistral)

def octoML(
implicit ec: ExecutionContext,
m: Materializer
): OpenAIChatCompletionStreamedServiceExtra = provideStreamed(OctoML)

def togetherAI(
implicit ec: ExecutionContext,
m: Materializer
): OpenAIChatCompletionStreamedServiceExtra = provideStreamed(TogetherAI)
}

private def provide(
settings: ProviderSettings
)(
implicit ec: ExecutionContext,
m: Materializer
): OpenAIChatCompletionService = OpenAIChatCompletionServiceFactory(
coreUrl = settings.coreUrl,
WsRequestContext(authHeaders =
Seq(("Authorization", s"Bearer ${sys.env(settings.apiKeyEnvVariable)}"))
)
)

private def provideStreamed(
settings: ProviderSettings
)(
implicit ec: ExecutionContext,
m: Materializer
): OpenAIChatCompletionStreamedServiceExtra = OpenAIChatCompletionStreamedServiceFactory(
): OpenAIChatCompletionStreamedService = OpenAIChatCompletionServiceFactory.withStreaming(
coreUrl = settings.coreUrl,
WsRequestContext(authHeaders =
Seq(("Authorization", s"Bearer ${sys.env(settings.apiKeyEnvVariable)}"))
)
)

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ object AnthropicCreateChatCompletionStreamedWithOpenAIAdapter

private val logger = LoggerFactory.getLogger(this.getClass)

override val service: OpenAIChatCompletionStreamedService =
ChatCompletionProvider.anthropic
override val service: OpenAIChatCompletionStreamedService = ChatCompletionProvider.anthropic

private val messages = Seq(
SystemMessage("You are a helpful assistant."),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ import scala.concurrent.Future
object AnthropicCreateChatCompletionWithOpenAIAdapter
extends ExampleBase[OpenAIChatCompletionService] {

override val service: OpenAIChatCompletionService =
ChatCompletionProvider.anthropic
override val service: OpenAIChatCompletionService = ChatCompletionProvider.anthropic

private val messages = Seq(
SystemMessage("You are a helpful assistant."),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ import scala.concurrent.Future
*/
object CerebrasCreateChatCompletion extends ExampleBase[OpenAIChatCompletionService] {

override val service: OpenAIChatCompletionService =
ChatCompletionProvider.cerebras
override val service: OpenAIChatCompletionService = ChatCompletionProvider.cerebras

private val messages = Seq(
SystemMessage("You are a helpful assistant."),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ import scala.concurrent.Future
object CerebrasCreateChatCompletionStreamed
extends ExampleBase[OpenAIChatCompletionStreamedServiceExtra] {

override val service: OpenAIChatCompletionStreamedServiceExtra =
ChatCompletionProvider.streamed.cerebras
override val service: OpenAIChatCompletionStreamedServiceExtra = ChatCompletionProvider.cerebras

private val messages = Seq(
SystemMessage("You are a helpful assistant."),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ import scala.concurrent.Future
object FireworksAICreateChatCompletion extends ExampleBase[OpenAIChatCompletionService] {

private val fireworksModelPrefix = "accounts/fireworks/models/"
override val service: OpenAIChatCompletionService =
ChatCompletionProvider.fireworks
override val service: OpenAIChatCompletionService = ChatCompletionProvider.fireworks

private val messages = Seq(
SystemMessage("You are a helpful assistant."),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ object FireworksAICreateChatCompletionStreamed
extends ExampleBase[OpenAIChatCompletionStreamedServiceExtra] {

private val fireworksModelPrefix = "accounts/fireworks/models/"
override val service: OpenAIChatCompletionStreamedServiceExtra =
ChatCompletionProvider.streamed.fireworks
override val service: OpenAIChatCompletionStreamedServiceExtra = ChatCompletionProvider.fireworks

private val messages = Seq(
SystemMessage("You are a helpful assistant."),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ import scala.concurrent.Future
*/
object GroqCreateChatCompletion extends ExampleBase[OpenAIChatCompletionService] {

override val service: OpenAIChatCompletionService =
ChatCompletionProvider.groq
override val service: OpenAIChatCompletionService = ChatCompletionProvider.groq

private val messages = Seq(
SystemMessage("You are a helpful assistant."),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ import scala.concurrent.Future
object GroqCreateChatCompletionStreamed
extends ExampleBase[OpenAIChatCompletionStreamedServiceExtra] {

override val service: OpenAIChatCompletionStreamedServiceExtra =
ChatCompletionProvider.streamed.groq
override val service: OpenAIChatCompletionStreamedServiceExtra = ChatCompletionProvider.groq

private val messages = Seq(
SystemMessage("You are a helpful assistant."),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ import scala.concurrent.Future
// requires `MISTRAL_API_KEY` environment variable to be set
object MistralCreateChatCompletion extends ExampleBase[OpenAIChatCompletionService] {

override val service: OpenAIChatCompletionService =
ChatCompletionProvider.mistral
override val service: OpenAIChatCompletionService = ChatCompletionProvider.mistral

private val messages = Seq(
SystemMessage("You are a helpful assistant."),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ import scala.concurrent.Future
object MistralCreateChatCompletionStreamed
extends ExampleBase[OpenAIChatCompletionStreamedServiceExtra] {

override val service: OpenAIChatCompletionStreamedServiceExtra =
ChatCompletionProvider.streamed.mistral
override val service: OpenAIChatCompletionStreamedServiceExtra = ChatCompletionProvider.mistral

private val messages = Seq(
SystemMessage("You are a helpful assistant."),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ import scala.concurrent.Future
// requires `OCTOAI_TOKEN` environment variable to be set
object OctoMLCreateChatCompletion extends ExampleBase[OpenAIChatCompletionService] {

override val service: OpenAIChatCompletionService =
ChatCompletionProvider.octoML
override val service: OpenAIChatCompletionService = ChatCompletionProvider.octoML

private val messages = Seq(
SystemMessage("You are a helpful assistant."),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ import scala.concurrent.Future
object OctoMLCreateChatCompletionStreamed
extends ExampleBase[OpenAIChatCompletionStreamedServiceExtra] {

override val service: OpenAIChatCompletionStreamedServiceExtra =
ChatCompletionProvider.streamed.octoML
override val service: OpenAIChatCompletionStreamedServiceExtra = ChatCompletionProvider.octoML

private val messages = Seq(
SystemMessage("You are a helpful assistant."),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ import scala.concurrent.Future
*/
object TogetherAICreateChatCompletion extends ExampleBase[OpenAIChatCompletionService] {

override val service: OpenAIChatCompletionService =
ChatCompletionProvider.togetherAI
override val service: OpenAIChatCompletionService = ChatCompletionProvider.togetherAI

private val messages = Seq(
SystemMessage("You are a helpful assistant."),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ import scala.concurrent.Future
object VertexAICreateChatCompletionStreamedWithOpenAIAdapter
extends ExampleBase[OpenAIChatCompletionService] {

override val service: OpenAIChatCompletionStreamedService =
ChatCompletionProvider.vertexAI
override val service: OpenAIChatCompletionStreamedService = ChatCompletionProvider.vertexAI

private val model = NonOpenAIModelId.gemini_1_5_flash_001

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ import scala.concurrent.Future
object VertexAICreateChatCompletionWithOpenAIAdapter
extends ExampleBase[OpenAIChatCompletionService] {

override val service: OpenAIChatCompletionService =
ChatCompletionProvider.vertexAI
override val service: OpenAIChatCompletionService = ChatCompletionProvider.vertexAI

private val model = NonOpenAIModelId.gemini_1_5_pro_001

Expand Down

0 comments on commit f6a9d5f

Please sign in to comment.