Skip to content

Commit

Permalink
feat: embedding models working
Browse files Browse the repository at this point in the history
  • Loading branch information
xavidop committed Dec 26, 2024
1 parent d57ec20 commit a708504
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 100 deletions.
172 changes: 83 additions & 89 deletions src/aws_bedrock_embedders.ts
Original file line number Diff line number Diff line change
@@ -1,98 +1,92 @@
// /**
// * Copyright 2024 The Fire Company
// *
// * Licensed under the Apache License, Version 2.0 (the "License");
// * you may not use this file except in compliance with the License.
// * You may obtain a copy of the License at
// *
// * http://www.apache.org/licenses/LICENSE-2.0
// *
// * Unless required by applicable law or agreed to in writing, software
// * distributed under the License is distributed on an "AS IS" BASIS,
// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// * See the License for the specific language governing permissions and
// * limitations under the License.
// */
// /* eslint-disable @typescript-eslint/no-explicit-any */
/**
* Copyright 2024 The Fire Company
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* eslint-disable @typescript-eslint/no-explicit-any */

// import { embedderRef, Genkit } from "genkit";
// import ModelClient, {
// GetEmbeddings200Response,
// GetEmbeddingsParameters,
// } from "@azure-rest/ai-inference";
// import { z } from "zod";
// import { type PluginOptions } from "./index.js";
// import { AzureKeyCredential } from "@azure/core-auth";
import { embedderRef, Genkit } from "genkit";

// export const TextEmbeddingConfigSchema = z.object({
// dimensions: z.number().optional(),
// encodingFormat: z.union([z.literal("float"), z.literal("base64")]).optional(),
// });
import { z } from "zod";
import {
BedrockRuntimeClient,
InvokeModelCommand,
InvokeModelCommandInput,
InvokeModelCommandOutput,
} from "@aws-sdk/client-bedrock-runtime";

// export type TextEmbeddingGeckoConfig = z.infer<
// typeof TextEmbeddingConfigSchema
// >;
export const TextEmbeddingConfigSchema = z.object({
dimensions: z.number().optional(),
});

// export const TextEmbeddingInputSchema = z.string();
export type TextEmbeddingGeckoConfig = z.infer<
typeof TextEmbeddingConfigSchema
>;

// export const openAITextEmbedding3Small = embedderRef({
// name: "github/text-embedding-3-small",
// configSchema: TextEmbeddingConfigSchema,
// info: {
// dimensions: 1536,
// label: "OpenAI - Text-embedding-3-small",
// supports: {
// input: ["text"],
// },
// },
// });
export const TextEmbeddingInputSchema = z.string();

// export const SUPPORTED_EMBEDDING_MODELS: Record<string, any> = {
// "text-embedding-3-small": openAITextEmbedding3Small,
// };
export const amazonTitanEmbedTextV2 = embedderRef({
name: "aws-bedrock/amazon.titan-embed-text-v2:0",
configSchema: TextEmbeddingConfigSchema,
info: {
dimensions: 512,
label: "Amazon - titan-embed-text-v2:0",
supports: {
input: ["text"],
},
},
});

// export function awsBedrockEmbedder(
// name: string,
// ai: Genkit,
// options?: PluginOptions,
// ) {
// const token = options?.githubToken || process.env.GITHUB_TOKEN;
// let endpoint = options?.endpoint || process.env.GITHUB_ENDPOINT;
// if (!token) {
// throw new Error(
// "Please pass in the TOKEN key or set the GITHUB_TOKEN environment variable",
// );
// }
// if (!endpoint) {
// endpoint = "https://models.inference.ai.azure.com";
// }
export const SUPPORTED_EMBEDDING_MODELS: Record<string, any> = {
"amazon.titan-embed-text-v2:0": amazonTitanEmbedTextV2,
};

// const client = ModelClient(endpoint, new AzureKeyCredential(token));
// const model = SUPPORTED_EMBEDDING_MODELS[name];
export function awsBedrockEmbedder(
name: string,
ai: Genkit,
client: BedrockRuntimeClient,
) {
const model = SUPPORTED_EMBEDDING_MODELS[name];

// return ai.defineEmbedder(
// {
// info: model.info!,
// configSchema: TextEmbeddingConfigSchema,
// name: model.name,
// },
// async (input, options) => {
// const body = {
// body: {
// model: name,
// input: input.map((d) => d.text),
// dimensions: options?.dimensions,
// encoding_format: options?.encodingFormat,
// },
// } as GetEmbeddingsParameters;
// const embeddings = (await client
// .path("/embeddings")
// .post(body)) as GetEmbeddings200Response;
// return {
// embeddings: embeddings.body.data.map((d) => ({
// embedding: Array.isArray(d.embedding) ? d.embedding : [],
// })),
// };
// },
// );
// }
return ai.defineEmbedder(
{
info: model.info!,
configSchema: TextEmbeddingConfigSchema,
name: model.name,
},
async (input, options) => {
const body: InvokeModelCommandInput = {
modelId: name,
contentType: "application/json",
body: JSON.stringify({
inputText: input.map((d) => d.text).join(","),
dimensions: options?.dimensions,
}),
};

const command = new InvokeModelCommand(body);

const response = (await client.send(command)) as InvokeModelCommandOutput;
const embeddings = new TextDecoder().decode(response.body)
? JSON.parse(new TextDecoder().decode(response.body))
: [];
return {
embeddings: [
{
embedding: embeddings.embedding as number[],
},
],
};
},
);
}
3 changes: 1 addition & 2 deletions src/aws_bedrock_llms.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
* limitations under the License.
*/
/* eslint-disable @typescript-eslint/no-explicit-any */
import * as fs from "fs";

import {
Message,
Expand Down Expand Up @@ -103,7 +102,7 @@ export function toAwsBedrockTextAndMedia(
text: part.text,
};
} else if (part.media) {
const imageBuffer = new Uint8Array(fs.readFileSync(part.media.url).buffer);
const imageBuffer = new Uint8Array(Buffer.from(part.media.url, "base64"));

return {
image: {
Expand Down
18 changes: 9 additions & 9 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ import {
amazonNovaProV1,
SUPPORTED_AWS_BEDROCK_MODELS,
} from "./aws_bedrock_llms.js";
// import {
// awsBedrockEmbedder,
// openAITextEmbedding3Small,
// SUPPORTED_EMBEDDING_MODELS,
// } from "./aws_bedrock_embedders.js";
import {
awsBedrockEmbedder,
amazonTitanEmbedTextV2,
SUPPORTED_EMBEDDING_MODELS,
} from "./aws_bedrock_embedders.js";

export { amazonNovaProV1 };

// export { openAITextEmbedding3Small };
export { amazonTitanEmbedTextV2 };

export type PluginOptions = BedrockRuntimeClientConfig;

Expand All @@ -30,9 +30,9 @@ export function awsBedrock(options?: PluginOptions) {
awsBedrockModel(name, client, ai);
});

// Object.keys(SUPPORTED_EMBEDDING_MODELS).forEach((name) =>
// awsBedrockEmbedder(name, ai, options),
// );
Object.keys(SUPPORTED_EMBEDDING_MODELS).forEach((name) =>
awsBedrockEmbedder(name, ai, client),
);
});
}

Expand Down

0 comments on commit a708504

Please sign in to comment.