Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add heartbeat to Mission Control Protocol #261

Merged
merged 10 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/network/MissionControlProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ using websocket::msghandler_t;
using websocket::validator_t;

const std::chrono::milliseconds TELEM_REPORT_PERIOD = 100ms;
const std::chrono::milliseconds HEARTBEAT_TIMEOUT_PERIOD = 3000ms;

// TODO: possibly use frozen::string for this so we don't have to use raw char ptrs
// request keys
Expand Down Expand Up @@ -285,6 +286,14 @@ void MissionControlProtocol::handleConnection() {
}
}

void MissionControlProtocol::handleHeartbeatTimedOut() {
this->stopAndShutdownPowerRepeat();
robot::emergencyStop();
log(LOG_ERROR, "Heartbeat timed out! Emergency stopping.\n");
Globals::E_STOP = true;
Globals::armIKEnabled = false;
}

void MissionControlProtocol::startPowerRepeat() {
// note: take care to lock mutexes in a consistent order
std::lock_guard<std::mutex> flagLock(_joint_repeat_running_mutex);
Expand Down Expand Up @@ -377,6 +386,10 @@ MissionControlProtocol::MissionControlProtocol(SingleClientWSServer& server)
this->addDisconnectionHandler(
std::bind(&MissionControlProtocol::stopAndShutdownPowerRepeat, this));

this->setHeartbeatTimedOutHandler(
HEARTBEAT_TIMEOUT_PERIOD,
std::bind(&MissionControlProtocol::handleHeartbeatTimedOut, this));

this->_streaming_running = true;
this->_streaming_thread = std::thread(&MissionControlProtocol::videoStreamTask, this);

Expand Down
1 change: 1 addition & 0 deletions src/network/MissionControlProtocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class MissionControlProtocol : public WebSocketProtocol { // TODO: add documenta
void sendJointPositionReport(const std::string& jointName, int32_t position);
void sendRoverPos();
void handleConnection();
void handleHeartbeatTimedOut();
void startPowerRepeat();
void stopAndShutdownPowerRepeat();
void setRequestedJointPower(jointid_t joint, double power);
Expand Down
11 changes: 11 additions & 0 deletions src/network/websocket/WebSocketProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ void WebSocketProtocol::addDisconnectionHandler(const connhandler_t& handler) {
disconnectionHandlers.push_back(handler);
}

void WebSocketProtocol::setHeartbeatTimedOutHandler(std::chrono::milliseconds timeout,
const heartbeattimeouthandler_t& handler) {
heartbeatInfo = {timeout, handler};
}

void WebSocketProtocol::clientConnected() {
for (const auto& f : connectionHandlers) {
f();
Expand All @@ -58,6 +63,12 @@ void WebSocketProtocol::clientDisconnected() {
}
}

void WebSocketProtocol::heartbeatTimedOut() {
if (heartbeatInfo.has_value()) {
heartbeatInfo->second();
}
}

void WebSocketProtocol::processMessage(const json& obj) const {
if (obj.contains(TYPE_KEY)) {
std::string messageType = obj[TYPE_KEY];
Expand Down
51 changes: 36 additions & 15 deletions src/network/websocket/WebSocketProtocol.h
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
#pragma once

#include <chrono>
#include <functional>
#include <map>
#include <optional>
#include <string>

#include <nlohmann/json.hpp>

namespace net{
namespace net {
namespace websocket {

using nlohmann::json;

typedef std::function<void(const json&)> msghandler_t;
typedef std::function<bool(const json&)> validator_t;
typedef std::function<void()> connhandler_t;
typedef std::function<void()> heartbeattimeouthandler_t;

/**
* @brief Defines a protocol which will be served at an endpoint of a server.
Expand Down Expand Up @@ -85,9 +88,38 @@ class WebSocketProtocol {

void addDisconnectionHandler(const connhandler_t& handler);

/**
* @brief Set the handler that's called when the heartbeat times out.
*
* If the heartbeat is reestablished after timing out, and then times out again, this
* handler will be called again.
*
* @param timeout The heartbeat timeout.
abhaybd marked this conversation as resolved.
Show resolved Hide resolved
* @param handler The handler to call when timed out.
*/
void setHeartbeatTimedOutHandler(std::chrono::milliseconds timeout,
const heartbeattimeouthandler_t& handler);

/**
* @brief Get the protocol path of the endpoint this protocol is served on.
*
* @return The protocol path, of the form "/foo/bar".
*/
std::string getProtocolPath() const;

private:
friend class SingleClientWSServer;

std::string protocolPath;
std::map<std::string, msghandler_t> handlerMap;
std::map<std::string, validator_t> validatorMap;
std::vector<connhandler_t> connectionHandlers;
std::vector<connhandler_t> disconnectionHandlers;
std::optional<std::pair<std::chrono::milliseconds, heartbeattimeouthandler_t>>
heartbeatInfo;
abhaybd marked this conversation as resolved.
Show resolved Hide resolved

/**
* @brief Process the given JSON object that was sent to this protocol's endpoint.
* Generally, this shouldn't be used by client code.
*
* @param obj The JSON object to be processed by this protocol. It is expected to have a
* "type" key.
Expand All @@ -96,29 +128,18 @@ class WebSocketProtocol {

/**
* @brief Invoke all connection handlers for this protocol.
* Generally, this shouldn't be used by client code.
*/
void clientConnected();

/**
* @brief Invoke all disconnection handlers for this protocol.
* Generally, this shouldn't be used by client code.
*/
void clientDisconnected();

/**
* @brief Get the protocol path of the endpoint this protocol is served on.
*
* @return The protocol path, of the form "/foo/bar".
* @brief Invoke the heartbeat timeout handlers for this protocol.
*/
std::string getProtocolPath() const;

private:
std::string protocolPath;
std::map<std::string, msghandler_t> handlerMap;
std::map<std::string, validator_t> validatorMap;
std::vector<connhandler_t> connectionHandlers;
std::vector<connhandler_t> disconnectionHandlers;
void heartbeatTimedOut();
};

} // namespace websocket
Expand Down
48 changes: 46 additions & 2 deletions src/network/websocket/WebSocketServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "../../Constants.h"
#include "../../log.h"
#include "../../utils/core.h"

#include <string>
namespace net {
Expand All @@ -26,6 +27,8 @@ SingleClientWSServer::SingleClientWSServer(const std::string& serverName, uint16
server.set_validate_handler([&](connection_hdl hdl) { return this->validate(hdl); });
server.set_message_handler(
[&](connection_hdl hdl, message_t msg) { this->onMessage(hdl, msg); });
server.set_pong_handler(
[&](connection_hdl hdl, std::string payload) { this->onPong(hdl, payload); });
}

SingleClientWSServer::~SingleClientWSServer() {
Expand Down Expand Up @@ -80,6 +83,27 @@ bool SingleClientWSServer::addProtocol(std::unique_ptr<WebSocketProtocol> protoc
std::string path = protocol->getProtocolPath();
if (protocolMap.find(path) == protocolMap.end()) {
protocolMap.emplace(path, std::move(protocol));
auto& protocolData = protocolMap.at(path);
const auto& heartbeatInfo = protocolData.protocol->heartbeatInfo;
if (heartbeatInfo.has_value()) {
auto eventID =
pingScheduler.scheduleEvent(heartbeatInfo->first / 2, [this, path]() {
auto& pd = this->protocolMap.at(path);
abhaybd marked this conversation as resolved.
Show resolved Hide resolved
std::lock_guard lock(pd.mutex);
if (pd.client.has_value()) {
log(LOG_DEBUG, "Ping!\n");
server.ping(pd.client.value(), path);
}
});
std::lock_guard l(protocolData.mutex);
// util::Watchdog is non-copyable and non-movable, so we must create in-place
// Since we want to create a member field of the pair in-place, it gets complicated
// so we have to use piecewise_construct to allow us to separately initialize all
// pair fields in-place
protocolData.heartbeatInfo.emplace(std::piecewise_construct,
std::tuple<decltype(eventID)>{eventID},
util::pairToTuple(heartbeatInfo.value()));
}
return true;
} else {
return false;
Expand Down Expand Up @@ -147,8 +171,15 @@ void SingleClientWSServer::onClose(connection_hdl hdl) {
serverName.c_str(), path.c_str(), client.c_str());

auto& protocolData = protocolMap.at(path);
protocolData.client.reset();
protocolData.protocol->clientDisconnected();
{
std::lock_guard lock(protocolData.mutex);
protocolData.client.reset();
if (protocolData.heartbeatInfo.has_value()) {
pingScheduler.removeEvent(protocolData.heartbeatInfo->first);
protocolData.heartbeatInfo.reset();
}
protocolData.protocol->clientDisconnected();
}
}

void SingleClientWSServer::onMessage(connection_hdl hdl, message_t message) {
Expand All @@ -162,5 +193,18 @@ void SingleClientWSServer::onMessage(connection_hdl hdl, message_t message) {
json obj = json::parse(jsonStr);
protocolMap.at(path).protocol->processMessage(obj);
}

void SingleClientWSServer::onPong(connection_hdl hdl, const std::string& payload) {
log(LOG_DEBUG, "Pong from %s\n", payload.c_str());
auto conn = server.get_con_from_hdl(hdl);

assert(protocolMap.find(payload) != protocolMap.end());
abhaybd marked this conversation as resolved.
Show resolved Hide resolved

auto& pd = protocolMap.at(payload);
std::lock_guard lock(pd.mutex);
if (pd.heartbeatInfo.has_value()) {
pd.heartbeatInfo->second.feed();
}
}
} // namespace websocket
} // namespace net
8 changes: 7 additions & 1 deletion src/network/websocket/WebSocketServer.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "../../utils/scheduler.h"
#include "WebSocketProtocol.h"

#include <functional>
Expand Down Expand Up @@ -96,8 +97,11 @@ class SingleClientWSServer {
class ProtocolData {
public:
ProtocolData(std::unique_ptr<WebSocketProtocol> protocol);
std::unique_ptr<WebSocketProtocol> protocol;
const std::unique_ptr<WebSocketProtocol> protocol;
std::optional<connection_hdl> client;
std::optional<std::pair<util::PeriodicScheduler<>::eventid_t, util::Watchdog<>>>
heartbeatInfo;
std::mutex mutex;
};

std::string serverName;
Expand All @@ -106,11 +110,13 @@ class SingleClientWSServer {
bool isRunning;
std::map<std::string, ProtocolData> protocolMap;
std::thread serverThread;
util::PeriodicScheduler<> pingScheduler;

bool validate(connection_hdl hdl);
void onOpen(connection_hdl hdl);
void onClose(connection_hdl hdl);
void onMessage(connection_hdl hdl, message_t message);
void onPong(connection_hdl hdl, const std::string& payload);
void serverTask();
};
} // namespace websocket
Expand Down
14 changes: 14 additions & 0 deletions src/utils/core.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#pragma once

#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <utility>

#include <frozen/string.h>

Expand Down Expand Up @@ -53,4 +55,16 @@ template <typename T> std::string to_string(const T& val) {
*/
frozen::string freezeStr(const std::string& str);

/**
* @brief Converts a pair to a tuple. Elements are copied to the returned tuple.
*
* @tparam T The type of the first element.
* @tparam U The type of the second element.
* @param pair The pair to convert to a tuple.
* @return std::tuple<T, U> The converted tuple.
*/
template <typename T, typename U> std::tuple<T, U> pairToTuple(const std::pair<T, U>& pair) {
abhaybd marked this conversation as resolved.
Show resolved Hide resolved
return std::tuple<T, U>(pair.first, pair.second);
}

} // namespace util
Loading