From 58e7ee96dce4c3b5a559deae6e886c50cf67683e Mon Sep 17 00:00:00 2001 From: slam Date: Sat, 15 Jun 2024 19:46:31 +0800 Subject: [PATCH] feat: support addProvider - OpenAIGateway --- README.md | 29 ++++- .../kotlin/com/tddworks/openai/api/OpenAI.kt | 33 +++++- .../com/tddworks/openai/api/OpenAITest.kt | 8 +- .../gateway/api/DefaultOpenAIProvider.kt | 32 ++++++ .../openai/gateway/api/OpenAIGateway.kt | 5 +- .../api/internal/DefaultOpenAIGateway.kt | 17 ++- ...wayTest.kt => DefaultOpenAIGatewayTest.kt} | 20 +++- .../gateway/api/DefaultOpenAIProviderTest.kt | 103 ++++++++++++++++++ 8 files changed, 235 insertions(+), 12 deletions(-) create mode 100644 openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/DefaultOpenAIProvider.kt rename openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/openai/gateway/api/{OpenAIGatewayTest.kt => DefaultOpenAIGatewayTest.kt} (91%) create mode 100644 openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/openai/gateway/api/DefaultOpenAIProviderTest.kt diff --git a/README.md b/README.md index 964b085..2170272 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ val openAI = initOpenAI(OpenAIConfig( )) // stream completions -openAI.streamCompletions( +openAI.streamChatCompletions( ChatCompletionRequest( messages = listOf(ChatMessage.UserMessage("hello")), maxTokens = 1024, @@ -41,13 +41,23 @@ openAI.streamCompletions( } // chat completions -val chatCompletion = openAI.completions( +val chatCompletion = openAI.chatCompletions( ChatCompletionRequest( messages = listOf(ChatMessage.UserMessage("hello")), maxTokens = 1024, model = Model.GPT_3_5_TURBO ) ) + +// completions(legacy) +val completion = openAI.completions( + CompletionRequest( + prompt = "Once upon a time", + suffix = "The end", + maxTokens = 10, + temperature = 0.5 + ) +) ``` @@ -100,7 +110,7 @@ val openAIGateway = initOpenAIGateway( ) // stream completions -openAIGateway.streamCompletions( +openAIGateway.streamChatCompletions( ChatCompletionRequest( messages = listOf(ChatMessage.UserMessage("hello")), maxTokens = 1024, @@ -111,11 +121,22 @@ openAIGateway.streamCompletions( } // chat completions -val chatCompletion = openAIGateway.completions( +val chatCompletion = openAIGateway.chatCompletions( ChatCompletionRequest( messages = listOf(ChatMessage.UserMessage("hello")), maxTokens = 1024, model = Model(Model.GPT_3_5_TURBO.value) ) ) + +// completions(legacy) +val completion = openAIGateway.completions( + CompletionRequest( + prompt = "Once upon a time", + suffix = "The end", + maxTokens = 10, + temperature = 0.5 + ) +) + ``` \ No newline at end of file diff --git a/openai-client/openai-client-core/src/commonMain/kotlin/com/tddworks/openai/api/OpenAI.kt b/openai-client/openai-client-core/src/commonMain/kotlin/com/tddworks/openai/api/OpenAI.kt index b638c48..641fd97 100644 --- a/openai-client/openai-client-core/src/commonMain/kotlin/com/tddworks/openai/api/OpenAI.kt +++ b/openai-client/openai-client-core/src/commonMain/kotlin/com/tddworks/openai/api/OpenAI.kt @@ -1,14 +1,45 @@ package com.tddworks.openai.api +import com.tddworks.common.network.api.ktor.api.HttpRequester +import com.tddworks.common.network.api.ktor.internal.createHttpClient +import com.tddworks.common.network.api.ktor.internal.default +import com.tddworks.di.createJson import com.tddworks.di.getInstance import com.tddworks.openai.api.chat.api.Chat +import com.tddworks.openai.api.chat.internal.DefaultChatApi import com.tddworks.openai.api.images.api.Images -import com.tddworks.openai.api.legacy.completions.api.Completion +import com.tddworks.openai.api.images.internal.DefaultImagesApi import com.tddworks.openai.api.legacy.completions.api.Completions +import com.tddworks.openai.api.legacy.completions.api.internal.DefaultCompletionsApi interface OpenAI : Chat, Images, Completions { companion object { const val BASE_URL = "api.openai.com" + + fun create(config: OpenAIConfig): OpenAI { + val requester = HttpRequester.default( + createHttpClient( + host = config.baseUrl, + authToken = config.apiKey, + // get from commonModule + json = createJson(), + ) + ) + val chatApi = DefaultChatApi( + requester = requester + ) + + val imagesApi = DefaultImagesApi( + requester = requester + ) + + val completionsApi = DefaultCompletionsApi( + requester = requester + ) + + return object : OpenAI, Chat by chatApi, Images by imagesApi, + Completions by completionsApi {} + } } } diff --git a/openai-client/openai-client-core/src/jvmTest/kotlin/com/tddworks/openai/api/OpenAITest.kt b/openai-client/openai-client-core/src/jvmTest/kotlin/com/tddworks/openai/api/OpenAITest.kt index 2516280..f5e7a6b 100644 --- a/openai-client/openai-client-core/src/jvmTest/kotlin/com/tddworks/openai/api/OpenAITest.kt +++ b/openai-client/openai-client-core/src/jvmTest/kotlin/com/tddworks/openai/api/OpenAITest.kt @@ -9,14 +9,18 @@ import com.tddworks.openai.api.images.api.ImageCreate import kotlinx.coroutines.flow.toList import kotlinx.coroutines.runBlocking import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertNull import org.junit.jupiter.api.Disabled import org.junit.jupiter.api.Test +import kotlin.test.assertNotNull class OpenAITest { @Test - fun `should return correct base url`() { - assertEquals("api.openai.com", OpenAI.BASE_URL) + fun `should create openai instance`() { + val openAI = OpenAI.create(OpenAIConfig()) + + assertNotNull(openAI) } @Disabled diff --git a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/DefaultOpenAIProvider.kt b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/DefaultOpenAIProvider.kt new file mode 100644 index 0000000..d6e674d --- /dev/null +++ b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/DefaultOpenAIProvider.kt @@ -0,0 +1,32 @@ +@file:OptIn(ExperimentalSerializationApi::class) + +package com.tddworks.openai.gateway.api + +import com.tddworks.openai.api.OpenAI +import com.tddworks.openai.api.OpenAIConfig +import com.tddworks.openai.api.chat.api.ChatCompletion +import com.tddworks.openai.api.chat.api.ChatCompletionChunk +import com.tddworks.openai.api.chat.api.ChatCompletionRequest +import com.tddworks.openai.api.chat.api.Model +import kotlinx.coroutines.flow.Flow +import kotlinx.serialization.ExperimentalSerializationApi + +class DefaultOpenAIProvider( + config: OpenAIConfig, + models: List, + private val openAI: OpenAI = OpenAI.create(config) +) : OpenAIProvider { + private val availableModels: MutableList = models.toMutableList() + + override fun supports(model: Model): Boolean { + return availableModels.any { it.value == model.value } + } + + override suspend fun chatCompletions(request: ChatCompletionRequest): ChatCompletion { + return openAI.chatCompletions(request) + } + + override fun streamChatCompletions(request: ChatCompletionRequest): Flow { + return openAI.streamChatCompletions(request) + } +} diff --git a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/OpenAIGateway.kt b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/OpenAIGateway.kt index 15b620c..bced450 100644 --- a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/OpenAIGateway.kt +++ b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/OpenAIGateway.kt @@ -6,4 +6,7 @@ import com.tddworks.openai.api.legacy.completions.api.Completions /** * Interface for connecting to the OpenAI Gateway to chat. */ -interface OpenAIGateway : Chat, Completions \ No newline at end of file +interface OpenAIGateway : Chat, Completions { + fun addProvider(provider: OpenAIProvider): OpenAIGateway + fun getProviders(): List +} \ No newline at end of file diff --git a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIGateway.kt b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIGateway.kt index 60a2b37..52df85c 100644 --- a/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIGateway.kt +++ b/openai-gateway/openai-gateway-core/src/commonMain/kotlin/com/tddworks/openai/gateway/api/internal/DefaultOpenAIGateway.kt @@ -18,9 +18,20 @@ import kotlinx.serialization.ExperimentalSerializationApi */ @ExperimentalSerializationApi class DefaultOpenAIGateway( - private val providers: List, + providers: List, private val openAI: OpenAI, ) : OpenAIGateway { + private val availableProviders: MutableList = + providers.toMutableList() + + override fun addProvider(provider: OpenAIProvider): OpenAIGateway { + availableProviders.add(provider) + return this + } + + override fun getProviders(): List { + return availableProviders.toList() + } /** * This function is called to get completions for a chat based on the given request. @@ -29,7 +40,7 @@ class DefaultOpenAIGateway( * @return A ChatCompletion object containing the completions for the provided request. */ override suspend fun chatCompletions(request: ChatCompletionRequest): ChatCompletion { - return providers.firstOrNull { + return availableProviders.firstOrNull { it.supports(request.model) }?.chatCompletions(request) ?: openAI.chatCompletions(request) } @@ -42,7 +53,7 @@ class DefaultOpenAIGateway( * @return a Flow of ChatCompletionChunk objects representing the completions for the input model */ override fun streamChatCompletions(request: ChatCompletionRequest): Flow { - return providers.firstOrNull { + return availableProviders.firstOrNull { it.supports(request.model) }?.streamChatCompletions(request) ?: openAI.streamChatCompletions(request) } diff --git a/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/openai/gateway/api/OpenAIGatewayTest.kt b/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/openai/gateway/api/DefaultOpenAIGatewayTest.kt similarity index 91% rename from openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/openai/gateway/api/OpenAIGatewayTest.kt rename to openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/openai/gateway/api/DefaultOpenAIGatewayTest.kt index 1b9c903..24e050a 100644 --- a/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/openai/gateway/api/OpenAIGatewayTest.kt +++ b/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/openai/gateway/api/DefaultOpenAIGatewayTest.kt @@ -20,7 +20,7 @@ import kotlin.test.assertEquals import com.tddworks.anthropic.api.Model as AnthropicModel @OptIn(ExperimentalSerializationApi::class) -class OpenAIGatewayTest { +class DefaultOpenAIGatewayTest { private val anthropic = mock { on(it.supports(Model(AnthropicModel.CLAUDE_3_HAIKU.value))).thenReturn(true) } @@ -41,6 +41,24 @@ class OpenAIGatewayTest { openAI = openAI ) + @Test + fun `should able to add new provider`() { + // Given + val provider = mock() + + // When + val gateway = DefaultOpenAIGateway( + providers, + openAI = openAI + ).run { + addProvider(provider) + } + + // Then + assertEquals(3, gateway.getProviders().size) + assertEquals(provider, gateway.getProviders().last()) + } + @Test fun `should use ollama client to get chat completions`() = runTest { // Given diff --git a/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/openai/gateway/api/DefaultOpenAIProviderTest.kt b/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/openai/gateway/api/DefaultOpenAIProviderTest.kt new file mode 100644 index 0000000..fe15ce6 --- /dev/null +++ b/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/openai/gateway/api/DefaultOpenAIProviderTest.kt @@ -0,0 +1,103 @@ +package com.tddworks.openai.gateway.api + +import app.cash.turbine.test +import com.tddworks.ollama.api.OllamaModel +import com.tddworks.openai.api.OpenAI +import com.tddworks.openai.api.OpenAIConfig +import com.tddworks.openai.api.chat.api.ChatCompletion +import com.tddworks.openai.api.chat.api.ChatCompletionChunk +import com.tddworks.openai.api.chat.api.ChatCompletionRequest +import com.tddworks.openai.api.chat.api.Model +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.ExtendWith +import org.mockito.Mock +import org.mockito.Mockito.mock +import org.mockito.junit.jupiter.MockitoExtension +import org.mockito.kotlin.whenever + +@ExtendWith(MockitoExtension::class) +class DefaultOpenAIProviderTest { + @Mock + lateinit var client: OpenAI + + + private lateinit var provider: DefaultOpenAIProvider + + @BeforeEach + fun setUp() { + provider = + DefaultOpenAIProvider(OpenAIConfig(), listOf(Model.GPT_3_5_TURBO), client) + } + + @Test + fun `should return completions from OpenAI API`() = runTest { + // given + val request = ChatCompletionRequest.dummy(Model(OllamaModel.LLAMA2.value)) + val response = ChatCompletion.dummy() + whenever(client.chatCompletions(request)).thenReturn(response) + + // when + val completions = provider.chatCompletions(request) + + // then + assertEquals(response, completions) + } + + @Test + fun `should stream completions for chat`() = runTest { + // given + val request = ChatCompletionRequest.dummy(Model(OllamaModel.LLAMA2.value)) + + val response = ChatCompletionChunk.dummy() + whenever(client.streamChatCompletions(request)).thenReturn(flow { + emit( + response + ) + }) + + // when + provider.streamChatCompletions(request).test { + // then + assertEquals( + response, + awaitItem() + ) + awaitComplete() + } + + } + + @Test + fun `should return false when model is not supported`() { + // Given + val openAI = mock() + val model = Model("gpt-3.5-turbo") + val provider = DefaultOpenAIProvider(OpenAIConfig(), emptyList(), openAI) + + // When + val result = provider.supports(model) + + // Then + assertTrue(!result) + } + + @Test + fun `should return true when model is supported`() { + // Given + val openAI = mock() + val model = Model("gpt-3.5-turbo") + val provider = DefaultOpenAIProvider(OpenAIConfig(), listOf(model), openAI) + + // When + val result = provider.supports(model) + + // Then + assertTrue(result) + } + +} \ No newline at end of file