diff --git a/docs/user-guide/database/pagination.rst b/docs/user-guide/database/pagination.rst index b0f6f45a..21051f81 100644 --- a/docs/user-guide/database/pagination.rst +++ b/docs/user-guide/database/pagination.rst @@ -25,7 +25,7 @@ Third, your application passes the SQL query and any limit or cursor, along with This will apply the sort order and any restrictions from the limit or cursor and then execute the query in that session. It will return a `~safir.database.PaginatedList`, which holds the results along with pagination information. -Finally, the application will return, via the API handler, the list of entries included in the `~safir.database.PaginatedList` along with information about how to obtain the next or previous group of entries and the total number of records. +Finally, the application will return, via the API handler, the list of entries included in the `~safir.database.PaginatedList` along with information about how to obtain the next or previous group of entries and, optionally, the total number of records. This pagination information is generally returned in HTTP headers, although if you wish to return it in a data structure wrapper around the results, you can do that instead. Defining the cursor @@ -187,6 +187,15 @@ If the SQL query returns a tuple of individually selected attributes that corres Either way, the results will be a `~safir.database.PaginatedList` wrapping a list of Pydantic models of the appropriate type. +If you want to also return the total number of entries, run a separate ``COUNT`` query: + +.. code-block:: python + + count = await runner.query_count(session, stmt) + +This returns the total number of matching rows without regard to cursor or limit. +Best practice is to return this information in the response so that clients can estimate the total number of result pages, but this query will only be fast if it can be satisfied from the table indices or the table is small, so it is not run by default. + Returning paginated results =========================== @@ -226,14 +235,15 @@ Here is a very simplified example of a route handler that sets this header: ) if cursor or limit: response.headers["Link"] = results.link_header(request.url) - response.headers["X-Total-Count"] = str(results.count) + count = await runner.query_count(session, stmt) + response.headers["X-Total-Count"] = str(count) return results.entries Here, ``perform_query`` is a wrapper around `~safir.database.PaginatedQueryRunner` that constructs and runs the query. A real route handler would have more query parameters and more documentation. Note that this example also sets a non-standard ``X-Total-Count`` header containing the total count of entries returned by the underlying query without pagination. -`~safir.database.PaginatedQueryRunner` obtains this information by default, since the count query is often fast for databases to perform. +`~safir.database.PaginatedQueryRunner.query_count` will return this information. There is no standard way to return this information to the client, but ``X-Total-Count`` is a widely-used informal standard. Including links in the response diff --git a/safir/src/safir/database/_pagination.py b/safir/src/safir/database/_pagination.py index 12ef3e67..8a053ee2 100644 --- a/safir/src/safir/database/_pagination.py +++ b/safir/src/safir/database/_pagination.py @@ -269,9 +269,6 @@ class PaginatedList(Generic[E, C]): entries: list[E] """A batch of entries.""" - count: int - """Total available entries.""" - next_cursor: C | None = None """Cursor for the next batch of entries.""" @@ -380,6 +377,30 @@ def __init__(self, entry_type: type[E], cursor_type: type[C]) -> None: self._entry_type = entry_type self._cursor_type = cursor_type + async def query_count( + self, session: async_scoped_session, stmt: Select[tuple] + ) -> int: + """Count the number of objects that match a query. + + There is nothing particular to pagination about this query, but it is + often used in conjunction with pagination to provide the total count + of matching entries, often in an ``X-Total-Count`` HTTP header. + + Parameters + ---------- + session + Database session within which to run the query. + stmt + Select statement to execute. + + Returns + ------- + int + Count of matching rows. + """ + count_stmt = select(func.count()).select_from(stmt.subquery()) + return await session.scalar(count_stmt) or 0 + async def query_object( self, session: async_scoped_session, @@ -506,10 +527,7 @@ async def _full_query( for r in result.all() ] return PaginatedList[E, C]( - entries=entries, - count=len(entries), - prev_cursor=None, - next_cursor=None, + entries=entries, prev_cursor=None, next_cursor=None ) async def _paginated_query( @@ -578,8 +596,6 @@ async def _paginated_query( self._entry_type.model_validate(r, from_attributes=True) for r in result.all() ] - count_stmt = select(func.count()).select_from(stmt.subquery()) - count = await session.scalar(count_stmt) or 0 # Calculate the cursors and remove the extra element we asked for. prev_cursor = None @@ -603,8 +619,5 @@ async def _paginated_query( # Return the results. return PaginatedList[E, C]( - entries=entries, - count=count, - prev_cursor=prev_cursor, - next_cursor=next_cursor, + entries=entries, prev_cursor=prev_cursor, next_cursor=next_cursor ) diff --git a/safir/tests/database_test.py b/safir/tests/database_test.py index 95647362..f56f0500 100644 --- a/safir/tests/database_test.py +++ b/safir/tests/database_test.py @@ -15,7 +15,7 @@ import structlog from pydantic import BaseModel, SecretStr from pydantic_core import Url -from sqlalchemy import Column, MetaData, String, Table, select +from sqlalchemy import Column, MetaData, Select, String, Table, select from sqlalchemy.exc import OperationalError, ProgrammingError from sqlalchemy.orm import ( DeclarativeBase, @@ -411,11 +411,10 @@ async def test_pagination(database_url: str, database_password: str) -> None: # forwards. builder = PaginatedQueryRunner(PaginationModel, TableCursor) async with session.begin(): - result = await builder.query_object( - session, select(PaginationTable), limit=2 - ) + stmt: Select[tuple] = select(PaginationTable) + assert await builder.query_count(session, stmt) == 7 + result = await builder.query_object(session, stmt, limit=2) assert_model_lists_equal(result.entries, rows[:2]) - assert result.count == 7 assert not result.prev_cursor base_url = URL("https://example.com/query") next_url = f"{base_url!s}?cursor={result.next_cursor}" @@ -428,13 +427,9 @@ async def test_pagination(database_url: str, database_password: str) -> None: assert str(result.next_cursor) == "1600000000.5_1" result = await builder.query_object( - session, - select(PaginationTable), - cursor=result.next_cursor, - limit=3, + session, stmt, cursor=result.next_cursor, limit=3 ) assert_model_lists_equal(result.entries, rows[2:5]) - assert result.count == 7 assert str(result.next_cursor) == "1510000000_2" assert str(result.prev_cursor) == "p1600000000.5_1" base_url = URL("https://example.com/query?foo=bar&foo=baz&cursor=xxxx") @@ -452,21 +447,17 @@ async def test_pagination(database_url: str, database_password: str) -> None: next_cursor = result.next_cursor result = await builder.query_object( - session, select(PaginationTable), cursor=result.prev_cursor + session, stmt, cursor=result.prev_cursor ) assert_model_lists_equal(result.entries, rows[:2]) - assert result.count == 7 base_url = URL("https://example.com/query?limit=2") assert result.link_header(base_url) == ( f'<{base_url!s}>; rel="first", ' f'<{base_url!s}&cursor={result.next_cursor}>; rel="next"' ) - result = await builder.query_object( - session, select(PaginationTable), cursor=next_cursor - ) + result = await builder.query_object(session, stmt, cursor=next_cursor) assert_model_lists_equal(result.entries, rows[5:]) - assert result.count == 7 assert not result.next_cursor base_url = URL("https://example.com/query") assert result.next_url(base_url) is None @@ -476,21 +467,17 @@ async def test_pagination(database_url: str, database_password: str) -> None: ) prev_cursor = result.prev_cursor - result = await builder.query_object( - session, select(PaginationTable), cursor=prev_cursor - ) + result = await builder.query_object(session, stmt, cursor=prev_cursor) assert_model_lists_equal(result.entries, rows[:5]) - assert result.count == 7 assert result.link_header(base_url) == ( f'<{base_url!s}>; rel="first", ' f'<{base_url!s}?cursor={result.next_cursor}>; rel="next"' ) result = await builder.query_object( - session, select(PaginationTable), cursor=prev_cursor, limit=2 + session, stmt, cursor=prev_cursor, limit=2 ) assert_model_lists_equal(result.entries, rows[3:5]) - assert result.count == 7 assert str(result.prev_cursor) == "p1520000000_5" assert result.link_header(base_url) == ( f'<{base_url!s}>; rel="first", ' @@ -501,11 +488,10 @@ async def test_pagination(database_url: str, database_password: str) -> None: # Perform one of the queries by attribute instead to test the query_row # function. async with session.begin(): - result = await builder.query_row( - session, select(PaginationTable.time, PaginationTable.id), limit=2 - ) + stmt = select(PaginationTable.time, PaginationTable.id) + result = await builder.query_row(session, stmt, limit=2) assert_model_lists_equal(result.entries, rows[:2]) - assert result.count == 7 + assert await builder.query_count(session, stmt) == 7 # Querying for the entire table should return the everything with no # pagination cursors. Try this with both an object query and an attribute @@ -513,14 +499,11 @@ async def test_pagination(database_url: str, database_password: str) -> None: async with session.begin(): result = await builder.query_object(session, select(PaginationTable)) assert_model_lists_equal(result.entries, rows) - assert result.count == 7 assert not result.next_cursor assert not result.prev_cursor - result = await builder.query_row( - session, select(PaginationTable.id, PaginationTable.time) - ) + stmt = select(PaginationTable.id, PaginationTable.time) + result = await builder.query_row(session, stmt) assert_model_lists_equal(result.entries, rows) - assert result.count == 7 assert not result.next_cursor assert not result.prev_cursor base_url = URL("https://example.com/query?foo=b")