Skip to content

Commit ee8fd70

Browse files
fix(functions): Refresh credentials before enqueueing first task (#907)
* fix(functions): Refresh credentials before enqueueing task This change addresses an issue where enqueueing a task from a Cloud Function would fail with a InvalidArgumentError error. This was caused by uninitialized credentials being used to in the task payload. The fix explicitly refreshes the credential before accessing the credential, ensuring a valid token or service account email is used in the in the task payload. This also includes a correction for an f-string typo in the Authorization header construction. * fix(functions): Move credential refresh to functions service init * fix(functions): Moved credential refresh to run on task payload update with freshness guard --------- Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
1 parent de713d2 commit ee8fd70

File tree

3 files changed

+101
-3
lines changed

3 files changed

+101
-3
lines changed

firebase_admin/functions.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@
2222
from base64 import b64encode
2323
from typing import Any, Optional, Dict
2424
from dataclasses import dataclass
25+
2526
from google.auth.compute_engine import Credentials as ComputeEngineCredentials
27+
from google.auth.credentials import TokenState
28+
from google.auth.exceptions import RefreshError
29+
from google.auth.transport import requests as google_auth_requests
2630

2731
import requests
2832
import firebase_admin
@@ -285,14 +289,22 @@ def _update_task_payload(self, task: Task, resource: Resource, extension_id: str
285289
# Get function url from task or generate from resources
286290
if not _Validators.is_non_empty_string(task.http_request['url']):
287291
task.http_request['url'] = self._get_url(resource, _FIREBASE_FUNCTION_URL_FORMAT)
292+
293+
# Refresh the credential to ensure all attributes (e.g. service_account_email, id_token)
294+
# are populated, preventing cold start errors.
295+
if self._credential.token_state != TokenState.FRESH:
296+
try:
297+
self._credential.refresh(google_auth_requests.Request())
298+
except RefreshError as err:
299+
raise ValueError(f'Initial task payload credential refresh failed: {err}') from err
300+
288301
# If extension id is provided, it emplies that it is being run from a deployed extension.
289302
# Meaning that it's credential should be a Compute Engine Credential.
290303
if _Validators.is_non_empty_string(extension_id) and \
291304
isinstance(self._credential, ComputeEngineCredentials):
292-
293305
id_token = self._credential.token
294306
task.http_request['headers'] = \
295-
{**task.http_request['headers'], 'Authorization': f'Bearer ${id_token}'}
307+
{**task.http_request['headers'], 'Authorization': f'Bearer {id_token}'}
296308
# Delete oidc token
297309
del task.http_request['oidc_token']
298310
else:

tests/test_functions.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,10 @@ def test_task_enqueue(self):
124124
assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header
125125
assert task_id == 'test-task-id'
126126

127+
task = json.loads(recorder[0].body.decode())['task']
128+
assert task['http_request']['oidc_token'] == {'service_account_email': 'mock-email'}
129+
assert task['http_request']['headers'] == {'Content-Type': 'application/json'}
130+
127131
def test_task_enqueue_with_extension(self):
128132
resource_name = (
129133
'projects/test-project/locations/us-central1/queues/'
@@ -142,6 +146,59 @@ def test_task_enqueue_with_extension(self):
142146
assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header
143147
assert task_id == 'test-task-id'
144148

149+
task = json.loads(recorder[0].body.decode())['task']
150+
assert task['http_request']['oidc_token'] == {'service_account_email': 'mock-email'}
151+
assert task['http_request']['headers'] == {'Content-Type': 'application/json'}
152+
153+
def test_task_enqueue_compute_engine(self):
154+
app = firebase_admin.initialize_app(
155+
testutils.MockComputeEngineCredential(),
156+
options={'projectId': 'test-project'},
157+
name='test-project-gce')
158+
_, recorder = self._instrument_functions_service(app)
159+
queue = functions.task_queue('test-function-name', app=app)
160+
task_id = queue.enqueue(_DEFAULT_DATA)
161+
assert len(recorder) == 1
162+
assert recorder[0].method == 'POST'
163+
assert recorder[0].url == _DEFAULT_REQUEST_URL
164+
assert recorder[0].headers['Content-Type'] == 'application/json'
165+
assert recorder[0].headers['Authorization'] == 'Bearer mock-compute-engine-token'
166+
expected_metrics_header = _utils.get_metrics_header() + ' mock-gce-cred-metric-tag'
167+
assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header
168+
assert task_id == 'test-task-id'
169+
170+
task = json.loads(recorder[0].body.decode())['task']
171+
assert task['http_request']['oidc_token'] == {'service_account_email': 'mock-gce-email'}
172+
assert task['http_request']['headers'] == {'Content-Type': 'application/json'}
173+
174+
def test_task_enqueue_with_extension_compute_engine(self):
175+
resource_name = (
176+
'projects/test-project/locations/us-central1/queues/'
177+
'ext-test-extension-id-test-function-name/tasks'
178+
)
179+
extension_response = json.dumps({'name': resource_name + '/test-task-id'})
180+
app = firebase_admin.initialize_app(
181+
testutils.MockComputeEngineCredential(),
182+
options={'projectId': 'test-project'},
183+
name='test-project-gce-extensions')
184+
_, recorder = self._instrument_functions_service(app, payload=extension_response)
185+
queue = functions.task_queue('test-function-name', 'test-extension-id', app)
186+
task_id = queue.enqueue(_DEFAULT_DATA)
187+
assert len(recorder) == 1
188+
assert recorder[0].method == 'POST'
189+
assert recorder[0].url == _CLOUD_TASKS_URL + resource_name
190+
assert recorder[0].headers['Content-Type'] == 'application/json'
191+
assert recorder[0].headers['Authorization'] == 'Bearer mock-compute-engine-token'
192+
expected_metrics_header = _utils.get_metrics_header() + ' mock-gce-cred-metric-tag'
193+
assert recorder[0].headers['x-goog-api-client'] == expected_metrics_header
194+
assert task_id == 'test-task-id'
195+
196+
task = json.loads(recorder[0].body.decode())['task']
197+
assert 'oidc_token' not in task['http_request']
198+
assert task['http_request']['headers'] == {
199+
'Content-Type': 'application/json',
200+
'Authorization': 'Bearer mock-compute-engine-token'}
201+
145202
def test_task_delete(self):
146203
_, recorder = self._instrument_functions_service()
147204
queue = functions.task_queue('test-function-name')

tests/testutils.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,25 @@ def __call__(self, *args, **kwargs): # pylint: disable=arguments-differ
116116
# pylint: disable=abstract-method
117117
class MockGoogleCredential(credentials.Credentials):
118118
"""A mock Google authentication credential."""
119+
120+
def __init__(self):
121+
super().__init__()
122+
self.token = None
123+
self._service_account_email = None
124+
self._token_state = credentials.TokenState.INVALID
125+
119126
def refresh(self, request):
120127
self.token = 'mock-token'
128+
self._service_account_email = 'mock-email'
129+
self._token_state = credentials.TokenState.FRESH
130+
131+
@property
132+
def token_state(self):
133+
return self._token_state
121134

122135
@property
123136
def service_account_email(self):
124-
return 'mock-email'
137+
return self._service_account_email
125138

126139
# Simulate x-goog-api-client modification in credential refresh
127140
def _metric_header_for_usage(self):
@@ -139,8 +152,24 @@ def get_credential(self):
139152

140153
class MockGoogleComputeEngineCredential(compute_engine.Credentials):
141154
"""A mock Compute Engine credential"""
155+
156+
def __init__(self):
157+
super().__init__()
158+
self.token = None
159+
self._service_account_email = None
160+
self._token_state = credentials.TokenState.INVALID
161+
142162
def refresh(self, request):
143163
self.token = 'mock-compute-engine-token'
164+
self._service_account_email = 'mock-gce-email'
165+
self._token_state = credentials.TokenState.FRESH
166+
167+
@property
168+
def token_state(self):
169+
return self._token_state
170+
171+
def _metric_header_for_usage(self):
172+
return 'mock-gce-cred-metric-tag'
144173

145174
class MockComputeEngineCredential(firebase_admin.credentials.Base):
146175
"""A mock Firebase credential implementation."""

0 commit comments

Comments
 (0)