diff --git a/README.md b/README.md index ffbb59c..8833575 100644 --- a/README.md +++ b/README.md @@ -78,6 +78,8 @@ import com.agentclientprotocol.common.Event import com.agentclientprotocol.common.SessionParameters import com.agentclientprotocol.common.remoteSessionOperations import com.agentclientprotocol.model.AgentCapabilities +import com.agentclientprotocol.model.AvailableCommand +import com.agentclientprotocol.model.AvailableCommandInput import com.agentclientprotocol.model.ContentBlock import com.agentclientprotocol.model.LATEST_PROTOCOL_VERSION import com.agentclientprotocol.model.PromptResponse @@ -149,6 +151,18 @@ private class TerminalAgentSupport : AgentSupport { override suspend fun loadSession(sessionId: SessionId, sessionParameters: SessionParameters): AgentSession = // Rehydrate existing sessions with the provided identifier. TerminalAgentSession(sessionId) + + override suspend fun onSessionReady( + session: AgentSession, + sessionParameters: SessionParameters, + client: ClientSessionOperations + ) { + client.notify( + SessionUpdate.AvailableCommandsUpdate( + listOf(AvailableCommand("help", "Show available commands", AvailableCommandInput.Unstructured("topic"))) + ) + ) + } } fun main(): Unit = runBlocking { diff --git a/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/SimpleAgentTest.kt b/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/SimpleAgentTest.kt index 37a2b5e..e70ef86 100644 --- a/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/SimpleAgentTest.kt +++ b/acp-ktor-test/src/commonTest/kotlin/com/agentclientprotocol/SimpleAgentTest.kt @@ -113,6 +113,66 @@ abstract class SimpleAgentTest(protocolDriver: ProtocolDriver) : ProtocolDriver assertEquals(result!!.stopReason, StopReason.END_TURN) } + @Test + @OptIn(UnstableApi::class) + fun `session ready notifies available commands`() = testWithProtocols { clientProtocol, agentProtocol -> + val readyUpdate = CompletableDeferred() + val client = Client(protocol = clientProtocol) + Agent(protocol = agentProtocol, agentSupport = object : AgentSupport { + override suspend fun initialize(clientInfo: ClientInfo): AgentInfo { + return AgentInfo(clientInfo.protocolVersion) + } + + override suspend fun createSession(sessionParameters: SessionCreationParameters): AgentSession { + return object : AgentSession { + override val sessionId: SessionId = SessionId("ready-session-id") + + override suspend fun prompt( + content: List, + _meta: JsonElement?, + ): Flow = emptyFlow() + } + } + + override suspend fun onSessionReady( + session: AgentSession, + sessionParameters: SessionCreationParameters, + client: ClientSessionOperations, + ) { + client.notify( + SessionUpdate.AvailableCommandsUpdate( + listOf(AvailableCommand("help", "Show available commands", AvailableCommandInput.Unstructured("topic"))) + ) + ) + } + }) + + client.initialize(ClientInfo(protocolVersion = 10)) + client.newSession(SessionCreationParameters("/test/path", emptyList())) { _, _ -> + object : ClientSessionOperations { + override suspend fun requestPermissions( + toolCall: SessionUpdate.ToolCallUpdate, + permissions: List, + _meta: JsonElement?, + ): RequestPermissionResponse { + return RequestPermissionResponse(RequestPermissionOutcome.Cancelled) + } + + override suspend fun notify( + notification: SessionUpdate, + _meta: JsonElement?, + ) { + readyUpdate.complete(notification) + } + } + } + + val update = withTimeout(1000) { readyUpdate.await() } + assertTrue(update is SessionUpdate.AvailableCommandsUpdate) + val command = (update as SessionUpdate.AvailableCommandsUpdate).availableCommands.single() + assertEquals("help", command.name) + } + @Test fun `prompt response and update have proper order`() = testWithProtocols { clientProtocol, agentProtocol -> val client = Client(protocol = clientProtocol) @@ -970,4 +1030,4 @@ abstract class SimpleAgentTest(protocolDriver: ProtocolDriver) : ProtocolDriver assertEquals(1, sessions.size) } -} \ No newline at end of file +} diff --git a/acp/api/acp.api b/acp/api/acp.api index 311b163..535ba65 100644 --- a/acp/api/acp.api +++ b/acp/api/acp.api @@ -83,6 +83,8 @@ public abstract interface class com/agentclientprotocol/agent/AgentSupport { public static synthetic fun listSessions$suspendImpl (Lcom/agentclientprotocol/agent/AgentSupport;Ljava/lang/String;Lkotlinx/serialization/json/JsonElement;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public fun loadSession-nk3TnMc (Ljava/lang/String;Lcom/agentclientprotocol/common/SessionCreationParameters;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static synthetic fun loadSession-nk3TnMc$suspendImpl (Lcom/agentclientprotocol/agent/AgentSupport;Ljava/lang/String;Lcom/agentclientprotocol/common/SessionCreationParameters;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public fun onSessionReady (Lcom/agentclientprotocol/agent/AgentSession;Lcom/agentclientprotocol/common/SessionCreationParameters;Lcom/agentclientprotocol/common/ClientSessionOperations;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static synthetic fun onSessionReady$suspendImpl (Lcom/agentclientprotocol/agent/AgentSupport;Lcom/agentclientprotocol/agent/AgentSession;Lcom/agentclientprotocol/common/SessionCreationParameters;Lcom/agentclientprotocol/common/ClientSessionOperations;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public fun resumeSession-nk3TnMc (Ljava/lang/String;Lcom/agentclientprotocol/common/SessionCreationParameters;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static synthetic fun resumeSession-nk3TnMc$suspendImpl (Lcom/agentclientprotocol/agent/AgentSupport;Ljava/lang/String;Lcom/agentclientprotocol/common/SessionCreationParameters;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } @@ -92,6 +94,7 @@ public final class com/agentclientprotocol/agent/AgentSupport$DefaultImpls { public static fun forkSession-nk3TnMc (Lcom/agentclientprotocol/agent/AgentSupport;Ljava/lang/String;Lcom/agentclientprotocol/common/SessionCreationParameters;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static fun listSessions (Lcom/agentclientprotocol/agent/AgentSupport;Ljava/lang/String;Lkotlinx/serialization/json/JsonElement;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static fun loadSession-nk3TnMc (Lcom/agentclientprotocol/agent/AgentSupport;Ljava/lang/String;Lcom/agentclientprotocol/common/SessionCreationParameters;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static fun onSessionReady (Lcom/agentclientprotocol/agent/AgentSupport;Lcom/agentclientprotocol/agent/AgentSession;Lcom/agentclientprotocol/common/SessionCreationParameters;Lcom/agentclientprotocol/common/ClientSessionOperations;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static fun resumeSession-nk3TnMc (Lcom/agentclientprotocol/agent/AgentSupport;Ljava/lang/String;Lcom/agentclientprotocol/common/SessionCreationParameters;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } diff --git a/acp/src/commonMain/kotlin/com/agentclientprotocol/agent/Agent.kt b/acp/src/commonMain/kotlin/com/agentclientprotocol/agent/Agent.kt index 5e876e0..eedae81 100644 --- a/acp/src/commonMain/kotlin/com/agentclientprotocol/agent/Agent.kt +++ b/acp/src/commonMain/kotlin/com/agentclientprotocol/agent/Agent.kt @@ -14,6 +14,7 @@ import kotlinx.atomicfu.update import kotlinx.collections.immutable.persistentMapOf import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.Job import kotlinx.coroutines.currentCoroutineContext import kotlinx.coroutines.withContext import kotlinx.serialization.json.JsonElement @@ -147,6 +148,8 @@ public class Agent( protocol.setRequestHandler(AcpMethod.AgentMethods.SessionNew) { params: NewSessionRequest -> val sessionParameters = SessionCreationParameters(params.cwd, params.mcpServers, params._meta) val session = createSession(sessionParameters) { agentSupport.createSession(it) } + @OptIn(UnstableApi::class) + scheduleSessionReady(session, sessionParameters) @OptIn(UnstableApi::class) return@setRequestHandler NewSessionResponse( @@ -161,6 +164,8 @@ public class Agent( val sessionParameters = SessionCreationParameters(params.cwd, params.mcpServers, params._meta) val session = createSession(sessionParameters) { agentSupport.loadSession(params.sessionId, sessionParameters) } @OptIn(UnstableApi::class) + scheduleSessionReady(session, sessionParameters) + @OptIn(UnstableApi::class) return@setRequestHandler LoadSessionResponse( // maybe unify result of these two methods to have sessionId in both // sessionId = session.sessionId, @@ -173,6 +178,8 @@ public class Agent( protocol.setRequestHandler(AcpMethod.AgentMethods.SessionResume) { params: ResumeSessionRequest -> val sessionParameters = SessionCreationParameters(params.cwd, params.mcpServers, params._meta) val session = createSession(sessionParameters) { agentSupport.resumeSession(params.sessionId, sessionParameters) } + @OptIn(UnstableApi::class) + scheduleSessionReady(session, sessionParameters) return@setRequestHandler ResumeSessionResponse( modes = session.asModeState(), models = session.asModelState() @@ -211,6 +218,8 @@ public class Agent( protocol.setRequestHandler(AcpMethod.AgentMethods.SessionFork) { params: ForkSessionRequest -> val sessionParameters = SessionCreationParameters(params.cwd, params.mcpServers, params._meta) val session = createSession(sessionParameters) { agentSupport.forkSession(params.sessionId, sessionParameters) } + @OptIn(UnstableApi::class) + scheduleSessionReady(session, sessionParameters) return@setRequestHandler ForkSessionResponse( sessionId = session.sessionId, modes = session.asModeState(), @@ -257,6 +266,23 @@ public class Agent( } private fun getSessionOrThrow(sessionId: SessionId): SessionWrapper = _sessions.value[sessionId] ?: acpFail("Session $sessionId not found") + + @OptIn(UnstableApi::class) + private suspend fun scheduleSessionReady( + session: AgentSession, + sessionParameters: SessionCreationParameters, + ) { + val requestJob = currentCoroutineContext()[Job] ?: return + val sessionWrapper = getSessionOrThrow(session.sessionId) + requestJob.invokeOnCompletion { cause -> + if (cause != null) return@invokeOnCompletion + protocol.launch { + sessionWrapper.executeWithSession { + agentSupport.onSessionReady(session, sessionParameters, currentCoroutineContext().client) + } + } + } + } } @@ -279,4 +305,3 @@ public val CoroutineContext.clientInfo: ClientInfo */ public val CoroutineContext.client: ClientSessionOperations get() = this[SessionWrapperContextElement.Key]?.sessionWrapper?.clientOperations ?: error("No remote client found in context") - diff --git a/acp/src/commonMain/kotlin/com/agentclientprotocol/agent/AgentSupport.kt b/acp/src/commonMain/kotlin/com/agentclientprotocol/agent/AgentSupport.kt index 31c9db3..47ba2b6 100644 --- a/acp/src/commonMain/kotlin/com/agentclientprotocol/agent/AgentSupport.kt +++ b/acp/src/commonMain/kotlin/com/agentclientprotocol/agent/AgentSupport.kt @@ -2,6 +2,7 @@ package com.agentclientprotocol.agent import com.agentclientprotocol.annotations.UnstableApi import com.agentclientprotocol.client.ClientInfo +import com.agentclientprotocol.common.ClientSessionOperations import com.agentclientprotocol.common.SessionCreationParameters import com.agentclientprotocol.model.AuthMethodId import com.agentclientprotocol.model.AuthenticateResponse @@ -65,6 +66,19 @@ public interface AgentSupport { throw NotImplementedError("loadSession is not implemented. The capability is declared in AgentCapabilities.loadSession") } + /** + * **UNSTABLE** + * + * Hook invoked after a session is created/loaded and bound to a client, before responding to the request. + * Use it to push initial session updates (for example, available commands) that must be sent before any prompt. + */ + @UnstableApi + public suspend fun onSessionReady( + session: AgentSession, + sessionParameters: SessionCreationParameters, + client: ClientSessionOperations, + ) {} + /** * **UNSTABLE** * @@ -96,4 +110,4 @@ public interface AgentSupport { public suspend fun resumeSession(sessionId: SessionId, sessionParameters: SessionCreationParameters): AgentSession { throw NotImplementedError("resumeSession is not implemented. The capability is declared in AgentCapabilities.sessionCapabilities.resume") } -} \ No newline at end of file +} diff --git a/acp/src/commonMain/kotlin/com/agentclientprotocol/protocol/Protocol.kt b/acp/src/commonMain/kotlin/com/agentclientprotocol/protocol/Protocol.kt index bb17d28..b7fc947 100644 --- a/acp/src/commonMain/kotlin/com/agentclientprotocol/protocol/Protocol.kt +++ b/acp/src/commonMain/kotlin/com/agentclientprotocol/protocol/Protocol.kt @@ -147,6 +147,8 @@ public class Protocol( private val notificationHandlers: AtomicRef Unit>> = atomic(persistentMapOf()) + internal fun launch(block: suspend CoroutineScope.() -> Unit): Job = scope.launch(block = block) + /** * Connect to a transport and start processing messages. */ @@ -505,4 +507,4 @@ private fun convertJsonRpcExceptionIfPossible(jsonRpcException: JsonRpcException return jsonRpcException } } -} \ No newline at end of file +}