diff --git a/docs/user-guide/database/pagination.rst b/docs/user-guide/database/pagination.rst index 2d3efca5..415eb603 100644 --- a/docs/user-guide/database/pagination.rst +++ b/docs/user-guide/database/pagination.rst @@ -264,10 +264,60 @@ Here is a very simplified example of a route handler that sets this header: Here, ``perform_query`` is a wrapper around `~safir.database.PaginatedQueryRunner` that constructs and runs the query. A real route handler would have more query parameters and more documentation. -Note that this example also sets a non-standard ``X-Total-Count`` header containing the total count of entries returned by the underlying query without pagination. +Including result counts +----------------------- + +The example above also sets a non-standard ``X-Total-Count`` header containing the total count of entries returned by the underlying query without pagination. `~safir.database.PaginatedQueryRunner.query_count` will return this information. There is no standard way to return this information to the client, but ``X-Total-Count`` is a widely-used informal standard. +If you will always want to include the count, use `~safir.database.CountedPaginatedQueryRunner` instead. +Its `~safir.database.CountedPaginatedQueryRunner.query_object` and `~safir.database.CountedPaginatedQueryRunner.query_row` methods will return a `~safir.database.CountedPaginatedList`, which contains a ``count`` attribute holding the count. +This is equivalent to calling `~safir.database.PaginatedQueryRunner.query_object` or `~safir.database.PaginatedQueryRunner.query_object` followed by `~safir.database.PaginatedQueryRunner.query_count`, but the encapsulation into a data structure makes it easier to pass the results between components of the service. + +Here's the same code above but using that approach: + +.. code-block:: python + +.. code-block:: python + :emphasize-lines: 27, 34 + + @router.get("/query", response_class=Model) + async def query( + *, + cursor: Annotated[ + str | None, + Query( + title="Pagination cursor", + description="Cursor to navigate paginated results", + ), + ] = None, + limit: Annotated[ + int, + Query( + title="Row limit", + description="Maximum number of entries to return", + examples=[100], + ge=1, + le=100, + ), + ] = 100, + request: Request, + response: Response, + ) -> list[Model]: + parsed_cursor = None + if cursor: + parsed_cursor = ModelCursor.from_str(cursor) + runner = CountedPaginatedQueryRunner(Model, ModelCursor) + stmt = build_query(...) + results = await runner.query_object( + session, stmt, cursor=parsed_cursor, limit=limit + ) + if cursor or limit: + response.headers["Link"] = results.link_header(request.url) + response.headers["X-Total-Count"] = str(results.count) + return results.entries + Including links in the response ------------------------------- diff --git a/safir/src/safir/database/__init__.py b/safir/src/safir/database/__init__.py index 5fb86255..3f8d69d3 100644 --- a/safir/src/safir/database/__init__.py +++ b/safir/src/safir/database/__init__.py @@ -17,6 +17,8 @@ initialize_database, ) from ._pagination import ( + CountedPaginatedList, + CountedPaginatedQueryRunner, DatetimeIdCursor, InvalidCursorError, PaginatedList, @@ -28,6 +30,8 @@ __all__ = [ "AlembicConfigError", + "CountedPaginatedList", + "CountedPaginatedQueryRunner", "DatabaseInitializationError", "DatetimeIdCursor", "InvalidCursorError", diff --git a/safir/src/safir/database/_pagination.py b/safir/src/safir/database/_pagination.py index bcaa1eb9..bef0264d 100644 --- a/safir/src/safir/database/_pagination.py +++ b/safir/src/safir/database/_pagination.py @@ -34,6 +34,8 @@ """Type of an entry in a paginated list.""" __all__ = [ + "CountedPaginatedList", + "CountedPaginatedQueryRunner", "DatetimeIdCursor", "InvalidCursorError", "PaginatedList", @@ -321,19 +323,18 @@ def __str__(self) -> str: class PaginatedList(Generic[E, C]): """Paginated SQL results with accompanying pagination metadata. - Holds a paginated list of any Pydantic type, complete with a count and - cursors. Can hold any type of entry and any type of cursor, but implicitly - requires the entry type be one that is meaningfully paginated by that type - of cursor. + Holds a paginated list of any Pydantic type with pagination cursors. Can + hold any type of entry and any type of cursor, but implicitly requires the + entry type be one that is meaningfully paginated by that type of cursor. """ entries: list[E] """A batch of entries.""" - next_cursor: C | None = None + next_cursor: C | None """Cursor for the next batch of entries.""" - prev_cursor: C | None = None + prev_cursor: C | None """Cursor for the previous batch of entries.""" def first_url(self, current_url: URL) -> str: @@ -418,7 +419,7 @@ def link_header(self, current_url: URL) -> str: class PaginatedQueryRunner(Generic[E, C]): - """Construct and run database queries that return paginated results. + """Run database queries that return paginated results. This class implements the logic for keyset pagination based on arbitrary SQLAlchemy ORM where clauses. @@ -688,3 +689,148 @@ async def _paginated_query( return PaginatedList[E, C]( entries=entries, prev_cursor=prev_cursor, next_cursor=next_cursor ) + + +@dataclass +class CountedPaginatedList(PaginatedList[E, C]): + """Paginated SQL results with pagination metadata and total count. + + Holds a paginated list of any Pydantic type, complete with a count and + cursors. Can hold any type of entry and any type of cursor, but implicitly + requires the entry type be one that is meaningfully paginated by that type + of cursor. + """ + + count: int + """Total number of entries if queried without pagination.""" + + +class CountedPaginatedQueryRunner(PaginatedQueryRunner[E, C]): + """Run database queries that return paginated results with counts. + + This variation of `PaginatedQueryRunner` always runs a second query to + count the total number of available entries if queried without pagination. + It should only be used on small tables or with queries that can be + satisfied from the table indices; otherwise, the count query could be + undesirably slow. + + Parameters + ---------- + entry_type + Type of each entry returned by the queries. This must be a Pydantic + model. + cursor_type + Type of the pagination cursor, which encapsulates the logic of how + entries are sorted and what set of keys is used to retrieve the next + or previous batch of entries. + """ + + async def query_object( + self, + session: async_scoped_session, + stmt: Select[tuple], + *, + cursor: C | None = None, + limit: int | None = None, + ) -> CountedPaginatedList[E, C]: + """Perform a query for objects with an optional cursor and limit. + + Perform the query provided in ``stmt`` with appropriate sorting and + pagination as determined by the cursor type. Also performs a second + query to get the total count of entries if retrieved without + pagination. + + This method should be used with queries that return a single + SQLAlchemy model. The provided query will be run with the session + `~sqlalchemy.ext.asyncio.async_scoped_session.scalars` method and the + resulting object passed to Pydantic's ``model_validate`` to convert to + ``entry_type``. For queries returning a tuple of attributes, use + `query_row` instead. + + Unfortunately, this distinction cannot be type-checked, so be careful + to use the correct method. + + Parameters + ---------- + session + Database session within which to run the query. + stmt + Select statement to execute. Pagination and ordering will be + added, so this statement should not already have limits or order + clauses applied. This statement must return a list of SQLAlchemy + ORM models that can be converted to ``entry_type`` by Pydantic. + cursor + If present, continue from the provided keyset cursor. + limit + If present, limit the result count to at most this number of rows. + + Returns + ------- + CountedPaginatedList + Results of the query wrapped with pagination information and a + count of the total number of entries. + """ + result = await super().query_object( + session, stmt, cursor=cursor, limit=limit + ) + count = await self.query_count(session, stmt) + return CountedPaginatedList[E, C]( + entries=result.entries, + next_cursor=result.next_cursor, + prev_cursor=result.prev_cursor, + count=count, + ) + + async def query_row( + self, + session: async_scoped_session, + stmt: Select[tuple], + *, + cursor: C | None = None, + limit: int | None = None, + ) -> CountedPaginatedList[E, C]: + """Perform a query for attributes with an optional cursor and limit. + + Perform the query provided in ``stmt`` with appropriate sorting and + pagination as determined by the cursor type. Also performs a second + query to get the total count of entries if retrieved without + pagination. + + This method should be used with queries that return a list of + attributes that can be converted to the ``entry_type`` Pydantic model. + For queries returning a single ORM object, use `query_object` instead. + + Unfortunately, this distinction cannot be type-checked, so be careful + to use the correct method. + + Parameters + ---------- + session + Database session within which to run the query. + stmt + Select statement to execute. Pagination and ordering will be + added, so this statement should not already have limits or order + clauses applied. This statement must return a tuple of attributes + that can be converted to ``entry_type`` by Pydantic's + ``model_validate``. + cursor + If present, continue from the provided keyset cursor. + limit + If present, limit the result count to at most this number of rows. + + Returns + ------- + CountedPaginatedList + Results of the query wrapped with pagination information and a + count of the total number of entries. + """ + result = await super().query_row( + session, stmt, cursor=cursor, limit=limit + ) + count = await self.query_count(session, stmt) + return CountedPaginatedList[E, C]( + entries=result.entries, + next_cursor=result.next_cursor, + prev_cursor=result.prev_cursor, + count=count, + ) diff --git a/safir/tests/database_test.py b/safir/tests/database_test.py index 4ddddbf8..0ac77d64 100644 --- a/safir/tests/database_test.py +++ b/safir/tests/database_test.py @@ -26,6 +26,7 @@ from starlette.datastructures import URL from safir.database import ( + CountedPaginatedQueryRunner, DatetimeIdCursor, PaginatedQueryRunner, PaginationLinkData, @@ -410,11 +411,12 @@ async def test_pagination(database_url: str, database_password: str) -> None: # Query by object and test the pagination cursors going backwards and # forwards. - builder = PaginatedQueryRunner(PaginationModel, TableCursor) + runner = PaginatedQueryRunner(PaginationModel, TableCursor) + counted_runner = CountedPaginatedQueryRunner(PaginationModel, TableCursor) async with session.begin(): stmt: Select[tuple] = select(PaginationTable) - assert await builder.query_count(session, stmt) == 7 - result = await builder.query_object(session, stmt, limit=2) + assert await runner.query_count(session, stmt) == 7 + result = await runner.query_object(session, stmt, limit=2) assert_model_lists_equal(result.entries, rows[:2]) assert not result.prev_cursor base_url = URL("https://example.com/query") @@ -427,7 +429,15 @@ async def test_pagination(database_url: str, database_password: str) -> None: assert result.prev_url(base_url) is None assert str(result.next_cursor) == "1600000000.5_1" - result = await builder.query_object( + counted_result = await counted_runner.query_object( + session, stmt, limit=2 + ) + assert counted_result.entries == result.entries + assert counted_result.prev_cursor == result.prev_cursor + assert counted_result.next_cursor == result.next_cursor + assert counted_result.count == 7 + + result = await runner.query_object( session, stmt, cursor=result.next_cursor, limit=3 ) assert_model_lists_equal(result.entries, rows[2:5]) @@ -447,7 +457,7 @@ async def test_pagination(database_url: str, database_password: str) -> None: assert result.prev_url(base_url) == prev_url next_cursor = result.next_cursor - result = await builder.query_object( + result = await runner.query_object( session, stmt, cursor=result.prev_cursor ) assert_model_lists_equal(result.entries, rows[:2]) @@ -457,7 +467,7 @@ async def test_pagination(database_url: str, database_password: str) -> None: f'<{base_url!s}&cursor={result.next_cursor}>; rel="next"' ) - result = await builder.query_object(session, stmt, cursor=next_cursor) + result = await runner.query_object(session, stmt, cursor=next_cursor) assert_model_lists_equal(result.entries, rows[5:]) assert not result.next_cursor base_url = URL("https://example.com/query") @@ -468,14 +478,14 @@ async def test_pagination(database_url: str, database_password: str) -> None: ) prev_cursor = result.prev_cursor - result = await builder.query_object(session, stmt, cursor=prev_cursor) + result = await runner.query_object(session, stmt, cursor=prev_cursor) assert_model_lists_equal(result.entries, rows[:5]) assert result.link_header(base_url) == ( f'<{base_url!s}>; rel="first", ' f'<{base_url!s}?cursor={result.next_cursor}>; rel="next"' ) - result = await builder.query_object( + result = await runner.query_object( session, stmt, cursor=prev_cursor, limit=2 ) assert_model_lists_equal(result.entries, rows[3:5]) @@ -490,26 +500,39 @@ async def test_pagination(database_url: str, database_password: str) -> None: # function. async with session.begin(): stmt = select(PaginationTable.time, PaginationTable.id) - result = await builder.query_row(session, stmt, limit=2) + result = await runner.query_row(session, stmt, limit=2) assert_model_lists_equal(result.entries, rows[:2]) - assert await builder.query_count(session, stmt) == 7 + assert await runner.query_count(session, stmt) == 7 + + counted_result = await counted_runner.query_row(session, stmt, limit=2) + assert counted_result.entries == result.entries + assert counted_result.prev_cursor == result.prev_cursor + assert counted_result.next_cursor == result.next_cursor + assert counted_result.count == 7 # Querying for the entire table should return the everything with no # pagination cursors. Try this with both an object query and an attribute # query. async with session.begin(): - result = await builder.query_object(session, select(PaginationTable)) + stmt = select(PaginationTable) + result = await runner.query_object(session, stmt) assert_model_lists_equal(result.entries, rows) assert not result.next_cursor assert not result.prev_cursor stmt = select(PaginationTable.id, PaginationTable.time) - result = await builder.query_row(session, stmt) + result = await runner.query_row(session, stmt) assert_model_lists_equal(result.entries, rows) assert not result.next_cursor assert not result.prev_cursor base_url = URL("https://example.com/query?foo=b") assert result.link_header(base_url) == (f'<{base_url!s}>; rel="first"') + counted_result = await counted_runner.query_row(session, stmt) + assert counted_result.entries == result.entries + assert not counted_result.next_cursor + assert not counted_result.prev_cursor + assert counted_result.count == len(counted_result.entries) + def test_link_data() -> None: header = (