Skip to content

Commit

Permalink
Support rerank on aws (#208)
Browse files Browse the repository at this point in the history
* Support rerank on aws

* Fix model names

* Rerank 3.5

* Skip
  • Loading branch information
billytrend-cohere authored Dec 3, 2024
1 parent c430bc8 commit 22d9664
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 28 deletions.
66 changes: 38 additions & 28 deletions src/aws-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,35 @@ const withTempEnv = async <R>(updateEnv: () => void, fn: () => Promise<R>): Prom
}
};

const streamingResponseParser: Record<string, any> = {
"chat": serializers.StreamedChatResponse,
"generate": serializers.GenerateStreamedResponse,
const streamingResponseParser: Record<string, Record<string, any>> = {
1: {
"chat": serializers.StreamedChatResponse,
"generate": serializers.GenerateStreamedResponse,
},
2: {
"chat": serializers.StreamedChatResponseV2,
"generate": serializers.GenerateStreamedResponse,
}
}

const nonStreamedResponseParser: Record<string, any> = {
"chat": serializers.NonStreamedChatResponse,
"embed": serializers.EmbedResponse,
"generate": serializers.Generation,
const nonStreamedResponseParser: Record<string, Record<string, any>> = {
1: {
"chat": serializers.NonStreamedChatResponse,
"embed": serializers.EmbedResponse,
"generate": serializers.Generation,
"rerank": serializers.RerankResponse,
},
2: {
"chat": serializers.ChatResponse,
"embed": serializers.EmbedByTypeResponse,
"generate": serializers.Generation,
"rerank": serializers.V2RerankResponse,
}
}

export const mapResponseFromBedrock = async (streaming: boolean, endpoint: string, obj: {}) => {
export const mapResponseFromBedrock = async (streaming: boolean, endpoint: string, version: 1 | 2, obj: {}) => {

const parser = streaming ? streamingResponseParser[endpoint] : nonStreamedResponseParser[endpoint];
const parser = streaming ? streamingResponseParser[version][endpoint] : nonStreamedResponseParser[version][endpoint];

const config = {
unrecognizedObjectKeys: "passthrough",
Expand All @@ -56,7 +71,7 @@ export type AwsProps = {

export type AwsPlatform = "sagemaker" | "bedrock"

export type AwsEndpoint = "chat" | "generate" | "embed"
export type AwsEndpoint = "chat" | "generate" | "embed" | "rerank"

export const getUrl = (
platform: "bedrock" | "sagemaker",
Expand Down Expand Up @@ -124,19 +139,6 @@ export const getAuthHeaders = async (url: URL, method: string, headers: Record<s
return signed.headers;
};

export const getEndpointFromUrl = (url: string, chatModel?: string, embedModel?: string, generateModel?: string): string => {
if (chatModel && url.includes(chatModel)) {
return "chat";
}
if (embedModel && url.includes(embedModel)) {
return "embed";
}
if (generateModel && url.includes(generateModel)) {
return "generate";
}
throw new Error(`Unknown endpoint in url: ${url}`);
}

export const parseAWSEvent = (line: string) => {
const regex = /{[^\}]*}/;
const match = line.match(regex);
Expand All @@ -152,14 +154,18 @@ export const parseAWSEvent = (line: string) => {
}
}

const getVersion = (version: string): 1 | 2 => (({"v1": 1, "v2": 2})[version] || 1) as 1 | 2

export const fetchOverride = (platform: AwsPlatform, {
awsRegion,
awsAccessKey,
awsSecretKey,
awsSessionToken,
}: AwsProps): FetchFunction => async (fetcherArgs: Fetcher.Args): Promise<APIResponse<any, Fetcher.Error>> => {
const endpoint = fetcherArgs.url.split('/').pop() as string;
const bodyJson = fetcherArgs.body as { model?: string, stream?: boolean };
const splittedUrl: string[] = fetcherArgs.url.split('/');
const endpoint = splittedUrl.pop()!;
const version = getVersion(splittedUrl.pop()!);
const bodyJson = fetcherArgs.body as { model?: string, stream?: boolean, api_version?: number };
console.assert(bodyJson.model, "model is required")

const isStreaming = Boolean(bodyJson.stream);
Expand All @@ -171,6 +177,10 @@ export const fetchOverride = (platform: AwsPlatform, {
isStreaming,
);

if (endpoint === "rerank") {
bodyJson["api_version"] = version;
}

delete bodyJson["stream"];
delete bodyJson["model"];
delete (fetcherArgs.headers as Record<string, string>)['Authorization'];
Expand Down Expand Up @@ -209,7 +219,7 @@ export const fetchOverride = (platform: AwsPlatform, {
for (const line of lineDecoder.decode(chunk as any)) {
const event = parseAWSEvent(line);
if (event) {
const obj = await mapResponseFromBedrock(isStreaming, endpoint, event);
const obj = await mapResponseFromBedrock(isStreaming, endpoint, version, event);
newBody.push(JSON.stringify(obj) + "\n");
}
}
Expand All @@ -218,7 +228,7 @@ export const fetchOverride = (platform: AwsPlatform, {
for (const line of lineDecoder.flush()) {
const event = parseAWSEvent(line);
if (event) {
const obj = await mapResponseFromBedrock(isStreaming, endpoint, event);
const obj = await mapResponseFromBedrock(isStreaming, endpoint, version, event);
newBody.push(JSON.stringify(obj) + "\n");
}
}
Expand All @@ -229,7 +239,7 @@ export const fetchOverride = (platform: AwsPlatform, {
}
} else {
const oldBody = await response.body as {};
const mappedResponse = await mapResponseFromBedrock(isStreaming, endpoint, oldBody);
const mappedResponse = await mapResponseFromBedrock(isStreaming, endpoint, version, oldBody);
return {
ok: true,
body: mappedResponse
Expand Down
18 changes: 18 additions & 0 deletions src/test/bedrock-tests.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ const models: Record<AwsPlatform, Record<AwsEndpoint, string>> = {
generate: "cohere.command-text-v14",
embed: "cohere.embed-multilingual-v3",
chat: "cohere.command-r-plus-v1:0",
rerank: "cohere.rerank-v3-5:0",
},
sagemaker: {
generate: "cohere-command-light",
embed: "cohere-embed-multilingual-v3",
chat: "cohere-command-plus",
rerank: "cohere.rerank-v1",
},
};

Expand Down Expand Up @@ -83,6 +85,22 @@ describe.each<AwsPlatform>(["bedrock"])(
});
});

test.skip("rerank works", async () => {
const rerank = await cohere.v2.rerank({
model: models[platform].rerank,
documents: [
"Carson City is the capital city of the American state of Nevada.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
"Washington, D.C. d (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.",
],
query: "What is the capital of the United States?",
topN: 3,
});

expect(rerank.results).toBeDefined();
});

test.skip("chat stream works", async () => {
const chat = await cohere.chatStream({
model: models[platform].chat,
Expand Down

0 comments on commit 22d9664

Please sign in to comment.