Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OAuth cache improvements #435

Merged
merged 3 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,10 @@ the [OAuth2 authentication type](https://trino.io/docs/current/security/oauth2.h

A callback to handle the redirect url can be provided via param `redirect_auth_url_handler` of the `trino.auth.OAuth2Authentication` class. By default, it will try to launch a web browser (`trino.auth.WebBrowserRedirectHandler`) to go through the authentication flow and output the redirect url to stdout (`trino.auth.ConsoleRedirectHandler`). Multiple redirect handlers are combined using the `trino.auth.CompositeRedirectHandler` class.

The OAuth2 token will be cached either per `trino.auth.OAuth2Authentication` instance or, when keyring is installed, it will be cached within a secure backend (MacOS keychain, Windows credential locker, etc) under a key including host of the Trino connection. Keyring can be installed using `pip install 'trino[external-authentication-token-cache]'`.
The OAuth2 token will be cached either per `trino.auth.OAuth2Authentication` instance and username or, when keyring is installed, it will be cached within a secure backend (MacOS keychain, Windows credential locker, etc) under a key including host of the Trino connection. Keyring can be installed using `pip install 'trino[external-authentication-token-cache]'`.

> [!WARNING]
> If username is not specified then the OAuth2 token cache is shared and stored per host.

- DBAPI

Expand Down
64 changes: 42 additions & 22 deletions trino/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import re
import threading
import webbrowser
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple
from urllib.parse import urlparse

from requests import PreparedRequest, Request, Response, Session
Expand All @@ -26,6 +26,7 @@

import trino.logging
from trino.client import exceptions
from trino.constants import HEADER_USER

logger = trino.logging.get_logger(__name__)

Expand Down Expand Up @@ -208,32 +209,33 @@ class _OAuth2TokenCache(metaclass=abc.ABCMeta):
"""

@abc.abstractmethod
def get_token_from_cache(self, host: Optional[str]) -> Optional[str]:
def get_token_from_cache(self, key: Optional[str]) -> Optional[str]:
pass

@abc.abstractmethod
def store_token_to_cache(self, host: Optional[str], token: str) -> None:
def store_token_to_cache(self, key: Optional[str], token: str) -> None:
pass


class _OAuth2TokenInMemoryCache(_OAuth2TokenCache):
"""
In-memory token cache implementation. The token is stored per host, so multiple clients can share the same cache.
Multiple clients can share the same cache only if each connection explicitly specifies
a user otherwise the first cached token will be used to authenticate all other users.
"""

def __init__(self) -> None:
self._cache: Dict[Optional[str], str] = {}

def get_token_from_cache(self, host: Optional[str]) -> Optional[str]:
return self._cache.get(host)
def get_token_from_cache(self, key: Optional[str]) -> Optional[str]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you extract the host -> key rename to it's own commit functional changes will be more visible and easier to see.

return self._cache.get(key)

def store_token_to_cache(self, host: Optional[str], token: str) -> None:
self._cache[host] = token
def store_token_to_cache(self, key: Optional[str], token: str) -> None:
self._cache[key] = token


class _OAuth2KeyRingTokenCache(_OAuth2TokenCache):
"""
Keyring Token Cache implementation
Keyring token cache implementation
"""

def __init__(self) -> None:
Expand All @@ -245,20 +247,21 @@ def __init__(self) -> None:
logger.info("keyring module not found. OAuth2 token will not be stored in keyring.")

def is_keyring_available(self) -> bool:
return self._keyring is not None
return self._keyring is not None \
and not isinstance(self._keyring.get_keyring(), self._keyring.backends.fail.Keyring)

def get_token_from_cache(self, host: Optional[str]) -> Optional[str]:
def get_token_from_cache(self, key: Optional[str]) -> Optional[str]:
try:
return self._keyring.get_password(host, "token")
return self._keyring.get_password(key, "token")
except self._keyring.errors.NoKeyringError as e:
raise trino.exceptions.NotSupportedError("Although keyring module is installed no backend has been "
"detected, check https://pypi.org/project/keyring/ for more "
"information.") from e

def store_token_to_cache(self, host: Optional[str], token: str) -> None:
def store_token_to_cache(self, key: Optional[str], token: str) -> None:
try:
# keyring is installed, so we can store the token for reuse within multiple threads
self._keyring.set_password(host, "token", token)
self._keyring.set_password(key, "token", token)
except self._keyring.errors.NoKeyringError as e:
raise trino.exceptions.NotSupportedError("Although keyring module is installed no backend has been "
"detected, check https://pypi.org/project/keyring/ for more "
Expand All @@ -267,7 +270,7 @@ def store_token_to_cache(self, host: Optional[str], token: str) -> None:

class _OAuth2TokenBearer(AuthBase):
"""
Custom implementation of Trino Oauth2 based authorization to get the token
Custom implementation of Trino OAuth2 based authentication to get the token
"""
MAX_OAUTH_ATTEMPTS = 5
_BEARER_PREFIX = re.compile(r"bearer", flags=re.IGNORECASE)
Expand All @@ -282,7 +285,9 @@ def __init__(self, redirect_auth_url_handler: Callable[[str], None]):

def __call__(self, r: PreparedRequest) -> PreparedRequest:
host = self._determine_host(r.url)
token = self._get_token_from_cache(host)
user = self._determine_user(r.headers)
key = self._construct_cache_key(host, user)
token = self._get_token_from_cache(key)

if token is not None:
r.headers['Authorization'] = "Bearer " + token
Expand Down Expand Up @@ -340,15 +345,19 @@ def _attempt_oauth(self, response: Response, **kwargs: Any) -> None:

request = response.request
host = self._determine_host(request.url)
self._store_token_to_cache(host, token)
user = self._determine_user(request.headers)
key = self._construct_cache_key(host, user)
self._store_token_to_cache(key, token)

def _retry_request(self, response: Response, **kwargs: Any) -> Optional[Response]:
request = response.request.copy()
extract_cookies_to_jar(request._cookies, response.request, response.raw) # type: ignore
request.prepare_cookies(request._cookies) # type: ignore

host = self._determine_host(response.request.url)
token = self._get_token_from_cache(host)
user = self._determine_user(request.headers)
key = self._construct_cache_key(host, user)
token = self._get_token_from_cache(key)
if token is not None:
request.headers['Authorization'] = "Bearer " + token
retry_response = response.connection.send(request, **kwargs) # type: ignore
Expand Down Expand Up @@ -381,18 +390,29 @@ def _get_token(self, token_server: str, response: Response, **kwargs: Any) -> st

raise exceptions.TrinoAuthError("Exceeded max attempts while getting the token")

def _get_token_from_cache(self, host: Optional[str]) -> Optional[str]:
def _get_token_from_cache(self, key: Optional[str]) -> Optional[str]:
with self._token_lock:
return self._token_cache.get_token_from_cache(host)
return self._token_cache.get_token_from_cache(key)

def _store_token_to_cache(self, host: Optional[str], token: str) -> None:
def _store_token_to_cache(self, key: Optional[str], token: str) -> None:
with self._token_lock:
self._token_cache.store_token_to_cache(host, token)
self._token_cache.store_token_to_cache(key, token)

@staticmethod
def _determine_host(url: Optional[str]) -> Any:
return urlparse(url).hostname

@staticmethod
def _determine_user(headers: Mapping[Any, Any]) -> Optional[Any]:
return headers.get(HEADER_USER)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

X-Trino-User header is optional

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, is there any other way to get username?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the actual user might come from user mapping then getting it from the server is probably the only option.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can use the client provided principal/user instead of user which it gets resolved to on the server to bypass this problem entirely?


@staticmethod
def _construct_cache_key(host: Optional[str], user: Optional[str]) -> Optional[str]:
if user is None:
return host
else:
return f"{host}@{user}"


class OAuth2Authentication(Authentication):
def __init__(self, redirect_auth_url_handler: CompositeRedirectHandler = CompositeRedirectHandler([
Expand Down
Loading