From 8d7ef28a5b931b48a0203ca28bb36cb872425624 Mon Sep 17 00:00:00 2001 From: "Nickolaus D. Saint" Date: Wed, 31 Oct 2018 11:53:25 -0700 Subject: [PATCH] Added features targeted for developers writing custom clients These changes make it easier for developers to create their own clients by adding simple built-in token storage, and a customized auth flow with a built-in local server. A built-in client_id was also added to reduce the barrier of entry required for creating a new client. Developers writing simple scripts can simply do: from globus_sdk import native_auth my_tokens = native_auth(save_tokens=True) No functionality here is meant to replace existing SDK functionality. Instead, the changes are intended to address commonly copied and re-implemented code. --- globus_sdk/__init__.py | 8 +- globus_sdk/auth/__init__.py | 2 + globus_sdk/auth/oauth2_native_app_shortcut.py | 183 +++++++++++++ globus_sdk/config.py | 90 +++++++ globus_sdk/exc.py | 16 ++ globus_sdk/utils/local_server.py | 153 +++++++++++ globus_sdk/utils/safeio.py | 34 +++ globus_sdk/utils/token_storage.py | 217 +++++++++++++++ tests/files/sample_configs/set_test.cfg | 2 + tests/unit/conftest.py | 15 ++ tests/unit/test_config.py | 36 +++ tests/unit/test_config_saving.py | 51 ++++ tests/unit/test_native_auth.py | 254 ++++++++++++++++++ tests/unit/test_utils_local_server.py | 59 ++++ tests/unit/test_utils_safe_io.py | 27 ++ tests/unit/test_utils_token_storage.py | 117 ++++++++ 16 files changed, 1263 insertions(+), 1 deletion(-) create mode 100644 globus_sdk/auth/oauth2_native_app_shortcut.py create mode 100644 globus_sdk/utils/local_server.py create mode 100644 globus_sdk/utils/safeio.py create mode 100644 globus_sdk/utils/token_storage.py create mode 100644 tests/files/sample_configs/set_test.cfg create mode 100644 tests/unit/conftest.py create mode 100644 tests/unit/test_config_saving.py create mode 100644 tests/unit/test_native_auth.py create mode 100644 tests/unit/test_utils_local_server.py create mode 100644 tests/unit/test_utils_safe_io.py create mode 100644 tests/unit/test_utils_token_storage.py diff --git a/globus_sdk/__init__.py b/globus_sdk/__init__.py index 3b1792453..30bc09729 100644 --- a/globus_sdk/__init__.py +++ b/globus_sdk/__init__.py @@ -1,6 +1,11 @@ import logging -from globus_sdk.auth import AuthClient, ConfidentialAppAuthClient, NativeAppAuthClient +from globus_sdk.auth import ( + AuthClient, + ConfidentialAppAuthClient, + NativeAppAuthClient, + native_auth, +) from globus_sdk.authorizers import ( AccessTokenAuthorizer, BasicAuthorizer, @@ -48,6 +53,7 @@ "ClientCredentialsAuthorizer", "AuthClient", "NativeAppAuthClient", + "native_auth", "ConfidentialAppAuthClient", "TransferClient", "TransferData", diff --git a/globus_sdk/auth/__init__.py b/globus_sdk/auth/__init__.py index 6e292c2bc..f022eed02 100644 --- a/globus_sdk/auth/__init__.py +++ b/globus_sdk/auth/__init__.py @@ -5,6 +5,7 @@ ) from globus_sdk.auth.oauth2_authorization_code import GlobusAuthorizationCodeFlowManager from globus_sdk.auth.oauth2_native_app import GlobusNativeAppFlowManager +from globus_sdk.auth.oauth2_native_app_shortcut import native_auth __all__ = [ "AuthClient", @@ -12,4 +13,5 @@ "ConfidentialAppAuthClient", "GlobusNativeAppFlowManager", "GlobusAuthorizationCodeFlowManager", + "native_auth", ] diff --git a/globus_sdk/auth/oauth2_native_app_shortcut.py b/globus_sdk/auth/oauth2_native_app_shortcut.py new file mode 100644 index 000000000..e1299a9e7 --- /dev/null +++ b/globus_sdk/auth/oauth2_native_app_shortcut.py @@ -0,0 +1,183 @@ +import logging +import os +import webbrowser +from socket import gethostname + +from six.moves import input + +from globus_sdk.auth.client_types.native_client import NativeAppAuthClient +from globus_sdk.auth.oauth2_constants import DEFAULT_REQUESTED_SCOPES +from globus_sdk.exc import ConfigError +from globus_sdk.utils.local_server import is_remote_session, start_local_server +from globus_sdk.utils.safeio import safe_print +from globus_sdk.utils.token_storage import clear_tokens, load_tokens, save_tokens + +logger = logging.getLogger(__name__) + +AUTH_CODE_REDIRECT = "https://auth.globus.org/v2/web/auth-code" + +NATIVE_AUTH_DEFAULTS = { + "config_filename": os.path.expanduser("~/.globus-native-apps.cfg"), + "config_section": None, # Defaults to client_id if not set + "client_id": "0af96eea-fec8-4d6e-aad2-c87feed8151c", + "requested_scopes": DEFAULT_REQUESTED_SCOPES, + "refresh_tokens": False, + "prefill_named_grant": gethostname(), + "additional_auth_params": {}, + "save_tokens": False, + "check_tokens_expired": True, + "force_login": False, + "no_local_server": False, + "no_browser": False, + "server_hostname": "127.0.0.1", + "server_port": 8890, + "redirect_uri": "http://localhost:8890/", +} + + +def native_auth(**kwargs): + """ + Provides a simple shortcut for doing a native auth flow for most use-cases + by setting common defaults for frequently used fields. Although a default + client id is provided, production apps should define their own at + https://developers.globus.org. See `NativeAppAuthClient` for constructing + a more fine-tuned native auth flow. Returns tokens organized by resource + server. + + **Native App Parameters** + ``client_id`` (*string*) + Client App id registered at https://developers.globus.org. Defaults + to a built-in one for testing. + + ``requested_scopes`` (*iterable* or *string*) + The scopes on the token(s) being requested, as a space-separated + string or iterable of strings. Defaults to ``openid profile email + urn:globus:auth:scope:transfer.api.globus.org:all`` + + ``redirect_uri`` (*string*) + The page that users should be directed to after authenticating at + the authorize URL. Defaults to + 'https://auth.globus.org/v2/web/auth-code', which displays the + resulting ``auth_code`` for users to copy-paste back into your + application (and thereby be passed back to the + ``GlobusNativeAppFlowManager``) + + ``refresh_tokens`` (*bool*) + When True, request refresh tokens in addition to access tokens + + ``prefill_named_grant`` (*string*) + Optionally prefill the named grant label on the consent page + + ``additional_auth_params`` (*dict*) + Set ``additional_parameters`` in + NativeAppAuthClient.oauth2_get_authorize_url() + + **Login Parameters** + ``save_tokens`` (*bool*) + Save user tokens to disk and reload them on repeated calls. + Defaults to False. + + ``check_tokens_expired`` (*bool*) + Check if loaded access tokens have expired since the last login. + You should set this to False if using Refresh Tokens. + Defaults to True. + + ``force_login`` (*bool*) + Do not attempt to load save tokens, and complete a new auth flow + instead. Defaults to False. + + ``no_local_server`` (*bool*) + Do not start a local server for fetching the auth_code. Setting + this to false will require the user to copy paste a code into + the console. Defaults to False. + + ``no_browser`` (*bool*) + Do not automatically attempt to open a browser for the auth flow. + Defaults to False. + + ``server_hostname`` (*string*) + Hostname for the local server to use. No effect if + ``no_local_server`` is set. MUST be specified in ``redirect_uri``. + Defaults to 127.0.0.1. + + ``server_port`` (*string*) + Port for the local server to use. No effect if ``no_local_server`` + is set. MUST be specified in ``redirect_uri``. Defaults to 8890. + + **Configfile Parameters** + ``config_filename`` (*string*) + Filename to use for reading and writing values. + + ``config_section`` (*string*) + Section within the config file to store information (like tokens). + + **Examples** + + ``native_auth()`` + + Or to save tokens: ``native_auth(save_tokens=True)`` + """ + unaccepted = [k for k in kwargs.keys() if k not in NATIVE_AUTH_DEFAULTS.keys()] + if any(unaccepted): + raise ValueError("Invalid args: {}".format(unaccepted)) + + opts = {k: kwargs.get(k, v) for k, v in NATIVE_AUTH_DEFAULTS.items()} + + # Default to the auth-code page redirect if the user is copy-pasting + if ( + opts["no_local_server"] is True + and opts["redirect_uri"] == NATIVE_AUTH_DEFAULTS["redirect_uri"] + ): + opts["redirect_uri"] = AUTH_CODE_REDIRECT + + config_section = opts["config_section"] or opts["client_id"] + + if opts["force_login"] is False: + try: + return load_tokens( + config_section, opts["requested_scopes"], opts["check_tokens_expired"] + ) + except ConfigError as ce: + logger.debug( + "Loading Tokens Failed, doing auth flow instead. " + "Error: {}".format(ce) + ) + + # Clear previous tokens to ensure no previously saved scopes remain. + clear_tokens(config_section=config_section, client_id=opts["client_id"]) + + client = NativeAppAuthClient(client_id=opts["client_id"]) + client.oauth2_start_flow( + requested_scopes=opts["requested_scopes"], + redirect_uri=opts["redirect_uri"], + refresh_tokens=opts["refresh_tokens"], + prefill_named_grant=opts["prefill_named_grant"], + ) + url = client.oauth2_get_authorize_url( + additional_params=opts["additional_auth_params"] + ) + + if opts["no_local_server"] is False: + server_address = (opts["server_hostname"], opts["server_port"]) + with start_local_server(listen=server_address) as server: + _prompt_login(url, opts["no_browser"]) + auth_code = server.wait_for_code() + else: + _prompt_login(url, opts["no_browser"]) + safe_print("Enter the resulting Authorization Code here: ", end="") + auth_code = input() + + token_response = client.oauth2_exchange_code_for_tokens(auth_code) + tokens_by_resource_server = token_response.by_resource_server + if opts["save_tokens"] is True: + save_tokens(tokens_by_resource_server, config_section) + # return a set of tokens, organized by resource server name + + return tokens_by_resource_server + + +def _prompt_login(url, no_browser): + if no_browser is False and not is_remote_session(): + webbrowser.open(url, new=1) + else: + safe_print("Please paste the following URL in a browser: " "\n{}".format(url)) diff --git a/globus_sdk/config.py b/globus_sdk/config.py index 62c83f15d..350550be5 100644 --- a/globus_sdk/config.py +++ b/globus_sdk/config.py @@ -44,11 +44,13 @@ class GlobusConfigParser(object): """ _GENERAL_CONF_SECTION = "general" + DEFAULT_WRITE_PATH = os.path.expanduser("~/.globus-native-apps.cfg") def __init__(self): logger.debug("Loading SDK Config parser") self._parser = ConfigParser() self._load_config() + self._write_path = self.DEFAULT_WRITE_PATH logger.debug("Config load succeeded") def _load_config(self): @@ -134,6 +136,86 @@ def get( return value + def get_section(self, section): + """Attempt to lookup a section in the config file. Returns None + if no section is found.""" + if self._parser.has_section(section): + return dict(self._parser.items(section)) + + def set_write_config_file(self, filename): + """Set a new config file for writing to disk. Attempts to load + the new config if one exists but will not raise an error if this + fails. Future config values will be written to this location.""" + logger.debug("New config file set to: {}".format(filename)) + try: + self._parser.read([filename]) + except Exception: + logger.debug("New config failed to load: {}".format(filename)) + self._write_path = filename + + def _get_write_config(self): + """Get the config for the current configured self._write_path. If it + does not exist, an empty config is returned instead. General config + values will not be included in the returned config so they aren't + copied and written to disk. + """ + if self._write_path is None: + raise GlobusError( + "Failed to write to the config file {}, please ensure you " + "have write access.".format(self._write_path) + ) + + cfg = ConfigParser() + if not os.path.exists(self._write_path): + cfg[self._GENERAL_CONF_SECTION] = {} + else: + cfg.read(self._write_path) + + return cfg + + def _save(self, cfg): + """Saves config options to disk at the file self._write_path. The + config file permissions are also always set to only allow User access + to the config file for a little bit of added security.""" + + # deny rwx to Group and World -- don't bother storing the returned + # old mask value, since we'll never restore it anyway + # do this on every call to ensure that we're always consistent about it + os.umask(0o077) + with open(self._write_path, "w") as configfile: + cfg.write(configfile) + + def set(self, option, value, section): + """ + Write an option to disk using the previously configured config + at set_config_file() or .globus.cfg. Creates the section if it does + not exist. + """ + cfg = self._get_write_config() + + # add the section if absent + if section not in cfg.sections(): + cfg.add_section(section) + + cfg.set(section, option, value) + self._save(cfg) + + # Update the Global config + if section not in self._parser.sections(): + self._parser.add_section(section) + self._parser.set(section, option, value) + + def remove(self, option, section): + """ + Remove an option from the config. True if option previously existed, + false otherwise. + """ + cfg = self._get_write_config() + removed = cfg.remove_option(section, option) + self._save(cfg) + self._parser.remove_option(section, option) + return removed + def _get_parser(): """ @@ -146,6 +228,14 @@ def _get_parser(): return _parser +def get_parser(): + """ + Historically components only needed read-only access. Since token storage, + new components may need to lookup config values or occasionally save data + """ + return _get_parser() + + # at import-time, it's None _parser = None diff --git a/globus_sdk/exc.py b/globus_sdk/exc.py index 1547ec7d1..b10537cdb 100644 --- a/globus_sdk/exc.py +++ b/globus_sdk/exc.py @@ -271,6 +271,22 @@ class GlobusConnectionError(NetworkError): """A connection error occured while making a REST request.""" +class ConfigError(GlobusError): + """An error reading or writing from the configuration file.""" + + +class RequestedScopesMismatch(ConfigError): + """Requested scopes differ from scopes saved to config.""" + + +class LoadedTokensExpired(ConfigError): + """Tokens loaded from disk have expired since last login.""" + + +class LocalServerError(GlobusError): + """Error encountered with local server used for native auth.""" + + def convert_request_exception(exc): """Converts incoming requests.Exception to a Globus NetworkError""" diff --git a/globus_sdk/utils/local_server.py b/globus_sdk/utils/local_server.py new file mode 100644 index 000000000..f7ba00673 --- /dev/null +++ b/globus_sdk/utils/local_server.py @@ -0,0 +1,153 @@ +import logging +import os +import sys +import threading +from contextlib import contextmanager +from string import Template + +import six +from six.moves import http_client, queue +from six.moves.urllib.parse import parse_qsl, urlparse + +from globus_sdk.exc import LocalServerError + +try: + from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler +except ImportError: + from http.server import HTTPServer, BaseHTTPRequestHandler + + +HTML_TEMPLATE = Template( + """ + + + + + + Globus SDK Login + + + +
+ Globus +
+ +
+

Globus SDK

+

+ $login_result. You may close this tab. +

+ $post_login_message +

+
+ + +""" +) + +DOC_URL = """ +SDK Documentation +""" + + +def enable_requests_logging(): + http_client.HTTPConnection.debuglevel = 4 + + logging.basicConfig() + logging.getLogger().setLevel(logging.DEBUG) + requests_log = logging.getLogger("requests.packages.urllib3") + requests_log.setLevel(logging.DEBUG) + requests_log.propagate = True + + +def is_remote_session(): + return os.environ.get("SSH_TTY", os.environ.get("SSH_CONNECTION")) + + +class RedirectHandler(BaseHTTPRequestHandler): + def do_GET(self): # noqa + self.send_response(200) + self.send_header("Content-type", "text/html") + self.end_headers() + + query_params = dict(parse_qsl(urlparse(self.path).query)) + code = query_params.get("code") + if code: + self.wfile.write( + six.b( + HTML_TEMPLATE.substitute( + post_login_message=DOC_URL, login_result="Login successful" + ) + ) + ) + self.server.return_code(code) + else: + msg = query_params.get("error_description", query_params.get("error")) + + self.wfile.write( + six.b( + HTML_TEMPLATE.substitute( + post_login_message=msg, login_result="Login failed" + ) + ) + ) + + self.server.return_code(LocalServerError(msg)) + + def log_message(self, format, *args): + return + + +class RedirectHTTPServer(HTTPServer, object): + def __init__(self, listen, handler_class): + super(RedirectHTTPServer, self).__init__(listen, handler_class) + + self._auth_code_queue = queue.Queue() + + def handle_error(self, request, client_address): + exctype, excval, exctb = sys.exc_info() + self._auth_code_queue.put(excval) + + def return_code(self, code): + self._auth_code_queue.put_nowait(code) + + def wait_for_code(self): + # workaround for handling control-c interrupt. + # relevant Python issue discussing this behavior: + # https://bugs.python.org/issue1360 + try: + return self._auth_code_queue.get(block=True, timeout=3600) + except (queue.Empty, KeyboardInterrupt): + raise LocalServerError() + finally: + # shutdown() stops the server thread + # https://github.com/python/cpython/blob/3.7/Lib/socketserver.py#L241 + self.shutdown() + # server_close() closes the socket: + # https://github.com/python/cpython/blob/3.7/Lib/socketserver.py#L474 + self.server_close() + + +@contextmanager +def start_local_server(listen=("", 0)): + server = RedirectHTTPServer(listen, RedirectHandler) + thread = threading.Thread(target=server.serve_forever) + thread.daemon = True + thread.start() + + yield server diff --git a/globus_sdk/utils/safeio.py b/globus_sdk/utils/safeio.py new file mode 100644 index 000000000..627fbafd8 --- /dev/null +++ b/globus_sdk/utils/safeio.py @@ -0,0 +1,34 @@ +from __future__ import print_function + +import sys + + +class SafeIO(object): + """SafeIO allows developers to change how the SDK prints output strings + to the user. By default, it provides a generic 'write()' method for + printing strings to stdout, but can be changed if needed.""" + + def write(self, message, *args, **kwargs): + print_kwargs = { + k: arg for k, arg in kwargs.items() if k in ("sep", "end", "file", "flush") + } + print_kwargs["file"] = print_kwargs.get("file") or sys.stderr + messages = [message] + list(args) + print(*messages, **print_kwargs) + + def set_write_function(self, func): + setattr(self, "write", func) + + +_safe_io = None + + +def get_safe_io(): + global _safe_io + if _safe_io is None: + _safe_io = SafeIO() + return _safe_io + + +def safe_print(message, *args, **kwargs): + get_safe_io().write(message, *args, **kwargs) diff --git a/globus_sdk/utils/token_storage.py b/globus_sdk/utils/token_storage.py new file mode 100644 index 000000000..518b3986b --- /dev/null +++ b/globus_sdk/utils/token_storage.py @@ -0,0 +1,217 @@ +import logging +import time + +from globus_sdk.auth.client_types.native_client import NativeAppAuthClient +from globus_sdk.config import get_parser +from globus_sdk.exc import ConfigError, LoadedTokensExpired, RequestedScopesMismatch + +TOKEN_KEYS = [ + "scope", + "access_token", + "refresh_token", + "token_type", + "expires_at_seconds", + "resource_server", +] +REQUIRED_KEYS = ["scope", "access_token", "expires_at_seconds", "resource_server"] +CONFIG_TOKEN_GROUPS = "token_groups" + +logger = logging.getLogger(__name__) + + +def save_tokens(tokens, config_section=None): + """ + Save a dict of tokens in config_section, for the current configfile. + Tokens should be formatted like the following: + { + "auth.globus.org": { + "scope": "profile openid email", + "access_token": "", + "refresh_token": None, + "token_type": "Bearer", + "expires_at_seconds": 1539984535, + "resource_server": "auth.globus.org" + }, ... + } + """ + config = get_parser() + + cfg_tokens = _serialize_token_groups(tokens) + for key, value in cfg_tokens.items(): + config.set(key, value, section=config_section) + + +def load_tokens(config_section=None, requested_scopes=(), check_expired=True): + """ + Load Tokens from a config section in the configfile. If requested_scopes + is given, it will match against the loaded scopes and raise a + RequestedScopesMismatch exception if they differ from one another. + + check_expired will check the expires_at_seconds number against the time + the user last logged in, and raise LoadedTokensExpired if it is greater. + check_expired should be set to false if you want to use refresh tokens. + + Returns tokens in a similar format to token_response.by_resource_server: + { + "auth.globus.org": { + "scope": "profile openid email", + "access_token": "", + "refresh_token": None, + "token_type": "Bearer", + "expires_at_seconds": 1539984535, + "resource_server": "auth.globus.org" + }, ... + } + """ + + config = get_parser() + try: + cfg_tokens = config.get_section(config_section) + loaded_tokens = _deserialize_token_groups(cfg_tokens) + except Exception: + raise ConfigError("Error loading tokens from: {}".format(config_section)) + + for tok_set in loaded_tokens.values(): + missing = [mk for mk in REQUIRED_KEYS if not tok_set.get(mk)] + if any(missing): + raise ConfigError("Missing {} from loaded tokens".format(missing)) + + if requested_scopes: + scope_lists = [t["scope"].split() for t in loaded_tokens.values()] + loaded_scopes = {s for slist in scope_lists for s in slist} + if loaded_scopes.difference(set(requested_scopes)): + raise RequestedScopesMismatch( + "Requested Scopes differ from loaded scopes. Requested: " + "{}, Loaded: {}".format(requested_scopes, list(loaded_scopes)) + ) + + if check_expired is True: + expired = [ + time.time() >= t["expires_at_seconds"] for t in loaded_tokens.values() + ] + if any(expired): + raise LoadedTokensExpired() + + return loaded_tokens + + +def clear_tokens(config_section=None, client_id=None): + """Revokes and deletes tokens saved to disk. ``config_section`` is the + section where the tokens are stored, ``client_id`` must be a valid Globus + App. Returns True if tokens were revoked (or expired) and deleted, false + otherwise. Raises globus_sdk.exc.AuthAPIError if tokens are live and + client_id is invalid. + """ + tokens = [] + try: + naac = NativeAppAuthClient(client_id) + tokens = load_tokens(config_section=config_section, check_expired=True) + for tok_set in tokens.values(): + logger.debug("Revoking: {}".format(tok_set["resource_server"])) + naac.oauth2_revoke_token(tok_set["access_token"]) + except LoadedTokensExpired: + # If they expired, no need to revoke but fetch again for deletion + tokens = load_tokens(config_section=config_section, check_expired=False) + except ConfigError as ce: + logger.debug(ce) + + if not tokens: + return False + + cfg_tsets = _serialize_token_groups(tokens) + config = get_parser() + for cfg_token_name in cfg_tsets.keys(): + config.remove(cfg_token_name, section=config_section) + config.remove(CONFIG_TOKEN_GROUPS, section=config_section) + return True + + +def _serialize_token_groups(tokens): + """ + Take a dict of tokens organized by resource server and return a dict + that can be easily saved to the config file. + + Resource servers containing '.' in their name will automatically be + converted to '_' (auth.globus.org == auth_globus_org). This is only for + cosmetic reasons. A resource server named "foo=;# = !@#$%^&*()" will have + funky looking config keys, but saving/loading will behave normally. + + Int values are converted to string, None values are converted to empty + string. *No other types are checked*. + + `tokens` should be formatted: + { + "auth.globus.org": { + "scope": "profile openid email", + "access_token": "", + "refresh_token": None, + "token_type": "Bearer", + "expires_at_seconds": 1539984535, + "resource_server": "auth.globus.org" + }, ... + } + Returns a flat dict of tokens prefixed by resource server. + { + "auth_globus_org_scope": "profile openid email", + "auth_globus_org_access_token": "", + "auth_globus_org_refresh_token": "", + "auth_globus_org_token_type": "Bearer", + "auth_globus_org_expires_at_seconds": "1540051101", + "auth_globus_org_resource_server": "auth.globus.org", + "token_groups": "auth_globus_org" + }""" + serialized_items = {} + token_groups = [] + for token_set in tokens.values(): + token_groups.append(_serialize_token(token_set["resource_server"])) + for key, value in token_set.items(): + key_name = _serialize_token(token_set["resource_server"], key) + if isinstance(value, int): + value = str(value) + if value is None: + value = "" + serialized_items[key_name] = value + + serialized_items[CONFIG_TOKEN_GROUPS] = ",".join(token_groups) + return serialized_items + + +def _deserialize_token_groups(config_items): + """ + Takes a dict from a config section and returns a dict of tokens by + resource server. `config_items` is a raw dict of config options returned + from get_parser().get_section(). + + Returns tokens in the format: + { + "auth.globus.org": { + "scope": "profile openid email", + "access_token": "", + "refresh_token": None, + "token_type": "Bearer", + "expires_at_seconds": 1539984535, + "resource_server": "auth.globus.org" + }, ... + } + """ + token_groups = {} + + tsets = config_items.get(CONFIG_TOKEN_GROUPS) + config_token_groups = tsets.split(",") + for group in config_token_groups: + tset = {k: config_items.get(_deserialize_token(group, k)) for k in TOKEN_KEYS} + tset["expires_at_seconds"] = int(tset["expires_at_seconds"]) + # Config loaded 'null' values will be an empty string. Set these to + # None for consistency + tset = {k: v if v else None for k, v in tset.items()} + token_groups[tset["resource_server"]] = tset + + return token_groups + + +def _deserialize_token(grouping, token): + return "{}{}".format(grouping, token) + + +def _serialize_token(resource_server, token=""): + return "{}_{}".format(resource_server.replace(".", "_"), token) diff --git a/tests/files/sample_configs/set_test.cfg b/tests/files/sample_configs/set_test.cfg new file mode 100644 index 000000000..3585e0f21 --- /dev/null +++ b/tests/files/sample_configs/set_test.cfg @@ -0,0 +1,2 @@ +[default] +option = general_value diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 000000000..319f38df5 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,15 @@ +import tempfile + +import pytest + +import globus_sdk + + +@pytest.yield_fixture +def temp_config(): + globus_sdk.config._parser = None + temp_config = tempfile.NamedTemporaryFile() + cfg = globus_sdk.config.get_parser() + cfg.set_write_config_file(temp_config.name) + yield cfg + temp_config.close() diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index a84c353c2..a388ebec2 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -1,5 +1,6 @@ import os from contextlib import contextmanager +from tempfile import NamedTemporaryFile import pytest import six @@ -282,3 +283,38 @@ def test_get_globus_environ_production(): del os.environ["GLOBUS_SDK_ENVIRONMENT"] # ensure that passing a value returns that value assert globus_sdk.config.get_globus_environ("production") == "default" + + +def test_verify_set_config_file(): + new_config = NamedTemporaryFile() + with custom_config(""): + cfg = globus_sdk.config.get_parser() + cfg.set_write_config_file(new_config.name) + assert cfg._write_path == new_config.name + + +def test_verify_load_from_new_config_file(): + with custom_config(""), NamedTemporaryFile(mode="w+") as new_cfg: + new_cfg.file.write("[default]\noption = general_value\n") + new_cfg.file.flush() + + cfg = globus_sdk.config.get_parser() + cfg.set_write_config_file(new_cfg.name) + assert cfg.get("option", "default") == "general_value" + + +def test_verify_write_config_option(temp_config): + with custom_config(""): + temp_config.set("foo", "bar", "mysec") + assert temp_config.get("foo", "mysec") == "bar" + + temp_config.set("baz", "car", "new_sec") + assert temp_config.get("baz", "new_sec") == "car" + + +def test_verify_remove_config_option(temp_config): + with custom_config(""): + temp_config.set("foo", "bar", "mysec") + assert temp_config.get("foo", "mysec") == "bar" + temp_config.remove("foo", "mysec") + assert temp_config.get("foo", "mysec") is None diff --git a/tests/unit/test_config_saving.py b/tests/unit/test_config_saving.py new file mode 100644 index 000000000..8f78b0436 --- /dev/null +++ b/tests/unit/test_config_saving.py @@ -0,0 +1,51 @@ +# """ +# Test Config Saving produces expected results. +# """ +# +# import os +# from tempfile import NamedTemporaryFile +# +# import globus_sdk.config +# +# from tests.framework import get_fixture_file_dir, CapturedIOTestCase +# +# SET_CFG = os.path.join(get_fixture_file_dir(), 'sample_configs', +# 'set_test.cfg') +# +# +# class ConfigSaveTests(CapturedIOTestCase): +# +# def setUp(self): +# self.parser = globus_sdk.config.get_parser() +# self.config = NamedTemporaryFile() +# +# def tearDown(self): +# globus_sdk.config._parser = None +# self.config.close() +# +# def test_verify_set_config_file(self): +# self.parser.set_write_config_file(self.config.name) +# assert self.parser._write_path == self.config.name +# +# def test_verify_load_from_new_config_file(self): +# with open(SET_CFG) as ch, NamedTemporaryFile(mode='w+') as new_cfg: +# new_cfg.file.write(ch.read()) +# new_cfg.file.flush() +# +# self.parser.set_write_config_file(new_cfg.name) +# assert self.parser.get('option', 'default') == 'general_value' +# +# def test_verify_write_config_option(self): +# self.parser.set_write_config_file(self.config.name) +# self.parser.set('foo', 'bar', 'mysec') +# assert self.parser.get('foo', 'mysec') == 'bar' +# +# self.parser.set('baz', 'car', 'new_sec') +# assert self.parser.get('baz', 'new_sec') == 'car' +# +# def test_verify_remove_config_option(self): +# self.parser.set_write_config_file(self.config.name) +# self.parser.set('foo', 'bar', 'mysec') +# assert self.parser.get('foo', 'mysec') == 'bar' +# self.parser.remove('foo', 'mysec') +# assert self.parser.get('foo', 'mysec') is None diff --git a/tests/unit/test_native_auth.py b/tests/unit/test_native_auth.py new file mode 100644 index 000000000..868dc0eab --- /dev/null +++ b/tests/unit/test_native_auth.py @@ -0,0 +1,254 @@ +import copy +import uuid +import webbrowser +from time import time + +import pytest + +import globus_sdk +from globus_sdk.auth import oauth2_native_app_shortcut +from globus_sdk.auth.oauth2_native_app_shortcut import ( + AUTH_CODE_REDIRECT, + NATIVE_AUTH_DEFAULTS as NA_DEF, + native_auth, +) +from globus_sdk.utils import token_storage + +try: + import mock +except ImportError: + from unittest import mock + +MOCK_TOKENS = { + "auth.globus.org": { + "scope": "profile openid email", + "access_token": "9d0e6f2a21917cc3e04602838e0ba4f7df3399bbd49f1" + "5db3cf0af34d52c928f34f639444af0b28695086d97b1", + "refresh_token": None, + "token_type": "Bearer", + "expires_at_seconds": int(time()) + 60 * 60, + "resource_server": "auth.globus.org", + } +} + +MOCK_AUTH_CODE = "foobarbaz" + + +@pytest.fixture +def mock_webbrowser(monkeypatch): + monkeypatch.setattr(webbrowser, "open", mock.Mock()) + + +@pytest.fixture +def mock_save_tokens(monkeypatch): + mock_save = mock.Mock() + monkeypatch.setattr(oauth2_native_app_shortcut, "save_tokens", mock_save) + return mock_save + + +@pytest.fixture +def mock_clear_tokens(monkeypatch): + mocked_clear_tokens = mock.Mock() + monkeypatch.setattr(oauth2_native_app_shortcut, "clear_tokens", mocked_clear_tokens) + return mocked_clear_tokens + + +@pytest.yield_fixture +def mock_local_server(monkeypatch): + mock_server = mock.Mock() + mock_server.__enter__ = mock.Mock(return_value=mock.Mock()) + mock_server.__exit__ = mock.Mock(return_value=mock.Mock()) + mock_start = mock.Mock(return_value=mock_server) + monkeypatch.setattr( + "globus_sdk.auth.oauth2_native_app_shortcut.start_local_server", mock_start + ) + return mock_start + + +@pytest.fixture +def mock_native_client(monkeypatch): + mock_client = mock.Mock() + mock_class = mock.Mock(return_value=mock_client) + token_response = mock.Mock() + mock_client.oauth2_exchange_code_for_tokens.return_value = token_response + token_response.by_resource_server = MOCK_TOKENS + monkeypatch.setattr( + "globus_sdk.auth.oauth2_native_app_shortcut.NativeAppAuthClient", mock_class + ) + return mock_class, mock_client + + +@pytest.fixture +def mock_native_client_simple(mock_native_client): + client_class, _ = mock_native_client + return client_class + + +@pytest.yield_fixture +def saved_tokens(temp_config, mock_native_client): + token_storage.save_tokens(MOCK_TOKENS, NA_DEF["client_id"]) + + +@pytest.fixture +def expired_saved_tokens(temp_config): + expired = copy.deepcopy(MOCK_TOKENS) + expired["auth.globus.org"]["expires_at_seconds"] = int(time()) - 1 + token_storage.save_tokens(expired, NA_DEF["client_id"]) + + +@pytest.fixture +def mock_input(monkeypatch): + mock_input = mock.Mock() + monkeypatch.setattr("globus_sdk.auth.oauth2_native_app_shortcut.input", mock_input) + return mock_input + + +@pytest.fixture +def mock_safe_print(monkeypatch): + mock_print = mock.Mock() + monkeypatch.setattr("globus_sdk.auth.oauth2_native_app_shortcut.input", mock_print) + return mock_print + + +def test_native_auth( + mock_webbrowser, mock_local_server, mock_native_client, mock_save_tokens +): + native_app_class, native_app_client = mock_native_client + + native_auth() + + native_app_class.assert_called_with(client_id=NA_DEF["client_id"]) + native_app_client.oauth2_start_flow.assert_called_with( + requested_scopes=NA_DEF["requested_scopes"], + redirect_uri=NA_DEF["redirect_uri"], + refresh_tokens=NA_DEF["refresh_tokens"], + prefill_named_grant=NA_DEF["prefill_named_grant"], + ) + native_app_client.oauth2_get_authorize_url.assert_called_with(additional_params={}) + mock_local_server.assert_called_with( + listen=(NA_DEF["server_hostname"], NA_DEF["server_port"]) + ) + assert native_app_client.oauth2_exchange_code_for_tokens.called + assert not mock_save_tokens.called + assert webbrowser.open.called + + +def test_invalid_option_raises_error(mock_native_client, mock_local_server): + with pytest.raises(ValueError): + native_auth(conquer_the_world=True) + + +def test_native_auth_saving_tokens( + mock_save_tokens, mock_native_client, mock_local_server, temp_config +): + native_auth(save_tokens=True) + assert mock_save_tokens.called + + +def test_native_auth_loading_tokens( + mock_native_client_simple, mock_local_server, saved_tokens +): + native_auth() + # assert tokens were loaded and a native flow was not started + assert not mock_native_client_simple.called + + +def test_native_auth_force_login( + mock_native_client_simple, mock_local_server, saved_tokens, mock_clear_tokens +): + # Should disregard previously saved tokens + native_auth(force_login=True) + assert mock_native_client_simple.called + assert mock_clear_tokens.called + + +def test_native_auth_requested_scope_check( + mock_native_client_simple, mock_local_server, saved_tokens, mock_clear_tokens +): + # Scopes here are different than what was saved, so this should + # trigger an auth flow + native_auth(requested_scopes=("foo",)) + assert mock_native_client_simple.called + assert mock_clear_tokens.called + + +def test_native_auth_expired_token_check( + mock_native_client_simple, mock_local_server, expired_saved_tokens +): + native_auth() + assert mock_native_client_simple.called + + +def test_native_auth_expired_token_no_check( + mock_native_client_simple, mock_local_server, expired_saved_tokens +): + native_auth(check_tokens_expired=False) + assert not mock_native_client_simple.called + + +def test_native_auth_no_local_server( + mock_local_server, mock_native_client, temp_config, mock_input, mock_safe_print +): + native_app_class, native_app_client = mock_native_client + + native_auth(no_local_server=True) + native_app_client.oauth2_start_flow.assert_called_with( + requested_scopes=NA_DEF["requested_scopes"], + redirect_uri=AUTH_CODE_REDIRECT, + refresh_tokens=NA_DEF["refresh_tokens"], + prefill_named_grant=NA_DEF["prefill_named_grant"], + ) + assert native_app_class.called + assert not mock_local_server.called + assert mock_safe_print.called + assert not mock_input.called + + +def test_native_auth_no_browser( + mock_webbrowser, mock_local_server, mock_native_client_simple, mock_safe_print +): + native_auth(no_browser=True) + # Assert a native flow was not started + assert mock_native_client_simple.called + assert mock_local_server.called + assert not webbrowser.open.called + + +def test_native_auth_custom_config_section( + mock_native_client_simple, mock_local_server, temp_config +): + my_section = "my_section" + native_auth(config_section="my_section", save_tokens=True) + assert my_section in globus_sdk.config.get_parser()._parser.sections() + + +def test_native_auth_ancillary_options( + mock_webbrowser, mock_local_server, mock_native_client, mock_save_tokens +): + """Options here don't change the control flow and should not affect + one another. This test asserts they're present in expected places""" + native_class, native_client = mock_native_client + options = { + "client_id": str(uuid.uuid4()), + "redirect_uri": "http://example.com/login", + "requested_scopes": ("myscope", "myotherscope"), + "refresh_tokens": True, + "prefill_named_grant": "Captain Hammer's Lenovo", + "additional_auth_params": {"session_message": "hello!"}, + "server_hostname": "localhost", + "server_port": 9999, + } + native_auth(**options) + native_class.assert_called_with(client_id=options["client_id"]) + native_client.oauth2_start_flow.assert_called_with( + requested_scopes=options["requested_scopes"], + redirect_uri=options["redirect_uri"], + refresh_tokens=options["refresh_tokens"], + prefill_named_grant=options["prefill_named_grant"], + ) + native_client.oauth2_get_authorize_url.assert_called_with( + additional_params=options["additional_auth_params"] + ) + mock_local_server.assert_called_with( + listen=(options["server_hostname"], options["server_port"]) + ) diff --git a/tests/unit/test_utils_local_server.py b/tests/unit/test_utils_local_server.py new file mode 100644 index 000000000..f1cc74cb4 --- /dev/null +++ b/tests/unit/test_utils_local_server.py @@ -0,0 +1,59 @@ +import threading + +import httpretty +import pytest +import requests +from six.moves.urllib.parse import urlencode + +from globus_sdk.exc import LocalServerError +from globus_sdk.utils.local_server import start_local_server + + +class LocalServerTester: + def __init__(self): + self.server_response = None + + def _wait_for_code(self, server): + try: + self.server_response = server.wait_for_code() + except Exception as e: + self.server_response = e + + def test(self, response_params): + """ + Start a local server to wait for an 'auth_code'. Usually the user's + browser will redirect to this location, but in this case the user is + mocked with a separate request in another thread. + + Waits for threads to complete and returns the local_server response. + """ + with start_local_server() as server: + thread = threading.Thread(target=self._wait_for_code, args=(server,)) + thread.start() + host, port = server.server_address + url = "http://{}:{}/?{}".format( + "127.0.0.1", port, urlencode(response_params) + ) + requests.get(url) + thread.join() + return self.server_response + + +@pytest.yield_fixture +def test_server(): + httpretty.disable() + yield LocalServerTester() + httpretty.enable() + + +def test_local_server_with_auth_code(test_server): + MOCK_AUTH_CODE = ( + "V2UgY2FuJ3Qgd2FpdCBmb3IgY29kZXMgZm9yZXZlci4g" + "V2VsbCwgd2UgY2FuIGJ1dCBJIGRvbid0IHdhbnQgdG8u" + ) + assert test_server.test({"code": MOCK_AUTH_CODE}) == MOCK_AUTH_CODE + + +def test_local_server_with_error(test_server): + response = test_server.test({"error": "bad things happened"}) + assert isinstance(response, LocalServerError) diff --git a/tests/unit/test_utils_safe_io.py b/tests/unit/test_utils_safe_io.py new file mode 100644 index 000000000..592b0ac7d --- /dev/null +++ b/tests/unit/test_utils_safe_io.py @@ -0,0 +1,27 @@ +from tempfile import NamedTemporaryFile + +from globus_sdk.utils import safeio + +try: + import mock +except ImportError: + from unittest import mock + + +def test_safe_print_custom_output(): + my_log_file = NamedTemporaryFile() + + def my_logger(message): + with open(my_log_file.name, "w+") as lfh: + lfh.write(message) + + safeio.get_safe_io().set_write_function(my_logger) + safeio.safe_print("The hamsters are attacking!") + with open(my_log_file.name) as lfh: + assert lfh.read() == "The hamsters are attacking!" + + +def test_safe_print_normally(): + with mock.patch("globus_sdk.utils.safeio._safe_io") as sio: + safeio.safe_print("foo") + sio.write.assert_called_once_with("foo") diff --git a/tests/unit/test_utils_token_storage.py b/tests/unit/test_utils_token_storage.py new file mode 100644 index 000000000..4f66b8c9f --- /dev/null +++ b/tests/unit/test_utils_token_storage.py @@ -0,0 +1,117 @@ +import copy +from time import time + +import pytest + +from globus_sdk import config +from globus_sdk.exc import ConfigError, LoadedTokensExpired, RequestedScopesMismatch +from globus_sdk.utils.token_storage import clear_tokens, load_tokens, save_tokens + +try: + import mock +except ImportError: + from unittest import mock + +TOKEN_LIFETIME = 60 * 60 * 24 + +MOCK_TOKENS = { + "auth.globus.org": { + "scope": "profile openid email", + "access_token": "9d0e6f2a21917cc3e04602838e0ba4f7df3399bbd49f1" + "5db3cf0af34d52c928f34f639444af0b28695086d97b1", + "refresh_token": None, + "token_type": "Bearer", + "expires_at_seconds": int(time()) + TOKEN_LIFETIME, + "resource_server": "auth.globus.org", + }, + "workhorse.org": { + "scope": "all", + "access_token": "QmFkIEhvcnNlLCBCYWQgSG9yc2UuIEJhZCBIb3JzZSwg" + "QmFkIEhvcnNlLiBIZSByaWRlcyBhY3Jvc3MgdGhlIG5h", + "refresh_token": "VGhlIGV2aWwgbGVhZ3VlIG9mIGV2aWwsIGlzIHdhdGNo" + "aW5nIHNvIGJld2FyZS4gVGhlIGdyYWRlIHRoYXQgeW8=", + "token_type": "Bearer", + "expires_at_seconds": int(time()) + TOKEN_LIFETIME, + "resource_server": "workhorse.org", + }, +} + + +@pytest.fixture +def mock_expired_tokens(): + expired_tokens = copy.deepcopy(MOCK_TOKENS) + for _, token_set in expired_tokens.items(): + token_set["expires_at_seconds"] = int(time()) - 1 + return expired_tokens + + +@pytest.fixture +def mock_native_app(monkeypatch): + mock_client = mock.MagicMock() + monkeypatch.setattr( + "globus_sdk.utils.token_storage.NativeAppAuthClient", mock_client + ) + return mock_client + + +def test_save_and_load_tokens_matches_original(temp_config): + save_tokens(MOCK_TOKENS, "test") + tokens = load_tokens("test") + for set_name, set_values in tokens.items(): + loaded_set = set(set_values.values()) + mock_set = set(MOCK_TOKENS[set_name].values()) + assert not loaded_set.difference(mock_set) + + +def test_loading_bad_tokens_raises_error(temp_config): + save_tokens(MOCK_TOKENS, "test") + temp_config.remove("workhorse_org_access_token", "test") + with pytest.raises(ConfigError): + load_tokens("test") + + +def test_loading_raises_error_if_tokens_expire(temp_config, mock_expired_tokens): + save_tokens(mock_expired_tokens, "test") + with pytest.raises(LoadedTokensExpired): + load_tokens("test") + + +def test_loading_raises_error_if_scopes_differ(temp_config): + save_tokens(MOCK_TOKENS, "test") + transfer_scope = ("urn:globus:auth:scope:transfer.api.globus.org:all",) + with pytest.raises(RequestedScopesMismatch): + load_tokens("test", requested_scopes=transfer_scope) + + +def test_verify_clear_tokens(temp_config, mock_native_app): + save_tokens(MOCK_TOKENS, "test") + section = config.get_parser().get_section("test") + assert len(section.values()) == 13 + return_value = clear_tokens("test", "my_client_id") + section = config.get_parser().get_section("test") + assert len(section.values()) == 0 + mock_native_app.assert_called_with("my_client_id") + assert return_value is True + + +def test_clear_tokens_with_no_saved_tokens(temp_config, mock_native_app): + return_value = clear_tokens("test", "my_client_id") + assert return_value is False + + +def test_clear_expired_tokens(temp_config, mock_expired_tokens, mock_native_app): + save_tokens(mock_expired_tokens, "test") + section = config.get_parser().get_section("test") + assert len(section.values()) == 13 + return_value = clear_tokens("test", "my_client_id") + section = config.get_parser().get_section("test") + assert len(section.values()) == 0 + mock_native_app.assert_called_with("my_client_id") + assert return_value is True + + +def test_clear_tokens_with_invalid_client_raises_error(temp_config): + save_tokens(MOCK_TOKENS, "test") + config.get_parser().remove("workhorse_org_access_token", "test") + with pytest.raises(ConfigError): + load_tokens("test")