Skip to content

Commit

Permalink
Preparations for SSL support
Browse files Browse the repository at this point in the history
  • Loading branch information
resetius committed Dec 10, 2023
1 parent 59cd1e7 commit dc4cc72
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 55 deletions.
11 changes: 7 additions & 4 deletions server/server.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "coroio/all.hpp"
#include <coroio/all.hpp>
#include <csignal>
#include <timesource.h>
#include <raft.h>
Expand Down Expand Up @@ -38,8 +38,8 @@ int main(int argc, char** argv) {
if (host.Id == id) {
myHost = host;
} else {
nodes[host.Id] = std::make_shared<TNode<TPoller>>(
loop.Poller(),
nodes[host.Id] = std::make_shared<TNode<TPoller::TSocket>>(
[&](const NNet::TAddress& addr) { return TPoller::TSocket(addr, loop.Poller()); },
std::to_string(host.Id),
NNet::TAddress{host.Address, host.Port},
timeSource);
Expand All @@ -51,7 +51,10 @@ int main(int argc, char** argv) {
}

auto raft = std::make_shared<TRaft>(myHost.Id, nodes);
TRaftServer server(loop.Poller(), NNet::TAddress{myHost.Address, myHost.Port}, raft, nodes, timeSource);
TPoller::TSocket socket(NNet::TAddress{myHost.Address, myHost.Port}, loop.Poller());
socket.Bind();
socket.Listen();
TRaftServer server(loop.Poller(), std::move(socket), raft, nodes, timeSource);
server.Serve();
loop.Loop();
return 0;
Expand Down
63 changes: 29 additions & 34 deletions src/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ NNet::TValueTask<TMessageHolder<TMessage>> TMessageReader<TSocket>::Read() {
co_return mes;
}

template<typename TPoller>
void TNode<TPoller>::Send(TMessageHolder<TMessage> message) {
template<typename TSocket>
void TNode<TSocket>::Send(TMessageHolder<TMessage> message) {
Messages.emplace_back(std::move(message));
}

template<typename TPoller>
void TNode<TPoller>::Drain() {
template<typename TSocket>
void TNode<TSocket>::Drain() {
if (!Connected) {
Connect();
return;
Expand All @@ -65,8 +65,8 @@ void TNode<TPoller>::Drain() {
}
}

template<typename TPoller>
NNet::TVoidSuspendedTask TNode<TPoller>::DoDrain() {
template<typename TSocket>
NNet::TVoidSuspendedTask TNode<TSocket>::DoDrain() {
try {
while (!Messages.empty()) {
auto tosend = std::move(Messages); Messages.clear();
Expand All @@ -81,21 +81,21 @@ NNet::TVoidSuspendedTask TNode<TPoller>::DoDrain() {
co_return;
}

template<typename TPoller>
void TNode<TPoller>::Connect() {
template<typename TSocket>
void TNode<TSocket>::Connect() {
if (Address && (!Connector || Connector.done())) {
if (Connector && Connector.done()) {
Connector.destroy();
}

Socket = typename TPoller::TSocket(*Address, Poller);
Socket = SocketFactory(*Address);
Connected = false;
Connector = DoConnect();
}
}

template<typename TPoller>
NNet::TVoidSuspendedTask TNode<TPoller>::DoConnect() {
template<typename TSocket>
NNet::TVoidSuspendedTask TNode<TSocket>::DoConnect() {
std::cout << "Connecting " << Name << "\n";
while (!Connected) {
try {
Expand All @@ -106,18 +106,15 @@ NNet::TVoidSuspendedTask TNode<TPoller>::DoConnect() {
} catch (const std::exception& ex) {
std::cout << "Error on connect: " << Name << " " << ex.what() << "\n";
}
if (!Connected) {
co_await Poller.Sleep(std::chrono::milliseconds(1000));
}
}
co_return;
}

template<typename TPoller>
NNet::TVoidTask TRaftServer<TPoller>::InboundConnection(typename TPoller::TSocket socket) {
template<typename TSocket>
NNet::TVoidTask TRaftServer<TSocket>::InboundConnection(TSocket socket) {
try {
auto client = std::make_shared<TNode<TPoller>>(
Poller, "client", std::move(socket), TimeSource
auto client = std::make_shared<TNode<TSocket>>(
"client", std::move(socket), TimeSource
);
Nodes.insert(client);
while (true) {
Expand All @@ -133,25 +130,21 @@ NNet::TVoidTask TRaftServer<TPoller>::InboundConnection(typename TPoller::TSocke
co_return;
}

template<typename TPoller>
void TRaftServer<TPoller>::Serve() {
template<typename TSocket>
void TRaftServer<TSocket>::Serve() {
Idle();
InboundServe();
}

template<typename TPoller>
void TRaftServer<TPoller>::DrainNodes() {
template<typename TSocket>
void TRaftServer<TSocket>::DrainNodes() {
for (const auto& node : Nodes) {
node->Drain();
}
}

template<typename TPoller>
NNet::TVoidTask TRaftServer<TPoller>::InboundServe() {
std::cout << "Bind\n";
Socket.Bind();
std::cout << "Listen\n";
Socket.Listen();
template<typename TSocket>
NNet::TVoidTask TRaftServer<TSocket>::InboundServe() {
while (true) {
auto client = co_await Socket.Accept();
std::cout << "Accepted\n";
Expand All @@ -160,8 +153,8 @@ NNet::TVoidTask TRaftServer<TPoller>::InboundServe() {
co_return;
}

template<typename TPoller>
void TRaftServer<TPoller>::DebugPrint() {
template<typename TSocket>
void TRaftServer<TSocket>::DebugPrint() {
auto* state = Raft->GetState();
auto* volatileState = Raft->GetVolatileState();
if (Raft->CurrentStateName() == EState::LEADER) {
Expand Down Expand Up @@ -197,8 +190,8 @@ void TRaftServer<TPoller>::DebugPrint() {
}
}

template<typename TPoller>
NNet::TVoidTask TRaftServer<TPoller>::Idle() {
template<typename TSocket>
NNet::TVoidTask TRaftServer<TSocket>::Idle() {
auto t0 = TimeSource->Now();
auto dt = std::chrono::milliseconds(2000);
auto sleep = std::chrono::milliseconds(100);
Expand All @@ -215,5 +208,7 @@ NNet::TVoidTask TRaftServer<TPoller>::Idle() {
co_return;
}

template class TRaftServer<NNet::TDefaultPoller>;
template class TRaftServer<NNet::TPoll>;
template class TRaftServer<NNet::TSocket>;
#ifdef __linux__
template class TRaftServer<NNet::TUringSocket>;
#endif
33 changes: 16 additions & 17 deletions src/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,27 +68,26 @@ struct THost {
}
};

template<typename TPoller>
template<typename TSocket>
class TNode: public INode {
public:
TNode(TPoller& poller, const std::string& name, NNet::TAddress address, const std::shared_ptr<ITimeSource>& ts)
: Poller(poller)
, Name(name)
TNode(const std::function<TSocket(const NNet::TAddress&)> factory, const std::string& name, NNet::TAddress address, const std::shared_ptr<ITimeSource>& ts)
: Name(name)
, Address(address)
, TimeSource(ts)
, SocketFactory(factory)
{ }

TNode(TPoller& poller, const std::string& name, typename TPoller::TSocket socket, const std::shared_ptr<ITimeSource>& ts)
: Poller(poller)
, Name(name)
TNode(const std::string& name, TSocket socket, const std::shared_ptr<ITimeSource>& ts)
: Name(name)
, Socket(std::move(socket))
, Connected(true)
, TimeSource(ts)
{ }

void Send(TMessageHolder<TMessage> message) override;
void Drain() override;
typename TPoller::TSocket& Sock() {
TSocket& Sock() {
return Socket;
}

Expand All @@ -98,11 +97,11 @@ class TNode: public INode {
NNet::TVoidSuspendedTask DoDrain();
NNet::TVoidSuspendedTask DoConnect();

TPoller& Poller;
std::string Name;
std::optional<NNet::TAddress> Address;
std::shared_ptr<ITimeSource> TimeSource;
typename TPoller::TSocket Socket;
std::function<TSocket(const NNet::TAddress&)> SocketFactory;
TSocket Socket;
bool Connected = false;

std::coroutine_handle<> Drainer;
Expand All @@ -111,17 +110,17 @@ class TNode: public INode {
std::vector<TMessageHolder<TMessage>> Messages;
};

template<typename TPoller>
template<typename TSocket>
class TRaftServer {
public:
TRaftServer(
TPoller& poller,
NNet::TAddress address,
typename TSocket::TPoller& poller,
TSocket socket,
const std::shared_ptr<TRaft>& raft,
const TNodeDict& nodes,
const std::shared_ptr<ITimeSource>& ts)
: Poller(poller)
, Socket(std::move(address), Poller)
, Socket(std::move(socket))
, Raft(raft)
, TimeSource(ts)
{
Expand All @@ -134,13 +133,13 @@ class TRaftServer {

private:
NNet::TVoidTask InboundServe();
NNet::TVoidTask InboundConnection(typename TPoller::TSocket socket);
NNet::TVoidTask InboundConnection(TSocket socket);
NNet::TVoidTask Idle();
void DrainNodes();
void DebugPrint();

TPoller& Poller;
typename TPoller::TSocket Socket;
typename TSocket::TPoller& Poller;
TSocket Socket;
std::shared_ptr<TRaft> Raft;
std::unordered_set<std::shared_ptr<INode>> Nodes;
std::shared_ptr<ITimeSource> TimeSource;
Expand Down

0 comments on commit dc4cc72

Please sign in to comment.