diff --git a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/OpenAIGateway.kt b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/OpenAIGateway.kt index 0ad85ca..b6fec50 100644 --- a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/OpenAIGateway.kt +++ b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/OpenAIGateway.kt @@ -1,12 +1,20 @@ package com.tddworks.openai.gateway.api import com.tddworks.openai.api.chat.api.Chat +import com.tddworks.openai.api.chat.api.OpenAIModel import com.tddworks.openai.api.legacy.completions.api.Completions /** * Interface for connecting to the OpenAI Gateway to chat. */ interface OpenAIGateway : Chat, Completions { + fun updateProvider( + id: String, + name: String, + config: OpenAIProviderConfig, + models: List + ) + fun addProvider(provider: OpenAIProvider): OpenAIGateway fun removeProvider(name: String) fun getProviders(): List diff --git a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/OpenAIProvider.kt b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/OpenAIProvider.kt index 7bcf517..e9fc78a 100644 --- a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/OpenAIProvider.kt +++ b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/OpenAIProvider.kt @@ -8,6 +8,12 @@ import com.tddworks.openai.api.legacy.completions.api.Completions * Represents a provider for the OpenAI chat functionality. */ interface OpenAIProvider : Chat, Completions { + + /** + * The id of the provider. + */ + val id: String + /** * The name of the provider. */ diff --git a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/AnthropicOpenAIProvider.kt b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/AnthropicOpenAIProvider.kt index a203843..0e7f29d 100644 --- a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/AnthropicOpenAIProvider.kt +++ b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/AnthropicOpenAIProvider.kt @@ -18,11 +18,12 @@ import com.tddworks.openai.api.chat.api.ChatCompletionChunk as OpenAIChatComplet @OptIn(ExperimentalSerializationApi::class) class AnthropicOpenAIProvider( - override val name: String = "Anthropic", - override val models: List = AnthropicModel.availableModels.map { + override var id: String = "anthropic", + override var name: String = "Anthropic", + override var models: List = AnthropicModel.availableModels.map { OpenAIModel(it.value) }, - override val config: AnthropicOpenAIProviderConfig, + override var config: AnthropicOpenAIProviderConfig, private val client: Anthropic = Anthropic.create( AnthropicConfig( @@ -75,6 +76,7 @@ class AnthropicOpenAIProvider( } fun OpenAIProvider.Companion.anthropic( + id: String = "anthropic", config: AnthropicOpenAIProviderConfig, models: List = AnthropicModel.availableModels.map { OpenAIModel(it.value) @@ -87,5 +89,10 @@ fun OpenAIProvider.Companion.anthropic( ) ) ): OpenAIProvider { - return AnthropicOpenAIProvider(config = config, models = models, client = client) + return AnthropicOpenAIProvider( + id = id, + config = config, + models = models, + client = client + ) } \ No newline at end of file diff --git a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIGateway.kt b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIGateway.kt index cde3a41..72487fe 100644 --- a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIGateway.kt +++ b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIGateway.kt @@ -3,10 +3,12 @@ package com.tddworks.openai.gateway.api.internal import com.tddworks.openai.api.chat.api.ChatCompletion import com.tddworks.openai.api.chat.api.ChatCompletionChunk import com.tddworks.openai.api.chat.api.ChatCompletionRequest +import com.tddworks.openai.api.chat.api.OpenAIModel import com.tddworks.openai.api.legacy.completions.api.Completion import com.tddworks.openai.api.legacy.completions.api.CompletionRequest import com.tddworks.openai.gateway.api.OpenAIGateway import com.tddworks.openai.gateway.api.OpenAIProvider +import com.tddworks.openai.gateway.api.OpenAIProviderConfig import kotlinx.coroutines.flow.Flow import kotlinx.serialization.ExperimentalSerializationApi @@ -22,6 +24,24 @@ class DefaultOpenAIGateway( private val availableProviders: MutableList = providers.toMutableList() + + override fun updateProvider( + id: String, + name: String, + config: OpenAIProviderConfig, + models: List + ) { + availableProviders.removeAll { it.id == id } + availableProviders.add( + DefaultOpenAIProvider( + id = id, + name = name, + config = config, + models = models + ) + ) + } + override fun addProvider(provider: OpenAIProvider): OpenAIGateway { availableProviders.add(provider) return this diff --git a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIProvider.kt b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIProvider.kt index a05f8f5..9f27c8a 100644 --- a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIProvider.kt +++ b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIProvider.kt @@ -16,10 +16,11 @@ import kotlinx.coroutines.flow.Flow import kotlinx.serialization.ExperimentalSerializationApi class DefaultOpenAIProvider( + override val id: String = "openai", override val name: String = "OpenAI", override val models: List = availableModels, override val config: OpenAIProviderConfig, - private val openAI: OpenAI = OpenAI.create(config.toOpenAIConfig()) + private val openAI: OpenAI = OpenAI.create(config.toOpenAIConfig()), ) : OpenAIProvider { override fun supports(model: OpenAIModel): Boolean { @@ -40,9 +41,13 @@ class DefaultOpenAIProvider( } fun OpenAIProvider.Companion.openAI( + id: String = "openai", config: OpenAIProviderConfig, models: List, openAI: OpenAI = OpenAI.create(config.toOpenAIConfig()) ): OpenAIProvider { - return DefaultOpenAIProvider(config = config, models = models, openAI = openAI) + return DefaultOpenAIProvider( + id = id, + config = config, models = models, openAI = openAI + ) } diff --git a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/OllamaOpenAIProvider.kt b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/OllamaOpenAIProvider.kt index ed9b7e1..d3d0dde 100644 --- a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/OllamaOpenAIProvider.kt +++ b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/OllamaOpenAIProvider.kt @@ -17,6 +17,7 @@ import kotlinx.serialization.ExperimentalSerializationApi @OptIn(ExperimentalSerializationApi::class) class OllamaOpenAIProvider( + override val id: String = "ollama", override val name: String = "Ollama", override val config: OllamaOpenAIProviderConfig, override val models: List = OllamaModel.availableModels.map { @@ -72,6 +73,7 @@ class OllamaOpenAIProvider( } fun OpenAIProvider.Companion.ollama( + id: String = "ollama", config: OllamaOpenAIProviderConfig, models: List = OllamaModel.availableModels.map { OpenAIModel(it.value) @@ -84,5 +86,10 @@ fun OpenAIProvider.Companion.ollama( ) ) ): OpenAIProvider { - return OllamaOpenAIProvider(config = config, models = models, client = client) + return OllamaOpenAIProvider( + id = id, + config = config, + models = models, + client = client + ) } \ No newline at end of file diff --git a/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIGatewayTest.kt b/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIGatewayTest.kt index baa1b26..1b88c66 100644 --- a/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIGatewayTest.kt +++ b/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIGatewayTest.kt @@ -7,6 +7,7 @@ import com.tddworks.openai.api.chat.api.ChatCompletionChunk import com.tddworks.openai.api.chat.api.ChatCompletionRequest import com.tddworks.openai.api.chat.api.OpenAIModel import com.tddworks.openai.gateway.api.OpenAIProvider +import com.tddworks.openai.gateway.api.OpenAIProviderConfig import kotlinx.coroutines.flow.flow import kotlinx.coroutines.test.runTest import kotlinx.serialization.ExperimentalSerializationApi @@ -20,11 +21,13 @@ import com.tddworks.anthropic.api.AnthropicModel as AnthropicModel @OptIn(ExperimentalSerializationApi::class) class DefaultOpenAIGatewayTest { private val anthropic = mock { + on(it.id).thenReturn("anthropic") on(it.supports(OpenAIModel(AnthropicModel.CLAUDE_3_HAIKU.value))).thenReturn(true) on(it.name).thenReturn("Anthropic") } private val ollama = mock { + on(it.id).thenReturn("ollama") on(it.supports(OpenAIModel(OllamaModel.LLAMA2.value))).thenReturn(true) on(it.name).thenReturn("Ollama") } @@ -38,6 +41,28 @@ class DefaultOpenAIGatewayTest { providers, ) + @Test + fun `should able to update provider`() { + // Given + val id = "anthropic" + val name = "new Anthropic" + val config = OpenAIProviderConfig.anthropic( + apiKey = { "" }, + ) + + val models = listOf(OpenAIModel.GPT_3_5_TURBO) + + // When + openAIGateway.updateProvider(id, name, config, models) + + // Then + assertEquals(2, openAIGateway.getProviders().size) + val openAIProvider = openAIGateway.getProviders().first { it.id == id } + assertEquals(name, openAIProvider.name) + assertEquals(config, openAIProvider.config) + assertEquals(models, openAIProvider.models) + } + @Test fun `should able to remove provider`() { openAIGateway.removeProvider(anthropic.name) diff --git a/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIProviderTest.kt b/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIProviderTest.kt index 21e3d9d..b581a6c 100644 --- a/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIProviderTest.kt +++ b/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIProviderTest.kt @@ -35,11 +35,11 @@ class DefaultOpenAIProviderTest { fun setUp() { provider = OpenAIProvider.openAI( - OpenAIProviderConfig.default( + config = OpenAIProviderConfig.default( apiKey = { "" }, ), - listOf(OpenAIModel.GPT_3_5_TURBO), - client + models = listOf(OpenAIModel.GPT_3_5_TURBO), + openAI = client ) }