Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify GlobusAuthRequiremenetsError validation #795

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 82 additions & 137 deletions src/globus_sdk/experimental/auth_requirements_error/_variants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import sys
import typing as t

from . import validators
Expand All @@ -8,17 +9,24 @@
GlobusAuthRequirementsError,
)

if sys.version_info >= (3, 8):
from typing import Literal, Protocol
else:
from typing_extensions import Literal, Protocol

T = t.TypeVar("T", bound="LegacyAuthRequirementsErrorVariant")


class LegacyAuthRequirementsErrorVariant:
class HasSupportedFields(Protocol):
SUPPORTED_FIELDS: t.ClassVar[set[str]]


class LegacyAuthRequirementsErrorVariant(HasSupportedFields):
"""
Abstract base class for errors which can be converted to a
Globus Auth Requirements Error.
"""

SUPPORTED_FIELDS: dict[str, t.Callable[[t.Any], t.Any]] = {}

@classmethod
def from_dict(cls: t.Type[T], error_dict: dict[str, t.Any]) -> T:
"""
Expand All @@ -31,7 +39,7 @@ def from_dict(cls: t.Type[T], error_dict: dict[str, t.Any]) -> T:
extras = {k: v for k, v in error_dict.items() if k not in cls.SUPPORTED_FIELDS}
kwargs: dict[str, t.Any] = {"extra": extras}
# Ensure required fields are supplied
for field_name in cls.SUPPORTED_FIELDS.keys():
for field_name in cls.SUPPORTED_FIELDS:
kwargs[field_name] = error_dict.get(field_name)

return cls(**kwargs)
Expand All @@ -45,32 +53,19 @@ class LegacyConsentRequiredTransferError(LegacyAuthRequirementsErrorVariant):
The ConsentRequired error format emitted by the Globus Transfer service.
"""

code: str
required_scopes: list[str]
extra_fields: dict[str, t.Any]

SUPPORTED_FIELDS = {
"code": validators.StringLiteral("ConsentRequired"),
"required_scopes": validators.ListOfStrings,
}

def __init__(
self,
*,
code: str | None,
code: Literal["ConsentRequired"],
required_scopes: list[str] | None,
extra: dict[str, t.Any] | None = None,
): # pylint: disable=unused-argument
# Validate and assign supported fields
for field_name, validator in self.SUPPORTED_FIELDS.items():
try:
field_value = validator(locals()[field_name])
except ValueError as e:
raise ValueError(f"Error validating field '{field_name}': {e}") from e
):
self.code = _consent_required_validator("code")
self.required_scopes = validators.ListOfStrings("required_scopes")

setattr(self, field_name, field_value)
self.extra = extra or {}

self.extra_fields = extra or {}
SUPPORTED_FIELDS: set[str] = validators.derive_supported_fields(__init__)

def to_auth_requirements_error(self) -> GlobusAuthRequirementsError:
"""
Expand All @@ -80,9 +75,9 @@ def to_auth_requirements_error(self) -> GlobusAuthRequirementsError:
code=self.code,
authorization_parameters=GlobusAuthorizationParameters(
required_scopes=self.required_scopes,
session_message=self.extra_fields.get("message"),
session_message=self.extra.get("message"),
),
extra=self.extra_fields,
extra=self.extra,
)


Expand All @@ -92,32 +87,18 @@ class LegacyConsentRequiredAPError(LegacyAuthRequirementsErrorVariant):
Action Providers.
"""

code: str
required_scope: str
extra_fields: dict[str, t.Any]

SUPPORTED_FIELDS = {
"code": validators.StringLiteral("ConsentRequired"),
"required_scope": validators.String,
}

def __init__(
self,
*,
code: str | None,
required_scope: str | None,
code: Literal["ConsentRequired"],
required_scope: str,
extra: dict[str, t.Any] | None,
): # pylint: disable=unused-argument
# Validate and assign supported fields
for field_name, validator in self.SUPPORTED_FIELDS.items():
try:
field_value = validator(locals()[field_name])
except ValueError as e:
raise ValueError(f"Error validating field '{field_name}': {e}") from e

setattr(self, field_name, field_value)
):
self.code = _consent_required_validator("code")
self.required_scope = validators.String("required_scope")
self.extra = extra or {}

self.extra_fields = extra or {}
SUPPORTED_FIELDS: set[str] = validators.derive_supported_fields(__init__)

def to_auth_requirements_error(self) -> GlobusAuthRequirementsError:
"""
Expand All @@ -130,13 +111,11 @@ def to_auth_requirements_error(self) -> GlobusAuthRequirementsError:
code=self.code,
authorization_parameters=GlobusAuthorizationParameters(
required_scopes=[self.required_scope],
session_message=self.extra_fields.get("description"),
extra=self.extra_fields.get("authorization_parameters"),
session_message=self.extra.get("description"),
extra=self.extra.get("authorization_parameters"),
),
extra={
k: v
for k, v in self.extra_fields.items()
if k != "authorization_parameters"
k: v for k, v in self.extra.items() if k != "authorization_parameters"
},
)

Expand All @@ -147,29 +126,6 @@ class LegacyAuthorizationParameters:
Globus services.
"""

session_message: str | None
session_required_identities: list[str] | None
session_required_policies: str | list[str] | None
session_required_single_domain: str | list[str] | None
session_required_mfa: bool | None
# Declared here for compatibility with mixed legacy payloads
required_scopes: list[str] | None
extra_fields: dict[str, t.Any]

DEFAULT_CODE = "AuthorizationRequired"

SUPPORTED_FIELDS = {
"session_message": validators.OptionalString,
"session_required_identities": validators.OptionalListOfStrings,
"session_required_policies": (
validators.OptionalListOfStringsOrCommaDelimitedStrings
),
"session_required_single_domain": (
validators.OptionalListOfStringsOrCommaDelimitedStrings
),
"session_required_mfa": validators.OptionalBool,
}

def __init__(
self,
*,
Expand All @@ -179,55 +135,53 @@ def __init__(
session_required_single_domain: str | list[str] | None = None,
session_required_mfa: bool | None = None,
extra: dict[str, t.Any] | None = None,
): # pylint: disable=unused-argument
# Validate and assign supported fields
for field_name, validator in self.SUPPORTED_FIELDS.items():
try:
field_value = validator(locals()[field_name])
except ValueError as e:
raise ValueError(f"Error validating field '{field_name}': {e}") from e

setattr(self, field_name, field_value)

# Retain any additional fields
self.extra_fields = extra or {}

# Enforce that the error contains at least one of the fields we expect
if not any(
(getattr(self, field_name) is not None)
for field_name in self.SUPPORTED_FIELDS.keys()
):
raise ValueError(
"Must include at least one supported authorization parameter: "
", ".join(self.SUPPORTED_FIELDS.keys())
)
):
self.session_message = validators.OptionalString("session_message")
self.session_required_identities = validators.OptionalListOfStrings(
"session_required_identities"
)
# note the types on these two for clarity; although they should be
# inferred correctly by most type checkers
#
# because the validator returns a list[str] from any input string,
# the type of the instance variables is narrower than the accepted
# type for the relevant __init__ parameters
self.session_required_policies: (
list[str] | None
) = validators.OptionalListOfStringsOrCommaDelimitedStrings(
"session_required_policies"
)
self.session_required_single_domain: (
list[str] | None
) = validators.OptionalListOfStringsOrCommaDelimitedStrings(
"session_required_single_domain"
)
self.session_required_mfa = validators.OptionalBool("session_required_mfa")
self.extra = extra or {}

def to_authorization_parameters(
self,
) -> GlobusAuthorizationParameters:
validators.require_at_least_one_field(
self,
[f for f in self.SUPPORTED_FIELDS if f != "session_message"],
"supported authorization parameter",
)

SUPPORTED_FIELDS: set[str] = validators.derive_supported_fields(__init__)

def to_authorization_parameters(self) -> GlobusAuthorizationParameters:
"""
Return a normalized GlobusAuthorizationParameters instance representing
these parameters.

Normalizes fields that may have been provided
as comma-delimited strings to lists of strings.
"""
required_policies = self.session_required_policies
if isinstance(required_policies, str):
required_policies = required_policies.split(",")

# TODO: Unnecessary?
required_single_domain = self.session_required_single_domain
if isinstance(required_single_domain, str):
required_single_domain = required_single_domain.split(",")

return GlobusAuthorizationParameters(
session_message=self.session_message,
session_required_identities=self.session_required_identities,
session_required_mfa=self.session_required_mfa,
session_required_policies=required_policies,
session_required_single_domain=required_single_domain,
extra=self.extra_fields,
session_required_policies=self.session_required_policies,
session_required_single_domain=self.session_required_single_domain,
extra=self.extra,
)

@classmethod
Expand All @@ -238,12 +192,11 @@ def from_dict(cls, param_dict: dict[str, t.Any]) -> LegacyAuthorizationParameter
:param param_dict: The dictionary to create the AuthorizationParameters from.
:type param_dict: dict
"""

# Extract any extra fields
extras = {k: v for k, v in param_dict.items() if k not in cls.SUPPORTED_FIELDS}
kwargs: dict[str, t.Any] = {"extra": extras}
# Ensure required fields are supplied
for field_name in cls.SUPPORTED_FIELDS.keys():
for field_name in cls.SUPPORTED_FIELDS:
kwargs[field_name] = param_dict.get(field_name)

return cls(**kwargs)
Expand All @@ -255,18 +208,10 @@ class LegacyAuthorizationParametersError(LegacyAuthRequirementsErrorVariant):
in use by Globus services.
"""

authorization_parameters: LegacyAuthorizationParameters
code: str
extra_fields: dict[str, t.Any]

DEFAULT_CODE = "AuthorizationRequired"

SUPPORTED_FIELDS = {
"code": validators.String,
"authorization_parameters": validators.ClassInstance(
LegacyAuthorizationParameters
),
}
_authz_param_validator: validators.IsInstance[
LegacyAuthorizationParameters
] = validators.IsInstance(LegacyAuthorizationParameters)

def __init__(
self,
Expand All @@ -278,25 +223,19 @@ def __init__(
extra: dict[str, t.Any] | None = None,
):
# Apply default, if necessary
code = code or self.DEFAULT_CODE
self.code: str = validators.String("code", code or self.DEFAULT_CODE)

# Convert authorization_parameters if it's a dict
if isinstance(authorization_parameters, dict):
authorization_parameters = LegacyAuthorizationParameters.from_dict(
param_dict=authorization_parameters
)
self.authorization_parameters = self._authz_param_validator(
"authorization_parameters"
)
self.extra = extra or {}

# Validate and assign supported fields
for field_name, validator in self.SUPPORTED_FIELDS.items():
try:
field_value = validator(locals()[field_name])
except ValueError as e:
raise ValueError(f"Error validating field '{field_name}': {e}") from e

setattr(self, field_name, field_value)

# Retain any additional fields
self.extra_fields = extra or {}
SUPPORTED_FIELDS: set[str] = validators.derive_supported_fields(__init__)

def to_auth_requirements_error(self) -> GlobusAuthRequirementsError:
"""
Expand All @@ -308,5 +247,11 @@ def to_auth_requirements_error(self) -> GlobusAuthRequirementsError:
return GlobusAuthRequirementsError(
authorization_parameters=authorization_parameters,
code=self.code,
extra=self.extra_fields,
extra=self.extra,
)


# construct with an explicit type to get the correct type for the validator
_consent_required_validator: validators.Validator[
Literal["ConsentRequired"]
] = validators.StringLiteral("ConsentRequired")
Loading
Loading