Skip to content

Commit

Permalink
Merge pull request #332 from lsst-sqre/tickets/DM-47769
Browse files Browse the repository at this point in the history
DM-47769: Do not do COUNT queries for pagination by default
  • Loading branch information
rra authored Nov 25, 2024
2 parents f565a3c + adc46c3 commit 10c4f5a
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 47 deletions.
16 changes: 13 additions & 3 deletions docs/user-guide/database/pagination.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Third, your application passes the SQL query and any limit or cursor, along with
This will apply the sort order and any restrictions from the limit or cursor and then execute the query in that session.
It will return a `~safir.database.PaginatedList`, which holds the results along with pagination information.

Finally, the application will return, via the API handler, the list of entries included in the `~safir.database.PaginatedList` along with information about how to obtain the next or previous group of entries and the total number of records.
Finally, the application will return, via the API handler, the list of entries included in the `~safir.database.PaginatedList` along with information about how to obtain the next or previous group of entries and, optionally, the total number of records.
This pagination information is generally returned in HTTP headers, although if you wish to return it in a data structure wrapper around the results, you can do that instead.

Defining the cursor
Expand Down Expand Up @@ -187,6 +187,15 @@ If the SQL query returns a tuple of individually selected attributes that corres
Either way, the results will be a `~safir.database.PaginatedList` wrapping a list of Pydantic models of the appropriate type.

If you want to also return the total number of entries, run a separate ``COUNT`` query:

.. code-block:: python
count = await runner.query_count(session, stmt)
This returns the total number of matching rows without regard to cursor or limit.
Best practice is to return this information in the response so that clients can estimate the total number of result pages, but this query will only be fast if it can be satisfied from the table indices or the table is small, so it is not run by default.

Returning paginated results
===========================

Expand Down Expand Up @@ -226,14 +235,15 @@ Here is a very simplified example of a route handler that sets this header:
)
if cursor or limit:
response.headers["Link"] = results.link_header(request.url)
response.headers["X-Total-Count"] = str(results.count)
count = await runner.query_count(session, stmt)
response.headers["X-Total-Count"] = str(count)
return results.entries
Here, ``perform_query`` is a wrapper around `~safir.database.PaginatedQueryRunner` that constructs and runs the query.
A real route handler would have more query parameters and more documentation.

Note that this example also sets a non-standard ``X-Total-Count`` header containing the total count of entries returned by the underlying query without pagination.
`~safir.database.PaginatedQueryRunner` obtains this information by default, since the count query is often fast for databases to perform.
`~safir.database.PaginatedQueryRunner.query_count` will return this information.
There is no standard way to return this information to the client, but ``X-Total-Count`` is a widely-used informal standard.

Including links in the response
Expand Down
39 changes: 26 additions & 13 deletions safir/src/safir/database/_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,6 @@ class PaginatedList(Generic[E, C]):
entries: list[E]
"""A batch of entries."""

count: int
"""Total available entries."""

next_cursor: C | None = None
"""Cursor for the next batch of entries."""

Expand Down Expand Up @@ -380,6 +377,30 @@ def __init__(self, entry_type: type[E], cursor_type: type[C]) -> None:
self._entry_type = entry_type
self._cursor_type = cursor_type

async def query_count(
self, session: async_scoped_session, stmt: Select[tuple]
) -> int:
"""Count the number of objects that match a query.
There is nothing particular to pagination about this query, but it is
often used in conjunction with pagination to provide the total count
of matching entries, often in an ``X-Total-Count`` HTTP header.
Parameters
----------
session
Database session within which to run the query.
stmt
Select statement to execute.
Returns
-------
int
Count of matching rows.
"""
count_stmt = select(func.count()).select_from(stmt.subquery())
return await session.scalar(count_stmt) or 0

async def query_object(
self,
session: async_scoped_session,
Expand Down Expand Up @@ -506,10 +527,7 @@ async def _full_query(
for r in result.all()
]
return PaginatedList[E, C](
entries=entries,
count=len(entries),
prev_cursor=None,
next_cursor=None,
entries=entries, prev_cursor=None, next_cursor=None
)

async def _paginated_query(
Expand Down Expand Up @@ -578,8 +596,6 @@ async def _paginated_query(
self._entry_type.model_validate(r, from_attributes=True)
for r in result.all()
]
count_stmt = select(func.count()).select_from(stmt.subquery())
count = await session.scalar(count_stmt) or 0

# Calculate the cursors and remove the extra element we asked for.
prev_cursor = None
Expand All @@ -603,8 +619,5 @@ async def _paginated_query(

# Return the results.
return PaginatedList[E, C](
entries=entries,
count=count,
prev_cursor=prev_cursor,
next_cursor=next_cursor,
entries=entries, prev_cursor=prev_cursor, next_cursor=next_cursor
)
45 changes: 14 additions & 31 deletions safir/tests/database_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import structlog
from pydantic import BaseModel, SecretStr
from pydantic_core import Url
from sqlalchemy import Column, MetaData, String, Table, select
from sqlalchemy import Column, MetaData, Select, String, Table, select
from sqlalchemy.exc import OperationalError, ProgrammingError
from sqlalchemy.orm import (
DeclarativeBase,
Expand Down Expand Up @@ -411,11 +411,10 @@ async def test_pagination(database_url: str, database_password: str) -> None:
# forwards.
builder = PaginatedQueryRunner(PaginationModel, TableCursor)
async with session.begin():
result = await builder.query_object(
session, select(PaginationTable), limit=2
)
stmt: Select[tuple] = select(PaginationTable)
assert await builder.query_count(session, stmt) == 7
result = await builder.query_object(session, stmt, limit=2)
assert_model_lists_equal(result.entries, rows[:2])
assert result.count == 7
assert not result.prev_cursor
base_url = URL("https://example.com/query")
next_url = f"{base_url!s}?cursor={result.next_cursor}"
Expand All @@ -428,13 +427,9 @@ async def test_pagination(database_url: str, database_password: str) -> None:
assert str(result.next_cursor) == "1600000000.5_1"

result = await builder.query_object(
session,
select(PaginationTable),
cursor=result.next_cursor,
limit=3,
session, stmt, cursor=result.next_cursor, limit=3
)
assert_model_lists_equal(result.entries, rows[2:5])
assert result.count == 7
assert str(result.next_cursor) == "1510000000_2"
assert str(result.prev_cursor) == "p1600000000.5_1"
base_url = URL("https://example.com/query?foo=bar&foo=baz&cursor=xxxx")
Expand All @@ -452,21 +447,17 @@ async def test_pagination(database_url: str, database_password: str) -> None:
next_cursor = result.next_cursor

result = await builder.query_object(
session, select(PaginationTable), cursor=result.prev_cursor
session, stmt, cursor=result.prev_cursor
)
assert_model_lists_equal(result.entries, rows[:2])
assert result.count == 7
base_url = URL("https://example.com/query?limit=2")
assert result.link_header(base_url) == (
f'<{base_url!s}>; rel="first", '
f'<{base_url!s}&cursor={result.next_cursor}>; rel="next"'
)

result = await builder.query_object(
session, select(PaginationTable), cursor=next_cursor
)
result = await builder.query_object(session, stmt, cursor=next_cursor)
assert_model_lists_equal(result.entries, rows[5:])
assert result.count == 7
assert not result.next_cursor
base_url = URL("https://example.com/query")
assert result.next_url(base_url) is None
Expand All @@ -476,21 +467,17 @@ async def test_pagination(database_url: str, database_password: str) -> None:
)
prev_cursor = result.prev_cursor

result = await builder.query_object(
session, select(PaginationTable), cursor=prev_cursor
)
result = await builder.query_object(session, stmt, cursor=prev_cursor)
assert_model_lists_equal(result.entries, rows[:5])
assert result.count == 7
assert result.link_header(base_url) == (
f'<{base_url!s}>; rel="first", '
f'<{base_url!s}?cursor={result.next_cursor}>; rel="next"'
)

result = await builder.query_object(
session, select(PaginationTable), cursor=prev_cursor, limit=2
session, stmt, cursor=prev_cursor, limit=2
)
assert_model_lists_equal(result.entries, rows[3:5])
assert result.count == 7
assert str(result.prev_cursor) == "p1520000000_5"
assert result.link_header(base_url) == (
f'<{base_url!s}>; rel="first", '
Expand All @@ -501,26 +488,22 @@ async def test_pagination(database_url: str, database_password: str) -> None:
# Perform one of the queries by attribute instead to test the query_row
# function.
async with session.begin():
result = await builder.query_row(
session, select(PaginationTable.time, PaginationTable.id), limit=2
)
stmt = select(PaginationTable.time, PaginationTable.id)
result = await builder.query_row(session, stmt, limit=2)
assert_model_lists_equal(result.entries, rows[:2])
assert result.count == 7
assert await builder.query_count(session, stmt) == 7

# Querying for the entire table should return the everything with no
# pagination cursors. Try this with both an object query and an attribute
# query.
async with session.begin():
result = await builder.query_object(session, select(PaginationTable))
assert_model_lists_equal(result.entries, rows)
assert result.count == 7
assert not result.next_cursor
assert not result.prev_cursor
result = await builder.query_row(
session, select(PaginationTable.id, PaginationTable.time)
)
stmt = select(PaginationTable.id, PaginationTable.time)
result = await builder.query_row(session, stmt)
assert_model_lists_equal(result.entries, rows)
assert result.count == 7
assert not result.next_cursor
assert not result.prev_cursor
base_url = URL("https://example.com/query?foo=b")
Expand Down

0 comments on commit 10c4f5a

Please sign in to comment.