diff --git a/astrapy/api.py b/astrapy/api.py new file mode 100644 index 00000000..c49c4c1b --- /dev/null +++ b/astrapy/api.py @@ -0,0 +1,150 @@ +import logging +import httpx +from typing import Any, Dict, Optional, TypeVar, cast + +from astrapy.types import API_RESPONSE +from astrapy.utils import amake_request, make_request + +T = TypeVar("T", bound="APIRequestHandler") +AT = TypeVar("AT", bound="AsyncAPIRequestHandler") + + +logger = logging.getLogger(__name__) + + +class APIRequestError(ValueError): + def __init__(self, response: httpx.Response) -> None: + super().__init__(response.text) + + self.response = response + + def __repr__(self) -> str: + return f"{self.response}" + + +class APIRequestHandler: + def __init__( + self: T, + client: httpx.Client, + base_url: str, + auth_header: str, + token: str, + method: str, + json_data: Optional[Dict[str, Any]], + url_params: Optional[Dict[str, Any]], + path: Optional[str] = None, + skip_error_check: bool = False, + ) -> None: + self.client = client + self.base_url = base_url + self.auth_header = auth_header + self.token = token + self.method = method + self.path = path + self.json_data = json_data + self.url_params = url_params + self.skip_error_check = skip_error_check + + def raw_request(self: T) -> httpx.Response: + return make_request( + client=self.client, + base_url=self.base_url, + auth_header=self.auth_header, + token=self.token, + method=self.method, + path=self.path, + json_data=self.json_data, + url_params=self.url_params, + ) + + def request(self: T) -> API_RESPONSE: + # Make the raw request to the API + self.response = self.raw_request() + + # If the response was not successful (non-success error code) raise an error directly + self.response.raise_for_status() + + # Otherwise, process the successful response + return self._process_response() + + def _process_response(self: T) -> API_RESPONSE: + # In case of other successful responses, parse the JSON body. + try: + # Cast the response to the expected type. + response_body: API_RESPONSE = cast(API_RESPONSE, self.response.json()) + + # If the API produced an error, warn and return the API request error class + if "errors" in response_body and not self.skip_error_check: + logger.debug(response_body["errors"]) + + raise APIRequestError(self.response) + + # Otherwise, set the response body + return response_body + except ValueError: + # Handle cases where json() parsing fails (e.g., empty body) + raise APIRequestError(self.response) + + +class AsyncAPIRequestHandler: + def __init__( + self: AT, + client: httpx.AsyncClient, + base_url: str, + auth_header: str, + token: str, + method: str, + json_data: Optional[Dict[str, Any]], + url_params: Optional[Dict[str, Any]], + path: Optional[str] = None, + skip_error_check: bool = False, + ) -> None: + self.client = client + self.base_url = base_url + self.auth_header = auth_header + self.token = token + self.method = method + self.path = path + self.json_data = json_data + self.url_params = url_params + self.skip_error_check = skip_error_check + + async def raw_request(self: AT) -> httpx.Response: + return await amake_request( + client=self.client, + base_url=self.base_url, + auth_header=self.auth_header, + token=self.token, + method=self.method, + path=self.path, + json_data=self.json_data, + url_params=self.url_params, + ) + + async def request(self: AT) -> API_RESPONSE: + # Make the raw request to the API + self.response = await self.raw_request() + + # If the response was not successful (non-success error code) raise an error directly + self.response.raise_for_status() + + # Otherwise, process the successful response + return await self._process_response() + + async def _process_response(self: AT) -> API_RESPONSE: + # In case of other successful responses, parse the JSON body. + try: + # Cast the response to the expected type. + response_body: API_RESPONSE = cast(API_RESPONSE, self.response.json()) + + # If the API produced an error, warn and return the API request error class + if "errors" in response_body and not self.skip_error_check: + logger.debug(response_body["errors"]) + + raise APIRequestError(self.response) + + # Otherwise, set the response body + return response_body + except ValueError: + # Handle cases where json() parsing fails (e.g., empty body) + raise APIRequestError(self.response) diff --git a/astrapy/db.py b/astrapy/db.py index 62a0fc9e..02e13cbd 100644 --- a/astrapy/db.py +++ b/astrapy/db.py @@ -15,10 +15,11 @@ import asyncio import httpx -import json import logging +import json import threading + from concurrent.futures import ThreadPoolExecutor from functools import partial from queue import Queue @@ -37,6 +38,7 @@ AsyncGenerator, ) +from astrapy.api import AsyncAPIRequestHandler, APIRequestHandler from astrapy.defaults import ( DEFAULT_AUTH_HEADER, DEFAULT_JSON_API_PATH, @@ -59,6 +61,7 @@ AsyncPaginableRequestMethod, ) + logger = logging.getLogger(__name__) @@ -109,7 +112,7 @@ def _request( skip_error_check: bool = False, **kwargs: Any, ) -> API_RESPONSE: - response = make_request( + request_handler = APIRequestHandler( client=self.client, base_url=self.astra_db.base_url, auth_header=DEFAULT_AUTH_HEADER, @@ -118,13 +121,13 @@ def _request( path=path, json_data=json_data, url_params=url_params, + skip_error_check=skip_error_check, + **kwargs, ) - responsebody = cast(API_RESPONSE, response.json()) - if not skip_error_check and "errors" in responsebody: - raise ValueError(json.dumps(responsebody["errors"])) - else: - return responsebody + response = request_handler.request() + + return response def _get( self, path: Optional[str] = None, options: Optional[Dict[str, Any]] = None @@ -955,7 +958,7 @@ async def _request( skip_error_check: bool = False, **kwargs: Any, ) -> API_RESPONSE: - response = await amake_request( + arequest_handler = AsyncAPIRequestHandler( client=self.client, base_url=self.astra_db.base_url, auth_header=DEFAULT_AUTH_HEADER, @@ -964,13 +967,12 @@ async def _request( path=path, json_data=json_data, url_params=url_params, + skip_error_check=skip_error_check, ) - responsebody = cast(API_RESPONSE, response.json()) - if not skip_error_check and "errors" in responsebody: - raise ValueError(json.dumps(responsebody["errors"])) - else: - return responsebody + response = await arequest_handler.request() + + return response async def _get( self, path: Optional[str] = None, options: Optional[Dict[str, Any]] = None diff --git a/astrapy/ops.py b/astrapy/ops.py index 9fdb52ff..15c319cd 100644 --- a/astrapy/ops.py +++ b/astrapy/ops.py @@ -16,8 +16,9 @@ from typing import Any, cast, Dict, Optional import httpx +from astrapy.api import APIRequestHandler -from astrapy.utils import make_request, http_methods +from astrapy.utils import http_methods from astrapy.defaults import ( DEFAULT_DEV_OPS_AUTH_HEADER, DEFAULT_DEV_OPS_API_VERSION, @@ -56,17 +57,21 @@ def _ops_request( ) -> httpx.Response: _options = {} if options is None else options - return make_request( + request_handler = APIRequestHandler( client=self.client, base_url=self.base_url, - method=method, auth_header=DEFAULT_DEV_OPS_AUTH_HEADER, token=self.token, + method=method, + path=path, json_data=json_data, url_params=_options, - path=path, ) + response = request_handler.raw_request() + + return response + def _json_ops_request( self, method: str, @@ -74,16 +79,22 @@ def _json_ops_request( options: Optional[Dict[str, Any]] = None, json_data: Optional[Dict[str, Any]] = None, ) -> OPS_API_RESPONSE: - req_result = self._ops_request( + _options = {} if options is None else options + + request_handler = APIRequestHandler( + client=self.client, + base_url=self.base_url, + auth_header="Authorization", + token=self.token, method=method, path=path, - options=options, json_data=json_data, + url_params=_options, ) - return cast( - OPS_API_RESPONSE, - req_result.json(), - ) + + response = request_handler.request() + + return response def get_databases( self, options: Optional[Dict[str, Any]] = None diff --git a/astrapy/utils.py b/astrapy/utils.py index 5034c279..5423335c 100644 --- a/astrapy/utils.py +++ b/astrapy/utils.py @@ -141,8 +141,6 @@ async def amake_request( if logger.isEnabledFor(logging.DEBUG): log_request_response(r, json_data) - r.raise_for_status() - return r diff --git a/tests/astrapy/test_async_db_dml.py b/tests/astrapy/test_async_db_dml.py index ef7520cb..faed11f4 100644 --- a/tests/astrapy/test_async_db_dml.py +++ b/tests/astrapy/test_async_db_dml.py @@ -23,6 +23,7 @@ import pytest +from astrapy.api import APIRequestError from astrapy.types import API_DOC from astrapy.db import AsyncAstraDB, AsyncAstraDBCollection @@ -176,7 +177,7 @@ async def test_find_error( sort = {"$vector": "clearly not a list of floats!"} options = {"limit": 100} - with pytest.raises(ValueError): + with pytest.raises(APIRequestError): await async_readonly_vector_collection.find(sort=sort, options=options) diff --git a/tests/astrapy/test_db_dml.py b/tests/astrapy/test_db_dml.py index 1dc0a4e9..f3c59665 100644 --- a/tests/astrapy/test_db_dml.py +++ b/tests/astrapy/test_db_dml.py @@ -19,10 +19,13 @@ import uuid import logging +import json +import httpx from typing import Dict, List, Literal, Optional, Set import pytest +from astrapy.api import APIRequestError from astrapy.types import API_DOC from astrapy.db import AstraDB, AstraDBCollection @@ -41,6 +44,7 @@ def test_truncate_collection_fail(db: AstraDB) -> None: @pytest.mark.describe("should truncate a nonvector collection") def test_truncate_nonvector_collection(db: AstraDB) -> None: col = db.create_collection(TEST_TRUNCATED_NONVECTOR_COLLECTION_NAME) + try: col.insert_one({"a": 1}) assert len(col.find()["data"]["documents"]) == 1 @@ -53,6 +57,7 @@ def test_truncate_nonvector_collection(db: AstraDB) -> None: @pytest.mark.describe("should truncate a collection") def test_truncate_vector_collection(db: AstraDB) -> None: col = db.create_collection(TEST_TRUNCATED_VECTOR_COLLECTION_NAME, dimension=2) + try: col.insert_one({"a": 1, "$vector": [0.1, 0.2]}) assert len(col.find()["data"]["documents"]) == 1 @@ -167,7 +172,7 @@ def test_find_error(readonly_vector_collection: AstraDBCollection) -> None: sort = {"$vector": "clearly not a list of floats!"} options = {"limit": 100} - with pytest.raises(ValueError): + with pytest.raises(APIRequestError): readonly_vector_collection.find(sort=sort, options=options) @@ -531,6 +536,68 @@ def test_insert_many_ordered_false( assert check_response["data"]["document"]["_id"] == _id1 +@pytest.mark.describe("test error handling - duplicate document") +def test_error_handling_duplicate( + writable_vector_collection: AstraDBCollection, +) -> None: + _id1 = str(uuid.uuid4()) + + result1 = writable_vector_collection.insert_one( + { + "_id": _id1, + "a": 1, + "$vector": [0.3, 0.5], + } + ) + + assert result1["status"]["insertedIds"] == [_id1] + assert ( + writable_vector_collection.find_one( + {"_id": result1["status"]["insertedIds"][0]} + )["data"]["document"]["a"] + == 1 + ) + + with pytest.raises(ValueError): + writable_vector_collection.insert_one( + { + "_id": _id1, + "a": 1, + "$vector": [0.3, 0.5], + } + ) + + try: + writable_vector_collection.insert_one( + { + "_id": _id1, + "a": 1, + "$vector": [0.3, 0.5], + } + ) + except ValueError as e: + message = str(e) + parsed_json = json.loads(message) + + assert parsed_json["errors"][0]["errorCode"] == "DOCUMENT_ALREADY_EXISTS" + + +@pytest.mark.describe("test error handling - network error") +def test_error_handling_network( + invalid_writable_vector_collection: AstraDBCollection, +) -> None: + _id1 = str(uuid.uuid4()) + + with pytest.raises(httpx.ConnectError): + invalid_writable_vector_collection.insert_one( + { + "_id": _id1, + "a": 1, + "$vector": [0.3, 0.5], + } + ) + + @pytest.mark.describe("upsert_many") def test_upsert_many( writable_vector_collection: AstraDBCollection, diff --git a/tests/conftest.py b/tests/conftest.py index 1da266e8..5a17d2f5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -59,6 +59,15 @@ def astra_db_credentials_kwargs() -> Dict[str, Optional[str]]: } +@pytest.fixture(scope="session") +def astra_invalid_db_credentials_kwargs() -> Dict[str, Optional[str]]: + return { + "token": ASTRA_DB_APPLICATION_TOKEN, + "api_endpoint": "http://localhost:1234", + "namespace": ASTRA_DB_KEYSPACE, + } + + @pytest.fixture(scope="module") def cliff_uuid() -> str: return str(uuid.uuid4()) @@ -82,6 +91,13 @@ async def async_db( yield db +@pytest.fixture(scope="module") +def invalid_db( + astra_invalid_db_credentials_kwargs: Dict[str, Optional[str]] +) -> AstraDB: + return AstraDB(**astra_invalid_db_credentials_kwargs) + + @pytest.fixture(scope="module") def writable_vector_collection(db: AstraDB) -> Iterable[AstraDBCollection]: """ @@ -120,6 +136,21 @@ async def async_writable_vector_collection( await async_db.delete_collection(TEST_WRITABLE_VECTOR_COLLECTION) +@pytest.fixture(scope="module") +def invalid_writable_vector_collection( + invalid_db: AstraDB, +) -> Iterable[AstraDBCollection]: + """ + This is lasting for the whole test. Functions can write to it, + no guarantee (i.e. each test should use a different ID... + """ + collection = invalid_db.collection( + TEST_WRITABLE_VECTOR_COLLECTION, + ) + + yield collection + + @pytest.fixture(scope="module") def readonly_vector_collection(db: AstraDB) -> Iterable[AstraDBCollection]: collection = db.create_collection(