From 315ce7a11317cea82ac45c4f4ed9cba99657b95d Mon Sep 17 00:00:00 2001 From: Stefaan Lippens Date: Fri, 17 Jan 2025 18:07:35 +0100 Subject: [PATCH] Issue #254/#691 introduce _on_auth_update handler - to make sure all cases are covered - include authenticate_oidc_access_token --- CHANGELOG.md | 2 + openeo/rest/connection.py | 21 +++++- tests/rest/test_connection.py | 126 ++++++++++++++++++++-------------- 3 files changed, 93 insertions(+), 56 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c985ae28c..997979a4e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Clear capabilities cache on login ([#254](https://github.com/Open-EO/openeo-python-client/issues/254)) + ## [0.36.0] - 2024-12-10 diff --git a/openeo/rest/connection.py b/openeo/rest/connection.py index 04a5a0242..fa4725822 100644 --- a/openeo/rest/connection.py +++ b/openeo/rest/connection.py @@ -113,6 +113,7 @@ def __init__( slow_response_threshold: Optional[float] = None, ): self._root_url = root_url + self._auth = None self.auth = auth or NullAuth() self.session = session or requests.Session() self.default_timeout = default_timeout or DEFAULT_TIMEOUT @@ -129,6 +130,18 @@ def __init__( def root_url(self): return self._root_url + @property + def auth(self) -> Union[AuthBase, None]: + return self._auth + + @auth.setter + def auth(self, auth: Union[AuthBase, None]): + self._auth = auth + self._on_auth_update() + + def _on_auth_update(self): + pass + def build_url(self, path: str): return url_join(self._root_url, path) @@ -340,12 +353,12 @@ def __init__( if "://" not in url: url = "https://" + url self._orig_url = url + self._capabilities_cache = LazyLoadCache() super().__init__( root_url=self.version_discovery(url, session=session, timeout=default_timeout), auth=auth, session=session, default_timeout=default_timeout, slow_response_threshold=slow_response_threshold, ) - self._capabilities_cache = LazyLoadCache() # Initial API version check. self._api_version.require_at_least(self._MINIMUM_API_VERSION) @@ -380,6 +393,10 @@ def version_discovery( # Be very lenient about failing on the well-known URI strategy. return url + def _on_auth_update(self): + super()._on_auth_update() + self._capabilities_cache.clear() + def _get_auth_config(self) -> AuthConfig: if self._auth_config is None: self._auth_config = AuthConfig() @@ -411,7 +428,6 @@ def authenticate_basic(self, username: Optional[str] = None, password: Optional[ ).json() # Switch to bearer based authentication in further requests. self.auth = BasicBearerAuth(access_token=resp["access_token"]) - self._capabilities_cache.clear() return self def _get_oidc_provider( @@ -546,7 +562,6 @@ def _authenticate_oidc( _log.warning("No OIDC refresh token to store.") token = tokens.access_token self.auth = OidcBearerAuth(provider_id=provider_id, access_token=token) - self._capabilities_cache.clear() self._oidc_auth_renewer = oidc_auth_renewer return self diff --git a/tests/rest/test_connection.py b/tests/rest/test_connection.py index 02d55e1d7..643f734c3 100644 --- a/tests/rest/test_connection.py +++ b/tests/rest/test_connection.py @@ -49,6 +49,7 @@ API_URL = "https://oeo.test/" +# TODO: eliminate this and replace with `build_capabilities` usage BASIC_ENDPOINTS = [{"path": "/credentials/basic", "methods": ["GET"]}] @@ -551,83 +552,102 @@ def test_capabilities_caching(requests_mock): assert con.capabilities().api_version() == "1.0.0" assert m.call_count == 1 -def test_capabilities_caching_after_authenticate_basic(requests_mock): - user, pwd = "john262", "J0hndo3" - def get_capabilities(request, context): - endpoints = BASIC_ENDPOINTS.copy() - if "Authorization" in request.headers: - endpoints.append({"path": "/account/status", "methods": ["GET"]}) - return {"api_version": "1.0.0", "endpoints": endpoints} +def _get_capabilities_auth_dependent(request, context): + capabilities = build_capabilities() + capabilities["endpoints"] = [ + {"methods": ["GET"], "path": "/credentials/basic"}, + {"methods": ["GET"], "path": "/credentials/oidc"}, + ] + if "Authorization" in request.headers: + capabilities["endpoints"].append({"methods": ["GET"], "path": "/me"}) + return capabilities + - get_capabilities_mock = requests_mock.get(API_URL, json=get_capabilities) +def test_capabilities_caching_after_authenticate_basic(requests_mock): + user, pwd = "john262", "J0hndo3" + get_capabilities_mock = requests_mock.get(API_URL, json=_get_capabilities_auth_dependent) requests_mock.get(API_URL + 'credentials/basic', text=_credentials_basic_handler(user, pwd)) con = Connection(API_URL) - assert con.capabilities().capabilities == { - "api_version": "1.0.0", - "endpoints": [ - {"methods": ["GET"], "path": "/credentials/basic"}, - ], - } + assert con.capabilities().capabilities["endpoints"] == [ + {"methods": ["GET"], "path": "/credentials/basic"}, + {"methods": ["GET"], "path": "/credentials/oidc"}, + ] assert get_capabilities_mock.call_count == 1 con.capabilities() assert get_capabilities_mock.call_count == 1 - con.authenticate_basic(user, pwd) + con.authenticate_basic(username=user, password=pwd) assert get_capabilities_mock.call_count == 1 - assert con.capabilities().capabilities == { - "api_version": "1.0.0", - "endpoints": [ - {"methods": ["GET"], "path": "/credentials/basic"}, - {"methods": ["GET"], "path": "/account/status"}, - ], - } - assert get_capabilities_mock.call_count == 2 + assert con.capabilities().capabilities["endpoints"] == [ + {"methods": ["GET"], "path": "/credentials/basic"}, + {"methods": ["GET"], "path": "/credentials/oidc"}, + {"methods": ["GET"], "path": "/me"}, + ] + assert get_capabilities_mock.call_count == 2 -def test_capabilities_caching_after_authenticate_oidc(requests_mock): +def test_capabilities_caching_after_authenticate_oidc_refresh_token(requests_mock): client_id = "myclient" - - def get_capabilities(request, context): - endpoints = BASIC_ENDPOINTS.copy() - if "Authorization" in request.headers: - endpoints.append({"path": "/account/status", "methods": ["GET"]}) - return {"api_version": "1.0.0", "endpoints": endpoints} - - get_capabilities_mock = requests_mock.get(API_URL, json=get_capabilities) - requests_mock.get(API_URL + 'credentials/oidc', json={ - "providers": [{"id": "fauth", "issuer": "https://fauth.test", "title": "Foo Auth", "scopes": ["openid", "im"]}] - }) + refresh_token = "fr65h!" + get_capabilities_mock = requests_mock.get(API_URL, json=_get_capabilities_auth_dependent) + requests_mock.get( + API_URL + "credentials/oidc", + json={"providers": [{"id": "oi", "issuer": "https://oidc.test", "title": "OI!", "scopes": ["openid"]}]}, + ) oidc_mock = OidcMock( requests_mock=requests_mock, - expected_grant_type="authorization_code", + expected_grant_type="refresh_token", expected_client_id=client_id, - expected_fields={"scope": "im openid"}, - oidc_issuer="https://fauth.test", - scopes_supported=["openid", "im"], + expected_fields={"refresh_token": refresh_token}, ) + conn = Connection(API_URL) - assert conn.capabilities().capabilities == { - "api_version": "1.0.0", - "endpoints": [ - {"methods": ["GET"], "path": "/credentials/basic"}, - ], - } + assert conn.capabilities().capabilities["endpoints"] == [ + {"methods": ["GET"], "path": "/credentials/basic"}, + {"methods": ["GET"], "path": "/credentials/oidc"}, + ] + assert get_capabilities_mock.call_count == 1 conn.capabilities() assert get_capabilities_mock.call_count == 1 - conn.authenticate_oidc_authorization_code(client_id=client_id, webbrowser_open=oidc_mock.webbrowser_open) + conn.authenticate_oidc_refresh_token(client_id=client_id, refresh_token=refresh_token) assert get_capabilities_mock.call_count == 1 - assert conn.capabilities().capabilities == { - "api_version": "1.0.0", - "endpoints": [ - {"methods": ["GET"], "path": "/credentials/basic"}, - {"methods": ["GET"], "path": "/account/status"}, - ], - } + assert conn.capabilities().capabilities["endpoints"] == [ + {"methods": ["GET"], "path": "/credentials/basic"}, + {"methods": ["GET"], "path": "/credentials/oidc"}, + {"methods": ["GET"], "path": "/me"}, + ] + assert get_capabilities_mock.call_count == 2 + + +def test_capabilities_caching_after_authenticate_oidc_access_token(requests_mock): + get_capabilities_mock = requests_mock.get(API_URL, json=_get_capabilities_auth_dependent) + requests_mock.get( + API_URL + "credentials/oidc", + json={"providers": [{"id": "oi", "issuer": "https://oidc.test", "title": "OI!", "scopes": ["openid"]}]}, + ) + + conn = Connection(API_URL) + assert conn.capabilities().capabilities["endpoints"] == [ + {"methods": ["GET"], "path": "/credentials/basic"}, + {"methods": ["GET"], "path": "/credentials/oidc"}, + ] + + assert get_capabilities_mock.call_count == 1 + conn.capabilities() + assert get_capabilities_mock.call_count == 1 + + conn.authenticate_oidc_access_token(access_token="6cc355!") + assert get_capabilities_mock.call_count == 1 + assert conn.capabilities().capabilities["endpoints"] == [ + {"methods": ["GET"], "path": "/credentials/basic"}, + {"methods": ["GET"], "path": "/credentials/oidc"}, + {"methods": ["GET"], "path": "/me"}, + ] assert get_capabilities_mock.call_count == 2