diff --git a/pyproject.toml b/pyproject.toml index fc50b24199..b331b821a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,7 +123,7 @@ test = [ "kubernetes>=29.0.0", # For GkeCodeExecutor "langchain-community>=0.3.17", "langgraph>=0.2.60, <0.4.8", # For LangGraphAgent - "litellm>=1.75.5, <1.81.0", # For LiteLLM tests + "litellm>=1.75.5, <1.80.17", # For LiteLLM tests "llama-index-readers-file>=0.4.0", # For retrieval tests "openai>=1.100.2", # For LiteLLM "pytest-asyncio>=0.25.0", @@ -153,7 +153,7 @@ extensions = [ "docker>=7.0.0", # For ContainerCodeExecutor "kubernetes>=29.0.0", # For GkeCodeExecutor "langgraph>=0.2.60, <0.4.8", # For LangGraphAgent - "litellm>=1.75.5, <1.81.0", # For LiteLlm class. Currently has OpenAI limitations. TODO: once LiteLlm fix it + "litellm>=1.75.5, <1.80.17", # For LiteLlm class. Currently has OpenAI limitations. TODO: once LiteLlm fix it "llama-index-readers-file>=0.4.0", # For retrieval using LlamaIndex. "llama-index-embeddings-google-genai>=0.3.0", # For files retrieval using LlamaIndex. "lxml>=5.3.0", # For load_web_page tool. diff --git a/src/google/adk/agents/remote_a2a_agent.py b/src/google/adk/agents/remote_a2a_agent.py index 7c894d54a3..2da7a4faa0 100644 --- a/src/google/adk/agents/remote_a2a_agent.py +++ b/src/google/adk/agents/remote_a2a_agent.py @@ -134,7 +134,7 @@ def __init__( Args: name: Agent name (must be unique identifier) agent_card: AgentCard object, URL string, or file path string - description: Agent description (auto-populated from card if empty) + description: Agent description (autopopulated from card if empty) httpx_client: Optional shared HTTP client (will create own if not provided) [deprecated] Use a2a_client_factory instead. timeout: HTTP timeout in seconds diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index 4404d62e4e..b97932d042 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -395,7 +395,7 @@ def _setup_gcp_telemetry( # TODO - use trace_to_cloud here as well once otel_to_cloud is no # longer experimental. enable_cloud_tracing=True, - # TODO - reenable metrics once errors during shutdown are fixed. + # TODO - re-enable metrics once errors during shutdown are fixed. enable_cloud_metrics=False, enable_cloud_logging=True, google_auth=(credentials, project_id), diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 6058a063db..c4f1d405ad 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -1299,7 +1299,7 @@ def cli_web( ): """Starts a FastAPI server with Web UI for agents. - AGENTS_DIR: The directory of agents, where each sub-directory is a single + AGENTS_DIR: The directory of agents, where each subdirectory is a single agent, containing at least `__init__.py` and `agent.py` files. Example: @@ -1366,7 +1366,7 @@ async def _lifespan(app: FastAPI): @main.command("api_server") @feature_options() -# The directory of agents, where each sub-directory is a single agent. +# The directory of agents, where each subdirectory is a single agent. # By default, it is the current working directory @click.argument( "agents_dir", @@ -1401,7 +1401,7 @@ def cli_api_server( ): """Starts a FastAPI server for agents. - AGENTS_DIR: The directory of agents, where each sub-directory is a single + AGENTS_DIR: The directory of agents, where each subdirectory is a single agent, containing at least `__init__.py` and `agent.py` files. Example: diff --git a/src/google/adk/evaluation/eval_config.py b/src/google/adk/evaluation/eval_config.py index 185f02cc20..ead2303ceb 100644 --- a/src/google/adk/evaluation/eval_config.py +++ b/src/google/adk/evaluation/eval_config.py @@ -89,7 +89,7 @@ class EvalConfig(BaseModel): In the sample below, `tool_trajectory_avg_score`, `response_match_score` and `final_response_match_v2` are the standard eval metric names, represented as keys in the dictionary. The values in the dictionary are the corresponding -criterions. For the first two metrics, we use simple threshold as the criterion, +criteria. For the first two metrics, we use simple threshold as the criterion, the third one uses `LlmAsAJudgeCriterion`. { "criteria": { diff --git a/src/google/adk/evaluation/rubric_based_evaluator.py b/src/google/adk/evaluation/rubric_based_evaluator.py index 63c1d47d00..451a14f1a5 100644 --- a/src/google/adk/evaluation/rubric_based_evaluator.py +++ b/src/google/adk/evaluation/rubric_based_evaluator.py @@ -93,7 +93,7 @@ class PerInvocationResultsAggregator(abc.ABC): """An interface for aggregating per invocation samples. AutoRaters that are backed by an LLM are known to have certain degree of - unreliabilty to their responses. In order to counter that we sample the + unreliability to their responses. In order to counter that we sample the autorater more than once for a single invocation. The aggregator helps convert those multiple samples into a single result. @@ -419,7 +419,7 @@ def aggregate_per_invocation_samples( """Returns a combined result by aggregating multiple samples for the same invocation. AutoRaters that are backed by an LLM are known to have certain degree of - unreliabilty to their responses. In order to counter that we sample the + unreliability to their responses. In order to counter that we sample the autorater more than once for a single invocation. The aggregator helps convert those multiple samples into a single result. diff --git a/src/google/adk/evaluation/simulation/per_turn_user_simulator_quality_v1.py b/src/google/adk/evaluation/simulation/per_turn_user_simulator_quality_v1.py index 5d6208a10a..ade65d72a7 100644 --- a/src/google/adk/evaluation/simulation/per_turn_user_simulator_quality_v1.py +++ b/src/google/adk/evaluation/simulation/per_turn_user_simulator_quality_v1.py @@ -58,7 +58,7 @@ # Definition of Conversation History The Conversation History is the actual dialogue between the User Simulator and the Agent. -The Conversation History may not be complete, but the exsisting dialogue should adhere to the Conversation Plan. +The Conversation History may not be complete, but the existing dialogue should adhere to the Conversation Plan. The Conversation History may contain instances where the User Simulator troubleshoots an incorrect/inappropriate response from the Agent in order to enforce the Conversation Plan. The Conversation History is finished only when the User Simulator outputs `{stop_signal}` in its response. If this token is missing, the conversation between the User Simulator and the Agent has not finished, and more turns can be generated. @@ -171,7 +171,7 @@ def _parse_llm_response(response: str) -> Label: response, ) - # If there was not match for "is_valid", return NOT_FOUND + # If there was no match for "is_valid", return NOT_FOUND if is_valid_match is None: return Label.NOT_FOUND diff --git a/src/google/adk/flows/llm_flows/contents.py b/src/google/adk/flows/llm_flows/contents.py index 046a443a25..0a17ade66e 100644 --- a/src/google/adk/flows/llm_flows/contents.py +++ b/src/google/adk/flows/llm_flows/contents.py @@ -29,6 +29,7 @@ from .functions import remove_client_function_call_id from .functions import REQUEST_CONFIRMATION_FUNCTION_CALL_NAME from .functions import REQUEST_EUC_FUNCTION_CALL_NAME +from .functions import REQUEST_INPUT_FUNCTION_CALL_NAME logger = logging.getLogger('google_adk.' + __name__) @@ -280,6 +281,7 @@ def _should_include_event_in_context( or _is_adk_framework_event(event) or _is_auth_event(event) or _is_request_confirmation_event(event) + or _is_request_input_event(event) ) @@ -675,6 +677,11 @@ def _is_adk_framework_event(event: Event) -> bool: return _is_function_call_event(event, 'adk_framework') +def _is_request_input_event(event: Event) -> bool: + """Checks if the event is a request input event.""" + return _is_function_call_event(event, REQUEST_INPUT_FUNCTION_CALL_NAME) + + def _is_live_model_audio_event_with_inline_data(event: Event) -> bool: """Check if the event is a live/bidi audio event with inline data. diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index ec1cd2300a..e3120badbd 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -49,6 +49,7 @@ AF_FUNCTION_CALL_ID_PREFIX = 'adk-' REQUEST_EUC_FUNCTION_CALL_NAME = 'adk_request_credential' REQUEST_CONFIRMATION_FUNCTION_CALL_NAME = 'adk_request_confirmation' +REQUEST_INPUT_FUNCTION_CALL_NAME = 'adk_request_input' logger = logging.getLogger('google_adk.' + __name__) @@ -410,7 +411,7 @@ async def _run_with_trace(): function_response = altered_function_response if tool.is_long_running: - # Allow long running function to return None to not provide function + # Allow long-running function to return None to not provide function # response. if not function_response: return None @@ -893,7 +894,7 @@ def find_matching_function_call( ) for i in range(len(events) - 2, -1, -1): event = events[i] - # looking for the system long running request euc function call + # looking for the system long-running request euc function call function_calls = event.get_function_calls() if not function_calls: continue diff --git a/src/google/adk/memory/vertex_ai_rag_memory_service.py b/src/google/adk/memory/vertex_ai_rag_memory_service.py index bd6e9dc2b0..001b30b541 100644 --- a/src/google/adk/memory/vertex_ai_rag_memory_service.py +++ b/src/google/adk/memory/vertex_ai_rag_memory_service.py @@ -52,7 +52,7 @@ def __init__( or ``{rag_corpus_id}`` similarity_top_k: The number of contexts to retrieve. vector_distance_threshold: Only returns contexts with vector distance - smaller than the threshold.. + smaller than the threshold. """ self._vertex_rag_store = types.VertexRagStore( rag_resources=[ diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 733719b422..9015e4b4ee 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -58,7 +58,7 @@ class _ResourceExhaustedError(ClientError): - """Represents an resources exhausted error received from the Model.""" + """Represents a resources exhausted error received from the Model.""" def __init__( self, diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index 4131e47b8c..79182d7b0a 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -152,7 +152,7 @@ def _ensure_litellm_imported() -> None: """Imports LiteLLM with safe defaults. - LiteLLM defaults to DEV mode, which auto-loads a local `.env` at import time. + LiteLLM defaults to DEV mode, which autoloads a local `.env` at import time. ADK should not implicitly load `.env` just because LiteLLM is installed. Users can opt into LiteLLM's default behavior by setting LITELLM_MODE=DEV. diff --git a/src/google/adk/plugins/base_plugin.py b/src/google/adk/plugins/base_plugin.py index 2912ba0a7e..3639f61aa2 100644 --- a/src/google/adk/plugins/base_plugin.py +++ b/src/google/adk/plugins/base_plugin.py @@ -68,7 +68,7 @@ class BasePlugin(ABC): callback in the chain. For example, if a plugin modifies the tool input with before_tool_callback, the modified tool input will be passed to the before_tool_callback of the next plugin, and further passed to the agent - callbacks if not short circuited. + callbacks if not short-circuited. To use a plugin, implement the desired callback methods and pass an instance of your custom plugin class to the ADK Runner. diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 0c12d39a9c..84fa66eb66 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -699,6 +699,13 @@ def __init__( self._batch_processor_task: Optional[asyncio.Task] = None self._shutdown = False + async def flush(self) -> None: + """Flushes the queue by waiting for it to be empty.""" + if self._queue.empty(): + return + # Wait for all items in the queue to be processed + await self._queue.join() + async def start(self): """Starts the batch writer worker task.""" if self._batch_processor_task is None: @@ -1516,6 +1523,11 @@ def _format_content_safely( logger.warning("Content formatter failed: %s", e) return "[FORMATTING FAILED]", False + async def flush(self) -> None: + """Flushes any pending events to BigQuery.""" + if self.batch_processor: + await self.batch_processor.flush() + async def _lazy_setup(self, **kwargs) -> None: """Performs lazy initialization of BigQuery clients and resources.""" if self._started: @@ -1947,6 +1959,8 @@ async def after_run_callback( await self._log_event( "INVOCATION_COMPLETED", CallbackContext(invocation_context) ) + # Ensure all logs are flushed before the agent returns + await self.flush() async def before_agent_callback( self, *, agent: Any, callback_context: CallbackContext, **kwargs diff --git a/src/google/adk/plugins/logging_plugin.py b/src/google/adk/plugins/logging_plugin.py index f44a75e04b..df37ee7ee4 100644 --- a/src/google/adk/plugins/logging_plugin.py +++ b/src/google/adk/plugins/logging_plugin.py @@ -36,7 +36,7 @@ class LoggingPlugin(BasePlugin): """A plugin that logs important information at each callback point. - This plugin helps printing all critical events in the console. It is not a + This plugin helps print all critical events in the console. It is not a replacement of existing logging in ADK. It rather helps terminal based debugging by showing all logs in the console, and serves as a simple demo for everyone to leverage when developing new plugins. diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index b931561c4d..3aaa54e257 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -760,7 +760,7 @@ async def _exec_with_plugin( else: # Step 2: Otherwise continue with normal execution # Note for live/bidi: - # the transcription may arrive later then the action(function call + # the transcription may arrive later than the action(function call # event and thus function response event). In this case, the order of # transcription and function call event will be wrong if we just # append as it arrives. To address this, we should check if there is @@ -770,7 +770,7 @@ async def _exec_with_plugin( # identified by checking if the transcription event is partial. When # the next transcription event is not partial, it means the previous # transcription is finished. Then if there is any buffered function - # call event, we should append them after this finished(non-parital) + # call event, we should append them after this finished(non-partial) # transcription event. buffered_events: list[Event] = [] is_transcribing: bool = False @@ -789,7 +789,7 @@ async def _exec_with_plugin( buffered_events.append(event) continue # Note for live/bidi: for audio response, it's considered as - # non-paritla event(event.partial=None) + # non-partial event(event.partial=None) # event.partial=False and event.partial=None are considered as # non-partial event; event.partial=True is considered as partial # event. @@ -938,7 +938,7 @@ async def run_live( * **Live Model Audio Events with Inline Data:** Events containing raw audio `Blob` data(`inline_data`). * **Live Model Audio Events with File Data:** Both input and ouput audio - data are aggregated into a audio file saved into artifacts. The + data are aggregated into an audio file saved into artifacts. The reference to the file is saved in the event as `file_data`. * **Usage Metadata:** Events containing token usage. * **Transcription Events:** Both partial and non-partial transcription @@ -948,7 +948,7 @@ async def run_live( **Events Saved to the Session:** * **Live Model Audio Events with File Data:** Both input and ouput audio - data are aggregated into a audio file saved into artifacts. The + data are aggregated into an audio file saved into artifacts. The reference to the file is saved as event in the `file_data` to session if RunConfig.save_live_model_audio_to_session is True. * **Usage Metadata Events:** Saved to the session. @@ -1099,7 +1099,7 @@ def _find_agent_to_run( # If the last event is a function response, should send this response to # the agent that returned the corresponding function call regardless the # type of the agent. e.g. a remote a2a agent may surface a credential - # request as a special long running function tool call. + # request as a special long-running function tool call. event = find_matching_function_call(session.events) if event and event.author: return root_agent.find_agent(event.author) diff --git a/src/google/adk/sessions/migration/migration_runner.py b/src/google/adk/sessions/migration/migration_runner.py index 56acd810fb..edb5c83bcb 100644 --- a/src/google/adk/sessions/migration/migration_runner.py +++ b/src/google/adk/sessions/migration/migration_runner.py @@ -49,7 +49,7 @@ def upgrade(source_db_url: str, dest_db_url: str): LATEST_VERSION. If multiple migration steps are required, intermediate results are stored in - temporary SQLite database files. This means a multi-step migration + temporary SQLite database files. This means a multistep migration between other database types (e.g. PostgreSQL to PostgreSQL) will use SQLite for intermediate steps. diff --git a/src/google/adk/tools/google_search_tool.py b/src/google/adk/tools/google_search_tool.py index 6f7d1e52c6..406ad2189e 100644 --- a/src/google/adk/tools/google_search_tool.py +++ b/src/google/adk/tools/google_search_tool.py @@ -35,17 +35,26 @@ class GoogleSearchTool(BaseTool): local code execution. """ - def __init__(self, *, bypass_multi_tools_limit: bool = False): + def __init__( + self, + *, + bypass_multi_tools_limit: bool = False, + model: str | None = None, + ): """Initializes the Google search tool. Args: bypass_multi_tools_limit: Whether to bypass the multi tools limitation, so that the tool can be used with other tools in the same agent. + model: Optional model name to use for processing the LLM request. If + provided, this model will be used instead of the model from the + incoming llm_request. """ # Name and description are not used because this is a model built-in tool. super().__init__(name='google_search', description='google_search') self.bypass_multi_tools_limit = bypass_multi_tools_limit + self.model = model @override async def process_llm_request( @@ -54,6 +63,10 @@ async def process_llm_request( tool_context: ToolContext, llm_request: LlmRequest, ) -> None: + # If a custom model is specified, use it instead of the original model + if self.model is not None: + llm_request.model = self.model + llm_request.config = llm_request.config or types.GenerateContentConfig() llm_request.config.tools = llm_request.config.tools or [] if is_gemini_1_model(llm_request.model): diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 5ff541b834..2ebbc5dfe8 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -10,7 +10,7 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the Licens +# limitations under the License import contextlib import json diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index 5f2474c0a1..77968a3196 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -2061,3 +2061,28 @@ async def test_otel_integration_real_provider(self, callback_context): assert finished_spans[0].name == "test_span" assert format(finished_spans[0].context.span_id, "016x") == span_id assert format(finished_spans[0].context.trace_id, "032x") == trace_id + + @pytest.mark.asyncio + async def test_flush_mechanism( + self, + bq_plugin_inst, + mock_write_client, + dummy_arrow_schema, + invocation_context, + ): + """Verifies that flush() forces pending events to be written.""" + # Log an event + bigquery_agent_analytics_plugin.TraceManager.push_span(invocation_context) + await bq_plugin_inst.before_run_callback( + invocation_context=invocation_context + ) + + # Call flush - this should block until the event is written + await bq_plugin_inst.flush() + + # Verify write called + mock_write_client.append_rows.assert_called_once() + log_entry = await _get_captured_event_dict_async( + mock_write_client, dummy_arrow_schema + ) + assert log_entry["event_type"] == "INVOCATION_STARTING" diff --git a/tests/unittests/runners/test_runner_rewind.py b/tests/unittests/runners/test_runner_rewind.py index f8010238bd..035d28437b 100644 --- a/tests/unittests/runners/test_runner_rewind.py +++ b/tests/unittests/runners/test_runner_rewind.py @@ -130,14 +130,14 @@ async def test_rewind_async_with_state_and_artifacts(self): rewind_before_invocation_id="invocation2", ) - # 3. Verify state and artifacts are rewinded + # 3. Verify state and artifacts are rewound session = await runner.session_service.get_session( app_name=runner.app_name, user_id=user_id, session_id=session_id ) # After rewind before invocation2, only event1 state delta should apply. assert session.state["k1"] == "v1" assert not session.state["k2"] - # f1 should be rewinded to v0 + # f1 should be rewound to v0 assert await runner.artifact_service.load_artifact( app_name=runner.app_name, user_id=user_id, @@ -226,7 +226,7 @@ async def test_rewind_async_not_first_invocation(self): rewind_before_invocation_id="invocation3", ) - # 3. Verify state and artifacts are rewinded + # 3. Verify state and artifacts are rewound session = await runner.session_service.get_session( app_name=runner.app_name, user_id=user_id, session_id=session_id ) diff --git a/tests/unittests/tools/test_google_search_tool.py b/tests/unittests/tools/test_google_search_tool.py index aaffb8d44a..ad5d46b59e 100644 --- a/tests/unittests/tools/test_google_search_tool.py +++ b/tests/unittests/tools/test_google_search_tool.py @@ -432,3 +432,46 @@ async def test_process_llm_request_gemini_version_specifics(self): assert len(llm_request.config.tools) == 1 assert llm_request.config.tools[0].google_search is not None assert llm_request.config.tools[0].google_search_retrieval is None + + @pytest.mark.asyncio + @pytest.mark.parametrize( + ( + 'tool_model', + 'request_model', + 'expected_model', + ), + [ + ( + 'gemini-2.5-flash-lite', + 'gemini-2.5-flash', + 'gemini-2.5-flash-lite', + ), + ( + None, + 'gemini-2.5-flash', + 'gemini-2.5-flash', + ), + ], + ids=['with_custom_model', 'without_custom_model'], + ) + async def test_process_llm_request_custom_model_behavior( + self, + tool_model, + request_model, + expected_model, + ): + """Tests custom model parameter behavior in process_llm_request.""" + tool = GoogleSearchTool(model=tool_model) + tool_context = await _create_tool_context() + + llm_request = LlmRequest( + model=request_model, config=types.GenerateContentConfig() + ) + + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + assert llm_request.model == expected_model + assert llm_request.config.tools is not None + assert len(llm_request.config.tools) == 1