Skip to content

Commit

Permalink
Merge pull request #796 from sirosen/spartan-gare
Browse files Browse the repository at this point in the history
Make GARE validation as simple as possible
  • Loading branch information
sirosen authored Aug 7, 2023
2 parents 5562209 + 9c2fe1a commit ab6029e
Show file tree
Hide file tree
Showing 9 changed files with 271 additions and 461 deletions.
4 changes: 2 additions & 2 deletions docs/experimental/auth_requirements_errors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ by specifying ``include_extra=True`` when calling ``to_dict()``.
error.to_dict(include_extra=True)
These fields are stored by both the ``GlobusAuthRequirementsError`` and
``GlobusAuthenticationParameters`` classes in an ``extra_fields`` attribute.
``GlobusAuthenticationParameters`` classes in an ``extra`` attribute.

.. note::

Expand Down Expand Up @@ -136,4 +136,4 @@ supported field is supplied with a value of the wrong type.
any logic specific to the Globus Auth service with regard to what represents a valid
combination of fields (e.g., ``session_required_mfa`` requires either
``session_required_identities`` or ``session_required_single_domain``
in order to be properly handled).
in order to be properly handled).
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._validators import ValidationError
from .auth_requirements_error import (
GlobusAuthorizationParameters,
GlobusAuthRequirementsError,
Expand All @@ -10,6 +11,7 @@
)

__all__ = [
"ValidationError",
"GlobusAuthRequirementsError",
"GlobusAuthorizationParameters",
"to_auth_requirements_error",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from __future__ import annotations

import inspect
import typing as t

T = t.TypeVar("T", bound="Serializable")


class Serializable:
_EXCLUDE_VARS: t.ClassVar[tuple[str, ...]] = ("self", "extra")
extra: dict[str, t.Any]

@classmethod
def _supported_fields(cls) -> list[str]:
signature = inspect.signature(cls.__init__)
return [
name
for name in signature.parameters.keys()
if name not in cls._EXCLUDE_VARS
]

@classmethod
def from_dict(cls: type[T], data: dict[str, t.Any]) -> T:
"""
Instantiate from a dictionary.
:param data: The dictionary to create the error from.
:type data: dict
"""

# Extract any extra fields
extras = {k: v for k, v in data.items() if k not in cls._supported_fields()}
kwargs: dict[str, t.Any] = {"extra": extras}
# Ensure required fields are supplied
for field_name in cls._supported_fields():
kwargs[field_name] = data.get(field_name)

return cls(**kwargs)

def to_dict(self, include_extra: bool = False) -> dict[str, t.Any]:
"""
Render to a dictionary.
:param include_extra: Whether to include stored extra (non-standard) fields in
the returned dictionary.
:type include_extra: bool
"""
result = {}

# Set any authorization parameters
for field in self._supported_fields():
value = getattr(self, field)
if value is not None:
if isinstance(value, Serializable):
value = value.to_dict(include_extra=include_extra)
result[field] = value

# Set any extra fields
if include_extra:
result.update(self.extra)

return result
65 changes: 65 additions & 0 deletions src/globus_sdk/experimental/auth_requirements_error/_validators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from __future__ import annotations

import typing as t

from ._serializable import Serializable

S = t.TypeVar("S", bound=Serializable)


class ValidationError(ValueError):
pass


def str_(name: str, value: t.Any) -> str:
if isinstance(value, str):
return value
raise ValidationError(f"'{name}' must be a string")


def opt_str(name: str, value: t.Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
raise ValidationError(f"'{name}' must be a string or null")


def opt_bool(name: str, value: t.Any) -> bool | None:
if value is None or isinstance(value, bool):
return value
raise ValidationError(f"'{name}' must be a bool or null")


def str_list(name: str, value: t.Any) -> list[str]:
if isinstance(value, list) and all(isinstance(s, str) for s in value):
return value
raise ValidationError(f"'{name}' must be a list of strings")


def opt_str_list(name: str, value: t.Any) -> list[str] | None:
if value is None:
return None
if isinstance(value, list) and all(isinstance(s, str) for s in value):
return value
raise ValidationError(f"'{name}' must be a list of strings or null")


def opt_str_list_or_commasep(name: str, value: t.Any) -> list[str] | None:
if value is None:
return None
if isinstance(value, str):
value = value.split(",")
if isinstance(value, list) and all(isinstance(s, str) for s in value):
return value
raise ValidationError(
f"'{name}' must be a list of strings or a comma-delimited string or null"
)


def instance_or_dict(name: str, value: t.Any, cls: type[S]) -> S:
if isinstance(value, cls):
return value
if isinstance(value, dict):
return cls.from_dict(value)
raise ValidationError(f"'{name}' must be a '{cls.__name__}' object or a dictionary")
Loading

0 comments on commit ab6029e

Please sign in to comment.