Skip to content

Commit bebe401

Browse files
cbornetErick Friis
andauthored
astradb[patch]: Add AstraDBStore to langchain-astradb package (langchain-ai#17789)
Co-authored-by: Erick Friis <erick@langchain.dev>
1 parent 4e28888 commit bebe401

File tree

6 files changed

+551
-0
lines changed

6 files changed

+551
-0
lines changed

libs/community/langchain_community/storage/astradb.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
TypeVar,
1616
)
1717

18+
from langchain_core._api.deprecation import deprecated
1819
from langchain_core.stores import BaseStore, ByteStore
1920

2021
from langchain_community.utilities.astradb import (
@@ -124,6 +125,11 @@ async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[st
124125
yield key
125126

126127

128+
@deprecated(
129+
since="0.0.22",
130+
removal="0.2.0",
131+
alternative_import="langchain_astradb.AstraDBStore",
132+
)
127133
class AstraDBStore(AstraDBBaseStore[Any]):
128134
"""BaseStore implementation using DataStax AstraDB as the underlying store.
129135
@@ -143,6 +149,11 @@ def encode_value(self, value: Any) -> Any:
143149
return value
144150

145151

152+
@deprecated(
153+
since="0.0.22",
154+
removal="0.2.0",
155+
alternative_import="langchain_astradb.AstraDBByteStore",
156+
)
146157
class AstraDBByteStore(AstraDBBaseStore[bytes], ByteStore):
147158
"""ByteStore implementation using DataStax AstraDB as the underlying store.
148159
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
from langchain_astradb.storage import AstraDBByteStore, AstraDBStore
12
from langchain_astradb.vectorstores import AstraDBVectorStore
23

34
__all__ = [
5+
"AstraDBByteStore",
6+
"AstraDBStore",
47
"AstraDBVectorStore",
58
]
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
from __future__ import annotations
2+
3+
import base64
4+
from abc import ABC, abstractmethod
5+
from typing import (
6+
Any,
7+
AsyncIterator,
8+
Generic,
9+
Iterator,
10+
List,
11+
Optional,
12+
Sequence,
13+
Tuple,
14+
TypeVar,
15+
)
16+
17+
from astrapy.db import AstraDB, AsyncAstraDB
18+
from langchain_core.stores import BaseStore, ByteStore
19+
20+
from langchain_astradb.utils.astradb import (
21+
SetupMode,
22+
_AstraDBCollectionEnvironment,
23+
)
24+
25+
V = TypeVar("V")
26+
27+
28+
class AstraDBBaseStore(Generic[V], BaseStore[str, V], ABC):
29+
"""Base class for the DataStax AstraDB data store."""
30+
31+
def __init__(self, *args: Any, **kwargs: Any) -> None:
32+
self.astra_env = _AstraDBCollectionEnvironment(*args, **kwargs)
33+
self.collection = self.astra_env.collection
34+
self.async_collection = self.astra_env.async_collection
35+
36+
@abstractmethod
37+
def decode_value(self, value: Any) -> Optional[V]:
38+
"""Decodes value from Astra DB"""
39+
40+
@abstractmethod
41+
def encode_value(self, value: Optional[V]) -> Any:
42+
"""Encodes value for Astra DB"""
43+
44+
def mget(self, keys: Sequence[str]) -> List[Optional[V]]:
45+
self.astra_env.ensure_db_setup()
46+
docs_dict = {}
47+
for doc in self.collection.paginated_find(filter={"_id": {"$in": list(keys)}}):
48+
docs_dict[doc["_id"]] = doc.get("value")
49+
return [self.decode_value(docs_dict.get(key)) for key in keys]
50+
51+
async def amget(self, keys: Sequence[str]) -> List[Optional[V]]:
52+
await self.astra_env.aensure_db_setup()
53+
docs_dict = {}
54+
async for doc in self.async_collection.paginated_find(
55+
filter={"_id": {"$in": list(keys)}}
56+
):
57+
docs_dict[doc["_id"]] = doc.get("value")
58+
return [self.decode_value(docs_dict.get(key)) for key in keys]
59+
60+
def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
61+
self.astra_env.ensure_db_setup()
62+
for k, v in key_value_pairs:
63+
self.collection.upsert({"_id": k, "value": self.encode_value(v)})
64+
65+
async def amset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
66+
await self.astra_env.aensure_db_setup()
67+
for k, v in key_value_pairs:
68+
await self.async_collection.upsert(
69+
{"_id": k, "value": self.encode_value(v)}
70+
)
71+
72+
def mdelete(self, keys: Sequence[str]) -> None:
73+
self.astra_env.ensure_db_setup()
74+
self.collection.delete_many(filter={"_id": {"$in": list(keys)}})
75+
76+
async def amdelete(self, keys: Sequence[str]) -> None:
77+
await self.astra_env.aensure_db_setup()
78+
await self.async_collection.delete_many(filter={"_id": {"$in": list(keys)}})
79+
80+
def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
81+
self.astra_env.ensure_db_setup()
82+
docs = self.collection.paginated_find()
83+
for doc in docs:
84+
key = doc["_id"]
85+
if not prefix or key.startswith(prefix):
86+
yield key
87+
88+
async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]:
89+
await self.astra_env.aensure_db_setup()
90+
async for doc in self.async_collection.paginated_find():
91+
key = doc["_id"]
92+
if not prefix or key.startswith(prefix):
93+
yield key
94+
95+
96+
class AstraDBStore(AstraDBBaseStore[Any]):
97+
def __init__(
98+
self,
99+
collection_name: str,
100+
*,
101+
token: Optional[str] = None,
102+
api_endpoint: Optional[str] = None,
103+
astra_db_client: Optional[AstraDB] = None,
104+
namespace: Optional[str] = None,
105+
async_astra_db_client: Optional[AsyncAstraDB] = None,
106+
pre_delete_collection: bool = False,
107+
setup_mode: SetupMode = SetupMode.SYNC,
108+
) -> None:
109+
"""BaseStore implementation using DataStax AstraDB as the underlying store.
110+
111+
The value type can be any type serializable by json.dumps.
112+
Can be used to store embeddings with the CacheBackedEmbeddings.
113+
114+
Documents in the AstraDB collection will have the format
115+
116+
.. code-block:: json
117+
{
118+
"_id": "<key>",
119+
"value": <value>
120+
}
121+
122+
Args:
123+
collection_name: name of the Astra DB collection to create/use.
124+
token: API token for Astra DB usage.
125+
api_endpoint: full URL to the API endpoint,
126+
such as `https://<DB-ID>-us-east1.apps.astra.datastax.com`.
127+
astra_db_client: *alternative to token+api_endpoint*,
128+
you can pass an already-created 'astrapy.db.AstraDB' instance.
129+
async_astra_db_client: *alternative to token+api_endpoint*,
130+
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
131+
namespace: namespace (aka keyspace) where the
132+
collection is created. Defaults to the database's "default namespace".
133+
setup_mode: mode used to create the Astra DB collection (SYNC, ASYNC or
134+
OFF).
135+
pre_delete_collection: whether to delete the collection
136+
before creating it. If False and the collection already exists,
137+
the collection will be used as is.
138+
"""
139+
super().__init__(
140+
collection_name=collection_name,
141+
token=token,
142+
api_endpoint=api_endpoint,
143+
astra_db_client=astra_db_client,
144+
async_astra_db_client=async_astra_db_client,
145+
namespace=namespace,
146+
setup_mode=setup_mode,
147+
pre_delete_collection=pre_delete_collection,
148+
)
149+
150+
def decode_value(self, value: Any) -> Any:
151+
return value
152+
153+
def encode_value(self, value: Any) -> Any:
154+
return value
155+
156+
157+
class AstraDBByteStore(AstraDBBaseStore[bytes], ByteStore):
158+
def __init__(
159+
self,
160+
*,
161+
collection_name: str,
162+
token: Optional[str] = None,
163+
api_endpoint: Optional[str] = None,
164+
astra_db_client: Optional[AstraDB] = None,
165+
namespace: Optional[str] = None,
166+
async_astra_db_client: Optional[AsyncAstraDB] = None,
167+
pre_delete_collection: bool = False,
168+
setup_mode: SetupMode = SetupMode.SYNC,
169+
) -> None:
170+
"""ByteStore implementation using DataStax AstraDB as the underlying store.
171+
172+
The bytes values are converted to base64 encoded strings
173+
Documents in the AstraDB collection will have the format
174+
175+
.. code-block:: json
176+
{
177+
"_id": "<key>",
178+
"value": "<byte64 string value>"
179+
}
180+
181+
Args:
182+
collection_name: name of the Astra DB collection to create/use.
183+
token: API token for Astra DB usage.
184+
api_endpoint: full URL to the API endpoint,
185+
such as `https://<DB-ID>-us-east1.apps.astra.datastax.com`.
186+
astra_db_client: *alternative to token+api_endpoint*,
187+
you can pass an already-created 'astrapy.db.AstraDB' instance.
188+
async_astra_db_client: *alternative to token+api_endpoint*,
189+
you can pass an already-created 'astrapy.db.AsyncAstraDB' instance.
190+
namespace: namespace (aka keyspace) where the
191+
collection is created. Defaults to the database's "default namespace".
192+
setup_mode: mode used to create the Astra DB collection (SYNC, ASYNC or
193+
OFF).
194+
pre_delete_collection: whether to delete the collection
195+
before creating it. If False and the collection already exists,
196+
the collection will be used as is.
197+
"""
198+
super().__init__(
199+
collection_name=collection_name,
200+
token=token,
201+
api_endpoint=api_endpoint,
202+
astra_db_client=astra_db_client,
203+
async_astra_db_client=async_astra_db_client,
204+
namespace=namespace,
205+
setup_mode=setup_mode,
206+
pre_delete_collection=pre_delete_collection,
207+
)
208+
209+
def decode_value(self, value: Any) -> Optional[bytes]:
210+
if value is None:
211+
return None
212+
return base64.b64decode(value)
213+
214+
def encode_value(self, value: Optional[bytes]) -> Any:
215+
if value is None:
216+
return None
217+
return base64.b64encode(value).decode("ascii")

0 commit comments

Comments
 (0)