Skip to content

Commit

Permalink
implemented, all tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
hemidactylus committed Jan 12, 2024
1 parent c63ac81 commit c0ec96c
Show file tree
Hide file tree
Showing 6 changed files with 237 additions and 24 deletions.
7 changes: 5 additions & 2 deletions astrapy/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
http_methods,
amake_request,
normalize_for_api,
restore_from_api,
)
from astrapy.types import (
API_DOC,
Expand Down Expand Up @@ -125,7 +126,8 @@ def _request(
**kwargs,
)

response = request_handler.request()
direct_response = request_handler.request()
response = restore_from_api(direct_response)

return response

Expand Down Expand Up @@ -999,7 +1001,8 @@ async def _request(
skip_error_check=skip_error_check,
)

response = await arequest_handler.request()
direct_response = await arequest_handler.request()
response = restore_from_api(direct_response)

return response

Expand Down
66 changes: 56 additions & 10 deletions astrapy/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations
from typing import Any, Dict, Iterable, List, Optional
from typing import Any, cast, Dict, Iterable, List, Optional, Union
import time
import datetime
import logging

import httpx

from astrapy import __version__
from astrapy.defaults import DEFAULT_TIMEOUT
from astrapy.types import API_RESPONSE


class CustomLogger(logging.Logger):
Expand Down Expand Up @@ -196,12 +199,24 @@ def is_list_of_floats(vector: Iterable[Any]) -> bool:
return False


def convert_to_ejson_date_object(
date_value: Union[datetime.date, datetime.datetime]
) -> Dict[str, int]:
return {"$date": int(time.mktime(date_value.timetuple()) * 1000)}


def convert_ejson_date_object_to_datetime(
date_object: Dict[str, int]
) -> datetime.datetime:
return datetime.datetime.fromtimestamp(date_object["$date"] / 1000.0)


def _normalize_payload_value(path: List[str], value: Any) -> Any:
"""
The path helps determining special treatments
"""
_l2 = '.'.join(path[-2:])
_l1 = '.'.join(path[-1:])
_l2 = ".".join(path[-2:])
_l1 = ".".join(path[-1:])
if _l1 == "$vector" and _l2 != "projection.$vector":
if not is_list_of_floats(value):
return convert_vector_to_floats(value)
Expand All @@ -210,19 +225,22 @@ def _normalize_payload_value(path: List[str], value: Any) -> Any:
else:
if isinstance(value, dict):
return {
k: _normalize_payload_value(path + [k], v)
for k, v in value.items()
k: _normalize_payload_value(path + [k], v) for k, v in value.items()
}
elif isinstance(value, list):
return [
_normalize_payload_value(path + [""], list_item)
for list_item in value
_normalize_payload_value(path + [""], list_item) for list_item in value
]
else:
return value
if isinstance(value, datetime.datetime) or isinstance(value, datetime.date):
return convert_to_ejson_date_object(value)
else:
return value


def normalize_for_api(payload: Dict[str, Any]) -> Dict[str, Any]:
def normalize_for_api(
payload: Union[Dict[str, Any], None]
) -> Union[Dict[str, Any], None]:
"""
Normalize a payload for API calls.
This includes e.g. ensuring values for "$vector" key
Expand All @@ -235,4 +253,32 @@ def normalize_for_api(payload: Dict[str, Any]) -> Dict[str, Any]:
Dict[str, Any]: a "normalized" payload dict
"""

return _normalize_payload_value([], payload)
if payload:
return cast(Dict[str, Any], _normalize_payload_value([], payload))
else:
return payload


def _restore_response_value(path: List[str], value: Any) -> Any:
"""
The path helps determining special treatments
"""
if isinstance(value, dict):
if len(value) == 1 and "$date" in value:
# this is `{"$date": 123456}`, restore to datetime.datetime
return convert_ejson_date_object_to_datetime(value)
else:
return {k: _restore_response_value(path + [k], v) for k, v in value.items()}
elif isinstance(value, list):
return [_restore_response_value(path + [""], list_item) for list_item in value]
else:
return value


def restore_from_api(response: API_RESPONSE) -> API_RESPONSE:
"""
Process a dictionary just returned from the API.
This is the place where e.g. `{"$date": 123}` is
converted back into a datetime object.
"""
return cast(API_RESPONSE, _restore_response_value([], response))
89 changes: 86 additions & 3 deletions tests/astrapy/test_async_db_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
"""

import uuid
import datetime
import logging
from typing import Any, Dict, List, Literal, Optional, Union
from typing import Any, cast, Dict, Iterable, List, Literal, Optional, Union

import pytest

Expand Down Expand Up @@ -164,10 +165,12 @@ async def test_find_find_one_projection(
async def test_find_float32(
async_readonly_vector_collection: AsyncAstraDBCollection,
) -> None:
def ite():
def ite() -> Iterable[str]:
for v in [0.1, 0.2]:
yield f"{v}"
sort = {"$vector": ite()}

# we surreptitously trick typing here
sort = {"$vector": cast(List[float], ite())}
options = {"limit": 5}

response = await async_readonly_vector_collection.find(sort=sort, options=options)
Expand Down Expand Up @@ -1021,3 +1024,83 @@ async def test_pop_push_novector(
)
assert response2 is not None
assert response2["data"]["document"]["roles"] == ["user", "auditor"]


@pytest.mark.describe("store and retrieve dates and datetimes correctly")
async def test_insert_find_with_dates(
async_writable_vector_collection: AsyncAstraDBCollection,
) -> None:
date0 = datetime.date(2024, 1, 12)
datetime0 = datetime.datetime(2024, 1, 12, 0, 0)
date1 = datetime.date(2024, 1, 13)
datetime1 = datetime.datetime(2024, 1, 13, 0, 0)

d_doc_id = str(uuid.uuid4())
d_document = {
"_id": d_doc_id,
"my_date": date0,
"my_datetime": datetime0,
"nested": {
"n_date": date1,
"n_datetime": datetime1,
},
"nested_list": {
"the_list": [
date0,
datetime0,
date1,
datetime1,
]
},
}
expected_d_document = {
"_id": d_doc_id,
"my_date": datetime0,
"my_datetime": datetime0,
"nested": {
"n_date": datetime1,
"n_datetime": datetime1,
},
"nested_list": {
"the_list": [
datetime0,
datetime0,
datetime1,
datetime1,
]
},
}

_ = await async_writable_vector_collection.insert_one(d_document)

# retrieve it, simple
response0 = await async_writable_vector_collection.find_one(
filter={"_id": d_doc_id}
)
assert response0 is not None
document0 = response0["data"]["document"]
assert document0 == expected_d_document

# retrieve it, lt condition on a date
response1 = await async_writable_vector_collection.find_one(
filter={"nested_list.the_list.0": {"$lt": date1}}
)
assert response1 is not None
document1 = response1["data"]["document"]
assert document1 == expected_d_document

# retrieve it, gte condition on a datetime
response2 = await async_writable_vector_collection.find_one(
filter={"nested.n_date": {"$gte": datetime0}}
)
assert response2 is not None
document2 = response2["data"]["document"]
assert document2 == expected_d_document

# retrieve it, filter == condition on a datetime
response3 = await async_writable_vector_collection.find_one(
filter={"my_date": datetime0}
)
assert response3 is not None
document3 = response3["data"]["document"]
assert document3 == expected_d_document
7 changes: 4 additions & 3 deletions tests/astrapy/test_async_db_dml_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""

import logging
from typing import cast
from typing import cast, Iterable, List

import pytest

Expand Down Expand Up @@ -78,12 +78,13 @@ async def test_vector_find(
async def test_vector_find_float32(
async_readonly_vector_collection: AsyncAstraDBCollection,
) -> None:
def ite():
def ite() -> Iterable[str]:
for v in [0.1, 0.2]:
yield f"{v}"

documents_sim_1 = await async_readonly_vector_collection.vector_find(
vector=ite(),
# we surreptitously trick typing here
vector=cast(List[float], ite()),
limit=3,
)

Expand Down
86 changes: 83 additions & 3 deletions tests/astrapy/test_db_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
"""

import uuid
import datetime
import logging
import json
import httpx
from typing import Dict, List, Literal, Optional, Set
from typing import cast, Dict, Iterable, List, Literal, Optional, Set

import pytest

Expand Down Expand Up @@ -161,15 +162,18 @@ def test_find_find_one_projection(
def test_find_float32(
readonly_vector_collection: AstraDBCollection,
) -> None:
def ite():
def ite() -> Iterable[str]:
for v in [0.1, 0.2]:
yield f"{v}"
sort = {"$vector": ite()}

# we surreptitously trick typing here
sort = {"$vector": cast(List[float], ite())}
options = {"limit": 5}

response = readonly_vector_collection.find(sort=sort, options=options)
assert isinstance(response["data"]["documents"], list)


@pytest.mark.describe("find through vector")
def test_find(readonly_vector_collection: AstraDBCollection) -> None:
sort = {"$vector": [0.2, 0.6]}
Expand Down Expand Up @@ -1176,3 +1180,79 @@ def test_find_find_one_non_equality_operators(
projection=projection,
)
assert resp8["data"]["documents"][0]["marker"] == "abc"


@pytest.mark.describe("store and retrieve dates and datetimes correctly")
def test_insert_find_with_dates(
writable_vector_collection: AstraDBCollection,
) -> None:
date0 = datetime.date(2024, 1, 12)
datetime0 = datetime.datetime(2024, 1, 12, 0, 0)
date1 = datetime.date(2024, 1, 13)
datetime1 = datetime.datetime(2024, 1, 13, 0, 0)

d_doc_id = str(uuid.uuid4())
d_document = {
"_id": d_doc_id,
"my_date": date0,
"my_datetime": datetime0,
"nested": {
"n_date": date1,
"n_datetime": datetime1,
},
"nested_list": {
"the_list": [
date0,
datetime0,
date1,
datetime1,
]
},
}
expected_d_document = {
"_id": d_doc_id,
"my_date": datetime0,
"my_datetime": datetime0,
"nested": {
"n_date": datetime1,
"n_datetime": datetime1,
},
"nested_list": {
"the_list": [
datetime0,
datetime0,
datetime1,
datetime1,
]
},
}

_ = writable_vector_collection.insert_one(d_document)

# retrieve it, simple
response0 = writable_vector_collection.find_one(filter={"_id": d_doc_id})
assert response0 is not None
document0 = response0["data"]["document"]
assert document0 == expected_d_document

# retrieve it, lt condition on a date
response1 = writable_vector_collection.find_one(
filter={"nested_list.the_list.0": {"$lt": date1}}
)
assert response1 is not None
document1 = response1["data"]["document"]
assert document1 == expected_d_document

# retrieve it, gte condition on a datetime
response2 = writable_vector_collection.find_one(
filter={"nested.n_date": {"$gte": datetime0}}
)
assert response2 is not None
document2 = response2["data"]["document"]
assert document2 == expected_d_document

# retrieve it, filter == condition on a datetime
response3 = writable_vector_collection.find_one(filter={"my_date": datetime0})
assert response3 is not None
document3 = response3["data"]["document"]
assert document3 == expected_d_document
Loading

0 comments on commit c0ec96c

Please sign in to comment.