From 3582c17afdb27ff256b0f9702cafbf811e92007e Mon Sep 17 00:00:00 2001 From: Chandeep Date: Wed, 31 Jan 2024 19:08:24 +0530 Subject: [PATCH 01/19] completion without stream in progress --- src/providers/ollama/api.ts | 36 ++++++ src/providers/ollama/chatComplete.ts | 92 ++++++++++++++++ src/providers/ollama/complete.ts | 159 +++++++++++++++++++++++++++ src/providers/ollama/embed.ts | 19 ++++ src/providers/ollama/index.ts | 20 ++++ 5 files changed, 326 insertions(+) create mode 100644 src/providers/ollama/api.ts create mode 100644 src/providers/ollama/chatComplete.ts create mode 100644 src/providers/ollama/complete.ts create mode 100644 src/providers/ollama/embed.ts create mode 100644 src/providers/ollama/index.ts diff --git a/src/providers/ollama/api.ts b/src/providers/ollama/api.ts new file mode 100644 index 000000000..0c757a149 --- /dev/null +++ b/src/providers/ollama/api.ts @@ -0,0 +1,36 @@ +import { ProviderAPIConfig } from "../types"; + +const OllamaAPIConfig: ProviderAPIConfig = { + // getBaseURL: (RESOURCE_NAME:string, DEPLOYMENT_ID:string) => `https://${RESOURCE_NAME}.openai.azure.com/openai/deployments/${DEPLOYMENT_ID}`, + baseURL: "http://localhost:11434", + headers: () => { + return null; + }, + chatComplete: "/api/chat", + complete:"/api/generate", + embed:"/api/embeddings" + // getEndpoint: (fn: string, API_VERSION: string, url?: string) => { + // let mappedFn = fn; + // if (fn === "proxy" && url && url?.indexOf("/chat/completions") > -1) { + // mappedFn = "chatComplete"; + // } else if (fn === "proxy" && url && url?.indexOf("/completions") > -1) { + // mappedFn = "complete"; + // } else if (fn === "proxy" && url && url?.indexOf("/embeddings") > -1) { + // mappedFn = "embed"; + // } + + // switch (mappedFn) { + // case "complete": { + // return `/completions?api-version=${API_VERSION}`; + // } + // case "chatComplete": { + // return `/chat/completions?api-version=${API_VERSION}`; + // } + // case "embed": { + // return `/embeddings?api-version=${API_VERSION}`; + // } + // } + // }, +}; + +export default OllamaAPIConfig; diff --git a/src/providers/ollama/chatComplete.ts b/src/providers/ollama/chatComplete.ts new file mode 100644 index 000000000..d85d61422 --- /dev/null +++ b/src/providers/ollama/chatComplete.ts @@ -0,0 +1,92 @@ +import { ChatCompletionResponse, ProviderConfig } from "../types"; +import { Params } from "../../types/requestBody"; +// TODOS: this configuration does not enforce the maximum token limit for the input parameter. If you want to enforce this, you might need to add a custom validation function or a max property to the ParameterConfig interface, and then use it in the input configuration. However, this might be complex because the token count is not a simple length check, but depends on the specific tokenization method used by the model. + +const transformOptions = (params: Params) => { + const options: Record = {}; + if (params["temperature"]) { + options["temperature"] = params["temperature"]; + } + if (params["top_p"]) { + options["top_p"] = params["top_p"]; + } + if (params["top_k"]) { + options["top_k"] = params["top_k"]; + } + if (params["stop"]) { + options["stop"] = params["stop"]; + } + if (params["presence_penalty"]) { + options["presence_penalty"] = params["presence_penalty"]; + } + if (params["frequency_penalty"]) { + options["frequency_penalty"] = params["frequency_penalty"]; + } + if (params["max_tokens"]) { + options["num_predict"] = params["max_tokens"]; + } + return options; +}; + +export const OllamaChatCompleteConfig: ProviderConfig = { + model: { + param: "model", + }, + prompt: { + param: "prompt", + default: "", + }, + max_tokens: { + param: "options", + transform: (params: Params) => transformOptions(params), + default: 128, + min: -2, + }, + temperature: { + param: "options", + transform: (params: Params) => transformOptions(params), + default: 0.8, + min: 0, + max: 2, + }, + top_p: { + param: "options", + transform: (params: Params) => transformOptions(params), + default: 0.9, + min: 0, + max: 1, + }, + top_k: { + param: "options", + transform: (params: Params) => transformOptions(params), + default: 40, + min: 0, + max: 100, + }, + stream: { + param: "stream", + default: false, + }, + stop: { + param: "options", + transform: (params: Params) => transformOptions(params), + }, + presence_penalty: { + param: "options", + transform: (params: Params) => transformOptions(params), + min: -2, + max: 2, + }, + frequency_penalty: { + param: "options", + transform: (params: Params) => transformOptions(params), + min: -2, + max: 2, + }, +}; + +interface OllamaChatCompleteResponse extends ChatCompletionResponse {} + +export const OllamaChatCompleteResponseTransform: ( + response: OllamaChatCompleteResponse +) => ChatCompletionResponse = (response) => response; diff --git a/src/providers/ollama/complete.ts b/src/providers/ollama/complete.ts new file mode 100644 index 000000000..ab6b87c40 --- /dev/null +++ b/src/providers/ollama/complete.ts @@ -0,0 +1,159 @@ +import { CompletionResponse, ErrorResponse, ProviderConfig } from "../types"; +import { Params } from "../../types/requestBody"; +import { OLLAMA } from "../../globals"; +// TODOS: this configuration does not enforce the maximum token limit for the input parameter. If you want to enforce this, you might need to add a custom validation function or a max property to the ParameterConfig interface, and then use it in the input configuration. However, this might be complex because the token count is not a simple length check, but depends on the specific tokenization method used by the model. + +const transformOptions = (params: Params) => { + const options: Record = {}; + if (params["temperature"]) { + options["temperature"] = params["temperature"]; + } + if (params["top_p"]) { + options["top_p"] = params["top_p"]; + } + if (params["top_k"]) { + options["top_k"] = params["top_k"]; + } + if (params["stop"]) { + options["stop"] = params["stop"]; + } + if (params["presence_penalty"]) { + options["presence_penalty"] = params["presence_penalty"]; + } + if (params["frequency_penalty"]) { + options["frequency_penalty"] = params["frequency_penalty"]; + } + if (params["max_tokens"]) { + options["num_predict"] = params["max_tokens"]; + } + console.log(options); + + return options; +}; + +export const OllamaCompleteConfig: ProviderConfig = { + model: { + param: "model", + }, + prompt: { + param: "prompt", + default: "", + }, + max_tokens: { + param: "options", + transform: (params: Params) => transformOptions(params), + default: 128, + min: -2, + }, + temperature: { + param: "options", + transform: (params: Params) => transformOptions(params), + default: 0.8, + min: 0, + max: 2, + }, + top_p: { + param: "options", + transform: (params: Params) => transformOptions(params), + default: 0.9, + min: 0, + max: 1, + }, + top_k: { + param: "options", + transform: (params: Params) => transformOptions(params), + default: 40, + min: 0, + max: 100, + }, + stream: { + param: "stream", + default: false, + }, + stop: { + param: "options", + transform: (params: Params) => transformOptions(params), + }, + presence_penalty: { + param: "options", + transform: (params: Params) => transformOptions(params), + min: -2, + max: 2, + }, + frequency_penalty: { + param: "options", + transform: (params: Params) => transformOptions(params), + min: -2, + max: 2, + }, +}; + +interface OllamaCompleteResponse { + model: string; + created_at: string; + response: string; + done: boolean; + context: number[]; + total_duration: number; + load_duration: number; + prompt_eval_count: number; + prompt_eval_duration: number; + eval_count: number; + eval_duration: number; + error?: string +} +interface OllamaErrorResponse { + error: string; +} + +export const OllamaCompleteResponseTransform: ( + response: OllamaCompleteResponse | OllamaErrorResponse, + responseStatus: number +) => CompletionResponse | ErrorResponse = (response, responseStatus) => { + if (responseStatus !== 200 && "error" in response) { + return { + error: { + message: response.error, + type: null, + param: null, + code: null, + }, + provider: OLLAMA, + } as ErrorResponse; + } + + if ('response' in response) { + return { + id: Date.now().toString(), + object: "text_completion", + created: Date.now(), + model: response.model, + provider: OLLAMA, + choices: [ + { + text: response.response, + index: 0, + logprobs: null, + finish_reason: "length", + }, + ], + usage: { + prompt_tokens: response.prompt_eval_count, + completion_tokens: response.eval_count, + total_tokens: response.prompt_eval_count + response.eval_count, + }, + } + } + + return { + error: { + message: `Invalid response recieved from ${OLLAMA}: ${JSON.stringify( + response + )}`, + type: null, + param: null, + code: null, + }, + provider: OLLAMA, + } as ErrorResponse; +}; diff --git a/src/providers/ollama/embed.ts b/src/providers/ollama/embed.ts new file mode 100644 index 000000000..3d20bf3dc --- /dev/null +++ b/src/providers/ollama/embed.ts @@ -0,0 +1,19 @@ +import { EmbedResponse } from "../../types/embedRequestBody"; +import { ProviderConfig } from "../types"; + +// TODOS: this configuration does not enforce the maximum token limit for the input parameter. If you want to enforce this, you might need to add a custom validation function or a max property to the ParameterConfig interface, and then use it in the input configuration. However, this might be complex because the token count is not a simple length check, but depends on the specific tokenization method used by the model. + +export const OllamaEmbedConfig: ProviderConfig = { + model: { + param: "model", + }, + input: { + param: "prompt", + required: true, + } +}; + +interface OllamaEmbedResponse extends EmbedResponse {} + +export const OllamaEmbedResponseTransform: (response: OllamaEmbedResponse) => EmbedResponse = (response) => response; + diff --git a/src/providers/ollama/index.ts b/src/providers/ollama/index.ts new file mode 100644 index 000000000..771a59340 --- /dev/null +++ b/src/providers/ollama/index.ts @@ -0,0 +1,20 @@ +import { ProviderConfigs } from "../types"; +import { OllamaCompleteConfig, OllamaCompleteResponseTransform } from "./complete"; +import { OllamaEmbedConfig, OllamaEmbedResponseTransform } from "./embed"; +import OllamaAPIConfig from "./api"; +import { OllamaChatCompleteConfig, OllamaChatCompleteResponseTransform } from "./chatComplete"; + +const OllamaConfig: ProviderConfigs = { + + complete: OllamaCompleteConfig, + embed: OllamaEmbedConfig, + api: OllamaAPIConfig, + chatComplete: OllamaChatCompleteConfig, + responseTransforms: { + complete: OllamaCompleteResponseTransform, + chatComplete: OllamaChatCompleteResponseTransform, + embed: OllamaEmbedResponseTransform + } +}; + +export default OllamaConfig; From c8bad18d69f532dbfa939ad46294d65dbf530cc6 Mon Sep 17 00:00:00 2001 From: Chandeep Date: Thu, 1 Feb 2024 16:30:40 +0530 Subject: [PATCH 02/19] ollama completion with and without stream integrated --- src/globals.ts | 1 + src/middlewares/requestValidator/index.ts | 3 +- src/providers/index.ts | 4 ++- src/providers/ollama/complete.ts | 42 ++++++++++++++++++++--- src/providers/ollama/index.ts | 3 +- src/utils.ts | 5 ++- 6 files changed, 49 insertions(+), 9 deletions(-) diff --git a/src/globals.ts b/src/globals.ts index a867ef9e4..5a739a710 100644 --- a/src/globals.ts +++ b/src/globals.ts @@ -30,6 +30,7 @@ export const GOOGLE: string = "google"; export const PERPLEXITY_AI: string = "perplexity-ai"; export const MISTRAL_AI: string = "mistral-ai"; export const DEEPINFRA: string = "deepinfra"; +export const OLLAMA: string = "ollama"; export const providersWithStreamingSupport = [OPEN_AI, AZURE_OPEN_AI, ANTHROPIC, COHERE]; export const allowedProxyProviders = [OPEN_AI, COHERE, AZURE_OPEN_AI, ANTHROPIC]; diff --git a/src/middlewares/requestValidator/index.ts b/src/middlewares/requestValidator/index.ts index 531f6fbb5..14af6d170 100644 --- a/src/middlewares/requestValidator/index.ts +++ b/src/middlewares/requestValidator/index.ts @@ -13,6 +13,7 @@ import { POWERED_BY, TOGETHER_AI, DEEPINFRA, + OLLAMA, } from "../../globals"; import { configSchema } from "./schema/config"; @@ -61,7 +62,7 @@ export const requestValidator = (c: Context, next: any) => { } if ( requestHeaders[`x-${POWERED_BY}-provider`] && - ![OPEN_AI, AZURE_OPEN_AI, COHERE, ANTHROPIC, ANYSCALE, PALM, TOGETHER_AI, GOOGLE, MISTRAL_AI, PERPLEXITY_AI, DEEPINFRA].includes( + ![OPEN_AI, AZURE_OPEN_AI, COHERE, ANTHROPIC, ANYSCALE, PALM, TOGETHER_AI, GOOGLE, MISTRAL_AI, PERPLEXITY_AI, DEEPINFRA, OLLAMA].includes( requestHeaders[`x-${POWERED_BY}-provider`] ) ) { diff --git a/src/providers/index.ts b/src/providers/index.ts index 775b71df7..00ccfb242 100644 --- a/src/providers/index.ts +++ b/src/providers/index.ts @@ -9,6 +9,7 @@ import OpenAIConfig from "./openai"; import PalmAIConfig from "./palm"; import PerplexityAIConfig from "./perplexity-ai"; import TogetherAIConfig from "./together-ai"; +import OllamaAPIConfig from "./ollama"; import { ProviderConfigs } from "./types"; const Providers: { [key: string]: ProviderConfigs } = { @@ -22,7 +23,8 @@ const Providers: { [key: string]: ProviderConfigs } = { google: GoogleConfig, 'perplexity-ai': PerplexityAIConfig, 'mistral-ai': MistralAIConfig, - 'deepinfra': DeepInfraConfig + 'deepinfra': DeepInfraConfig, + 'ollama': OllamaAPIConfig }; export default Providers; diff --git a/src/providers/ollama/complete.ts b/src/providers/ollama/complete.ts index ab6b87c40..070f5f7c9 100644 --- a/src/providers/ollama/complete.ts +++ b/src/providers/ollama/complete.ts @@ -26,8 +26,6 @@ const transformOptions = (params: Params) => { if (params["max_tokens"]) { options["num_predict"] = params["max_tokens"]; } - console.log(options); - return options; }; @@ -100,12 +98,19 @@ interface OllamaCompleteResponse { prompt_eval_duration: number; eval_count: number; eval_duration: number; - error?: string } interface OllamaErrorResponse { error: string; } +interface OllamaCompleteStreamChunk { + model: string; + create_at: number; + response: string; + done: boolean; + context: number[]; +} + export const OllamaCompleteResponseTransform: ( response: OllamaCompleteResponse | OllamaErrorResponse, responseStatus: number @@ -122,7 +127,7 @@ export const OllamaCompleteResponseTransform: ( } as ErrorResponse; } - if ('response' in response) { + if ("response" in response) { return { id: Date.now().toString(), object: "text_completion", @@ -142,7 +147,7 @@ export const OllamaCompleteResponseTransform: ( completion_tokens: response.eval_count, total_tokens: response.prompt_eval_count + response.eval_count, }, - } + }; } return { @@ -157,3 +162,30 @@ export const OllamaCompleteResponseTransform: ( provider: OLLAMA, } as ErrorResponse; }; + +export const OllamaCompleteStreamChunkResponseTransform: ( + response: string +) => string = (responseChunk) => { + let chunk = responseChunk.trim(); + if (chunk.includes("context")) { + return `data: [DONE]` + `\n\n`; + } + const parsedChunk: OllamaCompleteResponse = JSON.parse(chunk); + return ( + `data: ${JSON.stringify({ + id: Date.now(), + object: "text_completion", + created: Date.now(), + model: parsedChunk.model, + provider: OLLAMA, + choices: [ + { + text: parsedChunk.response, + index: 0, + logprobs: null, + finish_reason: null, + }, + ], + })}` + "\n\n" + ); +}; diff --git a/src/providers/ollama/index.ts b/src/providers/ollama/index.ts index 771a59340..f5faec3b6 100644 --- a/src/providers/ollama/index.ts +++ b/src/providers/ollama/index.ts @@ -1,5 +1,5 @@ import { ProviderConfigs } from "../types"; -import { OllamaCompleteConfig, OllamaCompleteResponseTransform } from "./complete"; +import { OllamaCompleteConfig, OllamaCompleteResponseTransform, OllamaCompleteStreamChunkResponseTransform } from "./complete"; import { OllamaEmbedConfig, OllamaEmbedResponseTransform } from "./embed"; import OllamaAPIConfig from "./api"; import { OllamaChatCompleteConfig, OllamaChatCompleteResponseTransform } from "./chatComplete"; @@ -11,6 +11,7 @@ const OllamaConfig: ProviderConfigs = { api: OllamaAPIConfig, chatComplete: OllamaChatCompleteConfig, responseTransforms: { + 'stream-complete': OllamaCompleteStreamChunkResponseTransform, complete: OllamaCompleteResponseTransform, chatComplete: OllamaChatCompleteResponseTransform, embed: OllamaEmbedResponseTransform diff --git a/src/utils.ts b/src/utils.ts index 3d6068729..7952b287c 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,4 +1,4 @@ -import { ANTHROPIC, COHERE, GOOGLE, PERPLEXITY_AI, DEEPINFRA } from "./globals"; +import { ANTHROPIC, COHERE, GOOGLE, PERPLEXITY_AI, DEEPINFRA, OLLAMA } from "./globals"; import { Params } from "./types/requestBody"; export const getStreamModeSplitPattern = (proxyProvider: string, requestURL: string) => { @@ -21,6 +21,9 @@ export const getStreamModeSplitPattern = (proxyProvider: string, requestURL: str { splitPattern = '\r\n\r\n'; } + if(proxyProvider === OLLAMA){ + splitPattern ='\n'; + } return splitPattern; } From 4932fafb3d92e3384b81c320dcc4b69f60f678b5 Mon Sep 17 00:00:00 2001 From: Chandeep Date: Thu, 1 Feb 2024 19:45:59 +0530 Subject: [PATCH 03/19] ollama chat with and without stream integrated + small fix in cohere and google providers --- src/providers/cohere/chatComplete.ts | 2 +- src/providers/google/chatComplete.ts | 3 - src/providers/ollama/chatComplete.ts | 130 +++++++++++++++++++++++++-- src/providers/ollama/complete.ts | 2 +- src/providers/ollama/index.ts | 5 +- 5 files changed, 129 insertions(+), 13 deletions(-) diff --git a/src/providers/cohere/chatComplete.ts b/src/providers/cohere/chatComplete.ts index cc3e9d311..8f2011736 100644 --- a/src/providers/cohere/chatComplete.ts +++ b/src/providers/cohere/chatComplete.ts @@ -136,7 +136,7 @@ export const CohereChatCompleteStreamChunkTransform: (response: string, fallback return `data: ${JSON.stringify({ id: parsedChunk.id ?? fallbackId, - object: "text_completion", + object: "chat.completion.chunk", created: Math.floor(Date.now() / 1000), model: "", provider: COHERE, diff --git a/src/providers/google/chatComplete.ts b/src/providers/google/chatComplete.ts index 0ab240681..db48e2d3f 100644 --- a/src/providers/google/chatComplete.ts +++ b/src/providers/google/chatComplete.ts @@ -17,9 +17,6 @@ const transformGenerationConfig = (params: Params) => { if (params["top_k"]) { generationConfig["topK"] = params["top_k"]; } - if (params["top_k"]) { - generationConfig["topK"] = params["top_k"]; - } if (params["max_tokens"]) { generationConfig["maxOutputTokens"] = params["max_tokens"]; } diff --git a/src/providers/ollama/chatComplete.ts b/src/providers/ollama/chatComplete.ts index d85d61422..c3b23cfef 100644 --- a/src/providers/ollama/chatComplete.ts +++ b/src/providers/ollama/chatComplete.ts @@ -1,5 +1,11 @@ -import { ChatCompletionResponse, ProviderConfig } from "../types"; +import { + ChatCompletionResponse, + CompletionResponse, + ErrorResponse, + ProviderConfig, +} from "../types"; import { Params } from "../../types/requestBody"; +import { OLLAMA } from "../../globals"; // TODOS: this configuration does not enforce the maximum token limit for the input parameter. If you want to enforce this, you might need to add a custom validation function or a max property to the ParameterConfig interface, and then use it in the input configuration. However, this might be complex because the token count is not a simple length check, but depends on the specific tokenization method used by the model. const transformOptions = (params: Params) => { @@ -32,8 +38,8 @@ export const OllamaChatCompleteConfig: ProviderConfig = { model: { param: "model", }, - prompt: { - param: "prompt", + messages: { + param: "messages", default: "", }, max_tokens: { @@ -85,8 +91,120 @@ export const OllamaChatCompleteConfig: ProviderConfig = { }, }; -interface OllamaChatCompleteResponse extends ChatCompletionResponse {} +interface OllamaChatCompleteResponse { + model: string; + created_at: number; + message: { + role: string; + content: string; + }; + done: boolean; + total_duration: number; + load_duration: number; + prompt_eval_count: number; + prompt_eval_duration: number; + eval_count: number; + eval_duration: number; +} + +interface OllamaErrorResponse { + error: string; +} export const OllamaChatCompleteResponseTransform: ( - response: OllamaChatCompleteResponse -) => ChatCompletionResponse = (response) => response; + response: OllamaChatCompleteResponse | OllamaErrorResponse, + responseStatus: number +) => ChatCompletionResponse | ErrorResponse = (response, responseStatus) => { + if (responseStatus !== 200 && "error" in response) { + return { + error: { + message: response.error, + type: null, + param: null, + code: null, + }, + provider: OLLAMA, + } as ErrorResponse; + } + if ("model" in response) { + return { + id: Date.now().toString(), + object: "chat.completion", + created: Date.now(), + model: response.model, + provider: OLLAMA, + choices: [ + { + index: 0, + message: { + role: "assistant", + content: response.message.content, + }, + finish_reason: "stop", + logprobs: null, + }, + ], + usage: { + prompt_tokens: response.prompt_eval_count, + completion_tokens: response.eval_count, + total_tokens: response.prompt_eval_count + response.eval_count, + }, + }; + } + return { + error: { + message: `Invalid response recieved from ${OLLAMA}: ${JSON.stringify( + response + )}`, + type: null, + param: null, + code: null, + }, + provider: OLLAMA, + } as ErrorResponse; +}; + +interface OllamaCompleteStreamChunk { + model: string; + created_at: string; + message: { + role: string; + content: string; + }; + done: boolean, + total_duration: number; + load_duration: number; + prompt_eval_count: number; + prompt_eval_duration: number; + eval_count: number; + eval_duration: number; +} + +export const OllamaChatCompleteStreamChunkTransform: (response: string, fallbackId: string) => string = (resposeChunk, fallbackId) =>{ + let chunk = resposeChunk.trim() + console.log(chunk); + + if(chunk.includes('total_duration')){ + return `data: [DONE]` + `\n\n`; + } + const parsedChunk : OllamaCompleteStreamChunk = JSON.parse(chunk); + return ( + `data: ${JSON.stringify({ + id: Date.now() ?? fallbackId, + object: "chat.completion.chunk", + created: Date.now(), + model: parsedChunk.model, + provider: OLLAMA, + choices: [ + { + delta: { + content: parsedChunk.message.content + }, + index: 0, + logprobs: null, + finish_reason: null, + }, + ] + })}` + '\n\n' + ) +} \ No newline at end of file diff --git a/src/providers/ollama/complete.ts b/src/providers/ollama/complete.ts index 070f5f7c9..202fbef3c 100644 --- a/src/providers/ollama/complete.ts +++ b/src/providers/ollama/complete.ts @@ -170,7 +170,7 @@ export const OllamaCompleteStreamChunkResponseTransform: ( if (chunk.includes("context")) { return `data: [DONE]` + `\n\n`; } - const parsedChunk: OllamaCompleteResponse = JSON.parse(chunk); + const parsedChunk: OllamaCompleteStreamChunk = JSON.parse(chunk); return ( `data: ${JSON.stringify({ id: Date.now(), diff --git a/src/providers/ollama/index.ts b/src/providers/ollama/index.ts index f5faec3b6..5a59212fd 100644 --- a/src/providers/ollama/index.ts +++ b/src/providers/ollama/index.ts @@ -2,7 +2,7 @@ import { ProviderConfigs } from "../types"; import { OllamaCompleteConfig, OllamaCompleteResponseTransform, OllamaCompleteStreamChunkResponseTransform } from "./complete"; import { OllamaEmbedConfig, OllamaEmbedResponseTransform } from "./embed"; import OllamaAPIConfig from "./api"; -import { OllamaChatCompleteConfig, OllamaChatCompleteResponseTransform } from "./chatComplete"; +import { OllamaChatCompleteConfig, OllamaChatCompleteResponseTransform, OllamaChatCompleteStreamChunkTransform } from "./chatComplete"; const OllamaConfig: ProviderConfigs = { @@ -11,9 +11,10 @@ const OllamaConfig: ProviderConfigs = { api: OllamaAPIConfig, chatComplete: OllamaChatCompleteConfig, responseTransforms: { - 'stream-complete': OllamaCompleteStreamChunkResponseTransform, complete: OllamaCompleteResponseTransform, + 'stream-complete': OllamaCompleteStreamChunkResponseTransform, chatComplete: OllamaChatCompleteResponseTransform, + 'stream-chatComplete': OllamaChatCompleteStreamChunkTransform, embed: OllamaEmbedResponseTransform } }; From cf8f223dd3c93b3a82c8f74695488e148827032b Mon Sep 17 00:00:00 2001 From: Chandeep Date: Fri, 2 Feb 2024 16:46:14 +0530 Subject: [PATCH 04/19] Ollama embed integration + ASK: baseUrl for Ollama --- src/handlers/handlerUtils.ts | 9 ++- src/handlers/streamHandler.ts | 6 +- .../requestValidator/schema/config.ts | 4 +- src/providers/ollama/api.ts | 46 +++++++-------- src/providers/ollama/chatComplete.ts | 1 - src/providers/ollama/embed.ts | 57 +++++++++++++++++-- src/types/requestBody.ts | 2 + 7 files changed, 90 insertions(+), 35 deletions(-) diff --git a/src/handlers/handlerUtils.ts b/src/handlers/handlerUtils.ts index dc75bd3ff..ba9a3df3d 100644 --- a/src/handlers/handlerUtils.ts +++ b/src/handlers/handlerUtils.ts @@ -1,5 +1,5 @@ import { Context } from "hono"; -import { AZURE_OPEN_AI, CONTENT_TYPES, GOOGLE, HEADER_KEYS, PALM, POWERED_BY, RESPONSE_HEADER_KEYS, RETRY_STATUS_CODES } from "../globals"; +import { AZURE_OPEN_AI, CONTENT_TYPES, GOOGLE, HEADER_KEYS, OLLAMA, PALM, POWERED_BY, RESPONSE_HEADER_KEYS, RETRY_STATUS_CODES } from "../globals"; import Providers from "../providers"; import { ProviderAPIConfig, endpointStrings } from "../providers/types"; import transformToProviderRequest from "../services/transformToProviderRequest"; @@ -329,7 +329,12 @@ export async function tryPost(c: Context, providerOption:Options, inputParams: P fetchOptions = constructRequest(apiConfig.headers(), provider); baseUrl = apiConfig.baseURL; endpoint = apiConfig.getEndpoint(fn, providerOption.apiKey, transformedRequestBody.model, params.stream); - } else { + } else if (provider === OLLAMA && apiConfig.getEndpoint) { + fetchOptions = constructRequest(apiConfig.headers(), provider); + baseUrl = providerOption.baseUrl || "" + endpoint = apiConfig.getEndpoint(fn, providerOption.apiKey, transformedRequestBody.model, params.stream); + } + else { // Construct the base object for the POST request fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey), provider); diff --git a/src/handlers/streamHandler.ts b/src/handlers/streamHandler.ts index fb8cac2f4..084318e3d 100644 --- a/src/handlers/streamHandler.ts +++ b/src/handlers/streamHandler.ts @@ -1,4 +1,4 @@ -import { AZURE_OPEN_AI, CONTENT_TYPES, COHERE, GOOGLE } from "../globals"; +import { AZURE_OPEN_AI, CONTENT_TYPES, COHERE, GOOGLE, OLLAMA } from "../globals"; import { OpenAIChatCompleteResponse } from "../providers/openai/chatComplete"; import { OpenAICompleteResponse } from "../providers/openai/complete"; import { getStreamModeSplitPattern } from "../utils"; @@ -92,8 +92,8 @@ export async function handleStreamingMode(response: Response, proxyProvider: str writer.close(); })(); - // Convert GEMINI/COHERE json stream to text/event-stream for non-proxy calls - if ([GOOGLE, COHERE].includes(proxyProvider) && responseTransformer) { + // Convert GEMINI/COHERE/OLLAMA json stream to text/event-stream for non-proxy calls + if ([GOOGLE, COHERE, OLLAMA].includes(proxyProvider) && responseTransformer) { return new Response(readable, { ...response, headers: new Headers({ diff --git a/src/middlewares/requestValidator/schema/config.ts b/src/middlewares/requestValidator/schema/config.ts index ac629d981..27beceef3 100644 --- a/src/middlewares/requestValidator/schema/config.ts +++ b/src/middlewares/requestValidator/schema/config.ts @@ -10,6 +10,7 @@ import { PERPLEXITY_AI, TOGETHER_AI, DEEPINFRA, + OLLAMA } from "../../../globals"; export const configSchema: any = z @@ -45,7 +46,8 @@ export const configSchema: any = z GOOGLE, PERPLEXITY_AI, MISTRAL_AI, - DEEPINFRA + DEEPINFRA, + OLLAMA ].includes(value), { message: diff --git a/src/providers/ollama/api.ts b/src/providers/ollama/api.ts index 0c757a149..7f85952ba 100644 --- a/src/providers/ollama/api.ts +++ b/src/providers/ollama/api.ts @@ -1,36 +1,34 @@ import { ProviderAPIConfig } from "../types"; const OllamaAPIConfig: ProviderAPIConfig = { - // getBaseURL: (RESOURCE_NAME:string, DEPLOYMENT_ID:string) => `https://${RESOURCE_NAME}.openai.azure.com/openai/deployments/${DEPLOYMENT_ID}`, - baseURL: "http://localhost:11434", headers: () => { return null; }, chatComplete: "/api/chat", complete:"/api/generate", - embed:"/api/embeddings" - // getEndpoint: (fn: string, API_VERSION: string, url?: string) => { - // let mappedFn = fn; - // if (fn === "proxy" && url && url?.indexOf("/chat/completions") > -1) { - // mappedFn = "chatComplete"; - // } else if (fn === "proxy" && url && url?.indexOf("/completions") > -1) { - // mappedFn = "complete"; - // } else if (fn === "proxy" && url && url?.indexOf("/embeddings") > -1) { - // mappedFn = "embed"; - // } + embed:"/api/embeddings", + getEndpoint: (fn: string, API_VERSION: string, url?: string) => { + let mappedFn = fn; + if (fn === "proxy" && url && url?.indexOf("/chat/completions") > -1) { + mappedFn = "chatComplete"; + } else if (fn === "proxy" && url && url?.indexOf("/completions") > -1) { + mappedFn = "complete"; + } else if (fn === "proxy" && url && url?.indexOf("/embeddings") > -1) { + mappedFn = "embed"; + } - // switch (mappedFn) { - // case "complete": { - // return `/completions?api-version=${API_VERSION}`; - // } - // case "chatComplete": { - // return `/chat/completions?api-version=${API_VERSION}`; - // } - // case "embed": { - // return `/embeddings?api-version=${API_VERSION}`; - // } - // } - // }, + switch (mappedFn) { + case "complete": { + return `/api/generate`; + } + case "chatComplete": { + return `/api/chat`; + } + case "embed": { + return `/api/embeddings`; + } + } + }, }; export default OllamaAPIConfig; diff --git a/src/providers/ollama/chatComplete.ts b/src/providers/ollama/chatComplete.ts index c3b23cfef..c9de889d8 100644 --- a/src/providers/ollama/chatComplete.ts +++ b/src/providers/ollama/chatComplete.ts @@ -182,7 +182,6 @@ interface OllamaCompleteStreamChunk { export const OllamaChatCompleteStreamChunkTransform: (response: string, fallbackId: string) => string = (resposeChunk, fallbackId) =>{ let chunk = resposeChunk.trim() - console.log(chunk); if(chunk.includes('total_duration')){ return `data: [DONE]` + `\n\n`; diff --git a/src/providers/ollama/embed.ts b/src/providers/ollama/embed.ts index 3d20bf3dc..5b33fd0d9 100644 --- a/src/providers/ollama/embed.ts +++ b/src/providers/ollama/embed.ts @@ -1,5 +1,6 @@ +import { OLLAMA } from "../../globals"; import { EmbedResponse } from "../../types/embedRequestBody"; -import { ProviderConfig } from "../types"; +import { ErrorResponse, ProviderConfig } from "../types"; // TODOS: this configuration does not enforce the maximum token limit for the input parameter. If you want to enforce this, you might need to add a custom validation function or a max property to the ParameterConfig interface, and then use it in the input configuration. However, this might be complex because the token count is not a simple length check, but depends on the specific tokenization method used by the model. @@ -13,7 +14,55 @@ export const OllamaEmbedConfig: ProviderConfig = { } }; -interface OllamaEmbedResponse extends EmbedResponse {} - -export const OllamaEmbedResponseTransform: (response: OllamaEmbedResponse) => EmbedResponse = (response) => response; +interface OllamaEmbedResponse extends EmbedResponse { + embedding: number[]; +} +interface OllamaErrorResponse { + error: string; +} +export const OllamaEmbedResponseTransform: ( + response: OllamaEmbedResponse | OllamaErrorResponse, + responseStatus: number +) => EmbedResponse | ErrorResponse = (response, responseStatus) => { + + if ("error" in response) { + return { + error: { + message: response.error, + type: null, + param: null, + code: null, + }, + provider: "cohere", + } as ErrorResponse; + } + if ("embedding" in response) { + return { + object: "list", + data: [ + { + object: "embedding", + embedding: response.embedding, + index: 0, + }, + ], + model: "", // Todo: find a way to send the ollama embedding model name back + usage: { + prompt_tokens: -1, + total_tokens: -1, + }, + }; + } + return { + error: { + message: `Invalid response recieved from ${OLLAMA}: ${JSON.stringify( + response + )}`, + type: null, + param: null, + code: null, + }, + provider: OLLAMA, + } as ErrorResponse; +}; diff --git a/src/types/requestBody.ts b/src/types/requestBody.ts index 3bb7c25aa..8ccf351dc 100644 --- a/src/types/requestBody.ts +++ b/src/types/requestBody.ts @@ -43,6 +43,8 @@ export interface Options { deploymentId?: string; apiVersion?: string; adAuth?:string; + /** Ollama specific */ + baseUrl?: string; /** provider option index picked based on weight in loadbalance mode */ index?: number; cache?: CacheSettings | string; From 139b8daa54a318b5b85b012c4f25c0d176545397 Mon Sep 17 00:00:00 2001 From: Chandeep Date: Mon, 5 Feb 2024 14:04:18 +0530 Subject: [PATCH 05/19] proxy path URL handled --- src/handlers/proxyHandler.ts | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/handlers/proxyHandler.ts b/src/handlers/proxyHandler.ts index ab4d21f3b..6950414c8 100644 --- a/src/handlers/proxyHandler.ts +++ b/src/handlers/proxyHandler.ts @@ -1,7 +1,7 @@ import { Context } from "hono"; import { retryRequest } from "./retryHandler"; import Providers from "../providers"; -import { ANTHROPIC, MAX_RETRIES, HEADER_KEYS, RETRY_STATUS_CODES, POWERED_BY, RESPONSE_HEADER_KEYS, AZURE_OPEN_AI, CONTENT_TYPES } from "../globals"; +import { ANTHROPIC, MAX_RETRIES, HEADER_KEYS, RETRY_STATUS_CODES, POWERED_BY, RESPONSE_HEADER_KEYS, AZURE_OPEN_AI, CONTENT_TYPES, OLLAMA } from "../globals"; import { fetchProviderOptionsFromConfig, responseHandler, tryProvidersInSequence, updateResponseHeaders } from "./handlerUtils"; import { getStreamingMode } from "../utils"; import { Config, ShortConfig } from "../types/requestBody"; @@ -21,6 +21,10 @@ function getProxyPath(requestURL:string, proxyProvider:string, proxyEndpointPath if (proxyProvider === AZURE_OPEN_AI) { return `https:/${reqPath}${reqQuery}`; } + + if (proxyProvider === OLLAMA) { + return `https:/${reqPath}`; + } let proxyPath = `${providerBasePath}${reqPath}${reqQuery}`; // Fix specific for Anthropic SDK calls. Is this needed? - Yes From 2b235b62daf0fbb5396ab42a07907b003ed4d731 Mon Sep 17 00:00:00 2001 From: csgulati09 Date: Tue, 13 Feb 2024 19:03:37 +0530 Subject: [PATCH 06/19] feat: chatCompletion from openai compatible, embedding from native way, removed completion route --- src/providers/ollama/api.ts | 10 +- src/providers/ollama/chatComplete.ts | 307 +++++++++++++-------------- src/providers/ollama/complete.ts | 191 ----------------- src/providers/ollama/embed.ts | 4 +- src/providers/ollama/index.ts | 4 - src/utils.ts | 2 +- 6 files changed, 158 insertions(+), 360 deletions(-) delete mode 100644 src/providers/ollama/complete.ts diff --git a/src/providers/ollama/api.ts b/src/providers/ollama/api.ts index 7f85952ba..82bddd6cf 100644 --- a/src/providers/ollama/api.ts +++ b/src/providers/ollama/api.ts @@ -4,25 +4,19 @@ const OllamaAPIConfig: ProviderAPIConfig = { headers: () => { return null; }, - chatComplete: "/api/chat", - complete:"/api/generate", + chatComplete: "/v1/chat/completions", embed:"/api/embeddings", getEndpoint: (fn: string, API_VERSION: string, url?: string) => { let mappedFn = fn; if (fn === "proxy" && url && url?.indexOf("/chat/completions") > -1) { mappedFn = "chatComplete"; - } else if (fn === "proxy" && url && url?.indexOf("/completions") > -1) { - mappedFn = "complete"; } else if (fn === "proxy" && url && url?.indexOf("/embeddings") > -1) { mappedFn = "embed"; } switch (mappedFn) { - case "complete": { - return `/api/generate`; - } case "chatComplete": { - return `/api/chat`; + return `/v1/chat/completions`; } case "embed": { return `/api/embeddings`; diff --git a/src/providers/ollama/chatComplete.ts b/src/providers/ollama/chatComplete.ts index c9de889d8..a399fdc14 100644 --- a/src/providers/ollama/chatComplete.ts +++ b/src/providers/ollama/chatComplete.ts @@ -4,206 +4,205 @@ import { ErrorResponse, ProviderConfig, } from "../types"; -import { Params } from "../../types/requestBody"; import { OLLAMA } from "../../globals"; -// TODOS: this configuration does not enforce the maximum token limit for the input parameter. If you want to enforce this, you might need to add a custom validation function or a max property to the ParameterConfig interface, and then use it in the input configuration. However, this might be complex because the token count is not a simple length check, but depends on the specific tokenization method used by the model. - -const transformOptions = (params: Params) => { - const options: Record = {}; - if (params["temperature"]) { - options["temperature"] = params["temperature"]; - } - if (params["top_p"]) { - options["top_p"] = params["top_p"]; - } - if (params["top_k"]) { - options["top_k"] = params["top_k"]; - } - if (params["stop"]) { - options["stop"] = params["stop"]; - } - if (params["presence_penalty"]) { - options["presence_penalty"] = params["presence_penalty"]; - } - if (params["frequency_penalty"]) { - options["frequency_penalty"] = params["frequency_penalty"]; - } - if (params["max_tokens"]) { - options["num_predict"] = params["max_tokens"]; - } - return options; -}; export const OllamaChatCompleteConfig: ProviderConfig = { model: { param: "model", + required: true, + default: "llama2", }, messages: { param: "messages", default: "", }, - max_tokens: { - param: "options", - transform: (params: Params) => transformOptions(params), - default: 128, + frequency_penalty: { + param: "frequency_penalty", min: -2, + max: 2, }, - temperature: { - param: "options", - transform: (params: Params) => transformOptions(params), - default: 0.8, - min: 0, + presence_penalty: { + param: "presence_penalty", + min: -2, max: 2, }, - top_p: { - param: "options", - transform: (params: Params) => transformOptions(params), - default: 0.9, - min: 0, - max: 1, + response_format: { + param: "response_format", }, - top_k: { - param: "options", - transform: (params: Params) => transformOptions(params), - default: 40, - min: 0, - max: 100, + seed: { + param: "seed", + }, + stop: { + param: "stop", }, stream: { param: "stream", default: false, }, - stop: { - param: "options", - transform: (params: Params) => transformOptions(params), - }, - presence_penalty: { - param: "options", - transform: (params: Params) => transformOptions(params), - min: -2, + temperature: { + param: "temperature", + default: 1, + min: 0, max: 2, }, - frequency_penalty: { - param: "options", - transform: (params: Params) => transformOptions(params), - min: -2, - max: 2, + top_p: { + param: "top_p", + default: 1, + min: 0, + max: 1, + }, + max_tokens: { + param: "max_tokens", + default: 100, + min: 0, }, }; -interface OllamaChatCompleteResponse { - model: string; - created_at: number; - message: { - role: string; - content: string; - }; - done: boolean; - total_duration: number; - load_duration: number; - prompt_eval_count: number; - prompt_eval_duration: number; - eval_count: number; - eval_duration: number; +export interface OllamaChatCompleteResponse + extends ChatCompletionResponse, + ErrorResponse { + system_fingerprint: string; } -interface OllamaErrorResponse { - error: string; +export interface OllamaStreamChunk { + id: string; + object: string; + created: number; + model: string; + system_fingerprint: string; + choices: { + delta: { + role: string; + content?: string; + }; + index: number; + finish_reason: string | null; + }[]; } export const OllamaChatCompleteResponseTransform: ( - response: OllamaChatCompleteResponse | OllamaErrorResponse, + response: OllamaChatCompleteResponse, responseStatus: number ) => ChatCompletionResponse | ErrorResponse = (response, responseStatus) => { - if (responseStatus !== 200 && "error" in response) { + + if (responseStatus !== 200) { return { error: { - message: response.error, - type: null, + message: response.error?.message, + type: response.error?.type, param: null, code: null, }, provider: OLLAMA, } as ErrorResponse; } - if ("model" in response) { - return { - id: Date.now().toString(), - object: "chat.completion", - created: Date.now(), - model: response.model, - provider: OLLAMA, - choices: [ - { - index: 0, - message: { - role: "assistant", - content: response.message.content, - }, - finish_reason: "stop", - logprobs: null, - }, - ], - usage: { - prompt_tokens: response.prompt_eval_count, - completion_tokens: response.eval_count, - total_tokens: response.prompt_eval_count + response.eval_count, - }, - }; - } + return { - error: { - message: `Invalid response recieved from ${OLLAMA}: ${JSON.stringify( - response - )}`, - type: null, - param: null, - code: null, - }, + id: response.id, + object: response.object, + created: response.created, + model: response.model, provider: OLLAMA, - } as ErrorResponse; -}; - -interface OllamaCompleteStreamChunk { - model: string; - created_at: string; - message: { - role: string; - content: string; + choices: response.choices, + usage: response.usage, }; - done: boolean, - total_duration: number; - load_duration: number; - prompt_eval_count: number; - prompt_eval_duration: number; - eval_count: number; - eval_duration: number; -} +}; -export const OllamaChatCompleteStreamChunkTransform: (response: string, fallbackId: string) => string = (resposeChunk, fallbackId) =>{ - let chunk = resposeChunk.trim() - - if(chunk.includes('total_duration')){ - return `data: [DONE]` + `\n\n`; +export const OllamaChatCompleteStreamChunkTransform: ( + reponse: string +) => string = (responseChunk) => { + let chunk = responseChunk.trim(); + chunk = chunk.replace(/^data: /, ""); + chunk = chunk.trim(); + if (chunk === "[DONE]") { + return `data: ${chunk}\n\n`; } - const parsedChunk : OllamaCompleteStreamChunk = JSON.parse(chunk); + const parsedChunk: OllamaStreamChunk = JSON.parse(chunk); return ( `data: ${JSON.stringify({ - id: Date.now() ?? fallbackId, - object: "chat.completion.chunk", - created: Date.now(), + id: parsedChunk.id, + object: parsedChunk.object, + created: parsedChunk.created, model: parsedChunk.model, provider: OLLAMA, - choices: [ - { - delta: { - content: parsedChunk.message.content - }, - index: 0, - logprobs: null, - finish_reason: null, - }, - ] - })}` + '\n\n' - ) -} \ No newline at end of file + choices: parsedChunk.choices, + })}` + "\n\n" + ); +}; + +// export interface OllamaChatCompleteResponse extends ChatCompletionResponse { +// system_fingerprint: string; +// } + +// export const OllamaChatCompleteResponseTransform: ( +// response: OllamaChatCompleteResponse +// ) => ChatCompletionResponse = (response) => response; + +// export const OllamaChatCompleteResponseTransform: ( +// response: OllamaChatCompleteResponse | OllamaErrorResponse, +// responseStatus: number +// ) => ChatCompletionResponse | ErrorResponse = (response, responseStatus) => { +// if (responseStatus !== 200 && "error" in response) { +// return { +// error: { +// message: response.error, +// type: null, +// param: null, +// code: null, +// }, +// provider: OLLAMA, +// } as ErrorResponse; +// } +// if ("model" in response) { +// return { +// id: Date.now().toString(), +// object: "chat.completion", +// created: Date.now(), +// model: response.model, +// provider: OLLAMA, +// choices: [ +// { +// index: 0, +// message: { +// role: "assistant", +// content: response.message.content, +// }, +// finish_reason: "stop", +// logprobs: null, +// }, +// ], +// usage: { +// prompt_tokens: response.prompt_eval_count, +// completion_tokens: response.eval_count, +// total_tokens: response.prompt_eval_count + response.eval_count, +// }, +// }; +// } +// return { +// error: { +// message: `Invalid response recieved from ${OLLAMA}: ${JSON.stringify( +// response +// )}`, +// type: null, +// param: null, +// code: null, +// }, +// provider: OLLAMA, +// } as ErrorResponse; +// }; + +// interface OllamaCompleteStreamChunk { +// model: string; +// created_at: string; +// message: { +// role: string; +// content: string; +// }; +// done: boolean; +// total_duration: number; +// load_duration: number; +// prompt_eval_count: number; +// prompt_eval_duration: number; +// eval_count: number; +// eval_duration: number; +// } diff --git a/src/providers/ollama/complete.ts b/src/providers/ollama/complete.ts deleted file mode 100644 index 202fbef3c..000000000 --- a/src/providers/ollama/complete.ts +++ /dev/null @@ -1,191 +0,0 @@ -import { CompletionResponse, ErrorResponse, ProviderConfig } from "../types"; -import { Params } from "../../types/requestBody"; -import { OLLAMA } from "../../globals"; -// TODOS: this configuration does not enforce the maximum token limit for the input parameter. If you want to enforce this, you might need to add a custom validation function or a max property to the ParameterConfig interface, and then use it in the input configuration. However, this might be complex because the token count is not a simple length check, but depends on the specific tokenization method used by the model. - -const transformOptions = (params: Params) => { - const options: Record = {}; - if (params["temperature"]) { - options["temperature"] = params["temperature"]; - } - if (params["top_p"]) { - options["top_p"] = params["top_p"]; - } - if (params["top_k"]) { - options["top_k"] = params["top_k"]; - } - if (params["stop"]) { - options["stop"] = params["stop"]; - } - if (params["presence_penalty"]) { - options["presence_penalty"] = params["presence_penalty"]; - } - if (params["frequency_penalty"]) { - options["frequency_penalty"] = params["frequency_penalty"]; - } - if (params["max_tokens"]) { - options["num_predict"] = params["max_tokens"]; - } - return options; -}; - -export const OllamaCompleteConfig: ProviderConfig = { - model: { - param: "model", - }, - prompt: { - param: "prompt", - default: "", - }, - max_tokens: { - param: "options", - transform: (params: Params) => transformOptions(params), - default: 128, - min: -2, - }, - temperature: { - param: "options", - transform: (params: Params) => transformOptions(params), - default: 0.8, - min: 0, - max: 2, - }, - top_p: { - param: "options", - transform: (params: Params) => transformOptions(params), - default: 0.9, - min: 0, - max: 1, - }, - top_k: { - param: "options", - transform: (params: Params) => transformOptions(params), - default: 40, - min: 0, - max: 100, - }, - stream: { - param: "stream", - default: false, - }, - stop: { - param: "options", - transform: (params: Params) => transformOptions(params), - }, - presence_penalty: { - param: "options", - transform: (params: Params) => transformOptions(params), - min: -2, - max: 2, - }, - frequency_penalty: { - param: "options", - transform: (params: Params) => transformOptions(params), - min: -2, - max: 2, - }, -}; - -interface OllamaCompleteResponse { - model: string; - created_at: string; - response: string; - done: boolean; - context: number[]; - total_duration: number; - load_duration: number; - prompt_eval_count: number; - prompt_eval_duration: number; - eval_count: number; - eval_duration: number; -} -interface OllamaErrorResponse { - error: string; -} - -interface OllamaCompleteStreamChunk { - model: string; - create_at: number; - response: string; - done: boolean; - context: number[]; -} - -export const OllamaCompleteResponseTransform: ( - response: OllamaCompleteResponse | OllamaErrorResponse, - responseStatus: number -) => CompletionResponse | ErrorResponse = (response, responseStatus) => { - if (responseStatus !== 200 && "error" in response) { - return { - error: { - message: response.error, - type: null, - param: null, - code: null, - }, - provider: OLLAMA, - } as ErrorResponse; - } - - if ("response" in response) { - return { - id: Date.now().toString(), - object: "text_completion", - created: Date.now(), - model: response.model, - provider: OLLAMA, - choices: [ - { - text: response.response, - index: 0, - logprobs: null, - finish_reason: "length", - }, - ], - usage: { - prompt_tokens: response.prompt_eval_count, - completion_tokens: response.eval_count, - total_tokens: response.prompt_eval_count + response.eval_count, - }, - }; - } - - return { - error: { - message: `Invalid response recieved from ${OLLAMA}: ${JSON.stringify( - response - )}`, - type: null, - param: null, - code: null, - }, - provider: OLLAMA, - } as ErrorResponse; -}; - -export const OllamaCompleteStreamChunkResponseTransform: ( - response: string -) => string = (responseChunk) => { - let chunk = responseChunk.trim(); - if (chunk.includes("context")) { - return `data: [DONE]` + `\n\n`; - } - const parsedChunk: OllamaCompleteStreamChunk = JSON.parse(chunk); - return ( - `data: ${JSON.stringify({ - id: Date.now(), - object: "text_completion", - created: Date.now(), - model: parsedChunk.model, - provider: OLLAMA, - choices: [ - { - text: parsedChunk.response, - index: 0, - logprobs: null, - finish_reason: null, - }, - ], - })}` + "\n\n" - ); -}; diff --git a/src/providers/ollama/embed.ts b/src/providers/ollama/embed.ts index 5b33fd0d9..6566104fd 100644 --- a/src/providers/ollama/embed.ts +++ b/src/providers/ollama/embed.ts @@ -34,7 +34,7 @@ export const OllamaEmbedResponseTransform: ( param: null, code: null, }, - provider: "cohere", + provider: OLLAMA, } as ErrorResponse; } if ("embedding" in response) { @@ -47,7 +47,7 @@ export const OllamaEmbedResponseTransform: ( index: 0, }, ], - model: "", // Todo: find a way to send the ollama embedding model name back + model: "", usage: { prompt_tokens: -1, total_tokens: -1, diff --git a/src/providers/ollama/index.ts b/src/providers/ollama/index.ts index 5a59212fd..aaa3ccc26 100644 --- a/src/providers/ollama/index.ts +++ b/src/providers/ollama/index.ts @@ -1,18 +1,14 @@ import { ProviderConfigs } from "../types"; -import { OllamaCompleteConfig, OllamaCompleteResponseTransform, OllamaCompleteStreamChunkResponseTransform } from "./complete"; import { OllamaEmbedConfig, OllamaEmbedResponseTransform } from "./embed"; import OllamaAPIConfig from "./api"; import { OllamaChatCompleteConfig, OllamaChatCompleteResponseTransform, OllamaChatCompleteStreamChunkTransform } from "./chatComplete"; const OllamaConfig: ProviderConfigs = { - complete: OllamaCompleteConfig, embed: OllamaEmbedConfig, api: OllamaAPIConfig, chatComplete: OllamaChatCompleteConfig, responseTransforms: { - complete: OllamaCompleteResponseTransform, - 'stream-complete': OllamaCompleteStreamChunkResponseTransform, chatComplete: OllamaChatCompleteResponseTransform, 'stream-chatComplete': OllamaChatCompleteStreamChunkTransform, embed: OllamaEmbedResponseTransform diff --git a/src/utils.ts b/src/utils.ts index 7952b287c..396ebaca0 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -22,7 +22,7 @@ export const getStreamModeSplitPattern = (proxyProvider: string, requestURL: str splitPattern = '\r\n\r\n'; } if(proxyProvider === OLLAMA){ - splitPattern ='\n'; + splitPattern ='\n\n'; } return splitPattern; From 40aff7f6a706a71f3d9318a8ea5a50fb44b8d74e Mon Sep 17 00:00:00 2001 From: csgulati09 Date: Wed, 14 Feb 2024 15:20:52 +0530 Subject: [PATCH 07/19] fix: removed unwanted comments --- src/providers/ollama/chatComplete.ts | 77 ---------------------------- 1 file changed, 77 deletions(-) diff --git a/src/providers/ollama/chatComplete.ts b/src/providers/ollama/chatComplete.ts index a399fdc14..067b68019 100644 --- a/src/providers/ollama/chatComplete.ts +++ b/src/providers/ollama/chatComplete.ts @@ -129,80 +129,3 @@ export const OllamaChatCompleteStreamChunkTransform: ( })}` + "\n\n" ); }; - -// export interface OllamaChatCompleteResponse extends ChatCompletionResponse { -// system_fingerprint: string; -// } - -// export const OllamaChatCompleteResponseTransform: ( -// response: OllamaChatCompleteResponse -// ) => ChatCompletionResponse = (response) => response; - -// export const OllamaChatCompleteResponseTransform: ( -// response: OllamaChatCompleteResponse | OllamaErrorResponse, -// responseStatus: number -// ) => ChatCompletionResponse | ErrorResponse = (response, responseStatus) => { -// if (responseStatus !== 200 && "error" in response) { -// return { -// error: { -// message: response.error, -// type: null, -// param: null, -// code: null, -// }, -// provider: OLLAMA, -// } as ErrorResponse; -// } -// if ("model" in response) { -// return { -// id: Date.now().toString(), -// object: "chat.completion", -// created: Date.now(), -// model: response.model, -// provider: OLLAMA, -// choices: [ -// { -// index: 0, -// message: { -// role: "assistant", -// content: response.message.content, -// }, -// finish_reason: "stop", -// logprobs: null, -// }, -// ], -// usage: { -// prompt_tokens: response.prompt_eval_count, -// completion_tokens: response.eval_count, -// total_tokens: response.prompt_eval_count + response.eval_count, -// }, -// }; -// } -// return { -// error: { -// message: `Invalid response recieved from ${OLLAMA}: ${JSON.stringify( -// response -// )}`, -// type: null, -// param: null, -// code: null, -// }, -// provider: OLLAMA, -// } as ErrorResponse; -// }; - -// interface OllamaCompleteStreamChunk { -// model: string; -// created_at: string; -// message: { -// role: string; -// content: string; -// }; -// done: boolean; -// total_duration: number; -// load_duration: number; -// prompt_eval_count: number; -// prompt_eval_duration: number; -// eval_count: number; -// eval_duration: number; -// } From 106da25a08a4c04ab77de57fa63b4f8f20cba307 Mon Sep 17 00:00:00 2001 From: csgulati09 Date: Wed, 14 Feb 2024 16:14:37 +0530 Subject: [PATCH 08/19] fix: removed unused imports, getEndpoint to /api/chat, removed Ollama from stream handler, default split pattern --- src/handlers/streamHandler.ts | 4 ++-- src/providers/ollama/api.ts | 4 ++-- src/providers/ollama/chatComplete.ts | 1 - src/utils.ts | 3 --- 4 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/handlers/streamHandler.ts b/src/handlers/streamHandler.ts index 10c5427b0..71ae9f574 100644 --- a/src/handlers/streamHandler.ts +++ b/src/handlers/streamHandler.ts @@ -99,8 +99,8 @@ export async function handleStreamingMode(response: Response, proxyProvider: str writer.close(); })(); - // Convert GEMINI/COHERE/OLLAMA json stream to text/event-stream for non-proxy calls - if ([GOOGLE, COHERE, OLLAMA].includes(proxyProvider) && responseTransformer) { + // Convert GEMINI/COHERE json stream to text/event-stream for non-proxy calls + if ([GOOGLE, COHERE].includes(proxyProvider) && responseTransformer) { return new Response(readable, { ...response, headers: new Headers({ diff --git a/src/providers/ollama/api.ts b/src/providers/ollama/api.ts index 82bddd6cf..8a32afd3a 100644 --- a/src/providers/ollama/api.ts +++ b/src/providers/ollama/api.ts @@ -6,9 +6,9 @@ const OllamaAPIConfig: ProviderAPIConfig = { }, chatComplete: "/v1/chat/completions", embed:"/api/embeddings", - getEndpoint: (fn: string, API_VERSION: string, url?: string) => { + getEndpoint: (fn: string, url?: string) => { let mappedFn = fn; - if (fn === "proxy" && url && url?.indexOf("/chat/completions") > -1) { + if (fn === "proxy" && url && url?.indexOf("/api/chat") > -1) { mappedFn = "chatComplete"; } else if (fn === "proxy" && url && url?.indexOf("/embeddings") > -1) { mappedFn = "embed"; diff --git a/src/providers/ollama/chatComplete.ts b/src/providers/ollama/chatComplete.ts index 067b68019..36ab67b05 100644 --- a/src/providers/ollama/chatComplete.ts +++ b/src/providers/ollama/chatComplete.ts @@ -1,6 +1,5 @@ import { ChatCompletionResponse, - CompletionResponse, ErrorResponse, ProviderConfig, } from "../types"; diff --git a/src/utils.ts b/src/utils.ts index 396ebaca0..b8e8cdbc4 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -21,9 +21,6 @@ export const getStreamModeSplitPattern = (proxyProvider: string, requestURL: str { splitPattern = '\r\n\r\n'; } - if(proxyProvider === OLLAMA){ - splitPattern ='\n\n'; - } return splitPattern; } From 915cf725e85175777b7ecbf4ba43a24cd6f5dec6 Mon Sep 17 00:00:00 2001 From: visargD Date: Thu, 15 Feb 2024 17:00:25 +0530 Subject: [PATCH 09/19] feat: add new header constants --- src/globals.ts | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/globals.ts b/src/globals.ts index a156960ee..6bf4d6691 100644 --- a/src/globals.ts +++ b/src/globals.ts @@ -5,7 +5,9 @@ export const HEADER_KEYS: Record = { RETRIES: `x-${POWERED_BY}-retry-count`, PROVIDER: `x-${POWERED_BY}-provider`, TRACE_ID: `x-${POWERED_BY}-trace-id`, - CACHE: `x-${POWERED_BY}-cache` + CACHE: `x-${POWERED_BY}-cache`, + FORWARD_HEADERS: `x-${POWERED_BY}-forward-headers`, + CUSTOM_HOST: `x-${POWERED_BY}-custom-host` } export const RESPONSE_HEADER_KEYS: Record = { From f4675a081d166c99063faf78b3c931cf93016e0e Mon Sep 17 00:00:00 2001 From: visargD Date: Thu, 15 Feb 2024 17:02:23 +0530 Subject: [PATCH 10/19] feat: add custom host header validator --- src/middlewares/requestValidator/index.ts | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/middlewares/requestValidator/index.ts b/src/middlewares/requestValidator/index.ts index 12f8582ca..b9f1340ee 100644 --- a/src/middlewares/requestValidator/index.ts +++ b/src/middlewares/requestValidator/index.ts @@ -82,6 +82,22 @@ export const requestValidator = (c: Context, next: any) => { ); } + const customHostHeader = requestHeaders[`x-${POWERED_BY}-custom-host`]; + if (customHostHeader && customHostHeader.indexOf("api.portkey") > -1) { + return new Response( + JSON.stringify({ + status: "failure", + message: `Invalid custom host`, + }), + { + status: 400, + headers: { + "content-type": "application/json", + }, + } + );; + } + if (requestHeaders[`x-${POWERED_BY}-config`]) { try { From 2a2eb6c7a1104d5161b4a75a0309c0f1c67493c0 Mon Sep 17 00:00:00 2001 From: visargD Date: Thu, 15 Feb 2024 17:05:59 +0530 Subject: [PATCH 11/19] feat: add custom_host and forward_headers config schema props --- .../requestValidator/schema/config.ts | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/src/middlewares/requestValidator/schema/config.ts b/src/middlewares/requestValidator/schema/config.ts index 76de3e868..212a3fa8f 100644 --- a/src/middlewares/requestValidator/schema/config.ts +++ b/src/middlewares/requestValidator/schema/config.ts @@ -87,6 +87,8 @@ export const configSchema: any = z on_status_codes: z.array(z.number()).optional(), targets: z.array(z.lazy(() => configSchema)).optional(), request_timeout: z.number().optional(), + custom_host: z.string().optional(), + forward_headers: z.array(z.string()).optional() }) .refine( (value) => { @@ -94,16 +96,32 @@ export const configSchema: any = z value.provider !== undefined && value.api_key !== undefined; const hasModeTargets = value.strategy !== undefined && value.targets !== undefined; + const isOllamaProvider = value.provider === OLLAMA; + return ( hasProviderApiKey || hasModeTargets || value.cache || value.retry || - value.request_timeout + value.request_timeout || + isOllamaProvider ); }, { message: "Invalid configuration. It must have either 'provider' and 'api_key', or 'strategy' and 'targets', or 'cache', or 'retry', or 'request_timeout'", } + ) + .refine( + (value) => { + const customHost = value.custom_host; + if (customHost && (customHost.indexOf("api.portkey"))) { + return false; + } + return true; + }, + { + message: + "Invalid custom host", + } ); From 67c24cbb563b9a1352c2097925532f1d5622946f Mon Sep 17 00:00:00 2001 From: visargD Date: Thu, 15 Feb 2024 17:06:35 +0530 Subject: [PATCH 12/19] fix: return empty object instead of null for ollama headers --- src/providers/ollama/api.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/providers/ollama/api.ts b/src/providers/ollama/api.ts index 82bddd6cf..47e92b198 100644 --- a/src/providers/ollama/api.ts +++ b/src/providers/ollama/api.ts @@ -2,7 +2,7 @@ import { ProviderAPIConfig } from "../types"; const OllamaAPIConfig: ProviderAPIConfig = { headers: () => { - return null; + return {}; }, chatComplete: "/v1/chat/completions", embed:"/api/embeddings", From b24160b1cf6b422fed46949119f6d2672ba98fa5 Mon Sep 17 00:00:00 2001 From: visargD Date: Thu, 15 Feb 2024 17:07:31 +0530 Subject: [PATCH 13/19] feat: add customHost and forwardHeaders to interface --- src/types/requestBody.ts | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/types/requestBody.ts b/src/types/requestBody.ts index 881dc59fb..63688c4c9 100644 --- a/src/types/requestBody.ts +++ b/src/types/requestBody.ts @@ -43,8 +43,10 @@ export interface Options { deploymentId?: string; apiVersion?: string; adAuth?:string; - /** Ollama specific */ - baseUrl?: string; + /** The parameter to set custom base url */ + customHost?: string; + /** The parameter to set list of headers to be forwarded as-is to the provider */ + forwardHeaders?: string[]; /** provider option index picked based on weight in loadbalance mode */ index?: number; cache?: CacheSettings | string; @@ -96,6 +98,7 @@ export interface Config { cache?: CacheSettings; retry?: RetrySettings; strategy?: Strategy; + customHost?: string; } /** @@ -232,6 +235,7 @@ export interface ShortConfig { resourceName?: string; deploymentId?: string; apiVersion?: string; + customHost?: string; } /** From 51c4d569e3fba5bfcdbccd4eeb43ecb97d16add2 Mon Sep 17 00:00:00 2001 From: visargD Date: Thu, 15 Feb 2024 17:10:50 +0530 Subject: [PATCH 14/19] feat: support custom host in proxy routes --- src/handlers/proxyGetHandler.ts | 11 +++++++++-- src/handlers/proxyHandler.ts | 10 ++++++++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/handlers/proxyGetHandler.ts b/src/handlers/proxyGetHandler.ts index 275e50fc3..33d030c5c 100644 --- a/src/handlers/proxyGetHandler.ts +++ b/src/handlers/proxyGetHandler.ts @@ -10,11 +10,16 @@ function proxyProvider(proxyModeHeader:string, providerHeader: string) { return proxyProvider; } -function getProxyPath(requestURL:string, proxyProvider:string, proxyEndpointPath:string) { +function getProxyPath(requestURL:string, proxyProvider:string, proxyEndpointPath:string, customHost: string) { let reqURL = new URL(requestURL); let reqPath = reqURL.pathname; const reqQuery = reqURL.search; reqPath = reqPath.replace(proxyEndpointPath, ""); + + if (customHost) { + return `${customHost}/${reqPath}${reqQuery}` + } + const providerBasePath = Providers[proxyProvider].api.baseURL; if (proxyProvider === AZURE_OPEN_AI) { return `https:/${reqPath}${reqQuery}`; @@ -55,7 +60,9 @@ export async function proxyGetHandler(c: Context): Promise { proxyPath: c.req.url.indexOf("/v1/proxy") > -1 ? "/v1/proxy" : "/v1" } - let urlToFetch = getProxyPath(c.req.url, store.proxyProvider, store.proxyPath); + const customHost = requestHeaders[HEADER_KEYS.CUSTOM_HOST] || ""; + + let urlToFetch = getProxyPath(c.req.url, store.proxyProvider, store.proxyPath, customHost); let fetchOptions = { headers: headersToSend(requestHeaders, store.customHeadersToAvoid), diff --git a/src/handlers/proxyHandler.ts b/src/handlers/proxyHandler.ts index d56b7c84f..b90d18a31 100644 --- a/src/handlers/proxyHandler.ts +++ b/src/handlers/proxyHandler.ts @@ -12,11 +12,16 @@ function proxyProvider(proxyModeHeader:string, providerHeader: string) { return proxyProvider; } -function getProxyPath(requestURL:string, proxyProvider:string, proxyEndpointPath:string) { +function getProxyPath(requestURL:string, proxyProvider:string, proxyEndpointPath:string, customHost: string) { let reqURL = new URL(requestURL); let reqPath = reqURL.pathname; const reqQuery = reqURL.search; reqPath = reqPath.replace(proxyEndpointPath, ""); + + if (customHost) { + return `${customHost}/${reqPath}${reqQuery}` + } + const providerBasePath = Providers[proxyProvider].api.baseURL; if (proxyProvider === AZURE_OPEN_AI) { return `https:/${reqPath}${reqQuery}`; @@ -94,7 +99,8 @@ export async function proxyHandler(c: Context): Promise { } }; - let urlToFetch = getProxyPath(c.req.url, store.proxyProvider, store.proxyPath); + const customHost = requestConfig?.customHost || requestHeaders[HEADER_KEYS.CUSTOM_HOST] || ""; + let urlToFetch = getProxyPath(c.req.url, store.proxyProvider, store.proxyPath, customHost); store.isStreamingMode = getStreamingMode(store.reqBody, store.proxyProvider, urlToFetch) if (requestConfig && From c3ca279a6518231ff2f4793c2085517fdb41e18d Mon Sep 17 00:00:00 2001 From: visargD Date: Thu, 15 Feb 2024 17:14:40 +0530 Subject: [PATCH 15/19] feat: support custom host and forward headers for unified routes --- src/handlers/handlerUtils.ts | 111 +++++++++++++++++++++++------------ 1 file changed, 75 insertions(+), 36 deletions(-) diff --git a/src/handlers/handlerUtils.ts b/src/handlers/handlerUtils.ts index b014a48e9..9842e11f8 100644 --- a/src/handlers/handlerUtils.ts +++ b/src/handlers/handlerUtils.ts @@ -1,5 +1,5 @@ import { Context } from "hono"; -import { AZURE_OPEN_AI, CONTENT_TYPES, GOOGLE, HEADER_KEYS, OLLAMA, PALM, POWERED_BY, RESPONSE_HEADER_KEYS, RETRY_STATUS_CODES, STABILITY_AI } from "../globals"; +import { ANTHROPIC, AZURE_OPEN_AI, CONTENT_TYPES, GOOGLE, HEADER_KEYS, OLLAMA, PALM, POWERED_BY, RESPONSE_HEADER_KEYS, RETRY_STATUS_CODES, STABILITY_AI } from "../globals"; import Providers from "../providers"; import { ProviderAPIConfig, endpointStrings } from "../providers/types"; import transformToProviderRequest from "../services/transformToProviderRequest"; @@ -19,13 +19,23 @@ import { OpenAICompleteJSONToStreamResponseTransform } from "../providers/openai * @param {string} method - The HTTP method for the request. * @returns {RequestInit} - The fetch options for the request. */ -export function constructRequest(headers: any, provider: string = "", method: string = "POST") { +export function constructRequest(providerConfigMappedHeaders: any, provider: string, method: string, forwardHeaders: string[], requestHeaders: Record) { let baseHeaders: any = { "content-type": "application/json" }; + let headers: Record = { + ...providerConfigMappedHeaders + }; + + const forwardHeadersMap: Record = {}; + + forwardHeaders.forEach((h: string) => { + if (requestHeaders[h]) forwardHeadersMap[h] = requestHeaders[h]; + }) + // Add any headers that the model might need - headers = {...baseHeaders, ...headers} + headers = {...baseHeaders, ...headers, ...forwardHeadersMap} let fetchOptions: RequestInit = { method, @@ -122,7 +132,8 @@ export const fetchProviderOptionsFromConfig = (config: Config | ShortConfig): Op virtualKey: camelCaseConfig.virtualKey, apiKey: camelCaseConfig.apiKey, cache: camelCaseConfig.cache, - retry: camelCaseConfig.retry + retry: camelCaseConfig.retry, + customHost: camelCaseConfig.customHost }]; if (camelCaseConfig.resourceName) providerOptions[0].resourceName = camelCaseConfig.resourceName; if (camelCaseConfig.deploymentId) providerOptions[0].deploymentId = camelCaseConfig.deploymentId; @@ -162,40 +173,47 @@ export async function tryPostProxy(c: Context, providerOption:Options, inputPara const apiConfig: ProviderAPIConfig = Providers[provider].api; let fetchOptions; let url = providerOption.urlToFetch as string; - + let baseUrl:string, endpoint:string; + const forwardHeaders: string[] = []; + baseUrl = (providerOption.customHost || "") || (requestHeaders[HEADER_KEYS.CUSTOM_HOST] || "") || ""; + if (provider === AZURE_OPEN_AI && apiConfig.getBaseURL && apiConfig.getEndpoint) { // Construct the base object for the request if(!!providerOption.apiKey) { - fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, "apiKey"), provider, method); + fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, "apiKey"), provider, method, forwardHeaders, requestHeaders); } else { - fetchOptions = constructRequest(apiConfig.headers(providerOption.adAuth, "adAuth"), provider, method); + fetchOptions = constructRequest(apiConfig.headers(providerOption.adAuth, "adAuth"), provider, method, forwardHeaders, requestHeaders); } - baseUrl = apiConfig.getBaseURL(providerOption.resourceName, providerOption.deploymentId); + baseUrl = baseUrl || apiConfig.getBaseURL(providerOption.resourceName, providerOption.deploymentId); endpoint = apiConfig.getEndpoint(fn, providerOption.apiVersion, url); url = `${baseUrl}${endpoint}`; } else if (provider === PALM && apiConfig.baseURL && apiConfig.getEndpoint) { - fetchOptions = constructRequest(apiConfig.headers(), provider, method); - baseUrl = apiConfig.baseURL; + fetchOptions = constructRequest(apiConfig.headers(), provider, method, forwardHeaders, requestHeaders); + baseUrl = baseUrl || apiConfig.baseURL; endpoint = apiConfig.getEndpoint(fn, providerOption.apiKey, params?.model); url = `${baseUrl}${endpoint}`; - } else if (provider === "anthropic" && apiConfig.baseURL) { + } else if (provider === ANTHROPIC && apiConfig.baseURL) { // Construct the base object for the POST request - fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, fn), provider); - baseUrl = apiConfig.baseURL; + fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, fn), provider, "POST", forwardHeaders, requestHeaders); + baseUrl = baseUrl || apiConfig.baseURL; endpoint = apiConfig[fn] || ""; + url = `${baseUrl}${endpoint}`; } else if (provider === GOOGLE && apiConfig.baseURL && apiConfig.getEndpoint) { - fetchOptions = constructRequest(apiConfig.headers(), provider); - baseUrl = apiConfig.baseURL; + fetchOptions = constructRequest(apiConfig.headers(), provider, "POST", forwardHeaders, requestHeaders); + baseUrl = baseUrl || apiConfig.baseURL; endpoint = apiConfig.getEndpoint(fn, providerOption.apiKey, params.model, params.stream); + url = `${baseUrl}${endpoint}`; } else if (provider === STABILITY_AI && apiConfig.baseURL && apiConfig.getEndpoint) { - fetchOptions = constructRequest(apiConfig.headers(), provider); - baseUrl = apiConfig.baseURL; + fetchOptions = constructRequest(apiConfig.headers(), provider, "POST", forwardHeaders, requestHeaders); + baseUrl = baseUrl || apiConfig.baseURL; endpoint = apiConfig.getEndpoint(fn, params.model, url); + url = `${baseUrl}${endpoint}`; } else { // Construct the base object for the request - fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey), provider, method); + fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey), provider, method, forwardHeaders, requestHeaders); } + if (method === "POST") { fetchOptions.body = JSON.stringify(params) } @@ -314,42 +332,45 @@ export async function tryPost(c: Context, providerOption:Options, inputParams: P let baseUrl:string, endpoint:string, fetchOptions; + const forwardHeaders = providerOption.forwardHeaders || requestHeaders[HEADER_KEYS.FORWARD_HEADERS]?.split(",").map(h => h.trim()) || []; + baseUrl = (providerOption.customHost || "") || (requestHeaders[HEADER_KEYS.CUSTOM_HOST] || "") || ""; + if (provider === AZURE_OPEN_AI && apiConfig.getBaseURL && apiConfig.getEndpoint) { // Construct the base object for the POST request if(!!providerOption.apiKey) { - fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, "apiKey"), provider); + fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, "apiKey"), provider, "POST", forwardHeaders, requestHeaders); } else { - fetchOptions = constructRequest(apiConfig.headers(providerOption.adAuth, "adAuth"), provider); + fetchOptions = constructRequest(apiConfig.headers(providerOption.adAuth, "adAuth"), provider, "POST", forwardHeaders, requestHeaders); } - baseUrl = apiConfig.getBaseURL(providerOption.resourceName, providerOption.deploymentId); + baseUrl = baseUrl || apiConfig.getBaseURL(providerOption.resourceName, providerOption.deploymentId); endpoint = apiConfig.getEndpoint(fn, providerOption.apiVersion); } else if (provider === PALM && apiConfig.baseURL && apiConfig.getEndpoint) { - fetchOptions = constructRequest(apiConfig.headers(), provider); - baseUrl = apiConfig.baseURL; + fetchOptions = constructRequest(apiConfig.headers(), provider, "POST", forwardHeaders, requestHeaders); + baseUrl = baseUrl || apiConfig.baseURL; endpoint = apiConfig.getEndpoint(fn, providerOption.apiKey, providerOption.overrideParams?.model || params?.model); - } else if (provider === "anthropic" && apiConfig.baseURL) { + } else if (provider === ANTHROPIC && apiConfig.baseURL) { // Construct the base object for the POST request - fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, fn), provider); - baseUrl = apiConfig.baseURL; + fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, fn), provider, "POST", forwardHeaders, requestHeaders); + baseUrl = baseUrl || apiConfig.baseURL; endpoint = apiConfig[fn] || ""; } else if (provider === GOOGLE && apiConfig.baseURL && apiConfig.getEndpoint) { - fetchOptions = constructRequest(apiConfig.headers(), provider); - baseUrl = apiConfig.baseURL; + fetchOptions = constructRequest(apiConfig.headers(), provider, "POST", forwardHeaders, requestHeaders); + baseUrl = baseUrl || apiConfig.baseURL; endpoint = apiConfig.getEndpoint(fn, providerOption.apiKey, transformedRequestBody.model, params.stream); } else if (provider === STABILITY_AI && apiConfig.baseURL && apiConfig.getEndpoint) { - fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey), provider); - baseUrl = apiConfig.baseURL; + fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey), provider, "POST", forwardHeaders, requestHeaders); + baseUrl = baseUrl || apiConfig.baseURL; endpoint = apiConfig.getEndpoint(fn, params.model); } else if (provider === OLLAMA && apiConfig.getEndpoint) { - fetchOptions = constructRequest(apiConfig.headers(), provider); - baseUrl = providerOption.baseUrl || "" + fetchOptions = constructRequest(apiConfig.headers(), provider, "POST", forwardHeaders, requestHeaders); + baseUrl = baseUrl; endpoint = apiConfig.getEndpoint(fn, providerOption.apiKey, transformedRequestBody.model, params.stream); } else { // Construct the base object for the POST request - fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey), provider); + fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey), provider, "POST", forwardHeaders, requestHeaders); - baseUrl = apiConfig.baseURL || ""; + baseUrl = baseUrl || apiConfig.baseURL || ""; endpoint = apiConfig[fn] || ""; } @@ -363,7 +384,7 @@ export async function tryPost(c: Context, providerOption:Options, inputParams: P providerOption.retry = { attempts: providerOption.retry?.attempts ?? 0, - onStatusCodes: providerOption.retry?.onStatusCodes ?? [] + onStatusCodes: providerOption.retry?.onStatusCodes ?? RETRY_STATUS_CODES } const [getFromCacheFunction, cacheIdentifier, requestOptions] = [ @@ -548,7 +569,7 @@ export async function tryTargetsRecursively( const strategyMode = currentTarget.strategy?.mode; // start: merge inherited config with current target config (preference given to current) - const currentInheritedConfig = { + const currentInheritedConfig: Record = { overrideParams : { ...inheritedConfig.overrideParams, ...currentTarget.overrideParams @@ -558,11 +579,29 @@ export async function tryTargetsRecursively( requestTimeout: null } + if (currentTarget.forwardHeaders) { + currentInheritedConfig.forwardHeaders = [...currentTarget.forwardHeaders]; + } else if (inheritedConfig.forwardHeaders) { + currentInheritedConfig.forwardHeaders = [...inheritedConfig.forwardHeaders]; + currentTarget.forwardHeaders + } + + if (currentTarget.customHost) { + currentInheritedConfig.customHost = currentTarget.customHost + } else if (inheritedConfig.customHost) { + currentInheritedConfig.customHost = inheritedConfig.customHost; + currentTarget.customHost = inheritedConfig.customHost; + } + if (currentTarget.requestTimeout) { currentInheritedConfig.requestTimeout = currentTarget.requestTimeout } else if (inheritedConfig.requestTimeout) { currentInheritedConfig.requestTimeout = inheritedConfig.requestTimeout; + currentTarget.requestTimeout = inheritedConfig.requestTimeout; } + + + currentTarget.overrideParams = { ...currentInheritedConfig.overrideParams } From 59d80eb82868fa53cc24d53550334d0766b79dae Mon Sep 17 00:00:00 2001 From: visargD Date: Fri, 16 Feb 2024 15:44:17 +0530 Subject: [PATCH 16/19] fix: custom host config schema --- src/middlewares/requestValidator/schema/config.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/middlewares/requestValidator/schema/config.ts b/src/middlewares/requestValidator/schema/config.ts index 212a3fa8f..377f8a41a 100644 --- a/src/middlewares/requestValidator/schema/config.ts +++ b/src/middlewares/requestValidator/schema/config.ts @@ -115,7 +115,7 @@ export const configSchema: any = z .refine( (value) => { const customHost = value.custom_host; - if (customHost && (customHost.indexOf("api.portkey"))) { + if (customHost && (customHost.indexOf("api.portkey") > -1)) { return false; } return true; From 6aa71246a2223d7aea5ecbca3618629784bf69ed Mon Sep 17 00:00:00 2001 From: visargD Date: Fri, 16 Feb 2024 15:45:28 +0530 Subject: [PATCH 17/19] fix: forward headers config inherit logic --- src/handlers/handlerUtils.ts | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/handlers/handlerUtils.ts b/src/handlers/handlerUtils.ts index 9428988a3..267b96c07 100644 --- a/src/handlers/handlerUtils.ts +++ b/src/handlers/handlerUtils.ts @@ -175,8 +175,9 @@ export async function tryPostProxy(c: Context, providerOption:Options, inputPara let url = providerOption.urlToFetch as string; let baseUrl:string, endpoint:string; + const forwardHeaders: string[] = []; - baseUrl = (providerOption.customHost || "") || (requestHeaders[HEADER_KEYS.CUSTOM_HOST] || "") || ""; + baseUrl = requestHeaders[HEADER_KEYS.CUSTOM_HOST] || providerOption.customHost || ""; if (provider === AZURE_OPEN_AI && apiConfig.getBaseURL && apiConfig.getEndpoint) { // Construct the base object for the request @@ -332,8 +333,8 @@ export async function tryPost(c: Context, providerOption:Options, inputParams: P let baseUrl:string, endpoint:string, fetchOptions; - const forwardHeaders = providerOption.forwardHeaders || requestHeaders[HEADER_KEYS.FORWARD_HEADERS]?.split(",").map(h => h.trim()) || []; - baseUrl = (providerOption.customHost || "") || (requestHeaders[HEADER_KEYS.CUSTOM_HOST] || "") || ""; + const forwardHeaders = requestHeaders[HEADER_KEYS.FORWARD_HEADERS]?.split(",").map(h => h.trim()) || providerOption.forwardHeaders || []; + baseUrl = requestHeaders[HEADER_KEYS.CUSTOM_HOST] || providerOption.customHost || ""; if (provider === AZURE_OPEN_AI && apiConfig.getBaseURL && apiConfig.getEndpoint) { // Construct the base object for the POST request @@ -583,7 +584,7 @@ export async function tryTargetsRecursively( currentInheritedConfig.forwardHeaders = [...currentTarget.forwardHeaders]; } else if (inheritedConfig.forwardHeaders) { currentInheritedConfig.forwardHeaders = [...inheritedConfig.forwardHeaders]; - currentTarget.forwardHeaders + currentTarget.forwardHeaders = [...inheritedConfig.forwardHeaders]; } if (currentTarget.customHost) { From 45db52f0a94177c60fe89a985f43654b648d1bbc Mon Sep 17 00:00:00 2001 From: visargD Date: Sat, 17 Feb 2024 15:42:15 +0530 Subject: [PATCH 18/19] fix: change custom host selection preference --- src/handlers/proxyHandler.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/handlers/proxyHandler.ts b/src/handlers/proxyHandler.ts index b90d18a31..dbb7c1e6d 100644 --- a/src/handlers/proxyHandler.ts +++ b/src/handlers/proxyHandler.ts @@ -99,7 +99,7 @@ export async function proxyHandler(c: Context): Promise { } }; - const customHost = requestConfig?.customHost || requestHeaders[HEADER_KEYS.CUSTOM_HOST] || ""; + const customHost = requestHeaders[HEADER_KEYS.CUSTOM_HOST] || requestConfig?.customHost || ""; let urlToFetch = getProxyPath(c.req.url, store.proxyProvider, store.proxyPath, customHost); store.isStreamingMode = getStreamingMode(store.reqBody, store.proxyProvider, urlToFetch) From 2a25274c3e740be8141ef18b713a92e22ca3dd65 Mon Sep 17 00:00:00 2001 From: visargD Date: Sat, 17 Feb 2024 17:40:02 +0530 Subject: [PATCH 19/19] fix: custom host based url creation for proxy routes --- src/handlers/proxyGetHandler.ts | 2 +- src/handlers/proxyHandler.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/handlers/proxyGetHandler.ts b/src/handlers/proxyGetHandler.ts index 33d030c5c..20d6f0c0d 100644 --- a/src/handlers/proxyGetHandler.ts +++ b/src/handlers/proxyGetHandler.ts @@ -17,7 +17,7 @@ function getProxyPath(requestURL:string, proxyProvider:string, proxyEndpointPath reqPath = reqPath.replace(proxyEndpointPath, ""); if (customHost) { - return `${customHost}/${reqPath}${reqQuery}` + return `${customHost}${reqPath}${reqQuery}` } const providerBasePath = Providers[proxyProvider].api.baseURL; diff --git a/src/handlers/proxyHandler.ts b/src/handlers/proxyHandler.ts index dbb7c1e6d..7cfadfe2d 100644 --- a/src/handlers/proxyHandler.ts +++ b/src/handlers/proxyHandler.ts @@ -19,7 +19,7 @@ function getProxyPath(requestURL:string, proxyProvider:string, proxyEndpointPath reqPath = reqPath.replace(proxyEndpointPath, ""); if (customHost) { - return `${customHost}/${reqPath}${reqQuery}` + return `${customHost}${reqPath}${reqQuery}` } const providerBasePath = Providers[proxyProvider].api.baseURL;