diff --git a/openeo/rest/connection.py b/openeo/rest/connection.py index f687d5ca1..fa4725822 100644 --- a/openeo/rest/connection.py +++ b/openeo/rest/connection.py @@ -113,6 +113,7 @@ def __init__( slow_response_threshold: Optional[float] = None, ): self._root_url = root_url + self._auth = None self.auth = auth or NullAuth() self.session = session or requests.Session() self.default_timeout = default_timeout or DEFAULT_TIMEOUT @@ -129,6 +130,18 @@ def __init__( def root_url(self): return self._root_url + @property + def auth(self) -> Union[AuthBase, None]: + return self._auth + + @auth.setter + def auth(self, auth: Union[AuthBase, None]): + self._auth = auth + self._on_auth_update() + + def _on_auth_update(self): + pass + def build_url(self, path: str): return url_join(self._root_url, path) @@ -340,12 +353,12 @@ def __init__( if "://" not in url: url = "https://" + url self._orig_url = url + self._capabilities_cache = LazyLoadCache() super().__init__( root_url=self.version_discovery(url, session=session, timeout=default_timeout), auth=auth, session=session, default_timeout=default_timeout, slow_response_threshold=slow_response_threshold, ) - self._capabilities_cache = LazyLoadCache() # Initial API version check. self._api_version.require_at_least(self._MINIMUM_API_VERSION) @@ -380,6 +393,10 @@ def version_discovery( # Be very lenient about failing on the well-known URI strategy. return url + def _on_auth_update(self): + super()._on_auth_update() + self._capabilities_cache.clear() + def _get_auth_config(self) -> AuthConfig: if self._auth_config is None: self._auth_config = AuthConfig() diff --git a/openeo/util.py b/openeo/util.py index 44842124a..53550ccb8 100644 --- a/openeo/util.py +++ b/openeo/util.py @@ -476,6 +476,9 @@ def get(self, key: Union[str, tuple], load: Callable[[], Any]): self._cache[key] = load() return self._cache[key] + def clear(self): + self._cache = {} + def str_truncate(text: str, width: int = 64, ellipsis: str = "...") -> str: """Shorten a string (with an ellipsis) if it is longer than certain length.""" diff --git a/tests/rest/test_connection.py b/tests/rest/test_connection.py index 16f5cb894..643f734c3 100644 --- a/tests/rest/test_connection.py +++ b/tests/rest/test_connection.py @@ -49,6 +49,7 @@ API_URL = "https://oeo.test/" +# TODO: eliminate this and replace with `build_capabilities` usage BASIC_ENDPOINTS = [{"path": "/credentials/basic", "methods": ["GET"]}] @@ -552,6 +553,104 @@ def test_capabilities_caching(requests_mock): assert m.call_count == 1 +def _get_capabilities_auth_dependent(request, context): + capabilities = build_capabilities() + capabilities["endpoints"] = [ + {"methods": ["GET"], "path": "/credentials/basic"}, + {"methods": ["GET"], "path": "/credentials/oidc"}, + ] + if "Authorization" in request.headers: + capabilities["endpoints"].append({"methods": ["GET"], "path": "/me"}) + return capabilities + + +def test_capabilities_caching_after_authenticate_basic(requests_mock): + user, pwd = "john262", "J0hndo3" + get_capabilities_mock = requests_mock.get(API_URL, json=_get_capabilities_auth_dependent) + requests_mock.get(API_URL + 'credentials/basic', text=_credentials_basic_handler(user, pwd)) + + con = Connection(API_URL) + assert con.capabilities().capabilities["endpoints"] == [ + {"methods": ["GET"], "path": "/credentials/basic"}, + {"methods": ["GET"], "path": "/credentials/oidc"}, + ] + assert get_capabilities_mock.call_count == 1 + con.capabilities() + assert get_capabilities_mock.call_count == 1 + + con.authenticate_basic(username=user, password=pwd) + assert get_capabilities_mock.call_count == 1 + assert con.capabilities().capabilities["endpoints"] == [ + {"methods": ["GET"], "path": "/credentials/basic"}, + {"methods": ["GET"], "path": "/credentials/oidc"}, + {"methods": ["GET"], "path": "/me"}, + ] + + assert get_capabilities_mock.call_count == 2 + + +def test_capabilities_caching_after_authenticate_oidc_refresh_token(requests_mock): + client_id = "myclient" + refresh_token = "fr65h!" + get_capabilities_mock = requests_mock.get(API_URL, json=_get_capabilities_auth_dependent) + requests_mock.get( + API_URL + "credentials/oidc", + json={"providers": [{"id": "oi", "issuer": "https://oidc.test", "title": "OI!", "scopes": ["openid"]}]}, + ) + oidc_mock = OidcMock( + requests_mock=requests_mock, + expected_grant_type="refresh_token", + expected_client_id=client_id, + expected_fields={"refresh_token": refresh_token}, + ) + + conn = Connection(API_URL) + assert conn.capabilities().capabilities["endpoints"] == [ + {"methods": ["GET"], "path": "/credentials/basic"}, + {"methods": ["GET"], "path": "/credentials/oidc"}, + ] + + assert get_capabilities_mock.call_count == 1 + conn.capabilities() + assert get_capabilities_mock.call_count == 1 + + conn.authenticate_oidc_refresh_token(client_id=client_id, refresh_token=refresh_token) + assert get_capabilities_mock.call_count == 1 + assert conn.capabilities().capabilities["endpoints"] == [ + {"methods": ["GET"], "path": "/credentials/basic"}, + {"methods": ["GET"], "path": "/credentials/oidc"}, + {"methods": ["GET"], "path": "/me"}, + ] + assert get_capabilities_mock.call_count == 2 + + +def test_capabilities_caching_after_authenticate_oidc_access_token(requests_mock): + get_capabilities_mock = requests_mock.get(API_URL, json=_get_capabilities_auth_dependent) + requests_mock.get( + API_URL + "credentials/oidc", + json={"providers": [{"id": "oi", "issuer": "https://oidc.test", "title": "OI!", "scopes": ["openid"]}]}, + ) + + conn = Connection(API_URL) + assert conn.capabilities().capabilities["endpoints"] == [ + {"methods": ["GET"], "path": "/credentials/basic"}, + {"methods": ["GET"], "path": "/credentials/oidc"}, + ] + + assert get_capabilities_mock.call_count == 1 + conn.capabilities() + assert get_capabilities_mock.call_count == 1 + + conn.authenticate_oidc_access_token(access_token="6cc355!") + assert get_capabilities_mock.call_count == 1 + assert conn.capabilities().capabilities["endpoints"] == [ + {"methods": ["GET"], "path": "/credentials/basic"}, + {"methods": ["GET"], "path": "/credentials/oidc"}, + {"methods": ["GET"], "path": "/me"}, + ] + assert get_capabilities_mock.call_count == 2 + + def test_file_formats(requests_mock): requests_mock.get("https://oeo.test/", json={"api_version": "1.0.0"}) m = requests_mock.get("https://oeo.test/file_formats", json={"output": {"GTiff": {"gis_data_types": ["raster"]}}})