diff --git a/genkit-tools/cli/README.md b/genkit-tools/cli/README.md index 6d7c7a3e0a..aa7b2bcb27 100644 --- a/genkit-tools/cli/README.md +++ b/genkit-tools/cli/README.md @@ -44,6 +44,8 @@ Available commands: evaluate a flow against configured evaluators using provided data as input +**Parallelism tip:** for both `eval:run` (when your dataset already includes outputs) and `eval:flow`, setting `--batchSize` greater than 1 runs inference and evaluator actions in parallel (capped at 100). Higher values can speed up runs but may hit model/API rate limits or increase resource usage—tune according to your environment. + - `config` set development environment configuration diff --git a/genkit-tools/cli/src/commands/eval-flow.ts b/genkit-tools/cli/src/commands/eval-flow.ts index 53c7eacff8..8eb27d15bf 100644 --- a/genkit-tools/cli/src/commands/eval-flow.ts +++ b/genkit-tools/cli/src/commands/eval-flow.ts @@ -154,6 +154,7 @@ export const evalFlow = new Command('eval:flow') actionRef, inferenceDataset, context: options.context, + batchSize: options.batchSize, }); const evalRun = await runEvaluation({ diff --git a/genkit-tools/common/src/eval/evaluate.ts b/genkit-tools/common/src/eval/evaluate.ts index 596f33cfc9..bfd506b8ed 100644 --- a/genkit-tools/common/src/eval/evaluate.ts +++ b/genkit-tools/common/src/eval/evaluate.ts @@ -67,6 +67,7 @@ interface FullInferenceSample { const SUPPORTED_ACTION_TYPES = ['flow', 'model', 'executable-prompt'] as const; type SupportedActionType = (typeof SUPPORTED_ACTION_TYPES)[number]; const GENERATE_ACTION_UTIL = '/util/generate'; +const MAX_CONCURRENCY = 100; /** * Starts a new evaluation run. Intended to be used via the reflection API. @@ -119,6 +120,7 @@ export async function runNewEvaluation( inferenceDataset, context: request.options?.context, actionConfig: request.options?.actionConfig, + batchSize: request.options?.batchSize, }); const evaluatorActions = await getMatchingEvaluatorActions( manager, @@ -146,11 +148,20 @@ export async function runInference(params: { inferenceDataset: Dataset; context?: string; actionConfig?: any; + batchSize?: number; }): Promise { - const { manager, actionRef, inferenceDataset, context, actionConfig } = - params; + const { + manager, + actionRef, + inferenceDataset, + context, + actionConfig, + batchSize, + } = params; if (!isSupportedActionRef(actionRef)) { - throw new Error('Inference is only supported on flows and models'); + throw new Error( + 'Inference is only supported on flows, models, and executable prompts' + ); } const evalDataset: EvalInput[] = await bulkRunAction({ @@ -159,6 +170,7 @@ export async function runInference(params: { inferenceDataset, context, actionConfig, + batchSize, }); return evalDataset; } @@ -183,21 +195,23 @@ export async function runEvaluation(params: { const runtime = manager.getMostRecentRuntime(); const isNodeRuntime = runtime?.genkitVersion?.startsWith('nodejs') ?? false; - for (const action of evaluatorActions) { - const name = evaluatorName(action); - const response = await manager.runAction({ - key: name, - input: { - dataset: evalDataset.filter((row) => !row.error), - evalRunId, - batchSize: isNodeRuntime ? batchSize : undefined, - }, - }); - scores[name] = response.result; - logger.info( - `Finished evaluator '${action.name}'. Trace ID: ${response.telemetry?.traceId}` - ); - } + await Promise.all( + evaluatorActions.map(async (action) => { + const name = evaluatorName(action); + const response = await manager.runAction({ + key: name, + input: { + dataset: evalDataset.filter((row) => !row.error), + evalRunId, + batchSize: isNodeRuntime ? batchSize : undefined, + }, + }); + scores[name] = response.result; + logger.info( + `Finished evaluator '${action.name}'. Trace ID: ${response.telemetry?.traceId}` + ); + }) + ); const scoredResults = enrichResultsWithScoring(scores, evalDataset); const metadata = extractMetricsMetadata(evaluatorActions); @@ -258,55 +272,87 @@ async function bulkRunAction(params: { inferenceDataset: Dataset; context?: string; actionConfig?: any; + batchSize?: number; }): Promise { - const { manager, actionRef, inferenceDataset, context, actionConfig } = - params; + const { + manager, + actionRef, + inferenceDataset, + context, + actionConfig, + batchSize, + } = params; const actionType = getSupportedActionType(actionRef); if (inferenceDataset.length === 0) { throw new Error('Cannot run inference, no data provided'); } + const desiredConcurrency = Math.max(1, batchSize ?? 1); + const resolvedConcurrency = Math.min(desiredConcurrency, MAX_CONCURRENCY); + // Convert to satisfy TS checks. `input` is required in `Dataset` type, but // ZodAny also includes `undefined` in TS checks. This explcit conversion // works around this. const fullInferenceDataset = inferenceDataset as FullInferenceSample[]; - const states: InferenceRunState[] = []; + const total = fullInferenceDataset.length; + logger.info( + `Running inference '${actionRef}' on ${total} samples with concurrency=${resolvedConcurrency}...` + ); + let completed = 0; + + const states: InferenceRunState[] = new Array(total); const evalInputs: EvalInput[] = []; - for (const sample of fullInferenceDataset) { - logger.info(`Running inference '${actionRef}' ...`); - if (actionType === 'model') { - states.push( - await runModelAction({ + const runSample = async (sample: FullInferenceSample, index: number) => { + try { + logger.info(`Running inference '${actionRef}' ...`); + if (actionType === 'model') { + states[index] = await runModelAction({ manager, actionRef, sample, modelConfig: actionConfig, - }) - ); - } else if (actionType === 'flow') { - states.push( - await runFlowAction({ + }); + } else if (actionType === 'flow') { + states[index] = await runFlowAction({ manager, actionRef, sample, context, - }) - ); - } else { - // executable-prompt action - states.push( - await runPromptAction({ + }); + } else { + // executable-prompt action + states[index] = await runPromptAction({ manager, actionRef, sample, context, promptConfig: actionConfig, - }) - ); + }); + } + completed++; + if (completed % 10 === 0 || completed === total) { + logger.info(`Inference progress: ${completed}/${total} completed`); + } + } catch (error: any) { + completed++; + logger.error(`Inference failed for sample ${index}:`, error); + states[index] = { + testCaseId: sample.testCaseId, + input: sample.input, + reference: sample.reference, + traceIds: [], + evalError: error instanceof Error ? error.message : String(error), + }; } - } + }; + for (let i = 0; i < total; i += resolvedConcurrency) { + const batch = fullInferenceDataset.slice(i, i + resolvedConcurrency); + await Promise.all( + batch.map((sample, offset) => runSample(sample, i + offset)) + ); + } logger.info(`Gathering evalInputs...`); for (const state of states) { evalInputs.push(await gatherEvalInput({ manager, actionRef, state })); diff --git a/genkit-tools/common/tests/eval/evaluate_test.ts b/genkit-tools/common/tests/eval/evaluate_test.ts new file mode 100644 index 0000000000..2c6295e43d --- /dev/null +++ b/genkit-tools/common/tests/eval/evaluate_test.ts @@ -0,0 +1,176 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { describe, expect, it, jest } from '@jest/globals'; + +// Mock utils used inside evaluate.ts to avoid touching real traces/config. +jest.mock('../../src/utils', () => { + const logger = { + info: jest.fn(), + warn: jest.fn(), + error: jest.fn(), + }; + return { + evaluatorName: (action: any) => `/evaluator/${action.name}`, + generateTestCaseId: () => 'test-case', + getEvalExtractors: jest.fn(async () => ({ + input: (trace: any) => trace.mockInput, + output: (trace: any) => trace.mockOutput, + context: () => [], + })), + getModelInput: (data: any) => data, + hasAction: jest.fn().mockResolvedValue(true), + isEvaluator: (key: string) => key.startsWith('/evaluator'), + logger, + stackTraceSpans: jest.fn(() => ({ attributes: {}, spans: [] })), + }; +}); + +import type { Action, EvalInput } from '../../src/types'; +import * as evaluate from '../../src/eval/evaluate'; + +const bulkRunAction = (evaluate as any) + .bulkRunAction as (args: any) => Promise; + +function createMockManager() { + return { + runAction: jest.fn(), + getTrace: jest.fn(), + getMostRecentRuntime: jest.fn(() => ({ genkitVersion: 'nodejs-1.0' })), + }; +} + +function createAction(name: string): Action { + return { + key: `/evaluator/${name}`, + name, + description: '', + inputSchema: null, + outputSchema: null, + metadata: null, + }; +} + +describe('bulkRunAction', () => { + it('runs samples in batches respecting batchSize', async () => { + const manager = createMockManager(); + const delayMs = 40; + manager.runAction.mockImplementation(async (_req: any) => { + await new Promise((resolve) => setTimeout(resolve, delayMs)); + return { + result: 'ok', + telemetry: { traceId: 'trace' }, + }; + }); + manager.getTrace.mockResolvedValue({ + spans: {}, + mockInput: 'input', + mockOutput: 'output', + }); + + const dataset = Array.from({ length: 4 }, (_, i) => ({ + testCaseId: `case-${i}`, + input: { value: i }, + })); + + const start = Date.now(); + const results: EvalInput[] = await bulkRunAction({ + manager: manager as any, + actionRef: '/flow/test', + inferenceDataset: dataset as any, + batchSize: 2, + }); + const duration = Date.now() - start; + + expect(results).toHaveLength(4); + // With batchSize 2, the total time should be roughly two batches of delayMs. + expect(duration).toBeLessThan(delayMs * 4); // faster than fully sequential + expect(manager.runAction).toHaveBeenCalledTimes(4); + }); + + it('continues processing after an error', async () => { + const manager = createMockManager(); + manager.runAction + .mockImplementationOnce(async () => { + throw new Error('boom'); + }) + .mockImplementation(async () => ({ + result: 'ok', + telemetry: { traceId: 'trace' }, + })); + manager.getTrace.mockResolvedValue({ + spans: {}, + mockInput: 'input', + mockOutput: 'output', + }); + + const dataset = [ + { testCaseId: 'case-1', input: {} }, + { testCaseId: 'case-2', input: {} }, + { testCaseId: 'case-3', input: {} }, + ]; + + const results: EvalInput[] = await bulkRunAction({ + manager: manager as any, + actionRef: '/flow/test', + inferenceDataset: dataset as any, + batchSize: 2, + }); + + expect(results).toHaveLength(3); + expect(results.some((r) => r.error)).toBe(true); + expect(manager.runAction).toHaveBeenCalledTimes(3); + }); +}); + +describe('runEvaluation', () => { + it('executes evaluator actions in parallel', async () => { + const manager = createMockManager(); + let started = 0; + let release!: () => void; + const gate = new Promise((resolve) => { + release = resolve; + }); + + manager.runAction.mockImplementation(async () => { + started++; + if (started === 2) { + release(); + } + await gate; + return { result: { ok: true }, telemetry: { traceId: 'trace' } }; + }); + + const actions = [createAction('a'), createAction('b')]; + const evalDataset: EvalInput[] = [ + { testCaseId: 't1', input: 'in', output: 'out', traceIds: ['trace'] }, + ]; + + const evalPromise = evaluate.runEvaluation({ + manager: manager as any, + evaluatorActions: actions, + evalDataset, + }); + + // Give both runAction calls a moment to start and block on the gate. + await new Promise((resolve) => setTimeout(resolve, 10)); + expect(manager.runAction).toHaveBeenCalledTimes(2); + + // Unblock both and finish. + release(); + await evalPromise; + }); +});