diff --git a/src/globus_sdk/_guards.py b/src/globus_sdk/_guards.py new file mode 100644 index 000000000..8d80e943e --- /dev/null +++ b/src/globus_sdk/_guards.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import sys +import typing as t + +if sys.version_info >= (3, 10): + from typing import TypeGuard +else: + from typing_extensions import TypeGuard + +T = t.TypeVar("T") + + +def is_list_of(data: t.Any, typ: type[T]) -> TypeGuard[list[T]]: + return isinstance(data, list) and all(isinstance(item, typ) for item in data) + + +def is_optional(data: t.Any, typ: type[T]) -> TypeGuard[T | None]: + return data is None or isinstance(data, typ) + + +def is_optional_list_of(data: t.Any, typ: type[T]) -> TypeGuard[list[T] | None]: + return data is None or ( + isinstance(data, list) and all(isinstance(item, typ) for item in data) + ) diff --git a/src/globus_sdk/exc/api.py b/src/globus_sdk/exc/api.py index 340bfed3c..2dc275896 100644 --- a/src/globus_sdk/exc/api.py +++ b/src/globus_sdk/exc/api.py @@ -6,6 +6,8 @@ import requests +from globus_sdk import _guards + from .base import GlobusError from .err_info import ErrorInfoContainer from .warnings import warn_deprecated @@ -273,7 +275,7 @@ def _detect_error_format(self) -> _ErrorFormat: # well-formed if self._jsonapi_mimetype(): errors = self._dict_data.get("errors") - if not isinstance(errors, list): + if not _guards.is_list_of(errors, dict): return _ErrorFormat.undefined elif len(errors) < 1: return _ErrorFormat.undefined @@ -321,9 +323,7 @@ def _parse_type_zero_error_format(self) -> bool: self.code = self._dict_data["code"] self.messages = [self._dict_data["message"]] self.request_id = self._dict_data.get("request_id") - if isinstance(self._dict_data.get("errors"), list) and all( - isinstance(subdoc, dict) for subdoc in self._dict_data["errors"] - ): + if _guards.is_list_of(self._dict_data.get("errors"), dict): raw_errors = self._dict_data["errors"] else: raw_errors = [self._dict_data] @@ -341,9 +341,7 @@ def _parse_undefined_error_format(self) -> bool: """ # attempt to pull out errors if possible and valid - if isinstance(self._dict_data.get("errors"), list) and all( - isinstance(subdoc, dict) for subdoc in self._dict_data["errors"] - ): + if _guards.is_list_of(self._dict_data.get("errors"), dict): raw_errors = self._dict_data["errors"] # if no 'errors' were found, or 'errors' is invalid, then # 'errors' should be set to contain the root document diff --git a/src/globus_sdk/exc/err_info.py b/src/globus_sdk/exc/err_info.py index 7e0bd002a..3ccc1a2d8 100644 --- a/src/globus_sdk/exc/err_info.py +++ b/src/globus_sdk/exc/err_info.py @@ -1,21 +1,13 @@ from __future__ import annotations import logging -import sys import typing as t -if sys.version_info < (3, 10): - from typing_extensions import TypeGuard -else: - from typing import TypeGuard +from globus_sdk import _guards log = logging.getLogger(__name__) -def _is_list_of_strs(obj: t.Any) -> TypeGuard[list[str]]: - return isinstance(obj, list) and all(isinstance(item, str) for item in obj) - - class ErrorInfo: """ Errors may contain "containers" of data which are testable (define ``__bool__``). @@ -98,7 +90,7 @@ def _parse_session_required_identities( self, data: dict[str, t.Any] ) -> list[str] | None: session_required_identities = data.get("session_required_identities") - if _is_list_of_strs(session_required_identities): + if _guards.is_list_of(session_required_identities, str): return session_required_identities elif session_required_identities is not None: self._warn_type( @@ -112,7 +104,7 @@ def _parse_session_required_single_domain( self, data: dict[str, t.Any] ) -> list[str] | None: session_required_single_domain = data.get("session_required_single_domain") - if _is_list_of_strs(session_required_single_domain): + if _guards.is_list_of(session_required_single_domain, str): return session_required_single_domain elif session_required_single_domain is not None: self._warn_type( @@ -128,7 +120,7 @@ def _parse_session_required_policies( session_required_policies = data.get("session_required_policies") if isinstance(session_required_policies, str): return session_required_policies.split(",") - elif _is_list_of_strs(session_required_policies): + elif _guards.is_list_of(session_required_policies, str): return session_required_policies elif session_required_policies is not None: self._warn_type( @@ -162,7 +154,7 @@ def __init__(self, error_data: dict[str, t.Any]) -> None: self._has_data = has_code and bool(self.required_scopes) def _parse_required_scopes(self, data: dict[str, t.Any]) -> list[str]: - if _is_list_of_strs(data.get("required_scopes")): + if _guards.is_list_of(data.get("required_scopes"), str): return t.cast("list[str]", data["required_scopes"]) elif isinstance(data.get("required_scope"), str): return [data["required_scope"]] diff --git a/src/globus_sdk/experimental/auth_requirements_error/_validators.py b/src/globus_sdk/experimental/auth_requirements_error/_validators.py index c1873e62d..e14289f8a 100644 --- a/src/globus_sdk/experimental/auth_requirements_error/_validators.py +++ b/src/globus_sdk/experimental/auth_requirements_error/_validators.py @@ -2,6 +2,8 @@ import typing as t +from globus_sdk import _guards + from ._serializable import Serializable S = t.TypeVar("S", bound=Serializable) @@ -18,40 +20,34 @@ def str_(name: str, value: t.Any) -> str: def opt_str(name: str, value: t.Any) -> str | None: - if value is None: - return None - if isinstance(value, str): + if _guards.is_optional(value, str): return value raise ValidationError(f"'{name}' must be a string or null") def opt_bool(name: str, value: t.Any) -> bool | None: - if value is None or isinstance(value, bool): + if _guards.is_optional(value, bool): return value raise ValidationError(f"'{name}' must be a bool or null") def str_list(name: str, value: t.Any) -> list[str]: - if isinstance(value, list) and all(isinstance(s, str) for s in value): + if _guards.is_list_of(value, str): return value raise ValidationError(f"'{name}' must be a list of strings") def opt_str_list(name: str, value: t.Any) -> list[str] | None: - if value is None: - return None - if isinstance(value, list) and all(isinstance(s, str) for s in value): + if _guards.is_optional_list_of(value, str): return value raise ValidationError(f"'{name}' must be a list of strings or null") def opt_str_list_or_commasep(name: str, value: t.Any) -> list[str] | None: - if value is None: - return None - if isinstance(value, str): - value = value.split(",") - if isinstance(value, list) and all(isinstance(s, str) for s in value): + if _guards.is_optional_list_of(value, str): return value + if isinstance(value, str): + return value.split(",") raise ValidationError( f"'{name}' must be a list of strings or a comma-delimited string or null" ) diff --git a/src/globus_sdk/response.py b/src/globus_sdk/response.py index 0f144f78f..76dcf4e28 100644 --- a/src/globus_sdk/response.py +++ b/src/globus_sdk/response.py @@ -7,6 +7,8 @@ from requests import Response +from globus_sdk import _guards + log = logging.getLogger(__name__) if t.TYPE_CHECKING: @@ -134,7 +136,7 @@ def get(self, key: str, default: t.Any = None) -> t.Any: ``get`` is just an alias for ``data.get(key, default)``, but with the added checks that if ``data`` is ``None`` or a list, it returns the default. """ - if self.data is None or isinstance(self.data, list): + if _guards.is_optional(self.data, list): return default # NB: `default` is provided as a positional because the native dict type # doesn't recognize a keyword argument `default` diff --git a/src/globus_sdk/services/auth/client/base.py b/src/globus_sdk/services/auth/client/base.py index a5d713477..e4660b309 100644 --- a/src/globus_sdk/services/auth/client/base.py +++ b/src/globus_sdk/services/auth/client/base.py @@ -9,6 +9,7 @@ import jwt from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey +from globus_sdk import _guards from globus_sdk.authorizers import GlobusAuthorizer if sys.version_info >= (3, 8): @@ -701,9 +702,7 @@ def oauth2_validate_token( # if this client has no way of authenticating itself but # it does have a client_id, we'll send that in the request - no_authentication = self.authorizer is None or isinstance( - self.authorizer, NullAuthorizer - ) + no_authentication = _guards.is_optional(self.authorizer, NullAuthorizer) if no_authentication and self.client_id: log.debug("Validating token with unauthenticated client") body.update({"client_id": self.client_id}) diff --git a/src/globus_sdk/services/flows/errors.py b/src/globus_sdk/services/flows/errors.py index e7ab08079..270a5adf0 100644 --- a/src/globus_sdk/services/flows/errors.py +++ b/src/globus_sdk/services/flows/errors.py @@ -1,5 +1,6 @@ from __future__ import annotations +from globus_sdk import _guards from globus_sdk.exc import ErrorSubdocument, GlobusAPIError @@ -27,7 +28,7 @@ def _parse_undefined_error_format(self) -> bool: self.code = self._extract_code_from_error_array(self.errors) details = self._dict_data["error"].get("detail") - if isinstance(details, list): + if _guards.is_list_of(details, dict): self.messages = [ error_detail["msg"] for error_detail in details diff --git a/src/globus_sdk/services/timer/errors.py b/src/globus_sdk/services/timer/errors.py index 5b04bcb8a..3482b602b 100644 --- a/src/globus_sdk/services/timer/errors.py +++ b/src/globus_sdk/services/timer/errors.py @@ -2,6 +2,7 @@ import typing as t +from globus_sdk import _guards from globus_sdk.exc import ErrorSubdocument, GlobusAPIError @@ -54,7 +55,7 @@ def _parse_undefined_error_format(self) -> bool: self.code = self._extract_code_from_error_array(self.errors) self.messages = self._extract_messages_from_error_array(self.errors) return True - elif isinstance(self._dict_data.get("detail"), list): + elif _guards.is_list_of(self._dict_data.get("detail"), dict): # FIXME: # the 'code' is currently being set explicitly by the # SDK in this case even though none was provided by @@ -63,11 +64,7 @@ def _parse_undefined_error_format(self) -> bool: self.code = "Validation Error" # collect the errors array from details - self.errors = [ - ErrorSubdocument(d) - for d in self._dict_data["detail"] - if isinstance(d, dict) - ] + self.errors = [ErrorSubdocument(d) for d in self._dict_data["detail"]] # drop error objects which don't have the relevant fields # and then build custom 'messages' for Globus Timers errors @@ -87,8 +84,6 @@ def _details_from_errors( if not isinstance(d.get("msg"), str): continue loc_list = d.get("loc") - if not isinstance(loc_list, list): - continue - if not all(isinstance(path_item, str) for path_item in loc_list): + if not _guards.is_list_of(loc_list, str): continue yield d.raw diff --git a/tests/non-pytest/mypy-ignore-tests/test_guards.py b/tests/non-pytest/mypy-ignore-tests/test_guards.py new file mode 100644 index 000000000..0dc761533 --- /dev/null +++ b/tests/non-pytest/mypy-ignore-tests/test_guards.py @@ -0,0 +1,31 @@ +# test that the internal _guards module provides valid and well-formed type-guards +import typing as t + +from globus_sdk import _guards + + +def get_any() -> t.Any: + return 1 + + +x = get_any() +t.assert_type(x, t.Any) + +# test is_list_of +if _guards.is_list_of(x, str): + t.assert_type(x, list[str]) +elif _guards.is_list_of(x, int): + t.assert_type(x, list[int]) + +# test is_optional +if _guards.is_optional(x, float): + t.assert_type(x, float | None) +elif _guards.is_optional(x, bytes): + t.assert_type(x, bytes | None) + + +# test is_optional_list_of +if _guards.is_optional_list_of(x, type(None)): + t.assert_type(x, list[None] | None) +elif _guards.is_optional_list_of(x, dict): + t.assert_type(x, list[dict[t.Any, t.Any]] | None) diff --git a/tests/unit/test_guards.py b/tests/unit/test_guards.py new file mode 100644 index 000000000..8b3b4d642 --- /dev/null +++ b/tests/unit/test_guards.py @@ -0,0 +1,63 @@ +import pytest + +from globus_sdk import _guards + + +@pytest.mark.parametrize( + "value, typ, ok", + [ + # passing + ([], str, True), + ([1, 2], int, True), + (["1", ""], str, True), + ([], list, True), + ([[], [1, 2], ["foo"]], list, True), + # failing + ([1], str, False), + (["foo"], int, False), + ((1, 2), int, False), + (list, list, False), + (list, str, False), + (["foo", 1], str, False), + ([1, 2], list, False), + ], +) +def test_list_of_guard(value, typ, ok): + assert _guards.is_list_of(value, typ) == ok + + +@pytest.mark.parametrize( + "value, typ, ok", + [ + # passing + (None, str, True), + ("foo", str, True), + # failing + (b"foo", str, False), + ("", int, False), + (type(None), str, False), + ], +) +def test_opt_guard(value, typ, ok): + assert _guards.is_optional(value, typ) == ok + + +@pytest.mark.parametrize( + "value, typ, ok", + [ + # passing + ([], str, True), + ([], int, True), + ([1, 2], int, True), + (["1", ""], str, True), + (None, str, True), + # failing + # NB: the guard checks `list[str] | None`, not `list[str | None]` + ([None], str, False), + (b"foo", str, False), + ("", str, False), + (type(None), str, False), + ], +) +def test_opt_list_guard(value, typ, ok): + assert _guards.is_optional_list_of(value, typ) == ok