From d2a5a19d1df4f30858a027d78f60c6e979a0e850 Mon Sep 17 00:00:00 2001 From: slam Date: Thu, 18 Apr 2024 18:39:13 +0800 Subject: [PATCH] refactor: - update com.tddworks.central-portal-publisher - add unit tests --- build.gradle.kts | 2 +- .../ollama/api/chat/OllamaChatRequest.kt | 12 +++ .../ollama/api/chat/OllamaChatResponse.kt | 14 ++- .../ollama/api/chat/OllamaChatRequestTest.kt | 20 ++++ .../gateway/api/OllamaOpenAIProviderTest.kt | 93 +++++++++++++++++++ 5 files changed, 139 insertions(+), 2 deletions(-) create mode 100644 openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/ollama/api/chat/OllamaChatRequestTest.kt create mode 100644 openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/openai/gateway/api/OllamaOpenAIProviderTest.kt diff --git a/build.gradle.kts b/build.gradle.kts index eb9f3cc..ec7f34e 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -8,7 +8,7 @@ plugins { alias(libs.plugins.build.dokka.plugin) alias(libs.plugins.kotlinx.binary.validator) apply false - id("com.tddworks.sonatype-portal-publisher") version "0.0.1" + id("com.tddworks.central-portal-publisher") version "0.0.2" } sonatypePortalPublisher { diff --git a/ollama-client/ollama-client-core/src/commonMain/kotlin/com/tddworks/ollama/api/chat/OllamaChatRequest.kt b/ollama-client/ollama-client-core/src/commonMain/kotlin/com/tddworks/ollama/api/chat/OllamaChatRequest.kt index 2afff53..9eb7c25 100644 --- a/ollama-client/ollama-client-core/src/commonMain/kotlin/com/tddworks/ollama/api/chat/OllamaChatRequest.kt +++ b/ollama-client/ollama-client-core/src/commonMain/kotlin/com/tddworks/ollama/api/chat/OllamaChatRequest.kt @@ -23,6 +23,18 @@ data class OllamaChatRequest( } .let { JsonObject(it) } } + + companion object { + fun dummy() = OllamaChatRequest( + model = "llama2", + messages = listOf( + OllamaChatMessage( + role = "user", + content = "Hello!" + ) + ) + ) + } } diff --git a/ollama-client/ollama-client-core/src/commonMain/kotlin/com/tddworks/ollama/api/chat/OllamaChatResponse.kt b/ollama-client/ollama-client-core/src/commonMain/kotlin/com/tddworks/ollama/api/chat/OllamaChatResponse.kt index 264ca0b..5393bdd 100644 --- a/ollama-client/ollama-client-core/src/commonMain/kotlin/com/tddworks/ollama/api/chat/OllamaChatResponse.kt +++ b/ollama-client/ollama-client-core/src/commonMain/kotlin/com/tddworks/ollama/api/chat/OllamaChatResponse.kt @@ -56,4 +56,16 @@ data class OllamaChatResponse( @SerialName("prompt_eval_duration") val promptEvalDuration: Long? = null, @SerialName("eval_count") val evalCount: Int? = null, @SerialName("eval_duration") val evalDuration: Long? = null, -) \ No newline at end of file +) { + companion object { + fun dummy() = OllamaChatResponse( + model = "llama2", + createdAt = "2023-08-04T08:52:19.385406455-07:00", + message = OllamaChatMessage( + role = "assistant", + content = "The" + ), + done = false + ) + } +} \ No newline at end of file diff --git a/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/ollama/api/chat/OllamaChatRequestTest.kt b/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/ollama/api/chat/OllamaChatRequestTest.kt new file mode 100644 index 0000000..955074d --- /dev/null +++ b/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/ollama/api/chat/OllamaChatRequestTest.kt @@ -0,0 +1,20 @@ +package com.tddworks.ollama.api.chat + +import org.junit.jupiter.api.Test +import kotlin.test.assertEquals + +class OllamaChatRequestTest { + + @Test + fun `should return dummy request`() { + // given + val request = OllamaChatRequest.dummy() + + // then + assertEquals("llama2", request.model) + assertEquals(1, request.messages.size) + assertEquals("user", request.messages[0].role) + assertEquals("Hello!", request.messages[0].content) + } + +} \ No newline at end of file diff --git a/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/openai/gateway/api/OllamaOpenAIProviderTest.kt b/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/openai/gateway/api/OllamaOpenAIProviderTest.kt new file mode 100644 index 0000000..bb6ebc6 --- /dev/null +++ b/openai-gateway/openai-gateway-core/src/jvmTest/kotlin/com/tddworks/openai/gateway/api/OllamaOpenAIProviderTest.kt @@ -0,0 +1,93 @@ +package com.tddworks.openai.gateway.api + +import app.cash.turbine.test +import com.tddworks.ollama.api.Ollama +import com.tddworks.ollama.api.OllamaModel +import com.tddworks.ollama.api.chat.OllamaChatResponse +import com.tddworks.ollama.api.chat.api.toOllamaChatRequest +import com.tddworks.ollama.api.chat.api.toOpenAIChatCompletion +import com.tddworks.ollama.api.chat.api.toOpenAIChatCompletionChunk +import com.tddworks.openai.api.chat.api.ChatCompletionRequest +import com.tddworks.openai.api.chat.api.Model +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.test.runTest +import kotlinx.serialization.ExperimentalSerializationApi +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.extension.ExtendWith +import org.mockito.InjectMocks +import org.mockito.Mock +import org.mockito.junit.jupiter.MockitoExtension +import org.mockito.kotlin.whenever + +@ExperimentalSerializationApi +@ExtendWith(MockitoExtension::class) +class OllamaOpenAIProviderTest { + @Mock + lateinit var client: Ollama + + @InjectMocks + lateinit var provider: OllamaOpenAIProvider + + @Test + fun `should return true when model is supported`() { + // given + val supportedModel = Model(OllamaModel.LLAMA2.value) + + // when + val isSupported = provider.supports(supportedModel) + + // then + kotlin.test.assertTrue(isSupported) + } + + @Test + fun `should return false when model is not supported`() { + // given + val unsupportedModel = Model.GPT_3_5_TURBO + + // when + val isSupported = provider.supports(unsupportedModel) + + // then + kotlin.test.assertFalse(isSupported) + } + + @Test + fun `should fetch completions from OpenAI API`() = runTest { + // given + val request = ChatCompletionRequest.dummy(Model(OllamaModel.LLAMA2.value)) + val response = OllamaChatResponse.dummy() + whenever(client.request(request.toOllamaChatRequest())).thenReturn(response) + + // when + val completions = provider.completions(request) + + // then + assertEquals(response.toOpenAIChatCompletion(), completions) + } + + @Test + fun `should stream completions for chat`() = runTest { + // given + val request = ChatCompletionRequest.dummy(Model(OllamaModel.LLAMA2.value)) + + val response = OllamaChatResponse.dummy() + whenever(client.stream(request.toOllamaChatRequest())).thenReturn(flow { + emit( + response + ) + }) + + // when + provider.streamCompletions(request).test { + // then + assertEquals( + response.toOpenAIChatCompletionChunk(), + awaitItem() + ) + awaitComplete() + } + + } +} \ No newline at end of file