diff --git a/astrapy/db.py b/astrapy/db.py index b0ab9141..9e7f477a 100644 --- a/astrapy/db.py +++ b/astrapy/db.py @@ -15,7 +15,20 @@ import logging import json from functools import partial -from typing import Any, cast, Dict, Iterable, List, Optional, Tuple, Union +from types import TracebackType +from typing import ( + Any, + cast, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, + Type, + AsyncIterable, + AsyncGenerator, +) import httpx @@ -25,9 +38,13 @@ DEFAULT_JSON_API_VERSION, DEFAULT_KEYSPACE_NAME, ) -from astrapy.utils import make_payload, make_request, http_methods -from astrapy.types import API_DOC, API_RESPONSE, PaginableRequestMethod - +from astrapy.utils import make_payload, make_request, http_methods, amake_request +from astrapy.types import ( + API_DOC, + API_RESPONSE, + PaginableRequestMethod, + AsyncPaginableRequestMethod, +) logger = logging.getLogger(__name__) @@ -748,72 +765,61 @@ def upsert(self, document: API_DOC) -> str: return upserted_id -class AstraDB: - # Initialize the shared httpx client as a class attribute - client = httpx.Client() - +class AsyncAstraDBCollection: def __init__( self, + collection_name: str, + astra_db: Optional[AsyncAstraDB] = None, token: Optional[str] = None, api_endpoint: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, namespace: Optional[str] = None, ) -> None: """ - Initialize an Astra DB instance. + Initialize an AstraDBCollection instance. Args: + collection_name (str): The name of the collection. + astra_db (AstraDB, optional): An instance of Astra DB. token (str, optional): Authentication token for Astra DB. api_endpoint (str, optional): API endpoint URL. namespace (str, optional): Namespace for the database. """ - if token is None or api_endpoint is None: - raise AssertionError("Must provide token and api_endpoint") + # Check for presence of the Astra DB object + if astra_db is None: + if token is None or api_endpoint is None: + raise AssertionError("Must provide token and api_endpoint") - if namespace is None: - logger.info( - f"ASTRA_DB_KEYSPACE is not set. Defaulting to '{DEFAULT_KEYSPACE_NAME}'" + astra_db = AsyncAstraDB( + token=token, api_endpoint=api_endpoint, namespace=namespace ) - namespace = DEFAULT_KEYSPACE_NAME - - # Store the API token - self.token = token - - # Set the Base URL for the API calls - self.base_url = api_endpoint.strip("/") - - # Set the API version and path from the call - self.api_path = (api_path or DEFAULT_JSON_API_PATH).strip("/") - self.api_version = (api_version or DEFAULT_JSON_API_VERSION).strip("/") - - # Set the namespace - self.namespace = namespace - # Finally, construct the full base path - self.base_path = f"/{self.api_path}/{self.api_version}/{self.namespace}" + # Set the remaining instance attributes + self.astra_db: AsyncAstraDB = astra_db + self.client = astra_db.client + self.collection_name = collection_name + self.base_path = f"{self.astra_db.base_path}/{self.collection_name}" def __repr__(self) -> str: - return f'Astra DB[endpoint="{self.base_url}"]' + return f'Astra DB Collection[name="{self.collection_name}", endpoint="{self.astra_db.base_url}"]' - def _request( + async def _request( self, method: str = http_methods.POST, path: Optional[str] = None, json_data: Optional[Dict[str, Any]] = None, url_params: Optional[Dict[str, Any]] = None, skip_error_check: bool = False, + **kwargs: Any, ) -> API_RESPONSE: - response = make_request( + response = await amake_request( client=self.client, - base_url=self.base_url, + base_url=self.astra_db.base_url, auth_header=DEFAULT_AUTH_HEADER, - token=self.token, + token=self.astra_db.token, method=method, path=path, json_data=json_data, url_params=url_params, ) - responsebody = cast(API_RESPONSE, response.json()) if not skip_error_check and "errors" in responsebody: @@ -821,128 +827,851 @@ def _request( else: return responsebody - def collection(self, collection_name: str) -> AstraDBCollection: + async def _get( + self, path: Optional[str] = None, options: Optional[Dict[str, Any]] = None + ) -> Optional[API_RESPONSE]: + full_path = f"{self.base_path}/{path}" if path else self.base_path + response = await self._request( + method=http_methods.GET, path=full_path, url_params=options + ) + if isinstance(response, dict): + return response + return None + + async def _put( + self, path: Optional[str] = None, document: Optional[API_RESPONSE] = None + ) -> API_RESPONSE: + full_path = f"{self.base_path}/{path}" if path else self.base_path + response = await self._request( + method=http_methods.PUT, path=full_path, json_data=document + ) + return response + + async def _post( + self, path: Optional[str] = None, document: Optional[API_DOC] = None + ) -> API_RESPONSE: + full_path = f"{self.base_path}/{path}" if path else self.base_path + response = await self._request( + method=http_methods.POST, path=full_path, json_data=document + ) + return response + + def _pre_process_find( + self, vector: List[float], fields: Optional[List[str]] = None + ) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]]]: + # Must pass a vector + if not vector: + raise ValueError("Must pass a vector") + + # Edge case for field selection + if fields and "$similarity" in fields: + raise ValueError("Please use the `include_similarity` parameter") + + # Build the new vector parameter + sort: Dict[str, Any] = {"$vector": vector} + + # Build the new fields parameter + # Note: do not leave projection={}, make it None + # (or it will devour $similarity away in the API response) + if fields is not None and len(fields) > 0: + projection = {f: 1 for f in fields} + else: + projection = None + + return sort, projection + + async def get(self, path: Optional[str] = None) -> Optional[API_RESPONSE]: """ - Retrieve a collection from the database. + Retrieve a document from the collection by its path. Args: - collection_name (str): The name of the collection to retrieve. + path (str, optional): The path of the document to retrieve. Returns: - AstraDBCollection: The collection object. + dict: The retrieved document. """ - return AstraDBCollection(collection_name=collection_name, astra_db=self) + return await self._get(path=path) - def get_collections(self, options: Optional[Dict[str, Any]] = None) -> API_RESPONSE: + async def find( + self, + filter: Optional[Dict[str, Any]] = None, + projection: Optional[Dict[str, Any]] = None, + sort: Optional[Dict[str, Any]] = {}, + options: Optional[Dict[str, Any]] = None, + ) -> API_RESPONSE: """ - Retrieve a list of collections from the database. + Find documents in the collection that match the given filter. Args: - options (dict, optional): Options to get the collection list + filter (dict, optional): Criteria to filter documents. + projection (dict, optional): Specifies the fields to return. + sort (dict, optional): Specifies the order in which to return matching documents. + options (dict, optional): Additional options for the query. Returns: - dict: An object containing the list of collections in the database: - {"status": {"collections": [...]}} + dict: The query response containing matched documents. """ - # Parse the options parameter - if options is None: - options = {} - json_query = make_payload( - top_level="findCollections", + top_level="find", + filter=filter, + projection=projection, options=options, + sort=sort, ) - response = self._request( - method=http_methods.POST, - path=self.base_path, - json_data=json_query, + response = await self._post( + document=json_query, ) return response - def create_collection( + async def vector_find( self, - collection_name: str, + vector: List[float], *, - options: Optional[Dict[str, Any]] = None, - dimension: Optional[int] = None, - metric: Optional[str] = None, - ) -> AstraDBCollection: + limit: int, + filter: Optional[Dict[str, Any]] = None, + fields: Optional[List[str]] = None, + include_similarity: bool = True, + ) -> List[API_DOC]: """ - Create a new collection in the database. + Perform a vector-based search in the collection. Args: - collection_name (str): The name of the collection to create. - options (dict, optional): Options for the collection. - dimension (int, optional): Dimension for vector search. - metric (str, optional): Metric choice for vector search. + vector (list): The vector to search with. + limit (int): The maximum number of documents to return. + filter (dict, optional): Criteria to filter documents. + fields (list, optional): Specifies the fields to return. + include_similarity (bool, optional): Whether to include similarity score in the result. Returns: - AstraDBCollection: The created collection object. + list: A list of documents matching the vector search criteria. """ - # options from named params - vector_options = { - k: v - for k, v in { - "dimension": dimension, - "metric": metric, - }.items() - if v is not None - } + # Must pass a limit + if not limit: + raise ValueError("Must pass a limit") - # overlap/merge with stuff in options.vector - dup_params = set((options or {}).get("vector", {}).keys()) & set( - vector_options.keys() + # Pre-process the included arguments + sort, projection = self._pre_process_find( + vector, + fields=fields, ) - # If any params are duplicated, we raise an error - if dup_params: - dups = ", ".join(sorted(dup_params)) - raise ValueError( - f"Parameter(s) {dups} passed both to the method and in the options" - ) - - # Build our options dictionary if we have vector options - if vector_options: - options = options or {} - options["vector"] = { - **options.get("vector", {}), - **vector_options, - } - if "dimension" not in options["vector"]: - raise ValueError("Must pass dimension for vector collections") - - # Build the final json payload - jsondata = { - k: v - for k, v in {"name": collection_name, "options": options}.items() - if v is not None - } - - # Make the request to the endpoint - self._request( - method=http_methods.POST, - path=f"{self.base_path}", - json_data={"createCollection": jsondata}, + # Call the underlying find() method to search + raw_find_result = await self.find( + filter=filter, + projection=projection, + sort=sort, + options={ + "limit": limit, + "includeSimilarity": include_similarity, + }, ) - # Get the instance object as the return of the call - return AstraDBCollection(astra_db=self, collection_name=collection_name) + return cast(List[API_DOC], raw_find_result["data"]["documents"]) - def delete_collection(self, collection_name: str) -> API_RESPONSE: + @staticmethod + async def paginate( + *, + request_method: AsyncPaginableRequestMethod, + options: Optional[Dict[str, Any]], + ) -> AsyncGenerator[API_DOC, None]: """ - Delete a collection from the database. + Generate paginated results for a given database query method. Args: - collection_name (str): The name of the collection to delete. - Returns: - dict: The response from the database. + request_method (function): The database query method to paginate. + options (dict): Options for the database query. + kwargs: Additional arguments to pass to the database query method. + Yields: + dict: The next document in the paginated result set. """ - # Make sure we provide a collection name - if not collection_name: - raise ValueError("Must provide a collection name") - - response = self._request( - method=http_methods.POST, - path=f"{self.base_path}", - json_data={"deleteCollection": {"name": collection_name}}, - ) + _options = options or {} + response0 = await request_method(options=_options) + next_page_state = response0["data"]["nextPageState"] + options0 = _options + for document in response0["data"]["documents"]: + yield document + while next_page_state is not None: + options1 = {**options0, **{"pageState": next_page_state}} + response1 = await request_method(options=options1) + for document in response1["data"]["documents"]: + yield document + next_page_state = response1["data"]["nextPageState"] - return response + def paginated_find( + self, + filter: Optional[Dict[str, Any]] = None, + projection: Optional[Dict[str, Any]] = None, + sort: Optional[Dict[str, Any]] = None, + options: Optional[Dict[str, Any]] = None, + ) -> AsyncIterable[API_DOC]: + """ + Perform a paginated search in the collection. + Args: + filter (dict, optional): Criteria to filter documents. + projection (dict, optional): Specifies the fields to return. + sort (dict, optional): Specifies the order in which to return matching documents. + options (dict, optional): Additional options for the query. + Returns: + generator: A generator yielding documents in the paginated result set. + """ + partialed_find = partial( + self.find, + filter=filter, + projection=projection, + sort=sort, + ) + return self.paginate( + request_method=partialed_find, + options=options, + ) + + async def pop( + self, filter: Dict[str, Any], pop: Dict[str, Any], options: Dict[str, Any] + ) -> API_RESPONSE: + """ + Pop the last data in the tags array + Args: + filter (dict): Criteria to identify the document to update. + pop (dict): The pop to apply to the tags. + options (dict): Additional options for the update operation. + Returns: + dict: The original document before the update. + """ + json_query = make_payload( + top_level="findOneAndUpdate", + filter=filter, + update={"$pop": pop}, + options=options, + ) + + response = await self._request( + method=http_methods.POST, + path=self.base_path, + json_data=json_query, + ) + + return response + + async def push( + self, filter: Dict[str, Any], push: Dict[str, Any], options: Dict[str, Any] + ) -> API_RESPONSE: + """ + Push new data to the tags array + Args: + filter (dict): Criteria to identify the document to update. + push (dict): The push to apply to the tags. + options (dict): Additional options for the update operation. + Returns: + dict: The result of the update operation. + """ + json_query = make_payload( + top_level="findOneAndUpdate", + filter=filter, + update={"$push": push}, + options=options, + ) + + response = await self._request( + method=http_methods.POST, + path=self.base_path, + json_data=json_query, + ) + + return response + + async def find_one_and_replace( + self, + replacement: Optional[Dict[str, Any]] = None, + *, + sort: Optional[Dict[str, Any]] = {}, + filter: Optional[Dict[str, Any]] = None, + options: Optional[Dict[str, Any]] = None, + ) -> API_RESPONSE: + """ + Find a single document and replace it. + Args: + replacement (dict): The new document to replace the existing one. + filter (dict, optional): Criteria to filter documents. + sort (dict, optional): Specifies the order in which to find the document. + options (dict, optional): Additional options for the operation. + Returns: + dict: The result of the find and replace operation. + """ + json_query = make_payload( + top_level="findOneAndReplace", + filter=filter, + replacement=replacement, + options=options, + sort=sort, + ) + + response = await self._request( + method=http_methods.POST, path=f"{self.base_path}", json_data=json_query + ) + + return response + + async def vector_find_one_and_replace( + self, + vector: List[float], + replacement: Dict[str, Any], + *, + filter: Optional[Dict[str, Any]] = None, + fields: Optional[List[str]] = None, + ) -> Union[API_DOC, None]: + """ + Perform a vector-based search and replace the first matched document. + Args: + vector (dict): The vector to search with. + replacement (dict): The new document to replace the existing one. + filter (dict, optional): Criteria to filter documents. + fields (list, optional): Specifies the fields to return in the result. + Returns: + dict or None: either the matched document or None if nothing found + """ + # Pre-process the included arguments + sort, _ = self._pre_process_find( + vector, + fields=fields, + ) + + # Call the underlying find() method to search + raw_find_result = await self.find_one_and_replace( + replacement=replacement, + filter=filter, + sort=sort, + ) + + return cast(Union[API_DOC, None], raw_find_result["data"]["document"]) + + async def find_one_and_update( + self, + sort: Optional[Dict[str, Any]] = {}, + update: Optional[Dict[str, Any]] = None, + filter: Optional[Dict[str, Any]] = None, + options: Optional[Dict[str, Any]] = None, + ) -> API_RESPONSE: + """ + Find a single document and update it. + Args: + sort (dict, optional): Specifies the order in which to find the document. + update (dict, optional): The update to apply to the document. + filter (dict, optional): Criteria to filter documents. + options (dict, optional): Additional options for the operation. + Returns: + dict: The result of the find and update operation. + """ + json_query = make_payload( + top_level="findOneAndUpdate", + filter=filter, + update=update, + options=options, + sort=sort, + ) + + response = await self._request( + method=http_methods.POST, + path=f"{self.base_path}", + json_data=json_query, + ) + + return response + + async def vector_find_one_and_update( + self, + vector: List[float], + update: Dict[str, Any], + *, + filter: Optional[Dict[str, Any]] = None, + fields: Optional[List[str]] = None, + ) -> Union[API_DOC, None]: + """ + Perform a vector-based search and update the first matched document. + Args: + vector (list): The vector to search with. + update (dict): The update to apply to the matched document. + filter (dict, optional): Criteria to filter documents before applying the vector search. + fields (list, optional): Specifies the fields to return in the updated document. + Returns: + dict or None: The result of the vector-based find and + update operation, or None if nothing found + """ + # Pre-process the included arguments + sort, _ = self._pre_process_find( + vector, + fields=fields, + ) + + # Call the underlying find() method to search + raw_find_result = await self.find_one_and_update( + update=update, + filter=filter, + sort=sort, + ) + + return cast(Union[API_DOC, None], raw_find_result["data"]["document"]) + + async def count_documents( + self, + filter: Dict[str, Any] = {}, + ) -> API_RESPONSE: + """ + Count documents matching a given predicate (expressed as filter). + Args: + filter (dict, defaults to {}): Criteria to filter documents. + Returns: + dict: the response, either + {"status": {"count": }} + or + {"errors": [...]} + """ + json_query = make_payload( + top_level="countDocuments", + filter=filter, + ) + + response = await self._post( + document=json_query, + ) + + return response + + async def find_one( + self, + filter: Optional[Dict[str, Any]] = {}, + projection: Optional[Dict[str, Any]] = {}, + sort: Optional[Dict[str, Any]] = {}, + options: Optional[Dict[str, Any]] = {}, + ) -> API_RESPONSE: + """ + Find a single document in the collection. + Args: + filter (dict, optional): Criteria to filter documents. + projection (dict, optional): Specifies the fields to return. + sort (dict, optional): Specifies the order in which to return the document. + options (dict, optional): Additional options for the query. + Returns: + dict: the response, either + {"data": {"document": }} + or + {"data": {"document": None}} + depending on whether a matching document is found or not. + """ + json_query = make_payload( + top_level="findOne", + filter=filter, + projection=projection, + options=options, + sort=sort, + ) + + response = await self._post( + document=json_query, + ) + + return response + + async def vector_find_one( + self, + vector: List[float], + *, + filter: Optional[Dict[str, Any]] = None, + fields: Optional[List[str]] = None, + include_similarity: bool = True, + ) -> Union[API_DOC, None]: + """ + Perform a vector-based search to find a single document in the collection. + Args: + vector (list): The vector to search with. + filter (dict, optional): Additional criteria to filter documents. + fields (list, optional): Specifies the fields to return in the result. + include_similarity (bool, optional): Whether to include similarity score in the result. + Returns: + dict or None: The found document or None if no matching document is found. + """ + # Pre-process the included arguments + sort, projection = self._pre_process_find( + vector, + fields=fields, + ) + + # Call the underlying find() method to search + raw_find_result = await self.find_one( + filter=filter, + projection=projection, + sort=sort, + options={"includeSimilarity": include_similarity}, + ) + + return cast(Union[API_DOC, None], raw_find_result["data"]["document"]) + + async def insert_one( + self, document: API_DOC, failures_allowed: bool = False + ) -> API_RESPONSE: + """ + Insert a single document into the collection. + Args: + document (dict): The document to insert. + failures_allowed (bool): Whether to allow failures in the insert operation. + Returns: + dict: The response from the database after the insert operation. + """ + json_query = make_payload(top_level="insertOne", document=document) + + response = await self._request( + method=http_methods.POST, + path=self.base_path, + json_data=json_query, + skip_error_check=failures_allowed, + ) + + return response + + async def insert_many( + self, + documents: List[API_DOC], + options: Optional[Dict[str, Any]] = None, + partial_failures_allowed: bool = False, + ) -> API_RESPONSE: + """ + Insert multiple documents into the collection. + Args: + documents (list): A list of documents to insert. + options (dict, optional): Additional options for the insert operation. + partial_failures_allowed (bool, optional): Whether to allow partial failures in the batch. + Returns: + dict: The response from the database after the insert operation. + """ + json_query = make_payload( + top_level="insertMany", documents=documents, options=options + ) + + response = await self._request( + method=http_methods.POST, + path=f"{self.base_path}", + json_data=json_query, + skip_error_check=partial_failures_allowed, + ) + + return response + + async def update_one( + self, filter: Dict[str, Any], update: Dict[str, Any] + ) -> API_RESPONSE: + """ + Update a single document in the collection. + Args: + filter (dict): Criteria to identify the document to update. + update (dict): The update to apply to the document. + Returns: + dict: The response from the database after the update operation. + """ + json_query = make_payload(top_level="updateOne", filter=filter, update=update) + + response = await self._request( + method=http_methods.POST, + path=f"{self.base_path}", + json_data=json_query, + ) + + return response + + async def replace(self, path: str, document: API_DOC) -> API_RESPONSE: + """ + Replace a document in the collection. + Args: + path (str): The path to the document to replace. + document (dict): The new document to replace the existing one. + Returns: + dict: The response from the database after the replace operation. + """ + return await self._put(path=path, document=document) + + async def delete_one(self, id: str) -> API_RESPONSE: + """ + Delete a single document from the collection based on its ID. + Args: + id (str): The ID of the document to delete. + Returns: + dict: The response from the database after the delete operation. + """ + json_query = { + "deleteOne": { + "filter": {"_id": id}, + } + } + + response = await self._request( + method=http_methods.POST, path=f"{self.base_path}", json_data=json_query + ) + + return response + + async def delete_many(self, filter: Dict[str, Any]) -> API_RESPONSE: + """ + Delete many documents from the collection based on a filter condition + Args: + filter (dict): Criteria to identify the documents to delete. + Returns: + dict: The response from the database after the delete operation. + """ + json_query = { + "deleteMany": { + "filter": filter, + } + } + + response = await self._request( + method=http_methods.POST, path=f"{self.base_path}", json_data=json_query + ) + + return response + + async def delete_subdocument(self, id: str, subdoc: str) -> API_RESPONSE: + """ + Delete a subdocument or field from a document in the collection. + Args: + id (str): The ID of the document containing the subdocument. + subdoc (str): The key of the subdocument or field to remove. + Returns: + dict: The response from the database after the update operation. + """ + json_query = { + "findOneAndUpdate": { + "filter": {"_id": id}, + "update": {"$unset": {subdoc: ""}}, + } + } + + response = await self._request( + method=http_methods.POST, path=f"{self.base_path}", json_data=json_query + ) + + return response + + async def upsert(self, document: API_DOC) -> str: + """ + Emulate an upsert operation for a single document in the collection. + + This method attempts to insert the document. If a document with the same _id exists, it updates the existing document. + + Args: + document (dict): The document to insert or update. + + Returns: + str: The _id of the inserted or updated document. + """ + # Build the payload for the insert attempt + result = await self.insert_one(document, failures_allowed=True) + + # If the call failed, then we replace the existing doc + if ( + "errors" in result + and "errorCode" in result["errors"][0] + and result["errors"][0]["errorCode"] == "DOCUMENT_ALREADY_EXISTS" + ): + # Now we attempt to update + result = await self.find_one_and_replace( + replacement=document, + filter={"_id": document["_id"]}, + ) + upserted_id = cast(str, result["data"]["document"]["_id"]) + else: + upserted_id = cast(str, result["status"]["insertedIds"][0]) + + return upserted_id + + +class AstraDB: + # Initialize the shared httpx client as a class attribute + client = httpx.Client() + + def __init__( + self, + token: Optional[str] = None, + api_endpoint: Optional[str] = None, + api_path: Optional[str] = None, + api_version: Optional[str] = None, + namespace: Optional[str] = None, + ) -> None: + """ + Initialize an Astra DB instance. + Args: + token (str, optional): Authentication token for Astra DB. + api_endpoint (str, optional): API endpoint URL. + namespace (str, optional): Namespace for the database. + """ + if token is None or api_endpoint is None: + raise AssertionError("Must provide token and api_endpoint") + + if namespace is None: + logger.info( + f"ASTRA_DB_KEYSPACE is not set. Defaulting to '{DEFAULT_KEYSPACE_NAME}'" + ) + namespace = DEFAULT_KEYSPACE_NAME + + # Store the API token + self.token = token + + # Set the Base URL for the API calls + self.base_url = api_endpoint.strip("/") + + # Set the API version and path from the call + self.api_path = (api_path or DEFAULT_JSON_API_PATH).strip("/") + self.api_version = (api_version or DEFAULT_JSON_API_VERSION).strip("/") + + # Set the namespace + self.namespace = namespace + + # Finally, construct the full base path + self.base_path = f"/{self.api_path}/{self.api_version}/{self.namespace}" + + def __repr__(self) -> str: + return f'Astra DB[endpoint="{self.base_url}"]' + + def _request( + self, + method: str = http_methods.POST, + path: Optional[str] = None, + json_data: Optional[Dict[str, Any]] = None, + url_params: Optional[Dict[str, Any]] = None, + skip_error_check: bool = False, + ) -> API_RESPONSE: + response = make_request( + client=self.client, + base_url=self.base_url, + auth_header=DEFAULT_AUTH_HEADER, + token=self.token, + method=method, + path=path, + json_data=json_data, + url_params=url_params, + ) + + responsebody = cast(API_RESPONSE, response.json()) + + if not skip_error_check and "errors" in responsebody: + raise ValueError(json.dumps(responsebody["errors"])) + else: + return responsebody + + def collection(self, collection_name: str) -> AstraDBCollection: + """ + Retrieve a collection from the database. + Args: + collection_name (str): The name of the collection to retrieve. + Returns: + AstraDBCollection: The collection object. + """ + return AstraDBCollection(collection_name=collection_name, astra_db=self) + + def get_collections(self, options: Optional[Dict[str, Any]] = None) -> API_RESPONSE: + """ + Retrieve a list of collections from the database. + Args: + options (dict, optional): Options to get the collection list + Returns: + dict: An object containing the list of collections in the database: + {"status": {"collections": [...]}} + """ + # Parse the options parameter + if options is None: + options = {} + + json_query = make_payload( + top_level="findCollections", + options=options, + ) + + response = self._request( + method=http_methods.POST, + path=self.base_path, + json_data=json_query, + ) + + return response + + def create_collection( + self, + collection_name: str, + *, + options: Optional[Dict[str, Any]] = None, + dimension: Optional[int] = None, + metric: Optional[str] = None, + ) -> AstraDBCollection: + """ + Create a new collection in the database. + Args: + collection_name (str): The name of the collection to create. + options (dict, optional): Options for the collection. + dimension (int, optional): Dimension for vector search. + metric (str, optional): Metric choice for vector search. + Returns: + AstraDBCollection: The created collection object. + """ + # options from named params + vector_options = { + k: v + for k, v in { + "dimension": dimension, + "metric": metric, + }.items() + if v is not None + } + + # overlap/merge with stuff in options.vector + dup_params = set((options or {}).get("vector", {}).keys()) & set( + vector_options.keys() + ) + + # If any params are duplicated, we raise an error + if dup_params: + dups = ", ".join(sorted(dup_params)) + raise ValueError( + f"Parameter(s) {dups} passed both to the method and in the options" + ) + + # Build our options dictionary if we have vector options + if vector_options: + options = options or {} + options["vector"] = { + **options.get("vector", {}), + **vector_options, + } + if "dimension" not in options["vector"]: + raise ValueError("Must pass dimension for vector collections") + + # Build the final json payload + jsondata = { + k: v + for k, v in {"name": collection_name, "options": options}.items() + if v is not None + } + + # Make the request to the endpoint + self._request( + method=http_methods.POST, + path=f"{self.base_path}", + json_data={"createCollection": jsondata}, + ) + + # Get the instance object as the return of the call + return AstraDBCollection(astra_db=self, collection_name=collection_name) + + def delete_collection(self, collection_name: str) -> API_RESPONSE: + """ + Delete a collection from the database. + Args: + collection_name (str): The name of the collection to delete. + Returns: + dict: The response from the database. + """ + # Make sure we provide a collection name + if not collection_name: + raise ValueError("Must provide a collection name") + + response = self._request( + method=http_methods.POST, + path=f"{self.base_path}", + json_data={"deleteCollection": {"name": collection_name}}, + ) + + return response def truncate_collection(self, collection_name: str) -> AstraDBCollection: """ @@ -979,3 +1708,247 @@ def truncate_collection(self, collection_name: str) -> AstraDBCollection: collection_name, options=existing_collection.get("options"), ) + + +class AsyncAstraDB: + def __init__( + self, + token: Optional[str] = None, + api_endpoint: Optional[str] = None, + api_path: Optional[str] = None, + api_version: Optional[str] = None, + namespace: Optional[str] = None, + ) -> None: + """ + Initialize an Astra DB instance. + Args: + token (str, optional): Authentication token for Astra DB. + api_endpoint (str, optional): API endpoint URL. + namespace (str, optional): Namespace for the database. + """ + self.client = httpx.AsyncClient() + if token is None or api_endpoint is None: + raise AssertionError("Must provide token and api_endpoint") + + if namespace is None: + logger.info( + f"ASTRA_DB_KEYSPACE is not set. Defaulting to '{DEFAULT_KEYSPACE_NAME}'" + ) + namespace = DEFAULT_KEYSPACE_NAME + + # Store the API token + self.token = token + + # Set the Base URL for the API calls + self.base_url = api_endpoint.strip("/") + + # Set the API version and path from the call + self.api_path = (api_path or DEFAULT_JSON_API_PATH).strip("/") + self.api_version = (api_version or DEFAULT_JSON_API_VERSION).strip("/") + + # Set the namespace + self.namespace = namespace + + # Finally, construct the full base path + self.base_path = f"/{self.api_path}/{self.api_version}/{self.namespace}" + + def __repr__(self) -> str: + return f'Async Astra DB[endpoint="{self.base_url}"]' + + async def __aenter__(self) -> AsyncAstraDB: + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + traceback: Optional[TracebackType] = None, + ) -> None: + await self.client.aclose() + + async def _request( + self, + method: str = http_methods.POST, + path: Optional[str] = None, + json_data: Optional[Dict[str, Any]] = None, + url_params: Optional[Dict[str, Any]] = None, + skip_error_check: bool = False, + ) -> API_RESPONSE: + response = await amake_request( + client=self.client, + base_url=self.base_url, + auth_header=DEFAULT_AUTH_HEADER, + token=self.token, + method=method, + path=path, + json_data=json_data, + url_params=url_params, + ) + + responsebody = cast(API_RESPONSE, response.json()) + + if not skip_error_check and "errors" in responsebody: + raise ValueError(json.dumps(responsebody["errors"])) + else: + return responsebody + + async def collection(self, collection_name: str) -> AsyncAstraDBCollection: + """ + Retrieve a collection from the database. + Args: + collection_name (str): The name of the collection to retrieve. + Returns: + AstraDBCollection: The collection object. + """ + return AsyncAstraDBCollection(collection_name=collection_name, astra_db=self) + + async def get_collections( + self, options: Optional[Dict[str, Any]] = None + ) -> API_RESPONSE: + """ + Retrieve a list of collections from the database. + Args: + options (dict, optional): Options to get the collection list + Returns: + dict: An object containing the list of collections in the database: + {"status": {"collections": [...]}} + """ + # Parse the options parameter + if options is None: + options = {} + + json_query = make_payload( + top_level="findCollections", + options=options, + ) + + response = await self._request( + method=http_methods.POST, + path=self.base_path, + json_data=json_query, + ) + + return response + + async def create_collection( + self, + collection_name: str, + *, + options: Optional[Dict[str, Any]] = None, + dimension: Optional[int] = None, + metric: Optional[str] = None, + ) -> AsyncAstraDBCollection: + """ + Create a new collection in the database. + Args: + collection_name (str): The name of the collection to create. + options (dict, optional): Options for the collection. + dimension (int, optional): Dimension for vector search. + metric (str, optional): Metric choice for vector search. + Returns: + AsyncAstraDBCollection: The created collection object. + """ + # options from named params + vector_options = { + k: v + for k, v in { + "dimension": dimension, + "metric": metric, + }.items() + if v is not None + } + + # overlap/merge with stuff in options.vector + dup_params = set((options or {}).get("vector", {}).keys()) & set( + vector_options.keys() + ) + + # If any params are duplicated, we raise an error + if dup_params: + dups = ", ".join(sorted(dup_params)) + raise ValueError( + f"Parameter(s) {dups} passed both to the method and in the options" + ) + + # Build our options dictionary if we have vector options + if vector_options: + options = options or {} + options["vector"] = { + **options.get("vector", {}), + **vector_options, + } + if "dimension" not in options["vector"]: + raise ValueError("Must pass dimension for vector collections") + + # Build the final json payload + jsondata = { + k: v + for k, v in {"name": collection_name, "options": options}.items() + if v is not None + } + + # Make the request to the endpoint + await self._request( + method=http_methods.POST, + path=f"{self.base_path}", + json_data={"createCollection": jsondata}, + ) + + # Get the instance object as the return of the call + return AsyncAstraDBCollection(astra_db=self, collection_name=collection_name) + + async def delete_collection(self, collection_name: str) -> API_RESPONSE: + """ + Delete a collection from the database. + Args: + collection_name (str): The name of the collection to delete. + Returns: + dict: The response from the database. + """ + # Make sure we provide a collection name + if not collection_name: + raise ValueError("Must provide a collection name") + + response = await self._request( + method=http_methods.POST, + path=f"{self.base_path}", + json_data={"deleteCollection": {"name": collection_name}}, + ) + + return response + + async def truncate_collection(self, collection_name: str) -> AsyncAstraDBCollection: + """ + Truncate a collection in the database. + Args: + collection_name (str): The name of the collection to truncate. + Returns: + dict: The response from the database. + """ + # Make sure we provide a collection name + if not collection_name: + raise ValueError("Must provide a collection name") + + # Retrieve the required collections from DB + collections = await self.get_collections(options={"explain": "true"}) + matches = [ + col + for col in collections["status"]["collections"] + if col["name"] == collection_name + ] + + # If we didn't find it, raise an error + if matches == []: + raise ValueError(f"Collection {collection_name} not found") + + # Otherwise we found it, so get the collection + existing_collection = matches[0] + + # We found it, so let's delete it + await self.delete_collection(collection_name) + + # End the function by returning the the new collection + return await self.create_collection( + collection_name, + options=existing_collection.get("options"), + ) diff --git a/astrapy/types.py b/astrapy/types.py index 673ccd51..633ce8a1 100644 --- a/astrapy/types.py +++ b/astrapy/types.py @@ -23,3 +23,9 @@ class PaginableRequestMethod(Protocol): def __call__(self, options: Dict[str, Any]) -> API_RESPONSE: ... + + +# This is for the (partialed, if necessary) async functions that can be "paginated". +class AsyncPaginableRequestMethod(Protocol): + async def __call__(self, options: Dict[str, Any]) -> API_RESPONSE: + ... diff --git a/astrapy/utils.py b/astrapy/utils.py index c3ee7372..7de380dc 100644 --- a/astrapy/utils.py +++ b/astrapy/utils.py @@ -88,6 +88,49 @@ def make_request( return r +async def amake_request( + client: httpx.AsyncClient, + base_url: str, + auth_header: str, + token: str, + method: str = http_methods.POST, + path: Optional[str] = None, + json_data: Optional[Dict[str, Any]] = None, + url_params: Optional[Dict[str, Any]] = None, +) -> httpx.Response: + """ + Make an HTTP request to a specified URL. + + Args: + client (httpx): The httpx client for the request. + base_url (str): The base URL for the request. + auth_header (str): The authentication header key. + token (str): The token used for authentication. + method (str, optional): The HTTP method to use for the request. Default is POST. + path (str, optional): The specific path to append to the base URL. + json_data (dict, optional): JSON payload to be sent with the request. + url_params (dict, optional): URL parameters to be sent with the request. + + Returns: + requests.Response: The response from the HTTP request. + """ + r = await client.request( + method=method, + url=f"{base_url}{path}", + params=url_params, + json=json_data, + timeout=DEFAULT_TIMEOUT, + headers={auth_header: token, "User-Agent": f"{package_name}/{__version__}"}, + ) + + if logger.isEnabledFor(logging.DEBUG): + log_request_response(r, json_data) + + r.raise_for_status() + + return r + + def make_payload(top_level: str, **kwargs: Any) -> Dict[str, Any]: """ Construct a JSON payload for an HTTP request with a specified top-level key. diff --git a/pytest.ini b/pytest.ini index e43ab08a..143be0e6 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,5 +1,6 @@ [pytest] filterwarnings = ignore::DeprecationWarning addopts = -v --cov=astrapy --testdox --cov-report term-missing +asyncio_mode = auto log_cli = 1 log_cli_level = INFO diff --git a/requirements-dev.txt b/requirements-dev.txt index cf28fb86..8fe067f5 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,6 +2,7 @@ black~=23.11.0 faker~=20.0.0 mypy~=1.7.0 pre-commit~=3.5.0 +pytest-asyncio~=0.23.2 pytest-cov~=4.1.0 pytest-testdox~=3.1.0 pytest~=7.4.3 diff --git a/tests/astrapy/test_async_db_ddl.py b/tests/astrapy/test_async_db_ddl.py new file mode 100644 index 00000000..828629af --- /dev/null +++ b/tests/astrapy/test_async_db_ddl.py @@ -0,0 +1,116 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for the `db.py` parts related to DML & client creation +""" + +import logging +from typing import Dict, Optional + +import pytest + +from astrapy.db import AsyncAstraDB, AsyncAstraDBCollection +from astrapy.defaults import DEFAULT_KEYSPACE_NAME + +TEST_CREATE_DELETE_VECTOR_COLLECTION_NAME = "ephemeral_v_col" +TEST_CREATE_DELETE_NONVECTOR_COLLECTION_NAME = "ephemeral_non_v_col" + +logger = logging.getLogger(__name__) + + +@pytest.mark.describe("should confirm path handling in constructor") +async def test_path_handling( + astra_db_credentials_kwargs: Dict[str, Optional[str]] +) -> None: + async with AsyncAstraDB(**astra_db_credentials_kwargs) as astra_db_1: + url_1 = astra_db_1.base_path + + async with AsyncAstraDB( + **astra_db_credentials_kwargs, + api_version="v1", + ) as astra_db_2: + url_2 = astra_db_2.base_path + + async with AsyncAstraDB( + **astra_db_credentials_kwargs, + api_version="/v1", + ) as astra_db_3: + url_3 = astra_db_3.base_path + + async with AsyncAstraDB( + **astra_db_credentials_kwargs, + api_version="/v1/", + ) as astra_db_4: + url_4 = astra_db_4.base_path + + assert url_1 == url_2 == url_3 == url_4 + + # autofill of the default keyspace name + async with AsyncAstraDB( + **{ + **astra_db_credentials_kwargs, + **{"namespace": DEFAULT_KEYSPACE_NAME}, + } + ) as unspecified_ks_client, AsyncAstraDB( + **{ + **astra_db_credentials_kwargs, + **{"namespace": None}, + } + ) as explicit_ks_client: + assert unspecified_ks_client.base_path == explicit_ks_client.base_path + + +@pytest.mark.describe("should create, use and destroy a non-vector collection") +async def test_create_use_destroy_nonvector_collection(async_db: AsyncAstraDB) -> None: + col = await async_db.create_collection(TEST_CREATE_DELETE_NONVECTOR_COLLECTION_NAME) + assert isinstance(col, AsyncAstraDBCollection) + await col.insert_one({"_id": "first", "name": "a"}) + await col.insert_many( + [ + {"_id": "second", "name": "b", "room": 7}, + {"name": "c", "room": 7}, + {"_id": "last", "type": "unnamed", "room": 7}, + ] + ) + docs = await col.find(filter={"room": 7}, projection={"name": 1}) + ids = [doc["_id"] for doc in docs["data"]["documents"]] + assert len(ids) == 3 + assert "second" in ids + assert "first" not in ids + auto_id = [id for id in ids if id not in {"second", "last"}][0] + await col.delete_one(auto_id) + assert (await col.find_one(filter={"name": "c"}))["data"]["document"] is None + del_res = await async_db.delete_collection( + TEST_CREATE_DELETE_NONVECTOR_COLLECTION_NAME + ) + assert del_res["status"]["ok"] == 1 + + +@pytest.mark.describe("should create and destroy a vector collection") +async def test_create_use_destroy_vector_collection(async_db: AsyncAstraDB) -> None: + col = await async_db.create_collection( + collection_name=TEST_CREATE_DELETE_VECTOR_COLLECTION_NAME, dimension=2 + ) + assert isinstance(col, AsyncAstraDBCollection) + del_res = await async_db.delete_collection( + collection_name=TEST_CREATE_DELETE_VECTOR_COLLECTION_NAME + ) + assert del_res["status"]["ok"] == 1 + + +@pytest.mark.describe("should get all collections") +async def test_get_collections(async_db: AsyncAstraDB) -> None: + res = await async_db.get_collections() + assert res["status"]["collections"] is not None diff --git a/tests/astrapy/test_async_db_dml.py b/tests/astrapy/test_async_db_dml.py new file mode 100644 index 00000000..ef7520cb --- /dev/null +++ b/tests/astrapy/test_async_db_dml.py @@ -0,0 +1,761 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for the `db.py` parts on data manipulation "standard" methods +(i.e. non `vector_*` methods) +""" + +import uuid +import logging +from typing import List + +import pytest + +from astrapy.types import API_DOC +from astrapy.db import AsyncAstraDB, AsyncAstraDBCollection + +TEST_TRUNCATED_NONVECTOR_COLLECTION_NAME = "ephemeral_tr_non_v_col" +TEST_TRUNCATED_VECTOR_COLLECTION_NAME = "ephemeral_tr_v_col" + +logger = logging.getLogger(__name__) + + +@pytest.mark.describe("should fail truncating a non-existent collection") +async def test_truncate_collection_fail(async_db: AsyncAstraDB) -> None: + with pytest.raises(ValueError): + await async_db.truncate_collection("this$does%not exist!!!") + + +@pytest.mark.describe("should truncate a nonvector collection") +async def test_truncate_nonvector_collection(async_db: AsyncAstraDB) -> None: + col = await async_db.create_collection(TEST_TRUNCATED_NONVECTOR_COLLECTION_NAME) + try: + await col.insert_one({"a": 1}) + assert len((await col.find())["data"]["documents"]) == 1 + await async_db.truncate_collection(TEST_TRUNCATED_NONVECTOR_COLLECTION_NAME) + assert len((await col.find())["data"]["documents"]) == 0 + finally: + await async_db.delete_collection(TEST_TRUNCATED_NONVECTOR_COLLECTION_NAME) + + +@pytest.mark.describe("should truncate a collection") +async def test_truncate_vector_collection(async_db: AsyncAstraDB) -> None: + col = await async_db.create_collection( + TEST_TRUNCATED_VECTOR_COLLECTION_NAME, dimension=2 + ) + try: + await col.insert_one({"a": 1, "$vector": [0.1, 0.2]}) + assert len((await col.find())["data"]["documents"]) == 1 + await async_db.truncate_collection(TEST_TRUNCATED_VECTOR_COLLECTION_NAME) + assert len((await col.find())["data"]["documents"]) == 0 + finally: + await async_db.delete_collection(TEST_TRUNCATED_VECTOR_COLLECTION_NAME) + + +@pytest.mark.describe("find_one, not through vector") +async def test_find_one_filter_novector( + async_readonly_vector_collection: AsyncAstraDBCollection, cliff_uuid: str +) -> None: + response = await async_readonly_vector_collection.find_one( + filter={"_id": "1"}, + ) + document = response["data"]["document"] + assert document["text"] == "Sample entry number <1>" + assert ( + document.keys() ^ {"_id", "text", "otherfield", "anotherfield", "$vector"} + == set() + ) + + response_not_by_id = await async_readonly_vector_collection.find_one( + filter={"text": "Sample entry number <1>"}, + ) + document_not_by_id = response_not_by_id["data"]["document"] + assert document_not_by_id["_id"] == "1" + assert ( + document_not_by_id.keys() + ^ {"_id", "text", "otherfield", "anotherfield", "$vector"} + == set() + ) + + response_no = await async_readonly_vector_collection.find_one( + filter={"_id": "Z"}, + ) + document_no = response_no["data"]["document"] + assert document_no is None + + response_no_not_by_id = await async_readonly_vector_collection.find_one( + filter={"text": "No such text."}, + ) + document_no_not_by_id = response_no_not_by_id["data"]["document"] + assert document_no_not_by_id is None + + +@pytest.mark.describe("find, not through vector") +async def test_find_filter_novector( + async_readonly_vector_collection: AsyncAstraDBCollection, +) -> None: + response_n2 = await async_readonly_vector_collection.find( + filter={"anotherfield": "alpha"}, + ) + documents_n2 = response_n2["data"]["documents"] + assert isinstance(documents_n2, list) + assert {document["_id"] for document in documents_n2} == {"1", "2"} + + response_n1 = await async_readonly_vector_collection.find( + filter={"anotherfield": "alpha"}, + options={"limit": 1}, + ) + documents_n1 = response_n1["data"]["documents"] + assert isinstance(documents_n1, list) + assert len(documents_n1) == 1 + assert documents_n1[0]["_id"] in {"1", "2"} + + +@pytest.mark.describe("obey projection in find and find_one") +async def test_find_find_one_projection( + async_readonly_vector_collection: AsyncAstraDBCollection, +) -> None: + query = [0.2, 0.6] + sort = {"$vector": query} + options = {"limit": 1} + + projs = [ + None, + {}, + {"text": 1}, + {"$vector": 1}, + {"text": 1, "$vector": 1}, + ] + exp_fieldsets = [ + {"$vector", "_id", "otherfield", "anotherfield", "text"}, + {"$vector", "_id", "otherfield", "anotherfield", "text"}, + {"_id", "text"}, + {"$vector", "_id"}, + {"$vector", "_id", "text"}, + ] + for proj, exp_fields in zip(projs, exp_fieldsets): + response_n = await async_readonly_vector_collection.find( + sort=sort, options=options, projection=proj + ) + fields = set(response_n["data"]["documents"][0].keys()) + assert fields == exp_fields + # + response_1 = await async_readonly_vector_collection.find_one( + sort=sort, projection=proj + ) + fields = set(response_1["data"]["document"].keys()) + assert fields == exp_fields + + +@pytest.mark.describe("find through vector") +async def test_find(async_readonly_vector_collection: AsyncAstraDBCollection) -> None: + sort = {"$vector": [0.2, 0.6]} + options = {"limit": 100} + + response = await async_readonly_vector_collection.find(sort=sort, options=options) + assert isinstance(response["data"]["documents"], list) + + +@pytest.mark.describe("proper error raising in find") +async def test_find_error( + async_readonly_vector_collection: AsyncAstraDBCollection, +) -> None: + """Wrong type of arguments should raise an API error (ValueError).""" + sort = {"$vector": "clearly not a list of floats!"} + options = {"limit": 100} + + with pytest.raises(ValueError): + await async_readonly_vector_collection.find(sort=sort, options=options) + + +@pytest.mark.describe("find through vector, without explicit limit") +async def test_find_limitless( + async_readonly_vector_collection: AsyncAstraDBCollection, +) -> None: + sort = {"$vector": [0.2, 0.6]} + projection = {"$vector": 1} + + response = await async_readonly_vector_collection.find( + sort=sort, projection=projection + ) + assert response is not None + assert isinstance(response["data"]["documents"], list) + + +@pytest.mark.describe("correctly count documents according to predicate") +async def test_count_documents( + async_readonly_vector_collection: AsyncAstraDBCollection, +) -> None: + c_all_response0 = await async_readonly_vector_collection.count_documents() + assert c_all_response0["status"]["count"] == 3 + + c_all_response1 = await async_readonly_vector_collection.count_documents(filter={}) + assert c_all_response1["status"]["count"] == 3 + + c_pred_response = await async_readonly_vector_collection.count_documents( + filter={"anotherfield": "alpha"} + ) + assert c_pred_response["status"]["count"] == 2 + + c_no_response = await async_readonly_vector_collection.count_documents( + filter={"false_field": 137} + ) + assert c_no_response["status"]["count"] == 0 + + +@pytest.mark.describe("insert_one, w/out _id, w/out vector") +async def test_create_document( + async_writable_vector_collection: AsyncAstraDBCollection, +) -> None: + i_vector = [0.3, 0.5] + id_v_i = str(uuid.uuid4()) + result_v_i = await async_writable_vector_collection.insert_one( + { + "_id": id_v_i, + "a": 1, + "$vector": i_vector, + } + ) + assert result_v_i["status"]["insertedIds"] == [id_v_i] + assert ( + await async_writable_vector_collection.find_one( + {"_id": result_v_i["status"]["insertedIds"][0]} + ) + )["data"]["document"]["a"] == 1 + + id_n_i = str(uuid.uuid4()) + result_n_i = await async_writable_vector_collection.insert_one( + { + "_id": id_n_i, + "a": 2, + } + ) + assert result_n_i["status"]["insertedIds"] == [id_n_i] + assert ( + await async_writable_vector_collection.find_one( + {"_id": result_n_i["status"]["insertedIds"][0]} + ) + )["data"]["document"]["a"] == 2 + + with pytest.raises(ValueError): + await async_writable_vector_collection.insert_one( + { + "_id": id_n_i, + "a": 3, + } + ) + + result_v_n = await async_writable_vector_collection.insert_one( + { + "a": 4, + "$vector": i_vector, + } + ) + assert isinstance(result_v_n["status"]["insertedIds"], list) + assert isinstance(result_v_n["status"]["insertedIds"][0], str) + assert len(result_v_n["status"]["insertedIds"]) == 1 + assert ( + await async_writable_vector_collection.find_one( + {"_id": result_v_n["status"]["insertedIds"][0]} + ) + )["data"]["document"]["a"] == 4 + + result_n_n = await async_writable_vector_collection.insert_one( + { + "a": 5, + } + ) + assert isinstance(result_n_n["status"]["insertedIds"], list) + assert isinstance(result_n_n["status"]["insertedIds"][0], str) + assert len(result_n_n["status"]["insertedIds"]) == 1 + assert ( + await async_writable_vector_collection.find_one( + {"_id": result_n_n["status"]["insertedIds"][0]} + ) + )["data"]["document"]["a"] == 5 + + +@pytest.mark.describe("insert_many") +async def test_insert_many( + async_writable_vector_collection: AsyncAstraDBCollection, +) -> None: + _id0 = str(uuid.uuid4()) + _id2 = str(uuid.uuid4()) + documents: List[API_DOC] = [ + { + "_id": _id0, + "name": "Abba", + "traits": [10, 9, 3], + "$vector": [0.6, 0.2], + }, + { + "name": "Bacchus", + "happy": True, + }, + { + "_id": _id2, + "name": "Ciccio", + "description": "The thid in this list", + "$vector": [0.4, 0.3], + }, + ] + + response = await async_writable_vector_collection.insert_many(documents) + assert response is not None + inserted_ids = set(response["status"]["insertedIds"]) + assert len(inserted_ids - {_id0, _id2}) == 1 + assert isinstance(list(inserted_ids - {_id0, _id2})[0], str) + + +@pytest.mark.describe("insert_many with 'ordered' set to False") +async def test_insert_many_ordered_false( + async_writable_vector_collection: AsyncAstraDBCollection, +) -> None: + _id0 = str(uuid.uuid4()) + _id1 = str(uuid.uuid4()) + _id2 = str(uuid.uuid4()) + documents_a = [ + { + "_id": _id0, + "first_name": "Dang", + "last_name": "Son", + }, + { + "_id": _id1, + "first_name": "Yep", + "last_name": "Boss", + }, + ] + response_a = await async_writable_vector_collection.insert_many(documents_a) + assert response_a is not None + assert response_a["status"]["insertedIds"] == [_id0, _id1] + + documents_b = [ + { + "_id": _id1, + "first_name": "Maureen", + "last_name": "Caloggero", + }, + { + "_id": _id2, + "first_name": "Miv", + "last_name": "Fuff", + }, + ] + response_b = await async_writable_vector_collection.insert_many( + documents_b, + partial_failures_allowed=True, + ) + assert response_b is not None + assert response_b["status"]["insertedIds"] == [] + + response_b2 = await async_writable_vector_collection.insert_many( + documents=documents_b, + options={"ordered": False}, + partial_failures_allowed=True, + ) + assert response_b2 is not None + assert response_b2["status"]["insertedIds"] == [_id2] + + check_response = await async_writable_vector_collection.find_one( + filter={"first_name": "Yep"} + ) + assert check_response is not None + assert check_response["data"]["document"]["_id"] == _id1 + + +@pytest.mark.describe("upsert") +async def test_upsert_document( + async_writable_vector_collection: AsyncAstraDBCollection, +) -> None: + _id = str(uuid.uuid4()) + + document0 = { + "_id": _id, + "addresses": { + "work": { + "city": "Seattle", + "state": "WA", + }, + }, + } + upsert_result0 = await async_writable_vector_collection.upsert(document0) + assert upsert_result0 == _id + + response0 = await async_writable_vector_collection.find_one(filter={"_id": _id}) + assert response0 is not None + assert response0["data"]["document"] == document0 + + document1 = { + "_id": _id, + "addresses": { + "work": { + "state": "MN", + "floor": 12, + }, + }, + "hobbies": [ + "ice skating", + "accounting", + ], + } + upsert_result1 = await async_writable_vector_collection.upsert(document1) + assert upsert_result1 == _id + + response1 = await async_writable_vector_collection.find_one(filter={"_id": _id}) + assert response1 is not None + assert response1["data"]["document"] == document1 + + +@pytest.mark.describe("update_one to create a subdocument, not through vector") +async def test_update_one_create_subdocument_novector( + async_disposable_vector_collection: AsyncAstraDBCollection, +) -> None: + update_one_response = await async_disposable_vector_collection.update_one( + filter={"_id": "1"}, + update={"$set": {"name": "Eric"}}, + ) + + assert update_one_response["status"]["matchedCount"] >= 1 + assert update_one_response["status"]["modifiedCount"] == 1 + + response = await async_disposable_vector_collection.find_one(filter={"_id": "1"}) + assert response["data"]["document"]["name"] == "Eric" + + +@pytest.mark.describe("delete_subdocument, not through vector") +async def test_delete_subdocument_novector( + async_disposable_vector_collection: AsyncAstraDBCollection, +) -> None: + delete_subdocument_response = ( + await async_disposable_vector_collection.delete_subdocument( + id="1", + subdoc="otherfield.subfield", + ) + ) + + assert delete_subdocument_response["status"]["matchedCount"] >= 1 + assert delete_subdocument_response["status"]["modifiedCount"] == 1 + + response = await async_disposable_vector_collection.find_one(filter={"_id": "1"}) + assert response["data"]["document"]["otherfield"] == {} + + +@pytest.mark.describe("find_one_and_update, through vector") +async def test_find_one_and_update_vector( + async_disposable_vector_collection: AsyncAstraDBCollection, +) -> None: + find_filter = {"status": {"$exists": True}} + response0 = await async_disposable_vector_collection.find_one(filter=find_filter) + assert response0["data"]["document"] is None + + sort = {"$vector": [0.2, 0.6]} + + update0 = {"$set": {"status": "active"}} + options0 = {"returnDocument": "after"} + + update_response0 = await async_disposable_vector_collection.find_one_and_update( + sort=sort, update=update0, options=options0 + ) + assert isinstance(update_response0["data"]["document"], dict) + assert update_response0["data"]["document"]["status"] == "active" + assert update_response0["status"]["matchedCount"] >= 1 + assert update_response0["status"]["modifiedCount"] >= 1 + + response1 = await async_disposable_vector_collection.find_one(filter=find_filter) + assert isinstance(response1["data"]["document"], dict) + assert response1["data"]["document"]["status"] == "active" + + update1 = {"$set": {"status": "inactive"}} + options1 = {"returnDocument": "before"} + + update_response1 = await async_disposable_vector_collection.find_one_and_update( + sort=sort, update=update1, options=options1 + ) + assert isinstance(update_response1["data"]["document"], dict) + assert update_response1["data"]["document"]["status"] == "active" + assert update_response1["status"]["matchedCount"] >= 1 + assert update_response1["status"]["modifiedCount"] >= 1 + + response2 = await async_disposable_vector_collection.find_one(filter=find_filter) + assert isinstance(response2["data"]["document"], dict) + assert response2["data"]["document"]["status"] == "inactive" + + filter2 = {"nonexistent_subfield": 10} + update2 = update1 + options2 = options1 + + update_response2 = await async_disposable_vector_collection.find_one_and_update( + sort=sort, update=update2, options=options2, filter=filter2 + ) + assert update_response2["data"]["document"] is None + assert update_response2["status"]["matchedCount"] == 0 + assert update_response2["status"]["modifiedCount"] == 0 + + +@pytest.mark.describe("find_one_and_update, not through vector") +async def test_find_one_and_update_novector( + async_disposable_vector_collection: AsyncAstraDBCollection, +) -> None: + find_filter = {"status": {"$exists": True}} + response0 = await async_disposable_vector_collection.find_one(filter=find_filter) + assert response0["data"]["document"] is None + + update_filter = {"anotherfield": "omega"} + + update0 = {"$set": {"status": "active"}} + options0 = {"returnDocument": "after"} + + update_response0 = await async_disposable_vector_collection.find_one_and_update( + filter=update_filter, update=update0, options=options0 + ) + assert isinstance(update_response0["data"]["document"], dict) + assert update_response0["data"]["document"]["status"] == "active" + assert update_response0["status"]["matchedCount"] >= 1 + assert update_response0["status"]["modifiedCount"] >= 1 + + response1 = await async_disposable_vector_collection.find_one(filter=find_filter) + assert isinstance(response1["data"]["document"], dict) + assert response1["data"]["document"]["status"] == "active" + + update1 = {"$set": {"status": "inactive"}} + options1 = {"returnDocument": "before"} + + update_response1 = await async_disposable_vector_collection.find_one_and_update( + filter=update_filter, update=update1, options=options1 + ) + assert isinstance(update_response1["data"]["document"], dict) + assert update_response1["data"]["document"]["status"] == "active" + assert update_response1["status"]["matchedCount"] >= 1 + assert update_response1["status"]["modifiedCount"] >= 1 + + response2 = await async_disposable_vector_collection.find_one(filter=find_filter) + assert isinstance(response2["data"]["document"], dict) + assert response2["data"]["document"]["status"] == "inactive" + + filter2 = {**update_filter, **{"nonexistent_subfield": 10}} + update2 = update1 + options2 = options1 + + update_response2 = await async_disposable_vector_collection.find_one_and_update( + filter=filter2, update=update2, options=options2 + ) + assert update_response2["data"]["document"] is None + assert update_response2["status"]["matchedCount"] == 0 + assert update_response2["status"]["modifiedCount"] == 0 + + +@pytest.mark.describe("find_one_and_replace, through vector") +async def test_find_one_and_replace_vector( + async_disposable_vector_collection: AsyncAstraDBCollection, +) -> None: + sort = {"$vector": [0.2, 0.6]} + + response0 = await async_disposable_vector_collection.find_one(sort=sort) + assert response0 is not None + assert "anotherfield" in response0["data"]["document"] + + doc0vector = response0["data"]["document"]["$vector"] + + replace_response0 = await async_disposable_vector_collection.find_one_and_replace( + sort=sort, + replacement={ + "phyla": ["Echinodermata", "Platelminta", "Chordata"], + "$vector": doc0vector, # to find this doc again below! + }, + ) + assert replace_response0 is not None + assert "anotherfield" in replace_response0["data"]["document"] + + response1 = await async_disposable_vector_collection.find_one(sort=sort) + assert response1 is not None + assert response1["data"]["document"]["phyla"] == [ + "Echinodermata", + "Platelminta", + "Chordata", + ] + assert "anotherfield" not in response1["data"]["document"] + + replace_response1 = await async_disposable_vector_collection.find_one_and_replace( + sort=sort, + replacement={ + "phone": "0123-4567", + "$vector": doc0vector, + }, + ) + assert replace_response1 is not None + assert replace_response1["data"]["document"]["phyla"] == [ + "Echinodermata", + "Platelminta", + "Chordata", + ] + assert "anotherfield" not in replace_response1["data"]["document"] + + response2 = await async_disposable_vector_collection.find_one(sort=sort) + assert response2 is not None + assert response2["data"]["document"]["phone"] == "0123-4567" + assert "phyla" not in response2["data"]["document"] + + # non-existing-doc case + filter_no = {"nonexisting_field": -123} + replace_response_no = await async_disposable_vector_collection.find_one_and_replace( + sort=sort, + filter=filter_no, + replacement={ + "whatever": -123, + "$vector": doc0vector, + }, + ) + assert replace_response_no is not None + assert replace_response_no["data"]["document"] is None + + +@pytest.mark.describe("find_one_and_replace, not through vector") +async def test_find_one_and_replace_novector( + async_disposable_vector_collection: AsyncAstraDBCollection, +) -> None: + response0 = await async_disposable_vector_collection.find_one(filter={"_id": "1"}) + assert response0 is not None + assert response0["data"]["document"]["anotherfield"] == "alpha" + + replace_response0 = await async_disposable_vector_collection.find_one_and_replace( + filter={"_id": "1"}, + replacement={ + "_id": "1", + "phyla": ["Echinodermata", "Platelminta", "Chordata"], + }, + ) + assert replace_response0 is not None + assert replace_response0["data"]["document"]["anotherfield"] == "alpha" + + response1 = await async_disposable_vector_collection.find_one(filter={"_id": "1"}) + assert response1 is not None + assert response1["data"]["document"]["phyla"] == [ + "Echinodermata", + "Platelminta", + "Chordata", + ] + assert "anotherfield" not in response1["data"]["document"] + + replace_response1 = await async_disposable_vector_collection.find_one_and_replace( + filter={"_id": "1"}, + replacement={ + "phone": "0123-4567", + }, + ) + assert replace_response1 is not None + assert replace_response1["data"]["document"]["phyla"] == [ + "Echinodermata", + "Platelminta", + "Chordata", + ] + assert "anotherfield" not in replace_response1["data"]["document"] + + response2 = await async_disposable_vector_collection.find_one(filter={"_id": "1"}) + assert response2 is not None + assert response2["data"]["document"]["phone"] == "0123-4567" + assert "phyla" not in response2["data"]["document"] + + # non-existing-doc case + replace_response_no = await async_disposable_vector_collection.find_one_and_replace( + filter={"_id": "z"}, + replacement={ + "whatever": -123, + }, + ) + assert replace_response_no is not None + assert replace_response_no["data"]["document"] is None + + +@pytest.mark.describe("delete_one, not through vector") +async def test_delete_one_novector( + async_disposable_vector_collection: AsyncAstraDBCollection, +) -> None: + delete_response = await async_disposable_vector_collection.delete_one(id="3") + assert delete_response["status"]["deletedCount"] == 1 + + response = await async_disposable_vector_collection.find_one(filter={"_id": "3"}) + assert response["data"]["document"] is None + + delete_response_no = await async_disposable_vector_collection.delete_one(id="3") + assert delete_response_no["status"]["deletedCount"] == 0 + + +@pytest.mark.describe("delete_many, not through vector") +async def test_delete_many_novector( + async_disposable_vector_collection: AsyncAstraDBCollection, +) -> None: + delete_response = await async_disposable_vector_collection.delete_many( + filter={"anotherfield": "alpha"} + ) + assert delete_response["status"]["deletedCount"] == 2 + + documents_no = await async_disposable_vector_collection.find( + filter={"anotherfield": "alpha"} + ) + assert documents_no["data"]["documents"] == [] + + delete_response_no = await async_disposable_vector_collection.delete_many( + filter={"anotherfield": "alpha"} + ) + assert delete_response_no["status"]["deletedCount"] == 0 + + +@pytest.mark.describe("pop, push functions, not through vector") +async def test_pop_push_novector( + async_disposable_vector_collection: AsyncAstraDBCollection, +) -> None: + user_id = str(uuid.uuid4()) + await async_disposable_vector_collection.insert_one( + document={ + "_id": user_id, + "first_name": "Cliff", + "last_name": "Wicklow", + "roles": ["user", "admin"], + }, + ) + + pop = {"roles": 1} + options = {"returnDocument": "after"} + + pop_response = await async_disposable_vector_collection.pop( + filter={"_id": user_id}, pop=pop, options=options + ) + assert pop_response is not None + assert pop_response["data"]["document"]["roles"] == ["user"] + assert pop_response["status"]["matchedCount"] >= 1 + assert pop_response["status"]["modifiedCount"] == 1 + + response1 = await async_disposable_vector_collection.find_one( + filter={"_id": user_id} + ) + assert response1 is not None + assert response1["data"]["document"]["roles"] == ["user"] + + push = {"roles": "auditor"} + + push_response = await async_disposable_vector_collection.push( + filter={"_id": user_id}, push=push, options=options + ) + assert push_response is not None + assert push_response["data"]["document"]["roles"] == ["user", "auditor"] + assert push_response["status"]["matchedCount"] >= 1 + assert push_response["status"]["modifiedCount"] == 1 + + response2 = await async_disposable_vector_collection.find_one( + filter={"_id": user_id} + ) + assert response2 is not None + assert response2["data"]["document"]["roles"] == ["user", "auditor"] diff --git a/tests/astrapy/test_async_db_dml_pagination.py b/tests/astrapy/test_async_db_dml_pagination.py new file mode 100644 index 00000000..7048bd01 --- /dev/null +++ b/tests/astrapy/test_async_db_dml_pagination.py @@ -0,0 +1,96 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for the `db.py` parts on pagination primitives +""" + +import math +import os +import logging +from typing import Dict, Iterable, List, Optional, Set, TypeVar, AsyncIterable +import pytest + +from astrapy.db import AsyncAstraDBCollection, AsyncAstraDB + +logger = logging.getLogger(__name__) + + +TEST_PAGINATION_COLLECTION_NAME = "pagination_v_col" +INSERT_BATCH_SIZE = 20 # max 20, fixed by API constraints +N = 200 # must be EVEN +FIND_LIMIT = 183 # Keep this > 20 and <= N to actually put pagination to test + +T = TypeVar("T") + + +def _mk_vector(index: int, n_total_steps: int) -> List[float]: + angle = 2 * math.pi * index / n_total_steps + return [math.cos(angle), math.sin(angle)] + + +def _batch_iterable(iterable: Iterable[T], batch_size: int) -> Iterable[Iterable[T]]: + this_batch = [] + for entry in iterable: + this_batch.append(entry) + if len(this_batch) == batch_size: + yield this_batch + this_batch = [] + if this_batch: + yield this_batch + + +@pytest.fixture +async def pag_test_collection( + astra_db_credentials_kwargs: Dict[str, Optional[str]] +) -> AsyncIterable[AsyncAstraDBCollection]: + async with AsyncAstraDB(**astra_db_credentials_kwargs) as astra_db: + astra_db_collection = await astra_db.create_collection( + collection_name=TEST_PAGINATION_COLLECTION_NAME, dimension=2 + ) + + if int(os.getenv("TEST_PAGINATION_SKIP_INSERTION", "0")) == 0: + inserted_ids: Set[str] = set() + for i_batch in _batch_iterable(range(N), INSERT_BATCH_SIZE): + batch_ids = ( + await astra_db_collection.insert_many( + documents=[ + {"_id": str(i), "$vector": _mk_vector(i, N)} + for i in i_batch + ] + ) + )["status"]["insertedIds"] + inserted_ids = inserted_ids | set(batch_ids) + assert inserted_ids == {str(i) for i in range(N)} + yield astra_db_collection + if int(os.getenv("TEST_PAGINATION_SKIP_DELETE_COLLECTION", "0")) == 0: + _ = await astra_db.delete_collection( + collection_name=TEST_PAGINATION_COLLECTION_NAME + ) + + +@pytest.mark.describe( + "should retrieve the required amount of documents, all different, through pagination" +) +async def test_find_paginated(pag_test_collection: AsyncAstraDBCollection) -> None: + options = {"limit": FIND_LIMIT} + projection = {"$vector": 0} + + paginated_documents = pag_test_collection.paginated_find( + projection=projection, + options=options, + ) + paginated_ids = [doc["_id"] async for doc in paginated_documents] + assert len(paginated_ids) == FIND_LIMIT + assert len(paginated_ids) == len(set(paginated_ids)) diff --git a/tests/astrapy/test_async_db_dml_vector.py b/tests/astrapy/test_async_db_dml_vector.py new file mode 100644 index 00000000..a5025cf0 --- /dev/null +++ b/tests/astrapy/test_async_db_dml_vector.py @@ -0,0 +1,300 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for the `db.py` parts on data manipulation `vector_*` methods +""" + +import logging +from typing import cast + +import pytest + +from astrapy.db import AsyncAstraDBCollection +from astrapy.types import API_DOC + +logger = logging.getLogger(__name__) + + +@pytest.mark.describe("vector_find and include_similarity parameter") +async def test_vector_find( + async_readonly_vector_collection: AsyncAstraDBCollection, +) -> None: + documents_sim_1 = await async_readonly_vector_collection.vector_find( + vector=[0.2, 0.6], + limit=3, + ) + + assert documents_sim_1 is not None + assert isinstance(documents_sim_1, list) + assert len(documents_sim_1) > 0 + assert "_id" in documents_sim_1[0] + assert "$vector" in documents_sim_1[0] + assert "text" in documents_sim_1[0] + assert "$similarity" in documents_sim_1[0] + + documents_sim_2 = await async_readonly_vector_collection.vector_find( + vector=[0.2, 0.6], + limit=3, + include_similarity=True, + ) + + assert documents_sim_2 is not None + assert isinstance(documents_sim_2, list) + assert len(documents_sim_2) > 0 + assert "_id" in documents_sim_2[0] + assert "$vector" in documents_sim_2[0] + assert "text" in documents_sim_2[0] + assert "$similarity" in documents_sim_2[0] + + documents_no_sim = await async_readonly_vector_collection.vector_find( + vector=[0.2, 0.6], + limit=3, + fields=["_id", "$vector"], + include_similarity=False, + ) + + assert documents_no_sim is not None + assert isinstance(documents_no_sim, list) + assert len(documents_no_sim) > 0 + assert "_id" in documents_no_sim[0] + assert "$vector" in documents_no_sim[0] + assert "text" not in documents_no_sim[0] + assert "$similarity" not in documents_no_sim[0] + + +@pytest.mark.describe("vector_find, obey projection") +async def test_vector_find_projection( + async_readonly_vector_collection: AsyncAstraDBCollection, +) -> None: + query = [0.2, 0.6] + + req_fieldsets = [ + None, + set(), + {"text"}, + {"$vector"}, + {"text", "$vector"}, + ] + exp_fieldsets = [ + {"$vector", "_id", "otherfield", "anotherfield", "text"}, + {"$vector", "_id", "otherfield", "anotherfield", "text"}, + {"_id", "text"}, + {"$vector", "_id"}, + {"$vector", "_id", "text"}, + ] + for include_similarity in [True, False]: + for req_fields, exp_fields0 in zip(req_fieldsets, exp_fieldsets): + vdocs = await async_readonly_vector_collection.vector_find( + query, + limit=1, + fields=list(req_fields) if req_fields is not None else req_fields, + include_similarity=include_similarity, + ) + if include_similarity: + exp_fields = exp_fields0 | {"$similarity"} + else: + exp_fields = exp_fields0 + assert set(vdocs[0].keys()) == exp_fields + + +@pytest.mark.describe("vector_find with filters") +async def test_vector_find_filters( + async_readonly_vector_collection: AsyncAstraDBCollection, +) -> None: + documents = await async_readonly_vector_collection.vector_find( + vector=[0.2, 0.6], + filter={"anotherfield": "alpha"}, + limit=3, + ) + assert isinstance(documents, list) + assert len(documents) == 2 + assert {doc["otherfield"]["subfield"] for doc in documents} == {"x1y", "x2y"} + + documents_no = await async_readonly_vector_collection.vector_find( + vector=[0.2, 0.6], + filter={"anotherfield": "epsilon"}, + limit=3, + ) + assert isinstance(documents_no, list) + assert len(documents_no) == 0 + + +@pytest.mark.describe("vector_find_one and include_similarity parameter") +async def test_vector_find_one( + async_readonly_vector_collection: AsyncAstraDBCollection, +) -> None: + document0 = await async_readonly_vector_collection.vector_find_one( + [0.2, 0.6], + ) + + assert document0 is not None + assert "_id" in document0 + assert "$vector" in document0 + assert "text" in document0 + assert "$similarity" in document0 + + document_w_sim = await async_readonly_vector_collection.vector_find_one( + [0.2, 0.6], + include_similarity=True, + ) + + assert document_w_sim is not None + assert "_id" in document_w_sim + assert "$vector" in document_w_sim + assert "text" in document_w_sim + assert "$similarity" in document_w_sim + + document_no_sim = await async_readonly_vector_collection.vector_find_one( + [0.2, 0.6], + include_similarity=False, + ) + + assert document_no_sim is not None + assert "_id" in document_no_sim + assert "$vector" in document_no_sim + assert "text" in document_no_sim + assert "$similarity" not in document_no_sim + + document_w_fields = await async_readonly_vector_collection.vector_find_one( + [0.2, 0.6], fields=["text"] + ) + + assert document_w_fields is not None + assert "_id" in document_w_fields + assert "$vector" not in document_w_fields + assert "text" in document_w_fields + assert "$similarity" in document_w_fields + + document_no = await async_readonly_vector_collection.vector_find_one( + [0.2, 0.6], + filter={"nonexisting": "gotcha"}, + ) + + assert document_no is None + + +@pytest.mark.describe("vector_find_one_and_update") +async def test_vector_find_one_and_update( + async_disposable_vector_collection: AsyncAstraDBCollection, +) -> None: + update = {"$set": {"status": "active"}} + + document0 = await async_disposable_vector_collection.vector_find_one( + vector=[0.1, 0.9], + filter={"status": "active"}, + ) + assert document0 is None + + update_response = ( + await async_disposable_vector_collection.vector_find_one_and_update( + vector=[0.1, 0.9], + update=update, + ) + ) + assert update_response is not None + assert update_response["_id"] == "1" + + document1 = await async_disposable_vector_collection.vector_find_one( + vector=[0.1, 0.9], + filter={"status": "active"}, + ) + + assert document1 is not None + assert document1["_id"] == update_response["_id"] + assert document1["status"] == "active" + + update_response_no = ( + await async_disposable_vector_collection.vector_find_one_and_update( + vector=[0.1, 0.9], + filter={"nonexisting": "gotcha"}, + update=update, + ) + ) + assert update_response_no is None + + +@pytest.mark.describe("vector_find_one_and_replace") +async def test_vector_find_one_and_replace( + async_disposable_vector_collection: AsyncAstraDBCollection, +) -> None: + replacement0 = { + "_id": "1", + "text": "Revised sample entry number <1>", + "added_field": True, + "$vector": [0.101, 0.899], + } + + document0 = await async_disposable_vector_collection.vector_find_one( + vector=[0.1, 0.9], + filter={"added_field": True}, + ) + assert document0 is None + + replace_response0 = ( + await async_disposable_vector_collection.vector_find_one_and_replace( + vector=[0.1, 0.9], + replacement=replacement0, + ) + ) + assert replace_response0 is not None + assert replace_response0["_id"] == "1" + + document1 = await async_disposable_vector_collection.vector_find_one( + vector=[0.1, 0.9], + filter={"added_field": True}, + ) + + assert document1 is not None + assert document1["_id"] == replace_response0["_id"] + assert "otherfield" not in document1 + assert "anotherfield" not in document1 + assert document1["text"] == replacement0["text"] + assert document1["added_field"] is True + + # no supplying the _id + replacement1 = { + "text": "Further revised sample entry number <1>", + "different_added_field": False, + "$vector": [0.101, 0.899], + } + + replace_response1 = ( + await async_disposable_vector_collection.vector_find_one_and_replace( + vector=[0.1, 0.9], + replacement=replacement1, + ) + ) + assert replace_response0 is not None + assert replace_response0["_id"] == "1" + + document2 = await async_disposable_vector_collection.vector_find_one( + vector=[0.1, 0.9], + filter={"different_added_field": False}, + ) + + assert document2 is not None + assert cast(API_DOC, document2)["_id"] == cast(API_DOC, replace_response1)["_id"] + assert cast(API_DOC, document2)["text"] == replacement1["text"] + assert "added_field" not in cast(API_DOC, document2) + assert cast(API_DOC, document2)["different_added_field"] is False + + replace_response_no = ( + await async_disposable_vector_collection.vector_find_one_and_replace( + vector=[0.1, 0.9], + filter={"nonexisting": "gotcha"}, + replacement=replacement1, + ) + ) + assert replace_response_no is None diff --git a/tests/conftest.py b/tests/conftest.py index 85cefd2e..383c7451 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,15 +2,16 @@ Test fixtures """ import os + import pytest import uuid -from typing import Dict, Iterable, Optional +from typing import Dict, Iterable, Optional, AsyncIterable +import pytest_asyncio from dotenv import load_dotenv from astrapy.defaults import DEFAULT_KEYSPACE_NAME -from astrapy.db import AstraDB, AstraDBCollection - +from astrapy.db import AstraDB, AstraDBCollection, AsyncAstraDB, AsyncAstraDBCollection load_dotenv() @@ -72,6 +73,14 @@ def db(astra_db_credentials_kwargs: Dict[str, Optional[str]]) -> AstraDB: return AstraDB(**astra_db_credentials_kwargs) +@pytest.fixture +async def async_db( + astra_db_credentials_kwargs: Dict[str, Optional[str]] +) -> AsyncIterable[AsyncAstraDB]: + async with AsyncAstraDB(**astra_db_credentials_kwargs) as db: + yield db + + @pytest.fixture(scope="module") def writable_vector_collection(db: AstraDB) -> Iterable[AstraDBCollection]: """ @@ -90,6 +99,26 @@ def writable_vector_collection(db: AstraDB) -> Iterable[AstraDBCollection]: db.delete_collection(TEST_WRITABLE_VECTOR_COLLECTION) +@pytest_asyncio.fixture +async def async_writable_vector_collection( + async_db: AsyncAstraDB, +) -> AsyncIterable[AsyncAstraDBCollection]: + """ + This is lasting for the whole test. Functions can write to it, + no guarantee (i.e. each test should use a different ID... + """ + collection = await async_db.create_collection( + TEST_WRITABLE_VECTOR_COLLECTION, + dimension=2, + ) + + await collection.insert_many(VECTOR_DOCUMENTS) + + yield collection + + await async_db.delete_collection(TEST_WRITABLE_VECTOR_COLLECTION) + + @pytest.fixture(scope="module") def readonly_vector_collection(db: AstraDB) -> Iterable[AstraDBCollection]: collection = db.create_collection( @@ -104,6 +133,22 @@ def readonly_vector_collection(db: AstraDB) -> Iterable[AstraDBCollection]: db.delete_collection(TEST_READONLY_VECTOR_COLLECTION) +@pytest.fixture +async def async_readonly_vector_collection( + async_db: AsyncAstraDB, +) -> AsyncIterable[AsyncAstraDBCollection]: + collection = await async_db.create_collection( + TEST_READONLY_VECTOR_COLLECTION, + dimension=2, + ) + + await collection.insert_many(VECTOR_DOCUMENTS) + + yield collection + + await async_db.delete_collection(TEST_READONLY_VECTOR_COLLECTION) + + @pytest.fixture(scope="function") def disposable_vector_collection(db: AstraDB) -> Iterable[AstraDBCollection]: collection = db.create_collection( @@ -116,3 +161,19 @@ def disposable_vector_collection(db: AstraDB) -> Iterable[AstraDBCollection]: yield collection db.delete_collection(TEST_DISPOSABLE_VECTOR_COLLECTION) + + +@pytest.fixture +async def async_disposable_vector_collection( + async_db: AsyncAstraDB, +) -> AsyncIterable[AsyncAstraDBCollection]: + collection = await async_db.create_collection( + TEST_DISPOSABLE_VECTOR_COLLECTION, + dimension=2, + ) + + await collection.insert_many(VECTOR_DOCUMENTS) + + yield collection + + await async_db.delete_collection(TEST_DISPOSABLE_VECTOR_COLLECTION)