Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 92 additions & 6 deletions integration/test_batch_v4.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
import concurrent.futures
import uuid
from dataclasses import dataclass
from typing import Callable, Generator, List, Optional, Protocol, Tuple

import pytest
import pytest_asyncio
from _pytest.fixtures import SubRequest

import weaviate
Expand Down Expand Up @@ -119,6 +121,53 @@ def _factory(
client_fixture.close()


class AsyncClientFactory(Protocol):
"""Typing for fixture."""

async def __call__(
self, name: str = "", ports: Tuple[int, int] = (8080, 50051), multi_tenant: bool = False
) -> Tuple[weaviate.WeaviateAsyncClient, str]:
"""Typing for fixture."""
...


@pytest_asyncio.fixture
async def async_client_factory(request: SubRequest):
name_fixtures: List[str] = []
client_fixture: Optional[weaviate.WeaviateAsyncClient] = None

async def _factory(
name: str = "", ports: Tuple[int, int] = (8080, 50051), multi_tenant: bool = False
):
nonlocal client_fixture, name_fixtures # noqa: F824
name_fixture = _sanitize_collection_name(request.node.name) + name
name_fixtures.append(name_fixture)
if client_fixture is None:
client_fixture = weaviate.use_async_with_local(grpc_port=ports[1], port=ports[0])
await client_fixture.connect()

if await client_fixture.collections.exists(name_fixture):
await client_fixture.collections.delete(name_fixture)

await client_fixture.collections.create(
name=name_fixture,
properties=[
Property(name="name", data_type=DataType.TEXT),
Property(name="age", data_type=DataType.INT),
],
references=[ReferenceProperty(name="test", target_collection=name_fixture)],
multi_tenancy_config=Configure.multi_tenancy(multi_tenant),
vectorizer_config=Configure.Vectorizer.none(),
)
return client_fixture, name_fixture

try:
yield _factory
finally:
if client_fixture is not None:
await client_fixture.close()


def test_add_objects_in_multiple_batches(client_factory: ClientFactory) -> None:
client, name = client_factory()
with client.batch.rate_limit(50) as batch:
Expand Down Expand Up @@ -365,15 +414,15 @@ def test_add_ref_batch_with_tenant(client_factory: ClientFactory) -> None:
@pytest.mark.parametrize(
"batching_method",
[
# lambda client: client.batch.dynamic(),
# lambda client: client.batch.fixed_size(),
# lambda client: client.batch.rate_limit(9999),
lambda client: client.batch.dynamic(),
lambda client: client.batch.fixed_size(),
lambda client: client.batch.rate_limit(9999),
lambda client: client.batch.experimental(concurrency=1),
],
ids=[
# "test_add_ten_thousand_data_objects_dynamic",
# "test_add_ten_thousand_data_objects_fixed_size",
# "test_add_ten_thousand_data_objects_rate_limit",
"test_add_ten_thousand_data_objects_dynamic",
"test_add_ten_thousand_data_objects_fixed_size",
"test_add_ten_thousand_data_objects_rate_limit",
"test_add_ten_thousand_data_objects_experimental",
],
)
Expand Down Expand Up @@ -768,3 +817,40 @@ def test_references_with_to_uuids(client_factory: ClientFactory) -> None:

assert len(client.batch.failed_references) == 0, client.batch.failed_references
client.collections.delete(["target", "source"])


@pytest.mark.asyncio
async def test_add_ten_thousand_data_objects_async(
async_client_factory: AsyncClientFactory,
) -> None:
"""Test adding ten thousand data objects."""
client, name = await async_client_factory()
if client._connection._weaviate_version.is_lower_than(1, 34, 0):
pytest.skip("Server-side batching not supported in Weaviate < 1.34.0")
nr_objects = 100000
import time

start = time.time()
async with client.batch.experimental(concurrency=1) as batch:
async for i in arange(nr_objects):
await batch.add_object(
collection=name,
properties={"name": "test" + str(i)},
)
end = time.time()
print(f"Time taken to add {nr_objects} objects: {end - start} seconds")
assert len(client.batch.results.objs.errors) == 0
assert len(client.batch.results.objs.all_responses) == nr_objects
assert len(client.batch.results.objs.uuids) == nr_objects
assert await client.collections.use(name).length() == nr_objects
assert client.batch.results.objs.has_errors is False
assert len(client.batch.failed_objects) == 0, [
obj.message for obj in client.batch.failed_objects
]
await client.collections.delete(name)


async def arange(count):
for i in range(count):
yield i
await asyncio.sleep(0.0)
3 changes: 2 additions & 1 deletion weaviate/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .auth import AuthCredentials
from .backup import _Backup, _BackupAsync
from .cluster import _Cluster, _ClusterAsync
from .collections.batch.client import _BatchClientWrapper
from .collections.batch.client import _BatchClientWrapper, _BatchClientWrapperAsync
from .collections.collections import _Collections, _CollectionsAsync
from .config import AdditionalConfig
from .connect import executor
Expand Down Expand Up @@ -76,6 +76,7 @@ def __init__(
)
self.alias = _AliasAsync(self._connection)
self.backup = _BackupAsync(self._connection)
self.batch = _BatchClientWrapperAsync(self._connection)
self.cluster = _ClusterAsync(self._connection)
self.collections = _CollectionsAsync(self._connection)
self.debug = _DebugAsync(self._connection)
Expand Down
3 changes: 2 additions & 1 deletion weaviate/client.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ from weaviate.users.sync import _Users

from .backup import _Backup, _BackupAsync
from .cluster import _Cluster, _ClusterAsync
from .collections.batch.client import _BatchClientWrapper
from .collections.batch.client import _BatchClientWrapper, _BatchClientWrapperAsync
from .debug import _Debug, _DebugAsync
from .rbac import _Roles, _RolesAsync
from .types import NUMBER
Expand All @@ -29,6 +29,7 @@ class WeaviateAsyncClient(_WeaviateClientExecutor[ConnectionAsync]):
_connection: ConnectionAsync
alias: _AliasAsync
backup: _BackupAsync
batch: _BatchClientWrapperAsync
collections: _CollectionsAsync
cluster: _ClusterAsync
debug: _DebugAsync
Expand Down
Loading