Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: unsubscribe pub/sub connections after cluster migration #4529

Open
wants to merge 1 commit into
base: kpr1
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/facade/conn_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <absl/container/flat_hash_set.h>

#include <memory>
#include <string_view>

#include "core/heap_size.h"
#include "facade/acl_commands_def.h"
Expand Down Expand Up @@ -34,6 +35,10 @@ class ConnectionContext {

virtual size_t UsedMemory() const;

// Noop.
virtual void Unsubscribe(std::string_view channel) {
}

// connection state / properties.
bool conn_closing : 1;
bool req_auth : 1;
Expand Down
14 changes: 13 additions & 1 deletion src/facade/dragonfly_connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ using nonstd::make_unexpected;

namespace facade {


namespace {

void SendProtocolError(RedisParser::Result pres, SinkReplyBuilder* builder) {
Expand Down Expand Up @@ -468,6 +467,17 @@ void Connection::AsyncOperations::operator()(const AclUpdateMessage& msg) {

void Connection::AsyncOperations::operator()(const PubMessage& pub_msg) {
RedisReplyBuilder* rbuilder = (RedisReplyBuilder*)builder;

if (pub_msg.should_unsubscribe) {
rbuilder->StartCollection(3, RedisReplyBuilder::CollectionType::PUSH);
rbuilder->SendBulkString("unsubscribe");
rbuilder->SendBulkString(pub_msg.channel);
rbuilder->SendLong(0);
auto* cntx = self->cntx();
cntx->Unsubscribe(pub_msg.channel);
return;
}

unsigned i = 0;
array<string_view, 4> arr;
if (pub_msg.pattern.empty()) {
Expand All @@ -476,8 +486,10 @@ void Connection::AsyncOperations::operator()(const PubMessage& pub_msg) {
arr[i++] = "pmessage";
arr[i++] = pub_msg.pattern;
}

arr[i++] = pub_msg.channel;
arr[i++] = pub_msg.message;

rbuilder->SendBulkStrArr(absl::Span<string_view>{arr.data(), i},
RedisReplyBuilder::CollectionType::PUSH);
}
Expand Down
1 change: 1 addition & 0 deletions src/facade/dragonfly_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class Connection : public util::Connection {
std::string pattern{}; // non-empty for pattern subscriber
std::shared_ptr<char[]> buf; // stores channel name and message
std::string_view channel, message; // channel and message parts from buf
bool should_unsubscribe = false; // unsubscribe from channel after sending the message
};

// Pipeline message, accumulated Redis command to be executed.
Expand Down
99 changes: 94 additions & 5 deletions src/server/channel_store.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ extern "C" {
#include <absl/container/fixed_array.h>

#include "base/logging.h"
#include "server/cluster/slot_set.h"
#include "server/cluster_support.h"
#include "server/engine_shard_set.h"
#include "server/server_state.h"

Expand All @@ -26,7 +28,7 @@ bool Matches(string_view pattern, string_view channel) {
}

// Build functor for sending messages to connection
auto BuildSender(string_view channel, facade::ArgRange messages) {
auto BuildSender(string_view channel, facade::ArgRange messages, bool unsubscribe = false) {
absl::FixedArray<string_view, 1> views(messages.Size());
size_t messages_size = accumulate(messages.begin(), messages.end(), 0,
[](int sum, string_view str) { return sum + str.size(); });
Expand All @@ -43,11 +45,11 @@ auto BuildSender(string_view channel, facade::ArgRange messages) {
}
}

return [channel, buf = std::move(buf), views = std::move(views)](facade::Connection* conn,
string pattern) {
return [channel, buf = std::move(buf), views = std::move(views), unsubscribe](
facade::Connection* conn, string pattern) {
string_view channel_view{buf.get(), channel.size()};
for (std::string_view message_view : views)
conn->SendPubMessageAsync({std::move(pattern), buf, channel_view, message_view});
conn->SendPubMessageAsync({std::move(pattern), buf, channel_view, message_view, unsubscribe});
};
}

Expand Down Expand Up @@ -153,7 +155,6 @@ unsigned ChannelStore::SendMessages(std::string_view channel, facade::ArgRange m
auto it = lower_bound(subscribers_ptr->begin(), subscribers_ptr->end(), idx,
ChannelStore::Subscriber::ByThreadId);
while (it != subscribers_ptr->end() && it->Thread() == idx) {
// if ptr->cntx() is null, a connection might have closed or be in the process of closing
if (auto* ptr = it->Get(); ptr && ptr->cntx() != nullptr)
send(ptr, it->pattern);
it++;
Expand Down Expand Up @@ -203,6 +204,45 @@ size_t ChannelStore::PatternCount() const {
return patterns_->size();
}

void ChannelStore::UnsubscribeAfterClusterSlotMigration(const cluster::SlotSet& deleted_slots) {
if (deleted_slots.Empty()) {
return;
}

const uint32_t tid = util::ProactorBase::me()->GetPoolIndex();
ChannelStoreUpdater csu(false, false, nullptr, tid);

for (const auto& [channel, _] : *channels_) {
auto channel_slot = KeySlot(channel);
if (deleted_slots.Contains(channel_slot)) {
csu.Record(channel);
}
}

csu.ApplyAndUnsubscribe();
}

void ChannelStore::UnsubscribeConnectionsFromDeletedSlots(std::vector<std::string_view> channels,
uint32_t idx) {
const bool should_unsubscribe = true;
for (auto channel : channels) {
facade::ArgSlice slice{std::string_view{}};
auto send = BuildSender(channel, slice, should_unsubscribe);

auto subscribers = FetchSubscribers(channel);
auto it = lower_bound(subscribers.begin(), subscribers.end(), idx,
ChannelStore::Subscriber::ByThreadId);
while (it != subscribers.end() && it->Thread() == idx) {
// if ptr->cntx() is null, a connection might have closed or be in the process of closing
if (auto* ptr = it->Get(); ptr && ptr->cntx() != nullptr) {
DCHECK(it->pattern.empty());
send(ptr, it->pattern);
}
++it;
}
}
}

ChannelStoreUpdater::ChannelStoreUpdater(bool pattern, bool to_add, ConnectionContext* cntx,
uint32_t thread_id)
: pattern_{pattern}, to_add_{to_add}, cntx_{cntx}, thread_id_{thread_id} {
Expand Down Expand Up @@ -302,4 +342,53 @@ void ChannelStoreUpdater::Apply() {
delete ptr;
}

void ChannelStoreUpdater::ApplyAndUnsubscribe() {
DCHECK(to_add_ == false);
DCHECK(pattern_ == false);
DCHECK(cntx_ == nullptr);

if (ops_.empty()) {
return;
}

// Wait for other updates to finish, lock the control block and update store pointer.
auto& cb = ChannelStore::control_block;
cb.update_mu.lock();
auto* store = cb.most_recent.load(memory_order_relaxed);

// Deep copy, we will remove channels
auto* target = new ChannelStore::ChannelMap{*store->channels_};

for (auto key : ops_) {
auto it = target->find(key);
freelist_.push_back(it->second.Get());
target->erase(it);
continue;
}

// Prepare replacement.
auto* replacement = new ChannelStore{target, store->patterns_};

// Update control block and unlock it.
cb.most_recent.store(replacement, memory_order_relaxed);
cb.update_mu.unlock();

// Update thread local references. Readers fetch subscribers via FetchSubscribers,
// which runs without preemption, and store references to them in self container Subscriber
// structs. This means that any point on the other thread is safe to update the channel store.
// Regardless of whether we need to replace, we dispatch to make sure all
// queued SubscribeMaps in the freelist are no longer in use.
shard_set->pool()->AwaitFiberOnAll([this](unsigned idx, util::ProactorBase*) {
ServerState::tlocal()->UnsubscribeSlotsAndUpdateChannelStore(
ops_, ChannelStore::control_block.most_recent.load(memory_order_relaxed));
});

// Delete previous map and channel store.
delete store->channels_;
delete store;

for (auto ptr : freelist_)
delete ptr;
}

} // namespace dfly
15 changes: 15 additions & 0 deletions src/server/channel_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ namespace dfly {

class ChannelStoreUpdater;

namespace cluster {
class SlotSet;
}

// ChannelStore manages PUB/SUB subscriptions.
//
// Updates are carried out via RCU (read-copy-update). Each thread stores a pointer to ChannelStore
Expand Down Expand Up @@ -61,8 +65,13 @@ class ChannelStore {
std::vector<Subscriber> FetchSubscribers(std::string_view channel) const;

std::vector<std::string> ListChannels(const std::string_view pattern) const;

size_t PatternCount() const;

void UnsubscribeAfterClusterSlotMigration(const cluster::SlotSet& deleted_slots);

void UnsubscribeConnectionsFromDeletedSlots(std::vector<std::string_view> channels, uint32_t idx);

// Destroy current instance and delete it.
static void Destroy();

Expand Down Expand Up @@ -128,6 +137,12 @@ class ChannelStoreUpdater {
void Record(std::string_view key);
void Apply();

// Used for cluster when slots migrate. We need to:
// 1. Remove the channel from the copy.
// 2. Unsuscribe all the connections from each channel.
// 3. Update the control block pointer.
void ApplyAndUnsubscribe();

private:
using ChannelMap = ChannelStore::ChannelMap;

Expand Down
4 changes: 4 additions & 0 deletions src/server/cluster/cluster_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "facade/dragonfly_connection.h"
#include "facade/error.h"
#include "server/acl/acl_commands_def.h"
#include "server/channel_store.h"
#include "server/command_registry.h"
#include "server/conn_context.h"
#include "server/dflycmd.h"
Expand Down Expand Up @@ -625,6 +626,9 @@ void ClusterFamily::DflyClusterConfig(CmdArgList args, SinkReplyBuilder* builder
auto deleted_slots = (before.GetRemovedSlots(after)).ToSlotRanges();
deleted_slots.Merge(outgoing_migrations.slot_ranges);
DeleteSlots(deleted_slots);
auto* channel_store = ServerState::tlocal()->channel_store();
auto deleted = SlotSet(deleted_slots);
channel_store->UnsubscribeAfterClusterSlotMigration(deleted);
LOG_IF(INFO, !deleted_slots.Empty())
<< "Flushing newly unowned slots: " << deleted_slots.ToString();
WriteFlushSlotsToJournal(deleted_slots);
Expand Down
10 changes: 10 additions & 0 deletions src/server/conn_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,16 @@ size_t ConnectionContext::UsedMemory() const {
return facade::ConnectionContext::UsedMemory() + dfly::HeapSize(conn_state);
}

void ConnectionContext::Unsubscribe(std::string_view channel) {
auto* sinfo = conn_state.subscribe_info.get();
sinfo->channels.erase(channel);
if (sinfo->IsEmpty()) {
conn_state.subscribe_info.reset();
DCHECK_GE(subscriptions, 1u);
--subscriptions;
}
}

vector<unsigned> ConnectionContext::ChangeSubscriptions(CmdArgList channels, bool pattern,
bool to_add, bool to_reply) {
vector<unsigned> result(to_reply ? channels.size() : 0, 0);
Expand Down
2 changes: 2 additions & 0 deletions src/server/conn_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,8 @@ class ConnectionContext : public facade::ConnectionContext {

size_t UsedMemory() const override;

virtual void Unsubscribe(std::string_view channel) override;

// Whether this connection is a connection from a replica to its master.
// This flag is true only on replica side, where we need to setup a special ConnectionContext
// instance that helps applying commands coming from master.
Expand Down
11 changes: 9 additions & 2 deletions src/server/server_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ extern "C" {
#include "base/logging.h"
#include "facade/conn_context.h"
#include "facade/dragonfly_connection.h"
#include "server/channel_store.h"
#include "server/journal/journal.h"
#include "util/listener_interface.h"

Expand Down Expand Up @@ -261,8 +262,8 @@ void ServerState::ConnectionsWatcherFb(util::ListenerInterface* main) {
is_replica = dfly_conn->cntx()->replica_conn;
}

if ((phase == Phase::READ_SOCKET || dfly_conn->IsSending()) &&
!is_replica && dfly_conn->idle_time() > timeout) {
if ((phase == Phase::READ_SOCKET || dfly_conn->IsSending()) && !is_replica &&
dfly_conn->idle_time() > timeout) {
conn_refs.push_back(dfly_conn->Borrow());
}
};
Expand All @@ -285,4 +286,10 @@ void ServerState::ConnectionsWatcherFb(util::ListenerInterface* main) {
}
}

void ServerState::UnsubscribeSlotsAndUpdateChannelStore(std::vector<std::string_view> channels,
ChannelStore* replacement) {
channel_store_->UnsubscribeConnectionsFromDeletedSlots(channels, thread_index_);
channel_store_ = replacement;
}

} // end of namespace dfly
3 changes: 3 additions & 0 deletions src/server/server_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@ class ServerState { // public struct - to allow initialization.
channel_store_ = replacement;
}

void UnsubscribeSlotsAndUpdateChannelStore(std::vector<std::string_view> channels,
ChannelStore* replacement);

bool ShouldLogSlowCmd(unsigned latency_usec) const;

Stats stats;
Expand Down
48 changes: 48 additions & 0 deletions tests/dragonfly/cluster_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2901,3 +2901,51 @@ async def test_cluster_sharded_pub_sub(df_factory: DflyInstanceFactory):
await c_nodes[0].execute_command("SPUBLISH kostas new_message")
message = consumer.get_sharded_message(target_node=node_a)
assert message == {"type": "unsubscribe", "pattern": None, "channel": b"kostas", "data": 0}


@dfly_args({"proactor_threads": 2, "cluster_mode": "yes"})
async def test_cluster_sharded_pub_sub_migration(df_factory: DflyInstanceFactory):
instances = [df_factory.create(port=next(next_port)) for i in range(2)]
df_factory.start_all(instances)

c_nodes = [instance.client() for instance in instances]

nodes = [(await create_node_info(instance)) for instance in instances]
nodes[0].slots = [(0, 16383)]
nodes[1].slots = []

await push_config(json.dumps(generate_config(nodes)), [node.client for node in nodes])

# Setup producer and consumer
node_a = ClusterNode("localhost", instances[0].port)
node_b = ClusterNode("localhost", instances[1].port)

consumer_client = RedisCluster(startup_nodes=[node_a, node_b])
consumer = consumer_client.pubsub()
consumer.ssubscribe("kostas")

# Push new config
nodes[0].migrations.append(
MigrationInfo("127.0.0.1", nodes[1].instance.port, [(0, 16383)], nodes[1].id)
)
await push_config(json.dumps(generate_config(nodes)), [node.client for node in nodes])

await wait_for_status(nodes[0].client, nodes[1].id, "FINISHED")

nodes[0].migrations = []
nodes[0].slots = []
nodes[1].slots = [(0, 16383)]
logging.debug("remove finished migrations")
await push_config(json.dumps(generate_config(nodes)), [node.client for node in nodes])

# channel name kostas crc is at slot 2883 which is part of the second now.
with pytest.raises(redis.exceptions.ResponseError) as moved_error:
await c_nodes[0].execute_command("SSUBSCRIBE kostas")

assert str(moved_error.value) == f"MOVED 2833 127.0.0.1:{instances[1].port}"

# Consume subscription message result from above
message = consumer.get_sharded_message(target_node=node_a)
assert message == {"type": "subscribe", "pattern": None, "channel": b"kostas", "data": 1}
message = consumer.get_sharded_message(target_node=node_a)
assert message == {"type": "unsubscribe", "pattern": None, "channel": b"kostas", "data": 0}