Skip to content

Commit

Permalink
Cache OAuth access token per host and user pair
Browse files Browse the repository at this point in the history
  • Loading branch information
hovaesco committed Dec 29, 2023
1 parent 0517c65 commit 5632003
Showing 1 changed file with 37 additions and 21 deletions.
58 changes: 37 additions & 21 deletions trino/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@

from requests import PreparedRequest, Request, Response, Session
from requests.auth import AuthBase, extract_cookies_to_jar
from requests.utils import parse_dict_header
from requests.utils import CaseInsensitiveDict, parse_dict_header

import trino.logging
from trino.client import exceptions
from trino.constants import HEADER_USER

logger = trino.logging.get_logger(__name__)

Expand Down Expand Up @@ -208,32 +209,33 @@ class _OAuth2TokenCache(metaclass=abc.ABCMeta):
"""

@abc.abstractmethod
def get_token_from_cache(self, host: Optional[str]) -> Optional[str]:
def get_token_from_cache(self, key: Optional[str]) -> Optional[str]:
pass

@abc.abstractmethod
def store_token_to_cache(self, host: Optional[str], token: str) -> None:
def store_token_to_cache(self, key: Optional[str], token: str) -> None:
pass


class _OAuth2TokenInMemoryCache(_OAuth2TokenCache):
"""
In-memory token cache implementation. The token is stored per host, so multiple clients can share the same cache.
In-memory token cache implementation. The token is stored per host and user pair,
so multiple clients can share the same cache.
"""

def __init__(self) -> None:
self._cache: Dict[Optional[str], str] = {}

def get_token_from_cache(self, host: Optional[str]) -> Optional[str]:
return self._cache.get(host)
def get_token_from_cache(self, key: Optional[str]) -> Optional[str]:
return self._cache.get(key)

def store_token_to_cache(self, host: Optional[str], token: str) -> None:
self._cache[host] = token
def store_token_to_cache(self, key: Optional[str], token: str) -> None:
self._cache[key] = token


class _OAuth2KeyRingTokenCache(_OAuth2TokenCache):
"""
Keyring Token Cache implementation
Keyring token cache implementation
"""

def __init__(self) -> None:
Expand All @@ -248,18 +250,18 @@ def is_keyring_available(self) -> bool:
return self._keyring is not None \
and not isinstance(self._keyring.get_keyring(), self._keyring.backends.fail.Keyring)

def get_token_from_cache(self, host: Optional[str]) -> Optional[str]:
def get_token_from_cache(self, key: Optional[str]) -> Optional[str]:
try:
return self._keyring.get_password(host, "token")
return self._keyring.get_password(key, "token")
except self._keyring.errors.NoKeyringError as e:
raise trino.exceptions.NotSupportedError("Although keyring module is installed no backend has been "
"detected, check https://pypi.org/project/keyring/ for more "
"information.") from e

def store_token_to_cache(self, host: Optional[str], token: str) -> None:
def store_token_to_cache(self, key: Optional[str], token: str) -> None:
try:
# keyring is installed, so we can store the token for reuse within multiple threads
self._keyring.set_password(host, "token", token)
self._keyring.set_password(key, "token", token)
except self._keyring.errors.NoKeyringError as e:
raise trino.exceptions.NotSupportedError("Although keyring module is installed no backend has been "
"detected, check https://pypi.org/project/keyring/ for more "
Expand All @@ -268,7 +270,7 @@ def store_token_to_cache(self, host: Optional[str], token: str) -> None:

class _OAuth2TokenBearer(AuthBase):
"""
Custom implementation of Trino Oauth2 based authorization to get the token
Custom implementation of Trino OAuth2 based authentication to get the token
"""
MAX_OAUTH_ATTEMPTS = 5
_BEARER_PREFIX = re.compile(r"bearer", flags=re.IGNORECASE)
Expand All @@ -283,7 +285,9 @@ def __init__(self, redirect_auth_url_handler: Callable[[str], None]):

def __call__(self, r: PreparedRequest) -> PreparedRequest:
host = self._determine_host(r.url)
token = self._get_token_from_cache(host)
user = self._determine_user(r.headers)
key = self._construct_cache_key(host, user)
token = self._get_token_from_cache(key)

if token is not None:
r.headers['Authorization'] = "Bearer " + token
Expand Down Expand Up @@ -341,15 +345,19 @@ def _attempt_oauth(self, response: Response, **kwargs: Any) -> None:

request = response.request
host = self._determine_host(request.url)
self._store_token_to_cache(host, token)
user = self._determine_user(request.headers)
key = self._construct_cache_key(host, user)
self._store_token_to_cache(key, token)

def _retry_request(self, response: Response, **kwargs: Any) -> Optional[Response]:
request = response.request.copy()
extract_cookies_to_jar(request._cookies, response.request, response.raw) # type: ignore
request.prepare_cookies(request._cookies) # type: ignore

host = self._determine_host(response.request.url)
token = self._get_token_from_cache(host)
user = self._determine_user(request.headers)
key = self._construct_cache_key(host, user)
token = self._get_token_from_cache(key)
if token is not None:
request.headers['Authorization'] = "Bearer " + token
retry_response = response.connection.send(request, **kwargs) # type: ignore
Expand Down Expand Up @@ -382,18 +390,26 @@ def _get_token(self, token_server: str, response: Response, **kwargs: Any) -> st

raise exceptions.TrinoAuthError("Exceeded max attempts while getting the token")

def _get_token_from_cache(self, host: Optional[str]) -> Optional[str]:
def _get_token_from_cache(self, key: Optional[str]) -> Optional[str]:
with self._token_lock:
return self._token_cache.get_token_from_cache(host)
return self._token_cache.get_token_from_cache(key)

def _store_token_to_cache(self, host: Optional[str], token: str) -> None:
def _store_token_to_cache(self, key: Optional[str], token: str) -> None:
with self._token_lock:
self._token_cache.store_token_to_cache(host, token)
self._token_cache.store_token_to_cache(key, token)

@staticmethod
def _determine_host(url: Optional[str]) -> Any:
return urlparse(url).hostname

@staticmethod
def _determine_user(headers: CaseInsensitiveDict[Any]) -> Optional[Any]:
return headers.get(HEADER_USER)

@staticmethod
def _construct_cache_key(host: Optional[str], user: Optional[str]) -> str:
return f"{host}@{user}"


class OAuth2Authentication(Authentication):
def __init__(self, redirect_auth_url_handler: CompositeRedirectHandler = CompositeRedirectHandler([
Expand Down

0 comments on commit 5632003

Please sign in to comment.