Skip to content
Open
14 changes: 13 additions & 1 deletion src/a2a/client/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ async def send_message(
self,
request: Message,
*,
configuration: MessageSendConfiguration | None = None,
context: ClientCallContext | None = None,
) -> AsyncIterator[ClientEvent | Message]:
"""Sends a message to the agent.
Expand All @@ -56,12 +57,13 @@ async def send_message(

Args:
request: The message to send to the agent.
configuration: Optional per-call overrides for message sending behavior.
context: The client call context.

Yields:
An async iterator of `ClientEvent` or a final `Message` response.
"""
config = MessageSendConfiguration(
base_config = MessageSendConfiguration(
accepted_output_modes=self._config.accepted_output_modes,
blocking=not self._config.polling,
push_notification_config=(
Expand All @@ -70,6 +72,16 @@ async def send_message(
else None
),
)
if configuration is not None:
overrides = configuration.model_dump(
exclude_unset=True,
exclude_none=True,
by_alias=False,
)
config = base_config.model_copy(update=overrides)
else:
config = base_config

params = MessageSendParams(message=request, configuration=config)

if not self._config.streaming or not self._card.capabilities.streaming:
Expand Down
107 changes: 107 additions & 0 deletions tests/client/test_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
AgentCapabilities,
AgentCard,
Message,
MessageSendConfiguration,
Part,
Role,
Task,
Expand Down Expand Up @@ -115,3 +116,109 @@ async def test_send_message_non_streaming_agent_capability_false(
assert not mock_transport.send_message_streaming.called
assert len(events) == 1
assert events[0][0].id == 'task-789'


@pytest.mark.asyncio
async def test_send_message_callsite_config_overrides_history_length_non_streaming(
base_client: BaseClient, mock_transport: MagicMock, sample_message: Message
):
base_client._config.streaming = False
mock_transport.send_message.return_value = Task(
id='task-cfg-ns-1',
context_id='ctx-cfg-ns-1',
status=TaskStatus(state=TaskState.completed),
)

cfg = MessageSendConfiguration(history_length=2)
events = [
event
async for event in base_client.send_message(
sample_message, configuration=cfg
)
]

mock_transport.send_message.assert_called_once()
assert not mock_transport.send_message_streaming.called
assert len(events) == 1
task, _ = events[0]
assert task.id == 'task-cfg-ns-1'

params = mock_transport.send_message.await_args.args[0]
assert params.configuration.history_length == 2
assert params.configuration.blocking == (not base_client._config.polling)
assert (
params.configuration.accepted_output_modes
== base_client._config.accepted_output_modes
)


@pytest.mark.asyncio
async def test_send_message_ignores_none_fields_in_callsite_configuration_non_streaming(
base_client: BaseClient, mock_transport: MagicMock, sample_message: Message
):
base_client._config.streaming = False
mock_transport.send_message.return_value = Task(
id='task-cfg-ns-2',
context_id='ctx-cfg-ns-2',
status=TaskStatus(state=TaskState.completed),
)

cfg = MessageSendConfiguration(history_length=None, blocking=None)
events = [
event
async for event in base_client.send_message(
sample_message, configuration=cfg
)
]

mock_transport.send_message.assert_called_once()
assert len(events) == 1
task, _ = events[0]
assert task.id == 'task-cfg-ns-2'

params = mock_transport.send_message.await_args.args[0]
assert params.configuration.history_length is None
assert params.configuration.blocking == (not base_client._config.polling)
assert (
params.configuration.accepted_output_modes
== base_client._config.accepted_output_modes
)


@pytest.mark.asyncio
async def test_send_message_callsite_config_overrides_history_length_streaming(
base_client: BaseClient, mock_transport: MagicMock, sample_message: Message
):
base_client._config.streaming = True
base_client._card.capabilities.streaming = True

async def create_stream(*args, **kwargs):
yield Task(
id='task-cfg-s-1',
context_id='ctx-cfg-s-1',
status=TaskStatus(state=TaskState.completed),
)

mock_transport.send_message_streaming.return_value = create_stream()

cfg = MessageSendConfiguration(history_length=0)
events = [
event
async for event in base_client.send_message(
sample_message, configuration=cfg
)
]

mock_transport.send_message_streaming.assert_called_once()
assert not mock_transport.send_message.called
assert len(events) == 1
task, _ = events[0]
assert task.id == 'task-cfg-s-1'

params = mock_transport.send_message_streaming.call_args.args[0]
assert params.configuration.history_length == 0
assert params.configuration.blocking == (not base_client._config.polling)
assert (
params.configuration.accepted_output_modes
== base_client._config.accepted_output_modes
)
Loading