diff --git a/.gitignore b/.gitignore index 0f965e940..8618236ab 100644 --- a/.gitignore +++ b/.gitignore @@ -133,3 +133,6 @@ dist # Rollup build dir build .DS_Store + +# Wrangler temp directory +.wrangler \ No newline at end of file diff --git a/src/globals.ts b/src/globals.ts index f9a4474d4..aea933b53 100644 --- a/src/globals.ts +++ b/src/globals.ts @@ -5,7 +5,9 @@ export const HEADER_KEYS: Record = { RETRIES: `x-${POWERED_BY}-retry-count`, PROVIDER: `x-${POWERED_BY}-provider`, TRACE_ID: `x-${POWERED_BY}-trace-id`, - CACHE: `x-${POWERED_BY}-cache` + CACHE: `x-${POWERED_BY}-cache`, + FORWARD_HEADERS: `x-${POWERED_BY}-forward-headers`, + CUSTOM_HOST: `x-${POWERED_BY}-custom-host` } export const RESPONSE_HEADER_KEYS: Record = { @@ -33,6 +35,7 @@ export const MISTRAL_AI: string = "mistral-ai"; export const DEEPINFRA: string = "deepinfra"; export const STABILITY_AI: string = "stability-ai"; export const NOMIC: string = "nomic"; +export const OLLAMA: string = "ollama"; export const providersWithStreamingSupport = [OPEN_AI, AZURE_OPEN_AI, ANTHROPIC, COHERE]; export const allowedProxyProviders = [OPEN_AI, COHERE, AZURE_OPEN_AI, ANTHROPIC]; diff --git a/src/handlers/handlerUtils.ts b/src/handlers/handlerUtils.ts index 797ef308d..267b96c07 100644 --- a/src/handlers/handlerUtils.ts +++ b/src/handlers/handlerUtils.ts @@ -1,5 +1,5 @@ import { Context } from "hono"; -import { AZURE_OPEN_AI, CONTENT_TYPES, GOOGLE, HEADER_KEYS, PALM, POWERED_BY, RESPONSE_HEADER_KEYS, RETRY_STATUS_CODES, STABILITY_AI } from "../globals"; +import { ANTHROPIC, AZURE_OPEN_AI, CONTENT_TYPES, GOOGLE, HEADER_KEYS, OLLAMA, PALM, POWERED_BY, RESPONSE_HEADER_KEYS, RETRY_STATUS_CODES, STABILITY_AI } from "../globals"; import Providers from "../providers"; import { ProviderAPIConfig, endpointStrings } from "../providers/types"; import transformToProviderRequest from "../services/transformToProviderRequest"; @@ -19,13 +19,23 @@ import { OpenAICompleteJSONToStreamResponseTransform } from "../providers/openai * @param {string} method - The HTTP method for the request. * @returns {RequestInit} - The fetch options for the request. */ -export function constructRequest(headers: any, provider: string = "", method: string = "POST") { +export function constructRequest(providerConfigMappedHeaders: any, provider: string, method: string, forwardHeaders: string[], requestHeaders: Record) { let baseHeaders: any = { "content-type": "application/json" }; + let headers: Record = { + ...providerConfigMappedHeaders + }; + + const forwardHeadersMap: Record = {}; + + forwardHeaders.forEach((h: string) => { + if (requestHeaders[h]) forwardHeadersMap[h] = requestHeaders[h]; + }) + // Add any headers that the model might need - headers = {...baseHeaders, ...headers} + headers = {...baseHeaders, ...headers, ...forwardHeadersMap} let fetchOptions: RequestInit = { method, @@ -122,7 +132,8 @@ export const fetchProviderOptionsFromConfig = (config: Config | ShortConfig): Op virtualKey: camelCaseConfig.virtualKey, apiKey: camelCaseConfig.apiKey, cache: camelCaseConfig.cache, - retry: camelCaseConfig.retry + retry: camelCaseConfig.retry, + customHost: camelCaseConfig.customHost }]; if (camelCaseConfig.resourceName) providerOptions[0].resourceName = camelCaseConfig.resourceName; if (camelCaseConfig.deploymentId) providerOptions[0].deploymentId = camelCaseConfig.deploymentId; @@ -162,40 +173,48 @@ export async function tryPostProxy(c: Context, providerOption:Options, inputPara const apiConfig: ProviderAPIConfig = Providers[provider].api; let fetchOptions; let url = providerOption.urlToFetch as string; - + let baseUrl:string, endpoint:string; + + const forwardHeaders: string[] = []; + baseUrl = requestHeaders[HEADER_KEYS.CUSTOM_HOST] || providerOption.customHost || ""; + if (provider === AZURE_OPEN_AI && apiConfig.getBaseURL && apiConfig.getEndpoint) { // Construct the base object for the request if(!!providerOption.apiKey) { - fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, "apiKey"), provider, method); + fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, "apiKey"), provider, method, forwardHeaders, requestHeaders); } else { - fetchOptions = constructRequest(apiConfig.headers(providerOption.adAuth, "adAuth"), provider, method); + fetchOptions = constructRequest(apiConfig.headers(providerOption.adAuth, "adAuth"), provider, method, forwardHeaders, requestHeaders); } - baseUrl = apiConfig.getBaseURL(providerOption.resourceName, providerOption.deploymentId); + baseUrl = baseUrl || apiConfig.getBaseURL(providerOption.resourceName, providerOption.deploymentId); endpoint = apiConfig.getEndpoint(fn, providerOption.apiVersion, url); url = `${baseUrl}${endpoint}`; } else if (provider === PALM && apiConfig.baseURL && apiConfig.getEndpoint) { - fetchOptions = constructRequest(apiConfig.headers(), provider, method); - baseUrl = apiConfig.baseURL; + fetchOptions = constructRequest(apiConfig.headers(), provider, method, forwardHeaders, requestHeaders); + baseUrl = baseUrl || apiConfig.baseURL; endpoint = apiConfig.getEndpoint(fn, providerOption.apiKey, params?.model); url = `${baseUrl}${endpoint}`; - } else if (provider === "anthropic" && apiConfig.baseURL) { + } else if (provider === ANTHROPIC && apiConfig.baseURL) { // Construct the base object for the POST request - fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, fn), provider); - baseUrl = apiConfig.baseURL; + fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, fn), provider, "POST", forwardHeaders, requestHeaders); + baseUrl = baseUrl || apiConfig.baseURL; endpoint = apiConfig[fn] || ""; + url = `${baseUrl}${endpoint}`; } else if (provider === GOOGLE && apiConfig.baseURL && apiConfig.getEndpoint) { - fetchOptions = constructRequest(apiConfig.headers(), provider); - baseUrl = apiConfig.baseURL; + fetchOptions = constructRequest(apiConfig.headers(), provider, "POST", forwardHeaders, requestHeaders); + baseUrl = baseUrl || apiConfig.baseURL; endpoint = apiConfig.getEndpoint(fn, providerOption.apiKey, params.model, params.stream); + url = `${baseUrl}${endpoint}`; } else if (provider === STABILITY_AI && apiConfig.baseURL && apiConfig.getEndpoint) { - fetchOptions = constructRequest(apiConfig.headers(), provider); - baseUrl = apiConfig.baseURL; + fetchOptions = constructRequest(apiConfig.headers(), provider, "POST", forwardHeaders, requestHeaders); + baseUrl = baseUrl || apiConfig.baseURL; endpoint = apiConfig.getEndpoint(fn, params.model, url); + url = `${baseUrl}${endpoint}`; } else { // Construct the base object for the request - fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey), provider, method); + fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey), provider, method, forwardHeaders, requestHeaders); } + if (method === "POST") { fetchOptions.body = JSON.stringify(params) } @@ -314,37 +333,45 @@ export async function tryPost(c: Context, providerOption:Options, inputParams: P let baseUrl:string, endpoint:string, fetchOptions; + const forwardHeaders = requestHeaders[HEADER_KEYS.FORWARD_HEADERS]?.split(",").map(h => h.trim()) || providerOption.forwardHeaders || []; + baseUrl = requestHeaders[HEADER_KEYS.CUSTOM_HOST] || providerOption.customHost || ""; + if (provider === AZURE_OPEN_AI && apiConfig.getBaseURL && apiConfig.getEndpoint) { // Construct the base object for the POST request if(!!providerOption.apiKey) { - fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, "apiKey"), provider); + fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, "apiKey"), provider, "POST", forwardHeaders, requestHeaders); } else { - fetchOptions = constructRequest(apiConfig.headers(providerOption.adAuth, "adAuth"), provider); + fetchOptions = constructRequest(apiConfig.headers(providerOption.adAuth, "adAuth"), provider, "POST", forwardHeaders, requestHeaders); } - baseUrl = apiConfig.getBaseURL(providerOption.resourceName, providerOption.deploymentId); + baseUrl = baseUrl || apiConfig.getBaseURL(providerOption.resourceName, providerOption.deploymentId); endpoint = apiConfig.getEndpoint(fn, providerOption.apiVersion); } else if (provider === PALM && apiConfig.baseURL && apiConfig.getEndpoint) { - fetchOptions = constructRequest(apiConfig.headers(), provider); - baseUrl = apiConfig.baseURL; + fetchOptions = constructRequest(apiConfig.headers(), provider, "POST", forwardHeaders, requestHeaders); + baseUrl = baseUrl || apiConfig.baseURL; endpoint = apiConfig.getEndpoint(fn, providerOption.apiKey, providerOption.overrideParams?.model || params?.model); - } else if (provider === "anthropic" && apiConfig.baseURL) { + } else if (provider === ANTHROPIC && apiConfig.baseURL) { // Construct the base object for the POST request - fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, fn), provider); - baseUrl = apiConfig.baseURL; + fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, fn), provider, "POST", forwardHeaders, requestHeaders); + baseUrl = baseUrl || apiConfig.baseURL; endpoint = apiConfig[fn] || ""; } else if (provider === GOOGLE && apiConfig.baseURL && apiConfig.getEndpoint) { - fetchOptions = constructRequest(apiConfig.headers(), provider); - baseUrl = apiConfig.baseURL; + fetchOptions = constructRequest(apiConfig.headers(), provider, "POST", forwardHeaders, requestHeaders); + baseUrl = baseUrl || apiConfig.baseURL; endpoint = apiConfig.getEndpoint(fn, providerOption.apiKey, transformedRequestBody.model, params.stream); } else if (provider === STABILITY_AI && apiConfig.baseURL && apiConfig.getEndpoint) { - fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey), provider); - baseUrl = apiConfig.baseURL; + fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey), provider, "POST", forwardHeaders, requestHeaders); + baseUrl = baseUrl || apiConfig.baseURL; endpoint = apiConfig.getEndpoint(fn, params.model); - } else { + } else if (provider === OLLAMA && apiConfig.getEndpoint) { + fetchOptions = constructRequest(apiConfig.headers(), provider, "POST", forwardHeaders, requestHeaders); + baseUrl = baseUrl; + endpoint = apiConfig.getEndpoint(fn, providerOption.apiKey, transformedRequestBody.model, params.stream); + } + else { // Construct the base object for the POST request - fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey), provider); + fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey), provider, "POST", forwardHeaders, requestHeaders); - baseUrl = apiConfig.baseURL || ""; + baseUrl = baseUrl || apiConfig.baseURL || ""; endpoint = apiConfig[fn] || ""; } @@ -358,7 +385,7 @@ export async function tryPost(c: Context, providerOption:Options, inputParams: P providerOption.retry = { attempts: providerOption.retry?.attempts ?? 0, - onStatusCodes: providerOption.retry?.onStatusCodes ?? [] + onStatusCodes: providerOption.retry?.onStatusCodes ?? RETRY_STATUS_CODES } const [getFromCacheFunction, cacheIdentifier, requestOptions] = [ @@ -543,7 +570,7 @@ export async function tryTargetsRecursively( const strategyMode = currentTarget.strategy?.mode; // start: merge inherited config with current target config (preference given to current) - const currentInheritedConfig = { + const currentInheritedConfig: Record = { overrideParams : { ...inheritedConfig.overrideParams, ...currentTarget.overrideParams @@ -553,11 +580,29 @@ export async function tryTargetsRecursively( requestTimeout: null } + if (currentTarget.forwardHeaders) { + currentInheritedConfig.forwardHeaders = [...currentTarget.forwardHeaders]; + } else if (inheritedConfig.forwardHeaders) { + currentInheritedConfig.forwardHeaders = [...inheritedConfig.forwardHeaders]; + currentTarget.forwardHeaders = [...inheritedConfig.forwardHeaders]; + } + + if (currentTarget.customHost) { + currentInheritedConfig.customHost = currentTarget.customHost + } else if (inheritedConfig.customHost) { + currentInheritedConfig.customHost = inheritedConfig.customHost; + currentTarget.customHost = inheritedConfig.customHost; + } + if (currentTarget.requestTimeout) { currentInheritedConfig.requestTimeout = currentTarget.requestTimeout } else if (inheritedConfig.requestTimeout) { currentInheritedConfig.requestTimeout = inheritedConfig.requestTimeout; + currentTarget.requestTimeout = inheritedConfig.requestTimeout; } + + + currentTarget.overrideParams = { ...currentInheritedConfig.overrideParams } diff --git a/src/handlers/proxyGetHandler.ts b/src/handlers/proxyGetHandler.ts index 275e50fc3..20d6f0c0d 100644 --- a/src/handlers/proxyGetHandler.ts +++ b/src/handlers/proxyGetHandler.ts @@ -10,11 +10,16 @@ function proxyProvider(proxyModeHeader:string, providerHeader: string) { return proxyProvider; } -function getProxyPath(requestURL:string, proxyProvider:string, proxyEndpointPath:string) { +function getProxyPath(requestURL:string, proxyProvider:string, proxyEndpointPath:string, customHost: string) { let reqURL = new URL(requestURL); let reqPath = reqURL.pathname; const reqQuery = reqURL.search; reqPath = reqPath.replace(proxyEndpointPath, ""); + + if (customHost) { + return `${customHost}${reqPath}${reqQuery}` + } + const providerBasePath = Providers[proxyProvider].api.baseURL; if (proxyProvider === AZURE_OPEN_AI) { return `https:/${reqPath}${reqQuery}`; @@ -55,7 +60,9 @@ export async function proxyGetHandler(c: Context): Promise { proxyPath: c.req.url.indexOf("/v1/proxy") > -1 ? "/v1/proxy" : "/v1" } - let urlToFetch = getProxyPath(c.req.url, store.proxyProvider, store.proxyPath); + const customHost = requestHeaders[HEADER_KEYS.CUSTOM_HOST] || ""; + + let urlToFetch = getProxyPath(c.req.url, store.proxyProvider, store.proxyPath, customHost); let fetchOptions = { headers: headersToSend(requestHeaders, store.customHeadersToAvoid), diff --git a/src/handlers/proxyHandler.ts b/src/handlers/proxyHandler.ts index d806f5c9f..7cfadfe2d 100644 --- a/src/handlers/proxyHandler.ts +++ b/src/handlers/proxyHandler.ts @@ -1,7 +1,7 @@ import { Context } from "hono"; import { retryRequest } from "./retryHandler"; import Providers from "../providers"; -import { ANTHROPIC, MAX_RETRIES, HEADER_KEYS, RETRY_STATUS_CODES, POWERED_BY, RESPONSE_HEADER_KEYS, AZURE_OPEN_AI, CONTENT_TYPES } from "../globals"; +import { ANTHROPIC, MAX_RETRIES, HEADER_KEYS, RETRY_STATUS_CODES, POWERED_BY, RESPONSE_HEADER_KEYS, AZURE_OPEN_AI, CONTENT_TYPES, OLLAMA } from "../globals"; import { fetchProviderOptionsFromConfig, responseHandler, tryProvidersInSequence, updateResponseHeaders } from "./handlerUtils"; import { convertKeysToCamelCase, getStreamingMode } from "../utils"; import { Config, ShortConfig } from "../types/requestBody"; @@ -12,15 +12,24 @@ function proxyProvider(proxyModeHeader:string, providerHeader: string) { return proxyProvider; } -function getProxyPath(requestURL:string, proxyProvider:string, proxyEndpointPath:string) { +function getProxyPath(requestURL:string, proxyProvider:string, proxyEndpointPath:string, customHost: string) { let reqURL = new URL(requestURL); let reqPath = reqURL.pathname; const reqQuery = reqURL.search; reqPath = reqPath.replace(proxyEndpointPath, ""); + + if (customHost) { + return `${customHost}${reqPath}${reqQuery}` + } + const providerBasePath = Providers[proxyProvider].api.baseURL; if (proxyProvider === AZURE_OPEN_AI) { return `https:/${reqPath}${reqQuery}`; } + + if (proxyProvider === OLLAMA) { + return `https:/${reqPath}`; + } let proxyPath = `${providerBasePath}${reqPath}${reqQuery}`; // Fix specific for Anthropic SDK calls. Is this needed? - Yes @@ -90,7 +99,8 @@ export async function proxyHandler(c: Context): Promise { } }; - let urlToFetch = getProxyPath(c.req.url, store.proxyProvider, store.proxyPath); + const customHost = requestHeaders[HEADER_KEYS.CUSTOM_HOST] || requestConfig?.customHost || ""; + let urlToFetch = getProxyPath(c.req.url, store.proxyProvider, store.proxyPath, customHost); store.isStreamingMode = getStreamingMode(store.reqBody, store.proxyProvider, urlToFetch) if (requestConfig && diff --git a/src/handlers/streamHandler.ts b/src/handlers/streamHandler.ts index 798fb15fa..19cca43ba 100644 --- a/src/handlers/streamHandler.ts +++ b/src/handlers/streamHandler.ts @@ -1,4 +1,4 @@ -import { AZURE_OPEN_AI, CONTENT_TYPES, COHERE, GOOGLE, REQUEST_TIMEOUT_STATUS_CODE } from "../globals"; +import { AZURE_OPEN_AI, CONTENT_TYPES, COHERE, GOOGLE, REQUEST_TIMEOUT_STATUS_CODE, OLLAMA } from "../globals"; import { OpenAIChatCompleteResponse } from "../providers/openai/chatComplete"; import { OpenAICompleteResponse } from "../providers/openai/complete"; import { getStreamModeSplitPattern } from "../utils"; diff --git a/src/middlewares/requestValidator/index.ts b/src/middlewares/requestValidator/index.ts index 0b9ecc380..b9f1340ee 100644 --- a/src/middlewares/requestValidator/index.ts +++ b/src/middlewares/requestValidator/index.ts @@ -15,6 +15,7 @@ import { DEEPINFRA, STABILITY_AI, NOMIC, + OLLAMA, } from "../../globals"; import { configSchema } from "./schema/config"; @@ -63,7 +64,7 @@ export const requestValidator = (c: Context, next: any) => { } if ( requestHeaders[`x-${POWERED_BY}-provider`] && - ![OPEN_AI, AZURE_OPEN_AI, COHERE, ANTHROPIC, ANYSCALE, PALM, TOGETHER_AI, GOOGLE, MISTRAL_AI, PERPLEXITY_AI, DEEPINFRA, NOMIC, STABILITY_AI].includes( + ![OPEN_AI, AZURE_OPEN_AI, COHERE, ANTHROPIC, ANYSCALE, PALM, TOGETHER_AI, GOOGLE, MISTRAL_AI, PERPLEXITY_AI, DEEPINFRA, NOMIC, STABILITY_AI, OLLAMA].includes( requestHeaders[`x-${POWERED_BY}-provider`] ) ) { @@ -81,6 +82,22 @@ export const requestValidator = (c: Context, next: any) => { ); } + const customHostHeader = requestHeaders[`x-${POWERED_BY}-custom-host`]; + if (customHostHeader && customHostHeader.indexOf("api.portkey") > -1) { + return new Response( + JSON.stringify({ + status: "failure", + message: `Invalid custom host`, + }), + { + status: 400, + headers: { + "content-type": "application/json", + }, + } + );; + } + if (requestHeaders[`x-${POWERED_BY}-config`]) { try { diff --git a/src/middlewares/requestValidator/schema/config.ts b/src/middlewares/requestValidator/schema/config.ts index 8e3e60801..377f8a41a 100644 --- a/src/middlewares/requestValidator/schema/config.ts +++ b/src/middlewares/requestValidator/schema/config.ts @@ -12,6 +12,7 @@ import { DEEPINFRA, NOMIC, STABILITY_AI, + OLLAMA } from "../../../globals"; export const configSchema: any = z @@ -49,7 +50,8 @@ export const configSchema: any = z MISTRAL_AI, DEEPINFRA, NOMIC, - STABILITY_AI + STABILITY_AI, + OLLAMA ].includes(value), { message: @@ -85,6 +87,8 @@ export const configSchema: any = z on_status_codes: z.array(z.number()).optional(), targets: z.array(z.lazy(() => configSchema)).optional(), request_timeout: z.number().optional(), + custom_host: z.string().optional(), + forward_headers: z.array(z.string()).optional() }) .refine( (value) => { @@ -92,16 +96,32 @@ export const configSchema: any = z value.provider !== undefined && value.api_key !== undefined; const hasModeTargets = value.strategy !== undefined && value.targets !== undefined; + const isOllamaProvider = value.provider === OLLAMA; + return ( hasProviderApiKey || hasModeTargets || value.cache || value.retry || - value.request_timeout + value.request_timeout || + isOllamaProvider ); }, { message: "Invalid configuration. It must have either 'provider' and 'api_key', or 'strategy' and 'targets', or 'cache', or 'retry', or 'request_timeout'", } + ) + .refine( + (value) => { + const customHost = value.custom_host; + if (customHost && (customHost.indexOf("api.portkey") > -1)) { + return false; + } + return true; + }, + { + message: + "Invalid custom host", + } ); diff --git a/src/providers/cohere/chatComplete.ts b/src/providers/cohere/chatComplete.ts index cc3e9d311..8f2011736 100644 --- a/src/providers/cohere/chatComplete.ts +++ b/src/providers/cohere/chatComplete.ts @@ -136,7 +136,7 @@ export const CohereChatCompleteStreamChunkTransform: (response: string, fallback return `data: ${JSON.stringify({ id: parsedChunk.id ?? fallbackId, - object: "text_completion", + object: "chat.completion.chunk", created: Math.floor(Date.now() / 1000), model: "", provider: COHERE, diff --git a/src/providers/google/chatComplete.ts b/src/providers/google/chatComplete.ts index 0ab240681..db48e2d3f 100644 --- a/src/providers/google/chatComplete.ts +++ b/src/providers/google/chatComplete.ts @@ -17,9 +17,6 @@ const transformGenerationConfig = (params: Params) => { if (params["top_k"]) { generationConfig["topK"] = params["top_k"]; } - if (params["top_k"]) { - generationConfig["topK"] = params["top_k"]; - } if (params["max_tokens"]) { generationConfig["maxOutputTokens"] = params["max_tokens"]; } diff --git a/src/providers/index.ts b/src/providers/index.ts index 8f7e76af5..d50fd9afd 100644 --- a/src/providers/index.ts +++ b/src/providers/index.ts @@ -11,6 +11,7 @@ import PalmAIConfig from "./palm"; import PerplexityAIConfig from "./perplexity-ai"; import TogetherAIConfig from "./together-ai"; import StabilityAIConfig from "./stability-ai"; +import OllamaAPIConfig from "./ollama"; import { ProviderConfigs } from "./types"; const Providers: { [key: string]: ProviderConfigs } = { @@ -26,7 +27,8 @@ const Providers: { [key: string]: ProviderConfigs } = { 'mistral-ai': MistralAIConfig, 'deepinfra': DeepInfraConfig, 'stability-ai': StabilityAIConfig, - nomic: NomicConfig + nomic: NomicConfig, + 'ollama': OllamaAPIConfig }; export default Providers; diff --git a/src/providers/ollama/api.ts b/src/providers/ollama/api.ts new file mode 100644 index 000000000..5521348ed --- /dev/null +++ b/src/providers/ollama/api.ts @@ -0,0 +1,28 @@ +import { ProviderAPIConfig } from "../types"; + +const OllamaAPIConfig: ProviderAPIConfig = { + headers: () => { + return {}; + }, + chatComplete: "/v1/chat/completions", + embed:"/api/embeddings", + getEndpoint: (fn: string, url?: string) => { + let mappedFn = fn; + if (fn === "proxy" && url && url?.indexOf("/api/chat") > -1) { + mappedFn = "chatComplete"; + } else if (fn === "proxy" && url && url?.indexOf("/embeddings") > -1) { + mappedFn = "embed"; + } + + switch (mappedFn) { + case "chatComplete": { + return `/v1/chat/completions`; + } + case "embed": { + return `/api/embeddings`; + } + } + }, +}; + +export default OllamaAPIConfig; diff --git a/src/providers/ollama/chatComplete.ts b/src/providers/ollama/chatComplete.ts new file mode 100644 index 000000000..36ab67b05 --- /dev/null +++ b/src/providers/ollama/chatComplete.ts @@ -0,0 +1,130 @@ +import { + ChatCompletionResponse, + ErrorResponse, + ProviderConfig, +} from "../types"; +import { OLLAMA } from "../../globals"; + +export const OllamaChatCompleteConfig: ProviderConfig = { + model: { + param: "model", + required: true, + default: "llama2", + }, + messages: { + param: "messages", + default: "", + }, + frequency_penalty: { + param: "frequency_penalty", + min: -2, + max: 2, + }, + presence_penalty: { + param: "presence_penalty", + min: -2, + max: 2, + }, + response_format: { + param: "response_format", + }, + seed: { + param: "seed", + }, + stop: { + param: "stop", + }, + stream: { + param: "stream", + default: false, + }, + temperature: { + param: "temperature", + default: 1, + min: 0, + max: 2, + }, + top_p: { + param: "top_p", + default: 1, + min: 0, + max: 1, + }, + max_tokens: { + param: "max_tokens", + default: 100, + min: 0, + }, +}; + +export interface OllamaChatCompleteResponse + extends ChatCompletionResponse, + ErrorResponse { + system_fingerprint: string; +} + +export interface OllamaStreamChunk { + id: string; + object: string; + created: number; + model: string; + system_fingerprint: string; + choices: { + delta: { + role: string; + content?: string; + }; + index: number; + finish_reason: string | null; + }[]; +} + +export const OllamaChatCompleteResponseTransform: ( + response: OllamaChatCompleteResponse, + responseStatus: number +) => ChatCompletionResponse | ErrorResponse = (response, responseStatus) => { + + if (responseStatus !== 200) { + return { + error: { + message: response.error?.message, + type: response.error?.type, + param: null, + code: null, + }, + provider: OLLAMA, + } as ErrorResponse; + } + + return { + id: response.id, + object: response.object, + created: response.created, + model: response.model, + provider: OLLAMA, + choices: response.choices, + usage: response.usage, + }; +}; + +export const OllamaChatCompleteStreamChunkTransform: ( + reponse: string +) => string = (responseChunk) => { + let chunk = responseChunk.trim(); + chunk = chunk.replace(/^data: /, ""); + chunk = chunk.trim(); + if (chunk === "[DONE]") { + return `data: ${chunk}\n\n`; + } + const parsedChunk: OllamaStreamChunk = JSON.parse(chunk); + return ( + `data: ${JSON.stringify({ + id: parsedChunk.id, + object: parsedChunk.object, + created: parsedChunk.created, + model: parsedChunk.model, + provider: OLLAMA, + choices: parsedChunk.choices, + })}` + "\n\n" + ); +}; diff --git a/src/providers/ollama/embed.ts b/src/providers/ollama/embed.ts new file mode 100644 index 000000000..6566104fd --- /dev/null +++ b/src/providers/ollama/embed.ts @@ -0,0 +1,68 @@ +import { OLLAMA } from "../../globals"; +import { EmbedResponse } from "../../types/embedRequestBody"; +import { ErrorResponse, ProviderConfig } from "../types"; + +// TODOS: this configuration does not enforce the maximum token limit for the input parameter. If you want to enforce this, you might need to add a custom validation function or a max property to the ParameterConfig interface, and then use it in the input configuration. However, this might be complex because the token count is not a simple length check, but depends on the specific tokenization method used by the model. + +export const OllamaEmbedConfig: ProviderConfig = { + model: { + param: "model", + }, + input: { + param: "prompt", + required: true, + } +}; + +interface OllamaEmbedResponse extends EmbedResponse { + embedding: number[]; +} + +interface OllamaErrorResponse { + error: string; +} +export const OllamaEmbedResponseTransform: ( + response: OllamaEmbedResponse | OllamaErrorResponse, + responseStatus: number +) => EmbedResponse | ErrorResponse = (response, responseStatus) => { + + if ("error" in response) { + return { + error: { + message: response.error, + type: null, + param: null, + code: null, + }, + provider: OLLAMA, + } as ErrorResponse; + } + if ("embedding" in response) { + return { + object: "list", + data: [ + { + object: "embedding", + embedding: response.embedding, + index: 0, + }, + ], + model: "", + usage: { + prompt_tokens: -1, + total_tokens: -1, + }, + }; + } + return { + error: { + message: `Invalid response recieved from ${OLLAMA}: ${JSON.stringify( + response + )}`, + type: null, + param: null, + code: null, + }, + provider: OLLAMA, + } as ErrorResponse; +}; diff --git a/src/providers/ollama/index.ts b/src/providers/ollama/index.ts new file mode 100644 index 000000000..aaa3ccc26 --- /dev/null +++ b/src/providers/ollama/index.ts @@ -0,0 +1,18 @@ +import { ProviderConfigs } from "../types"; +import { OllamaEmbedConfig, OllamaEmbedResponseTransform } from "./embed"; +import OllamaAPIConfig from "./api"; +import { OllamaChatCompleteConfig, OllamaChatCompleteResponseTransform, OllamaChatCompleteStreamChunkTransform } from "./chatComplete"; + +const OllamaConfig: ProviderConfigs = { + + embed: OllamaEmbedConfig, + api: OllamaAPIConfig, + chatComplete: OllamaChatCompleteConfig, + responseTransforms: { + chatComplete: OllamaChatCompleteResponseTransform, + 'stream-chatComplete': OllamaChatCompleteStreamChunkTransform, + embed: OllamaEmbedResponseTransform + } +}; + +export default OllamaConfig; diff --git a/src/types/requestBody.ts b/src/types/requestBody.ts index 18793b1fd..63688c4c9 100644 --- a/src/types/requestBody.ts +++ b/src/types/requestBody.ts @@ -43,6 +43,10 @@ export interface Options { deploymentId?: string; apiVersion?: string; adAuth?:string; + /** The parameter to set custom base url */ + customHost?: string; + /** The parameter to set list of headers to be forwarded as-is to the provider */ + forwardHeaders?: string[]; /** provider option index picked based on weight in loadbalance mode */ index?: number; cache?: CacheSettings | string; @@ -94,6 +98,7 @@ export interface Config { cache?: CacheSettings; retry?: RetrySettings; strategy?: Strategy; + customHost?: string; } /** @@ -230,6 +235,7 @@ export interface ShortConfig { resourceName?: string; deploymentId?: string; apiVersion?: string; + customHost?: string; } /** diff --git a/src/utils.ts b/src/utils.ts index 3d6068729..b8e8cdbc4 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -1,4 +1,4 @@ -import { ANTHROPIC, COHERE, GOOGLE, PERPLEXITY_AI, DEEPINFRA } from "./globals"; +import { ANTHROPIC, COHERE, GOOGLE, PERPLEXITY_AI, DEEPINFRA, OLLAMA } from "./globals"; import { Params } from "./types/requestBody"; export const getStreamModeSplitPattern = (proxyProvider: string, requestURL: string) => {