From b807d62fe35a5de4e1efd37f38230edcf5f992ad Mon Sep 17 00:00:00 2001 From: Faraaz Ahmed Date: Tue, 2 Dec 2025 19:13:38 -0800 Subject: [PATCH] feat(bigquery): Add labels support to BigQueryToolConfig for job tracking and monitoring Merge https://github.com/google/adk-python/pull/3583 **Please ensure you have read the [contribution guide](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) before creating a pull request.** ### Link to Issue or Description of Change **1. Link to an existing issue (if applicable):** - Closes: #3582 **2. Or, if no issue exists, describe the change:** _If applicable, please follow the issue templates to provide as much detail as possible._ **Problem:** Currently, the BigQuery tool in ADK does not provide a way for developers to add custom labels to BigQuery jobs created by their agents. This makes it difficult to: Track and monitor BigQuery costs associated with specific agents or use cases Organize and filter BigQuery jobs in the Google Cloud Console Implement billing attribution and resource organization strategies Differentiate between jobs from different environments (dev, staging, production) While the tool automatically adds an internal adk-bigquery-tool label with the caller_id, there's no mechanism for users to add their own custom labels for tracking and monitoring purposes. **Solution:** Add a labels configuration field to BigQueryToolConfig that allows users to specify custom key-value pairs to be applied to all BigQuery jobs executed by the agent. The solution should: Configuration Option: Add an optional labels parameter to BigQueryToolConfig accepting a dictionary of string key-value pairs Validation: Ensure labels follow BigQuery's requirements (non-empty string keys, string values) Job Application: Automatically apply configured labels to all BigQuery jobs alongside the existing internal labels Documentation: Provide clear documentation on how to use labels for tracking and monitoring ### Testing Plan _Please describe the tests that you ran to verify your changes. This is required for all PRs that are not small documentation or typo fixes._ **Unit Tests:** - [x] I have added or updated unit tests for my change. - [x] All unit tests pass locally. _Please include a summary of passed `pytest` results._ ``` pytest tests/unittests/tools/bigquery/test_bigquery_tool_config.py -v --tb=line -W ignore::UserWarning ========================================= test session starts ========================================== platform darwin -- Python 3.11.14, pytest-9.0.1, pluggy-1.6.0 -- *****redacted****** cachedir: .pytest_cache rootdir: *****redacted****** configfile: pyproject.toml plugins: mock-3.15.1, anyio-4.11.0, xdist-3.8.0, langsmith-0.4.43, asyncio-1.3.0 asyncio: mode=Mode.AUTO, debug=False, asyncio_default_fixture_loop_scope=function, asyncio_default_test_loop_scope=function collected 14 items tests/unittests/tools/bigquery/test_bigquery_tool_config.py::test_bigquery_tool_config_experimental_warning PASSED [ 7%] tests/unittests/tools/bigquery/test_bigquery_tool_config.py::test_bigquery_tool_config_invalid_property PASSED [ 14%] tests/unittests/tools/bigquery/test_bigquery_tool_config.py::test_bigquery_tool_config_invalid_application_name PASSED [ 21%] tests/unittests/tools/bigquery/test_bigquery_tool_config.py::test_bigquery_tool_config_max_query_result_rows_default PASSED [ 28%] tests/unittests/tools/bigquery/test_bigquery_tool_config.py::test_bigquery_tool_config_max_query_result_rows_custom PASSED [ 35%] tests/unittests/tools/bigquery/test_bigquery_tool_config.py::test_bigquery_tool_config_valid_maximum_bytes_billed PASSED [ 42%] tests/unittests/tools/bigquery/test_bigquery_tool_config.py::test_bigquery_tool_config_invalid_maximum_bytes_billed PASSED [ 50%] tests/unittests/tools/bigquery/test_bigquery_tool_config.py::test_bigquery_tool_config_valid_labels PASSED [ 57%] tests/unittests/tools/bigquery/test_bigquery_tool_config.py::test_bigquery_tool_config_empty_labels PASSED [ 64%] tests/unittests/tools/bigquery/test_bigquery_tool_config.py::test_bigquery_tool_config_none_labels PASSED [ 71%] tests/unittests/tools/bigquery/test_bigquery_tool_config.py::test_bigquery_tool_config_invalid_labels_type PASSED [ 78%] tests/unittests/tools/bigquery/test_bigquery_tool_config.py::test_bigquery_tool_config_invalid_label_key_type PASSED [ 85%] tests/unittests/tools/bigquery/test_bigquery_tool_config.py::test_bigquery_tool_config_invalid_label_value_type PASSED [ 92%] tests/unittests/tools/bigquery/test_bigquery_tool_config.py::test_bigquery_tool_config_empty_label_key PASSED [100%] ==================================================================================================== 14 passed in 2.02s ==================================================================================================== ``` **Manual End-to-End (E2E) Tests:** _Please provide instructions on how to manually test your changes, including any necessary setup or configuration. Please provide logs or screenshots to help reviewers better understand the fix._ ### Checklist - [x] I have read the [CONTRIBUTING.md](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) document. - [x] I have performed a self-review of my own code. - [x] I have commented my code, particularly in hard-to-understand areas. - [x] I have added tests that prove my fix is effective or that my feature works. - [x] New and existing unit tests pass locally with my changes. - [x] I have manually tested my changes end-to-end. - [x] Any dependent changes have been merged and published in downstream modules. ### Additional context _Add any other context or screenshots about the feature request here._ COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/3583 from Faraaz1994:feature/bq_label 0fd7fe6a3b1ee20a36f73562e425d007b8d7dc9d PiperOrigin-RevId: 839523588 --- src/google/adk/tools/bigquery/config.py | 20 ++ src/google/adk/tools/bigquery/query_tool.py | 5 +- .../bigquery/test_bigquery_query_tool.py | 237 ++++++++++++++++++ .../bigquery/test_bigquery_tool_config.py | 58 +++++ 4 files changed, 319 insertions(+), 1 deletion(-) diff --git a/src/google/adk/tools/bigquery/config.py b/src/google/adk/tools/bigquery/config.py index 7768f214ed..39b6a3d9b6 100644 --- a/src/google/adk/tools/bigquery/config.py +++ b/src/google/adk/tools/bigquery/config.py @@ -101,6 +101,16 @@ class BigQueryToolConfig(BaseModel): locations, see https://cloud.google.com/bigquery/docs/locations. """ + job_labels: Optional[dict[str, str]] = None + """Labels to apply to BigQuery jobs for tracking and monitoring. + + These labels will be added to all BigQuery jobs executed by the tools. + Labels must be key-value pairs where both keys and values are strings. + Labels can be used for billing, monitoring, and resource organization. + For more information about labels, see + https://cloud.google.com/bigquery/docs/labels-intro. + """ + @field_validator('maximum_bytes_billed') @classmethod def validate_maximum_bytes_billed(cls, v): @@ -121,3 +131,13 @@ def validate_application_name(cls, v): if v and ' ' in v: raise ValueError('Application name should not contain spaces.') return v + + @field_validator('job_labels') + @classmethod + def validate_job_labels(cls, v): + """Validate that job_labels keys are not empty.""" + if v is not None: + for key in v.keys(): + if not key: + raise ValueError('Label keys cannot be empty.') + return v diff --git a/src/google/adk/tools/bigquery/query_tool.py b/src/google/adk/tools/bigquery/query_tool.py index 666dc3c5a1..5bcd734e70 100644 --- a/src/google/adk/tools/bigquery/query_tool.py +++ b/src/google/adk/tools/bigquery/query_tool.py @@ -68,7 +68,10 @@ def _execute_sql( bq_connection_properties = [] # BigQuery job labels if applicable - bq_job_labels = {} + bq_job_labels = ( + settings.job_labels.copy() if settings and settings.job_labels else {} + ) + if caller_id: bq_job_labels["adk-bigquery-tool"] = caller_id if settings and settings.application_name: diff --git a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py index eef83a1f5e..1791100e1f 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py +++ b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py @@ -1709,6 +1709,65 @@ def test_execute_sql_job_labels( } +@pytest.mark.parametrize( + ("write_mode", "dry_run", "query_call_count", "query_and_wait_call_count"), + [ + pytest.param(WriteMode.ALLOWED, False, 0, 1, id="write-allowed"), + pytest.param(WriteMode.ALLOWED, True, 1, 0, id="write-allowed-dry-run"), + pytest.param(WriteMode.BLOCKED, False, 1, 1, id="write-blocked"), + pytest.param(WriteMode.BLOCKED, True, 2, 0, id="write-blocked-dry-run"), + pytest.param(WriteMode.PROTECTED, False, 2, 1, id="write-protected"), + pytest.param( + WriteMode.PROTECTED, True, 3, 0, id="write-protected-dry-run" + ), + ], +) +def test_execute_sql_user_job_labels_augment_internal_labels( + write_mode, dry_run, query_call_count, query_and_wait_call_count +): + """Test execute_sql tool augments user job_labels with internal labels.""" + project = "my_project" + query = "SELECT 123 AS num" + statement_type = "SELECT" + credentials = mock.create_autospec(Credentials, instance=True) + user_labels = {"environment": "test", "team": "data"} + tool_settings = BigQueryToolConfig( + write_mode=write_mode, + job_labels=user_labels, + ) + tool_context = mock.create_autospec(ToolContext, instance=True) + tool_context.state.get.return_value = None + + with mock.patch.object(bigquery, "Client", autospec=True) as Client: + bq_client = Client.return_value + + query_job = mock.create_autospec(bigquery.QueryJob) + query_job.statement_type = statement_type + bq_client.query.return_value = query_job + + query_tool.execute_sql( + project, + query, + credentials, + tool_settings, + tool_context, + dry_run=dry_run, + ) + + assert bq_client.query.call_count == query_call_count + assert bq_client.query_and_wait.call_count == query_and_wait_call_count + # Build expected labels from user_labels + internal label + expected_labels = {**user_labels, "adk-bigquery-tool": "execute_sql"} + for call_args_list in [ + bq_client.query.call_args_list, + bq_client.query_and_wait.call_args_list, + ]: + for call_args in call_args_list: + _, mock_kwargs = call_args + # Verify user labels are preserved and internal label is added + assert mock_kwargs["job_config"].labels == expected_labels + + @pytest.mark.parametrize( ("tool_call", "expected_tool_label"), [ @@ -1850,6 +1909,94 @@ def test_ml_tool_job_labels_w_application_name(tool_call, expected_tool_label): assert mock_kwargs["job_config"].labels == expected_labels +@pytest.mark.parametrize( + ("tool_call", "expected_labels"), + [ + pytest.param( + lambda tool_context: query_tool.forecast( + project_id="test-project", + history_data="SELECT * FROM `test-dataset.test-table`", + timestamp_col="ts_col", + data_col="data_col", + credentials=mock.create_autospec(Credentials, instance=True), + settings=BigQueryToolConfig( + write_mode=WriteMode.ALLOWED, + job_labels={"environment": "prod", "app": "forecaster"}, + ), + tool_context=tool_context, + ), + { + "environment": "prod", + "app": "forecaster", + "adk-bigquery-tool": "forecast", + }, + id="forecast", + ), + pytest.param( + lambda tool_context: query_tool.analyze_contribution( + project_id="test-project", + input_data="test-dataset.test-table", + dimension_id_cols=["dim1", "dim2"], + contribution_metric="SUM(metric)", + is_test_col="is_test", + credentials=mock.create_autospec(Credentials, instance=True), + settings=BigQueryToolConfig( + write_mode=WriteMode.ALLOWED, + job_labels={"environment": "prod", "app": "analyzer"}, + ), + tool_context=tool_context, + ), + { + "environment": "prod", + "app": "analyzer", + "adk-bigquery-tool": "analyze_contribution", + }, + id="analyze-contribution", + ), + pytest.param( + lambda tool_context: query_tool.detect_anomalies( + project_id="test-project", + history_data="SELECT * FROM `test-dataset.test-table`", + times_series_timestamp_col="ts_timestamp", + times_series_data_col="ts_data", + credentials=mock.create_autospec(Credentials, instance=True), + settings=BigQueryToolConfig( + write_mode=WriteMode.ALLOWED, + job_labels={"environment": "prod", "app": "detector"}, + ), + tool_context=tool_context, + ), + { + "environment": "prod", + "app": "detector", + "adk-bigquery-tool": "detect_anomalies", + }, + id="detect-anomalies", + ), + ], +) +def test_ml_tool_user_job_labels_augment_internal_labels( + tool_call, expected_labels +): + """Test ML tools augment user job_labels with internal labels.""" + + with mock.patch.object(bigquery, "Client", autospec=True) as Client: + bq_client = Client.return_value + + tool_context = mock.create_autospec(ToolContext, instance=True) + tool_context.state.get.return_value = None + tool_call(tool_context) + + for call_args_list in [ + bq_client.query.call_args_list, + bq_client.query_and_wait.call_args_list, + ]: + for call_args in call_args_list: + _, mock_kwargs = call_args + # Verify user labels are preserved and internal label is added + assert mock_kwargs["job_config"].labels == expected_labels + + def test_execute_sql_max_rows_config(): """Test execute_sql tool respects max_query_result_rows from config.""" project = "my_project" @@ -2014,3 +2161,93 @@ def test_tool_call_doesnt_change_global_settings(tool_call): # Test settings write mode after assert settings.write_mode == WriteMode.ALLOWED + + +@pytest.mark.parametrize( + ("tool_call",), + [ + pytest.param( + lambda settings, tool_context: query_tool.execute_sql( + project_id="test-project", + query="SELECT * FROM `test-dataset.test-table`", + credentials=mock.create_autospec(Credentials, instance=True), + settings=settings, + tool_context=tool_context, + ), + id="execute-sql", + ), + pytest.param( + lambda settings, tool_context: query_tool.forecast( + project_id="test-project", + history_data="SELECT * FROM `test-dataset.test-table`", + timestamp_col="ts_col", + data_col="data_col", + credentials=mock.create_autospec(Credentials, instance=True), + settings=settings, + tool_context=tool_context, + ), + id="forecast", + ), + pytest.param( + lambda settings, tool_context: query_tool.analyze_contribution( + project_id="test-project", + input_data="test-dataset.test-table", + dimension_id_cols=["dim1", "dim2"], + contribution_metric="SUM(metric)", + is_test_col="is_test", + credentials=mock.create_autospec(Credentials, instance=True), + settings=settings, + tool_context=tool_context, + ), + id="analyze-contribution", + ), + pytest.param( + lambda settings, tool_context: query_tool.detect_anomalies( + project_id="test-project", + history_data="SELECT * FROM `test-dataset.test-table`", + times_series_timestamp_col="ts_timestamp", + times_series_data_col="ts_data", + credentials=mock.create_autospec(Credentials, instance=True), + settings=settings, + tool_context=tool_context, + ), + id="detect-anomalies", + ), + ], +) +def test_tool_call_doesnt_mutate_job_labels(tool_call): + """Test query tools don't mutate job_labels in global settings.""" + original_labels = {"environment": "test", "team": "data"} + settings = BigQueryToolConfig( + write_mode=WriteMode.ALLOWED, + job_labels=original_labels.copy(), + ) + tool_context = mock.create_autospec(ToolContext, instance=True) + tool_context.state.get.return_value = ( + "test-bq-session-id", + "_anonymous_dataset", + ) + + with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: + # The mock instance + bq_client = Client.return_value + + # Simulate the result of query API + query_job = mock.create_autospec(bigquery.QueryJob) + query_job.destination.dataset_id = "_anonymous_dataset" + bq_client.query.return_value = query_job + bq_client.query_and_wait.return_value = [] + + # Test job_labels before + assert settings.job_labels == original_labels + assert "adk-bigquery-tool" not in settings.job_labels + + # Call the tool + result = tool_call(settings, tool_context) + + # Test successful execution of the tool + assert result == {"status": "SUCCESS", "rows": []} + + # Test job_labels remain unchanged after tool call + assert settings.job_labels == original_labels + assert "adk-bigquery-tool" not in settings.job_labels diff --git a/tests/unittests/tools/bigquery/test_bigquery_tool_config.py b/tests/unittests/tools/bigquery/test_bigquery_tool_config.py index 5854c97797..072ccea7d0 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_tool_config.py +++ b/tests/unittests/tools/bigquery/test_bigquery_tool_config.py @@ -77,3 +77,61 @@ def test_bigquery_tool_config_invalid_maximum_bytes_billed(): ), ): BigQueryToolConfig(maximum_bytes_billed=10_485_759) + + +@pytest.mark.parametrize( + "labels", + [ + pytest.param( + {"environment": "test", "team": "data"}, + id="valid-labels", + ), + pytest.param( + {}, + id="empty-labels", + ), + pytest.param( + None, + id="none-labels", + ), + ], +) +def test_bigquery_tool_config_valid_labels(labels): + """Test BigQueryToolConfig accepts valid labels.""" + with pytest.warns(UserWarning): + config = BigQueryToolConfig(job_labels=labels) + assert config.job_labels == labels + + +@pytest.mark.parametrize( + ("labels", "message"), + [ + pytest.param( + "invalid", + "Input should be a valid dictionary", + id="invalid-type", + ), + pytest.param( + {123: "value"}, + "Input should be a valid string", + id="non-str-key", + ), + pytest.param( + {"key": 123}, + "Input should be a valid string", + id="non-str-value", + ), + pytest.param( + {"": "value"}, + "Label keys cannot be empty", + id="empty-label-key", + ), + ], +) +def test_bigquery_tool_config_invalid_labels(labels, message): + """Test BigQueryToolConfig raises an exception with invalid labels.""" + with pytest.raises( + ValueError, + match=message, + ): + BigQueryToolConfig(job_labels=labels)