Skip to content

Commit

Permalink
SL "embedding headers provider" refactoring (#292)
Browse files Browse the repository at this point in the history
* wip to EmbeddingHeadersProvider transition

* embedding headers provider + unit tests

* api_options retains emb.header provider with coercion at the front door; docstrings for h.providers

* adjust docstrings (collection, database) to emb.header providers

* More streamlined flow around redacting headers in logging

* vectorize integration testing handles multi-header (i.e. bedrock)

* re-enable override for api options with emb.hea.providers

* improve api options override/default logic and add docstrings; expose api_options in collection __repr__

* finished support for multiheader and bedrock case
  • Loading branch information
hemidactylus committed Jul 5, 2024
1 parent 462c6f6 commit 2adce1a
Show file tree
Hide file tree
Showing 14 changed files with 506 additions and 124 deletions.
4 changes: 4 additions & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
(master)
========
Support for multiple-header embedding api keys:
- `EmbeddingHeadersProvider` classes for `embedding_api_key` parameter
- AWS header provider in addition to the regular one-header one
- adapt CI to cover this setup
Testing:
- restructure CI to fully support HCD alongside Astra DB
- add details for testing new embedding providers
Expand Down
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,13 @@ Remove logging noise with:
poetry run pytest [...] -o log_cli=0
```

Do not drop collections (core):
Increase logging level to `TRACE` (i.e. level `5`):

```
poetry run pytest [...] -o log_cli=1 --log-cli-level=5
```

Do not drop collections (valid for core):

```
TEST_SKIP_COLLECTION_DELETE=1 poetry run pytest [...]
Expand Down
14 changes: 10 additions & 4 deletions astrapy/api_commander.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@

import httpx

from astrapy.authentication import (
EMBEDDING_HEADER_API_KEY,
EMBEDDING_HEADER_AWS_ACCESS_ID,
EMBEDDING_HEADER_AWS_SECRET_ID,
)
from astrapy.core.defaults import (
DEFAULT_AUTH_HEADER,
DEFAULT_DEV_OPS_AUTH_HEADER,
DEFAULT_TIMEOUT,
DEFAULT_VECTORIZE_SECRET_HEADER,
)
from astrapy.core.utils import (
TimeoutInfoWideType,
Expand All @@ -44,11 +48,13 @@
to_dataapi_timeout_exception,
)

DEFAULT_REDACTED_HEADER_NAMES = [
DEFAULT_REDACTED_HEADER_NAMES = {
DEFAULT_AUTH_HEADER,
DEFAULT_DEV_OPS_AUTH_HEADER,
DEFAULT_VECTORIZE_SECRET_HEADER,
]
EMBEDDING_HEADER_AWS_ACCESS_ID,
EMBEDDING_HEADER_AWS_SECRET_ID,
EMBEDDING_HEADER_API_KEY,
}


def full_user_agent(
Expand Down
62 changes: 52 additions & 10 deletions astrapy/api_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,14 @@

from __future__ import annotations

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional, TypeVar

from astrapy.authentication import (
EmbeddingHeadersProvider,
StaticEmbeddingHeadersProvider,
)

AO = TypeVar("AO", bound="BaseAPIOptions")


Expand All @@ -39,22 +44,54 @@ class BaseAPIOptions:
max_time_ms: Optional[int] = None

def with_default(self: AO, default: Optional[BaseAPIOptions]) -> AO:
"""
Return a new instance created by completing this instance with a default
API options object.
In other words, `optA.with_default(optB)` will take fields from optA
when possible and draw defaults from optB when optA has them set to anything
evaluating to False. (This relies on the __bool__ definition of the values,
such as that of the EmbeddingHeadersTokenProvider instances)
Args:
default: an API options instance to draw defaults from.
Returns:
a new instance of this class obtained by merging this one and the default.
"""
if default:
default_dict = default.__dict__
return self.__class__(
**{
**default.__dict__,
**{k: v for k, v in self.__dict__.items() if v is not None},
k: self_v or default_dict.get(k)
for k, self_v in self.__dict__.items()
}
)
else:
return self

def with_override(self: AO, override: Optional[BaseAPIOptions]) -> AO:
"""
Return a new instance created by overriding the members of this instance
with those taken from a supplied "override" API options object.
In other words, `optA.with_default(optB)` will take fields from optB
when possible and fall back to optA when optB has them set to anything
evaluating to False. (This relies on the __bool__ definition of the values,
such as that of the EmbeddingHeadersTokenProvider instances)
Args:
override: an API options instance to preferentially draw fields from.
Returns:
a new instance of this class obtained by merging the override and this one.
"""
if override:
self_dict = self.__dict__
return self.__class__(
**{
**self.__dict__,
**{k: v for k, v in override.__dict__.items() if v is not None},
k: override_v or self_dict.get(k)
for k, override_v in override.__dict__.items()
}
)
else:
Expand All @@ -77,10 +114,15 @@ class CollectionAPIOptions(BaseAPIOptions):
`find`, `delete_many`, `insert_many` and so on), it is strongly suggested
to provide a specific timeout as the default one likely wouldn't make
much sense.
embedding_api_key: an optional API key for interacting with the collection.
If an embedding service is configured, and this attribute is set,
each Data API call will include a "x-embedding-api-key" header
with the value of this attribute.
embedding_api_key: an `astrapy.authentication.EmbeddingHeadersProvider`
object, encoding embedding-related API keys that will be passed
as headers when interacting with the collection (on each Data API request).
The default value is `StaticEmbeddingHeadersProvider(None)`, i.e.
no embedding-specific headers, whereas if the collection is configured
with an embedding service other choices for this parameter can be
meaningfully supplied. is configured for the collection,
"""

embedding_api_key: Optional[str] = None
embedding_api_key: EmbeddingHeadersProvider = field(
default_factory=lambda: StaticEmbeddingHeadersProvider(None)
)
182 changes: 172 additions & 10 deletions astrapy/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,29 @@

import base64
from abc import ABC, abstractmethod
from typing import Any, Union
from typing import Any, Dict, Optional, Union

EMBEDDING_HEADER_AWS_ACCESS_ID = "X-Embedding-Access-Id"
EMBEDDING_HEADER_AWS_SECRET_ID = "X-Embedding-Secret-Id"
EMBEDDING_HEADER_API_KEY = "X-Embedding-Api-Key"

def coerce_token_provider(token: Any) -> TokenProvider:

def coerce_token_provider(token: Optional[Union[str, TokenProvider]]) -> TokenProvider:
if isinstance(token, TokenProvider):
return token
else:
return StaticTokenProvider(token)


def coerce_embedding_headers_provider(
embedding_api_key: Optional[Union[str, EmbeddingHeadersProvider]],
) -> EmbeddingHeadersProvider:
if isinstance(embedding_api_key, EmbeddingHeadersProvider):
return embedding_api_key
else:
return StaticEmbeddingHeadersProvider(embedding_api_key)


class TokenProvider(ABC):
"""
Abstract base class for a token provider.
Expand Down Expand Up @@ -54,14 +67,6 @@ def __eq__(self, other: Any) -> bool:
@abstractmethod
def __repr__(self) -> str: ...

@abstractmethod
def get_token(self) -> Union[str, None]:
"""
Produce a string for direct use as token in a subsequent API request,
or None for no token.
"""
...

def __or__(self, other: TokenProvider) -> TokenProvider:
"""
Implement the logic as for "token_str_a or token_str_b" for the TokenProvider,
Expand All @@ -82,6 +87,14 @@ def __bool__(self) -> bool:
"""
return self.get_token() is not None

@abstractmethod
def get_token(self) -> Union[str, None]:
"""
Produce a string for direct use as token in a subsequent API request,
or None for no token.
"""
...


class StaticTokenProvider(TokenProvider):
"""
Expand Down Expand Up @@ -150,3 +163,152 @@ def _b64(cleartext: str) -> str:

def get_token(self) -> str:
return self.token


class EmbeddingHeadersProvider(ABC):
"""
Abstract base class for a provider of embedding-related headers (such as API Keys).
The relevant method in this interface is returning a dict to use as
(part of the) headers in Data API requests for a collection.
This class captures the fact that, depending on the embedding provider for
the collection, there may be zero, one *or more* headers to be passed
if relying on the HEADERS auth method for Vectorize.
"""

def __eq__(self, other: Any) -> bool:
my_headers = self.get_headers()
if isinstance(other, EmbeddingHeadersProvider):
return other.get_headers() == my_headers
else:
return False

@abstractmethod
def __repr__(self) -> str: ...

def __bool__(self) -> bool:
"""
All headers providers evaluate to True unless they yield the empty dict.
This method enables the override mechanism in APIOptions.
"""
return self.get_headers() != {}

@abstractmethod
def get_headers(self) -> Dict[str, str]:
"""
Produce a dictionary for use as (part of) the headers in HTTP requests
to the Data API.
"""
...


class StaticEmbeddingHeadersProvider(EmbeddingHeadersProvider):
"""
A "pass-through" header provider representing the single-header
(typically "X-Embedding-Api-Key") auth scheme, in use by most of the
embedding models in Vectorize.
Args:
embedding_api_key: a string that will be the value for the header.
If None is passed, this results in a no-headers provider (such
as the one used for non-Vectorize collections).
Example:
>>> from astrapy import DataAPIClient
>>> from astrapy.authentication import (
CollectionVectorServiceOptions,
StaticEmbeddingHeadersProvider,
)
>>> my_emb_api_key = StaticEmbeddingHeadersProvider("abc012...")
>>> service_options = CollectionVectorServiceOptions(
... provider="a-certain-provider",
... model_name="some-embedding-model",
... )
>>>
>>> database = DataAPIClient().get_database(
... "https://01234567-...-eu-west1.apps.datastax.com",
... token="AstraCS:...",
... )
>>> collection = database.create_collection(
... "vectorize_collection",
... service=service_options,
... embedding_api_key=my_emb_api_key,
... )
>>> # likewise:
>>> collection_b = database.get_collection(
... "vectorize_collection",
... embedding_api_key=my_emb_api_key,
... )
"""

def __init__(self, embedding_api_key: Optional[str]) -> None:
self.embedding_api_key = embedding_api_key

def __repr__(self) -> str:
if self.embedding_api_key is None:
return f"{self.__class__.__name__}(empty)"
else:
return f'{self.__class__.__name__}("{self.embedding_api_key[:5]}...")'

def get_headers(self) -> Dict[str, str]:
if self.embedding_api_key is not None:
return {EMBEDDING_HEADER_API_KEY: self.embedding_api_key}
else:
return {}


class AWSEmbeddingHeadersProvider(EmbeddingHeadersProvider):
"""
A header provider representing the two-header auth scheme in use
by the Amazon Web Services (e.g. AWS Bedrock) when using header-based
authentication.
Args:
embedding_access_id: value of the "Access ID" secret. This will become
the value for the corresponding header.
embedding_secret_id: value of the "Secret ID" secret. This will become
the value for the corresponding header.
Example:
>>> from astrapy import DataAPIClient
>>> from astrapy.authentication import (
CollectionVectorServiceOptions,
AWSEmbeddingHeadersProvider,
)
>>> my_aws_emb_api_key = AWSEmbeddingHeadersProvider(
embedding_access_id="my-access-id-012...",
embedding_secret_id="my-secret-id-abc...",
)
>>> service_options = CollectionVectorServiceOptions(
... provider="bedrock",
... model_name="some-aws-bedrock-model",
... )
>>>
>>> database = DataAPIClient().get_database(
... "https://01234567-...-eu-west1.apps.datastax.com",
... token="AstraCS:...",
... )
>>> collection = database.create_collection(
... "vectorize_aws_collection",
... service=service_options,
... embedding_api_key=my_aws_emb_api_key,
... )
>>> # likewise:
>>> collection_b = database.get_collection(
... "vectorize_aws_collection",
... embedding_api_key=my_aws_emb_api_key,
... )
"""

def __init__(self, *, embedding_access_id: str, embedding_secret_id: str) -> None:
self.embedding_access_id = embedding_access_id
self.embedding_secret_id = embedding_secret_id

def __repr__(self) -> str:
return f'{self.__class__.__name__}("{self.embedding_access_id[:3]}...", "{self.embedding_secret_id[:3]}...")'

def get_headers(self) -> Dict[str, str]:
return {
EMBEDDING_HEADER_AWS_ACCESS_ID: self.embedding_access_id,
EMBEDDING_HEADER_AWS_SECRET_ID: self.embedding_secret_id,
}
Loading

0 comments on commit 2adce1a

Please sign in to comment.