diff --git a/src/network/MissionControlProtocol.cpp b/src/network/MissionControlProtocol.cpp index 53a992761..30e854aca 100644 --- a/src/network/MissionControlProtocol.cpp +++ b/src/network/MissionControlProtocol.cpp @@ -30,6 +30,7 @@ using websocket::msghandler_t; using websocket::validator_t; const std::chrono::milliseconds TELEM_REPORT_PERIOD = 100ms; +const std::chrono::milliseconds HEARTBEAT_TIMEOUT_PERIOD = 3000ms; // TODO: possibly use frozen::string for this so we don't have to use raw char ptrs // request keys @@ -285,6 +286,14 @@ void MissionControlProtocol::handleConnection() { } } +void MissionControlProtocol::handleHeartbeatTimedOut() { + this->stopAndShutdownPowerRepeat(); + robot::emergencyStop(); + log(LOG_ERROR, "Heartbeat timed out! Emergency stopping.\n"); + Globals::E_STOP = true; + Globals::armIKEnabled = false; +} + void MissionControlProtocol::startPowerRepeat() { // note: take care to lock mutexes in a consistent order std::lock_guard flagLock(_joint_repeat_running_mutex); @@ -377,6 +386,10 @@ MissionControlProtocol::MissionControlProtocol(SingleClientWSServer& server) this->addDisconnectionHandler( std::bind(&MissionControlProtocol::stopAndShutdownPowerRepeat, this)); + this->setHeartbeatTimedOutHandler( + HEARTBEAT_TIMEOUT_PERIOD, + std::bind(&MissionControlProtocol::handleHeartbeatTimedOut, this)); + this->_streaming_running = true; this->_streaming_thread = std::thread(&MissionControlProtocol::videoStreamTask, this); diff --git a/src/network/MissionControlProtocol.h b/src/network/MissionControlProtocol.h index d9748a6e1..30fe49afe 100644 --- a/src/network/MissionControlProtocol.h +++ b/src/network/MissionControlProtocol.h @@ -69,6 +69,7 @@ class MissionControlProtocol : public WebSocketProtocol { // TODO: add documenta void sendJointPositionReport(const std::string& jointName, int32_t position); void sendRoverPos(); void handleConnection(); + void handleHeartbeatTimedOut(); void startPowerRepeat(); void stopAndShutdownPowerRepeat(); void setRequestedJointPower(jointid_t joint, double power); diff --git a/src/network/websocket/WebSocketProtocol.cpp b/src/network/websocket/WebSocketProtocol.cpp index 4b50de4c6..904c9d373 100644 --- a/src/network/websocket/WebSocketProtocol.cpp +++ b/src/network/websocket/WebSocketProtocol.cpp @@ -46,6 +46,11 @@ void WebSocketProtocol::addDisconnectionHandler(const connhandler_t& handler) { disconnectionHandlers.push_back(handler); } +void WebSocketProtocol::setHeartbeatTimedOutHandler(std::chrono::milliseconds timeout, + const heartbeattimeouthandler_t& handler) { + heartbeatInfo = {timeout, handler}; +} + void WebSocketProtocol::clientConnected() { for (const auto& f : connectionHandlers) { f(); @@ -58,6 +63,12 @@ void WebSocketProtocol::clientDisconnected() { } } +void WebSocketProtocol::heartbeatTimedOut() { + if (heartbeatInfo.has_value()) { + heartbeatInfo->second(); + } +} + void WebSocketProtocol::processMessage(const json& obj) const { if (obj.contains(TYPE_KEY)) { std::string messageType = obj[TYPE_KEY]; diff --git a/src/network/websocket/WebSocketProtocol.h b/src/network/websocket/WebSocketProtocol.h index 7b4e4473c..58a4e6bf3 100644 --- a/src/network/websocket/WebSocketProtocol.h +++ b/src/network/websocket/WebSocketProtocol.h @@ -1,12 +1,14 @@ #pragma once +#include #include #include +#include #include #include -namespace net{ +namespace net { namespace websocket { using nlohmann::json; @@ -14,6 +16,7 @@ using nlohmann::json; typedef std::function msghandler_t; typedef std::function validator_t; typedef std::function connhandler_t; +typedef std::function heartbeattimeouthandler_t; /** * @brief Defines a protocol which will be served at an endpoint of a server. @@ -85,9 +88,38 @@ class WebSocketProtocol { void addDisconnectionHandler(const connhandler_t& handler); + /** + * @brief Set the handler that's called when the heartbeat times out. + * + * If the heartbeat is reestablished after timing out, and then times out again, this + * handler will be called again. + * + * @param timeout The heartbeat timeout. + * @param handler The handler to call when timed out. + */ + void setHeartbeatTimedOutHandler(std::chrono::milliseconds timeout, + const heartbeattimeouthandler_t& handler); + + /** + * @brief Get the protocol path of the endpoint this protocol is served on. + * + * @return The protocol path, of the form "/foo/bar". + */ + std::string getProtocolPath() const; + +private: + friend class SingleClientWSServer; + + std::string protocolPath; + std::map handlerMap; + std::map validatorMap; + std::vector connectionHandlers; + std::vector disconnectionHandlers; + std::optional> + heartbeatInfo; + /** * @brief Process the given JSON object that was sent to this protocol's endpoint. - * Generally, this shouldn't be used by client code. * * @param obj The JSON object to be processed by this protocol. It is expected to have a * "type" key. @@ -96,29 +128,18 @@ class WebSocketProtocol { /** * @brief Invoke all connection handlers for this protocol. - * Generally, this shouldn't be used by client code. */ void clientConnected(); /** * @brief Invoke all disconnection handlers for this protocol. - * Generally, this shouldn't be used by client code. */ void clientDisconnected(); /** - * @brief Get the protocol path of the endpoint this protocol is served on. - * - * @return The protocol path, of the form "/foo/bar". + * @brief Invoke the heartbeat timeout handlers for this protocol. */ - std::string getProtocolPath() const; - -private: - std::string protocolPath; - std::map handlerMap; - std::map validatorMap; - std::vector connectionHandlers; - std::vector disconnectionHandlers; + void heartbeatTimedOut(); }; } // namespace websocket diff --git a/src/network/websocket/WebSocketServer.cpp b/src/network/websocket/WebSocketServer.cpp index 1cf57b836..05850ce24 100644 --- a/src/network/websocket/WebSocketServer.cpp +++ b/src/network/websocket/WebSocketServer.cpp @@ -2,6 +2,7 @@ #include "../../Constants.h" #include "../../log.h" +#include "../../utils/core.h" #include namespace net { @@ -26,6 +27,8 @@ SingleClientWSServer::SingleClientWSServer(const std::string& serverName, uint16 server.set_validate_handler([&](connection_hdl hdl) { return this->validate(hdl); }); server.set_message_handler( [&](connection_hdl hdl, message_t msg) { this->onMessage(hdl, msg); }); + server.set_pong_handler( + [&](connection_hdl hdl, std::string payload) { this->onPong(hdl, payload); }); } SingleClientWSServer::~SingleClientWSServer() { @@ -80,6 +83,27 @@ bool SingleClientWSServer::addProtocol(std::unique_ptr protoc std::string path = protocol->getProtocolPath(); if (protocolMap.find(path) == protocolMap.end()) { protocolMap.emplace(path, std::move(protocol)); + auto& protocolData = protocolMap.at(path); + const auto& heartbeatInfo = protocolData.protocol->heartbeatInfo; + if (heartbeatInfo.has_value()) { + auto eventID = + pingScheduler.scheduleEvent(heartbeatInfo->first / 2, [this, path]() { + auto& pd = this->protocolMap.at(path); + std::lock_guard lock(pd.mutex); + if (pd.client.has_value()) { + log(LOG_DEBUG, "Ping!\n"); + server.ping(pd.client.value(), path); + } + }); + std::lock_guard l(protocolData.mutex); + // util::Watchdog is non-copyable and non-movable, so we must create in-place + // Since we want to create a member field of the pair in-place, it gets complicated + // so we have to use piecewise_construct to allow us to separately initialize all + // pair fields in-place + protocolData.heartbeatInfo.emplace(std::piecewise_construct, + std::tuple{eventID}, + util::pairToTuple(heartbeatInfo.value())); + } return true; } else { return false; @@ -147,8 +171,15 @@ void SingleClientWSServer::onClose(connection_hdl hdl) { serverName.c_str(), path.c_str(), client.c_str()); auto& protocolData = protocolMap.at(path); - protocolData.client.reset(); - protocolData.protocol->clientDisconnected(); + { + std::lock_guard lock(protocolData.mutex); + protocolData.client.reset(); + if (protocolData.heartbeatInfo.has_value()) { + pingScheduler.removeEvent(protocolData.heartbeatInfo->first); + protocolData.heartbeatInfo.reset(); + } + protocolData.protocol->clientDisconnected(); + } } void SingleClientWSServer::onMessage(connection_hdl hdl, message_t message) { @@ -162,5 +193,18 @@ void SingleClientWSServer::onMessage(connection_hdl hdl, message_t message) { json obj = json::parse(jsonStr); protocolMap.at(path).protocol->processMessage(obj); } + +void SingleClientWSServer::onPong(connection_hdl hdl, const std::string& payload) { + log(LOG_DEBUG, "Pong from %s\n", payload.c_str()); + auto conn = server.get_con_from_hdl(hdl); + + assert(protocolMap.find(payload) != protocolMap.end()); + + auto& pd = protocolMap.at(payload); + std::lock_guard lock(pd.mutex); + if (pd.heartbeatInfo.has_value()) { + pd.heartbeatInfo->second.feed(); + } +} } // namespace websocket } // namespace net diff --git a/src/network/websocket/WebSocketServer.h b/src/network/websocket/WebSocketServer.h index f6a787ec7..f44e025c7 100644 --- a/src/network/websocket/WebSocketServer.h +++ b/src/network/websocket/WebSocketServer.h @@ -1,5 +1,6 @@ #pragma once +#include "../../utils/scheduler.h" #include "WebSocketProtocol.h" #include @@ -96,8 +97,11 @@ class SingleClientWSServer { class ProtocolData { public: ProtocolData(std::unique_ptr protocol); - std::unique_ptr protocol; + const std::unique_ptr protocol; std::optional client; + std::optional::eventid_t, util::Watchdog<>>> + heartbeatInfo; + std::mutex mutex; }; std::string serverName; @@ -106,11 +110,13 @@ class SingleClientWSServer { bool isRunning; std::map protocolMap; std::thread serverThread; + util::PeriodicScheduler<> pingScheduler; bool validate(connection_hdl hdl); void onOpen(connection_hdl hdl); void onClose(connection_hdl hdl); void onMessage(connection_hdl hdl, message_t message); + void onPong(connection_hdl hdl, const std::string& payload); void serverTask(); }; } // namespace websocket diff --git a/src/utils/core.h b/src/utils/core.h index 809bd9dfe..ea126e50f 100644 --- a/src/utils/core.h +++ b/src/utils/core.h @@ -1,8 +1,10 @@ #pragma once #include +#include #include #include +#include #include @@ -53,4 +55,16 @@ template std::string to_string(const T& val) { */ frozen::string freezeStr(const std::string& str); +/** + * @brief Converts a pair to a tuple. Elements are copied to the returned tuple. + * + * @tparam T The type of the first element. + * @tparam U The type of the second element. + * @param pair The pair to convert to a tuple. + * @return std::tuple The converted tuple. + */ +template std::tuple pairToTuple(const std::pair& pair) { + return std::tuple(pair.first, pair.second); +} + } // namespace util diff --git a/src/utils/scheduler.h b/src/utils/scheduler.h index 30b1139b2..bb07a9038 100644 --- a/src/utils/scheduler.h +++ b/src/utils/scheduler.h @@ -175,9 +175,30 @@ class PeriodicScheduler : private impl::Notifiable { std::unordered_set toRemove; }; +/** + * @brief Implements a thread-safe watchdog. + * + * A watchdog is a timer that is periodically reset (fed) by the client code. If the client + * fails to feed the watchdog for some duration, then the watchdog is "starved", and the + * callback is invoked. This is useful for implementing things such as heartbeats. + * + * @tparam Clock The clock to use for timing. + * + * @see https://en.wikipedia.org/wiki/Watchdog_timer + */ template class Watchdog : private impl::Notifiable { public: + /** + * @brief Construct a new Watchdog. + * + * @param duration The timeout duration. If not fed for at least this long, then the + * callback is invoked. + * @param callback The callback to invoke when the watchdog starves. + * @param keepCallingOnDeath If true, keep invoking @p callback every @p duration + * milliseconds until fed again. Otherwise, only call @p callback when starved, and do not + * call again until being reset and subsequently starved again. + */ explicit Watchdog(std::chrono::milliseconds duration, const std::function& callback, bool keepCallingOnDeath = false) : duration(duration), callback(callback), keepCallingOnDeath(keepCallingOnDeath), @@ -189,8 +210,6 @@ class Watchdog : private impl::Notifiable { { std::unique_lock lock(mutex); quitting = true; - // wake up the thread with the cv so it can quit - fed = true; } cv.notify_one(); thread.join(); @@ -198,6 +217,11 @@ class Watchdog : private impl::Notifiable { Watchdog& operator=(const Watchdog&) = delete; + /** + * @brief Feed the watchdog. + * + * Call at least once per period, or the watchdog starves. + */ void feed() { { std::lock_guard lock(mutex);