diff --git a/conda-store-server/conda_store_server/_internal/schema.py b/conda-store-server/conda_store_server/_internal/schema.py index 32603dfac..679ddd250 100644 --- a/conda-store-server/conda_store_server/_internal/schema.py +++ b/conda-store-server/conda_store_server/_internal/schema.py @@ -498,6 +498,14 @@ class APIPaginatedResponse(APIResponse): count: int +class APICursorPaginatedResponse(BaseModel): + data: Optional[Any] = None + status: APIStatus + message: Optional[str] = None + cursor: Optional[str] = None + count: int + + class APIAckResponse(BaseModel): status: APIStatus message: Optional[str] = None @@ -562,8 +570,8 @@ class APIDeleteNamespaceRole(BaseModel): # GET /api/v1/environment -class APIListEnvironment(APIPaginatedResponse): - data: List[Environment] +class APIListEnvironment(APICursorPaginatedResponse): + data: List[Environment] = [] # GET /api/v1/environment/{namespace}/{name} diff --git a/conda-store-server/conda_store_server/_internal/server/views/api.py b/conda-store-server/conda_store_server/_internal/server/views/api.py index b33c90ffd..603f612cc 100644 --- a/conda-store-server/conda_store_server/_internal/server/views/api.py +++ b/conda-store-server/conda_store_server/_internal/server/views/api.py @@ -14,12 +14,39 @@ from conda_store_server import __version__, api, app from conda_store_server._internal import orm, schema from conda_store_server._internal.environment import filter_environments -from conda_store_server._internal.schema import AuthenticationToken, Permissions +from conda_store_server._internal.schema import ( + AuthenticationToken, + Permissions, +) from conda_store_server._internal.server import dependencies +from conda_store_server._internal.server.views.pagination import ( + Cursor, + CursorPaginatedArgs, + Ordering, + OrderingMetadata, + paginate, +) from conda_store_server.exception import CondaStoreError from conda_store_server.server.auth import Authentication +def get_cursor(cursor: Optional[str] = None) -> Cursor: + return Cursor.load(cursor) + + +def get_cursor_paginated_args( + order: Optional[Ordering] = Ordering.ASCENDING, + limit: Optional[int] = None, + sort_by: List[str] = Query([]), + server=Depends(dependencies.get_server), +) -> CursorPaginatedArgs: + return CursorPaginatedArgs( + limit=server.max_page_size if limit is None else limit, + order=order, + sort_by=sort_by, + ) + + class PaginatedArgs(TypedDict): """Dictionary type holding information about paginated requests.""" @@ -632,10 +659,12 @@ async def api_delete_namespace( response_model=schema.APIListEnvironment, ) async def api_list_environments( + request: Request, auth: Authentication = Depends(dependencies.get_auth), conda_store: app.CondaStore = Depends(dependencies.get_conda_store), entity: AuthenticationToken = Depends(dependencies.get_entity), - paginated_args: PaginatedArgs = Depends(get_paginated_args), + paginated_args: CursorPaginatedArgs = Depends(get_cursor_paginated_args), + cursor: Cursor = Depends(get_cursor), artifact: Optional[schema.BuildArtifactType] = None, jwt: Optional[str] = None, name: Optional[str] = None, @@ -643,7 +672,7 @@ async def api_list_environments( packages: Optional[List[str]] = Query([]), search: Optional[str] = None, status: Optional[schema.BuildStatus] = None, -): +) -> schema.APIListEnvironment: """Retrieve a list of environments. Parameters @@ -654,7 +683,7 @@ async def api_list_environments( the request entity : AuthenticationToken Token of the user making the request - paginated_args : PaginatedArgs + paginated_args : CursorPaginatedArgs Arguments for controlling pagination of the response conda_store : app.CondaStore The running conda store application @@ -678,9 +707,11 @@ async def api_list_environments( Returns ------- - Dict - Paginated JSON response containing the requested environments - + schema.APIListEnvironment + Paginated JSON response containing the requested environments. Results are sorted by each + envrionment's build's scheduled_on time to ensure all results are returned when iterating + over pages in systems where the number of environments is changing while results are being + requested; see https://github.com/conda-incubator/conda-store/issues/859 for context """ with conda_store.get_db() as db: if jwt: @@ -693,7 +724,7 @@ async def api_list_environments( else: role_bindings = None - orm_environments = api.list_environments( + query = api.list_environments( db, search=search, namespace=namespace, @@ -706,21 +737,29 @@ async def api_list_environments( ) # Filter by environments that the user who made the query has access to - orm_environments = filter_environments( - query=orm_environments, + query = filter_environments( + query=query, role_bindings=auth.entity_bindings(entity), ) - return paginated_api_response( - orm_environments, - paginated_args, - schema.Environment, - exclude={"current_build"}, - allowed_sort_bys={ - "namespace": orm.Namespace.name, - "name": orm.Environment.name, - }, - default_sort_by=["namespace", "name"], + paginated, next_cursor = paginate( + query=query, + ordering_metadata=OrderingMetadata( + order_names=["namespace", "name"], + column_names=["namespace.name", "name"], + column_objects=[orm.Namespace.name, orm.Environment.name], + ), + cursor=cursor, + order_by=paginated_args.sort_by, + order=paginated_args.order, + limit=paginated_args.limit, + ) + + return schema.APIListEnvironment( + data=paginated, + status="ok", + cursor=next_cursor.dump(), + count=1000, ) diff --git a/conda-store-server/conda_store_server/_internal/server/views/pagination.py b/conda-store-server/conda_store_server/_internal/server/views/pagination.py new file mode 100644 index 000000000..f21b37278 --- /dev/null +++ b/conda-store-server/conda_store_server/_internal/server/views/pagination.py @@ -0,0 +1,324 @@ +from __future__ import annotations + +import base64 +import operator +from enum import Enum +from typing import Any, Optional + +import pydantic +from fastapi import HTTPException +from sqlalchemy import asc, desc, tuple_ +from sqlalchemy.orm import InstrumentedAttribute +from sqlalchemy.orm import Query as SqlQuery +from sqlalchemy.sql.expression import ColumnClause + +from conda_store_server._internal.orm import Base + + +class Ordering(Enum): + ASCENDING = "asc" + DESCENDING = "desc" + + +class Cursor(pydantic.BaseModel): + last_id: int | None = 0 + count: int | None = 0 + + # List query parameters to order by, and the last value of the ordered attribute + # { + # 'namespace': 'foo', + # 'environment': 'bar', + # } + last_value: dict[str, str] | None = {} + + def dump(self) -> str: + """Dump the cursor as a b64-encoded string. + + Returns + ------- + str + base64-encoded string containing the information needed + to retrieve the page of data following the location of the cursor + """ + return base64.b64encode(self.model_dump_json().encode("utf8")) + + @classmethod + def load(cls, b64_cursor: str | None = None) -> Cursor: + """Create a Cursor from a b64-encoded string. + + Parameters + ---------- + b64_cursor : str | None + base64-encoded string containing information about the cursor + + Returns + ------- + Cursor + Cursor representation of the b64-encoded string + """ + if b64_cursor is None: + return cls(last_id=None, count=0, last_value=None) + return cls.model_validate_json(base64.b64decode(b64_cursor).decode("utf8")) + + def get_last_values(self, order_names: list[str]) -> list[Any]: + """Get a list of the values corresponding to the order_names. + + Parameters + ---------- + order_names : list[str] + List of names of values stored in the cursor + + Returns + ------- + list[Any] + The last values pointed to by the cursor for the given order_names + """ + if order_names: + return [self.last_value[name] for name in order_names] + else: + return [] + + @classmethod + def end(cls) -> Cursor: + """Cursor representing the end of a set of paginated results. + + Returns + ------- + Cursor + An empty cursor + """ + return cls(last_id=None, count=0, last_value=None) + + @classmethod + def begin(cls) -> Cursor: + return cls(last_id=0, count=None, last_value=None) + + +def paginate( + query: SqlQuery, + ordering_metadata: OrderingMetadata, + cursor: Cursor | None = None, + order_by: list[str] | None = None, + order: Ordering = Ordering.ASCENDING, + limit: int = 10, +) -> tuple[SqlQuery, Cursor]: + """Paginate the query using the cursor and the requested sort_bys. + + This function assumes that the first column of the query contains + the type whose ID should be used to sort the results. + + Additionally, with cursor pagination all keys used to order the results + must be included in the call to query.filter(). + + https://medium.com/@george_16060/cursor-based-pagination-with-arbitrary-ordering-b4af6d5e22db + + Parameters + ---------- + query : SqlQuery + Query containing database results to paginate + cursor : Cursor | None + Cursor object containing information about the last item on the previous page. + If None, the first page is returned. + order_by : list[str] | None + List of sort_by query parameters + + Returns + ------- + tuple[SqlQuery, Cursor] + Query containing the paginated results, and Cursor for retrieving + the next page + """ + if order_by is None: + order_by = [] + + if order == Ordering.ASCENDING: + comparison = operator.gt + order_func = asc + elif order == Ordering.DESCENDING: + comparison = operator.lt + order_func = desc + else: + raise HTTPException( + status_code=400, + detail=f"Invalid query parameter: order = {order}; must be one of ['asc', 'desc']", + ) + + # Get the python type of the objects being queried + queried_type = query.column_descriptions[0]["type"] + columns = ordering_metadata.get_requested_columns(order_by) + + # If there's a cursor already, use the last attributes to filter + # the results by (*attributes, id) >/< (*last_values, last_id) + # Order by desc or asc + if cursor is not None and cursor != Cursor.end(): + last_values = cursor.get_last_values(order_by) + query = query.filter( + comparison( + tuple_(*columns, queried_type.id), + (*last_values, cursor.last_id), + ) + ) + + order_by_args = [order_func(col) for col in columns] + [order_func(queried_type.id)] + + query = query.order_by(*order_by_args) + data = query.limit(limit).all() + count = query.count() + + if count > 0: + last_result = data[-1] + last_value = ordering_metadata.get_attr_values(last_result, order_by) + + next_cursor = Cursor( + last_id=data[-1].id, last_value=last_value, count=query.count() + ) + else: + next_cursor = Cursor.end() + + return (data, next_cursor) + + +class CursorPaginatedArgs(pydantic.BaseModel): + limit: Optional[int] = 10 + order: Optional[Ordering] = Ordering.ASCENDING + sort_by: Optional[list[str]] = [] + + @pydantic.field_validator("sort_by") + def validate_sort_by(cls, v: list[str]) -> list[str]: + """Validate the columns to sort by. + + FastAPI doesn't support lists in query parameters, so if the + `sort_by` value is a single-element list, assume that this + could be a comma-separated list. No harm in attempting to split + this by commas. + + Parameters + ---------- + v : list[str] + + + Returns + ------- + list[str] + """ + if len(v) == 1: + v = v[0].split(",") + return v + + +class OrderingMetadata: + def __init__( + self, + order_names: list[str] | None = None, + column_names: list[str] | None = None, + column_objects: list[InstrumentedAttribute] | None = None, + ): + self.order_names = order_names + self.column_names = column_names + self.column_objects = column_objects + + def validate(self, model: Base): + if len(self.order_names) != len(self.column_names): + raise ValueError( + "Each name of a valid ordering available to the order_by query parameter" + "must have an associated column name to select in the table." + ) + + for col in self.column_names: + if not hasattr(model, col): + raise ValueError(f"No column named {col} found on model {model}.") + + def get_requested_columns( + self, + order_by: list[str] | None = None, + ) -> list[ColumnClause]: + """Get a list of sqlalchemy columns requested by the value of the order_by query param. + + Parameters + ---------- + order_by : list[str] | None + If specified, this should be a subset of self.order_names. If none, an + empty list is returned. + + Returns + ------- + list[ColumnClause] + A list of sqlalchemy columns corresponding to the order_by values passed + as a query parameter + """ + columns = [] + if order_by: + for order_name in order_by: + idx = self.order_names.index(order_name) + columns.append(self.column_objects[idx]) + + return columns + + def __str__(self) -> str: + return f"OrderingMetadata" + + def __repr__(self) -> str: + return str(self) + + def get_attr_values( + self, + obj: Base, + order_by: list[str] | None = None, + ) -> dict[str, Any]: + """Using the order_by values, get the corresponding attribute values on obj. + + Parameters + ---------- + obj : Any + sqlalchemy model containing attribute names that are contained in + `self.column_names` + order_by : list[str] | None + Values that the user wants to order by; these are used to look up the corresponding + column names that are used to access the attributes of `obj`. + + Returns + ------- + dict[str, Any] + A mapping between the `order_by` values and the attribute values on `obj` + + """ + values = {} + for order_name in order_by: + idx = self.order_names.index(order_name) + attr = self.column_names[idx] + values[order_name] = get_nested_attribute(obj, attr) + + return values + + +def get_nested_attribute(obj: Base, attr: str) -> str | int | float: + """Get a nested attribute from the given sqlalchemy model. + + Parameters + ---------- + obj : Base + A sqlalchemy model for which a (possibly nested) attribute is to be + retrieved + attr : str + String attribute; nested attributes should be separated with `.` + + Returns + ------- + str | int | float + Value of the attribute; strictly this can be any column type supported + by sqlalchemy, but for conda-store this is a str, an int, or a float + + Examples + -------- + >>> env = db.query(orm.Environment).join(orm.Namespace).first() + >>> get_nested_attribute(env, 'namespace.name') + 'namespace1' + >>> get_nested_attribute(env, 'name') + 'my_environment' + """ + attribute, *rest = attr.split(".") + while len(rest) > 0: + obj = getattr(obj, attribute) + attribute, *rest = rest + + return getattr(obj, attribute) diff --git a/conda-store-server/pyproject.toml b/conda-store-server/pyproject.toml index cc24b1c1f..fc437da64 100644 --- a/conda-store-server/pyproject.toml +++ b/conda-store-server/pyproject.toml @@ -125,9 +125,6 @@ user-journey-test = ["pytest -m user_journey"] conda-store-server = "conda_store_server._internal.server.__main__:main" conda-store-worker = "conda_store_server._internal.worker.__main__:main" -[tool.black] -line-length = 88 - [tool.isort] lines_between_types = 1 lines_after_imports = 2 @@ -142,7 +139,6 @@ exclude = [ [tool.ruff.lint] ignore = [ "E501", # line-length - "ANN001", # missing-type-function-argument "ANN002", # missing-type-args "ANN003", # missing-type-kwargs diff --git a/conda-store-server/tests/_internal/server/views/test_api.py b/conda-store-server/tests/_internal/server/views/test_api.py index dd46ad6a5..188006b12 100644 --- a/conda-store-server/tests/_internal/server/views/test_api.py +++ b/conda-store-server/tests/_internal/server/views/test_api.py @@ -16,6 +16,7 @@ from conda_store_server import CONDA_STORE_DIR, __version__ from conda_store_server._internal import schema from conda_store_server._internal.server import dependencies +from conda_store_server._internal.server.views.pagination import Cursor @contextlib.contextmanager @@ -1070,3 +1071,74 @@ def test_default_conda_store_dir(): assert dir == rf"C:\Users\{user}\AppData\Local\conda-store\conda-store" else: assert dir == f"/home/{user}/.local/share/conda-store" + + +@pytest.mark.parametrize( + "order", + [ + "asc", + "desc", + ], +) +@pytest.mark.parametrize( + ("sort_by_param", "attr_func"), + [ + ("name", lambda x: (x.name, x.id)), + ("namespace", lambda x: (x.namespace.name, x.id)), + ("name,namespace", lambda x: (x.name, x.namespace.name, x.id)), + ("namespace,name", lambda x: (x.namespace.name, x.name, x.id)), + ], +) +def test_api_list_environments( + conda_store_server, + testclient, + seed_conda_store_big, + authenticate, + order, + sort_by_param, + attr_func, +): + """Test the REST API lists the paginated envs when sorting by name.""" + response = testclient.get( + f"api/v1/environment/?sort_by={sort_by_param}&order={order}" + ) + response.raise_for_status() + + model = schema.APIListEnvironment.model_validate(response.json()) + assert model.status == schema.APIStatus.OK + + # Pull out the attributes that we are sorting on from each environment + envs = [attr_func(env) for env in model.data] + + # The environments should already be sorted; check that this is the case + assert sorted(envs, reverse=(order == "desc")) == envs + + +def test_api_list_environments_no_qparam( + conda_store_server, + testclient, + seed_conda_store_big, + authenticate, +): + """Test the REST API lists the envs by id when no query params are specified.""" + response = testclient.get("api/v1/environment/?limit=10") + response.raise_for_status() + + model = schema.APIListEnvironment.model_validate(response.json()) + assert model.status == schema.APIStatus.OK + + envs = model.data + + while Cursor.load(model.cursor).last_id is not None: + response = testclient.get(f"api/v1/environment/?limit=10&cursor={model.cursor}") + response.raise_for_status() + + model = schema.APIListEnvironment.model_validate(response.json()) + assert model.status == schema.APIStatus.OK + + envs.extend(model.data) + + env_ids = [env.id for env in model.data] + + # Check that the environments are sorted by ID + assert sorted(env_ids) == env_ids diff --git a/conda-store-server/tests/conftest.py b/conda-store-server/tests/conftest.py index f1cebe23b..4ccdfe2d3 100644 --- a/conda-store-server/tests/conftest.py +++ b/conda-store-server/tests/conftest.py @@ -5,9 +5,12 @@ import datetime import json import pathlib +import random +import string import sys import typing import uuid +from collections import defaultdict import pytest import yaml @@ -165,6 +168,46 @@ def seed_conda_store(db, conda_store): return db +@pytest.fixture +def seed_conda_store_big(db, conda_store): + """Seed the conda-store db with 150 randomly named envs in 5 random namespaces.""" + namespace_names = [str(uuid.uuid4()) for _ in range(5)] + namespaces = defaultdict(dict) + for i in range(50): + name = "".join(random.choices(string.ascii_letters, k=10)) + namespaces[random.choice(namespace_names)][name] = schema.CondaSpecification( + name=name, channels=["defaults"], dependencies=["numpy"] + ) + + name = "".join(random.choices(string.ascii_letters, k=11)) + namespaces[random.choice(namespace_names)][name] = schema.CondaSpecification( + name=name, + channels=["defaults"], + dependencies=["flask"], + ) + + name = "".join(random.choices(string.ascii_letters, k=12)) + namespaces[random.choice(namespace_names)][name] = schema.CondaSpecification( + name=name, + channels=["defaults"], + dependencies=["flask"], + ) + + _seed_conda_store( + db, + conda_store, + namespaces, + ) + + # for testing purposes make build 4 complete + build = api.get_build(db, build_id=4) + build.started_on = datetime.datetime.utcnow() + build.ended_on = datetime.datetime.utcnow() + build.status = schema.BuildStatus.COMPLETED + db.commit() + return db + + @pytest.fixture def conda_store(conda_store_config): _conda_store = app.CondaStore(config=conda_store_config)