From c0ec96c6a49f4c5b328273268623f275db4b7702 Mon Sep 17 00:00:00 2001 From: Stefano Lottini Date: Fri, 12 Jan 2024 20:38:18 +0100 Subject: [PATCH] implemented, all tests passing --- astrapy/db.py | 7 +- astrapy/utils.py | 66 ++++++++++++++--- tests/astrapy/test_async_db_dml.py | 89 ++++++++++++++++++++++- tests/astrapy/test_async_db_dml_vector.py | 7 +- tests/astrapy/test_db_dml.py | 86 +++++++++++++++++++++- tests/astrapy/test_db_dml_vector.py | 6 +- 6 files changed, 237 insertions(+), 24 deletions(-) diff --git a/astrapy/db.py b/astrapy/db.py index e8825c34..7d4f54d2 100644 --- a/astrapy/db.py +++ b/astrapy/db.py @@ -53,6 +53,7 @@ http_methods, amake_request, normalize_for_api, + restore_from_api, ) from astrapy.types import ( API_DOC, @@ -125,7 +126,8 @@ def _request( **kwargs, ) - response = request_handler.request() + direct_response = request_handler.request() + response = restore_from_api(direct_response) return response @@ -999,7 +1001,8 @@ async def _request( skip_error_check=skip_error_check, ) - response = await arequest_handler.request() + direct_response = await arequest_handler.request() + response = restore_from_api(direct_response) return response diff --git a/astrapy/utils.py b/astrapy/utils.py index d1edaed3..8a4d4858 100644 --- a/astrapy/utils.py +++ b/astrapy/utils.py @@ -1,11 +1,14 @@ from __future__ import annotations -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, cast, Dict, Iterable, List, Optional, Union +import time +import datetime import logging import httpx from astrapy import __version__ from astrapy.defaults import DEFAULT_TIMEOUT +from astrapy.types import API_RESPONSE class CustomLogger(logging.Logger): @@ -196,12 +199,24 @@ def is_list_of_floats(vector: Iterable[Any]) -> bool: return False +def convert_to_ejson_date_object( + date_value: Union[datetime.date, datetime.datetime] +) -> Dict[str, int]: + return {"$date": int(time.mktime(date_value.timetuple()) * 1000)} + + +def convert_ejson_date_object_to_datetime( + date_object: Dict[str, int] +) -> datetime.datetime: + return datetime.datetime.fromtimestamp(date_object["$date"] / 1000.0) + + def _normalize_payload_value(path: List[str], value: Any) -> Any: """ The path helps determining special treatments """ - _l2 = '.'.join(path[-2:]) - _l1 = '.'.join(path[-1:]) + _l2 = ".".join(path[-2:]) + _l1 = ".".join(path[-1:]) if _l1 == "$vector" and _l2 != "projection.$vector": if not is_list_of_floats(value): return convert_vector_to_floats(value) @@ -210,19 +225,22 @@ def _normalize_payload_value(path: List[str], value: Any) -> Any: else: if isinstance(value, dict): return { - k: _normalize_payload_value(path + [k], v) - for k, v in value.items() + k: _normalize_payload_value(path + [k], v) for k, v in value.items() } elif isinstance(value, list): return [ - _normalize_payload_value(path + [""], list_item) - for list_item in value + _normalize_payload_value(path + [""], list_item) for list_item in value ] else: - return value + if isinstance(value, datetime.datetime) or isinstance(value, datetime.date): + return convert_to_ejson_date_object(value) + else: + return value -def normalize_for_api(payload: Dict[str, Any]) -> Dict[str, Any]: +def normalize_for_api( + payload: Union[Dict[str, Any], None] +) -> Union[Dict[str, Any], None]: """ Normalize a payload for API calls. This includes e.g. ensuring values for "$vector" key @@ -235,4 +253,32 @@ def normalize_for_api(payload: Dict[str, Any]) -> Dict[str, Any]: Dict[str, Any]: a "normalized" payload dict """ - return _normalize_payload_value([], payload) + if payload: + return cast(Dict[str, Any], _normalize_payload_value([], payload)) + else: + return payload + + +def _restore_response_value(path: List[str], value: Any) -> Any: + """ + The path helps determining special treatments + """ + if isinstance(value, dict): + if len(value) == 1 and "$date" in value: + # this is `{"$date": 123456}`, restore to datetime.datetime + return convert_ejson_date_object_to_datetime(value) + else: + return {k: _restore_response_value(path + [k], v) for k, v in value.items()} + elif isinstance(value, list): + return [_restore_response_value(path + [""], list_item) for list_item in value] + else: + return value + + +def restore_from_api(response: API_RESPONSE) -> API_RESPONSE: + """ + Process a dictionary just returned from the API. + This is the place where e.g. `{"$date": 123}` is + converted back into a datetime object. + """ + return cast(API_RESPONSE, _restore_response_value([], response)) diff --git a/tests/astrapy/test_async_db_dml.py b/tests/astrapy/test_async_db_dml.py index fe2ef32d..cad6c85b 100644 --- a/tests/astrapy/test_async_db_dml.py +++ b/tests/astrapy/test_async_db_dml.py @@ -18,8 +18,9 @@ """ import uuid +import datetime import logging -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, cast, Dict, Iterable, List, Literal, Optional, Union import pytest @@ -164,10 +165,12 @@ async def test_find_find_one_projection( async def test_find_float32( async_readonly_vector_collection: AsyncAstraDBCollection, ) -> None: - def ite(): + def ite() -> Iterable[str]: for v in [0.1, 0.2]: yield f"{v}" - sort = {"$vector": ite()} + + # we surreptitously trick typing here + sort = {"$vector": cast(List[float], ite())} options = {"limit": 5} response = await async_readonly_vector_collection.find(sort=sort, options=options) @@ -1021,3 +1024,83 @@ async def test_pop_push_novector( ) assert response2 is not None assert response2["data"]["document"]["roles"] == ["user", "auditor"] + + +@pytest.mark.describe("store and retrieve dates and datetimes correctly") +async def test_insert_find_with_dates( + async_writable_vector_collection: AsyncAstraDBCollection, +) -> None: + date0 = datetime.date(2024, 1, 12) + datetime0 = datetime.datetime(2024, 1, 12, 0, 0) + date1 = datetime.date(2024, 1, 13) + datetime1 = datetime.datetime(2024, 1, 13, 0, 0) + + d_doc_id = str(uuid.uuid4()) + d_document = { + "_id": d_doc_id, + "my_date": date0, + "my_datetime": datetime0, + "nested": { + "n_date": date1, + "n_datetime": datetime1, + }, + "nested_list": { + "the_list": [ + date0, + datetime0, + date1, + datetime1, + ] + }, + } + expected_d_document = { + "_id": d_doc_id, + "my_date": datetime0, + "my_datetime": datetime0, + "nested": { + "n_date": datetime1, + "n_datetime": datetime1, + }, + "nested_list": { + "the_list": [ + datetime0, + datetime0, + datetime1, + datetime1, + ] + }, + } + + _ = await async_writable_vector_collection.insert_one(d_document) + + # retrieve it, simple + response0 = await async_writable_vector_collection.find_one( + filter={"_id": d_doc_id} + ) + assert response0 is not None + document0 = response0["data"]["document"] + assert document0 == expected_d_document + + # retrieve it, lt condition on a date + response1 = await async_writable_vector_collection.find_one( + filter={"nested_list.the_list.0": {"$lt": date1}} + ) + assert response1 is not None + document1 = response1["data"]["document"] + assert document1 == expected_d_document + + # retrieve it, gte condition on a datetime + response2 = await async_writable_vector_collection.find_one( + filter={"nested.n_date": {"$gte": datetime0}} + ) + assert response2 is not None + document2 = response2["data"]["document"] + assert document2 == expected_d_document + + # retrieve it, filter == condition on a datetime + response3 = await async_writable_vector_collection.find_one( + filter={"my_date": datetime0} + ) + assert response3 is not None + document3 = response3["data"]["document"] + assert document3 == expected_d_document diff --git a/tests/astrapy/test_async_db_dml_vector.py b/tests/astrapy/test_async_db_dml_vector.py index ee76e308..5cd085e2 100644 --- a/tests/astrapy/test_async_db_dml_vector.py +++ b/tests/astrapy/test_async_db_dml_vector.py @@ -17,7 +17,7 @@ """ import logging -from typing import cast +from typing import cast, Iterable, List import pytest @@ -78,12 +78,13 @@ async def test_vector_find( async def test_vector_find_float32( async_readonly_vector_collection: AsyncAstraDBCollection, ) -> None: - def ite(): + def ite() -> Iterable[str]: for v in [0.1, 0.2]: yield f"{v}" documents_sim_1 = await async_readonly_vector_collection.vector_find( - vector=ite(), + # we surreptitously trick typing here + vector=cast(List[float], ite()), limit=3, ) diff --git a/tests/astrapy/test_db_dml.py b/tests/astrapy/test_db_dml.py index 3d65d82f..d866384d 100644 --- a/tests/astrapy/test_db_dml.py +++ b/tests/astrapy/test_db_dml.py @@ -18,10 +18,11 @@ """ import uuid +import datetime import logging import json import httpx -from typing import Dict, List, Literal, Optional, Set +from typing import cast, Dict, Iterable, List, Literal, Optional, Set import pytest @@ -161,15 +162,18 @@ def test_find_find_one_projection( def test_find_float32( readonly_vector_collection: AstraDBCollection, ) -> None: - def ite(): + def ite() -> Iterable[str]: for v in [0.1, 0.2]: yield f"{v}" - sort = {"$vector": ite()} + + # we surreptitously trick typing here + sort = {"$vector": cast(List[float], ite())} options = {"limit": 5} response = readonly_vector_collection.find(sort=sort, options=options) assert isinstance(response["data"]["documents"], list) + @pytest.mark.describe("find through vector") def test_find(readonly_vector_collection: AstraDBCollection) -> None: sort = {"$vector": [0.2, 0.6]} @@ -1176,3 +1180,79 @@ def test_find_find_one_non_equality_operators( projection=projection, ) assert resp8["data"]["documents"][0]["marker"] == "abc" + + +@pytest.mark.describe("store and retrieve dates and datetimes correctly") +def test_insert_find_with_dates( + writable_vector_collection: AstraDBCollection, +) -> None: + date0 = datetime.date(2024, 1, 12) + datetime0 = datetime.datetime(2024, 1, 12, 0, 0) + date1 = datetime.date(2024, 1, 13) + datetime1 = datetime.datetime(2024, 1, 13, 0, 0) + + d_doc_id = str(uuid.uuid4()) + d_document = { + "_id": d_doc_id, + "my_date": date0, + "my_datetime": datetime0, + "nested": { + "n_date": date1, + "n_datetime": datetime1, + }, + "nested_list": { + "the_list": [ + date0, + datetime0, + date1, + datetime1, + ] + }, + } + expected_d_document = { + "_id": d_doc_id, + "my_date": datetime0, + "my_datetime": datetime0, + "nested": { + "n_date": datetime1, + "n_datetime": datetime1, + }, + "nested_list": { + "the_list": [ + datetime0, + datetime0, + datetime1, + datetime1, + ] + }, + } + + _ = writable_vector_collection.insert_one(d_document) + + # retrieve it, simple + response0 = writable_vector_collection.find_one(filter={"_id": d_doc_id}) + assert response0 is not None + document0 = response0["data"]["document"] + assert document0 == expected_d_document + + # retrieve it, lt condition on a date + response1 = writable_vector_collection.find_one( + filter={"nested_list.the_list.0": {"$lt": date1}} + ) + assert response1 is not None + document1 = response1["data"]["document"] + assert document1 == expected_d_document + + # retrieve it, gte condition on a datetime + response2 = writable_vector_collection.find_one( + filter={"nested.n_date": {"$gte": datetime0}} + ) + assert response2 is not None + document2 = response2["data"]["document"] + assert document2 == expected_d_document + + # retrieve it, filter == condition on a datetime + response3 = writable_vector_collection.find_one(filter={"my_date": datetime0}) + assert response3 is not None + document3 = response3["data"]["document"] + assert document3 == expected_d_document diff --git a/tests/astrapy/test_db_dml_vector.py b/tests/astrapy/test_db_dml_vector.py index 9eea04ab..351af589 100644 --- a/tests/astrapy/test_db_dml_vector.py +++ b/tests/astrapy/test_db_dml_vector.py @@ -17,7 +17,7 @@ """ import logging -from typing import cast +from typing import cast, Iterable, List import pytest @@ -76,12 +76,12 @@ def test_vector_find(readonly_vector_collection: AstraDBCollection) -> None: def test_vector_find_float32( readonly_vector_collection: AstraDBCollection, ) -> None: - def ite(): + def ite() -> Iterable[str]: for v in [0.1, 0.2]: yield f"{v}" documents_sim_1 = readonly_vector_collection.vector_find( - vector=ite(), + vector=cast(List[float], ite()), # we surreptitously trick typing here limit=3, )