Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(expand auth certs to service account) #511

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions trailblazer/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class Container(containers.DeclarativeContainer):
oauth_client_secret: str | None = os.environ.get("GOOGLE_CLIENT_SECRET")
oauth_redirect_uri: str | None = os.environ.get("GOOGLE_REDIRECT_URI")
google_oauth_base_url: str | None = os.environ.get("GOOGLE_OAUTH_BASE_URL")
google_service_account_cert_url: str | None = os.environ.get("GOOGLE_SERVICE_ACCOUNT_CERTS_URL")
encryption_key: str | None = os.environ.get("ENCRYPTION_KEY")
google_api_base_url: str | None = os.environ.get("GOOGLE_API_BASE_URL")
slurm_jwt_token: str | None = os.environ.get("SLURM_JWT")
Expand Down Expand Up @@ -86,6 +87,7 @@ class Container(containers.DeclarativeContainer):
store=store,
google_client_id=oauth_client_id,
google_api_base_url=google_api_base_url,
google_service_account_cert_url=google_service_account_cert_url,
)
auth_service = providers.Singleton(
AuthenticationService,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,17 @@
class UserVerificationService:
"""Service to verify the user."""

def __init__(self, store: Store, google_client_id: str, google_api_base_url: str):
def __init__(
self,
store: Store,
google_client_id: str,
google_api_base_url: str,
google_service_account_cert_url: str,
):
self.store: Store = store
self.google_client_id: str = google_client_id
self.google_api_base_url: str = google_api_base_url
self.google_service_account_cert_url: str = google_service_account_cert_url

def verify_user(self, authorization_header: str) -> User:
"""Verify the user by checking if the JWT token provided is valid."""
Expand All @@ -27,8 +34,8 @@ def verify_user(self, authorization_header: str) -> User:
payload: Mapping = jwt.decode(
token=jwt_token,
certs=google_certs,
verify=True,
audience=self.google_client_id,
verify=True,
)
except Exception as error:
raise UserTokenVerificationError(f"{error}") from error
Expand All @@ -48,13 +55,20 @@ def _extract_token_from_header(authorization_header: str) -> str:

def _get_google_certs(self) -> Mapping:
"""Get the Google certificates."""
try:
# Fetch the Google public keys. Google oauth uses v1 certs.
response = requests.get(self.google_api_base_url + "/oauth2/v1/certs")
response.raise_for_status()
return response.json()
except requests.RequestException as e:
raise GoogleCertsError("Failed to fetch Google public keys") from e
certs = {}
# Get the google certs for the public oauth2 endpoint and the service account cert url
cert_urls: list[str] = [
self.google_api_base_url + "/oauth2/v1/certs",
self.google_service_account_cert_url,
]
for url in cert_urls:
try:
response = requests.get(url)
response.raise_for_status()
certs.update(response.json())
except requests.RequestException as e:
raise GoogleCertsError("Failed to fetch Google public keys") from e
return certs

def _get_user(self, user_email: str) -> User:
"""Check if the user is known."""
Expand Down
Loading