Skip to content

Commit

Permalink
Merge pull request #194 from Portkey-AI/feat/ollama-integration
Browse files Browse the repository at this point in the history
Feat/ollama integration
  • Loading branch information
VisargD authored Feb 17, 2024
2 parents 70acd29 + e4e560a commit 0047d92
Show file tree
Hide file tree
Showing 17 changed files with 405 additions and 51 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,6 @@ dist
# Rollup build dir
build
.DS_Store

# Wrangler temp directory
.wrangler
5 changes: 4 additions & 1 deletion src/globals.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ export const HEADER_KEYS: Record<string, string> = {
RETRIES: `x-${POWERED_BY}-retry-count`,
PROVIDER: `x-${POWERED_BY}-provider`,
TRACE_ID: `x-${POWERED_BY}-trace-id`,
CACHE: `x-${POWERED_BY}-cache`
CACHE: `x-${POWERED_BY}-cache`,
FORWARD_HEADERS: `x-${POWERED_BY}-forward-headers`,
CUSTOM_HOST: `x-${POWERED_BY}-custom-host`
}

export const RESPONSE_HEADER_KEYS: Record<string, string> = {
Expand Down Expand Up @@ -33,6 +35,7 @@ export const MISTRAL_AI: string = "mistral-ai";
export const DEEPINFRA: string = "deepinfra";
export const STABILITY_AI: string = "stability-ai";
export const NOMIC: string = "nomic";
export const OLLAMA: string = "ollama";

export const providersWithStreamingSupport = [OPEN_AI, AZURE_OPEN_AI, ANTHROPIC, COHERE];
export const allowedProxyProviders = [OPEN_AI, COHERE, AZURE_OPEN_AI, ANTHROPIC];
Expand Down
115 changes: 80 additions & 35 deletions src/handlers/handlerUtils.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Context } from "hono";
import { AZURE_OPEN_AI, CONTENT_TYPES, GOOGLE, HEADER_KEYS, PALM, POWERED_BY, RESPONSE_HEADER_KEYS, RETRY_STATUS_CODES, STABILITY_AI } from "../globals";
import { ANTHROPIC, AZURE_OPEN_AI, CONTENT_TYPES, GOOGLE, HEADER_KEYS, OLLAMA, PALM, POWERED_BY, RESPONSE_HEADER_KEYS, RETRY_STATUS_CODES, STABILITY_AI } from "../globals";
import Providers from "../providers";
import { ProviderAPIConfig, endpointStrings } from "../providers/types";
import transformToProviderRequest from "../services/transformToProviderRequest";
Expand All @@ -19,13 +19,23 @@ import { OpenAICompleteJSONToStreamResponseTransform } from "../providers/openai
* @param {string} method - The HTTP method for the request.
* @returns {RequestInit} - The fetch options for the request.
*/
export function constructRequest(headers: any, provider: string = "", method: string = "POST") {
export function constructRequest(providerConfigMappedHeaders: any, provider: string, method: string, forwardHeaders: string[], requestHeaders: Record<string, string>) {
let baseHeaders: any = {
"content-type": "application/json"
};

let headers: Record<string, string> = {
...providerConfigMappedHeaders
};

const forwardHeadersMap: Record<string, string> = {};

forwardHeaders.forEach((h: string) => {
if (requestHeaders[h]) forwardHeadersMap[h] = requestHeaders[h];
})

// Add any headers that the model might need
headers = {...baseHeaders, ...headers}
headers = {...baseHeaders, ...headers, ...forwardHeadersMap}

let fetchOptions: RequestInit = {
method,
Expand Down Expand Up @@ -122,7 +132,8 @@ export const fetchProviderOptionsFromConfig = (config: Config | ShortConfig): Op
virtualKey: camelCaseConfig.virtualKey,
apiKey: camelCaseConfig.apiKey,
cache: camelCaseConfig.cache,
retry: camelCaseConfig.retry
retry: camelCaseConfig.retry,
customHost: camelCaseConfig.customHost
}];
if (camelCaseConfig.resourceName) providerOptions[0].resourceName = camelCaseConfig.resourceName;
if (camelCaseConfig.deploymentId) providerOptions[0].deploymentId = camelCaseConfig.deploymentId;
Expand Down Expand Up @@ -162,40 +173,48 @@ export async function tryPostProxy(c: Context, providerOption:Options, inputPara
const apiConfig: ProviderAPIConfig = Providers[provider].api;
let fetchOptions;
let url = providerOption.urlToFetch as string;

let baseUrl:string, endpoint:string;

const forwardHeaders: string[] = [];
baseUrl = requestHeaders[HEADER_KEYS.CUSTOM_HOST] || providerOption.customHost || "";

if (provider === AZURE_OPEN_AI && apiConfig.getBaseURL && apiConfig.getEndpoint) {
// Construct the base object for the request
if(!!providerOption.apiKey) {
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, "apiKey"), provider, method);
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, "apiKey"), provider, method, forwardHeaders, requestHeaders);
} else {
fetchOptions = constructRequest(apiConfig.headers(providerOption.adAuth, "adAuth"), provider, method);
fetchOptions = constructRequest(apiConfig.headers(providerOption.adAuth, "adAuth"), provider, method, forwardHeaders, requestHeaders);
}
baseUrl = apiConfig.getBaseURL(providerOption.resourceName, providerOption.deploymentId);
baseUrl = baseUrl || apiConfig.getBaseURL(providerOption.resourceName, providerOption.deploymentId);
endpoint = apiConfig.getEndpoint(fn, providerOption.apiVersion, url);
url = `${baseUrl}${endpoint}`;
} else if (provider === PALM && apiConfig.baseURL && apiConfig.getEndpoint) {
fetchOptions = constructRequest(apiConfig.headers(), provider, method);
baseUrl = apiConfig.baseURL;
fetchOptions = constructRequest(apiConfig.headers(), provider, method, forwardHeaders, requestHeaders);
baseUrl = baseUrl || apiConfig.baseURL;
endpoint = apiConfig.getEndpoint(fn, providerOption.apiKey, params?.model);
url = `${baseUrl}${endpoint}`;
} else if (provider === "anthropic" && apiConfig.baseURL) {
} else if (provider === ANTHROPIC && apiConfig.baseURL) {
// Construct the base object for the POST request
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, fn), provider);
baseUrl = apiConfig.baseURL;
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, fn), provider, "POST", forwardHeaders, requestHeaders);
baseUrl = baseUrl || apiConfig.baseURL;
endpoint = apiConfig[fn] || "";
url = `${baseUrl}${endpoint}`;
} else if (provider === GOOGLE && apiConfig.baseURL && apiConfig.getEndpoint) {
fetchOptions = constructRequest(apiConfig.headers(), provider);
baseUrl = apiConfig.baseURL;
fetchOptions = constructRequest(apiConfig.headers(), provider, "POST", forwardHeaders, requestHeaders);
baseUrl = baseUrl || apiConfig.baseURL;
endpoint = apiConfig.getEndpoint(fn, providerOption.apiKey, params.model, params.stream);
url = `${baseUrl}${endpoint}`;
} else if (provider === STABILITY_AI && apiConfig.baseURL && apiConfig.getEndpoint) {
fetchOptions = constructRequest(apiConfig.headers(), provider);
baseUrl = apiConfig.baseURL;
fetchOptions = constructRequest(apiConfig.headers(), provider, "POST", forwardHeaders, requestHeaders);
baseUrl = baseUrl || apiConfig.baseURL;
endpoint = apiConfig.getEndpoint(fn, params.model, url);
url = `${baseUrl}${endpoint}`;
} else {
// Construct the base object for the request
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey), provider, method);
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey), provider, method, forwardHeaders, requestHeaders);
}

if (method === "POST") {
fetchOptions.body = JSON.stringify(params)
}
Expand Down Expand Up @@ -314,37 +333,45 @@ export async function tryPost(c: Context, providerOption:Options, inputParams: P


let baseUrl:string, endpoint:string, fetchOptions;
const forwardHeaders = requestHeaders[HEADER_KEYS.FORWARD_HEADERS]?.split(",").map(h => h.trim()) || providerOption.forwardHeaders || [];
baseUrl = requestHeaders[HEADER_KEYS.CUSTOM_HOST] || providerOption.customHost || "";

if (provider === AZURE_OPEN_AI && apiConfig.getBaseURL && apiConfig.getEndpoint) {
// Construct the base object for the POST request
if(!!providerOption.apiKey) {
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, "apiKey"), provider);
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, "apiKey"), provider, "POST", forwardHeaders, requestHeaders);
} else {
fetchOptions = constructRequest(apiConfig.headers(providerOption.adAuth, "adAuth"), provider);
fetchOptions = constructRequest(apiConfig.headers(providerOption.adAuth, "adAuth"), provider, "POST", forwardHeaders, requestHeaders);
}
baseUrl = apiConfig.getBaseURL(providerOption.resourceName, providerOption.deploymentId);
baseUrl = baseUrl || apiConfig.getBaseURL(providerOption.resourceName, providerOption.deploymentId);
endpoint = apiConfig.getEndpoint(fn, providerOption.apiVersion);
} else if (provider === PALM && apiConfig.baseURL && apiConfig.getEndpoint) {
fetchOptions = constructRequest(apiConfig.headers(), provider);
baseUrl = apiConfig.baseURL;
fetchOptions = constructRequest(apiConfig.headers(), provider, "POST", forwardHeaders, requestHeaders);
baseUrl = baseUrl || apiConfig.baseURL;
endpoint = apiConfig.getEndpoint(fn, providerOption.apiKey, providerOption.overrideParams?.model || params?.model);
} else if (provider === "anthropic" && apiConfig.baseURL) {
} else if (provider === ANTHROPIC && apiConfig.baseURL) {
// Construct the base object for the POST request
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, fn), provider);
baseUrl = apiConfig.baseURL;
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, fn), provider, "POST", forwardHeaders, requestHeaders);
baseUrl = baseUrl || apiConfig.baseURL;
endpoint = apiConfig[fn] || "";
} else if (provider === GOOGLE && apiConfig.baseURL && apiConfig.getEndpoint) {
fetchOptions = constructRequest(apiConfig.headers(), provider);
baseUrl = apiConfig.baseURL;
fetchOptions = constructRequest(apiConfig.headers(), provider, "POST", forwardHeaders, requestHeaders);
baseUrl = baseUrl || apiConfig.baseURL;
endpoint = apiConfig.getEndpoint(fn, providerOption.apiKey, transformedRequestBody.model, params.stream);
} else if (provider === STABILITY_AI && apiConfig.baseURL && apiConfig.getEndpoint) {
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey), provider);
baseUrl = apiConfig.baseURL;
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey), provider, "POST", forwardHeaders, requestHeaders);
baseUrl = baseUrl || apiConfig.baseURL;
endpoint = apiConfig.getEndpoint(fn, params.model);
} else {
} else if (provider === OLLAMA && apiConfig.getEndpoint) {
fetchOptions = constructRequest(apiConfig.headers(), provider, "POST", forwardHeaders, requestHeaders);
baseUrl = baseUrl;
endpoint = apiConfig.getEndpoint(fn, providerOption.apiKey, transformedRequestBody.model, params.stream);
}
else {
// Construct the base object for the POST request
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey), provider);
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey), provider, "POST", forwardHeaders, requestHeaders);

baseUrl = apiConfig.baseURL || "";
baseUrl = baseUrl || apiConfig.baseURL || "";
endpoint = apiConfig[fn] || "";
}

Expand All @@ -358,7 +385,7 @@ export async function tryPost(c: Context, providerOption:Options, inputParams: P

providerOption.retry = {
attempts: providerOption.retry?.attempts ?? 0,
onStatusCodes: providerOption.retry?.onStatusCodes ?? []
onStatusCodes: providerOption.retry?.onStatusCodes ?? RETRY_STATUS_CODES
}

const [getFromCacheFunction, cacheIdentifier, requestOptions] = [
Expand Down Expand Up @@ -543,7 +570,7 @@ export async function tryTargetsRecursively(
const strategyMode = currentTarget.strategy?.mode;

// start: merge inherited config with current target config (preference given to current)
const currentInheritedConfig = {
const currentInheritedConfig: Record<string, any> = {
overrideParams : {
...inheritedConfig.overrideParams,
...currentTarget.overrideParams
Expand All @@ -553,11 +580,29 @@ export async function tryTargetsRecursively(
requestTimeout: null
}

if (currentTarget.forwardHeaders) {
currentInheritedConfig.forwardHeaders = [...currentTarget.forwardHeaders];
} else if (inheritedConfig.forwardHeaders) {
currentInheritedConfig.forwardHeaders = [...inheritedConfig.forwardHeaders];
currentTarget.forwardHeaders = [...inheritedConfig.forwardHeaders];
}

if (currentTarget.customHost) {
currentInheritedConfig.customHost = currentTarget.customHost
} else if (inheritedConfig.customHost) {
currentInheritedConfig.customHost = inheritedConfig.customHost;
currentTarget.customHost = inheritedConfig.customHost;
}

if (currentTarget.requestTimeout) {
currentInheritedConfig.requestTimeout = currentTarget.requestTimeout
} else if (inheritedConfig.requestTimeout) {
currentInheritedConfig.requestTimeout = inheritedConfig.requestTimeout;
currentTarget.requestTimeout = inheritedConfig.requestTimeout;
}



currentTarget.overrideParams = {
...currentInheritedConfig.overrideParams
}
Expand Down
11 changes: 9 additions & 2 deletions src/handlers/proxyGetHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,16 @@ function proxyProvider(proxyModeHeader:string, providerHeader: string) {
return proxyProvider;
}

function getProxyPath(requestURL:string, proxyProvider:string, proxyEndpointPath:string) {
function getProxyPath(requestURL:string, proxyProvider:string, proxyEndpointPath:string, customHost: string) {
let reqURL = new URL(requestURL);
let reqPath = reqURL.pathname;
const reqQuery = reqURL.search;
reqPath = reqPath.replace(proxyEndpointPath, "");

if (customHost) {
return `${customHost}${reqPath}${reqQuery}`
}

const providerBasePath = Providers[proxyProvider].api.baseURL;
if (proxyProvider === AZURE_OPEN_AI) {
return `https:/${reqPath}${reqQuery}`;
Expand Down Expand Up @@ -55,7 +60,9 @@ export async function proxyGetHandler(c: Context): Promise<Response> {
proxyPath: c.req.url.indexOf("/v1/proxy") > -1 ? "/v1/proxy" : "/v1"
}

let urlToFetch = getProxyPath(c.req.url, store.proxyProvider, store.proxyPath);
const customHost = requestHeaders[HEADER_KEYS.CUSTOM_HOST] || "";

let urlToFetch = getProxyPath(c.req.url, store.proxyProvider, store.proxyPath, customHost);

let fetchOptions = {
headers: headersToSend(requestHeaders, store.customHeadersToAvoid),
Expand Down
16 changes: 13 additions & 3 deletions src/handlers/proxyHandler.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { Context } from "hono";
import { retryRequest } from "./retryHandler";
import Providers from "../providers";
import { ANTHROPIC, MAX_RETRIES, HEADER_KEYS, RETRY_STATUS_CODES, POWERED_BY, RESPONSE_HEADER_KEYS, AZURE_OPEN_AI, CONTENT_TYPES } from "../globals";
import { ANTHROPIC, MAX_RETRIES, HEADER_KEYS, RETRY_STATUS_CODES, POWERED_BY, RESPONSE_HEADER_KEYS, AZURE_OPEN_AI, CONTENT_TYPES, OLLAMA } from "../globals";
import { fetchProviderOptionsFromConfig, responseHandler, tryProvidersInSequence, updateResponseHeaders } from "./handlerUtils";
import { convertKeysToCamelCase, getStreamingMode } from "../utils";
import { Config, ShortConfig } from "../types/requestBody";
Expand All @@ -12,15 +12,24 @@ function proxyProvider(proxyModeHeader:string, providerHeader: string) {
return proxyProvider;
}

function getProxyPath(requestURL:string, proxyProvider:string, proxyEndpointPath:string) {
function getProxyPath(requestURL:string, proxyProvider:string, proxyEndpointPath:string, customHost: string) {
let reqURL = new URL(requestURL);
let reqPath = reqURL.pathname;
const reqQuery = reqURL.search;
reqPath = reqPath.replace(proxyEndpointPath, "");

if (customHost) {
return `${customHost}${reqPath}${reqQuery}`
}

const providerBasePath = Providers[proxyProvider].api.baseURL;
if (proxyProvider === AZURE_OPEN_AI) {
return `https:/${reqPath}${reqQuery}`;
}

if (proxyProvider === OLLAMA) {
return `https:/${reqPath}`;
}
let proxyPath = `${providerBasePath}${reqPath}${reqQuery}`;

// Fix specific for Anthropic SDK calls. Is this needed? - Yes
Expand Down Expand Up @@ -90,7 +99,8 @@ export async function proxyHandler(c: Context): Promise<Response> {
}
};

let urlToFetch = getProxyPath(c.req.url, store.proxyProvider, store.proxyPath);
const customHost = requestHeaders[HEADER_KEYS.CUSTOM_HOST] || requestConfig?.customHost || "";
let urlToFetch = getProxyPath(c.req.url, store.proxyProvider, store.proxyPath, customHost);
store.isStreamingMode = getStreamingMode(store.reqBody, store.proxyProvider, urlToFetch)

if (requestConfig &&
Expand Down
2 changes: 1 addition & 1 deletion src/handlers/streamHandler.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { AZURE_OPEN_AI, CONTENT_TYPES, COHERE, GOOGLE, REQUEST_TIMEOUT_STATUS_CODE } from "../globals";
import { AZURE_OPEN_AI, CONTENT_TYPES, COHERE, GOOGLE, REQUEST_TIMEOUT_STATUS_CODE, OLLAMA } from "../globals";
import { OpenAIChatCompleteResponse } from "../providers/openai/chatComplete";
import { OpenAICompleteResponse } from "../providers/openai/complete";
import { getStreamModeSplitPattern } from "../utils";
Expand Down
19 changes: 18 additions & 1 deletion src/middlewares/requestValidator/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
DEEPINFRA,
STABILITY_AI,
NOMIC,
OLLAMA,
} from "../../globals";
import { configSchema } from "./schema/config";

Expand Down Expand Up @@ -63,7 +64,7 @@ export const requestValidator = (c: Context, next: any) => {
}
if (
requestHeaders[`x-${POWERED_BY}-provider`] &&
![OPEN_AI, AZURE_OPEN_AI, COHERE, ANTHROPIC, ANYSCALE, PALM, TOGETHER_AI, GOOGLE, MISTRAL_AI, PERPLEXITY_AI, DEEPINFRA, NOMIC, STABILITY_AI].includes(
![OPEN_AI, AZURE_OPEN_AI, COHERE, ANTHROPIC, ANYSCALE, PALM, TOGETHER_AI, GOOGLE, MISTRAL_AI, PERPLEXITY_AI, DEEPINFRA, NOMIC, STABILITY_AI, OLLAMA].includes(
requestHeaders[`x-${POWERED_BY}-provider`]
)
) {
Expand All @@ -81,6 +82,22 @@ export const requestValidator = (c: Context, next: any) => {
);
}

const customHostHeader = requestHeaders[`x-${POWERED_BY}-custom-host`];
if (customHostHeader && customHostHeader.indexOf("api.portkey") > -1) {
return new Response(
JSON.stringify({
status: "failure",
message: `Invalid custom host`,
}),
{
status: 400,
headers: {
"content-type": "application/json",
},
}
);;
}


if (requestHeaders[`x-${POWERED_BY}-config`]) {
try {
Expand Down
Loading

0 comments on commit 0047d92

Please sign in to comment.