Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce an internal TypeGuard library, globus_sdk._guards #798

Merged
merged 6 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/globus_sdk/_guards.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations

import sys
import typing as t

if sys.version_info >= (3, 10):
from typing import TypeGuard
else:
from typing_extensions import TypeGuard

T = t.TypeVar("T")


def is_list_of(data: t.Any, typ: type[T]) -> TypeGuard[list[T]]:
return isinstance(data, list) and all(isinstance(item, typ) for item in data)


def is_optional(data: t.Any, typ: type[T]) -> TypeGuard[T | None]:
return data is None or isinstance(data, typ)


def is_optional_list_of(data: t.Any, typ: type[T]) -> TypeGuard[list[T] | None]:
return data is None or (
isinstance(data, list) and all(isinstance(item, typ) for item in data)
)
12 changes: 5 additions & 7 deletions src/globus_sdk/exc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import requests

from globus_sdk import _guards

from .base import GlobusError
from .err_info import ErrorInfoContainer
from .warnings import warn_deprecated
Expand Down Expand Up @@ -273,7 +275,7 @@ def _detect_error_format(self) -> _ErrorFormat:
# well-formed
if self._jsonapi_mimetype():
errors = self._dict_data.get("errors")
if not isinstance(errors, list):
if not _guards.is_list_of(errors, dict):
return _ErrorFormat.undefined
elif len(errors) < 1:
return _ErrorFormat.undefined
Expand Down Expand Up @@ -321,9 +323,7 @@ def _parse_type_zero_error_format(self) -> bool:
self.code = self._dict_data["code"]
self.messages = [self._dict_data["message"]]
self.request_id = self._dict_data.get("request_id")
if isinstance(self._dict_data.get("errors"), list) and all(
isinstance(subdoc, dict) for subdoc in self._dict_data["errors"]
):
if _guards.is_list_of(self._dict_data.get("errors"), dict):
raw_errors = self._dict_data["errors"]
else:
raw_errors = [self._dict_data]
Expand All @@ -341,9 +341,7 @@ def _parse_undefined_error_format(self) -> bool:
"""

# attempt to pull out errors if possible and valid
if isinstance(self._dict_data.get("errors"), list) and all(
isinstance(subdoc, dict) for subdoc in self._dict_data["errors"]
):
if _guards.is_list_of(self._dict_data.get("errors"), dict):
raw_errors = self._dict_data["errors"]
# if no 'errors' were found, or 'errors' is invalid, then
# 'errors' should be set to contain the root document
Expand Down
18 changes: 5 additions & 13 deletions src/globus_sdk/exc/err_info.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
from __future__ import annotations

import logging
import sys
import typing as t

if sys.version_info < (3, 10):
from typing_extensions import TypeGuard
else:
from typing import TypeGuard
from globus_sdk import _guards

log = logging.getLogger(__name__)


def _is_list_of_strs(obj: t.Any) -> TypeGuard[list[str]]:
return isinstance(obj, list) and all(isinstance(item, str) for item in obj)


class ErrorInfo:
"""
Errors may contain "containers" of data which are testable (define ``__bool__``).
Expand Down Expand Up @@ -98,7 +90,7 @@ def _parse_session_required_identities(
self, data: dict[str, t.Any]
) -> list[str] | None:
session_required_identities = data.get("session_required_identities")
if _is_list_of_strs(session_required_identities):
if _guards.is_list_of(session_required_identities, str):
return session_required_identities
elif session_required_identities is not None:
self._warn_type(
Expand All @@ -112,7 +104,7 @@ def _parse_session_required_single_domain(
self, data: dict[str, t.Any]
) -> list[str] | None:
session_required_single_domain = data.get("session_required_single_domain")
if _is_list_of_strs(session_required_single_domain):
if _guards.is_list_of(session_required_single_domain, str):
return session_required_single_domain
elif session_required_single_domain is not None:
self._warn_type(
Expand All @@ -128,7 +120,7 @@ def _parse_session_required_policies(
session_required_policies = data.get("session_required_policies")
if isinstance(session_required_policies, str):
return session_required_policies.split(",")
elif _is_list_of_strs(session_required_policies):
elif _guards.is_list_of(session_required_policies, str):
return session_required_policies
elif session_required_policies is not None:
self._warn_type(
Expand Down Expand Up @@ -162,7 +154,7 @@ def __init__(self, error_data: dict[str, t.Any]) -> None:
self._has_data = has_code and bool(self.required_scopes)

def _parse_required_scopes(self, data: dict[str, t.Any]) -> list[str]:
if _is_list_of_strs(data.get("required_scopes")):
if _guards.is_list_of(data.get("required_scopes"), str):
return t.cast("list[str]", data["required_scopes"])
elif isinstance(data.get("required_scope"), str):
return [data["required_scope"]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import typing as t

from globus_sdk import _guards

from ._serializable import Serializable

S = t.TypeVar("S", bound=Serializable)
Expand All @@ -18,40 +20,34 @@ def str_(name: str, value: t.Any) -> str:


def opt_str(name: str, value: t.Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
if _guards.is_optional(value, str):
return value
raise ValidationError(f"'{name}' must be a string or null")


def opt_bool(name: str, value: t.Any) -> bool | None:
if value is None or isinstance(value, bool):
if _guards.is_optional(value, bool):
return value
raise ValidationError(f"'{name}' must be a bool or null")


def str_list(name: str, value: t.Any) -> list[str]:
if isinstance(value, list) and all(isinstance(s, str) for s in value):
if _guards.is_list_of(value, str):
return value
raise ValidationError(f"'{name}' must be a list of strings")


def opt_str_list(name: str, value: t.Any) -> list[str] | None:
if value is None:
return None
if isinstance(value, list) and all(isinstance(s, str) for s in value):
if _guards.is_optional_list_of(value, str):
return value
raise ValidationError(f"'{name}' must be a list of strings or null")


def opt_str_list_or_commasep(name: str, value: t.Any) -> list[str] | None:
if value is None:
return None
if isinstance(value, str):
value = value.split(",")
if isinstance(value, list) and all(isinstance(s, str) for s in value):
if _guards.is_optional_list_of(value, str):
return value
if isinstance(value, str):
return value.split(",")
raise ValidationError(
f"'{name}' must be a list of strings or a comma-delimited string or null"
)
Expand Down
4 changes: 3 additions & 1 deletion src/globus_sdk/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from requests import Response

from globus_sdk import _guards

log = logging.getLogger(__name__)

if t.TYPE_CHECKING:
Expand Down Expand Up @@ -134,7 +136,7 @@ def get(self, key: str, default: t.Any = None) -> t.Any:
``get`` is just an alias for ``data.get(key, default)``, but with the added
checks that if ``data`` is ``None`` or a list, it returns the default.
"""
if self.data is None or isinstance(self.data, list):
if _guards.is_optional(self.data, list):
return default
# NB: `default` is provided as a positional because the native dict type
# doesn't recognize a keyword argument `default`
Expand Down
5 changes: 2 additions & 3 deletions src/globus_sdk/services/auth/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import jwt
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey

from globus_sdk import _guards
from globus_sdk.authorizers import GlobusAuthorizer

if sys.version_info >= (3, 8):
Expand Down Expand Up @@ -701,9 +702,7 @@ def oauth2_validate_token(

# if this client has no way of authenticating itself but
# it does have a client_id, we'll send that in the request
no_authentication = self.authorizer is None or isinstance(
self.authorizer, NullAuthorizer
)
no_authentication = _guards.is_optional(self.authorizer, NullAuthorizer)
if no_authentication and self.client_id:
log.debug("Validating token with unauthenticated client")
body.update({"client_id": self.client_id})
Expand Down
3 changes: 2 additions & 1 deletion src/globus_sdk/services/flows/errors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from globus_sdk import _guards
from globus_sdk.exc import ErrorSubdocument, GlobusAPIError


Expand Down Expand Up @@ -27,7 +28,7 @@ def _parse_undefined_error_format(self) -> bool:
self.code = self._extract_code_from_error_array(self.errors)

details = self._dict_data["error"].get("detail")
if isinstance(details, list):
if _guards.is_list_of(details, dict):
self.messages = [
error_detail["msg"]
for error_detail in details
Expand Down
13 changes: 4 additions & 9 deletions src/globus_sdk/services/timer/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import typing as t

from globus_sdk import _guards
from globus_sdk.exc import ErrorSubdocument, GlobusAPIError


Expand Down Expand Up @@ -54,7 +55,7 @@ def _parse_undefined_error_format(self) -> bool:
self.code = self._extract_code_from_error_array(self.errors)
self.messages = self._extract_messages_from_error_array(self.errors)
return True
elif isinstance(self._dict_data.get("detail"), list):
elif _guards.is_list_of(self._dict_data.get("detail"), dict):
# FIXME:
# the 'code' is currently being set explicitly by the
# SDK in this case even though none was provided by
Expand All @@ -63,11 +64,7 @@ def _parse_undefined_error_format(self) -> bool:
self.code = "Validation Error"

# collect the errors array from details
self.errors = [
ErrorSubdocument(d)
for d in self._dict_data["detail"]
if isinstance(d, dict)
]
self.errors = [ErrorSubdocument(d) for d in self._dict_data["detail"]]

# drop error objects which don't have the relevant fields
# and then build custom 'messages' for Globus Timers errors
Expand All @@ -87,8 +84,6 @@ def _details_from_errors(
if not isinstance(d.get("msg"), str):
continue
loc_list = d.get("loc")
if not isinstance(loc_list, list):
continue
if not all(isinstance(path_item, str) for path_item in loc_list):
if not _guards.is_list_of(loc_list, str):
continue
yield d.raw
31 changes: 31 additions & 0 deletions tests/non-pytest/mypy-ignore-tests/test_guards.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# test that the internal _guards module provides valid and well-formed type-guards
import typing as t

from globus_sdk import _guards


def get_any() -> t.Any:
return 1


x = get_any()
t.assert_type(x, t.Any)

# test is_list_of
if _guards.is_list_of(x, str):
t.assert_type(x, list[str])
elif _guards.is_list_of(x, int):
t.assert_type(x, list[int])

# test is_optional
if _guards.is_optional(x, float):
t.assert_type(x, float | None)
elif _guards.is_optional(x, bytes):
t.assert_type(x, bytes | None)


# test is_optional_list_of
if _guards.is_optional_list_of(x, type(None)):
t.assert_type(x, list[None] | None)
elif _guards.is_optional_list_of(x, dict):
t.assert_type(x, list[dict[t.Any, t.Any]] | None)
63 changes: 63 additions & 0 deletions tests/unit/test_guards.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import pytest

from globus_sdk import _guards


@pytest.mark.parametrize(
"value, typ, ok",
[
# passing
([], str, True),
([1, 2], int, True),
(["1", ""], str, True),
([], list, True),
([[], [1, 2], ["foo"]], list, True),
# failing
([1], str, False),
(["foo"], int, False),
((1, 2), int, False),
(list, list, False),
(list, str, False),
(["foo", 1], str, False),
([1, 2], list, False),
],
)
def test_list_of_guard(value, typ, ok):
assert _guards.is_list_of(value, typ) == ok


@pytest.mark.parametrize(
"value, typ, ok",
[
# passing
(None, str, True),
("foo", str, True),
# failing
(b"foo", str, False),
("", int, False),
(type(None), str, False),
],
)
def test_opt_guard(value, typ, ok):
assert _guards.is_optional(value, typ) == ok


@pytest.mark.parametrize(
"value, typ, ok",
[
# passing
([], str, True),
([], int, True),
([1, 2], int, True),
(["1", ""], str, True),
(None, str, True),
# failing
# NB: the guard checks `list[str] | None`, not `list[str | None]`
([None], str, False),
(b"foo", str, False),
("", str, False),
(type(None), str, False),
],
)
def test_opt_list_guard(value, typ, ok):
assert _guards.is_optional_list_of(value, typ) == ok