From 5905d8615887b22d277855e985e4d8da5b73e64e Mon Sep 17 00:00:00 2001 From: Stephen Rosen Date: Sun, 6 Aug 2023 14:20:28 -0500 Subject: [PATCH 1/6] Make GARE validation as simple as possible Rather than trying to build a framework, this is the "no framework" solution, in which we define very concrete helpers for very concrete purposes, including a dedicated one for the "ConsentRequired" Literal. All of the GARE classes now inherit from an internal helper class, Serializable, which defines from_dict, to_dict, and _supported_fields in a generic way. The last of these *does* use some signature inspection magic, but nothing too abstruse. The basic transformation is to replace each combination of a class-level annotation + a `SUPPORTED_FIELDS` entry with relevant assignment in `__init__`. Some tangentially related simplifications and minor improvements are included: - `extra_fields -> extra` - Removal of unnecessary str splitting (after object is initialized) - `isinstance` checking can also handle deserialization of dicts - checking for non-null values will not accept `session_message`-only GARE data -- at least one of the semantic fields is required by the rewritten check - GlobusAuthRequirementsError does not inherit from `GlobusError` or `Exception` -- it is not clear that this inheritance is useful or instructive to any user, since it mixes Exception hierarchies with data-representing hierarchies - replace direct `ValueError` usage with a custom ValidationError class -- this avoids any messy scenarios in which a ValueError is accidentally introduced and incorrectly caught (e.g. from a stdlib call) There are also some more significant improvements included. Most notably: - Annotations explicitly do not accept `None` unless it is a valid value (i.e. annotations align with validation requirements) - to_dict can now look for a concrete type (`Serializable`) and therefore can automatically invoke `to_dict` down a tree of objects Although brevity is a non-goal of this changeset -- more verbose but clearer would be acceptable -- the result is almost 200 LOC lighter in the `src/` tree. The primary ways in which things became shorter appear to be: - explicit version is often much shorter than the framework-ized version (e.g. `LegacyConsentRequiredAPError`), and even where the two are close, the explicit version is shorter - `to_dict` and `from_dict` centralization --- .../auth_requirements_error/_serializable.py | 60 +++++ .../auth_requirements_error/_validators.py | 77 ++++++ .../auth_requirements_error/_variants.py | 239 +++++------------- .../auth_requirements_error.py | 185 +++----------- .../auth_requirements_error/utils.py | 7 +- .../auth_requirements_error/validators.py | 103 -------- .../test_auth_requirements_error.py | 57 ++--- 7 files changed, 269 insertions(+), 459 deletions(-) create mode 100644 src/globus_sdk/experimental/auth_requirements_error/_serializable.py create mode 100644 src/globus_sdk/experimental/auth_requirements_error/_validators.py delete mode 100644 src/globus_sdk/experimental/auth_requirements_error/validators.py diff --git a/src/globus_sdk/experimental/auth_requirements_error/_serializable.py b/src/globus_sdk/experimental/auth_requirements_error/_serializable.py new file mode 100644 index 000000000..c1931fc1f --- /dev/null +++ b/src/globus_sdk/experimental/auth_requirements_error/_serializable.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import inspect +import typing as t + +T = t.TypeVar("T", bound="Serializable") + + +class Serializable: + _EXLUDE_VARS: t.ClassVar[tuple[str, ...]] = ("self", "extra_fields", "extra") + extra: dict[str, t.Any] + + @classmethod + def _supported_fields(cls) -> list[str]: + signature = inspect.signature(cls.__init__) + return [ + name for name in signature.parameters.keys() if name not in cls._EXLUDE_VARS + ] + + @classmethod + def from_dict(cls: type[T], data: dict[str, t.Any]) -> T: + """ + Instantiate from a dictionary. + + :param data: The dictionary to create the error from. + :type data: dict + """ + + # Extract any extra fields + extras = {k: v for k, v in data.items() if k not in cls._supported_fields()} + kwargs: dict[str, t.Any] = {"extra": extras} + # Ensure required fields are supplied + for field_name in cls._supported_fields(): + kwargs[field_name] = data.get(field_name) + + return cls(**kwargs) + + def to_dict(self, include_extra: bool = False) -> dict[str, t.Any]: + """ + Render to a dictionary. + + :param include_extra: Whether to include stored extra (non-standard) fields in + the returned dictionary. + :type include_extra: bool + """ + result = {} + + # Set any authorization parameters + for field in self._supported_fields(): + value = getattr(self, field) + if value is not None: + if isinstance(value, Serializable): + value = value.to_dict(include_extra=include_extra) + result[field] = value + + # Set any extra fields + if include_extra: + result.update(self.extra) + + return result diff --git a/src/globus_sdk/experimental/auth_requirements_error/_validators.py b/src/globus_sdk/experimental/auth_requirements_error/_validators.py new file mode 100644 index 000000000..69a054cb6 --- /dev/null +++ b/src/globus_sdk/experimental/auth_requirements_error/_validators.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import sys +import typing as t + +from ._serializable import Serializable + +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + +S = t.TypeVar("S", bound=Serializable) + + +class ValidationError(ValueError): + pass + + +def str_(name: str, value: t.Any) -> str: + if not isinstance(value, str): + raise ValidationError(f"'{name}' must be a string") + return value + + +def opt_str(name: str, value: t.Any) -> str | None: + if value is None: + return None + if not isinstance(value, str): + raise ValidationError(f"'{name}' must be a string") + return value + + +def consent_required_literal(name: str, value: t.Any) -> Literal["ConsentRequired"]: + if not isinstance(value, str) or value != "ConsentRequired": + raise ValidationError(f"'{name}' must be the string 'ConsentRequired'") + return t.cast(Literal["ConsentRequired"], value) + + +def opt_bool(name: str, value: t.Any) -> bool | None: + if value is None: + return None + if not isinstance(value, bool): + raise ValidationError(f"'{name}' must be a bool") + return value + + +def str_list(name: str, value: t.Any) -> list[str]: + if not (isinstance(value, list) and all(isinstance(s, str) for s in value)): + raise ValidationError(f"'{name}' must be a list of strings") + return value + + +def opt_str_list(name: str, value: t.Any) -> list[str] | None: + if value is None: + return None + if not (isinstance(value, list) and all(isinstance(s, str) for s in value)): + raise ValidationError(f"'{name}' must be a list of strings") + return value + + +def opt_str_list_or_commasep(name: str, value: t.Any) -> list[str] | None: + if value is None: + return None + if isinstance(value, str): + value = value.split(",") + if not (isinstance(value, list) and all(isinstance(s, str) for s in value)): + raise ValidationError(f"'{name}' must be a list of strings") + return value + + +def instance_or_dict(name: str, value: t.Any, cls: type[S]) -> S: + if isinstance(value, cls): + return value + if isinstance(value, dict): + return cls.from_dict(value) + raise ValidationError(f"'{name}' must be a '{cls.__name__}' object or a dictionary") diff --git a/src/globus_sdk/experimental/auth_requirements_error/_variants.py b/src/globus_sdk/experimental/auth_requirements_error/_variants.py index c1254d5ee..3bf21ec40 100644 --- a/src/globus_sdk/experimental/auth_requirements_error/_variants.py +++ b/src/globus_sdk/experimental/auth_requirements_error/_variants.py @@ -1,76 +1,50 @@ from __future__ import annotations +import sys import typing as t -from . import validators +from . import _serializable, _validators from .auth_requirements_error import ( GlobusAuthorizationParameters, GlobusAuthRequirementsError, ) -T = t.TypeVar("T", bound="LegacyAuthRequirementsErrorVariant") +if sys.version_info >= (3, 8): + from typing import Literal, Protocol +else: + from typing_extensions import Literal, Protocol +V = t.TypeVar("V", bound="LegacyAuthRequirementsErrorVariant") -class LegacyAuthRequirementsErrorVariant: + +class LegacyAuthRequirementsErrorVariant(Protocol): """ - Abstract base class for errors which can be converted to a - Globus Auth Requirements Error. + Protocol for errors which can be converted to a Globus Auth Requirements Error. """ - SUPPORTED_FIELDS: dict[str, t.Callable[[t.Any], t.Any]] = {} - @classmethod - def from_dict(cls: t.Type[T], error_dict: dict[str, t.Any]) -> T: - """ - Instantiate from an error dictionary. - - :param error_dict: The dictionary to instantiate the error from. - :type error_dict: dict - """ - # Extract any extra fields - extras = {k: v for k, v in error_dict.items() if k not in cls.SUPPORTED_FIELDS} - kwargs: dict[str, t.Any] = {"extra": extras} - # Ensure required fields are supplied - for field_name in cls.SUPPORTED_FIELDS.keys(): - kwargs[field_name] = error_dict.get(field_name) - - return cls(**kwargs) + def from_dict(cls: type[V], data: dict[str, t.Any]) -> V: + pass def to_auth_requirements_error(self) -> GlobusAuthRequirementsError: - raise NotImplementedError() + ... -class LegacyConsentRequiredTransferError(LegacyAuthRequirementsErrorVariant): +class LegacyConsentRequiredTransferError(_serializable.Serializable): """ The ConsentRequired error format emitted by the Globus Transfer service. """ - code: str - required_scopes: list[str] - extra_fields: dict[str, t.Any] - - SUPPORTED_FIELDS = { - "code": validators.StringLiteral("ConsentRequired"), - "required_scopes": validators.ListOfStrings, - } - def __init__( self, *, - code: str, + code: Literal["ConsentRequired"], required_scopes: list[str], extra: dict[str, t.Any] | None = None, - ): # pylint: disable=unused-argument - # Validate and assign supported fields - for field_name, validator in self.SUPPORTED_FIELDS.items(): - try: - field_value = validator(locals()[field_name]) - except ValueError as e: - raise ValueError(f"Error validating field '{field_name}': {e}") from e - - setattr(self, field_name, field_value) - - self.extra_fields = extra or {} + ): + self.code = _validators.consent_required_literal("code", code) + self.required_scopes = _validators.str_list("required_scopes", required_scopes) + self.extra = extra or {} def to_auth_requirements_error(self) -> GlobusAuthRequirementsError: """ @@ -80,44 +54,28 @@ def to_auth_requirements_error(self) -> GlobusAuthRequirementsError: code=self.code, authorization_parameters=GlobusAuthorizationParameters( required_scopes=self.required_scopes, - session_message=self.extra_fields.get("message"), + session_message=self.extra.get("message"), ), - extra=self.extra_fields, + extra=self.extra, ) -class LegacyConsentRequiredAPError(LegacyAuthRequirementsErrorVariant): +class LegacyConsentRequiredAPError(_serializable.Serializable): """ The ConsentRequired error format emitted by the legacy Globus Transfer Action Providers. """ - code: str - required_scope: str - extra_fields: dict[str, t.Any] - - SUPPORTED_FIELDS = { - "code": validators.StringLiteral("ConsentRequired"), - "required_scope": validators.String, - } - def __init__( self, *, - code: str, + code: Literal["ConsentRequired"], required_scope: str, extra: dict[str, t.Any] | None, - ): # pylint: disable=unused-argument - # Validate and assign supported fields - for field_name, validator in self.SUPPORTED_FIELDS.items(): - try: - field_value = validator(locals()[field_name]) - except ValueError as e: - raise ValueError(f"Error validating field '{field_name}': {e}") from e - - setattr(self, field_name, field_value) - - self.extra_fields = extra or {} + ): + self.code = _validators.consent_required_literal("code", code) + self.required_scope = _validators.str_("required_scope", required_scope) + self.extra = extra or {} def to_auth_requirements_error(self) -> GlobusAuthRequirementsError: """ @@ -130,46 +88,21 @@ def to_auth_requirements_error(self) -> GlobusAuthRequirementsError: code=self.code, authorization_parameters=GlobusAuthorizationParameters( required_scopes=[self.required_scope], - session_message=self.extra_fields.get("description"), - extra=self.extra_fields.get("authorization_parameters"), + session_message=self.extra.get("description"), + extra=self.extra.get("authorization_parameters"), ), extra={ - k: v - for k, v in self.extra_fields.items() - if k != "authorization_parameters" + k: v for k, v in self.extra.items() if k != "authorization_parameters" }, ) -class LegacyAuthorizationParameters: +class LegacyAuthorizationParameters(_serializable.Serializable): """ An Authorization Parameters object that describes all known variants in use by Globus services. """ - session_message: str | None - session_required_identities: list[str] | None - session_required_policies: str | list[str] | None - session_required_single_domain: str | list[str] | None - session_required_mfa: bool | None - # Declared here for compatibility with mixed legacy payloads - required_scopes: list[str] | None - extra_fields: dict[str, t.Any] - - DEFAULT_CODE = "AuthorizationRequired" - - SUPPORTED_FIELDS = { - "session_message": validators.OptionalString, - "session_required_identities": validators.OptionalListOfStrings, - "session_required_policies": ( - validators.OptionalListOfStringsOrCommaDelimitedStrings - ), - "session_required_single_domain": ( - validators.OptionalListOfStringsOrCommaDelimitedStrings - ), - "session_required_mfa": validators.OptionalBool, - } - def __init__( self, *, @@ -179,26 +112,32 @@ def __init__( session_required_single_domain: str | list[str] | None = None, session_required_mfa: bool | None = None, extra: dict[str, t.Any] | None = None, - ): # pylint: disable=unused-argument - # Validate and assign supported fields - for field_name, validator in self.SUPPORTED_FIELDS.items(): - try: - field_value = validator(locals()[field_name]) - except ValueError as e: - raise ValueError(f"Error validating field '{field_name}': {e}") from e - - setattr(self, field_name, field_value) - - # Retain any additional fields - self.extra_fields = extra or {} + ): + self.session_message = _validators.opt_str("session_message", session_message) + self.session_required_identities = _validators.opt_str_list( + "session_required_identities", session_required_identities + ) + self.session_required_policies = _validators.opt_str_list_or_commasep( + "session_required_policies", session_required_policies + ) + self.session_required_single_domain = _validators.opt_str_list_or_commasep( + "session_required_single_domain", session_required_single_domain + ) + self.session_required_mfa = _validators.opt_bool( + "session_required_mfa", session_required_mfa + ) + self.extra = extra or {} # Enforce that the error contains at least one of the fields we expect + requires_at_least_one = [ + name for name in self._supported_fields() if name != "session_message" + ] if all( - getattr(self, field_name) is None for field_name in self.SUPPORTED_FIELDS + getattr(self, field_name) is None for field_name in requires_at_least_one ): - raise ValueError( + raise _validators.ValidationError( "Must include at least one supported authorization parameter: " - ", ".join(self.SUPPORTED_FIELDS.keys()) + + ", ".join(requires_at_least_one) ) def to_authorization_parameters( @@ -211,62 +150,24 @@ def to_authorization_parameters( Normalizes fields that may have been provided as comma-delimited strings to lists of strings. """ - required_policies = self.session_required_policies - if isinstance(required_policies, str): - required_policies = required_policies.split(",") - - # TODO: Unnecessary? - required_single_domain = self.session_required_single_domain - if isinstance(required_single_domain, str): - required_single_domain = required_single_domain.split(",") - return GlobusAuthorizationParameters( session_message=self.session_message, session_required_identities=self.session_required_identities, session_required_mfa=self.session_required_mfa, - session_required_policies=required_policies, - session_required_single_domain=required_single_domain, - extra=self.extra_fields, + session_required_policies=self.session_required_policies, + session_required_single_domain=self.session_required_single_domain, + extra=self.extra, ) - @classmethod - def from_dict(cls, param_dict: dict[str, t.Any]) -> LegacyAuthorizationParameters: - """ - Instantiate from an authorization_parameters dictionary. - - :param param_dict: The dictionary to create the AuthorizationParameters from. - :type param_dict: dict - """ - - # Extract any extra fields - extras = {k: v for k, v in param_dict.items() if k not in cls.SUPPORTED_FIELDS} - kwargs: dict[str, t.Any] = {"extra": extras} - # Ensure required fields are supplied - for field_name in cls.SUPPORTED_FIELDS.keys(): - kwargs[field_name] = param_dict.get(field_name) - - return cls(**kwargs) - -class LegacyAuthorizationParametersError(LegacyAuthRequirementsErrorVariant): +class LegacyAuthorizationParametersError(_serializable.Serializable): """ Defines an Authorization Parameters error that describes all known variants in use by Globus services. """ - authorization_parameters: LegacyAuthorizationParameters - code: str - extra_fields: dict[str, t.Any] - DEFAULT_CODE = "AuthorizationRequired" - SUPPORTED_FIELDS = { - "code": validators.String, - "authorization_parameters": validators.ClassInstance( - LegacyAuthorizationParameters - ), - } - def __init__( self, *, @@ -275,25 +176,13 @@ def __init__( extra: dict[str, t.Any] | None = None, ): # Apply default, if necessary - code = code or self.DEFAULT_CODE - - # Convert authorization_parameters if it's a dict - if isinstance(authorization_parameters, dict): - authorization_parameters = LegacyAuthorizationParameters.from_dict( - param_dict=authorization_parameters - ) - - # Validate and assign supported fields - for field_name, validator in self.SUPPORTED_FIELDS.items(): - try: - field_value = validator(locals()[field_name]) - except ValueError as e: - raise ValueError(f"Error validating field '{field_name}': {e}") from e - - setattr(self, field_name, field_value) - - # Retain any additional fields - self.extra_fields = extra or {} + self.code = _validators.str_("code", code or self.DEFAULT_CODE) + self.authorization_parameters = _validators.instance_or_dict( + "authorization_parameters", + authorization_parameters, + LegacyAuthorizationParameters, + ) + self.extra = extra or {} def to_auth_requirements_error(self) -> GlobusAuthRequirementsError: """ @@ -305,5 +194,5 @@ def to_auth_requirements_error(self) -> GlobusAuthRequirementsError: return GlobusAuthRequirementsError( authorization_parameters=authorization_parameters, code=self.code, - extra=self.extra_fields, + extra=self.extra, ) diff --git a/src/globus_sdk/experimental/auth_requirements_error/auth_requirements_error.py b/src/globus_sdk/experimental/auth_requirements_error/auth_requirements_error.py index 42118413b..f47781d38 100644 --- a/src/globus_sdk/experimental/auth_requirements_error/auth_requirements_error.py +++ b/src/globus_sdk/experimental/auth_requirements_error/auth_requirements_error.py @@ -2,12 +2,10 @@ import typing as t -from globus_sdk.exc import GlobusError +from . import _serializable, _validators -from . import validators - -class GlobusAuthorizationParameters: +class GlobusAuthorizationParameters(_serializable.Serializable): """ Represents authorization parameters that can be used to instruct a client which additional authorizations are needed in order to complete a request. @@ -33,28 +31,11 @@ class GlobusAuthorizationParameters: :ivar required_scopes: A list of scopes for which consent is required. :vartype required_scopes: list of str, optional - :ivar extra_fields: A dictionary of additional fields that were provided. May + :ivar extra: A dictionary of additional fields that were provided. May be used for forward/backward compatibility. - :vartype extra_fields: dict + :vartype extra: dict """ - session_message: str | None - session_required_identities: list[str] | None - session_required_policies: list[str] | None - session_required_single_domain: list[str] | None - session_required_mfa: bool | None - required_scopes: list[str] | None - extra_fields: dict[str, t.Any] - - SUPPORTED_FIELDS = { - "session_message": validators.OptionalString, - "session_required_identities": validators.OptionalListOfStrings, - "session_required_policies": validators.OptionalListOfStrings, - "session_required_single_domain": validators.OptionalListOfStrings, - "session_required_mfa": validators.OptionalBool, - "required_scopes": validators.OptionalListOfStrings, - } - def __init__( self, *, @@ -65,68 +46,39 @@ def __init__( session_required_mfa: bool | None = None, required_scopes: list[str] | None = None, extra: dict[str, t.Any] | None = None, - ): # pylint: disable=unused-argument - # Validate and assign supported fields - for field_name, validator in self.SUPPORTED_FIELDS.items(): - try: - field_value = validator(locals()[field_name]) - except ValueError as e: - raise ValueError(f"Error validating field '{field_name}': {e}") from e - - setattr(self, field_name, field_value) - - self.extra_fields = extra or {} + ): + self.session_message = _validators.opt_str("session_message", session_message) + self.session_required_identities = _validators.opt_str_list( + "session_required_identities", session_required_identities + ) + self.session_required_policies = _validators.opt_str_list( + "session_required_policies", session_required_policies + ) + self.session_required_single_domain = _validators.opt_str_list( + "session_required_single_domain", session_required_single_domain + ) + self.session_required_mfa = _validators.opt_bool( + "session_required_mfa", session_required_mfa + ) + self.required_scopes = _validators.opt_str_list( + "required_scopes", required_scopes + ) + self.extra = extra or {} # Enforce that the error contains at least one of the fields we expect + requires_at_least_one = [ + name for name in self._supported_fields() if name != "session_message" + ] if all( - getattr(self, field_name) is None for field_name in self.SUPPORTED_FIELDS + getattr(self, field_name) is None for field_name in requires_at_least_one ): - raise ValueError( + raise _validators.ValidationError( "Must include at least one supported authorization parameter: " - + ", ".join(self.SUPPORTED_FIELDS.keys()) + + ", ".join(requires_at_least_one) ) - @classmethod - def from_dict(cls, param_dict: dict[str, t.Any]) -> GlobusAuthorizationParameters: - """ - Instantiate from an authorization parameters dictionary. - - :param param_dict: The dictionary to create the error from. - :type param_dict: dict - """ - - # Extract any extra fields - extras = {k: v for k, v in param_dict.items() if k not in cls.SUPPORTED_FIELDS} - kwargs: dict[str, t.Any] = {"extra": extras} - # Ensure required fields are supplied - for field_name in cls.SUPPORTED_FIELDS.keys(): - kwargs[field_name] = param_dict.get(field_name) - - return cls(**kwargs) - - def to_dict(self, include_extra: bool = False) -> dict[str, t.Any]: - """ - Return an authorization parameters dictionary. - :param include_extra: Whether to include stored extra (non-standard) fields in - the returned dictionary. - :type include_extra: bool - """ - error_dict = {} - - # Set any authorization parameters - for field in self.SUPPORTED_FIELDS.keys(): - if getattr(self, field) is not None: - error_dict[field] = getattr(self, field) - - # Set any extra fields - if include_extra: - error_dict.update(self.extra_fields) - - return error_dict - - -class GlobusAuthRequirementsError(GlobusError): +class GlobusAuthRequirementsError(_serializable.Serializable): """ Represents a Globus Auth Requirements Error. @@ -141,81 +93,22 @@ class GlobusAuthRequirementsError(GlobusError): :ivar authorization_parameters: The authorization parameters for this error. :vartype authorization_parameters: GlobusAuthorizationParameters - :ivar extra_fields: A dictionary of additional fields that were provided. May + :ivar extra: A dictionary of additional fields that were provided. May be used for forward/backward compatibility. - :vartype extra_fields: dict + :vartype extra: dict """ - code: str - authorization_parameters: GlobusAuthorizationParameters - extra_fields: dict[str, t.Any] - - SUPPORTED_FIELDS = { - "code": validators.String, - "authorization_parameters": validators.ClassInstance( - GlobusAuthorizationParameters - ), - } - def __init__( self, - code: str, # pylint: disable=unused-argument + code: str, authorization_parameters: dict[str, t.Any] | GlobusAuthorizationParameters, *, extra: dict[str, t.Any] | None = None, ): - # Convert authorization_parameters if it's a dict - if isinstance(authorization_parameters, dict): - authorization_parameters = GlobusAuthorizationParameters.from_dict( - param_dict=authorization_parameters - ) - - # Validate and assign supported fields - for field_name, validator in self.SUPPORTED_FIELDS.items(): - try: - field_value = validator(locals()[field_name]) - except ValueError as e: - raise ValueError(f"Error validating field '{field_name}': {e}") from e - - setattr(self, field_name, field_value) - - self.extra_fields = extra or {} - - @classmethod - def from_dict(cls, error_dict: dict[str, t.Any]) -> GlobusAuthRequirementsError: - """ - Instantiate a GlobusAuthRequirementsError from a dictionary. - - :param error_dict: The dictionary to create the error from. - :type error_dict: dict - """ - - # Extract any extra fields - extras = {k: v for k, v in error_dict.items() if k not in cls.SUPPORTED_FIELDS} - kwargs: dict[str, t.Any] = {"extra": extras} - # Ensure required fields are supplied - for field_name in cls.SUPPORTED_FIELDS.keys(): - kwargs[field_name] = error_dict.get(field_name) - - return cls(**kwargs) - - def to_dict(self, include_extra: bool = False) -> dict[str, t.Any]: - """ - Return a Globus Auth Requirements Error response dictionary. - - :param include_extra: Whether to include stored extra (non-standard) fields - in the dictionary. - :type include_extra: bool, optional (default: False) - """ - error_dict = { - "code": self.code, - "authorization_parameters": self.authorization_parameters.to_dict( - include_extra=include_extra - ), - } - - # Set any extra fields - if include_extra: - error_dict.update(self.extra_fields) - - return error_dict + self.code = _validators.str_("code", code) + self.authorization_parameters = _validators.instance_or_dict( + "authorization_parameters", + authorization_parameters, + GlobusAuthorizationParameters, + ) + self.extra = extra or {} diff --git a/src/globus_sdk/experimental/auth_requirements_error/utils.py b/src/globus_sdk/experimental/auth_requirements_error/utils.py index 92d9c6bb1..8d2536599 100644 --- a/src/globus_sdk/experimental/auth_requirements_error/utils.py +++ b/src/globus_sdk/experimental/auth_requirements_error/utils.py @@ -5,6 +5,7 @@ from globus_sdk.exc import ErrorSubdocument, GlobusAPIError +from . import _validators from ._variants import ( LegacyAuthorizationParametersError, LegacyAuthRequirementsErrorVariant, @@ -56,10 +57,10 @@ def to_auth_requirements_error( # Prefer a proper auth requirements error, if possible try: return GlobusAuthRequirementsError.from_dict(error_dict) - except ValueError as err: + except _validators.ValidationError as err: log.debug(f"Failed to parse error as 'GlobusAuthRequirementsError' ({err})") - supported_variants: list[t.Type[LegacyAuthRequirementsErrorVariant]] = [ + supported_variants: list[type[LegacyAuthRequirementsErrorVariant]] = [ LegacyAuthorizationParametersError, LegacyConsentRequiredTransferError, LegacyConsentRequiredAPError, @@ -67,7 +68,7 @@ def to_auth_requirements_error( for variant in supported_variants: try: return variant.from_dict(error_dict).to_auth_requirements_error() - except ValueError as err: + except _validators.ValidationError as err: log.debug(f"Failed to parse error as '{variant.__name__}' ({err})") return None diff --git a/src/globus_sdk/experimental/auth_requirements_error/validators.py b/src/globus_sdk/experimental/auth_requirements_error/validators.py deleted file mode 100644 index 3e824d363..000000000 --- a/src/globus_sdk/experimental/auth_requirements_error/validators.py +++ /dev/null @@ -1,103 +0,0 @@ -from __future__ import annotations - -import typing as t -from contextlib import suppress - -T = t.TypeVar("T") - - -def _string(value: t.Any) -> str: - if not isinstance(value, str): - raise ValueError("Must be a string") - - return value - - -def _string_literal(literal: str) -> t.Callable[[t.Any], str]: - def validator(value: t.Any) -> str: - value = _string(value) - if value != literal: - raise ValueError(f"Must be '{literal}'") - - return t.cast(str, value) - - return validator - - -def _class_instance(cls: t.Type[T]) -> t.Callable[[t.Any], T]: - def validator(value: t.Any) -> T: - if not isinstance(value, cls): - raise ValueError(f"Must be an instance of {cls.__name__}") - - return value - - return validator - - -def _list_of_strings(value: t.Any) -> list[str]: - if not (isinstance(value, list) and all(isinstance(v, str) for v in value)): - raise ValueError("Must be a list of strings") - - return value - - -def _comma_delimited_strings(value: t.Any) -> list[str]: - if not isinstance(value, str): - raise ValueError("Must be a comma-delimited string") - - return value.split(",") - - -def _boolean(value: t.Any) -> bool: - if not isinstance(value, bool): - raise ValueError("Must be a bool") - - return value - - -def _null(value: t.Any) -> None: - if value is not None: - raise ValueError("Must be null") - - return None - - -def _anyof( - validators: tuple[t.Callable[..., t.Any], ...], - description: str, -) -> t.Callable[..., t.Any]: - def _anyof_validator(value: t.Any) -> t.Any: - for validator in validators: - with suppress(ValueError): - return validator(value) - - raise ValueError(f"Must be {description}") - - return _anyof_validator - - -String: t.Callable[[t.Any], str] = _string -StringLiteral: t.Callable[[str], t.Callable[[t.Any], str]] = _string_literal -ClassInstance: t.Callable[[t.Any], t.Any] = _class_instance -ListOfStrings: t.Callable[[t.Any], list[str]] = _list_of_strings -CommaDelimitedStrings: t.Callable[[t.Any], list[str]] = _comma_delimited_strings -Bool: t.Callable[[t.Any], bool] = _boolean -Null: t.Callable[[t.Any], None] = _null -OptionalString: t.Callable[[t.Any], str | None] = _anyof( - (_string, _null), description="a string or null" -) -OptionalListOfStrings: t.Callable[[t.Any], list[str] | None] = _anyof( - (_list_of_strings, _null), description="a list of strings or null" -) -OptionalCommaDelimitedStrings: t.Callable[[t.Any], list[str] | None] = _anyof( - (_comma_delimited_strings, _null), description="a comma-delimited string or null" -) -OptionalListOfStringsOrCommaDelimitedStrings: t.Callable[ - [t.Any], list[str] | None -] = _anyof( - (_list_of_strings, _comma_delimited_strings, _null), - description="a list of strings, a comma-delimited string, or null", -) -OptionalBool: t.Callable[[t.Any], bool | None] = _anyof( - (_boolean, _null), description="a bool or null" -) diff --git a/tests/unit/experimental/test_auth_requirements_error.py b/tests/unit/experimental/test_auth_requirements_error.py index 9c0463bd6..31298b64f 100644 --- a/tests/unit/experimental/test_auth_requirements_error.py +++ b/tests/unit/experimental/test_auth_requirements_error.py @@ -1,5 +1,3 @@ -import inspect - import pytest from globus_sdk._testing import construct_error @@ -408,42 +406,37 @@ def test_backward_compatibility_consent_required_error(): @pytest.mark.parametrize( - "target_class", - [ - GlobusAuthRequirementsError, - GlobusAuthorizationParameters, - _variants.LegacyAuthorizationParameters, - _variants.LegacyAuthorizationParametersError, - _variants.LegacyConsentRequiredTransferError, - _variants.LegacyConsentRequiredAPError, - ], -) -def test_constructors_include_all_supported_fields(target_class): - """ - Test that all supported fields are included in the constructors. - """ - - method_sig = inspect.signature(target_class.__init__) - for field_name in target_class.SUPPORTED_FIELDS: - # Make sure the constructor has a parameter for this field - assert field_name in method_sig.parameters - - -@pytest.mark.parametrize( - "target_class, field_name", + "target_class, data, expect_message", [ - (GlobusAuthRequirementsError, "code"), - (_variants.LegacyAuthorizationParametersError, "authorization_parameters"), - (_variants.LegacyConsentRequiredTransferError, "code"), - (_variants.LegacyConsentRequiredAPError, "code"), + ( # missing 'code' + GlobusAuthRequirementsError, + {"authorization_parameters": {"session_required_policies": "foo"}}, + "'code' must be a string", + ), + ( # missing 'authorization_parameters' + _variants.LegacyAuthorizationParametersError, + {}, + "'authorization_parameters' must be a 'LegacyAuthorizationParameters' " + "object or a dictionary", + ), + ( # missing 'code' + _variants.LegacyConsentRequiredTransferError, + {"required_scopes": []}, + "'code' must be the string 'ConsentRequired'", + ), + ( # missing 'code' + _variants.LegacyConsentRequiredAPError, + {"required_scope": "foo"}, + "'code' must be the string 'ConsentRequired'", + ), ], ) -def test_error_from_dict_insufficient_input(target_class, field_name): +def test_error_from_dict_insufficient_input(target_class, data, expect_message): """ """ with pytest.raises(ValueError) as exc_info: - target_class.from_dict({}) + target_class.from_dict(data) - assert f"Error validating field '{field_name}'" in str(exc_info.value) + assert str(exc_info.value) == expect_message @pytest.mark.parametrize( From bd735e1d2a2c7a5a9f33fd6f098dd7699b8a39e2 Mon Sep 17 00:00:00 2001 From: Stephen Rosen Date: Mon, 7 Aug 2023 13:49:56 -0500 Subject: [PATCH 2/6] Update src/globus_sdk/experimental/auth_requirements_error/_serializable.py Co-authored-by: Ada <107940310+ada-globus@users.noreply.github.com> --- .../experimental/auth_requirements_error/_serializable.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/globus_sdk/experimental/auth_requirements_error/_serializable.py b/src/globus_sdk/experimental/auth_requirements_error/_serializable.py index c1931fc1f..4ae2a7034 100644 --- a/src/globus_sdk/experimental/auth_requirements_error/_serializable.py +++ b/src/globus_sdk/experimental/auth_requirements_error/_serializable.py @@ -7,14 +7,14 @@ class Serializable: - _EXLUDE_VARS: t.ClassVar[tuple[str, ...]] = ("self", "extra_fields", "extra") + _EXCLUDE_VARS: t.ClassVar[tuple[str, ...]] = ("self", "extra") extra: dict[str, t.Any] @classmethod def _supported_fields(cls) -> list[str]: signature = inspect.signature(cls.__init__) return [ - name for name in signature.parameters.keys() if name not in cls._EXLUDE_VARS + name for name in signature.parameters.keys() if name not in cls._EXCLUDE_VARS ] @classmethod From ac86722acd2f17d9a92ff23eb3a9aed69e209aaf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 7 Aug 2023 18:50:10 +0000 Subject: [PATCH 3/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../experimental/auth_requirements_error/_serializable.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/globus_sdk/experimental/auth_requirements_error/_serializable.py b/src/globus_sdk/experimental/auth_requirements_error/_serializable.py index 4ae2a7034..c976e7e61 100644 --- a/src/globus_sdk/experimental/auth_requirements_error/_serializable.py +++ b/src/globus_sdk/experimental/auth_requirements_error/_serializable.py @@ -14,7 +14,9 @@ class Serializable: def _supported_fields(cls) -> list[str]: signature = inspect.signature(cls.__init__) return [ - name for name in signature.parameters.keys() if name not in cls._EXCLUDE_VARS + name + for name in signature.parameters.keys() + if name not in cls._EXCLUDE_VARS ] @classmethod From 8b6ca104405de782a37555716d9c4c369c9fbad9 Mon Sep 17 00:00:00 2001 From: Stephen Rosen Date: Mon, 7 Aug 2023 14:29:48 -0500 Subject: [PATCH 4/6] Minor doc adjustment extra_fields->extra --- docs/experimental/auth_requirements_errors.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/experimental/auth_requirements_errors.rst b/docs/experimental/auth_requirements_errors.rst index 1809b8a3c..b2605bde6 100644 --- a/docs/experimental/auth_requirements_errors.rst +++ b/docs/experimental/auth_requirements_errors.rst @@ -62,7 +62,7 @@ by specifying ``include_extra=True`` when calling ``to_dict()``. error.to_dict(include_extra=True) These fields are stored by both the ``GlobusAuthRequirementsError`` and -``GlobusAuthenticationParameters`` classes in an ``extra_fields`` attribute. +``GlobusAuthenticationParameters`` classes in an ``extra`` attribute. .. note:: @@ -136,4 +136,4 @@ supported field is supplied with a value of the wrong type. any logic specific to the Globus Auth service with regard to what represents a valid combination of fields (e.g., ``session_required_mfa`` requires either ``session_required_identities`` or ``session_required_single_domain`` -in order to be properly handled). \ No newline at end of file +in order to be properly handled). From ca2b0b36886e47c2a7ef282256a911b1ccfb513b Mon Sep 17 00:00:00 2001 From: Stephen Rosen Date: Mon, 7 Aug 2023 14:38:09 -0500 Subject: [PATCH 5/6] Relocate validation helper --- .../auth_requirements_error/_validators.py | 12 ------------ .../auth_requirements_error/_variants.py | 14 ++++++++++++-- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/globus_sdk/experimental/auth_requirements_error/_validators.py b/src/globus_sdk/experimental/auth_requirements_error/_validators.py index 69a054cb6..04987a2af 100644 --- a/src/globus_sdk/experimental/auth_requirements_error/_validators.py +++ b/src/globus_sdk/experimental/auth_requirements_error/_validators.py @@ -1,15 +1,9 @@ from __future__ import annotations -import sys import typing as t from ._serializable import Serializable -if sys.version_info >= (3, 8): - from typing import Literal -else: - from typing_extensions import Literal - S = t.TypeVar("S", bound=Serializable) @@ -31,12 +25,6 @@ def opt_str(name: str, value: t.Any) -> str | None: return value -def consent_required_literal(name: str, value: t.Any) -> Literal["ConsentRequired"]: - if not isinstance(value, str) or value != "ConsentRequired": - raise ValidationError(f"'{name}' must be the string 'ConsentRequired'") - return t.cast(Literal["ConsentRequired"], value) - - def opt_bool(name: str, value: t.Any) -> bool | None: if value is None: return None diff --git a/src/globus_sdk/experimental/auth_requirements_error/_variants.py b/src/globus_sdk/experimental/auth_requirements_error/_variants.py index 3bf21ec40..cd81da1c9 100644 --- a/src/globus_sdk/experimental/auth_requirements_error/_variants.py +++ b/src/globus_sdk/experimental/auth_requirements_error/_variants.py @@ -42,7 +42,7 @@ def __init__( required_scopes: list[str], extra: dict[str, t.Any] | None = None, ): - self.code = _validators.consent_required_literal("code", code) + self.code = _validate_consent_required_literal("code", code) self.required_scopes = _validators.str_list("required_scopes", required_scopes) self.extra = extra or {} @@ -73,7 +73,7 @@ def __init__( required_scope: str, extra: dict[str, t.Any] | None, ): - self.code = _validators.consent_required_literal("code", code) + self.code = _validate_consent_required_literal("code", code) self.required_scope = _validators.str_("required_scope", required_scope) self.extra = extra or {} @@ -196,3 +196,13 @@ def to_auth_requirements_error(self) -> GlobusAuthRequirementsError: code=self.code, extra=self.extra, ) + + +def _validate_consent_required_literal( + name: str, value: t.Any +) -> Literal["ConsentRequired"]: + if not isinstance(value, str) or value != "ConsentRequired": + raise _validators.ValidationError( + f"'{name}' must be the string 'ConsentRequired'" + ) + return t.cast(Literal["ConsentRequired"], value) From 9c2fe1acd64da5dd0322e767091b1f5c595acbc2 Mon Sep 17 00:00:00 2001 From: Stephen Rosen Date: Mon, 7 Aug 2023 15:06:01 -0500 Subject: [PATCH 6/6] Invert ordering of GARE validators This makes these more uniform. --- .../auth_requirements_error/__init__.py | 2 + .../auth_requirements_error/_validators.py | 40 +++++++++---------- .../auth_requirements_error/_variants.py | 8 ++-- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/globus_sdk/experimental/auth_requirements_error/__init__.py b/src/globus_sdk/experimental/auth_requirements_error/__init__.py index 7376a9b2d..012feedf7 100644 --- a/src/globus_sdk/experimental/auth_requirements_error/__init__.py +++ b/src/globus_sdk/experimental/auth_requirements_error/__init__.py @@ -1,3 +1,4 @@ +from ._validators import ValidationError from .auth_requirements_error import ( GlobusAuthorizationParameters, GlobusAuthRequirementsError, @@ -10,6 +11,7 @@ ) __all__ = [ + "ValidationError", "GlobusAuthRequirementsError", "GlobusAuthorizationParameters", "to_auth_requirements_error", diff --git a/src/globus_sdk/experimental/auth_requirements_error/_validators.py b/src/globus_sdk/experimental/auth_requirements_error/_validators.py index 04987a2af..c1873e62d 100644 --- a/src/globus_sdk/experimental/auth_requirements_error/_validators.py +++ b/src/globus_sdk/experimental/auth_requirements_error/_validators.py @@ -12,39 +12,37 @@ class ValidationError(ValueError): def str_(name: str, value: t.Any) -> str: - if not isinstance(value, str): - raise ValidationError(f"'{name}' must be a string") - return value + if isinstance(value, str): + return value + raise ValidationError(f"'{name}' must be a string") def opt_str(name: str, value: t.Any) -> str | None: if value is None: return None - if not isinstance(value, str): - raise ValidationError(f"'{name}' must be a string") - return value + if isinstance(value, str): + return value + raise ValidationError(f"'{name}' must be a string or null") def opt_bool(name: str, value: t.Any) -> bool | None: - if value is None: - return None - if not isinstance(value, bool): - raise ValidationError(f"'{name}' must be a bool") - return value + if value is None or isinstance(value, bool): + return value + raise ValidationError(f"'{name}' must be a bool or null") def str_list(name: str, value: t.Any) -> list[str]: - if not (isinstance(value, list) and all(isinstance(s, str) for s in value)): - raise ValidationError(f"'{name}' must be a list of strings") - return value + if isinstance(value, list) and all(isinstance(s, str) for s in value): + return value + raise ValidationError(f"'{name}' must be a list of strings") def opt_str_list(name: str, value: t.Any) -> list[str] | None: if value is None: return None - if not (isinstance(value, list) and all(isinstance(s, str) for s in value)): - raise ValidationError(f"'{name}' must be a list of strings") - return value + if isinstance(value, list) and all(isinstance(s, str) for s in value): + return value + raise ValidationError(f"'{name}' must be a list of strings or null") def opt_str_list_or_commasep(name: str, value: t.Any) -> list[str] | None: @@ -52,9 +50,11 @@ def opt_str_list_or_commasep(name: str, value: t.Any) -> list[str] | None: return None if isinstance(value, str): value = value.split(",") - if not (isinstance(value, list) and all(isinstance(s, str) for s in value)): - raise ValidationError(f"'{name}' must be a list of strings") - return value + if isinstance(value, list) and all(isinstance(s, str) for s in value): + return value + raise ValidationError( + f"'{name}' must be a list of strings or a comma-delimited string or null" + ) def instance_or_dict(name: str, value: t.Any, cls: type[S]) -> S: diff --git a/src/globus_sdk/experimental/auth_requirements_error/_variants.py b/src/globus_sdk/experimental/auth_requirements_error/_variants.py index cd81da1c9..488b74476 100644 --- a/src/globus_sdk/experimental/auth_requirements_error/_variants.py +++ b/src/globus_sdk/experimental/auth_requirements_error/_variants.py @@ -201,8 +201,6 @@ def to_auth_requirements_error(self) -> GlobusAuthRequirementsError: def _validate_consent_required_literal( name: str, value: t.Any ) -> Literal["ConsentRequired"]: - if not isinstance(value, str) or value != "ConsentRequired": - raise _validators.ValidationError( - f"'{name}' must be the string 'ConsentRequired'" - ) - return t.cast(Literal["ConsentRequired"], value) + if value == "ConsentRequired": + return "ConsentRequired" + raise _validators.ValidationError(f"'{name}' must be the string 'ConsentRequired'")