Skip to content

Commit

Permalink
feat(gateway): to support update provider
Browse files Browse the repository at this point in the history
  • Loading branch information
hanrw committed Aug 22, 2024
1 parent d89c0c3 commit e8096b8
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
package com.tddworks.openai.gateway.api

import com.tddworks.openai.api.chat.api.Chat
import com.tddworks.openai.api.chat.api.OpenAIModel
import com.tddworks.openai.api.legacy.completions.api.Completions

/**
* Interface for connecting to the OpenAI Gateway to chat.
*/
interface OpenAIGateway : Chat, Completions {
fun updateProvider(
id: String,
name: String,
config: OpenAIProviderConfig,
models: List<OpenAIModel>
)

fun addProvider(provider: OpenAIProvider): OpenAIGateway
fun removeProvider(name: String)
fun getProviders(): List<OpenAIProvider>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ import com.tddworks.openai.api.legacy.completions.api.Completions
* Represents a provider for the OpenAI chat functionality.
*/
interface OpenAIProvider : Chat, Completions {

/**
* The id of the provider.
*/
val id: String

/**
* The name of the provider.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ import com.tddworks.openai.api.chat.api.ChatCompletionChunk as OpenAIChatComplet

@OptIn(ExperimentalSerializationApi::class)
class AnthropicOpenAIProvider(
override val name: String = "Anthropic",
override val models: List<OpenAIModel> = AnthropicModel.availableModels.map {
override var id: String = "anthropic",
override var name: String = "Anthropic",
override var models: List<OpenAIModel> = AnthropicModel.availableModels.map {
OpenAIModel(it.value)
},
override val config: AnthropicOpenAIProviderConfig,
override var config: AnthropicOpenAIProviderConfig,

private val client: Anthropic = Anthropic.create(
AnthropicConfig(
Expand Down Expand Up @@ -75,6 +76,7 @@ class AnthropicOpenAIProvider(
}

fun OpenAIProvider.Companion.anthropic(
id: String = "anthropic",
config: AnthropicOpenAIProviderConfig,
models: List<OpenAIModel> = AnthropicModel.availableModels.map {
OpenAIModel(it.value)
Expand All @@ -87,5 +89,10 @@ fun OpenAIProvider.Companion.anthropic(
)
)
): OpenAIProvider {
return AnthropicOpenAIProvider(config = config, models = models, client = client)
return AnthropicOpenAIProvider(
id = id,
config = config,
models = models,
client = client
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ package com.tddworks.openai.gateway.api.internal
import com.tddworks.openai.api.chat.api.ChatCompletion
import com.tddworks.openai.api.chat.api.ChatCompletionChunk
import com.tddworks.openai.api.chat.api.ChatCompletionRequest
import com.tddworks.openai.api.chat.api.OpenAIModel
import com.tddworks.openai.api.legacy.completions.api.Completion
import com.tddworks.openai.api.legacy.completions.api.CompletionRequest
import com.tddworks.openai.gateway.api.OpenAIGateway
import com.tddworks.openai.gateway.api.OpenAIProvider
import com.tddworks.openai.gateway.api.OpenAIProviderConfig
import kotlinx.coroutines.flow.Flow
import kotlinx.serialization.ExperimentalSerializationApi

Expand All @@ -22,6 +24,24 @@ class DefaultOpenAIGateway(
private val availableProviders: MutableList<OpenAIProvider> =
providers.toMutableList()


override fun updateProvider(
id: String,
name: String,
config: OpenAIProviderConfig,
models: List<OpenAIModel>
) {
availableProviders.removeAll { it.id == id }
availableProviders.add(
DefaultOpenAIProvider(
id = id,
name = name,
config = config,
models = models
)
)
}

override fun addProvider(provider: OpenAIProvider): OpenAIGateway {
availableProviders.add(provider)
return this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ import kotlinx.coroutines.flow.Flow
import kotlinx.serialization.ExperimentalSerializationApi

class DefaultOpenAIProvider(
override val id: String = "openai",
override val name: String = "OpenAI",
override val models: List<OpenAIModel> = availableModels,
override val config: OpenAIProviderConfig,
private val openAI: OpenAI = OpenAI.create(config.toOpenAIConfig())
private val openAI: OpenAI = OpenAI.create(config.toOpenAIConfig()),
) : OpenAIProvider {

override fun supports(model: OpenAIModel): Boolean {
Expand All @@ -40,9 +41,13 @@ class DefaultOpenAIProvider(
}

fun OpenAIProvider.Companion.openAI(
id: String = "openai",
config: OpenAIProviderConfig,
models: List<OpenAIModel>,
openAI: OpenAI = OpenAI.create(config.toOpenAIConfig())
): OpenAIProvider {
return DefaultOpenAIProvider(config = config, models = models, openAI = openAI)
return DefaultOpenAIProvider(
id = id,
config = config, models = models, openAI = openAI
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import kotlinx.serialization.ExperimentalSerializationApi

@OptIn(ExperimentalSerializationApi::class)
class OllamaOpenAIProvider(
override val id: String = "ollama",
override val name: String = "Ollama",
override val config: OllamaOpenAIProviderConfig,
override val models: List<OpenAIModel> = OllamaModel.availableModels.map {
Expand Down Expand Up @@ -72,6 +73,7 @@ class OllamaOpenAIProvider(
}

fun OpenAIProvider.Companion.ollama(
id: String = "ollama",
config: OllamaOpenAIProviderConfig,
models: List<OpenAIModel> = OllamaModel.availableModels.map {
OpenAIModel(it.value)
Expand All @@ -84,5 +86,10 @@ fun OpenAIProvider.Companion.ollama(
)
)
): OpenAIProvider {
return OllamaOpenAIProvider(config = config, models = models, client = client)
return OllamaOpenAIProvider(
id = id,
config = config,
models = models,
client = client
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import com.tddworks.openai.api.chat.api.ChatCompletionChunk
import com.tddworks.openai.api.chat.api.ChatCompletionRequest
import com.tddworks.openai.api.chat.api.OpenAIModel
import com.tddworks.openai.gateway.api.OpenAIProvider
import com.tddworks.openai.gateway.api.OpenAIProviderConfig
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.test.runTest
import kotlinx.serialization.ExperimentalSerializationApi
Expand All @@ -20,11 +21,13 @@ import com.tddworks.anthropic.api.AnthropicModel as AnthropicModel
@OptIn(ExperimentalSerializationApi::class)
class DefaultOpenAIGatewayTest {
private val anthropic = mock<OpenAIProvider> {
on(it.id).thenReturn("anthropic")
on(it.supports(OpenAIModel(AnthropicModel.CLAUDE_3_HAIKU.value))).thenReturn(true)
on(it.name).thenReturn("Anthropic")
}

private val ollama = mock<OpenAIProvider> {
on(it.id).thenReturn("ollama")
on(it.supports(OpenAIModel(OllamaModel.LLAMA2.value))).thenReturn(true)
on(it.name).thenReturn("Ollama")
}
Expand All @@ -38,6 +41,28 @@ class DefaultOpenAIGatewayTest {
providers,
)

@Test
fun `should able to update provider`() {
// Given
val id = "anthropic"
val name = "new Anthropic"
val config = OpenAIProviderConfig.anthropic(
apiKey = { "" },
)

val models = listOf(OpenAIModel.GPT_3_5_TURBO)

// When
openAIGateway.updateProvider(id, name, config, models)

// Then
assertEquals(2, openAIGateway.getProviders().size)
val openAIProvider = openAIGateway.getProviders().first { it.id == id }
assertEquals(name, openAIProvider.name)
assertEquals(config, openAIProvider.config)
assertEquals(models, openAIProvider.models)
}

@Test
fun `should able to remove provider`() {
openAIGateway.removeProvider(anthropic.name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ class DefaultOpenAIProviderTest {
fun setUp() {
provider =
OpenAIProvider.openAI(
OpenAIProviderConfig.default(
config = OpenAIProviderConfig.default(
apiKey = { "" },
),
listOf(OpenAIModel.GPT_3_5_TURBO),
client
models = listOf(OpenAIModel.GPT_3_5_TURBO),
openAI = client
)
}

Expand Down

0 comments on commit e8096b8

Please sign in to comment.