diff --git a/client/client.cpp b/client/client.cpp index f5ef84c..a2b8589 100644 --- a/client/client.cpp +++ b/client/client.cpp @@ -42,6 +42,7 @@ TVoidTask Client(TPoller& poller, TSocket socket) { TLine line; TCommandRequest header; + header.Flags = TCommandRequest::EWrite; header.Type = static_cast(TCommandRequest::MessageType); auto lineReader = TLineReader(input, 2*1024); auto byteWriter = TByteWriter(socket); diff --git a/examples/kv.cpp b/examples/kv.cpp index 2c76b4b..0f6569c 100644 --- a/examples/kv.cpp +++ b/examples/kv.cpp @@ -18,7 +18,7 @@ struct TResultValue: public TCommandResponse { char Data[0]; }; -TMessageHolder TKv::Read(TMessageHolder message) { +TMessageHolder TKv::Read(TMessageHolder message, uint64_t index) { auto readKv = message.Cast(); std::string_view k(readKv->Data, readKv->KeySize); auto it = H.find(std::string(k)); diff --git a/examples/kv.h b/examples/kv.h index 1c9bad0..2db952e 100644 --- a/examples/kv.h +++ b/examples/kv.h @@ -7,7 +7,7 @@ class TKv: public IRsm { public: - TMessageHolder Read(TMessageHolder message) override; + TMessageHolder Read(TMessageHolder message, uint64_t index) override; void Write(TMessageHolder message) override; TMessageHolder Prepare(TMessageHolder message, uint64_t term) override; diff --git a/src/messages.h b/src/messages.h index 516727e..a039cc3 100644 --- a/src/messages.h +++ b/src/messages.h @@ -82,10 +82,15 @@ static_assert(sizeof(TAppendEntriesResponse) == sizeof(TMessageEx) + 16); struct TCommandRequest: public TMessage { static constexpr EMessageType MessageType = EMessageType::COMMAND_REQUEST; + enum EFlags { + ENone = 0, + EWrite = 1, + }; + uint32_t Flags = ENone; char Data[0]; }; -static_assert(sizeof(TCommandRequest) == sizeof(TMessage)); +static_assert(sizeof(TCommandRequest) == sizeof(TMessage) + 4); struct TCommandResponse: public TMessage { static constexpr EMessageType MessageType = EMessageType::COMMAND_RESPONSE; diff --git a/src/raft.cpp b/src/raft.cpp index cd6cabe..79ef3c5 100644 --- a/src/raft.cpp +++ b/src/raft.cpp @@ -24,9 +24,9 @@ static uint32_t rand_(uint32_t* seed) { } // namespace -TMessageHolder TDummyRsm::Read(TMessageHolder message) +TMessageHolder TDummyRsm::Read(TMessageHolder message, uint64_t index) { - return {}; + return NewHoldedMessage(TCommandResponse {.Index = index}); } void TDummyRsm::Write(TMessageHolder message) @@ -253,12 +253,13 @@ void TRaft::OnAppendEntries(TMessageHolder message) { void TRaft::OnCommandRequest(TMessageHolder command, const std::shared_ptr& replyTo) { auto& log = State->Log; - auto entry = Rsm->Prepare(std::move(command), State->CurrentTerm); - log.emplace_back(std::move(entry)); + if (command->Flags & TCommandRequest::EWrite) { + auto entry = Rsm->Prepare(std::move(command), State->CurrentTerm); + log.emplace_back(std::move(entry)); + } auto index = log.size(); if (replyTo) { - auto mes = NewHoldedMessage(TCommandResponse {.Index = index}); - waiting.emplace(TWaiting{mes->Index, mes, replyTo}); + waiting.emplace(TWaiting{index, std::move(command), replyTo}); } } @@ -378,7 +379,13 @@ void TRaft::ProcessWaiting() { auto lastApplied = VolatileState->LastApplied; while (!waiting.empty() && waiting.top().Index <= lastApplied) { auto w = waiting.top(); waiting.pop(); - w.ReplyTo->Send(std::move(w.Message)); + TMessageHolder reply; + if (w.Command->Flags & TCommandRequest::EWrite) { + reply = NewHoldedMessage(TCommandResponse {.Index = w.Index}); + } else { + reply = Rsm->Read(std::move(w.Command), w.Index); + } + w.ReplyTo->Send(std::move(reply)); } } diff --git a/src/raft.h b/src/raft.h index 682ffbf..4df1f17 100644 --- a/src/raft.h +++ b/src/raft.h @@ -20,13 +20,13 @@ struct INode { // CommandRequest -> Read? -> CurrentIndex (fixate) >= CommittedIndex -> CommandResponse struct IRsm { virtual ~IRsm() = default; - virtual TMessageHolder Read(TMessageHolder message) = 0; + virtual TMessageHolder Read(TMessageHolder message, uint64_t index) = 0; virtual void Write(TMessageHolder message) = 0; virtual TMessageHolder Prepare(TMessageHolder message, uint64_t term) = 0; }; struct TDummyRsm: public IRsm { - TMessageHolder Read(TMessageHolder message) override; + TMessageHolder Read(TMessageHolder message, uint64_t index) override; void Write(TMessageHolder message) override; TMessageHolder Prepare(TMessageHolder message, uint64_t term) override; }; @@ -158,7 +158,7 @@ class TRaft { struct TWaiting { uint64_t Index; - TMessageHolder Message; + TMessageHolder Command; std::shared_ptr ReplyTo; bool operator< (const TWaiting& other) const { return Index > other.Index;