From 33523238bf3a0bffc948dd81c3c3fe834927c2a7 Mon Sep 17 00:00:00 2001 From: Alexey Ozeritskiy Date: Sun, 26 Nov 2023 21:04:33 +0300 Subject: [PATCH] Use shared_ptr --- src/raft.cpp | 4 ++-- src/raft.h | 4 ++-- src/server.cpp | 20 ++++++++++---------- src/server.h | 9 ++++++--- 4 files changed, 20 insertions(+), 17 deletions(-) diff --git a/src/raft.cpp b/src/raft.cpp index 7190efc..b8e9b32 100644 --- a/src/raft.cpp +++ b/src/raft.cpp @@ -375,7 +375,7 @@ void TRaft::Become(EState newStateName) { } } -void TRaft::Process(TMessageHolder message, INode* replyTo) { +void TRaft::Process(TMessageHolder message, const std::shared_ptr& replyTo) { auto now = TimeSource->Now(); if (message.IsEx()) { @@ -404,7 +404,7 @@ void TRaft::Process(TMessageHolder message, INode* replyTo) { ApplyResult(now, std::move(result), replyTo); } -void TRaft::ApplyResult(ITimeSource::Time now, std::unique_ptr result, INode* replyTo) { +void TRaft::ApplyResult(ITimeSource::Time now, std::unique_ptr result, const std::shared_ptr& replyTo) { if (!result) { return; } diff --git a/src/raft.h b/src/raft.h index 9d5d555..3ebb19a 100644 --- a/src/raft.h +++ b/src/raft.h @@ -69,8 +69,8 @@ class TRaft { public: TRaft(int node, const TNodeDict& nodes, const std::shared_ptr& ts); - void Process(TMessageHolder message, INode* replyTo = nullptr); - void ApplyResult(ITimeSource::Time now, std::unique_ptr result, INode* replyTo = nullptr); + void Process(TMessageHolder message, const std::shared_ptr& replyTo = {}); + void ApplyResult(ITimeSource::Time now, std::unique_ptr result, const std::shared_ptr& replyTo = {}); // ut EState CurrentStateName() const { diff --git a/src/server.cpp b/src/server.cpp index e39de77..9130e35 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -85,6 +85,10 @@ void TNode::Send(const TMessageHolder& message) { } void TNode::Drain() { + if (!Connected) { + Connect(); + return; + } if (!Drainer || Drainer.done()) { if (Drainer && Drainer.done()) { Drainer.destroy(); @@ -94,10 +98,6 @@ void TNode::Drain() { } NNet::TTestTask TNode::DoDrain() { - if (!Connected) { - Connect(); - co_return; - } auto tosend = std::move(Messages); try { for (auto&& m : tosend) { @@ -105,7 +105,6 @@ NNet::TTestTask TNode::DoDrain() { } } catch (const std::exception& ex) { std::cout << "Error on write: " << ex.what() << "\n"; - Connect(); } Messages.clear(); co_return; @@ -143,13 +142,14 @@ NNet::TTestTask TNode::DoConnect() { NNet::TSimpleTask TRaftServer::InboundConnection(NNet::TSocket socket) { try { - TClientNode client; + auto client = std::make_shared(); + Nodes.insert(client); while (true) { auto mes = co_await TReader(socket).Read(); std::cout << "Got message " << mes->Type << "\n"; - Raft->Process(std::move(mes), &client); - if (!client.Messages.empty()) { - auto tosend = std::move(client.Messages); client.Messages.clear(); + Raft->Process(std::move(mes), client); + if (!client->Messages.empty()) { + auto tosend = std::move(client->Messages); client->Messages.clear(); for (auto&& mes : tosend) { co_await TWriter(socket).Write(std::move(mes)); } @@ -168,7 +168,7 @@ void TRaftServer::Serve() { } void TRaftServer::DrainNodes() { - for (auto [id, node] : Nodes) { + for (const auto& node : Nodes) { node->Drain(); } } diff --git a/src/server.h b/src/server.h index aad578d..1794b8e 100644 --- a/src/server.h +++ b/src/server.h @@ -195,9 +195,12 @@ class TRaftServer { : Poller(poller) , Socket(std::move(address), Poller) , Raft(raft) - , Nodes(nodes) , TimeSource(ts) - { } + { + for (const auto& [_, node] : nodes) { + Nodes.insert(node); + } + } void Serve(); @@ -210,6 +213,6 @@ class TRaftServer { NNet::TPoll& Poller; NNet::TPoll::TSocket Socket; std::shared_ptr Raft; - TNodeDict Nodes; + std::unordered_set> Nodes; std::shared_ptr TimeSource; };