Skip to content

Commit

Permalink
feat: models with inferenceRegion
Browse files Browse the repository at this point in the history
  • Loading branch information
xavidop committed Dec 27, 2024
1 parent 9087bdf commit 723b461
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 39 deletions.
90 changes: 54 additions & 36 deletions src/aws_bedrock_llms.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import {
Role,
ToolRequestPart,
Genkit,
ModelReference,
} from "genkit";

import {
Expand Down Expand Up @@ -68,42 +69,56 @@ export const amazonNovaProV1 = modelRef({
configSchema: GenerationCommonConfigSchema,
});

export const anthropicClaude35HaikuV1 = modelRef({
name: "aws-bedrock/us.anthropic.claude-3-5-haiku-20241022-v1:0",
info: {
versions: ["us.anthropic.claude-3-5-haiku-20241022-v1:0"],
label: "Anthropic - Claude 3.5 Haiku V1",
supports: {
multiturn: true,
tools: true,
media: false,
systemRole: true,
output: ["text", "json"],
export const anthropicClaude35HaikuV1 = (
infrenceRegion: string = "us",
): ModelReference<typeof GenerationCommonConfigSchema> => {
return modelRef({
name: `aws-bedrock/${infrenceRegion}.anthropic.claude-3-5-haiku-20241022-v1:0`,
info: {
versions: [`${infrenceRegion}.anthropic.claude-3-5-haiku-20241022-v1:0`],
label: "Anthropic - Claude 3.5 Haiku V1",
supports: {
multiturn: true,
tools: true,
media: false,
systemRole: true,
output: ["text", "json"],
},
},
},
configSchema: GenerationCommonConfigSchema,
});
configSchema: GenerationCommonConfigSchema,
});
};

export const anthropicClaude35SonnetV1 = modelRef({
name: "aws-bedrock/us.anthropic.claude-3-5-sonnet-20241022-v2:0",
info: {
versions: ["us.anthropic.claude-3-5-sonnet-20241022-v2:0"],
label: "Anthropic - Claude 3.5 Haiku V1",
supports: {
multiturn: true,
tools: true,
media: true,
systemRole: true,
output: ["text", "json"],
export const anthropicClaude35SonnetV1 = (
infrenceRegion: string = "us",
): ModelReference<typeof GenerationCommonConfigSchema> => {
return modelRef({
name: `aws-bedrock/${infrenceRegion}.anthropic.claude-3-5-sonnet-20241022-v2:0`,
info: {
versions: [`${infrenceRegion}anthropic.claude-3-5-sonnet-20241022-v2:0`],
label: "Anthropic - Claude 3.5 Haiku V1",
supports: {
multiturn: true,
tools: true,
media: true,
systemRole: true,
output: ["text", "json"],
},
},
},
configSchema: GenerationCommonConfigSchema,
});
configSchema: GenerationCommonConfigSchema,
});
};

export const SUPPORTED_AWS_BEDROCK_MODELS: Record<string, any> = {
"amazon.nova-pro-v1:0": amazonNovaProV1,
"us.anthropic.claude-3-5-haiku-20241022-v1:0": anthropicClaude35HaikuV1,
"us.anthropic.claude-3-5-sonnet-20241022-v2:0": anthropicClaude35SonnetV1,
export const SUPPORTED_AWS_BEDROCK_MODELS = (
infrenceRegion: string = "us",
): Record<string, any> => {
return {
"amazon.nova-pro-v1:0": amazonNovaProV1,
[`${infrenceRegion}.anthropic.claude-3-5-haiku-20241022-v1:0`]:
anthropicClaude35HaikuV1(infrenceRegion),
[`${infrenceRegion}.anthropic.claude-3-5-sonnet-20241022-v2:0`]:
anthropicClaude35SonnetV1(infrenceRegion),
};
};

function toAwsBedrockbRole(role: Role): string {
Expand Down Expand Up @@ -356,8 +371,9 @@ function fromAwsBedrockChunkChoice(
export function toAwsBedrockRequestBody(
modelName: string,
request: GenerateRequest<typeof GenerationCommonConfigSchema>,
inferenceRegion: string,
) {
const model = SUPPORTED_AWS_BEDROCK_MODELS[modelName];
const model = SUPPORTED_AWS_BEDROCK_MODELS(inferenceRegion)[modelName];
if (!model) throw new Error(`Unsupported model: ${modelName}`);
const awsBedrockMessages = toAwsBedrockMessages(request.messages);

Expand Down Expand Up @@ -416,20 +432,22 @@ export function awsBedrockModel(
name: string,
client: BedrockRuntimeClient,
ai: Genkit,
inferenceRegion: string,
): ModelAction<typeof GenerationCommonConfigSchema> {
const modelId = `aws-bedrock/${name}`;
const model = SUPPORTED_AWS_BEDROCK_MODELS[name];
const model = SUPPORTED_AWS_BEDROCK_MODELS(inferenceRegion)[name];
if (!model) throw new Error(`Unsupported model: ${name}`);

return ai.defineModel(
{
name: modelId,
...model.info,
configSchema: SUPPORTED_AWS_BEDROCK_MODELS[name].configSchema,
configSchema:
SUPPORTED_AWS_BEDROCK_MODELS(inferenceRegion)[name].configSchema,
},
async (request, streamingCallback) => {
let response: ConverseStreamCommandOutput | ConverseCommandOutput;
const body = toAwsBedrockRequestBody(name, request);
const body = toAwsBedrockRequestBody(name, request, inferenceRegion);
if (streamingCallback) {
const command = new ConverseStreamCommand(body);
response = await client.send(command);
Expand Down
16 changes: 13 additions & 3 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,19 @@ export function awsBedrock(options?: PluginOptions) {
return genkitPlugin("aws-bedrock", async (ai: Genkit) => {
const client = new BedrockRuntimeClient(options || {});

Object.keys(SUPPORTED_AWS_BEDROCK_MODELS).forEach((name) => {
awsBedrockModel(name, client, ai);
});
const region =
typeof options?.region === "string"
? options.region
: typeof options?.region === "function"
? await options.region()
: undefined;
const inferenceRegion = region ? region.substring(0, 2) : "us";

Object.keys(SUPPORTED_AWS_BEDROCK_MODELS(inferenceRegion)).forEach(
(name) => {
awsBedrockModel(name, client, ai, inferenceRegion);
},
);

Object.keys(SUPPORTED_EMBEDDING_MODELS).forEach((name) =>
awsBedrockEmbedder(name, ai, client),
Expand Down

0 comments on commit 723b461

Please sign in to comment.