Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ import com.agentclientprotocol.common.Event
import com.agentclientprotocol.common.SessionParameters
import com.agentclientprotocol.common.remoteSessionOperations
import com.agentclientprotocol.model.AgentCapabilities
import com.agentclientprotocol.model.AvailableCommand
import com.agentclientprotocol.model.AvailableCommandInput
import com.agentclientprotocol.model.ContentBlock
import com.agentclientprotocol.model.LATEST_PROTOCOL_VERSION
import com.agentclientprotocol.model.PromptResponse
Expand Down Expand Up @@ -149,6 +151,18 @@ private class TerminalAgentSupport : AgentSupport {
override suspend fun loadSession(sessionId: SessionId, sessionParameters: SessionParameters): AgentSession =
// Rehydrate existing sessions with the provided identifier.
TerminalAgentSession(sessionId)

override suspend fun onSessionReady(
session: AgentSession,
sessionParameters: SessionParameters,
client: ClientSessionOperations
) {
client.notify(
SessionUpdate.AvailableCommandsUpdate(
listOf(AvailableCommand("help", "Show available commands", AvailableCommandInput.Unstructured("topic")))
)
)
}
}

fun main(): Unit = runBlocking {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,66 @@ abstract class SimpleAgentTest(protocolDriver: ProtocolDriver) : ProtocolDriver
assertEquals(result!!.stopReason, StopReason.END_TURN)
}

@Test
@OptIn(UnstableApi::class)
fun `session ready notifies available commands`() = testWithProtocols { clientProtocol, agentProtocol ->
val readyUpdate = CompletableDeferred<SessionUpdate>()
val client = Client(protocol = clientProtocol)
Agent(protocol = agentProtocol, agentSupport = object : AgentSupport {
override suspend fun initialize(clientInfo: ClientInfo): AgentInfo {
return AgentInfo(clientInfo.protocolVersion)
}

override suspend fun createSession(sessionParameters: SessionCreationParameters): AgentSession {
return object : AgentSession {
override val sessionId: SessionId = SessionId("ready-session-id")

override suspend fun prompt(
content: List<ContentBlock>,
_meta: JsonElement?,
): Flow<Event> = emptyFlow()
}
}

override suspend fun onSessionReady(
session: AgentSession,
sessionParameters: SessionCreationParameters,
client: ClientSessionOperations,
) {
client.notify(
SessionUpdate.AvailableCommandsUpdate(
listOf(AvailableCommand("help", "Show available commands", AvailableCommandInput.Unstructured("topic")))
)
)
}
})

client.initialize(ClientInfo(protocolVersion = 10))
client.newSession(SessionCreationParameters("/test/path", emptyList())) { _, _ ->
object : ClientSessionOperations {
override suspend fun requestPermissions(
toolCall: SessionUpdate.ToolCallUpdate,
permissions: List<PermissionOption>,
_meta: JsonElement?,
): RequestPermissionResponse {
return RequestPermissionResponse(RequestPermissionOutcome.Cancelled)
}

override suspend fun notify(
notification: SessionUpdate,
_meta: JsonElement?,
) {
readyUpdate.complete(notification)
}
}
}

val update = withTimeout(1000) { readyUpdate.await() }
assertTrue(update is SessionUpdate.AvailableCommandsUpdate)
val command = (update as SessionUpdate.AvailableCommandsUpdate).availableCommands.single()
assertEquals("help", command.name)
}

@Test
fun `prompt response and update have proper order`() = testWithProtocols { clientProtocol, agentProtocol ->
val client = Client(protocol = clientProtocol)
Expand Down Expand Up @@ -970,4 +1030,4 @@ abstract class SimpleAgentTest(protocolDriver: ProtocolDriver) : ProtocolDriver
assertEquals(1, sessions.size)
}

}
}
3 changes: 3 additions & 0 deletions acp/api/acp.api
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ public abstract interface class com/agentclientprotocol/agent/AgentSupport {
public static synthetic fun listSessions$suspendImpl (Lcom/agentclientprotocol/agent/AgentSupport;Ljava/lang/String;Lkotlinx/serialization/json/JsonElement;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public fun loadSession-nk3TnMc (Ljava/lang/String;Lcom/agentclientprotocol/common/SessionCreationParameters;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static synthetic fun loadSession-nk3TnMc$suspendImpl (Lcom/agentclientprotocol/agent/AgentSupport;Ljava/lang/String;Lcom/agentclientprotocol/common/SessionCreationParameters;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public fun onSessionReady (Lcom/agentclientprotocol/agent/AgentSession;Lcom/agentclientprotocol/common/SessionCreationParameters;Lcom/agentclientprotocol/common/ClientSessionOperations;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static synthetic fun onSessionReady$suspendImpl (Lcom/agentclientprotocol/agent/AgentSupport;Lcom/agentclientprotocol/agent/AgentSession;Lcom/agentclientprotocol/common/SessionCreationParameters;Lcom/agentclientprotocol/common/ClientSessionOperations;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public fun resumeSession-nk3TnMc (Ljava/lang/String;Lcom/agentclientprotocol/common/SessionCreationParameters;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static synthetic fun resumeSession-nk3TnMc$suspendImpl (Lcom/agentclientprotocol/agent/AgentSupport;Ljava/lang/String;Lcom/agentclientprotocol/common/SessionCreationParameters;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
}
Expand All @@ -92,6 +94,7 @@ public final class com/agentclientprotocol/agent/AgentSupport$DefaultImpls {
public static fun forkSession-nk3TnMc (Lcom/agentclientprotocol/agent/AgentSupport;Ljava/lang/String;Lcom/agentclientprotocol/common/SessionCreationParameters;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static fun listSessions (Lcom/agentclientprotocol/agent/AgentSupport;Ljava/lang/String;Lkotlinx/serialization/json/JsonElement;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static fun loadSession-nk3TnMc (Lcom/agentclientprotocol/agent/AgentSupport;Ljava/lang/String;Lcom/agentclientprotocol/common/SessionCreationParameters;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static fun onSessionReady (Lcom/agentclientprotocol/agent/AgentSupport;Lcom/agentclientprotocol/agent/AgentSession;Lcom/agentclientprotocol/common/SessionCreationParameters;Lcom/agentclientprotocol/common/ClientSessionOperations;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static fun resumeSession-nk3TnMc (Lcom/agentclientprotocol/agent/AgentSupport;Ljava/lang/String;Lcom/agentclientprotocol/common/SessionCreationParameters;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
}

Expand Down
27 changes: 26 additions & 1 deletion acp/src/commonMain/kotlin/com/agentclientprotocol/agent/Agent.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import kotlinx.atomicfu.update
import kotlinx.collections.immutable.persistentMapOf
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.Job
import kotlinx.coroutines.currentCoroutineContext
import kotlinx.coroutines.withContext
import kotlinx.serialization.json.JsonElement
Expand Down Expand Up @@ -147,6 +148,8 @@ public class Agent(
protocol.setRequestHandler(AcpMethod.AgentMethods.SessionNew) { params: NewSessionRequest ->
val sessionParameters = SessionCreationParameters(params.cwd, params.mcpServers, params._meta)
val session = createSession(sessionParameters) { agentSupport.createSession(it) }
@OptIn(UnstableApi::class)
scheduleSessionReady(session, sessionParameters)

@OptIn(UnstableApi::class)
return@setRequestHandler NewSessionResponse(
Expand All @@ -161,6 +164,8 @@ public class Agent(
val sessionParameters = SessionCreationParameters(params.cwd, params.mcpServers, params._meta)
val session = createSession(sessionParameters) { agentSupport.loadSession(params.sessionId, sessionParameters) }
@OptIn(UnstableApi::class)
scheduleSessionReady(session, sessionParameters)
@OptIn(UnstableApi::class)
return@setRequestHandler LoadSessionResponse(
// maybe unify result of these two methods to have sessionId in both
// sessionId = session.sessionId,
Expand All @@ -173,6 +178,8 @@ public class Agent(
protocol.setRequestHandler(AcpMethod.AgentMethods.SessionResume) { params: ResumeSessionRequest ->
val sessionParameters = SessionCreationParameters(params.cwd, params.mcpServers, params._meta)
val session = createSession(sessionParameters) { agentSupport.resumeSession(params.sessionId, sessionParameters) }
@OptIn(UnstableApi::class)
scheduleSessionReady(session, sessionParameters)
return@setRequestHandler ResumeSessionResponse(
modes = session.asModeState(),
models = session.asModelState()
Expand Down Expand Up @@ -211,6 +218,8 @@ public class Agent(
protocol.setRequestHandler(AcpMethod.AgentMethods.SessionFork) { params: ForkSessionRequest ->
val sessionParameters = SessionCreationParameters(params.cwd, params.mcpServers, params._meta)
val session = createSession(sessionParameters) { agentSupport.forkSession(params.sessionId, sessionParameters) }
@OptIn(UnstableApi::class)
scheduleSessionReady(session, sessionParameters)
return@setRequestHandler ForkSessionResponse(
sessionId = session.sessionId,
modes = session.asModeState(),
Expand Down Expand Up @@ -257,6 +266,23 @@ public class Agent(
}

private fun getSessionOrThrow(sessionId: SessionId): SessionWrapper = _sessions.value[sessionId] ?: acpFail("Session $sessionId not found")

@OptIn(UnstableApi::class)
private suspend fun scheduleSessionReady(
session: AgentSession,
sessionParameters: SessionCreationParameters,
) {
val requestJob = currentCoroutineContext()[Job] ?: return
val sessionWrapper = getSessionOrThrow(session.sessionId)
requestJob.invokeOnCompletion { cause ->
if (cause != null) return@invokeOnCompletion
protocol.launch {
sessionWrapper.executeWithSession {
agentSupport.onSessionReady(session, sessionParameters, currentCoroutineContext().client)
}
}
}
}
}


Expand All @@ -279,4 +305,3 @@ public val CoroutineContext.clientInfo: ClientInfo
*/
public val CoroutineContext.client: ClientSessionOperations
get() = this[SessionWrapperContextElement.Key]?.sessionWrapper?.clientOperations ?: error("No remote client found in context")

Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package com.agentclientprotocol.agent

import com.agentclientprotocol.annotations.UnstableApi
import com.agentclientprotocol.client.ClientInfo
import com.agentclientprotocol.common.ClientSessionOperations
import com.agentclientprotocol.common.SessionCreationParameters
import com.agentclientprotocol.model.AuthMethodId
import com.agentclientprotocol.model.AuthenticateResponse
Expand Down Expand Up @@ -65,6 +66,19 @@ public interface AgentSupport {
throw NotImplementedError("loadSession is not implemented. The capability is declared in AgentCapabilities.loadSession")
}

/**
* **UNSTABLE**
*
* Hook invoked after a session is created/loaded and bound to a client, before responding to the request.
* Use it to push initial session updates (for example, available commands) that must be sent before any prompt.
*/
@UnstableApi
public suspend fun onSessionReady(
session: AgentSession,
sessionParameters: SessionCreationParameters,
client: ClientSessionOperations,
) {}

/**
* **UNSTABLE**
*
Expand Down Expand Up @@ -96,4 +110,4 @@ public interface AgentSupport {
public suspend fun resumeSession(sessionId: SessionId, sessionParameters: SessionCreationParameters): AgentSession {
throw NotImplementedError("resumeSession is not implemented. The capability is declared in AgentCapabilities.sessionCapabilities.resume")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ public class Protocol(
private val notificationHandlers: AtomicRef<PersistentMap<MethodName, suspend (JsonRpcNotification) -> Unit>> =
atomic(persistentMapOf())

internal fun launch(block: suspend CoroutineScope.() -> Unit): Job = scope.launch(block = block)

/**
* Connect to a transport and start processing messages.
*/
Expand Down Expand Up @@ -505,4 +507,4 @@ private fun convertJsonRpcExceptionIfPossible(jsonRpcException: JsonRpcException
return jsonRpcException
}
}
}
}
Loading