Skip to content

Commit

Permalink
heading to 1.4.0; find_embedding_providers returns a top-level object…
Browse files Browse the repository at this point in the history
… with the provider map as its member
  • Loading branch information
hemidactylus committed Jul 9, 2024
1 parent fe454ac commit 5e6b3f9
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 47 deletions.
2 changes: 1 addition & 1 deletion CHANGES
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
(master)
v. 1.4.0
========
DatabaseAdmin classes retain a reference to the Async/Database instance that spawned it, if any
- introduced a spawner_database parameter to database admin constructors
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ from astrapy.info import (
EmbeddingProviderToken,
EmbeddingProviderAuthentication,
EmbeddingProvider,
FindEmbeddingProvidersResult,
)
```

Expand Down
50 changes: 15 additions & 35 deletions astrapy/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
ops_recast_method_sync,
to_dataapi_timeout_exception,
)
from astrapy.info import AdminDatabaseInfo, DatabaseInfo, EmbeddingProvider
from astrapy.info import AdminDatabaseInfo, DatabaseInfo, FindEmbeddingProvidersResult

if TYPE_CHECKING:
from astrapy import AsyncDatabase, Database
Expand Down Expand Up @@ -1453,14 +1453,14 @@ def get_async_database(self, *pargs: Any, **kwargs: Any) -> AsyncDatabase:
@abstractmethod
def find_embedding_providers(
self, *pargs: Any, **kwargs: Any
) -> Dict[str, EmbeddingProvider]:
) -> FindEmbeddingProvidersResult:
"""Query the Data API for the available embedding providers."""
...

@abstractmethod
async def async_find_embedding_providers(
self, *pargs: Any, **kwargs: Any
) -> Dict[str, EmbeddingProvider]:
) -> FindEmbeddingProvidersResult:
"""
Query the Data API for the available embedding providers.
(Async version of the method.)
Expand Down Expand Up @@ -2517,15 +2517,15 @@ def get_async_database(

def find_embedding_providers(
self, *, max_time_ms: Optional[int] = None
) -> Dict[str, EmbeddingProvider]:
) -> FindEmbeddingProvidersResult:
"""
Query the API for the full information on available embedding providers.
Args:
max_time_ms: a timeout, in milliseconds, for the DevOps API request.
Returns:
An `EmbeddingProvidersDescriptor` object with the complete information
A `FindEmbeddingProvidersResult` object with the complete information
returned by the API about available embedding providers
Example (output abridged and indented for clarity):
Expand Down Expand Up @@ -2554,16 +2554,11 @@ def find_embedding_providers(
)
else:
logger.info("finished getting list of embedding providers")
return {
ep_name: EmbeddingProvider.from_dict(ep_dict)
for ep_name, ep_dict in fe_response["status"][
"embeddingProviders"
].items()
}
return FindEmbeddingProvidersResult.from_dict(fe_response["status"])

async def async_find_embedding_providers(
self, *, max_time_ms: Optional[int] = None
) -> Dict[str, EmbeddingProvider]:
) -> FindEmbeddingProvidersResult:
"""
Query the API for the full information on available embedding providers.
Async version of the method, for use in an asyncio context.
Expand All @@ -2572,7 +2567,7 @@ async def async_find_embedding_providers(
max_time_ms: a timeout, in milliseconds, for the DevOps API request.
Returns:
An `EmbeddingProvidersDescriptor` object with the complete information
A `FindEmbeddingProvidersResult` object with the complete information
returned by the API about available embedding providers
Example (output abridged and indented for clarity):
Expand Down Expand Up @@ -2601,12 +2596,7 @@ async def async_find_embedding_providers(
)
else:
logger.info("finished getting list of embedding providers, async")
return {
ep_name: EmbeddingProvider.from_dict(ep_dict)
for ep_name, ep_dict in fe_response["status"][
"embeddingProviders"
].items()
}
return FindEmbeddingProvidersResult.from_dict(fe_response["status"])


class DataAPIDatabaseAdmin(DatabaseAdmin):
Expand Down Expand Up @@ -3190,15 +3180,15 @@ def get_async_database(

def find_embedding_providers(
self, *, max_time_ms: Optional[int] = None
) -> Dict[str, EmbeddingProvider]:
) -> FindEmbeddingProvidersResult:
"""
Query the API for the full information on available embedding providers.
Args:
max_time_ms: a timeout, in milliseconds, for the DevOps API request.
Returns:
An `EmbeddingProvidersDescriptor` object with the complete information
A `FindEmbeddingProvidersResult` object with the complete information
returned by the API about available embedding providers
Example (output abridged and indented for clarity):
Expand Down Expand Up @@ -3227,16 +3217,11 @@ def find_embedding_providers(
)
else:
logger.info("finished getting list of embedding providers")
return {
ep_name: EmbeddingProvider.from_dict(ep_dict)
for ep_name, ep_dict in fe_response["status"][
"embeddingProviders"
].items()
}
return FindEmbeddingProvidersResult.from_dict(fe_response["status"])

async def async_find_embedding_providers(
self, *, max_time_ms: Optional[int] = None
) -> Dict[str, EmbeddingProvider]:
) -> FindEmbeddingProvidersResult:
"""
Query the API for the full information on available embedding providers.
Async version of the method, for use in an asyncio context.
Expand All @@ -3245,7 +3230,7 @@ async def async_find_embedding_providers(
max_time_ms: a timeout, in milliseconds, for the DevOps API request.
Returns:
An `EmbeddingProvidersDescriptor` object with the complete information
A `FindEmbeddingProvidersResult` object with the complete information
returned by the API about available embedding providers
Example (output abridged and indented for clarity):
Expand Down Expand Up @@ -3274,9 +3259,4 @@ async def async_find_embedding_providers(
)
else:
logger.info("finished getting list of embedding providers, async")
return {
ep_name: EmbeddingProvider.from_dict(ep_dict)
for ep_name, ep_dict in fe_response["status"][
"embeddingProviders"
].items()
}
return FindEmbeddingProvidersResult.from_dict(fe_response["status"])
54 changes: 54 additions & 0 deletions astrapy/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,3 +747,57 @@ def from_dict(raw_dict: Dict[str, Any]) -> EmbeddingProvider:
},
url=raw_dict["url"],
)


@dataclass
class FindEmbeddingProvidersResult:
"""
A representation of the whole response from the 'findEmbeddingProviders'
Data API endpoint.
Attributes:
embedding_providers: a dictionary of provider names to EmbeddingProvider objects.
raw_info: a (nested) dictionary containing the original full response from the endpoint.
"""

def __repr__(self) -> str:
return (
"FindEmbeddingProvidersResult(embedding_providers="
f"{', '.join(sorted(self.embedding_providers.keys()))})"
)

embedding_providers: Dict[str, EmbeddingProvider]
raw_info: Optional[Dict[str, Any]]

def as_dict(self) -> Dict[str, Any]:
"""Recast this object into a dictionary."""

return {
"embeddingProviders": {
ep_name: e_provider.as_dict()
for ep_name, e_provider in self.embedding_providers.items()
},
}

@staticmethod
def from_dict(raw_dict: Dict[str, Any]) -> FindEmbeddingProvidersResult:
"""
Create an instance of FindEmbeddingProvidersResult from a dictionary
such as one from the Data API.
"""

residual_keys = raw_dict.keys() - {
"embeddingProviders",
}
if residual_keys:
warnings.warn(
"Unexpected key(s) encountered parsing a dictionary into "
f"a `FindEmbeddingProvidersResult`: '{','.join(sorted(residual_keys))}'"
)
return FindEmbeddingProvidersResult(
raw_info=raw_dict,
embedding_providers={
ep_name: EmbeddingProvider.from_dict(ep_body)
for ep_name, ep_body in raw_dict["embeddingProviders"].items()
},
)
1 change: 1 addition & 0 deletions tests/idiomatic/unit/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def test_imports() -> None:
EmbeddingProviderModel,
EmbeddingProviderParameter,
EmbeddingProviderToken,
FindEmbeddingProvidersResult,
)
from astrapy.operations import ( # noqa: F401
AsyncBaseOperation,
Expand Down
18 changes: 13 additions & 5 deletions tests/vectorize_idiomatic/integration/test_vectorize_ops_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pytest

from astrapy import AsyncDatabase
from astrapy.info import EmbeddingProvider
from astrapy.info import EmbeddingProvider, FindEmbeddingProvidersResult


class TestVectorizeOpsAsync:
Expand All @@ -27,14 +27,22 @@ async def test_collection_methods_vectorize_async(
async_database: AsyncDatabase,
) -> None:
database_admin = async_database.get_database_admin()
ep_map = await database_admin.async_find_embedding_providers()
ep_result = database_admin.find_embedding_providers()

assert isinstance(ep_result, FindEmbeddingProvidersResult)

assert all(
isinstance(emb_prov, EmbeddingProvider) for emb_prov in ep_map.values()
isinstance(emb_prov, EmbeddingProvider)
for emb_prov in ep_result.embedding_providers.values()
)

reconstructed = {
ep_name: EmbeddingProvider.from_dict(emb_prov.as_dict())
for ep_name, emb_prov in ep_map.items()
for ep_name, emb_prov in ep_result.embedding_providers.items()
}
assert reconstructed == ep_result.embedding_providers
dict_mapping = {
ep_name: emb_prov.as_dict()
for ep_name, emb_prov in ep_result.embedding_providers.items()
}
assert reconstructed == ep_map
assert dict_mapping == ep_result.raw_info["embeddingProviders"] # type: ignore[index]
18 changes: 13 additions & 5 deletions tests/vectorize_idiomatic/integration/test_vectorize_ops_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pytest

from astrapy import Database
from astrapy.info import EmbeddingProvider
from astrapy.info import EmbeddingProvider, FindEmbeddingProvidersResult


class TestVectorizeOpsSync:
Expand All @@ -27,14 +27,22 @@ def test_collection_methods_vectorize_sync(
sync_database: Database,
) -> None:
database_admin = sync_database.get_database_admin()
ep_map = database_admin.find_embedding_providers()
ep_result = database_admin.find_embedding_providers()

assert isinstance(ep_result, FindEmbeddingProvidersResult)

assert all(
isinstance(emb_prov, EmbeddingProvider) for emb_prov in ep_map.values()
isinstance(emb_prov, EmbeddingProvider)
for emb_prov in ep_result.embedding_providers.values()
)

reconstructed = {
ep_name: EmbeddingProvider.from_dict(emb_prov.as_dict())
for ep_name, emb_prov in ep_map.items()
for ep_name, emb_prov in ep_result.embedding_providers.items()
}
assert reconstructed == ep_result.embedding_providers
dict_mapping = {
ep_name: emb_prov.as_dict()
for ep_name, emb_prov in ep_result.embedding_providers.items()
}
assert reconstructed == ep_map
assert dict_mapping == ep_result.raw_info["embeddingProviders"] # type: ignore[index]
2 changes: 1 addition & 1 deletion tests/vectorize_idiomatic/live_provider_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,4 @@ def live_provider_info() -> Dict[str, EmbeddingProvider]:

database_admin = database.get_database_admin()
response = database_admin.find_embedding_providers()
return response
return response.embedding_providers

0 comments on commit 5e6b3f9

Please sign in to comment.