-
Notifications
You must be signed in to change notification settings - Fork 168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OAuth cache improvements #435
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,7 @@ | |
import re | ||
import threading | ||
import webbrowser | ||
from typing import Any, Callable, Dict, List, Optional, Tuple | ||
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple | ||
from urllib.parse import urlparse | ||
|
||
from requests import PreparedRequest, Request, Response, Session | ||
|
@@ -26,6 +26,7 @@ | |
|
||
import trino.logging | ||
from trino.client import exceptions | ||
from trino.constants import HEADER_USER | ||
|
||
logger = trino.logging.get_logger(__name__) | ||
|
||
|
@@ -208,32 +209,33 @@ class _OAuth2TokenCache(metaclass=abc.ABCMeta): | |
""" | ||
|
||
@abc.abstractmethod | ||
def get_token_from_cache(self, host: Optional[str]) -> Optional[str]: | ||
def get_token_from_cache(self, key: Optional[str]) -> Optional[str]: | ||
pass | ||
|
||
@abc.abstractmethod | ||
def store_token_to_cache(self, host: Optional[str], token: str) -> None: | ||
def store_token_to_cache(self, key: Optional[str], token: str) -> None: | ||
pass | ||
|
||
|
||
class _OAuth2TokenInMemoryCache(_OAuth2TokenCache): | ||
""" | ||
In-memory token cache implementation. The token is stored per host, so multiple clients can share the same cache. | ||
Multiple clients can share the same cache only if each connection explicitly specifies | ||
a user otherwise the first cached token will be used to authenticate all other users. | ||
""" | ||
|
||
def __init__(self) -> None: | ||
self._cache: Dict[Optional[str], str] = {} | ||
|
||
def get_token_from_cache(self, host: Optional[str]) -> Optional[str]: | ||
return self._cache.get(host) | ||
def get_token_from_cache(self, key: Optional[str]) -> Optional[str]: | ||
return self._cache.get(key) | ||
|
||
def store_token_to_cache(self, host: Optional[str], token: str) -> None: | ||
self._cache[host] = token | ||
def store_token_to_cache(self, key: Optional[str], token: str) -> None: | ||
self._cache[key] = token | ||
|
||
|
||
class _OAuth2KeyRingTokenCache(_OAuth2TokenCache): | ||
""" | ||
Keyring Token Cache implementation | ||
Keyring token cache implementation | ||
""" | ||
|
||
def __init__(self) -> None: | ||
|
@@ -245,20 +247,21 @@ def __init__(self) -> None: | |
logger.info("keyring module not found. OAuth2 token will not be stored in keyring.") | ||
|
||
def is_keyring_available(self) -> bool: | ||
return self._keyring is not None | ||
return self._keyring is not None \ | ||
and not isinstance(self._keyring.get_keyring(), self._keyring.backends.fail.Keyring) | ||
|
||
def get_token_from_cache(self, host: Optional[str]) -> Optional[str]: | ||
def get_token_from_cache(self, key: Optional[str]) -> Optional[str]: | ||
try: | ||
return self._keyring.get_password(host, "token") | ||
return self._keyring.get_password(key, "token") | ||
except self._keyring.errors.NoKeyringError as e: | ||
raise trino.exceptions.NotSupportedError("Although keyring module is installed no backend has been " | ||
"detected, check https://pypi.org/project/keyring/ for more " | ||
"information.") from e | ||
|
||
def store_token_to_cache(self, host: Optional[str], token: str) -> None: | ||
def store_token_to_cache(self, key: Optional[str], token: str) -> None: | ||
try: | ||
# keyring is installed, so we can store the token for reuse within multiple threads | ||
self._keyring.set_password(host, "token", token) | ||
self._keyring.set_password(key, "token", token) | ||
except self._keyring.errors.NoKeyringError as e: | ||
raise trino.exceptions.NotSupportedError("Although keyring module is installed no backend has been " | ||
"detected, check https://pypi.org/project/keyring/ for more " | ||
|
@@ -267,7 +270,7 @@ def store_token_to_cache(self, host: Optional[str], token: str) -> None: | |
|
||
class _OAuth2TokenBearer(AuthBase): | ||
""" | ||
Custom implementation of Trino Oauth2 based authorization to get the token | ||
Custom implementation of Trino OAuth2 based authentication to get the token | ||
""" | ||
MAX_OAUTH_ATTEMPTS = 5 | ||
_BEARER_PREFIX = re.compile(r"bearer", flags=re.IGNORECASE) | ||
|
@@ -282,7 +285,9 @@ def __init__(self, redirect_auth_url_handler: Callable[[str], None]): | |
|
||
def __call__(self, r: PreparedRequest) -> PreparedRequest: | ||
host = self._determine_host(r.url) | ||
token = self._get_token_from_cache(host) | ||
user = self._determine_user(r.headers) | ||
key = self._construct_cache_key(host, user) | ||
token = self._get_token_from_cache(key) | ||
|
||
if token is not None: | ||
r.headers['Authorization'] = "Bearer " + token | ||
|
@@ -340,15 +345,19 @@ def _attempt_oauth(self, response: Response, **kwargs: Any) -> None: | |
|
||
request = response.request | ||
host = self._determine_host(request.url) | ||
self._store_token_to_cache(host, token) | ||
user = self._determine_user(request.headers) | ||
key = self._construct_cache_key(host, user) | ||
self._store_token_to_cache(key, token) | ||
|
||
def _retry_request(self, response: Response, **kwargs: Any) -> Optional[Response]: | ||
request = response.request.copy() | ||
extract_cookies_to_jar(request._cookies, response.request, response.raw) # type: ignore | ||
request.prepare_cookies(request._cookies) # type: ignore | ||
|
||
host = self._determine_host(response.request.url) | ||
token = self._get_token_from_cache(host) | ||
user = self._determine_user(request.headers) | ||
key = self._construct_cache_key(host, user) | ||
token = self._get_token_from_cache(key) | ||
if token is not None: | ||
request.headers['Authorization'] = "Bearer " + token | ||
retry_response = response.connection.send(request, **kwargs) # type: ignore | ||
|
@@ -381,18 +390,29 @@ def _get_token(self, token_server: str, response: Response, **kwargs: Any) -> st | |
|
||
raise exceptions.TrinoAuthError("Exceeded max attempts while getting the token") | ||
|
||
def _get_token_from_cache(self, host: Optional[str]) -> Optional[str]: | ||
def _get_token_from_cache(self, key: Optional[str]) -> Optional[str]: | ||
with self._token_lock: | ||
return self._token_cache.get_token_from_cache(host) | ||
return self._token_cache.get_token_from_cache(key) | ||
|
||
def _store_token_to_cache(self, host: Optional[str], token: str) -> None: | ||
def _store_token_to_cache(self, key: Optional[str], token: str) -> None: | ||
with self._token_lock: | ||
self._token_cache.store_token_to_cache(host, token) | ||
self._token_cache.store_token_to_cache(key, token) | ||
|
||
@staticmethod | ||
def _determine_host(url: Optional[str]) -> Any: | ||
return urlparse(url).hostname | ||
|
||
@staticmethod | ||
def _determine_user(headers: Mapping[Any, Any]) -> Optional[Any]: | ||
return headers.get(HEADER_USER) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right, is there any other way to get username? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the actual user might come from user mapping then getting it from the server is probably the only option. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe we can use the client provided principal/user instead of user which it gets resolved to on the server to bypass this problem entirely? |
||
|
||
@staticmethod | ||
def _construct_cache_key(host: Optional[str], user: Optional[str]) -> Optional[str]: | ||
if user is None: | ||
return host | ||
else: | ||
return f"{host}@{user}" | ||
|
||
|
||
class OAuth2Authentication(Authentication): | ||
def __init__(self, redirect_auth_url_handler: CompositeRedirectHandler = CompositeRedirectHandler([ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you extract the
host
->key
rename to it's own commit functional changes will be more visible and easier to see.