Skip to content

Commit

Permalink
Rerank 3.5
Browse files Browse the repository at this point in the history
  • Loading branch information
billytrend-cohere committed Dec 3, 2024
1 parent cfc909a commit 6c6b234
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 17 deletions.
52 changes: 37 additions & 15 deletions src/aws-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,35 @@ const withTempEnv = async <R>(updateEnv: () => void, fn: () => Promise<R>): Prom
}
};

const streamingResponseParser: Record<string, any> = {
"chat": serializers.StreamedChatResponse,
"generate": serializers.GenerateStreamedResponse,
const streamingResponseParser: Record<string, Record<string, any>> = {
1: {
"chat": serializers.StreamedChatResponse,
"generate": serializers.GenerateStreamedResponse,
},
2: {
"chat": serializers.StreamedChatResponseV2,
"generate": serializers.GenerateStreamedResponse,
}
}

const nonStreamedResponseParser: Record<string, any> = {
"chat": serializers.NonStreamedChatResponse,
"embed": serializers.EmbedResponse,
"generate": serializers.Generation,
"rerank": serializers.RerankResponse,
const nonStreamedResponseParser: Record<string, Record<string, any>> = {
1: {
"chat": serializers.NonStreamedChatResponse,
"embed": serializers.EmbedResponse,
"generate": serializers.Generation,
"rerank": serializers.RerankResponse,
},
2: {
"chat": serializers.ChatResponse,
"embed": serializers.EmbedByTypeResponse,
"generate": serializers.Generation,
"rerank": serializers.V2RerankResponse,
}
}

export const mapResponseFromBedrock = async (streaming: boolean, endpoint: string, obj: {}) => {
export const mapResponseFromBedrock = async (streaming: boolean, endpoint: string, version: 1 | 2, obj: {}) => {

const parser = streaming ? streamingResponseParser[endpoint] : nonStreamedResponseParser[endpoint];
const parser = streaming ? streamingResponseParser[version][endpoint] : nonStreamedResponseParser[version][endpoint];

const config = {
unrecognizedObjectKeys: "passthrough",
Expand Down Expand Up @@ -140,14 +154,18 @@ export const parseAWSEvent = (line: string) => {
}
}

const getVersion = (version: string): 1 | 2 => (({"v1": 1, "v2": 2})[version] || 1) as 1 | 2

export const fetchOverride = (platform: AwsPlatform, {
awsRegion,
awsAccessKey,
awsSecretKey,
awsSessionToken,
}: AwsProps): FetchFunction => async (fetcherArgs: Fetcher.Args): Promise<APIResponse<any, Fetcher.Error>> => {
const endpoint = fetcherArgs.url.split('/').pop() as string;
const bodyJson = fetcherArgs.body as { model?: string, stream?: boolean };
const splittedUrl: string[] = fetcherArgs.url.split('/');
const endpoint = splittedUrl.pop()!;
const version = getVersion(splittedUrl.pop()!);
const bodyJson = fetcherArgs.body as { model?: string, stream?: boolean, api_version?: number };
console.assert(bodyJson.model, "model is required")

const isStreaming = Boolean(bodyJson.stream);
Expand All @@ -159,6 +177,10 @@ export const fetchOverride = (platform: AwsPlatform, {
isStreaming,
);

if (endpoint === "rerank") {
bodyJson["api_version"] = version;
}

delete bodyJson["stream"];
delete bodyJson["model"];
delete (fetcherArgs.headers as Record<string, string>)['Authorization'];
Expand Down Expand Up @@ -197,7 +219,7 @@ export const fetchOverride = (platform: AwsPlatform, {
for (const line of lineDecoder.decode(chunk as any)) {
const event = parseAWSEvent(line);
if (event) {
const obj = await mapResponseFromBedrock(isStreaming, endpoint, event);
const obj = await mapResponseFromBedrock(isStreaming, endpoint, version, event);
newBody.push(JSON.stringify(obj) + "\n");
}
}
Expand All @@ -206,7 +228,7 @@ export const fetchOverride = (platform: AwsPlatform, {
for (const line of lineDecoder.flush()) {
const event = parseAWSEvent(line);
if (event) {
const obj = await mapResponseFromBedrock(isStreaming, endpoint, event);
const obj = await mapResponseFromBedrock(isStreaming, endpoint, version, event);
newBody.push(JSON.stringify(obj) + "\n");
}
}
Expand All @@ -217,7 +239,7 @@ export const fetchOverride = (platform: AwsPlatform, {
}
} else {
const oldBody = await response.body as {};
const mappedResponse = await mapResponseFromBedrock(isStreaming, endpoint, oldBody);
const mappedResponse = await mapResponseFromBedrock(isStreaming, endpoint, version, oldBody);
return {
ok: true,
body: mappedResponse
Expand Down
20 changes: 18 additions & 2 deletions src/test/bedrock-tests.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ const models: Record<AwsPlatform, Record<AwsEndpoint, string>> = {
generate: "cohere.command-text-v14",
embed: "cohere.embed-multilingual-v3",
chat: "cohere.command-r-plus-v1:0",
rerank: "cohere.rerank-v1",
rerank: "cohere.rerank-v3-5:0",
},
sagemaker: {
generate: "cohere-command-light",
Expand Down Expand Up @@ -85,7 +85,23 @@ describe.each<AwsPlatform>(["bedrock"])(
});
});

test.skip("chat stream works", async () => {
test("rerank works", async () => {
const rerank = await cohere.v2.rerank({
model: models[platform].rerank,
documents: [
"Carson City is the capital city of the American state of Nevada.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
"Washington, D.C. d (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
"Capital punishment (the death penalty) has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.",
],
query: "What is the capital of the United States?",
topN: 3,
});

expect(rerank.results).toBeDefined();
});

test("chat stream works", async () => {
const chat = await cohere.chatStream({
model: models[platform].chat,
message: "send me a short message",
Expand Down

0 comments on commit 6c6b234

Please sign in to comment.