Skip to content

Commit

Permalink
Merge pull request #399 from scylladb/dk/invalidate-tablets
Browse files Browse the repository at this point in the history
Invalidate tablets when table or keyspace is deleted
  • Loading branch information
dkropachev authored Jan 5, 2025
2 parents d62eb38 + be0bcf2 commit fdfc7df
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 24 deletions.
13 changes: 12 additions & 1 deletion cassandra/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import struct
import random
import itertools
from typing import Optional

murmur3 = None
try:
Expand Down Expand Up @@ -168,10 +169,13 @@ def _rebuild_all(self, parser):
current_keyspaces = set()
for keyspace_meta in parser.get_all_keyspaces():
current_keyspaces.add(keyspace_meta.name)
old_keyspace_meta = self.keyspaces.get(keyspace_meta.name, None)
old_keyspace_meta: Optional[KeyspaceMetadata] = self.keyspaces.get(keyspace_meta.name, None)
self.keyspaces[keyspace_meta.name] = keyspace_meta
if old_keyspace_meta:
self._keyspace_updated(keyspace_meta.name)
for table_name in old_keyspace_meta.tables.keys():
if table_name not in keyspace_meta.tables:
self._table_removed(keyspace_meta.name, table_name)
else:
self._keyspace_added(keyspace_meta.name)

Expand Down Expand Up @@ -265,17 +269,22 @@ def _drop_aggregate(self, keyspace, aggregate):
except KeyError:
pass

def _table_removed(self, keyspace, table):
self._tablets.drop_tablets(keyspace, table)

def _keyspace_added(self, ksname):
if self.token_map:
self.token_map.rebuild_keyspace(ksname, build_if_absent=False)

def _keyspace_updated(self, ksname):
if self.token_map:
self.token_map.rebuild_keyspace(ksname, build_if_absent=False)
self._tablets.drop_tablets(ksname)

def _keyspace_removed(self, ksname):
if self.token_map:
self.token_map.remove_keyspace(ksname)
self._tablets.drop_tablets(ksname)

def rebuild_token_map(self, partitioner, token_map):
"""
Expand Down Expand Up @@ -340,11 +349,13 @@ def add_or_return_host(self, host):
return host, True

def remove_host(self, host):
self._tablets.drop_tablets_by_host_id(host.host_id)
with self._hosts_lock:
self._host_id_by_endpoint.pop(host.endpoint, False)
return bool(self._hosts.pop(host.host_id, False))

def remove_host_by_host_id(self, host_id, endpoint=None):
self._tablets.drop_tablets_by_host_id(host_id)
with self._hosts_lock:
if endpoint and self._host_id_by_endpoint[endpoint] == host_id:
self._host_id_by_endpoint.pop(endpoint, False)
Expand Down
35 changes: 35 additions & 0 deletions cassandra/tablets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from threading import Lock
from typing import Optional
from uuid import UUID


class Tablet(object):
Expand Down Expand Up @@ -32,6 +34,12 @@ def from_row(first_token, last_token, replicas):
return tablet
return None

def replica_contains_host_id(self, uuid: UUID) -> bool:
for replica in self.replicas:
if replica[0] == uuid:
return True
return False


class Tablets(object):
_lock = None
Expand All @@ -51,6 +59,33 @@ def get_tablet_for_key(self, keyspace, table, t):
return tablet[id]
return None

def drop_tablets(self, keyspace: str, table: Optional[str] = None):
with self._lock:
if table is not None:
self._tablets.pop((keyspace, table), None)
return

to_be_deleted = []
for key in self._tablets.keys():
if key[0] == keyspace:
to_be_deleted.append(key)

for key in to_be_deleted:
del self._tablets[key]

def drop_tablets_by_host_id(self, host_id: Optional[UUID]):
if host_id is None:
return
with self._lock:
for key, tablets in self._tablets.items():
to_be_deleted = []
for tablet_id, tablet in enumerate(tablets):
if tablet.replica_contains_host_id(host_id):
to_be_deleted.append(tablet_id)

for tablet_id in reversed(to_be_deleted):
tablets.pop(tablet_id)

def add_tablet(self, keyspace, table, tablet):
with self._lock:
tablets_for_table = self._tablets.setdefault((keyspace, table), [])
Expand Down
123 changes: 100 additions & 23 deletions tests/integration/experiments/test_tablets.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,45 @@
import time
import unittest

import pytest
import os

from cassandra.cluster import Cluster
from cassandra.policies import ConstantReconnectionPolicy, RoundRobinPolicy, TokenAwarePolicy

from tests.integration import PROTOCOL_VERSION, use_cluster
from tests.unit.test_host_connection_pool import LOGGER

CCM_CLUSTER = None

def setup_module():
use_cluster('tablets', [3], start=True)
global CCM_CLUSTER

CCM_CLUSTER = use_cluster('tablets', [3], start=True)

class TestTabletsIntegration(unittest.TestCase):

class TestTabletsIntegration:
@classmethod
def setup_class(cls):
cls.cluster = Cluster(contact_points=["127.0.0.1", "127.0.0.2", "127.0.0.3"], protocol_version=PROTOCOL_VERSION,
load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()),
reconnection_policy=ConstantReconnectionPolicy(1))
cls.session = cls.cluster.connect()
cls.create_ks_and_cf(cls)
cls.create_ks_and_cf(cls.session)
cls.create_data(cls.session)

@classmethod
def teardown_class(cls):
cls.cluster.shutdown()

def verify_same_host_in_tracing(self, results):
def verify_hosts_in_tracing(self, results, expected):
traces = results.get_query_trace()
events = traces.events
host_set = set()
for event in events:
LOGGER.info("TRACE EVENT: %s %s %s", event.source, event.thread_name, event.description)
host_set.add(event.source)

self.assertEqual(len(host_set), 1)
self.assertIn('locally', "\n".join([event.description for event in events]))
assert len(host_set) == expected
assert 'locally' in "\n".join([event.description for event in events])

trace_id = results.response_future.get_query_trace_ids()[0]
traces = self.session.execute("SELECT * FROM system_traces.events WHERE session_id = %s", (trace_id,))
Expand All @@ -44,8 +49,12 @@ def verify_same_host_in_tracing(self, results):
LOGGER.info("TRACE EVENT: %s %s", event.source, event.activity)
host_set.add(event.source)

self.assertEqual(len(host_set), 1)
self.assertIn('locally', "\n".join([event.activity for event in events]))
assert len(host_set) == expected
assert 'locally' in "\n".join([event.activity for event in events])

def get_tablet_record(self, query):
metadata = self.session.cluster.metadata
return metadata._tablets.get_tablet_for_key(query.keyspace, query.table, metadata.token_map.token_class.from_key(query.routing_key))

def verify_same_shard_in_tracing(self, results):
traces = results.get_query_trace()
Expand All @@ -55,8 +64,8 @@ def verify_same_shard_in_tracing(self, results):
LOGGER.info("TRACE EVENT: %s %s %s", event.source, event.thread_name, event.description)
shard_set.add(event.thread_name)

self.assertEqual(len(shard_set), 1)
self.assertIn('locally', "\n".join([event.description for event in events]))
assert len(shard_set) == 1
assert 'locally' in "\n".join([event.description for event in events])

trace_id = results.response_future.get_query_trace_ids()[0]
traces = self.session.execute("SELECT * FROM system_traces.events WHERE session_id = %s", (trace_id,))
Expand All @@ -66,27 +75,28 @@ def verify_same_shard_in_tracing(self, results):
LOGGER.info("TRACE EVENT: %s %s", event.thread, event.activity)
shard_set.add(event.thread)

self.assertEqual(len(shard_set), 1)
self.assertIn('locally', "\n".join([event.activity for event in events]))
assert len(shard_set) == 1
assert 'locally' in "\n".join([event.activity for event in events])

def create_ks_and_cf(self):
self.session.execute(
@classmethod
def create_ks_and_cf(cls, session):
session.execute(
"""
DROP KEYSPACE IF EXISTS test1
"""
)
self.session.execute(
session.execute(
"""
CREATE KEYSPACE test1
WITH replication = {
'class': 'NetworkTopologyStrategy',
'replication_factor': 1
'replication_factor': 2
} AND tablets = {
'initial': 8
}
""")

self.session.execute(
session.execute(
"""
CREATE TABLE test1.table1 (pk int, ck int, v int, PRIMARY KEY (pk, ck));
""")
Expand All @@ -110,7 +120,7 @@ def query_data_shard_select(self, session, verify_in_tracing=True):

bound = prepared.bind([(2)])
results = session.execute(bound, trace=True)
self.assertEqual(results, [(2, 2, 0)])
assert results == [(2, 2, 0)]
if verify_in_tracing:
self.verify_same_shard_in_tracing(results)

Expand All @@ -122,9 +132,9 @@ def query_data_host_select(self, session, verify_in_tracing=True):

bound = prepared.bind([(2)])
results = session.execute(bound, trace=True)
self.assertEqual(results, [(2, 2, 0)])
assert results == [(2, 2, 0)]
if verify_in_tracing:
self.verify_same_host_in_tracing(results)
self.verify_hosts_in_tracing(results, 1)

def query_data_shard_insert(self, session, verify_in_tracing=True):
prepared = session.prepare(
Expand All @@ -146,7 +156,7 @@ def query_data_host_insert(self, session, verify_in_tracing=True):
bound = prepared.bind([(52), (1), (2)])
results = session.execute(bound, trace=True)
if verify_in_tracing:
self.verify_same_host_in_tracing(results)
self.verify_hosts_in_tracing(results, 2)

def test_tablets(self):
self.query_data_host_select(self.session)
Expand All @@ -155,3 +165,70 @@ def test_tablets(self):
def test_tablets_shard_awareness(self):
self.query_data_shard_select(self.session)
self.query_data_shard_insert(self.session)

def test_tablets_invalidation_drop_ks_while_reconnecting(self):
def recreate_while_reconnecting(_):
# Kill control connection
conn = self.session.cluster.control_connection._connection
self.session.cluster.control_connection._connection = None
conn.close()

# Drop and recreate ks and table to trigger tablets invalidation
self.create_ks_and_cf(self.cluster.connect())

# Start control connection
self.session.cluster.control_connection._reconnect()

self.run_tablets_invalidation_test(recreate_while_reconnecting)

def test_tablets_invalidation_drop_ks(self):
def drop_ks(_):
# Drop and recreate ks and table to trigger tablets invalidation
self.create_ks_and_cf(self.cluster.connect())
time.sleep(3)

self.run_tablets_invalidation_test(drop_ks)

@pytest.mark.last
def test_tablets_invalidation_decommission_non_cc_node(self):
def decommission_non_cc_node(rec):
# Drop and recreate ks and table to trigger tablets invalidation
for node in CCM_CLUSTER.nodes.values():
if self.cluster.control_connection._connection.endpoint.address == node.network_interfaces["storage"][0]:
# Ignore node that control connection is connected to
continue
for replica in rec.replicas:
if str(replica[0]) == str(node.node_hostid):
node.decommission()
break
else:
continue
break
else:
assert False, "failed to find node to decommission"
time.sleep(10)

self.run_tablets_invalidation_test(decommission_non_cc_node)


def run_tablets_invalidation_test(self, invalidate):
# Make sure driver holds tablet info
# By landing query to the host that is not in replica set
bound = self.session.prepare(
"""
SELECT pk, ck, v FROM test1.table1 WHERE pk = ?
""").bind([(2)])

rec = None
for host in self.cluster.metadata.all_hosts():
self.session.execute(bound, host=host)
rec = self.get_tablet_record(bound)
if rec is not None:
break

assert rec is not None, "failed to find tablet record"

invalidate(rec)

# Check if tablets information was purged
assert self.get_tablet_record(bound) is None, "tablet was not deleted, invalidation did not work"

0 comments on commit fdfc7df

Please sign in to comment.