From 2adce1a6926fa68101e4e545d59b8b93f9a8632e Mon Sep 17 00:00:00 2001 From: Stefano Lottini Date: Fri, 5 Jul 2024 13:35:08 +0200 Subject: [PATCH] SL "embedding headers provider" refactoring (#292) * wip to EmbeddingHeadersProvider transition * embedding headers provider + unit tests * api_options retains emb.header provider with coercion at the front door; docstrings for h.providers * adjust docstrings (collection, database) to emb.header providers * More streamlined flow around redacting headers in logging * vectorize integration testing handles multi-header (i.e. bedrock) * re-enable override for api options with emb.hea.providers * improve api options override/default logic and add docstrings; expose api_options in collection __repr__ * finished support for multiheader and bedrock case --- CHANGES | 4 + README.md | 8 +- astrapy/api_commander.py | 14 +- astrapy/api_options.py | 62 +++++- astrapy/authentication.py | 182 +++++++++++++++++- astrapy/collection.py | 97 +++++----- astrapy/core/defaults.py | 11 +- astrapy/core/utils.py | 16 +- astrapy/database.py | 75 +++++--- pyproject.toml | 2 +- .../integration/test_vectorize_providers.py | 28 ++- tests/vectorize_idiomatic/query_providers.py | 24 ++- .../unit/test_embeddingheadersprovider.py | 67 +++++++ tests/vectorize_idiomatic/vectorize_models.py | 40 +++- 14 files changed, 506 insertions(+), 124 deletions(-) create mode 100644 tests/vectorize_idiomatic/unit/test_embeddingheadersprovider.py diff --git a/CHANGES b/CHANGES index 0387e0f4..e615509a 100644 --- a/CHANGES +++ b/CHANGES @@ -1,5 +1,9 @@ (master) ======== +Support for multiple-header embedding api keys: + - `EmbeddingHeadersProvider` classes for `embedding_api_key` parameter + - AWS header provider in addition to the regular one-header one + - adapt CI to cover this setup Testing: - restructure CI to fully support HCD alongside Astra DB - add details for testing new embedding providers diff --git a/README.md b/README.md index 4aac1cbc..4563828a 100644 --- a/README.md +++ b/README.md @@ -350,7 +350,13 @@ Remove logging noise with: poetry run pytest [...] -o log_cli=0 ``` -Do not drop collections (core): +Increase logging level to `TRACE` (i.e. level `5`): + +``` +poetry run pytest [...] -o log_cli=1 --log-cli-level=5 +``` + +Do not drop collections (valid for core): ``` TEST_SKIP_COLLECTION_DELETE=1 poetry run pytest [...] diff --git a/astrapy/api_commander.py b/astrapy/api_commander.py index 7d54f3eb..2e91bfbb 100644 --- a/astrapy/api_commander.py +++ b/astrapy/api_commander.py @@ -19,11 +19,15 @@ import httpx +from astrapy.authentication import ( + EMBEDDING_HEADER_API_KEY, + EMBEDDING_HEADER_AWS_ACCESS_ID, + EMBEDDING_HEADER_AWS_SECRET_ID, +) from astrapy.core.defaults import ( DEFAULT_AUTH_HEADER, DEFAULT_DEV_OPS_AUTH_HEADER, DEFAULT_TIMEOUT, - DEFAULT_VECTORIZE_SECRET_HEADER, ) from astrapy.core.utils import ( TimeoutInfoWideType, @@ -44,11 +48,13 @@ to_dataapi_timeout_exception, ) -DEFAULT_REDACTED_HEADER_NAMES = [ +DEFAULT_REDACTED_HEADER_NAMES = { DEFAULT_AUTH_HEADER, DEFAULT_DEV_OPS_AUTH_HEADER, - DEFAULT_VECTORIZE_SECRET_HEADER, -] + EMBEDDING_HEADER_AWS_ACCESS_ID, + EMBEDDING_HEADER_AWS_SECRET_ID, + EMBEDDING_HEADER_API_KEY, +} def full_user_agent( diff --git a/astrapy/api_options.py b/astrapy/api_options.py index 03eb6316..88f60f04 100644 --- a/astrapy/api_options.py +++ b/astrapy/api_options.py @@ -14,9 +14,14 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Optional, TypeVar +from astrapy.authentication import ( + EmbeddingHeadersProvider, + StaticEmbeddingHeadersProvider, +) + AO = TypeVar("AO", bound="BaseAPIOptions") @@ -39,22 +44,54 @@ class BaseAPIOptions: max_time_ms: Optional[int] = None def with_default(self: AO, default: Optional[BaseAPIOptions]) -> AO: + """ + Return a new instance created by completing this instance with a default + API options object. + + In other words, `optA.with_default(optB)` will take fields from optA + when possible and draw defaults from optB when optA has them set to anything + evaluating to False. (This relies on the __bool__ definition of the values, + such as that of the EmbeddingHeadersTokenProvider instances) + + Args: + default: an API options instance to draw defaults from. + + Returns: + a new instance of this class obtained by merging this one and the default. + """ if default: + default_dict = default.__dict__ return self.__class__( **{ - **default.__dict__, - **{k: v for k, v in self.__dict__.items() if v is not None}, + k: self_v or default_dict.get(k) + for k, self_v in self.__dict__.items() } ) else: return self def with_override(self: AO, override: Optional[BaseAPIOptions]) -> AO: + """ + Return a new instance created by overriding the members of this instance + with those taken from a supplied "override" API options object. + + In other words, `optA.with_default(optB)` will take fields from optB + when possible and fall back to optA when optB has them set to anything + evaluating to False. (This relies on the __bool__ definition of the values, + such as that of the EmbeddingHeadersTokenProvider instances) + + Args: + override: an API options instance to preferentially draw fields from. + + Returns: + a new instance of this class obtained by merging the override and this one. + """ if override: + self_dict = self.__dict__ return self.__class__( **{ - **self.__dict__, - **{k: v for k, v in override.__dict__.items() if v is not None}, + k: override_v or self_dict.get(k) + for k, override_v in override.__dict__.items() } ) else: @@ -77,10 +114,15 @@ class CollectionAPIOptions(BaseAPIOptions): `find`, `delete_many`, `insert_many` and so on), it is strongly suggested to provide a specific timeout as the default one likely wouldn't make much sense. - embedding_api_key: an optional API key for interacting with the collection. - If an embedding service is configured, and this attribute is set, - each Data API call will include a "x-embedding-api-key" header - with the value of this attribute. + embedding_api_key: an `astrapy.authentication.EmbeddingHeadersProvider` + object, encoding embedding-related API keys that will be passed + as headers when interacting with the collection (on each Data API request). + The default value is `StaticEmbeddingHeadersProvider(None)`, i.e. + no embedding-specific headers, whereas if the collection is configured + with an embedding service other choices for this parameter can be + meaningfully supplied. is configured for the collection, """ - embedding_api_key: Optional[str] = None + embedding_api_key: EmbeddingHeadersProvider = field( + default_factory=lambda: StaticEmbeddingHeadersProvider(None) + ) diff --git a/astrapy/authentication.py b/astrapy/authentication.py index f26495e5..49174368 100644 --- a/astrapy/authentication.py +++ b/astrapy/authentication.py @@ -16,16 +16,29 @@ import base64 from abc import ABC, abstractmethod -from typing import Any, Union +from typing import Any, Dict, Optional, Union +EMBEDDING_HEADER_AWS_ACCESS_ID = "X-Embedding-Access-Id" +EMBEDDING_HEADER_AWS_SECRET_ID = "X-Embedding-Secret-Id" +EMBEDDING_HEADER_API_KEY = "X-Embedding-Api-Key" -def coerce_token_provider(token: Any) -> TokenProvider: + +def coerce_token_provider(token: Optional[Union[str, TokenProvider]]) -> TokenProvider: if isinstance(token, TokenProvider): return token else: return StaticTokenProvider(token) +def coerce_embedding_headers_provider( + embedding_api_key: Optional[Union[str, EmbeddingHeadersProvider]], +) -> EmbeddingHeadersProvider: + if isinstance(embedding_api_key, EmbeddingHeadersProvider): + return embedding_api_key + else: + return StaticEmbeddingHeadersProvider(embedding_api_key) + + class TokenProvider(ABC): """ Abstract base class for a token provider. @@ -54,14 +67,6 @@ def __eq__(self, other: Any) -> bool: @abstractmethod def __repr__(self) -> str: ... - @abstractmethod - def get_token(self) -> Union[str, None]: - """ - Produce a string for direct use as token in a subsequent API request, - or None for no token. - """ - ... - def __or__(self, other: TokenProvider) -> TokenProvider: """ Implement the logic as for "token_str_a or token_str_b" for the TokenProvider, @@ -82,6 +87,14 @@ def __bool__(self) -> bool: """ return self.get_token() is not None + @abstractmethod + def get_token(self) -> Union[str, None]: + """ + Produce a string for direct use as token in a subsequent API request, + or None for no token. + """ + ... + class StaticTokenProvider(TokenProvider): """ @@ -150,3 +163,152 @@ def _b64(cleartext: str) -> str: def get_token(self) -> str: return self.token + + +class EmbeddingHeadersProvider(ABC): + """ + Abstract base class for a provider of embedding-related headers (such as API Keys). + The relevant method in this interface is returning a dict to use as + (part of the) headers in Data API requests for a collection. + + This class captures the fact that, depending on the embedding provider for + the collection, there may be zero, one *or more* headers to be passed + if relying on the HEADERS auth method for Vectorize. + """ + + def __eq__(self, other: Any) -> bool: + my_headers = self.get_headers() + if isinstance(other, EmbeddingHeadersProvider): + return other.get_headers() == my_headers + else: + return False + + @abstractmethod + def __repr__(self) -> str: ... + + def __bool__(self) -> bool: + """ + All headers providers evaluate to True unless they yield the empty dict. + This method enables the override mechanism in APIOptions. + """ + return self.get_headers() != {} + + @abstractmethod + def get_headers(self) -> Dict[str, str]: + """ + Produce a dictionary for use as (part of) the headers in HTTP requests + to the Data API. + """ + ... + + +class StaticEmbeddingHeadersProvider(EmbeddingHeadersProvider): + """ + A "pass-through" header provider representing the single-header + (typically "X-Embedding-Api-Key") auth scheme, in use by most of the + embedding models in Vectorize. + + Args: + embedding_api_key: a string that will be the value for the header. + If None is passed, this results in a no-headers provider (such + as the one used for non-Vectorize collections). + + Example: + >>> from astrapy import DataAPIClient + >>> from astrapy.authentication import ( + CollectionVectorServiceOptions, + StaticEmbeddingHeadersProvider, + ) + >>> my_emb_api_key = StaticEmbeddingHeadersProvider("abc012...") + >>> service_options = CollectionVectorServiceOptions( + ... provider="a-certain-provider", + ... model_name="some-embedding-model", + ... ) + >>> + >>> database = DataAPIClient().get_database( + ... "https://01234567-...-eu-west1.apps.datastax.com", + ... token="AstraCS:...", + ... ) + >>> collection = database.create_collection( + ... "vectorize_collection", + ... service=service_options, + ... embedding_api_key=my_emb_api_key, + ... ) + >>> # likewise: + >>> collection_b = database.get_collection( + ... "vectorize_collection", + ... embedding_api_key=my_emb_api_key, + ... ) + """ + + def __init__(self, embedding_api_key: Optional[str]) -> None: + self.embedding_api_key = embedding_api_key + + def __repr__(self) -> str: + if self.embedding_api_key is None: + return f"{self.__class__.__name__}(empty)" + else: + return f'{self.__class__.__name__}("{self.embedding_api_key[:5]}...")' + + def get_headers(self) -> Dict[str, str]: + if self.embedding_api_key is not None: + return {EMBEDDING_HEADER_API_KEY: self.embedding_api_key} + else: + return {} + + +class AWSEmbeddingHeadersProvider(EmbeddingHeadersProvider): + """ + A header provider representing the two-header auth scheme in use + by the Amazon Web Services (e.g. AWS Bedrock) when using header-based + authentication. + + Args: + embedding_access_id: value of the "Access ID" secret. This will become + the value for the corresponding header. + embedding_secret_id: value of the "Secret ID" secret. This will become + the value for the corresponding header. + + Example: + >>> from astrapy import DataAPIClient + >>> from astrapy.authentication import ( + CollectionVectorServiceOptions, + AWSEmbeddingHeadersProvider, + ) + >>> my_aws_emb_api_key = AWSEmbeddingHeadersProvider( + embedding_access_id="my-access-id-012...", + embedding_secret_id="my-secret-id-abc...", + ) + >>> service_options = CollectionVectorServiceOptions( + ... provider="bedrock", + ... model_name="some-aws-bedrock-model", + ... ) + >>> + >>> database = DataAPIClient().get_database( + ... "https://01234567-...-eu-west1.apps.datastax.com", + ... token="AstraCS:...", + ... ) + >>> collection = database.create_collection( + ... "vectorize_aws_collection", + ... service=service_options, + ... embedding_api_key=my_aws_emb_api_key, + ... ) + >>> # likewise: + >>> collection_b = database.get_collection( + ... "vectorize_aws_collection", + ... embedding_api_key=my_aws_emb_api_key, + ... ) + """ + + def __init__(self, *, embedding_access_id: str, embedding_secret_id: str) -> None: + self.embedding_access_id = embedding_access_id + self.embedding_secret_id = embedding_secret_id + + def __repr__(self) -> str: + return f'{self.__class__.__name__}("{self.embedding_access_id[:3]}...", "{self.embedding_secret_id[:3]}...")' + + def get_headers(self) -> Dict[str, str]: + return { + EMBEDDING_HEADER_AWS_ACCESS_ID: self.embedding_access_id, + EMBEDDING_HEADER_AWS_SECRET_ID: self.embedding_secret_id, + } diff --git a/astrapy/collection.py b/astrapy/collection.py index 632a4f70..a72ebca9 100644 --- a/astrapy/collection.py +++ b/astrapy/collection.py @@ -24,6 +24,7 @@ from astrapy import __version__ from astrapy.api_options import CollectionAPIOptions +from astrapy.authentication import coerce_embedding_headers_provider from astrapy.constants import ( DocumentType, FilterType, @@ -34,10 +35,7 @@ normalize_optional_projection, ) from astrapy.core.db import AstraDBCollection, AsyncAstraDBCollection -from astrapy.core.defaults import ( - DEFAULT_INSERT_NUM_DOCUMENTS, - DEFAULT_VECTORIZE_SECRET_HEADER, -) +from astrapy.core.defaults import DEFAULT_INSERT_NUM_DOCUMENTS from astrapy.cursors import AsyncCursor, Cursor from astrapy.database import AsyncDatabase, Database from astrapy.exceptions import ( @@ -66,6 +64,7 @@ ) if TYPE_CHECKING: + from astrapy.authentication import EmbeddingHeadersProvider from astrapy.operations import AsyncBaseOperation, BaseOperation @@ -264,13 +263,7 @@ def __init__( self.api_options = CollectionAPIOptions() else: self.api_options = api_options - additional_headers = { - k: v - for k, v in { - DEFAULT_VECTORIZE_SECRET_HEADER: self.api_options.embedding_api_key, - }.items() - if v is not None - } + additional_headers = self.api_options.embedding_api_key.get_headers() self._astra_db_collection: AstraDBCollection = AstraDBCollection( collection_name=name, astra_db=database._astra_db, @@ -287,7 +280,8 @@ def __init__( def __repr__(self) -> str: return ( f'{self.__class__.__name__}(name="{self.name}", ' - f'namespace="{self.namespace}", database={self.database})' + f'namespace="{self.namespace}", database={self.database}, ' + f"api_options={self.api_options})" ) def __eq__(self, other: Any) -> bool: @@ -332,7 +326,7 @@ def with_options( self, *, name: Optional[str] = None, - embedding_api_key: Optional[str] = None, + embedding_api_key: Optional[Union[str, EmbeddingHeadersProvider]] = None, collection_max_time_ms: Optional[int] = None, caller_name: Optional[str] = None, caller_version: Optional[str] = None, @@ -344,10 +338,15 @@ def with_options( name: the name of the collection. This parameter is useful to quickly spawn Collection instances each pointing to a different collection existing in the same namespace. - embedding_api_key: an optional API key for interacting with the collection. - If an embedding service is configured, and this attribute is set, - each Data API call will include a "x-embedding-api-key" header - with the value of this attribute. + embedding_api_key: optional API key(s) for interacting with the collection. + If an embedding service is configured, and this parameter is not None, + each Data API call will include the necessary embedding-related headers + as specified by this parameter. If a string is passed, it translates + into the one "embedding api key" header + (i.e. `astrapy.authentication.StaticEmbeddingHeadersProvider`). + For some vectorize providers/models, if using header-based authentication, + specialized subclasses of `astrapy.authentication.EmbeddingHeadersProvider` + should be supplied. collection_max_time_ms: a default timeout, in millisecond, for the duration of each operation on the collection. Individual timeouts can be provided to each collection method call and will take precedence, with this value @@ -371,7 +370,7 @@ def with_options( """ _api_options = CollectionAPIOptions( - embedding_api_key=embedding_api_key, + embedding_api_key=coerce_embedding_headers_provider(embedding_api_key), max_time_ms=collection_max_time_ms, ) @@ -388,7 +387,7 @@ def to_async( database: Optional[AsyncDatabase] = None, name: Optional[str] = None, namespace: Optional[str] = None, - embedding_api_key: Optional[str] = None, + embedding_api_key: Optional[Union[str, EmbeddingHeadersProvider]] = None, collection_max_time_ms: Optional[int] = None, caller_name: Optional[str] = None, caller_version: Optional[str] = None, @@ -406,10 +405,15 @@ def to_async( collection on the database. namespace: this is the namespace to which the collection belongs. If not specified, the database's working namespace is used. - embedding_api_key: an optional API key for interacting with the collection. - If an embedding service is configured, and this attribute is set, - each Data API call will include a "x-embedding-api-key" header - with the value of this attribute. + embedding_api_key: optional API key(s) for interacting with the collection. + If an embedding service is configured, and this parameter is not None, + each Data API call will include the necessary embedding-related headers + as specified by this parameter. If a string is passed, it translates + into the one "embedding api key" header + (i.e. `astrapy.authentication.StaticEmbeddingHeadersProvider`). + For some vectorize providers/models, if using header-based authentication, + specialized subclasses of `astrapy.authentication.EmbeddingHeadersProvider` + should be supplied. collection_max_time_ms: a default timeout, in millisecond, for the duration of each operation on the collection. Individual timeouts can be provided to each collection method call and will take precedence, with this value @@ -431,7 +435,7 @@ def to_async( """ _api_options = CollectionAPIOptions( - embedding_api_key=embedding_api_key, + embedding_api_key=coerce_embedding_headers_provider(embedding_api_key), max_time_ms=collection_max_time_ms, ) @@ -2677,13 +2681,7 @@ def __init__( self.api_options = CollectionAPIOptions() else: self.api_options = api_options - additional_headers = { - k: v - for k, v in { - DEFAULT_VECTORIZE_SECRET_HEADER: self.api_options.embedding_api_key, - }.items() - if v is not None - } + additional_headers = self.api_options.embedding_api_key.get_headers() self._astra_db_collection: AsyncAstraDBCollection = AsyncAstraDBCollection( collection_name=name, astra_db=database._astra_db, @@ -2700,7 +2698,8 @@ def __init__( def __repr__(self) -> str: return ( f'{self.__class__.__name__}(name="{self.name}", ' - f'namespace="{self.namespace}", database={self.database})' + f'namespace="{self.namespace}", database={self.database}, ' + f"api_options={self.api_options})" ) def __eq__(self, other: Any) -> bool: @@ -2745,7 +2744,7 @@ def with_options( self, *, name: Optional[str] = None, - embedding_api_key: Optional[str] = None, + embedding_api_key: Optional[Union[str, EmbeddingHeadersProvider]] = None, collection_max_time_ms: Optional[int] = None, caller_name: Optional[str] = None, caller_version: Optional[str] = None, @@ -2757,10 +2756,15 @@ def with_options( name: the name of the collection. This parameter is useful to quickly spawn AsyncCollection instances each pointing to a different collection existing in the same namespace. - embedding_api_key: an optional API key for interacting with the collection. - If an embedding service is configured, and this attribute is set, - each Data API call will include a "x-embedding-api-key" header - with the value of this attribute. + embedding_api_key: optional API key(s) for interacting with the collection. + If an embedding service is configured, and this parameter is not None, + each Data API call will include the necessary embedding-related headers + as specified by this parameter. If a string is passed, it translates + into the one "embedding api key" header + (i.e. `astrapy.authentication.StaticEmbeddingHeadersProvider`). + For some vectorize providers/models, if using header-based authentication, + specialized subclasses of `astrapy.authentication.EmbeddingHeadersProvider` + should be supplied. collection_max_time_ms: a default timeout, in millisecond, for the duration of each operation on the collection. Individual timeouts can be provided to each collection method call and will take precedence, with this value @@ -2784,7 +2788,7 @@ def with_options( """ _api_options = CollectionAPIOptions( - embedding_api_key=embedding_api_key, + embedding_api_key=coerce_embedding_headers_provider(embedding_api_key), max_time_ms=collection_max_time_ms, ) @@ -2801,7 +2805,7 @@ def to_sync( database: Optional[Database] = None, name: Optional[str] = None, namespace: Optional[str] = None, - embedding_api_key: Optional[str] = None, + embedding_api_key: Optional[Union[str, EmbeddingHeadersProvider]] = None, collection_max_time_ms: Optional[int] = None, caller_name: Optional[str] = None, caller_version: Optional[str] = None, @@ -2819,10 +2823,15 @@ def to_sync( collection on the database. namespace: this is the namespace to which the collection belongs. If not specified, the database's working namespace is used. - embedding_api_key: an optional API key for interacting with the collection. - If an embedding service is configured, and this attribute is set, - each Data API call will include a "x-embedding-api-key" header - with the value of this attribute. + embedding_api_key: optional API key(s) for interacting with the collection. + If an embedding service is configured, and this parameter is not None, + each Data API call will include the necessary embedding-related headers + as specified by this parameter. If a string is passed, it translates + into the one "embedding api key" header + (i.e. `astrapy.authentication.StaticEmbeddingHeadersProvider`). + For some vectorize providers/models, if using header-based authentication, + specialized subclasses of `astrapy.authentication.EmbeddingHeadersProvider` + should be supplied. collection_max_time_ms: a default timeout, in millisecond, for the duration of each operation on the collection. Individual timeouts can be provided to each collection method call and will take precedence, with this value @@ -2844,7 +2853,7 @@ def to_sync( """ _api_options = CollectionAPIOptions( - embedding_api_key=embedding_api_key, + embedding_api_key=coerce_embedding_headers_provider(embedding_api_key), max_time_ms=collection_max_time_ms, ) diff --git a/astrapy/core/defaults.py b/astrapy/core/defaults.py index 1e6c3d6e..44d0b49f 100644 --- a/astrapy/core/defaults.py +++ b/astrapy/core/defaults.py @@ -31,4 +31,13 @@ MAX_INSERT_NUM_DOCUMENTS = 100 DEFAULT_INSERT_NUM_DOCUMENTS = 50 -DEFAULT_VECTORIZE_SECRET_HEADER = "x-embedding-api-key" +# Some of these are repeated by hand from idiomatic, tolerable duplication +# as long as `core` is in place: +# (the case must match, so that secrets are effectively masked) +DEFAULT_REDACTED_HEADERS = { + DEFAULT_DEV_OPS_AUTH_HEADER, + DEFAULT_AUTH_HEADER, + "X-Embedding-Api-Key", + "X-Embedding-Access-Id", + "X-Embedding-Secret-Id", +} diff --git a/astrapy/core/utils.py b/astrapy/core/utils.py index a6ffcc89..dad44019 100644 --- a/astrapy/core/utils.py +++ b/astrapy/core/utils.py @@ -14,7 +14,6 @@ from __future__ import annotations -import copy import datetime import json import logging @@ -25,11 +24,7 @@ from astrapy import __version__ from astrapy.core.core_types import API_RESPONSE -from astrapy.core.defaults import ( - DEFAULT_AUTH_HEADER, - DEFAULT_TIMEOUT, - DEFAULT_VECTORIZE_SECRET_HEADER, -) +from astrapy.core.defaults import DEFAULT_REDACTED_HEADERS, DEFAULT_TIMEOUT from astrapy.core.ids import UUID, ObjectId @@ -97,11 +92,10 @@ def log_request( logger.debug(f"Request params: {params}") # Redact known secrets from the request headers - headers_log = copy.deepcopy(headers) - if DEFAULT_AUTH_HEADER in headers_log: - headers_log[DEFAULT_AUTH_HEADER] = "***" - if DEFAULT_VECTORIZE_SECRET_HEADER in headers_log: - headers_log[DEFAULT_VECTORIZE_SECRET_HEADER] = "***" + headers_log = { + hdr_k: hdr_v if hdr_k not in DEFAULT_REDACTED_HEADERS else "***" + for hdr_k, hdr_v in headers.items() + } logger.debug(f"Request headers: {headers_log}") diff --git a/astrapy/database.py b/astrapy/database.py index b22960f9..3e572844 100644 --- a/astrapy/database.py +++ b/astrapy/database.py @@ -25,7 +25,10 @@ parse_api_endpoint, ) from astrapy.api_options import CollectionAPIOptions -from astrapy.authentication import coerce_token_provider +from astrapy.authentication import ( + coerce_embedding_headers_provider, + coerce_token_provider, +) from astrapy.constants import Environment from astrapy.core.db import AstraDB, AsyncAstraDB from astrapy.cursors import AsyncCommandCursor, CommandCursor @@ -46,7 +49,7 @@ if TYPE_CHECKING: from astrapy.admin import DatabaseAdmin - from astrapy.authentication import TokenProvider + from astrapy.authentication import EmbeddingHeadersProvider, TokenProvider from astrapy.collection import AsyncCollection, Collection @@ -419,7 +422,7 @@ def get_collection( name: str, *, namespace: Optional[str] = None, - embedding_api_key: Optional[str] = None, + embedding_api_key: Optional[Union[str, EmbeddingHeadersProvider]] = None, collection_max_time_ms: Optional[int] = None, ) -> Collection: """ @@ -436,10 +439,15 @@ def get_collection( name: the name of the collection. namespace: the namespace containing the collection. If no namespace is specified, the general setting for this database is used. - embedding_api_key: an optional API key for interacting with the collection. - If an embedding service is configured, and this attribute is set, - each Data API call will include a "x-embedding-api-key" header - with the value of this attribute. + embedding_api_key: optional API key(s) for interacting with the collection. + If an embedding service is configured, and this parameter is not None, + each Data API call will include the necessary embedding-related headers + as specified by this parameter. If a string is passed, it translates + into the one "embedding api key" header + (i.e. `astrapy.authentication.StaticEmbeddingHeadersProvider`). + For some vectorize providers/models, if using header-based authentication, + specialized subclasses of `astrapy.authentication.EmbeddingHeadersProvider` + should be supplied. collection_max_time_ms: a default timeout, in millisecond, for the duration of each operation on the collection. Individual timeouts can be provided to each collection method call and will take precedence, with this value @@ -475,7 +483,7 @@ def get_collection( name, namespace=_namespace, api_options=CollectionAPIOptions( - embedding_api_key=embedding_api_key, + embedding_api_key=coerce_embedding_headers_provider(embedding_api_key), max_time_ms=collection_max_time_ms, ), ) @@ -494,7 +502,7 @@ def create_collection( additional_options: Optional[Dict[str, Any]] = None, check_exists: Optional[bool] = None, max_time_ms: Optional[int] = None, - embedding_api_key: Optional[str] = None, + embedding_api_key: Optional[Union[str, EmbeddingHeadersProvider]] = None, collection_max_time_ms: Optional[int] = None, ) -> Collection: """ @@ -539,10 +547,15 @@ def create_collection( preexisting collections, the command will succeed or fail depending on whether the options match or not. max_time_ms: a timeout, in milliseconds, for the underlying HTTP request. - embedding_api_key: an optional API key for interacting with the collection. - If an embedding service is configured, and this attribute is set, - each Data API call will include a "x-embedding-api-key" header - with the value of this attribute. + embedding_api_key: optional API key(s) for interacting with the collection. + If an embedding service is configured, and this parameter is not None, + each Data API call will include the necessary embedding-related headers + as specified by this parameter. If a string is passed, it translates + into the one "embedding api key" header + (i.e. `astrapy.authentication.StaticEmbeddingHeadersProvider`). + For some vectorize providers/models, if using header-based authentication, + specialized subclasses of `astrapy.authentication.EmbeddingHeadersProvider` + should be supplied. collection_max_time_ms: a default timeout, in millisecond, for the duration of each operation on the collection. Individual timeouts can be provided to each collection method call and will take precedence, with this value @@ -626,7 +639,7 @@ def create_collection( return self.get_collection( name, namespace=namespace, - embedding_api_key=embedding_api_key, + embedding_api_key=coerce_embedding_headers_provider(embedding_api_key), collection_max_time_ms=collection_max_time_ms, ) @@ -1247,7 +1260,7 @@ async def get_collection( name: str, *, namespace: Optional[str] = None, - embedding_api_key: Optional[str] = None, + embedding_api_key: Optional[Union[str, EmbeddingHeadersProvider]] = None, collection_max_time_ms: Optional[int] = None, ) -> AsyncCollection: """ @@ -1264,10 +1277,15 @@ async def get_collection( name: the name of the collection. namespace: the namespace containing the collection. If no namespace is specified, the setting for this database is used. - embedding_api_key: an optional API key for interacting with the collection. - If an embedding service is configured, and this attribute is set, - each Data API call will include a "x-embedding-api-key" header - with the value of this attribute. + embedding_api_key: optional API key(s) for interacting with the collection. + If an embedding service is configured, and this parameter is not None, + each Data API call will include the necessary embedding-related headers + as specified by this parameter. If a string is passed, it translates + into the one "embedding api key" header + (i.e. `astrapy.authentication.StaticEmbeddingHeadersProvider`). + For some vectorize providers/models, if using header-based authentication, + specialized subclasses of `astrapy.authentication.EmbeddingHeadersProvider` + should be supplied. collection_max_time_ms: a default timeout, in millisecond, for the duration of each operation on the collection. Individual timeouts can be provided to each collection method call and will take precedence, with this value @@ -1306,7 +1324,7 @@ async def get_collection( name, namespace=_namespace, api_options=CollectionAPIOptions( - embedding_api_key=embedding_api_key, + embedding_api_key=coerce_embedding_headers_provider(embedding_api_key), max_time_ms=collection_max_time_ms, ), ) @@ -1325,7 +1343,7 @@ async def create_collection( additional_options: Optional[Dict[str, Any]] = None, check_exists: Optional[bool] = None, max_time_ms: Optional[int] = None, - embedding_api_key: Optional[str] = None, + embedding_api_key: Optional[Union[str, EmbeddingHeadersProvider]] = None, collection_max_time_ms: Optional[int] = None, ) -> AsyncCollection: """ @@ -1370,10 +1388,15 @@ async def create_collection( preexisting collections, the command will succeed or fail depending on whether the options match or not. max_time_ms: a timeout, in milliseconds, for the underlying HTTP request. - embedding_api_key: an optional API key for interacting with the collection. - If an embedding service is configured, and this attribute is set, - each Data API call will include a "x-embedding-api-key" header - with the value of this attribute. + embedding_api_key: optional API key(s) for interacting with the collection. + If an embedding service is configured, and this parameter is not None, + each Data API call will include the necessary embedding-related headers + as specified by this parameter. If a string is passed, it translates + into the one "embedding api key" header + (i.e. `astrapy.authentication.StaticEmbeddingHeadersProvider`). + For some vectorize providers/models, if using header-based authentication, + specialized subclasses of `astrapy.authentication.EmbeddingHeadersProvider` + should be supplied. collection_max_time_ms: a default timeout, in millisecond, for the duration of each operation on the collection. Individual timeouts can be provided to each collection method call and will take precedence, with this value @@ -1460,7 +1483,7 @@ async def create_collection( return await self.get_collection( name, namespace=namespace, - embedding_api_key=embedding_api_key, + embedding_api_key=coerce_embedding_headers_provider(embedding_api_key), collection_max_time_ms=collection_max_time_ms, ) diff --git a/pyproject.toml b/pyproject.toml index afcb73ec..4a8f31f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "astrapy" -version = "1.3.1" +version = "1.3.2" description = "AstraPy is a Pythonic SDK for DataStax Astra and its Data API" authors = [ "Stefano Lottini ", diff --git a/tests/vectorize_idiomatic/integration/test_vectorize_providers.py b/tests/vectorize_idiomatic/integration/test_vectorize_providers.py index f63c7292..1f883f64 100644 --- a/tests/vectorize_idiomatic/integration/test_vectorize_providers.py +++ b/tests/vectorize_idiomatic/integration/test_vectorize_providers.py @@ -13,11 +13,12 @@ # limitations under the License. import os -from typing import Any, Dict, List +from typing import Any, Dict, List, Union import pytest from astrapy import Database +from astrapy.authentication import AWSEmbeddingHeadersProvider, EmbeddingHeadersProvider from astrapy.exceptions import DataAPIResponseException, InsertManyException from astrapy.info import CollectionVectorServiceOptions @@ -88,9 +89,28 @@ def test_vectorize_usage_auth_type_header_sync( testable_vectorize_model: Dict[str, Any], ) -> None: simple_tag = testable_vectorize_model["simple_tag"].lower() - embedding_api_key = os.environ[ - f"HEADER_EMBEDDING_API_KEY_{testable_vectorize_model['secret_tag']}" - ] + # switch betewen header providers according to what is needed + # For the time being this is necessary on HEADER only + embedding_api_key: Union[str, EmbeddingHeadersProvider] + at_tokens = testable_vectorize_model["auth_type_tokens"] + at_token_lnames = {tk["accepted"].lower() for tk in at_tokens} + if at_token_lnames == {"x-embedding-api-key"}: + embedding_api_key = os.environ[ + f"HEADER_EMBEDDING_API_KEY_{testable_vectorize_model['secret_tag']}" + ] + elif at_token_lnames == {"x-embedding-access-id", "x-embedding-secret-id"}: + embedding_api_key = AWSEmbeddingHeadersProvider( + embedding_access_id=os.environ[ + f"HEADER_EMBEDDING_ACCESS_ID_{testable_vectorize_model['secret_tag']}" + ], + embedding_secret_id=os.environ[ + f"HEADER_EMBEDDING_SECRET_ID_{testable_vectorize_model['secret_tag']}" + ], + ) + else: + raise ValueError( + f"Unsupported auth type tokens for {testable_vectorize_model['model_tag']}" + ) dimension = testable_vectorize_model.get("dimension") service_options = testable_vectorize_model["service_options"] diff --git a/tests/vectorize_idiomatic/query_providers.py b/tests/vectorize_idiomatic/query_providers.py index c8bd3b40..d85c5e17 100644 --- a/tests/vectorize_idiomatic/query_providers.py +++ b/tests/vectorize_idiomatic/query_providers.py @@ -27,14 +27,24 @@ def desc_param(param_data: Dict[str, Any]) -> str: if param_data["type"].lower() == "string": return "str" elif param_data["type"].lower() == "number": - validation = param_data.get("validation", {}).get("numericRange") - assert isinstance(validation, list) and len(validation) == 2 - range_desc = f"[{validation[0]} : {validation[1]}]" - if "defaultValue" in param_data: - range_desc2 = f"{range_desc} (default={param_data['defaultValue']})" + validation = param_data.get("validation", {}) + if "numericRange" in validation: + validation_nr = validation["numericRange"] + assert isinstance(validation_nr, list) and len(validation_nr) == 2 + range_desc = f"[{validation_nr[0]} : {validation_nr[1]}]" + if "defaultValue" in param_data: + range_desc2 = f"{range_desc} (default={param_data['defaultValue']})" + else: + range_desc2 = range_desc + return f"number, {range_desc2}" + elif "options" in validation: + validation_op = validation["options"] + assert isinstance(validation_op, list) and len(validation_op) > 1 + return f"number, {' / '.join(str(v) for v in validation_op)}" else: - range_desc2 = range_desc - return f"number, {range_desc2}" + raise ValueError( + f"Unknown number validation spec: '{json.dumps(validation)}'" + ) elif param_data["type"].lower() == "boolean": return "bool" else: diff --git a/tests/vectorize_idiomatic/unit/test_embeddingheadersprovider.py b/tests/vectorize_idiomatic/unit/test_embeddingheadersprovider.py new file mode 100644 index 00000000..c650cfb5 --- /dev/null +++ b/tests/vectorize_idiomatic/unit/test_embeddingheadersprovider.py @@ -0,0 +1,67 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from astrapy.authentication import ( + EMBEDDING_HEADER_API_KEY, + EMBEDDING_HEADER_AWS_ACCESS_ID, + EMBEDDING_HEADER_AWS_SECRET_ID, + AWSEmbeddingHeadersProvider, + StaticEmbeddingHeadersProvider, + coerce_embedding_headers_provider, +) + + +class TestEmbeddingHeadersProvider: + @pytest.mark.describe("test of headers from StaticEmbeddingHeadersProvider") + def test_embeddingheadersprovider_static(self) -> None: + ehp = StaticEmbeddingHeadersProvider("x") + assert {k.lower(): v for k, v in ehp.get_headers().items()} == { + EMBEDDING_HEADER_API_KEY.lower(): "x" + } + + @pytest.mark.describe("test of headers from empty StaticEmbeddingHeadersProvider") + def test_embeddingheadersprovider_null(self) -> None: + ehp = StaticEmbeddingHeadersProvider(None) + assert ehp.get_headers() == {} + + @pytest.mark.describe("test of headers from AWSEmbeddingHeadersProvider") + def test_embeddingheadersprovider_aws(self) -> None: + ehp = AWSEmbeddingHeadersProvider( + embedding_access_id="x", + embedding_secret_id="y", + ) + gen_headers_lower = {k.lower(): v for k, v in ehp.get_headers().items()} + exp_headers_lower = { + EMBEDDING_HEADER_AWS_ACCESS_ID.lower(): "x", + EMBEDDING_HEADER_AWS_SECRET_ID.lower(): "y", + } + assert gen_headers_lower == exp_headers_lower + + @pytest.mark.describe("test of embedding headers provider coercion") + def test_embeddingheadersprovider_coercion(self) -> None: + """This doubles as equality test.""" + ehp_s = StaticEmbeddingHeadersProvider("x") + ehp_n = StaticEmbeddingHeadersProvider(None) + ehp_a = AWSEmbeddingHeadersProvider( + embedding_access_id="x", + embedding_secret_id="y", + ) + assert coerce_embedding_headers_provider(ehp_s) == ehp_s + assert coerce_embedding_headers_provider(ehp_n) == ehp_n + assert coerce_embedding_headers_provider(ehp_a) == ehp_a + + assert coerce_embedding_headers_provider("x") == ehp_s + assert coerce_embedding_headers_provider(None) == ehp_n diff --git a/tests/vectorize_idiomatic/vectorize_models.py b/tests/vectorize_idiomatic/vectorize_models.py index 29b9928c..f251da74 100644 --- a/tests/vectorize_idiomatic/vectorize_models.py +++ b/tests/vectorize_idiomatic/vectorize_models.py @@ -14,10 +14,15 @@ import os import sys -from typing import Any, Dict, Iterable, Tuple +from typing import Any, Dict, Iterable, List, Tuple sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from astrapy.authentication import ( + EMBEDDING_HEADER_API_KEY, + EMBEDDING_HEADER_AWS_ACCESS_ID, + EMBEDDING_HEADER_AWS_SECRET_ID, +) from astrapy.info import CollectionVectorServiceOptions from .live_provider_info import live_provider_info @@ -84,6 +89,7 @@ SECRET_NAME_ROOT_MAP = { "azureOpenAI": "AZURE_OPENAI", + "bedrock": "BEDROCK", "cohere": "COHERE", "huggingface": "HUGGINGFACE", "huggingfaceDedicated": "HUGGINGFACEDED", @@ -149,6 +155,9 @@ "OPENAI_ORGANIZATION_ID" ], ("openai", "text-embedding-ada-002", "projectId"): os.environ["OPENAI_PROJECT_ID"], + # + ("bedrock", "amazon.titan-embed-text-v1", "region"): os.environ["BEDROCK_REGION"], + ("bedrock", "amazon.titan-embed-text-v2:0", "region"): os.environ["BEDROCK_REGION"], } # this is ad-hoc for HF dedicated. Models here, though "optional" dimension, @@ -169,6 +178,12 @@ def _from_validation(pspec: Dict[str, Any]) -> int: m0: int = pspec["validation"]["numericRange"][0] m1: int = pspec["validation"]["numericRange"][1] return (m0 + m1) // 2 + elif "options" in pspec["validation"]: + options: List[int] = pspec["validation"]["options"] + if len(options) > 1: + return options[1] + else: + return options[0] else: raise ValueError("unsupported pspec") @@ -190,12 +205,25 @@ def _collapse(longt: str) -> str: if auth_type_name == "NONE": assert auth_type_desc["tokens"] == [] elif auth_type_name == "HEADER": - assert {t["accepted"] for t in auth_type_desc["tokens"]} == { - "x-embedding-api-key" + header_names_lower = tuple( + sorted( + t["accepted"].lower() for t in auth_type_desc["tokens"] + ) + ) + assert header_names_lower in { + (EMBEDDING_HEADER_API_KEY.lower(),), + ( + EMBEDDING_HEADER_AWS_ACCESS_ID.lower(), + EMBEDDING_HEADER_AWS_SECRET_ID.lower(), + ), } elif auth_type_name == "SHARED_SECRET": - assert {t["accepted"] for t in auth_type_desc["tokens"]} == { - "providerKey" + authkey_names = tuple( + sorted(t["accepted"] for t in auth_type_desc["tokens"]) + ) + assert authkey_names in { + ("providerKey",), + ("accessId", "secretKey"), } else: raise ValueError("Unknown auth type") @@ -269,6 +297,7 @@ def _collapse(longt: str) -> str: "".join(c for c in model_tag_0 if c in alphanum) ), "auth_type_name": auth_type_name, + "auth_type_tokens": auth_type_desc["tokens"], "secret_tag": SECRET_NAME_ROOT_MAP[provider_name], "test_assets": TEST_ASSETS_MAP.get( (provider_name, model["name"]), DEFAULT_TEST_ASSETS @@ -292,6 +321,7 @@ def _collapse(longt: str) -> str: ): root_model = { "auth_type_name": auth_type_name, + "auth_type_tokens": auth_type_desc["tokens"], "dimension": dimension, "secret_tag": SECRET_NAME_ROOT_MAP[provider_name], "test_assets": TEST_ASSETS_MAP.get(