diff --git a/Makefile b/Makefile index d440d03503..e1032e5525 100644 --- a/Makefile +++ b/Makefile @@ -5,10 +5,10 @@ check: black --check --config black.toml . test: - pytest -m "not enterprise" + pytest --verbose -m "not enterprise" test-enterprise: - pytest + pytest --verbose test-cover: pytest --cov=hazelcast --cov-report=xml diff --git a/hazelcast/asyncio/client.py b/hazelcast/asyncio/client.py index 0f6db252fe..6920eef361 100644 --- a/hazelcast/asyncio/client.py +++ b/hazelcast/asyncio/client.py @@ -5,13 +5,14 @@ from hazelcast.internal.asyncio_cluster import ClusterService, _InternalClusterService from hazelcast.internal.asyncio_compact import CompactSchemaService -from hazelcast.config import Config +from hazelcast.config import Config, IndexConfig from hazelcast.internal.asyncio_connection import ConnectionManager, DefaultAddressProvider from hazelcast.core import DistributedObjectEvent, DistributedObjectInfo from hazelcast.cp import CPSubsystem, ProxySessionManager from hazelcast.discovery import HazelcastCloudAddressProvider from hazelcast.errors import IllegalStateError, InvalidConfigurationError from hazelcast.internal.asyncio_invocation import InvocationService, Invocation +from hazelcast.internal.asyncio_proxy.vector_collection import VectorCollection from hazelcast.lifecycle import LifecycleService, LifecycleState, _InternalLifecycleService from hazelcast.internal.asyncio_listener import ClusterViewListenerService, ListenerService from hazelcast.near_cache import NearCacheManager @@ -20,17 +21,19 @@ client_add_distributed_object_listener_codec, client_get_distributed_objects_codec, client_remove_distributed_object_listener_codec, + dynamic_config_add_vector_collection_config_codec, ) from hazelcast.internal.asyncio_proxy.manager import ( MAP_SERVICE, ProxyManager, + VECTOR_SERVICE, ) from hazelcast.internal.asyncio_proxy.base import Proxy from hazelcast.internal.asyncio_proxy.map import Map from hazelcast.internal.asyncio_reactor import AsyncioReactor from hazelcast.serialization import SerializationServiceV1 from hazelcast.sql import SqlService, _InternalSqlService -from hazelcast.statistics import Statistics +from hazelcast.internal.asyncio_statistics import Statistics from hazelcast.types import KeyType, ValueType, ItemType, MessageType from hazelcast.util import AtomicInteger, RoundRobinLB @@ -162,7 +165,6 @@ def _init_context(self): ) async def _start(self): - self._reactor.start() try: self._internal_lifecycle_service.start() self._invocation_service.start() @@ -177,7 +179,7 @@ async def _start(self): self._listener_service.start() await self._invocation_service.add_backup_listener() self._load_balancer.init(self._cluster_service) - self._statistics.start() + await self._statistics.start() except Exception: await self.shutdown() raise @@ -186,6 +188,37 @@ async def _start(self): async def get_map(self, name: str) -> Map[KeyType, ValueType]: return await self._proxy_manager.get_or_create(MAP_SERVICE, name) + async def create_vector_collection_config( + self, + name: str, + indexes: typing.List[IndexConfig], + backup_count: int = 1, + async_backup_count: int = 0, + split_brain_protection_name: typing.Optional[str] = None, + merge_policy: str = "PutIfAbsentMergePolicy", + merge_batch_size: int = 100, + ) -> None: + # check that indexes have different names + if indexes: + index_names = set(index.name for index in indexes) + if len(index_names) != len(indexes): + raise AssertionError("index names must be unique") + + request = dynamic_config_add_vector_collection_config_codec.encode_request( + name, + indexes, + backup_count, + async_backup_count, + split_brain_protection_name, + merge_policy, + merge_batch_size, + ) + invocation = Invocation(request, response_handler=lambda m: m) + await self._invocation_service.ainvoke(invocation) + + async def get_vector_collection(self, name: str) -> VectorCollection: + return await self._proxy_manager.get_or_create(VECTOR_SERVICE, name) + async def add_distributed_object_listener( self, listener_func: typing.Callable[[DistributedObjectEvent], None] ) -> str: @@ -250,7 +283,6 @@ async def shutdown(self) -> None: await self._connection_manager.shutdown() self._invocation_service.shutdown() self._statistics.shutdown() - self._reactor.shutdown() self._internal_lifecycle_service.fire_lifecycle_event(LifecycleState.SHUTDOWN) @property diff --git a/hazelcast/internal/asyncio_connection.py b/hazelcast/internal/asyncio_connection.py index f747a53110..27f157547b 100644 --- a/hazelcast/internal/asyncio_connection.py +++ b/hazelcast/internal/asyncio_connection.py @@ -185,6 +185,9 @@ def __init__( self._use_public_ip = ( isinstance(address_provider, DefaultAddressProvider) and config.use_public_ip ) + # asyncio tasks are weakly referenced + # storing tasks here in order not to lose them midway + self._tasks = set() def add_listener(self, on_connection_opened=None, on_connection_closed=None): """Registers a ConnectionListener. @@ -315,22 +318,21 @@ async def on_connection_close(self, closed_connection): disconnected = False removed = False trigger_reconnection = False - async with self._lock: - connection = self.active_connections.get(remote_uuid, None) - if connection == closed_connection: - self.active_connections.pop(remote_uuid, None) - removed = True - _logger.info( - "Removed connection to %s:%s, connection: %s", - remote_address, - remote_uuid, - connection, - ) + connection = self.active_connections.get(remote_uuid, None) + if connection == closed_connection: + self.active_connections.pop(remote_uuid, None) + removed = True + _logger.info( + "Removed connection to %s:%s, connection: %s", + remote_address, + remote_uuid, + connection, + ) - if not self.active_connections: - trigger_reconnection = True - if self._client_state == ClientState.INITIALIZED_ON_CLUSTER: - disconnected = True + if not self.active_connections: + trigger_reconnection = True + if self._client_state == ClientState.INITIALIZED_ON_CLUSTER: + disconnected = True if disconnected: self._lifecycle_service.fire_lifecycle_event(LifecycleState.DISCONNECTED) @@ -809,6 +811,9 @@ def __init__(self, connection_manager, client, config, reactor, invocation_servi self._heartbeat_timeout = config.heartbeat_timeout self._heartbeat_interval = config.heartbeat_interval self._heartbeat_task: asyncio.Task | None = None + # asyncio tasks are weakly referenced + # storing tasks here in order not to lose them midway + self._tasks = set() def start(self): """Starts sending periodic HeartBeat operations.""" @@ -848,7 +853,9 @@ async def _check_connection(self, now, connection): if (now - connection.last_write_time) > self._heartbeat_interval: request = client_ping_codec.encode_request() invocation = Invocation(request, connection=connection, urgent=True) - asyncio.create_task(self._invocation_service.ainvoke(invocation)) + task = asyncio.create_task(self._invocation_service.ainvoke(invocation)) + self._tasks.add(task) + task.add_done_callback(self._tasks.discard) _frame_header = struct.Struct(" bool: ``True`` if this proxy is destroyed successfully, ``False`` otherwise. """ - self._on_destroy() - return await self._context.proxy_manager.destroy_proxy(self.service_name, self.name) + async with asyncio.TaskGroup() as tg: # type: ignore[attr-defined] + tg.create_task(self._on_destroy()) + return await tg.create_task( + self._context.proxy_manager.destroy_proxy(self.service_name, self.name) + ) - def _on_destroy(self): + async def _on_destroy(self): pass def __repr__(self) -> str: diff --git a/hazelcast/internal/asyncio_proxy/manager.py b/hazelcast/internal/asyncio_proxy/manager.py index 6bf635bcfc..9daeca0e1f 100644 --- a/hazelcast/internal/asyncio_proxy/manager.py +++ b/hazelcast/internal/asyncio_proxy/manager.py @@ -1,5 +1,9 @@ import typing +from hazelcast.internal.asyncio_proxy.vector_collection import ( + VectorCollection, + create_vector_collection_proxy, +) from hazelcast.protocol.codec import client_create_proxy_codec, client_destroy_proxy_codec from hazelcast.internal.asyncio_invocation import Invocation from hazelcast.internal.asyncio_proxy.base import Proxy @@ -7,9 +11,14 @@ from hazelcast.util import to_list MAP_SERVICE = "hz:impl:mapService" +VECTOR_SERVICE = "hz:service:vector" -_proxy_init: typing.Dict[str, typing.Callable[[str, str, typing.Any], Proxy]] = { +_proxy_init: typing.Dict[ + str, + typing.Callable[[str, str, typing.Any], typing.Coroutine[typing.Any, typing.Any, typing.Any]], +] = { MAP_SERVICE: create_map_proxy, + VECTOR_SERVICE: create_vector_collection_proxy, } @@ -34,7 +43,7 @@ async def _create_proxy(self, service_name, name, create_on_remote) -> Proxy: invocation_service = self._context.invocation_service await invocation_service.ainvoke(invocation) - return _proxy_init[service_name](service_name, name, self._context) + return await _proxy_init[service_name](service_name, name, self._context) async def destroy_proxy(self, service_name, name, destroy_on_remote=True): ns = (service_name, name) diff --git a/hazelcast/internal/asyncio_proxy/map.py b/hazelcast/internal/asyncio_proxy/map.py index 9f2f765ec1..913bd4ef34 100644 --- a/hazelcast/internal/asyncio_proxy/map.py +++ b/hazelcast/internal/asyncio_proxy/map.py @@ -5,7 +5,6 @@ from hazelcast.aggregator import Aggregator from hazelcast.config import IndexUtil, IndexType, IndexConfig from hazelcast.core import SimpleEntryView -from hazelcast.errors import InvalidConfigurationError from hazelcast.projection import Projection from hazelcast.protocol import PagingPredicateHolder from hazelcast.protocol.codec import ( @@ -65,6 +64,7 @@ map_set_with_max_idle_codec, map_remove_interceptor_codec, map_remove_all_codec, + map_add_near_cache_invalidation_listener_codec, ) from hazelcast.internal.asyncio_proxy.base import ( Proxy, @@ -971,8 +971,177 @@ def handler(message): return self._invoke_on_key(request, key_data, handler) -def create_map_proxy(service_name, name, context): +class MapFeatNearCache(Map[KeyType, ValueType]): + """Map proxy implementation featuring Near Cache""" + + def __init__(self, service_name, name, context): + super(MapFeatNearCache, self).__init__(service_name, name, context) + self._invalidation_listener_id = None + self._near_cache = context.near_cache_manager.get_or_create_near_cache(name) + + async def clear(self): + self._near_cache._clear() + return await super(MapFeatNearCache, self).clear() + + async def evict_all(self): + self._near_cache.clear() + return await super(MapFeatNearCache, self).evict_all() + + async def load_all(self, keys=None, replace_existing_values=True): + if keys is None and replace_existing_values: + self._near_cache.clear() + return await super(MapFeatNearCache, self).load_all(keys, replace_existing_values) + + async def _on_destroy(self): + await self._remove_near_cache_invalidation_listener() + self._near_cache.clear() + await super(MapFeatNearCache, self)._on_destroy() + + async def _add_near_cache_invalidation_listener(self): + codec = map_add_near_cache_invalidation_listener_codec + request = codec.encode_request(self.name, EntryEventType.INVALIDATION, self._is_smart) + self._invalidation_listener_id = await self._register_listener( + request, + lambda r: codec.decode_response(r), + lambda reg_id: map_remove_entry_listener_codec.encode_request(self.name, reg_id), + lambda m: codec.handle(m, self._handle_invalidation, self._handle_batch_invalidation), + ) + + async def _remove_near_cache_invalidation_listener(self): + if self._invalidation_listener_id: + await self.remove_entry_listener(self._invalidation_listener_id) + + def _handle_invalidation(self, key, source_uuid, partition_uuid, sequence): + # key is always ``Data`` + # null key means near cache has to remove all entries in it. + # see MapAddNearCacheEntryListenerMessageTask. + if key is None: + self._near_cache._clear() + else: + self._invalidate_cache(key) + + def _handle_batch_invalidation(self, keys, source_uuids, partition_uuids, sequences): + # key_list is always list of ``Data`` + for key_data in keys: + self._invalidate_cache(key_data) + + def _invalidate_cache(self, key_data): + self._near_cache._invalidate(key_data) + + def _invalidate_cache_batch(self, key_data_list): + for key_data in key_data_list: + self._near_cache._invalidate(key_data) + + # internals + async def _contains_key_internal(self, key_data): + try: + return self._near_cache[key_data] + except KeyError: + return await super(MapFeatNearCache, self)._contains_key_internal(key_data) + + async def _get_internal(self, key_data): + try: + return self._near_cache[key_data] + except KeyError: + value = await super(MapFeatNearCache, self)._get_internal(key_data) + self._near_cache.__setitem__(key_data, value) + return value + + async def _get_all_internal(self, partition_to_keys, tasks=None): + tasks = tasks or [] + for key_dic in partition_to_keys.values(): + for key in list(key_dic.keys()): + try: + key_data = key_dic[key] + value = self._near_cache[key_data] + future = asyncio.Future() + future.set_result((key, value)) + tasks.append(future) + del key_dic[key] + except KeyError: + pass + return await super(MapFeatNearCache, self)._get_all_internal(partition_to_keys, tasks) + + def _try_remove_internal(self, key_data, timeout): + self._invalidate_cache(key_data) + return super(MapFeatNearCache, self)._try_remove_internal(key_data, timeout) + + def _try_put_internal(self, key_data, value_data, timeout): + self._invalidate_cache(key_data) + return super(MapFeatNearCache, self)._try_put_internal(key_data, value_data, timeout) + + def _set_internal(self, key_data, value_data, ttl, max_idle): + self._invalidate_cache(key_data) + return super(MapFeatNearCache, self)._set_internal(key_data, value_data, ttl, max_idle) + + def _set_ttl_internal(self, key_data, ttl): + self._invalidate_cache(key_data) + return super(MapFeatNearCache, self)._set_ttl_internal(key_data, ttl) + + def _replace_internal(self, key_data, value_data): + self._invalidate_cache(key_data) + return super(MapFeatNearCache, self)._replace_internal(key_data, value_data) + + def _replace_if_same_internal(self, key_data, old_value_data, new_value_data): + self._invalidate_cache(key_data) + return super(MapFeatNearCache, self)._replace_if_same_internal( + key_data, old_value_data, new_value_data + ) + + def _remove_internal(self, key_data): + self._invalidate_cache(key_data) + return super(MapFeatNearCache, self)._remove_internal(key_data) + + def _remove_all_internal(self, predicate_data): + self._near_cache.clear() + return super(MapFeatNearCache, self)._remove_all_internal(predicate_data) + + def _remove_if_same_internal_(self, key_data, value_data): + self._invalidate_cache(key_data) + return super(MapFeatNearCache, self)._remove_if_same_internal_(key_data, value_data) + + def _put_transient_internal(self, key_data, value_data, ttl, max_idle): + self._invalidate_cache(key_data) + return super(MapFeatNearCache, self)._put_transient_internal( + key_data, value_data, ttl, max_idle + ) + + def _put_internal(self, key_data, value_data, ttl, max_idle): + self._invalidate_cache(key_data) + return super(MapFeatNearCache, self)._put_internal(key_data, value_data, ttl, max_idle) + + def _put_if_absent_internal(self, key_data, value_data, ttl, max_idle): + self._invalidate_cache(key_data) + return super(MapFeatNearCache, self)._put_if_absent_internal( + key_data, value_data, ttl, max_idle + ) + + def _load_all_internal(self, key_data_list, replace_existing_values): + self._invalidate_cache_batch(key_data_list) + return super(MapFeatNearCache, self)._load_all_internal( + key_data_list, replace_existing_values + ) + + def _execute_on_key_internal(self, key_data, entry_processor_data): + self._invalidate_cache(key_data) + return super(MapFeatNearCache, self)._execute_on_key_internal( + key_data, entry_processor_data + ) + + def _evict_internal(self, key_data): + self._invalidate_cache(key_data) + return super(MapFeatNearCache, self)._evict_internal(key_data) + + def _delete_internal(self, key_data): + self._invalidate_cache(key_data) + return super(MapFeatNearCache, self)._delete_internal(key_data) + + +async def create_map_proxy(service_name, name, context): near_cache_config = context.config.near_caches.get(name, None) if near_cache_config is None: return Map(service_name, name, context) - raise InvalidConfigurationError("near cache is not supported") + nc = MapFeatNearCache(service_name, name, context) + if nc._near_cache.invalidate_on_change: + await nc._add_near_cache_invalidation_listener() + return nc diff --git a/hazelcast/internal/asyncio_proxy/vector_collection.py b/hazelcast/internal/asyncio_proxy/vector_collection.py new file mode 100644 index 0000000000..ef12e032af --- /dev/null +++ b/hazelcast/internal/asyncio_proxy/vector_collection.py @@ -0,0 +1,257 @@ +import asyncio +import copy +import typing +import uuid +from typing import Any, Dict, List, Optional, Tuple + +from hazelcast.protocol.codec import ( + vector_collection_set_codec, + vector_collection_get_codec, + vector_collection_search_near_vector_codec, + vector_collection_delete_codec, + vector_collection_put_codec, + vector_collection_put_if_absent_codec, + vector_collection_remove_codec, + vector_collection_put_all_codec, + vector_collection_clear_codec, + vector_collection_optimize_codec, + vector_collection_size_codec, +) +from hazelcast.internal.asyncio_proxy.base import Proxy +from hazelcast.serialization.compact import SchemaNotReplicatedError +from hazelcast.serialization.data import Data +from hazelcast.types import KeyType, ValueType +from hazelcast.util import check_not_none +from hazelcast.vector import ( + Document, + SearchResult, + Vector, + VectorType, + VectorSearchOptions, +) + + +class VectorCollection(Proxy, typing.Generic[KeyType, ValueType]): + def __init__(self, service_name, name, context): + super(VectorCollection, self).__init__(service_name, name, context) + + async def get(self, key: Any) -> Document | None: + check_not_none(key, "key can't be None") + return await self._get_internal(key) + + async def set(self, key: Any, document: Document) -> None: + check_not_none(key, "key can't be None") + check_not_none(document, "document can't be None") + check_not_none(document.value, "document value can't be None") + return await self._set_internal(key, document) + + async def put(self, key: Any, document: Document) -> Document | None: + check_not_none(key, "key can't be None") + check_not_none(document, "document can't be None") + check_not_none(document.value, "document value can't be None") + return await self._put_internal(key, document) + + async def put_all(self, map: Dict[Any, Document]) -> None: + check_not_none(map, "map can't be None") + if not map: + return None + partition_service = self._context.partition_service + partition_map: Dict[int, List[Tuple[Data, Document]]] = {} + for key, doc in map.items(): + check_not_none(key, "key can't be None") + check_not_none(doc, "value can't be None") + doc = copy.copy(doc) + try: + entry = (self._to_data(key), doc) + doc.value = self._to_data(doc.value) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.put_all, map) + + partition_id = partition_service.get_partition_id(entry[0]) + partition_map.setdefault(partition_id, []).append(entry) + + async with asyncio.TaskGroup() as tg: # type: ignore[attr-defined] + for partition_id, entry_list in partition_map.items(): + request = vector_collection_put_all_codec.encode_request(self.name, entry_list) + tg.create_task(self._ainvoke_on_partition(request, partition_id)) + + return None + + async def put_if_absent(self, key: Any, document: Document) -> Document | None: + check_not_none(key, "key can't be None") + check_not_none(document, "document can't be None") + check_not_none(document.value, "document value can't be None") + return await self._put_if_absent_internal(key, document) + + async def search_near_vector( + self, + vector: Vector, + *, + include_value: bool = False, + include_vectors: bool = False, + limit: int = 10, + hints: Dict[str, str] = None + ) -> List[SearchResult]: + check_not_none(vector, "vector can't be None") + if limit <= 0: + raise AssertionError("limit must be positive") + return await self._search_near_vector_internal( + vector, + include_value=include_value, + include_vectors=include_vectors, + limit=limit, + hints=hints, + ) + + async def remove(self, key: Any) -> Document | None: + check_not_none(key, "key can't be None") + return await self._remove_internal(key) + + async def delete(self, key: Any) -> None: + check_not_none(key, "key can't be None") + return await self._delete_internal(key) + + async def optimize(self, index_name: str = None) -> None: + request = vector_collection_optimize_codec.encode_request( + self.name, index_name, uuid.uuid4() + ) + return await self._invoke(request) + + async def clear(self) -> None: + request = vector_collection_clear_codec.encode_request(self.name) + return await self._invoke(request) + + async def size(self) -> int: + request = vector_collection_size_codec.encode_request(self.name) + return await self._invoke(request, vector_collection_size_codec.decode_response) + + def _set_internal(self, key: Any, document: Document) -> asyncio.Future[None]: + try: + key_data = self._to_data(key) + value_data = self._to_data(document.value) + except SchemaNotReplicatedError as e: + return self._send_schema_and_retry(e, self.set, key, document) + document = copy.copy(document) + document.value = value_data + request = vector_collection_set_codec.encode_request( + self.name, + key_data, + document, + ) + return self._invoke_on_key(request, key_data) + + def _get_internal(self, key: Any) -> asyncio.Future[Any]: + def handler(message): + doc = vector_collection_get_codec.decode_response(message) + return self._transform_document(doc) + + try: + key_data = self._to_data(key) + except SchemaNotReplicatedError as e: + return self._send_schema_and_retry(e, self.get, key) + request = vector_collection_get_codec.encode_request( + self.name, + key_data, + ) + return self._invoke_on_key(request, key_data, response_handler=handler) + + def _search_near_vector_internal( + self, + vector: Vector, + *, + include_value: bool = False, + include_vectors: bool = False, + limit: int = 10, + hints: Dict[str, str] = None + ) -> asyncio.Future[List[SearchResult]]: + def handler(message): + results: List[ + SearchResult + ] = vector_collection_search_near_vector_codec.decode_response(message) + for result in results: + if result.key is not None: + result.key = self._to_object(result.key) + if result.value is not None: + result.value = self._to_object(result.value) + if result.vectors: + for vec in result.vectors: + vec.type = VectorType(vec.type) + return results + + options = VectorSearchOptions( + include_value=include_value, + include_vectors=include_vectors, + limit=limit, + hints=hints or {}, + ) + request = vector_collection_search_near_vector_codec.encode_request( + self.name, + [vector], + options, + ) + return self._invoke(request, response_handler=handler) + + def _delete_internal(self, key: Any) -> asyncio.Future[None]: + key_data = self._to_data(key) + request = vector_collection_delete_codec.encode_request(self.name, key_data) + return self._invoke_on_key(request, key_data) + + def _remove_internal(self, key: Any) -> asyncio.Future[Document | None]: + def handler(message): + doc = vector_collection_remove_codec.decode_response(message) + return self._transform_document(doc) + + key_data = self._to_data(key) + request = vector_collection_remove_codec.encode_request(self.name, key_data) + return self._invoke_on_key(request, key_data, response_handler=handler) + + def _put_internal(self, key: Any, document: Document) -> asyncio.Future[Document | None]: + def handler(message): + doc = vector_collection_put_codec.decode_response(message) + return self._transform_document(doc) + + try: + key_data = self._to_data(key) + value_data = self._to_data(document.value) + except SchemaNotReplicatedError as e: + return self._send_schema_and_retry(e, self.set, key, document) + document = copy.copy(document) + document.value = value_data + request = vector_collection_put_codec.encode_request( + self.name, + key_data, + document, + ) + return self._invoke_on_key(request, key_data, response_handler=handler) + + def _put_if_absent_internal( + self, key: Any, document: Document + ) -> asyncio.Future[Document | None]: + def handler(message): + doc = vector_collection_put_if_absent_codec.decode_response(message) + return self._transform_document(doc) + + try: + key_data = self._to_data(key) + value_data = self._to_data(document.value) + except SchemaNotReplicatedError as e: + return self._send_schema_and_retry(e, self.set, key, document) + document.value = value_data + request = vector_collection_put_if_absent_codec.encode_request( + self.name, + key_data, + document, + ) + return self._invoke_on_key(request, key_data, response_handler=handler) + + def _transform_document(self, doc: Optional[Document]) -> Optional[Document]: + if doc is not None: + if doc.value is not None: + doc.value = self._to_object(doc.value) + for vec in doc.vectors: + vec.type = VectorType(vec.type) + return doc + + +async def create_vector_collection_proxy(service_name, name, context): + return VectorCollection(service_name, name, context) diff --git a/hazelcast/internal/asyncio_reactor.py b/hazelcast/internal/asyncio_reactor.py index a44d656449..8565bc0cf3 100644 --- a/hazelcast/internal/asyncio_reactor.py +++ b/hazelcast/internal/asyncio_reactor.py @@ -17,7 +17,6 @@ class AsyncioReactor: def __init__(self, loop: AbstractEventLoop | None = None): - self._is_live = False self._loop = loop or asyncio.get_running_loop() self._bytes_sent = 0 self._bytes_received = 0 @@ -25,14 +24,6 @@ def __init__(self, loop: AbstractEventLoop | None = None): def add_timer(self, delay, callback): return self._loop.call_later(delay, callback) - def start(self): - self._is_live = True - - def shutdown(self): - if not self._is_live: - return - # TODO: cancel tasks - async def connection_factory( self, connection_manager, connection_id, address: Address, network_config, message_callback ): @@ -70,6 +61,7 @@ def __init__( self._address = address self._config = config self._proto = None + self.connected_address = address @classmethod async def create_and_connect( @@ -105,7 +97,14 @@ async def _create_connection(self, config, address): ssl=ssl_context, server_hostname=server_hostname, ) - _sock, self._proto = res + sock, self._proto = res + if hasattr(sock, "_ssl_protocol"): + sock = sock._ssl_protocol._transport._sock + else: + sock = sock._sock + sockname = sock.getsockname() + host, port = sockname[0], sockname[1] + self.local_address = Address(host, port) def _write(self, buf): self._proto.write(buf) @@ -174,6 +173,9 @@ def __init__(self, conn: AsyncioConnection): self._write_buf_size = 0 self._recv_buf = None self._alive = True + # asyncio tasks are weakly referenced + # storing tasks here in order not to lose them midway + self._tasks: set = set() def connection_made(self, transport: transports.BaseTransport): self._transport = transport @@ -184,7 +186,9 @@ def connection_made(self, transport: transports.BaseTransport): def connection_lost(self, exc): self._alive = False - self._conn._loop.create_task(self._conn.close_connection(str(exc), None)) + task = self._conn._loop.create_task(self._conn.close_connection(str(exc), None)) + self._tasks.add(task) + task.add_done_callback(self._tasks.discard) return False def close(self): diff --git a/hazelcast/internal/asyncio_statistics.py b/hazelcast/internal/asyncio_statistics.py new file mode 100644 index 0000000000..71377481b7 --- /dev/null +++ b/hazelcast/internal/asyncio_statistics.py @@ -0,0 +1,394 @@ +import asyncio +import logging +import os + +from hazelcast.core import CLIENT_TYPE +from hazelcast.internal.asyncio_invocation import Invocation +from hazelcast.metrics import MetricsCompressor, MetricDescriptor, ValueType, ProbeUnit +from hazelcast.protocol.codec import client_statistics_codec +from hazelcast.util import current_time_in_millis, to_millis, to_nanos, current_time +from hazelcast import __version__ + +try: + # psutil does not support type hints + import psutil # type: ignore[import] + + _PSUTIL_ENABLED = True +except ImportError: + _PSUTIL_ENABLED = False + +_logger = logging.getLogger(__name__) + +_NEAR_CACHE_CATEGORY_PREFIX = "nc." +_ATTRIBUTE_SEPARATOR = "," +_KEY_VALUE_SEPARATOR = "=" +_EMPTY_ATTRIBUTE_VALUE = "" + +_NEAR_CACHE_DESCRIPTOR_PREFIX = "nearcache" +_NEAR_CACHE_DESCRIPTOR_DISCRIMINATOR = "name" + +_TCP_METRICS_PREFIX = "tcp" + + +class Statistics: + def __init__( + self, client, config, reactor, connection_manager, invocation_service, near_cache_manager + ): + self._client = client + self._reactor = reactor + self._connection_manager = connection_manager + self._invocation_service = invocation_service + self._near_cache_manager = near_cache_manager + self._enabled = config.statistics_enabled + self._period = config.statistics_period + self._statistics_task = None + self._registered_system_gauges = {} + self._registered_process_gauges = {} + + async def start(self): + if not self._enabled: + return + self._register_gauges() + + async def _statistics_task(): + await asyncio.sleep(self._period) + if not self._client.lifecycle_service.is_running(): + return + try: + await self._collect_and_send_stats() + finally: + self._statistics_task = asyncio.create_task(_statistics_task()) + + self._statistics_task = asyncio.create_task(_statistics_task()) + _logger.info("Client statistics enabled with the period of %s seconds.", self._period) + + def shutdown(self): + if self._statistics_task: + self._statistics_task.cancel() + + def _register_gauges(self): + if not _PSUTIL_ENABLED: + _logger.warning( + "Statistics collection is enabled, but psutil is not found. " + "Runtime and system related metrics will not be collected." + ) + return + + self._register_system_gauge( + "os.totalPhysicalMemorySize", + lambda: psutil.virtual_memory().total, + ) + self._register_system_gauge( + "os.freePhysicalMemorySize", + lambda: psutil.virtual_memory().free, + ) + self._register_system_gauge( + "os.committedVirtualMemorySize", + lambda: psutil.virtual_memory().used, + ) + self._register_system_gauge( + "os.totalSwapSpaceSize", + lambda: psutil.swap_memory().total, + ) + self._register_system_gauge( + "os.freeSwapSpaceSize", + lambda: psutil.swap_memory().free, + ) + self._register_system_gauge( + "os.systemLoadAverage", + lambda: os.getloadavg()[0], + ValueType.DOUBLE, + ) + self._register_system_gauge( + "runtime.availableProcessors", + lambda: psutil.cpu_count(), + ) + + self._register_process_gauge( + "runtime.usedMemory", + lambda p: p.memory_info().rss, + ) + self._register_process_gauge( + "os.openFileDescriptorCount", + lambda p: p.num_fds(), + ) + self._register_process_gauge( + "os.maxFileDescriptorCount", + lambda p: p.rlimit(psutil.RLIMIT_NOFILE)[1], + ) + self._register_process_gauge( + "os.processCpuTime", + lambda p: to_nanos(sum(p.cpu_times())), + ) + self._register_process_gauge( + "runtime.uptime", + lambda p: to_millis(current_time() - p.create_time()), + ) + + def _register_system_gauge(self, gauge_name, gauge_fn, value_type=ValueType.LONG): + # Try a gauge function read, we will register it if it succeeds. + try: + gauge_fn() + self._registered_system_gauges[gauge_name] = (gauge_fn, value_type) + except Exception as e: + _logger.debug( + "Unable to register the system related gauge %s. Error: %s", gauge_name, e + ) + + def _register_process_gauge(self, gauge_name, gauge_fn, value_type=ValueType.LONG): + # Try a gauge function read, we will register it if it succeeds. + try: + process = psutil.Process() + gauge_fn(process) + self._registered_process_gauges[gauge_name] = (gauge_fn, value_type) + except Exception as e: + _logger.debug( + "Unable to register the process related gauge %s. Error: %s", gauge_name, e + ) + + async def _collect_and_send_stats(self): + connection = self._connection_manager.get_random_connection() + if not connection: + _logger.debug("Cannot send client statistics to the server. No connection found.") + return + + collection_timestamp = current_time_in_millis() + attributes = [] + compressor = MetricsCompressor() + + self._add_client_attributes(attributes, connection) + self._add_near_cache_metrics(attributes, compressor) + self._add_system_and_process_metrics(attributes, compressor) + self._add_tcp_metrics(compressor) + await self._send_stats( + collection_timestamp, "".join(attributes), compressor.generate_blob(), connection + ) + + async def _send_stats(self, collection_timestamp, attributes, metrics_blob, connection): + request = client_statistics_codec.encode_request( + collection_timestamp, attributes, metrics_blob + ) + invocation = Invocation(request, connection=connection) + await self._invocation_service.ainvoke(invocation) + + def _add_system_and_process_metrics(self, attributes, compressor): + if not _PSUTIL_ENABLED: + # Nothing to do if psutil is not found + return + + for gauge_name, (gauge_fn, value_type) in self._registered_system_gauges.items(): + try: + value = gauge_fn() + self._add_system_or_process_metric( + attributes, compressor, gauge_name, value, value_type + ) + except: + _logger.exception("Error while collecting '%s'.", gauge_name) + + if not self._registered_process_gauges: + # Do not create the process object if no process-related + # metric is registered. + return + + process = psutil.Process() + for gauge_name, (gauge_fn, value_type) in self._registered_process_gauges.items(): + try: + value = gauge_fn(process) + self._add_system_or_process_metric( + attributes, compressor, gauge_name, value, value_type + ) + except: + _logger.exception("Error while collecting '%s'.", gauge_name) + + def _add_system_or_process_metric(self, attributes, compressor, gauge_name, value, value_type): + # We don't have any metrics that do not have prefix. + # Necessary care must be taken when we will send simple + # named metrics. + prefix, metric_name = gauge_name.rsplit(".", 1) + descriptor = MetricDescriptor(metric=metric_name, prefix=prefix) + self._add_metric(compressor, descriptor, value, value_type) + self._add_attribute(attributes, gauge_name, value) + + def _add_client_attributes(self, attributes, connection): + self._add_attribute(attributes, "lastStatisticsCollectionTime", current_time_in_millis()) + self._add_attribute(attributes, "enterprise", "false") + self._add_attribute(attributes, "clientType", CLIENT_TYPE) + self._add_attribute(attributes, "clientVersion", __version__) + self._add_attribute( + attributes, "clusterConnectionTimestamp", to_millis(connection.start_time) + ) + + local_address = connection.local_address + local_address = str(local_address.host) + ":" + str(local_address.port) + self._add_attribute(attributes, "clientAddress", local_address) + self._add_attribute(attributes, "clientName", self._client.name) + + def _add_near_cache_metrics(self, attributes, compressor): + for near_cache in self._near_cache_manager.list_near_caches(): + nc_name = near_cache.name + nc_name_with_prefix = self._get_name_with_prefix(nc_name) + nc_name_with_prefix.append(".") + nc_name_with_prefix = "".join(nc_name_with_prefix) + + near_cache_stats = near_cache.get_statistics() + self._add_near_cache_metric( + attributes, + compressor, + "creationTime", + to_millis(near_cache_stats["creation_time"]), + ValueType.LONG, + ProbeUnit.MS, + nc_name, + nc_name_with_prefix, + ) + + self._add_near_cache_metric( + attributes, + compressor, + "evictions", + near_cache_stats["evictions"], + ValueType.LONG, + ProbeUnit.COUNT, + nc_name, + nc_name_with_prefix, + ) + + self._add_near_cache_metric( + attributes, + compressor, + "hits", + near_cache_stats["hits"], + ValueType.LONG, + ProbeUnit.COUNT, + nc_name, + nc_name_with_prefix, + ) + + self._add_near_cache_metric( + attributes, + compressor, + "misses", + near_cache_stats["misses"], + ValueType.LONG, + ProbeUnit.COUNT, + nc_name, + nc_name_with_prefix, + ) + + self._add_near_cache_metric( + attributes, + compressor, + "ownedEntryCount", + near_cache_stats["owned_entry_count"], + ValueType.LONG, + ProbeUnit.COUNT, + nc_name, + nc_name_with_prefix, + ) + + self._add_near_cache_metric( + attributes, + compressor, + "expirations", + near_cache_stats["expirations"], + ValueType.LONG, + ProbeUnit.COUNT, + nc_name, + nc_name_with_prefix, + ) + + self._add_near_cache_metric( + attributes, + compressor, + "invalidations", + near_cache_stats["invalidations"], + ValueType.LONG, + ProbeUnit.COUNT, + nc_name, + nc_name_with_prefix, + ) + + self._add_near_cache_metric( + attributes, + compressor, + "invalidationRequests", + near_cache_stats["invalidation_requests"], + ValueType.LONG, + ProbeUnit.COUNT, + nc_name, + nc_name_with_prefix, + ) + + self._add_near_cache_metric( + attributes, + compressor, + "ownedEntryMemoryCost", + near_cache_stats["owned_entry_memory_cost"], + ValueType.LONG, + ProbeUnit.BYTES, + nc_name, + nc_name_with_prefix, + ) + + def _add_near_cache_metric( + self, attributes, compressor, metric, value, value_type, unit, nc_name, nc_name_with_prefix + ): + descriptor = MetricDescriptor( + metric=metric, + prefix=_NEAR_CACHE_DESCRIPTOR_PREFIX, + discriminator=_NEAR_CACHE_DESCRIPTOR_DISCRIMINATOR, + discriminator_value=nc_name, + unit=unit, + ) + try: + self._add_metric(compressor, descriptor, value, value_type) + self._add_attribute(attributes, metric, value, nc_name_with_prefix) + except: + _logger.exception( + "Error while collecting %s metric for near cache '%s'.", metric, nc_name + ) + + def _add_tcp_metrics(self, compressor): + self._add_tcp_metric(compressor, "bytesSend", self._reactor._bytes_sent) + self._add_tcp_metric(compressor, "bytesReceived", self._reactor._bytes_received) + + def _add_tcp_metric( + self, compressor, metric, value, value_type=ValueType.LONG, unit=ProbeUnit.BYTES + ): + descriptor = MetricDescriptor( + metric=metric, + prefix=_TCP_METRICS_PREFIX, + unit=unit, + ) + try: + self._add_metric(compressor, descriptor, value, value_type) + except: + _logger.exception("Error while collecting '%s.%s'.", _TCP_METRICS_PREFIX, metric) + + def _add_metric(self, compressor, descriptor, value, value_type): + if value_type == ValueType.LONG: + compressor.add_long(descriptor, value) + elif value_type == ValueType.DOUBLE: + compressor.add_double(descriptor, value) + else: + raise ValueError("Unexpected type: " + value_type) + + def _add_attribute(self, attributes, name, value, key_prefix=None): + if len(attributes) != 0: + attributes.append(_ATTRIBUTE_SEPARATOR) + + if key_prefix: + attributes.append(key_prefix) + + attributes.append(name) + attributes.append(_KEY_VALUE_SEPARATOR) + attributes.append(str(value)) + + def _get_name_with_prefix(self, name): + return [_NEAR_CACHE_CATEGORY_PREFIX, self._escape_special_characters(name)] + + def _escape_special_characters(self, name): + escaped_name = ( + name.replace("\\", "\\\\").replace(",", "\\,").replace(".", "\\.").replace("=", "\\=") + ) + return escaped_name[1:] if name[0] == "/" else escaped_name diff --git a/tests/integration/asyncio/cluster_test.py b/tests/integration/asyncio/cluster_test.py new file mode 100644 index 0000000000..681a4b4f6f --- /dev/null +++ b/tests/integration/asyncio/cluster_test.py @@ -0,0 +1,317 @@ +import os +import tempfile +import unittest + +import pytest + +from hazelcast.asyncio import HazelcastClient +from hazelcast.util import RandomLB, RoundRobinLB +from tests.integration.asyncio.base import HazelcastTestCase, SingleMemberTestCase +from tests.util import ( + random_string, + event_collector, + skip_if_client_version_older_than, + compare_client_version, +) + +try: + from hazelcast.core import EndpointQualifier, ProtocolType +except ImportError: + # Added in 5.0 version of the client. + pass + + +class ClusterTest(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + rc = None + + def setUp(self): + self.rc = self.create_rc() + self.cluster = self.create_cluster(self.rc) + + def create_config(self): + return { + "cluster_name": self.cluster.id, + } + + async def asyncTearDown(self): + await self.shutdown_all_clients() + self.rc.terminateCluster(self.cluster.id) + self.rc.exit() + + async def test_initial_membership_listener(self): + events = [] + + def member_added(m): + events.append(m) + + config = self.create_config() + config["membership_listeners"] = [(member_added, None)] + member = self.cluster.start_member() + await self.create_client(config) + self.assertEqual(len(events), 1) + self.assertEqual(str(events[0].uuid), member.uuid) + self.assertEqual(events[0].address, member.address) + + async def test_for_existing_members(self): + events = [] + + def member_added(member): + events.append(member) + + member = self.cluster.start_member() + config = self.create_config() + client = await self.create_client(config) + client.cluster_service.add_listener(member_added, fire_for_existing=True) + self.assertEqual(len(events), 1) + self.assertEqual(str(events[0].uuid), member.uuid) + self.assertEqual(events[0].address, member.address) + + async def test_member_added(self): + events = [] + + def member_added(member): + events.append(member) + + self.cluster.start_member() + config = self.create_config() + client = await self.create_client(config) + client.cluster_service.add_listener(member_added, fire_for_existing=True) + new_member = self.cluster.start_member() + + def assertion(): + self.assertEqual(len(events), 2) + self.assertEqual(str(events[1].uuid), new_member.uuid) + self.assertEqual(events[1].address, new_member.address) + + await self.assertTrueEventually(assertion) + + async def test_member_removed(self): + events = [] + + def member_removed(member): + events.append(member) + + self.cluster.start_member() + member_to_remove = self.cluster.start_member() + config = self.create_config() + client = await self.create_client(config) + client.cluster_service.add_listener(member_removed=member_removed) + member_to_remove.shutdown() + + def assertion(): + self.assertEqual(len(events), 1) + self.assertEqual(str(events[0].uuid), member_to_remove.uuid) + self.assertEqual(events[0].address, member_to_remove.address) + + await self.assertTrueEventually(assertion) + + async def test_exception_in_membership_listener(self): + def listener(_): + raise RuntimeError("error") + + config = self.create_config() + config["membership_listeners"] = [(listener, listener)] + self.cluster.start_member() + await self.create_client(config) + + async def test_cluster_service_get_members(self): + self.cluster.start_member() + config = self.create_config() + client = await self.create_client(config) + self.assertEqual(1, len(client.cluster_service.get_members())) + + async def test_cluster_service_get_members_with_selector(self): + member = self.cluster.start_member() + config = self.create_config() + client = await self.create_client(config) + self.assertEqual( + 0, len(client.cluster_service.get_members(lambda m: member.address != m.address)) + ) + + +class LoadBalancersWithRealClusterTest(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + @classmethod + def setUpClass(cls): + cls.rc = cls.create_rc() + cls.cluster = cls.create_cluster(cls.rc, None) + cls.member1 = cls.cluster.start_member() + cls.member2 = cls.cluster.start_member() + cls.addresses = [cls.member1.address, cls.member2.address] + + @classmethod + def tearDownClass(cls): + cls.rc.terminateCluster(cls.cluster.id) + cls.rc.exit() + + async def test_random_load_balancer(self): + client = await HazelcastClient.create_and_start( + cluster_name=self.cluster.id, load_balancer=RandomLB() + ) + self.assertTrue(client.lifecycle_service.is_running()) + lb = client._load_balancer + self.assertTrue(isinstance(lb, RandomLB)) + self.assertCountEqual( + self.addresses, list(map(lambda m: m.address, self._get_members_from_lb(lb))) + ) + for _ in range(10): + self.assertTrue(lb.next().address in self.addresses) + await client.shutdown() + + async def test_round_robin_load_balancer(self): + client = await HazelcastClient.create_and_start( + cluster_name=self.cluster.id, load_balancer=RoundRobinLB() + ) + self.assertTrue(client.lifecycle_service.is_running()) + lb = client._load_balancer + self.assertTrue(isinstance(lb, RoundRobinLB)) + self.assertCountEqual( + self.addresses, list(map(lambda m: m.address, self._get_members_from_lb(lb))) + ) + for i in range(10): + self.assertEqual(self.addresses[i % len(self.addresses)], lb.next().address) + await client.shutdown() + + @staticmethod + def _get_members_from_lb(lb): + # For backward-compatibility + members = lb._members + if isinstance(members, list): + return members + # 4.2+ + return members.members + + +@pytest.mark.enterprise +class HotRestartEventTest(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + @classmethod + def setUpClass(cls): + tmp_dir = tempfile.gettempdir() + cls.tmp_dir = os.path.join(tmp_dir, "hr-test-" + random_string()) + + def setUp(self): + self.rc = self.create_rc() + self.cluster = self.create_cluster_keep_cluster_name(self.rc, self.get_config(5701)) + self.client = None + + async def asyncTearDown(self): + if self.client: + await self.client.shutdown() + self.rc.terminateCluster(self.cluster.id) + self.rc.exit() + + async def test_when_member_started_with_another_port_and_the_same_uuid(self): + member = self.cluster.start_member() + self.client = await HazelcastClient.create_and_start(cluster_name=self.cluster.id) + added_listener = event_collector() + removed_listener = event_collector() + self.client.cluster_service.add_listener( + member_added=added_listener, member_removed=removed_listener + ) + self.rc.shutdownCluster(self.cluster.id) + # now stop cluster, restart it with the same name and then start member with port 5702 + self.cluster = self.create_cluster_keep_cluster_name(self.rc, self.get_config(5702)) + self.cluster.start_member() + + def assertion(): + self.assertEqual(1, len(added_listener.events)) + self.assertEqual(1, len(removed_listener.events)) + + await self.assertTrueEventually(assertion) + members = self.client.cluster_service.get_members() + self.assertEqual(1, len(members)) + self.assertEqual(member.uuid, str(members[0].uuid)) + + async def test_when_member_started_with_the_same_address(self): + skip_if_client_version_older_than(self, "4.2") + old_member = self.cluster.start_member() + self.client = await HazelcastClient.create_and_start(cluster_name=self.cluster.id) + members_added = [] + members_removed = [] + self.client.cluster_service.add_listener( + lambda m: members_added.append(m), lambda m: members_removed.append(m) + ) + self.rc.shutdownMember(self.cluster.id, old_member.uuid) + new_member = self.cluster.start_member() + + def assertion(): + self.assertEqual(1, len(members_added)) + self.assertEqual(new_member.uuid, str(members_added[0].uuid)) + self.assertEqual(1, len(members_removed)) + self.assertEqual(old_member.uuid, str(members_removed[0].uuid)) + + await self.assertTrueEventually(assertion) + members = self.client.cluster_service.get_members() + self.assertEqual(1, len(members)) + self.assertEqual(new_member.uuid, str(members[0].uuid)) + + def get_config(self, port): + return """ + + hot-restart-test + + %s + + + %s + + """ % ( + port, + self.tmp_dir, + ) + + +_SERVER_PORT = 5701 +_CLIENT_PORT = 5702 +_SERVER_WITH_CLIENT_ENDPOINT = """ + + + + %s + + + %s + + + +""" % ( + _SERVER_PORT, + _CLIENT_PORT, +) + + +@unittest.skipIf( + compare_client_version("5.0") < 0, "Tests the features added in 5.0 version of the client" +) +class AdvancedNetworkConfigTest(SingleMemberTestCase): + @classmethod + def configure_cluster(cls): + return _SERVER_WITH_CLIENT_ENDPOINT + + @classmethod + def configure_client(cls, config): + config["cluster_members"] = ["localhost:%s" % _CLIENT_PORT] + config["cluster_name"] = cls.cluster.id + return config + + def test_member_list(self): + members = self.client.cluster_service.get_members() + self.assertEqual(1, len(members)) + member = members[0] + # Make sure member address is assigned to client endpoint port + self.assertEqual(_CLIENT_PORT, member.address.port) + + # Make sure there are mappings for CLIENT and MEMBER endpoints + self.assertEqual(2, len(member.address_map)) + self.assertEqual( + _SERVER_PORT, member.address_map.get(EndpointQualifier(ProtocolType.MEMBER, None)).port + ) + self.assertEqual( + _CLIENT_PORT, + member.address_map.get(EndpointQualifier(ProtocolType.CLIENT, None)).port, + ) diff --git a/tests/integration/asyncio/connection_strategy_test.py b/tests/integration/asyncio/connection_strategy_test.py new file mode 100644 index 0000000000..ccd51d267d --- /dev/null +++ b/tests/integration/asyncio/connection_strategy_test.py @@ -0,0 +1,101 @@ +import unittest + +from hazelcast.asyncio import HazelcastClient +from hazelcast.config import ReconnectMode +from hazelcast.errors import ClientOfflineError, HazelcastClientNotActiveError +from hazelcast.lifecycle import LifecycleState +from tests.integration.asyncio.base import HazelcastTestCase +from tests.util import random_string + + +class ConnectionStrategyTest(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + @classmethod + def setUpClass(cls): + cls.rc = cls.create_rc() + + @classmethod + def tearDownClass(cls): + cls.rc.exit() + + def setUp(self): + self.client = None + self.cluster = None + + async def asyncTearDown(self): + if self.client: + await self.client.shutdown() + self.client = None + if self.cluster: + self.rc.terminateCluster(self.cluster.id) + self.cluster = None + + async def test_off_reconnect_mode(self): + self.cluster = self.rc.createCluster(None, None) + member = self.rc.startMember(self.cluster.id) + + def collector(): + events = [] + + def on_state_change(event): + if event == LifecycleState.SHUTDOWN: + events.append(event) + + on_state_change.events = events + return on_state_change + + event_collector = collector() + + self.client = await HazelcastClient.create_and_start( + cluster_members=["localhost:5701"], + cluster_name=self.cluster.id, + reconnect_mode=ReconnectMode.OFF, + lifecycle_listeners=[event_collector], + ) + m = await self.client.get_map(random_string()) + # no exception at this point + await m.put(1, 1) + self.rc.shutdownMember(self.cluster.id, member.uuid) + await self.assertTrueEventually(lambda: self.assertEqual(1, len(event_collector.events))) + with self.assertRaises(HazelcastClientNotActiveError): + await m.put(1, 1) + + async def test_async_reconnect_mode(self): + import logging + + logging.basicConfig(level=logging.DEBUG) + self.cluster = self.rc.createCluster(None, None) + member = self.rc.startMember(self.cluster.id) + + def collector(event_type): + events = [] + + def on_state_change(event): + if event == event_type: + events.append(event) + + on_state_change.events = events + return on_state_change + + disconnected_collector = collector(LifecycleState.DISCONNECTED) + self.client = await HazelcastClient.create_and_start( + cluster_members=["localhost:5701"], + cluster_name=self.cluster.id, + reconnect_mode=ReconnectMode.ASYNC, + lifecycle_listeners=[disconnected_collector], + ) + m = await self.client.get_map(random_string()) + # no exception at this point + await m.put(1, 1) + self.rc.shutdownMember(self.cluster.id, member.uuid) + await self.assertTrueEventually( + lambda: self.assertEqual(1, len(disconnected_collector.events)) + ) + with self.assertRaises(ClientOfflineError): + await m.put(1, 1) + connected_collector = collector(LifecycleState.CONNECTED) + self.client.lifecycle_service.add_listener(connected_collector) + self.rc.startMember(self.cluster.id) + await self.assertTrueEventually( + lambda: self.assertEqual(1, len(connected_collector.events)) + ) + await m.put(1, 1) diff --git a/tests/integration/asyncio/hazelcast_json_value_test.py b/tests/integration/asyncio/hazelcast_json_value_test.py new file mode 100644 index 0000000000..2193d4d2b3 --- /dev/null +++ b/tests/integration/asyncio/hazelcast_json_value_test.py @@ -0,0 +1,74 @@ +from hazelcast.core import HazelcastJsonValue +from hazelcast.predicate import greater, equal +from tests.integration.asyncio.base import SingleMemberTestCase + + +class HazelcastJsonValueWithMapTest(SingleMemberTestCase): + @classmethod + def setUpClass(cls): + super(HazelcastJsonValueWithMapTest, cls).setUpClass() + cls.json_str = '{"key": "value"}' + cls.json_obj = {"key": "value"} + + @classmethod + def configure_client(cls, config): + config["cluster_name"] = cls.cluster.id + return config + + async def asyncSetUp(self): + await super().asyncSetUp() + self.map = await self.client.get_map("json-test") + + async def asyncTearDown(self): + await self.map.destroy() + + async def test_storing_hazelcast_json_value_as_key(self): + json_value = HazelcastJsonValue(self.json_str) + await self.map.put(json_value, 0) + self.assertEqual(0, await self.map.get(json_value)) + + async def test_storing_hazelcast_json_value_as_value(self): + json_value = HazelcastJsonValue(self.json_str) + await self.map.put(0, json_value) + self.assertEqual(json_value.to_string(), (await self.map.get(0)).to_string()) + + async def test_storing_hazelcast_json_value_with_invalid_str(self): + json_value = HazelcastJsonValue('{"a') + await self.map.put(0, json_value) + self.assertEqual(json_value.to_string(), (await self.map.get(0)).to_string()) + + async def test_querying_over_keys_with_hazelcast_json_value(self): + json_value = HazelcastJsonValue({"a": 1}) + json_value2 = HazelcastJsonValue({"a": 3}) + await self.map.put(json_value, 1) + await self.map.put(json_value2, 2) + results = await self.map.key_set(greater("__key.a", 2)) + self.assertEqual(1, len(results)) + self.assertEqual(json_value2.to_string(), results[0].to_string()) + + async def test_querying_nested_attr_over_keys_with_hazelcast_json_value(self): + json_value = HazelcastJsonValue({"a": 1, "b": {"c": "d"}}) + json_value2 = HazelcastJsonValue({"a": 2, "b": {"c": "e"}}) + await self.map.put(json_value, 1) + await self.map.put(json_value2, 2) + results = await self.map.key_set(equal("__key.b.c", "d")) + self.assertEqual(1, len(results)) + self.assertEqual(json_value.to_string(), results[0].to_string()) + + async def test_querying_over_values_with_hazelcast_json_value(self): + json_value = HazelcastJsonValue({"a": 1}) + json_value2 = HazelcastJsonValue({"a": 3}) + await self.map.put(1, json_value) + await self.map.put(2, json_value2) + results = await self.map.values(greater("a", 2)) + self.assertEqual(1, len(results)) + self.assertEqual(json_value2.to_string(), results[0].to_string()) + + async def test_querying_nested_attr_over_values_with_hazelcast_json_value(self): + json_value = HazelcastJsonValue({"a": 1, "b": {"c": "d"}}) + json_value2 = HazelcastJsonValue({"a": 2, "b": {"c": "e"}}) + await self.map.put(1, json_value) + await self.map.put(2, json_value2) + results = await self.map.values(equal("b.c", "d")) + self.assertEqual(1, len(results)) + self.assertEqual(json_value.to_string(), results[0].to_string()) diff --git a/tests/integration/asyncio/heartbeat_test.py b/tests/integration/asyncio/heartbeat_test.py new file mode 100644 index 0000000000..5fbdd72b3f --- /dev/null +++ b/tests/integration/asyncio/heartbeat_test.py @@ -0,0 +1,97 @@ +import asyncio +import threading +import unittest + +from hazelcast.asyncio import HazelcastClient +from hazelcast.core import Address +from tests.integration.asyncio.base import HazelcastTestCase +from tests.integration.asyncio.util import open_connection_to_address, wait_for_partition_table + + +class HeartbeatTest(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + rc = None + + @classmethod + def setUpClass(cls): + cls.rc = cls.create_rc() + + @classmethod + def tearDownClass(cls): + cls.rc.exit() + + async def asyncSetUp(self): + self.cluster = self.create_cluster(self.rc) + self.member = self.rc.startMember(self.cluster.id) + self.client = await HazelcastClient.create_and_start( + cluster_name=self.cluster.id, + heartbeat_interval=0.5, + heartbeat_timeout=2, + ) + + async def asyncTearDown(self): + await self.client.shutdown() + self.rc.shutdownCluster(self.cluster.id) + + async def test_heartbeat_stopped_and_restored(self): + member2 = await asyncio.to_thread(self.rc.startMember, self.cluster.id) + addr = Address(member2.host, member2.port) + await wait_for_partition_table(self.client) + await open_connection_to_address(self.client, member2.uuid) + + def connection_collector(): + connections = [] + + def collector(c, *_): + connections.append(c) + + collector.connections = connections + return collector + + connection_added_collector = connection_collector() + connection_removed_collector = connection_collector() + self.client._connection_manager.add_listener( + connection_added_collector, connection_removed_collector + ) + assertion_succeeded = False + + async def run(): + nonlocal assertion_succeeded + # It is possible for client to override the set last_read_time + # of the connection, in case of the periodically sent heartbeat + # requests getting responses, right after we try to set a new + # value to it, before the next iteration of the heartbeat manager. + # In this case, the connection won't be closed, and the test would + # fail. To avoid it, we will try multiple times. + for i in range(10): + if assertion_succeeded: + # We have successfully simulated heartbeat loss + return + + for connection in list(self.client._connection_manager.active_connections.values()): + if connection.remote_address == addr: + connection.last_read_time -= 2 + break + + await asyncio.sleep((i + 1) * 0.1) + + asyncio.create_task(run()) + + async def assert_heartbeat_stopped_and_restored(): + nonlocal assertion_succeeded + self.assertGreaterEqual(len(connection_added_collector.connections), 1) + self.assertGreaterEqual(len(connection_removed_collector.connections), 1) + + stopped_connection = connection_removed_collector.connections[0] + restored_connection = connection_added_collector.connections[0] + + self.assertEqual( + stopped_connection.connected_address, + Address(member2.host, member2.port), + ) + self.assertEqual( + restored_connection.connected_address, + Address(member2.host, member2.port), + ) + assertion_succeeded = True + + await self.assertTrueEventually(assert_heartbeat_stopped_and_restored) diff --git a/tests/integration/asyncio/invocation_test.py b/tests/integration/asyncio/invocation_test.py new file mode 100644 index 0000000000..97c95b96fb --- /dev/null +++ b/tests/integration/asyncio/invocation_test.py @@ -0,0 +1,64 @@ +import asyncio +import time +import unittest + +from mock import MagicMock + +from hazelcast.asyncio import HazelcastClient +from hazelcast.errors import OperationTimeoutError +from hazelcast.internal.asyncio_invocation import Invocation +from hazelcast.protocol.client_message import OutboundMessage +from hazelcast.serialization import LE_INT +from tests.integration.asyncio.base import HazelcastTestCase + + +class InvocationTimeoutTest(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + @classmethod + def setUpClass(cls): + cls.rc = cls.create_rc() + cls.cluster = cls.create_cluster(cls.rc, None) + cls.member = cls.cluster.start_member() + + @classmethod + def tearDownClass(cls): + cls.rc.terminateCluster(cls.cluster.id) + cls.rc.exit() + + async def asyncSetUp(self): + self.client = await HazelcastClient.create_and_start( + cluster_name=self.cluster.id, invocation_timeout=1 + ) + + async def asyncTearDown(self): + await self.client.shutdown() + + def configure_client(cls, config): + config["cluster_name"] = cls.cluster.id + config["invocation_timeout"] = 1 + return config + + async def test_invocation_timeout(self): + request = OutboundMessage(bytearray(22), True) + invocation_service = self.client._invocation_service + invocation = Invocation(request, partition_id=1) + + def mock(*_): + time.sleep(2) + return False + + invocation_service._invoke_on_partition_owner = MagicMock(side_effect=mock) + invocation_service._invoke_on_random_connection = MagicMock(return_value=False) + invocation_service.invoke(invocation) + with self.assertRaises(OperationTimeoutError): + await invocation.future + + async def test_invocation_not_timed_out_when_there_is_no_exception(self): + buf = bytearray(22) + LE_INT.pack_into(buf, 0, 22) + request = OutboundMessage(buf, True) + invocation_service = self.client._invocation_service + invocation = Invocation(request) + invocation_service.invoke(invocation) + await asyncio.sleep(2) + self.assertFalse(invocation.future.done()) + self.assertEqual(1, len(invocation_service._pending)) diff --git a/tests/integration/asyncio/lifecycle_test.py b/tests/integration/asyncio/lifecycle_test.py new file mode 100644 index 0000000000..4efa014053 --- /dev/null +++ b/tests/integration/asyncio/lifecycle_test.py @@ -0,0 +1,101 @@ +import unittest + +from hazelcast.lifecycle import LifecycleState +from tests.integration.asyncio.base import HazelcastTestCase +from tests.util import event_collector + + +class LifecycleTest(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + rc = None + + def setUp(self): + self.rc = self.create_rc() + self.cluster = self.create_cluster(self.rc) + + async def asyncTearDown(self): + await self.shutdown_all_clients() + self.rc.exit() + + async def test_lifecycle_listener_receives_events_in_order(self): + collector = event_collector() + self.cluster.start_member() + client = await self.create_client( + { + "cluster_name": self.cluster.id, + "lifecycle_listeners": [ + collector, + ], + } + ) + await client.shutdown() + self.assertEqual( + collector.events, + [ + LifecycleState.STARTING, + LifecycleState.STARTED, + LifecycleState.CONNECTED, + LifecycleState.SHUTTING_DOWN, + LifecycleState.DISCONNECTED, + LifecycleState.SHUTDOWN, + ], + ) + + async def test_lifecycle_listener_receives_events_in_order_after_startup(self): + self.cluster.start_member() + collector = event_collector() + client = await self.create_client( + { + "cluster_name": self.cluster.id, + } + ) + client.lifecycle_service.add_listener(collector) + await client.shutdown() + self.assertEqual( + collector.events, + [LifecycleState.SHUTTING_DOWN, LifecycleState.DISCONNECTED, LifecycleState.SHUTDOWN], + ) + + async def test_lifecycle_listener_receives_disconnected_event(self): + member = self.cluster.start_member() + collector = event_collector() + client = await self.create_client( + { + "cluster_name": self.cluster.id, + } + ) + client.lifecycle_service.add_listener(collector) + member.shutdown() + + def assertion(): + self.assertEqual(collector.events, [LifecycleState.DISCONNECTED]) + + await self.assertTrueEventually(assertion) + + await client.shutdown() + + async def test_remove_lifecycle_listener(self): + collector = event_collector() + self.cluster.start_member() + client = await self.create_client( + { + "cluster_name": self.cluster.id, + } + ) + registration_id = client.lifecycle_service.add_listener(collector) + client.lifecycle_service.remove_listener(registration_id) + await client.shutdown() + self.assertEqual(collector.events, []) + + async def test_exception_in_listener(self): + def listener(_): + raise RuntimeError("error") + + self.cluster.start_member() + await self.create_client( + { + "cluster_name": self.cluster.id, + "lifecycle_listeners": [ + listener, + ], + } + ) diff --git a/tests/integration/asyncio/listener_test.py b/tests/integration/asyncio/listener_test.py new file mode 100644 index 0000000000..da13a25438 --- /dev/null +++ b/tests/integration/asyncio/listener_test.py @@ -0,0 +1,121 @@ +import asyncio +import unittest + +from parameterized import parameterized + +from tests.integration.asyncio.base import HazelcastTestCase +from tests.integration.asyncio.util import ( + generate_key_owned_by_instance, + wait_for_partition_table, +) +from tests.util import ( + random_string, + event_collector, +) + +LISTENER_TYPES = [ + ( + "smart", + True, + ), + ( + "non-smart", + False, + ), +] + + +class ListenerRemoveMemberTest(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + def setUp(self): + self.rc = self.create_rc() + self.cluster = self.create_cluster(self.rc, None) + self.m1 = self.cluster.start_member() + self.m2 = self.cluster.start_member() + self.client_config = { + "cluster_name": self.cluster.id, + "heartbeat_interval": 1.0, + } + self.collector = event_collector() + + async def asyncTearDown(self): + await self.shutdown_all_clients() + self.rc.terminateCluster(self.cluster.id) + self.rc.exit() + + async def test_remove_member_smart(self): + await self._remove_member_test(True) + + async def test_remove_member_unisocket(self): + await self._remove_member_test(False) + + async def _remove_member_test(self, is_smart): + self.client_config["smart_routing"] = is_smart + client = await self.create_client(self.client_config) + await wait_for_partition_table(client) + key_m1 = generate_key_owned_by_instance(client, self.m1.uuid) + random_map = await client.get_map(random_string()) + await random_map.add_entry_listener(added_func=self.collector) + await asyncio.to_thread(self.m1.shutdown) + await random_map.put(key_m1, "value2") + + def assert_event(): + self.assertEqual(1, len(self.collector.events)) + + await self.assertTrueEventually(assert_event) + + +class ListenerAddMemberTest(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + def setUp(self): + self.rc = self.create_rc() + self.cluster = self.create_cluster(self.rc, None) + self.m1 = self.cluster.start_member() + self.client_config = { + "cluster_name": self.cluster.id, + } + self.collector = event_collector() + + async def asyncTearDown(self): + await self.shutdown_all_clients() + self.rc.terminateCluster(self.cluster.id) + self.rc.exit() + + async def test_add_member_smart(self): + await self._add_member_test(True) + + async def test_add_member_unisocket(self): + await self._add_member_test(True) + + async def _add_member_test(self, is_smart): + self.client_config["smart_routing"] = is_smart + client = await self.create_client(self.client_config) + random_map = await client.get_map(random_string()) + await random_map.add_entry_listener(added_func=self.collector, updated_func=self.collector) + m2 = await asyncio.to_thread(self.cluster.start_member) + await wait_for_partition_table(client) + key_m2 = generate_key_owned_by_instance(client, m2.uuid) + assertion_succeeded = False + + async def run(): + nonlocal assertion_succeeded + # When a new connection is added, we add the existing + # listeners to it, but we do it non-blocking. So, it might + # be the case that, the listener registration request is + # sent to the new member, but not completed yet. + # So, we might not get an event for the put. To avoid this, + # we will put multiple times. + for i in range(10): + if assertion_succeeded: + # We have successfully got an event + return + + await random_map.put(key_m2, f"value-{i}") + await asyncio.sleep((i + 1) * 0.1) + + asyncio.create_task(run()) + + def assert_event(): + nonlocal assertion_succeeded + self.assertGreaterEqual(len(self.collector.events), 1) + assertion_succeeded = True + + await self.assertTrueEventually(assert_event) diff --git a/tests/integration/asyncio/predicate_test.py b/tests/integration/asyncio/predicate_test.py new file mode 100644 index 0000000000..4b7473a6f3 --- /dev/null +++ b/tests/integration/asyncio/predicate_test.py @@ -0,0 +1,571 @@ +import os +import unittest + +from hazelcast.predicate import ( + equal, + and_, + between, + less, + less_or_equal, + greater, + greater_or_equal, + or_, + not_equal, + not_, + like, + ilike, + regex, + sql, + true, + false, + in_, + instance_of, + paging, +) +from hazelcast.serialization.api import Portable, IdentifiedDataSerializable +from hazelcast.util import IterationType +from tests.integration.asyncio.base import SingleMemberTestCase, HazelcastTestCase +from tests.integration.backward_compatible.util import ( + write_string_to_writer, + read_string_from_reader, +) +from tests.util import random_string, get_abs_path +from hazelcast.asyncio import HazelcastClient + + +class PredicateTest(SingleMemberTestCase): + @classmethod + def configure_client(cls, config): + config["cluster_name"] = cls.cluster.id + return config + + async def asyncSetUp(self): + await super().asyncSetUp() + self.map = await self.client.get_map(random_string()) + + async def asyncTearDown(self): + await self.map.destroy() + await super().asyncTearDown() + + async def fill_map(self, count=10): + m = {"key-%d" % x: "value-%d" % x for x in range(0, count)} + await self.map.put_all(m) + return m + + async def fill_map_numeric(self, count=100): + m = {n: n for n in range(count)} + await self.map.put_all(m) + + async def test_key_set(self): + await self.fill_map() + key_set = await self.map.key_set() + list(key_set) + key_set_list = list(key_set) + assert key_set_list[0] + + async def test_sql(self): + await self.fill_map() + predicate = sql("this == 'value-1'") + self.assertCountEqual(await self.map.key_set(predicate), ["key-1"]) + + async def test_and(self): + await self.fill_map() + predicate = and_(equal("this", "value-1"), equal("this", "value-2")) + self.assertCountEqual(await self.map.key_set(predicate), []) + + async def test_or(self): + await self.fill_map() + predicate = or_(equal("this", "value-1"), equal("this", "value-2")) + self.assertCountEqual(await self.map.key_set(predicate), ["key-1", "key-2"]) + + async def test_not(self): + await self.fill_map(count=3) + predicate = not_(equal("this", "value-1")) + self.assertCountEqual(await self.map.key_set(predicate), ["key-0", "key-2"]) + + async def test_between(self): + await self.fill_map_numeric() + predicate = between("this", 1, 20) + self.assertCountEqual(await self.map.key_set(predicate), list(range(1, 21))) + + async def test_equal(self): + await self.fill_map() + predicate = equal("this", "value-1") + self.assertCountEqual(await self.map.key_set(predicate), ["key-1"]) + + async def test_not_equal(self): + await self.fill_map(count=3) + predicate = not_equal("this", "value-1") + self.assertCountEqual(await self.map.key_set(predicate), ["key-0", "key-2"]) + + async def test_in(self): + await self.fill_map_numeric(count=10) + predicate = in_("this", 1, 5, 7) + self.assertCountEqual(await self.map.key_set(predicate), [1, 5, 7]) + + async def test_less_than(self): + await self.fill_map_numeric() + predicate = less("this", 10) + self.assertCountEqual(await self.map.key_set(predicate), list(range(0, 10))) + + async def test_less_than_or_equal(self): + await self.fill_map_numeric() + predicate = less_or_equal("this", 10) + self.assertCountEqual(await self.map.key_set(predicate), list(range(0, 11))) + + async def test_greater_than(self): + await self.fill_map_numeric() + predicate = greater("this", 10) + self.assertCountEqual(await self.map.key_set(predicate), list(range(11, 100))) + + async def test_greater_than_or_equal(self): + await self.fill_map_numeric() + predicate = greater_or_equal("this", 10) + self.assertCountEqual(await self.map.key_set(predicate), list(range(10, 100))) + + async def test_like(self): + await self.map.put("key-1", "a_value") + await self.map.put("key-2", "b_value") + await self.map.put("key-3", "aa_value") + await self.map.put("key-4", "AA_value") + predicate = like("this", "a%") + self.assertCountEqual(await self.map.key_set(predicate), ["key-1", "key-3"]) + + async def test_ilike(self): + await self.map.put("key-1", "a_value") + await self.map.put("key-2", "b_value") + await self.map.put("key-3", "AA_value") + predicate = ilike("this", "a%") + self.assertCountEqual(await self.map.key_set(predicate), ["key-1", "key-3"]) + + async def test_regex(self): + await self.map.put("key-1", "car") + await self.map.put("key-2", "cry") + await self.map.put("key-3", "giraffe") + predicate = regex("this", "c[ar].*") + self.assertCountEqual(await self.map.key_set(predicate), ["key-1", "key-2"]) + + async def test_instance_of(self): + await self.map.put("key-1", True) + await self.map.put("key-2", 5) + await self.map.put("key-3", "str") + predicate = instance_of("java.lang.Boolean") + self.assertCountEqual(await self.map.key_set(predicate), ["key-1"]) + + async def test_true(self): + m = await self.fill_map() + predicate = true() + self.assertCountEqual(await self.map.key_set(predicate), list(m.keys())) + + async def test_false(self): + await self.fill_map() + predicate = false() + self.assertCountEqual(await self.map.key_set(predicate), []) + + async def test_paging(self): + await self.fill_map_numeric() + predicate = paging(less("this", 4), 2) + self.assertCountEqual([0, 1], await self.map.key_set(predicate)) + predicate.next_page() + self.assertCountEqual([2, 3], await self.map.key_set(predicate)) + predicate.next_page() + self.assertCountEqual([], await self.map.key_set(predicate)) + + +class SimplePortable(Portable): + def __init__(self, field=None): + self.field = field + + def write_portable(self, writer): + writer.write_int("field", self.field) + + def read_portable(self, reader): + self.field = reader.read_int("field") + + def get_factory_id(self): + return 1 + + def get_class_id(self): + return 1 + + +class PredicatePortableTest(SingleMemberTestCase): + @classmethod + def configure_client(cls, config): + config["cluster_name"] = cls.cluster.id + config["portable_factories"] = {1: {1: SimplePortable}} + return config + + async def asyncSetUp(self): + await super().asyncSetUp() + self.map = await self.client.get_map(random_string()) + + async def asyncTearDown(self): + await self.map.destroy() + await super().asyncTearDown() + + async def fill_map(self, count=1000): + m = {x: SimplePortable(x) for x in range(0, count)} + await self.map.put_all(m) + return m + + async def test_predicate_portable_key(self): + _map = await self.fill_map() + map_keys = list(_map.keys()) + predicate = sql("field >= 900") + entries = await self.map.entry_set(predicate) + self.assertEqual(len(entries), 100) + for k, v in entries: + self.assertGreaterEqual(v.field, 900) + self.assertIn(k, map_keys) + + +class NestedPredicatePortableTest(SingleMemberTestCase): + class Body(Portable): + def __init__(self, name=None, limb=None): + self.name = name + self.limb = limb + + def get_class_id(self): + return 1 + + def get_factory_id(self): + return 1 + + def get_class_version(self): + return 15 + + def write_portable(self, writer): + write_string_to_writer(writer, "name", self.name) + writer.write_portable("limb", self.limb) + + def read_portable(self, reader): + self.name = read_string_from_reader(reader, "name") + self.limb = reader.read_portable("limb") + + def __eq__(self, other): + return isinstance(other, self.__class__) and (self.name, self.limb) == ( + other.name, + other.limb, + ) + + class Limb(Portable): + def __init__(self, name=None): + self.name = name + + def get_class_id(self): + return 2 + + def get_factory_id(self): + return 1 + + def get_class_version(self): + return 2 + + def write_portable(self, writer): + write_string_to_writer(writer, "name", self.name) + + def read_portable(self, reader): + self.name = read_string_from_reader(reader, "name") + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.name == other.name + + @classmethod + def configure_client(cls, config): + config["cluster_name"] = cls.cluster.id + config["portable_factories"] = { + 1: { + 1: NestedPredicatePortableTest.Body, + 2: NestedPredicatePortableTest.Limb, + }, + } + return config + + async def asyncSetUp(self): + await super().asyncSetUp() + self.map = await self.client.get_map(random_string()) + await self.map.put( + 1, NestedPredicatePortableTest.Body("body1", NestedPredicatePortableTest.Limb("hand")) + ) + await self.map.put( + 2, NestedPredicatePortableTest.Body("body2", NestedPredicatePortableTest.Limb("leg")) + ) + + async def asyncTearDown(self): + await self.map.destroy() + await super().asyncTearDown() + + async def test_adding_indexes(self): + # single-attribute index + await self.map.add_index(attributes=["name"]) + # nested-attribute index + await self.map.add_index(attributes=["limb.name"]) + + async def test_single_attribute_query_portable_predicates(self): + predicate = equal("limb.name", "hand") + values = await self.map.values(predicate) + self.assertEqual(1, len(values)) + self.assertEqual("body1", values[0].name) + + async def test_nested_attribute_query_sql_predicate(self): + predicate = sql("limb.name == 'leg'") + values = await self.map.values(predicate) + self.assertEqual(1, len(values)) + self.assertEqual("body2", values[0].name) + + +class PagingPredicateTest(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + rc = None + cluster = None + client = None + map = None + + @classmethod + def setUpClass(cls): + cls.rc = cls.create_rc() + cls.cluster = cls.create_cluster(cls.rc, cls.configure_cluster()) + cls.cluster.start_member() + cls.cluster.start_member() + + @classmethod + def tearDownClass(cls): + cls.rc.shutdownCluster(cls.cluster.id) + cls.rc.exit() + + async def asyncSetUp(self): + self.client = await HazelcastClient.create_and_start(cluster_name=self.cluster.id) + self.map = await self.client.get_map(random_string()) + await self.map.clear() + + async def asyncTearDown(self): + await self.map.destroy() + await self.client.shutdown() + + @staticmethod + def configure_cluster(): + current_directory = os.path.dirname(__file__) + dir_path = os.path.dirname(current_directory) + path = os.path.join(dir_path, "backward_compatible/proxy/hazelcast.xml") + with open(path, "r") as f: + return f.read() + + def test_with_inner_paging_predicate(self): + predicate = paging(true(), 1) + + with self.assertRaises(TypeError): + paging(predicate, 1) + + def test_with_non_positive_page_size(self): + with self.assertRaises(ValueError): + paging(true(), 0) + + with self.assertRaises(ValueError): + paging(true(), -1) + + def test_previous_page_when_index_is_zero(self): + predicate = paging(true(), 2) + self.assertEqual(0, predicate.previous_page()) + self.assertEqual(0, predicate.previous_page()) + + async def test_entry_set_with_paging_predicate(self): + await self.fill_map(3) + entry_set = await self.map.entry_set(paging(greater_or_equal("this", 2), 1)) + self.assertEqual(len(entry_set), 1) + self.assertEqual(entry_set[0], ("key-2", 2)) + + async def test_key_set_with_paging_predicate(self): + await self.fill_map(3) + key_set = await self.map.key_set(paging(greater_or_equal("this", 2), 1)) + self.assertEqual(len(key_set), 1) + self.assertEqual(key_set[0], "key-2") + + async def test_values_with_paging_predicate(self): + await self.fill_map(3) + values = await self.map.values(paging(greater_or_equal("this", 2), 1)) + self.assertEqual(len(values), 1) + self.assertEqual(values[0], 2) + + async def test_with_none_inner_predicate(self): + await self.fill_map(3) + predicate = paging(None, 10) + self.assertEqual(await self.map.values(predicate), [0, 1, 2]) + + async def test_first_page(self): + await self.fill_map() + predicate = paging(greater_or_equal("this", 40), 2) + self.assertEqual(await self.map.values(predicate), [40, 41]) + + async def test_next_page(self): + await self.fill_map() + predicate = paging(greater_or_equal("this", 40), 2) + predicate.next_page() + self.assertEqual(await self.map.values(predicate), [42, 43]) + + async def test_set_page(self): + await self.fill_map() + predicate = paging(greater_or_equal("this", 40), 2) + predicate.page = 4 + self.assertEqual(await self.map.values(predicate), [48, 49]) + + def test_get_page(self): + predicate = paging(greater_or_equal("this", 40), 2) + predicate.page = 4 + self.assertEqual(predicate.page, 4) + + def test_page_size(self): + predicate = paging(greater_or_equal("this", 40), 2) + self.assertEqual(predicate.page_size, 2) + + async def test_previous_page(self): + await self.fill_map() + predicate = paging(greater_or_equal("this", 40), 2) + predicate.page = 4 + predicate.previous_page() + self.assertEqual(await self.map.values(predicate), [46, 47]) + + async def test_get_4th_then_previous_page(self): + await self.fill_map() + predicate = paging(greater_or_equal("this", 40), 2) + predicate.page = 4 + await self.map.values(predicate) + predicate.previous_page() + self.assertEqual(await self.map.values(predicate), [46, 47]) + + async def test_get_3rd_then_next_page(self): + await self.fill_map() + predicate = paging(greater_or_equal("this", 40), 2) + predicate.page = 3 + await self.map.values(predicate) + predicate.next_page() + self.assertEqual(await self.map.values(predicate), [48, 49]) + + async def test_set_nonexistent_page(self): + # Trying to get page 10, which is out of range, should return empty list. + await self.fill_map() + predicate = paging(greater_or_equal("this", 40), 2) + predicate.page = 10 + self.assertEqual(await self.map.values(predicate), []) + + async def test_nonexistent_previous_page(self): + # Trying to get previous page while already at first page should return first page. + await self.fill_map() + predicate = paging(greater_or_equal("this", 40), 2) + predicate.previous_page() + self.assertEqual(await self.map.values(predicate), [40, 41]) + + async def test_nonexistent_next_page(self): + # Trying to get next page while already at last page should return empty list. + await self.fill_map() + predicate = paging(greater_or_equal("this", 40), 2) + predicate.page = 4 + predicate.next_page() + self.assertEqual(await self.map.values(predicate), []) + + async def test_get_half_full_last_page(self): + # Page size set to 2, but last page only has 1 element. + await self.fill_map() + predicate = paging(greater_or_equal("this", 41), 2) + predicate.page = 4 + self.assertEqual(await self.map.values(predicate), [49]) + + async def test_reset(self): + await self.fill_map() + predicate = paging(greater_or_equal("this", 40), 2) + self.assertEqual(await self.map.values(predicate), [40, 41]) + predicate.next_page() + self.assertEqual(await self.map.values(predicate), [42, 43]) + predicate.reset() + self.assertEqual(await self.map.values(predicate), [40, 41]) + + async def test_empty_map(self): + # Empty map should return empty list. + predicate = paging(greater_or_equal("this", 30), 2) + self.assertEqual(await self.map.values(predicate), []) + + async def test_equal_values_paging(self): + await self.fill_map() + # keys[50 - 99], values[0 - 49]: + m = {"key-%d" % i: i - 50 for i in range(50, 100)} + await self.map.put_all(m) + predicate = paging(less_or_equal("this", 8), 5) + self.assertEqual(await self.map.values(predicate), [0, 0, 1, 1, 2]) + predicate.next_page() + self.assertEqual(await self.map.values(predicate), [2, 3, 3, 4, 4]) + predicate.next_page() + self.assertEqual(await self.map.values(predicate), [5, 5, 6, 6, 7]) + predicate.next_page() + self.assertEqual(await self.map.values(predicate), [7, 8, 8]) + + async def test_entry_set_with_custom_comparator(self): + m = await self.fill_map() + predicate = paging(less("this", 10), 5, CustomComparator(1, IterationType.KEY)) + + def entries(start, end): + return list( + sorted( + map(lambda k: (k, m[k]), filter(lambda k: start <= m[k] < end, m)), + key=lambda e: e[1], + reverse=True, + ) + ) + + self.assertEqual(entries(5, 10), await self.map.entry_set(predicate)) + predicate.next_page() + self.assertEqual(entries(0, 5), await self.map.entry_set(predicate)) + predicate.next_page() + self.assertEqual([], await self.map.entry_set(predicate)) + + async def test_key_set_with_custom_comparator(self): + m = await self.fill_map() + predicate = paging(less("this", 10), 5, CustomComparator(1, IterationType.KEY)) + keys = list(sorted(m.keys(), key=lambda k: m[k])) + self.assertEqual(keys[9:4:-1], await self.map.key_set(predicate)) + predicate.next_page() + self.assertEqual(keys[4::-1], await self.map.key_set(predicate)) + predicate.next_page() + self.assertEqual([], await self.map.key_set(predicate)) + + async def test_values_with_custom_comparator(self): + m = await self.fill_map() + predicate = paging(less("this", 10), 5, CustomComparator(1, IterationType.KEY)) + values = list(sorted(m.values())) + self.assertEqual(values[9:4:-1], await self.map.values(predicate)) + predicate.next_page() + self.assertEqual(values[4::-1], await self.map.values(predicate)) + predicate.next_page() + self.assertEqual([], await self.map.values(predicate)) + + async def fill_map(self, count=50): + m = {"key-%d" % x: x for x in range(count)} + await self.map.put_all(m) + return m + + +class CustomComparator(IdentifiedDataSerializable): + """ + For type: + + - 0 -> lexicographical order + - 1 -> reverse lexicographical + - 2 -> length increasing order + + Iteration type is same as the ``hazelcast.util.IterationType`` + """ + + def __init__(self, order, iteration_type): + self.order = order + self.iteration_type = iteration_type + + def write_data(self, object_data_output): + object_data_output.write_int(self.order) + object_data_output.write_int(self.iteration_type) + + def read_data(self, object_data_input): + pass + + def get_factory_id(self): + return 66 + + def get_class_id(self): + return 2 diff --git a/tests/integration/asyncio/proxy/map_nearcache_test.py b/tests/integration/asyncio/proxy/map_nearcache_test.py new file mode 100644 index 0000000000..9665ad1a15 --- /dev/null +++ b/tests/integration/asyncio/proxy/map_nearcache_test.py @@ -0,0 +1,193 @@ +import asyncio +import os +import unittest + +from hazelcast.config import ReconnectMode +from hazelcast.errors import ClientOfflineError +from hazelcast.lifecycle import LifecycleState +from hazelcast.predicate import true +from tests.hzrc.ttypes import Lang + +from tests.integration.asyncio.base import SingleMemberTestCase, HazelcastTestCase +from tests.util import random_string, skip_if_client_version_older_than +from hazelcast.asyncio import HazelcastClient + + +class MapTest(SingleMemberTestCase): + @classmethod + def configure_cluster(cls): + path = os.path.abspath(__file__) + dir_path = os.path.dirname(path) + with open(os.path.join(dir_path, "../../backward_compatible/proxy/hazelcast.xml")) as f: + return f.read() + + @classmethod + def configure_client(cls, config): + cls.map_name = random_string() + config["cluster_name"] = cls.cluster.id + config["near_caches"] = {cls.map_name: {}} + return config + + async def asyncSetUp(self): + await super().asyncSetUp() + self.map = await self.client.get_map(self.map_name) + + async def asyncTearDown(self): + await self.map.destroy() + await super().asyncTearDown() + + async def test_put_get(self): + key = "key" + value = "value" + await self.map.put(key, value) + value2 = await self.map.get(key) + value3 = await self.map.get(key) + self.assertEqual(value, value2) + self.assertEqual(value, value3) + self.assertEqual(1, self.map._near_cache._hits) + self.assertEqual(1, self.map._near_cache._misses) + + async def test_put_get_remove(self): + key = "key" + value = "value" + await self.map.put(key, value) + value2 = await self.map.get(key) + value3 = await self.map.get(key) + await self.map.remove(key) + self.assertEqual(value, value2) + self.assertEqual(value, value3) + self.assertEqual(1, self.map._near_cache._hits) + self.assertEqual(1, self.map._near_cache._misses) + self.assertEqual(0, len(self.map._near_cache)) + + async def test_remove_all(self): + skip_if_client_version_older_than(self, "5.6.0") + + await self.fill_map_and_near_cache(10) + await self.map.remove_all(predicate=true()) + self.assertEqual(0, len(self.map._near_cache)) + + async def test_invalidate_single_key(self): + await self.fill_map_and_near_cache(10) + initial_cache_size = len(self.map._near_cache) + script = """map = instance_0.getMap("{}");map.remove("key-5")""".format(self.map.name) + response = await asyncio.to_thread( + self.rc.executeOnController, self.cluster.id, script, Lang.PYTHON + ) + self.assertTrue(response.success) + self.assertEqual(initial_cache_size, 10) + + def assertion(): + self.assertEqual(9, len(self.map._near_cache)) + + await self.assertTrueEventually(assertion, timeout=30) + + async def test_invalidate_nonexist_key(self): + await self.fill_map_and_near_cache(10) + initial_cache_size = len(self.map._near_cache) + script = ( + """ + var map = instance_0.getMap("%s"); + map.put("key-99","x"); + map.put("key-NonExist","x"); + map.remove("key-NonExist");""" + % self.map.name + ) + + response = self.rc.executeOnController(self.cluster.id, script, Lang.JAVASCRIPT) + self.assertTrue(response.success) + self.assertEqual(initial_cache_size, 10) + + async def assertion(): + self.assertEqual(await self.map.size(), 11) + self.assertEqual(len(self.map._near_cache), 10) + + await self.assertTrueEventually(assertion) + + async def test_invalidate_multiple_keys(self): + await self.fill_map_and_near_cache(10) + initial_cache_size = len(self.map._near_cache) + script = """map = instance_0.getMap("{}");map.clear()""".format(self.map.name) + response = await asyncio.to_thread( + self.rc.executeOnController, self.cluster.id, script, Lang.PYTHON + ) + self.assertTrue(response.success) + self.assertEqual(initial_cache_size, 10) + + def assertion(): + self.assertEqual(0, len(self.map._near_cache)) + + await self.assertTrueEventually(assertion, timeout=60) + + async def fill_map_and_near_cache(self, count=10): + fill_content = {"key-%d" % x: "value-%d" % x for x in range(0, count)} + for k, v in fill_content.items(): + await self.map.put(k, v) + for k, v in fill_content.items(): + await self.map.get(k) + return fill_content + + +ENTRY_COUNT = 100 + + +class NonStopNearCacheTest(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + async def asyncSetUp(self): + rc = self.create_rc() + cluster = self.create_cluster(rc, self.read_cluster_config()) + cluster.start_member() + + def event_collector(): + events = [] + + def collector(e): + if e == LifecycleState.DISCONNECTED: + events.append(e) + + collector.events = events + return collector + + collector = event_collector() + + client = await HazelcastClient.create_and_start( + cluster_name=cluster.id, + reconnect_mode=ReconnectMode.ASYNC, + near_caches={"map": {}}, + lifecycle_listeners=[collector], + ) + + map = await client.get_map("map") + for i in range(ENTRY_COUNT): + await map.put(i, i) + + # Populate the near cache + for i in range(ENTRY_COUNT): + await map.get(i) + + rc.terminateCluster(cluster.id) + rc.exit() + await self.assertTrueEventually(lambda: self.assertEqual(1, len(collector.events))) + self.client = client + self.map = map + + async def asyncTearDown(self): + await self.client.shutdown() + + async def test_get_existing_key_from_cache_when_the_cluster_is_down(self): + for i in range(ENTRY_COUNT): + self.assertEqual(i, await self.map.get(i)) + + async def test_get_non_existing_key_from_cache_when_the_cluster_is_down(self): + with self.assertRaises(ClientOfflineError): + await self.map.get(ENTRY_COUNT) + + async def test_put_to_map_when_the_cluster_is_down(self): + with self.assertRaises(ClientOfflineError): + await self.map.put(ENTRY_COUNT, ENTRY_COUNT) + + @staticmethod + def read_cluster_config(): + path = os.path.abspath(__file__) + dir_path = os.path.dirname(path) + with open(os.path.join(dir_path, "../../backward_compatible/proxy/hazelcast.xml")) as f: + return f.read() diff --git a/tests/integration/asyncio/proxy/map_test.py b/tests/integration/asyncio/proxy/map_test.py index b63ae9e0fe..2b4a5d5360 100644 --- a/tests/integration/asyncio/proxy/map_test.py +++ b/tests/integration/asyncio/proxy/map_test.py @@ -40,7 +40,6 @@ from hazelcast.core import HazelcastJsonValue from hazelcast.config import IndexType, IntType -from hazelcast.errors import HazelcastError from hazelcast.predicate import greater_or_equal, less_or_equal, sql, paging, true from hazelcast.internal.asyncio_proxy.map import EntryEventType from hazelcast.serialization.api import IdentifiedDataSerializable diff --git a/tests/integration/asyncio/proxy/vector_collection_test.py b/tests/integration/asyncio/proxy/vector_collection_test.py new file mode 100644 index 0000000000..d6db835782 --- /dev/null +++ b/tests/integration/asyncio/proxy/vector_collection_test.py @@ -0,0 +1,339 @@ +import os +import unittest + +import pytest + +import hazelcast.errors +from tests.integration.asyncio.base import SingleMemberTestCase +from tests.util import ( + random_string, + compare_client_version, + skip_if_server_version_older_than, + skip_if_client_version_older_than, +) + +try: + from hazelcast.vector import IndexConfig, Metric, Document, Vector, Type +except ImportError: + # backward compatibility + pass + + +@unittest.skipIf( + compare_client_version("5.5") < 0, "Tests the features added in 5.5 version of the client" +) +@pytest.mark.enterprise +class VectorCollectionTest(SingleMemberTestCase): + @classmethod + def configure_cluster(cls): + path = os.path.abspath(__file__) + dir_path = os.path.dirname(path) + with open(os.path.join(dir_path, "../../backward_compatible/proxy/hazelcast.xml")) as f: + return f.read() + + @classmethod + def configure_client(cls, config): + config["cluster_name"] = cls.cluster.id + return config + + async def asyncSetUp(self): + await super().asyncSetUp() + skip_if_server_version_older_than(self, self.client, "5.5") + name = random_string() + await self.client.create_vector_collection_config( + name, [IndexConfig("vector", Metric.COSINE, 3)] + ) + self.vector_collection = await self.client.get_vector_collection(name) + + async def asyncTearDown(self): + await self.vector_collection.destroy() + await super().asyncTearDown() + + async def test_set(self): + doc = Document("v1", Vector("vector", Type.DENSE, [0.1, 0.2, 0.3])) + await self.vector_collection.set("k1", doc) + + async def test_get(self): + doc = await self.vector_collection.get("k1") + self.assertIsNone(doc) + doc = Document("v1", Vector("vector", Type.DENSE, [0.1, 0.2, 0.3])) + await self.vector_collection.set("k1", doc) + got_doc = await self.vector_collection.get("k1") + self.assert_document_equal(got_doc, doc) + + async def test_put(self): + doc = Document("v1", Vector("vector", Type.DENSE, [0.1, 0.2, 0.3])) + doc_old = await self.vector_collection.put("k1", doc) + self.assertIsNone(doc_old) + doc2 = Document("v1", Vector("vector", Type.DENSE, [0.4, 0.5, 0.6])) + doc_old = await self.vector_collection.put("k1", doc2) + self.assert_document_equal(doc_old, doc) + + async def test_delete(self): + doc = await self.vector_collection.get("k1") + self.assertIsNone(doc) + doc = Document("v1", Vector("vector", Type.DENSE, [0.1, 0.2, 0.3])) + await self.vector_collection.set("k1", doc) + await self.vector_collection.delete("k1") + doc = await self.vector_collection.get("k1") + self.assertIsNone(doc) + + async def test_remove(self): + doc = await self.vector_collection.get("k1") + self.assertIsNone(doc) + doc = Document("v1", Vector("vector", Type.DENSE, [0.1, 0.2, 0.3])) + await self.vector_collection.set("k1", doc) + doc2 = await self.vector_collection.remove("k1") + self.assert_document_equal(doc, doc2) + + async def test_put_all(self): + doc1 = self.doc1("v1", [0.1, 0.2, 0.3]) + doc2 = self.doc1("v1", [0.2, 0.3, 0.4]) + await self.vector_collection.put_all( + { + "k1": doc1, + "k2": doc2, + } + ) + k1 = await self.vector_collection.get("k1") + self.assert_document_equal(k1, doc1) + k2 = await self.vector_collection.get("k2") + self.assert_document_equal(k2, doc2) + + async def test_clear(self): + doc = await self.vector_collection.get("k1") + self.assertIsNone(doc) + doc = Document("v1", self.vec1([0.1, 0.2, 0.3])) + await self.vector_collection.set("k1", doc) + await self.vector_collection.clear() + doc = await self.vector_collection.get("k1") + self.assertIsNone(doc) + + async def test_optimize(self): + doc = Document("v1", self.vec1([0.1, 0.2, 0.3])) + await self.vector_collection.set("k1", doc) + # it is hard to observe results of optimize, so just test that the invocation works + await self.vector_collection.optimize() + + async def test_optimize_with_name(self): + doc = Document("v1", self.vec1([0.1, 0.2, 0.3])) + await self.vector_collection.set("k1", doc) + # it is hard to observe results of optimize, so just test that the invocation works + await self.vector_collection.optimize("vector") + + async def test_search_near_vector_include_all(self): + target_doc = self.doc1("v1", [0.3, 0.4, 0.5]) + await self.vector_collection.put_all( + { + "k1": self.doc1("v1", [0.1, 0.2, 0.3]), + "k2": self.doc1("v1", [0.2, 0.3, 0.4]), + "k3": target_doc, + } + ) + result = await self.vector_collection.search_near_vector( + self.vec1([0.2, 0.2, 0.3]), limit=1, include_vectors=True, include_value=True + ) + self.assertEqual(1, len(result)) + self.assert_document_equal(target_doc, result[0]) + self.assertAlmostEqual(0.9973459243774414, result[0].score) + + async def test_search_near_vector_include_none(self): + target_doc = self.doc1("v1", [0.3, 0.4, 0.5]) + await self.vector_collection.put_all( + { + "k1": self.doc1("v1", [0.1, 0.2, 0.3]), + "k2": self.doc1("v1", [0.2, 0.3, 0.4]), + "k3": target_doc, + } + ) + result = await self.vector_collection.search_near_vector( + self.vec1([0.2, 0.2, 0.3]), limit=1, include_vectors=False, include_value=False + ) + self.assertEqual(1, len(result)) + result1 = result[0] + self.assertAlmostEqual(0.9973459243774414, result1.score) + self.assertIsNone(result1.value) + self.assertIsNone(result1.vectors) + + async def test_search_near_vector_hint(self): + # not empty collection is needed for search to do something + doc = Document("v1", self.vec1([0.1, 0.2, 0.3])) + await self.vector_collection.set("k1", doc) + # trigger validation error to check if hint was sent + with self.assertRaises(hazelcast.errors.IllegalArgumentError): + await self.vector_collection.search_near_vector( + self.vec1([0.2, 0.2, 0.3]), + limit=1, + include_vectors=False, + include_value=False, + hints={"partitionLimit": "-1"}, + ) + + async def test_size(self): + self.assertEqual(await self.vector_collection.size(), 0) + doc = Document("v1", Vector("vector", Type.DENSE, [0.1, 0.2, 0.3])) + await self.vector_collection.put("k1", doc) + self.assertEqual(await self.vector_collection.size(), 1) + await self.vector_collection.clear() + self.assertEqual(await self.vector_collection.size(), 0) + + async def test_backup_count_valid_values_pass(self): + skip_if_client_version_older_than(self, "6.0") + name = random_string() + await self.client.create_vector_collection_config( + name, [IndexConfig("vector", Metric.COSINE, 3)], backup_count=2, async_backup_count=2 + ) + await self.client.get_vector_collection(name) + + async def test_backup_count_max_value_pass(self): + skip_if_client_version_older_than(self, "6.0") + name = random_string() + await self.client.create_vector_collection_config( + name, [IndexConfig("vector", Metric.COSINE, 3)], backup_count=6 + ) + await self.client.get_vector_collection(name) + + async def test_backup_count_min_value_pass(self): + skip_if_client_version_older_than(self, "6.0") + name = random_string() + await self.client.create_vector_collection_config( + name, [IndexConfig("vector", Metric.COSINE, 3)], backup_count=0 + ) + await self.client.get_vector_collection(name) + + async def test_backup_count_more_than_max_value_fail(self): + skip_if_server_version_older_than(self, self.client, "6.0") + name = random_string() + # check that the parameter is used by ensuring that it is validated on server side + # there is no simple way to check number of backups + with self.assertRaises(hazelcast.errors.IllegalArgumentError): + await self.client.create_vector_collection_config( + name, + [IndexConfig("vector", Metric.COSINE, 3)], + backup_count=7, + async_backup_count=0, + ) + + async def test_backup_count_less_than_min_value_fail(self): + skip_if_server_version_older_than(self, self.client, "6.0") + name = random_string() + with self.assertRaises(hazelcast.errors.IllegalArgumentError): + await self.client.create_vector_collection_config( + name, [IndexConfig("vector", Metric.COSINE, 3)], backup_count=-1 + ) + + async def test_async_backup_count_max_value_pass(self): + skip_if_client_version_older_than(self, "6.0") + name = random_string() + await self.client.create_vector_collection_config( + name, + [IndexConfig("vector", Metric.COSINE, 3)], + backup_count=0, + async_backup_count=6, + ) + await self.client.get_vector_collection(name) + + async def test_async_backup_count_min_value_pass(self): + skip_if_client_version_older_than(self, "6.0") + name = random_string() + await self.client.create_vector_collection_config( + name, [IndexConfig("vector", Metric.COSINE, 3)], async_backup_count=0 + ) + await self.client.get_vector_collection(name) + + async def test_async_backup_count_more_than_max_value_fail(self): + skip_if_server_version_older_than(self, self.client, "6.0") + name = random_string() + # check that the parameter is used by ensuring that it is validated on server side + # there is no simple way to check number of backups + with self.assertRaises(hazelcast.errors.IllegalArgumentError): + await self.client.create_vector_collection_config( + name, + [IndexConfig("vector", Metric.COSINE, 3)], + backup_count=0, + async_backup_count=7, + ) + + async def test_async_backup_count_less_than_min_value_fail(self): + skip_if_server_version_older_than(self, self.client, "6.0") + name = random_string() + with self.assertRaises(hazelcast.errors.IllegalArgumentError): + await self.client.create_vector_collection_config( + name, + [IndexConfig("vector", Metric.COSINE, 3)], + async_backup_count=-1, + ) + + async def test_sync_and_async_backup_count_more_than_max_value_fail(self): + skip_if_server_version_older_than(self, self.client, "6.0") + name = random_string() + with self.assertRaises(hazelcast.errors.IllegalArgumentError): + await self.client.create_vector_collection_config( + name, + [IndexConfig("vector", Metric.COSINE, 3)], + backup_count=4, + async_backup_count=3, + ) + + async def test_merge_policy_can_be_sent(self): + skip_if_client_version_older_than(self, "6.0") + name = random_string() + await self.client.create_vector_collection_config( + name, + [IndexConfig("vector", Metric.COSINE, 3)], + merge_policy="DiscardMergePolicy", + merge_batch_size=1000, + ) + # validation happens when the collection proxy is created + await self.client.get_vector_collection(name) + + async def test_wrong_merge_policy_fails(self): + skip_if_client_version_older_than(self, "6.0") + skip_if_server_version_older_than(self, self.client, "6.0") + name = random_string() + with self.assertRaises(hazelcast.errors.InvalidConfigurationError): + await self.client.create_vector_collection_config( + name, [IndexConfig("vector", Metric.COSINE, 3)], merge_policy="non-existent" + ) + # validation happens when the collection proxy is created + await self.client.get_vector_collection(name) + + async def test_split_brain_name_can_be_sent(self): + skip_if_client_version_older_than(self, "6.0") + name = random_string() + await self.client.create_vector_collection_config( + name, + [IndexConfig("vector", Metric.COSINE, 3)], + # wrong name will be ignored + split_brain_protection_name="non-existent", + ) + col = await self.client.get_vector_collection(name) + doc = Document("v1", Vector("vector", Type.DENSE, [0.1, 0.2, 0.3])) + await col.set("k1", doc) + + def assert_document_equal(self, doc1, doc2) -> None: + self.assertEqual(doc1.value, doc2.value) + self.assertEqual(len(doc1.vectors), len(doc2.vectors)) + # currently there's a bug on the server-side about vector names. + # if there's a single vector, its name is not returned + # see: https://hazelcast.atlassian.net/browse/HZAI-67 + # working around that for now + skip_check_name = len(doc1.vectors) == 1 + for i in range(len(doc1.vectors)): + self.assert_vector_equal(doc1.vectors[i], doc2.vectors[i], skip_check_name) + + def assert_vector_equal(self, vec1, vec2, skip_check_name=False): + if not skip_check_name: + self.assertEqual(vec1.name, vec2.name) + self.assertEqual(vec1.type, vec2.type) + self.assertEqual(len(vec1.vector), len(vec2.vector)) + for i in range(len(vec1.vector)): + self.assertAlmostEqual(vec1.vector[i], vec2.vector[i]) + + @classmethod + def vec1(cls, elems): + return Vector("vector", Type.DENSE, elems) + + @classmethod + def doc1(cls, value, vector_elems): + return Document(value, cls.vec1(vector_elems)) diff --git a/tests/integration/asyncio/reconnect_test.py b/tests/integration/asyncio/reconnect_test.py new file mode 100644 index 0000000000..ed09afabf7 --- /dev/null +++ b/tests/integration/asyncio/reconnect_test.py @@ -0,0 +1,269 @@ +import asyncio +import sys +import unittest + +from hazelcast.asyncio import HazelcastClient +from hazelcast.errors import HazelcastError, TargetDisconnectedError +from hazelcast.lifecycle import LifecycleState +from hazelcast.util import AtomicInteger +from tests.integration.asyncio.base import HazelcastTestCase +from tests.util import event_collector + + +class ReconnectTest(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + rc = None + + def setUp(self): + self.rc = self.create_rc() + self.cluster = self.create_cluster(self.rc) + + async def asyncTearDown(self): + await self.shutdown_all_clients() + self.rc.exit() + + async def test_start_client_with_no_member(self): + with self.assertRaises(HazelcastError): + await self.create_client( + { + "cluster_members": [ + "127.0.0.1:5701", + "127.0.0.1:5702", + "127.0.0.1:5703", + ], + "cluster_connect_timeout": 2, + } + ) + + async def test_start_client_before_member(self): + async def run(): + await asyncio.sleep(1.0) + await asyncio.to_thread(self.cluster.start_member) + + asyncio.create_task(run()) + await self.create_client( + { + "cluster_name": self.cluster.id, + "cluster_connect_timeout": 5.0, + } + ) + + async def test_restart_member(self): + member = await asyncio.to_thread(self.cluster.start_member) + client = await self.create_client( + { + "cluster_name": self.cluster.id, + "cluster_connect_timeout": 5.0, + } + ) + + state = [None] + + def listener(s): + state[0] = s + + client.lifecycle_service.add_listener(listener) + + await asyncio.to_thread(member.shutdown) + await self.assertTrueEventually( + lambda: self.assertEqual(state[0], LifecycleState.DISCONNECTED) + ) + await asyncio.to_thread(self.cluster.start_member) + await self.assertTrueEventually( + lambda: self.assertEqual(state[0], LifecycleState.CONNECTED) + ) + + async def test_listener_re_register(self): + member = await asyncio.to_thread(self.cluster.start_member) + client = await self.create_client( + { + "cluster_name": self.cluster.id, + "cluster_connect_timeout": 5.0, + } + ) + map = await client.get_map("map") + collector = event_collector() + reg_id = await map.add_entry_listener(added_func=collector) + self.logger.info("Registered listener with id %s", reg_id) + await asyncio.to_thread(member.shutdown) + await asyncio.to_thread(self.cluster.start_member) + count = AtomicInteger() + + async def assert_events(): + if client.lifecycle_service.is_running(): + while True: + try: + await map.put("key-%d" % count.get_and_increment(), "value") + break + except TargetDisconnectedError: + pass + self.assertGreater(len(collector.events), 0) + else: + self.fail("Client disconnected...") + + await self.assertTrueEventually(assert_events) + + async def test_member_list_after_reconnect(self): + old_member = await asyncio.to_thread(self.cluster.start_member) + client = await self.create_client( + { + "cluster_name": self.cluster.id, + "cluster_connect_timeout": 5.0, + } + ) + await asyncio.to_thread(old_member.shutdown) + new_member = await asyncio.to_thread(self.cluster.start_member) + + def assert_member_list(): + members = client.cluster_service.get_members() + self.assertEqual(1, len(members)) + self.assertEqual(new_member.uuid, str(members[0].uuid)) + + await self.assertTrueEventually(assert_member_list) + + async def test_reconnect_toNewNode_ViaLastMemberList(self): + old_member = await asyncio.to_thread(self.cluster.start_member) + client = await self.create_client( + { + "cluster_name": self.cluster.id, + "cluster_members": [ + "127.0.0.1:5701", + ], + "smart_routing": False, + "cluster_connect_timeout": 10.0, + } + ) + new_member = await asyncio.to_thread(self.cluster.start_member) + await asyncio.to_thread(old_member.shutdown) + + def assert_member_list(): + members = client.cluster_service.get_members() + self.assertEqual(1, len(members)) + self.assertEqual(new_member.uuid, str(members[0].uuid)) + + await self.assertTrueEventually(assert_member_list) + + +class ReconnectWithDifferentInterfacesTest(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + def _create_cluster_config(self, public_address, heartbeat_seconds=300): + return """ + + + %s + + + %d + + """ % ( + public_address, + heartbeat_seconds, + ) + + def setUp(self): + self.rc = self.create_rc() + self.client = None + + async def asyncTearDown(self): + if self.client: + # If the test is failed, and we couldn't shutdown + # the client, try to shutdown here to make sure that + # we are not going to affect other tests. If the client + # is already shutdown, then this is basically no-op. + await self.client.shutdown() + + self.rc.exit() + + async def test_connection_count_after_reconnect_with_member_hostname_client_ip(self): + await self._verify_connection_count_after_reconnect("localhost", "127.0.0.1") + + async def test_connection_count_after_reconnect_with_member_hostname_client_hostname(self): + await self._verify_connection_count_after_reconnect("localhost", "localhost") + + async def test_connection_count_after_reconnect_with_member_ip_client_ip(self): + await self._verify_connection_count_after_reconnect("127.0.0.1", "127.0.0.1") + + async def test_connection_count_after_reconnect_with_member_ip_client_hostname(self): + await self._verify_connection_count_after_reconnect("127.0.0.1", "localhost") + + async def test_listeners_after_client_disconnected_with_member_hostname_client_ip(self): + await self._verify_listeners_after_client_disconnected("localhost", "127.0.0.1") + + async def test_listeners_after_client_disconnected_with_member_hostname_client_hostname(self): + await self._verify_listeners_after_client_disconnected("localhost", "localhost") + + async def test_listeners_after_client_disconnected_with_member_ip_client_ip(self): + await self._verify_listeners_after_client_disconnected("127.0.0.1", "127.0.0.1") + + async def test_listeners_after_client_disconnected_with_member_ip_client_hostname(self): + await self._verify_listeners_after_client_disconnected("127.0.0.1", "localhost") + + async def _verify_connection_count_after_reconnect(self, member_address, client_address): + cluster = await asyncio.to_thread( + self.create_cluster, self.rc, self._create_cluster_config(member_address) + ) + member = await asyncio.to_thread(cluster.start_member) + + disconnected = asyncio.Event() + reconnected = asyncio.Event() + + def listener(state): + if state == "DISCONNECTED": + disconnected.set() + + if state == "CONNECTED" and disconnected.is_set(): + reconnected.set() + + client = await HazelcastClient.create_and_start( + cluster_name=cluster.id, + cluster_members=[client_address], + cluster_connect_timeout=sys.maxsize, + lifecycle_listeners=[listener], + ) + + self.client = client + await self.assertTrueEventually( + lambda: self.assertEqual(1, len(client._connection_manager.active_connections)) + ) + await asyncio.to_thread(member.shutdown) + await self.assertTrueEventually(lambda: self.assertTrue(disconnected.is_set())) + await asyncio.to_thread(cluster.start_member) + await self.assertTrueEventually(lambda: self.assertTrue(reconnected.is_set())) + self.assertEqual(1, len(client._connection_manager.active_connections)) + await client.shutdown() + await asyncio.to_thread(self.rc.terminateCluster, cluster.id) + + async def _verify_listeners_after_client_disconnected(self, member_address, client_address): + heartbeat_seconds = 2 + cluster = self.create_cluster( + self.rc, self._create_cluster_config(member_address, heartbeat_seconds) + ) + member = cluster.start_member() + client = await HazelcastClient.create_and_start( + cluster_name=cluster.id, + cluster_members=[client_address], + cluster_connect_timeout=sys.maxsize, + ) + self.client = client + test_map = await client.get_map("test") + event_count = AtomicInteger() + await test_map.add_entry_listener( + added_func=lambda _: event_count.get_and_increment(), include_value=False + ) + await self.assertTrueEventually( + lambda: self.assertEqual(1, len(client._connection_manager.active_connections)) + ) + member.shutdown() + await asyncio.sleep(2 * heartbeat_seconds) + cluster.start_member() + + async def assertion(): + await test_map.remove(1) + await test_map.put(1, 2) + self.assertNotEqual(0, event_count.get()) + + await self.assertTrueEventually(assertion) + + await client.shutdown() + self.rc.terminateCluster(cluster.id) diff --git a/tests/integration/asyncio/shutdown_test.py b/tests/integration/asyncio/shutdown_test.py new file mode 100644 index 0000000000..920fbef0c1 --- /dev/null +++ b/tests/integration/asyncio/shutdown_test.py @@ -0,0 +1,56 @@ +import asyncio +import unittest + +from hazelcast.errors import HazelcastClientNotActiveError +from tests.integration.asyncio.base import HazelcastTestCase + + +class ShutdownTest(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + rc = None + + def setUp(self): + self.rc = self.create_rc() + self.cluster = self.create_cluster(self.rc) + + async def asyncTearDown(self): + await self.shutdown_all_clients() + self.rc.terminateCluster(self.cluster.id) + self.rc.exit() + + async def test_shutdown_not_hang_on_member_closed(self): + member = self.cluster.start_member() + client = await self.create_client( + { + "cluster_name": self.cluster.id, + "cluster_connect_timeout": 5.0, + } + ) + my_map = await client.get_map("test") + await my_map.put("key", "value") + member.shutdown() + with self.assertRaises(HazelcastClientNotActiveError): + while True: + await my_map.get("key") + + async def test_invocations_finalised_when_client_shutdowns(self): + self.cluster.start_member() + client = await self.create_client( + { + "cluster_name": self.cluster.id, + } + ) + m = await client.get_map("test") + await m.put("key", "value") + + async def run(): + for _ in range(1000): + try: + await m.get("key") + except Exception: + pass + + async with asyncio.TaskGroup() as tg: + for _ in range(10): + tg.create_task(run()) + + await client.shutdown() diff --git a/tests/integration/asyncio/smart_listener_test.py b/tests/integration/asyncio/smart_listener_test.py new file mode 100644 index 0000000000..18ca118f2e --- /dev/null +++ b/tests/integration/asyncio/smart_listener_test.py @@ -0,0 +1,45 @@ +import asyncio +import unittest + +from tests.integration.asyncio.base import HazelcastTestCase +from tests.util import random_string, event_collector + + +class SmartListenerTest(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + + rc = None + cluster = None + + @classmethod + def setUpClass(cls): + cls.rc = cls.create_rc() + cls.cluster = cls.create_cluster(cls.rc, None) + cls.cluster.start_member() + cls.cluster.start_member() + + @classmethod + def tearDownClass(cls): + cls.rc.terminateCluster(cls.cluster.id) + cls.rc.exit() + + async def asyncSetUp(self): + self.client = await self.create_client( + { + "cluster_name": self.cluster.id, + "smart_routing": True, + } + ) + self.collector = event_collector() + + async def asyncTearDown(self): + await self.shutdown_all_clients() + + async def test_map_smart_listener_local_only(self): + map = await self.client.get_map(random_string()) + await map.add_entry_listener(added_func=self.collector) + await map.put("key", "value") + await self.assert_event_received_once() + + async def assert_event_received_once(self): + await asyncio.sleep(2) + self.assertEqual(1, len(self.collector.events)) diff --git a/tests/integration/asyncio/statistics_test.py b/tests/integration/asyncio/statistics_test.py new file mode 100644 index 0000000000..01b22abd0f --- /dev/null +++ b/tests/integration/asyncio/statistics_test.py @@ -0,0 +1,266 @@ +import asyncio +import itertools +import unittest +import zlib + +from hazelcast import __version__ +from hazelcast.asyncio import HazelcastClient +from hazelcast.core import CLIENT_TYPE +from hazelcast.serialization import BE_INT, INT_SIZE_IN_BYTES +from hazelcast.statistics import Statistics +from tests.integration.asyncio.base import HazelcastTestCase +from tests.hzrc.ttypes import Lang +from tests.util import get_current_timestamp, random_string, skip_if_client_version_older_than + + +class StatisticsTest(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + DEFAULT_STATS_PERIOD = 3 + STATS_PERIOD = 1 + + @classmethod + def setUpClass(cls): + cls.rc = cls.create_rc() + cls.cluster = cls.create_cluster(cls.rc) + cls.member = cls.cluster.start_member() + + @classmethod + def tearDownClass(cls): + cls.rc.exit() + + async def test_statistics_disabled_by_default(self): + client = await HazelcastClient.create_and_start( + cluster_name=self.cluster.id, cluster_connect_timeout=30.0 + ) + await asyncio.sleep(2 * self.DEFAULT_STATS_PERIOD) + client_uuid = client._connection_manager.client_uuid + response = self.get_client_stats_from_server(client_uuid) + self.assertTrue(response.success) + self.assertIsNone(response.result) + await client.shutdown() + + async def test_statistics_enabled(self): + client = await HazelcastClient.create_and_start( + cluster_name=self.cluster.id, cluster_connect_timeout=30.0, statistics_enabled=True + ) + client_uuid = client._connection_manager.client_uuid + await asyncio.sleep(2 * self.DEFAULT_STATS_PERIOD) + await self.wait_for_statistics_collection(client_uuid) + await client.shutdown() + + async def test_statistics_period(self): + client = await HazelcastClient.create_and_start( + cluster_name=self.cluster.id, + cluster_connect_timeout=30.0, + statistics_enabled=True, + statistics_period=self.STATS_PERIOD, + ) + client_uuid = client._connection_manager.client_uuid + await asyncio.sleep(2 * self.STATS_PERIOD) + response1 = await self.wait_for_statistics_collection(client_uuid) + await asyncio.sleep(2 * self.STATS_PERIOD) + response2 = await self.wait_for_statistics_collection(client_uuid) + self.assertNotEqual(response1, response2) + await client.shutdown() + + async def test_statistics_content(self): + map_name = random_string() + client = await HazelcastClient.create_and_start( + cluster_name=self.cluster.id, + cluster_connect_timeout=30.0, + statistics_enabled=True, + statistics_period=self.STATS_PERIOD, + near_caches={ + map_name: {}, + }, + ) + client_uuid = client._connection_manager.client_uuid + await client.get_map(map_name) + await asyncio.sleep(2 * self.STATS_PERIOD) + response = await self.wait_for_statistics_collection(client_uuid) + result = response.result.decode("utf-8") + info = client._internal_cluster_service.get_local_client() + local_address = "%s:%s" % (info.address.host, info.address.port) + # Check near cache and client statistics + self.assertEqual(1, result.count("clientName=" + client.name)) + self.assertEqual(1, result.count("lastStatisticsCollectionTime=")) + self.assertEqual(1, result.count("enterprise=false")) + self.assertEqual(1, result.count("clientType=" + CLIENT_TYPE)) + self.assertEqual(1, result.count("clientVersion=" + __version__)) + self.assertEqual(1, result.count("clusterConnectionTimestamp=")) + self.assertEqual(1, result.count("clientAddress=" + local_address)) + self.assertEqual(1, result.count("nc." + map_name + ".creationTime")) + self.assertEqual(1, result.count("nc." + map_name + ".evictions")) + self.assertEqual(1, result.count("nc." + map_name + ".hits")) + self.assertEqual(1, result.count("nc." + map_name + ".misses")) + self.assertEqual(1, result.count("nc." + map_name + ".ownedEntryCount")) + self.assertEqual(1, result.count("nc." + map_name + ".expirations")) + self.assertEqual(1, result.count("nc." + map_name + ".invalidations")) + self.assertEqual(1, result.count("nc." + map_name + ".invalidationRequests")) + self.assertEqual(1, result.count("nc." + map_name + ".ownedEntryMemoryCost")) + # Check OS and runtime statistics. We cannot know what kind of statistics will be available + # in different platforms. So, first try to get these statistics and then check the + # response content + for stat_name in self.get_runtime_and_system_metrics(client): + self.assertEqual(1, result.count(stat_name)) + + await client.shutdown() + + async def test_special_characters(self): + map_name = random_string() + ",t=es\\t" + client = await HazelcastClient.create_and_start( + cluster_name=self.cluster.id, + cluster_connect_timeout=30.0, + statistics_enabled=True, + statistics_period=self.STATS_PERIOD, + near_caches={ + map_name: {}, + }, + ) + client_uuid = client._connection_manager.client_uuid + await client.get_map(map_name) + await asyncio.sleep(2 * self.STATS_PERIOD) + response = await self.wait_for_statistics_collection(client_uuid) + result = response.result.decode("utf-8") + unescaped_result = self.unescape_special_chars(result) + self.assertEqual(-1, result.find(map_name)) + self.assertNotEqual(-1, unescaped_result.find(map_name)) + await client.shutdown() + + async def test_near_cache_stats(self): + map_name = random_string() + client = await HazelcastClient.create_and_start( + cluster_name=self.cluster.id, + cluster_connect_timeout=30.0, + statistics_enabled=True, + statistics_period=self.STATS_PERIOD, + near_caches={ + map_name: {}, + }, + ) + client_uuid = client._connection_manager.client_uuid + test_map = await client.get_map(map_name) + await asyncio.sleep(2 * self.STATS_PERIOD) + response = await self.wait_for_statistics_collection(client_uuid) + result = response.result.decode("utf-8") + self.assertEqual(1, result.count("nc." + map_name + ".evictions=0")) + self.assertEqual(1, result.count("nc." + map_name + ".hits=0")) + self.assertEqual(1, result.count("nc." + map_name + ".misses=0")) + self.assertEqual(1, result.count("nc." + map_name + ".ownedEntryCount=0")) + self.assertEqual(1, result.count("nc." + map_name + ".expirations=0")) + self.assertEqual(1, result.count("nc." + map_name + ".invalidations=0")) + self.assertEqual(1, result.count("nc." + map_name + ".invalidationRequests=0")) + await test_map.put(1, 2) # invalidation request + await test_map.get(1) # cache miss + await test_map.get(1) # cache hit + await test_map.put(1, 3) # invalidation + invalidation request + await test_map.get(1) # cache miss + await asyncio.sleep(2 * self.STATS_PERIOD) + response = await self.wait_for_statistics_collection(client_uuid) + result = response.result.decode("utf-8") + self.assertEqual(1, result.count("nc." + map_name + ".evictions=0")) + self.assertEqual(1, result.count("nc." + map_name + ".hits=1")) + self.assertEqual(1, result.count("nc." + map_name + ".misses=2")) + self.assertEqual(1, result.count("nc." + map_name + ".ownedEntryCount=1")) + self.assertEqual(1, result.count("nc." + map_name + ".expirations=0")) + self.assertEqual(1, result.count("nc." + map_name + ".invalidations=1")) + self.assertEqual(1, result.count("nc." + map_name + ".invalidationRequests=2")) + await client.shutdown() + + async def test_metrics_blob(self): + skip_if_client_version_older_than(self, "4.2.1") + map_name = random_string() + client = await HazelcastClient.create_and_start( + cluster_name=self.cluster.id, + cluster_connect_timeout=30.0, + statistics_enabled=True, + statistics_period=self.STATS_PERIOD, + near_caches={ + map_name: {}, + }, + ) + client_uuid = client._connection_manager.client_uuid + await client.get_map(map_name) + await asyncio.sleep(2 * self.STATS_PERIOD) + response = await self.wait_for_statistics_collection(client_uuid, get_metric_blob=True) + result = bytearray(response.result) + # We will try to decompress the blob according to its contract + # to verify we have sent something that make sense + pos = 2 # Skip the version + dict_buf_size = BE_INT.unpack_from(result, pos)[0] + pos += INT_SIZE_IN_BYTES + dict_buf = result[pos : pos + dict_buf_size] + self.assertTrue(len(dict_buf) > 0) + pos += dict_buf_size + pos += INT_SIZE_IN_BYTES # Skip metric count + metrics_buf = result[pos:] + self.assertTrue(len(metrics_buf) > 0) + # If we are able to decompress it, we count the blob + # as valid. + zlib.decompress(dict_buf) + zlib.decompress(metrics_buf) + await client.shutdown() + + def get_metrics_blob(self, client_uuid): + script = ( + """ + stats = instance_0.getOriginal().node.getClientEngine().getClientStatistics(); + keys = stats.keySet().toArray(); + for(i=0; i < keys.length; i++) { + if (keys[i].toString().equals("%s")) { + result = stats.get(keys[i]).metricsBlob(); + break; + } + }""" + % client_uuid + ) + + return self.rc.executeOnController(self.cluster.id, script, Lang.JAVASCRIPT) + + def get_client_stats_from_server(self, client_uuid): + script = ( + """ + stats = instance_0.getOriginal().node.getClientEngine().getClientStatistics(); + keys = stats.keySet().toArray(); + for(i=0; i < keys.length; i++) { + if (keys[i].toString().equals("%s")) { + result = stats.get(keys[i]).clientAttributes(); + break; + } + }""" + % client_uuid + ) + + return self.rc.executeOnController(self.cluster.id, script, Lang.JAVASCRIPT) + + def unescape_special_chars(self, value): + return ( + value.replace("\\,", ",").replace("\\=", "=").replace("\\.", ".").replace("\\\\", "\\") + ) + + def verify_response_not_empty(self, response): + if not response.success or response.result is None: + raise AssertionError + + async def wait_for_statistics_collection(self, client_uuid, timeout=30, get_metric_blob=False): + timeout_time = get_current_timestamp() + timeout + while get_current_timestamp() < timeout_time: + if get_metric_blob: + response = self.get_metrics_blob(client_uuid) + else: + response = self.get_client_stats_from_server(client_uuid) + + try: + self.verify_response_not_empty(response) + return response + except AssertionError: + await asyncio.sleep(0.1) + + raise AssertionError + + def get_runtime_and_system_metrics(self, client): + s = Statistics(client, client._config, None, None, None, None) + try: + # Compatibility for <4.2.1 clients + return s._get_os_and_runtime_stats() + except: + return itertools.chain(s._registered_system_gauges, s._registered_process_gauges) diff --git a/tests/integration/asyncio/util.py b/tests/integration/asyncio/util.py index e101a58103..6a15d9c8ec 100644 --- a/tests/integration/asyncio/util.py +++ b/tests/integration/asyncio/util.py @@ -1,6 +1,34 @@ +from uuid import uuid4 + +import asyncio + + async def fill_map(map, size=10, key_prefix="key", value_prefix="val"): entries = dict() for i in range(size): entries[key_prefix + str(i)] = value_prefix + str(i) await map.put_all(entries) return entries + + +async def open_connection_to_address(client, uuid): + key = generate_key_owned_by_instance(client, uuid) + m = await client.get_map(str(uuid4())) + await m.put(key, 0) + await m.destroy() + + +def generate_key_owned_by_instance(client, uuid): + while True: + key = str(uuid4()) + partition_id = client.partition_service.get_partition_id(key) + owner = str(client.partition_service.get_partition_owner(partition_id)) + if owner == uuid: + return key + + +async def wait_for_partition_table(client): + m = await client.get_map(str(uuid4())) + while not client.partition_service.get_partition_owner(0): + await m.put(str(uuid4()), 0) + await asyncio.sleep(0.1)