diff --git a/.changeset/curvy-coins-boil.md b/.changeset/curvy-coins-boil.md new file mode 100644 index 000000000..07f78b1b9 --- /dev/null +++ b/.changeset/curvy-coins-boil.md @@ -0,0 +1,9 @@ +--- +"@livekit/agents": patch +"@livekit/agents-plugin-google": patch +"@livekit/agents-plugin-livekit": patch +"@livekit/agents-plugin-openai": patch +"livekit-agents-examples": patch +--- + +Implement AgentTask feature diff --git a/agents/src/cli.ts b/agents/src/cli.ts index 2cc354f7d..1e53c16c0 100644 --- a/agents/src/cli.ts +++ b/agents/src/cli.ts @@ -77,16 +77,16 @@ const runServer = async (args: CliArgs) => { * ``` */ export const runApp = (opts: ServerOptions) => { + const logLevelOption = (defaultLevel: string) => + new Option('--log-level ', 'Set the logging level') + .choices(['trace', 'debug', 'info', 'warn', 'error', 'fatal']) + .default(defaultLevel) + .env('LOG_LEVEL'); + const program = new Command() .name('agents') .description('LiveKit Agents CLI') .version(version) - .addOption( - new Option('--log-level ', 'Set the logging level') - .choices(['trace', 'debug', 'info', 'warn', 'error', 'fatal']) - .default('info') - .env('LOG_LEVEL'), - ) .addOption( new Option('--url ', 'LiveKit server or Cloud project websocket URL').env( 'LIVEKIT_URL', @@ -120,13 +120,15 @@ export const runApp = (opts: ServerOptions) => { program .command('start') .description('Start the worker in production mode') - .action(() => { - const options = program.optsWithGlobals(); - opts.wsURL = options.url || opts.wsURL; - opts.apiKey = options.apiKey || opts.apiKey; - opts.apiSecret = options.apiSecret || opts.apiSecret; - opts.logLevel = options.logLevel || opts.logLevel; - opts.workerToken = options.workerToken || opts.workerToken; + .addOption(logLevelOption('info')) + .action((...[, command]) => { + const globalOptions = program.optsWithGlobals(); + const commandOptions = command.opts(); + opts.wsURL = globalOptions.url || opts.wsURL; + opts.apiKey = globalOptions.apiKey || opts.apiKey; + opts.apiSecret = globalOptions.apiSecret || opts.apiSecret; + opts.logLevel = commandOptions.logLevel; + opts.workerToken = globalOptions.workerToken || opts.workerToken; runServer({ opts, production: true, @@ -137,19 +139,14 @@ export const runApp = (opts: ServerOptions) => { program .command('dev') .description('Start the worker in development mode') - .addOption( - new Option('--log-level ', 'Set the logging level') - .choices(['trace', 'debug', 'info', 'warn', 'error', 'fatal']) - .default('debug') - .env('LOG_LEVEL'), - ) + .addOption(logLevelOption('debug')) .action((...[, command]) => { const globalOptions = program.optsWithGlobals(); const commandOptions = command.opts(); opts.wsURL = globalOptions.url || opts.wsURL; opts.apiKey = globalOptions.apiKey || opts.apiKey; opts.apiSecret = globalOptions.apiSecret || opts.apiSecret; - opts.logLevel = commandOptions.logLevel || globalOptions.logLevel || opts.logLevel; + opts.logLevel = commandOptions.logLevel; opts.workerToken = globalOptions.workerToken || opts.workerToken; runServer({ opts, @@ -163,19 +160,14 @@ export const runApp = (opts: ServerOptions) => { .description('Connect to a specific room') .requiredOption('--room ', 'Room name to connect to') .option('--participant-identity ', 'Identity of user to listen to') - .addOption( - new Option('--log-level ', 'Set the logging level') - .choices(['trace', 'debug', 'info', 'warn', 'error', 'fatal']) - .default('debug') - .env('LOG_LEVEL'), - ) + .addOption(logLevelOption('info')) .action((...[, command]) => { const globalOptions = program.optsWithGlobals(); const commandOptions = command.opts(); opts.wsURL = globalOptions.url || opts.wsURL; opts.apiKey = globalOptions.apiKey || opts.apiKey; opts.apiSecret = globalOptions.apiSecret || opts.apiSecret; - opts.logLevel = commandOptions.logLevel || globalOptions.logLevel || opts.logLevel; + opts.logLevel = commandOptions.logLevel; opts.workerToken = globalOptions.workerToken || opts.workerToken; runServer({ opts, @@ -189,12 +181,7 @@ export const runApp = (opts: ServerOptions) => { program .command('download-files') .description('Download plugin dependency files') - .addOption( - new Option('--log-level ', 'Set the logging level') - .choices(['trace', 'debug', 'info', 'warn', 'error', 'fatal']) - .default('debug') - .env('LOG_LEVEL'), - ) + .addOption(logLevelOption('debug')) .action((...[, command]) => { const commandOptions = command.opts(); initializeLogger({ pretty: true, level: commandOptions.logLevel }); diff --git a/agents/src/ipc/job_proc_lazy_main.ts b/agents/src/ipc/job_proc_lazy_main.ts index 11fe2a0c9..2448614ee 100644 --- a/agents/src/ipc/job_proc_lazy_main.ts +++ b/agents/src/ipc/job_proc_lazy_main.ts @@ -15,6 +15,14 @@ import type { IPCMessage } from './message.js'; const ORPHANED_TIMEOUT = 15 * 1000; +const safeSend = (msg: IPCMessage): boolean => { + if (process.connected && process.send) { + process.send(msg); + return true; + } + return false; +}; + type JobTask = { ctx: JobContext; task: Promise; @@ -50,7 +58,10 @@ class InfClient implements InferenceExecutor { async doInference(method: string, data: unknown): Promise { const requestId = shortuuid('inference_job_'); - process.send!({ case: 'inferenceRequest', value: { requestId, method, data } }); + if (!safeSend({ case: 'inferenceRequest', value: { requestId, method, data } })) { + throw new Error('IPC channel closed'); + } + this.#requests[requestId] = new PendingInference(); const resp = await this.#requests[requestId]!.promise; if (resp.error) { @@ -117,7 +128,7 @@ const startJob = ( await once(closeEvent, 'close').then((close) => { logger.debug('shutting down'); shutdown = true; - process.send!({ case: 'exiting', value: { reason: close[1] } }); + safeSend({ case: 'exiting', value: { reason: close[1] } }); }); // Close the primary agent session if it exists @@ -139,7 +150,7 @@ const startJob = ( logger.error({ error }, 'error while shutting down the job'), ); - process.send!({ case: 'done' }); + safeSend({ case: 'done', value: undefined }); joinFuture.resolve(); })(); @@ -199,7 +210,7 @@ const startJob = ( logger.debug('initializing job runner'); await agent.prewarm(proc); logger.debug('job runner initialized'); - process.send({ case: 'initializeResponse' }); + safeSend({ case: 'initializeResponse', value: undefined }); let job: JobTask | undefined = undefined; const closeEvent = new EventEmitter(); @@ -213,7 +224,7 @@ const startJob = ( switch (msg.case) { case 'pingRequest': { orphanedTimeout.refresh(); - process.send!({ + safeSend({ case: 'pongResponse', value: { lastTimestamp: msg.value.timestamp, timestamp: Date.now() }, }); diff --git a/agents/src/llm/chat_context.ts b/agents/src/llm/chat_context.ts index 0f4a644ef..ce36d92c6 100644 --- a/agents/src/llm/chat_context.ts +++ b/agents/src/llm/chat_context.ts @@ -510,6 +510,41 @@ export class ChatContext { return new ChatContext(items); } + merge( + other: ChatContext, + options: { + excludeFunctionCall?: boolean; + excludeInstructions?: boolean; + } = {}, + ): ChatContext { + const { excludeFunctionCall = false, excludeInstructions = false } = options; + const existingIds = new Set(this._items.map((item) => item.id)); + + for (const item of other.items) { + if (excludeFunctionCall && ['function_call', 'function_call_output'].includes(item.type)) { + continue; + } + + if ( + excludeInstructions && + item.type === 'message' && + (item.role === 'system' || item.role === 'developer') + ) { + continue; + } + + if (existingIds.has(item.id)) { + continue; + } + + const idx = this.findInsertionIndex(item.createdAt); + this._items.splice(idx, 0, item); + existingIds.add(item.id); + } + + return this; + } + truncate(maxItems: number): ChatContext { if (maxItems <= 0) return this; diff --git a/agents/src/llm/provider_format/utils.ts b/agents/src/llm/provider_format/utils.ts index 20dd8fe93..dea9e3abe 100644 --- a/agents/src/llm/provider_format/utils.ts +++ b/agents/src/llm/provider_format/utils.ts @@ -56,12 +56,14 @@ class ChatItemGroup { } removeInvalidToolCalls() { - if (this.toolCalls.length === this.toolOutputs.length) { - return; - } - const toolCallIds = new Set(this.toolCalls.map((call) => call.callId)); const toolOutputIds = new Set(this.toolOutputs.map((output) => output.callId)); + const sameIds = + toolCallIds.size === toolOutputIds.size && + [...toolCallIds].every((id) => toolOutputIds.has(id)); + if (this.toolCalls.length === this.toolOutputs.length && sameIds) { + return; + } // intersection of tool call ids and tool output ids const validCallIds = intersection(toolCallIds, toolOutputIds); diff --git a/agents/src/llm/realtime.ts b/agents/src/llm/realtime.ts index bebeffcf4..5c132afd0 100644 --- a/agents/src/llm/realtime.ts +++ b/agents/src/llm/realtime.ts @@ -48,6 +48,7 @@ export interface RealtimeCapabilities { userTranscription: boolean; autoToolReplyGeneration: boolean; audioOutput: boolean; + manualFunctionCalls: boolean; } export interface InputTranscriptionCompleted { diff --git a/agents/src/stream/deferred_stream.ts b/agents/src/stream/deferred_stream.ts index 71a10c7e8..d1e09b9ce 100644 --- a/agents/src/stream/deferred_stream.ts +++ b/agents/src/stream/deferred_stream.ts @@ -59,16 +59,17 @@ export class DeferredReadableStream { throw new Error('Stream source already set'); } - this.sourceReader = source.getReader(); - this.pump(); + const sourceReader = source.getReader(); + this.sourceReader = sourceReader; + void this.pump(sourceReader); } - private async pump() { + private async pump(sourceReader: ReadableStreamDefaultReader) { let sourceError: unknown; try { while (true) { - const { done, value } = await this.sourceReader!.read(); + const { done, value } = await sourceReader.read(); if (done) break; await this.writer.write(value); } @@ -81,7 +82,7 @@ export class DeferredReadableStream { // any other error from source will be propagated to the consumer if (sourceError) { try { - this.writer.abort(sourceError); + await this.writer.abort(sourceError); } catch (e) { // ignore if writer is already closed } @@ -118,10 +119,20 @@ export class DeferredReadableStream { return; } + const sourceReader = this.sourceReader!; + // Clear source first so future setSource() calls can reattach cleanly. + this.sourceReader = undefined; + // release lock will make any pending read() throw TypeError // which are expected, and we intentionally catch those error // using isStreamReaderReleaseError // this will unblock any pending read() inside the async for loop - this.sourceReader!.releaseLock(); + try { + sourceReader.releaseLock(); + } catch (e) { + if (!isStreamReaderReleaseError(e)) { + throw e; + } + } } } diff --git a/agents/src/utils.test.ts b/agents/src/utils.test.ts index 6bab4d642..a44678d08 100644 --- a/agents/src/utils.test.ts +++ b/agents/src/utils.test.ts @@ -469,6 +469,93 @@ describe('utils', () => { expect((error as Error).name).toBe('TypeError'); } }); + + it('should return undefined for Task.current outside task context', () => { + expect(Task.current()).toBeUndefined(); + }); + + it('should preserve Task.current inside a task across awaits', async () => { + const task = Task.from( + async () => { + const currentAtStart = Task.current(); + await delay(5); + const currentAfterAwait = Task.current(); + + expect(currentAtStart).toBeDefined(); + expect(currentAfterAwait).toBe(currentAtStart); + + return currentAtStart; + }, + undefined, + 'current-context-test', + ); + + const currentFromResult = await task.result; + expect(currentFromResult).toBe(task); + }); + + it('should isolate nested Task.current context and restore parent context', async () => { + const parentTask = Task.from( + async (controller) => { + const parentCurrent = Task.current(); + expect(parentCurrent).toBeDefined(); + + const childTask = Task.from( + async () => { + const childCurrentStart = Task.current(); + await delay(5); + const childCurrentAfterAwait = Task.current(); + + expect(childCurrentStart).toBeDefined(); + expect(childCurrentAfterAwait).toBe(childCurrentStart); + expect(childCurrentStart).not.toBe(parentCurrent); + + return childCurrentStart; + }, + controller, + 'child-current-context-test', + ); + + const childCurrent = await childTask.result; + const parentCurrentAfterChild = Task.current(); + + expect(parentCurrentAfterChild).toBe(parentCurrent); + + return { parentCurrent, childCurrent }; + }, + undefined, + 'parent-current-context-test', + ); + + const { parentCurrent, childCurrent } = await parentTask.result; + expect(parentCurrent).toBe(parentTask); + expect(childCurrent).not.toBe(parentCurrent); + expect(Task.current()).toBeUndefined(); + }); + + it('should always expose Task.current for concurrent task callbacks', async () => { + const tasks = Array.from({ length: 25 }, (_, idx) => + Task.from( + async () => { + const currentAtStart = Task.current(); + await delay(1); + const currentAfterAwait = Task.current(); + + expect(currentAtStart).toBeDefined(); + expect(currentAfterAwait).toBe(currentAtStart); + + return currentAtStart; + }, + undefined, + `current-context-stress-${idx}`, + ), + ); + + const currentTasks = await Promise.all(tasks.map((task) => task.result)); + currentTasks.forEach((currentTask, idx) => { + expect(currentTask).toBe(tasks[idx]); + }); + }); }); describe('Event', () => { diff --git a/agents/src/utils.ts b/agents/src/utils.ts index 686728333..03202c60f 100644 --- a/agents/src/utils.ts +++ b/agents/src/utils.ts @@ -9,6 +9,7 @@ import type { TrackKind, } from '@livekit/rtc-node'; import { AudioFrame, AudioResampler, RoomEvent } from '@livekit/rtc-node'; +import { AsyncLocalStorage } from 'node:async_hooks'; import { EventEmitter, once } from 'node:events'; import type { ReadableStream } from 'node:stream/web'; import { TransformStream, type TransformStreamDefaultController } from 'node:stream/web'; @@ -434,7 +435,9 @@ export enum TaskResult { * @param T - The type of the task result */ export class Task { + private static readonly currentTaskStorage = new AsyncLocalStorage>(); private resultFuture: Future; + private doneCallbacks: Set<() => void> = new Set(); #logger = log(); @@ -444,6 +447,21 @@ export class Task { readonly name?: string, ) { this.resultFuture = new Future(); + void this.resultFuture.await + .then( + () => undefined, + () => undefined, + ) + .finally(() => { + for (const callback of this.doneCallbacks) { + try { + callback(); + } catch (error) { + this.#logger.error({ error }, 'Task done callback failed'); + } + } + this.doneCallbacks.clear(); + }); this.runTask(); } @@ -463,6 +481,13 @@ export class Task { return new Task(fn, abortController, name); } + /** + * Returns the currently running task in this async context, if available. + */ + static current(): Task | undefined { + return Task.currentTaskStorage.getStore(); + } + private async runTask() { const run = async () => { if (this.name) { @@ -471,7 +496,8 @@ export class Task { return await this.fn(this.controller); }; - return run() + return Task.currentTaskStorage + .run(this as Task, run) .then((value) => { this.resultFuture.resolve(value); return value; @@ -543,7 +569,15 @@ export class Task { } addDoneCallback(callback: () => void) { - this.resultFuture.await.finally(callback); + if (this.done) { + queueMicrotask(callback); + return; + } + this.doneCallbacks.add(callback); + } + + removeDoneCallback(callback: () => void) { + this.doneCallbacks.delete(callback); } } diff --git a/agents/src/voice/agent.test.ts b/agents/src/voice/agent.test.ts index cc620e26a..fd5f39183 100644 --- a/agents/src/voice/agent.test.ts +++ b/agents/src/voice/agent.test.ts @@ -1,10 +1,15 @@ // SPDX-FileCopyrightText: 2025 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 -import { describe, expect, it } from 'vitest'; +import { describe, expect, it, vi } from 'vitest'; import { z } from 'zod'; import { tool } from '../llm/index.js'; -import { Agent } from './agent.js'; +import { initializeLogger } from '../log.js'; +import { Task } from '../utils.js'; +import { Agent, AgentTask, _setActivityTaskInfo } from './agent.js'; +import { agentActivityStorage } from './agent_activity.js'; + +initializeLogger({ pretty: false, level: 'error' }); describe('Agent', () => { it('should create agent with basic instructions', () => { @@ -77,4 +82,137 @@ describe('Agent', () => { expect(tools1).toEqual(tools2); expect(tools1).toEqual(tools); }); + + it('should require AgentTask to run inside task context', async () => { + class TestTask extends AgentTask { + constructor() { + super({ instructions: 'test task' }); + } + } + + const task = new TestTask(); + await expect(task.run()).rejects.toThrow('must be executed inside a Task context'); + }); + + it('should require AgentTask to run inside inline task context', async () => { + class TestTask extends AgentTask { + constructor() { + super({ instructions: 'test task' }); + } + } + + const task = new TestTask(); + const wrapper = Task.from(async () => { + return await task.run(); + }); + + await expect(wrapper.result).rejects.toThrow( + 'should only be awaited inside function tools or the onEnter/onExit methods of an Agent', + ); + }); + + it('should allow AgentTask run from inline task context', async () => { + class TestTask extends AgentTask { + constructor() { + super({ instructions: 'test task' }); + } + } + + const task = new TestTask(); + const oldAgent = new Agent({ instructions: 'old agent' }); + const mockSession = { + currentAgent: oldAgent, + _globalRunState: undefined, + _updateActivity: async (agent: Agent) => { + if (agent === task) { + task.complete('ok'); + } + }, + }; + + const mockActivity = { + agent: oldAgent, + agentSession: mockSession, + _onEnterTask: undefined, + llm: undefined, + close: async () => {}, + }; + + const wrapper = Task.from(async () => { + const currentTask = Task.current(); + if (!currentTask) { + throw new Error('expected task context'); + } + _setActivityTaskInfo(currentTask, { inlineTask: true }); + return await agentActivityStorage.run(mockActivity as any, () => task.run()); + }); + + await expect(wrapper.result).resolves.toBe('ok'); + }); + + it('should require AgentTask to run inside AgentActivity context', async () => { + class TestTask extends AgentTask { + constructor() { + super({ instructions: 'test task' }); + } + } + + const task = new TestTask(); + const wrapper = Task.from(async () => { + const currentTask = Task.current(); + if (!currentTask) { + throw new Error('expected task context'); + } + _setActivityTaskInfo(currentTask, { inlineTask: true }); + return await task.run(); + }); + + await expect(wrapper.result).rejects.toThrow( + 'must be executed inside an AgentActivity context', + ); + }); + + it('should close old activity when current agent changes while AgentTask is pending', async () => { + class TestTask extends AgentTask { + constructor() { + super({ instructions: 'test task' }); + } + } + + const task = new TestTask(); + const oldAgent = new Agent({ instructions: 'old agent' }); + const switchedAgent = new Agent({ instructions: 'switched agent' }); + const closeOldActivity = vi.fn(async () => {}); + + const mockSession = { + currentAgent: oldAgent as Agent, + _globalRunState: undefined, + _updateActivity: async (agent: Agent) => { + if (agent === task) { + mockSession.currentAgent = switchedAgent; + task.complete('ok'); + } + }, + }; + + const mockActivity = { + agent: oldAgent, + agentSession: mockSession, + _onEnterTask: undefined, + llm: undefined, + close: closeOldActivity, + }; + + const wrapper = Task.from(async () => { + const currentTask = Task.current(); + if (!currentTask) { + throw new Error('expected task context'); + } + _setActivityTaskInfo(currentTask, { inlineTask: true }); + return await agentActivityStorage.run(mockActivity as any, () => task.run()); + }); + + await expect(wrapper.result).resolves.toBe('ok'); + expect(closeOldActivity).toHaveBeenCalledTimes(1); + }); }); diff --git a/agents/src/voice/agent.ts b/agents/src/voice/agent.ts index 1fb6664c2..06a59e8eb 100644 --- a/agents/src/voice/agent.ts +++ b/agents/src/voice/agent.ts @@ -13,26 +13,71 @@ import { type TTSModelString, } from '../inference/index.js'; import { ReadonlyChatContext } from '../llm/chat_context.js'; -import type { ChatMessage, FunctionCall, RealtimeModel } from '../llm/index.js'; +import type { ChatMessage, FunctionCall } from '../llm/index.js'; import { type ChatChunk, ChatContext, LLM, + RealtimeModel, type ToolChoice, type ToolContext, } from '../llm/index.js'; +import { log } from '../log.js'; import type { STT, SpeechEvent } from '../stt/index.js'; import { StreamAdapter as STTStreamAdapter } from '../stt/index.js'; import { SentenceTokenizer as BasicSentenceTokenizer } from '../tokenize/basic/index.js'; import type { TTS } from '../tts/index.js'; import { SynthesizeStream, StreamAdapter as TTSStreamAdapter } from '../tts/index.js'; import { USERDATA_TIMED_TRANSCRIPT } from '../types.js'; +import { Future, Task } from '../utils.js'; import type { VAD } from '../vad.js'; -import type { AgentActivity } from './agent_activity.js'; +import { type AgentActivity, agentActivityStorage } from './agent_activity.js'; import type { AgentSession, TurnDetectionMode } from './agent_session.js'; import type { TimedString } from './io.js'; +import type { SpeechHandle } from './speech_handle.js'; + +export const functionCallStorage = new AsyncLocalStorage<{ functionCall?: FunctionCall }>(); +export const speechHandleStorage = new AsyncLocalStorage(); +const activityTaskInfoStorage = new WeakMap, _ActivityTaskInfo>(); + +type _ActivityTaskInfo = { + functionCall: FunctionCall | null; + speechHandle: SpeechHandle | null; + inlineTask: boolean; +}; + +/** @internal */ +export function _setActivityTaskInfo( + task: Task, + options: { + functionCall?: FunctionCall | null; + speechHandle?: SpeechHandle | null; + inlineTask?: boolean; + }, +): void { + const info = activityTaskInfoStorage.get(task) ?? { + functionCall: null, + speechHandle: null, + inlineTask: false, + }; + + if (Object.hasOwn(options, 'functionCall')) { + info.functionCall = options.functionCall ?? null; + } + if (Object.hasOwn(options, 'speechHandle')) { + info.speechHandle = options.speechHandle ?? null; + } + if (Object.hasOwn(options, 'inlineTask')) { + info.inlineTask = options.inlineTask ?? false; + } + + activityTaskInfoStorage.set(task, info); +} -export const asyncLocalStorage = new AsyncLocalStorage<{ functionCall?: FunctionCall }>(); +/** @internal */ +export function _getActivityTaskInfo(task: Task): _ActivityTaskInfo | undefined { + return activityTaskInfoStorage.get(task); +} export const STOP_RESPONSE_SYMBOL = Symbol('StopResponse'); export class StopResponse extends Error { @@ -268,20 +313,20 @@ export class Agent { throw new Error('sttNode called but no STT node is available'); } - let wrapped_stt = activity.stt; + let wrappedStt = activity.stt; - if (!wrapped_stt.capabilities.streaming) { + if (!wrappedStt.capabilities.streaming) { const vad = agent.vad || activity.vad; if (!vad) { throw new Error( 'STT does not support streaming, add a VAD to the AgentTask/VoiceAgent to enable streaming', ); } - wrapped_stt = new STTStreamAdapter(wrapped_stt, vad); + wrappedStt = new STTStreamAdapter(wrappedStt, vad); } const connOptions = activity.agentSession.connOptions.sttConnOptions; - const stream = wrapped_stt.stream({ connOptions }); + const stream = wrappedStt.stream({ connOptions }); // Set startTimeOffset to provide linear timestamps across reconnections const audioInputStartedAt = @@ -382,14 +427,14 @@ export class Agent { throw new Error('ttsNode called but no TTS node is available'); } - let wrapped_tts = activity.tts; + let wrappedTts = activity.tts; if (!activity.tts.capabilities.streaming) { - wrapped_tts = new TTSStreamAdapter(wrapped_tts, new BasicSentenceTokenizer()); + wrappedTts = new TTSStreamAdapter(wrappedTts, new BasicSentenceTokenizer()); } const connOptions = activity.agentSession.connOptions.ttsConnOptions; - const stream = wrapped_tts.stream({ connOptions }); + const stream = wrappedTts.stream({ connOptions }); stream.updateInputStream(text); let cleaned = false; @@ -440,3 +485,137 @@ export class Agent { }, }; } + +export class AgentTask extends Agent { + private started = false; + private future = new Future(); + + #logger = log(); + + get done(): boolean { + return this.future.done; + } + + complete(result: ResultT | Error): void { + if (this.future.done) { + throw new Error(`${this.constructor.name} is already done`); + } + + if (result instanceof Error) { + this.future.reject(result); + } else { + this.future.resolve(result); + } + + const speechHandle = speechHandleStorage.getStore(); + if (speechHandle) { + speechHandle._maybeRunFinalOutput = result; + } + } + + async run(): Promise { + if (this.started) { + throw new Error( + `Task ${this.constructor.name} has already started and cannot be awaited multiple times`, + ); + } + this.started = true; + + const currentTask = Task.current(); + if (!currentTask) { + throw new Error(`${this.constructor.name} must be executed inside a Task context`); + } + + const taskInfo = _getActivityTaskInfo(currentTask); + if (!taskInfo || !taskInfo.inlineTask) { + throw new Error( + `${this.constructor.name} should only be awaited inside function tools or the onEnter/onExit methods of an Agent`, + ); + } + + const speechHandle = speechHandleStorage.getStore(); + const oldActivity = agentActivityStorage.getStore(); + if (!oldActivity) { + throw new Error(`${this.constructor.name} must be executed inside an AgentActivity context`); + } + + currentTask.addDoneCallback(() => { + if (this.future.done) return; + + // If the Task finished before the AgentTask was completed, complete the AgentTask with an error. + this.#logger.error(`The Task finished before ${this.constructor.name} was completed.`); + this.complete(new Error(`The Task finished before ${this.constructor.name} was completed.`)); + }); + + const oldAgent = oldActivity.agent; + const session = oldActivity.agentSession; + + const blockedTasks: Task[] = [currentTask]; + const onEnterTask = oldActivity._onEnterTask; + + if (onEnterTask && !onEnterTask.done && onEnterTask !== currentTask) { + blockedTasks.push(onEnterTask); + } + + if ( + taskInfo.functionCall && + oldActivity.llm instanceof RealtimeModel && + !oldActivity.llm.capabilities.manualFunctionCalls + ) { + this.#logger.error( + `Realtime model does not support resuming function calls from chat context, ` + + `using AgentTask inside a function tool may have unexpected behavior.`, + ); + } + + await session._updateActivity(this, { + previousActivity: 'pause', + newActivity: 'start', + blockedTasks, + }); + + let runState = session._globalRunState; + if (speechHandle && runState && !runState.done()) { + // Only unwatch the parent speech handle if there are other handles keeping the run alive. + // When watchedHandleCount is 1 (only the parent), unwatching would drop it to 0 and + // mark the run done prematurely — before function_call_output and assistant message arrive. + if (runState._watchedHandleCount() > 1) { + runState._unwatchHandle(speechHandle); + } + // it is OK to call _markDoneIfNeeded here, the above _updateActivity will call onEnter + // and newly added handles keep the run alive. + runState._markDoneIfNeeded(); + } + + try { + return await this.future.await; + } finally { + // runState could have changed after future resolved + runState = session._globalRunState; + + if (session.currentAgent !== this) { + this.#logger.warn( + `${this.constructor.name} completed, but the agent has changed in the meantime. ` + + `Ignoring handoff to the previous agent, likely due to AgentSession.updateAgent being invoked.`, + ); + await oldActivity.close(); + } else { + if (speechHandle && runState && !runState.done()) { + runState._watchHandle(speechHandle); + } + + const mergedChatCtx = oldAgent._chatCtx.merge(this._chatCtx, { + excludeFunctionCall: true, + excludeInstructions: true, + }); + oldAgent._chatCtx.items = mergedChatCtx.items; + + await session._updateActivity(oldAgent, { + previousActivity: 'close', + newActivity: 'resume', + waitOnEnter: false, + }); + } + } + } +} diff --git a/agents/src/voice/agent_activity.ts b/agents/src/voice/agent_activity.ts index aa4e46fd9..9310459e5 100644 --- a/agents/src/voice/agent_activity.ts +++ b/agents/src/voice/agent_activity.ts @@ -35,7 +35,7 @@ import type { TTSMetrics, VADMetrics, } from '../metrics/base.js'; -import { DeferredReadableStream } from '../stream/deferred_stream.js'; +import { MultiInputStream } from '../stream/multi_input_stream.js'; import { STT, type STTError, type SpeechEvent } from '../stt/stt.js'; import { recordRealtimeMetrics, traceTypes, tracer } from '../telemetry/index.js'; import { splitWords } from '../tokenize/basic/word.js'; @@ -43,7 +43,13 @@ import { TTS, type TTSError } from '../tts/tts.js'; import { Future, Task, cancelAndWait, waitFor } from '../utils.js'; import { VAD, type VADEvent } from '../vad.js'; import type { Agent, ModelSettings } from './agent.js'; -import { StopResponse, asyncLocalStorage } from './agent.js'; +import { + StopResponse, + _getActivityTaskInfo, + _setActivityTaskInfo, + functionCallStorage, + speechHandleStorage, +} from './agent.js'; import { type AgentSession, type TurnDetectionMode } from './agent_session.js'; import { AudioRecognition, @@ -60,7 +66,7 @@ import { createSpeechCreatedEvent, createUserInputTranscribedEvent, } from './events.js'; -import type { ToolExecutionOutput, _TTSGenerationData } from './generation.js'; +import type { ToolExecutionOutput, ToolOutput, _TTSGenerationData } from './generation.js'; import { type _AudioOut, type _TextOut, @@ -76,7 +82,7 @@ import type { TimedString } from './io.js'; import { SpeechHandle } from './speech_handle.js'; import { setParticipantSpanAttributes } from './utils.js'; -const speechHandleStorage = new AsyncLocalStorage(); +export const agentActivityStorage = new AsyncLocalStorage(); interface PreemptiveGeneration { speechHandle: SpeechHandle; @@ -89,31 +95,47 @@ interface PreemptiveGeneration { } export class AgentActivity implements RecognitionHooks { + agent: Agent; + agentSession: AgentSession; + private static readonly REPLY_TASK_CANCEL_TIMEOUT = 5000; + private started = false; private audioRecognition?: AudioRecognition; private realtimeSession?: RealtimeSession; private realtimeSpans?: Map; // Maps response_id to OTEL span for metrics recording private turnDetectionMode?: Exclude; private logger = log(); - private _draining = false; + private _schedulingPaused = true; + private _drainBlockedTasks: Task[] = []; private _currentSpeech?: SpeechHandle; private speechQueue: Heap<[number, number, SpeechHandle]>; // [priority, timestamp, speechHandle] private q_updated: Future; private speechTasks: Set> = new Set(); private lock = new Mutex(); - private audioStream = new DeferredReadableStream(); + private audioStream = new MultiInputStream(); + private audioStreamId?: string; + // default to null as None, which maps to the default provider tool choice value private toolChoice: ToolChoice | null = null; private _preemptiveGeneration?: PreemptiveGeneration; - agent: Agent; - agentSession: AgentSession; - /** @internal */ _mainTask?: Task; - _userTurnCompletedTask?: Promise; - + _onEnterTask?: Task; + _onExitTask?: Task; + _userTurnCompletedTask?: Task; + + private readonly onRealtimeGenerationCreated = (ev: GenerationCreatedEvent) => + this.onGenerationCreated(ev); + private readonly onRealtimeInputSpeechStarted = (ev: InputSpeechStartedEvent) => + this.onInputSpeechStarted(ev); + private readonly onRealtimeInputSpeechStopped = (ev: InputSpeechStoppedEvent) => + this.onInputSpeechStopped(ev); + private readonly onRealtimeInputAudioTranscriptionCompleted = (ev: InputTranscriptionCompleted) => + this.onInputAudioTranscriptionCompleted(ev); + private readonly onModelError = (ev: RealtimeModelError | STTError | TTSError | LLMError) => + this.onError(ev); constructor(agent: Agent, agentSession: AgentSession) { this.agent = agent; this.agentSession = agentSession; @@ -133,7 +155,7 @@ export class AgentActivity implements RecognitionHooks { if (this.turnDetectionMode === 'vad' && this.vad === undefined) { this.logger.warn( - 'turnDetection is set to "vad", but no VAD model is provided, ignoring the turnDdetection setting', + 'turnDetection is set to "vad", but no VAD model is provided, ignoring the turnDetection setting', ); this.turnDetectionMode = undefined; } @@ -211,120 +233,138 @@ export class AgentActivity implements RecognitionHooks { async start(): Promise { const unlock = await this.lock.lock(); try { - // Create start_agent_activity as a ROOT span (new trace) to match Python behavior - const startSpan = tracer.startSpan({ - name: 'start_agent_activity', - attributes: { [traceTypes.ATTR_AGENT_LABEL]: this.agent.id }, - context: ROOT_CONTEXT, - }); + await this._startSession({ spanName: 'start_agent_activity', runOnEnter: true }); + } finally { + unlock(); + } + } - this.agent._agentActivity = this; + async resume(): Promise { + const unlock = await this.lock.lock(); + try { + await this._startSession({ spanName: 'resume_agent_activity', runOnEnter: false }); + } finally { + unlock(); + } + } - if (this.llm instanceof RealtimeModel) { - this.realtimeSession = this.llm.session(); - this.realtimeSpans = new Map(); - this.realtimeSession.on('generation_created', (ev) => this.onGenerationCreated(ev)); - this.realtimeSession.on('input_speech_started', (ev) => this.onInputSpeechStarted(ev)); - this.realtimeSession.on('input_speech_stopped', (ev) => this.onInputSpeechStopped(ev)); - this.realtimeSession.on('input_audio_transcription_completed', (ev) => - this.onInputAudioTranscriptionCompleted(ev), - ); - this.realtimeSession.on('metrics_collected', (ev) => this.onMetricsCollected(ev)); - this.realtimeSession.on('error', (ev) => this.onError(ev)); - - removeInstructions(this.agent._chatCtx); - try { - await this.realtimeSession.updateInstructions(this.agent.instructions); - } catch (error) { - this.logger.error(error, 'failed to update the instructions'); - } + private async _startSession(options: { + spanName: 'start_agent_activity' | 'resume_agent_activity'; + runOnEnter: boolean; + }): Promise { + const { spanName, runOnEnter } = options; + const startSpan = tracer.startSpan({ + name: spanName, + attributes: { [traceTypes.ATTR_AGENT_LABEL]: this.agent.id }, + context: ROOT_CONTEXT, + }); - try { - await this.realtimeSession.updateChatCtx(this.agent.chatCtx); - } catch (error) { - this.logger.error(error, 'failed to update the chat context'); - } + this.agent._agentActivity = this; - try { - await this.realtimeSession.updateTools(this.tools); - } catch (error) { - this.logger.error(error, 'failed to update the tools'); - } + if (this.llm instanceof RealtimeModel) { + this.realtimeSession = this.llm.session(); + this.realtimeSpans = new Map(); + this.realtimeSession.on('generation_created', this.onRealtimeGenerationCreated); + this.realtimeSession.on('input_speech_started', this.onRealtimeInputSpeechStarted); + this.realtimeSession.on('input_speech_stopped', this.onRealtimeInputSpeechStopped); + this.realtimeSession.on( + 'input_audio_transcription_completed', + this.onRealtimeInputAudioTranscriptionCompleted, + ); + this.realtimeSession.on('metrics_collected', this.onMetricsCollected); + this.realtimeSession.on('error', this.onModelError); - if (!this.llm.capabilities.audioOutput && !this.tts && this.agentSession.output.audio) { - this.logger.error( - 'audio output is enabled but RealtimeModel has no audio modality ' + - 'and no TTS is set. Either enable audio modality in the RealtimeModel ' + - 'or set a TTS model.', - ); - } - } else if (this.llm instanceof LLM) { - try { - updateInstructions({ - chatCtx: this.agent._chatCtx, - instructions: this.agent.instructions, - addIfMissing: true, - }); - } catch (error) { - this.logger.error('failed to update the instructions', error); - } + removeInstructions(this.agent._chatCtx); + try { + await this.realtimeSession.updateInstructions(this.agent.instructions); + } catch (error) { + this.logger.error(error, 'failed to update the instructions'); } - // metrics and error handling - if (this.llm instanceof LLM) { - this.llm.on('metrics_collected', (ev) => this.onMetricsCollected(ev)); - this.llm.on('error', (ev) => this.onError(ev)); + try { + await this.realtimeSession.updateChatCtx(this.agent.chatCtx); + } catch (error) { + this.logger.error(error, 'failed to update the chat context'); } - if (this.stt instanceof STT) { - this.stt.on('metrics_collected', (ev) => this.onMetricsCollected(ev)); - this.stt.on('error', (ev) => this.onError(ev)); + try { + await this.realtimeSession.updateTools(this.tools); + } catch (error) { + this.logger.error(error, 'failed to update the tools'); } - if (this.tts instanceof TTS) { - this.tts.on('metrics_collected', (ev) => this.onMetricsCollected(ev)); - this.tts.on('error', (ev) => this.onError(ev)); + if (!this.llm.capabilities.audioOutput && !this.tts && this.agentSession.output.audio) { + this.logger.error( + 'audio output is enabled but RealtimeModel has no audio modality ' + + 'and no TTS is set. Either enable audio modality in the RealtimeModel ' + + 'or set a TTS model.', + ); } - - if (this.vad instanceof VAD) { - this.vad.on('metrics_collected', (ev) => this.onMetricsCollected(ev)); + } else if (this.llm instanceof LLM) { + try { + updateInstructions({ + chatCtx: this.agent._chatCtx, + instructions: this.agent.instructions, + addIfMissing: true, + }); + } catch (error) { + this.logger.error('failed to update the instructions', error); } + } - this.audioRecognition = new AudioRecognition({ - recognitionHooks: this, - // Disable stt node if stt is not provided - stt: this.stt ? (...args) => this.agent.sttNode(...args) : undefined, - vad: this.vad, - turnDetector: typeof this.turnDetection === 'string' ? undefined : this.turnDetection, - turnDetectionMode: this.turnDetectionMode, - minEndpointingDelay: this.agentSession.options.minEndpointingDelay, - maxEndpointingDelay: this.agentSession.options.maxEndpointingDelay, - rootSpanContext: this.agentSession.rootSpanContext, - sttModel: this.stt?.label, - sttProvider: this.getSttProvider(), - getLinkedParticipant: () => this.agentSession._roomIO?.linkedParticipant, - }); - this.audioRecognition.start(); - this.started = true; + // metrics and error handling + if (this.llm instanceof LLM) { + this.llm.on('metrics_collected', this.onMetricsCollected); + this.llm.on('error', this.onModelError); + } - this._mainTask = Task.from(({ signal }) => this.mainTask(signal)); + if (this.stt instanceof STT) { + this.stt.on('metrics_collected', this.onMetricsCollected); + this.stt.on('error', this.onModelError); + } - // Create on_enter as a child of start_agent_activity in the new trace - const onEnterTask = tracer.startActiveSpan(async () => this.agent.onEnter(), { - name: 'on_enter', - context: trace.setSpan(ROOT_CONTEXT, startSpan), - attributes: { [traceTypes.ATTR_AGENT_LABEL]: this.agent.id }, - }); + if (this.tts instanceof TTS) { + this.tts.on('metrics_collected', this.onMetricsCollected); + this.tts.on('error', this.onModelError); + } - this.createSpeechTask({ - task: Task.from(() => onEnterTask), + if (this.vad instanceof VAD) { + this.vad.on('metrics_collected', this.onMetricsCollected); + } + + this.audioRecognition = new AudioRecognition({ + recognitionHooks: this, + // Disable stt node if stt is not provided + stt: this.stt ? (...args) => this.agent.sttNode(...args) : undefined, + vad: this.vad, + turnDetector: typeof this.turnDetection === 'string' ? undefined : this.turnDetection, + turnDetectionMode: this.turnDetectionMode, + minEndpointingDelay: this.agentSession.options.minEndpointingDelay, + maxEndpointingDelay: this.agentSession.options.maxEndpointingDelay, + rootSpanContext: this.agentSession.rootSpanContext, + sttModel: this.stt?.label, + sttProvider: this.getSttProvider(), + getLinkedParticipant: () => this.agentSession._roomIO?.linkedParticipant, + }); + this.audioRecognition.start(); + this.started = true; + + this._resumeSchedulingTask(); + + if (runOnEnter) { + this._onEnterTask = this.createSpeechTask({ + taskFn: () => + tracer.startActiveSpan(async () => this.agent.onEnter(), { + name: 'on_enter', + context: trace.setSpan(ROOT_CONTEXT, startSpan), + attributes: { [traceTypes.ATTR_AGENT_LABEL]: this.agent.id }, + }), + inlineTask: true, name: 'AgentActivity_onEnter', }); - - startSpan.end(); - } finally { - unlock(); } + + startSpan.end(); } get currentSpeech(): SpeechHandle | undefined { @@ -362,8 +402,8 @@ export class AgentActivity implements RecognitionHooks { return this.agent.toolCtx; } - get draining(): boolean { - return this._draining; + get schedulingPaused(): boolean { + return this._schedulingPaused; } get realtimeLLMSession(): RealtimeSession | undefined { @@ -417,18 +457,10 @@ export class AgentActivity implements RecognitionHooks { } attachAudioInput(audioStream: ReadableStream): void { - if (this.audioStream.isSourceSet) { - this.logger.debug('detaching existing audio input in agent activity'); - this.audioStream.detachSource(); - } + void this.audioStream.close(); + this.audioStream = new MultiInputStream(); - /** - * We need to add a deferred ReadableStream layer on top of the audioStream from the agent session. - * The tee() operation should be applied to the deferred stream, not the original audioStream. - * This is important because teeing the original stream directly makes it very difficult—if not - * impossible—to implement stream unlock logic cleanly. - */ - this.audioStream.setSource(audioStream); + this.audioStreamId = this.audioStream.addInputStream(audioStream); const [realtimeAudioStream, recognitionAudioStream] = this.audioStream.stream.tee(); if (this.realtimeSession) { @@ -441,16 +473,29 @@ export class AgentActivity implements RecognitionHooks { } detachAudioInput(): void { - this.audioStream.detachSource(); + if (this.audioStreamId === undefined) { + return; + } + + void this.audioStream.close(); + this.audioStream = new MultiInputStream(); + this.audioStreamId = undefined; } - commitUserTurn() { + commitUserTurn( + options: { + audioDetached?: boolean; + throwIfNotReady?: boolean; + } = {}, + ) { + const { audioDetached = false, throwIfNotReady = true } = options; if (!this.audioRecognition) { - throw new Error('AudioRecognition is not initialized'); + if (throwIfNotReady) { + throw new Error('AudioRecognition is not initialized'); + } + return; } - // TODO(brian): add audio_detached flag - const audioDetached = false; this.audioRecognition.commitUserTurn(audioDetached); } @@ -508,14 +553,13 @@ export class AgentActivity implements RecognitionHooks { }), ); const task = this.createSpeechTask({ - task: Task.from((abortController: AbortController) => + taskFn: (abortController: AbortController) => this.ttsTask(handle, text, addToChatCtx, {}, abortController, audio), - ), ownedSpeechHandle: handle, name: 'AgentActivity.say_tts', }); - task.finally(() => this.onPipelineReplyDone()); + task.result.finally(() => this.onPipelineReplyDone()); this.scheduleSpeech(handle, SpeechHandle.SPEECH_PRIORITY_NORMAL); return handle; } @@ -628,9 +672,9 @@ export class AgentActivity implements RecognitionHooks { return; } - if (this.draining) { + if (this.schedulingPaused) { // TODO(shubhra): should we "forward" this new turn to the next agent? - this.logger.warn('skipping new realtime generation, the agent is draining'); + this.logger.warn('skipping new realtime generation, the speech scheduling is not running'); return; } @@ -648,9 +692,8 @@ export class AgentActivity implements RecognitionHooks { this.logger.info({ speech_id: handle.id }, 'Creating speech handle'); this.createSpeechTask({ - task: Task.from((abortController: AbortController) => + taskFn: (abortController: AbortController) => this.realtimeGenerationTask(handle, ev, {}, abortController), - ), ownedSpeechHandle: handle, name: 'AgentActivity.realtimeGeneration', }); @@ -782,7 +825,7 @@ export class AgentActivity implements RecognitionHooks { onPreemptiveGeneration(info: PreemptiveGenerationInfo): void { if ( !this.agentSession.options.preemptiveGeneration || - this.draining || + this.schedulingPaused || (this._currentSpeech !== undefined && !this._currentSpeech.interrupted) || !(this.llm instanceof LLM) ) { @@ -829,11 +872,32 @@ export class AgentActivity implements RecognitionHooks { } private createSpeechTask(options: { - task: Task; + taskFn: (controller: AbortController) => Promise; + controller?: AbortController; ownedSpeechHandle?: SpeechHandle; + inlineTask?: boolean; name?: string; - }): Promise { - const { task, ownedSpeechHandle } = options; + }): Task { + const { taskFn, controller, ownedSpeechHandle, inlineTask, name } = options; + + const wrappedFn = (ctrl: AbortController) => { + return agentActivityStorage.run(this, () => { + // Mark inline/speech metadata at task runtime to avoid a race where taskFn executes + // before post-construction metadata is attached to the Task instance. + const currentTask = Task.current(); + if (currentTask) { + _setActivityTaskInfo(currentTask, { speechHandle: ownedSpeechHandle, inlineTask }); + } + + if (ownedSpeechHandle) { + return speechHandleStorage.run(ownedSpeechHandle, () => taskFn(ctrl)); + } + return taskFn(ctrl); + }); + }; + + const task = Task.from(wrappedFn, controller, name); + _setActivityTaskInfo(task, { speechHandle: ownedSpeechHandle, inlineTask }); this.speechTasks.add(task); task.addDoneCallback(() => { @@ -853,13 +917,16 @@ export class AgentActivity implements RecognitionHooks { this.wakeupMainTask(); }); - return task.result; + return task; } async onEndOfTurn(info: EndOfTurnInfo): Promise { - if (this.draining) { + if (this.schedulingPaused) { this.cancelPreemptiveGeneration(); - this.logger.warn({ user_input: info.newTranscript }, 'skipping user input, task is draining'); + this.logger.warn( + { user_input: info.newTranscript }, + 'skipping user input, speech scheduling is paused', + ); // TODO(shubhra): should we "forward" this new turn to the next agent/activity? return true; } @@ -892,7 +959,7 @@ export class AgentActivity implements RecognitionHooks { const oldTask = this._userTurnCompletedTask; this._userTurnCompletedTask = this.createSpeechTask({ - task: Task.from(() => this.userTurnCompleted(info, oldTask)), + taskFn: () => this.userTurnCompleted(info, oldTask), name: 'AgentActivity.userTurnCompleted', }); return true; @@ -928,10 +995,12 @@ export class AgentActivity implements RecognitionHooks { this._currentSpeech = undefined; } - // If we're draining and there are no more speech tasks, we can exit. - // Only speech tasks can bypass draining to create a tool response - if (this.draining && this.speechTasks.size === 0) { - this.logger.info('mainTask: draining and no more speech tasks'); + // if we're draining/pausing and there are no more speech tasks, we can exit. + // only speech tasks can bypass draining to create a tool response (see scheduleSpeech) + const toWait = this.getDrainPendingSpeechTasks(); + + if (this._schedulingPaused && toWait.length === 0) { + this.logger.info('mainTask: scheduling paused and no more speech tasks to wait'); break; } @@ -941,6 +1010,39 @@ export class AgentActivity implements RecognitionHooks { this.logger.info('AgentActivity mainTask: exiting'); } + private getDrainPendingSpeechTasks(): Task[] { + const blockedHandles: SpeechHandle[] = []; + + for (const task of this._drainBlockedTasks) { + const info = _getActivityTaskInfo(task); + if (!info) { + this.logger.error('blocked task without activity info; skipping.'); + continue; + } + + if (!info.speechHandle) { + continue; // onEnter/onExit + } + + blockedHandles.push(info.speechHandle); + } + + const toWait: Task[] = []; + for (const task of this.speechTasks) { + if (this._drainBlockedTasks.includes(task)) { + continue; + } + + const info = _getActivityTaskInfo(task); + if (info && info.speechHandle && blockedHandles.includes(info.speechHandle)) { + continue; + } + + toWait.push(task); + } + return toWait; + } + private wakeupMainTask(): void { this.q_updated.resolve(); } @@ -982,7 +1084,7 @@ export class AgentActivity implements RecognitionHooks { throw new Error('trying to generate reply without an LLM model'); } - const functionCall = asyncLocalStorage.getStore()?.functionCall; + const functionCall = functionCallStorage.getStore()?.functionCall; if (toolChoice === undefined && functionCall !== undefined) { // when generateReply is called inside a tool, set toolChoice to 'none' by default toolChoice = 'none'; @@ -1004,7 +1106,7 @@ export class AgentActivity implements RecognitionHooks { if (this.llm instanceof RealtimeModel) { this.createSpeechTask({ - task: Task.from((abortController: AbortController) => + taskFn: (abortController: AbortController) => this.realtimeReplyTask({ speechHandle: handle, // TODO(brian): support llm.ChatMessage for the realtime model @@ -1016,7 +1118,6 @@ export class AgentActivity implements RecognitionHooks { }, abortController, }), - ), ownedSpeechHandle: handle, name: 'AgentActivity.realtimeReply', }); @@ -1029,7 +1130,7 @@ export class AgentActivity implements RecognitionHooks { } const task = this.createSpeechTask({ - task: Task.from((abortController: AbortController) => + taskFn: (abortController: AbortController) => this.pipelineReplyTask( handle, chatCtx ?? this.agent.chatCtx, @@ -1041,12 +1142,11 @@ export class AgentActivity implements RecognitionHooks { instructions, userMessage, ), - ), ownedSpeechHandle: handle, name: 'AgentActivity.pipelineReply', }); - task.finally(() => this.onPipelineReplyDone()); + task.result.finally(() => this.onPipelineReplyDone()); } if (scheduleSpeech) { @@ -1055,16 +1155,19 @@ export class AgentActivity implements RecognitionHooks { return handle; } - interrupt(): Future { + interrupt(options: { force?: boolean } = {}): Future { + const { force = false } = options; + this.cancelPreemptiveGeneration(); + const future = new Future(); const currentSpeech = this._currentSpeech; //TODO(AJS-273): add interrupt for background speeches - currentSpeech?.interrupt(); + currentSpeech?.interrupt(force); for (const [_, __, speech] of this.speechQueue) { - speech.interrupt(); + speech.interrupt(force); } this.realtimeSession?.interrupt(); @@ -1087,13 +1190,13 @@ export class AgentActivity implements RecognitionHooks { } } - private async userTurnCompleted(info: EndOfTurnInfo, oldTask?: Promise): Promise { + private async userTurnCompleted(info: EndOfTurnInfo, oldTask?: Task): Promise { if (oldTask) { // We never cancel user code as this is very confusing. // So we wait for the old execution of onUserTurnCompleted to finish. // In practice this is OK because most speeches will be interrupted if a new turn // is detected. So the previous execution should complete quickly. - await oldTask; + await oldTask.result; } // When the audio recognition detects the end of a user turn: @@ -1551,13 +1654,15 @@ export class AgentActivity implements RecognitionHooks { for (const msg of toolsMessages) { msg.createdAt = replyStartedAt; } - this.agent._chatCtx.insert(toolsMessages); - // Only add FunctionCallOutput items to session history since FunctionCall items - // were already added by onToolExecutionStarted when the tool execution began + // Only insert FunctionCallOutput items into agent._chatCtx since FunctionCall items + // were already added by onToolExecutionStarted when the tool execution began. + // Inserting function_calls again would create duplicates that break provider APIs + // (e.g. Google's "function response parts != function call parts" error). const toolCallOutputs = toolsMessages.filter( (m): m is FunctionCallOutput => m.type === 'function_call_output', ); if (toolCallOutputs.length > 0) { + this.agent._chatCtx.insert(toolCallOutputs); this.agentSession._toolItemsAdded(toolCallOutputs); } } @@ -1665,52 +1770,18 @@ export class AgentActivity implements RecognitionHooks { return; } - const functionToolsExecutedEvent = createFunctionToolsExecutedEvent({ - functionCalls: [], - functionCallOutputs: [], - }); - let shouldGenerateToolReply: boolean = false; - let newAgentTask: Agent | null = null; - let ignoreTaskSwitch: boolean = false; - - for (const sanitizedOut of toolOutput.output) { - if (sanitizedOut.toolCallOutput !== undefined) { - functionToolsExecutedEvent.functionCalls.push(sanitizedOut.toolCall); - functionToolsExecutedEvent.functionCallOutputs.push(sanitizedOut.toolCallOutput); - if (sanitizedOut.replyRequired) { - shouldGenerateToolReply = true; - } - } - - if (newAgentTask !== null && sanitizedOut.agentTask !== undefined) { - this.logger.error('expected to receive only one agent task from the tool executions'); - ignoreTaskSwitch = true; - // TODO(brian): should we mark the function call as failed to notify the LLM? - } - - newAgentTask = sanitizedOut.agentTask ?? null; - - this.logger.debug( - { - speechId: speechHandle.id, - name: sanitizedOut.toolCall?.name, - args: sanitizedOut.toolCall.args, - output: sanitizedOut.toolCallOutput?.output, - isError: sanitizedOut.toolCallOutput?.isError, - }, - 'Tool call execution finished', - ); - } + const { functionToolsExecutedEvent, shouldGenerateToolReply, newAgentTask, ignoreTaskSwitch } = + this.summarizeToolExecutionOutput(toolOutput, speechHandle); this.agentSession.emit( AgentSessionEventTypes.FunctionToolsExecuted, functionToolsExecutedEvent, ); - let draining = this.draining; + let schedulingPaused = this.schedulingPaused; if (!ignoreTaskSwitch && newAgentTask !== null) { this.agentSession.updateAgent(newAgentTask); - draining = true; + schedulingPaused = true; } const toolMessages = [ @@ -1725,11 +1796,12 @@ export class AgentActivity implements RecognitionHooks { // Avoid setting tool_choice to "required" or a specific function when // passing tool response back to the LLM - const respondToolChoice = draining || modelSettings.toolChoice === 'none' ? 'none' : 'auto'; + const respondToolChoice = + schedulingPaused || modelSettings.toolChoice === 'none' ? 'none' : 'auto'; // Reuse same speechHandle for tool response (parity with Python agent_activity.py L2122-2140) const toolResponseTask = this.createSpeechTask({ - task: Task.from(() => + taskFn: () => this.pipelineReplyTask( speechHandle, chatCtx, @@ -1740,12 +1812,11 @@ export class AgentActivity implements RecognitionHooks { undefined, toolMessages, ), - ), ownedSpeechHandle: speechHandle, name: 'AgentActivity.pipelineReply', }); - toolResponseTask.finally(() => this.onPipelineReplyDone()); + toolResponseTask.result.finally(() => this.onPipelineReplyDone()); this.scheduleSpeech(speechHandle, SpeechHandle.SPEECH_PRIORITY_NORMAL, true); } else if (functionToolsExecutedEvent.functionCallOutputs.length > 0) { @@ -1753,15 +1824,12 @@ export class AgentActivity implements RecognitionHooks { msg.createdAt = replyStartedAt; } - this.agent._chatCtx.insert(toolMessages); - - // Only add FunctionCallOutput items to session history since FunctionCall items - // were already added by onToolExecutionStarted when the tool execution began const toolCallOutputs = toolMessages.filter( (m): m is FunctionCallOutput => m.type === 'function_call_output', ); if (toolCallOutputs.length > 0) { + this.agent._chatCtx.insert(toolCallOutputs); this.agentSession._toolItemsAdded(toolCallOutputs); } } @@ -2164,50 +2232,18 @@ export class AgentActivity implements RecognitionHooks { return; } - const functionToolsExecutedEvent = createFunctionToolsExecutedEvent({ - functionCalls: [], - functionCallOutputs: [], - }); - let shouldGenerateToolReply: boolean = false; - let newAgentTask: Agent | null = null; - let ignoreTaskSwitch: boolean = false; - - for (const sanitizedOut of toolOutput.output) { - if (sanitizedOut.toolCallOutput !== undefined) { - functionToolsExecutedEvent.functionCallOutputs.push(sanitizedOut.toolCallOutput); - if (sanitizedOut.replyRequired) { - shouldGenerateToolReply = true; - } - } - - if (newAgentTask !== null && sanitizedOut.agentTask !== undefined) { - this.logger.error('expected to receive only one agent task from the tool executions'); - ignoreTaskSwitch = true; - } - - newAgentTask = sanitizedOut.agentTask ?? null; - - this.logger.debug( - { - speechId: speechHandle.id, - name: sanitizedOut.toolCall?.name, - args: sanitizedOut.toolCall.args, - output: sanitizedOut.toolCallOutput?.output, - isError: sanitizedOut.toolCallOutput?.isError, - }, - 'Tool call execution finished', - ); - } + const { functionToolsExecutedEvent, shouldGenerateToolReply, newAgentTask, ignoreTaskSwitch } = + this.summarizeToolExecutionOutput(toolOutput, speechHandle); this.agentSession.emit( AgentSessionEventTypes.FunctionToolsExecuted, functionToolsExecutedEvent, ); - let draining = this.draining; + let schedulingPaused = this.schedulingPaused; if (!ignoreTaskSwitch && newAgentTask !== null) { this.agentSession.updateAgent(newAgentTask); - draining = true; + schedulingPaused = true; } if (functionToolsExecutedEvent.functionCallOutputs.length > 0) { @@ -2263,15 +2299,14 @@ export class AgentActivity implements RecognitionHooks { }), ); - const toolChoice = draining || modelSettings.toolChoice === 'none' ? 'none' : 'auto'; + const toolChoice = schedulingPaused || modelSettings.toolChoice === 'none' ? 'none' : 'auto'; this.createSpeechTask({ - task: Task.from((abortController: AbortController) => + taskFn: (abortController: AbortController) => this.realtimeReplyTask({ speechHandle: replySpeechHandle, modelSettings: { toolChoice }, abortController, }), - ), ownedSpeechHandle: replySpeechHandle, name: 'AgentActivity.realtime_reply', }); @@ -2279,6 +2314,53 @@ export class AgentActivity implements RecognitionHooks { this.scheduleSpeech(replySpeechHandle, SpeechHandle.SPEECH_PRIORITY_NORMAL, true); } + private summarizeToolExecutionOutput(toolOutput: ToolOutput, speechHandle: SpeechHandle) { + const functionToolsExecutedEvent = createFunctionToolsExecutedEvent({ + functionCalls: [], + functionCallOutputs: [], + }); + + let shouldGenerateToolReply = false; + let newAgentTask: Agent | null = null; + let ignoreTaskSwitch = false; + + for (const sanitizedOut of toolOutput.output) { + if (sanitizedOut.toolCallOutput !== undefined) { + // Keep event payload symmetric for pipeline + realtime paths. + functionToolsExecutedEvent.functionCalls.push(sanitizedOut.toolCall); + functionToolsExecutedEvent.functionCallOutputs.push(sanitizedOut.toolCallOutput); + if (sanitizedOut.replyRequired) { + shouldGenerateToolReply = true; + } + } + + if (newAgentTask !== null && sanitizedOut.agentTask !== undefined) { + this.logger.error('expected to receive only one agent task from the tool executions'); + ignoreTaskSwitch = true; + } + + newAgentTask = sanitizedOut.agentTask ?? null; + + this.logger.debug( + { + speechId: speechHandle.id, + name: sanitizedOut.toolCall?.name, + args: sanitizedOut.toolCall.args, + output: sanitizedOut.toolCallOutput?.output, + isError: sanitizedOut.toolCallOutput?.isError, + }, + 'Tool call execution finished', + ); + } + + return { + functionToolsExecutedEvent, + shouldGenerateToolReply, + newAgentTask, + ignoreTaskSwitch, + }; + } + private async realtimeReplyTask({ speechHandle, modelSettings: { toolChoice }, @@ -2337,10 +2419,10 @@ export class AgentActivity implements RecognitionHooks { priority: number, force: boolean = false, ): void { - // when force=true, we allow tool responses to bypass draining + // when force=true, we allow tool responses to bypass scheduling pause // This allows for tool responses to be generated before the AgentActivity is finalized - if (this.draining && !force) { - throw new Error('cannot schedule new speech, the agent is draining'); + if (this.schedulingPaused && !force) { + throw new Error('cannot schedule new speech, the speech scheduling is draining/pausing'); } // Monotonic time to avoid near 0 collisions @@ -2349,6 +2431,48 @@ export class AgentActivity implements RecognitionHooks { this.wakeupMainTask(); } + private async _pauseSchedulingTask(blockedTasks: Task[]): Promise { + if (this._schedulingPaused) return; + + this._schedulingPaused = true; + this._drainBlockedTasks = blockedTasks; + this.wakeupMainTask(); + + if (this._mainTask) { + // When pausing/draining, we ensure that all speech_tasks complete fully. + // This means that even if the SpeechHandle themselves have finished, + // we still wait for the entire execution (e.g function_tools) + await this._mainTask.result; + } + } + + private _resumeSchedulingTask(): void { + if (!this._schedulingPaused) return; + + this._schedulingPaused = false; + this._mainTask = Task.from(({ signal }) => this.mainTask(signal)); + } + + async pause(options: { blockedTasks?: Task[] } = {}): Promise { + const { blockedTasks = [] } = options; + const unlock = await this.lock.lock(); + + try { + const span = tracer.startSpan({ + name: 'pause_agent_activity', + attributes: { [traceTypes.ATTR_AGENT_LABEL]: this.agent.id }, + }); + try { + await this._pauseSchedulingTask(blockedTasks); + await this._closeSessionResources(); + } finally { + span.end(); + } + } finally { + unlock(); + } + } + async drain(): Promise { // Create drain_agent_activity as a ROOT span (new trace) to match Python behavior return tracer.startActiveSpan(async (span) => this._drainImpl(span), { @@ -2362,23 +2486,22 @@ export class AgentActivity implements RecognitionHooks { const unlock = await this.lock.lock(); try { - if (this._draining) return; + if (this._schedulingPaused) return; - this.cancelPreemptiveGeneration(); - - const onExitTask = tracer.startActiveSpan(async () => this.agent.onExit(), { - name: 'on_exit', - attributes: { [traceTypes.ATTR_AGENT_LABEL]: this.agent.id }, - }); - - this.createSpeechTask({ - task: Task.from(() => onExitTask), + this._onExitTask = this.createSpeechTask({ + taskFn: () => + tracer.startActiveSpan(async () => this.agent.onExit(), { + name: 'on_exit', + attributes: { [traceTypes.ATTR_AGENT_LABEL]: this.agent.id }, + }), + inlineTask: true, name: 'AgentActivity_onExit', }); - this.wakeupMainTask(); - this._draining = true; - await this._mainTask?.result; + this.cancelPreemptiveGeneration(); + + await this._onExitTask.result; + await this._pauseSchedulingTask([]); } finally { unlock(); } @@ -2387,44 +2510,59 @@ export class AgentActivity implements RecognitionHooks { async close(): Promise { const unlock = await this.lock.lock(); try { - if (!this._draining) { - this.logger.warn('task closing without draining'); - } - this.cancelPreemptiveGeneration(); - // Unregister event handlers to prevent duplicate metrics - if (this.llm instanceof LLM) { - this.llm.off('metrics_collected', this.onMetricsCollected); - } - if (this.realtimeSession) { - this.realtimeSession.off('generation_created', this.onGenerationCreated); - this.realtimeSession.off('input_speech_started', this.onInputSpeechStarted); - this.realtimeSession.off('input_speech_stopped', this.onInputSpeechStopped); - this.realtimeSession.off( - 'input_audio_transcription_completed', - this.onInputAudioTranscriptionCompleted, - ); - this.realtimeSession.off('metrics_collected', this.onMetricsCollected); - } - if (this.stt instanceof STT) { - this.stt.off('metrics_collected', this.onMetricsCollected); - } - if (this.tts instanceof TTS) { - this.tts.off('metrics_collected', this.onMetricsCollected); - } - if (this.vad instanceof VAD) { - this.vad.off('metrics_collected', this.onMetricsCollected); + await this._closeSessionResources(); + + if (this._mainTask) { + await this._mainTask.cancelAndWait(); } - this.detachAudioInput(); - this.realtimeSpans?.clear(); - await this.realtimeSession?.close(); - await this.audioRecognition?.close(); - await this._mainTask?.cancelAndWait(); + this.agent._agentActivity = undefined; } finally { unlock(); } } + + private async _closeSessionResources(): Promise { + // Unregister event handlers to prevent duplicate metrics + if (this.llm instanceof LLM) { + this.llm.off('metrics_collected', this.onMetricsCollected); + this.llm.off('error', this.onModelError); + } + + if (this.realtimeSession) { + this.realtimeSession.off('generation_created', this.onRealtimeGenerationCreated); + this.realtimeSession.off('input_speech_started', this.onRealtimeInputSpeechStarted); + this.realtimeSession.off('input_speech_stopped', this.onRealtimeInputSpeechStopped); + this.realtimeSession.off( + 'input_audio_transcription_completed', + this.onRealtimeInputAudioTranscriptionCompleted, + ); + this.realtimeSession.off('metrics_collected', this.onMetricsCollected); + this.realtimeSession.off('error', this.onModelError); + } + + if (this.stt instanceof STT) { + this.stt.off('metrics_collected', this.onMetricsCollected); + this.stt.off('error', this.onModelError); + } + + if (this.tts instanceof TTS) { + this.tts.off('metrics_collected', this.onMetricsCollected); + this.tts.off('error', this.onModelError); + } + + if (this.vad instanceof VAD) { + this.vad.off('metrics_collected', this.onMetricsCollected); + } + + this.detachAudioInput(); + this.realtimeSpans?.clear(); + await this.realtimeSession?.close(); + await this.audioRecognition?.close(); + this.realtimeSession = undefined; + this.audioRecognition = undefined; + } } function toOaiToolChoice(toolChoice: ToolChoice | null): ToolChoice | undefined { diff --git a/agents/src/voice/agent_session.ts b/agents/src/voice/agent_session.ts index b7d7826ed..efa414693 100644 --- a/agents/src/voice/agent_session.ts +++ b/agents/src/voice/agent_session.ts @@ -1,12 +1,14 @@ // SPDX-FileCopyrightText: 2024 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 +import { Mutex } from '@livekit/mutex'; import type { AudioFrame, Room } from '@livekit/rtc-node'; import type { TypedEventEmitter as TypedEmitter } from '@livekit/typed-emitter'; import type { Context, Span } from '@opentelemetry/api'; import { ROOT_CONTEXT, context as otelContext, trace } from '@opentelemetry/api'; import { EventEmitter } from 'node:events'; import type { ReadableStream } from 'node:stream/web'; +import type { z } from 'zod'; import { LLM as InferenceLLM, STT as InferenceSTT, @@ -31,6 +33,7 @@ import { type ResolvedSessionConnectOptions, type SessionConnectOptions, } from '../types.js'; +import { Task } from '../utils.js'; import type { VAD } from '../vad.js'; import type { Agent } from './agent.js'; import { AgentActivity } from './agent_activity.js'; @@ -115,6 +118,13 @@ export type AgentSessionOptions = { connOptions?: SessionConnectOptions; }; +type ActivityTransitionOptions = { + previousActivity?: 'close' | 'pause'; + newActivity?: 'start' | 'resume'; + blockedTasks?: Task[]; + waitOnEnter?: boolean; +}; + export class AgentSession< UserData = UnknownUserData, > extends (EventEmitter as new () => TypedEmitter) { @@ -129,8 +139,10 @@ export class AgentSession< private agent?: Agent; private activity?: AgentActivity; private nextActivity?: AgentActivity; + private updateActivityTask?: Task; private started = false; private userState: UserState = 'listening'; + private readonly activityLock = new Mutex(); /** @internal */ _roomIO?: RoomIO; @@ -360,7 +372,8 @@ export class AgentSession< } // TODO(AJS-265): add shutdown callback to job context - tasks.push(this.updateActivity(this.agent)); + // Initial start does not wait on onEnter + tasks.push(this._updateActivity(this.agent, { waitOnEnter: false })); await Promise.allSettled(tasks); @@ -432,8 +445,34 @@ export class AgentSession< updateAgent(agent: Agent): void { this.agent = agent; - if (this.started) { - this.updateActivity(agent); + if (!this.started) { + return; + } + + const _updateActivityTask = async (oldTask: Task | undefined, agent: Agent) => { + if (oldTask) { + try { + await oldTask.result; + } catch (error) { + this.logger.error(error, 'previous updateAgent transition failed'); + } + } + + await this._updateActivity(agent); + }; + + const oldTask = this.updateActivityTask; + this.updateActivityTask = Task.from( + async () => _updateActivityTask(oldTask, agent), + undefined, + 'AgentSession_updateActivityTask', + ); + + const runState = this._globalRunState; + if (runState) { + // Don't mark the RunResult as done, if there is currently an agent transition happening. + // (used to make sure we're correctly adding the AgentHandoffResult before completion) + runState._watchHandle(this.updateActivityTask); } } @@ -464,24 +503,42 @@ export class AgentSession< throw new Error('AgentSession is not running'); } - const doSay = (activity: AgentActivity) => { + const doSay = (activity: AgentActivity, nextActivity?: AgentActivity) => { + if (activity.schedulingPaused) { + if (!nextActivity) { + throw new Error('AgentSession is closing, cannot use say()'); + } + return nextActivity.say(text, options); + } return activity.say(text, options); }; + const runState = this._globalRunState; + let handle: SpeechHandle; + // attach to the session span if called outside of the AgentSession const activeSpan = trace.getActiveSpan(); if (!activeSpan && this.rootSpanContext) { - return otelContext.with(this.rootSpanContext, () => doSay(this.activity!)); + handle = otelContext.with(this.rootSpanContext, () => + doSay(this.activity!, this.nextActivity), + ); + } else { + handle = doSay(this.activity, this.nextActivity); + } + + if (runState) { + runState._watchHandle(handle); } - return doSay(this.activity); + return handle; } - interrupt() { + interrupt(options?: { force?: boolean }) { if (!this.activity) { throw new Error('AgentSession is not running'); } - return this.activity.interrupt(); + + return this.activity.interrupt(options); } generateReply(options?: { @@ -502,7 +559,7 @@ export class AgentSession< : undefined; const doGenerateReply = (activity: AgentActivity, nextActivity?: AgentActivity) => { - if (activity.draining) { + if (activity.schedulingPaused) { if (!nextActivity) { throw new Error('AgentSession is closing, cannot use generateReply()'); } @@ -542,53 +599,128 @@ export class AgentSession< * result.expect.noMoreEvents(); * ``` * - * @param options - Run options including user input + * @param options - Run options including user input and optional output type * @returns A RunResult that resolves when the agent finishes responding - * - * TODO: Add outputType parameter for typed outputs (parity with Python) */ - run(options: { userInput: string }): RunResult { + run({ + userInput, + outputType, + }: { + userInput: string; + outputType?: z.ZodType; + }): RunResult { if (this._globalRunState && !this._globalRunState.done()) { throw new Error('nested runs are not supported'); } - const runState = new RunResult({ userInput: options.userInput }); + const runState = new RunResult({ + userInput, + outputType, + }); + this._globalRunState = runState; - this.generateReply({ userInput: options.userInput }); + + // Defer generateReply through the activityLock to ensure any in-progress + // activity transition (e.g. AgentTask started from onEnter) completes first. + // TS Task.from starts onEnter synchronously, so the transition may already be + // mid-flight by the time run() is called after session.start() resolves. + // Acquiring and immediately releasing the lock guarantees FIFO ordering: + // the transition's lock section finishes before we route generateReply. + (async () => { + try { + const unlock = await this.activityLock.lock(); + unlock(); + this.generateReply({ userInput }); + } catch (e) { + runState._reject(e instanceof Error ? e : new Error(String(e))); + } + })(); return runState; } - private async updateActivity(agent: Agent): Promise { + /** @internal */ + async _updateActivity(agent: Agent, options: ActivityTransitionOptions = {}): Promise { + const { previousActivity = 'close', newActivity = 'start', blockedTasks = [] } = options; + const waitOnEnter = options.waitOnEnter ?? newActivity === 'start'; + const runWithContext = async () => { - // TODO(AJS-129): add lock to agent activity core lifecycle - this.nextActivity = new AgentActivity(agent, this); + const unlock = await this.activityLock.lock(); + let onEnterTask: Task | undefined; - const previousActivity = this.activity; + try { + this.agent = agent; + const prevActivityObj = this.activity; + + if (newActivity === 'start') { + const prevAgent = prevActivityObj?.agent; + if ( + agent._agentActivity && + // allow updating the same agent that is running + (agent !== prevAgent || previousActivity !== 'close') + ) { + throw new Error('Cannot start agent: an activity is already running'); + } + this.nextActivity = new AgentActivity(agent, this); + } else if (newActivity === 'resume') { + if (!agent._agentActivity) { + throw new Error('Cannot resume agent: no existing activity to resume'); + } + this.nextActivity = agent._agentActivity; + } - if (this.activity) { - await this.activity.drain(); - await this.activity.close(); - } + if (prevActivityObj && prevActivityObj !== this.nextActivity) { + if (previousActivity === 'pause') { + await prevActivityObj.pause({ blockedTasks }); + } else { + await prevActivityObj.drain(); + await prevActivityObj.close(); + } + } - this.activity = this.nextActivity; - this.nextActivity = undefined; + this.activity = this.nextActivity; + this.nextActivity = undefined; - this._chatCtx.insert( - new AgentHandoffItem({ - oldAgentId: previousActivity?.agent.id, + const runState = this._globalRunState; + const handoffItem = new AgentHandoffItem({ + oldAgentId: prevActivityObj?.agent.id, newAgentId: agent.id, - }), - ); - this.logger.debug( - { previousAgentId: previousActivity?.agent.id, newAgentId: agent.id }, - 'Agent handoff inserted into chat context', - ); + }); - await this.activity.start(); + if (runState) { + runState._agentHandoff({ + item: handoffItem, + oldAgent: prevActivityObj?.agent, + newAgent: this.activity!.agent, + }); + } + + this._chatCtx.insert(handoffItem); + this.logger.debug( + { previousAgentId: prevActivityObj?.agent.id, newAgentId: agent.id }, + 'Agent handoff inserted into chat context', + ); + + if (newActivity === 'start') { + await this.activity!.start(); + } else { + await this.activity!.resume(); + } + + onEnterTask = this.activity!._onEnterTask; + + if (this._input.audio) { + this.activity!.attachAudioInput(this._input.audio.stream); + } + } finally { + unlock(); + } - if (this._input.audio) { - this.activity.attachAudioInput(this._input.audio.stream); + if (waitOnEnter) { + if (!onEnterTask) { + throw new Error('expected onEnter task to be available while waitOnEnter=true'); + } + await onEnterTask.result; } }; @@ -842,15 +974,21 @@ export class AgentSession< if (this.activity) { if (!drain) { try { - this.activity.interrupt(); + await this.activity.interrupt({ force: true }).await; } catch (error) { - // TODO(shubhra): force interrupt or wait for it to finish? - // it might be an audio played from the error callback + // Uninterruptible speech can throw during forced interruption. + this.logger.warn({ error }, 'Error interrupting activity'); } } + await this.activity.drain(); // wait any uninterruptible speech to finish await this.activity.currentSpeech?.waitForPlayout(); + + if (reason !== CloseReason.ERROR) { + this.activity.commitUserTurn({ audioDetached: true, throwIfNotReady: false }); + } + try { this.activity.detachAudioInput(); } catch (error) { diff --git a/agents/src/voice/audio_recognition.ts b/agents/src/voice/audio_recognition.ts index 3af042d17..a564a842d 100644 --- a/agents/src/voice/audio_recognition.ts +++ b/agents/src/voice/audio_recognition.ts @@ -768,6 +768,10 @@ export class AudioRecognition { this.logger.debug('User turn committed'); }) .catch((err: unknown) => { + if (err instanceof Error && err.name === 'AbortError') { + this.logger.debug('User turn commit task cancelled'); + return; + } this.logger.error(err, 'Error in user turn commit task:'); }); } diff --git a/agents/src/voice/generation.ts b/agents/src/voice/generation.ts index fd274a66e..1f141ab37 100644 --- a/agents/src/voice/generation.ts +++ b/agents/src/voice/generation.ts @@ -26,7 +26,13 @@ import { IdentityTransform } from '../stream/identity_transform.js'; import { traceTypes, tracer } from '../telemetry/index.js'; import { USERDATA_TIMED_TRANSCRIPT } from '../types.js'; import { Future, Task, shortuuid, toError, waitForAbort } from '../utils.js'; -import { type Agent, type ModelSettings, asyncLocalStorage, isStopResponse } from './agent.js'; +import { + type Agent, + type ModelSettings, + _setActivityTaskInfo, + functionCallStorage, + isStopResponse, +} from './agent.js'; import type { AgentSession } from './agent_session.js'; import { AudioOutput, @@ -719,7 +725,7 @@ export interface _AudioOut { async function forwardAudio( ttsStream: ReadableStream, - audioOuput: AudioOutput, + audioOutput: AudioOutput, out: _AudioOut, signal?: AbortSignal, ): Promise { @@ -733,8 +739,8 @@ async function forwardAudio( }; try { - audioOuput.on(AudioOutput.EVENT_PLAYBACK_STARTED, onPlaybackStarted); - audioOuput.resume(); + audioOutput.on(AudioOutput.EVENT_PLAYBACK_STARTED, onPlaybackStarted); + audioOutput.resume(); while (true) { if (signal?.aborted) { @@ -748,36 +754,36 @@ async function forwardAudio( if ( !out.firstFrameFut.done && - audioOuput.sampleRate && - audioOuput.sampleRate !== frame.sampleRate && + audioOutput.sampleRate && + audioOutput.sampleRate !== frame.sampleRate && !resampler ) { - resampler = new AudioResampler(frame.sampleRate, audioOuput.sampleRate, 1); + resampler = new AudioResampler(frame.sampleRate, audioOutput.sampleRate, 1); } if (resampler) { for (const f of resampler.push(frame)) { - await audioOuput.captureFrame(f); + await audioOutput.captureFrame(f); } } else { - await audioOuput.captureFrame(frame); + await audioOutput.captureFrame(frame); } } if (resampler) { for (const f of resampler.flush()) { - await audioOuput.captureFrame(f); + await audioOutput.captureFrame(f); } } } finally { - audioOuput.off(AudioOutput.EVENT_PLAYBACK_STARTED, onPlaybackStarted); + audioOutput.off(AudioOutput.EVENT_PLAYBACK_STARTED, onPlaybackStarted); if (!out.firstFrameFut.done) { out.firstFrameFut.reject(new Error('audio forwarding cancelled before playback started')); } reader?.releaseLock(); - audioOuput.flush(); + audioOutput.flush(); } } @@ -836,7 +842,7 @@ export function performToolExecutions({ const signal = controller.signal; const reader = toolCallStream.getReader(); - const tasks: Promise[] = []; + const tasks: Task[] = []; while (!signal.aborted) { const { done, value: toolCall } = await reader.read(); if (signal.aborted) break; @@ -929,14 +935,6 @@ export function performToolExecutions({ 'Executing LLM tool call', ); - const toolExecution = asyncLocalStorage.run({ functionCall: toolCall }, async () => { - return await tool.execute(parsedArgs, { - ctx: new RunContext(session, speechHandle, toolCall), - toolCallId: toolCall.callId, - abortSignal: signal, - }); - }); - const _tracableToolExecutionImpl = async (toolExecTask: Promise, span: Span) => { span.setAttribute(traceTypes.ATTR_FUNCTION_TOOL_NAME, toolCall.name); span.setAttribute(traceTypes.ATTR_FUNCTION_TOOL_ARGS, toolCall.args); @@ -993,11 +991,42 @@ export function performToolExecutions({ name: 'function_tool', }); + const toolTask = Task.from( + async () => { + // Ensure this task is marked inline before user tool code executes. + const currentTask = Task.current(); + if (currentTask) { + _setActivityTaskInfo(currentTask, { + speechHandle, + functionCall: toolCall, + inlineTask: true, + }); + } + + const toolExecution = functionCallStorage.run({ functionCall: toolCall }, async () => { + return await tool.execute(parsedArgs, { + ctx: new RunContext(session, speechHandle, toolCall), + toolCallId: toolCall.callId, + abortSignal: signal, + }); + }); + + await tracableToolExecution(toolExecution); + }, + controller, + `performToolExecution:${toolCall.name}`, + ); + + _setActivityTaskInfo(toolTask, { + speechHandle, + functionCall: toolCall, + inlineTask: true, + }); // wait, not cancelling all tool calling tasks - tasks.push(tracableToolExecution(toolExecution)); + tasks.push(toolTask); } - await Promise.allSettled(tasks); + await Promise.allSettled(tasks.map((task) => task.result)); if (toolOutput.output.length > 0) { logger.debug( { diff --git a/agents/src/voice/index.ts b/agents/src/voice/index.ts index 655e846b6..947013336 100644 --- a/agents/src/voice/index.ts +++ b/agents/src/voice/index.ts @@ -1,7 +1,7 @@ // SPDX-FileCopyrightText: 2025 LiveKit, Inc. // // SPDX-License-Identifier: Apache-2.0 -export { Agent, StopResponse, type AgentOptions, type ModelSettings } from './agent.js'; +export { Agent, AgentTask, StopResponse, type AgentOptions, type ModelSettings } from './agent.js'; export { AgentSession, type AgentSessionOptions, type VoiceOptions } from './agent_session.js'; export * from './avatar/index.js'; export * from './background_audio.js'; diff --git a/agents/src/voice/speech_handle.ts b/agents/src/voice/speech_handle.ts index e491b3b99..a3cde5aa6 100644 --- a/agents/src/voice/speech_handle.ts +++ b/agents/src/voice/speech_handle.ts @@ -5,7 +5,7 @@ import type { Context } from '@opentelemetry/api'; import type { ChatItem } from '../llm/index.js'; import type { Task } from '../utils.js'; import { Event, Future, shortuuid } from '../utils.js'; -import { asyncLocalStorage } from './agent.js'; +import { functionCallStorage } from './agent.js'; /** Symbol used to identify SpeechHandle instances */ const SPEECH_HANDLE_SYMBOL = Symbol.for('livekit.agents.SpeechHandle'); @@ -46,6 +46,9 @@ export class SpeechHandle { /** @internal - OpenTelemetry context for the agent turn span */ _agentTurnContext?: Context; + /** @internal - used by AgentTask/RunResult final output plumbing */ + _maybeRunFinalOutput?: unknown; + private itemAddedCallbacks: Set<(item: ChatItem) => void> = new Set(); private doneCallbacks: Set<(sh: SpeechHandle) => void> = new Set(); @@ -148,7 +151,7 @@ export class SpeechHandle { * has entirely played out, including any tool calls and response follow-ups. */ async waitForPlayout(): Promise { - const store = asyncLocalStorage.getStore(); + const store = functionCallStorage.getStore(); if (store && store?.functionCall) { throw new Error( `Cannot call 'SpeechHandle.waitForPlayout()' from inside the function tool '${store.functionCall.name}'. ` + @@ -167,6 +170,10 @@ export class SpeechHandle { } addDoneCallback(callback: (sh: SpeechHandle) => void) { + if (this.done()) { + queueMicrotask(() => callback(this)); + return; + } this.doneCallbacks.add(callback); } diff --git a/agents/src/voice/testing/run_result.ts b/agents/src/voice/testing/run_result.ts index ea9f1d994..4ee0ccc56 100644 --- a/agents/src/voice/testing/run_result.ts +++ b/agents/src/voice/testing/run_result.ts @@ -30,6 +30,8 @@ import { // Type for agent constructor (used in assertions) // eslint-disable-next-line @typescript-eslint/no-explicit-any type AgentConstructor = new (...args: any[]) => Agent; +// In JS we use a zod schema so runtime validation and TS generic inference stay aligned. +type OutputSchema = z.ZodType; // Environment variable for verbose output const evalsVerbose = parseInt(process.env.LIVEKIT_EVALS_VERBOSE || '0', 10); @@ -48,19 +50,21 @@ export class RunResult { private _events: RunEvent[] = []; private doneFut = new Future(); private userInput?: string; + private outputType?: OutputSchema; + private finalOutputValue?: T; + private hasFinalOutput = false; private handles: Set> = new Set(); private lastSpeechHandle?: SpeechHandle; private runAssert?: RunAssert; + // Store per-handle closures so _unwatchHandle can remove callbacks symmetrically. + private doneCallbacks = new Map, () => void>(); - // TODO(brian): Add typed output support for parity with Python - // - Add outputType?: new (...args: unknown[]) => T - // - Add finalOutput?: T - // - Implement markDone() to extract final_output from SpeechHandle.maybeRunFinalOutput - // - See Python: run_result.py lines 182-201 + private readonly itemAddedCallback = (item: ChatItem) => this._itemAdded(item); - constructor(options?: { userInput?: string }) { + constructor(options?: { userInput?: string; outputType?: OutputSchema }) { this.userInput = options?.userInput; + this.outputType = options?.outputType; } /** @@ -92,12 +96,17 @@ export class RunResult { /** * Returns the final output of the run after completion. - * - * @throws Error - Not implemented yet. */ get finalOutput(): T { - // TODO(brian): Implement typed output support after AgentTask is implemented. - throw new Error('finalOutput is not yet implemented in JS.'); + if (!this.doneFut.done) { + throw new Error('cannot retrieve finalOutput, RunResult is not done'); + } + + if (!this.hasFinalOutput) { + throw new Error('no final output'); + } + + return this.finalOutputValue as T; } /** @@ -167,15 +176,18 @@ export class RunResult { * Watch a speech handle or task for completion. */ _watchHandle(handle: SpeechHandle | Task): void { + if (this.handles.has(handle)) return; + this.handles.add(handle); if (isSpeechHandle(handle)) { - handle._addItemAddedCallback(this._itemAdded.bind(this)); + handle._addItemAddedCallback(this.itemAddedCallback); } - handle.addDoneCallback(() => { - this._markDoneIfNeeded(handle); - }); + const doneCallback = () => this._markDoneIfNeeded(handle); + + this.doneCallbacks.set(handle, doneCallback); + handle.addDoneCallback(doneCallback); } /** @@ -184,31 +196,77 @@ export class RunResult { */ _unwatchHandle(handle: SpeechHandle | Task): void { this.handles.delete(handle); + const doneCallback = this.doneCallbacks.get(handle); + + if (doneCallback) { + handle.removeDoneCallback(doneCallback); + this.doneCallbacks.delete(handle); + } if (isSpeechHandle(handle)) { - handle._removeItemAddedCallback(this._itemAdded.bind(this)); + handle._removeItemAddedCallback(this.itemAddedCallback); + } + } + + /** @internal */ + _watchedHandleCount(): number { + return this.handles.size; + } + + /** @internal – Reject the run with an error (e.g. when deferred generateReply fails). */ + _reject(error: Error): void { + if (!this.doneFut.done) { + this.doneFut.reject(error); } } - private _markDoneIfNeeded(handle: SpeechHandle | Task): void { + /** @internal */ + _markDoneIfNeeded(handle?: SpeechHandle | Task | null): void { if (isSpeechHandle(handle)) { this.lastSpeechHandle = handle; } - if ([...this.handles].every((h) => (isSpeechHandle(h) ? h.done() : h.done))) { + const allDone = [...this.handles].every((h) => (isSpeechHandle(h) ? h.done() : h.done)); + if (allDone) { this._markDone(); } } private _markDone(): void { - // TODO(brian): Implement final output support after AgentTask is implemented. - // See Python run_result.py _mark_done() for reference: - // - Check lastSpeechHandle._maybeRunFinalOutput - // - Validate output type matches expected type - // - Set exception or resolve based on output - if (!this.doneFut.done) { + if (this.doneFut.done) { + return; + } + + if (!this.lastSpeechHandle) { this.doneFut.resolve(); + return; + } + + const finalOutput = this.lastSpeechHandle._maybeRunFinalOutput; + if (finalOutput instanceof Error) { + this.doneFut.reject(finalOutput); + return; + } + + if (this.outputType) { + const result = this.outputType.safeParse(finalOutput); + if (!result.success) { + this.doneFut.reject( + new Error(`Expected output matching provided zod schema: ${result.error.message}`), + ); + return; + } + this.finalOutputValue = result.data; + this.hasFinalOutput = true; + this.doneFut.resolve(); + return; + } + + if (finalOutput !== undefined) { + this.finalOutputValue = finalOutput as T; + this.hasFinalOutput = true; } + this.doneFut.resolve(); } /** diff --git a/examples/src/basic_agent_task.ts b/examples/src/basic_agent_task.ts new file mode 100644 index 000000000..e01d24752 --- /dev/null +++ b/examples/src/basic_agent_task.ts @@ -0,0 +1,134 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { + type JobContext, + type JobProcess, + ServerOptions, + cli, + defineAgent, + inference, + llm, + voice, +} from '@livekit/agents'; +import * as openai from '@livekit/agents-plugin-openai'; +import * as silero from '@livekit/agents-plugin-silero'; +import { fileURLToPath } from 'node:url'; +import { z } from 'zod'; + +class InfoTask extends voice.AgentTask { + constructor(info: string) { + super({ + instructions: `Collect the user's information. around ${info}. Once you have the information, call the saveUserInfo tool to save the information to the database IMMEDIATELY. DO NOT have chitchat with the user, just collect the information and call the saveUserInfo tool.`, + tts: 'elevenlabs/eleven_turbo_v2_5', + tools: { + saveUserInfo: llm.tool({ + description: `Save the user's ${info} to database`, + parameters: z.object({ + [info]: z.string(), + }), + execute: async (args) => { + this.complete(args[info] as string); + return `Thanks, collected ${info} successfully: ${args[info]}`; + }, + }), + }, + }); + } + + async onEnter() { + this.session.generateReply({ + userInput: 'Ask the user for their ${info}', + }); + } +} + +class SurveyAgent extends voice.Agent { + constructor() { + super({ + instructions: + 'You orchestrate a short intro survey. Speak naturally and keep the interaction brief.', + tools: { + collectUserInfo: llm.tool({ + description: 'Call this when user want to provide some information to you', + parameters: z.object({ + key: z + .string() + .describe( + 'The key of the information to collect, e.g. "name" or "role" should be no space and underscore separated', + ), + }), + execute: async ({ key }) => { + const value = await new InfoTask(key).run(); + return `Collected ${key} successfully: ${value}`; + }, + }), + transferToWeatherAgent: llm.tool({ + description: 'Call this immediately after user want to know the weather', + execute: async () => { + const agent = new voice.Agent({ + instructions: + 'You are a weather agent. You are responsible for providing the weather information to the user.', + tts: 'deepgram/aura-2', + tools: { + getWeather: llm.tool({ + description: 'Get the weather for a given location', + parameters: z.object({ + location: z.string().describe('The location to get the weather for'), + }), + execute: async ({ location }) => { + return `The weather in ${location} is sunny today.`; + }, + }), + finishWeatherConversation: llm.tool({ + description: 'Call this when you want to finish the weather conversation', + execute: async () => { + return llm.handoff({ + agent: new SurveyAgent(), + returns: 'Transfer to survey agent successfully!', + }); + }, + }), + }, + }); + + return llm.handoff({ agent, returns: "Let's start the weather conversation!" }); + }, + }), + }, + }); + } + + async onEnter() { + const name = await new InfoTask('name').run(); + const role = await new InfoTask('role').run(); + + await this.session.say( + `Great to meet you ${name}. I noted your role as ${role}. We can continue now.`, + ); + } +} + +export default defineAgent({ + prewarm: async (proc: JobProcess) => { + proc.userData.vad = await silero.VAD.load(); + }, + entry: async (ctx: JobContext) => { + const session = new voice.AgentSession({ + vad: ctx.proc.userData.vad as silero.VAD, + stt: new inference.STT({ model: 'deepgram/nova-3' }), + llm: new openai.responses.LLM(), + tts: new inference.TTS({ + model: 'cartesia/sonic-3', + voice: '9626c31c-bec5-4cca-baa8-f8ba9e84c8bc', + }), + }); + + await session.start({ + room: ctx.room, + agent: new SurveyAgent(), + }); + }, +}); + +cli.runApp(new ServerOptions({ agent: fileURLToPath(import.meta.url) })); diff --git a/examples/src/testing/agent_task.test.ts b/examples/src/testing/agent_task.test.ts new file mode 100644 index 000000000..636a8e21d --- /dev/null +++ b/examples/src/testing/agent_task.test.ts @@ -0,0 +1,389 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { Future, initializeLogger, llm, voice } from '@livekit/agents'; +import * as openai from '@livekit/agents-plugin-openai'; +import { afterEach, describe, expect, it } from 'vitest'; +import { z } from 'zod'; + +initializeLogger({ pretty: true, level: 'warn' }); + +/** + * AgentTask scenario coverage: + * + * 1. Agent -> onEnter -> AgentTask -> onEnter -> self.complete + * COVERED: "agent calls a task in onEnter" (WelcomeTask) + * + * 2. Agent -> onEnter -> AgentTask -> onEnter -> generateReply -> User -> Tool -> self.complete + * NOT TESTABLE: session.run() rejects with "speech scheduling draining" when task is started + * from onEnter. Works in production (basic_agent_task.ts) with real voice/STT. + * Tool-triggered variant COVERED: "LLM-powered IntroTask", "LLM-powered GetEmailTask" + * + * 3. Agent -> Tool Call -> AgentTask -> User message -> Tool Call -> self.complete + * COVERED: "LLM-powered IntroTask", "LLM-powered GetEmailTask" + * + * 4. Agent -> Tool handoff -> onExit -> AgentTask -> self.complete -> handoff target + * DEADLOCK: AgentTask.run() from onExit during updateAgent transition holds activity lock. + * NOT COVERED in this suite due to known deadlock limitation. + */ + +function asError(error: unknown): Error { + return error instanceof Error ? error : new Error(String(error)); +} + +async function withFutureResolution(done: Future, fn: () => Promise): Promise { + try { + done.resolve(await fn()); + } catch (error) { + done.reject(asError(error)); + } +} + +function createOpenAILLM(): openai.LLM { + return new openai.LLM({ model: 'gpt-4o-mini', temperature: 0 }); +} + +async function runAndWait(session: voice.AgentSession, userInput: string) { + const result = session.run({ userInput }); + await result.wait(); + return result; +} + +describe('AgentTask examples', { timeout: 120_000 }, () => { + const sessions: voice.AgentSession[] = []; + + afterEach(async () => { + await Promise.allSettled(sessions.map((s) => s.close())); + sessions.length = 0; + }); + + async function startSession(agent: voice.Agent, options?: { llm?: openai.LLM }) { + const session = new voice.AgentSession({ llm: options?.llm }); + sessions.push(session); + await session.start({ agent }); + return session; + } + + it('agent calls a task in onEnter', async () => { + const done = new Future(); + + class WelcomeTask extends voice.AgentTask { + constructor() { + super({ instructions: 'Collect a welcome token and finish quickly.' }); + } + + async onEnter() { + this.complete('welcome-token'); + } + } + + class ParentAgent extends voice.Agent { + constructor() { + super({ instructions: 'Parent agent used for AgentTask lifecycle tests.' }); + } + + async onEnter() { + await withFutureResolution(done, async () => new WelcomeTask().run()); + } + } + + await startSession(new ParentAgent()); + await expect(done.await).resolves.toBe('welcome-token'); + }); + + it('agent calls two tasks in onEnter', async () => { + const done = new Future<{ first: number; second: number; order: string[] }>(); + + class FirstTask extends voice.AgentTask { + constructor() { + super({ instructions: 'Return first value.' }); + } + + async onEnter() { + this.complete(1); + } + } + + class SecondTask extends voice.AgentTask { + constructor() { + super({ instructions: 'Return second value.' }); + } + + async onEnter() { + this.complete(2); + } + } + + class ParentAgent extends voice.Agent { + constructor() { + super({ instructions: 'Parent agent for sequential task orchestration.' }); + } + + async onEnter() { + await withFutureResolution(done, async () => { + const order: string[] = []; + const first = await new FirstTask().run(); + order.push('first'); + const second = await new SecondTask().run(); + order.push('second'); + return { first, second, order }; + }); + } + } + + await startSession(new ParentAgent()); + await expect(done.await).resolves.toEqual({ + first: 1, + second: 2, + order: ['first', 'second'], + }); + }); + + const itIfOpenAI = process.env.OPENAI_API_KEY ? it : it.skip; + + // Scenario 2: Agent onEnter -> AgentTask -> onEnter -> generateReply -> User -> Tool -> self.complete + itIfOpenAI( + 'scenario 2: onEnter AgentTask with generateReply then user input via run()', + async () => { + const done = new Future<{ name: string; role: string }>(); + + class IntroTask extends voice.AgentTask<{ name: string; role: string }> { + constructor() { + super({ + instructions: + 'You are collecting a name and role. Extract both from user input and call recordIntro.', + tools: { + recordIntro: llm.tool({ + description: 'Record the name and role', + parameters: z.object({ + name: z.string().describe('User name'), + role: z.string().describe('User role'), + }), + execute: async ({ name, role }) => { + this.complete({ name, role }); + return 'recorded'; + }, + }), + }, + }); + } + + async onEnter() { + this.session.generateReply({ + instructions: 'Ask the user for their name and role.', + }); + } + } + + class ParentAgent extends voice.Agent { + constructor() { + super({ instructions: 'Parent agent that launches IntroTask on enter.' }); + } + + async onEnter() { + await withFutureResolution(done, async () => new IntroTask().run()); + } + } + + const llmModel = createOpenAILLM(); + const session = await startSession(new ParentAgent(), { llm: llmModel }); + + let result = await runAndWait(session, "I'm Sam and I'm a frontend engineer."); + + const taskResult = await done.await; + result.expect.containsFunctionCall({ name: 'recordIntro' }); + expect(taskResult.name.toLowerCase()).toContain('sam'); + expect(taskResult.role.toLowerCase()).toMatch(/frontend/); + + result = await runAndWait(session, 'What is my name and role?'); + result.expect + .nextEvent() + .isMessage({ role: 'assistant' }) + .judge(llmModel, { intent: 'should answer name as Sam and role as frontend engineer' }); + }, + ); + + itIfOpenAI( + 'agent calls a task in a tool; resuming previous activity does not execute onEnter again', + async () => { + let parentOnEnterCount = 0; + let taskOnEnterCount = 0; + let toolCallCount = 0; + + class GetEmailAddressTask extends voice.AgentTask { + constructor() { + super({ instructions: 'Capture an email address and complete.' }); + } + + async onEnter() { + taskOnEnterCount += 1; + this.complete('alice@example.com'); + } + } + + class ToolAgent extends voice.Agent { + constructor() { + super({ + instructions: + 'When asked to capture email, ALWAYS call captureEmail exactly once, then respond briefly.', + tools: { + captureEmail: llm.tool({ + description: 'Capture an email by running a nested AgentTask.', + parameters: z.object({}), + execute: async () => { + toolCallCount += 1; + try { + const email = await new GetEmailAddressTask().run(); + return `captured:${email}`; + } catch (error) { + throw error; + } + }, + }), + }, + }); + } + + async onEnter() { + parentOnEnterCount += 1; + } + } + + const llmModel = createOpenAILLM(); + const session = await startSession(new ToolAgent(), { llm: llmModel }); + const result = await runAndWait(session, 'Please capture my email using your tool.'); + + result.expect.containsFunctionCall({ name: 'captureEmail' }); + result.expect.containsAgentHandoff({ newAgentType: GetEmailAddressTask }); + result.expect.containsFunctionCallOutput({ + isError: false, + }); + result.expect.containsMessage({ role: 'assistant' }).judge(llmModel, { + intent: 'should answer email captured, not necessarily need to state the email address', + }); + + expect(toolCallCount).toBe(1); + expect(taskOnEnterCount).toBe(1); + expect(parentOnEnterCount).toBe(1); + }, + ); + + itIfOpenAI('IntroTask records intro details', async () => { + let introTaskResult: { name: string; intro: string } | undefined; + let runIntroTaskCalls = 0; + let recordIntroToolCalls = 0; + + class IntroTask extends voice.AgentTask<{ name: string; intro: string }> { + constructor() { + super({ + instructions: + 'You are Alex, an interviewer. Extract the candidate name and a short intro from the latest user input. ' + + 'Use the tool recordIntro exactly once when both are available.', + tools: { + recordIntro: llm.tool({ + description: 'Record candidate name and intro summary.', + parameters: z.object({ + name: z.string().describe('Candidate name'), + introNotes: z.string().describe('A concise candidate intro summary'), + }), + execute: async ({ name, introNotes }) => { + recordIntroToolCalls += 1; + this.complete({ name, intro: introNotes }); + return 'Intro recorded.'; + }, + }), + }, + }); + } + + async onEnter() { + this.session.generateReply({ + instructions: + 'Ask the user for name and intro if missing, then call recordIntro with concise values.', + }); + } + } + + class ParentAgent extends voice.Agent { + constructor() { + super({ + instructions: + 'When the user asks to run the intro task, ALWAYS call collectIntroWithTask exactly once.', + tools: { + collectIntroWithTask: llm.tool({ + description: 'Launch the IntroTask and return the captured intro details.', + parameters: z.object({}), + execute: async () => { + runIntroTaskCalls += 1; + const result = await new IntroTask().run(); + introTaskResult = result; + return JSON.stringify(result); + }, + }), + }, + }); + } + } + + const llmModel = createOpenAILLM(); + const session = await startSession(new ParentAgent(), { llm: llmModel }); + const triggerRun = await runAndWait(session, 'Please run the intro task.'); + triggerRun.expect.containsFunctionCall({ name: 'collectIntroWithTask' }); + triggerRun.expect.containsMessage({ role: 'assistant' }).judge(llmModel, { + intent: 'Ask the user for name and intro', + }); + + const answerRun = await runAndWait( + session, + "I'm Morgan, and I'm a backend engineer focused on APIs.", + ); + answerRun.expect.containsAgentHandoff({ newAgentType: ParentAgent }); + + expect(runIntroTaskCalls).toBe(1); + expect(recordIntroToolCalls).toBeGreaterThanOrEqual(1); + expect(introTaskResult).toBeDefined(); + expect(introTaskResult!.name.toLowerCase()).toContain('morgan'); + expect(introTaskResult!.intro.toLowerCase()).toMatch(/backend|api/); + }); + + it('AgentTask instance is non-reentrant (edge case)', async () => { + const done = new Future<{ first: string; secondRunError: string }>(); + + class SingleUseTask extends voice.AgentTask { + constructor() { + super({ instructions: 'Single-use AgentTask edge case.' }); + } + + async onEnter() { + this.complete('ok'); + } + } + + class ParentAgent extends voice.Agent { + constructor() { + super({ instructions: 'Agent validating AgentTask re-entrancy behavior.' }); + } + + async onEnter() { + await withFutureResolution(done, async () => { + const task = new SingleUseTask(); + const first = await task.run(); + let secondRunError = ''; + + try { + await task.run(); + } catch (error) { + secondRunError = error instanceof Error ? error.message : String(error); + } + + return { first, secondRunError }; + }); + } + } + + await startSession(new ParentAgent()); + const result = await done.await; + expect(result.first).toBe('ok'); + expect(result.secondRunError).toContain('cannot be awaited multiple times'); + }); +}); diff --git a/examples/src/testing/run_result.test.ts b/examples/src/testing/run_result.test.ts index 66af0e322..583cbaffa 100644 --- a/examples/src/testing/run_result.test.ts +++ b/examples/src/testing/run_result.test.ts @@ -230,8 +230,9 @@ describe('RunResult', { timeout: 120_000 }, () => { const result = session.run({ userInput: "What's the weather in London?" }); await result.wait(); - // Skip function_call and function_call_output - result.expect.skipNext(2); + // Skip all events except the last (assistant message); LLM may emit 1+ function_call pairs + const n = result.events.length; + result.expect.skipNext(n - 1); result.expect.nextEvent().isMessage({ role: 'assistant' }); result.expect.noMoreEvents(); }); diff --git a/plugins/google/src/beta/realtime/realtime_api.ts b/plugins/google/src/beta/realtime/realtime_api.ts index 7f0b3d33c..90866729c 100644 --- a/plugins/google/src/beta/realtime/realtime_api.ts +++ b/plugins/google/src/beta/realtime/realtime_api.ts @@ -309,6 +309,7 @@ export class RealtimeModel extends llm.RealtimeModel { userTranscription: inputAudioTranscription !== null, autoToolReplyGeneration: true, audioOutput: options.modalities?.includes(Modality.AUDIO) ?? true, + manualFunctionCalls: false, }); // Environment variable fallbacks diff --git a/plugins/openai/src/realtime/realtime_model.ts b/plugins/openai/src/realtime/realtime_model.ts index 51b28afed..1aaffd014 100644 --- a/plugins/openai/src/realtime/realtime_model.ts +++ b/plugins/openai/src/realtime/realtime_model.ts @@ -175,6 +175,7 @@ export class RealtimeModel extends llm.RealtimeModel { userTranscription: options.inputAudioTranscription !== null, autoToolReplyGeneration: false, audioOutput: modalities.includes('audio'), + manualFunctionCalls: true, }); const isAzure = !!(options.apiVersion || options.entraToken || options.azureDeployment); diff --git a/plugins/openai/src/realtime/realtime_model_beta.ts b/plugins/openai/src/realtime/realtime_model_beta.ts index 2db65104f..19aee2aee 100644 --- a/plugins/openai/src/realtime/realtime_model_beta.ts +++ b/plugins/openai/src/realtime/realtime_model_beta.ts @@ -176,6 +176,7 @@ export class RealtimeModel extends llm.RealtimeModel { userTranscription: options.inputAudioTranscription !== null, autoToolReplyGeneration: false, audioOutput: modalities.includes('audio'), + manualFunctionCalls: true, }); const isAzure = !!(options.apiVersion || options.entraToken || options.azureDeployment); diff --git a/plugins/phonic/src/realtime/realtime_model.ts b/plugins/phonic/src/realtime/realtime_model.ts index 96c9a72d5..d17c9cb2a 100644 --- a/plugins/phonic/src/realtime/realtime_model.ts +++ b/plugins/phonic/src/realtime/realtime_model.ts @@ -128,6 +128,7 @@ export class RealtimeModel extends llm.RealtimeModel { // TODO @Phonic-Co: Implement tool support // Phonic has automatic tool reply generation, but tools are not supported with LiveKit Agents yet. autoToolReplyGeneration: true, + manualFunctionCalls: false, audioOutput: true, });