Skip to content

Commit

Permalink
fix cohere rerank
Browse files Browse the repository at this point in the history
  • Loading branch information
pepesi committed Jan 23, 2025
1 parent 7025c0e commit 1094293
Show file tree
Hide file tree
Showing 9 changed files with 14 additions and 17 deletions.
4 changes: 4 additions & 0 deletions plugins/wasm-go/extensions/ai-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,5 +280,9 @@ func getOpenAiApiName(path string) provider.ApiName {
if strings.HasSuffix(path, "/v1/images/generations") {
return provider.ApiNameImageGeneration
}
// rerank
if strings.HasSuffix(path, "/v1/rerank") {
return provider.ApiNameCohereV1Rerank
}
return ""
}
3 changes: 1 addition & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ type azureProviderInitializer struct {

func (m *azureProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
// azure 此配置无实质作用,只是为了保持和其他provider的一致性
// TODO: azure的模式和openai是一致的,只是需要处理前缀,可以在TransformRequestHeaders中处理,以支持通用能力
// TODO: azure's pattern is the same as openai, just need to handle the prefix, can be done in TransformRequestHeaders to support general capabilities
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
}
Expand Down
1 change: 0 additions & 1 deletion plugins/wasm-go/extensions/ai-proxy/provider/baichuan.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ func (m *baichuanProviderInitializer) ValidateConfig(config *ProviderConfig) err

func (m *baichuanProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
// 百川AI的chat和embeddings接口和OpenAI的chat和embeddings接口一样
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
}
Expand Down
4 changes: 2 additions & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (

const (
cohereDomain = "api.cohere.com"
// TODO: 现在cohere有v2, 也有embeddings, 考虑更多支持: https://docs.cohere.com/v2/reference/rerank
// TODO: support more capabilities, upgrade to v2, docs: https://docs.cohere.com/v2/reference/chat
cohereChatCompletionPath = "/v1/chat"
cohereRerankPath = "/v1/rerank"
)
Expand Down Expand Up @@ -100,7 +100,7 @@ func (m *cohereProvider) buildCohereRequest(origin *chatCompletionRequest) *cohe
}

func (m *cohereProvider) TransformRequestHeaders(ctx wrapper.HttpContext, apiName ApiName, headers http.Header, log wrapper.Log) {
util.OverwriteRequestPathHeader(headers, cohereChatCompletionPath)
util.OverwriteRequestPathHeaderByCapability(headers, string(apiName), m.config.capabilities)
util.OverwriteRequestHostHeader(headers, cohereDomain)
util.OverwriteRequestAuthorizationHeader(headers, "Bearer "+m.config.GetApiTokenInUse(ctx))
headers.Del("Content-Length")
Expand Down
4 changes: 1 addition & 3 deletions plugins/wasm-go/extensions/ai-proxy/provider/coze.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ func (m *cozeProviderInitializer) ValidateConfig(config *ProviderConfig) error {
}

func (m *cozeProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
// 此配置暂时无实质作用,只是为了保持和其他provider的一致性
}
return map[string]string{}
}

func (m *cozeProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
Expand Down
3 changes: 2 additions & 1 deletion plugins/wasm-go/extensions/ai-proxy/provider/deepseek.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ import (

const (
deepseekDomain = "api.deepseek.com"
// TODO: 根据文档 docs: https://api-docs.deepseek.com/api/create-chat-completion, path应该是 /chat/completions, 待验证
// TODO: docs: https://api-docs.deepseek.com/api/create-chat-completion
// accourding to the docs, the path should be /chat/completions, need to be verified
deepseekChatCompletionPath = "/v1/chat/completions"
)

Expand Down
6 changes: 1 addition & 5 deletions plugins/wasm-go/extensions/ai-proxy/provider/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,7 @@ func (g *geminiProviderInitializer) ValidateConfig(config *ProviderConfig) error
}

func (g *geminiProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
// path在gemini中没有实际意义,只是为了保持和其他provider的一致性
string(ApiNameChatCompletion): "_",
string(ApiNameEmbeddings): "_",
}
return map[string]string{}
}

func (g *geminiProviderInitializer) CreateProvider(config ProviderConfig) (Provider, error) {
Expand Down
3 changes: 1 addition & 2 deletions plugins/wasm-go/extensions/ai-proxy/provider/mistral.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ func (m *mistralProviderInitializer) ValidateConfig(config *ProviderConfig) erro

func (m *mistralProviderInitializer) DefaultCapabilities() map[string]string {
return map[string]string{
// mistral的chat接口和OpenAI的chat接口一样
// docs: https://docs.mistral.ai/api/
// The chat interface of mistral is the same as that of OpenAI. docs: https://docs.mistral.ai/api/
string(ApiNameChatCompletion): PathOpenAIChatCompletions,
string(ApiNameEmbeddings): PathOpenAIEmbeddings,
}
Expand Down
3 changes: 2 additions & 1 deletion plugins/wasm-go/extensions/ai-proxy/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,8 @@ func (c *ProviderConfig) FromJson(json gjson.Result) {
case string(ApiNameChatCompletion),
string(ApiNameEmbeddings),
string(ApiNameImageGeneration),
string(ApiNameAudioSpeech):
string(ApiNameAudioSpeech),
string(ApiNameCohereV1Rerank):
c.capabilities[capability] = pathJson.String()
}
}
Expand Down

0 comments on commit 1094293

Please sign in to comment.