From 486248d64ea8057e6e223c0480eb0d5d03803759 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Fri, 20 Sep 2024 18:14:02 +0200 Subject: [PATCH] Ad ruff rules UP(pyupgrade) --- astrapy/admin.py | 416 +++++++------ astrapy/api_commander.py | 80 ++- astrapy/api_options.py | 8 +- astrapy/authentication.py | 20 +- astrapy/client.py | 86 +-- astrapy/collection.py | 471 +++++++-------- astrapy/constants.py | 6 +- astrapy/core/api.py | 72 +-- astrapy/core/core_types.py | 4 +- astrapy/core/db.py | 568 +++++++++--------- astrapy/core/ops.py | 104 ++-- astrapy/core/utils.py | 80 ++- astrapy/cursors.py | 154 +++-- astrapy/database.py | 304 +++++----- astrapy/exceptions.py | 132 ++-- astrapy/info.py | 132 ++-- astrapy/operations.py | 212 +++---- astrapy/request_tools.py | 10 +- astrapy/results.py | 18 +- astrapy/transform_payload.py | 26 +- astrapy/user_agents.py | 13 +- pyproject.toml | 9 +- scripts/astrapy_latest_interface.py | 3 +- tests/conftest.py | 18 +- tests/core/conftest.py | 14 +- tests/core/test_async_db_dml.py | 22 +- tests/core/test_async_db_dml_pagination.py | 3 +- tests/core/test_db_dml.py | 18 +- tests/core/test_db_dml_pagination.py | 3 +- tests/core/test_ops.py | 2 +- tests/idiomatic/integration/test_admin.py | 16 +- tests/idiomatic/integration/test_dml_async.py | 14 +- tests/idiomatic/integration/test_dml_sync.py | 8 +- .../integration/test_exceptions_async.py | 5 +- tests/idiomatic/unit/test_apicommander.py | 5 +- .../idiomatic/unit/test_collection_options.py | 4 +- .../unit/test_document_extractors.py | 4 +- tests/preprocess_env.py | 29 +- tests/vectorize_idiomatic/conftest.py | 6 +- .../test_vectorize_methods_async.py | 12 +- .../test_vectorize_methods_sync.py | 10 +- .../integration/test_vectorize_providers.py | 18 +- tests/vectorize_idiomatic/query_providers.py | 3 +- tests/vectorize_idiomatic/vectorize_models.py | 9 +- 44 files changed, 1550 insertions(+), 1601 deletions(-) diff --git a/astrapy/admin.py b/astrapy/admin.py index 0dea5eeb..17930b44 100644 --- a/astrapy/admin.py +++ b/astrapy/admin.py @@ -21,7 +21,7 @@ import warnings from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any from deprecation import DeprecatedWarning @@ -104,7 +104,7 @@ class ParsedAPIEndpoint: environment: str -def parse_api_endpoint(api_endpoint: str) -> Optional[ParsedAPIEndpoint]: +def parse_api_endpoint(api_endpoint: str) -> ParsedAPIEndpoint | None: """ Parse an API Endpoint into a ParsedAPIEndpoint structure. @@ -137,7 +137,7 @@ def api_endpoint_parsing_error_message(failing_url: str) -> str: ) -def parse_generic_api_url(api_endpoint: str) -> Optional[str]: +def parse_generic_api_url(api_endpoint: str) -> str | None: """ Validate a generic API Endpoint string, such as `http://10.1.1.1:123` or `https://my.domain`. @@ -194,10 +194,10 @@ def build_api_endpoint(environment: str, database_id: str, region: str) -> str: def fetch_raw_database_info_from_id_token( id: str, *, - token: Optional[str], + token: str | None, environment: str = Environment.PROD, - max_time_ms: Optional[int] = None, -) -> Dict[str, Any]: + max_time_ms: int | None = None, +) -> dict[str, Any]: """ Fetch database information through the DevOps API and return it in full, exactly like the API gives it back. @@ -214,7 +214,7 @@ def fetch_raw_database_info_from_id_token( The full response from the DevOps API about the database. """ - ops_headers: Dict[str, str | None] + ops_headers: dict[str, str | None] if token is not None: ops_headers = { DEFAULT_DEV_OPS_AUTH_HEADER: f"{DEFAULT_DEV_OPS_AUTH_PREFIX}{token}", @@ -245,10 +245,10 @@ def fetch_raw_database_info_from_id_token( async def async_fetch_raw_database_info_from_id_token( id: str, *, - token: Optional[str], + token: str | None, environment: str = Environment.PROD, - max_time_ms: Optional[int] = None, -) -> Dict[str, Any]: + max_time_ms: int | None = None, +) -> dict[str, Any]: """ Fetch database information through the DevOps API and return it in full, exactly like the API gives it back. @@ -266,7 +266,7 @@ async def async_fetch_raw_database_info_from_id_token( The full response from the DevOps API about the database. """ - ops_headers: Dict[str, str | None] + ops_headers: dict[str, str | None] if token is not None: ops_headers = { DEFAULT_DEV_OPS_AUTH_HEADER: f"{DEFAULT_DEV_OPS_AUTH_PREFIX}{token}", @@ -296,10 +296,10 @@ async def async_fetch_raw_database_info_from_id_token( def fetch_database_info( api_endpoint: str, - token: Optional[str], - namespace: Optional[str], - max_time_ms: Optional[int] = None, -) -> Optional[DatabaseInfo]: + token: str | None, + namespace: str | None, + max_time_ms: int | None = None, +) -> DatabaseInfo | None: """ Fetch database information through the DevOps API. @@ -342,10 +342,10 @@ def fetch_database_info( async def async_fetch_database_info( api_endpoint: str, - token: Optional[str], - namespace: Optional[str], - max_time_ms: Optional[int] = None, -) -> Optional[DatabaseInfo]: + token: str | None, + namespace: str | None, + max_time_ms: int | None = None, +) -> DatabaseInfo | None: """ Fetch database information through the DevOps API. Async version of the function, for use in an asyncio context. @@ -388,7 +388,7 @@ async def async_fetch_database_info( def _recast_as_admin_database_info( - admin_database_info_dict: Dict[str, Any], + admin_database_info_dict: dict[str, Any], *, environment: str, ) -> AdminDatabaseInfo: @@ -423,10 +423,10 @@ def _recast_as_admin_database_info( def normalize_api_endpoint( id_or_endpoint: str, - region: Optional[str], + region: str | None, token: TokenProvider, environment: str, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> str: """ Ensure that a id(+region) / endpoint init signature is normalized into @@ -480,9 +480,7 @@ def normalize_api_endpoint( return _api_endpoint.strip("/") -def normalize_id_endpoint_parameters( - id: Optional[str], api_endpoint: Optional[str] -) -> str: +def normalize_id_endpoint_parameters(id: str | None, api_endpoint: str | None) -> str: if id is None: if api_endpoint is None: raise ValueError( @@ -538,13 +536,13 @@ class AstraDBAdmin: def __init__( self, - token: Optional[Union[str, TokenProvider]] = None, + token: str | TokenProvider | None = None, *, - environment: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - dev_ops_url: Optional[str] = None, - dev_ops_api_version: Optional[str] = None, + environment: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + dev_ops_url: str | None = None, + dev_ops_api_version: str | None = None, ) -> None: self.token_provider = coerce_token_provider(token) self.environment = (environment or Environment.PROD).lower() @@ -557,7 +555,7 @@ def __init__( self._dev_ops_url = dev_ops_url self._dev_ops_api_version = dev_ops_api_version - self._dev_ops_commander_headers: Dict[str, str | None] + self._dev_ops_commander_headers: dict[str, str | None] if self.token_provider: _token = self.token_provider.get_token() self._dev_ops_commander_headers = { @@ -571,12 +569,12 @@ def __init__( self._dev_ops_api_commander = self._get_dev_ops_api_commander() def __repr__(self) -> str: - token_desc: Optional[str] + token_desc: str | None if self.token_provider: token_desc = f'"{redact_secret(str(self.token_provider), 15)}"' else: token_desc = None - env_desc: Optional[str] + env_desc: str | None if self.environment == Environment.PROD: env_desc = None else: @@ -620,12 +618,12 @@ def _get_dev_ops_api_commander(self) -> APICommander: def _copy( self, *, - token: Optional[Union[str, TokenProvider]] = None, - environment: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - dev_ops_url: Optional[str] = None, - dev_ops_api_version: Optional[str] = None, + token: str | TokenProvider | None = None, + environment: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + dev_ops_url: str | None = None, + dev_ops_api_version: str | None = None, ) -> AstraDBAdmin: return AstraDBAdmin( token=coerce_token_provider(token) or self.token_provider, @@ -639,9 +637,9 @@ def _copy( def with_options( self, *, - token: Optional[Union[str, TokenProvider]] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + token: str | TokenProvider | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> AstraDBAdmin: """ Create a clone of this AstraDBAdmin with some changed attributes. @@ -673,8 +671,8 @@ def with_options( def set_caller( self, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: """ Set a new identity for the application/framework on behalf of which @@ -702,7 +700,7 @@ def set_caller( def list_databases( self, *, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> CommandCursor[AdminDatabaseInfo]: """ Get the list of databases, as obtained with a request to the DevOps API. @@ -752,7 +750,7 @@ def list_databases( async def async_list_databases( self, *, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> CommandCursor[AdminDatabaseInfo]: """ Get the list of databases, as obtained with a request to the DevOps API. @@ -802,7 +800,7 @@ async def async_list_databases( ) def database_info( - self, id: str, *, max_time_ms: Optional[int] = None + self, id: str, *, max_time_ms: int | None = None ) -> AdminDatabaseInfo: """ Get the full information on a given database, through a request to the DevOps API. @@ -838,7 +836,7 @@ def database_info( ) async def async_database_info( - self, id: str, *, max_time_ms: Optional[int] = None + self, id: str, *, max_time_ms: int | None = None ) -> AdminDatabaseInfo: """ Get the full information on a given database, through a request to the DevOps API. @@ -879,9 +877,9 @@ def create_database( *, cloud_provider: str, region: str, - namespace: Optional[str] = None, + namespace: str | None = None, wait_until_active: bool = True, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> AstraDBDatabaseAdmin: """ Create a database as requested, optionally waiting for it to be ready. @@ -985,9 +983,9 @@ async def async_create_database( *, cloud_provider: str, region: str, - namespace: Optional[str] = None, + namespace: str | None = None, wait_until_active: bool = True, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> AstraDBDatabaseAdmin: """ Create a database as requested, optionally waiting for it to be ready. @@ -1095,8 +1093,8 @@ def drop_database( id: str, *, wait_until_active: bool = True, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Drop a database, i.e. delete it completely and permanently with all its data. @@ -1146,8 +1144,8 @@ def drop_database( ) logger.info(f"DevOps API returned from dropping database '{id}'") if wait_until_active: - last_status_seen: Optional[str] = DEV_OPS_DATABASE_STATUS_TERMINATING - _db_name: Optional[str] = None + last_status_seen: str | None = DEV_OPS_DATABASE_STATUS_TERMINATING + _db_name: str | None = None while last_status_seen == DEV_OPS_DATABASE_STATUS_TERMINATING: logger.info(f"sleeping to poll for status of '{id}'") time.sleep(DEV_OPS_DATABASE_POLL_INTERVAL_S) @@ -1178,8 +1176,8 @@ async def async_drop_database( id: str, *, wait_until_active: bool = True, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Drop a database, i.e. delete it completely and permanently with all its data. Async version of the method, for use in an asyncio context. @@ -1226,8 +1224,8 @@ async def async_drop_database( ) logger.info(f"DevOps API returned from dropping database '{id}', async") if wait_until_active: - last_status_seen: Optional[str] = DEV_OPS_DATABASE_STATUS_TERMINATING - _db_name: Optional[str] = None + last_status_seen: str | None = DEV_OPS_DATABASE_STATUS_TERMINATING + _db_name: str | None = None while last_status_seen == DEV_OPS_DATABASE_STATUS_TERMINATING: logger.info(f"sleeping to poll for status of '{id}', async") await asyncio.sleep(DEV_OPS_DATABASE_POLL_INTERVAL_S) @@ -1255,11 +1253,11 @@ async def async_drop_database( def get_database_admin( self, - id: Optional[str] = None, + id: str | None = None, *, - api_endpoint: Optional[str] = None, - region: Optional[str] = None, - max_time_ms: Optional[int] = None, + api_endpoint: str | None = None, + region: str | None = None, + max_time_ms: int | None = None, ) -> AstraDBDatabaseAdmin: """ Create an AstraDBDatabaseAdmin object for admin work within a certain database. @@ -1308,15 +1306,15 @@ def get_database_admin( def get_database( self, - id: Optional[str] = None, + id: str | None = None, *, - api_endpoint: Optional[str] = None, - token: Optional[Union[str, TokenProvider]] = None, - namespace: Optional[str] = None, - region: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, - max_time_ms: Optional[int] = None, + api_endpoint: str | None = None, + token: str | TokenProvider | None = None, + namespace: str | None = None, + region: str | None = None, + api_path: str | None = None, + api_version: str | None = None, + max_time_ms: int | None = None, ) -> Database: """ Create a Database instance for a specific database, to be used @@ -1381,7 +1379,7 @@ def get_database( max_time_ms=max_time_ms, ) - _namespace: Optional[str] + _namespace: str | None if namespace: _namespace = namespace else: @@ -1409,14 +1407,14 @@ def get_database( def get_async_database( self, - id: Optional[str] = None, + id: str | None = None, *, - api_endpoint: Optional[str] = None, - token: Optional[Union[str, TokenProvider]] = None, - namespace: Optional[str] = None, - region: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, + api_endpoint: str | None = None, + token: str | TokenProvider | None = None, + namespace: str | None = None, + region: str | None = None, + api_path: str | None = None, + api_version: str | None = None, ) -> AsyncDatabase: """ Create an AsyncDatabase instance for a specific database, to be used @@ -1445,10 +1443,10 @@ class DatabaseAdmin(ABC): """ environment: str - spawner_database: Union[Database, AsyncDatabase] + spawner_database: Database | AsyncDatabase @abstractmethod - def list_namespaces(self, *pargs: Any, **kwargs: Any) -> List[str]: + def list_namespaces(self, *pargs: Any, **kwargs: Any) -> list[str]: """Get a list of namespaces for the database.""" ... @@ -1457,23 +1455,23 @@ def create_namespace( self, name: str, *, - update_db_namespace: Optional[bool] = None, + update_db_namespace: bool | None = None, **kwargs: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Create a namespace in the database, returning {'ok': 1} if successful. """ ... @abstractmethod - def drop_namespace(self, name: str, *pargs: Any, **kwargs: Any) -> Dict[str, Any]: + def drop_namespace(self, name: str, *pargs: Any, **kwargs: Any) -> dict[str, Any]: """ Drop (delete) a namespace from the database, returning {'ok': 1} if successful. """ ... @abstractmethod - async def async_list_namespaces(self, *pargs: Any, **kwargs: Any) -> List[str]: + async def async_list_namespaces(self, *pargs: Any, **kwargs: Any) -> list[str]: """ Get a list of namespaces for the database. (Async version of the method.) @@ -1482,8 +1480,8 @@ async def async_list_namespaces(self, *pargs: Any, **kwargs: Any) -> List[str]: @abstractmethod async def async_create_namespace( - self, name: str, *, update_db_namespace: Optional[bool] = None, **kwargs: Any - ) -> Dict[str, Any]: + self, name: str, *, update_db_namespace: bool | None = None, **kwargs: Any + ) -> dict[str, Any]: """ Create a namespace in the database, returning {'ok': 1} if successful. (Async version of the method.) @@ -1493,7 +1491,7 @@ async def async_create_namespace( @abstractmethod async def async_drop_namespace( self, name: str, *pargs: Any, **kwargs: Any - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Drop (delete) a namespace from the database, returning {'ok': 1} if successful. (Async version of the method.) @@ -1595,20 +1593,20 @@ class is created by a method such as `Database.get_database_admin()`, def __init__( self, - id: Optional[str] = None, + id: str | None = None, *, - api_endpoint: Optional[str] = None, - token: Optional[Union[str, TokenProvider]] = None, - region: Optional[str] = None, - environment: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - dev_ops_url: Optional[str] = None, - dev_ops_api_version: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, - spawner_database: Optional[Union[Database, AsyncDatabase]] = None, - max_time_ms: Optional[int] = None, + api_endpoint: str | None = None, + token: str | TokenProvider | None = None, + region: str | None = None, + environment: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + dev_ops_url: str | None = None, + dev_ops_api_version: str | None = None, + api_path: str | None = None, + api_version: str | None = None, + spawner_database: Database | AsyncDatabase | None = None, + max_time_ms: int | None = None, ) -> None: # lazy import here to avoid circular dependency from astrapy.database import Database @@ -1681,7 +1679,7 @@ def __init__( if dev_ops_api_version is not None else DEV_OPS_VERSION_ENV_MAP[self.environment] ).strip("/") - self._dev_ops_commander_headers: Dict[str, str | None] + self._dev_ops_commander_headers: dict[str, str | None] if self.token_provider: _token = self.token_provider.get_token() self._dev_ops_commander_headers = { @@ -1703,12 +1701,12 @@ def __init__( def __repr__(self) -> str: ep_desc = f'api_endpoint="{self.api_endpoint}"' - token_desc: Optional[str] + token_desc: str | None if self.token_provider: token_desc = f'token="{redact_secret(str(self.token_provider), 15)}"' else: token_desc = None - env_desc: Optional[str] + env_desc: str | None if self.environment == Environment.PROD: env_desc = None else: @@ -1767,16 +1765,16 @@ def _get_dev_ops_api_commander(self) -> APICommander: def _copy( self, - id: Optional[str] = None, - token: Optional[Union[str, TokenProvider]] = None, - region: Optional[str] = None, - environment: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - dev_ops_url: Optional[str] = None, - dev_ops_api_version: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, + id: str | None = None, + token: str | TokenProvider | None = None, + region: str | None = None, + environment: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + dev_ops_url: str | None = None, + dev_ops_api_version: str | None = None, + api_path: str | None = None, + api_version: str | None = None, ) -> AstraDBDatabaseAdmin: return AstraDBDatabaseAdmin( id=id or self._database_id, @@ -1794,10 +1792,10 @@ def _copy( def with_options( self, *, - id: Optional[str] = None, - token: Optional[Union[str, TokenProvider]] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + id: str | None = None, + token: str | TokenProvider | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> AstraDBDatabaseAdmin: """ Create a clone of this AstraDBDatabaseAdmin with some changed attributes. @@ -1830,8 +1828,8 @@ def with_options( def set_caller( self, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: """ Set a new identity for the application/framework on behalf of which @@ -1883,9 +1881,9 @@ def region(self) -> str: def from_astra_db_admin( id: str, *, - region: Optional[str], + region: str | None, astra_db_admin: AstraDBAdmin, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> AstraDBDatabaseAdmin: """ Create an AstraDBDatabaseAdmin from an AstraDBAdmin and a database ID. @@ -1941,11 +1939,11 @@ def from_astra_db_admin( def from_api_endpoint( api_endpoint: str, *, - token: Optional[Union[str, TokenProvider]] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - dev_ops_url: Optional[str] = None, - dev_ops_api_version: Optional[str] = None, + token: str | TokenProvider | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + dev_ops_url: str | None = None, + dev_ops_api_version: str | None = None, ) -> AstraDBDatabaseAdmin: """ Create an AstraDBDatabaseAdmin from an API Endpoint and optionally a token. @@ -2001,7 +1999,7 @@ def from_api_endpoint( msg = api_endpoint_parsing_error_message(api_endpoint) raise ValueError(msg) - def info(self, *, max_time_ms: Optional[int] = None) -> AdminDatabaseInfo: + def info(self, *, max_time_ms: int | None = None) -> AdminDatabaseInfo: """ Query the DevOps API for the full info on this database. @@ -2027,9 +2025,7 @@ def info(self, *, max_time_ms: Optional[int] = None) -> AdminDatabaseInfo: logger.info(f"finished getting info ('{self._database_id}')") return req_response - async def async_info( - self, *, max_time_ms: Optional[int] = None - ) -> AdminDatabaseInfo: + async def async_info(self, *, max_time_ms: int | None = None) -> AdminDatabaseInfo: """ Query the DevOps API for the full info on this database. Async version of the method, for use in an asyncio context. @@ -2058,7 +2054,7 @@ async def async_info( logger.info(f"finished getting info ('{self._database_id}'), async") return req_response - def list_namespaces(self, *, max_time_ms: Optional[int] = None) -> List[str]: + def list_namespaces(self, *, max_time_ms: int | None = None) -> list[str]: """ Query the DevOps API for a list of the namespaces in the database. @@ -2082,8 +2078,8 @@ def list_namespaces(self, *, max_time_ms: Optional[int] = None) -> List[str]: return info.raw_info["info"]["keyspaces"] # type: ignore[no-any-return] async def async_list_namespaces( - self, *, max_time_ms: Optional[int] = None - ) -> List[str]: + self, *, max_time_ms: int | None = None + ) -> list[str]: """ Query the DevOps API for a list of the namespaces in the database. Async version of the method, for use in an asyncio context. @@ -2120,10 +2116,10 @@ def create_namespace( name: str, *, wait_until_active: bool = True, - update_db_namespace: Optional[bool] = None, - max_time_ms: Optional[int] = None, + update_db_namespace: bool | None = None, + max_time_ms: int | None = None, **kwargs: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Create a namespace in this database as requested, optionally waiting for it to be ready. @@ -2208,10 +2204,10 @@ async def async_create_namespace( name: str, *, wait_until_active: bool = True, - update_db_namespace: Optional[bool] = None, - max_time_ms: Optional[int] = None, + update_db_namespace: bool | None = None, + max_time_ms: int | None = None, **kwargs: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Create a namespace in this database as requested, optionally waiting for it to be ready. @@ -2299,8 +2295,8 @@ def drop_namespace( name: str, *, wait_until_active: bool = True, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Delete a namespace from the database, optionally waiting for it to become active again. @@ -2379,8 +2375,8 @@ async def async_drop_namespace( name: str, *, wait_until_active: bool = True, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Delete a namespace from the database, optionally waiting for it to become active again. @@ -2461,8 +2457,8 @@ def drop( self, *, wait_until_active: bool = True, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Drop this database, i.e. delete it completely and permanently with all its data. @@ -2512,8 +2508,8 @@ async def async_drop( self, *, wait_until_active: bool = True, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Drop this database, i.e. delete it completely and permanently with all its data. Async version of the method, for use in an asyncio context. @@ -2560,12 +2556,12 @@ async def async_drop( def get_database( self, *, - token: Optional[Union[str, TokenProvider]] = None, - namespace: Optional[str] = None, - region: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, - max_time_ms: Optional[int] = None, + token: str | TokenProvider | None = None, + namespace: str | None = None, + region: str | None = None, + api_path: str | None = None, + api_version: str | None = None, + max_time_ms: int | None = None, ) -> Database: """ Create a Database instance from this database admin, for data-related tasks. @@ -2623,12 +2619,12 @@ def get_database( def get_async_database( self, *, - token: Optional[Union[str, TokenProvider]] = None, - namespace: Optional[str] = None, - region: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, - max_time_ms: Optional[int] = None, + token: str | TokenProvider | None = None, + namespace: str | None = None, + region: str | None = None, + api_path: str | None = None, + api_version: str | None = None, + max_time_ms: int | None = None, ) -> AsyncDatabase: """ Create an AsyncDatabase instance out of this class for working @@ -2648,7 +2644,7 @@ def get_async_database( ).to_async() def find_embedding_providers( - self, *, max_time_ms: Optional[int] = None + self, *, max_time_ms: int | None = None ) -> FindEmbeddingProvidersResult: """ Query the API for the full information on available embedding providers. @@ -2691,7 +2687,7 @@ def find_embedding_providers( return FindEmbeddingProvidersResult.from_dict(fe_response["status"]) async def async_find_embedding_providers( - self, *, max_time_ms: Optional[int] = None + self, *, max_time_ms: int | None = None ) -> FindEmbeddingProvidersResult: """ Query the API for the full information on available embedding providers. @@ -2795,13 +2791,13 @@ def __init__( self, api_endpoint: str, *, - token: Optional[Union[str, TokenProvider]] = None, - environment: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - spawner_database: Optional[Union[Database, AsyncDatabase]] = None, + token: str | TokenProvider | None = None, + environment: str | None = None, + api_path: str | None = None, + api_version: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + spawner_database: Database | AsyncDatabase | None = None, ) -> None: # lazy import here to avoid circular dependency from astrapy.database import Database @@ -2836,7 +2832,7 @@ def __init__( def __repr__(self) -> str: ep_desc = f'api_endpoint="{self.api_endpoint}"' - token_desc: Optional[str] + token_desc: str | None if self.token_provider: token_desc = f'token="{redact_secret(str(self.token_provider), 15)}"' else: @@ -2868,13 +2864,13 @@ def _get_api_commander(self) -> APICommander: def _copy( self, - api_endpoint: Optional[str] = None, - token: Optional[Union[str, TokenProvider]] = None, - environment: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + api_endpoint: str | None = None, + token: str | TokenProvider | None = None, + environment: str | None = None, + api_path: str | None = None, + api_version: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> DataAPIDatabaseAdmin: return DataAPIDatabaseAdmin( api_endpoint=api_endpoint or self.api_endpoint, @@ -2889,10 +2885,10 @@ def _copy( def with_options( self, *, - api_endpoint: Optional[str] = None, - token: Optional[Union[str, TokenProvider]] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + api_endpoint: str | None = None, + token: str | TokenProvider | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> DataAPIDatabaseAdmin: """ Create a clone of this DataAPIDatabaseAdmin with some changed attributes. @@ -2925,8 +2921,8 @@ def with_options( def set_caller( self, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: """ Set a new identity for the application/framework on behalf of which @@ -2951,7 +2947,7 @@ def set_caller( self.caller_version = caller_version self._api_commander = self._get_api_commander() - def list_namespaces(self, *, max_time_ms: Optional[int] = None) -> List[str]: + def list_namespaces(self, *, max_time_ms: int | None = None) -> list[str]: """ Query the API for a list of the namespaces in the database. @@ -2983,11 +2979,11 @@ def create_namespace( self, name: str, *, - replication_options: Optional[Dict[str, Any]] = None, - update_db_namespace: Optional[bool] = None, - max_time_ms: Optional[int] = None, + replication_options: dict[str, Any] | None = None, + update_db_namespace: bool | None = None, + max_time_ms: int | None = None, **kwargs: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Create a namespace in the database, returning {'ok': 1} if successful. @@ -3052,8 +3048,8 @@ def drop_namespace( self, name: str, *, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Drop (delete) a namespace from the database. @@ -3092,8 +3088,8 @@ def drop_namespace( return dn_response["status"] # type: ignore[no-any-return] async def async_list_namespaces( - self, *, max_time_ms: Optional[int] = None - ) -> List[str]: + self, *, max_time_ms: int | None = None + ) -> list[str]: """ Query the API for a list of the namespaces in the database. Async version of the method, for use in an asyncio context. @@ -3126,11 +3122,11 @@ async def async_create_namespace( self, name: str, *, - replication_options: Optional[Dict[str, Any]] = None, - update_db_namespace: Optional[bool] = None, - max_time_ms: Optional[int] = None, + replication_options: dict[str, Any] | None = None, + update_db_namespace: bool | None = None, + max_time_ms: int | None = None, **kwargs: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Create a namespace in the database, returning {'ok': 1} if successful. Async version of the method, for use in an asyncio context. @@ -3198,8 +3194,8 @@ async def async_drop_namespace( self, name: str, *, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Drop (delete) a namespace from the database. Async version of the method, for use in an asyncio context. @@ -3243,10 +3239,10 @@ async def async_drop_namespace( def get_database( self, *, - token: Optional[Union[str, TokenProvider]] = None, - namespace: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, + token: str | TokenProvider | None = None, + namespace: str | None = None, + api_path: str | None = None, + api_version: str | None = None, ) -> Database: """ Create a Database instance out of this class for working with the data in it. @@ -3295,10 +3291,10 @@ def get_database( def get_async_database( self, *, - token: Optional[Union[str, TokenProvider]] = None, - namespace: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, + token: str | TokenProvider | None = None, + namespace: str | None = None, + api_path: str | None = None, + api_version: str | None = None, ) -> AsyncDatabase: """ Create an AsyncDatabase instance for the database, to be used @@ -3315,7 +3311,7 @@ def get_async_database( ).to_async() def find_embedding_providers( - self, *, max_time_ms: Optional[int] = None + self, *, max_time_ms: int | None = None ) -> FindEmbeddingProvidersResult: """ Query the API for the full information on available embedding providers. @@ -3358,7 +3354,7 @@ def find_embedding_providers( return FindEmbeddingProvidersResult.from_dict(fe_response["status"]) async def async_find_embedding_providers( - self, *, max_time_ms: Optional[int] = None + self, *, max_time_ms: int | None = None ) -> FindEmbeddingProvidersResult: """ Query the API for the full information on available embedding providers. diff --git a/astrapy/api_commander.py b/astrapy/api_commander.py index 32a15490..39d2280d 100644 --- a/astrapy/api_commander.py +++ b/astrapy/api_commander.py @@ -21,11 +21,6 @@ Any, Dict, Iterable, - List, - Optional, - Tuple, - Type, - Union, cast, ) @@ -75,8 +70,8 @@ def __init__( self, api_endpoint: str, path: str, - headers: Dict[str, Union[str, None]] = {}, - callers: List[Tuple[Optional[str], Optional[str]]] = [], + headers: dict[str, str | None] = {}, + callers: list[tuple[str | None, str | None]] = [], redacted_header_names: Iterable[str] = DEFAULT_REDACTED_HEADER_NAMES, dev_ops_api: bool = False, ) -> None: @@ -88,15 +83,14 @@ def __init__( self.redacted_header_names = set(redacted_header_names) self.dev_ops_api = dev_ops_api - self._faulty_response_exc_class: Union[ - Type[DevOpsAPIFaultyResponseException], Type[DataAPIFaultyResponseException] - ] - self._response_exc_class: Union[ - Type[DevOpsAPIResponseException], Type[DataAPIResponseException] - ] - self._http_exc_class: Union[ - Type[DataAPIHttpException], Type[DevOpsAPIHttpException] - ] + self._faulty_response_exc_class: ( + type[DevOpsAPIFaultyResponseException] + | type[DataAPIFaultyResponseException] + ) + self._response_exc_class: ( + type[DevOpsAPIResponseException] | type[DataAPIResponseException] + ) + self._http_exc_class: type[DataAPIHttpException] | type[DevOpsAPIHttpException] if self.dev_ops_api: self._faulty_response_exc_class = DevOpsAPIFaultyResponseException self._response_exc_class = DevOpsAPIResponseException @@ -109,10 +103,10 @@ def __init__( full_user_agent_string = compose_full_user_agent( [user_agent_ragstack] + self.callers + [user_agent_astrapy] ) - self.caller_header: Dict[str, str] = ( + self.caller_header: dict[str, str] = ( {"User-Agent": full_user_agent_string} if full_user_agent_string else {} ) - self.full_headers: Dict[str, str] = { + self.full_headers: dict[str, str] = { **{k: v for k, v in self.headers.items() if v is not None}, **self.caller_header, **{"Content-Type": "application/json"}, @@ -143,20 +137,20 @@ async def __aenter__(self) -> APICommander: async def __aexit__( self, - exc_type: Optional[Type[BaseException]] = None, - exc_value: Optional[BaseException] = None, - traceback: Optional[TracebackType] = None, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, ) -> None: await self.async_client.aclose() def _copy( self, - api_endpoint: Optional[str] = None, - path: Optional[str] = None, - headers: Optional[Dict[str, Union[str, None]]] = None, - callers: Optional[List[Tuple[Optional[str], Optional[str]]]] = None, - redacted_header_names: Optional[List[str]] = None, - dev_ops_api: Optional[bool] = None, + api_endpoint: str | None = None, + path: str | None = None, + headers: dict[str, str | None] | None = None, + callers: list[tuple[str | None, str | None]] | None = None, + redacted_header_names: list[str] | None = None, + dev_ops_api: bool | None = None, ) -> APICommander: # some care in allowing e.g. {} to override (but not None): return APICommander( @@ -178,10 +172,10 @@ def _raw_response_to_json( self, raw_response: httpx.Response, raise_api_errors: bool, - payload: Optional[Dict[str, Any]], - ) -> Dict[str, Any]: + payload: dict[str, Any] | None, + ) -> dict[str, Any]: # try to process the httpx raw response into a JSON or throw a failure - raw_response_json: Dict[str, Any] + raw_response_json: dict[str, Any] try: raw_response_json = cast( Dict[str, Any], @@ -212,15 +206,15 @@ def _raw_response_to_json( response_json = restore_from_api(raw_response_json) return response_json - def _compose_request_url(self, additional_path: Optional[str]) -> str: + def _compose_request_url(self, additional_path: str | None) -> str: if additional_path: return "/".join([self.full_path.rstrip("/"), additional_path.lstrip("/")]) else: return self.full_path def _encode_payload( - self, normalized_payload: Optional[Dict[str, Any]] - ) -> Optional[bytes]: + self, normalized_payload: dict[str, Any] | None + ) -> bytes | None: if normalized_payload is not None: return json.dumps( normalized_payload, @@ -234,8 +228,8 @@ def raw_request( self, *, http_method: str = HttpMethod.POST, - payload: Optional[Dict[str, Any]] = None, - additional_path: Optional[str] = None, + payload: dict[str, Any] | None = None, + additional_path: str | None = None, raise_api_errors: bool = True, timeout_info: TimeoutInfoWideType = None, ) -> httpx.Response: @@ -276,8 +270,8 @@ async def async_raw_request( self, *, http_method: str = HttpMethod.POST, - payload: Optional[Dict[str, Any]] = None, - additional_path: Optional[str] = None, + payload: dict[str, Any] | None = None, + additional_path: str | None = None, raise_api_errors: bool = True, timeout_info: TimeoutInfoWideType = None, ) -> httpx.Response: @@ -318,11 +312,11 @@ def request( self, *, http_method: str = HttpMethod.POST, - payload: Optional[Dict[str, Any]] = None, - additional_path: Optional[str] = None, + payload: dict[str, Any] | None = None, + additional_path: str | None = None, raise_api_errors: bool = True, timeout_info: TimeoutInfoWideType = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: raw_response = self.raw_request( http_method=http_method, payload=payload, @@ -338,11 +332,11 @@ async def async_request( self, *, http_method: str = HttpMethod.POST, - payload: Optional[Dict[str, Any]] = None, - additional_path: Optional[str] = None, + payload: dict[str, Any] | None = None, + additional_path: str | None = None, raise_api_errors: bool = True, timeout_info: TimeoutInfoWideType = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: raw_response = await self.async_raw_request( http_method=http_method, payload=payload, diff --git a/astrapy/api_options.py b/astrapy/api_options.py index 5b84179b..d0f0dc7e 100644 --- a/astrapy/api_options.py +++ b/astrapy/api_options.py @@ -15,7 +15,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Optional, TypeVar +from typing import TypeVar from astrapy.authentication import ( EmbeddingAPIKeyHeaderProvider, @@ -41,9 +41,9 @@ class BaseAPIOptions: much sense. """ - max_time_ms: Optional[int] = None + max_time_ms: int | None = None - def with_default(self: AO, default: Optional[BaseAPIOptions]) -> AO: + def with_default(self: AO, default: BaseAPIOptions | None) -> AO: """ Return a new instance created by completing this instance with a default API options object. @@ -70,7 +70,7 @@ def with_default(self: AO, default: Optional[BaseAPIOptions]) -> AO: else: return self - def with_override(self: AO, override: Optional[BaseAPIOptions]) -> AO: + def with_override(self: AO, override: BaseAPIOptions | None) -> AO: """ Return a new instance created by overriding the members of this instance with those taken from a supplied "override" API options object. diff --git a/astrapy/authentication.py b/astrapy/authentication.py index fbe0776d..8b009f43 100644 --- a/astrapy/authentication.py +++ b/astrapy/authentication.py @@ -16,7 +16,7 @@ import base64 from abc import ABC, abstractmethod -from typing import Any, Dict, Optional, Union +from typing import Any from astrapy.defaults import ( EMBEDDING_HEADER_API_KEY, @@ -28,7 +28,7 @@ ) -def coerce_token_provider(token: Optional[Union[str, TokenProvider]]) -> TokenProvider: +def coerce_token_provider(token: str | TokenProvider | None) -> TokenProvider: if isinstance(token, TokenProvider): return token else: @@ -36,7 +36,7 @@ def coerce_token_provider(token: Optional[Union[str, TokenProvider]]) -> TokenPr def coerce_embedding_headers_provider( - embedding_api_key: Optional[Union[str, EmbeddingHeadersProvider]], + embedding_api_key: str | EmbeddingHeadersProvider | None, ) -> EmbeddingHeadersProvider: if isinstance(embedding_api_key, EmbeddingHeadersProvider): return embedding_api_key @@ -121,7 +121,7 @@ def __bool__(self) -> bool: return self.get_token() is not None @abstractmethod - def get_token(self) -> Union[str, None]: + def get_token(self) -> str | None: """ Produce a string for direct use as token in a subsequent API request, or None for no token. @@ -146,7 +146,7 @@ class StaticTokenProvider(TokenProvider): ... ) """ - def __init__(self, token: Union[str, None]) -> None: + def __init__(self, token: str | None) -> None: self.token = token def __repr__(self) -> str: @@ -155,7 +155,7 @@ def __repr__(self) -> str: else: return self.token - def get_token(self) -> Union[str, None]: + def get_token(self) -> str | None: return self.token @@ -230,7 +230,7 @@ def __bool__(self) -> bool: return self.get_headers() != {} @abstractmethod - def get_headers(self) -> Dict[str, str]: + def get_headers(self) -> dict[str, str]: """ Produce a dictionary for use as (part of) the headers in HTTP requests to the Data API. @@ -277,7 +277,7 @@ class EmbeddingAPIKeyHeaderProvider(EmbeddingHeadersProvider): ... ) """ - def __init__(self, embedding_api_key: Optional[str]) -> None: + def __init__(self, embedding_api_key: str | None) -> None: self.embedding_api_key = embedding_api_key def __repr__(self) -> str: @@ -286,7 +286,7 @@ def __repr__(self) -> str: else: return f'{self.__class__.__name__}("{redact_secret(self.embedding_api_key, 8)}")' - def get_headers(self) -> Dict[str, str]: + def get_headers(self) -> dict[str, str]: if self.embedding_api_key is not None: return {EMBEDDING_HEADER_API_KEY: self.embedding_api_key} else: @@ -347,7 +347,7 @@ def __repr__(self) -> str: f'embedding_secret_id="{redact_secret(self.embedding_secret_id, 6)}")' ) - def get_headers(self) -> Dict[str, str]: + def get_headers(self) -> dict[str, str]: return { EMBEDDING_HEADER_AWS_ACCESS_ID: self.embedding_access_id, EMBEDDING_HEADER_AWS_SECRET_ID: self.embedding_secret_id, diff --git a/astrapy/client.py b/astrapy/client.py index 1ae8439e..cb426263 100644 --- a/astrapy/client.py +++ b/astrapy/client.py @@ -16,7 +16,7 @@ import logging import re -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any from astrapy.admin import ( api_endpoint_parser, @@ -80,11 +80,11 @@ class DataAPIClient: def __init__( self, - token: Optional[Union[str, TokenProvider]] = None, + token: str | TokenProvider | None = None, *, - environment: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + environment: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: self.token_provider = coerce_token_provider(token) self.environment = (environment or Environment.PROD).lower() @@ -96,12 +96,12 @@ def __init__( self._caller_version = caller_version def __repr__(self) -> str: - token_desc: Optional[str] + token_desc: str | None if self.token_provider: token_desc = f'"{redact_secret(str(self.token_provider), 15)}"' else: token_desc = None - env_desc: Optional[str] + env_desc: str | None if self.environment == Environment.PROD: env_desc = None else: @@ -139,10 +139,10 @@ def __getitem__(self, database_id_or_api_endpoint: str) -> Database: def _copy( self, *, - token: Optional[Union[str, TokenProvider]] = None, - environment: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + token: str | TokenProvider | None = None, + environment: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> DataAPIClient: return DataAPIClient( token=coerce_token_provider(token) or self.token_provider, @@ -154,9 +154,9 @@ def _copy( def with_options( self, *, - token: Optional[Union[str, TokenProvider]] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + token: str | TokenProvider | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> DataAPIClient: """ Create a clone of this DataAPIClient with some changed attributes. @@ -188,8 +188,8 @@ def with_options( def set_caller( self, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: """ Set a new identity for the application/framework on behalf of which @@ -212,15 +212,15 @@ def set_caller( def get_database( self, - id: Optional[str] = None, + id: str | None = None, *, - api_endpoint: Optional[str] = None, - token: Optional[Union[str, TokenProvider]] = None, - namespace: Optional[str] = None, - region: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, - max_time_ms: Optional[int] = None, + api_endpoint: str | None = None, + token: str | TokenProvider | None = None, + namespace: str | None = None, + region: str | None = None, + api_path: str | None = None, + api_version: str | None = None, + max_time_ms: int | None = None, ) -> Database: """ Get a Database object from this client, for doing data-related work. @@ -341,15 +341,15 @@ def get_database( def get_async_database( self, - id: Optional[str] = None, + id: str | None = None, *, - api_endpoint: Optional[str] = None, - token: Optional[Union[str, TokenProvider]] = None, - namespace: Optional[str] = None, - region: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, - max_time_ms: Optional[int] = None, + api_endpoint: str | None = None, + token: str | TokenProvider | None = None, + namespace: str | None = None, + region: str | None = None, + api_path: str | None = None, + api_version: str | None = None, + max_time_ms: int | None = None, ) -> AsyncDatabase: """ Get an AsyncDatabase object from this client. @@ -373,10 +373,10 @@ def get_database_by_api_endpoint( self, api_endpoint: str, *, - token: Optional[Union[str, TokenProvider]] = None, - namespace: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, + token: str | TokenProvider | None = None, + namespace: str | None = None, + api_path: str | None = None, + api_version: str | None = None, ) -> Database: """ Get a Database object from this client, for doing data-related work. @@ -469,10 +469,10 @@ def get_async_database_by_api_endpoint( self, api_endpoint: str, *, - token: Optional[Union[str, TokenProvider]] = None, - namespace: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, + token: str | TokenProvider | None = None, + namespace: str | None = None, + api_path: str | None = None, + api_version: str | None = None, ) -> AsyncDatabase: """ Get an AsyncDatabase object from this client, for doing data-related work. @@ -497,9 +497,9 @@ def get_async_database_by_api_endpoint( def get_admin( self, *, - token: Optional[Union[str, TokenProvider]] = None, - dev_ops_url: Optional[str] = None, - dev_ops_api_version: Optional[str] = None, + token: str | TokenProvider | None = None, + dev_ops_url: str | None = None, + dev_ops_api_version: str | None = None, ) -> AstraDBAdmin: """ Get an AstraDBAdmin instance corresponding to this client, for diff --git a/astrapy/collection.py b/astrapy/collection.py index 203f206c..5d12fbca 100644 --- a/astrapy/collection.py +++ b/astrapy/collection.py @@ -22,13 +22,7 @@ from typing import ( TYPE_CHECKING, Any, - Dict, Iterable, - List, - Optional, - Tuple, - Type, - Union, ) import deprecation @@ -85,7 +79,7 @@ logger = logging.getLogger(__name__) -def _prepare_update_info(statuses: List[Dict[str, Any]]) -> Dict[str, Any]: +def _prepare_update_info(statuses: list[dict[str, Any]]) -> dict[str, Any]: reduced_status = { "matchedCount": sum( status["matchedCount"] for status in statuses if "matchedCount" in status @@ -116,11 +110,11 @@ def _prepare_update_info(statuses: List[Dict[str, Any]]) -> Dict[str, Any]: def _collate_vector_to_sort( - sort: Optional[SortType], - vector: Optional[VectorType], - vectorize: Optional[str], -) -> Optional[SortType]: - _vsort: Dict[str, Any] + sort: SortType | None, + vector: VectorType | None, + vectorize: str | None, +) -> SortType | None: + _vsort: dict[str, Any] if vector is None: if vectorize is None: return sort @@ -147,7 +141,7 @@ def _collate_vector_to_sort( ) -def _is_vector_sort(sort: Optional[SortType]) -> bool: +def _is_vector_sort(sort: SortType | None) -> bool: if sort is None: return False else: @@ -155,7 +149,7 @@ def _is_vector_sort(sort: Optional[SortType]) -> bool: def _collate_vector_to_document( - document0: DocumentType, vector: Optional[VectorType], vectorize: Optional[str] + document0: DocumentType, vector: VectorType | None, vectorize: str | None ) -> DocumentType: if vector is None: if vectorize is None: @@ -191,9 +185,9 @@ def _collate_vector_to_document( def _collate_vectors_to_documents( documents: Iterable[DocumentType], - vectors: Optional[Iterable[Optional[VectorType]]], - vectorize: Optional[Iterable[Optional[str]]], -) -> List[DocumentType]: + vectors: Iterable[VectorType | None] | None, + vectorize: Iterable[str | None] | None, +) -> list[DocumentType]: if vectors is None and vectorize is None: return list(documents) else: @@ -264,10 +258,10 @@ def __init__( database: Database, name: str, *, - namespace: Optional[str] = None, - api_options: Optional[CollectionAPIOptions] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + namespace: str | None = None, + api_options: CollectionAPIOptions | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: if api_options is None: self.api_options = CollectionAPIOptions() @@ -350,12 +344,12 @@ def _get_api_commander(self) -> APICommander: def _copy( self, *, - database: Optional[Database] = None, - name: Optional[str] = None, - namespace: Optional[str] = None, - api_options: Optional[CollectionAPIOptions] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + database: Database | None = None, + name: str | None = None, + namespace: str | None = None, + api_options: CollectionAPIOptions | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> Collection: return Collection( database=database or self.database._copy(), @@ -369,11 +363,11 @@ def _copy( def with_options( self, *, - name: Optional[str] = None, - embedding_api_key: Optional[Union[str, EmbeddingHeadersProvider]] = None, - collection_max_time_ms: Optional[int] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + name: str | None = None, + embedding_api_key: str | EmbeddingHeadersProvider | None = None, + collection_max_time_ms: int | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> Collection: """ Create a clone of this collection with some changed attributes. @@ -428,13 +422,13 @@ def with_options( def to_async( self, *, - database: Optional[AsyncDatabase] = None, - name: Optional[str] = None, - namespace: Optional[str] = None, - embedding_api_key: Optional[Union[str, EmbeddingHeadersProvider]] = None, - collection_max_time_ms: Optional[int] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + database: AsyncDatabase | None = None, + name: str | None = None, + namespace: str | None = None, + embedding_api_key: str | EmbeddingHeadersProvider | None = None, + collection_max_time_ms: int | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> AsyncCollection: """ Create an AsyncCollection from this one. Save for the arguments @@ -494,8 +488,8 @@ def to_async( def set_caller( self, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: """ Set a new identity for the application/framework on behalf of which @@ -515,7 +509,7 @@ def set_caller( self.caller_version = caller_version or self.caller_version self._api_commander = self._get_api_commander() - def options(self, *, max_time_ms: Optional[int] = None) -> CollectionOptions: + def options(self, *, max_time_ms: int | None = None) -> CollectionOptions: """ Get the collection options, i.e. its configuration as read from the database. @@ -638,9 +632,9 @@ def insert_one( self, document: DocumentType, *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - max_time_ms: Optional[int] = None, + vector: VectorType | None = None, + vectorize: str | None = None, + max_time_ms: int | None = None, ) -> InsertOneResult: """ Insert a single document in the collection in an atomic operation. @@ -726,12 +720,12 @@ def insert_many( self, documents: Iterable[DocumentType], *, - vectors: Optional[Iterable[Optional[VectorType]]] = None, - vectorize: Optional[Iterable[Optional[str]]] = None, + vectors: Iterable[VectorType | None] | None = None, + vectorize: Iterable[str | None] | None = None, ordered: bool = False, - chunk_size: Optional[int] = None, - concurrency: Optional[int] = None, - max_time_ms: Optional[int] = None, + chunk_size: int | None = None, + concurrency: int | None = None, + max_time_ms: int | None = None, ) -> InsertManyResult: """ Insert a list of documents into the collection. @@ -840,11 +834,11 @@ def insert_many( _documents = _collate_vectors_to_documents(documents, vectors, vectorize) _max_time_ms = max_time_ms or self.api_options.max_time_ms logger.info(f"inserting {len(_documents)} documents in '{self.name}'") - raw_results: List[Dict[str, Any]] = [] + raw_results: list[dict[str, Any]] = [] timeout_manager = MultiCallTimeoutManager(overall_max_time_ms=_max_time_ms) if ordered: options = {"ordered": True} - inserted_ids: List[Any] = [] + inserted_ids: list[Any] = [] for i in range(0, len(_documents), _chunk_size): im_payload = { "insertMany": { @@ -894,8 +888,8 @@ def insert_many( with ThreadPoolExecutor(max_workers=_concurrency) as executor: def _chunk_insertor( - document_chunk: List[Dict[str, Any]] - ) -> Dict[str, Any]: + document_chunk: list[dict[str, Any]], + ) -> dict[str, Any]: im_payload = { "insertMany": { "documents": document_chunk, @@ -971,17 +965,17 @@ def _chunk_insertor( def find( self, - filter: Optional[FilterType] = None, + filter: FilterType | None = None, *, - projection: Optional[ProjectionType] = None, - skip: Optional[int] = None, - limit: Optional[int] = None, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - include_similarity: Optional[bool] = None, - include_sort_vector: Optional[bool] = None, - sort: Optional[SortType] = None, - max_time_ms: Optional[int] = None, + projection: ProjectionType | None = None, + skip: int | None = None, + limit: int | None = None, + vector: VectorType | None = None, + vectorize: str | None = None, + include_similarity: bool | None = None, + include_sort_vector: bool | None = None, + sort: SortType | None = None, + max_time_ms: int | None = None, ) -> Cursor: """ Find documents on the collection, matching a certain provided filter. @@ -1184,15 +1178,15 @@ def find( def find_one( self, - filter: Optional[FilterType] = None, + filter: FilterType | None = None, *, - projection: Optional[ProjectionType] = None, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - include_similarity: Optional[bool] = None, - sort: Optional[SortType] = None, - max_time_ms: Optional[int] = None, - ) -> Union[DocumentType, None]: + projection: ProjectionType | None = None, + vector: VectorType | None = None, + vectorize: str | None = None, + include_similarity: bool | None = None, + sort: SortType | None = None, + max_time_ms: int | None = None, + ) -> DocumentType | None: """ Run a search, returning the first document in the collection that matches provided filters, if any is found. @@ -1293,9 +1287,9 @@ def distinct( self, key: str, *, - filter: Optional[FilterType] = None, - max_time_ms: Optional[int] = None, - ) -> List[Any]: + filter: FilterType | None = None, + max_time_ms: int | None = None, + ) -> list[Any]: """ Return a list of the unique values of `key` across the documents in the collection that match the provided filter. @@ -1373,7 +1367,7 @@ def count_documents( filter: FilterType, *, upper_bound: int, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> int: """ Count the documents in the collection matching the specified filter. @@ -1453,7 +1447,7 @@ def count_documents( def estimated_document_count( self, *, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> int: """ Query the API server for an estimate of the document count in the collection. @@ -1472,7 +1466,7 @@ def estimated_document_count( 35700 """ _max_time_ms = max_time_ms or self.api_options.max_time_ms - ed_payload: Dict[str, Any] = {"estimatedDocumentCount": {}} + ed_payload: dict[str, Any] = {"estimatedDocumentCount": {}} logger.info(f"estimatedDocumentCount on '{self.name}'") ed_response = self._api_commander.request( payload=ed_payload, @@ -1493,14 +1487,14 @@ def find_one_and_replace( filter: FilterType, replacement: DocumentType, *, - projection: Optional[ProjectionType] = None, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, + projection: ProjectionType | None = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, upsert: bool = False, return_document: str = ReturnDocument.BEFORE, - max_time_ms: Optional[int] = None, - ) -> Union[DocumentType, None]: + max_time_ms: int | None = None, + ) -> DocumentType | None: """ Find a document on the collection and replace it entirely with a new one, optionally inserting a new one if no match is found. @@ -1641,11 +1635,11 @@ def replace_one( filter: FilterType, replacement: DocumentType, *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, upsert: bool = False, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> UpdateResult: """ Replace a single document on the collection with a new one, @@ -1742,16 +1736,16 @@ def replace_one( def find_one_and_update( self, filter: FilterType, - update: Dict[str, Any], + update: dict[str, Any], *, - projection: Optional[ProjectionType] = None, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, + projection: ProjectionType | None = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, upsert: bool = False, return_document: str = ReturnDocument.BEFORE, - max_time_ms: Optional[int] = None, - ) -> Union[DocumentType, None]: + max_time_ms: int | None = None, + ) -> DocumentType | None: """ Find a document on the collection and update it as requested, optionally inserting a new one if no match is found. @@ -1896,13 +1890,13 @@ def find_one_and_update( def update_one( self, filter: FilterType, - update: Dict[str, Any], + update: dict[str, Any], *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, upsert: bool = False, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> UpdateResult: """ Update a single document on the collection as requested, @@ -2003,10 +1997,10 @@ def update_one( def update_many( self, filter: FilterType, - update: Dict[str, Any], + update: dict[str, Any], *, upsert: bool = False, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> UpdateResult: """ Apply an update operations to all documents matching a condition, @@ -2066,9 +2060,9 @@ def update_many( api_options = { "upsert": upsert, } - page_state_options: Dict[str, str] = {} - um_responses: List[Dict[str, Any]] = [] - um_statuses: List[Dict[str, Any]] = [] + page_state_options: dict[str, str] = {} + um_responses: list[dict[str, Any]] = [] + um_statuses: list[dict[str, Any]] = [] must_proceed = True _max_time_ms = max_time_ms or self.api_options.max_time_ms logger.info(f"starting update_many on '{self.name}'") @@ -2134,12 +2128,12 @@ def find_one_and_delete( self, filter: FilterType, *, - projection: Optional[ProjectionType] = None, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, - max_time_ms: Optional[int] = None, - ) -> Union[DocumentType, None]: + projection: ProjectionType | None = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, + max_time_ms: int | None = None, + ) -> DocumentType | None: """ Find a document in the collection and delete it. The deleted document, however, is the return value of the method. @@ -2251,10 +2245,10 @@ def delete_one( self, filter: FilterType, *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, - max_time_ms: Optional[int] = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, + max_time_ms: int | None = None, ) -> DeleteResult: """ Delete one document matching a provided filter. @@ -2356,7 +2350,7 @@ def delete_many( self, filter: FilterType, *, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> DeleteResult: """ Delete all documents matching a provided filter. @@ -2400,7 +2394,7 @@ def delete_many( collection is devoid of matches. An exception is the `filter={}` case, whereby the operation is atomic. """ - dm_responses: List[Dict[str, Any]] = [] + dm_responses: list[dict[str, Any]] = [] deleted_count = 0 must_proceed = True _max_time_ms = max_time_ms or self.api_options.max_time_ms @@ -2450,7 +2444,7 @@ def delete_many( current_version=__version__, details="Use delete_many with filter={} instead.", ) - def delete_all(self, *, max_time_ms: Optional[int] = None) -> Dict[str, Any]: + def delete_all(self, *, max_time_ms: int | None = None) -> dict[str, Any]: """ Delete all documents in a collection. @@ -2488,8 +2482,8 @@ def bulk_write( requests: Iterable[BaseOperation], *, ordered: bool = False, - concurrency: Optional[int] = None, - max_time_ms: Optional[int] = None, + concurrency: int | None = None, + max_time_ms: int | None = None, ) -> BulkWriteResult: """ Execute an arbitrary amount of operations such as inserts, updates, deletes @@ -2553,7 +2547,7 @@ def bulk_write( logger.info(f"startng a bulk write on '{self.name}'") timeout_manager = MultiCallTimeoutManager(overall_max_time_ms=_max_time_ms) if ordered: - bulk_write_results: List[BulkWriteResult] = [] + bulk_write_results: list[BulkWriteResult] = [] for operation_i, operation in enumerate(requests): try: this_bw_result = operation.execute( @@ -2600,7 +2594,7 @@ def bulk_write( def _execute_as_either( operation: BaseOperation, operation_i: int - ) -> Tuple[Optional[BulkWriteResult], Optional[DataAPIResponseException]]: + ) -> tuple[BulkWriteResult | None, DataAPIResponseException | None]: try: ex_result = operation.execute( self, @@ -2660,7 +2654,7 @@ def _execute_as_either( logger.info(f"finished a bulk write on '{self.name}'") return reduce_bulk_write_results(bulk_write_successes) - def drop(self, *, max_time_ms: Optional[int] = None) -> Dict[str, Any]: + def drop(self, *, max_time_ms: int | None = None) -> dict[str, Any]: """ Drop the collection, i.e. delete it from the database along with all the documents it contains. @@ -2703,11 +2697,11 @@ def drop(self, *, max_time_ms: Optional[int] = None) -> Dict[str, Any]: def command( self, - body: Dict[str, Any], + body: dict[str, Any], *, raise_api_errors: bool = True, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Send a POST request to the Data API for this collection with an arbitrary, caller-provided payload. @@ -2790,10 +2784,10 @@ def __init__( database: AsyncDatabase, name: str, *, - namespace: Optional[str] = None, - api_options: Optional[CollectionAPIOptions] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + namespace: str | None = None, + api_options: CollectionAPIOptions | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: if api_options is None: self.api_options = CollectionAPIOptions() @@ -2880,9 +2874,9 @@ async def __aenter__(self) -> AsyncCollection: async def __aexit__( self, - exc_type: Optional[Type[BaseException]] = None, - exc_value: Optional[BaseException] = None, - traceback: Optional[TracebackType] = None, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, ) -> None: if self._api_commander is not None: await self._api_commander.__aexit__( @@ -2894,12 +2888,12 @@ async def __aexit__( def _copy( self, *, - database: Optional[AsyncDatabase] = None, - name: Optional[str] = None, - namespace: Optional[str] = None, - api_options: Optional[CollectionAPIOptions] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + database: AsyncDatabase | None = None, + name: str | None = None, + namespace: str | None = None, + api_options: CollectionAPIOptions | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> AsyncCollection: return AsyncCollection( database=database or self.database._copy(), @@ -2913,11 +2907,11 @@ def _copy( def with_options( self, *, - name: Optional[str] = None, - embedding_api_key: Optional[Union[str, EmbeddingHeadersProvider]] = None, - collection_max_time_ms: Optional[int] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + name: str | None = None, + embedding_api_key: str | EmbeddingHeadersProvider | None = None, + collection_max_time_ms: int | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> AsyncCollection: """ Create a clone of this collection with some changed attributes. @@ -2972,13 +2966,13 @@ def with_options( def to_sync( self, *, - database: Optional[Database] = None, - name: Optional[str] = None, - namespace: Optional[str] = None, - embedding_api_key: Optional[Union[str, EmbeddingHeadersProvider]] = None, - collection_max_time_ms: Optional[int] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + database: Database | None = None, + name: str | None = None, + namespace: str | None = None, + embedding_api_key: str | EmbeddingHeadersProvider | None = None, + collection_max_time_ms: int | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> Collection: """ Create a Collection from this one. Save for the arguments @@ -3038,8 +3032,8 @@ def to_sync( def set_caller( self, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: """ Set a new identity for the application/framework on behalf of which @@ -3059,7 +3053,7 @@ def set_caller( self.caller_version = caller_version or self.caller_version self._api_commander = self._get_api_commander() - async def options(self, *, max_time_ms: Optional[int] = None) -> CollectionOptions: + async def options(self, *, max_time_ms: int | None = None) -> CollectionOptions: """ Get the collection options, i.e. its configuration as read from the database. @@ -3184,9 +3178,9 @@ async def insert_one( self, document: DocumentType, *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - max_time_ms: Optional[int] = None, + vector: VectorType | None = None, + vectorize: str | None = None, + max_time_ms: int | None = None, ) -> InsertOneResult: """ Insert a single document in the collection in an atomic operation. @@ -3279,12 +3273,12 @@ async def insert_many( self, documents: Iterable[DocumentType], *, - vectors: Optional[Iterable[Optional[VectorType]]] = None, - vectorize: Optional[Iterable[Optional[str]]] = None, + vectors: Iterable[VectorType | None] | None = None, + vectorize: Iterable[str | None] | None = None, ordered: bool = False, - chunk_size: Optional[int] = None, - concurrency: Optional[int] = None, - max_time_ms: Optional[int] = None, + chunk_size: int | None = None, + concurrency: int | None = None, + max_time_ms: int | None = None, ) -> InsertManyResult: """ Insert a list of documents into the collection. @@ -3406,11 +3400,11 @@ async def insert_many( _documents = _collate_vectors_to_documents(documents, vectors, vectorize) _max_time_ms = max_time_ms or self.api_options.max_time_ms logger.info(f"inserting {len(_documents)} documents in '{self.name}'") - raw_results: List[Dict[str, Any]] = [] + raw_results: list[dict[str, Any]] = [] timeout_manager = MultiCallTimeoutManager(overall_max_time_ms=_max_time_ms) if ordered: options = {"ordered": True} - inserted_ids: List[Any] = [] + inserted_ids: list[Any] = [] for i in range(0, len(_documents), _chunk_size): im_payload = { "insertMany": { @@ -3460,8 +3454,8 @@ async def insert_many( sem = asyncio.Semaphore(_concurrency) async def concurrent_insert_chunk( - document_chunk: List[DocumentType], - ) -> Dict[str, Any]: + document_chunk: list[DocumentType], + ) -> dict[str, Any]: async with sem: im_payload = { "insertMany": { @@ -3527,17 +3521,17 @@ async def concurrent_insert_chunk( def find( self, - filter: Optional[FilterType] = None, + filter: FilterType | None = None, *, - projection: Optional[ProjectionType] = None, - skip: Optional[int] = None, - limit: Optional[int] = None, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - include_similarity: Optional[bool] = None, - include_sort_vector: Optional[bool] = None, - sort: Optional[SortType] = None, - max_time_ms: Optional[int] = None, + projection: ProjectionType | None = None, + skip: int | None = None, + limit: int | None = None, + vector: VectorType | None = None, + vectorize: str | None = None, + include_similarity: bool | None = None, + include_sort_vector: bool | None = None, + sort: SortType | None = None, + max_time_ms: int | None = None, ) -> AsyncCursor: """ Find documents on the collection, matching a certain provided filter. @@ -3750,15 +3744,15 @@ def find( async def find_one( self, - filter: Optional[FilterType] = None, + filter: FilterType | None = None, *, - projection: Optional[ProjectionType] = None, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - include_similarity: Optional[bool] = None, - sort: Optional[SortType] = None, - max_time_ms: Optional[int] = None, - ) -> Union[DocumentType, None]: + projection: ProjectionType | None = None, + vector: VectorType | None = None, + vectorize: str | None = None, + include_similarity: bool | None = None, + sort: SortType | None = None, + max_time_ms: int | None = None, + ) -> DocumentType | None: """ Run a search, returning the first document in the collection that matches provided filters, if any is found. @@ -3879,9 +3873,9 @@ async def distinct( self, key: str, *, - filter: Optional[FilterType] = None, - max_time_ms: Optional[int] = None, - ) -> List[Any]: + filter: FilterType | None = None, + max_time_ms: int | None = None, + ) -> list[Any]: """ Return a list of the unique values of `key` across the documents in the collection that match the provided filter. @@ -3967,7 +3961,7 @@ async def count_documents( filter: FilterType, *, upper_bound: int, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> int: """ Count the documents in the collection matching the specified filter. @@ -4052,7 +4046,7 @@ async def count_documents( async def estimated_document_count( self, *, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> int: """ Query the API server for an estimate of the document count in the collection. @@ -4071,7 +4065,7 @@ async def estimated_document_count( 35700 """ _max_time_ms = max_time_ms or self.api_options.max_time_ms - ed_payload: Dict[str, Any] = {"estimatedDocumentCount": {}} + ed_payload: dict[str, Any] = {"estimatedDocumentCount": {}} logger.info(f"estimatedDocumentCount on '{self.name}'") ed_response = await self._api_commander.async_request( payload=ed_payload, @@ -4092,14 +4086,14 @@ async def find_one_and_replace( filter: FilterType, replacement: DocumentType, *, - projection: Optional[ProjectionType] = None, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, + projection: ProjectionType | None = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, upsert: bool = False, return_document: str = ReturnDocument.BEFORE, - max_time_ms: Optional[int] = None, - ) -> Union[DocumentType, None]: + max_time_ms: int | None = None, + ) -> DocumentType | None: """ Find a document on the collection and replace it entirely with a new one, optionally inserting a new one if no match is found. @@ -4250,11 +4244,11 @@ async def replace_one( filter: FilterType, replacement: DocumentType, *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, upsert: bool = False, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> UpdateResult: """ Replace a single document on the collection with a new one, @@ -4371,16 +4365,16 @@ async def replace_one( async def find_one_and_update( self, filter: FilterType, - update: Dict[str, Any], + update: dict[str, Any], *, - projection: Optional[ProjectionType] = None, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, + projection: ProjectionType | None = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, upsert: bool = False, return_document: str = ReturnDocument.BEFORE, - max_time_ms: Optional[int] = None, - ) -> Union[DocumentType, None]: + max_time_ms: int | None = None, + ) -> DocumentType | None: """ Find a document on the collection and update it as requested, optionally inserting a new one if no match is found. @@ -4535,13 +4529,13 @@ async def find_one_and_update( async def update_one( self, filter: FilterType, - update: Dict[str, Any], + update: dict[str, Any], *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, upsert: bool = False, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> UpdateResult: """ Update a single document on the collection as requested, @@ -4661,10 +4655,10 @@ async def update_one( async def update_many( self, filter: FilterType, - update: Dict[str, Any], + update: dict[str, Any], *, upsert: bool = False, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> UpdateResult: """ Apply an update operations to all documents matching a condition, @@ -4735,9 +4729,9 @@ async def update_many( api_options = { "upsert": upsert, } - page_state_options: Dict[str, str] = {} - um_responses: List[Dict[str, Any]] = [] - um_statuses: List[Dict[str, Any]] = [] + page_state_options: dict[str, str] = {} + um_responses: list[dict[str, Any]] = [] + um_statuses: list[dict[str, Any]] = [] must_proceed = True _max_time_ms = max_time_ms or self.api_options.max_time_ms logger.info(f"starting update_many on '{self.name}'") @@ -4803,12 +4797,12 @@ async def find_one_and_delete( self, filter: FilterType, *, - projection: Optional[ProjectionType] = None, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, - max_time_ms: Optional[int] = None, - ) -> Union[DocumentType, None]: + projection: ProjectionType | None = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, + max_time_ms: int | None = None, + ) -> DocumentType | None: """ Find a document in the collection and delete it. The deleted document, however, is the return value of the method. @@ -4928,10 +4922,10 @@ async def delete_one( self, filter: FilterType, *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, - max_time_ms: Optional[int] = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, + max_time_ms: int | None = None, ) -> DeleteResult: """ Delete one document matching a provided filter. @@ -5035,7 +5029,7 @@ async def delete_many( self, filter: FilterType, *, - max_time_ms: Optional[int] = None, + max_time_ms: int | None = None, ) -> DeleteResult: """ Delete all documents matching a provided filter. @@ -5084,7 +5078,7 @@ async def delete_many( collection is devoid of matches. An exception is the `filter={}` case, whereby the operation is atomic. """ - dm_responses: List[Dict[str, Any]] = [] + dm_responses: list[dict[str, Any]] = [] deleted_count = 0 must_proceed = True _max_time_ms = max_time_ms or self.api_options.max_time_ms @@ -5134,7 +5128,7 @@ async def delete_many( current_version=__version__, details="Use delete_many with filter={} instead.", ) - async def delete_all(self, *, max_time_ms: Optional[int] = None) -> Dict[str, Any]: + async def delete_all(self, *, max_time_ms: int | None = None) -> dict[str, Any]: """ Delete all documents in a collection. @@ -5179,8 +5173,8 @@ async def bulk_write( requests: Iterable[AsyncBaseOperation], *, ordered: bool = False, - concurrency: Optional[int] = None, - max_time_ms: Optional[int] = None, + concurrency: int | None = None, + max_time_ms: int | None = None, ) -> BulkWriteResult: """ Execute an arbitrary amount of operations such as inserts, updates, deletes @@ -5260,7 +5254,7 @@ async def bulk_write( logger.info(f"startng a bulk write on '{self.name}'") timeout_manager = MultiCallTimeoutManager(overall_max_time_ms=_max_time_ms) if ordered: - bulk_write_results: List[BulkWriteResult] = [] + bulk_write_results: list[BulkWriteResult] = [] for operation_i, operation in enumerate(requests): try: this_bw_result = await operation.execute( @@ -5304,12 +5298,11 @@ async def bulk_write( logger.info(f"finished a bulk write on '{self.name}'") return full_bw_result else: - sem = asyncio.Semaphore(_concurrency) async def _concurrent_execute_as_either( operation: AsyncBaseOperation, operation_i: int - ) -> Tuple[Optional[BulkWriteResult], Optional[DataAPIResponseException]]: + ) -> tuple[BulkWriteResult | None, DataAPIResponseException | None]: async with sem: try: ex_result = await operation.execute( @@ -5360,7 +5353,7 @@ async def _concurrent_execute_as_either( logger.info(f"finished a bulk write on '{self.name}'") return reduce_bulk_write_results(bulk_write_successes) - async def drop(self, *, max_time_ms: Optional[int] = None) -> Dict[str, Any]: + async def drop(self, *, max_time_ms: int | None = None) -> dict[str, Any]: """ Drop the collection, i.e. delete it from the database along with all the documents it contains. @@ -5410,11 +5403,11 @@ async def drop(self, *, max_time_ms: Optional[int] = None) -> Dict[str, Any]: async def command( self, - body: Dict[str, Any], + body: dict[str, Any], *, raise_api_errors: bool = True, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Send a POST request to the Data API for this collection with an arbitrary, caller-provided payload. diff --git a/astrapy/constants.py b/astrapy/constants.py index 9a881805..84143b8d 100644 --- a/astrapy/constants.py +++ b/astrapy/constants.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Any, Dict, Iterable, Optional, Union +from typing import Any, Dict, Iterable, Union from astrapy.defaults import ( DATA_API_ENVIRONMENT_CASSANDRA, @@ -40,8 +40,8 @@ def normalize_optional_projection( - projection: Optional[ProjectionType], -) -> Optional[Dict[str, Union[bool, Dict[str, Union[int, Iterable[int]]]]]]: + projection: ProjectionType | None, +) -> dict[str, bool | dict[str, int | Iterable[int]]] | None: if projection: if isinstance(projection, dict): # already a dictionary diff --git a/astrapy/core/api.py b/astrapy/core/api.py index eea793c3..1de8d642 100644 --- a/astrapy/core/api.py +++ b/astrapy/core/api.py @@ -15,7 +15,7 @@ from __future__ import annotations import logging -from typing import Any, Dict, Optional, Union, cast +from typing import Any, cast import httpx @@ -27,7 +27,7 @@ class APIRequestError(ValueError): def __init__( - self, response: httpx.Response, payload: Optional[Dict[str, Any]] + self, response: httpx.Response, payload: dict[str, Any] | None ) -> None: super().__init__(response.text) @@ -42,15 +42,15 @@ def raw_api_request( client: httpx.Client, base_url: str, auth_header: str, - token: Optional[str], + token: str | None, method: str, - json_data: Optional[Dict[str, Any]], - url_params: Optional[Dict[str, Any]], - path: Optional[str], - caller_name: Optional[str], - caller_version: Optional[str], - timeout: Optional[Union[httpx.Timeout, float]], - additional_headers: Dict[str, str], + json_data: dict[str, Any] | None, + url_params: dict[str, Any] | None, + path: str | None, + caller_name: str | None, + caller_version: str | None, + timeout: httpx.Timeout | float | None, + additional_headers: dict[str, str], ) -> httpx.Response: return make_request( client=client, @@ -71,7 +71,7 @@ def raw_api_request( def process_raw_api_response( raw_response: httpx.Response, skip_error_check: bool, - json_data: Optional[Dict[str, Any]], + json_data: dict[str, Any] | None, ) -> API_RESPONSE: # In case of other successful responses, parse the JSON body. try: @@ -95,16 +95,16 @@ def api_request( client: httpx.Client, base_url: str, auth_header: str, - token: Optional[str], + token: str | None, method: str, - json_data: Optional[Dict[str, Any]], - url_params: Optional[Dict[str, Any]], - path: Optional[str], + json_data: dict[str, Any] | None, + url_params: dict[str, Any] | None, + path: str | None, skip_error_check: bool, - caller_name: Optional[str], - caller_version: Optional[str], - timeout: Optional[Union[httpx.Timeout, float]], - additional_headers: Dict[str, str], + caller_name: str | None, + caller_version: str | None, + timeout: httpx.Timeout | float | None, + additional_headers: dict[str, str], ) -> API_RESPONSE: raw_response = raw_api_request( client=client, @@ -131,15 +131,15 @@ async def async_raw_api_request( client: httpx.AsyncClient, base_url: str, auth_header: str, - token: Optional[str], + token: str | None, method: str, - json_data: Optional[Dict[str, Any]], - url_params: Optional[Dict[str, Any]], - path: Optional[str], - caller_name: Optional[str], - caller_version: Optional[str], - timeout: Optional[Union[httpx.Timeout, float]], - additional_headers: Dict[str, str], + json_data: dict[str, Any] | None, + url_params: dict[str, Any] | None, + path: str | None, + caller_name: str | None, + caller_version: str | None, + timeout: httpx.Timeout | float | None, + additional_headers: dict[str, str], ) -> httpx.Response: return await amake_request( client=client, @@ -160,7 +160,7 @@ async def async_raw_api_request( async def async_process_raw_api_response( raw_response: httpx.Response, skip_error_check: bool, - json_data: Optional[Dict[str, Any]], + json_data: dict[str, Any] | None, ) -> API_RESPONSE: # In case of other successful responses, parse the JSON body. try: @@ -184,16 +184,16 @@ async def async_api_request( client: httpx.AsyncClient, base_url: str, auth_header: str, - token: Optional[str], + token: str | None, method: str, - json_data: Optional[Dict[str, Any]], - url_params: Optional[Dict[str, Any]], - path: Optional[str], + json_data: dict[str, Any] | None, + url_params: dict[str, Any] | None, + path: str | None, skip_error_check: bool, - caller_name: Optional[str], - caller_version: Optional[str], - timeout: Optional[Union[httpx.Timeout, float]], - additional_headers: Dict[str, str], + caller_name: str | None, + caller_version: str | None, + timeout: httpx.Timeout | float | None, + additional_headers: dict[str, str], ) -> API_RESPONSE: raw_response = await async_raw_api_request( client=client, diff --git a/astrapy/core/core_types.py b/astrapy/core/core_types.py index 0b2f8c56..ffc61c87 100644 --- a/astrapy/core/core_types.py +++ b/astrapy/core/core_types.py @@ -36,9 +36,9 @@ # This is for the (partialed, if necessary) functions that can be "paginated". class PaginableRequestMethod(Protocol): - def __call__(self, options: Dict[str, Any]) -> API_RESPONSE: ... + def __call__(self, options: dict[str, Any]) -> API_RESPONSE: ... # This is for the (partialed, if necessary) async functions that can be "paginated". class AsyncPaginableRequestMethod(Protocol): - async def __call__(self, options: Dict[str, Any]) -> API_RESPONSE: ... + async def __call__(self, options: dict[str, Any]) -> API_RESPONSE: ... diff --git a/astrapy/core/db.py b/astrapy/core/db.py index 88bdbfee..3a53b5ff 100644 --- a/astrapy/core/db.py +++ b/astrapy/core/db.py @@ -27,12 +27,8 @@ from typing import ( Any, Callable, - Dict, Iterator, List, - Optional, - Tuple, - Type, Union, cast, ) @@ -71,10 +67,10 @@ def __init__( self, prefetched: int, request_method: PaginableRequestMethod, - options: Optional[Dict[str, Any]], - raw_response_callback: Optional[Callable[[Dict[str, Any]], None]] = None, + options: dict[str, Any] | None, + raw_response_callback: Callable[[dict[str, Any]], None] | None = None, ): - self.queue: queue.Queue[Optional[API_DOC]] = queue.Queue(prefetched) + self.queue: queue.Queue[API_DOC | None] = queue.Queue(prefetched) self.request_method = request_method self.options = options self.raw_response_callback = raw_response_callback @@ -93,8 +89,8 @@ def __iter__(self) -> Iterator[API_DOC]: @staticmethod def queue_put( - q: queue.Queue[Optional[API_DOC]], - item: Optional[API_DOC], + q: queue.Queue[API_DOC | None], + item: API_DOC | None, stop: threading.Event, ) -> None: while not stop.is_set(): @@ -139,13 +135,13 @@ class AstraDBCollection: def __init__( self, collection_name: str, - astra_db: Optional[AstraDB] = None, - token: Optional[str] = None, - api_endpoint: Optional[str] = None, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - additional_headers: Dict[str, str] = {}, + astra_db: AstraDB | None = None, + token: str | None = None, + api_endpoint: str | None = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + additional_headers: dict[str, str] = {}, ) -> None: """ Initialize an AstraDBCollection instance. @@ -188,8 +184,8 @@ def __init__( # Set the remaining instance attributes self.astra_db = astra_db - self.caller_name: Optional[str] = self.astra_db.caller_name - self.caller_version: Optional[str] = self.astra_db.caller_version + self.caller_name: str | None = self.astra_db.caller_name + self.caller_version: str | None = self.astra_db.caller_version self.additional_headers = additional_headers self.collection_name = collection_name self.base_path: str = f"{self.astra_db.base_path}/{self.collection_name}" @@ -214,15 +210,15 @@ def __eq__(self, other: Any) -> bool: def copy( self, *, - collection_name: Optional[str] = None, - token: Optional[str] = None, - api_endpoint: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - additional_headers: Optional[Dict[str, str]] = None, + collection_name: str | None = None, + token: str | None = None, + api_endpoint: str | None = None, + api_path: str | None = None, + api_version: str | None = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + additional_headers: dict[str, str] | None = None, ) -> AstraDBCollection: return AstraDBCollection( collection_name=collection_name or self.collection_name, @@ -251,8 +247,8 @@ def to_async(self) -> AsyncAstraDBCollection: def set_caller( self, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: self.astra_db.set_caller( caller_name=caller_name, @@ -264,9 +260,9 @@ def set_caller( def _request( self, method: str = http_methods.POST, - path: Optional[str] = None, - json_data: Optional[Dict[str, Any]] = None, - url_params: Optional[Dict[str, Any]] = None, + path: str | None = None, + json_data: dict[str, Any] | None = None, + url_params: dict[str, Any] | None = None, skip_error_check: bool = False, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: @@ -289,7 +285,7 @@ def _request( return response def post_raw_request( - self, body: Dict[str, Any], timeout_info: TimeoutInfoWideType = None + self, body: dict[str, Any], timeout_info: TimeoutInfoWideType = None ) -> API_RESPONSE: return self._request( method=http_methods.POST, @@ -300,10 +296,10 @@ def post_raw_request( def _get( self, - path: Optional[str] = None, - options: Optional[Dict[str, Any]] = None, + path: str | None = None, + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, - ) -> Optional[API_RESPONSE]: + ) -> API_RESPONSE | None: full_path = f"{self.base_path}/{path}" if path else self.base_path response = self._request( method=http_methods.GET, @@ -317,8 +313,8 @@ def _get( def _put( self, - path: Optional[str] = None, - document: Optional[API_RESPONSE] = None, + path: str | None = None, + document: API_RESPONSE | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: full_path = f"{self.base_path}/{path}" if path else self.base_path @@ -332,8 +328,8 @@ def _put( def _post( self, - path: Optional[str] = None, - document: Optional[API_DOC] = None, + path: str | None = None, + document: API_DOC | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: full_path = f"{self.base_path}/{path}" if path else self.base_path @@ -346,8 +342,8 @@ def _post( return response def _recast_as_sort_projection( - self, vector: List[float], fields: Optional[List[str]] = None - ) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]]]: + self, vector: list[float], fields: list[str] | None = None + ) -> tuple[dict[str, Any], dict[str, Any] | None]: """ Given a vector and optionally a list of fields, reformulate them as a sort, projection pair for regular @@ -362,7 +358,7 @@ def _recast_as_sort_projection( raise ValueError("Please use the `include_similarity` parameter") # Build the new vector parameter - sort: Dict[str, Any] = {"$vector": vector} + sort: dict[str, Any] = {"$vector": vector} # Build the new fields parameter # Note: do not leave projection={}, make it None @@ -375,8 +371,8 @@ def _recast_as_sort_projection( return sort, projection def get( - self, path: Optional[str] = None, timeout_info: TimeoutInfoWideType = None - ) -> Optional[API_RESPONSE]: + self, path: str | None = None, timeout_info: TimeoutInfoWideType = None + ) -> API_RESPONSE | None: """ Retrieve a document from the collection by its path. @@ -393,10 +389,10 @@ def get( def find( self, - filter: Optional[Dict[str, Any]] = None, - projection: Optional[Dict[str, Any]] = None, - sort: Optional[Dict[str, Any]] = None, - options: Optional[Dict[str, Any]] = None, + filter: dict[str, Any] | None = None, + projection: dict[str, Any] | None = None, + sort: dict[str, Any] | None = None, + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -428,14 +424,14 @@ def find( def vector_find( self, - vector: List[float], + vector: list[float], *, limit: int, - filter: Optional[Dict[str, Any]] = None, - fields: Optional[List[str]] = None, + filter: dict[str, Any] | None = None, + fields: list[str] | None = None, include_similarity: bool = True, timeout_info: TimeoutInfoWideType = None, - ) -> List[API_DOC]: + ) -> list[API_DOC]: """ Perform a vector-based search in the collection. @@ -480,8 +476,8 @@ def vector_find( def paginate( *, request_method: PaginableRequestMethod, - options: Optional[Dict[str, Any]], - raw_response_callback: Optional[Callable[[Dict[str, Any]], None]] = None, + options: dict[str, Any] | None, + raw_response_callback: Callable[[dict[str, Any]], None] | None = None, ) -> Iterator[API_DOC]: """ Generate paginated results for a given database query method. @@ -517,13 +513,13 @@ def paginate( def paginated_find( self, - filter: Optional[Dict[str, Any]] = None, - projection: Optional[Dict[str, Any]] = None, - sort: Optional[Dict[str, Any]] = None, - options: Optional[Dict[str, Any]] = None, - prefetched: Optional[int] = None, + filter: dict[str, Any] | None = None, + projection: dict[str, Any] | None = None, + sort: dict[str, Any] | None = None, + options: dict[str, Any] | None = None, + prefetched: int | None = None, timeout_info: TimeoutInfoWideType = None, - raw_response_callback: Optional[Callable[[Dict[str, Any]], None]] = None, + raw_response_callback: Callable[[dict[str, Any]], None] | None = None, ) -> Iterator[API_DOC]: """ Perform a paginated search in the collection. @@ -569,9 +565,9 @@ def paginated_find( def pop( self, - filter: Dict[str, Any], - pop: Dict[str, Any], - options: Dict[str, Any], + filter: dict[str, Any], + pop: dict[str, Any], + options: dict[str, Any], timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -606,9 +602,9 @@ def pop( def push( self, - filter: Dict[str, Any], - push: Dict[str, Any], - options: Dict[str, Any], + filter: dict[str, Any], + push: dict[str, Any], + options: dict[str, Any], timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -643,12 +639,12 @@ def push( def find_one_and_replace( self, - replacement: Dict[str, Any], + replacement: dict[str, Any], *, - filter: Optional[Dict[str, Any]] = None, - projection: Optional[Dict[str, Any]] = None, - sort: Optional[Dict[str, Any]] = None, - options: Optional[Dict[str, Any]] = None, + filter: dict[str, Any] | None = None, + projection: dict[str, Any] | None = None, + sort: dict[str, Any] | None = None, + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -686,13 +682,13 @@ def find_one_and_replace( def vector_find_one_and_replace( self, - vector: List[float], - replacement: Dict[str, Any], + vector: list[float], + replacement: dict[str, Any], *, - filter: Optional[Dict[str, Any]] = None, - fields: Optional[List[str]] = None, + filter: dict[str, Any] | None = None, + fields: list[str] | None = None, timeout_info: TimeoutInfoWideType = None, - ) -> Union[API_DOC, None]: + ) -> API_DOC | None: """ Perform a vector-based search and replace the first matched document. @@ -727,11 +723,11 @@ def vector_find_one_and_replace( def find_one_and_update( self, - update: Dict[str, Any], - sort: Optional[Dict[str, Any]] = {}, - filter: Optional[Dict[str, Any]] = None, - options: Optional[Dict[str, Any]] = None, - projection: Optional[Dict[str, Any]] = None, + update: dict[str, Any], + sort: dict[str, Any] | None = {}, + filter: dict[str, Any] | None = None, + options: dict[str, Any] | None = None, + projection: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -769,13 +765,13 @@ def find_one_and_update( def vector_find_one_and_update( self, - vector: List[float], - update: Dict[str, Any], + vector: list[float], + update: dict[str, Any], *, - filter: Optional[Dict[str, Any]] = None, - fields: Optional[List[str]] = None, + filter: dict[str, Any] | None = None, + fields: list[str] | None = None, timeout_info: TimeoutInfoWideType = None, - ) -> Union[API_DOC, None]: + ) -> API_DOC | None: """ Perform a vector-based search and update the first matched document. @@ -811,9 +807,9 @@ def vector_find_one_and_update( def find_one_and_delete( self, - sort: Optional[Dict[str, Any]] = {}, - filter: Optional[Dict[str, Any]] = None, - projection: Optional[Dict[str, Any]] = None, + sort: dict[str, Any] | None = {}, + filter: dict[str, Any] | None = None, + projection: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -847,7 +843,7 @@ def find_one_and_delete( return response def count_documents( - self, filter: Dict[str, Any] = {}, timeout_info: TimeoutInfoWideType = None + self, filter: dict[str, Any] = {}, timeout_info: TimeoutInfoWideType = None ) -> API_RESPONSE: """ Count documents matching a given predicate (expressed as filter). @@ -875,10 +871,10 @@ def count_documents( def find_one( self, - filter: Optional[Dict[str, Any]] = {}, - projection: Optional[Dict[str, Any]] = {}, - sort: Optional[Dict[str, Any]] = {}, - options: Optional[Dict[str, Any]] = {}, + filter: dict[str, Any] | None = {}, + projection: dict[str, Any] | None = {}, + sort: dict[str, Any] | None = {}, + options: dict[str, Any] | None = {}, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -914,13 +910,13 @@ def find_one( def vector_find_one( self, - vector: List[float], + vector: list[float], *, - filter: Optional[Dict[str, Any]] = None, - fields: Optional[List[str]] = None, + filter: dict[str, Any] | None = None, + fields: list[str] | None = None, include_similarity: bool = True, timeout_info: TimeoutInfoWideType = None, - ) -> Union[API_DOC, None]: + ) -> API_DOC | None: """ Perform a vector-based search to find a single document in the collection. @@ -986,8 +982,8 @@ def insert_one( def insert_many( self, - documents: List[API_DOC], - options: Optional[Dict[str, Any]] = None, + documents: list[API_DOC], + options: dict[str, Any] | None = None, partial_failures_allowed: bool = False, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: @@ -1023,13 +1019,13 @@ def insert_many( def chunked_insert_many( self, - documents: List[API_DOC], - options: Optional[Dict[str, Any]] = None, + documents: list[API_DOC], + options: dict[str, Any] | None = None, partial_failures_allowed: bool = False, chunk_size: int = DEFAULT_INSERT_NUM_DOCUMENTS, concurrency: int = 1, timeout_info: TimeoutInfoWideType = None, - ) -> List[Union[API_RESPONSE, Exception]]: + ) -> list[API_RESPONSE | Exception]: """ Insert multiple documents into the collection, handling chunking and optionally with concurrent insertions. @@ -1054,7 +1050,7 @@ def chunked_insert_many( This is a list of individual responses from the API: the caller will need to inspect them all, e.g. to collate the inserted IDs. """ - results: List[Union[API_RESPONSE, Exception]] = [] + results: list[API_RESPONSE | Exception] = [] # Raise a warning if ordered and concurrency if options and options.get("ordered") is True and concurrency > 1: @@ -1110,11 +1106,11 @@ def chunked_insert_many( def update_one( self, - filter: Dict[str, Any], - update: Dict[str, Any], - sort: Optional[Dict[str, Any]] = None, + filter: dict[str, Any], + update: dict[str, Any], + sort: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, - options: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, ) -> API_RESPONSE: """ Update a single document in the collection. @@ -1149,9 +1145,9 @@ def update_one( def update_many( self, - filter: Dict[str, Any], - update: Dict[str, Any], - options: Optional[Dict[str, Any]] = None, + filter: dict[str, Any], + update: dict[str, Any], + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -1204,7 +1200,7 @@ def replace( def delete_one( self, id: str, - sort: Optional[Dict[str, Any]] = None, + sort: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -1236,8 +1232,8 @@ def delete_one( def delete_one_by_predicate( self, - filter: Dict[str, Any], - sort: Optional[Dict[str, Any]] = None, + filter: dict[str, Any], + sort: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -1269,7 +1265,7 @@ def delete_one_by_predicate( def delete_many( self, - filter: Dict[str, Any], + filter: dict[str, Any], skip_error_check: bool = False, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: @@ -1304,8 +1300,8 @@ def delete_many( return response def chunked_delete_many( - self, filter: Dict[str, Any], timeout_info: TimeoutInfoWideType = None - ) -> List[API_RESPONSE]: + self, filter: dict[str, Any], timeout_info: TimeoutInfoWideType = None + ) -> list[API_RESPONSE]: """ Delete many documents from the collection based on a filter condition, chaining several API calls until exhaustion of the documents to delete. @@ -1436,7 +1432,7 @@ def upsert_many( concurrency: int = 1, partial_failures_allowed: bool = False, timeout_info: TimeoutInfoWideType = None, - ) -> List[Union[str, Exception]]: + ) -> list[str | Exception]: """ Emulate an upsert operation for multiple documents in the collection. @@ -1457,7 +1453,7 @@ def upsert_many( Returns: List[Union[str, Exception]]: A list of "_id"s of the inserted or updated documents. """ - results: List[Union[str, Exception]] = [] + results: list[str | Exception] = [] # If concurrency is 1, no need for thread pool if concurrency == 1: @@ -1493,13 +1489,13 @@ class AsyncAstraDBCollection: def __init__( self, collection_name: str, - astra_db: Optional[AsyncAstraDB] = None, - token: Optional[str] = None, - api_endpoint: Optional[str] = None, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - additional_headers: Dict[str, str] = {}, + astra_db: AsyncAstraDB | None = None, + token: str | None = None, + api_endpoint: str | None = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + additional_headers: dict[str, str] = {}, ) -> None: """ Initialize an AstraDBCollection instance. @@ -1542,8 +1538,8 @@ def __init__( # Set the remaining instance attributes self.astra_db: AsyncAstraDB = astra_db - self.caller_name: Optional[str] = self.astra_db.caller_name - self.caller_version: Optional[str] = self.astra_db.caller_version + self.caller_name: str | None = self.astra_db.caller_name + self.caller_version: str | None = self.astra_db.caller_version self.additional_headers = additional_headers self.client = astra_db.client self.collection_name = collection_name @@ -1569,15 +1565,15 @@ def __eq__(self, other: Any) -> bool: def copy( self, *, - collection_name: Optional[str] = None, - token: Optional[str] = None, - api_endpoint: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - additional_headers: Optional[Dict[str, str]] = None, + collection_name: str | None = None, + token: str | None = None, + api_endpoint: str | None = None, + api_path: str | None = None, + api_version: str | None = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + additional_headers: dict[str, str] | None = None, ) -> AsyncAstraDBCollection: return AsyncAstraDBCollection( collection_name=collection_name or self.collection_name, @@ -1597,8 +1593,8 @@ def copy( def set_caller( self, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: self.astra_db.set_caller( caller_name=caller_name, @@ -1619,9 +1615,9 @@ def to_sync(self) -> AstraDBCollection: async def _request( self, method: str = http_methods.POST, - path: Optional[str] = None, - json_data: Optional[Dict[str, Any]] = None, - url_params: Optional[Dict[str, Any]] = None, + path: str | None = None, + json_data: dict[str, Any] | None = None, + url_params: dict[str, Any] | None = None, skip_error_check: bool = False, timeout_info: TimeoutInfoWideType = None, **kwargs: Any, @@ -1645,7 +1641,7 @@ async def _request( return response async def post_raw_request( - self, body: Dict[str, Any], timeout_info: TimeoutInfoWideType = None + self, body: dict[str, Any], timeout_info: TimeoutInfoWideType = None ) -> API_RESPONSE: return await self._request( method=http_methods.POST, @@ -1656,10 +1652,10 @@ async def post_raw_request( async def _get( self, - path: Optional[str] = None, - options: Optional[Dict[str, Any]] = None, + path: str | None = None, + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, - ) -> Optional[API_RESPONSE]: + ) -> API_RESPONSE | None: full_path = f"{self.base_path}/{path}" if path else self.base_path response = await self._request( method=http_methods.GET, @@ -1673,8 +1669,8 @@ async def _get( async def _put( self, - path: Optional[str] = None, - document: Optional[API_RESPONSE] = None, + path: str | None = None, + document: API_RESPONSE | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: full_path = f"{self.base_path}/{path}" if path else self.base_path @@ -1688,8 +1684,8 @@ async def _put( async def _post( self, - path: Optional[str] = None, - document: Optional[API_DOC] = None, + path: str | None = None, + document: API_DOC | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: full_path = f"{self.base_path}/{path}" if path else self.base_path @@ -1702,8 +1698,8 @@ async def _post( return response def _recast_as_sort_projection( - self, vector: List[float], fields: Optional[List[str]] = None - ) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]]]: + self, vector: list[float], fields: list[str] | None = None + ) -> tuple[dict[str, Any], dict[str, Any] | None]: """ Given a vector and optionally a list of fields, reformulate them as a sort, projection pair for regular @@ -1718,7 +1714,7 @@ def _recast_as_sort_projection( raise ValueError("Please use the `include_similarity` parameter") # Build the new vector parameter - sort: Dict[str, Any] = {"$vector": vector} + sort: dict[str, Any] = {"$vector": vector} # Build the new fields parameter # Note: do not leave projection={}, make it None @@ -1731,8 +1727,8 @@ def _recast_as_sort_projection( return sort, projection async def get( - self, path: Optional[str] = None, timeout_info: TimeoutInfoWideType = None - ) -> Optional[API_RESPONSE]: + self, path: str | None = None, timeout_info: TimeoutInfoWideType = None + ) -> API_RESPONSE | None: """ Retrieve a document from the collection by its path. @@ -1749,10 +1745,10 @@ async def get( async def find( self, - filter: Optional[Dict[str, Any]] = None, - projection: Optional[Dict[str, Any]] = None, - sort: Optional[Dict[str, Any]] = None, - options: Optional[Dict[str, Any]] = None, + filter: dict[str, Any] | None = None, + projection: dict[str, Any] | None = None, + sort: dict[str, Any] | None = None, + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -1784,14 +1780,14 @@ async def find( async def vector_find( self, - vector: List[float], + vector: list[float], *, limit: int, - filter: Optional[Dict[str, Any]] = None, - fields: Optional[List[str]] = None, + filter: dict[str, Any] | None = None, + fields: list[str] | None = None, include_similarity: bool = True, timeout_info: TimeoutInfoWideType = None, - ) -> List[API_DOC]: + ) -> list[API_DOC]: """ Perform a vector-based search in the collection. @@ -1836,10 +1832,10 @@ async def vector_find( async def paginate( *, request_method: AsyncPaginableRequestMethod, - options: Optional[Dict[str, Any]], - prefetched: Optional[int] = None, + options: dict[str, Any] | None, + prefetched: int | None = None, timeout_info: TimeoutInfoWideType = None, - raw_response_callback: Optional[Callable[[Dict[str, Any]], None]] = None, + raw_response_callback: Callable[[dict[str, Any]], None] | None = None, ) -> AsyncGenerator[API_DOC, None]: """ Generate paginated results for a given database query method. @@ -1867,9 +1863,9 @@ async def paginate( if next_page_state is not None and prefetched: async def queued_paginate( - queue: asyncio.Queue[Optional[API_DOC]], + queue: asyncio.Queue[API_DOC | None], request_method: AsyncPaginableRequestMethod, - options: Optional[Dict[str, Any]], + options: dict[str, Any] | None, ) -> None: try: async for doc in AsyncAstraDBCollection.paginate( @@ -1879,7 +1875,7 @@ async def queued_paginate( finally: await queue.put(None) - queue: asyncio.Queue[Optional[API_DOC]] = asyncio.Queue(prefetched) + queue: asyncio.Queue[API_DOC | None] = asyncio.Queue(prefetched) options1 = {**options0, **{"pageState": next_page_state}} asyncio.create_task(queued_paginate(queue, request_method, options1)) for document in response0["data"]["documents"]: @@ -1902,13 +1898,13 @@ async def queued_paginate( def paginated_find( self, - filter: Optional[Dict[str, Any]] = None, - projection: Optional[Dict[str, Any]] = None, - sort: Optional[Dict[str, Any]] = None, - options: Optional[Dict[str, Any]] = None, - prefetched: Optional[int] = None, + filter: dict[str, Any] | None = None, + projection: dict[str, Any] | None = None, + sort: dict[str, Any] | None = None, + options: dict[str, Any] | None = None, + prefetched: int | None = None, timeout_info: TimeoutInfoWideType = None, - raw_response_callback: Optional[Callable[[Dict[str, Any]], None]] = None, + raw_response_callback: Callable[[dict[str, Any]], None] | None = None, ) -> AsyncIterator[API_DOC]: """ Perform a paginated search in the collection. @@ -1948,9 +1944,9 @@ def paginated_find( async def pop( self, - filter: Dict[str, Any], - pop: Dict[str, Any], - options: Dict[str, Any], + filter: dict[str, Any], + pop: dict[str, Any], + options: dict[str, Any], timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -1985,9 +1981,9 @@ async def pop( async def push( self, - filter: Dict[str, Any], - push: Dict[str, Any], - options: Dict[str, Any], + filter: dict[str, Any], + push: dict[str, Any], + options: dict[str, Any], timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -2022,12 +2018,12 @@ async def push( async def find_one_and_replace( self, - replacement: Dict[str, Any], + replacement: dict[str, Any], *, - filter: Optional[Dict[str, Any]] = None, - projection: Optional[Dict[str, Any]] = None, - sort: Optional[Dict[str, Any]] = None, - options: Optional[Dict[str, Any]] = None, + filter: dict[str, Any] | None = None, + projection: dict[str, Any] | None = None, + sort: dict[str, Any] | None = None, + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -2065,13 +2061,13 @@ async def find_one_and_replace( async def vector_find_one_and_replace( self, - vector: List[float], - replacement: Dict[str, Any], + vector: list[float], + replacement: dict[str, Any], *, - filter: Optional[Dict[str, Any]] = None, - fields: Optional[List[str]] = None, + filter: dict[str, Any] | None = None, + fields: list[str] | None = None, timeout_info: TimeoutInfoWideType = None, - ) -> Union[API_DOC, None]: + ) -> API_DOC | None: """ Perform a vector-based search and replace the first matched document. @@ -2106,11 +2102,11 @@ async def vector_find_one_and_replace( async def find_one_and_update( self, - update: Dict[str, Any], - sort: Optional[Dict[str, Any]] = {}, - filter: Optional[Dict[str, Any]] = None, - options: Optional[Dict[str, Any]] = None, - projection: Optional[Dict[str, Any]] = None, + update: dict[str, Any], + sort: dict[str, Any] | None = {}, + filter: dict[str, Any] | None = None, + options: dict[str, Any] | None = None, + projection: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -2148,13 +2144,13 @@ async def find_one_and_update( async def vector_find_one_and_update( self, - vector: List[float], - update: Dict[str, Any], + vector: list[float], + update: dict[str, Any], *, - filter: Optional[Dict[str, Any]] = None, - fields: Optional[List[str]] = None, + filter: dict[str, Any] | None = None, + fields: list[str] | None = None, timeout_info: TimeoutInfoWideType = None, - ) -> Union[API_DOC, None]: + ) -> API_DOC | None: """ Perform a vector-based search and update the first matched document. @@ -2190,9 +2186,9 @@ async def vector_find_one_and_update( async def find_one_and_delete( self, - sort: Optional[Dict[str, Any]] = {}, - filter: Optional[Dict[str, Any]] = None, - projection: Optional[Dict[str, Any]] = None, + sort: dict[str, Any] | None = {}, + filter: dict[str, Any] | None = None, + projection: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -2226,7 +2222,7 @@ async def find_one_and_delete( return response async def count_documents( - self, filter: Dict[str, Any] = {}, timeout_info: TimeoutInfoWideType = None + self, filter: dict[str, Any] = {}, timeout_info: TimeoutInfoWideType = None ) -> API_RESPONSE: """ Count documents matching a given predicate (expressed as filter). @@ -2254,10 +2250,10 @@ async def count_documents( async def find_one( self, - filter: Optional[Dict[str, Any]] = {}, - projection: Optional[Dict[str, Any]] = {}, - sort: Optional[Dict[str, Any]] = {}, - options: Optional[Dict[str, Any]] = {}, + filter: dict[str, Any] | None = {}, + projection: dict[str, Any] | None = {}, + sort: dict[str, Any] | None = {}, + options: dict[str, Any] | None = {}, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -2293,13 +2289,13 @@ async def find_one( async def vector_find_one( self, - vector: List[float], + vector: list[float], *, - filter: Optional[Dict[str, Any]] = None, - fields: Optional[List[str]] = None, + filter: dict[str, Any] | None = None, + fields: list[str] | None = None, include_similarity: bool = True, timeout_info: TimeoutInfoWideType = None, - ) -> Union[API_DOC, None]: + ) -> API_DOC | None: """ Perform a vector-based search to find a single document in the collection. @@ -2365,8 +2361,8 @@ async def insert_one( async def insert_many( self, - documents: List[API_DOC], - options: Optional[Dict[str, Any]] = None, + documents: list[API_DOC], + options: dict[str, Any] | None = None, partial_failures_allowed: bool = False, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: @@ -2401,13 +2397,13 @@ async def insert_many( async def chunked_insert_many( self, - documents: List[API_DOC], - options: Optional[Dict[str, Any]] = None, + documents: list[API_DOC], + options: dict[str, Any] | None = None, partial_failures_allowed: bool = False, chunk_size: int = DEFAULT_INSERT_NUM_DOCUMENTS, concurrency: int = 1, timeout_info: TimeoutInfoWideType = None, - ) -> List[Union[API_RESPONSE, Exception]]: + ) -> list[API_RESPONSE | Exception]: """ Insert multiple documents into the collection, handling chunking and optionally with concurrent insertions. @@ -2435,10 +2431,10 @@ async def chunked_insert_many( sem = asyncio.Semaphore(concurrency) async def concurrent_insert_many( - docs: List[API_DOC], + docs: list[API_DOC], index: int, partial_failures_allowed: bool, - ) -> Union[API_RESPONSE, Exception]: + ) -> API_RESPONSE | Exception: async with sem: logger.debug(f"Processing chunk #{index + 1} of size {len(docs)}") try: @@ -2488,11 +2484,11 @@ async def concurrent_insert_many( async def update_one( self, - filter: Dict[str, Any], - update: Dict[str, Any], - sort: Optional[Dict[str, Any]] = None, + filter: dict[str, Any], + update: dict[str, Any], + sort: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, - options: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, ) -> API_RESPONSE: """ Update a single document in the collection. @@ -2527,9 +2523,9 @@ async def update_one( async def update_many( self, - filter: Dict[str, Any], - update: Dict[str, Any], - options: Optional[Dict[str, Any]] = None, + filter: dict[str, Any], + update: dict[str, Any], + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -2582,7 +2578,7 @@ async def replace( async def delete_one( self, id: str, - sort: Optional[Dict[str, Any]] = None, + sort: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -2614,8 +2610,8 @@ async def delete_one( async def delete_one_by_predicate( self, - filter: Dict[str, Any], - sort: Optional[Dict[str, Any]] = None, + filter: dict[str, Any], + sort: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -2647,7 +2643,7 @@ async def delete_one_by_predicate( async def delete_many( self, - filter: Dict[str, Any], + filter: dict[str, Any], skip_error_check: bool = False, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: @@ -2682,8 +2678,8 @@ async def delete_many( return response async def chunked_delete_many( - self, filter: Dict[str, Any], timeout_info: TimeoutInfoWideType = None - ) -> List[API_RESPONSE]: + self, filter: dict[str, Any], timeout_info: TimeoutInfoWideType = None + ) -> list[API_RESPONSE]: """ Delete many documents from the collection based on a filter condition, chaining several API calls until exhaustion of the documents to delete. @@ -2818,7 +2814,7 @@ async def upsert_many( concurrency: int = 1, partial_failures_allowed: bool = False, timeout_info: TimeoutInfoWideType = None, - ) -> List[Union[str, Exception]]: + ) -> list[str | Exception]: """ Emulate an upsert operation for multiple documents in the collection. This method attempts to insert the documents. @@ -2860,13 +2856,13 @@ class AstraDB: def __init__( self, - token: Optional[str], + token: str | None, api_endpoint: str, - api_path: Optional[str] = None, - api_version: Optional[str] = None, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + api_path: str | None = None, + api_version: str | None = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: """ Initialize an Astra DB instance. @@ -2940,13 +2936,13 @@ def __eq__(self, other: Any) -> bool: def copy( self, *, - token: Optional[str] = None, - api_endpoint: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + token: str | None = None, + api_endpoint: str | None = None, + api_path: str | None = None, + api_version: str | None = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> AstraDB: return AstraDB( token=token or self.token, @@ -2971,8 +2967,8 @@ def to_async(self) -> AsyncAstraDB: def set_caller( self, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: self.caller_name = caller_name self.caller_version = caller_version @@ -2980,9 +2976,9 @@ def set_caller( def _request( self, method: str = http_methods.POST, - path: Optional[str] = None, - json_data: Optional[Dict[str, Any]] = None, - url_params: Optional[Dict[str, Any]] = None, + path: str | None = None, + json_data: dict[str, Any] | None = None, + url_params: dict[str, Any] | None = None, skip_error_check: bool = False, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: @@ -3005,7 +3001,7 @@ def _request( return response def post_raw_request( - self, body: Dict[str, Any], timeout_info: TimeoutInfoWideType = None + self, body: dict[str, Any], timeout_info: TimeoutInfoWideType = None ) -> API_RESPONSE: return self._request( method=http_methods.POST, @@ -3028,7 +3024,7 @@ def collection(self, collection_name: str) -> AstraDBCollection: def get_collections( self, - options: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -3066,10 +3062,10 @@ def create_collection( self, collection_name: str, *, - options: Optional[Dict[str, Any]] = None, - dimension: Optional[int] = None, - metric: Optional[str] = None, - service_dict: Optional[Dict[str, str]] = None, + options: dict[str, Any] | None = None, + dimension: int | None = None, + metric: str | None = None, + service_dict: dict[str, str] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> AstraDBCollection: """ @@ -3169,13 +3165,13 @@ def delete_collection( class AsyncAstraDB: def __init__( self, - token: Optional[str], + token: str | None, api_endpoint: str, - api_path: Optional[str] = None, - api_version: Optional[str] = None, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + api_path: str | None = None, + api_version: str | None = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: """ Initialize an Astra DB instance. @@ -3252,22 +3248,22 @@ async def __aenter__(self) -> AsyncAstraDB: async def __aexit__( self, - exc_type: Optional[Type[BaseException]] = None, - exc_value: Optional[BaseException] = None, - traceback: Optional[TracebackType] = None, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, ) -> None: await self.client.aclose() def copy( self, *, - token: Optional[str] = None, - api_endpoint: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + token: str | None = None, + api_endpoint: str | None = None, + api_path: str | None = None, + api_version: str | None = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> AsyncAstraDB: return AsyncAstraDB( token=token or self.token, @@ -3292,8 +3288,8 @@ def to_sync(self) -> AstraDB: def set_caller( self, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: self.caller_name = caller_name self.caller_version = caller_version @@ -3301,9 +3297,9 @@ def set_caller( async def _request( self, method: str = http_methods.POST, - path: Optional[str] = None, - json_data: Optional[Dict[str, Any]] = None, - url_params: Optional[Dict[str, Any]] = None, + path: str | None = None, + json_data: dict[str, Any] | None = None, + url_params: dict[str, Any] | None = None, skip_error_check: bool = False, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: @@ -3326,7 +3322,7 @@ async def _request( return response async def post_raw_request( - self, body: Dict[str, Any], timeout_info: TimeoutInfoWideType = None + self, body: dict[str, Any], timeout_info: TimeoutInfoWideType = None ) -> API_RESPONSE: return await self._request( method=http_methods.POST, @@ -3352,7 +3348,7 @@ async def collection(self, collection_name: str) -> AsyncAstraDBCollection: async def get_collections( self, - options: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -3390,10 +3386,10 @@ async def create_collection( self, collection_name: str, *, - options: Optional[Dict[str, Any]] = None, - dimension: Optional[int] = None, - metric: Optional[str] = None, - service_dict: Optional[Dict[str, str]] = None, + options: dict[str, Any] | None = None, + dimension: int | None = None, + metric: str | None = None, + service_dict: dict[str, str] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> AsyncAstraDBCollection: """ diff --git a/astrapy/core/ops.py b/astrapy/core/ops.py index 9cc90eb8..2b630d93 100644 --- a/astrapy/core/ops.py +++ b/astrapy/core/ops.py @@ -15,7 +15,7 @@ from __future__ import annotations import logging -from typing import Any, Dict, Optional, TypedDict, Union, cast +from typing import Any, TypedDict, cast import httpx @@ -38,11 +38,11 @@ class AstraDBOpsConstructorParams(TypedDict): - token: Union[str, None] - dev_ops_url: Optional[str] - dev_ops_api_version: Optional[str] - caller_name: Optional[str] - caller_version: Optional[str] + token: str | None + dev_ops_url: str | None + dev_ops_api_version: str | None + caller_name: str | None + caller_version: str | None class AstraDBOps: @@ -52,11 +52,11 @@ class AstraDBOps: def __init__( self, - token: Union[str, None], - dev_ops_url: Optional[str] = None, - dev_ops_api_version: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + token: str | None, + dev_ops_url: str | None = None, + dev_ops_api_version: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: self.caller_name = caller_name self.caller_version = caller_version @@ -74,7 +74,7 @@ def __init__( dev_ops_api_version or DEFAULT_DEV_OPS_API_VERSION ).strip("/") - self.token: Union[str, None] + self.token: str | None if token is not None: self.token = "Bearer " + token else: @@ -98,11 +98,11 @@ def __eq__(self, other: Any) -> bool: def copy( self, *, - token: Optional[str] = None, - dev_ops_url: Optional[str] = None, - dev_ops_api_version: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + token: str | None = None, + dev_ops_url: str | None = None, + dev_ops_api_version: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> AstraDBOps: return AstraDBOps( token=token or self.constructor_params["token"], @@ -115,8 +115,8 @@ def copy( def set_caller( self, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: self.caller_name = caller_name self.caller_version = caller_version @@ -125,8 +125,8 @@ def _ops_request( self, method: str, path: str, - options: Optional[Dict[str, Any]] = None, - json_data: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, + json_data: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> httpx.Response: _options = {} if options is None else options @@ -151,8 +151,8 @@ async def _async_ops_request( self, method: str, path: str, - options: Optional[Dict[str, Any]] = None, - json_data: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, + json_data: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> httpx.Response: _options = {} if options is None else options @@ -177,8 +177,8 @@ def _json_ops_request( self, method: str, path: str, - options: Optional[Dict[str, Any]] = None, - json_data: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, + json_data: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: _options = {} if options is None else options @@ -204,8 +204,8 @@ async def _async_json_ops_request( self, method: str, path: str, - options: Optional[Dict[str, Any]] = None, - json_data: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, + json_data: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: _options = {} if options is None else options @@ -229,7 +229,7 @@ async def _async_json_ops_request( def get_databases( self, - options: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -252,7 +252,7 @@ def get_databases( async def async_get_databases( self, - options: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -275,9 +275,9 @@ async def async_get_databases( def create_database( self, - database_definition: Optional[Dict[str, Any]] = None, + database_definition: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, - ) -> Dict[str, str]: + ) -> dict[str, str]: """ Create a new database. @@ -305,9 +305,9 @@ def create_database( async def async_create_database( self, - database_definition: Optional[Dict[str, Any]] = None, + database_definition: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, - ) -> Dict[str, str]: + ) -> dict[str, str]: """ Create a new database - async version of the method. @@ -392,7 +392,7 @@ async def async_terminate_database( def get_database( self, database: str = "", - options: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -418,7 +418,7 @@ def get_database( async def async_get_database( self, database: str = "", - options: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> API_RESPONSE: """ @@ -446,7 +446,7 @@ def create_keyspace( database: str = "", keyspace: str = "", timeout_info: TimeoutInfoWideType = None, - ) -> Dict[str, str]: + ) -> dict[str, str]: """ Create a keyspace in a specified database. @@ -476,7 +476,7 @@ async def async_create_keyspace( database: str = "", keyspace: str = "", timeout_info: TimeoutInfoWideType = None, - ) -> Dict[str, str]: + ) -> dict[str, str]: """ Create a keyspace in a specified database - async version of the method. @@ -600,7 +600,7 @@ def unpark_database( def resize_database( self, database: str = "", - options: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -623,7 +623,7 @@ def resize_database( def reset_database_password( self, database: str = "", - options: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -682,7 +682,7 @@ def get_datacenters( def create_datacenter( self, database: str = "", - options: Optional[Dict[str, Any]] = None, + options: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -745,7 +745,7 @@ def get_access_list( def replace_access_list( self, database: str = "", - access_list: Optional[Dict[str, Any]] = None, + access_list: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -768,7 +768,7 @@ def replace_access_list( def update_access_list( self, database: str = "", - access_list: Optional[Dict[str, Any]] = None, + access_list: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -791,7 +791,7 @@ def update_access_list( def add_access_list_address( self, database: str = "", - address: Optional[Dict[str, Any]] = None, + address: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -873,7 +873,7 @@ def create_datacenter_private_link( self, database: str = "", datacenter: str = "", - private_link: Optional[Dict[str, Any]] = None, + private_link: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -898,7 +898,7 @@ def create_datacenter_endpoint( self, database: str = "", datacenter: str = "", - endpoint: Optional[Dict[str, Any]] = None, + endpoint: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -923,7 +923,7 @@ def update_datacenter_endpoint( self, database: str = "", datacenter: str = "", - endpoint: Dict[str, Any] = {}, + endpoint: dict[str, Any] = {}, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -1035,7 +1035,7 @@ def get_roles(self, timeout_info: TimeoutInfoWideType = None) -> OPS_API_RESPONS def create_role( self, - role_definition: Optional[Dict[str, Any]] = None, + role_definition: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -1075,7 +1075,7 @@ def get_role( def update_role( self, role: str = "", - role_definition: Optional[Dict[str, Any]] = None, + role_definition: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -1115,7 +1115,7 @@ def delete_role( def invite_user( self, - user_definition: Optional[Dict[str, Any]] = None, + user_definition: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -1186,7 +1186,7 @@ def remove_user( def update_user_roles( self, user: str = "", - roles: Optional[Dict[str, Any]] = None, + roles: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -1219,7 +1219,7 @@ def get_clients(self, timeout_info: TimeoutInfoWideType = None) -> OPS_API_RESPO def create_token( self, - roles: Optional[Dict[str, Any]] = None, + roles: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ @@ -1360,7 +1360,7 @@ def get_streaming_tenants( def create_streaming_tenant( self, - tenant: Optional[Dict[str, Any]] = None, + tenant: dict[str, Any] | None = None, timeout_info: TimeoutInfoWideType = None, ) -> OPS_API_RESPONSE: """ diff --git a/astrapy/core/utils.py b/astrapy/core/utils.py index dad44019..6eece183 100644 --- a/astrapy/core/utils.py +++ b/astrapy/core/utils.py @@ -18,7 +18,7 @@ import json import logging import time -from typing import Any, Dict, Iterable, List, Optional, TypedDict, Union, cast +from typing import Any, Dict, Iterable, TypedDict, Union, cast import httpx @@ -57,7 +57,7 @@ class http_methods: user_agent_astrapy = f"{package_name}/{__version__}" -def detect_ragstack_user_agent() -> Optional[str]: +def detect_ragstack_user_agent() -> str | None: from importlib import metadata from importlib.metadata import PackageNotFoundError @@ -77,9 +77,9 @@ def detect_ragstack_user_agent() -> Optional[str]: def log_request( method: str, url: str, - params: Optional[Dict[str, Any]], - headers: Dict[str, str], - json_data: Optional[Dict[str, Any]], + params: dict[str, Any] | None, + headers: dict[str, str], + json_data: dict[str, Any] | None, ) -> None: """ Log the details of an HTTP request for debugging purposes. @@ -116,8 +116,8 @@ def log_response(r: httpx.Response) -> None: def user_agent_string( - caller_name: Optional[str], caller_version: Optional[str] -) -> Optional[str]: + caller_name: str | None, caller_version: str | None +) -> str | None: if caller_name: if caller_version: return f"{caller_name}/{caller_version}" @@ -127,9 +127,7 @@ def user_agent_string( return None -def compose_user_agent( - caller_name: Optional[str], caller_version: Optional[str] -) -> str: +def compose_user_agent(caller_name: str | None, caller_version: str | None) -> str: user_agent_caller = user_agent_string(caller_name, caller_version) all_user_agents = [ ua_block @@ -152,7 +150,7 @@ class TimeoutInfo(TypedDict, total=False): TimeoutInfoWideType = Union[TimeoutInfo, float, None] -def to_httpx_timeout(timeout_info: TimeoutInfoWideType) -> Union[httpx.Timeout, None]: +def to_httpx_timeout(timeout_info: TimeoutInfoWideType) -> httpx.Timeout | None: if timeout_info is None: return None if isinstance(timeout_info, float) or isinstance(timeout_info, int): @@ -170,15 +168,15 @@ def make_request( client: httpx.Client, base_url: str, auth_header: str, - token: Optional[str], + token: str | None, method: str, - json_data: Optional[Dict[str, Any]], - url_params: Optional[Dict[str, Any]], - path: Optional[str], - caller_name: Optional[str], - caller_version: Optional[str], - timeout: Optional[Union[httpx.Timeout, float]], - additional_headers: Dict[str, str], + json_data: dict[str, Any] | None, + url_params: dict[str, Any] | None, + path: str | None, + caller_name: str | None, + caller_version: str | None, + timeout: httpx.Timeout | float | None, + additional_headers: dict[str, str], ) -> httpx.Response: """ Make an HTTP request to a specified URL. @@ -233,15 +231,15 @@ async def amake_request( client: httpx.AsyncClient, base_url: str, auth_header: str, - token: Optional[str], + token: str | None, method: str, - path: Optional[str], - json_data: Optional[Dict[str, Any]], - url_params: Optional[Dict[str, Any]], - caller_name: Optional[str], - caller_version: Optional[str], - timeout: Optional[Union[httpx.Timeout, float]], - additional_headers: Dict[str, str], + path: str | None, + json_data: dict[str, Any] | None, + url_params: dict[str, Any] | None, + caller_name: str | None, + caller_version: str | None, + timeout: httpx.Timeout | float | None, + additional_headers: dict[str, str], ) -> httpx.Response: """ Make an HTTP request to a specified URL. @@ -292,7 +290,7 @@ async def amake_request( return r -def make_payload(top_level: str, **kwargs: Any) -> Dict[str, Any]: +def make_payload(top_level: str, **kwargs: Any) -> dict[str, Any]: """ Construct a JSON payload for an HTTP request with a specified top-level key. @@ -307,7 +305,7 @@ def make_payload(top_level: str, **kwargs: Any) -> Dict[str, Any]: for key, value in kwargs.items(): params[key] = value - json_query: Dict[str, Any] = {top_level: {}} + json_query: dict[str, Any] = {top_level: {}} # Adding keys only if they're provided for key, value in params.items(): @@ -317,7 +315,7 @@ def make_payload(top_level: str, **kwargs: Any) -> Dict[str, Any]: return json_query -def convert_vector_to_floats(vector: Iterable[Any]) -> List[float]: +def convert_vector_to_floats(vector: Iterable[Any]) -> list[float]: """ Convert a vector of strings to a vector of floats. @@ -341,36 +339,36 @@ def is_list_of_floats(vector: Iterable[Any]) -> bool: def convert_to_ejson_date_object( - date_value: Union[datetime.date, datetime.datetime] -) -> Dict[str, int]: + date_value: datetime.date | datetime.datetime, +) -> dict[str, int]: return {"$date": int(time.mktime(date_value.timetuple()) * 1000)} -def convert_to_ejson_uuid_object(uuid_value: UUID) -> Dict[str, str]: +def convert_to_ejson_uuid_object(uuid_value: UUID) -> dict[str, str]: return {"$uuid": str(uuid_value)} -def convert_to_ejson_objectid_object(objectid_value: ObjectId) -> Dict[str, str]: +def convert_to_ejson_objectid_object(objectid_value: ObjectId) -> dict[str, str]: return {"$objectId": str(objectid_value)} def convert_ejson_date_object_to_datetime( - date_object: Dict[str, int] + date_object: dict[str, int], ) -> datetime.datetime: return datetime.datetime.fromtimestamp(date_object["$date"] / 1000.0) -def convert_ejson_uuid_object_to_uuid(uuid_object: Dict[str, str]) -> UUID: +def convert_ejson_uuid_object_to_uuid(uuid_object: dict[str, str]) -> UUID: return UUID(uuid_object["$uuid"]) def convert_ejson_objectid_object_to_objectid( - objectid_object: Dict[str, str] + objectid_object: dict[str, str], ) -> ObjectId: return ObjectId(objectid_object["$objectId"]) -def _normalize_payload_value(path: List[str], value: Any) -> Any: +def _normalize_payload_value(path: list[str], value: Any) -> Any: """ The path helps determining special treatments """ @@ -401,9 +399,7 @@ def _normalize_payload_value(path: List[str], value: Any) -> Any: return value -def normalize_for_api( - payload: Union[Dict[str, Any], None] -) -> Union[Dict[str, Any], None]: +def normalize_for_api(payload: dict[str, Any] | None) -> dict[str, Any] | None: """ Normalize a payload for API calls. This includes e.g. ensuring values for "$vector" key @@ -422,7 +418,7 @@ def normalize_for_api( return payload -def _restore_response_value(path: List[str], value: Any) -> Any: +def _restore_response_value(path: list[str], value: Any) -> Any: """ The path helps determining special treatments """ diff --git a/astrapy/cursors.py b/astrapy/cursors.py index fab0b314..bf336b11 100644 --- a/astrapy/cursors.py +++ b/astrapy/cursors.py @@ -23,15 +23,12 @@ TYPE_CHECKING, Any, Callable, - Dict, Generic, Iterable, Iterator, - List, Optional, Tuple, TypeVar, - Union, ) from astrapy.constants import ( @@ -58,7 +55,7 @@ IndexPairType = Tuple[str, Optional[int]] -def _maybe_valid_list_index(key_block: str) -> Optional[int]: +def _maybe_valid_list_index(key_block: str) -> int | None: # '0', '1' is good. '00', '01', '-30' are not. try: kb_index = int(key_block) @@ -72,9 +69,8 @@ def _maybe_valid_list_index(key_block: str) -> Optional[int]: def _create_document_key_extractor( key: str, -) -> Callable[[Dict[str, Any]], Iterable[Any]]: - - key_blocks0: List[IndexPairType] = [ +) -> Callable[[dict[str, Any]], Iterable[Any]]: + key_blocks0: list[IndexPairType] = [ (kb_str, _maybe_valid_list_index(kb_str)) for kb_str in key.split(".") ] if key_blocks0 == []: @@ -83,7 +79,7 @@ def _create_document_key_extractor( raise ValueError("Field path components cannot be empty") def _extract_with_key_blocks( - key_blocks: List[IndexPairType], value: Any + key_blocks: list[IndexPairType], value: Any ) -> Iterable[Any]: if key_blocks == []: if isinstance(value, list): @@ -123,7 +119,7 @@ def _extract_with_key_blocks( # keyblocks are deeper than the document. Nothing to extract. return - def _item_extractor(document: Dict[str, Any]) -> Iterable[Any]: + def _item_extractor(document: dict[str, Any]) -> Iterable[Any]: return _extract_with_key_blocks(key_blocks=key_blocks0, value=document) return _item_extractor @@ -148,7 +144,7 @@ def _reduce_distinct_key_to_safe(distinct_key: str) -> str: return ".".join(valid_portion) -def _hash_document(document: Dict[str, Any]) -> str: +def _hash_document(document: dict[str, Any]) -> str: _normalized_item = normalize_payload_value(path=[], value=document) _normalized_json = json.dumps( _normalized_item, sort_keys=True, separators=(",", ":") @@ -165,7 +161,7 @@ class _LookAheadIterator: def __init__(self, iterator: Iterator[DocumentType]): self.iterator = iterator - self.preread_item: Optional[DocumentType] = None + self.preread_item: DocumentType | None = None self.has_preread = False self.preread_exhausted = False @@ -201,7 +197,7 @@ class _AsyncLookAheadIterator: def __init__(self, async_iterator: AsyncIterator[DocumentType]): self.async_iterator = async_iterator - self.preread_item: Optional[DocumentType] = None + self.preread_item: DocumentType | None = None self.has_preread = False self.preread_exhausted = False @@ -237,30 +233,30 @@ class BaseCursor: See classes Cursor and AsyncCursor for more information. """ - _collection: Union[Collection, AsyncCollection] - _filter: Optional[Dict[str, Any]] - _projection: Optional[ProjectionType] - _max_time_ms: Optional[int] - _overall_max_time_ms: Optional[int] - _started_time_s: Optional[float] - _limit: Optional[int] - _skip: Optional[int] - _include_similarity: Optional[bool] - _include_sort_vector: Optional[bool] - _sort: Optional[Dict[str, Any]] + _collection: Collection | AsyncCollection + _filter: dict[str, Any] | None + _projection: ProjectionType | None + _max_time_ms: int | None + _overall_max_time_ms: int | None + _started_time_s: float | None + _limit: int | None + _skip: int | None + _include_similarity: bool | None + _include_sort_vector: bool | None + _sort: dict[str, Any] | None _started: bool _retrieved: int _alive: bool - _iterator: Optional[Union[_LookAheadIterator, _AsyncLookAheadIterator]] = None - _api_response_status: Optional[Dict[str, Any]] + _iterator: _LookAheadIterator | _AsyncLookAheadIterator | None = None + _api_response_status: dict[str, Any] | None def __init__( self, - collection: Union[Collection, AsyncCollection], - filter: Optional[Dict[str, Any]], - projection: Optional[ProjectionType], - max_time_ms: Optional[int], - overall_max_time_ms: Optional[int], + collection: Collection | AsyncCollection, + filter: dict[str, Any] | None, + projection: ProjectionType | None, + max_time_ms: int | None, + overall_max_time_ms: int | None, ) -> None: raise NotImplementedError @@ -315,15 +311,15 @@ def _ensure_not_started(self) -> None: def _copy( self: BC, *, - projection: Optional[ProjectionType] = None, - max_time_ms: Optional[int] = None, - overall_max_time_ms: Optional[int] = None, - limit: Optional[int] = None, - skip: Optional[int] = None, - include_similarity: Optional[bool] = None, - include_sort_vector: Optional[bool] = None, - started: Optional[bool] = None, - sort: Optional[Dict[str, Any]] = None, + projection: ProjectionType | None = None, + max_time_ms: int | None = None, + overall_max_time_ms: int | None = None, + limit: int | None = None, + skip: int | None = None, + include_similarity: bool | None = None, + include_sort_vector: bool | None = None, + started: bool | None = None, + sort: dict[str, Any] | None = None, ) -> BC: new_cursor = self.__class__( collection=self._collection, @@ -417,7 +413,7 @@ def cursor_id(self) -> int: return id(self) - def limit(self: BC, limit: Optional[int]) -> BC: + def limit(self: BC, limit: int | None) -> BC: """ Set a new `limit` value for this cursor. @@ -433,7 +429,7 @@ def limit(self: BC, limit: Optional[int]) -> BC: self._limit = limit if limit != 0 else None return self - def include_similarity(self: BC, include_similarity: Optional[bool]) -> BC: + def include_similarity(self: BC, include_similarity: bool | None) -> BC: """ Set a new `include_similarity` value for this cursor. @@ -449,7 +445,7 @@ def include_similarity(self: BC, include_similarity: Optional[bool]) -> BC: self._include_similarity = include_similarity return self - def include_sort_vector(self: BC, include_sort_vector: Optional[bool]) -> BC: + def include_sort_vector(self: BC, include_sort_vector: bool | None) -> BC: """ Set a new `include_sort_vector` value for this cursor. @@ -487,7 +483,7 @@ def rewind(self: BC) -> BC: self._iterator = None return self - def skip(self: BC, skip: Optional[int]) -> BC: + def skip(self: BC, skip: int | None) -> BC: """ Set a new `skip` value for this cursor. @@ -509,7 +505,7 @@ def skip(self: BC, skip: Optional[int]) -> BC: def sort( self: BC, - sort: Optional[Dict[str, Any]], + sort: dict[str, Any] | None, ) -> BC: """ Set a new `sort` value for this cursor. @@ -581,10 +577,10 @@ class Cursor(BaseCursor): def __init__( self, collection: Collection, - filter: Optional[Dict[str, Any]], - projection: Optional[ProjectionType], - max_time_ms: Optional[int], - overall_max_time_ms: Optional[int], + filter: dict[str, Any] | None, + projection: ProjectionType | None, + max_time_ms: int | None, + overall_max_time_ms: int | None, ) -> None: self._collection: Collection = collection self._filter = filter @@ -594,17 +590,17 @@ def __init__( self._max_time_ms = min(max_time_ms, overall_max_time_ms) else: self._max_time_ms = max_time_ms - self._limit: Optional[int] = None - self._skip: Optional[int] = None - self._include_similarity: Optional[bool] = None - self._include_sort_vector: Optional[bool] = None - self._sort: Optional[Dict[str, Any]] = None + self._limit: int | None = None + self._skip: int | None = None + self._include_similarity: bool | None = None + self._include_sort_vector: bool | None = None + self._sort: dict[str, Any] | None = None self._started = False self._retrieved = 0 self._alive = True # - self._iterator: Optional[_LookAheadIterator] = None - self._api_response_status: Optional[Dict[str, Any]] = None + self._iterator: _LookAheadIterator | None = None + self._api_response_status: dict[str, Any] | None = None def __iter__(self) -> Cursor: self._ensure_alive() @@ -638,7 +634,7 @@ def __next__(self) -> DocumentType: self._alive = False raise - def get_sort_vector(self) -> Optional[List[float]]: + def get_sort_vector(self) -> list[float] | None: """ Return the vector used in this ANN search, if applicable. If this is not an ANN search, or it was invoked without the @@ -695,7 +691,7 @@ def _create_iterator(self) -> _LookAheadIterator: } def _find_iterator() -> Iterator[DocumentType]: - next_page_state: Optional[str] = None + next_page_state: str | None = None # resp_0 = self._collection.command( body=f0_payload, @@ -762,7 +758,7 @@ def collection(self) -> Collection: return self._collection - def distinct(self, key: str, max_time_ms: Optional[int] = None) -> List[Any]: + def distinct(self, key: str, max_time_ms: int | None = None) -> list[Any]: """ Compute a list of unique values for a specific field across all documents the cursor iterates through. @@ -860,10 +856,10 @@ class AsyncCursor(BaseCursor): def __init__( self, collection: AsyncCollection, - filter: Optional[Dict[str, Any]], - projection: Optional[ProjectionType], - max_time_ms: Optional[int], - overall_max_time_ms: Optional[int], + filter: dict[str, Any] | None, + projection: ProjectionType | None, + max_time_ms: int | None, + overall_max_time_ms: int | None, ) -> None: self._collection: AsyncCollection = collection self._filter = filter @@ -873,17 +869,17 @@ def __init__( self._max_time_ms = min(max_time_ms, overall_max_time_ms) else: self._max_time_ms = max_time_ms - self._limit: Optional[int] = None - self._skip: Optional[int] = None - self._include_similarity: Optional[bool] = None - self._include_sort_vector: Optional[bool] = None - self._sort: Optional[Dict[str, Any]] = None + self._limit: int | None = None + self._skip: int | None = None + self._include_similarity: bool | None = None + self._include_sort_vector: bool | None = None + self._sort: dict[str, Any] | None = None self._started = False self._retrieved = 0 self._alive = True # - self._iterator: Optional[_AsyncLookAheadIterator] = None - self._api_response_status: Optional[Dict[str, Any]] = None + self._iterator: _AsyncLookAheadIterator | None = None + self._api_response_status: dict[str, Any] | None = None def __aiter__(self) -> AsyncCursor: self._ensure_alive() @@ -917,7 +913,7 @@ async def __anext__(self) -> DocumentType: self._alive = False raise - async def get_sort_vector(self) -> Optional[List[float]]: + async def get_sort_vector(self) -> list[float] | None: """ Return the vector used in this ANN search, if applicable. If this is not an ANN search, or it was invoked without the @@ -1034,12 +1030,12 @@ async def _find_iterator() -> AsyncIterator[DocumentType]: def _to_sync( self: AsyncCursor, *, - limit: Optional[int] = None, - skip: Optional[int] = None, - include_similarity: Optional[bool] = None, - include_sort_vector: Optional[bool] = None, - started: Optional[bool] = None, - sort: Optional[Dict[str, Any]] = None, + limit: int | None = None, + skip: int | None = None, + include_similarity: bool | None = None, + include_sort_vector: bool | None = None, + started: bool | None = None, + sort: dict[str, Any] | None = None, ) -> Cursor: new_cursor = Cursor( collection=self._collection.to_sync(), @@ -1079,7 +1075,7 @@ def collection(self) -> AsyncCollection: return self._collection - async def distinct(self, key: str, max_time_ms: Optional[int] = None) -> List[Any]: + async def distinct(self, key: str, max_time_ms: int | None = None) -> list[Any]: """ Compute a list of unique values for a specific field across all documents the cursor iterates through. @@ -1142,7 +1138,7 @@ class CommandCursor(Generic[T]): (such as the database `list_collections` method). """ - def __init__(self, address: str, items: List[T]) -> None: + def __init__(self, address: str, items: list[T]) -> None: self._address = address self.items = items self.iterable = items.__iter__() @@ -1225,7 +1221,7 @@ class AsyncCommandCursor(Generic[T]): (such as the database `list_collections` method). """ - def __init__(self, address: str, items: List[T]) -> None: + def __init__(self, address: str, items: list[T]) -> None: self._address = address self.items = items self.iterable = items.__iter__() diff --git a/astrapy/database.py b/astrapy/database.py index f277661a..3156b3dc 100644 --- a/astrapy/database.py +++ b/astrapy/database.py @@ -16,7 +16,7 @@ import logging from types import TracebackType -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union +from typing import TYPE_CHECKING, Any from astrapy.admin import fetch_database_info, parse_api_endpoint from astrapy.api_commander import APICommander @@ -57,13 +57,13 @@ def _normalize_create_collection_options( - dimension: Optional[int], - metric: Optional[str], - service: Optional[Union[CollectionVectorServiceOptions, Dict[str, Any]]], - indexing: Optional[Dict[str, Any]], - default_id_type: Optional[str], - additional_options: Optional[Dict[str, Any]], -) -> Dict[str, Any]: + dimension: int | None, + metric: str | None, + service: CollectionVectorServiceOptions | dict[str, Any] | None, + indexing: dict[str, Any] | None, + default_id_type: str | None, + additional_options: dict[str, Any] | None, +) -> dict[str, Any]: """Raise errors related to invalid input, and return a ready-to-send payload.""" is_vector: bool if service is not None or dimension is not None: @@ -76,7 +76,7 @@ def _normalize_create_collection_options( "create_collection method." ) # prepare the payload - service_dict: Optional[Dict[str, Any]] + service_dict: dict[str, Any] | None if service is not None: service_dict = service if isinstance(service, dict) else service.as_dict() else: @@ -165,19 +165,19 @@ class Database: def __init__( self, api_endpoint: str, - token: Optional[Union[str, TokenProvider]] = None, + token: str | TokenProvider | None = None, *, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - environment: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + environment: str | None = None, + api_path: str | None = None, + api_version: str | None = None, ) -> None: self.environment = (environment or Environment.PROD).lower() # - _api_path: Optional[str] - _api_version: Optional[str] + _api_path: str | None + _api_version: str | None if api_path is None: _api_path = API_PATH_ENV_MAP[self.environment] else: @@ -192,7 +192,7 @@ def __init__( self.api_version = _api_version # enforce defaults if on Astra DB: - self.using_namespace: Optional[str] + self.using_namespace: str | None if namespace is None and self.environment in Environment.astra_db_values: self.using_namespace = DEFAULT_ASTRA_DB_NAMESPACE else: @@ -205,7 +205,7 @@ def __init__( self.caller_name = caller_name self.caller_version = caller_version self._api_commander = self._get_api_commander(namespace=self.namespace) - self._name: Optional[str] = None + self._name: str | None = None def __getattr__(self, collection_name: str) -> Collection: return self.get_collection(name=collection_name) @@ -215,12 +215,12 @@ def __getitem__(self, collection_name: str) -> Collection: def __repr__(self) -> str: ep_desc = f'api_endpoint="{self.api_endpoint}"' - token_desc: Optional[str] + token_desc: str | None if self.token_provider: token_desc = f'token="{redact_secret(str(self.token_provider), 15)}"' else: token_desc = None - namespace_desc: Optional[str] + namespace_desc: str | None if self.namespace is None: namespace_desc = "namespace not set" else: @@ -245,7 +245,7 @@ def __eq__(self, other: Any) -> bool: else: return False - def _get_api_commander(self, namespace: Optional[str]) -> Optional[APICommander]: + def _get_api_commander(self, namespace: str | None) -> APICommander | None: """ Instantiate a new APICommander based on the properties of this class and a provided namespace. @@ -274,12 +274,12 @@ def _get_api_commander(self, namespace: Optional[str]) -> Optional[APICommander] ) return api_commander - def _get_driver_commander(self, namespace: Optional[str]) -> APICommander: + def _get_driver_commander(self, namespace: str | None) -> APICommander: """ Building on _get_api_commander, fall back to class namespace in creating/returning a commander, and in any case raise an error if not set. """ - driver_commander: Optional[APICommander] + driver_commander: APICommander | None if namespace: driver_commander = self._get_api_commander(namespace=namespace) else: @@ -294,14 +294,14 @@ def _get_driver_commander(self, namespace: Optional[str]) -> APICommander: def _copy( self, *, - api_endpoint: Optional[str] = None, - token: Optional[Union[str, TokenProvider]] = None, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - environment: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, + api_endpoint: str | None = None, + token: str | TokenProvider | None = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + environment: str | None = None, + api_path: str | None = None, + api_version: str | None = None, ) -> Database: return Database( api_endpoint=api_endpoint or self.api_endpoint, @@ -317,9 +317,9 @@ def _copy( def with_options( self, *, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> Database: """ Create a clone of this database with some changed attributes. @@ -352,14 +352,14 @@ def with_options( def to_async( self, *, - api_endpoint: Optional[str] = None, - token: Optional[Union[str, TokenProvider]] = None, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - environment: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, + api_endpoint: str | None = None, + token: str | TokenProvider | None = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + environment: str | None = None, + api_path: str | None = None, + api_version: str | None = None, ) -> AsyncDatabase: """ Create an AsyncDatabase from this one. Save for the arguments @@ -407,8 +407,8 @@ def to_async( def set_caller( self, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: """ Set a new identity for the application/framework on behalf of which @@ -523,7 +523,7 @@ def name(self) -> str: return self._name @property - def namespace(self) -> Optional[str]: + def namespace(self) -> str | None: """ The namespace this database uses as target for all commands when no method-call-specific namespace is specified. @@ -542,9 +542,9 @@ def get_collection( self, name: str, *, - namespace: Optional[str] = None, - embedding_api_key: Optional[Union[str, EmbeddingHeadersProvider]] = None, - collection_max_time_ms: Optional[int] = None, + namespace: str | None = None, + embedding_api_key: str | EmbeddingHeadersProvider | None = None, + collection_max_time_ms: int | None = None, ) -> Collection: """ Spawn a `Collection` object instance representing a collection @@ -618,17 +618,17 @@ def create_collection( self, name: str, *, - namespace: Optional[str] = None, - dimension: Optional[int] = None, - metric: Optional[str] = None, - service: Optional[Union[CollectionVectorServiceOptions, Dict[str, Any]]] = None, - indexing: Optional[Dict[str, Any]] = None, - default_id_type: Optional[str] = None, - additional_options: Optional[Dict[str, Any]] = None, - check_exists: Optional[bool] = None, - max_time_ms: Optional[int] = None, - embedding_api_key: Optional[Union[str, EmbeddingHeadersProvider]] = None, - collection_max_time_ms: Optional[int] = None, + namespace: str | None = None, + dimension: int | None = None, + metric: str | None = None, + service: CollectionVectorServiceOptions | dict[str, Any] | None = None, + indexing: dict[str, Any] | None = None, + default_id_type: str | None = None, + additional_options: dict[str, Any] | None = None, + check_exists: bool | None = None, + max_time_ms: int | None = None, + embedding_api_key: str | EmbeddingHeadersProvider | None = None, + collection_max_time_ms: int | None = None, ) -> Collection: """ Creates a collection on the database and return the Collection @@ -752,10 +752,10 @@ def create_collection( def drop_collection( self, - name_or_collection: Union[str, Collection], + name_or_collection: str | Collection, *, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Drop a collection from the database, along with all documents therein. @@ -783,7 +783,7 @@ def drop_collection( # lazy importing here against circular-import error from astrapy.collection import Collection - _namespace: Optional[str] + _namespace: str | None _collection_name: str if isinstance(name_or_collection, Collection): _namespace = name_or_collection.namespace @@ -804,8 +804,8 @@ def drop_collection( def list_collections( self, *, - namespace: Optional[str] = None, - max_time_ms: Optional[int] = None, + namespace: str | None = None, + max_time_ms: int | None = None, ) -> CommandCursor[CollectionDescriptor]: """ List all collections in a given namespace for this database. @@ -857,9 +857,9 @@ def list_collections( def list_collection_names( self, *, - namespace: Optional[str] = None, - max_time_ms: Optional[int] = None, - ) -> List[str]: + namespace: str | None = None, + max_time_ms: int | None = None, + ) -> list[str]: """ List the names of all collections in a given namespace of this database. @@ -877,7 +877,7 @@ def list_collection_names( """ driver_commander = self._get_driver_commander(namespace=namespace) - gc_payload: Dict[str, Any] = {"findCollections": {}} + gc_payload: dict[str, Any] = {"findCollections": {}} logger.info("findCollections") gc_response = driver_commander.request( payload=gc_payload, @@ -895,13 +895,13 @@ def list_collection_names( def command( self, - body: Dict[str, Any], + body: dict[str, Any], *, - namespace: Optional[str] = None, - collection_name: Optional[str] = None, + namespace: str | None = None, + collection_name: str | None = None, raise_api_errors: bool = True, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Send a POST request to the Data API for this database with an arbitrary, caller-provided payload. @@ -959,9 +959,9 @@ def command( def get_database_admin( self, *, - token: Optional[Union[str, TokenProvider]] = None, - dev_ops_url: Optional[str] = None, - dev_ops_api_version: Optional[str] = None, + token: str | TokenProvider | None = None, + dev_ops_url: str | None = None, + dev_ops_api_version: str | None = None, ) -> DatabaseAdmin: """ Return a DatabaseAdmin object corresponding to this database, for @@ -1085,19 +1085,19 @@ class AsyncDatabase: def __init__( self, api_endpoint: str, - token: Optional[Union[str, TokenProvider]] = None, + token: str | TokenProvider | None = None, *, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - environment: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + environment: str | None = None, + api_path: str | None = None, + api_version: str | None = None, ) -> None: self.environment = (environment or Environment.PROD).lower() # - _api_path: Optional[str] - _api_version: Optional[str] + _api_path: str | None + _api_version: str | None if api_path is None: _api_path = API_PATH_ENV_MAP[self.environment] else: @@ -1112,7 +1112,7 @@ def __init__( self.api_version = _api_version # enforce defaults if on Astra DB: - self.using_namespace: Optional[str] + self.using_namespace: str | None if namespace is None and self.environment in Environment.astra_db_values: self.using_namespace = DEFAULT_ASTRA_DB_NAMESPACE else: @@ -1125,7 +1125,7 @@ def __init__( self.caller_name = caller_name self.caller_version = caller_version self._api_commander = self._get_api_commander(namespace=self.namespace) - self._name: Optional[str] = None + self._name: str | None = None def __getattr__(self, collection_name: str) -> AsyncCollection: return self.to_sync().get_collection(name=collection_name).to_async() @@ -1135,12 +1135,12 @@ def __getitem__(self, collection_name: str) -> AsyncCollection: def __repr__(self) -> str: ep_desc = f'api_endpoint="{self.api_endpoint}"' - token_desc: Optional[str] + token_desc: str | None if self.token_provider: token_desc = f'token="{redact_secret(str(self.token_provider), 15)}"' else: token_desc = None - namespace_desc: Optional[str] + namespace_desc: str | None if self.namespace is None: namespace_desc = "namespace not set" else: @@ -1165,7 +1165,7 @@ def __eq__(self, other: Any) -> bool: else: return False - def _get_api_commander(self, namespace: Optional[str]) -> Optional[APICommander]: + def _get_api_commander(self, namespace: str | None) -> APICommander | None: """ Instantiate a new APICommander based on the properties of this class and a provided namespace. @@ -1194,12 +1194,12 @@ def _get_api_commander(self, namespace: Optional[str]) -> Optional[APICommander] ) return api_commander - def _get_driver_commander(self, namespace: Optional[str]) -> APICommander: + def _get_driver_commander(self, namespace: str | None) -> APICommander: """ Building on _get_api_commander, fall back to class namespace in creating/returning a commander, and in any case raise an error if not set. """ - driver_commander: Optional[APICommander] + driver_commander: APICommander | None if namespace: driver_commander = self._get_api_commander(namespace=namespace) else: @@ -1216,9 +1216,9 @@ async def __aenter__(self) -> AsyncDatabase: async def __aexit__( self, - exc_type: Optional[Type[BaseException]] = None, - exc_value: Optional[BaseException] = None, - traceback: Optional[TracebackType] = None, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, ) -> None: if self._api_commander is not None: await self._api_commander.__aexit__( @@ -1230,14 +1230,14 @@ async def __aexit__( def _copy( self, *, - api_endpoint: Optional[str] = None, - token: Optional[Union[str, TokenProvider]] = None, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - environment: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, + api_endpoint: str | None = None, + token: str | TokenProvider | None = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + environment: str | None = None, + api_path: str | None = None, + api_version: str | None = None, ) -> AsyncDatabase: return AsyncDatabase( api_endpoint=api_endpoint or self.api_endpoint, @@ -1253,9 +1253,9 @@ def _copy( def with_options( self, *, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> AsyncDatabase: """ Create a clone of this database with some changed attributes. @@ -1288,14 +1288,14 @@ def with_options( def to_sync( self, *, - api_endpoint: Optional[str] = None, - token: Optional[Union[str, TokenProvider]] = None, - namespace: Optional[str] = None, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, - environment: Optional[str] = None, - api_path: Optional[str] = None, - api_version: Optional[str] = None, + api_endpoint: str | None = None, + token: str | TokenProvider | None = None, + namespace: str | None = None, + caller_name: str | None = None, + caller_version: str | None = None, + environment: str | None = None, + api_path: str | None = None, + api_version: str | None = None, ) -> Database: """ Create a (synchronous) Database from this one. Save for the arguments @@ -1344,8 +1344,8 @@ def to_sync( def set_caller( self, - caller_name: Optional[str] = None, - caller_version: Optional[str] = None, + caller_name: str | None = None, + caller_version: str | None = None, ) -> None: """ Set a new identity for the application/framework on behalf of which @@ -1460,7 +1460,7 @@ def name(self) -> str: return self._name @property - def namespace(self) -> Optional[str]: + def namespace(self) -> str | None: """ The namespace this database uses as target for all commands when no method-call-specific namespace is specified. @@ -1479,9 +1479,9 @@ async def get_collection( self, name: str, *, - namespace: Optional[str] = None, - embedding_api_key: Optional[Union[str, EmbeddingHeadersProvider]] = None, - collection_max_time_ms: Optional[int] = None, + namespace: str | None = None, + embedding_api_key: str | EmbeddingHeadersProvider | None = None, + collection_max_time_ms: int | None = None, ) -> AsyncCollection: """ Spawn an `AsyncCollection` object instance representing a collection @@ -1558,17 +1558,17 @@ async def create_collection( self, name: str, *, - namespace: Optional[str] = None, - dimension: Optional[int] = None, - metric: Optional[str] = None, - service: Optional[Union[CollectionVectorServiceOptions, Dict[str, Any]]] = None, - indexing: Optional[Dict[str, Any]] = None, - default_id_type: Optional[str] = None, - additional_options: Optional[Dict[str, Any]] = None, - check_exists: Optional[bool] = None, - max_time_ms: Optional[int] = None, - embedding_api_key: Optional[Union[str, EmbeddingHeadersProvider]] = None, - collection_max_time_ms: Optional[int] = None, + namespace: str | None = None, + dimension: int | None = None, + metric: str | None = None, + service: CollectionVectorServiceOptions | dict[str, Any] | None = None, + indexing: dict[str, Any] | None = None, + default_id_type: str | None = None, + additional_options: dict[str, Any] | None = None, + check_exists: bool | None = None, + max_time_ms: int | None = None, + embedding_api_key: str | EmbeddingHeadersProvider | None = None, + collection_max_time_ms: int | None = None, ) -> AsyncCollection: """ Creates a collection on the database and return the AsyncCollection @@ -1696,10 +1696,10 @@ async def create_collection( async def drop_collection( self, - name_or_collection: Union[str, AsyncCollection], + name_or_collection: str | AsyncCollection, *, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Drop a collection from the database, along with all documents therein. @@ -1727,7 +1727,7 @@ async def drop_collection( # lazy importing here against circular-import error from astrapy.collection import AsyncCollection - _namespace: Optional[str] + _namespace: str | None _collection_name: str if isinstance(name_or_collection, AsyncCollection): _namespace = name_or_collection.namespace @@ -1748,8 +1748,8 @@ async def drop_collection( def list_collections( self, *, - namespace: Optional[str] = None, - max_time_ms: Optional[int] = None, + namespace: str | None = None, + max_time_ms: int | None = None, ) -> AsyncCommandCursor[CollectionDescriptor]: """ List all collections in a given namespace for this database. @@ -1803,9 +1803,9 @@ def list_collections( async def list_collection_names( self, *, - namespace: Optional[str] = None, - max_time_ms: Optional[int] = None, - ) -> List[str]: + namespace: str | None = None, + max_time_ms: int | None = None, + ) -> list[str]: """ List the names of all collections in a given namespace of this database. @@ -1823,7 +1823,7 @@ async def list_collection_names( """ driver_commander = self._get_driver_commander(namespace=namespace) - gc_payload: Dict[str, Any] = {"findCollections": {}} + gc_payload: dict[str, Any] = {"findCollections": {}} logger.info("findCollections") gc_response = await driver_commander.async_request( payload=gc_payload, @@ -1841,13 +1841,13 @@ async def list_collection_names( async def command( self, - body: Dict[str, Any], + body: dict[str, Any], *, - namespace: Optional[str] = None, - collection_name: Optional[str] = None, + namespace: str | None = None, + collection_name: str | None = None, raise_api_errors: bool = True, - max_time_ms: Optional[int] = None, - ) -> Dict[str, Any]: + max_time_ms: int | None = None, + ) -> dict[str, Any]: """ Send a POST request to the Data API for this database with an arbitrary, caller-provided payload. @@ -1909,9 +1909,9 @@ async def command( def get_database_admin( self, *, - token: Optional[Union[str, TokenProvider]] = None, - dev_ops_url: Optional[str] = None, - dev_ops_api_version: Optional[str] = None, + token: str | TokenProvider | None = None, + dev_ops_url: str | None = None, + dev_ops_api_version: str | None = None, ) -> DatabaseAdmin: """ Return a DatabaseAdmin object corresponding to this database, for diff --git a/astrapy/exceptions.py b/astrapy/exceptions.py index 77d1d438..9ca1fe90 100644 --- a/astrapy/exceptions.py +++ b/astrapy/exceptions.py @@ -16,7 +16,7 @@ import time from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any import httpx @@ -55,15 +55,15 @@ class DevOpsAPIHttpException(DevOpsAPIException, httpx.HTTPStatusError): found in the response. """ - text: Optional[str] - error_descriptors: List[DevOpsAPIErrorDescriptor] + text: str | None + error_descriptors: list[DevOpsAPIErrorDescriptor] def __init__( self, - text: Optional[str], + text: str | None, *, httpx_error: httpx.HTTPStatusError, - error_descriptors: List[DevOpsAPIErrorDescriptor], + error_descriptors: list[DevOpsAPIErrorDescriptor], ) -> None: DataAPIException.__init__(self, text) httpx.HTTPStatusError.__init__( @@ -87,7 +87,7 @@ def from_httpx_error( ) -> DevOpsAPIHttpException: """Parse a httpx status error into this exception.""" - raw_response: Dict[str, Any] + raw_response: dict[str, Any] # the attempt to extract a response structure cannot afford failure. try: raw_response = httpx_error.response.json() @@ -128,16 +128,16 @@ class DevOpsAPITimeoutException(DevOpsAPIException): text: str timeout_type: str - endpoint: Optional[str] - raw_payload: Optional[str] + endpoint: str | None + raw_payload: str | None def __init__( self, text: str, *, timeout_type: str, - endpoint: Optional[str], - raw_payload: Optional[str], + endpoint: str | None, + raw_payload: str | None, ) -> None: super().__init__(text) self.text = text @@ -160,11 +160,11 @@ class DevOpsAPIErrorDescriptor: attributes: a dict with any further key-value pairs returned by the API. """ - id: Optional[int] - message: Optional[str] - attributes: Dict[str, Any] + id: int | None + message: str | None + attributes: dict[str, Any] - def __init__(self, error_dict: Dict[str, Any]) -> None: + def __init__(self, error_dict: dict[str, Any]) -> None: self.id = error_dict.get("ID") self.message = error_dict.get("message") self.attributes = { @@ -184,12 +184,12 @@ class DevOpsAPIFaultyResponseException(DevOpsAPIException): """ text: str - raw_response: Optional[Dict[str, Any]] + raw_response: dict[str, Any] | None def __init__( self, text: str, - raw_response: Optional[Dict[str, Any]], + raw_response: dict[str, Any] | None, ) -> None: super().__init__(text) self.text = text @@ -208,16 +208,16 @@ class DevOpsAPIResponseException(DevOpsAPIException): returned by the API in the response. """ - text: Optional[str] - command: Optional[Dict[str, Any]] - error_descriptors: List[DevOpsAPIErrorDescriptor] + text: str | None + command: dict[str, Any] | None + error_descriptors: list[DevOpsAPIErrorDescriptor] def __init__( self, - text: Optional[str] = None, + text: str | None = None, *, - command: Optional[Dict[str, Any]] = None, - error_descriptors: List[DevOpsAPIErrorDescriptor] = [], + command: dict[str, Any] | None = None, + error_descriptors: list[DevOpsAPIErrorDescriptor] = [], ) -> None: super().__init__(text or self.__class__.__name__) self.text = text @@ -226,8 +226,8 @@ def __init__( @staticmethod def from_response( - command: Optional[Dict[str, Any]], - raw_response: Dict[str, Any], + command: dict[str, Any] | None, + raw_response: dict[str, Any], ) -> DevOpsAPIResponseException: """Parse a raw response from the API into this exception.""" @@ -263,13 +263,13 @@ class DataAPIErrorDescriptor: attributes: a dict with any further key-value pairs returned by the API. """ - title: Optional[str] - error_code: Optional[str] - message: Optional[str] - family: Optional[str] - scope: Optional[str] - id: Optional[str] - attributes: Dict[str, Any] + title: str | None + error_code: str | None + message: str | None + family: str | None + scope: str | None + id: str | None + attributes: dict[str, Any] _known_dict_fields = { "title", @@ -280,7 +280,7 @@ class DataAPIErrorDescriptor: "id", } - def __init__(self, error_dict: Dict[str, str]) -> None: + def __init__(self, error_dict: dict[str, str]) -> None: self.title = error_dict.get("title") self.error_code = error_dict.get("errorCode") self.message = error_dict.get("message") @@ -322,9 +322,9 @@ class DataAPIDetailedErrorDescriptor: raw_response: the full API response in the form of a dict. """ - error_descriptors: List[DataAPIErrorDescriptor] - command: Optional[Dict[str, Any]] - raw_response: Dict[str, Any] + error_descriptors: list[DataAPIErrorDescriptor] + command: dict[str, Any] | None + raw_response: dict[str, Any] class DataAPIException(ValueError): @@ -356,15 +356,15 @@ class DataAPIHttpException(DataAPIException, httpx.HTTPStatusError): found in the response. """ - text: Optional[str] - error_descriptors: List[DataAPIErrorDescriptor] + text: str | None + error_descriptors: list[DataAPIErrorDescriptor] def __init__( self, - text: Optional[str], + text: str | None, *, httpx_error: httpx.HTTPStatusError, - error_descriptors: List[DataAPIErrorDescriptor], + error_descriptors: list[DataAPIErrorDescriptor], ) -> None: DataAPIException.__init__(self, text) httpx.HTTPStatusError.__init__( @@ -388,7 +388,7 @@ def from_httpx_error( ) -> DataAPIHttpException: """Parse a httpx status error into this exception.""" - raw_response: Dict[str, Any] + raw_response: dict[str, Any] # the attempt to extract a response structure cannot afford failure. try: raw_response = httpx_error.response.json() @@ -431,16 +431,16 @@ class DataAPITimeoutException(DataAPIException): text: str timeout_type: str - endpoint: Optional[str] - raw_payload: Optional[str] + endpoint: str | None + raw_payload: str | None def __init__( self, text: str, *, timeout_type: str, - endpoint: Optional[str], - raw_payload: Optional[str], + endpoint: str | None, + raw_payload: str | None, ) -> None: super().__init__(text) self.text = text @@ -573,12 +573,12 @@ class DataAPIFaultyResponseException(DataAPIException): """ text: str - raw_response: Optional[Dict[str, Any]] + raw_response: dict[str, Any] | None def __init__( self, text: str, - raw_response: Optional[Dict[str, Any]], + raw_response: dict[str, Any] | None, ) -> None: super().__init__(text) self.text = text @@ -608,16 +608,16 @@ class DataAPIResponseException(DataAPIException): has a single element. """ - text: Optional[str] - error_descriptors: List[DataAPIErrorDescriptor] - detailed_error_descriptors: List[DataAPIDetailedErrorDescriptor] + text: str | None + error_descriptors: list[DataAPIErrorDescriptor] + detailed_error_descriptors: list[DataAPIDetailedErrorDescriptor] def __init__( self, - text: Optional[str], + text: str | None, *, - error_descriptors: List[DataAPIErrorDescriptor], - detailed_error_descriptors: List[DataAPIDetailedErrorDescriptor], + error_descriptors: list[DataAPIErrorDescriptor], + detailed_error_descriptors: list[DataAPIDetailedErrorDescriptor], ) -> None: super().__init__(text) self.text = text @@ -627,8 +627,8 @@ def __init__( @classmethod def from_response( cls, - command: Optional[Dict[str, Any]], - raw_response: Dict[str, Any], + command: dict[str, Any] | None, + raw_response: dict[str, Any], **kwargs: Any, ) -> DataAPIResponseException: """Parse a raw response from the API into this exception.""" @@ -642,13 +642,13 @@ def from_response( @classmethod def from_responses( cls, - commands: List[Optional[Dict[str, Any]]], - raw_responses: List[Dict[str, Any]], + commands: list[dict[str, Any] | None], + raw_responses: list[dict[str, Any]], **kwargs: Any, ) -> DataAPIResponseException: """Parse a list of raw responses from the API into this exception.""" - detailed_error_descriptors: List[DataAPIDetailedErrorDescriptor] = [] + detailed_error_descriptors: list[DataAPIDetailedErrorDescriptor] = [] for command, raw_response in zip(commands, raw_responses): if raw_response.get("errors", []): error_descriptors = [ @@ -832,13 +832,13 @@ class BulkWriteException(DataAPIResponseException): """ partial_result: BulkWriteResult - exceptions: List[DataAPIResponseException] + exceptions: list[DataAPIResponseException] def __init__( self, - text: Optional[str], + text: str | None, partial_result: BulkWriteResult, - exceptions: List[DataAPIResponseException], + exceptions: list[DataAPIResponseException], *pargs: Any, **kwargs: Any, ) -> None: @@ -910,7 +910,7 @@ def to_devopsapi_timeout_exception( ) -def base_timeout_info(max_time_ms: Optional[int]) -> Union[TimeoutInfo, None]: +def base_timeout_info(max_time_ms: int | None) -> TimeoutInfo | None: if max_time_ms is not None: return {"base": max_time_ms / 1000.0} else: @@ -931,12 +931,12 @@ class MultiCallTimeoutManager: deadline_ms: optional deadline in milliseconds (computed by the class). """ - overall_max_time_ms: Optional[int] + overall_max_time_ms: int | None started_ms: int = -1 - deadline_ms: Optional[int] + deadline_ms: int | None def __init__( - self, overall_max_time_ms: Optional[int], dev_ops_api: bool = False + self, overall_max_time_ms: int | None, dev_ops_api: bool = False ) -> None: self.started_ms = int(time.time() * 1000) self.overall_max_time_ms = overall_max_time_ms @@ -946,7 +946,7 @@ def __init__( else: self.deadline_ms = None - def remaining_timeout_ms(self) -> Union[int, None]: + def remaining_timeout_ms(self) -> int | None: """ Ensure the deadline, if any, is not yet in the past. If it is, raise an appropriate timeout error. @@ -975,7 +975,7 @@ def remaining_timeout_ms(self) -> Union[int, None]: else: return None - def remaining_timeout_info(self) -> Union[TimeoutInfo, None]: + def remaining_timeout_info(self) -> TimeoutInfo | None: """ Ensure the deadline, if any, is not yet in the past. If it is, raise an appropriate timeout error. diff --git a/astrapy/info.py b/astrapy/info.py index 685bb6b7..05ad651b 100644 --- a/astrapy/info.py +++ b/astrapy/info.py @@ -16,7 +16,7 @@ import warnings from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any @dataclass @@ -51,10 +51,10 @@ class DatabaseInfo: id: str region: str - namespace: Optional[str] + namespace: str | None name: str environment: str - raw_info: Optional[Dict[str, Any]] + raw_info: dict[str, Any] | None @dataclass @@ -90,8 +90,8 @@ class AdminDatabaseInfo: """ info: DatabaseInfo - available_actions: Optional[List[str]] - cost: Dict[str, Any] + available_actions: list[str] | None + cost: dict[str, Any] cqlsh_url: str creation_time: str data_endpoint_url: str @@ -99,14 +99,14 @@ class AdminDatabaseInfo: graphql_url: str id: str last_usage_time: str - metrics: Dict[str, Any] + metrics: dict[str, Any] observed_status: str org_id: str owner_id: str status: str - storage: Dict[str, Any] + storage: dict[str, Any] termination_time: str - raw_info: Optional[Dict[str, Any]] + raw_info: dict[str, Any] | None @dataclass @@ -141,15 +141,13 @@ class CollectionDefaultIDOptions: default_id_type: str - def as_dict(self) -> Dict[str, Any]: + def as_dict(self) -> dict[str, Any]: """Recast this object into a dictionary.""" return {"type": self.default_id_type} @staticmethod - def from_dict( - raw_dict: Optional[Dict[str, Any]] - ) -> Optional[CollectionDefaultIDOptions]: + def from_dict(raw_dict: dict[str, Any] | None) -> CollectionDefaultIDOptions | None: """ Create an instance of CollectionDefaultIDOptions from a dictionary such as one from the Data API. @@ -176,12 +174,12 @@ class CollectionVectorServiceOptions: in the vector service options. """ - provider: Optional[str] - model_name: Optional[str] - authentication: Optional[Dict[str, Any]] = None - parameters: Optional[Dict[str, Any]] = None + provider: str | None + model_name: str | None + authentication: dict[str, Any] | None = None + parameters: dict[str, Any] | None = None - def as_dict(self) -> Dict[str, Any]: + def as_dict(self) -> dict[str, Any]: """Recast this object into a dictionary.""" return { @@ -197,8 +195,8 @@ def as_dict(self) -> Dict[str, Any]: @staticmethod def from_dict( - raw_dict: Optional[Dict[str, Any]] - ) -> Optional[CollectionVectorServiceOptions]: + raw_dict: dict[str, Any] | None, + ) -> CollectionVectorServiceOptions | None: """ Create an instance of CollectionVectorServiceOptions from a dictionary such as one from the Data API. @@ -229,11 +227,11 @@ class CollectionVectorOptions: service is configured for the collection. """ - dimension: Optional[int] - metric: Optional[str] - service: Optional[CollectionVectorServiceOptions] + dimension: int | None + metric: str | None + service: CollectionVectorServiceOptions | None - def as_dict(self) -> Dict[str, Any]: + def as_dict(self) -> dict[str, Any]: """Recast this object into a dictionary.""" return { @@ -247,9 +245,7 @@ def as_dict(self) -> Dict[str, Any]: } @staticmethod - def from_dict( - raw_dict: Optional[Dict[str, Any]] - ) -> Optional[CollectionVectorOptions]: + def from_dict(raw_dict: dict[str, Any] | None) -> CollectionVectorOptions | None: """ Create an instance of CollectionVectorOptions from a dictionary such as one from the Data API. @@ -280,10 +276,10 @@ class CollectionOptions: raw_options: the raw response from the Data API for the collection configuration. """ - vector: Optional[CollectionVectorOptions] - indexing: Optional[Dict[str, Any]] - default_id: Optional[CollectionDefaultIDOptions] - raw_options: Optional[Dict[str, Any]] + vector: CollectionVectorOptions | None + indexing: dict[str, Any] | None + default_id: CollectionDefaultIDOptions | None + raw_options: dict[str, Any] | None def __repr__(self) -> str: not_null_pieces = [ @@ -306,7 +302,7 @@ def __repr__(self) -> str: ] return f"{self.__class__.__name__}({', '.join(not_null_pieces)})" - def as_dict(self) -> Dict[str, Any]: + def as_dict(self) -> dict[str, Any]: """Recast this object into a dictionary.""" return { @@ -321,17 +317,17 @@ def as_dict(self) -> Dict[str, Any]: if v is not None } - def flatten(self) -> Dict[str, Any]: + def flatten(self) -> dict[str, Any]: """ Recast this object as a flat key-value pair suitable for use as kwargs in a create_collection method call (including recasts). """ - _dimension: Optional[int] - _metric: Optional[str] - _indexing: Optional[Dict[str, Any]] - _service: Optional[Dict[str, Any]] - _default_id_type: Optional[str] + _dimension: int | None + _metric: str | None + _indexing: dict[str, Any] | None + _service: dict[str, Any] | None + _default_id_type: str | None if self.vector is not None: _dimension = self.vector.dimension _metric = self.vector.metric @@ -362,7 +358,7 @@ def flatten(self) -> Dict[str, Any]: } @staticmethod - def from_dict(raw_dict: Dict[str, Any]) -> CollectionOptions: + def from_dict(raw_dict: dict[str, Any]) -> CollectionOptions: """ Create an instance of CollectionOptions from a dictionary such as one from the Data API. @@ -390,7 +386,7 @@ class CollectionDescriptor: name: str options: CollectionOptions - raw_descriptor: Optional[Dict[str, Any]] + raw_descriptor: dict[str, Any] | None def __repr__(self) -> str: not_null_pieces = [ @@ -404,7 +400,7 @@ def __repr__(self) -> str: ] return f"{self.__class__.__name__}({', '.join(not_null_pieces)})" - def as_dict(self) -> Dict[str, Any]: + def as_dict(self) -> dict[str, Any]: """ Recast this object into a dictionary. Empty `options` will not be returned at all. @@ -419,7 +415,7 @@ def as_dict(self) -> Dict[str, Any]: if v } - def flatten(self) -> Dict[str, Any]: + def flatten(self) -> dict[str, Any]: """ Recast this object as a flat key-value pair suitable for use as kwargs in a create_collection method call (including recasts). @@ -431,7 +427,7 @@ def flatten(self) -> Dict[str, Any]: } @staticmethod - def from_dict(raw_dict: Dict[str, Any]) -> CollectionDescriptor: + def from_dict(raw_dict: dict[str, Any]) -> CollectionDescriptor: """ Create an instance of CollectionDescriptor from a dictionary such as one from the Data API. @@ -460,18 +456,18 @@ class EmbeddingProviderParameter: """ default_value: Any - display_name: Optional[str] - help: Optional[str] - hint: Optional[str] + display_name: str | None + help: str | None + hint: str | None name: str required: bool parameter_type: str - validation: Dict[str, Any] + validation: dict[str, Any] def __repr__(self) -> str: return f"EmbeddingProviderParameter(name='{self.name}')" - def as_dict(self) -> Dict[str, Any]: + def as_dict(self) -> dict[str, Any]: """Recast this object into a dictionary.""" return { @@ -490,7 +486,7 @@ def as_dict(self) -> Dict[str, Any]: } @staticmethod - def from_dict(raw_dict: Dict[str, Any]) -> EmbeddingProviderParameter: + def from_dict(raw_dict: dict[str, Any]) -> EmbeddingProviderParameter: """ Create an instance of EmbeddingProviderParameter from a dictionary such as one from the Data API. @@ -539,13 +535,13 @@ class EmbeddingProviderModel: """ name: str - parameters: List[EmbeddingProviderParameter] - vector_dimension: Optional[int] + parameters: list[EmbeddingProviderParameter] + vector_dimension: int | None def __repr__(self) -> str: return f"EmbeddingProviderModel(name='{self.name}')" - def as_dict(self) -> Dict[str, Any]: + def as_dict(self) -> dict[str, Any]: """Recast this object into a dictionary.""" return { @@ -555,7 +551,7 @@ def as_dict(self) -> Dict[str, Any]: } @staticmethod - def from_dict(raw_dict: Dict[str, Any]) -> EmbeddingProviderModel: + def from_dict(raw_dict: dict[str, Any]) -> EmbeddingProviderModel: """ Create an instance of EmbeddingProviderModel from a dictionary such as one from the Data API. @@ -602,7 +598,7 @@ class EmbeddingProviderToken: def __repr__(self) -> str: return f"EmbeddingProviderToken('{self.accepted}')" - def as_dict(self) -> Dict[str, Any]: + def as_dict(self) -> dict[str, Any]: """Recast this object into a dictionary.""" return { @@ -611,7 +607,7 @@ def as_dict(self) -> Dict[str, Any]: } @staticmethod - def from_dict(raw_dict: Dict[str, Any]) -> EmbeddingProviderToken: + def from_dict(raw_dict: dict[str, Any]) -> EmbeddingProviderToken: """ Create an instance of EmbeddingProviderToken from a dictionary such as one from the Data API. @@ -646,7 +642,7 @@ class EmbeddingProviderAuthentication: """ enabled: bool - tokens: List[EmbeddingProviderToken] + tokens: list[EmbeddingProviderToken] def __repr__(self) -> str: return ( @@ -654,7 +650,7 @@ def __repr__(self) -> str: f"tokens={','.join(str(token) for token in self.tokens)})" ) - def as_dict(self) -> Dict[str, Any]: + def as_dict(self) -> dict[str, Any]: """Recast this object into a dictionary.""" return { @@ -663,7 +659,7 @@ def as_dict(self) -> Dict[str, Any]: } @staticmethod - def from_dict(raw_dict: Dict[str, Any]) -> EmbeddingProviderAuthentication: + def from_dict(raw_dict: dict[str, Any]) -> EmbeddingProviderAuthentication: """ Create an instance of EmbeddingProviderAuthentication from a dictionary such as one from the Data API. @@ -710,13 +706,13 @@ class EmbeddingProvider: def __repr__(self) -> str: return f"EmbeddingProvider(display_name='{self.display_name}', models={self.models})" - display_name: Optional[str] - models: List[EmbeddingProviderModel] - parameters: List[EmbeddingProviderParameter] - supported_authentication: Dict[str, EmbeddingProviderAuthentication] - url: Optional[str] + display_name: str | None + models: list[EmbeddingProviderModel] + parameters: list[EmbeddingProviderParameter] + supported_authentication: dict[str, EmbeddingProviderAuthentication] + url: str | None - def as_dict(self) -> Dict[str, Any]: + def as_dict(self) -> dict[str, Any]: """Recast this object into a dictionary.""" return { @@ -731,7 +727,7 @@ def as_dict(self) -> Dict[str, Any]: } @staticmethod - def from_dict(raw_dict: Dict[str, Any]) -> EmbeddingProvider: + def from_dict(raw_dict: dict[str, Any]) -> EmbeddingProvider: """ Create an instance of EmbeddingProvider from a dictionary such as one from the Data API. @@ -784,10 +780,10 @@ def __repr__(self) -> str: f"{', '.join(sorted(self.embedding_providers.keys()))})" ) - embedding_providers: Dict[str, EmbeddingProvider] - raw_info: Optional[Dict[str, Any]] + embedding_providers: dict[str, EmbeddingProvider] + raw_info: dict[str, Any] | None - def as_dict(self) -> Dict[str, Any]: + def as_dict(self) -> dict[str, Any]: """Recast this object into a dictionary.""" return { @@ -798,7 +794,7 @@ def as_dict(self) -> Dict[str, Any]: } @staticmethod - def from_dict(raw_dict: Dict[str, Any]) -> FindEmbeddingProvidersResult: + def from_dict(raw_dict: dict[str, Any]) -> FindEmbeddingProvidersResult: """ Create an instance of FindEmbeddingProvidersResult from a dictionary such as one from the Data API. diff --git a/astrapy/operations.py b/astrapy/operations.py index f3101130..ab18850d 100644 --- a/astrapy/operations.py +++ b/astrapy/operations.py @@ -17,7 +17,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from functools import reduce -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Iterable from astrapy.collection import AsyncCollection, Collection from astrapy.constants import DocumentType, SortType, VectorType @@ -31,7 +31,7 @@ ) -def reduce_bulk_write_results(results: List[BulkWriteResult]) -> BulkWriteResult: +def reduce_bulk_write_results(results: list[BulkWriteResult]) -> BulkWriteResult: """ Reduce a list of bulk write results into a single one. @@ -79,7 +79,7 @@ def execute( self, collection: Collection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: ... @@ -100,15 +100,15 @@ class InsertOne(BaseOperation): """ document: DocumentType - vector: Optional[VectorType] - vectorize: Optional[str] + vector: VectorType | None + vectorize: str | None def __init__( self, document: DocumentType, *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, + vector: VectorType | None = None, + vectorize: str | None = None, ) -> None: self.document = document check_deprecated_vector_ize( @@ -125,7 +125,7 @@ def execute( self, collection: Collection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: """ Execute this operation against a collection as part of a bulk write. @@ -166,21 +166,21 @@ class InsertMany(BaseOperation): """ documents: Iterable[DocumentType] - vectors: Optional[Iterable[Optional[VectorType]]] - vectorize: Optional[Iterable[Optional[str]]] + vectors: Iterable[VectorType | None] | None + vectorize: Iterable[str | None] | None ordered: bool - chunk_size: Optional[int] - concurrency: Optional[int] + chunk_size: int | None + concurrency: int | None def __init__( self, documents: Iterable[DocumentType], *, - vectors: Optional[Iterable[Optional[VectorType]]] = None, - vectorize: Optional[Iterable[Optional[str]]] = None, + vectors: Iterable[VectorType | None] | None = None, + vectorize: Iterable[str | None] | None = None, ordered: bool = True, - chunk_size: Optional[int] = None, - concurrency: Optional[int] = None, + chunk_size: int | None = None, + concurrency: int | None = None, ) -> None: self.documents = documents self.ordered = ordered @@ -200,7 +200,7 @@ def execute( self, collection: Collection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: """ Execute this operation against a collection as part of a bulk write. @@ -243,21 +243,21 @@ class UpdateOne(BaseOperation): upsert: controls what to do when no documents are found. """ - filter: Dict[str, Any] - update: Dict[str, Any] - vector: Optional[VectorType] - vectorize: Optional[str] - sort: Optional[SortType] + filter: dict[str, Any] + update: dict[str, Any] + vector: VectorType | None + vectorize: str | None + sort: SortType | None upsert: bool def __init__( self, - filter: Dict[str, Any], - update: Dict[str, Any], + filter: dict[str, Any], + update: dict[str, Any], *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, upsert: bool = False, ) -> None: self.filter = filter @@ -278,7 +278,7 @@ def execute( self, collection: Collection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: """ Execute this operation against a collection as part of a bulk write. @@ -312,14 +312,14 @@ class UpdateMany(BaseOperation): upsert: controls what to do when no documents are found. """ - filter: Dict[str, Any] - update: Dict[str, Any] + filter: dict[str, Any] + update: dict[str, Any] upsert: bool def __init__( self, - filter: Dict[str, Any], - update: Dict[str, Any], + filter: dict[str, Any], + update: dict[str, Any], *, upsert: bool = False, ) -> None: @@ -331,7 +331,7 @@ def execute( self, collection: Collection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: """ Execute this operation against a collection as part of a bulk write. @@ -371,21 +371,21 @@ class ReplaceOne(BaseOperation): upsert: controls what to do when no documents are found. """ - filter: Dict[str, Any] + filter: dict[str, Any] replacement: DocumentType - vector: Optional[VectorType] - vectorize: Optional[str] - sort: Optional[SortType] + vector: VectorType | None + vectorize: str | None + sort: SortType | None upsert: bool def __init__( self, - filter: Dict[str, Any], + filter: dict[str, Any], replacement: DocumentType, *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, upsert: bool = False, ) -> None: self.filter = filter @@ -406,7 +406,7 @@ def execute( self, collection: Collection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: """ Execute this operation against a collection as part of a bulk write. @@ -447,18 +447,18 @@ class DeleteOne(BaseOperation): sort: controls ordering of results, hence which document is affected. """ - filter: Dict[str, Any] - vector: Optional[VectorType] - vectorize: Optional[str] - sort: Optional[SortType] + filter: dict[str, Any] + vector: VectorType | None + vectorize: str | None + sort: SortType | None def __init__( self, - filter: Dict[str, Any], + filter: dict[str, Any], *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, ) -> None: self.filter = filter check_deprecated_vector_ize( @@ -476,7 +476,7 @@ def execute( self, collection: Collection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: """ Execute this operation against a collection as part of a bulk write. @@ -506,11 +506,11 @@ class DeleteMany(BaseOperation): filter: a filter condition to select target documents. """ - filter: Dict[str, Any] + filter: dict[str, Any] def __init__( self, - filter: Dict[str, Any], + filter: dict[str, Any], ) -> None: self.filter = filter @@ -518,7 +518,7 @@ def execute( self, collection: Collection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: """ Execute this operation against a collection as part of a bulk write. @@ -545,7 +545,7 @@ async def execute( self, collection: AsyncCollection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: ... @@ -566,15 +566,15 @@ class AsyncInsertOne(AsyncBaseOperation): """ document: DocumentType - vector: Optional[VectorType] - vectorize: Optional[str] + vector: VectorType | None + vectorize: str | None def __init__( self, document: DocumentType, *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, + vector: VectorType | None = None, + vectorize: str | None = None, ) -> None: self.document = document check_deprecated_vector_ize( @@ -592,7 +592,7 @@ async def execute( self, collection: AsyncCollection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: """ Execute this operation against a collection as part of a bulk write. @@ -633,21 +633,21 @@ class AsyncInsertMany(AsyncBaseOperation): """ documents: Iterable[DocumentType] - vectors: Optional[Iterable[Optional[VectorType]]] - vectorize: Optional[Iterable[Optional[str]]] + vectors: Iterable[VectorType | None] | None + vectorize: Iterable[str | None] | None ordered: bool - chunk_size: Optional[int] - concurrency: Optional[int] + chunk_size: int | None + concurrency: int | None def __init__( self, documents: Iterable[DocumentType], *, - vectors: Optional[Iterable[Optional[VectorType]]] = None, - vectorize: Optional[Iterable[Optional[str]]] = None, + vectors: Iterable[VectorType | None] | None = None, + vectorize: Iterable[str | None] | None = None, ordered: bool = True, - chunk_size: Optional[int] = None, - concurrency: Optional[int] = None, + chunk_size: int | None = None, + concurrency: int | None = None, ) -> None: self.documents = documents check_deprecated_vector_ize( @@ -668,7 +668,7 @@ async def execute( self, collection: AsyncCollection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: """ Execute this operation against a collection as part of a bulk write. @@ -711,21 +711,21 @@ class AsyncUpdateOne(AsyncBaseOperation): upsert: controls what to do when no documents are found. """ - filter: Dict[str, Any] - update: Dict[str, Any] - vector: Optional[VectorType] - vectorize: Optional[str] - sort: Optional[SortType] + filter: dict[str, Any] + update: dict[str, Any] + vector: VectorType | None + vectorize: str | None + sort: SortType | None upsert: bool def __init__( self, - filter: Dict[str, Any], - update: Dict[str, Any], + filter: dict[str, Any], + update: dict[str, Any], *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, upsert: bool = False, ) -> None: self.filter = filter @@ -747,7 +747,7 @@ async def execute( self, collection: AsyncCollection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: """ Execute this operation against a collection as part of a bulk write. @@ -781,14 +781,14 @@ class AsyncUpdateMany(AsyncBaseOperation): upsert: controls what to do when no documents are found. """ - filter: Dict[str, Any] - update: Dict[str, Any] + filter: dict[str, Any] + update: dict[str, Any] upsert: bool def __init__( self, - filter: Dict[str, Any], - update: Dict[str, Any], + filter: dict[str, Any], + update: dict[str, Any], *, upsert: bool = False, ) -> None: @@ -800,7 +800,7 @@ async def execute( self, collection: AsyncCollection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: """ Execute this operation against a collection as part of a bulk write. @@ -840,21 +840,21 @@ class AsyncReplaceOne(AsyncBaseOperation): upsert: controls what to do when no documents are found. """ - filter: Dict[str, Any] + filter: dict[str, Any] replacement: DocumentType - vector: Optional[VectorType] - vectorize: Optional[str] - sort: Optional[SortType] + vector: VectorType | None + vectorize: str | None + sort: SortType | None upsert: bool def __init__( self, - filter: Dict[str, Any], + filter: dict[str, Any], replacement: DocumentType, *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, upsert: bool = False, ) -> None: self.filter = filter @@ -876,7 +876,7 @@ async def execute( self, collection: AsyncCollection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: """ Execute this operation against a collection as part of a bulk write. @@ -917,18 +917,18 @@ class AsyncDeleteOne(AsyncBaseOperation): sort: controls ordering of results, hence which document is affected. """ - filter: Dict[str, Any] - vector: Optional[VectorType] - vectorize: Optional[str] - sort: Optional[SortType] + filter: dict[str, Any] + vector: VectorType | None + vectorize: str | None + sort: SortType | None def __init__( self, - filter: Dict[str, Any], + filter: dict[str, Any], *, - vector: Optional[VectorType] = None, - vectorize: Optional[str] = None, - sort: Optional[SortType] = None, + vector: VectorType | None = None, + vectorize: str | None = None, + sort: SortType | None = None, ) -> None: self.filter = filter check_deprecated_vector_ize( @@ -947,7 +947,7 @@ async def execute( self, collection: AsyncCollection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: """ Execute this operation against a collection as part of a bulk write. @@ -977,11 +977,11 @@ class AsyncDeleteMany(AsyncBaseOperation): filter: a filter condition to select target documents. """ - filter: Dict[str, Any] + filter: dict[str, Any] def __init__( self, - filter: Dict[str, Any], + filter: dict[str, Any], ) -> None: self.filter = filter @@ -989,7 +989,7 @@ async def execute( self, collection: AsyncCollection, index_in_bulk_write: int, - bulk_write_timeout_ms: Optional[int], + bulk_write_timeout_ms: int | None, ) -> BulkWriteResult: """ Execute this operation against a collection as part of a bulk write. diff --git a/astrapy/request_tools.py b/astrapy/request_tools.py index 8d67812a..3a1a1be6 100644 --- a/astrapy/request_tools.py +++ b/astrapy/request_tools.py @@ -15,7 +15,7 @@ from __future__ import annotations import logging -from typing import Any, Dict, Optional, TypedDict, Union +from typing import Any, TypedDict, Union import httpx @@ -27,9 +27,9 @@ def log_httpx_request( http_method: str, full_url: str, - request_params: Optional[Dict[str, Any]], - redacted_request_headers: Dict[str, str], - payload: Optional[Dict[str, Any]], + request_params: dict[str, Any] | None, + redacted_request_headers: dict[str, str], + payload: dict[str, Any] | None, ) -> None: """ Log the details of an HTTP request for debugging purposes. @@ -79,7 +79,7 @@ class TimeoutInfo(TypedDict, total=False): TimeoutInfoWideType = Union[TimeoutInfo, float, None] -def to_httpx_timeout(timeout_info: TimeoutInfoWideType) -> Union[httpx.Timeout, None]: +def to_httpx_timeout(timeout_info: TimeoutInfoWideType) -> httpx.Timeout | None: if timeout_info is None: return None if isinstance(timeout_info, float) or isinstance(timeout_info, int): diff --git a/astrapy/results.py b/astrapy/results.py index 7f150af0..e2c5715c 100644 --- a/astrapy/results.py +++ b/astrapy/results.py @@ -16,7 +16,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any @dataclass @@ -30,9 +30,9 @@ class OperationResult(ABC): list of raw responses can contain exactly one or a number of items. """ - raw_results: List[Dict[str, Any]] + raw_results: list[dict[str, Any]] - def _piecewise_repr(self, pieces: List[Optional[str]]) -> str: + def _piecewise_repr(self, pieces: list[str | None]) -> str: return f"{self.__class__.__name__}({', '.join(pc for pc in pieces if pc)})" @abstractmethod @@ -51,7 +51,7 @@ class DeleteResult(OperationResult): list of raw responses can contain exactly one or a number of items. """ - deleted_count: Optional[int] + deleted_count: int | None def __repr__(self) -> str: return self._piecewise_repr( @@ -115,7 +115,7 @@ class InsertManyResult(OperationResult): inserted_ids: list of the IDs of the inserted documents """ - inserted_ids: List[Any] + inserted_ids: list[Any] def __repr__(self) -> str: return self._piecewise_repr( @@ -153,7 +153,7 @@ class UpdateResult(OperationResult): """ - update_info: Dict[str, Any] + update_info: dict[str, Any] def __repr__(self) -> str: return self._piecewise_repr( @@ -201,13 +201,13 @@ class BulkWriteResult: upserted_ids: a (sparse) map from indices to ID of the upserted document """ - bulk_api_results: Dict[int, List[Dict[str, Any]]] - deleted_count: Optional[int] + bulk_api_results: dict[int, list[dict[str, Any]]] + deleted_count: int | None inserted_count: int matched_count: int modified_count: int upserted_count: int - upserted_ids: Dict[int, Any] + upserted_ids: dict[int, Any] def __repr__(self) -> str: pieces = [ diff --git a/astrapy/transform_payload.py b/astrapy/transform_payload.py index 9c2f78a1..e3bda7f3 100644 --- a/astrapy/transform_payload.py +++ b/astrapy/transform_payload.py @@ -16,13 +16,13 @@ import datetime import time -from typing import Any, Dict, Iterable, List, Union, cast +from typing import Any, Dict, Iterable, cast from astrapy.constants import DocumentType from astrapy.ids import UUID, ObjectId -def convert_vector_to_floats(vector: Iterable[Any]) -> List[float]: +def convert_vector_to_floats(vector: Iterable[Any]) -> list[float]: """ Convert a vector of strings to a vector of floats. @@ -46,36 +46,36 @@ def is_list_of_floats(vector: Iterable[Any]) -> bool: def convert_to_ejson_date_object( - date_value: Union[datetime.date, datetime.datetime] -) -> Dict[str, int]: + date_value: datetime.date | datetime.datetime, +) -> dict[str, int]: return {"$date": int(time.mktime(date_value.timetuple()) * 1000)} -def convert_to_ejson_uuid_object(uuid_value: UUID) -> Dict[str, str]: +def convert_to_ejson_uuid_object(uuid_value: UUID) -> dict[str, str]: return {"$uuid": str(uuid_value)} -def convert_to_ejson_objectid_object(objectid_value: ObjectId) -> Dict[str, str]: +def convert_to_ejson_objectid_object(objectid_value: ObjectId) -> dict[str, str]: return {"$objectId": str(objectid_value)} def convert_ejson_date_object_to_datetime( - date_object: Dict[str, int] + date_object: dict[str, int], ) -> datetime.datetime: return datetime.datetime.fromtimestamp(date_object["$date"] / 1000.0) -def convert_ejson_uuid_object_to_uuid(uuid_object: Dict[str, str]) -> UUID: +def convert_ejson_uuid_object_to_uuid(uuid_object: dict[str, str]) -> UUID: return UUID(uuid_object["$uuid"]) def convert_ejson_objectid_object_to_objectid( - objectid_object: Dict[str, str] + objectid_object: dict[str, str], ) -> ObjectId: return ObjectId(objectid_object["$objectId"]) -def normalize_payload_value(path: List[str], value: Any) -> Any: +def normalize_payload_value(path: list[str], value: Any) -> Any: """ The path helps determining special treatments """ @@ -104,9 +104,7 @@ def normalize_payload_value(path: List[str], value: Any) -> Any: return value -def normalize_for_api( - payload: Union[Dict[str, Any], None] -) -> Union[Dict[str, Any], None]: +def normalize_for_api(payload: dict[str, Any] | None) -> dict[str, Any] | None: """ Normalize a payload for API calls. This includes e.g. ensuring values for "$vector" key @@ -125,7 +123,7 @@ def normalize_for_api( return payload -def restore_response_value(path: List[str], value: Any) -> Any: +def restore_response_value(path: list[str], value: Any) -> Any: """ The path helps determining special treatments """ diff --git a/astrapy/user_agents.py b/astrapy/user_agents.py index b17aee60..9c9572a5 100644 --- a/astrapy/user_agents.py +++ b/astrapy/user_agents.py @@ -16,17 +16,16 @@ from importlib import metadata from importlib.metadata import PackageNotFoundError -from typing import List, Optional, Tuple from astrapy import __version__ -def detect_astrapy_user_agent() -> Tuple[Optional[str], Optional[str]]: +def detect_astrapy_user_agent() -> tuple[str | None, str | None]: package_name = __name__.split(".")[0] return (package_name, __version__) -def detect_ragstack_user_agent() -> Tuple[Optional[str], Optional[str]]: +def detect_ragstack_user_agent() -> tuple[str | None, str | None]: try: ragstack_meta = metadata.metadata("ragstack-ai") if ragstack_meta: @@ -38,8 +37,8 @@ def detect_ragstack_user_agent() -> Tuple[Optional[str], Optional[str]]: def compose_user_agent_string( - caller_name: Optional[str], caller_version: Optional[str] -) -> Optional[str]: + caller_name: str | None, caller_version: str | None +) -> str | None: if caller_name: if caller_version: return f"{caller_name}/{caller_version}" @@ -49,9 +48,7 @@ def compose_user_agent_string( return None -def compose_full_user_agent( - callers: List[Tuple[Optional[str], Optional[str]]] -) -> Optional[str]: +def compose_full_user_agent(callers: list[tuple[str | None, str | None]]) -> str | None: user_agent_strings = [ ua_string for ua_string in ( diff --git a/pyproject.toml b/pyproject.toml index f9902821..19091059 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,6 @@ +[project] +requires-python = ">=3.8" + [tool.poetry] name = "astrapy" version = "1.5.0" @@ -27,7 +30,6 @@ classifiers = [ ] [tool.poetry.dependencies] -python = "^3.8.0" cassio = "~0.1.4" deprecation = "~2.1.0" toml = "^0.10.2" @@ -47,7 +49,7 @@ pytest = "~8.0.0" python-dotenv = "~1.0.1" pytest-httpserver = "~1.0.8" testcontainers = "~3.7.1" -ruff = "~0.2.1" +ruff = "^0.6.6" types-toml = "^0.10.8.7" isort = "^5.13.2" @@ -55,6 +57,9 @@ isort = "^5.13.2" requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" +[tool.ruff.lint] +select = ["E4", "E7", "E9", "F", "FA", "I", "UP"] + [tool.mypy] disallow_any_generics = true disallow_incomplete_defs = true diff --git a/scripts/astrapy_latest_interface.py b/scripts/astrapy_latest_interface.py index 136d402e..2aecc6be 100644 --- a/scripts/astrapy_latest_interface.py +++ b/scripts/astrapy_latest_interface.py @@ -1,10 +1,9 @@ import os import sys -import astrapy - from dotenv import load_dotenv +import astrapy sys.path.append("../") diff --git a/tests/conftest.py b/tests/conftest.py index de9b8514..cddb331a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,7 +20,7 @@ import functools import warnings -from typing import Any, Awaitable, Callable, Optional, Tuple, TypedDict +from typing import Any, Awaitable, Callable, TypedDict import pytest from deprecation import UnsupportedWarning @@ -72,10 +72,10 @@ class DataAPICoreCredentials(TypedDict): class DataAPICredentialsInfo(TypedDict): environment: str region: str - secondary_namespace: Optional[str] + secondary_namespace: str | None -def env_region_from_endpoint(api_endpoint: str) -> Tuple[str, str]: +def env_region_from_endpoint(api_endpoint: str) -> tuple[str, str]: parsed = parse_api_endpoint(api_endpoint) if parsed is not None: return (parsed.environment, parsed.region) @@ -84,7 +84,7 @@ def env_region_from_endpoint(api_endpoint: str) -> Tuple[str, str]: def async_fail_if_not_removed( - method: Callable[..., Awaitable[Any]] + method: Callable[..., Awaitable[Any]], ) -> Callable[..., Awaitable[Any]]: """ Decorate a test async method to track removal of deprecated code. @@ -105,10 +105,7 @@ async def test_inner(*args: Any, **kwargs: Any) -> Any: for warning in caught_warnings: if warning.category == UnsupportedWarning: raise AssertionError( - ( - "%s uses a function that should be removed: %s" - % (method, str(warning.message)) - ) + f"{method} uses a function that should be removed: {str(warning.message)}" ) return rv @@ -132,10 +129,7 @@ def test_inner(*args: Any, **kwargs: Any) -> Any: for warning in caught_warnings: if warning.category == UnsupportedWarning: raise AssertionError( - ( - "%s uses a function that should be removed: %s" - % (method, str(warning.message)) - ) + f"{method} uses a function that should be removed: {str(warning.message)}" ) return rv diff --git a/tests/core/conftest.py b/tests/core/conftest.py index 7e9710a8..41e9813e 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -19,7 +19,7 @@ from __future__ import annotations import math -from typing import AsyncIterable, Dict, Iterable, List, Optional, Set, TypeVar +from typing import AsyncIterable, Iterable, TypeVar import pytest import pytest_asyncio @@ -111,7 +111,7 @@ def db( if token is None or api_endpoint is None: raise ValueError("Required ASTRA DB configuration is missing") - db_kwargs: Dict[str, str] + db_kwargs: dict[str, str] if data_api_credentials_info["environment"] in {"prod", "dev", "test"}: db_kwargs = {} else: @@ -134,7 +134,7 @@ async def async_db( if token is None or api_endpoint is None: raise ValueError("Required ASTRA DB configuration is missing") - db_kwargs: Dict[str, str] + db_kwargs: dict[str, str] if data_api_credentials_info["environment"] in {"prod", "dev", "test"}: db_kwargs = {} else: @@ -151,14 +151,14 @@ async def async_db( @pytest.fixture(scope="module") def invalid_db( - data_api_core_bad_credentials_kwargs: Dict[str, Optional[str]], + data_api_core_bad_credentials_kwargs: dict[str, str | None], data_api_credentials_info: DataAPICredentialsInfo, ) -> AstraDB: token = data_api_core_bad_credentials_kwargs["token"] api_endpoint = data_api_core_bad_credentials_kwargs["api_endpoint"] namespace = data_api_core_bad_credentials_kwargs.get("namespace") - db_kwargs: Dict[str, str] + db_kwargs: dict[str, str] if data_api_credentials_info["environment"] in {"prod", "dev", "test"}: db_kwargs = {} else: @@ -319,11 +319,11 @@ def pagination_v_collection( INSERT_BATCH_SIZE = 20 # max 20, fixed by API constraints N = 200 # must be EVEN - def _mk_vector(index: int, n_total_steps: int) -> List[float]: + def _mk_vector(index: int, n_total_steps: int) -> list[float]: angle = 2 * math.pi * index / n_total_steps return [math.cos(angle), math.sin(angle)] - inserted_ids: Set[str] = set() + inserted_ids: set[str] = set() for i_batch in _batch_iterable(range(N), INSERT_BATCH_SIZE): batch_ids = empty_v_collection.insert_many( documents=[{"_id": str(i), "$vector": _mk_vector(i, N)} for i in i_batch] diff --git a/tests/core/test_async_db_dml.py b/tests/core/test_async_db_dml.py index 83148909..90f7acfd 100644 --- a/tests/core/test_async_db_dml.py +++ b/tests/core/test_async_db_dml.py @@ -22,7 +22,7 @@ import datetime import logging import uuid -from typing import Any, Dict, Iterable, List, Literal, Optional, Union, cast +from typing import Any, Iterable, List, Literal, cast import pytest @@ -33,7 +33,7 @@ logger = logging.getLogger(__name__) -def _cleanvec(doc: Dict[str, Any]) -> Dict[str, Any]: +def _cleanvec(doc: dict[str, Any]) -> dict[str, Any]: return {k: v for k, v in doc.items() if k != "$vector"} @@ -135,7 +135,7 @@ async def test_find_find_one_projection( sort = {"$vector": query} options = {"limit": 1} - projs: List[Optional[Dict[str, Literal[1]]]] = [ + projs: list[dict[str, Literal[1]] | None] = [ None, {}, {"text": 1}, @@ -344,7 +344,7 @@ async def test_insert_many( ) -> None: _id0 = str(uuid.uuid4()) _id2 = str(uuid.uuid4()) - documents: List[API_DOC] = [ + documents: list[API_DOC] = [ { "_id": _id0, "name": "Abba", @@ -375,7 +375,7 @@ async def test_chunked_insert_many( async_writable_v_collection: AsyncAstraDBCollection, ) -> None: _ids0 = [str(uuid.uuid4()) for _ in range(20)] - documents0: List[API_DOC] = [ + documents0: list[API_DOC] = [ { "_id": _id, "specs": { @@ -387,9 +387,9 @@ async def test_chunked_insert_many( for doc_idx, _id in enumerate(_ids0) ] - responses0: List[Union[Dict[str, Any], Exception]] = ( - await async_writable_v_collection.chunked_insert_many(documents0, chunk_size=3) - ) + responses0: list[ + dict[str, Any] | Exception + ] = await async_writable_v_collection.chunked_insert_many(documents0, chunk_size=3) assert responses0 is not None inserted_ids0 = [ ins_id @@ -408,7 +408,7 @@ async def test_chunked_insert_many( _ids1 = [ _id0 if idx % 3 == 0 else str(uuid.uuid4()) for idx, _id0 in enumerate(_ids0) ] - documents1: List[API_DOC] = [ + documents1: list[API_DOC] = [ { "_id": _id, "specs": { @@ -458,7 +458,7 @@ async def test_concurrent_chunked_insert_many( async_writable_v_collection: AsyncAstraDBCollection, ) -> None: _ids0 = [str(uuid.uuid4()) for _ in range(20)] - documents0: List[API_DOC] = [ + documents0: list[API_DOC] = [ { "_id": _id, "specs": { @@ -493,7 +493,7 @@ async def test_concurrent_chunked_insert_many( _ids1 = [ _id0 if idx % 3 == 0 else str(uuid.uuid4()) for idx, _id0 in enumerate(_ids0) ] - documents1: List[API_DOC] = [ + documents1: list[API_DOC] = [ { "_id": _id, "specs": { diff --git a/tests/core/test_async_db_dml_pagination.py b/tests/core/test_async_db_dml_pagination.py index 09ac4878..44aabfe1 100644 --- a/tests/core/test_async_db_dml_pagination.py +++ b/tests/core/test_async_db_dml_pagination.py @@ -19,7 +19,6 @@ from __future__ import annotations import logging -from typing import Optional import pytest @@ -43,7 +42,7 @@ ], ) async def test_find_paginated( - prefetched: Optional[int], + prefetched: int | None, async_pagination_v_collection: AsyncAstraDBCollection, ) -> None: options = {"limit": FIND_LIMIT} diff --git a/tests/core/test_db_dml.py b/tests/core/test_db_dml.py index 3d69f7d9..a7006bd6 100644 --- a/tests/core/test_db_dml.py +++ b/tests/core/test_db_dml.py @@ -23,7 +23,7 @@ import json import logging import uuid -from typing import Any, Dict, Iterable, List, Literal, Optional, Set, cast +from typing import Any, Iterable, List, Literal, cast import httpx import pytest @@ -35,7 +35,7 @@ logger = logging.getLogger(__name__) -def _cleanvec(doc: Dict[str, Any]) -> Dict[str, Any]: +def _cleanvec(doc: dict[str, Any]) -> dict[str, Any]: return {k: v for k, v in doc.items() if k != "$vector"} @@ -130,14 +130,14 @@ def test_find_find_one_projection( query = [0.2, 0.6] sort = {"$vector": query} options = {"limit": 1} - projs: List[Optional[Dict[str, Literal[1]]]] = [ + projs: list[dict[str, Literal[1]] | None] = [ None, {}, {"text": 1}, {"$vector": 1}, {"text": 1, "$vector": 1}, ] - exp_fieldsets: List[Set[str]] = [ + exp_fieldsets: list[set[str]] = [ {"$vector", "_id", "otherfield", "anotherfield", "text"}, {"$vector", "_id", "otherfield", "anotherfield", "text"}, {"_id", "text"}, @@ -327,7 +327,7 @@ def test_insert_float32(writable_v_collection: AstraDBCollection, N: int = 2) -> def test_insert_many(writable_v_collection: AstraDBCollection) -> None: _id0 = str(uuid.uuid4()) _id2 = str(uuid.uuid4()) - documents: List[API_DOC] = [ + documents: list[API_DOC] = [ { "_id": _id0, "name": "Abba", @@ -358,7 +358,7 @@ def test_chunked_insert_many( writable_v_collection: AstraDBCollection, ) -> None: _ids0 = [str(uuid.uuid4()) for _ in range(20)] - documents0: List[API_DOC] = [ + documents0: list[API_DOC] = [ { "_id": _id, "specs": { @@ -389,7 +389,7 @@ def test_chunked_insert_many( _ids1 = [ _id0 if idx % 3 == 0 else str(uuid.uuid4()) for idx, _id0 in enumerate(_ids0) ] - documents1: List[API_DOC] = [ + documents1: list[API_DOC] = [ { "_id": _id, "specs": { @@ -439,7 +439,7 @@ def test_concurrent_chunked_insert_many( writable_v_collection: AstraDBCollection, ) -> None: _ids0 = [str(uuid.uuid4()) for _ in range(20)] - documents0: List[API_DOC] = [ + documents0: list[API_DOC] = [ { "_id": _id, "specs": { @@ -472,7 +472,7 @@ def test_concurrent_chunked_insert_many( _ids1 = [ _id0 if idx % 3 == 0 else str(uuid.uuid4()) for idx, _id0 in enumerate(_ids0) ] - documents1: List[API_DOC] = [ + documents1: list[API_DOC] = [ { "_id": _id, "specs": { diff --git a/tests/core/test_db_dml_pagination.py b/tests/core/test_db_dml_pagination.py index 6d94ecb5..23ff3113 100644 --- a/tests/core/test_db_dml_pagination.py +++ b/tests/core/test_db_dml_pagination.py @@ -20,7 +20,6 @@ import logging import time -from typing import Optional import pytest @@ -44,7 +43,7 @@ ], ) def test_find_paginated( - prefetched: Optional[int], + prefetched: int | None, pagination_v_collection: AstraDBCollection, caplog: pytest.LogCaptureFixture, ) -> None: diff --git a/tests/core/test_ops.py b/tests/core/test_ops.py index b6ada837..2139d15c 100644 --- a/tests/core/test_ops.py +++ b/tests/core/test_ops.py @@ -35,7 +35,7 @@ logger = logging.getLogger(__name__) -def find_new_name(existing: List[str], prefix: str) -> str: +def find_new_name(existing: list[str], prefix: str) -> str: candidate_name = prefix for idx in itertools.count(): candidate_name = f"{prefix}{idx}" diff --git a/tests/idiomatic/integration/test_admin.py b/tests/idiomatic/integration/test_admin.py index 635b9cf2..cbff481c 100644 --- a/tests/idiomatic/integration/test_admin.py +++ b/tests/idiomatic/integration/test_admin.py @@ -15,7 +15,7 @@ from __future__ import annotations import time -from typing import Any, Awaitable, Callable, List, Optional, Tuple +from typing import Any, Awaitable, Callable import pytest @@ -37,15 +37,15 @@ PRE_DROP_SAFETY_TIMEOUT = 120 -def admin_test_envs_tokens() -> List[Any]: +def admin_test_envs_tokens() -> list[Any]: """ This actually returns a List of `_pytest.mark.structures.ParameterSet` instances, each wrapping a Tuple[str, Optional[str]] = (env, token) """ - envs_tokens: List[Any] = [] + envs_tokens: list[Any] = [] for admin_env in ADMIN_ENV_LIST: markers = [] - pair: Tuple[str, Optional[str]] + pair: tuple[str, str | None] if ADMIN_ENV_VARIABLE_MAP[admin_env]["token"]: pair = (admin_env, ADMIN_ENV_VARIABLE_MAP[admin_env]["token"]) else: @@ -84,7 +84,7 @@ class TestAdmin: ) @pytest.mark.describe("test of the full tour with AstraDBDatabaseAdmin, sync") def test_astra_db_database_admin_sync( - self, admin_env_token: Tuple[str, str] + self, admin_env_token: tuple[str, str] ) -> None: """ Test plan (it has to be a single giant test to use one DB throughout): @@ -206,7 +206,7 @@ def test_astra_db_database_admin_sync( @pytest.mark.describe( "test of the full tour with AstraDBAdmin and client methods, sync" ) - def test_astra_db_admin_sync(self, admin_env_token: Tuple[str, str]) -> None: + def test_astra_db_admin_sync(self, admin_env_token: tuple[str, str]) -> None: """ Test plan (it has to be a single giant test to use the two DBs throughout): - create client -> get_admin @@ -336,7 +336,7 @@ def _waiter2() -> bool: ) @pytest.mark.describe("test of the full tour with AstraDBDatabaseAdmin, async") async def test_astra_db_database_admin_async( - self, admin_env_token: Tuple[str, str] + self, admin_env_token: tuple[str, str] ) -> None: """ Test plan (it has to be a single giant test to use one DB throughout): @@ -471,7 +471,7 @@ async def _awaiter3() -> bool: @pytest.mark.describe( "test of the full tour with AstraDBAdmin and client methods, async" ) - async def test_astra_db_admin_async(self, admin_env_token: Tuple[str, str]) -> None: + async def test_astra_db_admin_async(self, admin_env_token: tuple[str, str]) -> None: """ Test plan (it has to be a single giant test to use the two DBs throughout): - create client -> get_admin diff --git a/tests/idiomatic/integration/test_dml_async.py b/tests/idiomatic/integration/test_dml_async.py index 93b9ed5c..30a82350 100644 --- a/tests/idiomatic/integration/test_dml_async.py +++ b/tests/idiomatic/integration/test_dml_async.py @@ -15,7 +15,7 @@ from __future__ import annotations import datetime -from typing import Any, Dict, List +from typing import Any import pytest @@ -258,7 +258,7 @@ async def test_collection_find_async( Nsor = {"seq": SortDocuments.DESCENDING} Nfil = {"seq": {"$exists": True}} - async def _alist(acursor: AsyncCursor) -> List[DocumentType]: + async def _alist(acursor: AsyncCursor) -> list[DocumentType]: return [doc async for doc in acursor] # case 0000 of find-pattern matrix @@ -466,7 +466,7 @@ async def test_collection_cursors_async( document0b = await cursor0b.__anext__() assert "ternary" in document0b - async def _alist(acursor: AsyncCursor) -> List[DocumentType]: + async def _alist(acursor: AsyncCursor) -> list[DocumentType]: return [doc async for doc in acursor] # rewinding, slicing and retrieved @@ -544,7 +544,7 @@ async def test_collection_distinct_nonhashable_async( async_empty_collection: AsyncCollection, ) -> None: acol = async_empty_collection - documents: List[Dict[str, Any]] = [ + documents: list[dict[str, Any]] = [ {}, {"f": 1}, {"f": "a"}, @@ -685,13 +685,13 @@ async def test_collection_include_sort_vector_find_async( ) -> None: q_vector = [10, 9] - async def _alist(acursor: AsyncCursor) -> List[DocumentType]: + async def _alist(acursor: AsyncCursor) -> list[DocumentType]: return [doc async for doc in acursor] # with empty collection for include_sv in [False, True]: for sort_cl_label in ["reg", "vec"]: - sort_cl_e: Dict[str, Any] = ( + sort_cl_e: dict[str, Any] = ( {} if sort_cl_label == "reg" else {"$vector": q_vector} ) vec_expected = include_sv and sort_cl_label == "vec" @@ -726,7 +726,7 @@ async def _alist(acursor: AsyncCursor) -> List[DocumentType]: # with non-empty collection for include_sv in [False, True]: for sort_cl_label in ["reg", "vec"]: - sort_cl_f: Dict[str, Any] = ( + sort_cl_f: dict[str, Any] = ( {} if sort_cl_label == "reg" else {"$vector": q_vector} ) vec_expected = include_sv and sort_cl_label == "vec" diff --git a/tests/idiomatic/integration/test_dml_sync.py b/tests/idiomatic/integration/test_dml_sync.py index 8b1f4557..6eba5634 100644 --- a/tests/idiomatic/integration/test_dml_sync.py +++ b/tests/idiomatic/integration/test_dml_sync.py @@ -15,7 +15,7 @@ from __future__ import annotations import datetime -from typing import Any, Dict, List +from typing import Any import pytest @@ -478,7 +478,7 @@ def test_collection_distinct_nonhashable_sync( sync_empty_collection: Collection, ) -> None: col = sync_empty_collection - documents: List[Dict[str, Any]] = [ + documents: list[dict[str, Any]] = [ {}, {"f": 1}, {"f": "a"}, @@ -619,7 +619,7 @@ def test_collection_include_sort_vector_find_sync( # with empty collection for include_sv in [False, True]: for sort_cl_label in ["reg", "vec"]: - sort_cl_e: Dict[str, Any] = ( + sort_cl_e: dict[str, Any] = ( {} if sort_cl_label == "reg" else {"$vector": q_vector} ) vec_expected = include_sv and sort_cl_label == "vec" @@ -654,7 +654,7 @@ def test_collection_include_sort_vector_find_sync( # with non-empty collection for include_sv in [False, True]: for sort_cl_label in ["reg", "vec"]: - sort_cl_f: Dict[str, Any] = ( + sort_cl_f: dict[str, Any] = ( {} if sort_cl_label == "reg" else {"$vector": q_vector} ) vec_expected = include_sv and sort_cl_label == "vec" diff --git a/tests/idiomatic/integration/test_exceptions_async.py b/tests/idiomatic/integration/test_exceptions_async.py index 3f8f07e7..79f25f9d 100644 --- a/tests/idiomatic/integration/test_exceptions_async.py +++ b/tests/idiomatic/integration/test_exceptions_async.py @@ -14,8 +14,6 @@ from __future__ import annotations -from typing import List - import pytest from astrapy import AsyncCollection, AsyncDatabase @@ -62,8 +60,7 @@ async def test_collection_insert_many_insert_failures_async( self, async_empty_collection: AsyncCollection, ) -> None: - - async def _alist(acursor: AsyncCursor) -> List[DocumentType]: + async def _alist(acursor: AsyncCursor) -> list[DocumentType]: return [doc async for doc in acursor] acol = async_empty_collection diff --git a/tests/idiomatic/unit/test_apicommander.py b/tests/idiomatic/unit/test_apicommander.py index 713ecdc7..8b4d5ffe 100644 --- a/tests/idiomatic/unit/test_apicommander.py +++ b/tests/idiomatic/unit/test_apicommander.py @@ -16,7 +16,6 @@ import json import time -from typing import Optional import pytest import werkzeug @@ -90,7 +89,7 @@ def test_apicommander_request_sync(self, httpserver: HTTPServer) -> None: callers=[("cn", "cv")], ) - def hv_matcher(hk: str, hv: Optional[str], ev: str) -> bool: + def hv_matcher(hk: str, hv: str | None, ev: str) -> bool: if hk == "v": return hv == ev elif hk.lower() == "user-agent": @@ -138,7 +137,7 @@ async def test_apicommander_request_async(self, httpserver: HTTPServer) -> None: callers=[("cn", "cv")], ) - def hv_matcher(hk: str, hv: Optional[str], ev: str) -> bool: + def hv_matcher(hk: str, hv: str | None, ev: str) -> bool: if hk == "v": return hv == ev elif hk.lower() == "user-agent": diff --git a/tests/idiomatic/unit/test_collection_options.py b/tests/idiomatic/unit/test_collection_options.py index f6e09244..8b98bf5e 100644 --- a/tests/idiomatic/unit/test_collection_options.py +++ b/tests/idiomatic/unit/test_collection_options.py @@ -18,7 +18,7 @@ from __future__ import annotations -from typing import Any, Dict, List, Tuple +from typing import Any import pytest @@ -27,7 +27,7 @@ @pytest.mark.describe("test of recasting the collection options from the api") def test_recast_api_collection_dict() -> None: - api_coll_descs: List[Tuple[Dict[str, Any], Dict[str, Any]]] = [ + api_coll_descs: list[tuple[dict[str, Any], dict[str, Any]]] = [ # minimal: ( { diff --git a/tests/idiomatic/unit/test_document_extractors.py b/tests/idiomatic/unit/test_document_extractors.py index 985f28e8..3131c67a 100644 --- a/tests/idiomatic/unit/test_document_extractors.py +++ b/tests/idiomatic/unit/test_document_extractors.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Any, Dict, List +from typing import Any import pytest @@ -57,7 +57,7 @@ def test_dotted_fieldname_document_extractor(self) -> None: } def assert_extracts( - document: Dict[str, Any], key: str, expected: List[Any] + document: dict[str, Any], key: str, expected: list[Any] ) -> None: _extractor = _create_document_key_extractor(key) _extracted = list(_extractor(document)) diff --git a/tests/preprocess_env.py b/tests/preprocess_env.py index 876e9945..244762dd 100644 --- a/tests/preprocess_env.py +++ b/tests/preprocess_env.py @@ -22,7 +22,6 @@ import os import time -from typing import List, Optional from testcontainers.compose import DockerCompose @@ -40,18 +39,18 @@ IS_ASTRA_DB: bool DOCKER_COMPOSE_LOCAL_DATA_API: bool -SECONDARY_NAMESPACE: Optional[str] = None -ASTRA_DB_API_ENDPOINT: Optional[str] = None -ASTRA_DB_APPLICATION_TOKEN: Optional[str] = None -ASTRA_DB_KEYSPACE: Optional[str] = None -LOCAL_DATA_API_USERNAME: Optional[str] = None -LOCAL_DATA_API_PASSWORD: Optional[str] = None -LOCAL_DATA_API_APPLICATION_TOKEN: Optional[str] = None -LOCAL_DATA_API_ENDPOINT: Optional[str] = None -LOCAL_DATA_API_KEYSPACE: Optional[str] = None - -ASTRA_DB_TOKEN_PROVIDER: Optional[TokenProvider] = None -LOCAL_DATA_API_TOKEN_PROVIDER: Optional[TokenProvider] = None +SECONDARY_NAMESPACE: str | None = None +ASTRA_DB_API_ENDPOINT: str | None = None +ASTRA_DB_APPLICATION_TOKEN: str | None = None +ASTRA_DB_KEYSPACE: str | None = None +LOCAL_DATA_API_USERNAME: str | None = None +LOCAL_DATA_API_PASSWORD: str | None = None +LOCAL_DATA_API_APPLICATION_TOKEN: str | None = None +LOCAL_DATA_API_ENDPOINT: str | None = None +LOCAL_DATA_API_KEYSPACE: str | None = None + +ASTRA_DB_TOKEN_PROVIDER: TokenProvider | None = None +LOCAL_DATA_API_TOKEN_PROVIDER: TokenProvider | None = None # idiomatic-related settings if "LOCAL_DATA_API_ENDPOINT" in os.environ: @@ -114,7 +113,6 @@ is_docker_compose_started = False if DOCKER_COMPOSE_LOCAL_DATA_API: if not is_docker_compose_started: - """ Note: this is a trick to invoke `docker compose` as opposed to `docker-compose` while using testcontainers < 4. @@ -133,8 +131,7 @@ """ class RedefineCommandDockerCompose(DockerCompose): - - def docker_compose_command(self) -> List[str]: + def docker_compose_command(self) -> list[str]: docker_compose_cmd = ["docker", "compose"] for file in self.compose_file_names: docker_compose_cmd += ["-f", file] diff --git a/tests/vectorize_idiomatic/conftest.py b/tests/vectorize_idiomatic/conftest.py index 56eab027..61989175 100644 --- a/tests/vectorize_idiomatic/conftest.py +++ b/tests/vectorize_idiomatic/conftest.py @@ -19,7 +19,7 @@ from __future__ import annotations import os -from typing import Any, Dict, Iterable +from typing import Any, Iterable import pytest @@ -54,7 +54,7 @@ def async_database( @pytest.fixture(scope="session") -def service_collection_parameters() -> Iterable[Dict[str, Any]]: +def service_collection_parameters() -> Iterable[dict[str, Any]]: yield { "dimension": 1536, "provider": "openai", @@ -67,7 +67,7 @@ def service_collection_parameters() -> Iterable[Dict[str, Any]]: def sync_service_collection( data_api_credentials_kwargs: DataAPICredentials, sync_database: Database, - service_collection_parameters: Dict[str, Any], + service_collection_parameters: dict[str, Any], ) -> Iterable[Collection]: """ An actual collection on DB, in the main namespace. diff --git a/tests/vectorize_idiomatic/integration/test_vectorize_methods_async.py b/tests/vectorize_idiomatic/integration/test_vectorize_methods_async.py index 9a0c5fd8..cbeb2a15 100644 --- a/tests/vectorize_idiomatic/integration/test_vectorize_methods_async.py +++ b/tests/vectorize_idiomatic/integration/test_vectorize_methods_async.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Any, Dict, List +from typing import Any import pytest @@ -36,7 +36,7 @@ class TestVectorizeMethodsAsync: async def test_collection_methods_vectorize_async( self, async_empty_service_collection: AsyncCollection, - service_collection_parameters: Dict[str, Any], + service_collection_parameters: dict[str, Any], ) -> None: acol = async_empty_service_collection service_vector_dimension = service_collection_parameters["dimension"] @@ -187,12 +187,12 @@ async def test_collection_include_sort_vector_vectorize_find_async( def _is_vector(v: Any) -> bool: return isinstance(v, list) and isinstance(v[0], float) - async def _alist(acursor: AsyncCursor) -> List[DocumentType]: + async def _alist(acursor: AsyncCursor) -> list[DocumentType]: return [doc async for doc in acursor] for include_sv in [False, True]: for sort_cl_label in ["vze"]: - sort_cl_e: Dict[str, Any] = {"$vectorize": q_text} + sort_cl_e: dict[str, Any] = {"$vectorize": q_text} vec_expected = include_sv and sort_cl_label == "vze" # pristine iterator this_ite_1 = async_empty_service_collection.find( @@ -228,7 +228,7 @@ async def _alist(acursor: AsyncCursor) -> List[DocumentType]: # with non-empty collection for include_sv in [False, True]: for sort_cl_label in ["vze"]: - sort_cl_f: Dict[str, Any] = {"$vectorize": q_text} + sort_cl_f: dict[str, Any] = {"$vectorize": q_text} vec_expected = include_sv and sort_cl_label == "vze" # pristine iterator this_ite_1 = async_empty_service_collection.find( @@ -274,7 +274,7 @@ async def _alist(acursor: AsyncCursor) -> List[DocumentType]: async def test_database_create_collection_dimension_mismatch_failure_async( self, async_database: AsyncDatabase, - service_collection_parameters: Dict[str, Any], + service_collection_parameters: dict[str, Any], ) -> None: with pytest.raises(DataAPIResponseException): await async_database.create_collection( diff --git a/tests/vectorize_idiomatic/integration/test_vectorize_methods_sync.py b/tests/vectorize_idiomatic/integration/test_vectorize_methods_sync.py index 51e6d092..cf873224 100644 --- a/tests/vectorize_idiomatic/integration/test_vectorize_methods_sync.py +++ b/tests/vectorize_idiomatic/integration/test_vectorize_methods_sync.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Any, Dict +from typing import Any import pytest @@ -28,7 +28,7 @@ class TestVectorizeMethodsSync: def test_collection_methods_vectorize_sync( self, sync_empty_service_collection: Collection, - service_collection_parameters: Dict[str, Any], + service_collection_parameters: dict[str, Any], ) -> None: col = sync_empty_service_collection service_vector_dimension = service_collection_parameters["dimension"] @@ -178,7 +178,7 @@ def _is_vector(v: Any) -> bool: for include_sv in [False, True]: for sort_cl_label in ["vze"]: - sort_cl_e: Dict[str, Any] = {"$vectorize": q_text} + sort_cl_e: dict[str, Any] = {"$vectorize": q_text} vec_expected = include_sv and sort_cl_label == "vze" # pristine iterator this_ite_1 = sync_empty_service_collection.find( @@ -214,7 +214,7 @@ def _is_vector(v: Any) -> bool: # with non-empty collection for include_sv in [False, True]: for sort_cl_label in ["vze"]: - sort_cl_f: Dict[str, Any] = {"$vectorize": q_text} + sort_cl_f: dict[str, Any] = {"$vectorize": q_text} vec_expected = include_sv and sort_cl_label == "vze" # pristine iterator this_ite_1 = sync_empty_service_collection.find( @@ -256,7 +256,7 @@ def _is_vector(v: Any) -> bool: def test_database_create_collection_dimension_mismatch_failure_sync( self, sync_database: Database, - service_collection_parameters: Dict[str, Any], + service_collection_parameters: dict[str, Any], ) -> None: with pytest.raises(DataAPIResponseException): sync_database.create_collection( diff --git a/tests/vectorize_idiomatic/integration/test_vectorize_providers.py b/tests/vectorize_idiomatic/integration/test_vectorize_providers.py index cd4e6989..002f67e7 100644 --- a/tests/vectorize_idiomatic/integration/test_vectorize_providers.py +++ b/tests/vectorize_idiomatic/integration/test_vectorize_providers.py @@ -16,7 +16,7 @@ import os import sys -from typing import Any, Dict, List, Union +from typing import Any import pytest @@ -31,7 +31,7 @@ from ..vectorize_models import live_test_models -def enabled_vectorize_models(auth_type: str) -> List[Any]: +def enabled_vectorize_models(auth_type: str) -> list[Any]: """ This actually returns a List of `_pytest.mark.structures.ParameterSet` instances, each wrapping a dict with the needed info to test the model @@ -41,7 +41,7 @@ def enabled_vectorize_models(auth_type: str) -> List[Any]: where `tag` = "provider/model/auth_type/[0 or f]" """ all_test_models = list(live_test_models()) - all_model_ids: List[str] = [ + all_model_ids: list[str] = [ str(model_desc["model_tag"]) for model_desc in all_test_models ] # @@ -50,10 +50,10 @@ def enabled_vectorize_models(auth_type: str) -> List[Any]: for test_model in all_test_models if test_model["auth_type_name"] == auth_type ] - at_model_ids: List[str] = [ + at_model_ids: list[str] = [ str(model_desc["model_tag"]) for model_desc in at_test_models ] - at_chosen_models: List[Any] = [] + at_chosen_models: list[Any] = [] if "EMBEDDING_MODEL_TAGS" in os.environ: whitelisted_models = [ _wmd.strip() @@ -91,12 +91,12 @@ class TestVectorizeProviders: def test_vectorize_usage_auth_type_header_sync( self, sync_database: Database, - testable_vectorize_model: Dict[str, Any], + testable_vectorize_model: dict[str, Any], ) -> None: simple_tag = testable_vectorize_model["simple_tag"].lower() # switch betewen header providers according to what is needed # For the time being this is necessary on HEADER only - embedding_api_key: Union[str, EmbeddingHeadersProvider] + embedding_api_key: str | EmbeddingHeadersProvider at_tokens = testable_vectorize_model["auth_type_tokens"] at_token_lnames = {tk.accepted.lower() for tk in at_tokens} if at_token_lnames == {"x-embedding-api-key"}: @@ -201,7 +201,7 @@ def test_vectorize_usage_auth_type_header_sync( def test_vectorize_usage_auth_type_none_sync( self, sync_database: Database, - testable_vectorize_model: Dict[str, Any], + testable_vectorize_model: dict[str, Any], ) -> None: simple_tag = testable_vectorize_model["simple_tag"].lower() dimension = testable_vectorize_model.get("dimension") @@ -281,7 +281,7 @@ def test_vectorize_usage_auth_type_none_sync( def test_vectorize_usage_auth_type_shared_secret_sync( self, sync_database: Database, - testable_vectorize_model: Dict[str, Any], + testable_vectorize_model: dict[str, Any], ) -> None: simple_tag = testable_vectorize_model["simple_tag"].lower() secret_tag = testable_vectorize_model["secret_tag"] diff --git a/tests/vectorize_idiomatic/query_providers.py b/tests/vectorize_idiomatic/query_providers.py index 7ab2fccc..aace0916 100644 --- a/tests/vectorize_idiomatic/query_providers.py +++ b/tests/vectorize_idiomatic/query_providers.py @@ -17,7 +17,6 @@ import json import os import sys -from typing import List from astrapy.info import EmbeddingProviderParameter, FindEmbeddingProvidersResult @@ -113,7 +112,7 @@ def desc_param(param_data: EmbeddingProviderParameter) -> str: for test_model in all_test_models if test_model["auth_type_name"] == auth_type ] - at_model_ids: List[str] = sorted( + at_model_ids: list[str] = sorted( [str(model_desc["model_tag"]) for model_desc in at_test_models] ) if at_model_ids: diff --git a/tests/vectorize_idiomatic/vectorize_models.py b/tests/vectorize_idiomatic/vectorize_models.py index 6e5965eb..e2bd4a39 100644 --- a/tests/vectorize_idiomatic/vectorize_models.py +++ b/tests/vectorize_idiomatic/vectorize_models.py @@ -16,7 +16,7 @@ import os import sys -from typing import Any, Dict, Iterable, List, Tuple +from typing import Any, Iterable from astrapy.defaults import ( EMBEDDING_HEADER_API_KEY, @@ -85,7 +85,7 @@ ("voyageAI", "voyage-code-2"): CODE_TEST_ASSETS, } -USE_INSERT_ONE_MAP: Dict[Tuple[str, str], bool] = { +USE_INSERT_ONE_MAP: dict[tuple[str, str], bool] = { # ("upstageAI", "solar-1-mini-embedding"): True, } @@ -172,8 +172,7 @@ } -def live_test_models() -> Iterable[Dict[str, Any]]: - +def live_test_models() -> Iterable[dict[str, Any]]: def _from_validation(pspec: EmbeddingProviderParameter) -> int: assert pspec.parameter_type == "number" if "numericRange" in pspec.validation: @@ -181,7 +180,7 @@ def _from_validation(pspec: EmbeddingProviderParameter) -> int: m1: int = pspec.validation["numericRange"][1] return (m0 + m1) // 2 elif "options" in pspec.validation: - options: List[int] = pspec.validation["options"] + options: list[int] = pspec.validation["options"] if len(options) > 1: return options[1] else: