Skip to content

Commit

Permalink
Add async API
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Dec 19, 2023
1 parent 3534094 commit eafc5dc
Show file tree
Hide file tree
Showing 10 changed files with 2,488 additions and 130 deletions.
1,227 changes: 1,100 additions & 127 deletions astrapy/db.py

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions astrapy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,9 @@
class PaginableRequestMethod(Protocol):
def __call__(self, options: Dict[str, Any]) -> API_RESPONSE:
...


# This is for the (partialed, if necessary) async functions that can be "paginated".
class AsyncPaginableRequestMethod(Protocol):
async def __call__(self, options: Dict[str, Any]) -> API_RESPONSE:
...
43 changes: 43 additions & 0 deletions astrapy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,49 @@ def make_request(
return r


async def amake_request(
client: httpx.AsyncClient,
base_url: str,
auth_header: str,
token: str,
method: str = http_methods.POST,
path: Optional[str] = None,
json_data: Optional[Dict[str, Any]] = None,
url_params: Optional[Dict[str, Any]] = None,
) -> httpx.Response:
"""
Make an HTTP request to a specified URL.
Args:
client (httpx): The httpx client for the request.
base_url (str): The base URL for the request.
auth_header (str): The authentication header key.
token (str): The token used for authentication.
method (str, optional): The HTTP method to use for the request. Default is POST.
path (str, optional): The specific path to append to the base URL.
json_data (dict, optional): JSON payload to be sent with the request.
url_params (dict, optional): URL parameters to be sent with the request.
Returns:
requests.Response: The response from the HTTP request.
"""
r = await client.request(
method=method,
url=f"{base_url}{path}",
params=url_params,
json=json_data,
timeout=DEFAULT_TIMEOUT,
headers={auth_header: token, "User-Agent": f"{package_name}/{__version__}"},
)

if logger.isEnabledFor(logging.DEBUG):
log_request_response(r, json_data)

r.raise_for_status()

return r


def make_payload(top_level: str, **kwargs: Any) -> Dict[str, Any]:
"""
Construct a JSON payload for an HTTP request with a specified top-level key.
Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[pytest]
filterwarnings = ignore::DeprecationWarning
addopts = -v --cov=astrapy --testdox --cov-report term-missing
asyncio_mode = auto
log_cli = 1
log_cli_level = INFO
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ black~=23.11.0
faker~=20.0.0
mypy~=1.7.0
pre-commit~=3.5.0
pytest-asyncio~=0.23.2
pytest-cov~=4.1.0
pytest-testdox~=3.1.0
pytest~=7.4.3
Expand Down
116 changes: 116 additions & 0 deletions tests/astrapy/test_async_db_ddl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Tests for the `db.py` parts related to DML & client creation
"""

import logging
from typing import Dict, Optional

import pytest

from astrapy.db import AsyncAstraDB, AsyncAstraDBCollection
from astrapy.defaults import DEFAULT_KEYSPACE_NAME

TEST_CREATE_DELETE_VECTOR_COLLECTION_NAME = "ephemeral_v_col"
TEST_CREATE_DELETE_NONVECTOR_COLLECTION_NAME = "ephemeral_non_v_col"

logger = logging.getLogger(__name__)


@pytest.mark.describe("should confirm path handling in constructor")
async def test_path_handling(
astra_db_credentials_kwargs: Dict[str, Optional[str]]
) -> None:
async with AsyncAstraDB(**astra_db_credentials_kwargs) as astra_db_1:
url_1 = astra_db_1.base_path

async with AsyncAstraDB(
**astra_db_credentials_kwargs,
api_version="v1",
) as astra_db_2:
url_2 = astra_db_2.base_path

async with AsyncAstraDB(
**astra_db_credentials_kwargs,
api_version="/v1",
) as astra_db_3:
url_3 = astra_db_3.base_path

async with AsyncAstraDB(
**astra_db_credentials_kwargs,
api_version="/v1/",
) as astra_db_4:
url_4 = astra_db_4.base_path

assert url_1 == url_2 == url_3 == url_4

# autofill of the default keyspace name
async with AsyncAstraDB(
**{
**astra_db_credentials_kwargs,
**{"namespace": DEFAULT_KEYSPACE_NAME},
}
) as unspecified_ks_client, AsyncAstraDB(
**{
**astra_db_credentials_kwargs,
**{"namespace": None},
}
) as explicit_ks_client:
assert unspecified_ks_client.base_path == explicit_ks_client.base_path


@pytest.mark.describe("should create, use and destroy a non-vector collection")
async def test_create_use_destroy_nonvector_collection(async_db: AsyncAstraDB) -> None:
col = await async_db.create_collection(TEST_CREATE_DELETE_NONVECTOR_COLLECTION_NAME)
assert isinstance(col, AsyncAstraDBCollection)
await col.insert_one({"_id": "first", "name": "a"})
await col.insert_many(
[
{"_id": "second", "name": "b", "room": 7},
{"name": "c", "room": 7},
{"_id": "last", "type": "unnamed", "room": 7},
]
)
docs = await col.find(filter={"room": 7}, projection={"name": 1})
ids = [doc["_id"] for doc in docs["data"]["documents"]]
assert len(ids) == 3
assert "second" in ids
assert "first" not in ids
auto_id = [id for id in ids if id not in {"second", "last"}][0]
await col.delete_one(auto_id)
assert (await col.find_one(filter={"name": "c"}))["data"]["document"] is None
del_res = await async_db.delete_collection(
TEST_CREATE_DELETE_NONVECTOR_COLLECTION_NAME
)
assert del_res["status"]["ok"] == 1


@pytest.mark.describe("should create and destroy a vector collection")
async def test_create_use_destroy_vector_collection(async_db: AsyncAstraDB) -> None:
col = await async_db.create_collection(
collection_name=TEST_CREATE_DELETE_VECTOR_COLLECTION_NAME, dimension=2
)
assert isinstance(col, AsyncAstraDBCollection)
del_res = await async_db.delete_collection(
collection_name=TEST_CREATE_DELETE_VECTOR_COLLECTION_NAME
)
assert del_res["status"]["ok"] == 1


@pytest.mark.describe("should get all collections")
async def test_get_collections(async_db: AsyncAstraDB) -> None:
res = await async_db.get_collections()
assert res["status"]["collections"] is not None
Loading

0 comments on commit eafc5dc

Please sign in to comment.