Skip to content

Commit

Permalink
Fix #186 - token and api endpoint not optional (#187)
Browse files Browse the repository at this point in the history
* Fix #186 - token and api endpoint not optional

* Update test config

* Clean param specification
  • Loading branch information
erichare authored Feb 1, 2024
1 parent 1efa370 commit d7c91f1
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 53 deletions.
16 changes: 8 additions & 8 deletions astrapy/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -1858,17 +1858,17 @@ class AstraDB:

def __init__(
self,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
token: str,
api_endpoint: str,
api_path: Optional[str] = None,
api_version: Optional[str] = None,
namespace: Optional[str] = None,
) -> None:
"""
Initialize an Astra DB instance.
Args:
token (str, optional): Authentication token for Astra DB.
api_endpoint (str, optional): API endpoint URL.
token (str): Authentication token for Astra DB.
api_endpoint (str): API endpoint URL.
namespace (str, optional): Namespace for the database.
"""
if token is None or api_endpoint is None:
Expand Down Expand Up @@ -2090,17 +2090,17 @@ def truncate_collection(self, collection_name: str) -> AstraDBCollection:
class AsyncAstraDB:
def __init__(
self,
token: Optional[str] = None,
api_endpoint: Optional[str] = None,
token: str,
api_endpoint: str,
api_path: Optional[str] = None,
api_version: Optional[str] = None,
namespace: Optional[str] = None,
) -> None:
"""
Initialize an Astra DB instance.
Args:
token (str, optional): Authentication token for Astra DB.
api_endpoint (str, optional): API endpoint URL.
token (str): Authentication token for Astra DB.
api_endpoint (str): API endpoint URL.
namespace (str, optional): Namespace for the database.
"""
self.client = httpx.AsyncClient()
Expand Down
30 changes: 15 additions & 15 deletions tests/astrapy/test_async_db_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,40 +35,40 @@
async def test_path_handling(
astra_db_credentials_kwargs: Dict[str, Optional[str]]
) -> None:
async with AsyncAstraDB(**astra_db_credentials_kwargs) as astra_db_1:
token = astra_db_credentials_kwargs["token"]
api_endpoint = astra_db_credentials_kwargs["api_endpoint"]
namespace = astra_db_credentials_kwargs.get("namespace")

if token is None or api_endpoint is None:
raise ValueError("Required ASTRA DB configuration is missing")

async with AsyncAstraDB(
token=token, api_endpoint=api_endpoint, namespace=namespace
) as astra_db_1:
url_1 = astra_db_1.base_path

async with AsyncAstraDB(
**astra_db_credentials_kwargs,
api_version="v1",
token=token, api_endpoint=api_endpoint, namespace=namespace, api_version="v1"
) as astra_db_2:
url_2 = astra_db_2.base_path

async with AsyncAstraDB(
**astra_db_credentials_kwargs,
api_version="/v1",
token=token, api_endpoint=api_endpoint, namespace=namespace, api_version="/v1"
) as astra_db_3:
url_3 = astra_db_3.base_path

async with AsyncAstraDB(
**astra_db_credentials_kwargs,
api_version="/v1/",
token=token, api_endpoint=api_endpoint, namespace=namespace, api_version="/v1/"
) as astra_db_4:
url_4 = astra_db_4.base_path

assert url_1 == url_2 == url_3 == url_4

# autofill of the default keyspace name
async with AsyncAstraDB(
**{
**astra_db_credentials_kwargs,
**{"namespace": DEFAULT_KEYSPACE_NAME},
}
token=token, api_endpoint=api_endpoint, namespace=DEFAULT_KEYSPACE_NAME
) as unspecified_ks_client, AsyncAstraDB(
**{
**astra_db_credentials_kwargs,
**{"namespace": None},
}
token=token, api_endpoint=api_endpoint, namespace=None
) as explicit_ks_client:
assert unspecified_ks_client.base_path == explicit_ks_client.base_path

Expand Down
33 changes: 13 additions & 20 deletions tests/astrapy/test_db_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,46 +33,39 @@

@pytest.mark.describe("should confirm path handling in constructor")
def test_path_handling(astra_db_credentials_kwargs: Dict[str, Optional[str]]) -> None:
astra_db_1 = AstraDB(**astra_db_credentials_kwargs)
token = astra_db_credentials_kwargs["token"]
api_endpoint = astra_db_credentials_kwargs["api_endpoint"]
namespace = astra_db_credentials_kwargs.get("namespace")

if token is None or api_endpoint is None:
raise ValueError("Required ASTRA DB configuration is missing")

astra_db_1 = AstraDB(token=token, api_endpoint=api_endpoint, namespace=namespace)
url_1 = astra_db_1.base_path

astra_db_2 = AstraDB(
**astra_db_credentials_kwargs,
api_version="v1",
token=token, api_endpoint=api_endpoint, namespace=namespace, api_version="v1"
)

url_2 = astra_db_2.base_path

astra_db_3 = AstraDB(
**astra_db_credentials_kwargs,
api_version="/v1",
token=token, api_endpoint=api_endpoint, namespace=namespace, api_version="/v1"
)

url_3 = astra_db_3.base_path

astra_db_4 = AstraDB(
**astra_db_credentials_kwargs,
api_version="/v1/",
token=token, api_endpoint=api_endpoint, namespace=namespace, api_version="/v1/"
)

url_4 = astra_db_4.base_path

assert url_1 == url_2 == url_3 == url_4

# autofill of the default keyspace name
unspecified_ks_client = AstraDB(
**{
**astra_db_credentials_kwargs,
**{"namespace": DEFAULT_KEYSPACE_NAME},
}
)
explicit_ks_client = AstraDB(
**{
**astra_db_credentials_kwargs,
**{"namespace": None},
}
token=token, api_endpoint=api_endpoint, namespace=DEFAULT_KEYSPACE_NAME
)
explicit_ks_client = AstraDB(token=token, api_endpoint=api_endpoint, namespace=None)

assert unspecified_ks_client.base_path == explicit_ks_client.base_path


Expand Down
63 changes: 53 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,16 @@
import math

import pytest
from typing import AsyncIterable, Dict, Iterable, List, Optional, Set, TypeVar
from typing import (
AsyncIterable,
Dict,
Iterable,
List,
Optional,
Set,
TypeVar,
TypedDict,
)

import pytest_asyncio

Expand All @@ -15,8 +24,9 @@
T = TypeVar("T")


ASTRA_DB_APPLICATION_TOKEN = os.environ.get("ASTRA_DB_APPLICATION_TOKEN")
ASTRA_DB_API_ENDPOINT = os.environ.get("ASTRA_DB_API_ENDPOINT")
ASTRA_DB_APPLICATION_TOKEN = os.environ["ASTRA_DB_APPLICATION_TOKEN"]
ASTRA_DB_API_ENDPOINT = os.environ["ASTRA_DB_API_ENDPOINT"]

ASTRA_DB_KEYSPACE = os.environ.get("ASTRA_DB_KEYSPACE", DEFAULT_KEYSPACE_NAME)

# fixed
Expand Down Expand Up @@ -49,6 +59,12 @@
]


class AstraDBCredentials(TypedDict, total=False):
token: str
api_endpoint: str
namespace: Optional[str]


def _batch_iterable(iterable: Iterable[T], batch_size: int) -> Iterable[Iterable[T]]:
this_batch = []
for entry in iterable:
Expand All @@ -61,41 +77,68 @@ def _batch_iterable(iterable: Iterable[T], batch_size: int) -> Iterable[Iterable


@pytest.fixture(scope="session")
def astra_db_credentials_kwargs() -> Dict[str, Optional[str]]:
return {
def astra_db_credentials_kwargs() -> AstraDBCredentials:
astra_db_creds: AstraDBCredentials = {
"token": ASTRA_DB_APPLICATION_TOKEN,
"api_endpoint": ASTRA_DB_API_ENDPOINT,
"namespace": ASTRA_DB_KEYSPACE,
}

return astra_db_creds


@pytest.fixture(scope="session")
def astra_invalid_db_credentials_kwargs() -> Dict[str, Optional[str]]:
return {
def astra_invalid_db_credentials_kwargs() -> AstraDBCredentials:
astra_db_creds: AstraDBCredentials = {
"token": ASTRA_DB_APPLICATION_TOKEN,
"api_endpoint": "http://localhost:1234",
"namespace": ASTRA_DB_KEYSPACE,
}

return astra_db_creds


@pytest.fixture(scope="session")
def db(astra_db_credentials_kwargs: Dict[str, Optional[str]]) -> AstraDB:
return AstraDB(**astra_db_credentials_kwargs)
token = astra_db_credentials_kwargs["token"]
api_endpoint = astra_db_credentials_kwargs["api_endpoint"]
namespace = astra_db_credentials_kwargs.get("namespace")

if token is None or api_endpoint is None:
raise ValueError("Required ASTRA DB configuration is missing")

return AstraDB(token=token, api_endpoint=api_endpoint, namespace=namespace)


@pytest_asyncio.fixture(scope="function")
async def async_db(
astra_db_credentials_kwargs: Dict[str, Optional[str]]
) -> AsyncIterable[AsyncAstraDB]:
async with AsyncAstraDB(**astra_db_credentials_kwargs) as db:
token = astra_db_credentials_kwargs["token"]
api_endpoint = astra_db_credentials_kwargs["api_endpoint"]
namespace = astra_db_credentials_kwargs.get("namespace")

if token is None or api_endpoint is None:
raise ValueError("Required ASTRA DB configuration is missing")

async with AsyncAstraDB(
token=token, api_endpoint=api_endpoint, namespace=namespace
) as db:
yield db


@pytest.fixture(scope="module")
def invalid_db(
astra_invalid_db_credentials_kwargs: Dict[str, Optional[str]]
) -> AstraDB:
return AstraDB(**astra_invalid_db_credentials_kwargs)
token = astra_invalid_db_credentials_kwargs["token"]
api_endpoint = astra_invalid_db_credentials_kwargs["api_endpoint"]
namespace = astra_invalid_db_credentials_kwargs.get("namespace")

if token is None or api_endpoint is None:
raise ValueError("Required ASTRA DB configuration is missing")

return AstraDB(token=token, api_endpoint=api_endpoint, namespace=namespace)


@pytest.fixture(scope="session")
Expand Down

0 comments on commit d7c91f1

Please sign in to comment.