Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support session params in oauth2_get_authorize_url #773

Merged
merged 1 commit into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Added
~~~~~

- ``AuthClient.oauth2_get_authorize_url`` now supports the following parameters
for session management: ``session_required_identities``,
``session_required_single_domain``, and ``session_required_policies``. Each
of these accept list inputs, as returned by
``ErrorInfo.authorization_parameters``. (:pr:`NUMBER`)
30 changes: 29 additions & 1 deletion src/globus_sdk/services/auth/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,13 +371,27 @@ def get_projects(self) -> IterableResponse:
return GetProjectsResponse(self.get("/v2/api/projects"))

def oauth2_get_authorize_url(
self, *, query_params: dict[str, t.Any] | None = None
self,
*,
session_required_identities: UUIDLike | t.Iterable[UUIDLike] | None = None,
session_required_single_domain: str | t.Iterable[str] | None = None,
session_required_policies: UUIDLike | t.Iterable[UUIDLike] | None = None,
query_params: dict[str, t.Any] | None = None,
) -> str:
"""
Get the authorization URL to which users should be sent.
This method may only be called after ``oauth2_start_flow``
has been called on this ``AuthClient``.

:param session_required_identities: A list of identities must be
added to the session.
:type session_required_identities: str or uuid or list of str or uuid, optional
:param session_required_single_domain: A list of domain requirements
which must be satisfied by identities added to the session.
:type session_required_single_domain: str or list of str, optional
:param session_required_policies: A list of IDs for policies which must
be satisfied by the user.
:type session_required_policies: str or uuid or list of str or uuid, optional
:param query_params: Additional query parameters to include in the
authorize URL. Primarily for internal use
:type query_params: dict, optional
Expand All @@ -390,6 +404,20 @@ def oauth2_get_authorize_url(
"Call the oauth2_start_flow() method on this "
"AuthClient to resolve"
)
if query_params is None:
query_params = {}
if session_required_identities is not None:
query_params["session_required_identities"] = _commasep(
session_required_identities
)
if session_required_single_domain is not None:
query_params["session_required_single_domain"] = _commasep(
session_required_single_domain
)
if session_required_policies is not None:
query_params["session_required_policies"] = _commasep(
session_required_policies
)
auth_url = self.current_oauth2_flow_manager.get_authorize_url(
query_params=query_params
)
Expand Down
130 changes: 130 additions & 0 deletions tests/functional/services/auth/test_auth_client_flow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import urllib.parse
import uuid

import pytest

Expand All @@ -18,6 +19,135 @@ class CustomAuthClient(globus_sdk.AuthClient):
return CustomAuthClient(client_id=CLIENT_ID)


@pytest.fixture
def native_client(no_retry_transport):
class CustomAuthClient(globus_sdk.NativeAppAuthClient):
transport_class = no_retry_transport

return CustomAuthClient(client_id=CLIENT_ID)


@pytest.fixture
def confidential_client(no_retry_transport):
class CustomAuthClient(globus_sdk.ConfidentialAppAuthClient):
transport_class = no_retry_transport

return CustomAuthClient(
client_id=CLIENT_ID, client_secret="SECRET_SECRET_HES_GOT_A_SECRET"
)


# build a partial matrix over
#
# domain: str | list[str]
# identities: uuid | str | list[uuid] | list[str] | list[str | uuid]
# policies: uuid | str | list[uuid] | list[str] | list[str | uuid]
#
# start by seeding a small set of combinations of params to test, then run some loops to
# fill in the rest
_ALL_SESSION_PARAM_COMBINATIONS = [
(None, None, None),
("foo-id", "bar-id", "baz-id"),
(["foo-id", "bar-id"], None, ["quux-id", "snork-id"]),
(None, ["foo-id", "bar-id"], ["quux-id", "snork-id"]),
]
for domain_option in (None, "example.edu", ["example.edu", "example.org"]):
_ALL_SESSION_PARAM_COMBINATIONS.append((domain_option, None, None))
for identity_option in (
None,
uuid.UUID(int=0),
"foo-id",
[uuid.UUID(int=0), uuid.UUID(int=1)],
["foo-id", "bar-id"],
["foo-id", uuid.UUID(int=2)],
):
_ALL_SESSION_PARAM_COMBINATIONS.append((None, identity_option, None))
for policy_option in (
None,
uuid.UUID(int=3),
"baz-id",
[uuid.UUID(int=3), uuid.UUID(int=4)],
["baz-id", "quux-id"],
["baz-id", uuid.UUID(int=5)],
):
_ALL_SESSION_PARAM_COMBINATIONS.append((None, None, policy_option))


@pytest.mark.parametrize("flow_type", ("native_app", "confidential_app"))
# parametrize over both what is and what *is not* passed as a parameter
@pytest.mark.parametrize(
"domain_option, identity_option, policy_option", _ALL_SESSION_PARAM_COMBINATIONS
)
def test_oauth2_get_authorize_url_supports_session_params(
native_client,
confidential_client,
flow_type,
domain_option,
identity_option,
policy_option,
):
if flow_type == "native_app":
client = native_client
elif flow_type == "confidential_app":
client = confidential_client
else:
raise NotImplementedError

# get the url...
client.oauth2_start_flow(redirect_uri="https://example.com", requested_scopes="foo")
url_res = client.oauth2_get_authorize_url(
session_required_single_domain=domain_option,
session_required_identities=identity_option,
session_required_policies=policy_option,
)

# parse the result..
parsed_url = urllib.parse.urlparse(url_res)
parsed_params = urllib.parse.parse_qs(parsed_url.query)

# prepare some helper data...
expected_params_keys = {
"session_required_single_domain" if domain_option else None,
"session_required_identities" if identity_option else None,
"session_required_policies" if policy_option else None,
}
expected_params_keys.discard(None)
unexpected_query_params = {
"session_required_single_domain",
"session_required_identities",
"session_required_policies",
} - expected_params_keys
parsed_params_keys = set(parsed_params.keys())

# ...and validate!
assert expected_params_keys <= parsed_params_keys
assert (unexpected_query_params - parsed_params_keys) == unexpected_query_params

if domain_option is not None:
strized_option = (
",".join(str(x) for x in domain_option)
if isinstance(domain_option, list)
else str(domain_option)
)
assert parsed_params["session_required_single_domain"] == [strized_option]

if identity_option is not None:
strized_option = (
",".join(str(x) for x in identity_option)
if isinstance(identity_option, list)
else str(identity_option)
)
assert parsed_params["session_required_identities"] == [strized_option]

if policy_option is not None:
strized_option = (
",".join(str(x) for x in policy_option)
if isinstance(policy_option, list)
else str(policy_option)
)
assert parsed_params["session_required_policies"] == [strized_option]


def test_oauth2_get_authorize_url_native_defaults(client):
# default parameters for starting auth flow
# should warn because scopes were not specified
Expand Down