Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/ollama integration #194

Merged
merged 27 commits into from
Feb 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
3582c17
completion without stream in progress
Jan 31, 2024
c8bad18
ollama completion with and without stream integrated
Feb 1, 2024
4932faf
ollama chat with and without stream integrated + small fix in cohere …
Feb 1, 2024
cf8f223
Ollama embed integration + ASK: baseUrl for Ollama
Feb 2, 2024
139b8da
proxy path URL handled
Feb 5, 2024
2b235b6
feat: chatCompletion from openai compatible, embedding from native wa…
csgulati09 Feb 13, 2024
25933c4
chore: merged main on feat/ollama-integration
csgulati09 Feb 13, 2024
6c87e25
merged main on feat/ollama-integration
csgulati09 Feb 14, 2024
40aff7f
fix: removed unwanted comments
csgulati09 Feb 14, 2024
106da25
fix: removed unused imports, getEndpoint to /api/chat, removed Ollama…
csgulati09 Feb 14, 2024
915cf72
feat: add new header constants
VisargD Feb 15, 2024
f4675a0
feat: add custom host header validator
VisargD Feb 15, 2024
2a2eb6c
feat: add custom_host and forward_headers config schema props
VisargD Feb 15, 2024
67c24cb
fix: return empty object instead of null for ollama headers
VisargD Feb 15, 2024
b24160b
feat: add customHost and forwardHeaders to interface
VisargD Feb 15, 2024
51c4d56
feat: support custom host in proxy routes
VisargD Feb 15, 2024
c3ca279
feat: support custom host and forward headers for unified routes
VisargD Feb 15, 2024
cff9c61
Merge branch 'main' into feat/custom-host-and-forward-headers
VisargD Feb 16, 2024
59d80eb
fix: custom host config schema
VisargD Feb 16, 2024
6aa7124
fix: forward headers config inherit logic
VisargD Feb 16, 2024
9936f34
Merge branch 'feat/ollama-integration' into feat/custom-host-and-forw…
VisargD Feb 16, 2024
455ef4f
Merge branch 'main' into feat/ollama-integration
VisargD Feb 16, 2024
9b1c507
Merge branch 'feat/ollama-integration' into feat/custom-host-and-forw…
VisargD Feb 16, 2024
45db52f
fix: change custom host selection preference
VisargD Feb 17, 2024
2a25274
fix: custom host based url creation for proxy routes
VisargD Feb 17, 2024
0a84993
Merge pull request #218 from Portkey-AI/feat/custom-host-and-forward-…
VisargD Feb 17, 2024
e4e560a
Merge branch 'main' into feat/ollama-integration
VisargD Feb 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,6 @@ dist
# Rollup build dir
build
.DS_Store

# Wrangler temp directory
.wrangler
5 changes: 4 additions & 1 deletion src/globals.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ export const HEADER_KEYS: Record<string, string> = {
RETRIES: `x-${POWERED_BY}-retry-count`,
PROVIDER: `x-${POWERED_BY}-provider`,
TRACE_ID: `x-${POWERED_BY}-trace-id`,
CACHE: `x-${POWERED_BY}-cache`
CACHE: `x-${POWERED_BY}-cache`,
FORWARD_HEADERS: `x-${POWERED_BY}-forward-headers`,
CUSTOM_HOST: `x-${POWERED_BY}-custom-host`
}

export const RESPONSE_HEADER_KEYS: Record<string, string> = {
Expand Down Expand Up @@ -33,6 +35,7 @@ export const MISTRAL_AI: string = "mistral-ai";
export const DEEPINFRA: string = "deepinfra";
export const STABILITY_AI: string = "stability-ai";
export const NOMIC: string = "nomic";
export const OLLAMA: string = "ollama";

export const providersWithStreamingSupport = [OPEN_AI, AZURE_OPEN_AI, ANTHROPIC, COHERE];
export const allowedProxyProviders = [OPEN_AI, COHERE, AZURE_OPEN_AI, ANTHROPIC];
Expand Down
115 changes: 80 additions & 35 deletions src/handlers/handlerUtils.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Context } from "hono";
import { AZURE_OPEN_AI, CONTENT_TYPES, GOOGLE, HEADER_KEYS, PALM, POWERED_BY, RESPONSE_HEADER_KEYS, RETRY_STATUS_CODES, STABILITY_AI } from "../globals";
import { ANTHROPIC, AZURE_OPEN_AI, CONTENT_TYPES, GOOGLE, HEADER_KEYS, OLLAMA, PALM, POWERED_BY, RESPONSE_HEADER_KEYS, RETRY_STATUS_CODES, STABILITY_AI } from "../globals";
import Providers from "../providers";
import { ProviderAPIConfig, endpointStrings } from "../providers/types";
import transformToProviderRequest from "../services/transformToProviderRequest";
Expand All @@ -19,13 +19,23 @@ import { OpenAICompleteJSONToStreamResponseTransform } from "../providers/openai
* @param {string} method - The HTTP method for the request.
* @returns {RequestInit} - The fetch options for the request.
*/
export function constructRequest(headers: any, provider: string = "", method: string = "POST") {
export function constructRequest(providerConfigMappedHeaders: any, provider: string, method: string, forwardHeaders: string[], requestHeaders: Record<string, string>) {
let baseHeaders: any = {
"content-type": "application/json"
};

let headers: Record<string, string> = {
...providerConfigMappedHeaders
};

const forwardHeadersMap: Record<string, string> = {};

forwardHeaders.forEach((h: string) => {
if (requestHeaders[h]) forwardHeadersMap[h] = requestHeaders[h];
})

// Add any headers that the model might need
headers = {...baseHeaders, ...headers}
headers = {...baseHeaders, ...headers, ...forwardHeadersMap}

let fetchOptions: RequestInit = {
method,
Expand Down Expand Up @@ -122,7 +132,8 @@ export const fetchProviderOptionsFromConfig = (config: Config | ShortConfig): Op
virtualKey: camelCaseConfig.virtualKey,
apiKey: camelCaseConfig.apiKey,
cache: camelCaseConfig.cache,
retry: camelCaseConfig.retry
retry: camelCaseConfig.retry,
customHost: camelCaseConfig.customHost
}];
if (camelCaseConfig.resourceName) providerOptions[0].resourceName = camelCaseConfig.resourceName;
if (camelCaseConfig.deploymentId) providerOptions[0].deploymentId = camelCaseConfig.deploymentId;
Expand Down Expand Up @@ -162,40 +173,48 @@ export async function tryPostProxy(c: Context, providerOption:Options, inputPara
const apiConfig: ProviderAPIConfig = Providers[provider].api;
let fetchOptions;
let url = providerOption.urlToFetch as string;

let baseUrl:string, endpoint:string;

const forwardHeaders: string[] = [];
baseUrl = requestHeaders[HEADER_KEYS.CUSTOM_HOST] || providerOption.customHost || "";

if (provider === AZURE_OPEN_AI && apiConfig.getBaseURL && apiConfig.getEndpoint) {
// Construct the base object for the request
if(!!providerOption.apiKey) {
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, "apiKey"), provider, method);
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, "apiKey"), provider, method, forwardHeaders, requestHeaders);
} else {
fetchOptions = constructRequest(apiConfig.headers(providerOption.adAuth, "adAuth"), provider, method);
fetchOptions = constructRequest(apiConfig.headers(providerOption.adAuth, "adAuth"), provider, method, forwardHeaders, requestHeaders);
}
baseUrl = apiConfig.getBaseURL(providerOption.resourceName, providerOption.deploymentId);
baseUrl = baseUrl || apiConfig.getBaseURL(providerOption.resourceName, providerOption.deploymentId);
endpoint = apiConfig.getEndpoint(fn, providerOption.apiVersion, url);
url = `${baseUrl}${endpoint}`;
} else if (provider === PALM && apiConfig.baseURL && apiConfig.getEndpoint) {
fetchOptions = constructRequest(apiConfig.headers(), provider, method);
baseUrl = apiConfig.baseURL;
fetchOptions = constructRequest(apiConfig.headers(), provider, method, forwardHeaders, requestHeaders);
baseUrl = baseUrl || apiConfig.baseURL;
endpoint = apiConfig.getEndpoint(fn, providerOption.apiKey, params?.model);
url = `${baseUrl}${endpoint}`;
} else if (provider === "anthropic" && apiConfig.baseURL) {
} else if (provider === ANTHROPIC && apiConfig.baseURL) {
// Construct the base object for the POST request
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, fn), provider);
baseUrl = apiConfig.baseURL;
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, fn), provider, "POST", forwardHeaders, requestHeaders);
baseUrl = baseUrl || apiConfig.baseURL;
endpoint = apiConfig[fn] || "";
url = `${baseUrl}${endpoint}`;
} else if (provider === GOOGLE && apiConfig.baseURL && apiConfig.getEndpoint) {
fetchOptions = constructRequest(apiConfig.headers(), provider);
baseUrl = apiConfig.baseURL;
fetchOptions = constructRequest(apiConfig.headers(), provider, "POST", forwardHeaders, requestHeaders);
baseUrl = baseUrl || apiConfig.baseURL;
endpoint = apiConfig.getEndpoint(fn, providerOption.apiKey, params.model, params.stream);
url = `${baseUrl}${endpoint}`;
} else if (provider === STABILITY_AI && apiConfig.baseURL && apiConfig.getEndpoint) {
fetchOptions = constructRequest(apiConfig.headers(), provider);
baseUrl = apiConfig.baseURL;
fetchOptions = constructRequest(apiConfig.headers(), provider, "POST", forwardHeaders, requestHeaders);
baseUrl = baseUrl || apiConfig.baseURL;
endpoint = apiConfig.getEndpoint(fn, params.model, url);
url = `${baseUrl}${endpoint}`;
} else {
// Construct the base object for the request
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey), provider, method);
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey), provider, method, forwardHeaders, requestHeaders);
}

if (method === "POST") {
fetchOptions.body = JSON.stringify(params)
}
Expand Down Expand Up @@ -314,37 +333,45 @@ export async function tryPost(c: Context, providerOption:Options, inputParams: P


let baseUrl:string, endpoint:string, fetchOptions;
const forwardHeaders = requestHeaders[HEADER_KEYS.FORWARD_HEADERS]?.split(",").map(h => h.trim()) || providerOption.forwardHeaders || [];
baseUrl = requestHeaders[HEADER_KEYS.CUSTOM_HOST] || providerOption.customHost || "";

if (provider === AZURE_OPEN_AI && apiConfig.getBaseURL && apiConfig.getEndpoint) {
// Construct the base object for the POST request
if(!!providerOption.apiKey) {
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, "apiKey"), provider);
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, "apiKey"), provider, "POST", forwardHeaders, requestHeaders);
} else {
fetchOptions = constructRequest(apiConfig.headers(providerOption.adAuth, "adAuth"), provider);
fetchOptions = constructRequest(apiConfig.headers(providerOption.adAuth, "adAuth"), provider, "POST", forwardHeaders, requestHeaders);
}
baseUrl = apiConfig.getBaseURL(providerOption.resourceName, providerOption.deploymentId);
baseUrl = baseUrl || apiConfig.getBaseURL(providerOption.resourceName, providerOption.deploymentId);
endpoint = apiConfig.getEndpoint(fn, providerOption.apiVersion);
} else if (provider === PALM && apiConfig.baseURL && apiConfig.getEndpoint) {
fetchOptions = constructRequest(apiConfig.headers(), provider);
baseUrl = apiConfig.baseURL;
fetchOptions = constructRequest(apiConfig.headers(), provider, "POST", forwardHeaders, requestHeaders);
baseUrl = baseUrl || apiConfig.baseURL;
endpoint = apiConfig.getEndpoint(fn, providerOption.apiKey, providerOption.overrideParams?.model || params?.model);
} else if (provider === "anthropic" && apiConfig.baseURL) {
} else if (provider === ANTHROPIC && apiConfig.baseURL) {
// Construct the base object for the POST request
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, fn), provider);
baseUrl = apiConfig.baseURL;
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey, fn), provider, "POST", forwardHeaders, requestHeaders);
baseUrl = baseUrl || apiConfig.baseURL;
endpoint = apiConfig[fn] || "";
} else if (provider === GOOGLE && apiConfig.baseURL && apiConfig.getEndpoint) {
fetchOptions = constructRequest(apiConfig.headers(), provider);
baseUrl = apiConfig.baseURL;
fetchOptions = constructRequest(apiConfig.headers(), provider, "POST", forwardHeaders, requestHeaders);
baseUrl = baseUrl || apiConfig.baseURL;
endpoint = apiConfig.getEndpoint(fn, providerOption.apiKey, transformedRequestBody.model, params.stream);
} else if (provider === STABILITY_AI && apiConfig.baseURL && apiConfig.getEndpoint) {
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey), provider);
baseUrl = apiConfig.baseURL;
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey), provider, "POST", forwardHeaders, requestHeaders);
baseUrl = baseUrl || apiConfig.baseURL;
endpoint = apiConfig.getEndpoint(fn, params.model);
} else {
} else if (provider === OLLAMA && apiConfig.getEndpoint) {
fetchOptions = constructRequest(apiConfig.headers(), provider, "POST", forwardHeaders, requestHeaders);
baseUrl = baseUrl;
endpoint = apiConfig.getEndpoint(fn, providerOption.apiKey, transformedRequestBody.model, params.stream);
}
else {
// Construct the base object for the POST request
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey), provider);
fetchOptions = constructRequest(apiConfig.headers(providerOption.apiKey), provider, "POST", forwardHeaders, requestHeaders);

baseUrl = apiConfig.baseURL || "";
baseUrl = baseUrl || apiConfig.baseURL || "";
endpoint = apiConfig[fn] || "";
}

Expand All @@ -358,7 +385,7 @@ export async function tryPost(c: Context, providerOption:Options, inputParams: P

providerOption.retry = {
attempts: providerOption.retry?.attempts ?? 0,
onStatusCodes: providerOption.retry?.onStatusCodes ?? []
onStatusCodes: providerOption.retry?.onStatusCodes ?? RETRY_STATUS_CODES
}

const [getFromCacheFunction, cacheIdentifier, requestOptions] = [
Expand Down Expand Up @@ -543,7 +570,7 @@ export async function tryTargetsRecursively(
const strategyMode = currentTarget.strategy?.mode;

// start: merge inherited config with current target config (preference given to current)
const currentInheritedConfig = {
const currentInheritedConfig: Record<string, any> = {
overrideParams : {
...inheritedConfig.overrideParams,
...currentTarget.overrideParams
Expand All @@ -553,11 +580,29 @@ export async function tryTargetsRecursively(
requestTimeout: null
}

if (currentTarget.forwardHeaders) {
currentInheritedConfig.forwardHeaders = [...currentTarget.forwardHeaders];
} else if (inheritedConfig.forwardHeaders) {
currentInheritedConfig.forwardHeaders = [...inheritedConfig.forwardHeaders];
currentTarget.forwardHeaders = [...inheritedConfig.forwardHeaders];
}

if (currentTarget.customHost) {
currentInheritedConfig.customHost = currentTarget.customHost
} else if (inheritedConfig.customHost) {
currentInheritedConfig.customHost = inheritedConfig.customHost;
currentTarget.customHost = inheritedConfig.customHost;
}

if (currentTarget.requestTimeout) {
currentInheritedConfig.requestTimeout = currentTarget.requestTimeout
} else if (inheritedConfig.requestTimeout) {
currentInheritedConfig.requestTimeout = inheritedConfig.requestTimeout;
currentTarget.requestTimeout = inheritedConfig.requestTimeout;
}



currentTarget.overrideParams = {
...currentInheritedConfig.overrideParams
}
Expand Down
11 changes: 9 additions & 2 deletions src/handlers/proxyGetHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,16 @@ function proxyProvider(proxyModeHeader:string, providerHeader: string) {
return proxyProvider;
}

function getProxyPath(requestURL:string, proxyProvider:string, proxyEndpointPath:string) {
function getProxyPath(requestURL:string, proxyProvider:string, proxyEndpointPath:string, customHost: string) {
let reqURL = new URL(requestURL);
let reqPath = reqURL.pathname;
const reqQuery = reqURL.search;
reqPath = reqPath.replace(proxyEndpointPath, "");

if (customHost) {
return `${customHost}${reqPath}${reqQuery}`
}

const providerBasePath = Providers[proxyProvider].api.baseURL;
if (proxyProvider === AZURE_OPEN_AI) {
return `https:/${reqPath}${reqQuery}`;
Expand Down Expand Up @@ -55,7 +60,9 @@ export async function proxyGetHandler(c: Context): Promise<Response> {
proxyPath: c.req.url.indexOf("/v1/proxy") > -1 ? "/v1/proxy" : "/v1"
}

let urlToFetch = getProxyPath(c.req.url, store.proxyProvider, store.proxyPath);
const customHost = requestHeaders[HEADER_KEYS.CUSTOM_HOST] || "";

let urlToFetch = getProxyPath(c.req.url, store.proxyProvider, store.proxyPath, customHost);

let fetchOptions = {
headers: headersToSend(requestHeaders, store.customHeadersToAvoid),
Expand Down
16 changes: 13 additions & 3 deletions src/handlers/proxyHandler.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { Context } from "hono";
import { retryRequest } from "./retryHandler";
import Providers from "../providers";
import { ANTHROPIC, MAX_RETRIES, HEADER_KEYS, RETRY_STATUS_CODES, POWERED_BY, RESPONSE_HEADER_KEYS, AZURE_OPEN_AI, CONTENT_TYPES } from "../globals";
import { ANTHROPIC, MAX_RETRIES, HEADER_KEYS, RETRY_STATUS_CODES, POWERED_BY, RESPONSE_HEADER_KEYS, AZURE_OPEN_AI, CONTENT_TYPES, OLLAMA } from "../globals";
import { fetchProviderOptionsFromConfig, responseHandler, tryProvidersInSequence, updateResponseHeaders } from "./handlerUtils";
import { convertKeysToCamelCase, getStreamingMode } from "../utils";
import { Config, ShortConfig } from "../types/requestBody";
Expand All @@ -12,15 +12,24 @@ function proxyProvider(proxyModeHeader:string, providerHeader: string) {
return proxyProvider;
}

function getProxyPath(requestURL:string, proxyProvider:string, proxyEndpointPath:string) {
function getProxyPath(requestURL:string, proxyProvider:string, proxyEndpointPath:string, customHost: string) {
let reqURL = new URL(requestURL);
let reqPath = reqURL.pathname;
const reqQuery = reqURL.search;
reqPath = reqPath.replace(proxyEndpointPath, "");

if (customHost) {
return `${customHost}${reqPath}${reqQuery}`
}

const providerBasePath = Providers[proxyProvider].api.baseURL;
if (proxyProvider === AZURE_OPEN_AI) {
return `https:/${reqPath}${reqQuery}`;
}

if (proxyProvider === OLLAMA) {
return `https:/${reqPath}`;
}
let proxyPath = `${providerBasePath}${reqPath}${reqQuery}`;

// Fix specific for Anthropic SDK calls. Is this needed? - Yes
Expand Down Expand Up @@ -90,7 +99,8 @@ export async function proxyHandler(c: Context): Promise<Response> {
}
};

let urlToFetch = getProxyPath(c.req.url, store.proxyProvider, store.proxyPath);
const customHost = requestHeaders[HEADER_KEYS.CUSTOM_HOST] || requestConfig?.customHost || "";
let urlToFetch = getProxyPath(c.req.url, store.proxyProvider, store.proxyPath, customHost);
store.isStreamingMode = getStreamingMode(store.reqBody, store.proxyProvider, urlToFetch)

if (requestConfig &&
Expand Down
2 changes: 1 addition & 1 deletion src/handlers/streamHandler.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { AZURE_OPEN_AI, CONTENT_TYPES, COHERE, GOOGLE, REQUEST_TIMEOUT_STATUS_CODE } from "../globals";
import { AZURE_OPEN_AI, CONTENT_TYPES, COHERE, GOOGLE, REQUEST_TIMEOUT_STATUS_CODE, OLLAMA } from "../globals";
import { OpenAIChatCompleteResponse } from "../providers/openai/chatComplete";
import { OpenAICompleteResponse } from "../providers/openai/complete";
import { getStreamModeSplitPattern } from "../utils";
Expand Down
19 changes: 18 additions & 1 deletion src/middlewares/requestValidator/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
DEEPINFRA,
STABILITY_AI,
NOMIC,
OLLAMA,
} from "../../globals";
import { configSchema } from "./schema/config";

Expand Down Expand Up @@ -63,7 +64,7 @@ export const requestValidator = (c: Context, next: any) => {
}
if (
requestHeaders[`x-${POWERED_BY}-provider`] &&
![OPEN_AI, AZURE_OPEN_AI, COHERE, ANTHROPIC, ANYSCALE, PALM, TOGETHER_AI, GOOGLE, MISTRAL_AI, PERPLEXITY_AI, DEEPINFRA, NOMIC, STABILITY_AI].includes(
![OPEN_AI, AZURE_OPEN_AI, COHERE, ANTHROPIC, ANYSCALE, PALM, TOGETHER_AI, GOOGLE, MISTRAL_AI, PERPLEXITY_AI, DEEPINFRA, NOMIC, STABILITY_AI, OLLAMA].includes(
requestHeaders[`x-${POWERED_BY}-provider`]
)
) {
Expand All @@ -81,6 +82,22 @@ export const requestValidator = (c: Context, next: any) => {
);
}

const customHostHeader = requestHeaders[`x-${POWERED_BY}-custom-host`];
if (customHostHeader && customHostHeader.indexOf("api.portkey") > -1) {
return new Response(
JSON.stringify({
status: "failure",
message: `Invalid custom host`,
}),
{
status: 400,
headers: {
"content-type": "application/json",
},
}
);;
}


if (requestHeaders[`x-${POWERED_BY}-config`]) {
try {
Expand Down
Loading