Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 11 additions & 14 deletions js/plugins/google-genai/src/googleai/embedder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ import {
embedderActionMetadata,
EmbedderInfo,
EmbedderReference,
Genkit,
z,
} from 'genkit';
import { embedderRef } from 'genkit/embedder';
import { embedder as pluginEmbedder } from 'genkit/plugin';
import { embedContent } from './client.js';
import {
EmbedContentRequest,
Expand Down Expand Up @@ -122,43 +122,40 @@ export function listActions(models: Model[]): ActionMetadata[] {
);
}

export function defineKnownModels(ai: Genkit, options?: GoogleAIPluginOptions) {
for (const name of Object.keys(KNOWN_MODELS)) {
defineEmbedder(ai, name, options);
}
export function listKnownModels(options?: GoogleAIPluginOptions) {
return Object.keys(KNOWN_MODELS).map((name) => defineEmbedder(name, options));
}

export function defineEmbedder(
ai: Genkit,
name: string,
pluginOptions?: GoogleAIPluginOptions
): EmbedderAction {
checkApiKey(pluginOptions?.apiKey);
const ref = model(name);

return ai.defineEmbedder(
return pluginEmbedder(
{
name: ref.name,
configSchema: ref.configSchema,
info: ref.info,
},
async (input, reqOptions) => {
async (request, _) => {
const embedApiKey = calculateApiKey(
pluginOptions?.apiKey,
reqOptions?.apiKey
request.options?.apiKey
);
const embedVersion = reqOptions?.version || extractVersion(ref);
const embedVersion = request.options?.version || extractVersion(ref);

const embeddings = await Promise.all(
input.map(async (doc) => {
request.input.map(async (doc) => {
const response = await embedContent(embedApiKey, embedVersion, {
taskType: reqOptions?.taskType,
title: reqOptions?.title,
taskType: request.options?.taskType,
title: request.options?.title,
content: {
role: '',
parts: [{ text: doc.text }],
},
outputDimensionality: reqOptions?.outputDimensionality,
outputDimensionality: request.options?.outputDimensionality,
} as EmbedContentRequest);
const values = response.embedding.values;
return { embedding: values };
Expand Down
22 changes: 7 additions & 15 deletions js/plugins/google-genai/src/googleai/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,7 @@
* limitations under the License.
*/

import {
ActionMetadata,
Genkit,
GenkitError,
modelActionMetadata,
z,
} from 'genkit';
import { ActionMetadata, GenkitError, modelActionMetadata, z } from 'genkit';
import {
GenerationCommonConfigDescriptions,
GenerationCommonConfigSchema,
Expand All @@ -32,6 +26,7 @@ import {
modelRef,
} from 'genkit/model';
import { downloadRequestMedia } from 'genkit/model/middleware';
import { model as pluginModel } from 'genkit/plugin';
import { runInNewSpan } from 'genkit/tracing';
import {
fromGeminiCandidate,
Expand Down Expand Up @@ -434,17 +429,16 @@ export function listActions(models: Model[]): ActionMetadata[] {
);
}

export function defineKnownModels(ai: Genkit, options?: GoogleAIPluginOptions) {
for (const name of Object.keys(KNOWN_MODELS)) {
defineModel(ai, name, options);
}
export function listKnownModels(options?: GoogleAIPluginOptions) {
return Object.keys(KNOWN_MODELS).map((name: string) =>
defineModel(name, options)
);
}

/**
* Defines a new GoogleAI Gemini model.
*/
export function defineModel(
ai: Genkit,
name: string,
pluginOptions?: GoogleAIPluginOptions
): ModelAction {
Expand Down Expand Up @@ -482,9 +476,8 @@ export function defineModel(
);
}

return ai.defineModel(
return pluginModel(
{
apiVersion: 'v2',
name: ref.name,
...ref.info,
configSchema: ref.configSchema,
Expand Down Expand Up @@ -660,7 +653,6 @@ export function defineModel(
// API params as for input.
return pluginOptions?.experimental_debugTraces
? await runInNewSpan(
ai.registry,
{
metadata: {
name: streamingRequested ? 'sendMessageStream' : 'sendMessage',
Expand Down
14 changes: 6 additions & 8 deletions js/plugins/google-genai/src/googleai/imagen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import {
MessageData,
modelActionMetadata,
z,
type Genkit,
} from 'genkit';
import {
getBasicUsageStats,
Expand All @@ -30,6 +29,7 @@ import {
type ModelInfo,
type ModelReference,
} from 'genkit/model';
import { model as pluginModel } from 'genkit/plugin';
import { imagenPredict } from './client.js';
import type {
ClientOptions,
Expand Down Expand Up @@ -169,14 +169,13 @@ export function listActions(models: Model[]): ActionMetadata[] {
});
}

export function defineKnownModels(ai: Genkit, options?: GoogleAIPluginOptions) {
for (const name of Object.keys(KNOWN_MODELS)) {
defineModel(ai, name, options);
}
export function listKnownModels(options?: GoogleAIPluginOptions) {
return Object.keys(KNOWN_MODELS).map((name: string) =>
defineModel(name, options)
);
}

export function defineModel(
ai: Genkit,
name: string,
pluginOptions?: GoogleAIPluginOptions
): ModelAction {
Expand All @@ -187,9 +186,8 @@ export function defineModel(
baseUrl: pluginOptions?.baseUrl,
};

return ai.defineModel(
return pluginModel(
{
apiVersion: 'v2',
name: ref.name,
...ref.info,
configSchema: ref.configSchema,
Expand Down
64 changes: 32 additions & 32 deletions js/plugins/google-genai/src/googleai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@
* limitations under the License.
*/

import {
ActionMetadata,
EmbedderReference,
Genkit,
ModelReference,
z,
} from 'genkit';
import { ActionMetadata, EmbedderReference, ModelReference, z } from 'genkit';
import { logger } from 'genkit/logging';
import { GenkitPlugin, genkitPlugin } from 'genkit/plugin';
import {
GenkitPluginV2,
ResolvableAction,
genkitPluginV2,
} from 'genkit/plugin';
import { ActionType } from 'genkit/registry';
import { extractErrMsg } from '../common/utils.js';
import { listModels } from './client.js';
Expand All @@ -42,41 +40,41 @@ export { type GeminiConfig, type GeminiTtsConfig } from './gemini.js';
export { type ImagenConfig } from './imagen.js';
export { type GoogleAIPluginOptions };

async function initializer(ai: Genkit, options?: GoogleAIPluginOptions) {
imagen.defineKnownModels(ai, options);
gemini.defineKnownModels(ai, options);
embedder.defineKnownModels(ai, options);
veo.defineKnownModels(ai, options);
async function initializer(options?: GoogleAIPluginOptions) {
return [
...imagen.listKnownModels(options),
...gemini.listKnownModels(options),
...embedder.listKnownModels(options),
...veo.listKnownModels(options),
];
}

async function resolver(
ai: Genkit,
actionType: ActionType,
actionName: string,
options: GoogleAIPluginOptions
) {
): Promise<ResolvableAction | undefined> {
switch (actionType) {
case 'model':
if (veo.isVeoModelName(actionName)) {
// no-op (not gemini)
return undefined;
} else if (imagen.isImagenModelName(actionName)) {
imagen.defineModel(ai, actionName, options);
return await imagen.defineModel(actionName, options);
} else {
// gemini, tts, gemma, unknown models
gemini.defineModel(ai, actionName, options);
return await gemini.defineModel(actionName, options);
}
break;
case 'background-model':
if (veo.isVeoModelName(actionName)) {
veo.defineModel(ai, actionName, options);
return await veo.defineModel(actionName, options);
}
break;
case 'embedder':
embedder.defineEmbedder(ai, actionName, options);
return await embedder.defineEmbedder(actionName, options);
break;
default:
// no-op
}
return undefined;
}

async function listActions(
Expand Down Expand Up @@ -104,23 +102,25 @@ async function listActions(
/**
* Google Gemini Developer API plugin.
*/
export function googleAIPlugin(options?: GoogleAIPluginOptions): GenkitPlugin {
export function googleAIPlugin(
options?: GoogleAIPluginOptions
): GenkitPluginV2 {
let listActionsCache;
return genkitPlugin(
'googleai',
async (ai: Genkit) => await initializer(ai, options),
async (ai: Genkit, actionType: ActionType, actionName: string) =>
await resolver(ai, actionType, actionName, options || {}),
async () => {
return genkitPluginV2({
name: 'googleai',
init: async () => await initializer(options),
resolve: async (actionType: ActionType, actionName: string) =>
await resolver(actionType, actionName, options || {}),
list: async () => {
if (listActionsCache) return listActionsCache;
listActionsCache = await listActions(options);
return listActionsCache;
}
);
},
});
}

export type GoogleAIPlugin = {
(pluginOptions?: GoogleAIPluginOptions): GenkitPlugin;
(pluginOptions?: GoogleAIPluginOptions): GenkitPluginV2;
model(
name: gemini.KnownGemmaModels | (gemini.GemmaModelName & {}),
config: gemini.GemmaConfig
Expand Down
13 changes: 6 additions & 7 deletions js/plugins/google-genai/src/googleai/veo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import {
Operation,
modelActionMetadata,
z,
type Genkit,
} from 'genkit';
import {
BackgroundModelAction,
Expand All @@ -29,6 +28,7 @@ import {
type ModelInfo,
type ModelReference,
} from 'genkit/model';
import { backgroundModel as pluginBackgroundModel } from 'genkit/plugin';
import { veoCheckOperation, veoPredict } from './client.js';
import {
ClientOptions,
Expand Down Expand Up @@ -156,17 +156,16 @@ export function listActions(models: Model[]): ActionMetadata[] {
);
}

export function defineKnownModels(ai: Genkit, options?: GoogleAIPluginOptions) {
for (const name of Object.keys(KNOWN_MODELS)) {
defineModel(ai, name, options);
}
export function listKnownModels(options?: GoogleAIPluginOptions) {
return Object.keys(KNOWN_MODELS).map((name: string) =>
defineModel(name, options)
);
}

/**
* Defines a new GoogleAI Veo model.
*/
export function defineModel(
ai: Genkit,
name: string,
pluginOptions?: GoogleAIPluginOptions
): BackgroundModelAction<VeoConfigSchemaType> {
Expand All @@ -176,7 +175,7 @@ export function defineModel(
baseUrl: pluginOptions?.baseUrl,
};

return ai.defineBackgroundModel({
return pluginBackgroundModel({
name: ref.name,
...ref.info,
configSchema: ref.configSchema,
Expand Down
Loading