Skip to content

Commit

Permalink
feat(BE-215): ollama fim api support
Browse files Browse the repository at this point in the history
 - add missing ut
 - code clean
  • Loading branch information
hanrw committed Jun 19, 2024
1 parent 6dc1b4f commit bc61824
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import org.junit.jupiter.api.Assertions.*
import org.junit.jupiter.api.Test

class OllamaChatResponseTest {

@Test
fun `should decode response to non-streaming OllamaChatResponse`() {
val response = """
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,17 @@ import kotlinx.serialization.ExperimentalSerializationApi


fun OllamaChatResponse.toOpenAIChatCompletion(): ChatCompletion {
return ChatCompletion(
id = createdAt,
return ChatCompletion(id = createdAt,
created = 1L,
model = model,
choices = listOf(
ChatChoice(
message = AssistantMessage(
content = message?.content ?: "",
role = when (message?.role) {
"user" -> Role.User
"assistant" -> Role.Assistant
"system" -> Role.System
else -> throw IllegalArgumentException("Unknown role: ${message?.role}")
}
),
index = 0,
choices = message?.let {
listOf(
ChatChoice(
message = ChatMessage.assistant(it.content),
index = 0,
)
)
)
)
} ?: emptyList())
}

fun OllamaChatResponse.toOpenAIChatCompletionChunk(): ChatCompletionChunk {
Expand All @@ -49,8 +41,7 @@ fun OllamaChatResponse.toOpenAIChatCompletionChunk(): ChatCompletionChunk {
)
)

return ChatCompletionChunk(
id = id,
return ChatCompletionChunk(id = id,
`object` = "ollama-chunk",
created = created,
model = model,
Expand All @@ -60,31 +51,27 @@ fun OllamaChatResponse.toOpenAIChatCompletionChunk(): ChatCompletionChunk {
content = message?.content,
)
)
}
)
})
}

@OptIn(ExperimentalSerializationApi::class)
fun ChatCompletionRequest.toOllamaChatRequest(): OllamaChatRequest {
return OllamaChatRequest(
model = model.value,
messages = messages.map {
OllamaChatMessage(
role = when (it.role) {
Role.User -> "user"
Role.Assistant -> "assistant"
Role.System -> "system"
else -> throw IllegalArgumentException("Unknown role: ${it.role}")
},
content = when (it) {
is UserMessage -> it.content
is AssistantMessage -> it.content
is SystemMessage -> it.content
else -> throw IllegalArgumentException("Unknown message type: $it")
},
)
}
)
return OllamaChatRequest(model = model.value, messages = messages.map {
OllamaChatMessage(
role = when (it.role) {
Role.User -> "user"
Role.Assistant -> "assistant"
Role.System -> "system"
else -> throw IllegalArgumentException("Unknown role: ${it.role}")
},
content = when (it) {
is UserMessage -> it.content
is AssistantMessage -> it.content
is SystemMessage -> it.content
else -> throw IllegalArgumentException("Unknown message type: $it")
},
)
})
}

/**
Expand All @@ -98,34 +85,27 @@ fun CompletionRequest.toOllamaGenerateRequest(): OllamaGenerateRequest {
maxTokens?.let { options["num_predict"] = it }
stop?.let { options["stop"] = it.split(",").toTypedArray() }
return OllamaGenerateRequest(
model = model.value,
prompt = prompt,
stream = stream ?: false,
model = model.value, prompt = prompt, stream = stream ?: false,
// Looks only here can adapt the raw option
raw = (streamOptions?.get("raw") ?: false) as Boolean,
options = options
raw = (streamOptions?.get("raw") ?: false) as Boolean, options = options
)
}

/**
* Convert OllamaGenerateResponse to OpenAI Completion
*/
fun OllamaGenerateResponse.toOpenAICompletion(): Completion {
return Completion(
id = createdAt,
return Completion(id = createdAt,
model = model,
created = 1,
choices = listOf(
CompletionChoice(
text = response,
index = 0,
finishReason = doneReason ?: ""
text = response, index = 0, finishReason = doneReason ?: ""
)
),
usage = Usage(
promptTokens = promptEvalCount,
completionTokens = evalCount,
totalTokens = evalCount?.let { promptEvalCount?.plus(it) }
)
)
totalTokens = evalCount?.let { promptEvalCount?.plus(it) ?: it } ?: 0,
))
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package com.tddworks.ollama.api.chat.api

import com.tddworks.common.network.api.ktor.api.AnySerial
import com.tddworks.ollama.api.OllamaModel
import com.tddworks.ollama.api.chat.OllamaChatMessage
import com.tddworks.ollama.api.chat.OllamaChatResponse
import com.tddworks.ollama.api.generate.OllamaGenerateResponse
import com.tddworks.openai.api.chat.api.ChatCompletionRequest
import com.tddworks.openai.api.chat.api.ChatMessage
import com.tddworks.openai.api.chat.api.Model
import com.tddworks.openai.api.chat.api.Role
import com.tddworks.openai.api.legacy.completions.api.CompletionRequest
import com.tddworks.openai.api.legacy.completions.api.Usage
import kotlinx.serialization.ExperimentalSerializationApi
import org.junit.jupiter.api.Assertions.*
import org.junit.jupiter.api.Test
Expand All @@ -16,7 +19,24 @@ import org.junit.jupiter.api.Test
class ExtensionsTest {

@Test
fun `should convert CompletionRequest to OllamaGenerateRequest`() {
fun `should convert CompletionRequest to OllamaGenerateRequest with required fields`() {
val completionRequest = CompletionRequest(
model = Model(OllamaModel.CODE_LLAMA.value),
prompt = "Once upon a time",
)

val ollamaGenerateRequest = completionRequest.toOllamaGenerateRequest()
assertEquals("codellama", ollamaGenerateRequest.model)
assertEquals("Once upon a time", ollamaGenerateRequest.prompt)
assertFalse(ollamaGenerateRequest.stream)

assertFalse(ollamaGenerateRequest.raw)

assertEquals(emptyMap<String, AnySerial>(), ollamaGenerateRequest.options)
}

@Test
fun `should convert CompletionRequest to OllamaGenerateRequest with all fields`() {
val completionRequest = CompletionRequest(
model = Model(OllamaModel.CODE_LLAMA.value),
prompt = "Once upon a time",
Expand Down Expand Up @@ -48,7 +68,55 @@ class ExtensionsTest {
}

@Test
fun `should convert OllamaGenerateResponse to OpenAICompletion`() {
fun `should convert OllamaGenerateResponse to OpenAICompletion with required fields`() {
val ollamaGenerateResponse = OllamaGenerateResponse(
model = "some-model",
createdAt = "createdAt",
response = "response",
done = false,
)
val openAICompletion = ollamaGenerateResponse.toOpenAICompletion()
assertEquals("createdAt", openAICompletion.id)
assertEquals(1, openAICompletion.created)
assertEquals("some-model", openAICompletion.model)
assertEquals(1, openAICompletion.choices.size)
assertEquals("response", openAICompletion.choices[0].text)
assertEquals(0, openAICompletion.choices[0].index)
assertEquals("", openAICompletion.choices[0].finishReason)
assertEquals(Usage(totalTokens = 0), openAICompletion.usage)
}

@Test
fun `should convert OllamaGenerateResponse to OpenAICompletion without promptEvalCount`() {
val ollamaGenerateResponse = OllamaGenerateResponse(
model = "some-model",
createdAt = "createdAt",
response = "response",
done = false,
evalCount = 10,
evalDuration = 1000,
loadDuration = 1000,
promptEvalDuration = 1000,
)
val openAICompletion = ollamaGenerateResponse.toOpenAICompletion()
assertEquals("createdAt", openAICompletion.id)
assertEquals(1, openAICompletion.created)
assertEquals("some-model", openAICompletion.model)
assertEquals(1, openAICompletion.choices.size)
assertEquals("response", openAICompletion.choices[0].text)
assertEquals(0, openAICompletion.choices[0].index)
assertEquals("", openAICompletion.choices[0].finishReason)
assertEquals(
Usage(
promptTokens = null,
completionTokens = 10,
totalTokens = 10
), openAICompletion.usage
)
}

@Test
fun `should convert OllamaGenerateResponse to OpenAICompletion with all fields`() {
val ollamaGenerateResponse = OllamaGenerateResponse.dummy()
val openAICompletion = ollamaGenerateResponse.toOpenAICompletion()
assertEquals("createdAt", openAICompletion.id)
Expand All @@ -64,7 +132,21 @@ class ExtensionsTest {
}

@Test
fun `should convert OllamaChatResponse to OpenAIChatCompletion`() {
fun `should convert OllamaChatResponse to OpenAIChatCompletion without message`() {
val ollamaChatResponse = OllamaChatResponse(
createdAt = "123",
model = "llama2",
done = false
)
val openAIChatCompletion = ollamaChatResponse.toOpenAIChatCompletion()
assertEquals("123", openAIChatCompletion.id)
assertEquals(1L, openAIChatCompletion.created)
assertEquals("llama2", openAIChatCompletion.model)
assertEquals(0, openAIChatCompletion.choices.size)
}

@Test
fun `should convert OllamaChatResponse to OpenAIChatCompletion with all fields`() {
val ollamaChatResponse = OllamaChatResponse(
createdAt = "123",
model = "llama2",
Expand All @@ -84,6 +166,37 @@ class ExtensionsTest {
assertEquals("assistant", openAIChatCompletion.choices[0].message.role.name)
}

@Test
fun `should throw IllegalArgumentException when message not recognized`() {
val chatCompletionRequest = ChatCompletionRequest(
model = Model(OllamaModel.LLAMA2.value),
messages = listOf(
ChatMessage.vision(emptyList())
)
)

assertThrows(IllegalArgumentException::class.java) {
chatCompletionRequest.toOllamaChatRequest()
}
}

@Test
fun `should throw IllegalArgumentException when role not recognized`() {
val chatCompletionRequest = ChatCompletionRequest(
model = Model(OllamaModel.LLAMA2.value),
messages = listOf(
ChatMessage.UserMessage(
content = "Hello",
role = Role.Tool
)
)
)

assertThrows(IllegalArgumentException::class.java) {
chatCompletionRequest.toOllamaChatRequest()
}
}

@Test
fun `should convert ChatCompletionRequest to OllamaChatRequest`() {
val chatCompletionRequest = ChatCompletionRequest(
Expand Down

0 comments on commit bc61824

Please sign in to comment.