Skip to content

Commit

Permalink
Merge pull request #240 from Portkey-AI/feat/anthropic-claude-3
Browse files Browse the repository at this point in the history
Feat: add claude 3 support for anthropic and bedrock
  • Loading branch information
VisargD authored Mar 5, 2024
2 parents c3b0533 + 23c29b2 commit c6bd75f
Show file tree
Hide file tree
Showing 2 changed files with 222 additions and 70 deletions.
36 changes: 35 additions & 1 deletion src/providers/anthropic/chatComplete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,41 @@ export const AnthropicChatCompleteConfig: ProviderConfig = {
if (!!params.messages) {
params.messages.forEach(msg => {
if (msg.role !== "system") {
messages.push(msg);
if (msg.content && typeof msg.content === "object" && msg.content.length) {
const transformedMessage: Record<string, any> = {
role: msg.role,
content: [],
};
msg.content.forEach(item => {
if (item.type === "text") {
transformedMessage.content.push({ type: item.type, text: item.text });
} else if (item.type === "image_url" && item.image_url && item.image_url.url) {
const parts = item.image_url.url.split(";");
if (parts.length === 2) {
const base64ImageParts = parts[1].split(",");
const base64Image = base64ImageParts[1];
const mediaTypeParts = parts[0].split(":");
if (mediaTypeParts.length === 2 && base64Image) {
const mediaType = mediaTypeParts[1];
transformedMessage.content.push({
type: "image",
source: {
type: "base64",
media_type: mediaType,
data: base64Image,
},
});
}
}
}
});
messages.push(transformedMessage as Message);
} else {
messages.push({
role: msg.role,
content: msg.content
});
}
}
})
}
Expand Down
256 changes: 187 additions & 69 deletions src/providers/bedrock/chatComplete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@ import { BEDROCK } from "../../globals";
import { Message, Params } from "../../types/requestBody";
import {
ChatCompletionResponse,
CompletionResponse,
ErrorResponse,
ProviderConfig,
} from "../types";
import {
BedrockAI21CompleteResponse,
BedrockAnthropicCompleteResponse,
BedrockAnthropicStreamChunk,
BedrockCohereCompleteResponse,
BedrockCohereStreamChunk,
BedrockLlamaCompleteResponse,
Expand All @@ -20,31 +17,100 @@ import {
import { BedrockErrorResponse } from "./embed";

export const BedrockAnthropicChatCompleteConfig: ProviderConfig = {
messages: {
param: "prompt",
required: true,
transform: (params: Params) => {
let prompt: string = "";
if (!!params.messages) {
let messages: Message[] = params.messages;
messages.forEach((msg, index) => {
if (index === 0 && msg.role === "system") {
prompt += `System: ${msg.content}\n`;
} else if (msg.role == "user") {
prompt += `\n\nHuman: ${msg.content}\n`;
} else if (msg.role == "assistant") {
prompt += `Assistant: ${msg.content}\n`;
} else {
prompt += `${msg.role}: ${msg.content}\n`;
}
});
prompt += "Assistant:";
}
return prompt;
messages: [
{
param: "messages",
required: true,
transform: (params: Params) => {
let messages: Message[] = [];
// Transform the chat messages into a simple prompt
if (!!params.messages) {
params.messages.forEach((msg) => {
if (msg.role !== "system") {
if (
msg.content &&
typeof msg.content === "object" &&
msg.content.length
) {
const transformedMessage: Record<string, any> =
{
role: msg.role,
content: [],
};
msg.content.forEach((item) => {
if (item.type === "text") {
transformedMessage.content.push({
type: item.type,
text: item.text,
});
} else if (
item.type === "image_url" &&
item.image_url &&
item.image_url.url
) {
const parts =
item.image_url.url.split(";");
if (parts.length === 2) {
const base64ImageParts =
parts[1].split(",");
const base64Image =
base64ImageParts[1];
const mediaTypeParts =
parts[0].split(":");
if (
mediaTypeParts.length === 2 &&
base64Image
) {
const mediaType =
mediaTypeParts[1];
transformedMessage.content.push(
{
type: "image",
source: {
type: "base64",
media_type:
mediaType,
data: base64Image,
},
}
);
}
}
}
});
messages.push(transformedMessage as Message);
} else {
messages.push({
role: msg.role,
content: msg.content,
});
}
}
});
}

return messages;
},
},
},
{
param: "system",
required: false,
transform: (params: Params) => {
let systemMessage: string = "";
// Transform the chat messages into a simple prompt
if (!!params.messages) {
params.messages.forEach((msg) => {
if (msg.role === "system") {
systemMessage = msg.content as string;
}
});
}
return systemMessage;
},
},
],
max_tokens: {
param: "max_tokens_to_sample",
param: "max_tokens",
required: true,
},
temperature: {
Expand Down Expand Up @@ -74,6 +140,11 @@ export const BedrockAnthropicChatCompleteConfig: ProviderConfig = {
user: {
param: "metadata.user_id",
},
anthropic_version: {
param: "anthropic_version",
required: true,
default: "bedrock-2023-05-31",
},
};

export const BedrockCohereChatCompleteConfig: ProviderConfig = {
Expand Down Expand Up @@ -610,8 +681,21 @@ export const BedrockAI21ChatCompleteResponseTransform: (
} as ErrorResponse;
};

interface BedrockAnthropicChatCompleteResponse {
id: string;
type: string;
role: string;
content: {
type: string;
text: string;
}[];
stop_reason: string;
model: string;
stop_sequence: null | string;
}

export const BedrockAnthropicChatCompleteResponseTransform: (
response: BedrockAnthropicCompleteResponse | BedrockErrorResponse,
response: BedrockAnthropicChatCompleteResponse | BedrockErrorResponse,
responseStatus: number,
responseHeaders: Headers
) => ChatCompletionResponse | ErrorResponse = (
Expand All @@ -631,26 +715,27 @@ export const BedrockAnthropicChatCompleteResponseTransform: (
} as ErrorResponse;
}

if ("completion" in response) {
if ("content" in response) {
const prompt_tokens =
Number(responseHeaders.get("X-Amzn-Bedrock-Input-Token-Count")) ||
0;
const completion_tokens =
Number(responseHeaders.get("X-Amzn-Bedrock-Output-Token-Count")) ||
0;
return {
id: Date.now().toString(),
id: response.id,
object: "chat.completion",
created: Math.floor(Date.now() / 1000),
model: "",
model: response.model,
provider: BEDROCK,
choices: [
{
index: 0,
message: {
role: "assistant",
content: response.completion,
content: response.content[0].text,
},
index: 0,
logprobs: null,
finish_reason: response.stop_reason,
},
],
Expand All @@ -675,32 +760,41 @@ export const BedrockAnthropicChatCompleteResponseTransform: (
} as ErrorResponse;
};

interface BedrockAnthropicChatCompleteStreamResponse {
type: string;
index: number;
delta: {
type: string;
text: string;
stop_reason?: string;
};
"amazon-bedrock-invocationMetrics": {
inputTokenCount: number;
outputTokenCount: number;
invocationLatency: number;
firstByteLatency: number;
};
}

export const BedrockAnthropicChatCompleteStreamChunkTransform: (
response: string,
fallbackId: string
) => string | string[] = (responseChunk, fallbackId) => {
let chunk = responseChunk.trim();

const parsedChunk: BedrockAnthropicStreamChunk = JSON.parse(chunk);
if (parsedChunk.stop_reason) {
const parsedChunk: BedrockAnthropicChatCompleteStreamResponse =
JSON.parse(chunk);
if (
parsedChunk.type === "ping" ||
parsedChunk.type === "message_start" ||
parsedChunk.type === "content_block_start" ||
parsedChunk.type === "content_block_stop"
) {
return [];
}

if (parsedChunk.type === "message_stop") {
return [
`data: ${JSON.stringify({
id: fallbackId,
object: "chat.completion.chunk",
created: Math.floor(Date.now() / 1000),
model: "",
provider: BEDROCK,
choices: [
{
index: 0,
delta: {
role: "assistant",
content: parsedChunk.completion,
},
finish_reason: null,
},
],
})}\n\n`,
`data: ${JSON.stringify({
id: fallbackId,
object: "chat.completion.chunk",
Expand All @@ -711,7 +805,7 @@ export const BedrockAnthropicChatCompleteStreamChunkTransform: (
{
index: 0,
delta: {},
finish_reason: parsedChunk.stop_reason,
finish_reason: parsedChunk.delta?.stop_reason,
},
],
usage: {
Expand All @@ -728,27 +822,51 @@ export const BedrockAnthropicChatCompleteStreamChunkTransform: (
.outputTokenCount,
},
})}\n\n`,
`data: [DONE]\n\n`,
"data: [DONE]\n\n",
];
}

return `data: ${JSON.stringify({
id: fallbackId,
object: "chat.completion.chunk",
created: Math.floor(Date.now() / 1000),
model: "",
provider: BEDROCK,
choices: [
{
index: 0,
delta: {
role: "assistant",
content: parsedChunk.completion,
if (parsedChunk.delta?.stop_reason) {
return [
`data: ${JSON.stringify({
id: fallbackId,
object: "chat.completion.chunk",
created: Math.floor(Date.now() / 1000),
model: "",
provider: BEDROCK,
choices: [
{
delta: {
content: parsedChunk.delta?.text,
},
index: 0,
logprobs: null,
finish_reason: parsedChunk.delta?.stop_reason ?? null,
},
],
})}\n\n`,
];
}

return (
`data: ${JSON.stringify({
id: fallbackId,
object: "chat.completion.chunk",
created: Math.floor(Date.now() / 1000),
model: "",
provider: BEDROCK,
choices: [
{
delta: {
content: parsedChunk.delta?.text,
},
index: 0,
logprobs: null,
finish_reason: parsedChunk.delta?.stop_reason ?? null,
},
finish_reason: null,
},
],
})}\n\n`;
],
})}\n\n`
);
};

export const BedrockCohereChatCompleteResponseTransform: (
Expand Down

0 comments on commit c6bd75f

Please sign in to comment.