Skip to content

Commit

Permalink
Provider settings introduced + examples adjusted
Browse files Browse the repository at this point in the history
  • Loading branch information
peterbanda committed Sep 18, 2024
1 parent f6a9d5f commit a29edb7
Show file tree
Hide file tree
Showing 20 changed files with 140 additions and 69 deletions.
85 changes: 63 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,7 @@ Then you can obtain a service in one of the following ways.

- `OpenAIChatCompletionService` providing solely `createChatCompletion`

1. [Groq](https://wow.groq.com/)
```scala
val service = OpenAIChatCompletionServiceFactory(
coreUrl = "https://api.groq.com/openai/v1/",
authHeaders = Seq(("Authorization", s"Bearer ${sys.env("GROQ_API_KEY")}"))
)
```

2. [Azure AI](https://azure.microsoft.com/en-us/products/ai-studio) - e.g. Cohere R+ model
1. [Azure AI](https://azure.microsoft.com/en-us/products/ai-studio) - e.g. Cohere R+ model
```scala
val service = OpenAIChatCompletionServiceFactory.forAzureAI(
endpoint = sys.env("AZURE_AI_COHERE_R_PLUS_ENDPOINT"),
Expand All @@ -148,35 +140,84 @@ Then you can obtain a service in one of the following ways.
)
```

3. [Anthropic](https://www.anthropic.com/api) (requires our `openai-scala-anthropic-client` lib)
2. [Anthropic](https://www.anthropic.com/api) - requires `openai-scala-anthropic-client` lib and `ANTHROPIC_API_KEY`
```scala
val service = AnthropicServiceFactory.asOpenAI()
```

4. [Fireworks AI](https://fireworks.ai/)
3. [Google Vertex AI](https://cloud.google.com/vertex-ai) - requires `openai-scala-google-vertexai-client` lib and `VERTEXAI_LOCATION` + `VERTEXAI_PROJECT_ID`
```scala
val service = OpenAIChatCompletionServiceFactory(
coreUrl = "https://api.fireworks.ai/inference/v1/",
authHeaders = Seq(("Authorization", s"Bearer ${sys.env("FIREWORKS_API_KEY")}"))
)
val service = VertexAIServiceFactory.asOpenAI()
```

5. [Octo AI](https://octo.ai/)
4. [Groq](https://wow.groq.com/) - requires `GROQ_API_KEY"`
```scala
val service = OpenAIChatCompletionServiceFactory(
coreUrl = "https://text.octoai.run/v1/",
authHeaders = Seq(("Authorization", s"Bearer ${sys.env("OCTOAI_TOKEN")}"))
)
val service = OpenAIChatCompletionServiceFactory(ChatProviderSettings.groq)
```
or with streaming
```scala
val service = OpenAIChatCompletionServiceFactory.withStreaming(ChatProviderSettings.groq)
```

6. [Ollama](https://ollama.com/)
5. [Fireworks AI](https://fireworks.ai/) - requires `FIREWORKS_API_KEY"`
```scala
val service = OpenAIChatCompletionServiceFactory(ChatProviderSettings.fireworks)
```
or with streaming
```scala
val service = OpenAIChatCompletionServiceFactory.withStreaming(ChatProviderSettings.fireworks)
```

6. [Octo AI](https://octo.ai/) - requires `OCTOAI_TOKEN`
```scala
val service = OpenAIChatCompletionServiceFactory(ChatProviderSettings.octoML)
```
or with streaming
```scala
val service = OpenAIChatCompletionServiceFactory.withStreaming(ChatProviderSettings.octoML)
```

7. [TogetherAI](https://www.together.ai/) requires `TOGETHERAI_API_KEY`
```scala
val service = OpenAIChatCompletionServiceFactory(ChatProviderSettings.togetherAI)
```
or with streaming
```scala
val service = OpenAIChatCompletionServiceFactory.withStreaming(ChatProviderSettings.togetherAI)
```

8. [Cerebras](https://cerebras.ai/) requires `CEREBRAS_API_KEY`
```scala
val service = OpenAIChatCompletionServiceFactory(ChatProviderSettings.cerebras)
```
or with streaming
```scala
val service = OpenAIChatCompletionServiceFactory.withStreaming(ChatProviderSettings.cerebras)
```

9. [Mistral](https://mistral.ai/) requires `MISTRAL_API_KEY`
```scala
val service = OpenAIChatCompletionServiceFactory(ChatProviderSettings.mistral)
```
or with streaming
```scala
val service = OpenAIChatCompletionServiceFactory.withStreaming(ChatProviderSettings.mistral)
```

10. [Ollama](https://ollama.com/)
```scala
val service = OpenAIChatCompletionServiceFactory(
coreUrl = "http://localhost:11434/v1/"
)
```
or with streaming
```scala
val service = OpenAIChatCompletionServiceFactory.withStreaming(
coreUrl = "http://localhost:11434/v1/"
)
```

- Services with additional streaming support - `createCompletionStreamed` and `createChatCompletionStreamed` provided by [OpenAIStreamedServiceExtra](./openai-client-stream/src/main/scala/io/cequence/openaiscala/service/OpenAIStreamedServiceExtra.scala) (requires `openai-scala-client-stream` lib)
- Note that services with additional streaming support - `createCompletionStreamed` and `createChatCompletionStreamed` provided by [OpenAIStreamedServiceExtra](./openai-client-stream/src/main/scala/io/cequence/openaiscala/service/OpenAIStreamedServiceExtra.scala) (requires `openai-scala-client-stream` lib)

```scala
import io.cequence.openaiscala.service.StreamedServiceTypes.OpenAIStreamedService
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.cequence.openaiscala.service

import akka.stream.Materializer
import io.cequence.openaiscala.domain.ProviderSettings
import io.cequence.openaiscala.service.impl.OpenAIChatCompletionServiceImpl
import io.cequence.wsclient.domain.WsRequestContext
import io.cequence.wsclient.service.WSClientEngine
Expand Down Expand Up @@ -33,6 +34,20 @@ object OpenAIChatCompletionServiceFactory

// propose a new name for the trait
trait IOpenAIChatCompletionServiceFactory[F] extends RawWsServiceFactory[F] {

def apply(
providerSettings: ProviderSettings
)(
implicit ec: ExecutionContext,
materializer: Materializer
): F =
apply(
coreUrl = providerSettings.coreUrl,
WsRequestContext(authHeaders =
Seq(("Authorization", s"Bearer ${sys.env(providerSettings.apiKeyEnvVariable)}"))
)
)

def forAzureAI(
endpoint: String,
region: String,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package io.cequence.openaiscala.domain

case class ProviderSettings(
coreUrl: String,
apiKeyEnvVariable: String
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package io.cequence.openaiscala.service

import io.cequence.openaiscala.domain.ProviderSettings

object ChatProviderSettings {

val cerebras = ProviderSettings("https://api.cerebras.ai/v1/", "CEREBRAS_API_KEY")
val groq = ProviderSettings("https://api.groq.com/openai/v1/", "GROQ_API_KEY")
val fireworks =
ProviderSettings("https://api.fireworks.ai/inference/v1/", "FIREWORKS_API_KEY")
val mistral = ProviderSettings("https://api.mistral.ai/v1/", "MISTRAL_API_KEY")
val octoML = ProviderSettings("https://text.octoai.run/v1/", "OCTOAI_TOKEN")
val togetherAI =
ProviderSettings("https://api.together.xyz/v1/", "TOGETHERAI_API_KEY")
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import akka.stream.scaladsl.{RestartSource, Sink, Source}
import io.cequence.openaiscala.OpenAIScalaClientException
import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettings
import io.cequence.openaiscala.domain.{NonOpenAIModelId, SystemMessage, UserMessage}
import io.cequence.openaiscala.examples.{ChatCompletionProvider, ExampleBase}
import io.cequence.openaiscala.examples.ExampleBase
import io.cequence.openaiscala.service.StreamedServiceTypes.OpenAIChatCompletionStreamedService
import org.slf4j.LoggerFactory

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package io.cequence.openaiscala.examples.nonopenai

import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettings
import io.cequence.openaiscala.domain.{NonOpenAIModelId, SystemMessage, UserMessage}
import io.cequence.openaiscala.examples.{ChatCompletionProvider, ExampleBase}
import io.cequence.openaiscala.examples.ExampleBase
import io.cequence.openaiscala.service.OpenAIChatCompletionService

import scala.concurrent.Future
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@ package io.cequence.openaiscala.examples.nonopenai

import io.cequence.openaiscala.domain._
import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettings
import io.cequence.openaiscala.examples.{ChatCompletionProvider, ExampleBase}
import io.cequence.openaiscala.service.OpenAIChatCompletionService
import io.cequence.openaiscala.examples.ExampleBase
import io.cequence.openaiscala.service.{
ChatProviderSettings,
OpenAIChatCompletionService,
OpenAIChatCompletionServiceFactory
}

import scala.concurrent.Future

Expand All @@ -12,7 +16,8 @@ import scala.concurrent.Future
*/
object CerebrasCreateChatCompletion extends ExampleBase[OpenAIChatCompletionService] {

override val service: OpenAIChatCompletionService = ChatCompletionProvider.cerebras
override val service: OpenAIChatCompletionService =
OpenAIChatCompletionServiceFactory(ChatProviderSettings.cerebras)

private val messages = Seq(
SystemMessage("You are a helpful assistant."),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,22 @@ package io.cequence.openaiscala.examples.nonopenai
import akka.stream.scaladsl.Sink
import io.cequence.openaiscala.domain._
import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettings
import io.cequence.openaiscala.examples.{ChatCompletionProvider, ExampleBase}
import io.cequence.openaiscala.service.OpenAIChatCompletionStreamedServiceExtra
import io.cequence.openaiscala.examples.ExampleBase
import io.cequence.openaiscala.service.OpenAIStreamedServiceImplicits.ChatCompletionStreamFactoryExt
import io.cequence.openaiscala.service.{
ChatProviderSettings,
OpenAIChatCompletionServiceFactory,
OpenAIChatCompletionStreamedServiceExtra
}

import scala.concurrent.Future

// requires `openai-scala-client-stream` as a dependency and `CEREBRAS_API_KEY` environment variable to be set
object CerebrasCreateChatCompletionStreamed
extends ExampleBase[OpenAIChatCompletionStreamedServiceExtra] {

override val service: OpenAIChatCompletionStreamedServiceExtra = ChatCompletionProvider.cerebras
override val service: OpenAIChatCompletionStreamedServiceExtra =
OpenAIChatCompletionServiceFactory.withStreaming(ChatProviderSettings.cerebras)

private val messages = Seq(
SystemMessage("You are a helpful assistant."),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,76 +1,64 @@
package io.cequence.openaiscala.examples
package io.cequence.openaiscala.examples.nonopenai

import akka.stream.Materializer
import io.cequence.openaiscala.anthropic.service.AnthropicServiceFactory
import io.cequence.openaiscala.service.OpenAIChatCompletionServiceFactory
import io.cequence.openaiscala.domain.ProviderSettings
import io.cequence.openaiscala.service.{ChatProviderSettings, OpenAIChatCompletionServiceFactory}
import io.cequence.openaiscala.service.OpenAIStreamedServiceImplicits._
import io.cequence.openaiscala.service.StreamedServiceTypes.OpenAIChatCompletionStreamedService
import io.cequence.openaiscala.vertexai.service.VertexAIServiceFactory
import io.cequence.wsclient.domain.WsRequestContext

import scala.concurrent.ExecutionContext
import io.cequence.openaiscala.service.StreamedServiceTypes.OpenAIChatCompletionStreamedService

object ChatCompletionProvider {
private case class ProviderSettings(
coreUrl: String,
apiKeyEnvVariable: String
)

private val Cerebras = ProviderSettings("https://api.cerebras.ai/v1/", "CEREBRAS_API_KEY")
private val Groq = ProviderSettings("https://api.groq.com/openai/v1/", "GROQ_API_KEY")
private val Fireworks =
ProviderSettings("https://api.fireworks.ai/inference/v1/", "FIREWORKS_API_KEY")
private val Mistral = ProviderSettings("https://api.mistral.ai/v1/", "MISTRAL_API_KEY")
private val OctoML = ProviderSettings("https://text.octoai.run/v1/", "OCTOAI_TOKEN")
private val TogetherAI = ProviderSettings("https://api.together.xyz/v1/", "TOGETHERAI_API_KEY")

/**
* Requires `CEREBRAS_API_KEY`
*/
def cerebras(
implicit ec: ExecutionContext,
m: Materializer
): OpenAIChatCompletionStreamedService = provide(Cerebras)
): OpenAIChatCompletionStreamedService = provide(ChatProviderSettings.cerebras)

/**
* Requires `GROQ_API_KEY`
*/
def groq(
implicit ec: ExecutionContext,
m: Materializer
): OpenAIChatCompletionStreamedService = provide(Groq)
): OpenAIChatCompletionStreamedService = provide(ChatProviderSettings.groq)

/**
* Requires `FIREWORKS_API_KEY`
*/
def fireworks(
implicit ec: ExecutionContext,
m: Materializer
): OpenAIChatCompletionStreamedService = provide(Fireworks)
): OpenAIChatCompletionStreamedService = provide(ChatProviderSettings.fireworks)

/**
* Requires `MISTRAL_API_KEY`
*/
def mistral(
implicit ec: ExecutionContext,
m: Materializer
): OpenAIChatCompletionStreamedService = provide(Mistral)
): OpenAIChatCompletionStreamedService = provide(ChatProviderSettings.mistral)

/**
* Requires `OCTOAI_TOKEN`
*/
def octoML(
implicit ec: ExecutionContext,
m: Materializer
): OpenAIChatCompletionStreamedService = provide(OctoML)
): OpenAIChatCompletionStreamedService = provide(ChatProviderSettings.octoML)

/**
* Requires `TOGETHERAI_API_KEY`
*/
def togetherAI(
implicit ec: ExecutionContext,
m: Materializer
): OpenAIChatCompletionStreamedService = provide(TogetherAI)
): OpenAIChatCompletionStreamedService = provide(ChatProviderSettings.togetherAI)

/**
* Requires `VERTEXAI_API_KEY` and "VERTEXAI_LOCATION"
Expand All @@ -94,10 +82,5 @@ object ChatCompletionProvider {
)(
implicit ec: ExecutionContext,
m: Materializer
): OpenAIChatCompletionStreamedService = OpenAIChatCompletionServiceFactory.withStreaming(
coreUrl = settings.coreUrl,
WsRequestContext(authHeaders =
Seq(("Authorization", s"Bearer ${sys.env(settings.apiKeyEnvVariable)}"))
)
)
): OpenAIChatCompletionStreamedService = OpenAIChatCompletionServiceFactory.withStreaming(settings)
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package io.cequence.openaiscala.examples.nonopenai

import io.cequence.openaiscala.domain._
import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettings
import io.cequence.openaiscala.examples.{ChatCompletionProvider, ExampleBase}
import io.cequence.openaiscala.examples.ExampleBase
import io.cequence.openaiscala.service.OpenAIChatCompletionService

import scala.concurrent.Future
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package io.cequence.openaiscala.examples.nonopenai
import akka.stream.scaladsl.Sink
import io.cequence.openaiscala.domain._
import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettings
import io.cequence.openaiscala.examples.{ChatCompletionProvider, ExampleBase}
import io.cequence.openaiscala.examples.ExampleBase
import io.cequence.openaiscala.service.OpenAIChatCompletionStreamedServiceExtra

import scala.concurrent.Future
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package io.cequence.openaiscala.examples.nonopenai

import io.cequence.openaiscala.domain._
import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettings
import io.cequence.openaiscala.examples.{ChatCompletionProvider, ExampleBase}
import io.cequence.openaiscala.examples.ExampleBase
import io.cequence.openaiscala.service.OpenAIChatCompletionService

import scala.concurrent.Future
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package io.cequence.openaiscala.examples.nonopenai
import akka.stream.scaladsl.Sink
import io.cequence.openaiscala.domain._
import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettings
import io.cequence.openaiscala.examples.{ChatCompletionProvider, ExampleBase}
import io.cequence.openaiscala.examples.ExampleBase
import io.cequence.openaiscala.service.OpenAIChatCompletionStreamedServiceExtra

import scala.concurrent.Future
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package io.cequence.openaiscala.examples.nonopenai

import io.cequence.openaiscala.domain._
import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettings
import io.cequence.openaiscala.examples.{ChatCompletionProvider, ExampleBase}
import io.cequence.openaiscala.examples.ExampleBase
import io.cequence.openaiscala.service.OpenAIChatCompletionService

import scala.concurrent.Future
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package io.cequence.openaiscala.examples.nonopenai
import akka.stream.scaladsl.Sink
import io.cequence.openaiscala.domain._
import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettings
import io.cequence.openaiscala.examples.{ChatCompletionProvider, ExampleBase}
import io.cequence.openaiscala.examples.ExampleBase
import io.cequence.openaiscala.service.OpenAIChatCompletionStreamedServiceExtra

import scala.concurrent.Future
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package io.cequence.openaiscala.examples.nonopenai

import io.cequence.openaiscala.domain._
import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettings
import io.cequence.openaiscala.examples.{ChatCompletionProvider, ExampleBase}
import io.cequence.openaiscala.examples.ExampleBase
import io.cequence.openaiscala.service.OpenAIChatCompletionService

import scala.concurrent.Future
Expand Down
Loading

0 comments on commit a29edb7

Please sign in to comment.