Skip to content

Commit

Permalink
Merge pull request #144 from datastax/bugfix/#139-numpy-float
Browse files Browse the repository at this point in the history
Ensure that data gets coerced to JSON-serializable floats
  • Loading branch information
erichare committed Jan 2, 2024
2 parents c5e5809 + a1fcba3 commit 6c94729
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 9 deletions.
38 changes: 30 additions & 8 deletions astrapy/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,14 @@
DEFAULT_JSON_API_VERSION,
DEFAULT_KEYSPACE_NAME,
)
from astrapy.utils import make_payload, make_request, http_methods, amake_request
from astrapy.utils import (
convert_vector_to_floats,
make_payload,
make_request,
http_methods,
amake_request,
preprocess_insert,
)
from astrapy.types import (
API_DOC,
API_RESPONSE,
Expand Down Expand Up @@ -237,7 +244,7 @@ def vector_find(

# Pre-process the included arguments
sort, projection = self._pre_process_find(
vector,
convert_vector_to_floats(vector),
fields=fields,
)

Expand Down Expand Up @@ -399,7 +406,7 @@ def push(

def find_one_and_replace(
self,
replacement: Optional[Dict[str, Any]] = None,
replacement: Dict[str, Any],
*,
sort: Optional[Dict[str, Any]] = {},
filter: Optional[Dict[str, Any]] = None,
Expand All @@ -415,6 +422,8 @@ def find_one_and_replace(
Returns:
dict: The result of the find and replace operation.
"""
replacement = preprocess_insert(replacement)

json_query = make_payload(
top_level="findOneAndReplace",
filter=filter,
Expand Down Expand Up @@ -447,9 +456,11 @@ def vector_find_one_and_replace(
Returns:
dict or None: either the matched document or None if nothing found
"""
replacement = preprocess_insert(replacement)

# Pre-process the included arguments
sort, _ = self._pre_process_find(
vector,
convert_vector_to_floats(vector),
fields=fields,
)

Expand All @@ -464,21 +475,23 @@ def vector_find_one_and_replace(

def find_one_and_update(
self,
update: Dict[str, Any],
sort: Optional[Dict[str, Any]] = {},
update: Optional[Dict[str, Any]] = None,
filter: Optional[Dict[str, Any]] = None,
options: Optional[Dict[str, Any]] = None,
) -> API_RESPONSE:
"""
Find a single document and update it.
Args:
update (dict): The update to apply to the document.
sort (dict, optional): Specifies the order in which to find the document.
update (dict, optional): The update to apply to the document.
filter (dict, optional): Criteria to filter documents.
options (dict, optional): Additional options for the operation.
Returns:
dict: The result of the find and update operation.
"""
update = preprocess_insert(update)

json_query = make_payload(
top_level="findOneAndUpdate",
filter=filter,
Expand Down Expand Up @@ -514,9 +527,11 @@ def vector_find_one_and_update(
dict or None: The result of the vector-based find and
update operation, or None if nothing found
"""
update = preprocess_insert(update)

# Pre-process the included arguments
sort, _ = self._pre_process_find(
vector,
convert_vector_to_floats(vector),
fields=fields,
)

Expand Down Expand Up @@ -609,7 +624,7 @@ def vector_find_one(
"""
# Pre-process the included arguments
sort, projection = self._pre_process_find(
vector,
convert_vector_to_floats(vector),
fields=fields,
)

Expand All @@ -634,6 +649,8 @@ def insert_one(
Returns:
dict: The response from the database after the insert operation.
"""
document = preprocess_insert(document)

json_query = make_payload(top_level="insertOne", document=document)

response = self._request(
Expand All @@ -660,6 +677,10 @@ def insert_many(
Returns:
dict: The response from the database after the insert operation.
"""
# Check if the vector is a list of floats
for i, document in enumerate(documents):
documents[i] = preprocess_insert(document)

json_query = make_payload(
top_level="insertMany", documents=documents, options=options
)
Expand Down Expand Up @@ -784,6 +805,7 @@ def upsert(self, document: API_DOC) -> str:
str: The _id of the inserted or updated document.
"""
# Build the payload for the insert attempt
document = preprocess_insert(document)
result = self.insert_one(document, failures_allowed=True)

# If the call failed, then we replace the existing doc
Expand Down
37 changes: 36 additions & 1 deletion astrapy/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Any, Dict, Optional
from typing import Any, Dict, Iterable, List, Optional
import logging

import httpx
Expand Down Expand Up @@ -169,3 +169,38 @@ def make_payload(top_level: str, **kwargs: Any) -> Dict[str, Any]:
json_query[top_level][key] = value

return json_query


def convert_vector_to_floats(vector: Iterable[Any]) -> List[float]:
"""
Convert a vector of strings to a vector of floats.
Args:
vector (list): A vector of objects.
Returns:
list: A vector of floats.
"""
return [float(value) for value in vector]


def preprocess_insert(document: Dict[str, Any]) -> Dict[str, Any]:
"""
Perform preprocessing operations before an insertion
Args:
vector (list): A vector of objects.
Returns:
list: A vector of objects
"""

# Process each field of the cocument
for key, value in document.items():
# Vector coercision
if key == "$vector" and not isinstance(document["$vector"][0], float):
document[key] = convert_vector_to_floats(value)

# TODO: More pre-processing operations

return document
17 changes: 17 additions & 0 deletions tests/astrapy/test_db_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,23 @@ def test_create_document(writable_vector_collection: AstraDBCollection) -> None:
)


@pytest.mark.describe("should truncate a nonvector collection")
def test_insert_float32(
writable_vector_collection: AstraDBCollection, N: int = 2
) -> None:
_id0 = str(uuid.uuid4())
document = {
"_id": _id0,
"name": "Coerce",
"$vector": [f"{(i+1)/N+2:.4f}" for i in range(N)],
}
response = writable_vector_collection.insert_one(document)
assert response is not None
inserted_ids = response["status"]["insertedIds"]
assert len(inserted_ids) == 1
assert inserted_ids[0] == _id0


@pytest.mark.describe("insert_many")
def test_insert_many(writable_vector_collection: AstraDBCollection) -> None:
_id0 = str(uuid.uuid4())
Expand Down

0 comments on commit 6c94729

Please sign in to comment.