Skip to content

Commit 8a11d73

Browse files
committed
Refactor code to better separate oidc and non-oidc providers
1 parent 0ff3451 commit 8a11d73

File tree

3 files changed

+207
-78
lines changed

3 files changed

+207
-78
lines changed

lib/galaxy/authnz/oidc_utils.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
"""
2+
OIDC-specific utility functions for token handling and verification.
3+
4+
This module contains helper functions that are specific to OpenID Connect (OIDC)
5+
authentication. These should not be used with non-OIDC backends like OAuth2.
6+
"""
7+
8+
import logging
9+
from typing import Optional
10+
11+
import jwt
12+
from jwt import InvalidTokenError
13+
from social_core.backends.open_id_connect import OpenIdConnectAuth
14+
15+
from galaxy.exceptions import MalformedContents
16+
17+
log = logging.getLogger(__name__)
18+
19+
20+
def is_oidc_backend(backend) -> bool:
21+
"""
22+
Check if a PSA backend is OIDC-based.
23+
24+
:param backend: A PSA backend instance
25+
:return: True if backend is OpenIdConnectAuth, False otherwise
26+
"""
27+
return isinstance(backend, OpenIdConnectAuth)
28+
29+
30+
def is_decodable_jwt(token_str: str) -> bool:
31+
"""
32+
Check if a token string looks like a decodable JWT.
33+
34+
We assume decodable JWTs are in the format header.payload.signature
35+
36+
:param token_str: Token string to check
37+
:return: True if token appears to be JWT format
38+
"""
39+
if not token_str:
40+
return False
41+
components = token_str.split(".")
42+
return len(components) == 3
43+
44+
45+
def decode_access_token(token_str: str, backend: OpenIdConnectAuth) -> dict:
46+
"""
47+
Decode and verify an OIDC access token.
48+
49+
This function verifies:
50+
- Signature using provider's public keys
51+
- Token expiration (exp claim)
52+
- Token not-before time (nbf claim)
53+
- Token issued-at time (iat claim)
54+
- Audience (aud claim) matches accepted_audiences
55+
- Issuer (iss claim) matches expected issuer
56+
57+
:param token_str: JWT access token string
58+
:param backend: OpenIdConnectAuth backend instance
59+
:return: Decoded JWT payload as dict
60+
:raises InvalidTokenError: If token is invalid or verification fails
61+
"""
62+
signing_key = backend.find_valid_key(token_str)
63+
jwk = jwt.PyJWK(signing_key)
64+
65+
decoded = jwt.decode(
66+
token_str,
67+
key=jwk,
68+
algorithms=[jwk.algorithm_name],
69+
audience=backend.strategy.config["accepted_audiences"],
70+
issuer=backend.id_token_issuer(),
71+
options={
72+
"verify_signature": True,
73+
"verify_exp": True,
74+
"verify_nbf": True,
75+
"verify_iat": True,
76+
"verify_aud": bool(backend.strategy.config["accepted_audiences"]),
77+
"verify_iss": True,
78+
},
79+
)
80+
return decoded
81+
82+
83+
def verify_oidc_response(response: dict) -> None:
84+
"""
85+
Verify that an OIDC authentication response contains required fields.
86+
87+
Checks for:
88+
- id_token presence
89+
- iat (issued at) claim in id_token
90+
91+
:param response: OIDC authentication response dict
92+
:raises MalformedContents: If required fields are missing
93+
"""
94+
if "id_token" not in response:
95+
raise MalformedContents("Missing id_token in OIDC response")
96+
97+
# Decode without verification to check structure
98+
decoded = jwt.decode(response["id_token"], options={"verify_signature": False})
99+
if "iat" not in decoded:
100+
raise MalformedContents("Missing iat claim in id_token")

lib/galaxy/authnz/psa_authnz.py

Lines changed: 90 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@
3636
requests,
3737
)
3838
from . import IdentityProvider
39+
from .oidc_utils import (
40+
decode_access_token as decode_access_token_oidc,
41+
is_decodable_jwt,
42+
is_oidc_backend,
43+
verify_oidc_response,
44+
)
3945
from ..config import GalaxyAppConfiguration
4046

4147
log = logging.getLogger(__name__)
@@ -160,23 +166,35 @@ def __init__(self, provider, oidc_config, oidc_backend_config, app_config: Galax
160166
if "SOCIAL_AUTH_SECONDARY_AUTH_ENDPOINT" in self.config:
161167
del self.config["SOCIAL_AUTH_SECONDARY_AUTH_ENDPOINT"]
162168

169+
def _is_oidc_backend(self) -> bool:
170+
"""
171+
Check if the current backend is OIDC-based.
172+
173+
:return: True if backend is OpenID Connect, False for OAuth2/other backends
174+
"""
175+
backend_class = BACKENDS.get(self.config["provider"], "")
176+
return "OpenIdConnect" in backend_class or "openidconnect" in backend_class.lower()
177+
163178
def _setup_idp(self, oidc_backend_config):
179+
"""
180+
Configure backend-specific settings from oidc_backends_config.xml.
181+
182+
Sets up both universal settings (that work for all backends) and
183+
OIDC-specific settings (only for OIDC backends).
184+
"""
185+
# Universal settings (work for all backends: OIDC + OAuth2)
164186
self.config[setting_name("AUTH_EXTRA_ARGUMENTS")] = {"access_type": "offline"}
165187
self.config["KEY"] = oidc_backend_config.get("client_id")
166188
self.config["SECRET"] = oidc_backend_config.get("client_secret")
167-
self.config["TENANT_ID"] = oidc_backend_config.get("tenant_id")
189+
self.config["TENANT_ID"] = oidc_backend_config.get("tenant_id") # Azure/Tapis
168190
self.config["redirect_uri"] = oidc_backend_config.get("redirect_uri")
169-
self.config["accepted_audiences"] = oidc_backend_config.get("accepted_audiences")
170191
self.config["EXTRA_SCOPES"] = oidc_backend_config.get("extra_scopes")
192+
self.config["LABEL"] = oidc_backend_config.get("label", self.config["provider"].capitalize())
171193

172-
# OIDC-specific configurations
173-
self.config["PKCE_SUPPORT"] = oidc_backend_config.get("pkce_support", False)
174-
self.config["IDPHINT"] = oidc_backend_config.get("idphint")
194+
# Galaxy-specific pipeline settings (affect all backends)
175195
self.config["REQUIRE_CREATE_CONFIRMATION"] = oidc_backend_config.get("require_create_confirmation", False)
176-
self.config["LABEL"] = oidc_backend_config.get("label", self.config["provider"].capitalize())
177196

178-
if oidc_backend_config.get("oidc_endpoint"):
179-
self.config["OIDC_ENDPOINT"] = oidc_backend_config["oidc_endpoint"]
197+
# Optional generic settings
180198
if oidc_backend_config.get("prompt") is not None:
181199
self.config[setting_name("AUTH_EXTRA_ARGUMENTS")]["prompt"] = oidc_backend_config.get("prompt")
182200
if oidc_backend_config.get("api_url") is not None:
@@ -186,6 +204,14 @@ def _setup_idp(self, oidc_backend_config):
186204
if oidc_backend_config.get("username_key") is not None:
187205
self.config[setting_name("USERNAME_KEY")] = oidc_backend_config.get("username_key")
188206

207+
# OIDC-specific settings (only set for OIDC backends)
208+
if self._is_oidc_backend():
209+
self.config["PKCE_SUPPORT"] = oidc_backend_config.get("pkce_support", False)
210+
self.config["IDPHINT"] = oidc_backend_config.get("idphint")
211+
self.config["accepted_audiences"] = oidc_backend_config.get("accepted_audiences")
212+
if oidc_backend_config.get("oidc_endpoint"):
213+
self.config["OIDC_ENDPOINT"] = oidc_backend_config["oidc_endpoint"]
214+
189215
def _get_helper(self, name, do_import=False):
190216
this_config = self.config.get(setting_name(name), DEFAULTS.get(name, None))
191217
return do_import and module_member(this_config) or this_config
@@ -345,7 +371,8 @@ def logout(self, trans, post_user_logout_href=None):
345371
"""
346372
Logout from the identity provider.
347373
348-
Constructs a logout URL using the OIDC end_session_endpoint if available.
374+
For OIDC backends, constructs a logout URL using the end_session_endpoint.
375+
For non-OIDC backends, returns the fallback URL.
349376
350377
:param trans: Galaxy transaction object
351378
:param post_user_logout_href: URL to redirect to after logout
@@ -355,50 +382,61 @@ def logout(self, trans, post_user_logout_href=None):
355382
strategy = Strategy(trans.request, trans.session, Storage, self.config)
356383
backend = self._load_backend(strategy, self.config["redirect_uri"])
357384

358-
# Get OIDC configuration to find end_session_endpoint
359-
try:
360-
oidc_config = backend.oidc_config()
361-
end_session_endpoint = oidc_config.get("end_session_endpoint")
385+
# Only OIDC backends support IDP logout
386+
if is_oidc_backend(backend):
387+
try:
388+
# Get end_session_endpoint from OIDC discovery document
389+
oidc_config = backend.oidc_config()
390+
end_session_endpoint = oidc_config.get("end_session_endpoint")
391+
392+
if end_session_endpoint:
393+
# Construct logout URL with optional redirect_uri
394+
if post_user_logout_href:
395+
logout_url = f"{end_session_endpoint}?redirect_uri={quote(post_user_logout_href)}"
396+
else:
397+
logout_url = end_session_endpoint
362398

363-
if end_session_endpoint:
364-
# Construct logout URL with optional redirect_uri
365-
if post_user_logout_href:
366-
logout_url = f"{end_session_endpoint}?redirect_uri={quote(post_user_logout_href)}"
399+
return logout_url
367400
else:
368-
logout_url = end_session_endpoint
401+
# No end_session_endpoint available
402+
log.warning(f"No end_session_endpoint found for {self.config['provider']}")
403+
return post_user_logout_href or "/"
369404

370-
return logout_url
371-
else:
372-
# No end_session_endpoint available
373-
log.warning(f"No end_session_endpoint found in OIDC configuration for {self.config['provider']}")
405+
except Exception as e:
406+
log.exception(f"Error getting logout URL for {self.config['provider']}: {e}")
374407
return post_user_logout_href or "/"
375-
376-
except Exception as e:
377-
log.exception(f"Error getting logout URL for {self.config['provider']}: {e}")
408+
else:
409+
# Non-OIDC backends don't have IDP logout
410+
log.debug(f"Backend {self.config['provider']} does not support IDP logout")
378411
return post_user_logout_href or "/"
379412

380413
def decode_user_access_token(self, sa_session, access_token):
381414
"""
382415
Verifies and decodes an access token against this provider, returning the user and
383416
a dict containing the decoded token data.
384417
385-
This is used for API authentication with Bearer tokens.
418+
This is used for API authentication with Bearer tokens. Only works for OIDC backends.
386419
387420
:param sa_session: SQLAlchemy database session
388421
:param access_token: An OIDC access token
389422
:return: A tuple containing the user and decoded jwt data, or (None, None) if token is for different provider
390423
:rtype: Tuple[User, dict]
391424
:raises Exception: If token is valid but user hasn't logged in, or token validation fails
425+
:raises NotImplementedError: If backend is not OIDC-based
392426
"""
427+
# Only OIDC backends support JWT access tokens
428+
if not self._is_oidc_backend():
429+
raise NotImplementedError(f"Access token decoding not supported for {self.config['provider']}")
430+
393431
try:
394432
on_the_fly_config(sa_session)
395433
# Create a minimal strategy and backend just for token verification
396434
strategy = Strategy(None, {}, Storage, self.config)
397435
backend = self._load_backend(strategy, self.config["redirect_uri"])
398436

399-
# Decode and verify the access token using the helper function
437+
# Decode and verify the access token using oidc_utils
400438
# This will raise exceptions for: expired tokens, invalid audience, invalid signature, etc.
401-
decoded_jwt = _decode_access_token_helper(access_token, backend)
439+
decoded_jwt = decode_access_token_oidc(access_token, backend)
402440

403441
# JWT verified, now fetch the user
404442
user_id = decoded_jwt["sub"]
@@ -568,7 +606,7 @@ def on_the_fly_config(sa_session):
568606
PSAAssociation.sa_session = sa_session
569607

570608

571-
def contains_required_data(response=None, is_new=False, **kwargs):
609+
def contains_required_data(response=None, is_new=False, backend=None, **kwargs):
572610
"""
573611
This function is called as part of authentication and authorization
574612
pipeline before user is authenticated or authorized (see AUTH_PIPELINE).
@@ -577,6 +615,9 @@ def contains_required_data(response=None, is_new=False, **kwargs):
577615
is provided. It raises an exception if any of the required data is missing,
578616
and returns void if otherwise.
579617
618+
For OIDC backends, verifies presence of id_token and iat claim.
619+
For OAuth2 backends, performs basic validation only.
620+
580621
:type response: dict
581622
:param response: a dictionary containing decoded response from
582623
OIDC backend that contain the following keys
@@ -593,11 +634,12 @@ def contains_required_data(response=None, is_new=False, **kwargs):
593634
:type is_new: bool
594635
:param is_new: has the user been authenticated?
595636
637+
:param backend: The PSA backend being used for authentication
638+
596639
:param kwargs: may contain the following keys among others:
597640
598641
- uid: user ID
599642
- user: Galaxy user; if user is already authenticated
600-
- backend: the backend that is used for user authentication.
601643
- storage: an instance of Storage class.
602644
- strategy: an instance of the Strategy class.
603645
- state: the state code received from identity provider.
@@ -619,10 +661,15 @@ def contains_required_data(response=None, is_new=False, **kwargs):
619661
# scenarios; however, this case is implemented to prevent uncaught
620662
# server-side errors.
621663
raise MalformedContents(err_msg=f"`response` not found. {hint_msg}")
622-
if not response.get("id_token"):
623-
# This can happen if a non-OIDC compliant backend is used;
624-
# e.g., an OAuth2.0-based backend that only generates access token.
625-
raise MalformedContents(err_msg=f"Missing identity token. {hint_msg}")
664+
665+
# OIDC-specific validation
666+
if backend and is_oidc_backend(backend):
667+
try:
668+
verify_oidc_response(response)
669+
except MalformedContents:
670+
# Re-raise with hint message
671+
raise MalformedContents(err_msg=f"Missing required OIDC data. {hint_msg}")
672+
626673
if is_new and not response.get("refresh_token"):
627674
# An identity provider (e.g., Google) sends a refresh token the first
628675
# time user consents Galaxy's access (i.e., the first time user logs in
@@ -731,11 +778,12 @@ def disconnect(
731778
sa_session.commit()
732779

733780

734-
def decode_access_token(social: UserAuthnzToken, backend: OpenIdConnectAuth, **kwargs):
781+
def decode_access_token(social: UserAuthnzToken, backend, **kwargs):
735782
"""
736783
Auth pipeline step to decode the OIDC access token, if possible.
784+
737785
Note that some OIDC providers return an opaque access token, which
738-
cannot be decoded.
786+
cannot be decoded. This step only works for OIDC backends.
739787
740788
Returns the access token, making it available as a new argument
741789
"access_token" that can be used in future pipeline steps. If
@@ -745,60 +793,28 @@ def decode_access_token(social: UserAuthnzToken, backend: OpenIdConnectAuth, **k
745793
which should be handled by social_core.pipeline.social_auth.load_extra_data, so
746794
this step should be placed after load_extra_data in the pipeline.
747795
"""
796+
# Only decode for OIDC backends
797+
if not is_oidc_backend(backend):
798+
return {"access_token": None}
799+
748800
if social.extra_data is None:
749801
return {"access_token": None}
750802
access_token_encoded = social.extra_data.get("access_token")
751803
if access_token_encoded is None:
752804
return {"access_token": None}
753-
if not _is_decodable_jwt(access_token_encoded):
805+
if not is_decodable_jwt(access_token_encoded):
754806
log.warning(
755807
"Access token is not in header.payload.signature format and can't be decoded (may be an opaque token)"
756808
)
757809
return {"access_token": None}
758810
try:
759-
access_token_data = _decode_access_token_helper(token_str=access_token_encoded, backend=backend)
811+
access_token_data = decode_access_token_oidc(token_str=access_token_encoded, backend=backend)
760812
except InvalidTokenError as e:
761813
log.warning(f"Access token couldn't be decoded: {e}")
762814
return {"access_token": None}
763815
return {"access_token": access_token_data}
764816

765817

766-
def _is_decodable_jwt(token_str: str) -> bool:
767-
"""
768-
Check if a token string looks like a decodable JWT.
769-
We assume decodable JWTs are in the format header.payload.signature
770-
"""
771-
components = token_str.split(".")
772-
return len(components) == 3
773-
774-
775-
def _decode_access_token_helper(token_str: str, backend: OpenIdConnectAuth) -> dict:
776-
"""
777-
Decode the access token (verifying that signature, expiry and
778-
audience are valid).
779-
780-
Requires accepted_audiences to be configured in the OIDC backend config
781-
"""
782-
signing_key = backend.find_valid_key(token_str)
783-
jwk = jwt.PyJWK(signing_key)
784-
decoded = jwt.decode(
785-
token_str,
786-
key=jwk,
787-
algorithms=[jwk.algorithm_name],
788-
audience=backend.strategy.config["accepted_audiences"],
789-
issuer=backend.id_token_issuer(),
790-
options={
791-
"verify_signature": True,
792-
"verify_exp": True,
793-
"verify_nbf": True,
794-
"verify_iat": True,
795-
"verify_aud": bool(backend.strategy.config["accepted_audiences"]),
796-
"verify_iss": True,
797-
},
798-
)
799-
return decoded
800-
801-
802818
def associate_by_email_if_logged_in(
803819
strategy=None, backend=None, details=None, user=None, social=None, is_new=False, *args, **kwargs
804820
):

0 commit comments

Comments
 (0)