diff --git a/permit/api/role_assignments.py b/permit/api/role_assignments.py index 09223a9..863c675 100644 --- a/permit/api/role_assignments.py +++ b/permit/api/role_assignments.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Union from ..utils.pydantic_version import PYDANTIC_VERSION @@ -44,7 +44,7 @@ def __role_assignments(self) -> SimpleHttpClient: async def list( self, user_key: Optional[str] = None, - role_key: Optional[str] = None, + role_key: Optional[Union[str, List[str]]] = None, tenant_key: Optional[str] = None, resource_instance_key: Optional[str] = None, page: int = 1, @@ -68,15 +68,19 @@ async def list( PermitApiError: If the API returns an error HTTP status code. PermitContextError: If the configured ApiContext does not match the required endpoint context. """ - params = pagination_params(page, per_page) + params = list(pagination_params(page, per_page).items()) if user_key is not None: - params.update(dict(user=user_key)) + params.append(("user", user_key)) if role_key is not None: - params.update(dict(role=role_key)) + if isinstance(role_key, list): + for role in role_key: + params.append(("role", role)) + else: + params.append(("role", role_key)) if tenant_key is not None: - params.update(dict(tenant=tenant_key)) + params.append(("tenant", tenant_key)) if resource_instance_key is not None: - params.update(dict(resource_instance=resource_instance_key)) + params.append(("resource_instance", resource_instance_key)) return await self.__role_assignments.get( "", model=List[RoleAssignmentRead], diff --git a/tests/endpoints/test_role_assignments.py b/tests/endpoints/test_role_assignments.py new file mode 100644 index 0000000..babe8b4 --- /dev/null +++ b/tests/endpoints/test_role_assignments.py @@ -0,0 +1,49 @@ +from contextlib import contextmanager + +from permit import Permit, PermitApiError, RoleAssignmentCreate, RoleCreate, UserCreate + + +@contextmanager +def suppress_409(): + try: + yield + except PermitApiError as e: + if e.status_code != 409: + raise e + + +async def create_role_assignments(permit: Permit, role_key: str, user_count: int = 10): + with suppress_409(): + await permit.api.roles.create(RoleCreate(key=role_key, name=role_key)) + with suppress_409(): + await permit.api.users.bulk_create( + [UserCreate(key=f"user-{index}") for index in range(user_count)] + ) + with suppress_409(): + await permit.api.role_assignments.bulk_assign( + [ + RoleAssignmentCreate( + role=role_key, user=f"user-{index}", tenant="default" + ) + for index in range(user_count) + ] + ) + + +async def test_list_filter_by_role(permit: Permit): + await create_role_assignments(permit, "role-1") + await create_role_assignments(permit, "role-2") + role_assignments = await permit.api.role_assignments.list(role_key="role-1") + assert len(role_assignments) == 10 + assert {ra.role for ra in role_assignments} == {"role-1"} + + +async def test_list_filter_by_role_multiple(permit: Permit): + await create_role_assignments(permit, "role-1") + await create_role_assignments(permit, "role-2") + await create_role_assignments(permit, "role-3") + role_assignments = await permit.api.role_assignments.list( + role_key=["role-1", "role-2"] + ) + assert len(role_assignments) == 20 + assert {ra.role for ra in role_assignments} == {"role-1", "role-2"}