diff --git a/src/a2a/client/base_client.py b/src/a2a/client/base_client.py index f4a8d03d..66a89ab6 100644 --- a/src/a2a/client/base_client.py +++ b/src/a2a/client/base_client.py @@ -46,6 +46,7 @@ async def send_message( self, request: Message, *, + configuration: MessageSendConfiguration | None = None, context: ClientCallContext | None = None, ) -> AsyncIterator[ClientEvent | Message]: """Sends a message to the agent. @@ -56,12 +57,13 @@ async def send_message( Args: request: The message to send to the agent. + configuration: Optional per-call overrides for message sending behavior. context: The client call context. Yields: An async iterator of `ClientEvent` or a final `Message` response. """ - config = MessageSendConfiguration( + base_config = MessageSendConfiguration( accepted_output_modes=self._config.accepted_output_modes, blocking=not self._config.polling, push_notification_config=( @@ -70,6 +72,16 @@ async def send_message( else None ), ) + if configuration is not None: + overrides = configuration.model_dump( + exclude_unset=True, + exclude_none=True, + by_alias=False, + ) + config = base_config.model_copy(update=overrides) + else: + config = base_config + params = MessageSendParams(message=request, configuration=config) if not self._config.streaming or not self._card.capabilities.streaming: diff --git a/tests/client/test_base_client.py b/tests/client/test_base_client.py index c1251f1c..9bfea3f6 100644 --- a/tests/client/test_base_client.py +++ b/tests/client/test_base_client.py @@ -9,6 +9,7 @@ AgentCapabilities, AgentCard, Message, + MessageSendConfiguration, Part, Role, Task, @@ -115,3 +116,109 @@ async def test_send_message_non_streaming_agent_capability_false( assert not mock_transport.send_message_streaming.called assert len(events) == 1 assert events[0][0].id == 'task-789' + + +@pytest.mark.asyncio +async def test_send_message_callsite_config_overrides_history_length_non_streaming( + base_client: BaseClient, mock_transport: MagicMock, sample_message: Message +): + base_client._config.streaming = False + mock_transport.send_message.return_value = Task( + id='task-cfg-ns-1', + context_id='ctx-cfg-ns-1', + status=TaskStatus(state=TaskState.completed), + ) + + cfg = MessageSendConfiguration(history_length=2) + events = [ + event + async for event in base_client.send_message( + sample_message, configuration=cfg + ) + ] + + mock_transport.send_message.assert_called_once() + assert not mock_transport.send_message_streaming.called + assert len(events) == 1 + task, _ = events[0] + assert task.id == 'task-cfg-ns-1' + + params = mock_transport.send_message.await_args.args[0] + assert params.configuration.history_length == 2 + assert params.configuration.blocking == (not base_client._config.polling) + assert ( + params.configuration.accepted_output_modes + == base_client._config.accepted_output_modes + ) + + +@pytest.mark.asyncio +async def test_send_message_ignores_none_fields_in_callsite_configuration_non_streaming( + base_client: BaseClient, mock_transport: MagicMock, sample_message: Message +): + base_client._config.streaming = False + mock_transport.send_message.return_value = Task( + id='task-cfg-ns-2', + context_id='ctx-cfg-ns-2', + status=TaskStatus(state=TaskState.completed), + ) + + cfg = MessageSendConfiguration(history_length=None, blocking=None) + events = [ + event + async for event in base_client.send_message( + sample_message, configuration=cfg + ) + ] + + mock_transport.send_message.assert_called_once() + assert len(events) == 1 + task, _ = events[0] + assert task.id == 'task-cfg-ns-2' + + params = mock_transport.send_message.await_args.args[0] + assert params.configuration.history_length is None + assert params.configuration.blocking == (not base_client._config.polling) + assert ( + params.configuration.accepted_output_modes + == base_client._config.accepted_output_modes + ) + + +@pytest.mark.asyncio +async def test_send_message_callsite_config_overrides_history_length_streaming( + base_client: BaseClient, mock_transport: MagicMock, sample_message: Message +): + base_client._config.streaming = True + base_client._card.capabilities.streaming = True + + async def create_stream(*args, **kwargs): + yield Task( + id='task-cfg-s-1', + context_id='ctx-cfg-s-1', + status=TaskStatus(state=TaskState.completed), + ) + + mock_transport.send_message_streaming.return_value = create_stream() + + cfg = MessageSendConfiguration(history_length=0) + events = [ + event + async for event in base_client.send_message( + sample_message, configuration=cfg + ) + ] + + mock_transport.send_message_streaming.assert_called_once() + assert not mock_transport.send_message.called + assert len(events) == 1 + task, _ = events[0] + assert task.id == 'task-cfg-s-1' + + params = mock_transport.send_message_streaming.call_args.args[0] + assert params.configuration.history_length == 0 + assert params.configuration.blocking == (not base_client._config.polling) + assert ( + params.configuration.accepted_output_modes + == base_client._config.accepted_output_modes + )