Skip to content

Commit

Permalink
[Internal] Implement async token refresh (#893)
Browse files Browse the repository at this point in the history
## What changes are proposed in this pull request?

This PR is a step towards enabling asynchronous refreshes of data plane
tokens.
This PR updates the existing `Refreshable` abstract token class to
support async token refresh.

Note: async refreshes are disabled at the moment and will be enabled in
a follow-up PR.

## How is this tested?

Added unit tests.

## Changelog
The changelog entry will be added when the feature is enabled.

NO_CHANGELOG=true
  • Loading branch information
hectorcast-db authored Feb 24, 2025
1 parent 3d3752a commit e550ca1
Show file tree
Hide file tree
Showing 2 changed files with 339 additions and 10 deletions.
133 changes: 123 additions & 10 deletions databricks/sdk/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
import urllib.parse
import webbrowser
from abc import abstractmethod
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Any, Dict, List, Optional

Expand Down Expand Up @@ -187,21 +189,132 @@ def retrieve_token(client_id,
raise NotImplementedError(f"Not supported yet: {e}")


class _TokenState(Enum):
"""
Represents the state of a token. Each token can be in one of
the following three states:
- FRESH: The token is valid.
- STALE: The token is valid but will expire soon.
- EXPIRED: The token has expired and cannot be used.
"""
FRESH = 1 # The token is valid.
STALE = 2 # The token is valid but will expire soon.
EXPIRED = 3 # The token has expired and cannot be used.


class Refreshable(TokenSource):
"""A token source that supports refreshing expired tokens."""

_EXECUTOR = None
_EXECUTOR_LOCK = threading.Lock()
_DEFAULT_STALE_DURATION = timedelta(minutes=3)

@classmethod
def _get_executor(cls):
"""Lazy initialization of the ThreadPoolExecutor."""
if cls._EXECUTOR is None:
with cls._EXECUTOR_LOCK:
if cls._EXECUTOR is None:
# This thread pool has multiple workers because it is shared by all instances of Refreshable.
cls._EXECUTOR = ThreadPoolExecutor(max_workers=10)
return cls._EXECUTOR

def __init__(self, token=None):
self._lock = threading.Lock() # to guard _token
def __init__(self,
token: Token = None,
disable_async: bool = True,
stale_duration: timedelta = _DEFAULT_STALE_DURATION):
# Config properties
self._stale_duration = stale_duration
self._disable_async = disable_async
# Lock
self._lock = threading.Lock()
# Non Thread safe properties. They should be accessed only when protected by the lock above.
self._token = token
self._is_refreshing = False
self._refresh_err = False

# This is the main entry point for the Token. Do not access the token
# using any of the internal functions.
def token(self) -> Token:
self._lock.acquire()
try:
if self._token and self._token.valid:
return self._token
self._token = self.refresh()
"""Returns a valid token, blocking if async refresh is disabled."""
with self._lock:
if self._disable_async:
return self._blocking_token()
return self._async_token()

def _async_token(self) -> Token:
"""
Returns a token.
If the token is stale, triggers an asynchronous refresh.
If the token is expired, refreshes it synchronously, blocking until the refresh is complete.
"""
state = self._token_state()
token = self._token

if state == _TokenState.FRESH:
return token
if state == _TokenState.STALE:
self._trigger_async_refresh()
return token
return self._blocking_token()

def _token_state(self) -> _TokenState:
"""Returns the current state of the token."""
if not self._token or not self._token.valid:
return _TokenState.EXPIRED
if not self._token.expiry:
return _TokenState.FRESH

lifespan = self._token.expiry - datetime.now()
if lifespan < timedelta(seconds=0):
return _TokenState.EXPIRED
if lifespan < self._stale_duration:
return _TokenState.STALE
return _TokenState.FRESH

def _blocking_token(self) -> Token:
"""Returns a token, blocking if necessary to refresh it."""
state = self._token_state()
# This is important to recover from potential previous failed attempts
# to refresh the token asynchronously.
self._refresh_err = False
self._is_refreshing = False

# It's possible that the token got refreshed (either by a _blocking_refresh or
# an _async_refresh call) while this particular call was waiting to acquire
# the lock. This check avoids refreshing the token again in such cases.
if state != _TokenState.EXPIRED:
return self._token
finally:
self._lock.release()

self._token = self.refresh()
return self._token

def _trigger_async_refresh(self):
"""Starts an asynchronous refresh if none is in progress."""

def _refresh_internal():
new_token: Token = None
try:
new_token = self.refresh()
except Exception as e:
# This happens on a thread, so we don't want to propagate the error.
# Instead, if there is no new_token for any reason, we will disable async refresh below
# But we will do it inside the lock.
logger.warning(f'Tried to refresh token asynchronously, but failed: {e}')

with self._lock:
if new_token is not None:
self._token = new_token
else:
self._refresh_err = True
self._is_refreshing = False

# The token may have been refreshed by another thread.
if self._token_state() == _TokenState.FRESH:
return
if not self._is_refreshing and not self._refresh_err:
self._is_refreshing = True
Refreshable._get_executor().submit(_refresh_internal)

@abstractmethod
def refresh(self) -> Token:
Expand Down Expand Up @@ -295,7 +408,7 @@ def __init__(self,
super().__init__(token)

def as_dict(self) -> dict:
return {'token': self._token.as_dict()}
return {'token': self.token().as_dict()}

@staticmethod
def from_dict(raw: dict,
Expand Down
216 changes: 216 additions & 0 deletions tests/test_refreshable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import time
from datetime import datetime, timedelta
from time import sleep
from typing import Callable

from databricks.sdk.oauth import Refreshable, Token


class _MockRefreshable(Refreshable):

def __init__(self,
disable_async,
token=None,
stale_duration=timedelta(seconds=60),
refresh_effect: Callable[[], Token] = None):
super().__init__(token, disable_async, stale_duration)
self._refresh_effect = refresh_effect
self._refresh_count = 0

def refresh(self) -> Token:
if self._refresh_effect:
self._token = self._refresh_effect()
self._refresh_count += 1
return self._token


def fail() -> Token:
raise Exception("Simulated token refresh failure")


def static_token(token: Token, wait: int = 0) -> Callable[[], Token]:

def f() -> Token:
time.sleep(wait)
return token

return f


def blocking_refresh(token: Token) -> (Callable[[], Token], Callable[[], None]):
"""
Create a refresh function that blocks until unblock is called.
Param:
token: the token that will be returned
Returns:
A tuple containing the refresh function and the unblock function.
"""
blocking = True

def refresh():
while blocking:
sleep(0.1)
return token

def unblock():
nonlocal blocking
blocking = False

return refresh, unblock


def test_disable_async_stale_does_not_refresh():
stale_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=50), )
r = _MockRefreshable(token=stale_token, disable_async=True, refresh_effect=fail)
result = r.token()
assert r._refresh_count == 0
assert result == stale_token


def test_disable_async_no_token_does_refresh():
token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=50), )
r = _MockRefreshable(token=None, disable_async=True, refresh_effect=static_token(token))
result = r.token()
assert r._refresh_count == 1
assert result == token


def test_disable_async_no_expiration_does_not_refresh():
non_expiring_token = Token(access_token="access_token", )
r = _MockRefreshable(token=non_expiring_token, disable_async=True, refresh_effect=fail)
result = r.token()
assert r._refresh_count == 0
assert result == non_expiring_token


def test_disable_async_fresh_does_not_refresh():
# Create a token that is already stale. If async is disabled, the token should not be refreshed.
token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), )
r = _MockRefreshable(token=token, disable_async=True, refresh_effect=fail)
result = r.token()
assert r._refresh_count == 0
assert result == token


def test_disable_async_expired_does_refresh():
expired_token = Token(access_token="access_token", expiry=datetime.now() - timedelta(seconds=300), )
new_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), )
# Add one second to the refresh time to ensure that the call is blocking.
# If the call is not blocking, the wait time will ensure that the
# old token is returned.
r = _MockRefreshable(token=expired_token,
disable_async=True,
refresh_effect=static_token(new_token, wait=1))
result = r.token()
assert r._refresh_count == 1
assert result == new_token


def test_expired_does_refresh():
expired_token = Token(access_token="access_token", expiry=datetime.now() - timedelta(seconds=300), )
new_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), )
# Add one second to the refresh time to ensure that the call is blocking.
# If the call is not blocking, the wait time will ensure that the
# old token is returned.
r = _MockRefreshable(token=expired_token,
disable_async=False,
refresh_effect=static_token(new_token, wait=1))
result = r.token()
assert r._refresh_count == 1
assert result == new_token


def test_stale_does_refresh_async():
stale_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=50), )
new_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), )
# Add one second to the refresh to avoid race conditions.
# Without it, the new token may be returned in some cases.
refresh, unblock = blocking_refresh(new_token)
r = _MockRefreshable(token=stale_token, disable_async=False, refresh_effect=refresh)
result = r.token()
# NOTE: Do not check for refresh count here, since the
assert result == stale_token
assert r._refresh_count == 0
# Unblock the refresh and wait
unblock()
time.sleep(2)
# Call again and check that you get the new token
result = r.token()
assert result == new_token
# Ensure that all calls have completed
time.sleep(0.1)
assert r._refresh_count == 1


def test_no_token_does_refresh():
new_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), )
# Add one second to the refresh time to ensure that the call is blocking.
# If the call is not blocking, the wait time will ensure that the
# token is not returned.
r = _MockRefreshable(token=None, disable_async=False, refresh_effect=static_token(new_token, wait=1))
result = r.token()
assert r._refresh_count == 1
assert result == new_token


def test_fresh_does_not_refresh():
fresh_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), )
r = _MockRefreshable(token=fresh_token, disable_async=False, refresh_effect=fail)
result = r.token()
assert r._refresh_count == 0
assert result == fresh_token


def test_multiple_calls_dont_start_many_threads():
stale_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=59), )
new_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=300), )
refresh, unblock = blocking_refresh(new_token)
r = _MockRefreshable(token=stale_token, disable_async=False, refresh_effect=refresh)
# Call twice. The second call should not start a new thread.
result = r.token()
assert result == stale_token
result = r.token()
assert result == stale_token
unblock()
# Wait for the refresh to complete
time.sleep(1)
result = r.token()
# Check that only one refresh was called
assert r._refresh_count == 1
assert result == new_token


def test_async_failure_disables_async():
stale_token = Token(access_token="access_token", expiry=datetime.now() + timedelta(seconds=59), )
new_token = Token(access_token="new_token", expiry=datetime.now() + timedelta(seconds=300), )
r = _MockRefreshable(token=stale_token, disable_async=False, refresh_effect=fail)
# The call should fail and disable async refresh,
# but the exception will be catch inside the tread.
result = r.token()
assert result == stale_token
# Give time to the async refresh to fail
time.sleep(1)
assert r._refresh_err
# Now, the refresh should be blocking.
# Blocking refresh only happens for expired, not stale.
# Therefore, the next call should return the stale token.
r._refresh_effect = static_token(new_token, wait=1)
result = r.token()
assert result == stale_token
# Wait to be sure no async thread was started
time.sleep(1)
assert r._refresh_count == 0

# Inject an expired token.
expired_token = Token(access_token="access_token", expiry=datetime.now() - timedelta(seconds=300), )
r._token = expired_token

# This should be blocking and return the new token.
result = r.token()
assert r._refresh_count == 1
assert result == new_token
# The refresh error should be cleared.
assert not r._refresh_err

0 comments on commit e550ca1

Please sign in to comment.