diff --git a/integration/test_batch_v4.py b/integration/test_batch_v4.py index 4a55b0131..7788ef03d 100644 --- a/integration/test_batch_v4.py +++ b/integration/test_batch_v4.py @@ -1,9 +1,11 @@ +import asyncio import concurrent.futures import uuid from dataclasses import dataclass from typing import Callable, Generator, List, Optional, Protocol, Tuple import pytest +import pytest_asyncio from _pytest.fixtures import SubRequest import weaviate @@ -119,6 +121,53 @@ def _factory( client_fixture.close() +class AsyncClientFactory(Protocol): + """Typing for fixture.""" + + async def __call__( + self, name: str = "", ports: Tuple[int, int] = (8080, 50051), multi_tenant: bool = False + ) -> Tuple[weaviate.WeaviateAsyncClient, str]: + """Typing for fixture.""" + ... + + +@pytest_asyncio.fixture +async def async_client_factory(request: SubRequest): + name_fixtures: List[str] = [] + client_fixture: Optional[weaviate.WeaviateAsyncClient] = None + + async def _factory( + name: str = "", ports: Tuple[int, int] = (8080, 50051), multi_tenant: bool = False + ): + nonlocal client_fixture, name_fixtures # noqa: F824 + name_fixture = _sanitize_collection_name(request.node.name) + name + name_fixtures.append(name_fixture) + if client_fixture is None: + client_fixture = weaviate.use_async_with_local(grpc_port=ports[1], port=ports[0]) + await client_fixture.connect() + + if await client_fixture.collections.exists(name_fixture): + await client_fixture.collections.delete(name_fixture) + + await client_fixture.collections.create( + name=name_fixture, + properties=[ + Property(name="name", data_type=DataType.TEXT), + Property(name="age", data_type=DataType.INT), + ], + references=[ReferenceProperty(name="test", target_collection=name_fixture)], + multi_tenancy_config=Configure.multi_tenancy(multi_tenant), + vectorizer_config=Configure.Vectorizer.none(), + ) + return client_fixture, name_fixture + + try: + yield _factory + finally: + if client_fixture is not None: + await client_fixture.close() + + def test_add_objects_in_multiple_batches(client_factory: ClientFactory) -> None: client, name = client_factory() with client.batch.rate_limit(50) as batch: @@ -365,15 +414,15 @@ def test_add_ref_batch_with_tenant(client_factory: ClientFactory) -> None: @pytest.mark.parametrize( "batching_method", [ - # lambda client: client.batch.dynamic(), - # lambda client: client.batch.fixed_size(), - # lambda client: client.batch.rate_limit(9999), + lambda client: client.batch.dynamic(), + lambda client: client.batch.fixed_size(), + lambda client: client.batch.rate_limit(9999), lambda client: client.batch.experimental(concurrency=1), ], ids=[ - # "test_add_ten_thousand_data_objects_dynamic", - # "test_add_ten_thousand_data_objects_fixed_size", - # "test_add_ten_thousand_data_objects_rate_limit", + "test_add_ten_thousand_data_objects_dynamic", + "test_add_ten_thousand_data_objects_fixed_size", + "test_add_ten_thousand_data_objects_rate_limit", "test_add_ten_thousand_data_objects_experimental", ], ) @@ -768,3 +817,40 @@ def test_references_with_to_uuids(client_factory: ClientFactory) -> None: assert len(client.batch.failed_references) == 0, client.batch.failed_references client.collections.delete(["target", "source"]) + + +@pytest.mark.asyncio +async def test_add_ten_thousand_data_objects_async( + async_client_factory: AsyncClientFactory, +) -> None: + """Test adding ten thousand data objects.""" + client, name = await async_client_factory() + if client._connection._weaviate_version.is_lower_than(1, 34, 0): + pytest.skip("Server-side batching not supported in Weaviate < 1.34.0") + nr_objects = 100000 + import time + + start = time.time() + async with client.batch.experimental(concurrency=1) as batch: + async for i in arange(nr_objects): + await batch.add_object( + collection=name, + properties={"name": "test" + str(i)}, + ) + end = time.time() + print(f"Time taken to add {nr_objects} objects: {end - start} seconds") + assert len(client.batch.results.objs.errors) == 0 + assert len(client.batch.results.objs.all_responses) == nr_objects + assert len(client.batch.results.objs.uuids) == nr_objects + assert await client.collections.use(name).length() == nr_objects + assert client.batch.results.objs.has_errors is False + assert len(client.batch.failed_objects) == 0, [ + obj.message for obj in client.batch.failed_objects + ] + await client.collections.delete(name) + + +async def arange(count): + for i in range(count): + yield i + await asyncio.sleep(0.0) diff --git a/weaviate/client.py b/weaviate/client.py index 8cf856c51..d7f9080f4 100644 --- a/weaviate/client.py +++ b/weaviate/client.py @@ -10,7 +10,7 @@ from .auth import AuthCredentials from .backup import _Backup, _BackupAsync from .cluster import _Cluster, _ClusterAsync -from .collections.batch.client import _BatchClientWrapper +from .collections.batch.client import _BatchClientWrapper, _BatchClientWrapperAsync from .collections.collections import _Collections, _CollectionsAsync from .config import AdditionalConfig from .connect import executor @@ -76,6 +76,7 @@ def __init__( ) self.alias = _AliasAsync(self._connection) self.backup = _BackupAsync(self._connection) + self.batch = _BatchClientWrapperAsync(self._connection) self.cluster = _ClusterAsync(self._connection) self.collections = _CollectionsAsync(self._connection) self.debug = _DebugAsync(self._connection) diff --git a/weaviate/client.pyi b/weaviate/client.pyi index 205a34b4e..9b32af15f 100644 --- a/weaviate/client.pyi +++ b/weaviate/client.pyi @@ -18,7 +18,7 @@ from weaviate.users.sync import _Users from .backup import _Backup, _BackupAsync from .cluster import _Cluster, _ClusterAsync -from .collections.batch.client import _BatchClientWrapper +from .collections.batch.client import _BatchClientWrapper, _BatchClientWrapperAsync from .debug import _Debug, _DebugAsync from .rbac import _Roles, _RolesAsync from .types import NUMBER @@ -29,6 +29,7 @@ class WeaviateAsyncClient(_WeaviateClientExecutor[ConnectionAsync]): _connection: ConnectionAsync alias: _AliasAsync backup: _BackupAsync + batch: _BatchClientWrapperAsync collections: _CollectionsAsync cluster: _ClusterAsync debug: _DebugAsync diff --git a/weaviate/collections/batch/async_.py b/weaviate/collections/batch/async_.py new file mode 100644 index 000000000..f5f5758d3 --- /dev/null +++ b/weaviate/collections/batch/async_.py @@ -0,0 +1,482 @@ +import asyncio +import time +import uuid as uuid_package +from typing import ( + AsyncGenerator, + List, + Optional, + Set, + Union, +) + +from pydantic import ValidationError + +from weaviate.collections.batch.base import ( + ObjectsBatchRequest, + ReferencesBatchRequest, + _BatchDataWrapper, + _BatchMode, + _ServerSideBatching, +) +from weaviate.collections.batch.grpc_batch import _BatchGRPC +from weaviate.collections.classes.batch import ( + BatchObject, + BatchObjectReturn, + BatchReference, + ErrorObject, + ErrorReference, + Shard, +) +from weaviate.collections.classes.config import ConsistencyLevel +from weaviate.collections.classes.internal import ( + ReferenceInput, + ReferenceInputs, + ReferenceToMulti, +) +from weaviate.collections.classes.types import WeaviateProperties +from weaviate.connect.v4 import ConnectionAsync +from weaviate.exceptions import ( + WeaviateBatchValidationError, + WeaviateGRPCUnavailableError, + WeaviateStartUpError, +) +from weaviate.logger import logger +from weaviate.proto.v1 import batch_pb2 +from weaviate.types import UUID, VECTORS + + +class _BgTasks: + def __init__(self, send: asyncio.Task[None], recv: asyncio.Task[None]) -> None: + self.send = send + self.recv = recv + + +class _BatchBaseAsync: + def __init__( + self, + connection: ConnectionAsync, + consistency_level: Optional[ConsistencyLevel], + results: _BatchDataWrapper, + batch_mode: _BatchMode, + objects: Optional[ObjectsBatchRequest[batch_pb2.BatchObject]] = None, + references: Optional[ReferencesBatchRequest] = None, + ) -> None: + self.__batch_objects = objects or ObjectsBatchRequest[batch_pb2.BatchObject]() + self.__batch_references = references or ReferencesBatchRequest[batch_pb2.BatchReference]() + + self.__connection = connection + self.__consistency_level: ConsistencyLevel = consistency_level or ConsistencyLevel.QUORUM + self.__batch_size = 100 + + self.__batch_grpc = _BatchGRPC( + connection._weaviate_version, self.__consistency_level, connection._grpc_max_msg_size + ) + self.__stream = self.__connection.grpc_batch_stream() + + # lookup table for objects that are currently being processed - is used to not send references from objects that have not been added yet + self.__uuid_lookup: Set[str] = set() + + # we do not want that users can access the results directly as they are not thread-safe + self.__results_for_wrapper_backup = results + self.__results_for_wrapper = _BatchDataWrapper() + + self.__objs_count = 0 + self.__refs_count = 0 + + self.__uuid_lookup_lock = asyncio.Lock() + + self.__is_shutting_down = asyncio.Event() + self.__is_shutdown = asyncio.Event() + + self.__objs_cache_lock = asyncio.Lock() + self.__refs_cache_lock = asyncio.Lock() + self.__objs_cache: dict[str, BatchObject] = {} + self.__refs_cache: dict[int, BatchReference] = {} + + self.__stop = False + + self.__batch_mode = batch_mode + + @property + def number_errors(self) -> int: + """Return the number of errors in the batch.""" + return len(self.__results_for_wrapper.failed_objects) + len( + self.__results_for_wrapper.failed_references + ) + + async def _start(self): + assert isinstance(self.__batch_mode, _ServerSideBatching), ( + "Only server-side batching is supported in this mode" + ) + return _BgTasks( + send=asyncio.create_task(self.__send()), recv=asyncio.create_task(self.__recv()) + ) + + async def _shutdown(self) -> None: + # Shutdown the current batch and wait for all requests to be finished + await self.flush() + self.__stop = True + + # copy the results to the public results + self.__results_for_wrapper_backup.results = self.__results_for_wrapper.results + self.__results_for_wrapper_backup.failed_objects = self.__results_for_wrapper.failed_objects + self.__results_for_wrapper_backup.failed_references = ( + self.__results_for_wrapper.failed_references + ) + self.__results_for_wrapper_backup.imported_shards = ( + self.__results_for_wrapper.imported_shards + ) + + async def __send(self) -> None: + refresh_time: float = 0.01 + await self.__connection.grpc_batch_stream_write( + self.__stream, batch_pb2.BatchStreamRequest(start=batch_pb2.BatchStreamRequest.Start()) + ) + while True: + if len(self.__batch_objects) + len(self.__batch_references) > 0: + start = time.time() + while (len_o := len(self.__batch_objects)) + ( + len_r := len(self.__batch_references) + ) < self.__batch_size: + # wait for more objects to be added up to the batch size + await asyncio.sleep(0.01) + if time.time() - start >= 1 and ( + len_o == len(self.__batch_objects) or len_r == len(self.__batch_references) + ): + # no new objects were added in the last second, exit the loop + break + + objs = await self.__batch_objects.apop_items(self.__batch_size) + refs = await self.__batch_references.apop_items( + self.__batch_size - len(objs), + uuid_lookup=self.__uuid_lookup, + ) + async with self.__uuid_lookup_lock: + self.__uuid_lookup.difference_update(obj.uuid for obj in objs) + + async for req in self.__generate_stream_requests(objs, refs): + logged = False + while self.__is_shutting_down.is_set() or self.__is_shutdown.is_set(): + # if we were shutdown by the node we were connected to, we need to wait for the stream to be restarted + # so that the connection is refreshed to a new node where the objects can be accepted + # otherwise, we wait until the stream has been started by __batch_stream to send the first batch + if not logged: + logger.warning("Waiting for stream to be re-established...") + logged = True + # put sentinel into our queue to signal the end of the current stream + await self.__stream.done_writing() + await asyncio.sleep(1) + if logged: + logger.warning("Stream re-established, resuming sending batches") + await self.__connection.grpc_batch_stream_write(self.__stream, req) + elif self.__stop: + await self.__connection.grpc_batch_stream_write( + self.__stream, + batch_pb2.BatchStreamRequest(stop=batch_pb2.BatchStreamRequest.Stop()), + ) + await self.__stream.done_writing() + logger.warning("Batching finished, sent stop signal to batch stream") + return + await asyncio.sleep(refresh_time) + + async def __generate_stream_requests( + self, + objs: List[batch_pb2.BatchObject], + refs: List[batch_pb2.BatchReference], + ) -> AsyncGenerator[batch_pb2.BatchStreamRequest, None]: + per_object_overhead = 4 # extra overhead bytes per object in the request + + def request_maker(): + return batch_pb2.BatchStreamRequest() + + request = request_maker() + total_size = request.ByteSize() + + for obj in objs: + obj_size = obj.ByteSize() + per_object_overhead + + if total_size + obj_size >= self.__batch_grpc.grpc_max_msg_size: + await asyncio.sleep(0) # yield control to event loop + yield request + request = request_maker() + total_size = request.ByteSize() + + request.data.objects.values.append(obj) + total_size += obj_size + + for ref in refs: + ref_size = ref.ByteSize() + per_object_overhead + + if total_size + ref_size >= self.__batch_grpc.grpc_max_msg_size: + await asyncio.sleep(0) # yield control to event loop + yield request + request = request_maker() + total_size = request.ByteSize() + + request.data.references.values.append(ref) + total_size += ref_size + + if len(request.data.objects.values) > 0 or len(request.data.references.values) > 0: + await asyncio.sleep(0) # yield control to event loop + yield request + + async def __recv(self) -> None: + while True: + message = await self.__connection.grpc_batch_stream_read(self.__stream) + if not isinstance(message, batch_pb2.BatchStreamReply): + logger.warning("Server closed the stream from its side, shutting down batch") + return + if message.HasField("started"): + logger.warning("Batch stream started successfully") + if message.HasField("backoff"): + if ( + message.backoff.batch_size != self.__batch_size + and not self.__is_shutting_down.is_set() + and not self.__is_shutdown.is_set() + and not self.__stop + ): + self.__batch_size = message.backoff.batch_size + logger.warning( + f"Updated batch size to {self.__batch_size} as per server request" + ) + if message.HasField("results"): + result_objs = BatchObjectReturn() + # result_refs = BatchReferenceReturn() + failed_objs: List[ErrorObject] = [] + failed_refs: List[ErrorReference] = [] + for error in message.results.errors: + if error.HasField("uuid"): + try: + async with self.__objs_cache_lock: + cached = self.__objs_cache.pop(error.uuid) + except KeyError: + continue + err = ErrorObject( + message=error.error, + object_=cached, + ) + result_objs += BatchObjectReturn( + _all_responses=[err], + errors={cached.index: err}, + ) + failed_objs.append(err) + logger.warning( + { + "error": error.error, + "object": error.uuid, + "action": "use {client,collection}.batch.failed_objects to access this error", + } + ) + if error.HasField("beacon"): + # TODO: get cached ref from beacon + err = ErrorReference( + message=error.error, + reference=error.beacon, # pyright: ignore + ) + failed_refs.append(err) + logger.warning( + { + "error": error.error, + "reference": error.beacon, + "action": "use {client,collection}.batch.failed_references to access this error", + } + ) + for success in message.results.successes: + if success.HasField("uuid"): + try: + async with self.__objs_cache_lock: + cached = self.__objs_cache.pop(success.uuid) + except KeyError: + continue + uuid = uuid_package.UUID(success.uuid) + result_objs += BatchObjectReturn( + _all_responses=[uuid], + uuids={cached.index: uuid}, + ) + if success.HasField("beacon"): + # TODO: remove cached ref using beacon + # self.__refs_cache.pop(success.beacon, None) + pass + self.__results_for_wrapper.results.objs += result_objs + self.__results_for_wrapper.failed_objects.extend(failed_objs) + self.__results_for_wrapper.failed_references.extend(failed_refs) + elif message.HasField("shutting_down"): + logger.warning( + "Received shutting down message from server, pausing sending until stream is re-established" + ) + self.__is_shutting_down.set() + elif message.HasField("shutdown"): + logger.warning("Received shutdown finished message from server") + self.__is_shutdown.set() + self.__is_shutting_down.clear() + await self.__reconnect() + + # restart the stream if we were shutdown by the node we were connected to + if self.__is_shutdown.is_set(): + logger.warning("Restarting batch recv after shutdown...") + self.__is_shutdown.clear() + return await self.__recv() + + async def __reconnect(self, retry: int = 0) -> None: + try: + logger.warning(f"Trying to reconnect after shutdown... {retry + 1}/{5}") + self.__connection.close("sync") + await self.__connection.connect(force=True) + logger.warning("Reconnected successfully") + self.__stream = self.__connection.grpc_batch_stream() + except (WeaviateStartUpError, WeaviateGRPCUnavailableError) as e: + if retry < 5: + await asyncio.sleep(2**retry) + await self.__reconnect(retry + 1) + else: + logger.error("Failed to reconnect after 5 attempts") + self.__bg_thread_exception = e + + # def __start_bg_threads(self) -> _BgThreads: + # """Create a background thread that periodically checks how congested the batch queue is.""" + # self.__shut_background_thread_down = threading.Event() + + # def batch_send_wrapper() -> None: + # try: + # self.__batch_send() + # logger.warning("exited batch send thread") + # except Exception as e: + # logger.error(e) + # self.__bg_thread_exception = e + + # def batch_recv_wrapper() -> None: + # socket_hung_up = False + # try: + # self.__batch_recv() + # logger.warning("exited batch receive thread") + # except Exception as e: + # if isinstance(e, WeaviateBatchStreamError) and ( + # "Socket closed" in e.message or "context canceled" in e.message + # ): + # socket_hung_up = True + # else: + # logger.error(e) + # logger.error(type(e)) + # self.__bg_thread_exception = e + # if socket_hung_up: + # # this happens during ungraceful shutdown of the coordinator + # # lets restart the stream and add the cached objects again + # logger.warning("Stream closed unexpectedly, restarting...") + # self.__reconnect() + # # server sets this whenever it restarts, gracefully or unexpectedly, so need to clear it now + # self.__is_shutting_down.clear() + # with self.__objs_cache_lock: + # logger.warning( + # f"Re-adding {len(self.__objs_cache)} cached objects to the batch" + # ) + # self.__batch_objects.prepend( + # [ + # self.__batch_grpc.grpc_object(o._to_internal()) + # for o in self.__objs_cache.values() + # ] + # ) + # with self.__refs_cache_lock: + # self.__batch_references.prepend( + # [ + # self.__batch_grpc.grpc_reference(o._to_internal()) + # for o in self.__refs_cache.values() + # ] + # ) + # # start a new stream with a newly reconnected channel + # return batch_recv_wrapper() + + # threads = _BgThreads( + # send=threading.Thread( + # target=batch_send_wrapper, + # daemon=True, + # name="BgBatchSend", + # ), + # recv=threading.Thread( + # target=batch_recv_wrapper, + # daemon=True, + # name="BgBatchRecv", + # ), + # ) + # threads.start_recv() + # return threads + + async def flush(self) -> None: + """Flush the batch queue and wait for all requests to be finished.""" + # bg thread is sending objs+refs automatically, so simply wait for everything to be done + while len(self.__batch_objects) > 0 or len(self.__batch_references) > 0: + await asyncio.sleep(0.01) + + async def _add_object( + self, + collection: str, + properties: Optional[WeaviateProperties] = None, + references: Optional[ReferenceInputs] = None, + uuid: Optional[UUID] = None, + vector: Optional[VECTORS] = None, + tenant: Optional[str] = None, + ) -> UUID: + try: + batch_object = BatchObject( + collection=collection, + properties=properties, + references=references, + uuid=uuid, + vector=vector, + tenant=tenant, + index=self.__objs_count, + ) + self.__results_for_wrapper.imported_shards.add( + Shard(collection=collection, tenant=tenant) + ) + except ValidationError as e: + raise WeaviateBatchValidationError(repr(e)) + uuid = str(batch_object.uuid) + async with self.__uuid_lookup_lock: + self.__uuid_lookup.add(uuid) + await self.__batch_objects.aadd(self.__batch_grpc.grpc_object(batch_object._to_internal())) + async with self.__objs_cache_lock: + self.__objs_cache[uuid] = batch_object + self.__objs_count += 1 + + while len(self.__batch_objects) >= self.__batch_size * 2: + await asyncio.sleep(0.01) + + assert batch_object.uuid is not None + return batch_object.uuid + + async def _add_reference( + self, + from_object_uuid: UUID, + from_object_collection: str, + from_property_name: str, + to: ReferenceInput, + tenant: Optional[str] = None, + ) -> None: + if isinstance(to, ReferenceToMulti): + to_strs: Union[List[str], List[UUID]] = to.uuids_str + elif isinstance(to, str) or isinstance(to, uuid_package.UUID): + to_strs = [to] + else: + to_strs = list(to) + + for uid in to_strs: + try: + batch_reference = BatchReference( + from_object_collection=from_object_collection, + from_object_uuid=from_object_uuid, + from_property_name=from_property_name, + to_object_collection=( + to.target_collection if isinstance(to, ReferenceToMulti) else None + ), + to_object_uuid=uid, + tenant=tenant, + index=self.__refs_count, + ) + except ValidationError as e: + raise WeaviateBatchValidationError(repr(e)) + await self.__batch_references.aadd( + self.__batch_grpc.grpc_reference(batch_reference._to_internal()) + ) + async with self.__refs_cache_lock: + self.__refs_cache[self.__refs_count] = batch_reference + self.__refs_count += 1 diff --git a/weaviate/collections/batch/base.py b/weaviate/collections/batch/base.py index f9c5e491a..054c422d3 100644 --- a/weaviate/collections/batch/base.py +++ b/weaviate/collections/batch/base.py @@ -1,3 +1,4 @@ +import asyncio import contextvars import functools import math @@ -10,8 +11,7 @@ from concurrent.futures import ThreadPoolExecutor from copy import copy from dataclasses import dataclass, field -from queue import Queue -from typing import Any, Dict, Generator, Generic, List, Optional, Set, TypeVar, Union, cast +from typing import Any, Dict, Generic, List, Optional, Set, TypeVar, Union, cast from httpx import ConnectError from pydantic import ValidationError @@ -40,13 +40,10 @@ ) from weaviate.collections.classes.types import WeaviateProperties from weaviate.connect import executor -from weaviate.connect.v4 import ConnectionSync +from weaviate.connect.v4 import ConnectionAsync, ConnectionSync from weaviate.exceptions import ( EmptyResponseException, - WeaviateBatchStreamError, WeaviateBatchValidationError, - WeaviateGRPCUnavailableError, - WeaviateStartUpError, ) from weaviate.logger import logger from weaviate.proto.v1 import batch_pb2 @@ -75,24 +72,42 @@ class BatchRequest(ABC, Generic[TBatchInput, TBatchReturn]): def __init__(self) -> None: self._items: List[TBatchInput] = [] self._lock = threading.Lock() + self._alock = asyncio.Lock() def __len__(self) -> int: - return len(self._items) + with self._lock: + return len(self._items) + + async def alen(self) -> int: + """Asynchronously get the length of the BatchRequest.""" + async with self._alock: + return len(self._items) def add(self, item: TBatchInput) -> None: """Add an item to the BatchRequest.""" - self._lock.acquire() - self._items.append(item) - self._lock.release() + with self._lock: + self._items.append(item) + + async def aadd(self, item: TBatchInput) -> None: + """Asynchronously add an item to the BatchRequest.""" + async with self._alock: + self._items.append(item) def prepend(self, item: List[TBatchInput]) -> None: """Add items to the front of the BatchRequest. This is intended to be used when objects should be retries, eg. after a temporary error. """ - self._lock.acquire() - self._items = item + self._items - self._lock.release() + with self._lock: + self._items = item + self._items + + async def aprepend(self, item: List[TBatchInput]) -> None: + """Asynchronously add items to the front of the BatchRequest. + + This is intended to be used when objects should be retries, eg. after a temporary error. + """ + async with self._alock: + self._items = item + self._items Ref = TypeVar("Ref", bound=Union[_BatchReference, batch_pb2.BatchReference]) @@ -101,15 +116,9 @@ def prepend(self, item: List[TBatchInput]) -> None: class ReferencesBatchRequest(BatchRequest[Ref, BatchReferenceReturn]): """Collect Weaviate-object references to add them in one request to Weaviate.""" - def pop_items(self, pop_amount: int, uuid_lookup: Set[str]) -> List[Ref]: - """Pop the given number of items from the BatchRequest queue. - - Returns: - A list of items from the BatchRequest. - """ + def __pop_items(self, pop_amount: int, uuid_lookup: Set[str]) -> List[Ref]: ret: List[Ref] = [] i = 0 - self._lock.acquire() while len(ret) < pop_amount and len(self._items) > 0 and i < len(self._items): if self._items[i].from_uuid not in uuid_lookup and ( self._items[i].to_uuid is None or self._items[i].to_uuid not in uuid_lookup @@ -117,19 +126,48 @@ def pop_items(self, pop_amount: int, uuid_lookup: Set[str]) -> List[Ref]: ret.append(self._items.pop(i)) else: i += 1 - self._lock.release() return ret + def pop_items(self, pop_amount: int, uuid_lookup: Set[str]) -> List[Ref]: + """Pop the given number of items from the BatchRequest queue. + + Returns: + A list of items from the BatchRequest. + """ + with self._lock: + return self.__pop_items(pop_amount, uuid_lookup) + + async def apop_items(self, pop_amount: int, uuid_lookup: Set[str]) -> List[Ref]: + """Asynchronously pop the given number of items from the BatchRequest queue. + + Returns: + A list of items from the BatchRequest. + """ + async with self._alock: + return self.__pop_items(pop_amount, uuid_lookup) + + def __head(self) -> Optional[Ref]: + if len(self._items) > 0: + return self._items[0] + return None + def head(self) -> Optional[Ref]: """Get the first item from the BatchRequest queue without removing it. Returns: The first item from the BatchRequest or None if the queue is empty. """ - self._lock.acquire() - item = self._items[0] if len(self._items) > 0 else None - self._lock.release() - return item + with self._lock: + return self.__head() + + async def ahead(self) -> Optional[Ref]: + """Asynchronously get the first item from the BatchRequest queue without removing it. + + Returns: + The first item from the BatchRequest or None if the queue is empty. + """ + async with self._alock: + return self.__head() Obj = TypeVar("Obj", bound=Union[_BatchObject, batch_pb2.BatchObject]) @@ -138,33 +176,55 @@ def head(self) -> Optional[Ref]: class ObjectsBatchRequest(Generic[Obj], BatchRequest[Obj, BatchObjectReturn]): """Collect objects for one batch request to weaviate.""" - def pop_items(self, pop_amount: int) -> List[Obj]: - """Pop the given number of items from the BatchRequest queue. - - Returns: - A list of items from the BatchRequest. - """ - self._lock.acquire() + def __pop_items(self, pop_amount: int) -> List[Obj]: if pop_amount >= len(self._items): ret = copy(self._items) self._items.clear() else: ret = copy(self._items[:pop_amount]) self._items = self._items[pop_amount:] - - self._lock.release() return ret + def pop_items(self, pop_amount: int) -> List[Obj]: + """Pop the given number of items from the BatchRequest queue. + + Returns: + A list of items from the BatchRequest. + """ + with self._lock: + return self.__pop_items(pop_amount) + + async def apop_items(self, pop_amount: int) -> List[Obj]: + """Asynchronously pop the given number of items from the BatchRequest queue. + + Returns: + A list of items from the BatchRequest. + """ + async with self._alock: + return self.__pop_items(pop_amount) + + def __head(self) -> Optional[Obj]: + if len(self._items) > 0: + return self._items[0] + return None + def head(self) -> Optional[Obj]: """Get the first item from the BatchRequest queue without removing it. Returns: The first item from the BatchRequest or None if the queue is empty. """ - self._lock.acquire() - item = self._items[0] if len(self._items) > 0 else None - self._lock.release() - return item + with self._lock: + return self.__head() + + async def ahead(self) -> Optional[Obj]: + """Asynchronously get the first item from the BatchRequest queue without removing it. + + Returns: + The first item from the BatchRequest or None if the queue is empty. + """ + async with self._alock: + return self.__head() @dataclass @@ -834,547 +894,35 @@ def recv_alive(self) -> bool: return self.recv.is_alive() -class _BatchBaseNew: - def __init__( - self, - connection: ConnectionSync, - consistency_level: Optional[ConsistencyLevel], - results: _BatchDataWrapper, - batch_mode: _BatchMode, - executor: ThreadPoolExecutor, - vectorizer_batching: bool, - objects: Optional[ObjectsBatchRequest[batch_pb2.BatchObject]] = None, - references: Optional[ReferencesBatchRequest] = None, - ) -> None: - self.__batch_objects = objects or ObjectsBatchRequest[batch_pb2.BatchObject]() - self.__batch_references = references or ReferencesBatchRequest[batch_pb2.BatchReference]() - - self.__connection = connection - self.__consistency_level: ConsistencyLevel = consistency_level or ConsistencyLevel.QUORUM - self.__batch_size = 100 - - self.__batch_grpc = _BatchGRPC( - connection._weaviate_version, self.__consistency_level, connection._grpc_max_msg_size - ) - - # lookup table for objects that are currently being processed - is used to not send references from objects that have not been added yet - self.__uuid_lookup: Set[str] = set() - - # we do not want that users can access the results directly as they are not thread-safe - self.__results_for_wrapper_backup = results - self.__results_for_wrapper = _BatchDataWrapper() - - self.__objs_count = 0 - self.__refs_count = 0 - - self.__uuid_lookup_lock = threading.Lock() - self.__results_lock = threading.Lock() - - self.__bg_thread_exception: Optional[Exception] = None - self.__is_shutting_down = threading.Event() - self.__is_shutdown = threading.Event() - - self.__objs_cache_lock = threading.Lock() - self.__refs_cache_lock = threading.Lock() - self.__objs_cache: dict[str, BatchObject] = {} - self.__refs_cache: dict[str, BatchReference] = {} - - # maxsize=1 so that __batch_send does not run faster than generator for __batch_recv - # thereby using too much buffer in case of server-side shutdown - self.__reqs: Queue[Optional[batch_pb2.BatchStreamRequest]] = Queue(maxsize=1) - - self.__stop = False - - self.__batch_mode = batch_mode - - self.__total = 0 - - @property - def number_errors(self) -> int: - """Return the number of errors in the batch.""" - return len(self.__results_for_wrapper.failed_objects) + len( - self.__results_for_wrapper.failed_references - ) - - def __all_threads_alive(self) -> bool: - return self.__bg_threads is not None and all( - thread.is_alive() for thread in self.__bg_threads - ) - - def __any_threads_alive(self) -> bool: - return self.__bg_threads is not None and any( - thread.is_alive() for thread in self.__bg_threads - ) - - def _start(self) -> None: - assert isinstance(self.__batch_mode, _ServerSideBatching), ( - "Only server-side batching is supported in this mode" - ) - self.__bg_threads = [ - self.__start_bg_threads() for _ in range(self.__batch_mode.concurrency) - ] - logger.warning( - f"Provisioned {len(self.__bg_threads)} stream(s) to the server for batch processing" - ) - now = time.time() - while not self.__all_threads_alive(): - # wait for the stream to be started by __batch_stream - time.sleep(0.01) - if time.time() - now > 10: - raise WeaviateBatchValidationError( - "Batch stream was not started within 10 seconds. Please check your connection." - ) - - def _shutdown(self) -> None: - # Shutdown the current batch and wait for all requests to be finished - self.flush() - self.__stop = True - - # we are done, wait for bg threads to finish - # self.__batch_stream will set the shutdown event when it receives - # the stop message from the server - while self.__any_threads_alive(): - time.sleep(1) - logger.warning("Send & receive threads finished.") - - # copy the results to the public results - self.__results_for_wrapper_backup.results = self.__results_for_wrapper.results - self.__results_for_wrapper_backup.failed_objects = self.__results_for_wrapper.failed_objects - self.__results_for_wrapper_backup.failed_references = ( - self.__results_for_wrapper.failed_references - ) - self.__results_for_wrapper_backup.imported_shards = ( - self.__results_for_wrapper.imported_shards - ) - - def __batch_send(self) -> None: - refresh_time: float = 0.01 - while ( - self.__shut_background_thread_down is not None - and not self.__shut_background_thread_down.is_set() - ): - if len(self.__batch_objects) + len(self.__batch_references) > 0: - self._batch_send = True - start = time.time() - while (len_o := len(self.__batch_objects)) + ( - len_r := len(self.__batch_references) - ) < self.__batch_size: - # wait for more objects to be added up to the batch size - time.sleep(0.01) - if ( - self.__shut_background_thread_down is not None - and self.__shut_background_thread_down.is_set() - ): - logger.warning("Threads were shutdown, exiting batch send loop") - # shutdown was requested, exit early - self.__reqs.put(None) - return - if time.time() - start >= 1 and ( - len_o == len(self.__batch_objects) or len_r == len(self.__batch_references) - ): - # no new objects were added in the last second, exit the loop - break - - objs = self.__batch_objects.pop_items(self.__batch_size) - refs = self.__batch_references.pop_items( - self.__batch_size - len(objs), - uuid_lookup=self.__uuid_lookup, - ) - with self.__uuid_lookup_lock: - self.__uuid_lookup.difference_update(obj.uuid for obj in objs) - - for req in self.__generate_stream_requests(objs, refs): - logged = False - while self.__is_shutting_down.is_set() or self.__is_shutdown.is_set(): - # if we were shutdown by the node we were connected to, we need to wait for the stream to be restarted - # so that the connection is refreshed to a new node where the objects can be accepted - # otherwise, we wait until the stream has been started by __batch_stream to send the first batch - if not logged: - logger.warning("Waiting for stream to be re-established...") - logged = True - # put sentinel into our queue to signal the end of the current stream - self.__reqs.put(None) - time.sleep(1) - if logged: - logger.warning("Stream re-established, resuming sending batches") - self.__reqs.put(req) - elif self.__stop: - # we are done, send the sentinel into our queue to be consumed by the batch sender - self.__reqs.put(None) # signal the end of the stream - logger.warning("Batching finished, sent stop signal to batch stream") - return - time.sleep(refresh_time) - - def __generate_stream_requests( - self, - objs: List[batch_pb2.BatchObject], - refs: List[batch_pb2.BatchReference], - ) -> Generator[batch_pb2.BatchStreamRequest, None, None]: - per_object_overhead = 4 # extra overhead bytes per object in the request - - def request_maker(): - return batch_pb2.BatchStreamRequest() - - request = request_maker() - total_size = request.ByteSize() - - for obj in objs: - obj_size = obj.ByteSize() + per_object_overhead - - if total_size + obj_size >= self.__batch_grpc.grpc_max_msg_size: - yield request - request = request_maker() - total_size = request.ByteSize() - - request.data.objects.values.append(obj) - total_size += obj_size - - for ref in refs: - ref_size = ref.ByteSize() + per_object_overhead - - if total_size + ref_size >= self.__batch_grpc.grpc_max_msg_size: - yield request - request = request_maker() - total_size = request.ByteSize() - - request.data.references.values.append(ref) - total_size += ref_size - - if len(request.data.objects.values) > 0 or len(request.data.references.values) > 0: - yield request - - def __generate_stream_requests_for_grpc( - self, - ) -> Generator[batch_pb2.BatchStreamRequest, None, None]: - yield batch_pb2.BatchStreamRequest( - start=batch_pb2.BatchStreamRequest.Start( - consistency_level=self.__batch_grpc._consistency_level, - ), - ) - while ( - self.__shut_background_thread_down is not None - and not self.__shut_background_thread_down.is_set() - ): - req = self.__reqs.get() - if req is not None: - self.__total += len(req.data.objects.values) + len(req.data.references.values) - yield req - continue - if self.__stop and not ( - self.__is_shutting_down.is_set() or self.__is_shutdown.is_set() - ): - logger.warning("Batching finished, closing the client-side of the stream") - yield batch_pb2.BatchStreamRequest(stop=batch_pb2.BatchStreamRequest.Stop()) - return - if self.__is_shutting_down.is_set(): - logger.warning("Server shutting down, closing the client-side of the stream") - return - logger.warning("Received sentinel, but not stopping, continuing...") - - def __batch_recv(self) -> None: - for message in self.__batch_grpc.stream( - connection=self.__connection, - requests=self.__generate_stream_requests_for_grpc(), - ): - if message.HasField("started"): - logger.warning("Batch stream started successfully") - for threads in self.__bg_threads: - threads.start_send() - if message.HasField("backoff"): - if ( - message.backoff.batch_size != self.__batch_size - and not self.__is_shutting_down.is_set() - and not self.__is_shutdown.is_set() - and not self.__stop - ): - self.__batch_size = message.backoff.batch_size - logger.warning( - f"Updated batch size to {self.__batch_size} as per server request" - ) - if message.HasField("results"): - result_objs = BatchObjectReturn() - result_refs = BatchReferenceReturn() - failed_objs: List[ErrorObject] = [] - failed_refs: List[ErrorReference] = [] - for error in message.results.errors: - if error.HasField("uuid"): - try: - cached = self.__objs_cache.pop(error.uuid) - except KeyError: - continue - err = ErrorObject( - message=error.error, - object_=cached, - ) - result_objs += BatchObjectReturn( - _all_responses=[err], - errors={cached.index: err}, - ) - failed_objs.append(err) - logger.warning( - { - "error": error.error, - "object": error.uuid, - "action": "use {client,collection}.batch.failed_objects to access this error", - } - ) - if error.HasField("beacon"): - try: - cached = self.__refs_cache.pop(error.beacon) - except KeyError: - continue - err = ErrorReference( - message=error.error, - reference=error.beacon, # pyright: ignore - ) - failed_refs.append(err) - result_refs += BatchReferenceReturn( - errors={cached.index: err}, - ) - logger.warning( - { - "error": error.error, - "reference": error.beacon, - "action": "use {client,collection}.batch.failed_references to access this error", - } - ) - for success in message.results.successes: - if success.HasField("uuid"): - try: - cached = self.__objs_cache.pop(success.uuid) - except KeyError: - continue - uuid = uuid_package.UUID(success.uuid) - result_objs += BatchObjectReturn( - _all_responses=[uuid], - uuids={cached.index: uuid}, - ) - if success.HasField("beacon"): - try: - self.__refs_cache.pop(success.beacon, None) - except KeyError: - continue - with self.__results_lock: - self.__results_for_wrapper.results.objs += result_objs - self.__results_for_wrapper.results.refs += result_refs - self.__results_for_wrapper.failed_objects.extend(failed_objs) - self.__results_for_wrapper.failed_references.extend(failed_refs) - elif message.HasField("shutting_down"): - logger.warning( - "Received shutting down message from server, pausing sending until stream is re-established" - ) - self.__is_shutting_down.set() - elif message.HasField("shutdown"): - logger.warning("Received shutdown finished message from server") - self.__is_shutdown.set() - self.__is_shutting_down.clear() - self.__reconnect() - - # restart the stream if we were shutdown by the node we were connected to ensuring that the index is - # propagated properly from it to the new one - if self.__is_shutdown.is_set(): - logger.warning("Restarting batch recv after shutdown...") - self.__is_shutdown.clear() - return self.__batch_recv() - else: - logger.warning("Server closed the stream from its side, shutting down batch") - return - - def __reconnect(self, retry: int = 0) -> None: - if self.__consistency_level == ConsistencyLevel.ALL: - # check that all nodes are available before reconnecting - cluster = _ClusterBatch(self.__connection) - while len(nodes := cluster.get_nodes_status()) != 3 or any( - node["status"] != "HEALTHY" for node in nodes - ): - logger.warning( - "Waiting for all nodes to be HEALTHY before reconnecting to batch stream due to CL=ALL..." - ) - time.sleep(5) - try: - logger.warning(f"Trying to reconnect after shutdown... {retry + 1}/{5}") - self.__connection.close("sync") - self.__connection.connect(force=True) - logger.warning("Reconnected successfully") - except (WeaviateStartUpError, WeaviateGRPCUnavailableError) as e: - if retry < 5: - time.sleep(2**retry) - self.__reconnect(retry + 1) - else: - logger.error("Failed to reconnect after 5 attempts") - self.__bg_thread_exception = e - - def __start_bg_threads(self) -> _BgThreads: - """Create a background thread that periodically checks how congested the batch queue is.""" - self.__shut_background_thread_down = threading.Event() - - def batch_send_wrapper() -> None: - try: - self.__batch_send() - logger.warning("exited batch send thread") - except Exception as e: - logger.error(e) - self.__bg_thread_exception = e - - def batch_recv_wrapper() -> None: - socket_hung_up = False - try: - self.__batch_recv() - logger.warning("exited batch receive thread") - except Exception as e: - if isinstance(e, WeaviateBatchStreamError) and ( - "Socket closed" in e.message or "context canceled" in e.message - ): - socket_hung_up = True - else: - logger.error(e) - logger.error(type(e)) - self.__bg_thread_exception = e - if socket_hung_up: - # this happens during ungraceful shutdown of the coordinator - # lets restart the stream and add the cached objects again - logger.warning("Stream closed unexpectedly, restarting...") - self.__reconnect() - # server sets this whenever it restarts, gracefully or unexpectedly, so need to clear it now - self.__is_shutting_down.clear() - with self.__objs_cache_lock: - logger.warning( - f"Re-adding {len(self.__objs_cache)} cached objects to the batch" - ) - self.__batch_objects.prepend( - [ - self.__batch_grpc.grpc_object(o._to_internal()) - for o in self.__objs_cache.values() - ] - ) - with self.__refs_cache_lock: - self.__batch_references.prepend( - [ - self.__batch_grpc.grpc_reference(o._to_internal()) - for o in self.__refs_cache.values() - ] - ) - # start a new stream with a newly reconnected channel - return batch_recv_wrapper() - - threads = _BgThreads( - send=threading.Thread( - target=batch_send_wrapper, - daemon=True, - name="BgBatchSend", - ), - recv=threading.Thread( - target=batch_recv_wrapper, - daemon=True, - name="BgBatchRecv", - ), - ) - threads.start_recv() - return threads - - def flush(self) -> None: - """Flush the batch queue and wait for all requests to be finished.""" - # bg thread is sending objs+refs automatically, so simply wait for everything to be done - while len(self.__batch_objects) > 0 or len(self.__batch_references) > 0: - time.sleep(0.01) - self.__check_bg_threads_alive() +class _ClusterBatch: + def __init__(self, connection: ConnectionSync): + self._connection = connection - def _add_object( + def get_nodes_status( self, - collection: str, - properties: Optional[WeaviateProperties] = None, - references: Optional[ReferenceInputs] = None, - uuid: Optional[UUID] = None, - vector: Optional[VECTORS] = None, - tenant: Optional[str] = None, - ) -> UUID: - self.__check_bg_threads_alive() + ) -> List[Node]: try: - batch_object = BatchObject( - collection=collection, - properties=properties, - references=references, - uuid=uuid, - vector=vector, - tenant=tenant, - index=self.__objs_count, - ) - self.__results_for_wrapper.imported_shards.add( - Shard(collection=collection, tenant=tenant) - ) - except ValidationError as e: - raise WeaviateBatchValidationError(repr(e)) - uuid = str(batch_object.uuid) - with self.__uuid_lookup_lock: - self.__uuid_lookup.add(uuid) - self.__batch_objects.add(self.__batch_grpc.grpc_object(batch_object._to_internal())) - with self.__objs_cache_lock: - self.__objs_cache[uuid] = batch_object - self.__objs_count += 1 - - # block if queue gets too long or weaviate is overloaded - reading files is faster them sending them so we do - # not need a long queue - while len(self.__batch_objects) >= self.__batch_size * 2: - self.__check_bg_threads_alive() - time.sleep(0.01) - - assert batch_object.uuid is not None - return batch_object.uuid - - def _add_reference( - self, - from_object_uuid: UUID, - from_object_collection: str, - from_property_name: str, - to: ReferenceInput, - tenant: Optional[str] = None, - ) -> None: - self.__check_bg_threads_alive() - if isinstance(to, ReferenceToMulti): - to_strs: Union[List[str], List[UUID]] = to.uuids_str - elif isinstance(to, str) or isinstance(to, uuid_package.UUID): - to_strs = [to] - else: - to_strs = list(to) - - for uid in to_strs: - try: - batch_reference = BatchReference( - from_object_collection=from_object_collection, - from_object_uuid=from_object_uuid, - from_property_name=from_property_name, - to_object_collection=( - to.target_collection if isinstance(to, ReferenceToMulti) else None - ), - to_object_uuid=uid, - tenant=tenant, - index=self.__refs_count, - ) - except ValidationError as e: - raise WeaviateBatchValidationError(repr(e)) - self.__batch_references.add( - self.__batch_grpc.grpc_reference(batch_reference._to_internal()) - ) - with self.__refs_cache_lock: - self.__refs_cache[batch_reference._to_beacon()] = batch_reference - self.__refs_count += 1 - - def __check_bg_threads_alive(self) -> None: - if self.__any_threads_alive(): - return + response = executor.result(self._connection.get(path="/nodes")) + except ConnectError as conn_err: + raise ConnectError("Get nodes status failed due to connection error") from conn_err - raise self.__bg_thread_exception or Exception("Batch thread died unexpectedly") + response_typed = _decode_json_response_dict(response, "Nodes status") + assert response_typed is not None + nodes = response_typed.get("nodes") + if nodes is None or nodes == []: + raise EmptyResponseException("Nodes status response returned empty") + return cast(List[Node], nodes) -class _ClusterBatch: - def __init__(self, connection: ConnectionSync): +class _ClusterBatchAsync: + def __init__(self, connection: ConnectionAsync): self._connection = connection - def get_nodes_status( + async def get_nodes_status( self, ) -> List[Node]: try: - response = executor.result(self._connection.get(path="/nodes")) + response = await executor.aresult(self._connection.get(path="/nodes")) except ConnectError as conn_err: raise ConnectError("Get nodes status failed due to connection error") from conn_err diff --git a/weaviate/collections/batch/batch_wrapper.py b/weaviate/collections/batch/batch_wrapper.py index a64f267ca..f8e40395c 100644 --- a/weaviate/collections/batch/batch_wrapper.py +++ b/weaviate/collections/batch/batch_wrapper.py @@ -1,14 +1,18 @@ +import asyncio import time from typing import Any, Generic, List, Optional, Protocol, TypeVar, Union, cast +from weaviate.collections.batch.async_ import _BatchBaseAsync from weaviate.collections.batch.base import ( _BatchBase, - _BatchBaseNew, _BatchDataWrapper, _BatchMode, _ClusterBatch, + _ClusterBatchAsync, _DynamicBatching, + _ServerSideBatching, ) +from weaviate.collections.batch.sync import _BatchBaseSync from weaviate.collections.classes.batch import ( BatchResult, ErrorObject, @@ -20,7 +24,7 @@ from weaviate.collections.classes.tenants import Tenant from weaviate.collections.classes.types import Properties, WeaviateProperties from weaviate.connect import executor -from weaviate.connect.v4 import ConnectionSync +from weaviate.connect.v4 import ConnectionAsync, ConnectionSync from weaviate.logger import logger from weaviate.types import UUID, VECTORS from weaviate.util import _capitalize_first_letter, _decode_json_response_list @@ -34,7 +38,7 @@ def __init__( ): self._connection = connection self._consistency_level = consistency_level - self._current_batch: Optional[Union[_BatchBase, _BatchBaseNew]] = None + self._current_batch: Optional[Union[_BatchBase, _BatchBaseSync]] = None # config options self._batch_mode: _BatchMode = _DynamicBatching() @@ -127,6 +131,109 @@ def results(self) -> BatchResult: return self._batch_data.results +class _BatchWrapperAsync: + def __init__( + self, + connection: ConnectionAsync, + consistency_level: Optional[ConsistencyLevel], + ): + self._connection = connection + self._consistency_level = consistency_level + self._current_batch: Optional[_BatchBaseAsync] = None + # config options + self._batch_mode: _BatchMode = _ServerSideBatching(1) + + self._batch_data = _BatchDataWrapper() + self._cluster = _ClusterBatchAsync(connection) + + async def __is_ready( + self, max_count: int, shards: Optional[List[Shard]], backoff_count: int = 0 + ) -> bool: + try: + readinesses = await asyncio.gather( + *[ + self.__get_shards_readiness(shard) + for shard in shards or self._batch_data.imported_shards + ] + ) + return all(all(readiness) for readiness in readinesses) + except Exception as e: + logger.warning( + f"Error while getting class shards statuses: {e}, trying again with 2**n={2**backoff_count}s exponential backoff with n={backoff_count}" + ) + if backoff_count >= max_count: + raise e + await asyncio.sleep(2**backoff_count) + return await self.__is_ready(max_count, shards, backoff_count + 1) + + async def wait_for_vector_indexing( + self, shards: Optional[List[Shard]] = None, how_many_failures: int = 5 + ) -> None: + """Wait for the all the vectors of the batch imported objects to be indexed. + + Upon network error, it will retry to get the shards' status for `how_many_failures` times + with exponential backoff (2**n seconds with n=0,1,2,...,how_many_failures). + + Args: + shards: The shards to check the status of. If `None` it will check the status of all the shards of the imported objects in the batch. + how_many_failures: How many times to try to get the shards' status before raising an exception. Default 5. + """ + if shards is not None and not isinstance(shards, list): + raise TypeError(f"'shards' must be of type List[Shard]. Given type: {type(shards)}.") + if shards is not None and not isinstance(shards[0], Shard): + raise TypeError(f"'shards' must be of type List[Shard]. Given type: {type(shards)}.") + + waiting_count = 0 + while not await self.__is_ready(how_many_failures, shards): + if waiting_count % 20 == 0: # print every 5s + logger.debug("Waiting for async indexing to finish...") + await asyncio.sleep(0.25) + waiting_count += 1 + logger.debug("Async indexing finished!") + + async def __get_shards_readiness(self, shard: Shard) -> List[bool]: + path = f"/schema/{_capitalize_first_letter(shard.collection)}/shards{'' if shard.tenant is None else f'?tenant={shard.tenant}'}" + response = await executor.aresult(self._connection.get(path=path)) + + res = _decode_json_response_list(response, "Get shards' status") + assert res is not None + return [ + (cast(str, shard.get("status")) == "READY") + & (cast(int, shard.get("vectorQueueSize")) == 0) + for shard in res + ] + + async def _get_shards_readiness(self, shard: Shard) -> List[bool]: + return await self.__get_shards_readiness(shard) + + @property + def failed_objects(self) -> List[ErrorObject]: + """Get all failed objects from the batch manager. + + Returns: + A list of all the failed objects from the batch. + """ + return self._batch_data.failed_objects + + @property + def failed_references(self) -> List[ErrorReference]: + """Get all failed references from the batch manager. + + Returns: + A list of all the failed references from the batch. + """ + return self._batch_data.failed_references + + @property + def results(self) -> BatchResult: + """Get the results of the batch operation. + + Returns: + The results of the batch operation. + """ + return self._batch_data.results + + class BatchClientProtocol(Protocol): def add_object( self, @@ -204,6 +311,83 @@ def number_errors(self) -> int: ... +class BatchClientProtocolAsync(Protocol): + async def add_object( + self, + collection: str, + properties: Optional[WeaviateProperties] = None, + references: Optional[ReferenceInputs] = None, + uuid: Optional[UUID] = None, + vector: Optional[VECTORS] = None, + tenant: Optional[Union[str, Tenant]] = None, + ) -> UUID: + """Add one object to this batch. + + NOTE: If the UUID of one of the objects already exists then the existing object will be + replaced by the new object. + + Args: + collection: The name of the collection this object belongs to. + properties: The data properties of the object to be added as a dictionary. + references: The references of the object to be added as a dictionary. + uuid: The UUID of the object as an uuid.UUID object or str. It can be a Weaviate beacon or Weaviate href. + If it is None an UUIDv4 will generated, by default None + vector: The embedding of the object. Can be used when a collection does not have a vectorization module or the given + vector was generated using the _identical_ vectorization module that is configured for the class. In this + case this vector takes precedence. + Supported types are: + - for single vectors: `list`, 'numpy.ndarray`, `torch.Tensor` and `tf.Tensor`, by default None. + - for named vectors: Dict[str, *list above*], where the string is the name of the vector. + tenant: The tenant name or Tenant object to be used for this request. + + Returns: + The UUID of the added object. If one was not provided a UUIDv4 will be auto-generated for you and returned here. + + Raises: + WeaviateBatchValidationError: If the provided options are in the format required by Weaviate. + """ + ... + + async def add_reference( + self, + from_uuid: UUID, + from_collection: str, + from_property: str, + to: ReferenceInput, + tenant: Optional[Union[str, Tenant]] = None, + ) -> None: + """Add one reference to this batch. + + Args: + from_uuid: The UUID of the object, as an uuid.UUID object or str, that should reference another object. + from_collection: The name of the collection that should reference another object. + from_property: The name of the property that contains the reference. + to: The UUID of the referenced object, as an uuid.UUID object or str, that is actually referenced. + For multi-target references use wvc.Reference.to_multi_target(). + tenant: The tenant name or Tenant object to be used for this request. + + Raises: + WeaviateBatchValidationError: If the provided options are in the format required by Weaviate. + """ + ... + + async def flush(self) -> None: + """Flush the current batch. + + This will send all the objects and references in the current batch to Weaviate. + """ + ... + + @property + def number_errors(self) -> int: + """Get the number of errors in the current batch. + + Returns: + The number of errors in the current batch. + """ + ... + + class BatchCollectionProtocol(Generic[Properties], Protocol[Properties]): def add_object( self, @@ -260,7 +444,7 @@ def number_errors(self) -> int: ... -T = TypeVar("T", bound=Union[_BatchBase, _BatchBaseNew]) +T = TypeVar("T", bound=Union[_BatchBase, _BatchBaseSync]) P = TypeVar("P", bound=Union[BatchClientProtocol, BatchCollectionProtocol[Properties]]) @@ -274,3 +458,91 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: def __enter__(self) -> P: self.__current_batch._start() return self.__current_batch # pyright: ignore[reportReturnType] + + +class BatchClientAsync(_BatchBaseAsync): + async def add_object( + self, + collection: str, + properties: Optional[WeaviateProperties] = None, + references: Optional[ReferenceInputs] = None, + uuid: Optional[UUID] = None, + vector: Optional[VECTORS] = None, + tenant: Optional[Union[str, Tenant]] = None, + ) -> UUID: + """Add one object to this batch. + + NOTE: If the UUID of one of the objects already exists then the existing object will be + replaced by the new object. + + Args: + collection: The name of the collection this object belongs to. + properties: The data properties of the object to be added as a dictionary. + references: The references of the object to be added as a dictionary. + uuid: The UUID of the object as an uuid.UUID object or str. It can be a Weaviate beacon or Weaviate href. + If it is None an UUIDv4 will generated, by default None + vector: The embedding of the object. Can be used when a collection does not have a vectorization module or the given + vector was generated using the _identical_ vectorization module that is configured for the class. In this + case this vector takes precedence. + Supported types are: + - for single vectors: `list`, 'numpy.ndarray`, `torch.Tensor` and `tf.Tensor`, by default None. + - for named vectors: Dict[str, *list above*], where the string is the name of the vector. + tenant: The tenant name or Tenant object to be used for this request. + + Returns: + The UUID of the added object. If one was not provided a UUIDv4 will be auto-generated for you and returned here. + + Raises: + WeaviateBatchValidationError: If the provided options are in the format required by Weaviate. + """ + return await super()._add_object( + collection=collection, + properties=properties, + references=references, + uuid=uuid, + vector=vector, + tenant=tenant.name if isinstance(tenant, Tenant) else tenant, + ) + + async def add_reference( + self, + from_uuid: UUID, + from_collection: str, + from_property: str, + to: ReferenceInput, + tenant: Optional[Union[str, Tenant]] = None, + ) -> None: + """Add one reference to this batch. + + Args: + from_uuid: The UUID of the object, as an uuid.UUID object or str, that should reference another object. + from_collection: The name of the collection that should reference another object. + from_property: The name of the property that contains the reference. + to: The UUID of the referenced object, as an uuid.UUID object or str, that is actually referenced. + For multi-target references use wvc.Reference.to_multi_target(). + tenant: The tenant name or Tenant object to be used for this request. + + Raises: + WeaviateBatchValidationError: If the provided options are in the format required by Weaviate. + """ + await super()._add_reference( + from_object_uuid=from_uuid, + from_object_collection=from_collection, + from_property_name=from_property, + to=to, + tenant=tenant.name if isinstance(tenant, Tenant) else tenant, + ) + + +class _ContextManagerWrapperAsync: + def __init__(self, current_batch: BatchClientAsync): + self.__current_batch = current_batch + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + await self.__current_batch._shutdown() + await self.__bg_tasks.send + await self.__bg_tasks.recv + + async def __aenter__(self) -> BatchClientAsync: + self.__bg_tasks = await self.__current_batch._start() + return self.__current_batch diff --git a/weaviate/collections/batch/client.py b/weaviate/collections/batch/client.py index a86a3be10..0aaddd718 100644 --- a/weaviate/collections/batch/client.py +++ b/weaviate/collections/batch/client.py @@ -3,7 +3,6 @@ from weaviate.collections.batch.base import ( _BatchBase, - _BatchBaseNew, _BatchDataWrapper, _DynamicBatching, _FixedSizeBatching, @@ -11,16 +10,20 @@ _ServerSideBatching, ) from weaviate.collections.batch.batch_wrapper import ( + BatchClientAsync, BatchClientProtocol, _BatchMode, _BatchWrapper, + _BatchWrapperAsync, _ContextManagerWrapper, + _ContextManagerWrapperAsync, ) +from weaviate.collections.batch.sync import _BatchBaseSync from weaviate.collections.classes.config import ConsistencyLevel, Vectorizers from weaviate.collections.classes.internal import ReferenceInput, ReferenceInputs from weaviate.collections.classes.tenants import Tenant from weaviate.collections.classes.types import WeaviateProperties -from weaviate.connect.v4 import ConnectionSync +from weaviate.connect.v4 import ConnectionAsync, ConnectionSync from weaviate.exceptions import UnexpectedStatusCodeError, WeaviateUnsupportedFeatureError from weaviate.types import UUID, VECTORS @@ -102,7 +105,7 @@ def add_reference( ) -class _BatchClientNew(_BatchBaseNew): +class _BatchClientSync(_BatchBaseSync): def add_object( self, collection: str, @@ -177,10 +180,11 @@ def add_reference( BatchClient = _BatchClient -BatchClientNew = _BatchClientNew +BatchClientSync = _BatchClientSync ClientBatchingContextManager = _ContextManagerWrapper[ - Union[BatchClient, BatchClientNew], BatchClientProtocol + Union[BatchClient, BatchClientSync], BatchClientProtocol ] +AsyncClientBatchingContextManager = _ContextManagerWrapperAsync class _BatchClientWrapper(_BatchWrapper): @@ -197,7 +201,7 @@ def __init__( # define one executor per client with it shared between all child batch contexts def __create_batch_and_reset( - self, batch_client: Union[Type[_BatchClient], Type[_BatchClientNew]] + self, batch_client: Union[Type[_BatchClient], Type[_BatchClientSync]] ): if self._vectorizer_batching is None or not self._vectorizer_batching: try: @@ -311,4 +315,47 @@ def experimental( concurrency=1, # hard-code until client-side multi-threading is fixed ) self._consistency_level = consistency_level - return self.__create_batch_and_reset(_BatchClientNew) + return self.__create_batch_and_reset(_BatchClientSync) + + +class _BatchClientWrapperAsync(_BatchWrapperAsync): + def __init__( + self, + connection: ConnectionAsync, + ): + super().__init__(connection, None) + self._vectorizer_batching: Optional[bool] = None + + def __create_batch_and_reset(self): + self._batch_data = _BatchDataWrapper() # clear old data + return _ContextManagerWrapperAsync( + BatchClientAsync( + connection=self._connection, + consistency_level=self._consistency_level, + results=self._batch_data, + batch_mode=self._batch_mode, + ) + ) + + def experimental( + self, + *, + concurrency: Optional[int] = None, + consistency_level: Optional[ConsistencyLevel] = None, + ) -> AsyncClientBatchingContextManager: + """Configure the batching context manager using the experimental server-side batching mode. + + When you exit the context manager, the final batch will be sent automatically. + """ + if self._connection._weaviate_version.is_lower_than(1, 34, 0): + raise WeaviateUnsupportedFeatureError( + "Server-side batching", str(self._connection._weaviate_version), "1.34.0" + ) + self._batch_mode = _ServerSideBatching( + # concurrency=concurrency + # if concurrency is not None + # else len(self._cluster.get_nodes_status()) + concurrency=1, # hard-code until client-side multi-threading is fixed + ) + self._consistency_level = consistency_level + return self.__create_batch_and_reset() diff --git a/weaviate/collections/batch/collection.py b/weaviate/collections/batch/collection.py index 6abe4aaac..dcca3d9d7 100644 --- a/weaviate/collections/batch/collection.py +++ b/weaviate/collections/batch/collection.py @@ -3,7 +3,6 @@ from weaviate.collections.batch.base import ( _BatchBase, - _BatchBaseNew, _BatchDataWrapper, _BatchMode, _DynamicBatching, @@ -16,6 +15,7 @@ _BatchWrapper, _ContextManagerWrapper, ) +from weaviate.collections.batch.sync import _BatchBaseSync from weaviate.collections.classes.config import ConsistencyLevel, Vectorizers from weaviate.collections.classes.internal import ReferenceInput, ReferenceInputs from weaviate.collections.classes.types import Properties @@ -78,7 +78,7 @@ def add_reference( ) -class _BatchCollectionNew(Generic[Properties], _BatchBaseNew): +class _BatchCollectionSync(Generic[Properties], _BatchBaseSync): def __init__( self, executor: ThreadPoolExecutor, @@ -161,9 +161,9 @@ def add_reference( BatchCollection = _BatchCollection -BatchCollectionNew = _BatchCollectionNew +BatchCollectionSync = _BatchCollectionSync CollectionBatchingContextManager = _ContextManagerWrapper[ - Union[BatchCollection[Properties], BatchCollectionNew[Properties]], + Union[BatchCollection[Properties], BatchCollectionSync[Properties]], BatchCollectionProtocol[Properties], ] @@ -177,7 +177,7 @@ def __init__( tenant: Optional[str], config: "_ConfigCollection", batch_client: Union[ - Type[_BatchCollection[Properties]], Type[_BatchCollectionNew[Properties]] + Type[_BatchCollection[Properties]], Type[_BatchCollectionSync[Properties]] ], ) -> None: super().__init__(connection, consistency_level) @@ -192,7 +192,7 @@ def __init__( def __create_batch_and_reset( self, batch_client: Union[ - Type[_BatchCollection[Properties]], Type[_BatchCollectionNew[Properties]] + Type[_BatchCollection[Properties]], Type[_BatchCollectionSync[Properties]] ], ): if self._vectorizer_batching is None: @@ -278,4 +278,4 @@ def experimental( # else len(self._cluster.get_nodes_status()) concurrency=1, # hard-code until client-side multi-threading is fixed ) - return self.__create_batch_and_reset(_BatchCollectionNew) + return self.__create_batch_and_reset(_BatchCollectionSync) diff --git a/weaviate/collections/batch/grpc_batch.py b/weaviate/collections/batch/grpc_batch.py index 7384dcb49..6f01b2287 100644 --- a/weaviate/collections/batch/grpc_batch.py +++ b/weaviate/collections/batch/grpc_batch.py @@ -20,7 +20,7 @@ from weaviate.collections.grpc.shared import _BaseGRPC, _is_1d_vector, _Pack from weaviate.connect import executor from weaviate.connect.base import MAX_GRPC_MESSAGE_LENGTH -from weaviate.connect.v4 import Connection, ConnectionSync +from weaviate.connect.v4 import Connection, ConnectionAsync, ConnectionSync from weaviate.exceptions import ( WeaviateInsertInvalidPropertyError, WeaviateInsertManyAllFailedError, @@ -203,8 +203,8 @@ def stream( connection: ConnectionSync, *, requests: Generator[batch_pb2.BatchStreamRequest, None, None], - ) -> Generator[batch_pb2.BatchStreamReply, None, None]: - """Start a new stream for receiving messages about the ongoing server-side batching from Weaviate. + ): + """Start a new sync stream for send/recv messages about the ongoing server-side batching from Weaviate. Args: connection: The connection to the Weaviate instance. @@ -212,6 +212,17 @@ def stream( """ return connection.grpc_batch_stream(requests=requests) + def astream( + self, + connection: ConnectionAsync, + ): + """Start a new async stream for send/recv messages about the ongoing server-side batching from Weaviate. + + Args: + connection: The connection to the Weaviate instance. + """ + return connection.grpc_batch_stream() + def __translate_properties_from_python_to_grpc( self, data: Dict[str, Any], refs: ReferenceInputs ) -> batch_pb2.BatchObject.Properties: diff --git a/weaviate/collections/batch/sync.py b/weaviate/collections/batch/sync.py new file mode 100644 index 000000000..d87b5db02 --- /dev/null +++ b/weaviate/collections/batch/sync.py @@ -0,0 +1,577 @@ +import threading +import time +import uuid as uuid_package +from concurrent.futures import ThreadPoolExecutor +from queue import Queue +from typing import Generator, List, Optional, Set, Union + +from pydantic import ValidationError + +from weaviate.collections.batch.base import ( + ObjectsBatchRequest, + ReferencesBatchRequest, + _BatchDataWrapper, + _BatchMode, + _BgThreads, + _ClusterBatch, + _ServerSideBatching, +) +from weaviate.collections.batch.grpc_batch import _BatchGRPC +from weaviate.collections.classes.batch import ( + BatchObject, + BatchObjectReturn, + BatchReference, + BatchReferenceReturn, + ErrorObject, + ErrorReference, + Shard, +) +from weaviate.collections.classes.config import ConsistencyLevel +from weaviate.collections.classes.internal import ( + ReferenceInput, + ReferenceInputs, + ReferenceToMulti, +) +from weaviate.collections.classes.types import WeaviateProperties +from weaviate.connect.v4 import ConnectionSync +from weaviate.exceptions import ( + WeaviateBatchStreamError, + WeaviateBatchValidationError, + WeaviateGRPCUnavailableError, + WeaviateStartUpError, +) +from weaviate.logger import logger +from weaviate.proto.v1 import batch_pb2 +from weaviate.types import UUID, VECTORS + + +class _BatchBaseSync: + def __init__( + self, + connection: ConnectionSync, + consistency_level: Optional[ConsistencyLevel], + results: _BatchDataWrapper, + batch_mode: _BatchMode, + executor: ThreadPoolExecutor, + vectorizer_batching: bool, + objects: Optional[ObjectsBatchRequest[batch_pb2.BatchObject]] = None, + references: Optional[ReferencesBatchRequest] = None, + ) -> None: + self.__batch_objects = objects or ObjectsBatchRequest[batch_pb2.BatchObject]() + self.__batch_references = references or ReferencesBatchRequest[batch_pb2.BatchReference]() + + self.__connection = connection + self.__consistency_level: ConsistencyLevel = consistency_level or ConsistencyLevel.QUORUM + self.__batch_size = 100 + + self.__batch_grpc = _BatchGRPC( + connection._weaviate_version, self.__consistency_level, connection._grpc_max_msg_size + ) + + # lookup table for objects that are currently being processed - is used to not send references from objects that have not been added yet + self.__uuid_lookup: Set[str] = set() + + # we do not want that users can access the results directly as they are not thread-safe + self.__results_for_wrapper_backup = results + self.__results_for_wrapper = _BatchDataWrapper() + + self.__objs_count = 0 + self.__refs_count = 0 + + self.__uuid_lookup_lock = threading.Lock() + self.__results_lock = threading.Lock() + + self.__bg_thread_exception: Optional[Exception] = None + self.__is_shutting_down = threading.Event() + self.__is_shutdown = threading.Event() + + self.__objs_cache_lock = threading.Lock() + self.__refs_cache_lock = threading.Lock() + self.__objs_cache: dict[str, BatchObject] = {} + self.__refs_cache: dict[str, BatchReference] = {} + + # maxsize=1 so that __batch_send does not run faster than generator for __batch_recv + # thereby using too much buffer in case of server-side shutdown + self.__reqs: Queue[Optional[batch_pb2.BatchStreamRequest]] = Queue(maxsize=1) + + self.__stop = False + + self.__batch_mode = batch_mode + + self.__total = 0 + + @property + def number_errors(self) -> int: + """Return the number of errors in the batch.""" + return len(self.__results_for_wrapper.failed_objects) + len( + self.__results_for_wrapper.failed_references + ) + + def __all_threads_alive(self) -> bool: + return self.__bg_threads is not None and all( + thread.is_alive() for thread in self.__bg_threads + ) + + def __any_threads_alive(self) -> bool: + return self.__bg_threads is not None and any( + thread.is_alive() for thread in self.__bg_threads + ) + + def _start(self) -> None: + assert isinstance(self.__batch_mode, _ServerSideBatching), ( + "Only server-side batching is supported in this mode" + ) + self.__bg_threads = [ + self.__start_bg_threads() for _ in range(self.__batch_mode.concurrency) + ] + logger.warning( + f"Provisioned {len(self.__bg_threads)} stream(s) to the server for batch processing" + ) + now = time.time() + while not self.__all_threads_alive(): + # wait for the stream to be started by __batch_stream + time.sleep(0.01) + if time.time() - now > 10: + raise WeaviateBatchValidationError( + "Batch stream was not started within 10 seconds. Please check your connection." + ) + + def _shutdown(self) -> None: + # Shutdown the current batch and wait for all requests to be finished + self.flush() + self.__stop = True + + # we are done, wait for bg threads to finish + # self.__batch_stream will set the shutdown event when it receives + # the stop message from the server + while self.__any_threads_alive(): + time.sleep(1) + logger.warning("Send & receive threads finished.") + + # copy the results to the public results + self.__results_for_wrapper_backup.results = self.__results_for_wrapper.results + self.__results_for_wrapper_backup.failed_objects = self.__results_for_wrapper.failed_objects + self.__results_for_wrapper_backup.failed_references = ( + self.__results_for_wrapper.failed_references + ) + self.__results_for_wrapper_backup.imported_shards = ( + self.__results_for_wrapper.imported_shards + ) + + def __batch_send(self) -> None: + refresh_time: float = 0.01 + while ( + self.__shut_background_thread_down is not None + and not self.__shut_background_thread_down.is_set() + ): + if len(self.__batch_objects) + len(self.__batch_references) > 0: + self._batch_send = True + start = time.time() + while (len_o := len(self.__batch_objects)) + ( + len_r := len(self.__batch_references) + ) < self.__batch_size: + # wait for more objects to be added up to the batch size + time.sleep(0.01) + if ( + self.__shut_background_thread_down is not None + and self.__shut_background_thread_down.is_set() + ): + logger.warning("Threads were shutdown, exiting batch send loop") + # shutdown was requested, exit early + self.__reqs.put(None) + return + if time.time() - start >= 1 and ( + len_o == len(self.__batch_objects) or len_r == len(self.__batch_references) + ): + # no new objects were added in the last second, exit the loop + break + + objs = self.__batch_objects.pop_items(self.__batch_size) + refs = self.__batch_references.pop_items( + self.__batch_size - len(objs), + uuid_lookup=self.__uuid_lookup, + ) + with self.__uuid_lookup_lock: + self.__uuid_lookup.difference_update(obj.uuid for obj in objs) + + for req in self.__generate_stream_requests(objs, refs): + logged = False + while self.__is_shutting_down.is_set() or self.__is_shutdown.is_set(): + # if we were shutdown by the node we were connected to, we need to wait for the stream to be restarted + # so that the connection is refreshed to a new node where the objects can be accepted + # otherwise, we wait until the stream has been started by __batch_stream to send the first batch + if not logged: + logger.warning("Waiting for stream to be re-established...") + logged = True + # put sentinel into our queue to signal the end of the current stream + self.__reqs.put(None) + time.sleep(1) + if logged: + logger.warning("Stream re-established, resuming sending batches") + self.__reqs.put(req) + elif self.__stop: + # we are done, send the sentinel into our queue to be consumed by the batch sender + self.__reqs.put(None) # signal the end of the stream + logger.warning("Batching finished, sent stop signal to batch stream") + return + time.sleep(refresh_time) + + def __generate_stream_requests( + self, + objs: List[batch_pb2.BatchObject], + refs: List[batch_pb2.BatchReference], + ) -> Generator[batch_pb2.BatchStreamRequest, None, None]: + per_object_overhead = 4 # extra overhead bytes per object in the request + + def request_maker(): + return batch_pb2.BatchStreamRequest() + + request = request_maker() + total_size = request.ByteSize() + + for obj in objs: + obj_size = obj.ByteSize() + per_object_overhead + + if total_size + obj_size >= self.__batch_grpc.grpc_max_msg_size: + yield request + request = request_maker() + total_size = request.ByteSize() + + request.data.objects.values.append(obj) + total_size += obj_size + + for ref in refs: + ref_size = ref.ByteSize() + per_object_overhead + + if total_size + ref_size >= self.__batch_grpc.grpc_max_msg_size: + yield request + request = request_maker() + total_size = request.ByteSize() + + request.data.references.values.append(ref) + total_size += ref_size + + if len(request.data.objects.values) > 0 or len(request.data.references.values) > 0: + yield request + + def __generate_stream_requests_for_grpc( + self, + ) -> Generator[batch_pb2.BatchStreamRequest, None, None]: + yield batch_pb2.BatchStreamRequest( + start=batch_pb2.BatchStreamRequest.Start( + consistency_level=self.__batch_grpc._consistency_level, + ), + ) + while ( + self.__shut_background_thread_down is not None + and not self.__shut_background_thread_down.is_set() + ): + req = self.__reqs.get() + if req is not None: + self.__total += len(req.data.objects.values) + len(req.data.references.values) + yield req + continue + if self.__stop and not ( + self.__is_shutting_down.is_set() or self.__is_shutdown.is_set() + ): + logger.warning("Batching finished, closing the client-side of the stream") + yield batch_pb2.BatchStreamRequest(stop=batch_pb2.BatchStreamRequest.Stop()) + return + if self.__is_shutting_down.is_set(): + logger.warning("Server shutting down, closing the client-side of the stream") + return + logger.warning("Received sentinel, but not stopping, continuing...") + + def __batch_recv(self) -> None: + for message in self.__batch_grpc.stream( + connection=self.__connection, + requests=self.__generate_stream_requests_for_grpc(), + ): + if message.HasField("started"): + logger.warning("Batch stream started successfully") + for threads in self.__bg_threads: + threads.start_send() + if message.HasField("backoff"): + if ( + message.backoff.batch_size != self.__batch_size + and not self.__is_shutting_down.is_set() + and not self.__is_shutdown.is_set() + and not self.__stop + ): + self.__batch_size = message.backoff.batch_size + logger.warning( + f"Updated batch size to {self.__batch_size} as per server request" + ) + if message.HasField("results"): + result_objs = BatchObjectReturn() + result_refs = BatchReferenceReturn() + failed_objs: List[ErrorObject] = [] + failed_refs: List[ErrorReference] = [] + for error in message.results.errors: + if error.HasField("uuid"): + try: + cached = self.__objs_cache.pop(error.uuid) + except KeyError: + continue + err = ErrorObject( + message=error.error, + object_=cached, + ) + result_objs += BatchObjectReturn( + _all_responses=[err], + errors={cached.index: err}, + ) + failed_objs.append(err) + logger.warning( + { + "error": error.error, + "object": error.uuid, + "action": "use {client,collection}.batch.failed_objects to access this error", + } + ) + if error.HasField("beacon"): + try: + cached = self.__refs_cache.pop(error.beacon) + except KeyError: + continue + err = ErrorReference( + message=error.error, + reference=error.beacon, # pyright: ignore + ) + failed_refs.append(err) + result_refs += BatchReferenceReturn( + errors={cached.index: err}, + ) + logger.warning( + { + "error": error.error, + "reference": error.beacon, + "action": "use {client,collection}.batch.failed_references to access this error", + } + ) + for success in message.results.successes: + if success.HasField("uuid"): + try: + cached = self.__objs_cache.pop(success.uuid) + except KeyError: + continue + uuid = uuid_package.UUID(success.uuid) + result_objs += BatchObjectReturn( + _all_responses=[uuid], + uuids={cached.index: uuid}, + ) + if success.HasField("beacon"): + try: + self.__refs_cache.pop(success.beacon, None) + except KeyError: + continue + with self.__results_lock: + self.__results_for_wrapper.results.objs += result_objs + self.__results_for_wrapper.results.refs += result_refs + self.__results_for_wrapper.failed_objects.extend(failed_objs) + self.__results_for_wrapper.failed_references.extend(failed_refs) + elif message.HasField("shutting_down"): + logger.warning( + "Received shutting down message from server, pausing sending until stream is re-established" + ) + self.__is_shutting_down.set() + elif message.HasField("shutdown"): + logger.warning("Received shutdown finished message from server") + self.__is_shutdown.set() + self.__is_shutting_down.clear() + self.__reconnect() + + # restart the stream if we were shutdown by the node we were connected to ensuring that the index is + # propagated properly from it to the new one + if self.__is_shutdown.is_set(): + logger.warning("Restarting batch recv after shutdown...") + self.__is_shutdown.clear() + return self.__batch_recv() + else: + logger.warning("Server closed the stream from its side, shutting down batch") + return + + def __reconnect(self, retry: int = 0) -> None: + if self.__consistency_level == ConsistencyLevel.ALL: + # check that all nodes are available before reconnecting + cluster = _ClusterBatch(self.__connection) + while len(nodes := cluster.get_nodes_status()) != 3 or any( + node["status"] != "HEALTHY" for node in nodes + ): + logger.warning( + "Waiting for all nodes to be HEALTHY before reconnecting to batch stream due to CL=ALL..." + ) + time.sleep(5) + try: + logger.warning(f"Trying to reconnect after shutdown... {retry + 1}/{5}") + self.__connection.close("sync") + self.__connection.connect(force=True) + logger.warning("Reconnected successfully") + except (WeaviateStartUpError, WeaviateGRPCUnavailableError) as e: + if retry < 5: + time.sleep(2**retry) + self.__reconnect(retry + 1) + else: + logger.error("Failed to reconnect after 5 attempts") + self.__bg_thread_exception = e + + def __start_bg_threads(self) -> _BgThreads: + """Create a background thread that periodically checks how congested the batch queue is.""" + self.__shut_background_thread_down = threading.Event() + + def batch_send_wrapper() -> None: + try: + self.__batch_send() + logger.warning("exited batch send thread") + except Exception as e: + logger.error(e) + self.__bg_thread_exception = e + + def batch_recv_wrapper() -> None: + socket_hung_up = False + try: + self.__batch_recv() + logger.warning("exited batch receive thread") + except Exception as e: + if isinstance(e, WeaviateBatchStreamError) and ( + "Socket closed" in e.message or "context canceled" in e.message + ): + socket_hung_up = True + else: + logger.error(e) + logger.error(type(e)) + self.__bg_thread_exception = e + if socket_hung_up: + # this happens during ungraceful shutdown of the coordinator + # lets restart the stream and add the cached objects again + logger.warning("Stream closed unexpectedly, restarting...") + self.__reconnect() + # server sets this whenever it restarts, gracefully or unexpectedly, so need to clear it now + self.__is_shutting_down.clear() + with self.__objs_cache_lock: + logger.warning( + f"Re-adding {len(self.__objs_cache)} cached objects to the batch" + ) + self.__batch_objects.prepend( + [ + self.__batch_grpc.grpc_object(o._to_internal()) + for o in self.__objs_cache.values() + ] + ) + with self.__refs_cache_lock: + self.__batch_references.prepend( + [ + self.__batch_grpc.grpc_reference(o._to_internal()) + for o in self.__refs_cache.values() + ] + ) + # start a new stream with a newly reconnected channel + return batch_recv_wrapper() + + threads = _BgThreads( + send=threading.Thread( + target=batch_send_wrapper, + daemon=True, + name="BgBatchSend", + ), + recv=threading.Thread( + target=batch_recv_wrapper, + daemon=True, + name="BgBatchRecv", + ), + ) + threads.start_recv() + return threads + + def flush(self) -> None: + """Flush the batch queue and wait for all requests to be finished.""" + # bg thread is sending objs+refs automatically, so simply wait for everything to be done + while len(self.__batch_objects) > 0 or len(self.__batch_references) > 0: + time.sleep(0.01) + self.__check_bg_threads_alive() + + def _add_object( + self, + collection: str, + properties: Optional[WeaviateProperties] = None, + references: Optional[ReferenceInputs] = None, + uuid: Optional[UUID] = None, + vector: Optional[VECTORS] = None, + tenant: Optional[str] = None, + ) -> UUID: + self.__check_bg_threads_alive() + try: + batch_object = BatchObject( + collection=collection, + properties=properties, + references=references, + uuid=uuid, + vector=vector, + tenant=tenant, + index=self.__objs_count, + ) + self.__results_for_wrapper.imported_shards.add( + Shard(collection=collection, tenant=tenant) + ) + except ValidationError as e: + raise WeaviateBatchValidationError(repr(e)) + uuid = str(batch_object.uuid) + with self.__uuid_lookup_lock: + self.__uuid_lookup.add(uuid) + self.__batch_objects.add(self.__batch_grpc.grpc_object(batch_object._to_internal())) + with self.__objs_cache_lock: + self.__objs_cache[uuid] = batch_object + self.__objs_count += 1 + + # block if queue gets too long or weaviate is overloaded - reading files is faster them sending them so we do + # not need a long queue + while len(self.__batch_objects) >= self.__batch_size * 2: + self.__check_bg_threads_alive() + time.sleep(0.01) + + assert batch_object.uuid is not None + return batch_object.uuid + + def _add_reference( + self, + from_object_uuid: UUID, + from_object_collection: str, + from_property_name: str, + to: ReferenceInput, + tenant: Optional[str] = None, + ) -> None: + self.__check_bg_threads_alive() + if isinstance(to, ReferenceToMulti): + to_strs: Union[List[str], List[UUID]] = to.uuids_str + elif isinstance(to, str) or isinstance(to, uuid_package.UUID): + to_strs = [to] + else: + to_strs = list(to) + + for uid in to_strs: + try: + batch_reference = BatchReference( + from_object_collection=from_object_collection, + from_object_uuid=from_object_uuid, + from_property_name=from_property_name, + to_object_collection=( + to.target_collection if isinstance(to, ReferenceToMulti) else None + ), + to_object_uuid=uid, + tenant=tenant, + index=self.__refs_count, + ) + except ValidationError as e: + raise WeaviateBatchValidationError(repr(e)) + self.__batch_references.add( + self.__batch_grpc.grpc_reference(batch_reference._to_internal()) + ) + with self.__refs_cache_lock: + self.__refs_cache[batch_reference._to_beacon()] = batch_reference + self.__refs_count += 1 + + def __check_bg_threads_alive(self) -> None: + if self.__any_threads_alive(): + return + + raise self.__bg_thread_exception or Exception("Batch thread died unexpectedly") diff --git a/weaviate/collections/collection/sync.py b/weaviate/collections/collection/sync.py index 88f728b30..6f9369cec 100644 --- a/weaviate/collections/collection/sync.py +++ b/weaviate/collections/collection/sync.py @@ -7,7 +7,7 @@ from weaviate.collections.backups import _CollectionBackup from weaviate.collections.batch.collection import ( _BatchCollection, - _BatchCollectionNew, + _BatchCollectionSync, _BatchCollectionWrapper, ) from weaviate.collections.classes.cluster import Shard @@ -101,10 +101,8 @@ def __init__( name, tenant, config, - batch_client=_BatchCollectionNew[Properties] - if connection._weaviate_version.is_at_least( - 1, 32, 0 - ) # todo: change to 1.33.0 when it lands + batch_client=_BatchCollectionSync[Properties] + if connection._weaviate_version.is_at_least(1, 34, 0) else _BatchCollection[Properties], ) """This namespace contains all the functionality to upload data in batches to Weaviate for this specific collection.""" diff --git a/weaviate/collections/data/async_.pyi b/weaviate/collections/data/async_.pyi index 28dd4e2e4..15108447a 100644 --- a/weaviate/collections/data/async_.pyi +++ b/weaviate/collections/data/async_.pyi @@ -1,6 +1,10 @@ import uuid as uuid_package from typing import Generic, List, Literal, Optional, Sequence, Union, overload +from weaviate.collections.batch.collection import _BatchCollectionWrapper +from weaviate.collections.batch.grpc_batch import _BatchGRPC +from weaviate.collections.batch.grpc_batch_delete import _BatchDeleteGRPC +from weaviate.collections.batch.rest import _BatchREST from weaviate.collections.classes.batch import ( BatchObjectReturn, BatchReferenceReturn, @@ -23,6 +27,11 @@ from .executor import _DataCollectionExecutor class _DataCollectionAsync( Generic[Properties,], _DataCollectionExecutor[ConnectionAsync, Properties] ): + __batch_delete: _BatchDeleteGRPC + __batch_grpc: _BatchGRPC + __batch_rest: _BatchREST + __batch: _BatchCollectionWrapper[Properties] + async def insert( self, properties: Properties, diff --git a/weaviate/collections/data/executor.py b/weaviate/collections/data/executor.py index eb63a744d..8d6d12d40 100644 --- a/weaviate/collections/data/executor.py +++ b/weaviate/collections/data/executor.py @@ -19,6 +19,7 @@ from httpx import Response +from weaviate.collections.batch.collection import _BatchCollectionWrapper from weaviate.collections.batch.grpc_batch import _BatchGRPC from weaviate.collections.batch.grpc_batch_delete import _BatchDeleteGRPC from weaviate.collections.batch.rest import _BatchREST @@ -57,6 +58,11 @@ class _DataCollectionExecutor(Generic[ConnectionType, Properties]): + __batch_delete: _BatchDeleteGRPC + __batch_grpc: _BatchGRPC + __batch_rest: _BatchREST + __batch: _BatchCollectionWrapper[Properties] + def __init__( self, connection: ConnectionType, diff --git a/weaviate/collections/data/sync.pyi b/weaviate/collections/data/sync.pyi index 3fa145a4e..eda3da21a 100644 --- a/weaviate/collections/data/sync.pyi +++ b/weaviate/collections/data/sync.pyi @@ -1,6 +1,10 @@ import uuid as uuid_package from typing import Generic, List, Literal, Optional, Sequence, Union, overload +from weaviate.collections.batch.collection import _BatchCollectionWrapper +from weaviate.collections.batch.grpc_batch import _BatchGRPC +from weaviate.collections.batch.grpc_batch_delete import _BatchDeleteGRPC +from weaviate.collections.batch.rest import _BatchREST from weaviate.collections.classes.batch import ( BatchObjectReturn, BatchReferenceReturn, @@ -21,6 +25,11 @@ from weaviate.types import UUID, VECTORS from .executor import _DataCollectionExecutor class _DataCollection(Generic[Properties,], _DataCollectionExecutor[ConnectionSync, Properties]): + __batch_delete: _BatchDeleteGRPC + __batch_grpc: _BatchGRPC + __batch_rest: _BatchREST + __batch: _BatchCollectionWrapper[Properties] + def insert( self, properties: Properties, diff --git a/weaviate/connect/v4.py b/weaviate/connect/v4.py index 3734f650a..ecdaf666b 100644 --- a/weaviate/connect/v4.py +++ b/weaviate/connect/v4.py @@ -20,13 +20,14 @@ overload, ) +import grpc from authlib.integrations.httpx_client import ( # type: ignore AsyncOAuth2Client, OAuth2Client, ) from grpc import Call, RpcError, StatusCode from grpc import Channel as SyncChannel # type: ignore -from grpc.aio import AioRpcError +from grpc.aio import AioRpcError, StreamStreamCall from grpc.aio import Channel as AsyncChannel # type: ignore # from grpclib.client import Channel @@ -1114,8 +1115,8 @@ def grpc_aggregate( class ConnectionAsync(_ConnectionBase): """Connection class used to communicate to a weaviate instance.""" - async def connect(self) -> None: - if self._connected: + async def connect(self, force: bool = False) -> None: + if self._connected and not force: return None await executor.aresult(self._open_connections_rest(self._auth, "async")) @@ -1247,6 +1248,52 @@ async def grpc_batch_delete( raise InsufficientPermissionsError(e) raise WeaviateDeleteManyError(str(e)) + def grpc_batch_stream( + self, + ) -> StreamStreamCall[batch_pb2.BatchStreamRequest, batch_pb2.BatchStreamReply]: + assert isinstance(self._grpc_channel, grpc.aio.Channel) + return self._grpc_channel.stream_stream( + "/weaviate.v1.Weaviate/BatchStream", + request_serializer=batch_pb2.BatchStreamRequest.SerializeToString, + response_deserializer=batch_pb2.BatchStreamReply.FromString, + )( + request_iterator=None, + timeout=self.timeout_config.insert, + metadata=self.grpc_headers(), + ) + + async def grpc_batch_stream_write( + self, + stream: StreamStreamCall[batch_pb2.BatchStreamRequest, batch_pb2.BatchStreamReply], + request: batch_pb2.BatchStreamRequest, + ) -> None: + try: + await stream.write(request) + except AioRpcError as e: + error = cast(Call, e) + if error.code() == StatusCode.PERMISSION_DENIED: + raise InsufficientPermissionsError(error) + if error.code() == StatusCode.ABORTED: + raise _BatchStreamShutdownError() + raise WeaviateBatchStreamError(str(error.details())) + + async def grpc_batch_stream_read( + self, + stream: StreamStreamCall[batch_pb2.BatchStreamRequest, batch_pb2.BatchStreamReply], + ) -> Optional[batch_pb2.BatchStreamReply]: + try: + msg = await stream.read() + if not isinstance(msg, batch_pb2.BatchStreamReply): + return None + return msg + except AioRpcError as e: + error = cast(Call, e) + if error.code() == StatusCode.PERMISSION_DENIED: + raise InsufficientPermissionsError(error) + if error.code() == StatusCode.ABORTED: + raise _BatchStreamShutdownError() + raise WeaviateBatchStreamError(str(error.details())) + async def grpc_tenants_get( self, request: tenants_pb2.TenantsGetRequest ) -> tenants_pb2.TenantsGetReply: