diff --git a/hazelcast/asyncio/client.py b/hazelcast/asyncio/client.py index 758ae7011f..769a15c9c1 100644 --- a/hazelcast/asyncio/client.py +++ b/hazelcast/asyncio/client.py @@ -5,13 +5,14 @@ from hazelcast.internal.asyncio_cluster import ClusterService, _InternalClusterService from hazelcast.internal.asyncio_compact import CompactSchemaService -from hazelcast.config import Config +from hazelcast.config import Config, IndexConfig from hazelcast.internal.asyncio_connection import ConnectionManager, DefaultAddressProvider from hazelcast.core import DistributedObjectEvent, DistributedObjectInfo from hazelcast.cp import CPSubsystem, ProxySessionManager from hazelcast.discovery import HazelcastCloudAddressProvider from hazelcast.errors import IllegalStateError, InvalidConfigurationError from hazelcast.internal.asyncio_invocation import InvocationService, Invocation +from hazelcast.internal.asyncio_proxy.vector_collection import VectorCollection from hazelcast.lifecycle import LifecycleService, LifecycleState, _InternalLifecycleService from hazelcast.internal.asyncio_listener import ClusterViewListenerService, ListenerService from hazelcast.near_cache import NearCacheManager @@ -20,10 +21,12 @@ client_add_distributed_object_listener_codec, client_get_distributed_objects_codec, client_remove_distributed_object_listener_codec, + dynamic_config_add_vector_collection_config_codec, ) from hazelcast.internal.asyncio_proxy.manager import ( MAP_SERVICE, ProxyManager, + VECTOR_SERVICE, ) from hazelcast.internal.asyncio_proxy.base import Proxy from hazelcast.internal.asyncio_proxy.map import Map @@ -164,7 +167,7 @@ def _init_context(self): async def _start(self): try: self._internal_lifecycle_service.start() - self._invocation_service.start() + await self._invocation_service.start() membership_listeners = self._config.membership_listeners self._internal_cluster_service.start(self._connection_manager, membership_listeners) self._cluster_view_listener.start() @@ -185,6 +188,37 @@ async def _start(self): async def get_map(self, name: str) -> Map[KeyType, ValueType]: return await self._proxy_manager.get_or_create(MAP_SERVICE, name) + async def create_vector_collection_config( + self, + name: str, + indexes: typing.List[IndexConfig], + backup_count: int = 1, + async_backup_count: int = 0, + split_brain_protection_name: typing.Optional[str] = None, + merge_policy: str = "PutIfAbsentMergePolicy", + merge_batch_size: int = 100, + ) -> None: + # check that indexes have different names + if indexes: + index_names = set(index.name for index in indexes) + if len(index_names) != len(indexes): + raise AssertionError("index names must be unique") + + request = dynamic_config_add_vector_collection_config_codec.encode_request( + name, + indexes, + backup_count, + async_backup_count, + split_brain_protection_name, + merge_policy, + merge_batch_size, + ) + invocation = Invocation(request, response_handler=lambda m: m) + await self._invocation_service.ainvoke(invocation) + + async def get_vector_collection(self, name: str) -> VectorCollection: + return await self._proxy_manager.get_or_create(VECTOR_SERVICE, name) + async def add_distributed_object_listener( self, listener_func: typing.Callable[[DistributedObjectEvent], None] ) -> str: diff --git a/hazelcast/internal/asyncio_compact.py b/hazelcast/internal/asyncio_compact.py index 06f53ab97c..db2d292a6d 100644 --- a/hazelcast/internal/asyncio_compact.py +++ b/hazelcast/internal/asyncio_compact.py @@ -57,35 +57,35 @@ def fetch_schema(self, schema_id: int) -> asyncio.Future: self._invocation_service.invoke(fetch_schema_invocation) return fetch_schema_invocation.future - def send_schema_and_retry( + async def send_schema_and_retry( self, error: "SchemaNotReplicatedError", func: typing.Callable[..., asyncio.Future], *args: typing.Any, **kwargs: typing.Any, - ) -> asyncio.Future: + ) -> None: schema = error.schema clazz = error.clazz request = client_send_schema_codec.encode_request(schema) - def callback(): + async def callback(): self._has_replicated_schemas = True self._compact_serializer.register_schema_to_type(schema, clazz) - return func(*args, **kwargs) + return await func(*args, **kwargs) - return self._replicate_schema( - schema, request, CompactSchemaService._SEND_SCHEMA_RETRY_COUNT, callback + return await self._replicate_schema( + schema, request, CompactSchemaService._SEND_SCHEMA_RETRY_COUNT, callback() ) - def _replicate_schema( + async def _replicate_schema( self, schema: "Schema", request: "OutboundMessage", remaining_retries: int, - callback: typing.Callable[..., asyncio.Future], - ) -> asyncio.Future: - def continuation(future: asyncio.Future): - replicated_members = future.result() + callback: typing.Coroutine[typing.Any, typing.Any, typing.Any], + ) -> None: + while remaining_retries >= 2: + replicated_members = await self._send_schema_replication_request(request) members = self._cluster_service.get_members() for member in members: if member.uuid not in replicated_members: @@ -93,41 +93,25 @@ def continuation(future: asyncio.Future): else: # Loop completed normally. # All members in our member list all known to have the schema - return callback() + return await callback # There is a member in our member list that the schema # is not known to be replicated yet. We should retry # sending it in a random member. - if remaining_retries <= 1: - # We tried to send it a couple of times, but the member list - # in our local and the member list returned by the initiator - # nodes did not match. - raise IllegalStateError( - f"The schema {schema} cannot be replicated in the cluster, " - f"after {CompactSchemaService._SEND_SCHEMA_RETRY_COUNT} retries. " - f"It might be the case that the client is connected to the two " - f"halves of the cluster that is experiencing a split-brain, " - f"and continue putting the data associated with that schema " - f"might result in data loss. It might be possible to replicate " - f"the schema after some time, when the cluster is healed." - ) - - delayed_future: asyncio.Future = asyncio.get_running_loop().create_future() - self._reactor.add_timer( - self._invocation_retry_pause, - lambda: delayed_future.set_result(None), - ) - - def retry(_): - return self._replicate_schema( - schema, request.copy(), remaining_retries - 1, callback - ) - - return delayed_future.add_done_callback(retry) - - fut = self._send_schema_replication_request(request) - fut.add_done_callback(continuation) - return fut + await asyncio.sleep(self._invocation_retry_pause) + + # We tried to send it a couple of times, but the member list + # in our local and the member list returned by the initiator + # nodes did not match. + raise IllegalStateError( + f"The schema {schema} cannot be replicated in the cluster, " + f"after {CompactSchemaService._SEND_SCHEMA_RETRY_COUNT} retries. " + f"It might be the case that the client is connected to the two " + f"halves of the cluster that is experiencing a split-brain, " + f"and continue putting the data associated with that schema " + f"might result in data loss. It might be possible to replicate " + f"the schema after some time, when the cluster is healed." + ) def _send_schema_replication_request(self, request: "OutboundMessage") -> asyncio.Future: invocation = Invocation(request, response_handler=client_send_schema_codec.decode_response) diff --git a/hazelcast/internal/asyncio_invocation.py b/hazelcast/internal/asyncio_invocation.py index 591740faa0..f7effc6e5b 100644 --- a/hazelcast/internal/asyncio_invocation.py +++ b/hazelcast/internal/asyncio_invocation.py @@ -96,7 +96,7 @@ def __init__(self, client, config, reactor): self._backup_ack_to_client_enabled = smart_routing and config.backup_ack_to_client_enabled self._fail_on_indeterminate_state = config.fail_on_indeterminate_operation_state self._backup_timeout = config.operation_backup_timeout - self._clean_resources_timer = None + self._clean_resources_task = None self._shutdown = False self._compact_schema_service = None @@ -107,8 +107,8 @@ def init(self, partition_service, connection_manager, listener_service, compact_ self._check_invocation_allowed_fn = connection_manager.check_invocation_allowed self._compact_schema_service = compact_schema_service - def start(self): - self._start_clean_resources_timer() + async def start(self): + await self._start_clean_resources_timer() async def add_backup_listener(self): if self._backup_ack_to_client_enabled: @@ -152,8 +152,8 @@ def shutdown(self): return self._shutdown = True - if self._clean_resources_timer: - self._clean_resources_timer.cancel() + if self._clean_resources_task: + self._clean_resources_task.cancel() for invocation in list(self._pending.values()): self._notify_error(invocation, HazelcastClientNotActiveError()) @@ -400,8 +400,9 @@ def _notify_backup_complete(self, invocation): self._complete(invocation, invocation.pending_response) - def _start_clean_resources_timer(self): - def run(): + async def _start_clean_resources_timer(self): + async def run(): + await asyncio.sleep(self._CLEAN_RESOURCES_PERIOD) if self._shutdown: return @@ -419,9 +420,9 @@ def run(): if self._backup_ack_to_client_enabled: self._detect_and_handle_backup_timeout(invocation, now) - self._clean_resources_timer = self._reactor.add_timer(self._CLEAN_RESOURCES_PERIOD, run) + self._clean_resources_task = asyncio.create_task(run()) - self._clean_resources_timer = self._reactor.add_timer(self._CLEAN_RESOURCES_PERIOD, run) + self._clean_resources_task = asyncio.create_task(run()) def _detect_and_handle_backup_timeout(self, invocation, now): if not invocation.pending_response: diff --git a/hazelcast/internal/asyncio_proxy/base.py b/hazelcast/internal/asyncio_proxy/base.py index 4d4ba8b4e1..6219dc22a2 100644 --- a/hazelcast/internal/asyncio_proxy/base.py +++ b/hazelcast/internal/asyncio_proxy/base.py @@ -65,15 +65,14 @@ def _invoke_on_target( self._invocation_service.invoke(invocation) return invocation.future - def _invoke_on_key( + async def _invoke_on_key( self, request, key_data, response_handler=_no_op_response_handler - ) -> asyncio.Future: + ) -> typing.Any: partition_id = self._partition_service.get_partition_id(key_data) invocation = Invocation( request, partition_id=partition_id, response_handler=response_handler ) - self._invocation_service.invoke(invocation) - return invocation.future + return await self._invocation_service.ainvoke(invocation) def _invoke_on_partition( self, request, partition_id, response_handler=_no_op_response_handler diff --git a/hazelcast/internal/asyncio_proxy/manager.py b/hazelcast/internal/asyncio_proxy/manager.py index a5028addca..9daeca0e1f 100644 --- a/hazelcast/internal/asyncio_proxy/manager.py +++ b/hazelcast/internal/asyncio_proxy/manager.py @@ -1,5 +1,9 @@ import typing +from hazelcast.internal.asyncio_proxy.vector_collection import ( + VectorCollection, + create_vector_collection_proxy, +) from hazelcast.protocol.codec import client_create_proxy_codec, client_destroy_proxy_codec from hazelcast.internal.asyncio_invocation import Invocation from hazelcast.internal.asyncio_proxy.base import Proxy @@ -7,12 +11,14 @@ from hazelcast.util import to_list MAP_SERVICE = "hz:impl:mapService" +VECTOR_SERVICE = "hz:service:vector" _proxy_init: typing.Dict[ str, typing.Callable[[str, str, typing.Any], typing.Coroutine[typing.Any, typing.Any, typing.Any]], ] = { MAP_SERVICE: create_map_proxy, + VECTOR_SERVICE: create_vector_collection_proxy, } diff --git a/hazelcast/internal/asyncio_proxy/map.py b/hazelcast/internal/asyncio_proxy/map.py index a7e20bed55..84c8ecaa14 100644 --- a/hazelcast/internal/asyncio_proxy/map.py +++ b/hazelcast/internal/asyncio_proxy/map.py @@ -5,7 +5,6 @@ from hazelcast.aggregator import Aggregator from hazelcast.config import IndexUtil, IndexType, IndexConfig from hazelcast.core import SimpleEntryView -from hazelcast.errors import InvalidConfigurationError from hazelcast.projection import Projection from hazelcast.protocol import PagingPredicateHolder from hazelcast.protocol.codec import ( @@ -273,7 +272,7 @@ async def add_interceptor(self, interceptor: typing.Any) -> str: try: interceptor_data = self._to_data(interceptor) except SchemaNotReplicatedError as e: - return self._send_schema_and_retry(e, self.add_interceptor, interceptor) + return await self._send_schema_and_retry(e, self.add_interceptor, interceptor) request = map_add_interceptor_codec.encode_request(self.name, interceptor_data) return await self._invoke(request, map_add_interceptor_codec.decode_response) @@ -873,7 +872,7 @@ def _delete_internal(self, key_data): request = map_delete_codec.encode_request(self.name, key_data, thread_id()) return self._invoke_on_key(request, key_data) - def _put_internal(self, key_data, value_data, ttl, max_idle): + async def _put_internal(self, key_data, value_data, ttl, max_idle): def handler(message): return self._to_object(map_put_codec.decode_response(message)) @@ -885,7 +884,7 @@ def handler(message): request = map_put_codec.encode_request( self.name, key_data, value_data, thread_id(), to_millis(ttl) ) - return self._invoke_on_key(request, key_data, handler) + return await self._invoke_on_key(request, key_data, handler) def _set_internal(self, key_data, value_data, ttl, max_idle): if max_idle is not None: @@ -1107,9 +1106,11 @@ def _put_transient_internal(self, key_data, value_data, ttl, max_idle): key_data, value_data, ttl, max_idle ) - def _put_internal(self, key_data, value_data, ttl, max_idle): + async def _put_internal(self, key_data, value_data, ttl, max_idle): self._invalidate_cache(key_data) - return super(MapFeatNearCache, self)._put_internal(key_data, value_data, ttl, max_idle) + return await super(MapFeatNearCache, self)._put_internal( + key_data, value_data, ttl, max_idle + ) def _put_if_absent_internal(self, key_data, value_data, ttl, max_idle): self._invalidate_cache(key_data) diff --git a/hazelcast/internal/asyncio_proxy/vector_collection.py b/hazelcast/internal/asyncio_proxy/vector_collection.py new file mode 100644 index 0000000000..4abe7fb3b2 --- /dev/null +++ b/hazelcast/internal/asyncio_proxy/vector_collection.py @@ -0,0 +1,255 @@ +import asyncio +import copy +import typing +import uuid +from typing import Any, Dict, List, Optional, Tuple + +from hazelcast.protocol.codec import ( + vector_collection_set_codec, + vector_collection_get_codec, + vector_collection_search_near_vector_codec, + vector_collection_delete_codec, + vector_collection_put_codec, + vector_collection_put_if_absent_codec, + vector_collection_remove_codec, + vector_collection_put_all_codec, + vector_collection_clear_codec, + vector_collection_optimize_codec, + vector_collection_size_codec, +) +from hazelcast.internal.asyncio_proxy.base import Proxy +from hazelcast.serialization.compact import SchemaNotReplicatedError +from hazelcast.serialization.data import Data +from hazelcast.types import KeyType, ValueType +from hazelcast.util import check_not_none +from hazelcast.vector import ( + Document, + SearchResult, + Vector, + VectorType, + VectorSearchOptions, +) + + +class VectorCollection(Proxy, typing.Generic[KeyType, ValueType]): + def __init__(self, service_name, name, context): + super(VectorCollection, self).__init__(service_name, name, context) + + async def get(self, key: Any) -> Document | None: + check_not_none(key, "key can't be None") + return await self._get_internal(key) + + async def set(self, key: Any, document: Document) -> None: + check_not_none(key, "key can't be None") + check_not_none(document, "document can't be None") + check_not_none(document.value, "document value can't be None") + return await self._set_internal(key, document) + + async def put(self, key: Any, document: Document) -> Document | None: + check_not_none(key, "key can't be None") + check_not_none(document, "document can't be None") + check_not_none(document.value, "document value can't be None") + return await self._put_internal(key, document) + + async def put_all(self, map: Dict[Any, Document]) -> None: + check_not_none(map, "map can't be None") + if not map: + return None + partition_service = self._context.partition_service + partition_map: Dict[int, List[Tuple[Data, Document]]] = {} + for key, doc in map.items(): + check_not_none(key, "key can't be None") + check_not_none(doc, "value can't be None") + doc = copy.copy(doc) + try: + entry = (self._to_data(key), doc) + doc.value = self._to_data(doc.value) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.put_all, map) + + partition_id = partition_service.get_partition_id(entry[0]) + partition_map.setdefault(partition_id, []).append(entry) + + async with asyncio.TaskGroup() as tg: # type: ignore[attr-defined] + for partition_id, entry_list in partition_map.items(): + request = vector_collection_put_all_codec.encode_request(self.name, entry_list) + tg.create_task(self._ainvoke_on_partition(request, partition_id)) + + return None + + async def put_if_absent(self, key: Any, document: Document) -> Document | None: + check_not_none(key, "key can't be None") + check_not_none(document, "document can't be None") + check_not_none(document.value, "document value can't be None") + return await self._put_if_absent_internal(key, document) + + async def search_near_vector( + self, + vector: Vector, + *, + include_value: bool = False, + include_vectors: bool = False, + limit: int = 10, + hints: Dict[str, str] = None + ) -> List[SearchResult]: + check_not_none(vector, "vector can't be None") + if limit <= 0: + raise AssertionError("limit must be positive") + return await self._search_near_vector_internal( + vector, + include_value=include_value, + include_vectors=include_vectors, + limit=limit, + hints=hints, + ) + + async def remove(self, key: Any) -> Document | None: + check_not_none(key, "key can't be None") + return await self._remove_internal(key) + + async def delete(self, key: Any) -> None: + check_not_none(key, "key can't be None") + return await self._delete_internal(key) + + async def optimize(self, index_name: str = None) -> None: + request = vector_collection_optimize_codec.encode_request( + self.name, index_name, uuid.uuid4() + ) + return await self._invoke(request) + + async def clear(self) -> None: + request = vector_collection_clear_codec.encode_request(self.name) + return await self._invoke(request) + + async def size(self) -> int: + request = vector_collection_size_codec.encode_request(self.name) + return await self._invoke(request, vector_collection_size_codec.decode_response) + + async def _set_internal(self, key: Any, document: Document) -> None: + try: + key_data = self._to_data(key) + value_data = self._to_data(document.value) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.set, key, document) + document = copy.copy(document) + document.value = value_data + request = vector_collection_set_codec.encode_request( + self.name, + key_data, + document, + ) + return await self._invoke_on_key(request, key_data) + + async def _get_internal(self, key: Any) -> Any: + def handler(message): + doc = vector_collection_get_codec.decode_response(message) + return self._transform_document(doc) + + try: + key_data = self._to_data(key) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.get, key) + request = vector_collection_get_codec.encode_request( + self.name, + key_data, + ) + return await self._invoke_on_key(request, key_data, response_handler=handler) + + def _search_near_vector_internal( + self, + vector: Vector, + *, + include_value: bool = False, + include_vectors: bool = False, + limit: int = 10, + hints: Dict[str, str] = None + ) -> asyncio.Future[List[SearchResult]]: + def handler(message): + results: List[ + SearchResult + ] = vector_collection_search_near_vector_codec.decode_response(message) + for result in results: + if result.key is not None: + result.key = self._to_object(result.key) + if result.value is not None: + result.value = self._to_object(result.value) + if result.vectors: + for vec in result.vectors: + vec.type = VectorType(vec.type) + return results + + options = VectorSearchOptions( + include_value=include_value, + include_vectors=include_vectors, + limit=limit, + hints=hints or {}, + ) + request = vector_collection_search_near_vector_codec.encode_request( + self.name, + [vector], + options, + ) + return self._invoke(request, response_handler=handler) + + async def _delete_internal(self, key: Any) -> None: + key_data = self._to_data(key) + request = vector_collection_delete_codec.encode_request(self.name, key_data) + return await self._invoke_on_key(request, key_data) + + async def _remove_internal(self, key: Any) -> Document | None: + def handler(message): + doc = vector_collection_remove_codec.decode_response(message) + return self._transform_document(doc) + + key_data = self._to_data(key) + request = vector_collection_remove_codec.encode_request(self.name, key_data) + return await self._invoke_on_key(request, key_data, response_handler=handler) + + async def _put_internal(self, key: Any, document: Document) -> Document | None: + def handler(message): + doc = vector_collection_put_codec.decode_response(message) + return self._transform_document(doc) + + try: + key_data = self._to_data(key) + value_data = self._to_data(document.value) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.set, key, document) + document = copy.copy(document) + document.value = value_data + request = vector_collection_put_codec.encode_request( + self.name, + key_data, + document, + ) + return await self._invoke_on_key(request, key_data, response_handler=handler) + + async def _put_if_absent_internal(self, key: Any, document: Document) -> Document | None: + def handler(message): + doc = vector_collection_put_if_absent_codec.decode_response(message) + return self._transform_document(doc) + + try: + key_data = self._to_data(key) + value_data = self._to_data(document.value) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.set, key, document) + document.value = value_data + request = vector_collection_put_if_absent_codec.encode_request( + self.name, + key_data, + document, + ) + return await self._invoke_on_key(request, key_data, response_handler=handler) + + def _transform_document(self, doc: Optional[Document]) -> Optional[Document]: + if doc is not None: + if doc.value is not None: + doc.value = self._to_object(doc.value) + for vec in doc.vectors: + vec.type = VectorType(vec.type) + return doc + + +async def create_vector_collection_proxy(service_name, name, context): + return VectorCollection(service_name, name, context) diff --git a/hazelcast/internal/asyncio_reactor.py b/hazelcast/internal/asyncio_reactor.py index e1026c78c3..7311f10bfc 100644 --- a/hazelcast/internal/asyncio_reactor.py +++ b/hazelcast/internal/asyncio_reactor.py @@ -172,7 +172,6 @@ def __init__(self, conn: AsyncioConnection): self._write_buf = io.BytesIO() self._write_buf_size = 0 self._recv_buf = None - self._alive = True # asyncio tasks are weakly referenced # storing tasks here in order not to lose them midway # see: https: // docs.python.org / 3 / library / asyncio - task.html # creating-tasks @@ -186,8 +185,10 @@ def connection_made(self, transport: transports.BaseTransport): self._conn._loop.call_soon(self._write_loop) def connection_lost(self, exc): - self._alive = False - task = self._conn._loop.create_task(self._conn.close_connection(str(exc), None)) + _logger.warning("Connection closed by server") + task = self._conn._loop.create_task( + self._conn.close_connection(None, IOError("Connection closed by server")) + ) self._tasks.add(task) task.add_done_callback(self._tasks.discard) return False @@ -213,9 +214,6 @@ def buffer_updated(self, nbytes): if self._conn._reader.length: self._conn._reader.process() - def eof_received(self): - self._alive = False - def _do_write(self): if not self._write_buf_size: return diff --git a/tests/integration/asyncio/proxy/vector_collection_test.py b/tests/integration/asyncio/proxy/vector_collection_test.py new file mode 100644 index 0000000000..d6db835782 --- /dev/null +++ b/tests/integration/asyncio/proxy/vector_collection_test.py @@ -0,0 +1,339 @@ +import os +import unittest + +import pytest + +import hazelcast.errors +from tests.integration.asyncio.base import SingleMemberTestCase +from tests.util import ( + random_string, + compare_client_version, + skip_if_server_version_older_than, + skip_if_client_version_older_than, +) + +try: + from hazelcast.vector import IndexConfig, Metric, Document, Vector, Type +except ImportError: + # backward compatibility + pass + + +@unittest.skipIf( + compare_client_version("5.5") < 0, "Tests the features added in 5.5 version of the client" +) +@pytest.mark.enterprise +class VectorCollectionTest(SingleMemberTestCase): + @classmethod + def configure_cluster(cls): + path = os.path.abspath(__file__) + dir_path = os.path.dirname(path) + with open(os.path.join(dir_path, "../../backward_compatible/proxy/hazelcast.xml")) as f: + return f.read() + + @classmethod + def configure_client(cls, config): + config["cluster_name"] = cls.cluster.id + return config + + async def asyncSetUp(self): + await super().asyncSetUp() + skip_if_server_version_older_than(self, self.client, "5.5") + name = random_string() + await self.client.create_vector_collection_config( + name, [IndexConfig("vector", Metric.COSINE, 3)] + ) + self.vector_collection = await self.client.get_vector_collection(name) + + async def asyncTearDown(self): + await self.vector_collection.destroy() + await super().asyncTearDown() + + async def test_set(self): + doc = Document("v1", Vector("vector", Type.DENSE, [0.1, 0.2, 0.3])) + await self.vector_collection.set("k1", doc) + + async def test_get(self): + doc = await self.vector_collection.get("k1") + self.assertIsNone(doc) + doc = Document("v1", Vector("vector", Type.DENSE, [0.1, 0.2, 0.3])) + await self.vector_collection.set("k1", doc) + got_doc = await self.vector_collection.get("k1") + self.assert_document_equal(got_doc, doc) + + async def test_put(self): + doc = Document("v1", Vector("vector", Type.DENSE, [0.1, 0.2, 0.3])) + doc_old = await self.vector_collection.put("k1", doc) + self.assertIsNone(doc_old) + doc2 = Document("v1", Vector("vector", Type.DENSE, [0.4, 0.5, 0.6])) + doc_old = await self.vector_collection.put("k1", doc2) + self.assert_document_equal(doc_old, doc) + + async def test_delete(self): + doc = await self.vector_collection.get("k1") + self.assertIsNone(doc) + doc = Document("v1", Vector("vector", Type.DENSE, [0.1, 0.2, 0.3])) + await self.vector_collection.set("k1", doc) + await self.vector_collection.delete("k1") + doc = await self.vector_collection.get("k1") + self.assertIsNone(doc) + + async def test_remove(self): + doc = await self.vector_collection.get("k1") + self.assertIsNone(doc) + doc = Document("v1", Vector("vector", Type.DENSE, [0.1, 0.2, 0.3])) + await self.vector_collection.set("k1", doc) + doc2 = await self.vector_collection.remove("k1") + self.assert_document_equal(doc, doc2) + + async def test_put_all(self): + doc1 = self.doc1("v1", [0.1, 0.2, 0.3]) + doc2 = self.doc1("v1", [0.2, 0.3, 0.4]) + await self.vector_collection.put_all( + { + "k1": doc1, + "k2": doc2, + } + ) + k1 = await self.vector_collection.get("k1") + self.assert_document_equal(k1, doc1) + k2 = await self.vector_collection.get("k2") + self.assert_document_equal(k2, doc2) + + async def test_clear(self): + doc = await self.vector_collection.get("k1") + self.assertIsNone(doc) + doc = Document("v1", self.vec1([0.1, 0.2, 0.3])) + await self.vector_collection.set("k1", doc) + await self.vector_collection.clear() + doc = await self.vector_collection.get("k1") + self.assertIsNone(doc) + + async def test_optimize(self): + doc = Document("v1", self.vec1([0.1, 0.2, 0.3])) + await self.vector_collection.set("k1", doc) + # it is hard to observe results of optimize, so just test that the invocation works + await self.vector_collection.optimize() + + async def test_optimize_with_name(self): + doc = Document("v1", self.vec1([0.1, 0.2, 0.3])) + await self.vector_collection.set("k1", doc) + # it is hard to observe results of optimize, so just test that the invocation works + await self.vector_collection.optimize("vector") + + async def test_search_near_vector_include_all(self): + target_doc = self.doc1("v1", [0.3, 0.4, 0.5]) + await self.vector_collection.put_all( + { + "k1": self.doc1("v1", [0.1, 0.2, 0.3]), + "k2": self.doc1("v1", [0.2, 0.3, 0.4]), + "k3": target_doc, + } + ) + result = await self.vector_collection.search_near_vector( + self.vec1([0.2, 0.2, 0.3]), limit=1, include_vectors=True, include_value=True + ) + self.assertEqual(1, len(result)) + self.assert_document_equal(target_doc, result[0]) + self.assertAlmostEqual(0.9973459243774414, result[0].score) + + async def test_search_near_vector_include_none(self): + target_doc = self.doc1("v1", [0.3, 0.4, 0.5]) + await self.vector_collection.put_all( + { + "k1": self.doc1("v1", [0.1, 0.2, 0.3]), + "k2": self.doc1("v1", [0.2, 0.3, 0.4]), + "k3": target_doc, + } + ) + result = await self.vector_collection.search_near_vector( + self.vec1([0.2, 0.2, 0.3]), limit=1, include_vectors=False, include_value=False + ) + self.assertEqual(1, len(result)) + result1 = result[0] + self.assertAlmostEqual(0.9973459243774414, result1.score) + self.assertIsNone(result1.value) + self.assertIsNone(result1.vectors) + + async def test_search_near_vector_hint(self): + # not empty collection is needed for search to do something + doc = Document("v1", self.vec1([0.1, 0.2, 0.3])) + await self.vector_collection.set("k1", doc) + # trigger validation error to check if hint was sent + with self.assertRaises(hazelcast.errors.IllegalArgumentError): + await self.vector_collection.search_near_vector( + self.vec1([0.2, 0.2, 0.3]), + limit=1, + include_vectors=False, + include_value=False, + hints={"partitionLimit": "-1"}, + ) + + async def test_size(self): + self.assertEqual(await self.vector_collection.size(), 0) + doc = Document("v1", Vector("vector", Type.DENSE, [0.1, 0.2, 0.3])) + await self.vector_collection.put("k1", doc) + self.assertEqual(await self.vector_collection.size(), 1) + await self.vector_collection.clear() + self.assertEqual(await self.vector_collection.size(), 0) + + async def test_backup_count_valid_values_pass(self): + skip_if_client_version_older_than(self, "6.0") + name = random_string() + await self.client.create_vector_collection_config( + name, [IndexConfig("vector", Metric.COSINE, 3)], backup_count=2, async_backup_count=2 + ) + await self.client.get_vector_collection(name) + + async def test_backup_count_max_value_pass(self): + skip_if_client_version_older_than(self, "6.0") + name = random_string() + await self.client.create_vector_collection_config( + name, [IndexConfig("vector", Metric.COSINE, 3)], backup_count=6 + ) + await self.client.get_vector_collection(name) + + async def test_backup_count_min_value_pass(self): + skip_if_client_version_older_than(self, "6.0") + name = random_string() + await self.client.create_vector_collection_config( + name, [IndexConfig("vector", Metric.COSINE, 3)], backup_count=0 + ) + await self.client.get_vector_collection(name) + + async def test_backup_count_more_than_max_value_fail(self): + skip_if_server_version_older_than(self, self.client, "6.0") + name = random_string() + # check that the parameter is used by ensuring that it is validated on server side + # there is no simple way to check number of backups + with self.assertRaises(hazelcast.errors.IllegalArgumentError): + await self.client.create_vector_collection_config( + name, + [IndexConfig("vector", Metric.COSINE, 3)], + backup_count=7, + async_backup_count=0, + ) + + async def test_backup_count_less_than_min_value_fail(self): + skip_if_server_version_older_than(self, self.client, "6.0") + name = random_string() + with self.assertRaises(hazelcast.errors.IllegalArgumentError): + await self.client.create_vector_collection_config( + name, [IndexConfig("vector", Metric.COSINE, 3)], backup_count=-1 + ) + + async def test_async_backup_count_max_value_pass(self): + skip_if_client_version_older_than(self, "6.0") + name = random_string() + await self.client.create_vector_collection_config( + name, + [IndexConfig("vector", Metric.COSINE, 3)], + backup_count=0, + async_backup_count=6, + ) + await self.client.get_vector_collection(name) + + async def test_async_backup_count_min_value_pass(self): + skip_if_client_version_older_than(self, "6.0") + name = random_string() + await self.client.create_vector_collection_config( + name, [IndexConfig("vector", Metric.COSINE, 3)], async_backup_count=0 + ) + await self.client.get_vector_collection(name) + + async def test_async_backup_count_more_than_max_value_fail(self): + skip_if_server_version_older_than(self, self.client, "6.0") + name = random_string() + # check that the parameter is used by ensuring that it is validated on server side + # there is no simple way to check number of backups + with self.assertRaises(hazelcast.errors.IllegalArgumentError): + await self.client.create_vector_collection_config( + name, + [IndexConfig("vector", Metric.COSINE, 3)], + backup_count=0, + async_backup_count=7, + ) + + async def test_async_backup_count_less_than_min_value_fail(self): + skip_if_server_version_older_than(self, self.client, "6.0") + name = random_string() + with self.assertRaises(hazelcast.errors.IllegalArgumentError): + await self.client.create_vector_collection_config( + name, + [IndexConfig("vector", Metric.COSINE, 3)], + async_backup_count=-1, + ) + + async def test_sync_and_async_backup_count_more_than_max_value_fail(self): + skip_if_server_version_older_than(self, self.client, "6.0") + name = random_string() + with self.assertRaises(hazelcast.errors.IllegalArgumentError): + await self.client.create_vector_collection_config( + name, + [IndexConfig("vector", Metric.COSINE, 3)], + backup_count=4, + async_backup_count=3, + ) + + async def test_merge_policy_can_be_sent(self): + skip_if_client_version_older_than(self, "6.0") + name = random_string() + await self.client.create_vector_collection_config( + name, + [IndexConfig("vector", Metric.COSINE, 3)], + merge_policy="DiscardMergePolicy", + merge_batch_size=1000, + ) + # validation happens when the collection proxy is created + await self.client.get_vector_collection(name) + + async def test_wrong_merge_policy_fails(self): + skip_if_client_version_older_than(self, "6.0") + skip_if_server_version_older_than(self, self.client, "6.0") + name = random_string() + with self.assertRaises(hazelcast.errors.InvalidConfigurationError): + await self.client.create_vector_collection_config( + name, [IndexConfig("vector", Metric.COSINE, 3)], merge_policy="non-existent" + ) + # validation happens when the collection proxy is created + await self.client.get_vector_collection(name) + + async def test_split_brain_name_can_be_sent(self): + skip_if_client_version_older_than(self, "6.0") + name = random_string() + await self.client.create_vector_collection_config( + name, + [IndexConfig("vector", Metric.COSINE, 3)], + # wrong name will be ignored + split_brain_protection_name="non-existent", + ) + col = await self.client.get_vector_collection(name) + doc = Document("v1", Vector("vector", Type.DENSE, [0.1, 0.2, 0.3])) + await col.set("k1", doc) + + def assert_document_equal(self, doc1, doc2) -> None: + self.assertEqual(doc1.value, doc2.value) + self.assertEqual(len(doc1.vectors), len(doc2.vectors)) + # currently there's a bug on the server-side about vector names. + # if there's a single vector, its name is not returned + # see: https://hazelcast.atlassian.net/browse/HZAI-67 + # working around that for now + skip_check_name = len(doc1.vectors) == 1 + for i in range(len(doc1.vectors)): + self.assert_vector_equal(doc1.vectors[i], doc2.vectors[i], skip_check_name) + + def assert_vector_equal(self, vec1, vec2, skip_check_name=False): + if not skip_check_name: + self.assertEqual(vec1.name, vec2.name) + self.assertEqual(vec1.type, vec2.type) + self.assertEqual(len(vec1.vector), len(vec2.vector)) + for i in range(len(vec1.vector)): + self.assertAlmostEqual(vec1.vector[i], vec2.vector[i]) + + @classmethod + def vec1(cls, elems): + return Vector("vector", Type.DENSE, elems) + + @classmethod + def doc1(cls, value, vector_elems): + return Document(value, cls.vec1(vector_elems)) diff --git a/tests/integration/asyncio/serialization/__init__.py b/tests/integration/asyncio/serialization/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integration/asyncio/serialization/compact_test.py b/tests/integration/asyncio/serialization/compact_test.py new file mode 100644 index 0000000000..c49bf4e1fe --- /dev/null +++ b/tests/integration/asyncio/serialization/compact_test.py @@ -0,0 +1,929 @@ +import asyncio +import copy +import datetime +import decimal +import enum +import itertools +import random +import typing +import unittest + +from hazelcast.errors import HazelcastSerializationError +from hazelcast.predicate import sql +from hazelcast.util import AtomicInteger +from tests.integration.asyncio.base import HazelcastTestCase +from tests.util import ( + is_equal, + random_string, + compare_client_version, + compare_server_version_with_rc, +) + +try: + from hazelcast.serialization.api import ( + CompactSerializer, + CompactReader, + CompactWriter, + FieldKind, + ) + from hazelcast.serialization.compact import FIELD_OPERATIONS + + _COMPACT_AVAILABLE = True +except ImportError: + # For backward compatibility tests + + T = typing.TypeVar("T") + + class CompactSerializer(typing.Generic[T]): + pass + + class CompactReader: + pass + + class CompactWriter: + pass + + class FieldKind(enum.Enum): + pass + + _COMPACT_AVAILABLE = False + + +if _COMPACT_AVAILABLE: + FIELD_KINDS = [kind for kind in FieldKind if FIELD_OPERATIONS[kind] is not None] + FIX_SIZED_FIELD_KINDS = [ + kind for kind in FIELD_KINDS if not FIELD_OPERATIONS[kind].is_var_sized() + ] + VAR_SIZED_FIELD_KINDS = [kind for kind in FIELD_KINDS if FIELD_OPERATIONS[kind].is_var_sized()] + + FIX_SIZED_TO_NULLABLE = { + FieldKind.BOOLEAN: FieldKind.NULLABLE_BOOLEAN, + FieldKind.INT8: FieldKind.NULLABLE_INT8, + FieldKind.INT16: FieldKind.NULLABLE_INT16, + FieldKind.INT32: FieldKind.NULLABLE_INT32, + FieldKind.INT64: FieldKind.NULLABLE_INT64, + FieldKind.FLOAT32: FieldKind.NULLABLE_FLOAT32, + FieldKind.FLOAT64: FieldKind.NULLABLE_FLOAT64, + } + + FIX_SIZED_ARRAY_TO_NULLABLE_FIX_SIZED_ARRAY = { + FieldKind.ARRAY_OF_BOOLEAN: FieldKind.ARRAY_OF_NULLABLE_BOOLEAN, + FieldKind.ARRAY_OF_INT8: FieldKind.ARRAY_OF_NULLABLE_INT8, + FieldKind.ARRAY_OF_INT16: FieldKind.ARRAY_OF_NULLABLE_INT16, + FieldKind.ARRAY_OF_INT32: FieldKind.ARRAY_OF_NULLABLE_INT32, + FieldKind.ARRAY_OF_INT64: FieldKind.ARRAY_OF_NULLABLE_INT64, + FieldKind.ARRAY_OF_FLOAT32: FieldKind.ARRAY_OF_NULLABLE_FLOAT32, + FieldKind.ARRAY_OF_FLOAT64: FieldKind.ARRAY_OF_NULLABLE_FLOAT64, + } + + ARRAY_FIELD_KINDS_WITH_NULLABLE_ITEMS = [ + kind + for kind in VAR_SIZED_FIELD_KINDS + if ("ARRAY" in kind.name) and (kind not in FIX_SIZED_ARRAY_TO_NULLABLE_FIX_SIZED_ARRAY) + ] +else: + FIELD_KINDS = [] + FIX_SIZED_FIELD_KINDS = [] + VAR_SIZED_FIELD_KINDS = [] + FIX_SIZED_TO_NULLABLE = {} + FIX_SIZED_ARRAY_TO_NULLABLE_FIX_SIZED_ARRAY = {} + ARRAY_FIELD_KINDS_WITH_NULLABLE_ITEMS = [] + + +@unittest.skipIf( + compare_client_version("5.2") < 0, "Tests the features added in 5.2 version of the client" +) +class CompactTestBase(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + rc = None + cluster = None + member = None + + @classmethod + def setUpClass(cls) -> None: + cls.rc = cls.create_rc() + if compare_server_version_with_rc(cls.rc, "5.2") < 0: + cls.rc.exit() + raise unittest.SkipTest("Compact serialization requires 5.2 server") + + cls.cluster = cls.create_cluster(cls.rc, None) + cls.member = cls.cluster.start_member() + + @classmethod + def tearDownClass(cls): + cls.rc.terminateCluster(cls.cluster.id) + cls.rc.exit() + + async def asyncTearDown(self) -> None: + await self.shutdown_all_clients() + + +class CompactTest(CompactTestBase): + async def test_write_then_read_with_all_fields(self): + serializer = SomeFieldsSerializer.from_kinds(FIELD_KINDS) + await self._write_then_read(FIELD_KINDS, REFERENCE_OBJECTS, serializer) + + async def test_write_then_read_with_no_fields(self): + serializer = SomeFieldsSerializer.from_kinds([]) + await self._write_then_read([], {}, serializer) + + async def test_write_then_read_with_just_var_sized_fields(self): + serializer = SomeFieldsSerializer.from_kinds(VAR_SIZED_FIELD_KINDS) + await self._write_then_read(VAR_SIZED_FIELD_KINDS, REFERENCE_OBJECTS, serializer) + + async def test_write_then_read_with_just_fix_sized_fields(self): + serializer = SomeFieldsSerializer.from_kinds(FIX_SIZED_FIELD_KINDS) + await self._write_then_read(FIX_SIZED_FIELD_KINDS, REFERENCE_OBJECTS, serializer) + + async def test_write_then_read_object_with_different_position_readers(self): + params = [ + ("uint8_reader", 1), + ("uint16_reader", 20), + ("int32_reader", 42), + ] + for name, array_item_count in params: + with self.subTest(name, array_item_count=array_item_count): + reference_objects = { + FieldKind.ARRAY_OF_STRING: [ + "x" * (i * 100) for i in range(1, array_item_count) + ], + FieldKind.INT32: 32, + FieldKind.STRING: "hey", + } + reference_objects[FieldKind.ARRAY_OF_STRING].append(None) + serializer = SomeFieldsSerializer.from_kinds(list(reference_objects.keys())) + await self._write_then_read( + list(reference_objects.keys()), reference_objects, serializer + ) + + async def test_write_then_read_boolean_array(self): + params = [ + ("0", 0), + ("1", 1), + ("8", 8), + ("10", 10), + ("100", 100), + ("1000", 1000), + ] + for name, item_count in params: + with self.subTest(name, item_count=item_count): + reference_objects = { + FieldKind.ARRAY_OF_BOOLEAN: [ + random.randrange(0, 10) % 2 == 0 for _ in range(item_count) + ] + } + serializer = SomeFieldsSerializer.from_kinds(list(reference_objects.keys())) + await self._write_then_read( + list(reference_objects.keys()), reference_objects, serializer + ) + + async def test_write_and_read_with_multiple_boolean_fields(self): + params = [ + ("0", 0), + ("1", 1), + ("8", 8), + ("10", 10), + ("100", 100), + ("1000", 1000), + ] + for name, field_count in params: + with self.subTest(name, field_count=field_count): + all_fields = {str(i): random.randrange(0, 2) % 2 == 0 for i in range(field_count)} + + class Serializer(CompactSerializer[SomeFields]): + def __init__(self, field_names: typing.List[str]): + self._field_names = field_names + + def read(self, reader: CompactReader) -> SomeFields: + fields = {} + for field_name in self._field_names: + fields[field_name] = reader.read_boolean(field_name) + + return SomeFields(**fields) + + def write(self, writer: CompactWriter, obj: SomeFields) -> None: + for field_name in self._field_names: + writer.write_boolean(field_name, getattr(obj, field_name)) + + def get_type_name(self) -> str: + return SomeFields.__name__ + + def get_class(self) -> typing.Type[SomeFields]: + return SomeFields + + await self._write_then_read0(all_fields, Serializer(list(all_fields.keys()))) + + async def test_write_then_read(self): + params = [(field_kind.name, field_kind) for field_kind in FIELD_KINDS] + if not params: + self.skipTest("empty") + for name, field_kind in params: + with self.subTest(name, field_kind=field_kind): + field_name = field_kind.name.lower() + m = await self._put_entry( + map_name=random_string(), + value_to_put=REFERENCE_OBJECTS[field_kind], + field_name=field_name, + ) + obj = await m.get("key") + self.assertTrue(is_equal(REFERENCE_OBJECTS[field_kind], getattr(obj, field_name))) + + async def test_write_none_then_read(self): + params = [(field_kind.name, field_kind) for field_kind in VAR_SIZED_FIELD_KINDS] + if not params: + self.skipTest("empty") + for name, field_kind in params: + with self.subTest(name, field_kind=field_kind): + field_name = field_kind.name.lower() + m = await self._put_entry( + map_name=random_string(), + value_to_put=None, + field_name=field_name, + ) + obj = await m.get("key") + self.assertIsNone(getattr(obj, field_name)) + + async def test_write_array_with_none_items_then_read(self): + params = [ + (field_kind.name, field_kind) for field_kind in ARRAY_FIELD_KINDS_WITH_NULLABLE_ITEMS + ] + if not params: + self.skipTest("empty") + for name, field_kind in params: + with self.subTest(name, field_kind=field_kind): + field_name = field_kind.name.lower() + value = [None] + REFERENCE_OBJECTS[field_kind] + [None] + value.insert(2, None) + m = await self._put_entry( + map_name=random_string(), + value_to_put=value, + field_name=field_name, + ) + obj = await m.get("key") + self.assertTrue(is_equal(value, getattr(obj, field_name))) + + async def test_read_when_field_does_not_exist(self): + params = [(field_kind.name, field_kind) for field_kind in FIELD_KINDS] + if not params: + self.skipTest("empty") + for name, field_kind in params: + with self.subTest(name, field_kind=field_kind): + map_name = random_string() + field_name = field_kind.name.lower() + await self._put_entry( + map_name=map_name, + value_to_put=REFERENCE_OBJECTS[field_kind], + field_name=field_name, + ) + + client = await self.create_client( + { + "cluster_name": self.cluster.id, + "compact_serializers": [ + SomeFieldsSerializer( + [ + FieldDefinition( + name=field_name, + name_to_read="not-a-field", + reader_method_name=f"read_{field_name}", + ) + ] + ), + NestedSerializer(), + ], + } + ) + + evolved_m = await client.get_map(map_name) + with self.assertRaisesRegex(HazelcastSerializationError, "No field with the name"): + await evolved_m.get("key") + + async def test_read_with_type_mismatch(self): + params = [(field_kind.name, field_kind) for field_kind in FIELD_KINDS] + if not params: + self.skipTest("empty") + for name, field_kind in params: + with self.subTest(name, field_kind=field_kind): + map_name = random_string() + mismatched_field_kind = FIELD_KINDS[(field_kind.value + 1) % len(FIELD_KINDS)] + field_name = field_kind.name.lower() + await self._put_entry( + map_name=map_name, + value_to_put=REFERENCE_OBJECTS[mismatched_field_kind], + field_name=field_name, + writer_method_name=f"write_{mismatched_field_kind.name.lower()}", + ) + + client = await self.create_client( + { + "cluster_name": self.cluster.id, + "compact_serializers": [ + SomeFieldsSerializer([FieldDefinition(name=field_name)]), + NestedSerializer(), + ], + } + ) + + m = await client.get_map(map_name) + with self.assertRaisesRegex(HazelcastSerializationError, "Mismatched field types"): + await m.get("key") + + async def test_write_then_read_as_nullable(self): + params = [ + (field_kind.name, field_kind, nullable_field_kind) + for field_kind, nullable_field_kind in itertools.chain( + FIX_SIZED_TO_NULLABLE.items(), + FIX_SIZED_ARRAY_TO_NULLABLE_FIX_SIZED_ARRAY.items(), + ) + ] + if not params: + self.skipTest("empty") + for name, field_kind, nullable_field_kind in params: + with self.subTest(name, field_kind=field_kind, nullable_field_kind=nullable_field_kind): + map_name = random_string() + field_name = field_kind.name.lower() + await self._put_entry( + map_name=map_name, + value_to_put=REFERENCE_OBJECTS[field_kind], + field_name=field_name, + ) + nullable_method_suffix = nullable_field_kind.name.lower() + client = await self.create_client( + { + "cluster_name": self.cluster.id, + "compact_serializers": [ + SomeFieldsSerializer( + [ + FieldDefinition( + name=field_name, + reader_method_name=f"read_{nullable_method_suffix}", + ) + ] + ), + ], + } + ) + + m = await client.get_map(map_name) + obj = await m.get("key") + self.assertTrue(is_equal(REFERENCE_OBJECTS[field_kind], getattr(obj, field_name))) + + async def test_write_as_nullable_then_read(self): + params = [ + (field_kind.name, field_kind, nullable_field_kind) + for field_kind, nullable_field_kind in itertools.chain( + FIX_SIZED_TO_NULLABLE.items(), + FIX_SIZED_ARRAY_TO_NULLABLE_FIX_SIZED_ARRAY.items(), + ) + ] + if not params: + self.skipTest("empty") + for name, field_kind, nullable_field_kind in params: + with self.subTest(name, field_kind=field_kind, nullable_field_kind=nullable_field_kind): + map_name = random_string() + nullable_method_suffix = nullable_field_kind.name.lower() + field_name = field_kind.name.lower() + await self._put_entry( + map_name=map_name, + value_to_put=REFERENCE_OBJECTS[field_kind], + field_name=field_name, + writer_method_name=f"write_{nullable_method_suffix}", + ) + client = await self.create_client( + { + "cluster_name": self.cluster.id, + "compact_serializers": [ + SomeFieldsSerializer([FieldDefinition(name=field_name)]), + ], + } + ) + + m = await client.get_map(map_name) + obj = await m.get("key") + self.assertTrue(is_equal(REFERENCE_OBJECTS[field_kind], getattr(obj, field_name))) + + async def test_write_nullable_fix_sized_as_none_then_read_as_fix_sized(self): + params = [ + (field_kind.name, field_kind, nullable_field_kind) + for field_kind, nullable_field_kind in FIX_SIZED_TO_NULLABLE.items() + ] + if not params: + self.skipTest("empty") + for name, field_kind, nullable_field_kind in params: + with self.subTest(name, field_kind=field_kind, nullable_field_kind=nullable_field_kind): + map_name = random_string() + nullable_method_suffix = nullable_field_kind.name.lower() + field_name = field_kind.name.lower() + await self._put_entry( + map_name=map_name, + value_to_put=None, + field_name=field_name, + writer_method_name=f"write_{nullable_method_suffix}", + ) + client = await self.create_client( + { + "cluster_name": self.cluster.id, + "compact_serializers": [ + SomeFieldsSerializer([FieldDefinition(name=field_name)]), + ], + } + ) + + m = await client.get_map(map_name) + with self.assertRaisesRegex( + HazelcastSerializationError, "A 'None' value cannot be read" + ): + await m.get("key") + + async def test_write_nullable_fix_sized_array_with_none_item_then_read_as_fix_sized_array(self): + params = [ + (field_kind.name, field_kind, nullable_field_kind) + for field_kind, nullable_field_kind in FIX_SIZED_ARRAY_TO_NULLABLE_FIX_SIZED_ARRAY.items() + ] + if not params: + self.skipTest("empty") + for name, field_kind, nullable_field_kind in params: + with self.subTest(name, field_kind=field_kind, nullable_field_kind=nullable_field_kind): + map_name = random_string() + nullable_method_suffix = nullable_field_kind.name.lower() + field_name = field_kind.name.lower() + await self._put_entry( + map_name=map_name, + value_to_put=[None], + field_name=field_name, + writer_method_name=f"write_{nullable_method_suffix}", + ) + client = await self.create_client( + { + "cluster_name": self.cluster.id, + "compact_serializers": [ + SomeFieldsSerializer([FieldDefinition(name=field_name)]), + ], + } + ) + m = await client.get_map(map_name) + with self.assertRaisesRegex( + HazelcastSerializationError, "A `None` item cannot be read" + ): + await m.get("key") + + async def test_write_then_read_with_default_value(self): + params = [(field_kind.name, field_kind) for field_kind in FIELD_KINDS] + if not params: + self.skipTest("empty") + for name, field_kind in params: + with self.subTest(name, field_kind=field_kind): + field_name = field_kind.name.lower() + m = await self._put_entry( + map_name=random_string(), + value_to_put=REFERENCE_OBJECTS[field_kind], + field_name=field_name, + reader_method_name=f"read_{field_name}_or_default", + default_value_to_read=object(), + ) + obj = await m.get("key") + self.assertTrue(is_equal(REFERENCE_OBJECTS[field_kind], getattr(obj, field_name))) + + async def test_write_then_read_with_default_value_when_field_name_does_not_match(self): + params = [(field_kind.name, field_kind) for field_kind in FIELD_KINDS] + if not params: + self.skipTest("empty") + for name, field_kind in params: + with self.subTest(name, field_kind=field_kind): + field_name = field_kind.name.lower() + default_value = object() + m = await self._put_entry( + map_name=random_string(), + value_to_put=REFERENCE_OBJECTS[field_kind], + field_name=field_name, + field_name_to_read="not-a-field", + reader_method_name=f"read_{field_name}_or_default", + default_value_to_read=default_value, + ) + obj = await m.get("key") + self.assertTrue(getattr(obj, field_name) is default_value) + + async def test_write_then_read_with_default_value_when_field_type_does_not_match(self): + params = [(field_kind.name, field_kind) for field_kind in FIELD_KINDS] + if not params: + self.skipTest("empty") + for name, field_kind in params: + with self.subTest(name, field_kind=field_kind): + field_name = field_kind.name.lower() + mismatched_field_kind = FIELD_KINDS[(field_kind.value + 1) % len(FIELD_KINDS)] + default_value = object() + m = await self._put_entry( + map_name=random_string(), + value_to_put=REFERENCE_OBJECTS[mismatched_field_kind], + field_name=field_name, + field_name_to_read=field_name, + writer_method_name=f"write_{mismatched_field_kind.name.lower()}", + reader_method_name=f"read_{field_name}_or_default", + default_value_to_read=default_value, + ) + obj = await m.get("key") + self.assertTrue(getattr(obj, field_name) is default_value) + + async def _put_entry( + self, + *, + map_name: str, + value_to_put: typing.Any, + field_name: str, + field_name_to_read=None, + writer_method_name=None, + reader_method_name=None, + default_value_to_read=None, + ): + field_definition = FieldDefinition( + name=field_name, + name_to_read=field_name_to_read or field_name, + writer_method_name=writer_method_name or f"write_{field_name}", + reader_method_name=reader_method_name or f"read_{field_name}", + default_value_to_read=default_value_to_read, + ) + client = await self.create_client( + { + "cluster_name": self.cluster.id, + "compact_serializers": [ + SomeFieldsSerializer([field_definition]), + NestedSerializer(), + ], + } + ) + + m = await client.get_map(map_name) + await m.put("key", SomeFields(**{field_name: value_to_put})) + return m + + async def _write_then_read( + self, + kinds: typing.List[FieldKind], + reference_objects: typing.Dict[FieldKind, typing.Any], + serializer: CompactSerializer, + ): + fields = {kind.name.lower(): reference_objects[kind] for kind in kinds} + await self._write_then_read0(fields, serializer) + + async def _write_then_read0( + self, fields: typing.Dict[str, typing.Any], serializer: CompactSerializer + ): + client = await self.create_client( + { + "cluster_name": self.cluster.id, + "compact_serializers": [serializer, NestedSerializer()], + } + ) + + m = await client.get_map(random_string()) + await m.put("key", SomeFields(**fields)) + obj = await m.get("key") + for name, value in fields.items(): + self.assertTrue(is_equal(value, getattr(obj, name))) + + +class CompactSchemaEvolutionTest(CompactTestBase): + async def test_adding_a_fix_sized_field(self): + await self._verify_adding_a_field( + ("int32", 42), + ("string", "42"), + new_field_name="int64", + new_field_value=24, + new_field_default_value=12, + ) + + async def test_removing_a_fix_sized_field(self): + await self._verify_removing_a_field( + ("int64", 1234), + ("string", "hey"), + removed_field_name="int64", + removed_field_default_value=43321, + ) + + async def test_adding_a_var_sized_field(self): + await self._verify_adding_a_field( + ("int32", 42), + ("string", "42"), + new_field_name="array_of_boolean", + new_field_value=[True, False, True], + new_field_default_value=[False, False, False, True], + ) + + async def test_removing_a_var_sized_field(self): + await self._verify_removing_a_field( + ("int64", 1234), + ("string", "hey"), + removed_field_name="string", + removed_field_default_value="43321", + ) + + async def _create_client(self, serializer: CompactSerializer): + return await self.create_client( + { + "cluster_name": self.cluster.id, + "compact_serializers": [serializer], + } + ) + + async def _verify_adding_a_field( + self, + *existing_fields: typing.Tuple[str, typing.Any], + new_field_name: str, + new_field_value: typing.Any, + new_field_default_value: typing.Any, + ): + map_name = random_string() + v1_field_definitions = [FieldDefinition(name=name) for name, _ in existing_fields] + v1_serializer = SomeFieldsSerializer(v1_field_definitions) + v1_client = await self._create_client(v1_serializer) + v1_map = await v1_client.get_map(map_name) + v1_fields = {name: value for name, value in existing_fields} + await v1_map.put("key1", SomeFields(**v1_fields)) + + v2_field_definitions = v1_field_definitions + [FieldDefinition(name=new_field_name)] + v2_serializer = SomeFieldsSerializer(v2_field_definitions) + v2_client = await self._create_client(v2_serializer) + v2_map = await v2_client.get_map(map_name) + v2_fields = copy.deepcopy(v1_fields) + v2_fields[new_field_name] = new_field_value + await v2_map.put("key2", SomeFields(**v2_fields)) + + careful_v2_field_definitions = v1_field_definitions + [ + FieldDefinition( + name=new_field_name, + reader_method_name=f"read_{new_field_name}_or_default", + default_value_to_read=new_field_default_value, + ) + ] + careful_v2_serializer = SomeFieldsSerializer(careful_v2_field_definitions) + careful_client_v2 = await self._create_client(careful_v2_serializer) + careful_v2_map = await careful_client_v2.get_map(map_name) + + # Old client can read data written by the new client + v1_obj = await v1_map.get("key2") + for name in v1_fields: + self.assertEqual(v2_fields[name], getattr(v1_obj, name)) + + # New client cannot read data written by the old client, since + # there is no such field on the old data. + + with self.assertRaisesRegex(HazelcastSerializationError, "No field with the name"): + await v2_map.get("key1") + + # However, if it has default value, everything should work + + careful_v2_obj = await careful_v2_map.get("key1") + for name in v2_fields: + self.assertEqual( + v1_fields.get(name) or new_field_default_value, + getattr(careful_v2_obj, name), + ) + + async def _verify_removing_a_field( + self, + *existing_fields: typing.Tuple[str, typing.Any], + removed_field_name: str, + removed_field_default_value: typing.Any, + ): + map_name = random_string() + v1_field_definitions = [FieldDefinition(name=name) for name, _ in existing_fields] + v1_serializer = SomeFieldsSerializer(v1_field_definitions) + v1_client = await self._create_client(v1_serializer) + v1_map = await v1_client.get_map(map_name) + v1_fields = {name: value for name, value in existing_fields} + await v1_map.put("key1", SomeFields(**v1_fields)) + + v2_field_definitions = [ + FieldDefinition(name=name) for name, _ in existing_fields if name != removed_field_name + ] + v2_serializer = SomeFieldsSerializer(v2_field_definitions) + v2_client = await self._create_client(v2_serializer) + v2_map = await v2_client.get_map(map_name) + v2_fields = copy.deepcopy(v1_fields) + del v2_fields[removed_field_name] + await v2_map.put("key2", SomeFields(**v2_fields)) + + careful_v1_field_definitions = v2_field_definitions + [ + FieldDefinition( + name=removed_field_name, + reader_method_name=f"read_{removed_field_name}_or_default", + default_value_to_read=removed_field_default_value, + ) + ] + careful_v1_serializer = SomeFieldsSerializer(careful_v1_field_definitions) + careful_client_v1 = await self._create_client(careful_v1_serializer) + careful_v1_map = await careful_client_v1.get_map(map_name) + + # Old client cannot read data written by the new client, since + # there is no such field on the new data + + with self.assertRaisesRegex(HazelcastSerializationError, "No field with the name"): + await v1_map.get("key2") + + # However, if it has default value, everything should work + v1_obj = await careful_v1_map.get("key2") + for name in v1_fields: + self.assertEqual( + v2_fields.get(name) or removed_field_default_value, + getattr(v1_obj, name), + ) + + # New client can read data written by the old client + v2_obj = await v2_map.get("key1") + for name in v2_fields: + self.assertEqual(v1_fields[name], getattr(v2_obj, name)) + + with self.assertRaises(AttributeError): + getattr(v2_obj, removed_field_name) # no such field for the new schema + + +class CompactOnClusterRestartTest(CompactTestBase): + async def test_cluster_restart(self): + client = await self.create_client( + { + "cluster_name": self.cluster.id, + "compact_serializers": [SomeFieldsSerializer([FieldDefinition(name="int32")])], + } + ) + m = await client.get_map(random_string()) + await m.put(1, SomeFields(int32=42)) + # self.rc.terminateMember(self.cluster.id, self.member.uuid) + # CompactOnClusterRestartTest.member = self.cluster.start_member() + await asyncio.to_thread(self._restart) + await m.put(1, SomeFields(int32=42)) + obj = await m.get(1) + self.assertEqual(42, obj.int32) + # Perform a query to make sure that the schema is available on the cluster + self.assertEqual(1, len(await m.values(sql("int32 == 42")))) + + def _restart(self): + self.rc.terminateMember(self.cluster.id, self.member.uuid) + CompactOnClusterRestartTest.member = self.cluster.start_member() + + +class CompactWithListenerTest(CompactTestBase): + async def test_map_listener(self): + config = { + "cluster_name": self.cluster.id, + "compact_serializers": [SomeFieldsSerializer([FieldDefinition(name="int32")])], + } + client = await self.create_client(config) + map_name = random_string() + m = await client.get_map(map_name) + counter = AtomicInteger() + + def listener(_): + counter.add(1) + + await m.add_entry_listener(include_value=True, added_func=listener) + # Put the entry from other client to not create a local + # registry in the actual client. This will force it to + # go the cluster to fetch the schema. + other_client = await self.create_client(config) + other_client_map = await other_client.get_map(map_name) + await other_client_map.put(1, SomeFields(int32=42)) + await self.assertTrueEventually(lambda: self.assertEqual(1, counter.get())) + + +class SomeFields: + def __init__(self, **fields): + self._fields = fields + + def __getattr__(self, item): + if item not in self._fields: + raise AttributeError() + + return self._fields[item] + + +class Nested: + def __init__(self, i32_field, string_field): + self.i32_field = i32_field + self.string_field = string_field + + def __eq__(self, other): + return ( + isinstance(other, Nested) + and self.i32_field == other.i32_field + and self.string_field == other.string_field + ) + + +class NestedSerializer(CompactSerializer[Nested]): + def read(self, reader: CompactReader) -> Nested: + return Nested(reader.read_int32("i32_field"), reader.read_string("string_field")) + + def write(self, writer: CompactWriter, obj: Nested) -> None: + writer.write_int32("i32_field", obj.i32_field) + writer.write_string("string_field", obj.string_field) + + def get_type_name(self) -> str: + return Nested.__name__ + + def get_class(self) -> typing.Type[Nested]: + return Nested + + +class FieldDefinition: + def __init__( + self, + *, + name: str, + name_to_read: str = None, + writer_method_name: str = None, + reader_method_name: str = None, + default_value_to_read: typing.Any = None, + ): + self.name = name + self.name_to_read = name_to_read or name + self.writer_method_name = writer_method_name or f"write_{name}" + self.reader_method_name = reader_method_name or f"read_{name}" + self.default_value_to_read = default_value_to_read + + +class SomeFieldsSerializer(CompactSerializer[SomeFields]): + def __init__(self, field_definitions: typing.List[FieldDefinition]): + self._field_definitions = field_definitions + + def read(self, reader: CompactReader) -> SomeFields: + fields = {} + for field_definition in self._field_definitions: + reader_parameters = [field_definition.name_to_read] + default_value_to_read = field_definition.default_value_to_read + if default_value_to_read is not None: + reader_parameters.append(default_value_to_read) + + value = getattr(reader, field_definition.reader_method_name)(*reader_parameters) + fields[field_definition.name] = value + + return SomeFields(**fields) + + def write(self, writer: CompactWriter, obj: SomeFields) -> None: + for field_definition in self._field_definitions: + getattr(writer, field_definition.writer_method_name)( + field_definition.name, + getattr(obj, field_definition.name), + ) + + def get_type_name(self) -> str: + return SomeFields.__name__ + + def get_class(self) -> typing.Type[SomeFields]: + return SomeFields + + @staticmethod + def from_kinds(kinds: typing.List[FieldKind]) -> "SomeFieldsSerializer": + field_definitions = [FieldDefinition(name=kind.name.lower()) for kind in kinds] + return SomeFieldsSerializer(field_definitions) + + +if _COMPACT_AVAILABLE: + REFERENCE_OBJECTS = { + FieldKind.BOOLEAN: True, + FieldKind.ARRAY_OF_BOOLEAN: [True, False, True, True, True, False, True, True, False], + FieldKind.INT8: 42, + FieldKind.ARRAY_OF_INT8: [42, -128, -1, 127], + FieldKind.INT16: -456, + FieldKind.ARRAY_OF_INT16: [-4231, 12343, 0], + FieldKind.INT32: 21212121, + FieldKind.ARRAY_OF_INT32: [-1, 1, 0, 9999999], + FieldKind.INT64: 123456789, + FieldKind.ARRAY_OF_INT64: [11, -123456789], + FieldKind.FLOAT32: 12.5, + FieldKind.ARRAY_OF_FLOAT32: [-13.13, 12345.67, 0.1, 9876543.2, -99999.99], + FieldKind.FLOAT64: 12345678.90123, + FieldKind.ARRAY_OF_FLOAT64: [-12345.67], + FieldKind.STRING: "üğişçöa", + FieldKind.ARRAY_OF_STRING: ["17", "😊 😇 🙂", "abc"], + FieldKind.DECIMAL: decimal.Decimal("123.456"), + FieldKind.ARRAY_OF_DECIMAL: [decimal.Decimal("0"), decimal.Decimal("-123456.789")], + FieldKind.TIME: datetime.time(2, 3, 4, 5), + FieldKind.ARRAY_OF_TIME: [datetime.time(8, 7, 6, 5)], + FieldKind.DATE: datetime.date(2022, 1, 1), + FieldKind.ARRAY_OF_DATE: [datetime.date(2021, 11, 11), datetime.date(2020, 3, 3)], + FieldKind.TIMESTAMP: datetime.datetime(2022, 2, 2, 3, 3, 3, 4), + FieldKind.ARRAY_OF_TIMESTAMP: [datetime.datetime(1990, 2, 12, 13, 14, 54, 98765)], + FieldKind.TIMESTAMP_WITH_TIMEZONE: datetime.datetime( + 200, 10, 10, 16, 44, 42, 12345, datetime.timezone(datetime.timedelta(hours=2)) + ), + FieldKind.ARRAY_OF_TIMESTAMP_WITH_TIMEZONE: [ + datetime.datetime( + 2001, 1, 10, 12, 24, 2, 45, datetime.timezone(datetime.timedelta(hours=-2)) + ) + ], + FieldKind.COMPACT: Nested(42, "42"), + FieldKind.ARRAY_OF_COMPACT: [Nested(-42, "-42"), Nested(123, "123")], + FieldKind.NULLABLE_BOOLEAN: False, + FieldKind.ARRAY_OF_NULLABLE_BOOLEAN: [False, False, True], + FieldKind.NULLABLE_INT8: 34, + FieldKind.ARRAY_OF_NULLABLE_INT8: [-32, 32], + FieldKind.NULLABLE_INT16: 36, + FieldKind.ARRAY_OF_NULLABLE_INT16: [37, -37, 0, 12345], + FieldKind.NULLABLE_INT32: -38, + FieldKind.ARRAY_OF_NULLABLE_INT32: [-39, 2134567, -8765432, 39], + FieldKind.NULLABLE_INT64: -4040, + FieldKind.ARRAY_OF_NULLABLE_INT64: [1, 41, -1, 12312312312, -9312912391], + FieldKind.NULLABLE_FLOAT32: 42.4, + FieldKind.ARRAY_OF_NULLABLE_FLOAT32: [-43.4, 434.43], + FieldKind.NULLABLE_FLOAT64: 44.12, + FieldKind.ARRAY_OF_NULLABLE_FLOAT64: [45.678, -4567.8, 0.12345], + }