Skip to content

Commit

Permalink
Cursor/AsyncCursor, find and distinct (#234)
Browse files Browse the repository at this point in the history
* insert_many(sync)+test, find(sync)+test

* tests for find(sync); distinct in cursor(sync)+tests; distinct in collection(sync)+tests

* insert_many (sync/async) + tests

* refactor into base cursor and cursor

* async cursors

* distinct(s/a + tests), async cursors(+ tests)
  • Loading branch information
hemidactylus authored Mar 5, 2024
1 parent d0b97ac commit 89209de
Show file tree
Hide file tree
Showing 9 changed files with 1,186 additions and 27 deletions.
27 changes: 15 additions & 12 deletions astrapy/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
import json
import threading

from collections.abc import (
AsyncGenerator,
AsyncIterator,
Generator,
Iterator,
)
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from queue import Queue
Expand All @@ -28,14 +34,11 @@
Any,
cast,
Dict,
Iterable,
List,
Optional,
Tuple,
Union,
Type,
AsyncIterable,
AsyncGenerator,
)

from astrapy import __version__
Expand Down Expand Up @@ -119,7 +122,7 @@ def __init__(
self.caller_name = self.astra_db.caller_name
self.caller_version = self.astra_db.caller_version
self.collection_name = collection_name
self.base_path = f"{self.astra_db.base_path}/{self.collection_name}"
self.base_path: str = f"{self.astra_db.base_path}/{self.collection_name}"

def __repr__(self) -> str:
return f'AstraDBCollection[astra_db="{self.astra_db}", collection_name="{self.collection_name}"]'
Expand Down Expand Up @@ -280,7 +283,7 @@ def find(
self,
filter: Optional[Dict[str, Any]] = None,
projection: Optional[Dict[str, Any]] = None,
sort: Optional[Dict[str, Any]] = {},
sort: Optional[Dict[str, Any]] = None,
options: Optional[Dict[str, Any]] = None,
) -> API_RESPONSE:
"""
Expand Down Expand Up @@ -356,7 +359,7 @@ def paginate(
request_method: PaginableRequestMethod,
options: Optional[Dict[str, Any]],
prefetched: Optional[int] = None,
) -> Iterable[API_DOC]:
) -> Generator[API_DOC, None, None]:
"""
Generate paginated results for a given database query method.
Args:
Expand Down Expand Up @@ -415,7 +418,7 @@ def paginated_find(
sort: Optional[Dict[str, Any]] = None,
options: Optional[Dict[str, Any]] = None,
prefetched: Optional[int] = None,
) -> Iterable[API_DOC]:
) -> Iterator[API_DOC]:
"""
Perform a paginated search in the collection.
Args:
Expand Down Expand Up @@ -1156,7 +1159,7 @@ def __init__(
self.caller_version = self.astra_db.caller_version
self.client = astra_db.client
self.collection_name = collection_name
self.base_path = f"{self.astra_db.base_path}/{self.collection_name}"
self.base_path: str = f"{self.astra_db.base_path}/{self.collection_name}"

def __repr__(self) -> str:
return f'AsyncAstraDBCollection[astra_db="{self.astra_db}", collection_name="{self.collection_name}"]'
Expand Down Expand Up @@ -1318,7 +1321,7 @@ async def find(
self,
filter: Optional[Dict[str, Any]] = None,
projection: Optional[Dict[str, Any]] = None,
sort: Optional[Dict[str, Any]] = {},
sort: Optional[Dict[str, Any]] = None,
options: Optional[Dict[str, Any]] = None,
) -> API_RESPONSE:
"""
Expand Down Expand Up @@ -1449,7 +1452,7 @@ def paginated_find(
sort: Optional[Dict[str, Any]] = None,
options: Optional[Dict[str, Any]] = None,
prefetched: Optional[int] = None,
) -> AsyncIterable[API_DOC]:
) -> AsyncIterator[API_DOC]:
"""
Perform a paginated search in the collection.
Args:
Expand Down Expand Up @@ -2141,7 +2144,7 @@ def __init__(
self.namespace = namespace

# Finally, construct the full base path
self.base_path = f"/{self.api_path}/{self.api_version}/{self.namespace}"
self.base_path: str = f"/{self.api_path}/{self.api_version}/{self.namespace}"

def __repr__(self) -> str:
return f'AstraDB[endpoint="{self.base_url}", keyspace="{self.namespace}"]'
Expand Down Expand Up @@ -2428,7 +2431,7 @@ def __init__(
self.namespace = namespace

# Finally, construct the full base path
self.base_path = f"/{self.api_path}/{self.api_version}/{self.namespace}"
self.base_path: str = f"/{self.api_path}/{self.api_version}/{self.namespace}"

def __repr__(self) -> str:
return f'AsyncAstraDB[endpoint="{self.base_url}", keyspace="{self.namespace}"]'
Expand Down
175 changes: 165 additions & 10 deletions astrapy/idiomatic/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
from __future__ import annotations

import json
from typing import Any, Dict, Optional
from typing import Any, Dict, Iterable, List, Optional

from astrapy.db import AstraDBCollection, AsyncAstraDBCollection
from astrapy.idiomatic.types import DocumentType, ProjectionType
from astrapy.idiomatic.utils import raise_unsupported_parameter, unsupported
from astrapy.idiomatic.database import AsyncDatabase, Database
from astrapy.idiomatic.results import DeleteResult, InsertOneResult
from astrapy.idiomatic.results import DeleteResult, InsertManyResult, InsertOneResult
from astrapy.idiomatic.cursors import AsyncCursor, Cursor

INSERT_MANY_CONCURRENCY = 20


class Collection:
Expand Down Expand Up @@ -110,7 +114,7 @@ def set_caller(

def insert_one(
self,
document: Dict[str, Any],
document: DocumentType,
*,
bypass_document_validation: Optional[bool] = None,
) -> InsertOneResult:
Expand Down Expand Up @@ -138,6 +142,84 @@ def insert_one(
f"(gotten '${json.dumps(io_response)}')"
)

def insert_many(
self,
documents: Iterable[DocumentType],
*,
ordered: bool = True,
bypass_document_validation: Optional[bool] = None,
) -> InsertManyResult:
if bypass_document_validation:
raise_unsupported_parameter(
class_name=self.__class__.__name__,
method_name="insert_many",
parameter_name="bypass_document_validation",
)
if ordered:
cim_responses = self._astra_db_collection.chunked_insert_many(
documents=list(documents),
options={"ordered": True},
partial_failures_allowed=False,
concurrency=1,
)
else:
# unordered insertion: can do chunks concurrently
cim_responses = self._astra_db_collection.chunked_insert_many(
documents=list(documents),
options={"ordered": False},
partial_failures_allowed=True,
concurrency=INSERT_MANY_CONCURRENCY,
)
_exceptions = [cim_r for cim_r in cim_responses if isinstance(cim_r, Exception)]
_errors_in_response = [
err
for response in cim_responses
if isinstance(response, dict)
for err in (response.get("errors") or [])
]
if _exceptions:
raise _exceptions[0]
elif _errors_in_response:
raise ValueError(str(_errors_in_response[0]))
else:
inserted_ids = [
ins_id
for response in cim_responses
if isinstance(response, dict)
for ins_id in (response.get("status") or {}).get("insertedIds", [])
]
return InsertManyResult(inserted_ids=inserted_ids)

def find(
self,
filter: Optional[Dict[str, Any]] = None,
*,
projection: Optional[ProjectionType] = None,
skip: Optional[int] = None,
limit: Optional[int] = None,
sort: Optional[Dict[str, Any]] = None,
) -> Cursor:
return (
Cursor(
collection=self,
filter=filter,
projection=projection,
)
.skip(skip)
.limit(limit)
.sort(sort)
)

def distinct(
self,
key: str,
filter: Optional[Dict[str, Any]] = None,
) -> List[Any]:
return self.find(
filter=filter,
projection={key: True},
).distinct(key)

def count_documents(
self,
filter: Dict[str, Any],
Expand Down Expand Up @@ -278,9 +360,6 @@ def list_search_indexes(*pargs: Any, **kwargs: Any) -> Any: ...
@unsupported
def update_search_index(*pargs: Any, **kwargs: Any) -> Any: ...

@unsupported
def distinct(*pargs: Any, **kwargs: Any) -> Any: ...


class AsyncCollection:
def __init__(
Expand Down Expand Up @@ -369,7 +448,7 @@ def set_caller(

async def insert_one(
self,
document: Dict[str, Any],
document: DocumentType,
*,
bypass_document_validation: Optional[bool] = None,
) -> InsertOneResult:
Expand Down Expand Up @@ -397,6 +476,85 @@ async def insert_one(
f"(gotten '${json.dumps(io_response)}')"
)

async def insert_many(
self,
documents: Iterable[DocumentType],
*,
ordered: bool = True,
bypass_document_validation: Optional[bool] = None,
) -> InsertManyResult:
if bypass_document_validation:
raise_unsupported_parameter(
class_name=self.__class__.__name__,
method_name="insert_many",
parameter_name="bypass_document_validation",
)
if ordered:
cim_responses = await self._astra_db_collection.chunked_insert_many(
documents=list(documents),
options={"ordered": True},
partial_failures_allowed=False,
concurrency=1,
)
else:
# unordered insertion: can do chunks concurrently
cim_responses = await self._astra_db_collection.chunked_insert_many(
documents=list(documents),
options={"ordered": False},
partial_failures_allowed=True,
concurrency=INSERT_MANY_CONCURRENCY,
)
_exceptions = [cim_r for cim_r in cim_responses if isinstance(cim_r, Exception)]
_errors_in_response = [
err
for response in cim_responses
if isinstance(response, dict)
for err in (response.get("errors") or [])
]
if _exceptions:
raise _exceptions[0]
elif _errors_in_response:
raise ValueError(str(_errors_in_response[0]))
else:
inserted_ids = [
ins_id
for response in cim_responses
if isinstance(response, dict)
for ins_id in (response.get("status") or {}).get("insertedIds", [])
]
return InsertManyResult(inserted_ids=inserted_ids)

def find(
self,
filter: Optional[Dict[str, Any]] = None,
*,
projection: Optional[ProjectionType] = None,
skip: Optional[int] = None,
limit: Optional[int] = None,
sort: Optional[Dict[str, Any]] = None,
) -> AsyncCursor:
return (
AsyncCursor(
collection=self,
filter=filter,
projection=projection,
)
.skip(skip)
.limit(limit)
.sort(sort)
)

async def distinct(
self,
key: str,
filter: Optional[Dict[str, Any]] = None,
) -> List[Any]:
cursor = self.find(
filter=filter,
projection={key: True},
)
return await cursor.distinct(key)

async def count_documents(
self,
filter: Dict[str, Any],
Expand Down Expand Up @@ -538,6 +696,3 @@ async def list_search_indexes(*pargs: Any, **kwargs: Any) -> Any: ...

@unsupported
async def update_search_index(*pargs: Any, **kwargs: Any) -> Any: ...

@unsupported
async def distinct(*pargs: Any, **kwargs: Any) -> Any: ...
Loading

0 comments on commit 89209de

Please sign in to comment.