diff --git a/astrapy/db.py b/astrapy/db.py index 83904678..dbf89872 100644 --- a/astrapy/db.py +++ b/astrapy/db.py @@ -42,7 +42,14 @@ DEFAULT_JSON_API_VERSION, DEFAULT_KEYSPACE_NAME, ) -from astrapy.utils import make_payload, make_request, http_methods, amake_request +from astrapy.utils import ( + convert_vector_to_floats, + make_payload, + make_request, + http_methods, + amake_request, + preprocess_insert, +) from astrapy.types import ( API_DOC, API_RESPONSE, @@ -237,7 +244,7 @@ def vector_find( # Pre-process the included arguments sort, projection = self._pre_process_find( - vector, + convert_vector_to_floats(vector), fields=fields, ) @@ -399,7 +406,7 @@ def push( def find_one_and_replace( self, - replacement: Optional[Dict[str, Any]] = None, + replacement: Dict[str, Any], *, sort: Optional[Dict[str, Any]] = {}, filter: Optional[Dict[str, Any]] = None, @@ -415,6 +422,8 @@ def find_one_and_replace( Returns: dict: The result of the find and replace operation. """ + replacement = preprocess_insert(replacement) + json_query = make_payload( top_level="findOneAndReplace", filter=filter, @@ -447,9 +456,11 @@ def vector_find_one_and_replace( Returns: dict or None: either the matched document or None if nothing found """ + replacement = preprocess_insert(replacement) + # Pre-process the included arguments sort, _ = self._pre_process_find( - vector, + convert_vector_to_floats(vector), fields=fields, ) @@ -464,21 +475,23 @@ def vector_find_one_and_replace( def find_one_and_update( self, + update: Dict[str, Any], sort: Optional[Dict[str, Any]] = {}, - update: Optional[Dict[str, Any]] = None, filter: Optional[Dict[str, Any]] = None, options: Optional[Dict[str, Any]] = None, ) -> API_RESPONSE: """ Find a single document and update it. Args: + update (dict): The update to apply to the document. sort (dict, optional): Specifies the order in which to find the document. - update (dict, optional): The update to apply to the document. filter (dict, optional): Criteria to filter documents. options (dict, optional): Additional options for the operation. Returns: dict: The result of the find and update operation. """ + update = preprocess_insert(update) + json_query = make_payload( top_level="findOneAndUpdate", filter=filter, @@ -514,9 +527,11 @@ def vector_find_one_and_update( dict or None: The result of the vector-based find and update operation, or None if nothing found """ + update = preprocess_insert(update) + # Pre-process the included arguments sort, _ = self._pre_process_find( - vector, + convert_vector_to_floats(vector), fields=fields, ) @@ -609,7 +624,7 @@ def vector_find_one( """ # Pre-process the included arguments sort, projection = self._pre_process_find( - vector, + convert_vector_to_floats(vector), fields=fields, ) @@ -634,6 +649,8 @@ def insert_one( Returns: dict: The response from the database after the insert operation. """ + document = preprocess_insert(document) + json_query = make_payload(top_level="insertOne", document=document) response = self._request( @@ -660,6 +677,10 @@ def insert_many( Returns: dict: The response from the database after the insert operation. """ + # Check if the vector is a list of floats + for i, document in enumerate(documents): + documents[i] = preprocess_insert(document) + json_query = make_payload( top_level="insertMany", documents=documents, options=options ) @@ -784,6 +805,7 @@ def upsert(self, document: API_DOC) -> str: str: The _id of the inserted or updated document. """ # Build the payload for the insert attempt + document = preprocess_insert(document) result = self.insert_one(document, failures_allowed=True) # If the call failed, then we replace the existing doc diff --git a/astrapy/utils.py b/astrapy/utils.py index 15fa72b3..5034c279 100644 --- a/astrapy/utils.py +++ b/astrapy/utils.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Any, Dict, Optional +from typing import Any, Dict, Iterable, List, Optional import logging import httpx @@ -169,3 +169,38 @@ def make_payload(top_level: str, **kwargs: Any) -> Dict[str, Any]: json_query[top_level][key] = value return json_query + + +def convert_vector_to_floats(vector: Iterable[Any]) -> List[float]: + """ + Convert a vector of strings to a vector of floats. + + Args: + vector (list): A vector of objects. + + Returns: + list: A vector of floats. + """ + return [float(value) for value in vector] + + +def preprocess_insert(document: Dict[str, Any]) -> Dict[str, Any]: + """ + Perform preprocessing operations before an insertion + + Args: + vector (list): A vector of objects. + + Returns: + list: A vector of objects + """ + + # Process each field of the cocument + for key, value in document.items(): + # Vector coercision + if key == "$vector" and not isinstance(document["$vector"][0], float): + document[key] = convert_vector_to_floats(value) + + # TODO: More pre-processing operations + + return document diff --git a/tests/astrapy/test_db_dml.py b/tests/astrapy/test_db_dml.py index 30234b6e..5bf1d24a 100644 --- a/tests/astrapy/test_db_dml.py +++ b/tests/astrapy/test_db_dml.py @@ -277,6 +277,23 @@ def test_create_document(writable_vector_collection: AstraDBCollection) -> None: ) +@pytest.mark.describe("should truncate a nonvector collection") +def test_insert_float32( + writable_vector_collection: AstraDBCollection, N: int = 2 +) -> None: + _id0 = str(uuid.uuid4()) + document = { + "_id": _id0, + "name": "Coerce", + "$vector": [f"{(i+1)/N+2:.4f}" for i in range(N)], + } + response = writable_vector_collection.insert_one(document) + assert response is not None + inserted_ids = response["status"]["insertedIds"] + assert len(inserted_ids) == 1 + assert inserted_ids[0] == _id0 + + @pytest.mark.describe("insert_many") def test_insert_many(writable_vector_collection: AstraDBCollection) -> None: _id0 = str(uuid.uuid4())