Skip to content

Commit

Permalink
refactor:
Browse files Browse the repository at this point in the history
 - update com.tddworks.central-portal-publisher
 - add unit tests
  • Loading branch information
hanrw committed Apr 18, 2024
1 parent 7007cab commit d2a5a19
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 2 deletions.
2 changes: 1 addition & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ plugins {
alias(libs.plugins.build.dokka.plugin)

alias(libs.plugins.kotlinx.binary.validator) apply false
id("com.tddworks.sonatype-portal-publisher") version "0.0.1"
id("com.tddworks.central-portal-publisher") version "0.0.2"
}

sonatypePortalPublisher {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@ data class OllamaChatRequest(
}
.let { JsonObject(it) }
}

companion object {
fun dummy() = OllamaChatRequest(
model = "llama2",
messages = listOf(
OllamaChatMessage(
role = "user",
content = "Hello!"
)
)
)
}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,16 @@ data class OllamaChatResponse(
@SerialName("prompt_eval_duration") val promptEvalDuration: Long? = null,
@SerialName("eval_count") val evalCount: Int? = null,
@SerialName("eval_duration") val evalDuration: Long? = null,
)
) {
companion object {
fun dummy() = OllamaChatResponse(
model = "llama2",
createdAt = "2023-08-04T08:52:19.385406455-07:00",
message = OllamaChatMessage(
role = "assistant",
content = "The"
),
done = false
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.tddworks.ollama.api.chat

import org.junit.jupiter.api.Test
import kotlin.test.assertEquals

class OllamaChatRequestTest {

@Test
fun `should return dummy request`() {
// given
val request = OllamaChatRequest.dummy()

// then
assertEquals("llama2", request.model)
assertEquals(1, request.messages.size)
assertEquals("user", request.messages[0].role)
assertEquals("Hello!", request.messages[0].content)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package com.tddworks.openai.gateway.api

import app.cash.turbine.test
import com.tddworks.ollama.api.Ollama
import com.tddworks.ollama.api.OllamaModel
import com.tddworks.ollama.api.chat.OllamaChatResponse
import com.tddworks.ollama.api.chat.api.toOllamaChatRequest
import com.tddworks.ollama.api.chat.api.toOpenAIChatCompletion
import com.tddworks.ollama.api.chat.api.toOpenAIChatCompletionChunk
import com.tddworks.openai.api.chat.api.ChatCompletionRequest
import com.tddworks.openai.api.chat.api.Model
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.test.runTest
import kotlinx.serialization.ExperimentalSerializationApi
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.extension.ExtendWith
import org.mockito.InjectMocks
import org.mockito.Mock
import org.mockito.junit.jupiter.MockitoExtension
import org.mockito.kotlin.whenever

@ExperimentalSerializationApi
@ExtendWith(MockitoExtension::class)
class OllamaOpenAIProviderTest {
@Mock
lateinit var client: Ollama

@InjectMocks
lateinit var provider: OllamaOpenAIProvider

@Test
fun `should return true when model is supported`() {
// given
val supportedModel = Model(OllamaModel.LLAMA2.value)

// when
val isSupported = provider.supports(supportedModel)

// then
kotlin.test.assertTrue(isSupported)
}

@Test
fun `should return false when model is not supported`() {
// given
val unsupportedModel = Model.GPT_3_5_TURBO

// when
val isSupported = provider.supports(unsupportedModel)

// then
kotlin.test.assertFalse(isSupported)
}

@Test
fun `should fetch completions from OpenAI API`() = runTest {
// given
val request = ChatCompletionRequest.dummy(Model(OllamaModel.LLAMA2.value))
val response = OllamaChatResponse.dummy()
whenever(client.request(request.toOllamaChatRequest())).thenReturn(response)

// when
val completions = provider.completions(request)

// then
assertEquals(response.toOpenAIChatCompletion(), completions)
}

@Test
fun `should stream completions for chat`() = runTest {
// given
val request = ChatCompletionRequest.dummy(Model(OllamaModel.LLAMA2.value))

val response = OllamaChatResponse.dummy()
whenever(client.stream(request.toOllamaChatRequest())).thenReturn(flow {
emit(
response
)
})

// when
provider.streamCompletions(request).test {
// then
assertEquals(
response.toOpenAIChatCompletionChunk(),
awaitItem()
)
awaitComplete()
}

}
}

0 comments on commit d2a5a19

Please sign in to comment.