diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index 6cac45af..c9a9d15c 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -9,8 +9,10 @@ import urllib.parse import webbrowser from abc import abstractmethod +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from datetime import datetime, timedelta +from enum import Enum from http.server import BaseHTTPRequestHandler, HTTPServer from typing import Any, Dict, List, Optional @@ -187,21 +189,132 @@ def retrieve_token(client_id, raise NotImplementedError(f"Not supported yet: {e}") +class _TokenState(Enum): + """ + Represents the state of a token. Each token can be in one of + the following three states: + - FRESH: The token is valid. + - STALE: The token is valid but will expire soon. + - EXPIRED: The token has expired and cannot be used. + """ + FRESH = 1 # The token is valid. + STALE = 2 # The token is valid but will expire soon. + EXPIRED = 3 # The token has expired and cannot be used. + + class Refreshable(TokenSource): + """A token source that supports refreshing expired tokens.""" + + _EXECUTOR = None + _EXECUTOR_LOCK = threading.Lock() + _DEFAULT_STALE_DURATION = timedelta(minutes=3) + + @classmethod + def _get_executor(cls): + """Lazy initialization of the ThreadPoolExecutor.""" + if cls._EXECUTOR is None: + with cls._EXECUTOR_LOCK: + if cls._EXECUTOR is None: + # This thread pool has multiple workers because it is shared by all instances of Refreshable. + cls._EXECUTOR = ThreadPoolExecutor(max_workers=10) + return cls._EXECUTOR - def __init__(self, token=None): - self._lock = threading.Lock() # to guard _token + def __init__(self, + token: Token = None, + disable_async: bool = True, + stale_duration: timedelta = _DEFAULT_STALE_DURATION): + # Config properties + self._stale_duration = stale_duration + self._disable_async = disable_async + # Lock + self._lock = threading.Lock() + # Non Thread safe properties. They should be accessed only when protected by the lock above. self._token = token + self._is_refreshing = False + self._refresh_err = False + # This is the main entry point for the Token. Do not access the token + # using any of the internal functions. def token(self) -> Token: - self._lock.acquire() - try: - if self._token and self._token.valid: - return self._token - self._token = self.refresh() + """Returns a valid token, blocking if async refresh is disabled.""" + with self._lock: + if self._disable_async: + return self._blocking_token() + return self._async_token() + + def _async_token(self) -> Token: + """ + Returns a token. + If the token is stale, triggers an asynchronous refresh. + If the token is expired, refreshes it synchronously, blocking until the refresh is complete. + """ + state = self._token_state() + token = self._token + + if state == _TokenState.FRESH: + return token + if state == _TokenState.STALE: + self._trigger_async_refresh() + return token + return self._blocking_token() + + def _token_state(self) -> _TokenState: + """Returns the current state of the token.""" + if not self._token or not self._token.valid: + return _TokenState.EXPIRED + if not self._token.expiry: + return _TokenState.FRESH + + lifespan = self._token.expiry - datetime.now() + if lifespan < timedelta(seconds=0): + return _TokenState.EXPIRED + if lifespan < self._stale_duration: + return _TokenState.STALE + return _TokenState.FRESH + + def _blocking_token(self) -> Token: + """Returns a token, blocking if necessary to refresh it.""" + state = self._token_state() + # This is important to recover from potential previous failed attempts + # to refresh the token asynchronously. + self._refresh_err = False + self._is_refreshing = False + + # It's possible that the token got refreshed (either by a _blocking_refresh or + # an _async_refresh call) while this particular call was waiting to acquire + # the lock. This check avoids refreshing the token again in such cases. + if state != _TokenState.EXPIRED: return self._token - finally: - self._lock.release() + + self._token = self.refresh() + return self._token + + def _trigger_async_refresh(self): + """Starts an asynchronous refresh if none is in progress.""" + + def _refresh_internal(): + new_token: Token = None + try: + new_token = self.refresh() + except Exception as e: + # This happens on a thread, so we don't want to propagate the error. + # Instead, if there is no new_token for any reason, we will disable async refresh below + # But we will do it inside the lock. + logger.warning(f'Tried to refresh token asynchronously, but failed: {e}') + + with self._lock: + if new_token is not None: + self._token = new_token + else: + self._refresh_err = True + self._is_refreshing = False + + # The token may have been refreshed by another thread. + if self._token_state() == _TokenState.FRESH: + return + if not self._is_refreshing and not self._refresh_err: + self._is_refreshing = True + Refreshable._get_executor().submit(_refresh_internal) @abstractmethod def refresh(self) -> Token: @@ -295,7 +408,7 @@ def __init__(self, super().__init__(token) def as_dict(self) -> dict: - return {'token': self._token.as_dict()} + return {'token': self.token().as_dict()} @staticmethod def from_dict(raw: dict, diff --git a/tests/test_refreshable.py b/tests/test_refreshable.py new file mode 100644 index 00000000..7265026e --- /dev/null +++ b/tests/test_refreshable.py @@ -0,0 +1,216 @@ +import time +from datetime import datetime, timedelta +from time import sleep +from typing import Callable + +from databricks.sdk.oauth import Refreshable, Token + + +class _MockRefreshable(Refreshable): + + def __init__(self, + disable_async, + token=None, + stale_duration=timedelta(seconds=60), + refresh_effect: Callable[[], Token] = None): + super().__init__(token, disable_async, stale_duration) + self._refresh_effect = refresh_effect + self._refresh_count = 0 + + def refresh(self) -> Token: + if self._refresh_effect: + self._token = self._refresh_effect() + self._refresh_count += 1 + return self._token + + +def fail() -> Token: + raise Exception("Simulated token refresh failure") + + +def static_token(token: Token, wait: int = 0) -> Callable[[], Token]: + + def f() -> Token: + time.sleep(wait) + return token + + return f + + +def blocking_refresh(token: Token) -> (Callable[[], Token], Callable[[], None]): + """ + Create a refresh function that blocks until unblock is called. + + Param: + token: the token that will be returned + + Returns: + A tuple containing the refresh function and the unblock function. + + """ + blocking = True + + def refresh(): + while blocking: + sleep(0.1) + return token + + def unblock(): + nonlocal blocking + blocking = False + + return refresh, unblock + + +def test_disable_async_stale_does_not_refresh(): + stale_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=50), ) + r = _MockRefreshable(token=stale_token, disable_async=True, refresh_effect=fail) + result = r.token() + assert r._refresh_count == 0 + assert result == stale_token + + +def test_disable_async_no_token_does_refresh(): + token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=50), ) + r = _MockRefreshable(token=None, disable_async=True, refresh_effect=static_token(token)) + result = r.token() + assert r._refresh_count == 1 + assert result == token + + +def test_disable_async_no_expiration_does_not_refresh(): + non_expiring_token = Token(access_token="access_token", ) + r = _MockRefreshable(token=non_expiring_token, disable_async=True, refresh_effect=fail) + result = r.token() + assert r._refresh_count == 0 + assert result == non_expiring_token + + +def test_disable_async_fresh_does_not_refresh(): + # Create a token that is already stale. If async is disabled, the token should not be refreshed. + token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), ) + r = _MockRefreshable(token=token, disable_async=True, refresh_effect=fail) + result = r.token() + assert r._refresh_count == 0 + assert result == token + + +def test_disable_async_expired_does_refresh(): + expired_token = Token(access_token="access_token", expiry=datetime.now() - timedelta(seconds=300), ) + new_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), ) + # Add one second to the refresh time to ensure that the call is blocking. + # If the call is not blocking, the wait time will ensure that the + # old token is returned. + r = _MockRefreshable(token=expired_token, + disable_async=True, + refresh_effect=static_token(new_token, wait=1)) + result = r.token() + assert r._refresh_count == 1 + assert result == new_token + + +def test_expired_does_refresh(): + expired_token = Token(access_token="access_token", expiry=datetime.now() - timedelta(seconds=300), ) + new_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), ) + # Add one second to the refresh time to ensure that the call is blocking. + # If the call is not blocking, the wait time will ensure that the + # old token is returned. + r = _MockRefreshable(token=expired_token, + disable_async=False, + refresh_effect=static_token(new_token, wait=1)) + result = r.token() + assert r._refresh_count == 1 + assert result == new_token + + +def test_stale_does_refresh_async(): + stale_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=50), ) + new_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), ) + # Add one second to the refresh to avoid race conditions. + # Without it, the new token may be returned in some cases. + refresh, unblock = blocking_refresh(new_token) + r = _MockRefreshable(token=stale_token, disable_async=False, refresh_effect=refresh) + result = r.token() + # NOTE: Do not check for refresh count here, since the + assert result == stale_token + assert r._refresh_count == 0 + # Unblock the refresh and wait + unblock() + time.sleep(2) + # Call again and check that you get the new token + result = r.token() + assert result == new_token + # Ensure that all calls have completed + time.sleep(0.1) + assert r._refresh_count == 1 + + +def test_no_token_does_refresh(): + new_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), ) + # Add one second to the refresh time to ensure that the call is blocking. + # If the call is not blocking, the wait time will ensure that the + # token is not returned. + r = _MockRefreshable(token=None, disable_async=False, refresh_effect=static_token(new_token, wait=1)) + result = r.token() + assert r._refresh_count == 1 + assert result == new_token + + +def test_fresh_does_not_refresh(): + fresh_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), ) + r = _MockRefreshable(token=fresh_token, disable_async=False, refresh_effect=fail) + result = r.token() + assert r._refresh_count == 0 + assert result == fresh_token + + +def test_multiple_calls_dont_start_many_threads(): + stale_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=59), ) + new_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), ) + refresh, unblock = blocking_refresh(new_token) + r = _MockRefreshable(token=stale_token, disable_async=False, refresh_effect=refresh) + # Call twice. The second call should not start a new thread. + result = r.token() + assert result == stale_token + result = r.token() + assert result == stale_token + unblock() + # Wait for the refresh to complete + time.sleep(1) + result = r.token() + # Check that only one refresh was called + assert r._refresh_count == 1 + assert result == new_token + + +def test_async_failure_disables_async(): + stale_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=59), ) + new_token = Token(access_token="new_token", expiry=datetime.now() + timedelta(seconds=300), ) + r = _MockRefreshable(token=stale_token, disable_async=False, refresh_effect=fail) + # The call should fail and disable async refresh, + # but the exception will be catch inside the tread. + result = r.token() + assert result == stale_token + # Give time to the async refresh to fail + time.sleep(1) + assert r._refresh_err + # Now, the refresh should be blocking. + # Blocking refresh only happens for expired, not stale. + # Therefore, the next call should return the stale token. + r._refresh_effect = static_token(new_token, wait=1) + result = r.token() + assert result == stale_token + # Wait to be sure no async thread was started + time.sleep(1) + assert r._refresh_count == 0 + + # Inject an expired token. + expired_token = Token(access_token="access_token", expiry=datetime.now() - timedelta(seconds=300), ) + r._token = expired_token + + # This should be blocking and return the new token. + result = r.token() + assert r._refresh_count == 1 + assert result == new_token + # The refresh error should be cleared. + assert not r._refresh_err