diff --git a/tests/integration/standard/test_scylla_cloud.py b/tests/integration/standard/test_scylla_cloud.py index d1a22f882..02302e9be 100644 --- a/tests/integration/standard/test_scylla_cloud.py +++ b/tests/integration/standard/test_scylla_cloud.py @@ -1,15 +1,24 @@ +import json import logging import os.path from unittest import TestCase from ccmlib.utils.ssl_utils import generate_ssl_stores -from ccmlib.utils.sni_proxy import refresh_certs, get_cluster_info, start_sni_proxy, create_cloud_config +from ccmlib.utils.sni_proxy import refresh_certs, start_sni_proxy, create_cloud_config, NodeInfo -from tests.integration import use_cluster +from cassandra.policies import TokenAwarePolicy, RoundRobinPolicy, ConstantReconnectionPolicy +from tests.integration import use_cluster, PROTOCOL_VERSION from cassandra.cluster import Cluster, TwistedConnection -from cassandra.io.libevreactor import LibevConnection -supported_connection_classes = [LibevConnection, TwistedConnection] +supported_connection_classes = [TwistedConnection] + +try: + from cassandra.io.libevreactor import LibevConnection + supported_connection_classes += [LibevConnection] +except ImportError: + pass + + try: from cassandra.io.asyncorereactor import AsyncoreConnection supported_connection_classes += [AsyncoreConnection] @@ -22,6 +31,32 @@ # need to run them with specific configuration like `gevent.monkey.patch_all()` or under async functions # unsupported_connection_classes = [GeventConnection, AsyncioConnection, EventletConnection] +LOGGER = logging.getLogger(__name__) + + +def get_cluster_info(cluster, port=9142): + session = Cluster( + contact_points=list(map(lambda node: node.address(), cluster.nodelist())), protocol_version=PROTOCOL_VERSION, + load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()), + reconnection_policy=ConstantReconnectionPolicy(5) + ).connect() + + nodes_info = [] + + for row in session.execute('select host_id, broadcast_address, data_center from system.local'): + if row[0] and row[1]: + nodes_info.append(NodeInfo(address=row[1], + port=port, + host_id=row[0], + data_center=row[2])) + + for row in session.execute('select host_id, broadcast_address, data_center from system.local'): + nodes_info.append(NodeInfo(address=row[1], + port=port, + host_id=row[0], + data_center=row[2])) + + return nodes_info class ScyllaCloudConfigTests(TestCase):