From 141c4f9affc58fe63c7f7368679439ea55d8ca79 Mon Sep 17 00:00:00 2001 From: Eric Hare Date: Thu, 7 Dec 2023 15:23:47 -0800 Subject: [PATCH 1/2] Handling numpy vectors --- astrapy/db.py | 25 ++++++++++++++++++++++++- astrapy/utils.py | 15 ++++++++++++++- tests/astrapy/test_db_dml.py | 16 ++++++++++++++++ 3 files changed, 54 insertions(+), 2 deletions(-) diff --git a/astrapy/db.py b/astrapy/db.py index b0ab9141..77f228a1 100644 --- a/astrapy/db.py +++ b/astrapy/db.py @@ -25,7 +25,12 @@ DEFAULT_JSON_API_VERSION, DEFAULT_KEYSPACE_NAME, ) -from astrapy.utils import make_payload, make_request, http_methods +from astrapy.utils import ( + convert_vector_to_floats, + make_payload, + make_request, + http_methods, +) from astrapy.types import API_DOC, API_RESPONSE, PaginableRequestMethod @@ -578,6 +583,13 @@ def insert_one( Returns: dict: The response from the database after the insert operation. """ + if ( + "$vector" in document + and document["$vector"] + and not isinstance(document["$vector"][0], float) + ): + document["$vector"] = convert_vector_to_floats(document["$vector"]) + json_query = make_payload(top_level="insertOne", document=document) response = self._request( @@ -604,6 +616,17 @@ def insert_many( Returns: dict: The response from the database after the insert operation. """ + # Check if the vector is a list of floats + for i, document in enumerate(documents): + if ( + "$vector" in document + and document["$vector"] + and isinstance(document["$vector"][0], float) + ): + documents[i]["$vector"] = convert_vector_to_floats( + documents[i]["$vector"] + ) + json_query = make_payload( top_level="insertMany", documents=documents, options=options ) diff --git a/astrapy/utils.py b/astrapy/utils.py index c3ee7372..39ce9054 100644 --- a/astrapy/utils.py +++ b/astrapy/utils.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import logging import httpx @@ -111,3 +111,16 @@ def make_payload(top_level: str, **kwargs: Any) -> Dict[str, Any]: json_query[top_level][key] = value return json_query + + +def convert_vector_to_floats(vector: List[Any]) -> List[float]: + """ + Convert a vector of strings to a vector of floats. + + Args: + vector (list): A vector of objects. + + Returns: + list: A vector of floats. + """ + return [float(value) for value in vector] diff --git a/tests/astrapy/test_db_dml.py b/tests/astrapy/test_db_dml.py index 30234b6e..b9749338 100644 --- a/tests/astrapy/test_db_dml.py +++ b/tests/astrapy/test_db_dml.py @@ -19,6 +19,7 @@ import uuid import logging +import numpy from typing import List import pytest @@ -277,6 +278,21 @@ def test_create_document(writable_vector_collection: AstraDBCollection) -> None: ) +@pytest.mark.describe("should truncate a nonvector collection") +def test_insert_float32(writable_vector_collection: AstraDBCollection) -> None: + _id0 = str(uuid.uuid4()) + document = { + "_id": _id0, + "name": "Numpy", + "$vector": [numpy.float32(0.1), numpy.float32(0.2)], + } + response = writable_vector_collection.insert_one(document) + assert response is not None + inserted_ids = response["status"]["insertedIds"] + assert len(inserted_ids) == 1 + assert inserted_ids[0] == _id0 + + @pytest.mark.describe("insert_many") def test_insert_many(writable_vector_collection: AstraDBCollection) -> None: _id0 = str(uuid.uuid4()) From a6defe4c4d86cbd25e4c6e23337edadad7d36c65 Mon Sep 17 00:00:00 2001 From: Eric Hare Date: Mon, 11 Dec 2023 11:28:19 -0800 Subject: [PATCH 2/2] Updates based on feedback for coercion --- astrapy/db.py | 40 +++++++++++++++++------------------- astrapy/utils.py | 26 +++++++++++++++++++++-- tests/astrapy/test_db_dml.py | 9 ++++---- 3 files changed, 48 insertions(+), 27 deletions(-) diff --git a/astrapy/db.py b/astrapy/db.py index 77f228a1..b0eec7e9 100644 --- a/astrapy/db.py +++ b/astrapy/db.py @@ -30,6 +30,7 @@ make_payload, make_request, http_methods, + preprocess_insert, ) from astrapy.types import API_DOC, API_RESPONSE, PaginableRequestMethod @@ -221,7 +222,7 @@ def vector_find( # Pre-process the included arguments sort, projection = self._pre_process_find( - vector, + convert_vector_to_floats(vector), fields=fields, ) @@ -348,7 +349,7 @@ def push( def find_one_and_replace( self, - replacement: Optional[Dict[str, Any]] = None, + replacement: Dict[str, Any], *, sort: Optional[Dict[str, Any]] = {}, filter: Optional[Dict[str, Any]] = None, @@ -364,6 +365,8 @@ def find_one_and_replace( Returns: dict: The result of the find and replace operation. """ + replacement = preprocess_insert(replacement) + json_query = make_payload( top_level="findOneAndReplace", filter=filter, @@ -396,9 +399,11 @@ def vector_find_one_and_replace( Returns: dict or None: either the matched document or None if nothing found """ + replacement = preprocess_insert(replacement) + # Pre-process the included arguments sort, _ = self._pre_process_find( - vector, + convert_vector_to_floats(vector), fields=fields, ) @@ -413,21 +418,23 @@ def vector_find_one_and_replace( def find_one_and_update( self, + update: Dict[str, Any], sort: Optional[Dict[str, Any]] = {}, - update: Optional[Dict[str, Any]] = None, filter: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, ) -> API_RESPONSE: """ Find a single document and update it. Args: + update (dict): The update to apply to the document. sort (dict, optional): Specifies the order in which to find the document. - update (dict, optional): The update to apply to the document. filter (dict, optional): Criteria to filter documents. options (dict, optional): Additional options for the operation. Returns: dict: The result of the find and update operation. """ + update = preprocess_insert(update) + json_query = make_payload( top_level="findOneAndUpdate", filter=filter, @@ -463,9 +470,11 @@ def vector_find_one_and_update( dict or None: The result of the vector-based find and update operation, or None if nothing found """ + update = preprocess_insert(update) + # Pre-process the included arguments sort, _ = self._pre_process_find( - vector, + convert_vector_to_floats(vector), fields=fields, ) @@ -558,7 +567,7 @@ def vector_find_one( """ # Pre-process the included arguments sort, projection = self._pre_process_find( - vector, + convert_vector_to_floats(vector), fields=fields, ) @@ -583,12 +592,7 @@ def insert_one( Returns: dict: The response from the database after the insert operation. """ - if ( - "$vector" in document - and document["$vector"] - and not isinstance(document["$vector"][0], float) - ): - document["$vector"] = convert_vector_to_floats(document["$vector"]) + document = preprocess_insert(document) json_query = make_payload(top_level="insertOne", document=document) @@ -618,14 +622,7 @@ def insert_many( """ # Check if the vector is a list of floats for i, document in enumerate(documents): - if ( - "$vector" in document - and document["$vector"] - and isinstance(document["$vector"][0], float) - ): - documents[i]["$vector"] = convert_vector_to_floats( - documents[i]["$vector"] - ) + documents[i] = preprocess_insert(document) json_query = make_payload( top_level="insertMany", documents=documents, options=options @@ -751,6 +748,7 @@ def upsert(self, document: API_DOC) -> str: str: The _id of the inserted or updated document. """ # Build the payload for the insert attempt + document = preprocess_insert(document) result = self.insert_one(document, failures_allowed=True) # If the call failed, then we replace the existing doc diff --git a/astrapy/utils.py b/astrapy/utils.py index 39ce9054..dd29123e 100644 --- a/astrapy/utils.py +++ b/astrapy/utils.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Iterable, List, Optional import logging import httpx @@ -113,7 +113,7 @@ def make_payload(top_level: str, **kwargs: Any) -> Dict[str, Any]: return json_query -def convert_vector_to_floats(vector: List[Any]) -> List[float]: +def convert_vector_to_floats(vector: Iterable[Any]) -> List[float]: """ Convert a vector of strings to a vector of floats. @@ -124,3 +124,25 @@ def convert_vector_to_floats(vector: List[Any]) -> List[float]: list: A vector of floats. """ return [float(value) for value in vector] + + +def preprocess_insert(document: Dict[str, Any]) -> Dict[str, Any]: + """ + Perform preprocessing operations before an insertion + + Args: + vector (list): A vector of objects. + + Returns: + list: A vector of objects + """ + + # Process each field of the cocument + for key, value in document.items(): + # Vector coercision + if key == "$vector" and not isinstance(document["$vector"][0], float): + document[key] = convert_vector_to_floats(value) + + # TODO: More pre-processing operations + + return document diff --git a/tests/astrapy/test_db_dml.py b/tests/astrapy/test_db_dml.py index b9749338..5bf1d24a 100644 --- a/tests/astrapy/test_db_dml.py +++ b/tests/astrapy/test_db_dml.py @@ -19,7 +19,6 @@ import uuid import logging -import numpy from typing import List import pytest @@ -279,12 +278,14 @@ def test_create_document(writable_vector_collection: AstraDBCollection) -> None: @pytest.mark.describe("should truncate a nonvector collection") -def test_insert_float32(writable_vector_collection: AstraDBCollection) -> None: +def test_insert_float32( + writable_vector_collection: AstraDBCollection, N: int = 2 +) -> None: _id0 = str(uuid.uuid4()) document = { "_id": _id0, - "name": "Numpy", - "$vector": [numpy.float32(0.1), numpy.float32(0.2)], + "name": "Coerce", + "$vector": [f"{(i+1)/N+2:.4f}" for i in range(N)], } response = writable_vector_collection.insert_one(document) assert response is not None