diff --git a/globus_sdk/__init__.py b/globus_sdk/__init__.py
index 3b1792453..30bc09729 100644
--- a/globus_sdk/__init__.py
+++ b/globus_sdk/__init__.py
@@ -1,6 +1,11 @@
import logging
-from globus_sdk.auth import AuthClient, ConfidentialAppAuthClient, NativeAppAuthClient
+from globus_sdk.auth import (
+ AuthClient,
+ ConfidentialAppAuthClient,
+ NativeAppAuthClient,
+ native_auth,
+)
from globus_sdk.authorizers import (
AccessTokenAuthorizer,
BasicAuthorizer,
@@ -48,6 +53,7 @@
"ClientCredentialsAuthorizer",
"AuthClient",
"NativeAppAuthClient",
+ "native_auth",
"ConfidentialAppAuthClient",
"TransferClient",
"TransferData",
diff --git a/globus_sdk/auth/__init__.py b/globus_sdk/auth/__init__.py
index 6e292c2bc..f022eed02 100644
--- a/globus_sdk/auth/__init__.py
+++ b/globus_sdk/auth/__init__.py
@@ -5,6 +5,7 @@
)
from globus_sdk.auth.oauth2_authorization_code import GlobusAuthorizationCodeFlowManager
from globus_sdk.auth.oauth2_native_app import GlobusNativeAppFlowManager
+from globus_sdk.auth.oauth2_native_app_shortcut import native_auth
__all__ = [
"AuthClient",
@@ -12,4 +13,5 @@
"ConfidentialAppAuthClient",
"GlobusNativeAppFlowManager",
"GlobusAuthorizationCodeFlowManager",
+ "native_auth",
]
diff --git a/globus_sdk/auth/oauth2_native_app_shortcut.py b/globus_sdk/auth/oauth2_native_app_shortcut.py
new file mode 100644
index 000000000..e1299a9e7
--- /dev/null
+++ b/globus_sdk/auth/oauth2_native_app_shortcut.py
@@ -0,0 +1,183 @@
+import logging
+import os
+import webbrowser
+from socket import gethostname
+
+from six.moves import input
+
+from globus_sdk.auth.client_types.native_client import NativeAppAuthClient
+from globus_sdk.auth.oauth2_constants import DEFAULT_REQUESTED_SCOPES
+from globus_sdk.exc import ConfigError
+from globus_sdk.utils.local_server import is_remote_session, start_local_server
+from globus_sdk.utils.safeio import safe_print
+from globus_sdk.utils.token_storage import clear_tokens, load_tokens, save_tokens
+
+logger = logging.getLogger(__name__)
+
+AUTH_CODE_REDIRECT = "https://auth.globus.org/v2/web/auth-code"
+
+NATIVE_AUTH_DEFAULTS = {
+ "config_filename": os.path.expanduser("~/.globus-native-apps.cfg"),
+ "config_section": None, # Defaults to client_id if not set
+ "client_id": "0af96eea-fec8-4d6e-aad2-c87feed8151c",
+ "requested_scopes": DEFAULT_REQUESTED_SCOPES,
+ "refresh_tokens": False,
+ "prefill_named_grant": gethostname(),
+ "additional_auth_params": {},
+ "save_tokens": False,
+ "check_tokens_expired": True,
+ "force_login": False,
+ "no_local_server": False,
+ "no_browser": False,
+ "server_hostname": "127.0.0.1",
+ "server_port": 8890,
+ "redirect_uri": "http://localhost:8890/",
+}
+
+
+def native_auth(**kwargs):
+ """
+ Provides a simple shortcut for doing a native auth flow for most use-cases
+ by setting common defaults for frequently used fields. Although a default
+ client id is provided, production apps should define their own at
+ https://developers.globus.org. See `NativeAppAuthClient` for constructing
+ a more fine-tuned native auth flow. Returns tokens organized by resource
+ server.
+
+ **Native App Parameters**
+ ``client_id`` (*string*)
+ Client App id registered at https://developers.globus.org. Defaults
+ to a built-in one for testing.
+
+ ``requested_scopes`` (*iterable* or *string*)
+ The scopes on the token(s) being requested, as a space-separated
+ string or iterable of strings. Defaults to ``openid profile email
+ urn:globus:auth:scope:transfer.api.globus.org:all``
+
+ ``redirect_uri`` (*string*)
+ The page that users should be directed to after authenticating at
+ the authorize URL. Defaults to
+ 'https://auth.globus.org/v2/web/auth-code', which displays the
+ resulting ``auth_code`` for users to copy-paste back into your
+ application (and thereby be passed back to the
+ ``GlobusNativeAppFlowManager``)
+
+ ``refresh_tokens`` (*bool*)
+ When True, request refresh tokens in addition to access tokens
+
+ ``prefill_named_grant`` (*string*)
+ Optionally prefill the named grant label on the consent page
+
+ ``additional_auth_params`` (*dict*)
+ Set ``additional_parameters`` in
+ NativeAppAuthClient.oauth2_get_authorize_url()
+
+ **Login Parameters**
+ ``save_tokens`` (*bool*)
+ Save user tokens to disk and reload them on repeated calls.
+ Defaults to False.
+
+ ``check_tokens_expired`` (*bool*)
+ Check if loaded access tokens have expired since the last login.
+ You should set this to False if using Refresh Tokens.
+ Defaults to True.
+
+ ``force_login`` (*bool*)
+ Do not attempt to load save tokens, and complete a new auth flow
+ instead. Defaults to False.
+
+ ``no_local_server`` (*bool*)
+ Do not start a local server for fetching the auth_code. Setting
+ this to false will require the user to copy paste a code into
+ the console. Defaults to False.
+
+ ``no_browser`` (*bool*)
+ Do not automatically attempt to open a browser for the auth flow.
+ Defaults to False.
+
+ ``server_hostname`` (*string*)
+ Hostname for the local server to use. No effect if
+ ``no_local_server`` is set. MUST be specified in ``redirect_uri``.
+ Defaults to 127.0.0.1.
+
+ ``server_port`` (*string*)
+ Port for the local server to use. No effect if ``no_local_server``
+ is set. MUST be specified in ``redirect_uri``. Defaults to 8890.
+
+ **Configfile Parameters**
+ ``config_filename`` (*string*)
+ Filename to use for reading and writing values.
+
+ ``config_section`` (*string*)
+ Section within the config file to store information (like tokens).
+
+ **Examples**
+
+ ``native_auth()``
+
+ Or to save tokens: ``native_auth(save_tokens=True)``
+ """
+ unaccepted = [k for k in kwargs.keys() if k not in NATIVE_AUTH_DEFAULTS.keys()]
+ if any(unaccepted):
+ raise ValueError("Invalid args: {}".format(unaccepted))
+
+ opts = {k: kwargs.get(k, v) for k, v in NATIVE_AUTH_DEFAULTS.items()}
+
+ # Default to the auth-code page redirect if the user is copy-pasting
+ if (
+ opts["no_local_server"] is True
+ and opts["redirect_uri"] == NATIVE_AUTH_DEFAULTS["redirect_uri"]
+ ):
+ opts["redirect_uri"] = AUTH_CODE_REDIRECT
+
+ config_section = opts["config_section"] or opts["client_id"]
+
+ if opts["force_login"] is False:
+ try:
+ return load_tokens(
+ config_section, opts["requested_scopes"], opts["check_tokens_expired"]
+ )
+ except ConfigError as ce:
+ logger.debug(
+ "Loading Tokens Failed, doing auth flow instead. "
+ "Error: {}".format(ce)
+ )
+
+ # Clear previous tokens to ensure no previously saved scopes remain.
+ clear_tokens(config_section=config_section, client_id=opts["client_id"])
+
+ client = NativeAppAuthClient(client_id=opts["client_id"])
+ client.oauth2_start_flow(
+ requested_scopes=opts["requested_scopes"],
+ redirect_uri=opts["redirect_uri"],
+ refresh_tokens=opts["refresh_tokens"],
+ prefill_named_grant=opts["prefill_named_grant"],
+ )
+ url = client.oauth2_get_authorize_url(
+ additional_params=opts["additional_auth_params"]
+ )
+
+ if opts["no_local_server"] is False:
+ server_address = (opts["server_hostname"], opts["server_port"])
+ with start_local_server(listen=server_address) as server:
+ _prompt_login(url, opts["no_browser"])
+ auth_code = server.wait_for_code()
+ else:
+ _prompt_login(url, opts["no_browser"])
+ safe_print("Enter the resulting Authorization Code here: ", end="")
+ auth_code = input()
+
+ token_response = client.oauth2_exchange_code_for_tokens(auth_code)
+ tokens_by_resource_server = token_response.by_resource_server
+ if opts["save_tokens"] is True:
+ save_tokens(tokens_by_resource_server, config_section)
+ # return a set of tokens, organized by resource server name
+
+ return tokens_by_resource_server
+
+
+def _prompt_login(url, no_browser):
+ if no_browser is False and not is_remote_session():
+ webbrowser.open(url, new=1)
+ else:
+ safe_print("Please paste the following URL in a browser: " "\n{}".format(url))
diff --git a/globus_sdk/config.py b/globus_sdk/config.py
index 62c83f15d..350550be5 100644
--- a/globus_sdk/config.py
+++ b/globus_sdk/config.py
@@ -44,11 +44,13 @@ class GlobusConfigParser(object):
"""
_GENERAL_CONF_SECTION = "general"
+ DEFAULT_WRITE_PATH = os.path.expanduser("~/.globus-native-apps.cfg")
def __init__(self):
logger.debug("Loading SDK Config parser")
self._parser = ConfigParser()
self._load_config()
+ self._write_path = self.DEFAULT_WRITE_PATH
logger.debug("Config load succeeded")
def _load_config(self):
@@ -134,6 +136,86 @@ def get(
return value
+ def get_section(self, section):
+ """Attempt to lookup a section in the config file. Returns None
+ if no section is found."""
+ if self._parser.has_section(section):
+ return dict(self._parser.items(section))
+
+ def set_write_config_file(self, filename):
+ """Set a new config file for writing to disk. Attempts to load
+ the new config if one exists but will not raise an error if this
+ fails. Future config values will be written to this location."""
+ logger.debug("New config file set to: {}".format(filename))
+ try:
+ self._parser.read([filename])
+ except Exception:
+ logger.debug("New config failed to load: {}".format(filename))
+ self._write_path = filename
+
+ def _get_write_config(self):
+ """Get the config for the current configured self._write_path. If it
+ does not exist, an empty config is returned instead. General config
+ values will not be included in the returned config so they aren't
+ copied and written to disk.
+ """
+ if self._write_path is None:
+ raise GlobusError(
+ "Failed to write to the config file {}, please ensure you "
+ "have write access.".format(self._write_path)
+ )
+
+ cfg = ConfigParser()
+ if not os.path.exists(self._write_path):
+ cfg[self._GENERAL_CONF_SECTION] = {}
+ else:
+ cfg.read(self._write_path)
+
+ return cfg
+
+ def _save(self, cfg):
+ """Saves config options to disk at the file self._write_path. The
+ config file permissions are also always set to only allow User access
+ to the config file for a little bit of added security."""
+
+ # deny rwx to Group and World -- don't bother storing the returned
+ # old mask value, since we'll never restore it anyway
+ # do this on every call to ensure that we're always consistent about it
+ os.umask(0o077)
+ with open(self._write_path, "w") as configfile:
+ cfg.write(configfile)
+
+ def set(self, option, value, section):
+ """
+ Write an option to disk using the previously configured config
+ at set_config_file() or .globus.cfg. Creates the section if it does
+ not exist.
+ """
+ cfg = self._get_write_config()
+
+ # add the section if absent
+ if section not in cfg.sections():
+ cfg.add_section(section)
+
+ cfg.set(section, option, value)
+ self._save(cfg)
+
+ # Update the Global config
+ if section not in self._parser.sections():
+ self._parser.add_section(section)
+ self._parser.set(section, option, value)
+
+ def remove(self, option, section):
+ """
+ Remove an option from the config. True if option previously existed,
+ false otherwise.
+ """
+ cfg = self._get_write_config()
+ removed = cfg.remove_option(section, option)
+ self._save(cfg)
+ self._parser.remove_option(section, option)
+ return removed
+
def _get_parser():
"""
@@ -146,6 +228,14 @@ def _get_parser():
return _parser
+def get_parser():
+ """
+ Historically components only needed read-only access. Since token storage,
+ new components may need to lookup config values or occasionally save data
+ """
+ return _get_parser()
+
+
# at import-time, it's None
_parser = None
diff --git a/globus_sdk/exc.py b/globus_sdk/exc.py
index 1547ec7d1..b10537cdb 100644
--- a/globus_sdk/exc.py
+++ b/globus_sdk/exc.py
@@ -271,6 +271,22 @@ class GlobusConnectionError(NetworkError):
"""A connection error occured while making a REST request."""
+class ConfigError(GlobusError):
+ """An error reading or writing from the configuration file."""
+
+
+class RequestedScopesMismatch(ConfigError):
+ """Requested scopes differ from scopes saved to config."""
+
+
+class LoadedTokensExpired(ConfigError):
+ """Tokens loaded from disk have expired since last login."""
+
+
+class LocalServerError(GlobusError):
+ """Error encountered with local server used for native auth."""
+
+
def convert_request_exception(exc):
"""Converts incoming requests.Exception to a Globus NetworkError"""
diff --git a/globus_sdk/utils/local_server.py b/globus_sdk/utils/local_server.py
new file mode 100644
index 000000000..f7ba00673
--- /dev/null
+++ b/globus_sdk/utils/local_server.py
@@ -0,0 +1,153 @@
+import logging
+import os
+import sys
+import threading
+from contextlib import contextmanager
+from string import Template
+
+import six
+from six.moves import http_client, queue
+from six.moves.urllib.parse import parse_qsl, urlparse
+
+from globus_sdk.exc import LocalServerError
+
+try:
+ from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler
+except ImportError:
+ from http.server import HTTPServer, BaseHTTPRequestHandler
+
+
+HTML_TEMPLATE = Template(
+ """
+
+
+
+
+
+ Globus SDK Login
+
+
+
+
+
+
+ Globus SDK
+
+ $login_result. You may close this tab.
+
+ $post_login_message
+
+
+
+
+"""
+)
+
+DOC_URL = """
+SDK Documentation
+"""
+
+
+def enable_requests_logging():
+ http_client.HTTPConnection.debuglevel = 4
+
+ logging.basicConfig()
+ logging.getLogger().setLevel(logging.DEBUG)
+ requests_log = logging.getLogger("requests.packages.urllib3")
+ requests_log.setLevel(logging.DEBUG)
+ requests_log.propagate = True
+
+
+def is_remote_session():
+ return os.environ.get("SSH_TTY", os.environ.get("SSH_CONNECTION"))
+
+
+class RedirectHandler(BaseHTTPRequestHandler):
+ def do_GET(self): # noqa
+ self.send_response(200)
+ self.send_header("Content-type", "text/html")
+ self.end_headers()
+
+ query_params = dict(parse_qsl(urlparse(self.path).query))
+ code = query_params.get("code")
+ if code:
+ self.wfile.write(
+ six.b(
+ HTML_TEMPLATE.substitute(
+ post_login_message=DOC_URL, login_result="Login successful"
+ )
+ )
+ )
+ self.server.return_code(code)
+ else:
+ msg = query_params.get("error_description", query_params.get("error"))
+
+ self.wfile.write(
+ six.b(
+ HTML_TEMPLATE.substitute(
+ post_login_message=msg, login_result="Login failed"
+ )
+ )
+ )
+
+ self.server.return_code(LocalServerError(msg))
+
+ def log_message(self, format, *args):
+ return
+
+
+class RedirectHTTPServer(HTTPServer, object):
+ def __init__(self, listen, handler_class):
+ super(RedirectHTTPServer, self).__init__(listen, handler_class)
+
+ self._auth_code_queue = queue.Queue()
+
+ def handle_error(self, request, client_address):
+ exctype, excval, exctb = sys.exc_info()
+ self._auth_code_queue.put(excval)
+
+ def return_code(self, code):
+ self._auth_code_queue.put_nowait(code)
+
+ def wait_for_code(self):
+ # workaround for handling control-c interrupt.
+ # relevant Python issue discussing this behavior:
+ # https://bugs.python.org/issue1360
+ try:
+ return self._auth_code_queue.get(block=True, timeout=3600)
+ except (queue.Empty, KeyboardInterrupt):
+ raise LocalServerError()
+ finally:
+ # shutdown() stops the server thread
+ # https://github.com/python/cpython/blob/3.7/Lib/socketserver.py#L241
+ self.shutdown()
+ # server_close() closes the socket:
+ # https://github.com/python/cpython/blob/3.7/Lib/socketserver.py#L474
+ self.server_close()
+
+
+@contextmanager
+def start_local_server(listen=("", 0)):
+ server = RedirectHTTPServer(listen, RedirectHandler)
+ thread = threading.Thread(target=server.serve_forever)
+ thread.daemon = True
+ thread.start()
+
+ yield server
diff --git a/globus_sdk/utils/safeio.py b/globus_sdk/utils/safeio.py
new file mode 100644
index 000000000..627fbafd8
--- /dev/null
+++ b/globus_sdk/utils/safeio.py
@@ -0,0 +1,34 @@
+from __future__ import print_function
+
+import sys
+
+
+class SafeIO(object):
+ """SafeIO allows developers to change how the SDK prints output strings
+ to the user. By default, it provides a generic 'write()' method for
+ printing strings to stdout, but can be changed if needed."""
+
+ def write(self, message, *args, **kwargs):
+ print_kwargs = {
+ k: arg for k, arg in kwargs.items() if k in ("sep", "end", "file", "flush")
+ }
+ print_kwargs["file"] = print_kwargs.get("file") or sys.stderr
+ messages = [message] + list(args)
+ print(*messages, **print_kwargs)
+
+ def set_write_function(self, func):
+ setattr(self, "write", func)
+
+
+_safe_io = None
+
+
+def get_safe_io():
+ global _safe_io
+ if _safe_io is None:
+ _safe_io = SafeIO()
+ return _safe_io
+
+
+def safe_print(message, *args, **kwargs):
+ get_safe_io().write(message, *args, **kwargs)
diff --git a/globus_sdk/utils/token_storage.py b/globus_sdk/utils/token_storage.py
new file mode 100644
index 000000000..518b3986b
--- /dev/null
+++ b/globus_sdk/utils/token_storage.py
@@ -0,0 +1,217 @@
+import logging
+import time
+
+from globus_sdk.auth.client_types.native_client import NativeAppAuthClient
+from globus_sdk.config import get_parser
+from globus_sdk.exc import ConfigError, LoadedTokensExpired, RequestedScopesMismatch
+
+TOKEN_KEYS = [
+ "scope",
+ "access_token",
+ "refresh_token",
+ "token_type",
+ "expires_at_seconds",
+ "resource_server",
+]
+REQUIRED_KEYS = ["scope", "access_token", "expires_at_seconds", "resource_server"]
+CONFIG_TOKEN_GROUPS = "token_groups"
+
+logger = logging.getLogger(__name__)
+
+
+def save_tokens(tokens, config_section=None):
+ """
+ Save a dict of tokens in config_section, for the current configfile.
+ Tokens should be formatted like the following:
+ {
+ "auth.globus.org": {
+ "scope": "profile openid email",
+ "access_token": "",
+ "refresh_token": None,
+ "token_type": "Bearer",
+ "expires_at_seconds": 1539984535,
+ "resource_server": "auth.globus.org"
+ }, ...
+ }
+ """
+ config = get_parser()
+
+ cfg_tokens = _serialize_token_groups(tokens)
+ for key, value in cfg_tokens.items():
+ config.set(key, value, section=config_section)
+
+
+def load_tokens(config_section=None, requested_scopes=(), check_expired=True):
+ """
+ Load Tokens from a config section in the configfile. If requested_scopes
+ is given, it will match against the loaded scopes and raise a
+ RequestedScopesMismatch exception if they differ from one another.
+
+ check_expired will check the expires_at_seconds number against the time
+ the user last logged in, and raise LoadedTokensExpired if it is greater.
+ check_expired should be set to false if you want to use refresh tokens.
+
+ Returns tokens in a similar format to token_response.by_resource_server:
+ {
+ "auth.globus.org": {
+ "scope": "profile openid email",
+ "access_token": "",
+ "refresh_token": None,
+ "token_type": "Bearer",
+ "expires_at_seconds": 1539984535,
+ "resource_server": "auth.globus.org"
+ }, ...
+ }
+ """
+
+ config = get_parser()
+ try:
+ cfg_tokens = config.get_section(config_section)
+ loaded_tokens = _deserialize_token_groups(cfg_tokens)
+ except Exception:
+ raise ConfigError("Error loading tokens from: {}".format(config_section))
+
+ for tok_set in loaded_tokens.values():
+ missing = [mk for mk in REQUIRED_KEYS if not tok_set.get(mk)]
+ if any(missing):
+ raise ConfigError("Missing {} from loaded tokens".format(missing))
+
+ if requested_scopes:
+ scope_lists = [t["scope"].split() for t in loaded_tokens.values()]
+ loaded_scopes = {s for slist in scope_lists for s in slist}
+ if loaded_scopes.difference(set(requested_scopes)):
+ raise RequestedScopesMismatch(
+ "Requested Scopes differ from loaded scopes. Requested: "
+ "{}, Loaded: {}".format(requested_scopes, list(loaded_scopes))
+ )
+
+ if check_expired is True:
+ expired = [
+ time.time() >= t["expires_at_seconds"] for t in loaded_tokens.values()
+ ]
+ if any(expired):
+ raise LoadedTokensExpired()
+
+ return loaded_tokens
+
+
+def clear_tokens(config_section=None, client_id=None):
+ """Revokes and deletes tokens saved to disk. ``config_section`` is the
+ section where the tokens are stored, ``client_id`` must be a valid Globus
+ App. Returns True if tokens were revoked (or expired) and deleted, false
+ otherwise. Raises globus_sdk.exc.AuthAPIError if tokens are live and
+ client_id is invalid.
+ """
+ tokens = []
+ try:
+ naac = NativeAppAuthClient(client_id)
+ tokens = load_tokens(config_section=config_section, check_expired=True)
+ for tok_set in tokens.values():
+ logger.debug("Revoking: {}".format(tok_set["resource_server"]))
+ naac.oauth2_revoke_token(tok_set["access_token"])
+ except LoadedTokensExpired:
+ # If they expired, no need to revoke but fetch again for deletion
+ tokens = load_tokens(config_section=config_section, check_expired=False)
+ except ConfigError as ce:
+ logger.debug(ce)
+
+ if not tokens:
+ return False
+
+ cfg_tsets = _serialize_token_groups(tokens)
+ config = get_parser()
+ for cfg_token_name in cfg_tsets.keys():
+ config.remove(cfg_token_name, section=config_section)
+ config.remove(CONFIG_TOKEN_GROUPS, section=config_section)
+ return True
+
+
+def _serialize_token_groups(tokens):
+ """
+ Take a dict of tokens organized by resource server and return a dict
+ that can be easily saved to the config file.
+
+ Resource servers containing '.' in their name will automatically be
+ converted to '_' (auth.globus.org == auth_globus_org). This is only for
+ cosmetic reasons. A resource server named "foo=;# = !@#$%^&*()" will have
+ funky looking config keys, but saving/loading will behave normally.
+
+ Int values are converted to string, None values are converted to empty
+ string. *No other types are checked*.
+
+ `tokens` should be formatted:
+ {
+ "auth.globus.org": {
+ "scope": "profile openid email",
+ "access_token": "",
+ "refresh_token": None,
+ "token_type": "Bearer",
+ "expires_at_seconds": 1539984535,
+ "resource_server": "auth.globus.org"
+ }, ...
+ }
+ Returns a flat dict of tokens prefixed by resource server.
+ {
+ "auth_globus_org_scope": "profile openid email",
+ "auth_globus_org_access_token": "",
+ "auth_globus_org_refresh_token": "",
+ "auth_globus_org_token_type": "Bearer",
+ "auth_globus_org_expires_at_seconds": "1540051101",
+ "auth_globus_org_resource_server": "auth.globus.org",
+ "token_groups": "auth_globus_org"
+ }"""
+ serialized_items = {}
+ token_groups = []
+ for token_set in tokens.values():
+ token_groups.append(_serialize_token(token_set["resource_server"]))
+ for key, value in token_set.items():
+ key_name = _serialize_token(token_set["resource_server"], key)
+ if isinstance(value, int):
+ value = str(value)
+ if value is None:
+ value = ""
+ serialized_items[key_name] = value
+
+ serialized_items[CONFIG_TOKEN_GROUPS] = ",".join(token_groups)
+ return serialized_items
+
+
+def _deserialize_token_groups(config_items):
+ """
+ Takes a dict from a config section and returns a dict of tokens by
+ resource server. `config_items` is a raw dict of config options returned
+ from get_parser().get_section().
+
+ Returns tokens in the format:
+ {
+ "auth.globus.org": {
+ "scope": "profile openid email",
+ "access_token": "",
+ "refresh_token": None,
+ "token_type": "Bearer",
+ "expires_at_seconds": 1539984535,
+ "resource_server": "auth.globus.org"
+ }, ...
+ }
+ """
+ token_groups = {}
+
+ tsets = config_items.get(CONFIG_TOKEN_GROUPS)
+ config_token_groups = tsets.split(",")
+ for group in config_token_groups:
+ tset = {k: config_items.get(_deserialize_token(group, k)) for k in TOKEN_KEYS}
+ tset["expires_at_seconds"] = int(tset["expires_at_seconds"])
+ # Config loaded 'null' values will be an empty string. Set these to
+ # None for consistency
+ tset = {k: v if v else None for k, v in tset.items()}
+ token_groups[tset["resource_server"]] = tset
+
+ return token_groups
+
+
+def _deserialize_token(grouping, token):
+ return "{}{}".format(grouping, token)
+
+
+def _serialize_token(resource_server, token=""):
+ return "{}_{}".format(resource_server.replace(".", "_"), token)
diff --git a/tests/files/sample_configs/set_test.cfg b/tests/files/sample_configs/set_test.cfg
new file mode 100644
index 000000000..3585e0f21
--- /dev/null
+++ b/tests/files/sample_configs/set_test.cfg
@@ -0,0 +1,2 @@
+[default]
+option = general_value
diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py
new file mode 100644
index 000000000..319f38df5
--- /dev/null
+++ b/tests/unit/conftest.py
@@ -0,0 +1,15 @@
+import tempfile
+
+import pytest
+
+import globus_sdk
+
+
+@pytest.yield_fixture
+def temp_config():
+ globus_sdk.config._parser = None
+ temp_config = tempfile.NamedTemporaryFile()
+ cfg = globus_sdk.config.get_parser()
+ cfg.set_write_config_file(temp_config.name)
+ yield cfg
+ temp_config.close()
diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py
index a84c353c2..a388ebec2 100644
--- a/tests/unit/test_config.py
+++ b/tests/unit/test_config.py
@@ -1,5 +1,6 @@
import os
from contextlib import contextmanager
+from tempfile import NamedTemporaryFile
import pytest
import six
@@ -282,3 +283,38 @@ def test_get_globus_environ_production():
del os.environ["GLOBUS_SDK_ENVIRONMENT"]
# ensure that passing a value returns that value
assert globus_sdk.config.get_globus_environ("production") == "default"
+
+
+def test_verify_set_config_file():
+ new_config = NamedTemporaryFile()
+ with custom_config(""):
+ cfg = globus_sdk.config.get_parser()
+ cfg.set_write_config_file(new_config.name)
+ assert cfg._write_path == new_config.name
+
+
+def test_verify_load_from_new_config_file():
+ with custom_config(""), NamedTemporaryFile(mode="w+") as new_cfg:
+ new_cfg.file.write("[default]\noption = general_value\n")
+ new_cfg.file.flush()
+
+ cfg = globus_sdk.config.get_parser()
+ cfg.set_write_config_file(new_cfg.name)
+ assert cfg.get("option", "default") == "general_value"
+
+
+def test_verify_write_config_option(temp_config):
+ with custom_config(""):
+ temp_config.set("foo", "bar", "mysec")
+ assert temp_config.get("foo", "mysec") == "bar"
+
+ temp_config.set("baz", "car", "new_sec")
+ assert temp_config.get("baz", "new_sec") == "car"
+
+
+def test_verify_remove_config_option(temp_config):
+ with custom_config(""):
+ temp_config.set("foo", "bar", "mysec")
+ assert temp_config.get("foo", "mysec") == "bar"
+ temp_config.remove("foo", "mysec")
+ assert temp_config.get("foo", "mysec") is None
diff --git a/tests/unit/test_config_saving.py b/tests/unit/test_config_saving.py
new file mode 100644
index 000000000..8f78b0436
--- /dev/null
+++ b/tests/unit/test_config_saving.py
@@ -0,0 +1,51 @@
+# """
+# Test Config Saving produces expected results.
+# """
+#
+# import os
+# from tempfile import NamedTemporaryFile
+#
+# import globus_sdk.config
+#
+# from tests.framework import get_fixture_file_dir, CapturedIOTestCase
+#
+# SET_CFG = os.path.join(get_fixture_file_dir(), 'sample_configs',
+# 'set_test.cfg')
+#
+#
+# class ConfigSaveTests(CapturedIOTestCase):
+#
+# def setUp(self):
+# self.parser = globus_sdk.config.get_parser()
+# self.config = NamedTemporaryFile()
+#
+# def tearDown(self):
+# globus_sdk.config._parser = None
+# self.config.close()
+#
+# def test_verify_set_config_file(self):
+# self.parser.set_write_config_file(self.config.name)
+# assert self.parser._write_path == self.config.name
+#
+# def test_verify_load_from_new_config_file(self):
+# with open(SET_CFG) as ch, NamedTemporaryFile(mode='w+') as new_cfg:
+# new_cfg.file.write(ch.read())
+# new_cfg.file.flush()
+#
+# self.parser.set_write_config_file(new_cfg.name)
+# assert self.parser.get('option', 'default') == 'general_value'
+#
+# def test_verify_write_config_option(self):
+# self.parser.set_write_config_file(self.config.name)
+# self.parser.set('foo', 'bar', 'mysec')
+# assert self.parser.get('foo', 'mysec') == 'bar'
+#
+# self.parser.set('baz', 'car', 'new_sec')
+# assert self.parser.get('baz', 'new_sec') == 'car'
+#
+# def test_verify_remove_config_option(self):
+# self.parser.set_write_config_file(self.config.name)
+# self.parser.set('foo', 'bar', 'mysec')
+# assert self.parser.get('foo', 'mysec') == 'bar'
+# self.parser.remove('foo', 'mysec')
+# assert self.parser.get('foo', 'mysec') is None
diff --git a/tests/unit/test_native_auth.py b/tests/unit/test_native_auth.py
new file mode 100644
index 000000000..868dc0eab
--- /dev/null
+++ b/tests/unit/test_native_auth.py
@@ -0,0 +1,254 @@
+import copy
+import uuid
+import webbrowser
+from time import time
+
+import pytest
+
+import globus_sdk
+from globus_sdk.auth import oauth2_native_app_shortcut
+from globus_sdk.auth.oauth2_native_app_shortcut import (
+ AUTH_CODE_REDIRECT,
+ NATIVE_AUTH_DEFAULTS as NA_DEF,
+ native_auth,
+)
+from globus_sdk.utils import token_storage
+
+try:
+ import mock
+except ImportError:
+ from unittest import mock
+
+MOCK_TOKENS = {
+ "auth.globus.org": {
+ "scope": "profile openid email",
+ "access_token": "9d0e6f2a21917cc3e04602838e0ba4f7df3399bbd49f1"
+ "5db3cf0af34d52c928f34f639444af0b28695086d97b1",
+ "refresh_token": None,
+ "token_type": "Bearer",
+ "expires_at_seconds": int(time()) + 60 * 60,
+ "resource_server": "auth.globus.org",
+ }
+}
+
+MOCK_AUTH_CODE = "foobarbaz"
+
+
+@pytest.fixture
+def mock_webbrowser(monkeypatch):
+ monkeypatch.setattr(webbrowser, "open", mock.Mock())
+
+
+@pytest.fixture
+def mock_save_tokens(monkeypatch):
+ mock_save = mock.Mock()
+ monkeypatch.setattr(oauth2_native_app_shortcut, "save_tokens", mock_save)
+ return mock_save
+
+
+@pytest.fixture
+def mock_clear_tokens(monkeypatch):
+ mocked_clear_tokens = mock.Mock()
+ monkeypatch.setattr(oauth2_native_app_shortcut, "clear_tokens", mocked_clear_tokens)
+ return mocked_clear_tokens
+
+
+@pytest.yield_fixture
+def mock_local_server(monkeypatch):
+ mock_server = mock.Mock()
+ mock_server.__enter__ = mock.Mock(return_value=mock.Mock())
+ mock_server.__exit__ = mock.Mock(return_value=mock.Mock())
+ mock_start = mock.Mock(return_value=mock_server)
+ monkeypatch.setattr(
+ "globus_sdk.auth.oauth2_native_app_shortcut.start_local_server", mock_start
+ )
+ return mock_start
+
+
+@pytest.fixture
+def mock_native_client(monkeypatch):
+ mock_client = mock.Mock()
+ mock_class = mock.Mock(return_value=mock_client)
+ token_response = mock.Mock()
+ mock_client.oauth2_exchange_code_for_tokens.return_value = token_response
+ token_response.by_resource_server = MOCK_TOKENS
+ monkeypatch.setattr(
+ "globus_sdk.auth.oauth2_native_app_shortcut.NativeAppAuthClient", mock_class
+ )
+ return mock_class, mock_client
+
+
+@pytest.fixture
+def mock_native_client_simple(mock_native_client):
+ client_class, _ = mock_native_client
+ return client_class
+
+
+@pytest.yield_fixture
+def saved_tokens(temp_config, mock_native_client):
+ token_storage.save_tokens(MOCK_TOKENS, NA_DEF["client_id"])
+
+
+@pytest.fixture
+def expired_saved_tokens(temp_config):
+ expired = copy.deepcopy(MOCK_TOKENS)
+ expired["auth.globus.org"]["expires_at_seconds"] = int(time()) - 1
+ token_storage.save_tokens(expired, NA_DEF["client_id"])
+
+
+@pytest.fixture
+def mock_input(monkeypatch):
+ mock_input = mock.Mock()
+ monkeypatch.setattr("globus_sdk.auth.oauth2_native_app_shortcut.input", mock_input)
+ return mock_input
+
+
+@pytest.fixture
+def mock_safe_print(monkeypatch):
+ mock_print = mock.Mock()
+ monkeypatch.setattr("globus_sdk.auth.oauth2_native_app_shortcut.input", mock_print)
+ return mock_print
+
+
+def test_native_auth(
+ mock_webbrowser, mock_local_server, mock_native_client, mock_save_tokens
+):
+ native_app_class, native_app_client = mock_native_client
+
+ native_auth()
+
+ native_app_class.assert_called_with(client_id=NA_DEF["client_id"])
+ native_app_client.oauth2_start_flow.assert_called_with(
+ requested_scopes=NA_DEF["requested_scopes"],
+ redirect_uri=NA_DEF["redirect_uri"],
+ refresh_tokens=NA_DEF["refresh_tokens"],
+ prefill_named_grant=NA_DEF["prefill_named_grant"],
+ )
+ native_app_client.oauth2_get_authorize_url.assert_called_with(additional_params={})
+ mock_local_server.assert_called_with(
+ listen=(NA_DEF["server_hostname"], NA_DEF["server_port"])
+ )
+ assert native_app_client.oauth2_exchange_code_for_tokens.called
+ assert not mock_save_tokens.called
+ assert webbrowser.open.called
+
+
+def test_invalid_option_raises_error(mock_native_client, mock_local_server):
+ with pytest.raises(ValueError):
+ native_auth(conquer_the_world=True)
+
+
+def test_native_auth_saving_tokens(
+ mock_save_tokens, mock_native_client, mock_local_server, temp_config
+):
+ native_auth(save_tokens=True)
+ assert mock_save_tokens.called
+
+
+def test_native_auth_loading_tokens(
+ mock_native_client_simple, mock_local_server, saved_tokens
+):
+ native_auth()
+ # assert tokens were loaded and a native flow was not started
+ assert not mock_native_client_simple.called
+
+
+def test_native_auth_force_login(
+ mock_native_client_simple, mock_local_server, saved_tokens, mock_clear_tokens
+):
+ # Should disregard previously saved tokens
+ native_auth(force_login=True)
+ assert mock_native_client_simple.called
+ assert mock_clear_tokens.called
+
+
+def test_native_auth_requested_scope_check(
+ mock_native_client_simple, mock_local_server, saved_tokens, mock_clear_tokens
+):
+ # Scopes here are different than what was saved, so this should
+ # trigger an auth flow
+ native_auth(requested_scopes=("foo",))
+ assert mock_native_client_simple.called
+ assert mock_clear_tokens.called
+
+
+def test_native_auth_expired_token_check(
+ mock_native_client_simple, mock_local_server, expired_saved_tokens
+):
+ native_auth()
+ assert mock_native_client_simple.called
+
+
+def test_native_auth_expired_token_no_check(
+ mock_native_client_simple, mock_local_server, expired_saved_tokens
+):
+ native_auth(check_tokens_expired=False)
+ assert not mock_native_client_simple.called
+
+
+def test_native_auth_no_local_server(
+ mock_local_server, mock_native_client, temp_config, mock_input, mock_safe_print
+):
+ native_app_class, native_app_client = mock_native_client
+
+ native_auth(no_local_server=True)
+ native_app_client.oauth2_start_flow.assert_called_with(
+ requested_scopes=NA_DEF["requested_scopes"],
+ redirect_uri=AUTH_CODE_REDIRECT,
+ refresh_tokens=NA_DEF["refresh_tokens"],
+ prefill_named_grant=NA_DEF["prefill_named_grant"],
+ )
+ assert native_app_class.called
+ assert not mock_local_server.called
+ assert mock_safe_print.called
+ assert not mock_input.called
+
+
+def test_native_auth_no_browser(
+ mock_webbrowser, mock_local_server, mock_native_client_simple, mock_safe_print
+):
+ native_auth(no_browser=True)
+ # Assert a native flow was not started
+ assert mock_native_client_simple.called
+ assert mock_local_server.called
+ assert not webbrowser.open.called
+
+
+def test_native_auth_custom_config_section(
+ mock_native_client_simple, mock_local_server, temp_config
+):
+ my_section = "my_section"
+ native_auth(config_section="my_section", save_tokens=True)
+ assert my_section in globus_sdk.config.get_parser()._parser.sections()
+
+
+def test_native_auth_ancillary_options(
+ mock_webbrowser, mock_local_server, mock_native_client, mock_save_tokens
+):
+ """Options here don't change the control flow and should not affect
+ one another. This test asserts they're present in expected places"""
+ native_class, native_client = mock_native_client
+ options = {
+ "client_id": str(uuid.uuid4()),
+ "redirect_uri": "http://example.com/login",
+ "requested_scopes": ("myscope", "myotherscope"),
+ "refresh_tokens": True,
+ "prefill_named_grant": "Captain Hammer's Lenovo",
+ "additional_auth_params": {"session_message": "hello!"},
+ "server_hostname": "localhost",
+ "server_port": 9999,
+ }
+ native_auth(**options)
+ native_class.assert_called_with(client_id=options["client_id"])
+ native_client.oauth2_start_flow.assert_called_with(
+ requested_scopes=options["requested_scopes"],
+ redirect_uri=options["redirect_uri"],
+ refresh_tokens=options["refresh_tokens"],
+ prefill_named_grant=options["prefill_named_grant"],
+ )
+ native_client.oauth2_get_authorize_url.assert_called_with(
+ additional_params=options["additional_auth_params"]
+ )
+ mock_local_server.assert_called_with(
+ listen=(options["server_hostname"], options["server_port"])
+ )
diff --git a/tests/unit/test_utils_local_server.py b/tests/unit/test_utils_local_server.py
new file mode 100644
index 000000000..f1cc74cb4
--- /dev/null
+++ b/tests/unit/test_utils_local_server.py
@@ -0,0 +1,59 @@
+import threading
+
+import httpretty
+import pytest
+import requests
+from six.moves.urllib.parse import urlencode
+
+from globus_sdk.exc import LocalServerError
+from globus_sdk.utils.local_server import start_local_server
+
+
+class LocalServerTester:
+ def __init__(self):
+ self.server_response = None
+
+ def _wait_for_code(self, server):
+ try:
+ self.server_response = server.wait_for_code()
+ except Exception as e:
+ self.server_response = e
+
+ def test(self, response_params):
+ """
+ Start a local server to wait for an 'auth_code'. Usually the user's
+ browser will redirect to this location, but in this case the user is
+ mocked with a separate request in another thread.
+
+ Waits for threads to complete and returns the local_server response.
+ """
+ with start_local_server() as server:
+ thread = threading.Thread(target=self._wait_for_code, args=(server,))
+ thread.start()
+ host, port = server.server_address
+ url = "http://{}:{}/?{}".format(
+ "127.0.0.1", port, urlencode(response_params)
+ )
+ requests.get(url)
+ thread.join()
+ return self.server_response
+
+
+@pytest.yield_fixture
+def test_server():
+ httpretty.disable()
+ yield LocalServerTester()
+ httpretty.enable()
+
+
+def test_local_server_with_auth_code(test_server):
+ MOCK_AUTH_CODE = (
+ "V2UgY2FuJ3Qgd2FpdCBmb3IgY29kZXMgZm9yZXZlci4g"
+ "V2VsbCwgd2UgY2FuIGJ1dCBJIGRvbid0IHdhbnQgdG8u"
+ )
+ assert test_server.test({"code": MOCK_AUTH_CODE}) == MOCK_AUTH_CODE
+
+
+def test_local_server_with_error(test_server):
+ response = test_server.test({"error": "bad things happened"})
+ assert isinstance(response, LocalServerError)
diff --git a/tests/unit/test_utils_safe_io.py b/tests/unit/test_utils_safe_io.py
new file mode 100644
index 000000000..592b0ac7d
--- /dev/null
+++ b/tests/unit/test_utils_safe_io.py
@@ -0,0 +1,27 @@
+from tempfile import NamedTemporaryFile
+
+from globus_sdk.utils import safeio
+
+try:
+ import mock
+except ImportError:
+ from unittest import mock
+
+
+def test_safe_print_custom_output():
+ my_log_file = NamedTemporaryFile()
+
+ def my_logger(message):
+ with open(my_log_file.name, "w+") as lfh:
+ lfh.write(message)
+
+ safeio.get_safe_io().set_write_function(my_logger)
+ safeio.safe_print("The hamsters are attacking!")
+ with open(my_log_file.name) as lfh:
+ assert lfh.read() == "The hamsters are attacking!"
+
+
+def test_safe_print_normally():
+ with mock.patch("globus_sdk.utils.safeio._safe_io") as sio:
+ safeio.safe_print("foo")
+ sio.write.assert_called_once_with("foo")
diff --git a/tests/unit/test_utils_token_storage.py b/tests/unit/test_utils_token_storage.py
new file mode 100644
index 000000000..4f66b8c9f
--- /dev/null
+++ b/tests/unit/test_utils_token_storage.py
@@ -0,0 +1,117 @@
+import copy
+from time import time
+
+import pytest
+
+from globus_sdk import config
+from globus_sdk.exc import ConfigError, LoadedTokensExpired, RequestedScopesMismatch
+from globus_sdk.utils.token_storage import clear_tokens, load_tokens, save_tokens
+
+try:
+ import mock
+except ImportError:
+ from unittest import mock
+
+TOKEN_LIFETIME = 60 * 60 * 24
+
+MOCK_TOKENS = {
+ "auth.globus.org": {
+ "scope": "profile openid email",
+ "access_token": "9d0e6f2a21917cc3e04602838e0ba4f7df3399bbd49f1"
+ "5db3cf0af34d52c928f34f639444af0b28695086d97b1",
+ "refresh_token": None,
+ "token_type": "Bearer",
+ "expires_at_seconds": int(time()) + TOKEN_LIFETIME,
+ "resource_server": "auth.globus.org",
+ },
+ "workhorse.org": {
+ "scope": "all",
+ "access_token": "QmFkIEhvcnNlLCBCYWQgSG9yc2UuIEJhZCBIb3JzZSwg"
+ "QmFkIEhvcnNlLiBIZSByaWRlcyBhY3Jvc3MgdGhlIG5h",
+ "refresh_token": "VGhlIGV2aWwgbGVhZ3VlIG9mIGV2aWwsIGlzIHdhdGNo"
+ "aW5nIHNvIGJld2FyZS4gVGhlIGdyYWRlIHRoYXQgeW8=",
+ "token_type": "Bearer",
+ "expires_at_seconds": int(time()) + TOKEN_LIFETIME,
+ "resource_server": "workhorse.org",
+ },
+}
+
+
+@pytest.fixture
+def mock_expired_tokens():
+ expired_tokens = copy.deepcopy(MOCK_TOKENS)
+ for _, token_set in expired_tokens.items():
+ token_set["expires_at_seconds"] = int(time()) - 1
+ return expired_tokens
+
+
+@pytest.fixture
+def mock_native_app(monkeypatch):
+ mock_client = mock.MagicMock()
+ monkeypatch.setattr(
+ "globus_sdk.utils.token_storage.NativeAppAuthClient", mock_client
+ )
+ return mock_client
+
+
+def test_save_and_load_tokens_matches_original(temp_config):
+ save_tokens(MOCK_TOKENS, "test")
+ tokens = load_tokens("test")
+ for set_name, set_values in tokens.items():
+ loaded_set = set(set_values.values())
+ mock_set = set(MOCK_TOKENS[set_name].values())
+ assert not loaded_set.difference(mock_set)
+
+
+def test_loading_bad_tokens_raises_error(temp_config):
+ save_tokens(MOCK_TOKENS, "test")
+ temp_config.remove("workhorse_org_access_token", "test")
+ with pytest.raises(ConfigError):
+ load_tokens("test")
+
+
+def test_loading_raises_error_if_tokens_expire(temp_config, mock_expired_tokens):
+ save_tokens(mock_expired_tokens, "test")
+ with pytest.raises(LoadedTokensExpired):
+ load_tokens("test")
+
+
+def test_loading_raises_error_if_scopes_differ(temp_config):
+ save_tokens(MOCK_TOKENS, "test")
+ transfer_scope = ("urn:globus:auth:scope:transfer.api.globus.org:all",)
+ with pytest.raises(RequestedScopesMismatch):
+ load_tokens("test", requested_scopes=transfer_scope)
+
+
+def test_verify_clear_tokens(temp_config, mock_native_app):
+ save_tokens(MOCK_TOKENS, "test")
+ section = config.get_parser().get_section("test")
+ assert len(section.values()) == 13
+ return_value = clear_tokens("test", "my_client_id")
+ section = config.get_parser().get_section("test")
+ assert len(section.values()) == 0
+ mock_native_app.assert_called_with("my_client_id")
+ assert return_value is True
+
+
+def test_clear_tokens_with_no_saved_tokens(temp_config, mock_native_app):
+ return_value = clear_tokens("test", "my_client_id")
+ assert return_value is False
+
+
+def test_clear_expired_tokens(temp_config, mock_expired_tokens, mock_native_app):
+ save_tokens(mock_expired_tokens, "test")
+ section = config.get_parser().get_section("test")
+ assert len(section.values()) == 13
+ return_value = clear_tokens("test", "my_client_id")
+ section = config.get_parser().get_section("test")
+ assert len(section.values()) == 0
+ mock_native_app.assert_called_with("my_client_id")
+ assert return_value is True
+
+
+def test_clear_tokens_with_invalid_client_raises_error(temp_config):
+ save_tokens(MOCK_TOKENS, "test")
+ config.get_parser().remove("workhorse_org_access_token", "test")
+ with pytest.raises(ConfigError):
+ load_tokens("test")