Skip to content

Commit

Permalink
feat: support addProvider - OpenAIGateway
Browse files Browse the repository at this point in the history
  • Loading branch information
hanrw committed Jun 15, 2024
1 parent ba2b6fc commit 58e7ee9
Show file tree
Hide file tree
Showing 8 changed files with 235 additions and 12 deletions.
29 changes: 25 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ val openAI = initOpenAI(OpenAIConfig(
))

// stream completions
openAI.streamCompletions(
openAI.streamChatCompletions(
ChatCompletionRequest(
messages = listOf(ChatMessage.UserMessage("hello")),
maxTokens = 1024,
Expand All @@ -41,13 +41,23 @@ openAI.streamCompletions(
}

// chat completions
val chatCompletion = openAI.completions(
val chatCompletion = openAI.chatCompletions(
ChatCompletionRequest(
messages = listOf(ChatMessage.UserMessage("hello")),
maxTokens = 1024,
model = Model.GPT_3_5_TURBO
)
)

// completions(legacy)
val completion = openAI.completions(
CompletionRequest(
prompt = "Once upon a time",
suffix = "The end",
maxTokens = 10,
temperature = 0.5
)
)
```


Expand Down Expand Up @@ -100,7 +110,7 @@ val openAIGateway = initOpenAIGateway(
)

// stream completions
openAIGateway.streamCompletions(
openAIGateway.streamChatCompletions(
ChatCompletionRequest(
messages = listOf(ChatMessage.UserMessage("hello")),
maxTokens = 1024,
Expand All @@ -111,11 +121,22 @@ openAIGateway.streamCompletions(
}

// chat completions
val chatCompletion = openAIGateway.completions(
val chatCompletion = openAIGateway.chatCompletions(
ChatCompletionRequest(
messages = listOf(ChatMessage.UserMessage("hello")),
maxTokens = 1024,
model = Model(Model.GPT_3_5_TURBO.value)
)
)

// completions(legacy)
val completion = openAIGateway.completions(
CompletionRequest(
prompt = "Once upon a time",
suffix = "The end",
maxTokens = 10,
temperature = 0.5
)
)

```
Original file line number Diff line number Diff line change
@@ -1,14 +1,45 @@
package com.tddworks.openai.api

import com.tddworks.common.network.api.ktor.api.HttpRequester
import com.tddworks.common.network.api.ktor.internal.createHttpClient
import com.tddworks.common.network.api.ktor.internal.default
import com.tddworks.di.createJson
import com.tddworks.di.getInstance
import com.tddworks.openai.api.chat.api.Chat
import com.tddworks.openai.api.chat.internal.DefaultChatApi
import com.tddworks.openai.api.images.api.Images
import com.tddworks.openai.api.legacy.completions.api.Completion
import com.tddworks.openai.api.images.internal.DefaultImagesApi
import com.tddworks.openai.api.legacy.completions.api.Completions
import com.tddworks.openai.api.legacy.completions.api.internal.DefaultCompletionsApi

interface OpenAI : Chat, Images, Completions {
companion object {
const val BASE_URL = "api.openai.com"

fun create(config: OpenAIConfig): OpenAI {
val requester = HttpRequester.default(
createHttpClient(
host = config.baseUrl,
authToken = config.apiKey,
// get from commonModule
json = createJson(),
)
)
val chatApi = DefaultChatApi(
requester = requester
)

val imagesApi = DefaultImagesApi(
requester = requester
)

val completionsApi = DefaultCompletionsApi(
requester = requester
)

return object : OpenAI, Chat by chatApi, Images by imagesApi,
Completions by completionsApi {}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,18 @@ import com.tddworks.openai.api.images.api.ImageCreate
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.runBlocking
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertNull
import org.junit.jupiter.api.Disabled
import org.junit.jupiter.api.Test
import kotlin.test.assertNotNull

class OpenAITest {

@Test
fun `should return correct base url`() {
assertEquals("api.openai.com", OpenAI.BASE_URL)
fun `should create openai instance`() {
val openAI = OpenAI.create(OpenAIConfig())

assertNotNull(openAI)
}

@Disabled
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
@file:OptIn(ExperimentalSerializationApi::class)

package com.tddworks.openai.gateway.api

import com.tddworks.openai.api.OpenAI
import com.tddworks.openai.api.OpenAIConfig
import com.tddworks.openai.api.chat.api.ChatCompletion
import com.tddworks.openai.api.chat.api.ChatCompletionChunk
import com.tddworks.openai.api.chat.api.ChatCompletionRequest
import com.tddworks.openai.api.chat.api.Model
import kotlinx.coroutines.flow.Flow
import kotlinx.serialization.ExperimentalSerializationApi

class DefaultOpenAIProvider(
config: OpenAIConfig,
models: List<Model>,
private val openAI: OpenAI = OpenAI.create(config)
) : OpenAIProvider {
private val availableModels: MutableList<Model> = models.toMutableList()

override fun supports(model: Model): Boolean {
return availableModels.any { it.value == model.value }
}

override suspend fun chatCompletions(request: ChatCompletionRequest): ChatCompletion {
return openAI.chatCompletions(request)
}

override fun streamChatCompletions(request: ChatCompletionRequest): Flow<ChatCompletionChunk> {
return openAI.streamChatCompletions(request)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@ import com.tddworks.openai.api.legacy.completions.api.Completions
/**
* Interface for connecting to the OpenAI Gateway to chat.
*/
interface OpenAIGateway : Chat, Completions
interface OpenAIGateway : Chat, Completions {
fun addProvider(provider: OpenAIProvider): OpenAIGateway
fun getProviders(): List<OpenAIProvider>
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,20 @@ import kotlinx.serialization.ExperimentalSerializationApi
*/
@ExperimentalSerializationApi
class DefaultOpenAIGateway(
private val providers: List<OpenAIProvider>,
providers: List<OpenAIProvider>,
private val openAI: OpenAI,
) : OpenAIGateway {
private val availableProviders: MutableList<OpenAIProvider> =
providers.toMutableList()

override fun addProvider(provider: OpenAIProvider): OpenAIGateway {
availableProviders.add(provider)
return this
}

override fun getProviders(): List<OpenAIProvider> {
return availableProviders.toList()
}

/**
* This function is called to get completions for a chat based on the given request.
Expand All @@ -29,7 +40,7 @@ class DefaultOpenAIGateway(
* @return A ChatCompletion object containing the completions for the provided request.
*/
override suspend fun chatCompletions(request: ChatCompletionRequest): ChatCompletion {
return providers.firstOrNull {
return availableProviders.firstOrNull {
it.supports(request.model)
}?.chatCompletions(request) ?: openAI.chatCompletions(request)
}
Expand All @@ -42,7 +53,7 @@ class DefaultOpenAIGateway(
* @return a Flow of ChatCompletionChunk objects representing the completions for the input model
*/
override fun streamChatCompletions(request: ChatCompletionRequest): Flow<ChatCompletionChunk> {
return providers.firstOrNull {
return availableProviders.firstOrNull {
it.supports(request.model)
}?.streamChatCompletions(request) ?: openAI.streamChatCompletions(request)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import kotlin.test.assertEquals
import com.tddworks.anthropic.api.Model as AnthropicModel

@OptIn(ExperimentalSerializationApi::class)
class OpenAIGatewayTest {
class DefaultOpenAIGatewayTest {
private val anthropic = mock<OpenAIProvider> {
on(it.supports(Model(AnthropicModel.CLAUDE_3_HAIKU.value))).thenReturn(true)
}
Expand All @@ -41,6 +41,24 @@ class OpenAIGatewayTest {
openAI = openAI
)

@Test
fun `should able to add new provider`() {
// Given
val provider = mock<OpenAIProvider>()

// When
val gateway = DefaultOpenAIGateway(
providers,
openAI = openAI
).run {
addProvider(provider)
}

// Then
assertEquals(3, gateway.getProviders().size)
assertEquals(provider, gateway.getProviders().last())
}

@Test
fun `should use ollama client to get chat completions`() = runTest {
// Given
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package com.tddworks.openai.gateway.api

import app.cash.turbine.test
import com.tddworks.ollama.api.OllamaModel
import com.tddworks.openai.api.OpenAI
import com.tddworks.openai.api.OpenAIConfig
import com.tddworks.openai.api.chat.api.ChatCompletion
import com.tddworks.openai.api.chat.api.ChatCompletionChunk
import com.tddworks.openai.api.chat.api.ChatCompletionRequest
import com.tddworks.openai.api.chat.api.Model
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.extension.ExtendWith
import org.mockito.Mock
import org.mockito.Mockito.mock
import org.mockito.junit.jupiter.MockitoExtension
import org.mockito.kotlin.whenever

@ExtendWith(MockitoExtension::class)
class DefaultOpenAIProviderTest {
@Mock
lateinit var client: OpenAI


private lateinit var provider: DefaultOpenAIProvider

@BeforeEach
fun setUp() {
provider =
DefaultOpenAIProvider(OpenAIConfig(), listOf(Model.GPT_3_5_TURBO), client)
}

@Test
fun `should return completions from OpenAI API`() = runTest {
// given
val request = ChatCompletionRequest.dummy(Model(OllamaModel.LLAMA2.value))
val response = ChatCompletion.dummy()
whenever(client.chatCompletions(request)).thenReturn(response)

// when
val completions = provider.chatCompletions(request)

// then
assertEquals(response, completions)
}

@Test
fun `should stream completions for chat`() = runTest {
// given
val request = ChatCompletionRequest.dummy(Model(OllamaModel.LLAMA2.value))

val response = ChatCompletionChunk.dummy()
whenever(client.streamChatCompletions(request)).thenReturn(flow {
emit(
response
)
})

// when
provider.streamChatCompletions(request).test {
// then
assertEquals(
response,
awaitItem()
)
awaitComplete()
}

}

@Test
fun `should return false when model is not supported`() {
// Given
val openAI = mock<OpenAI>()
val model = Model("gpt-3.5-turbo")
val provider = DefaultOpenAIProvider(OpenAIConfig(), emptyList(), openAI)

// When
val result = provider.supports(model)

// Then
assertTrue(!result)
}

@Test
fun `should return true when model is supported`() {
// Given
val openAI = mock<OpenAI>()
val model = Model("gpt-3.5-turbo")
val provider = DefaultOpenAIProvider(OpenAIConfig(), listOf(model), openAI)

// When
val result = provider.supports(model)

// Then
assertTrue(result)
}

}

0 comments on commit 58e7ee9

Please sign in to comment.