Skip to content

Commit

Permalink
v2.0.0: fastapi extention pagination.
Browse files Browse the repository at this point in the history
  • Loading branch information
ALittleMoron committed May 29, 2024
1 parent b4c0ae8 commit d721bce
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 3 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,15 @@ dev = [

[project]
name = "sqlrepo"
version = "1.5.2"
version = "2.0.0"
description = "sqlalchemy repositories with crud operations and other utils for it."
authors = [{ name = "Dmitriy Lunev", email = "dima.lunev14@gmail.com" }]
requires-python = ">=3.11"
readme = "README.md"
license = { text = "MIT" }
dependencies = [
"sqlalchemy>=2.0.29",
"python-dev-utils[sqlalchemy_filters]>=1.11.0",
"python-dev-utils[sqlalchemy_filters]>=2.2.0",
]

[project.optional-dependencies]
Expand Down
129 changes: 129 additions & 0 deletions sqlrepo/ext/fastapi/pagination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from typing import Annotated, Generic, TypeVar

from fastapi import Query
from pydantic import BaseModel, ConfigDict

T = TypeVar("T")


class BaseSchema(BaseModel):
"""Base schema for pagination."""

model_config = ConfigDict(from_attributes=True)


class PaginationMeta(BaseSchema):
"""Metadata of pagination result."""

all_records_count: int
filtered_records_count: int
per_page: int
current_page: int
all_pages_count: int
filtered_pages_count: int
prev_page: int | None = None
next_page: int | None = None

@classmethod
def create( # noqa: D102
cls,
*,
pagination: "AbstractBasePagination",
all_records_count: int,
filtered_records_count: int | None = None,
) -> "PaginationMeta":
if filtered_records_count is None: # pragma: no coverage
filtered_records_count = all_records_count
current_page = pagination.current_page
per_page = pagination.per_page
all_pages_count = all_records_count // per_page
filtered_pages_count = filtered_records_count // per_page
prev_page = current_page - 1 if (current_page - 1) > 0 else None
next_page = current_page + 1 if (current_page + 1) <= filtered_pages_count else None
return cls(
all_records_count=all_records_count,
filtered_records_count=filtered_records_count,
per_page=per_page,
current_page=current_page,
all_pages_count=all_pages_count,
filtered_pages_count=filtered_pages_count,
prev_page=prev_page,
next_page=next_page,
)


class PaginatedResult(BaseSchema, Generic[T]):
"""Pagination result."""

meta: PaginationMeta
data: list[T]


class AbstractBasePagination:
"""Abstract base pagination depends."""

limit: int
offset: int
per_page: int
current_page: int

def __init__(self) -> None: # pragma: no coverage
raise NotImplementedError()


class LimitOffsetPagination(AbstractBasePagination):
"""Limit-Offset pagination depends."""

def __init__(
self,
limit: Annotated[
int,
Query(
ge=1,
le=100,
description="SQL limit.",
examples=[1, 50, 100],
),
] = 50,
offset: Annotated[
int,
Query(
ge=0,
description="SQL offset.",
examples=[0, 10, 1000],
),
] = 0,
) -> None:
self.limit = limit
self.offset = offset
self.per_page = limit
self.current_page = (offset // limit) + 1


class PageSizePagination(AbstractBasePagination):
"""Page-Size pagination depends."""

def __init__(
self,
per_page: Annotated[
int,
Query(
ge=1,
le=100,
description="Count of items in paginated result.",
examples=[1, 50, 100],
),
] = 50,
page: Annotated[
int,
Query(
ge=1,
description="Number of current page.",
examples=[0, 10, 1000],
),
] = 1,
) -> None:
self.per_page = per_page
self.current_page = page
self.limit = per_page
self.offset = (page - 1) * per_page
12 changes: 12 additions & 0 deletions sqlrepo/ext/fastapi/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sqlalchemy.orm.decl_api import DeclarativeBase

from sqlrepo.ext.fastapi.helpers import NotSet, NotSetType
from sqlrepo.ext.fastapi.pagination import PaginatedResult, PaginationMeta
from sqlrepo.logging import logger

if TYPE_CHECKING:
Expand Down Expand Up @@ -154,6 +155,17 @@ def resolve_entity_list(self, entities: "Sequence[TModel]") -> "list[VListSchema
raise AttributeError(msg) # noqa: TRY004
return TypeAdapter(list[self.list_schema]).validate_python(entities, from_attributes=True)

def paginate_result(
self,
entities: "Sequence[TModel]",
meta: PaginationMeta,
) -> PaginatedResult["VListSchema"]:
"""Resolve list if entities and put them into pagination."""
return PaginatedResult(
meta=meta,
data=self.resolve_entity_list(entities=entities),
)


class BaseAsyncService(BaseService[TModel, TDetailSchema, VListSchema]):
"""Base service with async interface."""
Expand Down
87 changes: 86 additions & 1 deletion tests/test_fastapi_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,16 @@
from sqlalchemy.orm.session import Session

from sqlrepo.ext.fastapi import BaseSyncContainer, BaseSyncService, add_container_overrides
from sqlrepo.ext.fastapi.pagination import (
AbstractBasePagination,
LimitOffsetPagination,
PageSizePagination,
PaginatedResult,
PaginationMeta,
)
from sqlrepo.ext.fastapi.services import NotSet, ServiceClassIncorrectUseWarning
from sqlrepo.repositories import BaseSyncRepository
from tests.utils import MyModel
from tests.utils import MyModel, assert_compare_db_item_with_dict

if TYPE_CHECKING:
from collections.abc import Callable
Expand Down Expand Up @@ -71,6 +78,18 @@ def list(self) -> list[MyModelList]: # noqa: D102
entities = self.my_model_repo.list()
return self.resolve_entity_list(entities)

def list_paginated( # noqa: D102
self,
pagination: AbstractBasePagination,
) -> PaginatedResult[MyModelList]:
entities = self.my_model_repo.list(limit=pagination.limit, offset=pagination.offset)
total_count = self.my_model_repo.count()
meta = PaginationMeta.create(
all_records_count=total_count,
pagination=pagination,
)
return self.paginate_result(entities, meta)


class InvalidService(MyModelService): # noqa: D101
...
Expand Down Expand Up @@ -159,6 +178,20 @@ def get_one_python(my_model_id: int = Path(), container: Container = Depends()):
def get_all(container: Container = Depends()): # type: ignore # noqa: ANN202
return container.my_model_service.list()

@app.get('/get-limit-offset-paginated/')
def get_paginated_limit_offset( # type: ignore # noqa: ANN202
pagination: LimitOffsetPagination = Depends(),
container: Container = Depends(),
):
return container.my_model_service.list_paginated(pagination)

@app.get('/get-page-size-paginated/')
def get_paginated_page_size( # type: ignore # noqa: ANN202
pagination: PageSizePagination = Depends(),
container: Container = Depends(),
):
return container.my_model_service.list_paginated(pagination)

@app.get('/get-all-invalid/')
def get_all_invalid(container: Container = Depends()): # type: ignore # noqa: ANN202
return container.invalid_service.list()
Expand Down Expand Up @@ -222,6 +255,58 @@ def test_get_one_python_not_found(
app_with_sync_container.get('/get-one-python/1251251')


def test_limit_offset_pagination(
db_sync_session: "Session",
mymodel_sync_factory: "SyncFactoryFunctionProtocol[MyModel]",
app_with_sync_container: "TestClient",
) -> None:

items = [mymodel_sync_factory(db_sync_session, commit=False) for _ in range(3)]
items_map = {item.id: item for item in items}
db_sync_session.commit()
response = app_with_sync_container.get('/get-limit-offset-paginated/?limit=1')
assert response.status_code == status.HTTP_200_OK
response = response.json()
schema = TypeAdapter(PaginatedResult[MyModelList]).validate_python(response)
assert schema.meta.all_pages_count == len(items)
assert schema.meta.filtered_pages_count == len(items)
assert schema.meta.all_records_count == len(items)
assert schema.meta.filtered_records_count == len(items)
assert schema.meta.per_page == 1
assert schema.meta.current_page == 1
assert schema.meta.prev_page is None
assert schema.meta.next_page == 2 # noqa: PLR2004
for item in schema.data:
assert item.id in items_map
assert_compare_db_item_with_dict(items_map[item.id], item.model_dump())


def test_page_size_pagination(
db_sync_session: "Session",
mymodel_sync_factory: "SyncFactoryFunctionProtocol[MyModel]",
app_with_sync_container: "TestClient",
) -> None:

items = [mymodel_sync_factory(db_sync_session, commit=False) for _ in range(3)]
items_map = {item.id: item for item in items}
db_sync_session.commit()
response = app_with_sync_container.get('/get-page-size-paginated/?per_page=1')
assert response.status_code == status.HTTP_200_OK
response = response.json()
schema = TypeAdapter(PaginatedResult[MyModelList]).validate_python(response)
assert schema.meta.all_pages_count == len(items)
assert schema.meta.filtered_pages_count == len(items)
assert schema.meta.all_records_count == len(items)
assert schema.meta.filtered_records_count == len(items)
assert schema.meta.per_page == 1
assert schema.meta.current_page == 1
assert schema.meta.prev_page is None
assert schema.meta.next_page == 2 # noqa: PLR2004
for item in schema.data:
assert item.id in items_map
assert_compare_db_item_with_dict(items_map[item.id], item.model_dump())


def test_get_all(
db_sync_session: "Session",
mymodel_sync_factory: "SyncFactoryFunctionProtocol[MyModel]",
Expand Down

0 comments on commit d721bce

Please sign in to comment.