diff --git a/src/raft.cpp b/src/raft.cpp index a1d790c..0224e89 100644 --- a/src/raft.cpp +++ b/src/raft.cpp @@ -8,9 +8,7 @@ TRaft::TRaft(int node, const TNodeDict& nodes, const std::shared_ptr &message) { - return Follower(now, message); - }) + , StateName(EState::FOLLOWER) , LastTime(TimeSource->now()) { } @@ -52,7 +50,7 @@ void TRaft::ApplyResult(uint64_t now, std::unique_ptr result, INode* re Nodes[m->Dst]->Send(m); } } - if (result->NextStateFunc) { - StateFunc = result->NextStateFunc; + if (result->NextStateName) { + StateName = static_cast(result->NextStateName); } } diff --git a/src/raft.h b/src/raft.h index 7824d05..b2daa01 100644 --- a/src/raft.h +++ b/src/raft.h @@ -66,12 +66,18 @@ struct TVolatileState { struct TResult; +enum class EState: int { + CANDIDATE = 1, + FOLLOWER = 2, + LEADER = 3, +}; + using TStateFunc = std::function(uint64_t now, const TMessageHolder& message)>; struct TResult { std::unique_ptr NextState; std::unique_ptr NextVolatileState; - TStateFunc NextStateFunc; + int NextStateName; bool UpdateLastTime; TMessageHolder Message; std::vector> Messages; @@ -84,6 +90,10 @@ class TRaft { void Process(const TMessageHolder& message, INode* replyTo = nullptr); void ApplyResult(uint64_t now, std::unique_ptr result, INode* replyTo = nullptr); + EState CurrentStateName() const { + return StateName; + } + private: std::unique_ptr Follower(uint64_t now, const TMessageHolder& message); @@ -96,6 +106,6 @@ class TRaft { std::unique_ptr State; std::unique_ptr VolatileState; - TStateFunc StateFunc; + EState StateName; uint64_t LastTime; }; diff --git a/test/test_raft.cpp b/test/test_raft.cpp index 4b63e06..36b7a6b 100644 --- a/test/test_raft.cpp +++ b/test/test_raft.cpp @@ -132,6 +132,7 @@ void test_message_send_recv(void** state) { void test_initial(void**) { auto raft = MakeRaft(); + assert_true(raft->CurrentStateName() == EState::FOLLOWER); } int main() {