From 69191acc057e1df747db648d8dea9241fd5ddf85 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Sat, 4 Jan 2025 08:28:23 -0400 Subject: [PATCH 1/2] test_tablets: stop using unittest Unittest interfere into pytest ordering, which will be needed in next commit --- tests/integration/experiments/test_tablets.py | 26 ++++++++----------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/tests/integration/experiments/test_tablets.py b/tests/integration/experiments/test_tablets.py index 98e65c538..7ba3fcb48 100644 --- a/tests/integration/experiments/test_tablets.py +++ b/tests/integration/experiments/test_tablets.py @@ -1,7 +1,3 @@ -import time -import unittest -import pytest -import os from cassandra.cluster import Cluster from cassandra.policies import ConstantReconnectionPolicy, RoundRobinPolicy, TokenAwarePolicy @@ -11,7 +7,7 @@ def setup_module(): use_cluster('tablets', [3], start=True) -class TestTabletsIntegration(unittest.TestCase): +class TestTabletsIntegration: @classmethod def setup_class(cls): cls.cluster = Cluster(contact_points=["127.0.0.1", "127.0.0.2", "127.0.0.3"], protocol_version=PROTOCOL_VERSION, @@ -33,8 +29,8 @@ def verify_same_host_in_tracing(self, results): LOGGER.info("TRACE EVENT: %s %s %s", event.source, event.thread_name, event.description) host_set.add(event.source) - self.assertEqual(len(host_set), 1) - self.assertIn('locally', "\n".join([event.description for event in events])) + assert len(host_set) == 1 + assert 'locally' in "\n".join([event.description for event in events]) trace_id = results.response_future.get_query_trace_ids()[0] traces = self.session.execute("SELECT * FROM system_traces.events WHERE session_id = %s", (trace_id,)) @@ -44,8 +40,8 @@ def verify_same_host_in_tracing(self, results): LOGGER.info("TRACE EVENT: %s %s", event.source, event.activity) host_set.add(event.source) - self.assertEqual(len(host_set), 1) - self.assertIn('locally', "\n".join([event.activity for event in events])) + assert len(host_set) == 1 + assert 'locally' in "\n".join([event.activity for event in events]) def verify_same_shard_in_tracing(self, results): traces = results.get_query_trace() @@ -55,8 +51,8 @@ def verify_same_shard_in_tracing(self, results): LOGGER.info("TRACE EVENT: %s %s %s", event.source, event.thread_name, event.description) shard_set.add(event.thread_name) - self.assertEqual(len(shard_set), 1) - self.assertIn('locally', "\n".join([event.description for event in events])) + assert len(shard_set) == 1 + assert 'locally' in "\n".join([event.description for event in events]) trace_id = results.response_future.get_query_trace_ids()[0] traces = self.session.execute("SELECT * FROM system_traces.events WHERE session_id = %s", (trace_id,)) @@ -66,8 +62,8 @@ def verify_same_shard_in_tracing(self, results): LOGGER.info("TRACE EVENT: %s %s", event.thread, event.activity) shard_set.add(event.thread) - self.assertEqual(len(shard_set), 1) - self.assertIn('locally', "\n".join([event.activity for event in events])) + assert len(shard_set) == 1 + assert 'locally' in "\n".join([event.activity for event in events]) def create_ks_and_cf(self): self.session.execute( @@ -110,7 +106,7 @@ def query_data_shard_select(self, session, verify_in_tracing=True): bound = prepared.bind([(2)]) results = session.execute(bound, trace=True) - self.assertEqual(results, [(2, 2, 0)]) + assert results == [(2, 2, 0)] if verify_in_tracing: self.verify_same_shard_in_tracing(results) @@ -122,7 +118,7 @@ def query_data_host_select(self, session, verify_in_tracing=True): bound = prepared.bind([(2)]) results = session.execute(bound, trace=True) - self.assertEqual(results, [(2, 2, 0)]) + assert results == [(2, 2, 0)] if verify_in_tracing: self.verify_same_host_in_tracing(results) From be0bcf2a31ecd6bb61bd70e68310955bda8f3f54 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Fri, 3 Jan 2025 20:18:12 -0400 Subject: [PATCH 2/2] Invalidate tablets when table or keyspace is deleted Delete tablets for table or keyspace when one is deleted. When host is removed from cluster delete all tablets that have this host in it. Ensure that if it happens when control connection is reconnection. --- cassandra/metadata.py | 13 ++- cassandra/tablets.py | 35 ++++++ tests/integration/experiments/test_tablets.py | 105 ++++++++++++++++-- 3 files changed, 140 insertions(+), 13 deletions(-) diff --git a/cassandra/metadata.py b/cassandra/metadata.py index 18d424978..30bcf8165 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -26,6 +26,7 @@ import struct import random import itertools +from typing import Optional murmur3 = None try: @@ -168,10 +169,13 @@ def _rebuild_all(self, parser): current_keyspaces = set() for keyspace_meta in parser.get_all_keyspaces(): current_keyspaces.add(keyspace_meta.name) - old_keyspace_meta = self.keyspaces.get(keyspace_meta.name, None) + old_keyspace_meta: Optional[KeyspaceMetadata] = self.keyspaces.get(keyspace_meta.name, None) self.keyspaces[keyspace_meta.name] = keyspace_meta if old_keyspace_meta: self._keyspace_updated(keyspace_meta.name) + for table_name in old_keyspace_meta.tables.keys(): + if table_name not in keyspace_meta.tables: + self._table_removed(keyspace_meta.name, table_name) else: self._keyspace_added(keyspace_meta.name) @@ -265,6 +269,9 @@ def _drop_aggregate(self, keyspace, aggregate): except KeyError: pass + def _table_removed(self, keyspace, table): + self._tablets.drop_tablets(keyspace, table) + def _keyspace_added(self, ksname): if self.token_map: self.token_map.rebuild_keyspace(ksname, build_if_absent=False) @@ -272,10 +279,12 @@ def _keyspace_added(self, ksname): def _keyspace_updated(self, ksname): if self.token_map: self.token_map.rebuild_keyspace(ksname, build_if_absent=False) + self._tablets.drop_tablets(ksname) def _keyspace_removed(self, ksname): if self.token_map: self.token_map.remove_keyspace(ksname) + self._tablets.drop_tablets(ksname) def rebuild_token_map(self, partitioner, token_map): """ @@ -340,11 +349,13 @@ def add_or_return_host(self, host): return host, True def remove_host(self, host): + self._tablets.drop_tablets_by_host_id(host.host_id) with self._hosts_lock: self._host_id_by_endpoint.pop(host.endpoint, False) return bool(self._hosts.pop(host.host_id, False)) def remove_host_by_host_id(self, host_id, endpoint=None): + self._tablets.drop_tablets_by_host_id(host_id) with self._hosts_lock: if endpoint and self._host_id_by_endpoint[endpoint] == host_id: self._host_id_by_endpoint.pop(endpoint, False) diff --git a/cassandra/tablets.py b/cassandra/tablets.py index 61394eace..457ee93ca 100644 --- a/cassandra/tablets.py +++ b/cassandra/tablets.py @@ -1,4 +1,6 @@ from threading import Lock +from typing import Optional +from uuid import UUID class Tablet(object): @@ -32,6 +34,12 @@ def from_row(first_token, last_token, replicas): return tablet return None + def replica_contains_host_id(self, uuid: UUID) -> bool: + for replica in self.replicas: + if replica[0] == uuid: + return True + return False + class Tablets(object): _lock = None @@ -51,6 +59,33 @@ def get_tablet_for_key(self, keyspace, table, t): return tablet[id] return None + def drop_tablets(self, keyspace: str, table: Optional[str] = None): + with self._lock: + if table is not None: + self._tablets.pop((keyspace, table), None) + return + + to_be_deleted = [] + for key in self._tablets.keys(): + if key[0] == keyspace: + to_be_deleted.append(key) + + for key in to_be_deleted: + del self._tablets[key] + + def drop_tablets_by_host_id(self, host_id: Optional[UUID]): + if host_id is None: + return + with self._lock: + for key, tablets in self._tablets.items(): + to_be_deleted = [] + for tablet_id, tablet in enumerate(tablets): + if tablet.replica_contains_host_id(host_id): + to_be_deleted.append(tablet_id) + + for tablet_id in reversed(to_be_deleted): + tablets.pop(tablet_id) + def add_tablet(self, keyspace, table, tablet): with self._lock: tablets_for_table = self._tablets.setdefault((keyspace, table), []) diff --git a/tests/integration/experiments/test_tablets.py b/tests/integration/experiments/test_tablets.py index 7ba3fcb48..79dd16660 100644 --- a/tests/integration/experiments/test_tablets.py +++ b/tests/integration/experiments/test_tablets.py @@ -1,11 +1,20 @@ +import time + +import pytest + from cassandra.cluster import Cluster from cassandra.policies import ConstantReconnectionPolicy, RoundRobinPolicy, TokenAwarePolicy from tests.integration import PROTOCOL_VERSION, use_cluster from tests.unit.test_host_connection_pool import LOGGER +CCM_CLUSTER = None + def setup_module(): - use_cluster('tablets', [3], start=True) + global CCM_CLUSTER + + CCM_CLUSTER = use_cluster('tablets', [3], start=True) + class TestTabletsIntegration: @classmethod @@ -14,14 +23,14 @@ def setup_class(cls): load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()), reconnection_policy=ConstantReconnectionPolicy(1)) cls.session = cls.cluster.connect() - cls.create_ks_and_cf(cls) + cls.create_ks_and_cf(cls.session) cls.create_data(cls.session) @classmethod def teardown_class(cls): cls.cluster.shutdown() - def verify_same_host_in_tracing(self, results): + def verify_hosts_in_tracing(self, results, expected): traces = results.get_query_trace() events = traces.events host_set = set() @@ -29,7 +38,7 @@ def verify_same_host_in_tracing(self, results): LOGGER.info("TRACE EVENT: %s %s %s", event.source, event.thread_name, event.description) host_set.add(event.source) - assert len(host_set) == 1 + assert len(host_set) == expected assert 'locally' in "\n".join([event.description for event in events]) trace_id = results.response_future.get_query_trace_ids()[0] @@ -40,9 +49,13 @@ def verify_same_host_in_tracing(self, results): LOGGER.info("TRACE EVENT: %s %s", event.source, event.activity) host_set.add(event.source) - assert len(host_set) == 1 + assert len(host_set) == expected assert 'locally' in "\n".join([event.activity for event in events]) + def get_tablet_record(self, query): + metadata = self.session.cluster.metadata + return metadata._tablets.get_tablet_for_key(query.keyspace, query.table, metadata.token_map.token_class.from_key(query.routing_key)) + def verify_same_shard_in_tracing(self, results): traces = results.get_query_trace() events = traces.events @@ -65,24 +78,25 @@ def verify_same_shard_in_tracing(self, results): assert len(shard_set) == 1 assert 'locally' in "\n".join([event.activity for event in events]) - def create_ks_and_cf(self): - self.session.execute( + @classmethod + def create_ks_and_cf(cls, session): + session.execute( """ DROP KEYSPACE IF EXISTS test1 """ ) - self.session.execute( + session.execute( """ CREATE KEYSPACE test1 WITH replication = { 'class': 'NetworkTopologyStrategy', - 'replication_factor': 1 + 'replication_factor': 2 } AND tablets = { 'initial': 8 } """) - self.session.execute( + session.execute( """ CREATE TABLE test1.table1 (pk int, ck int, v int, PRIMARY KEY (pk, ck)); """) @@ -120,7 +134,7 @@ def query_data_host_select(self, session, verify_in_tracing=True): results = session.execute(bound, trace=True) assert results == [(2, 2, 0)] if verify_in_tracing: - self.verify_same_host_in_tracing(results) + self.verify_hosts_in_tracing(results, 1) def query_data_shard_insert(self, session, verify_in_tracing=True): prepared = session.prepare( @@ -142,7 +156,7 @@ def query_data_host_insert(self, session, verify_in_tracing=True): bound = prepared.bind([(52), (1), (2)]) results = session.execute(bound, trace=True) if verify_in_tracing: - self.verify_same_host_in_tracing(results) + self.verify_hosts_in_tracing(results, 2) def test_tablets(self): self.query_data_host_select(self.session) @@ -151,3 +165,70 @@ def test_tablets(self): def test_tablets_shard_awareness(self): self.query_data_shard_select(self.session) self.query_data_shard_insert(self.session) + + def test_tablets_invalidation_drop_ks_while_reconnecting(self): + def recreate_while_reconnecting(_): + # Kill control connection + conn = self.session.cluster.control_connection._connection + self.session.cluster.control_connection._connection = None + conn.close() + + # Drop and recreate ks and table to trigger tablets invalidation + self.create_ks_and_cf(self.cluster.connect()) + + # Start control connection + self.session.cluster.control_connection._reconnect() + + self.run_tablets_invalidation_test(recreate_while_reconnecting) + + def test_tablets_invalidation_drop_ks(self): + def drop_ks(_): + # Drop and recreate ks and table to trigger tablets invalidation + self.create_ks_and_cf(self.cluster.connect()) + time.sleep(3) + + self.run_tablets_invalidation_test(drop_ks) + + @pytest.mark.last + def test_tablets_invalidation_decommission_non_cc_node(self): + def decommission_non_cc_node(rec): + # Drop and recreate ks and table to trigger tablets invalidation + for node in CCM_CLUSTER.nodes.values(): + if self.cluster.control_connection._connection.endpoint.address == node.network_interfaces["storage"][0]: + # Ignore node that control connection is connected to + continue + for replica in rec.replicas: + if str(replica[0]) == str(node.node_hostid): + node.decommission() + break + else: + continue + break + else: + assert False, "failed to find node to decommission" + time.sleep(10) + + self.run_tablets_invalidation_test(decommission_non_cc_node) + + + def run_tablets_invalidation_test(self, invalidate): + # Make sure driver holds tablet info + # By landing query to the host that is not in replica set + bound = self.session.prepare( + """ + SELECT pk, ck, v FROM test1.table1 WHERE pk = ? + """).bind([(2)]) + + rec = None + for host in self.cluster.metadata.all_hosts(): + self.session.execute(bound, host=host) + rec = self.get_tablet_record(bound) + if rec is not None: + break + + assert rec is not None, "failed to find tablet record" + + invalidate(rec) + + # Check if tablets information was purged + assert self.get_tablet_record(bound) is None, "tablet was not deleted, invalidation did not work"