Skip to content

Commit

Permalink
feat: aio_count method added
Browse files Browse the repository at this point in the history
  • Loading branch information
kalombos committed May 27, 2024
1 parent bc6bc52 commit 3089e5f
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 86 deletions.
27 changes: 25 additions & 2 deletions peewee_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,8 @@ async def aio_execute(self, query, fetch_results=None):
# To make `Database.aio_execute` compatible with peewee's sync queries we
# apply optional patching, it will do nothing for Aio-counterparts:
_patch_query_with_compat_methods(query, None)
sql, params = query.sql()
ctx = self.get_sql_context()
sql, params = ctx.sql(query).query()
fetch_results = fetch_results or getattr(query, 'fetch_results', None)
return await self.aio_execute_sql(sql, params, fetch_results=fetch_results)

Expand Down Expand Up @@ -694,7 +695,7 @@ async def fetch_results(self, cursor):
return await self.make_async_query_wrapper(cursor)


class AioModelSelect(peewee.ModelSelect, AioQueryMixin):
class AioSelectMixin(AioQueryMixin):

async def fetch_results(self, cursor):
return await self.make_async_query_wrapper(cursor)
Expand Down Expand Up @@ -723,6 +724,28 @@ async def aio_get(self, database=None):
'not exist:\nSQL: %s\nParams: %s' %
(clone.model, sql, params))

@peewee.database_required
async def aio_count(self, database, clear_limit=False):
clone = self.order_by().alias('_wrapped')
if clear_limit:
clone._limit = clone._offset = None
try:
if clone._having is None and clone._group_by is None and \
clone._windows is None and clone._distinct is None and \
clone._simple_distinct is not True:
clone = clone.select(peewee.SQL('1'))
except AttributeError:
pass
return await AioSelect([clone], [peewee.fn.COUNT(peewee.SQL('1'))]).aio_scalar(database)


class AioSelect(peewee.Select, AioSelectMixin):
pass


class AioModelSelect(peewee.ModelSelect, AioSelectMixin):
pass


class AioModel(peewee.Model):
"""Async version of **peewee.Model** that allows to execute queries asynchronously
Expand Down
27 changes: 8 additions & 19 deletions peewee_async_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def _patch_query_with_compat_methods(query, async_query_cls):
if async_query_cls is AioModelSelect:
query.aio_get = partial(async_query_cls.aio_get, query)
query.aio_scalar = partial(async_query_cls.aio_scalar, query)
query.aio_count = partial(async_query_cls.aio_count, query)


def _query_db(query):
Expand All @@ -94,25 +95,13 @@ async def count(query, clear_limit=False):
:return: number of objects in `select()` query
"""
database = _query_db(query)
clone = query.clone()
if query._distinct or query._group_by or query._limit or query._offset:
if clear_limit:
clone._limit = clone._offset = None
sql, params = clone.sql()
wrapped = 'SELECT COUNT(1) FROM (%s) AS wrapped_select' % sql
async def fetch_results(cursor):
row = await cursor.fetchone()
if row:
return row[0]
else:
return row
result = await database.aio_execute_sql(wrapped, params, fetch_results)
return result or 0
else:
clone._returning = [peewee.fn.Count(peewee.SQL('*'))]
clone._order_by = None
return (await scalar(clone)) or 0
from peewee_async import AioModelSelect # noqa
warnings.warn(
"`count` is deprecated, use `query.aio_count` method.",
DeprecationWarning
)
_patch_query_with_compat_methods(query, AioModelSelect)
return await query.aio_count(clear_limit=clear_limit)


async def prefetch(sq, *subqueries, prefetch_type):
Expand Down
18 changes: 18 additions & 0 deletions tests/aio_model/test_shortcuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,21 @@ async def test_aio_scalar(db):
).aio_scalar(as_tuple=True) == (2, 1)

assert await TestModel.select().aio_scalar() is None


@dbs_all
async def test_count_query(db):

for num in range(5):
await IntegerTestModel.aio_create(num=num)
count = await IntegerTestModel.select().limit(3).aio_count()
assert count == 3


@dbs_all
async def test_count_query_clear_limit(db):

for num in range(5):
await IntegerTestModel.aio_create(num=num)
count = await IntegerTestModel.select().limit(3).aio_count(clear_limit=True)
assert count == 5
71 changes: 71 additions & 0 deletions tests/compat/test_shortcuts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import uuid

import peewee

from tests.conftest import manager_for_all_dbs
from tests.models import CompatTestModel


@manager_for_all_dbs
async def test_get_or_none(manager):
"""Test get_or_none manager function."""
text1 = "Test %s" % uuid.uuid4()
text2 = "Test %s" % uuid.uuid4()

obj1 = await manager.create(CompatTestModel, text=text1)
obj2 = await manager.get_or_none(CompatTestModel, text=text1)
obj3 = await manager.get_or_none(CompatTestModel, text=text2)

assert obj1 == obj2
assert obj1 is not None
assert obj2 is not None
assert obj3 is None


@manager_for_all_dbs
async def test_count_query_with_limit(manager):
text = "Test %s" % uuid.uuid4()
await manager.create(CompatTestModel, text=text)
text = "Test %s" % uuid.uuid4()
await manager.create(CompatTestModel, text=text)
text = "Test %s" % uuid.uuid4()
await manager.create(CompatTestModel, text=text)

count = await manager.count(CompatTestModel.select().limit(1))
assert count == 1


@manager_for_all_dbs
async def test_count_query(manager):
text = "Test %s" % uuid.uuid4()
await manager.create(CompatTestModel, text=text)
text = "Test %s" % uuid.uuid4()
await manager.create(CompatTestModel, text=text)
text = "Test %s" % uuid.uuid4()
await manager.create(CompatTestModel, text=text)

count = await manager.count(CompatTestModel.select())
assert count == 3


@manager_for_all_dbs
async def test_scalar_query(manager):

text = "Test %s" % uuid.uuid4()
await manager.create(CompatTestModel, text=text)
text = "Test %s" % uuid.uuid4()
await manager.create(CompatTestModel, text=text)

fn = peewee.fn.Count(CompatTestModel.id)
count = await manager.scalar(CompatTestModel.select(fn))

assert count == 2


@manager_for_all_dbs
async def test_create_obj(manager):

text = "Test %s" % uuid.uuid4()
obj = await manager.create(CompatTestModel, text=text)
assert obj is not None
assert obj.text == text
65 changes: 0 additions & 65 deletions tests/test_shortcuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,62 +43,6 @@ async def test_prefetch(manager, prefetch_type):
assert tuple(result[0].betas[0].gammas) == (gamma_111, gamma_112)


@manager_for_all_dbs
async def test_get_or_none(manager):
"""Test get_or_none manager function."""
text1 = "Test %s" % uuid.uuid4()
text2 = "Test %s" % uuid.uuid4()

obj1 = await manager.create(TestModel, text=text1)
obj2 = await manager.get_or_none(TestModel, text=text1)
obj3 = await manager.get_or_none(TestModel, text=text2)

assert obj1 == obj2
assert obj1 is not None
assert obj2 is not None
assert obj3 is None


@manager_for_all_dbs
async def test_count_query_with_limit(manager):
text = "Test %s" % uuid.uuid4()
await manager.create(TestModel, text=text)
text = "Test %s" % uuid.uuid4()
await manager.create(TestModel, text=text)
text = "Test %s" % uuid.uuid4()
await manager.create(TestModel, text=text)

count = await manager.count(TestModel.select().limit(1))
assert count == 1


@manager_for_all_dbs
async def test_count_query(manager):
text = "Test %s" % uuid.uuid4()
await manager.create(TestModel, text=text)
text = "Test %s" % uuid.uuid4()
await manager.create(TestModel, text=text)
text = "Test %s" % uuid.uuid4()
await manager.create(TestModel, text=text)

count = await manager.count(TestModel.select())
assert count == 3


@manager_for_all_dbs
async def test_scalar_query(manager):

text = "Test %s" % uuid.uuid4()
await manager.create(TestModel, text=text)
text = "Test %s" % uuid.uuid4()
await manager.create(TestModel, text=text)

fn = peewee.fn.Count(TestModel.id)
count = await manager.scalar(TestModel.select(fn))

assert count == 2


@manager_for_all_dbs
async def test_delete_obj(manager):
text = "Test %s" % uuid.uuid4()
Expand All @@ -124,15 +68,6 @@ async def test_update_obj(manager):
assert obj2.text == "Test update object"


@manager_for_all_dbs
async def test_create_obj(manager):

text = "Test %s" % uuid.uuid4()
obj = await manager.create(TestModel, text=text)
assert obj is not None
assert obj.text == text


@manager_for_all_dbs
async def test_create_or_get(manager):
text = "Test %s" % uuid.uuid4()
Expand Down

0 comments on commit 3089e5f

Please sign in to comment.