diff --git a/CHANGELOG.md b/CHANGELOG.md index cd5b83df3d..fada470bb0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Changelog + +## [1.22.1](https://github.com/google/adk-python/compare/v1.22.0...v1.22.1) (2026-01-09) + +### Bug Fixes +* Add back `adk migrate session` CLI ([8fb2be2](https://github.com/google/adk-python/commit/8fb2be216f11dabe7fa361a0402e5e6316878ad8)). +* Escape database reserved keyword ([94d48fc](https://github.com/google/adk-python/commit/94d48fce32a1f07cef967d50e82f2b1975b4abd9)). + + ## [1.22.0](https://github.com/google/adk-python/compare/v1.21.0...v1.22.0) (2026-01-08) ### Features diff --git a/src/google/adk/a2a/utils/agent_card_builder.py b/src/google/adk/a2a/utils/agent_card_builder.py index c007870931..2855077704 100644 --- a/src/google/adk/a2a/utils/agent_card_builder.py +++ b/src/google/adk/a2a/utils/agent_card_builder.py @@ -114,7 +114,7 @@ async def _build_llm_agent_skills(agent: LlmAgent) -> List[AgentSkill]: id=agent.name, name='model', description=agent_description, - examples=agent_examples, + examples=_extract_inputs_from_examples(agent_examples), input_modes=_get_input_modes(agent), output_modes=_get_output_modes(agent), tags=['llm'], @@ -239,7 +239,7 @@ async def _build_non_llm_agent_skills(agent: BaseAgent) -> List[AgentSkill]: id=agent.name, name=agent_name, description=agent_description, - examples=agent_examples, + examples=_extract_inputs_from_examples(agent_examples), input_modes=_get_input_modes(agent), output_modes=_get_output_modes(agent), tags=[agent_type], @@ -350,6 +350,7 @@ def _build_llm_agent_description_with_instructions(agent: LlmAgent) -> str: def _replace_pronouns(text: str) -> str: """Replace pronouns and conjugate common verbs for agent description. + (e.g., "You are" -> "I am", "your" -> "my"). """ pronoun_map = { @@ -460,6 +461,33 @@ def _get_default_description(agent: BaseAgent) -> str: return 'A custom agent' +def _extract_inputs_from_examples(examples: Optional[list[dict]]) -> list[str]: + """Extracts only the input strings so they can be added to an AgentSkill.""" + if examples is None: + return [] + + extracted_inputs = [] + for example in examples: + example_input = example.get('input') + if not example_input: + continue + + parts = example_input.get('parts') + if parts is not None: + part_texts = [] + for part in parts: + text = part.get('text') + if text is not None: + part_texts.append(text) + extracted_inputs.append('\n'.join(part_texts)) + else: + text = example_input.get('text') + if text is not None: + extracted_inputs.append(text) + + return extracted_inputs + + async def _extract_examples_from_agent( agent: BaseAgent, ) -> Optional[List[Dict]]: diff --git a/src/google/adk/auth/auth_credential.py b/src/google/adk/auth/auth_credential.py index f707d6a0bc..d659bdaa50 100644 --- a/src/google/adk/auth/auth_credential.py +++ b/src/google/adk/auth/auth_credential.py @@ -61,6 +61,7 @@ class HttpAuth(BaseModelWithConfig): # Examples: 'basic', 'bearer' scheme: str credentials: HttpCredentials + additional_headers: Optional[Dict[str, str]] = None class OAuth2Auth(BaseModelWithConfig): diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index f0b8fba022..9984580244 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -36,6 +36,7 @@ from . import cli_deploy from .. import version from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE +from ..sessions.migration import migration_runner from .cli import run_cli from .fast_api import get_fast_api_app from .utils import envs diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 78b523ed9f..3e850ae207 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -303,7 +303,6 @@ def get_author_for_event(llm_response): else: return invocation_context.agent.name - assert invocation_context.live_request_queue try: while True: async with Aclosing(llm_connection.receive()) as agen: diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index 9fb02d865d..384d76da88 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -181,6 +181,45 @@ def _infer_mime_type_from_uri(uri: str) -> Optional[str]: return None +def _looks_like_openai_file_id(file_uri: str) -> bool: + """Returns True when file_uri resembles an OpenAI/Azure file id.""" + return file_uri.startswith("file-") + + +def _redact_file_uri_for_log( + file_uri: str, *, display_name: str | None = None +) -> str: + """Returns a privacy-preserving identifier for logs.""" + if display_name: + return display_name + if _looks_like_openai_file_id(file_uri): + return "file-" + try: + parsed = urlparse(file_uri) + except ValueError: + return "" + if not parsed.scheme: + return "" + segments = [segment for segment in parsed.path.split("/") if segment] + tail = segments[-1] if segments else "" + if tail: + return f"{parsed.scheme}:///{tail}" + return f"{parsed.scheme}://" + + +def _requires_file_uri_fallback( + provider: str, model: str, file_uri: str +) -> bool: + """Returns True when `file_uri` should not be sent as a file content block.""" + if provider in _FILE_ID_REQUIRED_PROVIDERS: + return not _looks_like_openai_file_id(file_uri) + if provider == "anthropic": + return True + if provider == "vertex_ai" and not _is_litellm_gemini_model(model): + return True + return False + + def _decode_inline_text_data(raw_bytes: bytes) -> str: """Decodes inline file bytes that represent textual content.""" try: @@ -447,6 +486,7 @@ async def _content_to_message_param( content: types.Content, *, provider: str = "", + model: str = "", ) -> Union[Message, list[Message]]: """Converts a types.Content to a litellm Message or list of Messages. @@ -456,6 +496,7 @@ async def _content_to_message_param( Args: content: The content to convert. provider: The LLM provider name (e.g., "openai", "azure"). + model: The LiteLLM model string, used for provider-specific behavior. Returns: A litellm Message, a list of litellm Messages. @@ -499,7 +540,9 @@ async def _content_to_message_param( if role == "user": user_parts = [part for part in content.parts if not part.thought] - message_content = await _get_content(user_parts, provider=provider) or None + message_content = ( + await _get_content(user_parts, provider=provider, model=model) or None + ) return ChatCompletionUserMessage(role="user", content=message_content) else: # assistant/model tool_calls = [] @@ -523,7 +566,7 @@ async def _content_to_message_param( content_parts.append(part) final_content = ( - await _get_content(content_parts, provider=provider) + await _get_content(content_parts, provider=provider, model=model) if content_parts else None ) @@ -620,6 +663,7 @@ async def _get_content( parts: Iterable[types.Part], *, provider: str = "", + model: str = "", ) -> OpenAIMessageContent: """Converts a list of parts to litellm content. @@ -629,6 +673,8 @@ async def _get_content( Args: parts: The parts to convert. provider: The LLM provider name (e.g., "openai", "azure"). + model: The LiteLLM model string (e.g., "openai/gpt-4o", + "vertex_ai/gemini-2.5-flash"). Returns: The litellm content. @@ -709,6 +755,32 @@ async def _get_content( f"{part.inline_data.mime_type}." ) elif part.file_data and part.file_data.file_uri: + if ( + provider in _FILE_ID_REQUIRED_PROVIDERS + and _looks_like_openai_file_id(part.file_data.file_uri) + ): + content_objects.append({ + "type": "file", + "file": {"file_id": part.file_data.file_uri}, + }) + continue + + if _requires_file_uri_fallback(provider, model, part.file_data.file_uri): + logger.debug( + "File URI %s not supported for provider %s, using text fallback", + _redact_file_uri_for_log( + part.file_data.file_uri, + display_name=part.file_data.display_name, + ), + provider, + ) + identifier = part.file_data.display_name or part.file_data.file_uri + content_objects.append({ + "type": "text", + "text": f'[File reference: "{identifier}"]', + }) + continue + file_object: ChatCompletionFileUrlObject = { "file_id": part.file_data.file_uri, } @@ -1363,7 +1435,7 @@ async def _get_completion_inputs( messages: List[Message] = [] for content in llm_request.contents or []: message_param_or_list = await _content_to_message_param( - content, provider=provider + content, provider=provider, model=model ) if isinstance(message_param_or_list, list): messages.extend(message_param_or_list) diff --git a/src/google/adk/telemetry/tracing.py b/src/google/adk/telemetry/tracing.py index 386ae3b453..d98d12a43c 100644 --- a/src/google/adk/telemetry/tracing.py +++ b/src/google/adk/telemetry/tracing.py @@ -149,7 +149,7 @@ def trace_tool_call( _safe_json_serialize(args), ) else: - span.set_attribute('gcp.vertex.agent.tool_call_args', {}) + span.set_attribute('gcp.vertex.agent.tool_call_args', '{}') # Tracing tool response tool_call_id = '' @@ -179,7 +179,7 @@ def trace_tool_call( _safe_json_serialize(tool_response), ) else: - span.set_attribute('gcp.vertex.agent.tool_response', {}) + span.set_attribute('gcp.vertex.agent.tool_response', '{}') def trace_merged_tool_calls( @@ -219,7 +219,7 @@ def trace_merged_tool_calls( function_response_event_json, ) else: - span.set_attribute('gcp.vertex.agent.tool_response', {}) + span.set_attribute('gcp.vertex.agent.tool_response', '{}') # Setting empty llm request and response (as UI expect these) while not # applicable for tool_response. span.set_attribute('gcp.vertex.agent.llm_request', '{}') @@ -265,7 +265,7 @@ def trace_call_llm( _safe_json_serialize(_build_llm_request_for_trace(llm_request)), ) else: - span.set_attribute('gcp.vertex.agent.llm_request', {}) + span.set_attribute('gcp.vertex.agent.llm_request', '{}') # Consider removing once GenAI SDK provides a way to record this info. if llm_request.config: if llm_request.config.top_p: @@ -290,7 +290,7 @@ def trace_call_llm( llm_response_json, ) else: - span.set_attribute('gcp.vertex.agent.llm_response', {}) + span.set_attribute('gcp.vertex.agent.llm_response', '{}') if llm_response.usage_metadata is not None: span.set_attribute( @@ -346,7 +346,7 @@ def trace_send_data( ]), ) else: - span.set_attribute('gcp.vertex.agent.data', {}) + span.set_attribute('gcp.vertex.agent.data', '{}') def _build_llm_request_for_trace(llm_request: LlmRequest) -> dict[str, Any]: diff --git a/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py b/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py index 4fdc87019b..6044cac2b1 100644 --- a/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py +++ b/src/google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py @@ -74,14 +74,18 @@ def exchange_credential( try: if auth_credential.service_account.use_default_credential: - credentials, _ = google.auth.default( + credentials, project_id = google.auth.default( scopes=["https://www.googleapis.com/auth/cloud-platform"], ) + quota_project_id = ( + getattr(credentials, "quota_project_id", None) or project_id + ) else: config = auth_credential.service_account credentials = service_account.Credentials.from_service_account_info( config.service_account_credential.model_dump(), scopes=config.scopes ) + quota_project_id = None credentials.refresh(Request()) @@ -90,6 +94,11 @@ def exchange_credential( http=HttpAuth( scheme="bearer", credentials=HttpCredentials(token=credentials.token), + additional_headers={ + "x-goog-user-project": quota_project_id, + } + if quota_project_id + else None, ), ) return updated_credential diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py index 5c27b16851..27c6acdaeb 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py @@ -320,6 +320,13 @@ def _prepare_request_params( user_agent = f"google-adk/{adk_version} (tool: {self.name})" header_params["User-Agent"] = user_agent + if ( + self.auth_credential + and self.auth_credential.http + and self.auth_credential.http.additional_headers + ): + header_params.update(self.auth_credential.http.additional_headers) + params_map: Dict[str, ApiParameter] = {p.py_name: p for p in parameters} # Fill in path, query, header and cookie parameters to the request diff --git a/src/google/adk/version.py b/src/google/adk/version.py index faed20df96..df5633d43a 100644 --- a/src/google/adk/version.py +++ b/src/google/adk/version.py @@ -13,4 +13,4 @@ # limitations under the License. # version: major.minor.patch -__version__ = "1.22.0" +__version__ = "1.22.1" diff --git a/tests/unittests/a2a/utils/test_agent_card_builder.py b/tests/unittests/a2a/utils/test_agent_card_builder.py index 3bf3202897..d8fbf1e9f9 100644 --- a/tests/unittests/a2a/utils/test_agent_card_builder.py +++ b/tests/unittests/a2a/utils/test_agent_card_builder.py @@ -28,6 +28,7 @@ from google.adk.a2a.utils.agent_card_builder import _build_sequential_description from google.adk.a2a.utils.agent_card_builder import _convert_example_tool_examples from google.adk.a2a.utils.agent_card_builder import _extract_examples_from_instruction +from google.adk.a2a.utils.agent_card_builder import _extract_inputs_from_examples from google.adk.a2a.utils.agent_card_builder import _get_agent_skill_name from google.adk.a2a.utils.agent_card_builder import _get_agent_type from google.adk.a2a.utils.agent_card_builder import _get_default_description @@ -41,6 +42,7 @@ from google.adk.agents.loop_agent import LoopAgent from google.adk.agents.parallel_agent import ParallelAgent from google.adk.agents.sequential_agent import SequentialAgent +from google.adk.examples import Example from google.adk.tools.example_tool import ExampleTool import pytest @@ -1100,3 +1102,73 @@ def test_extract_examples_from_instruction_odd_number_of_matches(self): assert len(result) == 1 # Only complete pairs should be included assert result[0]["input"] == {"text": "What is the weather?"} assert result[0]["output"] == [{"text": "What time is it?"}] + + def test_extract_inputs_from_examples_from_plain_text_input(self): + """Test _extract_inputs_from_examples on plain text as input.""" + # Arrange + examples = [ + { + "input": {"text": "What is the weather?"}, + "output": [{"text": "What time is it?"}], + }, + { + "input": {"text": "The weather is sunny."}, + "output": [{"text": "It is 3 PM."}], + }, + ] + + # Act + result = _extract_inputs_from_examples(examples) + + # Assert + assert len(result) == 2 + assert result[0] == "What is the weather?" + assert result[1] == "The weather is sunny." + + def test_extract_inputs_from_examples_from_example_tool(self): + """Test _extract_inputs_from_examples as extracted from ExampleTool.""" + + # Arrange + # This is what would be extracted from an ExampleTool + examples = [ + { + "input": { + "role": "user", + "parts": [{"text": "What is the weather?"}], + }, + "output": [ + { + "role": "model", + "parts": [{"text": "What time is it?"}], + }, + ], + }, + { + "input": { + "role": "user", + "parts": [{"text": "The weather is sunny."}], + }, + "output": [ + { + "role": "model", + "parts": [{"text": "It is 3 PM."}], + }, + ], + }, + ] + + # Act + result = _extract_inputs_from_examples(examples) + + # Assert + assert len(result) == 2 + assert result[0] == "What is the weather?" + assert result[1] == "The weather is sunny." + + def test_extract_inputs_from_examples_none_input(self): + """Test _extract_inputs_from_examples on None as input.""" + # Act + result = _extract_inputs_from_examples(None) + + # Assert + assert len(result) == 0 diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index b3f4bd9e25..c687ceb0cb 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -2304,6 +2304,126 @@ async def test_get_content_file_uri(file_uri, mime_type): } +@pytest.mark.asyncio +@pytest.mark.parametrize( + "provider,model", + [ + ("openai", "openai/gpt-4o"), + ("azure", "azure/gpt-4"), + ], +) +async def test_get_content_file_uri_file_id_required_falls_back_to_text( + provider, model +): + parts = [ + types.Part( + file_data=types.FileData( + file_uri="gs://bucket/path/to/document.pdf", + mime_type="application/pdf", + display_name="document.pdf", + ) + ) + ] + content = await _get_content(parts, provider=provider, model=model) + assert content == [ + {"type": "text", "text": '[File reference: "document.pdf"]'} + ] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "provider,model", + [ + ("openai", "openai/gpt-4o"), + ("azure", "azure/gpt-4"), + ], +) +async def test_get_content_file_uri_file_id_required_preserves_file_id( + provider, model +): + parts = [ + types.Part( + file_data=types.FileData( + file_uri="file-abc123", + mime_type="application/pdf", + ) + ) + ] + content = await _get_content(parts, provider=provider, model=model) + assert content == [{"type": "file", "file": {"file_id": "file-abc123"}}] + + +@pytest.mark.asyncio +async def test_get_content_file_uri_anthropic_falls_back_to_text(): + parts = [ + types.Part( + file_data=types.FileData( + file_uri="gs://bucket/path/to/document.pdf", + mime_type="application/pdf", + display_name="document.pdf", + ) + ) + ] + content = await _get_content( + parts, provider="anthropic", model="anthropic/claude-3-5" + ) + assert content == [ + {"type": "text", "text": '[File reference: "document.pdf"]'} + ] + + +@pytest.mark.asyncio +async def test_get_content_file_uri_anthropic_openai_file_id_falls_back_to_text(): + parts = [types.Part(file_data=types.FileData(file_uri="file-abc123"))] + content = await _get_content( + parts, provider="anthropic", model="anthropic/claude-3-5" + ) + assert content == [ + {"type": "text", "text": '[File reference: "file-abc123"]'} + ] + + +@pytest.mark.asyncio +async def test_get_content_file_uri_vertex_ai_non_gemini_falls_back_to_text(): + parts = [ + types.Part( + file_data=types.FileData( + file_uri="gs://bucket/path/to/document.pdf", + mime_type="application/pdf", + display_name="document.pdf", + ) + ) + ] + content = await _get_content( + parts, provider="vertex_ai", model="vertex_ai/claude-3-5" + ) + assert content == [ + {"type": "text", "text": '[File reference: "document.pdf"]'} + ] + + +@pytest.mark.asyncio +async def test_get_content_file_uri_vertex_ai_gemini_keeps_file_block(): + parts = [ + types.Part( + file_data=types.FileData( + file_uri="gs://bucket/path/to/document.pdf", + mime_type="application/pdf", + ) + ) + ] + content = await _get_content( + parts, provider="vertex_ai", model="vertex_ai/gemini-2.5-flash" + ) + assert content == [{ + "type": "file", + "file": { + "file_id": "gs://bucket/path/to/document.pdf", + "format": "application/pdf", + }, + }] + + @pytest.mark.asyncio async def test_get_content_file_uri_infer_mime_type(): """Test MIME type inference from file_uri extension. diff --git a/tests/unittests/telemetry/test_spans.py b/tests/unittests/telemetry/test_spans.py index c87730a5e7..dd785daf7e 100644 --- a/tests/unittests/telemetry/test_spans.py +++ b/tests/unittests/telemetry/test_spans.py @@ -27,6 +27,7 @@ from google.adk.telemetry.tracing import trace_agent_invocation from google.adk.telemetry.tracing import trace_call_llm from google.adk.telemetry.tracing import trace_merged_tool_calls +from google.adk.telemetry.tracing import trace_send_data from google.adk.telemetry.tracing import trace_tool_call from google.adk.tools.base_tool import BaseTool from google.genai import types @@ -447,7 +448,7 @@ def test_trace_merged_tool_calls_sets_correct_attributes( async def test_call_llm_disabling_request_response_content( monkeypatch, mock_span_fixture ): - """Test trace_call_llm doesn't set request and response attributes if env is set to false""" + """Test trace_call_llm sets placeholders when capture is disabled.""" # Arrange monkeypatch.setenv(ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS, 'false') monkeypatch.setattr( @@ -474,23 +475,19 @@ async def test_call_llm_disabling_request_response_content( trace_call_llm(invocation_context, 'test_event_id', llm_request, llm_response) # Assert - assert not any( - arg_name == 'gcp.vertex.agent.llm_request' and arg_value != {} - for arg_name, arg_value in ( - call_obj.args - for call_obj in mock_span_fixture.set_attribute.call_args_list - ) - ), "Attribute 'gcp.vertex.agent.llm_request' was incorrectly set on the span." - - assert not any( - arg_name == 'gcp.vertex.agent.llm_response' and arg_value != {} - for arg_name, arg_value in ( - call_obj.args - for call_obj in mock_span_fixture.set_attribute.call_args_list - ) - ), ( - "Attribute 'gcp.vertex.agent.llm_response' was incorrectly set on the" - ' span.' + assert ( + 'gcp.vertex.agent.llm_request', + '{}', + ) in ( + call_obj.args + for call_obj in mock_span_fixture.set_attribute.call_args_list + ) + assert ( + 'gcp.vertex.agent.llm_response', + '{}', + ) in ( + call_obj.args + for call_obj in mock_span_fixture.set_attribute.call_args_list ) @@ -500,7 +497,7 @@ def test_trace_tool_call_disabling_request_response_content( mock_tool_fixture, mock_event_fixture, ): - """Test trace_tool_call doesn't set request and response attributes if env is set to false""" + """Test trace_tool_call sets placeholders when capture is disabled.""" # Arrange monkeypatch.setenv(ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS, 'false') monkeypatch.setattr( @@ -537,26 +534,19 @@ def test_trace_tool_call_disabling_request_response_content( ) # Assert - assert not any( - arg_name == 'gcp.vertex.agent.tool_call_args' and arg_value != {} - for arg_name, arg_value in ( - call_obj.args - for call_obj in mock_span_fixture.set_attribute.call_args_list - ) - ), ( - "Attribute 'gcp.vertex.agent.tool_call_args' was incorrectly set on the" - ' span.' + assert ( + 'gcp.vertex.agent.tool_call_args', + '{}', + ) in ( + call_obj.args + for call_obj in mock_span_fixture.set_attribute.call_args_list ) - - assert not any( - arg_name == 'gcp.vertex.agent.tool_response' and arg_value != {} - for arg_name, arg_value in ( - call_obj.args - for call_obj in mock_span_fixture.set_attribute.call_args_list - ) - ), ( - "Attribute 'gcp.vertex.agent.tool_response' was incorrectly set on the" - ' span.' + assert ( + 'gcp.vertex.agent.tool_response', + '{}', + ) in ( + call_obj.args + for call_obj in mock_span_fixture.set_attribute.call_args_list ) @@ -565,7 +555,7 @@ def test_trace_merged_tool_disabling_request_response_content( mock_span_fixture, mock_event_fixture, ): - """Test trace_merged_tool doesn't set request and response attributes if env is set to false""" + """Test trace_merged_tool_calls sets placeholders when capture is disabled.""" # Arrange monkeypatch.setenv(ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS, 'false') monkeypatch.setattr( @@ -585,13 +575,40 @@ def test_trace_merged_tool_disabling_request_response_content( ) # Assert - assert not any( - arg_name == 'gcp.vertex.agent.tool_response' and arg_value != {} - for arg_name, arg_value in ( - call_obj.args - for call_obj in mock_span_fixture.set_attribute.call_args_list - ) - ), ( - "Attribute 'gcp.vertex.agent.tool_response' was incorrectly set on the" - ' span.' + assert ( + 'gcp.vertex.agent.tool_response', + '{}', + ) in ( + call_obj.args + for call_obj in mock_span_fixture.set_attribute.call_args_list + ) + + +@pytest.mark.asyncio +async def test_trace_send_data_disabling_request_response_content( + monkeypatch, mock_span_fixture +): + """Test trace_send_data sets placeholders when capture is disabled.""" + monkeypatch.setenv(ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS, 'false') + monkeypatch.setattr( + 'opentelemetry.trace.get_current_span', lambda: mock_span_fixture + ) + + agent = LlmAgent(name='test_agent') + invocation_context = await _create_invocation_context(agent) + + trace_send_data( + invocation_context=invocation_context, + event_id='test_event_id', + data=[ + types.Content( + role='user', + parts=[types.Part(text='hi')], + ) + ], + ) + + assert ('gcp.vertex.agent.data', '{}') in ( + call_obj.args + for call_obj in mock_span_fixture.set_attribute.call_args_list ) diff --git a/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py b/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py index db929c8e99..4d930b3977 100644 --- a/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py +++ b/tests/unittests/tools/openapi_tool/auth/credential_exchangers/test_service_account_exchanger.py @@ -99,14 +99,28 @@ def test_exchange_credential_success( mock_credentials.refresh.assert_called_once() +@pytest.mark.parametrize( + "cred_quota_project_id, adc_project_id, expected_quota_project_id", + [ + ("test_project", "another_project", "test_project"), + (None, "adc_project", "adc_project"), + (None, None, None), + ], +) def test_exchange_credential_use_default_credential_success( - service_account_exchanger, auth_scheme, monkeypatch + service_account_exchanger, + auth_scheme, + monkeypatch, + cred_quota_project_id, + adc_project_id, + expected_quota_project_id, ): """Test successful exchange of service account credentials using default credential.""" mock_credentials = MagicMock() mock_credentials.token = "mock_access_token" + mock_credentials.quota_project_id = cred_quota_project_id mock_google_auth_default = MagicMock( - return_value=(mock_credentials, "test_project") + return_value=(mock_credentials, adc_project_id) ) monkeypatch.setattr(google.auth, "default", mock_google_auth_default) @@ -125,6 +139,13 @@ def test_exchange_credential_use_default_credential_success( assert result.auth_type == AuthCredentialTypes.HTTP assert result.http.scheme == "bearer" assert result.http.credentials.token == "mock_access_token" + if expected_quota_project_id: + assert ( + result.http.additional_headers["x-goog-user-project"] + == expected_quota_project_id + ) + else: + assert not result.http.additional_headers # Verify google.auth.default is called with the correct scopes parameter mock_google_auth_default.assert_called_once_with( scopes=["https://www.googleapis.com/auth/cloud-platform"] diff --git a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py index 560813e619..ddf09aeb4a 100644 --- a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py +++ b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_rest_api_tool.py @@ -25,6 +25,10 @@ from fastapi.openapi.models import Parameter as OpenAPIParameter from fastapi.openapi.models import RequestBody from fastapi.openapi.models import Schema as OpenAPISchema +from google.adk.auth.auth_credential import AuthCredential +from google.adk.auth.auth_credential import AuthCredentialTypes +from google.adk.auth.auth_credential import HttpAuth +from google.adk.auth.auth_credential import HttpCredentials from google.adk.sessions.state import State from google.adk.tools.openapi_tool.auth.auth_helpers import token_to_scheme_credential from google.adk.tools.openapi_tool.common.common import ApiParameter @@ -721,6 +725,35 @@ def test_prepare_request_params_cookie_param( assert request_params["cookies"]["session_id"] == "cookie_value" + def test_prepare_request_params_quota_project_id( + self, + sample_endpoint, + sample_operation, + sample_auth_scheme, + ): + auth_credential = AuthCredential( + auth_type=AuthCredentialTypes.HTTP, + http=HttpAuth( + scheme="bearer", + credentials=HttpCredentials(), + additional_headers={"x-goog-user-project": "test-project"}, + ), + ) + tool = RestApiTool( + name="test_tool", + description="Test Tool", + endpoint=sample_endpoint, + operation=sample_operation, + auth_credential=auth_credential, + auth_scheme=sample_auth_scheme, + ) + params = [] + kwargs = {} + + request_params = tool._prepare_request_params(params, kwargs) + + assert request_params["headers"]["x-goog-user-project"] == "test-project" + def test_prepare_request_params_multiple_mime_types( self, sample_endpoint, sample_auth_credential, sample_auth_scheme ):