Skip to content

Commit

Permalink
Merge pull request #391 from TeskaLabs/fix/credential-search-performance
Browse files Browse the repository at this point in the history
Fix credential search performance
  • Loading branch information
byewokko authored Jun 27, 2024
2 parents 8a4dbbd + 13683f0 commit b17cb70
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 169 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
- `v24.20-alpha1`

### Breaking changes
- Removed total item count from credential list response (#391, `v24.20-alpha12`, Compatible with Seacat Admin WebUI v24.19-alpha28 and later)
- External login endpoints changed (#384, `v24.20-alpha9`)
- Default password criteria are more restrictive (#372, `v24.20-alpha1`, Compatible with Seacat Auth Webui v24.19-alpha and later, Seacat Account Webui v24.08-beta and later)

### Fix
- Fix credential search performance (#391, `v24.20-alpha12`)
- Fix AttributeError in credentials update (#399, `v24.20-alpha11`)
- Catch token decoding errors when finding sessions (#397, `v24.20-alpha10`)
- Properly encrypt cookie value in session update (#394, `v24.20-alpha8`)
Expand Down
5 changes: 5 additions & 0 deletions seacatauth/authz/role/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,8 @@ async def get_role_tenant(self, role_id):
await self.TenantService.get_tenant(tenant)

return tenant


async def get_assigned_role(self, credentials_id: str, role_id: str):
assignment_id = "{} {}".format(credentials_id, role_id)
return await self.StorageService.get(self.CredentialsRolesCollection, assignment_id)
6 changes: 1 addition & 5 deletions seacatauth/credentials/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,6 @@ async def list_credentials(self, request):
elif mode == "default":
search.SimpleFilter = request.query.get("f")

if len(search.AdvancedFilter) > 1:
raise asab.exceptions.ValidationError("No more than one advanced filter at a time is supported.")

try_global_search = asab.utils.string_to_boolean(request.query.get("global", "false"))

try:
Expand All @@ -215,9 +212,8 @@ async def list_credentials(self, request):
"result": "ACCESS-DENIED",
})
return asab.web.rest.json_response(request, {
"data": result["data"],
"count": result["count"],
"result": "OK",
**result
})


Expand Down
2 changes: 2 additions & 0 deletions seacatauth/credentials/providers/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ async def iterate(self, offset: int = 0, limit: int = -1, filtr: str = None):
if limit >= 0:
cursor.limit(limit)

cursor.sort("username", 1)

async for d in cursor:
yield self._normalize_credentials(d)

Expand Down
281 changes: 120 additions & 161 deletions seacatauth/credentials/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,194 +179,89 @@ async def get_by(self, key: str, value):
break
return credentials

async def iterate(self):

async def iterate(self, filter: typing.Optional[str] = None):
"""
This iterates over all providers and combines their results
Iterate over all providers and combine their results.
Fully asynchronous, but does not preserve the order of documents.
"""
pending = [provider.iterate() for provider in self.CredentialProviders.values()]
pending = [provider.iterate(filtr=filter) for provider in self.CredentialProviders.values()]
pending_tasks = {
asyncio.ensure_future(g.__anext__()): g for g in pending
}
while len(pending_tasks) > 0:
done, _ = await asyncio.wait(pending_tasks.keys(), return_when="FIRST_COMPLETED")
for d in done:
dg = pending_tasks.pop(d)
tasks_done, _ = await asyncio.wait(pending_tasks.keys(), return_when="FIRST_COMPLETED")
for task in tasks_done:
provider_generator = pending_tasks.pop(task)

try:
r = await d
credentials_data = await task
except StopAsyncIteration:
continue

pending_tasks[asyncio.ensure_future(dg.__anext__())] = dg
yield r
pending_tasks[asyncio.ensure_future(provider_generator.__anext__())] = provider_generator
yield credentials_data


async def list(self, session: SessionAdapter, search_params: generic.SearchParams, try_global_search: bool = False):
async def iterate_stable(self, offset: int = 0, filter: typing.Optional[str] = None):
"""
List credentials that are members of currently authorized tenants.
Global_search lists all credentials, regardless of tenants, but this requires superuser authorization.
Iterate over all providers and combine their results.
Preserves the order of results.
"""
# TODO: Searching with filters is very inefficient and needs serious optimization
if len(search_params.AdvancedFilter) > 1:
raise asab.exceptions.ValidationError("No more than one advanced filter at a time is supported.")

if try_global_search and not session.is_superuser():
# Return only tenant members
L.info("Not authorized to list credentials across all tenants", struct_data={
"cid": session.Credentials.Id, "resource_id": "authz:superuser"})
try_global_search = False

authorized_tenants = [tenant for tenant in session.Authorization.Authz if tenant != "*"]

# Authorize searched tenants
if "tenant" in search_params.AdvancedFilter:
# Search only requested tenant
tenant_id = search_params.AdvancedFilter["tenant"]
# Check tenant access
if not (tenant_id in authorized_tenants or session.is_superuser()):
raise exceptions.AccessDeniedError(
"Not authorized to access tenant members",
subject=session.Credentials.Id,
resource={"tenant_id": tenant_id}
)
searched_tenants = [tenant_id]
elif try_global_search:
# Search all credentials, ignore tenants
searched_tenants = None
else:
# Search currently authorized tenants
searched_tenants = authorized_tenants

if "role" in search_params.AdvancedFilter:
# Authorize searched roles
role_id = search_params.AdvancedFilter["role"]
tenant_id = role_id.split("/")[0]
# Check tenant access
# - global role is always accessible
# - role in my authorized tenants is accessible
# - superuser can access anything
if not (tenant_id == "*" or tenant_id in authorized_tenants or session.is_superuser()):
raise exceptions.AccessDeniedError(
"Not authorized to access tenant members",
subject=session.Credentials.Id,
resource={"tenant_id": tenant_id}
)
searched_roles = [role_id]
else:
# Do not filter by roles
searched_roles = None

# Get credential IDs that match both the tenant filter and the role filter
filtered_cids = None
estimated_count = None
if searched_roles:
role_svc = self.App.get_service("seacatauth.RoleService")
assignments = await role_svc.list_role_assignments(role_id=searched_roles)
if filtered_cids is None:
filtered_cids = set(a["c"] for a in assignments["data"])
estimated_count = assignments["count"]
else:
filtered_cids.intersection_update(a["c"] for a in assignments["data"])
estimated_count = min(estimated_count, assignments["count"])
if searched_tenants:
tenant_svc = self.App.get_service("seacatauth.TenantService")
provider = tenant_svc.get_provider()
assignments = await provider.list_tenant_assignments(searched_tenants)
if filtered_cids is None:
filtered_cids = set(a["c"] for a in assignments["data"])
estimated_count = assignments["count"]
else:
filtered_cids.intersection_update(a["c"] for a in assignments["data"])
estimated_count = min(estimated_count, assignments["count"])

if filtered_cids is not None:
if len(filtered_cids) == 0:
return {"count": 0, "data": []}

credentials = []
filtered_cids = sorted(filtered_cids)

offset = search_params.Page * search_params.ItemsPerPage
for cid in filtered_cids:
_, provider_id, _ = cid.split(":", 2)
try:
provider = self.CredentialProviders[provider_id]
cred_data = await provider.get(cid)
except KeyError:
L.info("Found an assignment of nonexisting credentials", struct_data={
"cid": cid, "role_ids": searched_roles, "tenant_ids": searched_tenants})
for provider in self.CredentialProviders.values():
async for credobj in provider.iterate(offset=offset, filtr=filter):
if offset > 0:
offset -= 1
continue
if not search_params.SimpleFilter or (
cred_data.get("username", "").startswith(search_params.SimpleFilter)
or cred_data.get("email", "").startswith(search_params.SimpleFilter)
):
if offset > 0:
# Skip until offset is reached
offset -= 1
continue
credentials.append(cred_data)
if len(credentials) >= search_params.ItemsPerPage:
# Page is full
break
yield credobj

return {
"data": credentials,
"count": estimated_count,
}

# Search without external filters
return await self._list(search_params)
async def _tenant_filter(self, credentials_id: str, tenant_ids: typing.Iterable):
if tenant_ids is None:
return True
tenant_svc = self.App.get_service("seacatauth.TenantService")
for tenant_id in tenant_ids:
try:
await tenant_svc.get_assigned_tenant(credentials_id, tenant_id)
return True
except KeyError:
return False

async def _role_filter(self, credentials_id: str, role_ids: typing.Iterable):
if role_ids is None:
return True
role_svc = self.App.get_service("seacatauth.RoleService")
for role_id in role_ids:
try:
await role_svc.get_assigned_role(credentials_id, role_id)
return True
except KeyError:
return False


async def _list(self, search_params: generic.SearchParams):
async def list(self, session: SessionAdapter, search_params: generic.SearchParams, try_global_search: bool = False):
"""
List credentials
List credentials that are members of currently authorized tenants.
Global_search lists all credentials, regardless of tenants, but this requires superuser authorization.
"""
provider_stack = []
total_count = 0 # If -1, then total count cannot be determined
for provider in self.CredentialProviders.values():
try:
count = await provider.count(filtr=search_params.SimpleFilter)
except Exception as e:
L.exception("Exception when getting count from a credentials provider: {}".format(e))
continue

provider_stack.append((count, provider))
if count >= 0 and total_count >= 0:
total_count += count
else:
total_count = -1
searched_tenants = _authorize_searched_tenants(session, search_params, try_global_search)
searched_roles = _authorize_searched_roles(session, search_params)

credentials = []
offset = search_params.Page * search_params.ItemsPerPage
remaining_items = search_params.ItemsPerPage
for count, provider in provider_stack:
if count >= 0:
if offset > count:
# The offset is beyond the count of the provider, skip to the next one
offset -= count
continue

async for credobj in provider.iterate(
offset=offset, limit=remaining_items, filtr=search_params.SimpleFilter
):
credentials.append(credobj)
remaining_items -= 1

if remaining_items <= 0:
break

offset = 0

else:
# TODO: Uncountable branch
L.error("Not implemented: Uncountable branch.", struct_data={"provider_id": provider.ProviderID})
async for credentials_data in self.iterate_stable(filter=search_params.SimpleFilter):
if not await self._role_filter(credentials_data["_id"], searched_roles):
continue
if not await self._tenant_filter(credentials_data["_id"], searched_tenants):
continue
if offset > 0:
offset -= 1
continue
credentials.append(credentials_data)
if len(credentials) >= search_params.ItemsPerPage:
break

return {
"data": credentials,
"count": total_count,
}
return {"data": credentials}


def get_provider(self, credentials_id):
Expand Down Expand Up @@ -746,3 +641,67 @@ async def can_access_credentials(self, session, credentials_id: str) -> bool:
return True
# The request and the target credentials have no tenant in common
return False


def _authorize_searched_tenants(
session: SessionAdapter,
search_params: generic.SearchParams,
try_global_search: bool = False
) -> typing.Optional[typing.Iterable[str]]:
"""
Authorize and return a list of tenants to filter by.
"""
if not session.is_superuser():
# Return only tenant members
try_global_search = False

authorized_tenants = [tenant for tenant in session.Authorization.Authz if tenant != "*"]

# Authorize searched tenants
if "tenant" in search_params.AdvancedFilter:
# Search only requested tenant
tenant_id = search_params.AdvancedFilter["tenant"]
# Check tenant access
if not (tenant_id in authorized_tenants or session.is_superuser()):
raise exceptions.AccessDeniedError(
"Not authorized to access tenant members",
subject=session.Credentials.Id,
resource={"tenant_id": tenant_id}
)
searched_tenants = [tenant_id]
elif try_global_search:
# Search all credentials, ignore tenants
searched_tenants = None
else:
# Search currently authorized tenants
searched_tenants = authorized_tenants

return searched_tenants


def _authorize_searched_roles(
session: SessionAdapter,
search_params: generic.SearchParams
) -> typing.Optional[typing.Iterable[str]]:
"""
Authorize and return a list of roles to filter by.
"""
authorized_tenants = [tenant for tenant in session.Authorization.Authz if tenant != "*"]
role_id = search_params.AdvancedFilter.get("role")
if not role_id:
return None

role_id = search_params.AdvancedFilter["role"]
tenant_id = role_id.split("/")[0]

# Check tenant access
# - global role is always accessible
# - role in my authorized tenants is accessible
# - superuser can access anything
if not (tenant_id == "*" or tenant_id in authorized_tenants or session.is_superuser()):
raise exceptions.AccessDeniedError(
"Not authorized to access role.",
subject=session.Credentials.Id,
resource={"role_id": role_id}
)
return [role_id]
2 changes: 1 addition & 1 deletion seacatauth/tenant/providers/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ async def delete_tenant_assignments(self, tenant):
})


async def get_assignment(self, credatials_id: str, tenant: str):
async def get_assignment(self, credatials_id: str, tenant: str) -> dict:
collection = await self.MongoDBStorageService.collection(self.AssignCollection)
query = {"c": credatials_id, "t": tenant}
result = await collection.find_one(query)
Expand Down
8 changes: 6 additions & 2 deletions seacatauth/tenant/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,9 +340,13 @@ async def get_tenants_by_scope(self, scope: list, credential_id: str, has_access
return tenants


async def has_tenant_assigned(self, credatials_id: str, tenant: str):
async def has_tenant_assigned(self, credatials_id: str, tenant: str) -> bool:
try:
await self.TenantsProvider.get_assignment(credatials_id, tenant)
await self.get_assigned_tenant(credatials_id, tenant)
except KeyError:
return False
return True


async def get_assigned_tenant(self, credatials_id: str, tenant: str) -> dict:
return await self.TenantsProvider.get_assignment(credatials_id, tenant)

0 comments on commit b17cb70

Please sign in to comment.