Skip to content

Commit

Permalink
feat(gateway): introduce OpenAIProviderConfig for gateway
Browse files Browse the repository at this point in the history
  • Loading branch information
hanrw committed Aug 21, 2024
1 parent 5d00ae0 commit 8752e8f
Show file tree
Hide file tree
Showing 16 changed files with 200 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,16 @@ class AnthropicApi(
return anthropicVersion
}

}

fun Anthropic.Companion.create(
apiKey: String,
apiURL: String,
anthropicVersion: String,
): Anthropic {
return AnthropicApi(
apiKey = apiKey,
apiURL = apiURL,
anthropicVersion = anthropicVersion
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,28 @@ class OllamaApi(
return protocol
}

}

fun Ollama(
baseUrl: () -> String = { Ollama.BASE_URL },
port: () -> Int = { Ollama.PORT },
protocol: () -> String = { Ollama.PROTOCOL },
): Ollama {
return OllamaApi(
baseUrl = baseUrl(),
port = port(),
protocol = protocol()
)
}

fun Ollama.Companion.create(
baseUrl: () -> String = { BASE_URL },
port: () -> Int = { PORT },
protocol: () -> String = { PROTOCOL },
): Ollama {
return OllamaApi(
baseUrl = baseUrl(),
port = port(),
protocol = protocol()
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ package com.tddworks.openai.gateway.api

import com.tddworks.anthropic.api.Anthropic
import com.tddworks.anthropic.api.AnthropicModel
import com.tddworks.anthropic.api.internal.create
import com.tddworks.anthropic.api.messages.api.*
import com.tddworks.openai.api.chat.api.ChatCompletionRequest
import com.tddworks.openai.api.legacy.completions.api.Completion
import com.tddworks.openai.api.legacy.completions.api.CompletionRequest
import com.tddworks.openai.gateway.api.internal.AnthropicOpenAIProviderConfig
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.filter
import kotlinx.coroutines.flow.transform
Expand All @@ -16,11 +18,18 @@ import com.tddworks.openai.api.chat.api.OpenAIModel as OpenAIModel

@OptIn(ExperimentalSerializationApi::class)
class AnthropicOpenAIProvider(
private val client: Anthropic,
override val name: String = "Anthropic",
override val models: List<OpenAIModel> = AnthropicModel.availableModels.map {
OpenAIModel(it.value)
}
},
override val config: AnthropicOpenAIProviderConfig,

private val client: Anthropic = Anthropic.create(
apiKey = config.apiKey(),
apiURL = config.baseUrl(),
anthropicVersion = config.anthropicVersion()
)

) : OpenAIProvider {

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@ import com.tddworks.openai.api.chat.api.OpenAIModel
import com.tddworks.openai.api.chat.api.OpenAIModel.Companion.availableModels
import com.tddworks.openai.api.legacy.completions.api.Completion
import com.tddworks.openai.api.legacy.completions.api.CompletionRequest
import com.tddworks.openai.gateway.api.internal.toOpenAIConfig
import kotlinx.coroutines.flow.Flow
import kotlinx.serialization.ExperimentalSerializationApi

class DefaultOpenAIProvider(
config: OpenAIConfig,
private val openAI: OpenAI = OpenAI.create(config),
override val name: String = "OpenAI",
override val models: List<OpenAIModel> = availableModels
override val models: List<OpenAIModel> = availableModels,
override val config: OpenAIProviderConfig,
private val openAI: OpenAI = OpenAI.create(config.toOpenAIConfig())
) : OpenAIProvider {

override fun supports(model: OpenAIModel): Boolean {
Expand All @@ -39,9 +40,9 @@ class DefaultOpenAIProvider(
}

fun OpenAIProvider.Companion.openAI(
config: OpenAIConfig,
config: OpenAIProviderConfig,
models: List<OpenAIModel>,
openAI: OpenAI = OpenAI.create(config)
openAI: OpenAI = OpenAI.create(config.toOpenAIConfig())
): OpenAIProvider {
return DefaultOpenAIProvider(config = config, openAI = openAI, models = models)
return DefaultOpenAIProvider(config = config, models = models, openAI = openAI)
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,30 @@ package com.tddworks.openai.gateway.api
import com.tddworks.ollama.api.Ollama
import com.tddworks.ollama.api.OllamaModel
import com.tddworks.ollama.api.chat.api.*
import com.tddworks.ollama.api.internal.create
import com.tddworks.openai.api.chat.api.ChatCompletion
import com.tddworks.openai.api.chat.api.ChatCompletionChunk
import com.tddworks.openai.api.chat.api.ChatCompletionRequest
import com.tddworks.openai.api.chat.api.OpenAIModel
import com.tddworks.openai.api.legacy.completions.api.Completion
import com.tddworks.openai.api.legacy.completions.api.CompletionRequest
import com.tddworks.openai.gateway.api.internal.OllamaOpenAIProviderConfig
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.transform
import kotlinx.serialization.ExperimentalSerializationApi

@OptIn(ExperimentalSerializationApi::class)
class OllamaOpenAIProvider(
private val client: Ollama,
override val name: String = "Ollama",
override val config: OllamaOpenAIProviderConfig,
override val models: List<OpenAIModel> = OllamaModel.availableModels.map {
OpenAIModel(it.value)
}
},
private val client: Ollama = Ollama.create(
baseUrl = config.baseUrl,
port = config.port,
protocol = config.protocol
)
) : OpenAIProvider {
/**
* Check if the given OpenAIModel is supported by the available models.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,21 @@ import com.tddworks.openai.api.legacy.completions.api.Completions
* Represents a provider for the OpenAI chat functionality.
*/
interface OpenAIProvider : Chat, Completions {
/**
* The name of the provider.
*/
val name: String

/**
* The models supported by the provider.
*/
val models: List<OpenAIModel>

/**
* The configuration for the provider.
*/
val config: OpenAIProviderConfig

/**
* Determines if the provided model is supported or not.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package com.tddworks.openai.gateway.api

/**
* Represents the configuration for the OpenAI API.
*/
interface OpenAIProviderConfig {
val apiKey: () -> String
val baseUrl: () -> String
companion object
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package com.tddworks.openai.gateway.api.internal

import com.tddworks.anthropic.api.AnthropicConfig
import com.tddworks.openai.gateway.api.OpenAIProviderConfig

class AnthropicOpenAIProviderConfig(
val anthropicVersion: () -> String = { "2023-06-01" },
override val apiKey: () -> String,
override val baseUrl: () -> String = { "api.anthropic.com" }
) : OpenAIProviderConfig

fun AnthropicOpenAIProviderConfig.toAnthropicOpenAIConfig() =
AnthropicConfig(anthropicVersion, apiKey, baseUrl)

fun OpenAIProviderConfig.Companion.anthropic(
apiKey: () -> String,
baseUrl: () -> String = { "api.anthropic.com" },
anthropicVersion: () -> String = { "2023-06-01" }
) = AnthropicOpenAIProviderConfig(anthropicVersion, apiKey, baseUrl)
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.tddworks.openai.gateway.api.internal

import com.tddworks.openai.api.OpenAIConfig
import com.tddworks.openai.gateway.api.OpenAIProviderConfig

data class DefaultOpenAIProviderConfig(
override val apiKey: () -> String,
override val baseUrl: () -> String = { DEFAULT_BASE_URL }
) : OpenAIProviderConfig {
companion object {
const val DEFAULT_BASE_URL = "api.openai.com"
}
}

fun OpenAIProviderConfig.toOpenAIConfig() = OpenAIConfig(apiKey, baseUrl)

fun OpenAIProviderConfig.Companion.default(
apiKey: () -> String,
baseUrl: () -> String = { DefaultOpenAIProviderConfig.DEFAULT_BASE_URL }
) = DefaultOpenAIProviderConfig(apiKey, baseUrl)
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package com.tddworks.openai.gateway.api.internal

import com.tddworks.ollama.api.OllamaConfig
import com.tddworks.openai.gateway.api.OpenAIProviderConfig

data class OllamaOpenAIProviderConfig(
val port: () -> Int = { 11434 },
val protocol: () -> String = { "http" },
override val baseUrl: () -> String = { "localhost" },
override val apiKey: () -> String = { "ollama-ignore-this" }
) : OpenAIProviderConfig

fun OllamaOpenAIProviderConfig.toOllamaConfig() =
OllamaConfig(baseUrl = baseUrl, protocol = protocol, port = port)

fun OpenAIProviderConfig.Companion.ollama(
apiKey: () -> String = { "ollama-ignore-this" },
baseUrl: () -> String = { "localhost" },
protocol: () -> String = { "http" },
port: () -> Int = { 11434 }
) = OllamaOpenAIProviderConfig(port, protocol, baseUrl, apiKey)
Original file line number Diff line number Diff line change
@@ -1,45 +1,54 @@
package com.tddworks.openai.gateway.di

import com.tddworks.anthropic.api.AnthropicConfig
import com.tddworks.anthropic.di.anthropicModules
import com.tddworks.di.commonModule
import com.tddworks.ollama.api.OllamaConfig
import com.tddworks.di.getInstance
import com.tddworks.ollama.di.ollamaModules
import com.tddworks.openai.api.OpenAIConfig
import com.tddworks.openai.di.openAIModules
import com.tddworks.openai.gateway.api.AnthropicOpenAIProvider
import com.tddworks.openai.gateway.api.OllamaOpenAIProvider
import com.tddworks.openai.gateway.api.OpenAIGateway
import com.tddworks.openai.gateway.api.OpenAIProvider
import com.tddworks.openai.gateway.api.internal.DefaultOpenAIGateway
import com.tddworks.openai.gateway.api.internal.*
import kotlinx.serialization.ExperimentalSerializationApi
import org.koin.core.context.startKoin
import org.koin.dsl.KoinAppDeclaration
import org.koin.dsl.module

@ExperimentalSerializationApi
fun initOpenAIGateway(
openAIConfig: OpenAIConfig,
anthropicConfig: AnthropicConfig,
ollamaConfig: OllamaConfig,
openAIConfig: DefaultOpenAIProviderConfig,
anthropicConfig: AnthropicOpenAIProviderConfig,
ollamaConfig: OllamaOpenAIProviderConfig,
appDeclaration: KoinAppDeclaration = {},
) = startKoin {
appDeclaration()
modules(
commonModule(false) +
anthropicModules(anthropicConfig) +
openAIModules(openAIConfig) +
ollamaModules(ollamaConfig) +
openAIGatewayModules()
commonModule(false)
+ openAIModules(openAIConfig.toOpenAIConfig())
+ anthropicModules(anthropicConfig.toAnthropicOpenAIConfig())
+ ollamaModules(ollamaConfig.toOllamaConfig())
+ openAIProviderConfigsModule(openAIConfig, anthropicConfig, ollamaConfig)
+ openAIGatewayModules()
)
}.koin.get<OpenAIGateway>()


private fun openAIProviderConfigsModule(
openAIConfig: DefaultOpenAIProviderConfig,
anthropicConfig: AnthropicOpenAIProviderConfig,
ollamaConfig: OllamaOpenAIProviderConfig
) = module {
single { openAIConfig }
single { anthropicConfig }
single { ollamaConfig }
}


@ExperimentalSerializationApi
fun createOpenAIGateway(providers: List<OpenAIProvider>) = startKoin {
modules(
commonModule(false) +
openAIGatewayModules(providers)
commonModule(false) + openAIGatewayModules(providers)
)
}.koin.get<OpenAIGateway>()

Expand All @@ -52,13 +61,12 @@ fun openAIGatewayModules(providers: List<OpenAIProvider>) = module {

@ExperimentalSerializationApi
fun openAIGatewayModules() = module {
single<AnthropicOpenAIProvider> { AnthropicOpenAIProvider(get()) }
single<OllamaOpenAIProvider> { OllamaOpenAIProvider(get()) }
single<AnthropicOpenAIProvider> { AnthropicOpenAIProvider(config = get()) }
single<OllamaOpenAIProvider> { OllamaOpenAIProvider(config = get()) }

single {
listOf(
get<AnthropicOpenAIProvider>(),
get<OllamaOpenAIProvider>()
get<AnthropicOpenAIProvider>(), get<OllamaOpenAIProvider>()
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import com.tddworks.anthropic.api.messages.api.*
import com.tddworks.openai.api.chat.api.ChatCompletionRequest
import com.tddworks.openai.api.chat.api.OpenAIModel
import com.tddworks.openai.api.legacy.completions.api.CompletionRequest
import com.tddworks.openai.gateway.api.internal.AnthropicOpenAIProviderConfig
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.Assertions.assertEquals
Expand All @@ -19,13 +20,15 @@ import kotlin.test.assertTrue

class AnthropicOpenAIProviderTest {
private lateinit var client: Anthropic
private lateinit var config: AnthropicOpenAIProviderConfig

private lateinit var provider: AnthropicOpenAIProvider

@BeforeEach
fun setUp() {
client = mock()
provider = AnthropicOpenAIProvider(client)
config = AnthropicOpenAIProviderConfig(apiKey = { "" })
provider = AnthropicOpenAIProvider(config = config, client = client)
}

@Test
Expand Down
Loading

0 comments on commit 8752e8f

Please sign in to comment.