Skip to content

Commit

Permalink
feat(BE-215): ollama fim api support
Browse files Browse the repository at this point in the history
 - OllamaOpenAIProvider support ollama generate api
 - make gateway to support ollama generate api
 - fix test failure - DefaultMessagesApiTest
  • Loading branch information
hanrw committed Jun 19, 2024
1 parent 7125dec commit 4a33ccb
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ import app.cash.turbine.test
import com.tddworks.anthropic.api.messages.api.*
import com.tddworks.anthropic.api.mockHttpClient
import com.tddworks.common.network.api.ktor.internal.DefaultHttpRequester
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.test.UnconfinedTestDispatcher
import kotlinx.coroutines.test.StandardTestDispatcher
import kotlinx.coroutines.test.runTest
import kotlinx.serialization.json.Json
import org.junit.jupiter.api.Assertions.assertEquals
Expand All @@ -15,15 +14,14 @@ import org.koin.dsl.module
import org.koin.test.KoinTest
import org.koin.test.junit5.KoinTestExtension

@OptIn(ExperimentalCoroutinesApi::class)
class DefaultMessagesApiTest : KoinTest {
@JvmField
@RegisterExtension
// This extension is used to set the main dispatcher to a test dispatcher
// launch coroutine eagerly
// same scheduling behavior as would have in a real app/production
val testKoinCoroutineExtension =
TestKoinCoroutineExtension(UnconfinedTestDispatcher())
TestKoinCoroutineExtension(StandardTestDispatcher())

@JvmField
@RegisterExtension
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.tddworks.ollama.api.chat.api

import com.tddworks.common.network.api.ktor.api.AnySerial
import com.tddworks.ollama.api.chat.OllamaChatMessage
import com.tddworks.ollama.api.chat.OllamaChatRequest
import com.tddworks.ollama.api.chat.OllamaChatResponse
Expand Down Expand Up @@ -86,17 +87,29 @@ fun ChatCompletionRequest.toOllamaChatRequest(): OllamaChatRequest {
)
}

/**
* Convert CompletionRequest to OllamaGenerateRequest
*/
@OptIn(ExperimentalSerializationApi::class)
fun CompletionRequest.toOllamaGenerateRequest(): OllamaGenerateRequest {

val options = mutableMapOf<String, AnySerial>()
temperature?.let { options["temperature"] = it }
maxTokens?.let { options["num_predict"] = it }
stop?.let { options["stop"] = it.split(",").toTypedArray() }
return OllamaGenerateRequest(
model = model.value,
prompt = prompt,
stream = stream ?: false,
// Looks only here can adapt the raw option
raw = (streamOptions?.get("raw") ?: false) as Boolean,
options = streamOptions?.filter { it.key != "raw" }
options = options
)
}

/**
* Convert OllamaGenerateResponse to OpenAI Completion
*/
fun OllamaGenerateResponse.toOpenAICompletion(): Completion {
return Completion(
id = createdAt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,9 @@ class ExtensionsTest {
maxTokens = 10,
temperature = 0.5,
stream = false,
stop = "<EOT>",
streamOptions = mapOf(
"raw" to true,
"temperature" to 0.5,
"stop" to arrayOf("<EOT>"),
"num_predict" to 100
"raw" to true
)
)

Expand All @@ -42,7 +40,7 @@ class ExtensionsTest {
with(ollamaGenerateRequest.options) {
assertNotNull(this)
assertEquals(0.5, this?.get("temperature"))
assertEquals(100, this?.get("num_predict"))
assertEquals(10, this?.get("num_predict"))
assertEquals("<EOT>", (this?.get("stop") as Array<*>)[0])
}
}
Expand Down

0 comments on commit 4a33ccb

Please sign in to comment.