Skip to content

Commit

Permalink
feat(BE-190): As a user, i want able to use ollama-client
Browse files Browse the repository at this point in the history
  • Loading branch information
hanrw committed Apr 15, 2024
1 parent 3fca36c commit 5276e8b
Show file tree
Hide file tree
Showing 14 changed files with 1,572 additions and 0 deletions.
1,040 changes: 1,040 additions & 0 deletions ollama-client/api.md

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions ollama-client/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
plugins {
`maven-publish`
}

kotlin {
jvm()
sourceSets {
commonMain {
dependencies {
api(projects.ollamaClient.ollamaClientCore)
}
}
}
}
Empty file added ollama-client/jvm/.gitkeep
Empty file.
51 changes: 51 additions & 0 deletions ollama-client/ollama-client-core/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
plugins {
alias(libs.plugins.kotlinx.serialization)
alias(libs.plugins.kover)
`maven-publish`
}

kotlin {
jvm()
macosArm64()
macosX64()

sourceSets {
commonMain.dependencies {
// put your Multiplatform dependencies here
implementation(libs.kotlinx.coroutines.core)
api(libs.kotlinx.serialization.json)
api(libs.bundles.ktor.client)
api(projects.common)
}

commonTest.dependencies {
implementation(libs.ktor.client.mock)
api(projects.common)
}

macosMain.dependencies {
api(libs.ktor.client.darwin)
}

jvmMain.dependencies {
api(libs.ktor.client.cio)
}

jvmTest.dependencies {
implementation(project.dependencies.platform(libs.junit.bom))
implementation(libs.bundles.jvm.test)
implementation(libs.kotlinx.coroutines.test)
implementation(libs.koin.test)
implementation(libs.koin.test.junit5)
implementation(libs.app.cash.turbine)
implementation("com.tngtech.archunit:archunit-junit5:1.1.0")
implementation("org.reflections:reflections:0.10.2")
}
}
}

tasks {
named<Test>("jvmTest") {
useJUnitPlatform()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package com.tddworks.ollama.api

/**
* @author hanrw
* @date 2024/4/14 17:32
*/
class Ollama {
companion object {
const val BASE_URL = "https://ollama.com"
const val ANTHROPIC_VERSION = "1.0.0"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.tddworks.ollama.api.chat

import kotlinx.coroutines.flow.Flow

interface OllamaChatApi {
suspend fun stream(request: OllamaChatRequest): Flow<OllamaChatResponse>
suspend fun request(request: OllamaChatRequest): OllamaChatResponse
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package com.tddworks.ollama.api.chat

import com.tddworks.common.network.api.StreamableRequest
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable


@Serializable
data class OllamaChatRequest(
@SerialName("model") val model: String,
@SerialName("messages") val messages: List<OllamaChatMessage>,
@SerialName("format") val format: String? = null,
// @SerialName("options") val options: Map<String, Any>? = null,
// @SerialName("stream") val stream: Boolean? = null,
@SerialName("keep_alive") val keepAlive: String? = null,
) : StreamableRequest


@Serializable
data class OllamaChatMessage(
@SerialName("role") val role: String,
@SerialName("content") val content: String,
@SerialName("images") val images: List<String>? = null,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package com.tddworks.ollama.api.chat

import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

/**
* {
* "model": "llama2",
* "created_at": "2023-08-04T08:52:19.385406455-07:00",
* "message": {
* "role": "assistant",
* "content": "The"
* },
* "done": false
* }
*/
@Serializable
data class OllamaChatResponse(
@SerialName("model") val model: String,
@SerialName("created_at") val createdAt: String,
@SerialName("message") val message: OllamaChatMessage? = null,
@SerialName("done") val done: Boolean?,
@SerialName("total_duration") val totalDuration: Long? = null,
@SerialName("load_duration") val loadDuration: Long? = null,
@SerialName("prompt_eval_count") val promptEvalCount: Int? = null,
@SerialName("prompt_eval_duration") val promptEvalDuration: Long? = null,
@SerialName("eval_count") val evalCount: Int? = null,
@SerialName("eval_duration") val evalDuration: Long? = null,
)

/**
* {
* "model": "llama2",
* "created_at": "2023-08-04T19:22:45.499127Z",
* "done": true,
* "total_duration": 8113331500,
* "load_duration": 6396458,
* "prompt_eval_count": 61,
* "prompt_eval_duration": 398801000,
* "eval_count": 468,
* "eval_duration": 7701267000
* }
*/
@Serializable
data class FinalOllamaChatResponse(
@SerialName("model") val model: String,
@SerialName("created_at") val createdAt: String,
@SerialName("done") val done: Boolean?,
@SerialName("total_duration") val totalDuration: Long?,
@SerialName("load_duration") val loadDuration: Long?,
@SerialName("prompt_eval_count") val promptEvalCount: Int?,
@SerialName("prompt_eval_duration") val promptEvalDuration: Long?,
@SerialName("eval_count") val evalCount: Int?,
@SerialName("eval_duration") val evalDuration: Long?,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package com.tddworks.ollama.api.chat.internal

import com.tddworks.common.network.api.ktor.api.HttpRequester
import com.tddworks.common.network.api.ktor.api.performRequest
import com.tddworks.common.network.api.ktor.api.streamRequest
import com.tddworks.ollama.api.chat.OllamaChatApi
import com.tddworks.ollama.api.chat.OllamaChatRequest
import com.tddworks.ollama.api.chat.OllamaChatResponse
import io.ktor.client.request.*
import io.ktor.http.*
import kotlinx.coroutines.flow.Flow
import kotlinx.serialization.json.Json

class DefaultOllamaChatApi(
private val requester: HttpRequester,
private val jsonLenient: Json = JsonLenient,
) : OllamaChatApi {
override suspend fun stream(request: OllamaChatRequest): Flow<OllamaChatResponse> {
return requester.streamRequest<OllamaChatResponse> {
method = HttpMethod.Post
url(path = CHAT_API_PATH)
setBody(request.asStreamRequest(jsonLenient))
contentType(ContentType.Application.Json)
accept(ContentType.Text.EventStream)
headers {
append(HttpHeaders.CacheControl, "no-cache")
append(HttpHeaders.Connection, "keep-alive")
}
}
}

override suspend fun request(request: OllamaChatRequest): OllamaChatResponse {
return requester.performRequest {
method = HttpMethod.Post
url(path = CHAT_API_PATH)
setBody(request)
contentType(ContentType.Application.Json)
}
}

companion object {
const val CHAT_API_PATH = "/api/chat"
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.tddworks.ollama.api.chat.internal

import com.tddworks.ollama.api.chat.internal.json.ollamaModule
import kotlinx.serialization.json.Json


/**
* Represents a JSON object that allows for leniency and ignores unknown keys.
*
* @property isLenient Removes JSON specification restriction (RFC-4627) and makes parser more liberal to the malformed input. In lenient mode quoted boolean literals, and unquoted string literals are allowed.
* Its relaxations can be expanded in the future, so that lenient parser becomes even more permissive to invalid value in the input, replacing them with defaults.
* false by default.
* @property ignoreUnknownKeys Specifies whether encounters of unknown properties in the input JSON should be ignored instead of throwing SerializationException. false by default..
*/
val JsonLenient = Json {
isLenient = true
ignoreUnknownKeys = true
// https://github.com/Kotlin/kotlinx.serialization/blob/master/docs/json.md#class-discriminator-for-polymorphism
classDiscriminator = "#class"
serializersModule = ollamaModule
encodeDefaults = true
explicitNulls = false
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package com.tddworks.ollama.api.chat.internal.json

import com.tddworks.common.network.api.StreamableRequest
import com.tddworks.ollama.api.chat.OllamaChatRequest
import kotlinx.serialization.modules.SerializersModule
import kotlinx.serialization.modules.polymorphic

/**
* The `SerializersModule` that defines the serialization and deserialization
* rules for the `StreamableRequest` class and its subclasses.
*/
val ollamaModule = SerializersModule {
/**
* Registers a polymorphic serialization/deserialization for the
* `StreamableRequest` class.
*/
polymorphic(StreamableRequest::class) {
/**
* Registers a subclass serializer for the `OllamaChatRequest` class.
*
* @param OllamaChatRequest.serializer() The serializer for the `OllamaChatRequest` class.
*/
subclass(OllamaChatRequest::class, OllamaChatRequest.serializer())
/**
* Registers a default deserializer for the `StreamableRequest` class.
*
* @param { OllamaChatRequest.serializer() } The deserializer for the `StreamableRequest` class.
*/
defaultDeserializer { OllamaChatRequest.serializer() }
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package com.tddworks.ollama.api


import com.tddworks.common.network.api.ktor.internal.JsonLenient
import io.ktor.client.*
import io.ktor.client.engine.mock.*
import io.ktor.client.plugins.*
import io.ktor.client.plugins.contentnegotiation.*
import io.ktor.client.request.*
import io.ktor.http.*
import io.ktor.serialization.kotlinx.*

/**
* See https://ktor.io/docs/http-client-testing.html#usage
*/
fun mockHttpClient(mockResponse: String) = HttpClient(MockEngine) {

val headers = headersOf("Content-Type" to listOf(ContentType.Application.Json.toString()))

install(ContentNegotiation) {
register(ContentType.Application.Json, KotlinxSerializationConverter(JsonLenient))
}

engine {
addHandler { request ->
if (request.url.encodedPath == "/api/chat") {
respond(mockResponse, HttpStatusCode.OK, headers)
} else {
error("Unhandled ${request.url.encodedPath}")
}
}
}

defaultRequest {
url {
protocol = URLProtocol.HTTPS
host = "api.lemonsqueezy.com"
}

header(HttpHeaders.ContentType, ContentType.Application.Json)
contentType(ContentType.Application.Json)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package com.tddworks.ollama.api.internal

import com.tddworks.common.network.api.ktor.internal.DefaultHttpRequester
import com.tddworks.common.network.api.ktor.internal.createHttpClient
import com.tddworks.di.initKoin
import com.tddworks.ollama.api.Ollama
import com.tddworks.ollama.api.chat.OllamaChatMessage
import com.tddworks.ollama.api.chat.OllamaChatRequest
import com.tddworks.ollama.api.chat.internal.DefaultOllamaChatApi
import com.tddworks.ollama.api.chat.internal.JsonLenient
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.koin.test.junit5.AutoCloseKoinTest

class DefaultOllamaChatApiITest : AutoCloseKoinTest() {

@BeforeEach
fun setUp() {
initKoin()
}


@Test
fun `should return correct base url`() {
assertEquals("api.anthropic.com", Ollama.BASE_URL)
}


@Test
fun `should return stream response`() = runTest {
val ollamaChatApi = DefaultOllamaChatApi(
requester = DefaultHttpRequester(
createHttpClient(
url = { "localhost" },
json = JsonLenient,
)
)
)

ollamaChatApi.stream(
OllamaChatRequest(
model = "llama2",
messages = listOf(
OllamaChatMessage(
role = "user",
content = "hello"
)
)
)
).collect {
println("stream response: $it")
}
}

// @Test
// fun `should return create response`() = runTest {
// //Client request(POST https://klaude.asusual.life/v1/messages) invalid: 401 Unauthorized. Text: "{"type":"error","error":{"type":"authentication_error","message":"invalid x-api-key"}}"
// //Client request(POST https://klaude.asusual.life/v1/messages) invalid: 400 Bad Request. Text: "{"type":"error","error":{"type":"invalid_request_error","message":"anthropic-version: header is required"}}"
// val anthropic = getInstance<Anthropic>()
//
// val r = anthropic.create(
// CreateMessageRequest(
// messages = listOf(Message.user("hello")),
// maxTokens = 1024,
// model = Model.CLAUDE_3_HAIKU
// )
// )
//
// assertNotNull(r.content[0].text)
// }
}
Loading

0 comments on commit 5276e8b

Please sign in to comment.