Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented simple kv storage #12

Merged
merged 1 commit into from
Mar 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 138 additions & 11 deletions examples/kv.cpp
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
#include <string_view>

#include <server.h>

#include "kv.h"

struct TWriteKv: public TLogEntry {
struct TWriteKvEntry: public TLogEntry {
uint16_t KeySize;
uint16_t ValSize;
char Data[0];
};

struct TReadKv: public TCommandRequest {
struct TWriteKv: public TCommandRequest {
uint16_t KeySize;
uint16_t ValSize;
char Data[0];
};

struct TResultValue: public TCommandResponse {
uint16_t ValSize;
struct TReadKv: public TCommandRequest {
uint16_t KeySize;
char Data[0];
};

Expand All @@ -23,20 +26,20 @@ TMessageHolder<TMessage> TKv::Read(TMessageHolder<TCommandRequest> message, uint
std::string_view k(readKv->Data, readKv->KeySize);
auto it = H.find(std::string(k));
if (it == H.end()) {
auto res = NewHoldedMessage<TResultValue>(sizeof(TResultValue));
res->ValSize = -1;
auto res = NewHoldedMessage<TCommandResponse>(sizeof(TCommandResponse));
res->Index = index;
return res;
} else {
auto res = NewHoldedMessage<TResultValue>(sizeof(TResultValue)+it->second.size());
res->ValSize = it->second.size();
memcpy(res->Data, it->second.data(), res->ValSize);
auto res = NewHoldedMessage<TCommandResponse>(sizeof(TCommandResponse)+it->second.size());
res->Index = index;
memcpy(res->Data, it->second.data(), it->second.size());
return res;
}
}

void TKv::Write(TMessageHolder<TLogEntry> message, uint64_t index) {
if (index < LastAppliedIndex) {
auto writeKv = message.Cast<TWriteKv>();
if (LastAppliedIndex < index) {
auto writeKv = message.Cast<TWriteKvEntry>();
std::string_view k(writeKv->Data, writeKv->KeySize);
std::string_view v(writeKv->Data + writeKv->KeySize, writeKv->ValSize);
H[std::string(k)] = std::string(v);
Expand All @@ -52,6 +55,130 @@ TMessageHolder<TLogEntry> TKv::Prepare(TMessageHolder<TCommandRequest> command,
return entry;
}

template<typename TPoller, typename TSocket>
NNet::TVoidTask Client(TPoller& poller, TSocket socket) {
using TFileHandle = typename TPoller::TFileHandle;
TFileHandle input{0, poller}; // stdin
co_await socket.Connect();
std::cout << "Connected\n";

NNet::TLine line;
TCommandRequest header;
header.Flags = TCommandRequest::EWrite;
header.Type = static_cast<uint32_t>(TCommandRequest::MessageType);
auto lineReader = NNet::TLineReader(input, 2*1024);
auto byteWriter = NNet::TByteWriter(socket);
const char* sep = " \t\r\n";

try {
while ((line = co_await lineReader.Read())) {
std::string strLine;
strLine += std::string_view(line.Part1.data(), line.Part1.size());
strLine += std::string_view(line.Part2.data(), line.Part2.size());
auto prefix = strtok((char*)strLine.data(), sep);
TMessageHolder<TMessage> req;

if (!strcmp(prefix, "set")) {
auto key = strtok(nullptr, sep);
auto val = strtok(nullptr, sep);
auto keySize = strlen(key);
auto valSize = strlen(val);
auto mes = NewHoldedMessage<TWriteKv>(sizeof(TWriteKv) + keySize + valSize);
mes->Flags = TCommandRequest::EWrite;
mes->KeySize = keySize;
mes->ValSize = valSize;
memcpy(mes->Data, key, keySize);
memcpy(mes->Data+keySize, val, valSize);
req = mes;
} else if (!strcmp(prefix, "get")) {
auto key = strtok(nullptr, sep);
auto size = strlen(key);
auto mes = NewHoldedMessage<TReadKv>(sizeof(TReadKv) + size);
mes->KeySize = size;
memcpy(mes->Data, key, size);
req = mes;
} else {
std::cout << "Cannot parse command: " << strLine << "\n";
}

co_await TMessageWriter(socket).Write(std::move(req));
auto reply = co_await TMessageReader(socket).Read();
auto res = reply.template Cast<TCommandResponse>();
auto len = res->Len - sizeof(TCommandResponse);
std::string_view data(res->Data, len);
std::cout << "Ok, commitIndex: " << res->Index << " "
<< data << "\n";
}
} catch (const std::exception& ex) {
std::cout << "Exception: " << ex.what() << "\n";
}
co_return;
}

void usage(const char* prog) {
std::cerr << prog << "--client|--server --id myid --node ip:port:id [--node ip:port:id ...]" << "\n";
exit(0);
}

int main(int argc, char** argv) {
signal(SIGPIPE, SIG_IGN);
std::vector<THost> hosts;
THost myHost;
TNodeDict nodes;
uint32_t id = 0;
bool server = false;
for (int i = 1; i < argc; i++) {
if (!strcmp(argv[i], "--server")) {
server = true;
} else if (!strcmp(argv[i], "--node") && i < argc - 1) {
// address:port:id
hosts.push_back(THost{argv[++i]});
} else if (!strcmp(argv[i], "--id") && i < argc - 1) {
id = atoi(argv[++i]);
} else if (!strcmp(argv[i], "--help")) {
usage(argv[0]);
}
}

using TPoller = NNet::TDefaultPoller;
std::shared_ptr<ITimeSource> timeSource = std::make_shared<TTimeSource>();
NNet::TLoop<TPoller> loop;

if (server) {
for (auto& host : hosts) {
if (!host) {
std::cerr << "Empty host\n"; return 1;
}
if (host.Id == id) {
myHost = host;
} else {
nodes[host.Id] = std::make_shared<TNode<TPoller::TSocket>>(
[&](const NNet::TAddress& addr) { return TPoller::TSocket(addr, loop.Poller()); },
std::to_string(host.Id),
NNet::TAddress{host.Address, host.Port},
timeSource);
}
}

if (!myHost) {
std::cerr << "Host not found\n"; return 1;
}

std::shared_ptr<IRsm> rsm = std::make_shared<TKv>();
auto raft = std::make_shared<TRaft>(rsm, myHost.Id, nodes);
TPoller::TSocket socket(NNet::TAddress{myHost.Address, myHost.Port}, loop.Poller());
socket.Bind();
socket.Listen();
TRaftServer server(loop.Poller(), std::move(socket), raft, nodes, timeSource);
server.Serve();
loop.Loop();
} else {
NNet::TAddress addr{hosts[0].Address, hosts[0].Port};
NNet::TSocket socket(std::move(addr), loop.Poller());

Client(loop.Poller(), std::move(socket));
loop.Loop();
}

return 0;
}
Loading