Skip to content

Commit

Permalink
distinct can work with nonhashable entries incl. 'EJSON types'
Browse files Browse the repository at this point in the history
  • Loading branch information
hemidactylus committed Mar 10, 2024
1 parent 7c21256 commit 31bb77e
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 12 deletions.
64 changes: 53 additions & 11 deletions astrapy/idiomatic/cursors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,23 @@

from __future__ import annotations

import hashlib
import json
from collections.abc import Iterator, AsyncIterator
from typing import (
Any,
Callable,
Dict,
Generic,
Iterable,
List,
Optional,
TypeVar,
Union,
TYPE_CHECKING,
)

from astrapy.utils import _normalize_payload_value
from astrapy.idiomatic.types import (
DocumentType,
ProjectionType,
Expand All @@ -42,6 +47,20 @@
FIND_PREFETCH = 20


def _create_document_key_extractor(
key: str,
) -> Callable[[Dict[str, Any]], Iterable[Any]]:
if "." in key:
raise NotImplementedError

def _item_extractor(document: Dict[str, Any]) -> Any:
# TEMPORARY
if key in document:
yield document[key]

return _item_extractor


class BaseCursor:
"""
Represents a generic Cursor over query results, regardless of whether
Expand Down Expand Up @@ -119,6 +138,7 @@ def _ensure_not_started(self) -> None:
def _copy(
self: BC,
*,
projection: Optional[ProjectionType] = None,
limit: Optional[int] = None,
skip: Optional[int] = None,
started: Optional[bool] = None,
Expand All @@ -127,7 +147,7 @@ def _copy(
new_cursor = self.__class__(
collection=self._collection,
filter=self._filter,
projection=self._projection,
projection=projection or self._projection,
)
# Cursor treated as mutable within this function scope:
new_cursor._limit = limit if limit is not None else self._limit
Expand Down Expand Up @@ -371,9 +391,22 @@ def distinct(self, key: str) -> List[Any]:
network traffic and possibly billing.
"""

return list(
{document[key] for document in self._copy(started=False) if key in document}
)
_item_hashes = set()
distinct_items = []

_extractor = _create_document_key_extractor(key)

d_cursor = self._copy(projection={key: True}, started=False)
for document in d_cursor:
for item in _extractor(document):
_normalized_item = _normalize_payload_value(path=[], value=item)
_normalized_json = json.dumps(_normalized_item, separators=(",", ":"))
_item_hash = hashlib.md5(_normalized_json.encode()).hexdigest()
if _item_hash not in _item_hashes:
_item_hashes.add(_item_hash)
distinct_items.append(item)

return distinct_items


class AsyncCursor(BaseCursor):
Expand Down Expand Up @@ -515,13 +548,22 @@ async def distinct(self, key: str) -> List[Any]:
network traffic and possibly billing.
"""

return list(
{
document[key]
async for document in self._copy(started=False)
if key in document
}
)
_item_hashes = set()
distinct_items = []

_extractor = _create_document_key_extractor(key)

d_cursor = self._copy(projection={key: True}, started=False)
async for document in d_cursor:
for item in _extractor(document):
_normalized_item = _normalize_payload_value(path=[], value=item)
_normalized_json = json.dumps(_normalized_item, separators=(",", ":"))
_item_hash = hashlib.md5(_normalized_json.encode()).hexdigest()
if _item_hash not in _item_hashes:
_item_hashes.add(_item_hash)
distinct_items.append(item)

return distinct_items


class CommandCursor(Generic[T]):
Expand Down
30 changes: 29 additions & 1 deletion tests/idiomatic/integration/test_dml_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
import datetime

from typing import Any, Dict, List

import pytest

Expand Down Expand Up @@ -438,6 +440,32 @@ async def _alist(acursor: AsyncCursor) -> List[DocumentType]:
cursor7.rewind()
cursor7["wrong"]

@pytest.mark.describe("test of distinct with non-hashable items, async")
async def test_collection_distinct_nonhashable_async(
self,
async_empty_collection: AsyncCollection,
) -> None:
acol = async_empty_collection
documents: List[Dict[str, Any]] = [
{},
{"f": 1},
{"f": "a"},
{"f": {"subf": 99}},
{"f": {"subf": 99, "another": {"subsubf": [True, False]}}},
{"f": [10, 11]},
{"f": [11, 10]},
{"f": [10]},
{"f": datetime.datetime(2000, 1, 1, 12, 00, 00)},
{"f": None},
]
await acol.insert_many(documents * 2)

d_items = await acol.distinct("f")
assert len(d_items) == len([doc for doc in documents if "f" in doc])
for doc in documents:
if "f" in doc:
assert doc["f"] in d_items

@pytest.mark.describe("test of collection insert_many, async")
async def test_collection_insert_many_async(
self,
Expand Down
29 changes: 29 additions & 0 deletions tests/idiomatic/integration/test_dml_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime

import pytest
from typing import Any, Dict, List

from astrapy import Collection
from astrapy.results import DeleteResult, InsertOneResult
Expand Down Expand Up @@ -384,6 +387,32 @@ def test_collection_cursors_sync(
cursor7.rewind()
cursor7["wrong"]

@pytest.mark.describe("test of distinct with non-hashable items, sync")
def test_collection_distinct_nonhashable_sync(
self,
sync_empty_collection: Collection,
) -> None:
col = sync_empty_collection
documents: List[Dict[str, Any]] = [
{},
{"f": 1},
{"f": "a"},
{"f": {"subf": 99}},
{"f": {"subf": 99, "another": {"subsubf": [True, False]}}},
{"f": [10, 11]},
{"f": [11, 10]},
{"f": [10]},
{"f": datetime.datetime(2000, 1, 1, 12, 00, 00)},
{"f": None},
]
col.insert_many(documents)

d_items = col.distinct("f")
assert len(d_items) == len([doc for doc in documents if "f" in doc])
for doc in documents:
if "f" in doc:
assert doc["f"] in d_items

@pytest.mark.describe("test of collection insert_many, sync")
def test_collection_insert_many_sync(
self,
Expand Down

0 comments on commit 31bb77e

Please sign in to comment.