Skip to content

Commit

Permalink
Use watchdog in MCP
Browse files Browse the repository at this point in the history
  • Loading branch information
abhaybd committed Aug 24, 2023
1 parent 6999f26 commit 5c5a406
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 20 deletions.
13 changes: 13 additions & 0 deletions src/network/MissionControlProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ using websocket::msghandler_t;
using websocket::validator_t;

const std::chrono::milliseconds TELEM_REPORT_PERIOD = 100ms;
const std::chrono::milliseconds HEARTBEAT_TIMEOUT_PERIOD = 3000ms;

// TODO: possibly use frozen::string for this so we don't have to use raw char ptrs
// request keys
Expand Down Expand Up @@ -285,6 +286,14 @@ void MissionControlProtocol::handleConnection() {
}
}

void MissionControlProtocol::handleHeartbeatTimedOut() {
this->stopAndShutdownPowerRepeat();
robot::emergencyStop();
log(LOG_ERROR, "Heartbeat timed out! Emergency stopping.\n");
Globals::E_STOP = true;
Globals::armIKEnabled = false;
}

void MissionControlProtocol::startPowerRepeat() {
// note: take care to lock mutexes in a consistent order
std::lock_guard<std::mutex> flagLock(_joint_repeat_running_mutex);
Expand Down Expand Up @@ -377,6 +386,10 @@ MissionControlProtocol::MissionControlProtocol(SingleClientWSServer& server)
this->addDisconnectionHandler(
std::bind(&MissionControlProtocol::stopAndShutdownPowerRepeat, this));

this->setHeartbeatTimedOutHandler(
HEARTBEAT_TIMEOUT_PERIOD,
std::bind(&MissionControlProtocol::handleHeartbeatTimedOut, this));

this->_streaming_running = true;
this->_streaming_thread = std::thread(&MissionControlProtocol::videoStreamTask, this);

Expand Down
1 change: 1 addition & 0 deletions src/network/MissionControlProtocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class MissionControlProtocol : public WebSocketProtocol { // TODO: add documenta
void sendJointPositionReport(const std::string& jointName, int32_t position);
void sendRoverPos();
void handleConnection();
void handleHeartbeatTimedOut();
void startPowerRepeat();
void stopAndShutdownPowerRepeat();
void setRequestedJointPower(jointid_t joint, double power);
Expand Down
11 changes: 11 additions & 0 deletions src/network/websocket/WebSocketProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ void WebSocketProtocol::addDisconnectionHandler(const connhandler_t& handler) {
disconnectionHandlers.push_back(handler);
}

void WebSocketProtocol::setHeartbeatTimedOutHandler(std::chrono::milliseconds timeout,
const heartbeattimeouthandler_t& handler) {
heartbeatInfo = {timeout, handler};
}

void WebSocketProtocol::clientConnected() {
for (const auto& f : connectionHandlers) {
f();
Expand All @@ -58,6 +63,12 @@ void WebSocketProtocol::clientDisconnected() {
}
}

void WebSocketProtocol::heartbeatTimedOut() {
if (heartbeatInfo.has_value()) {
heartbeatInfo->second();
}
}

void WebSocketProtocol::processMessage(const json& obj) const {
if (obj.contains(TYPE_KEY)) {
std::string messageType = obj[TYPE_KEY];
Expand Down
51 changes: 36 additions & 15 deletions src/network/websocket/WebSocketProtocol.h
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
#pragma once

#include <chrono>
#include <functional>
#include <map>
#include <optional>
#include <string>

#include <nlohmann/json.hpp>

namespace net{
namespace net {
namespace websocket {

using nlohmann::json;

typedef std::function<void(const json&)> msghandler_t;
typedef std::function<bool(const json&)> validator_t;
typedef std::function<void()> connhandler_t;
typedef std::function<void()> heartbeattimeouthandler_t;

/**
* @brief Defines a protocol which will be served at an endpoint of a server.
Expand Down Expand Up @@ -85,9 +88,38 @@ class WebSocketProtocol {

void addDisconnectionHandler(const connhandler_t& handler);

/**
* @brief Set the handler that's called when the heartbeat times out.
*
* If the heartbeat is reestablished after timing out, and then times out again, this
* handler will be called again.
*
* @param timeout The heartbeat timeout.
* @param handler The handler to call when timed out.
*/
void setHeartbeatTimedOutHandler(std::chrono::milliseconds timeout,
const heartbeattimeouthandler_t& handler);

/**
* @brief Get the protocol path of the endpoint this protocol is served on.
*
* @return The protocol path, of the form "/foo/bar".
*/
std::string getProtocolPath() const;

private:
friend class SingleClientWSServer;

std::string protocolPath;
std::map<std::string, msghandler_t> handlerMap;
std::map<std::string, validator_t> validatorMap;
std::vector<connhandler_t> connectionHandlers;
std::vector<connhandler_t> disconnectionHandlers;
std::optional<std::pair<std::chrono::milliseconds, heartbeattimeouthandler_t>>
heartbeatInfo;

/**
* @brief Process the given JSON object that was sent to this protocol's endpoint.
* Generally, this shouldn't be used by client code.
*
* @param obj The JSON object to be processed by this protocol. It is expected to have a
* "type" key.
Expand All @@ -96,29 +128,18 @@ class WebSocketProtocol {

/**
* @brief Invoke all connection handlers for this protocol.
* Generally, this shouldn't be used by client code.
*/
void clientConnected();

/**
* @brief Invoke all disconnection handlers for this protocol.
* Generally, this shouldn't be used by client code.
*/
void clientDisconnected();

/**
* @brief Get the protocol path of the endpoint this protocol is served on.
*
* @return The protocol path, of the form "/foo/bar".
* @brief Invoke the heartbeat timeout handlers for this protocol.
*/
std::string getProtocolPath() const;

private:
std::string protocolPath;
std::map<std::string, msghandler_t> handlerMap;
std::map<std::string, validator_t> validatorMap;
std::vector<connhandler_t> connectionHandlers;
std::vector<connhandler_t> disconnectionHandlers;
void heartbeatTimedOut();
};

} // namespace websocket
Expand Down
48 changes: 46 additions & 2 deletions src/network/websocket/WebSocketServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "../../Constants.h"
#include "../../log.h"
#include "../../utils/core.h"

#include <string>
namespace net {
Expand All @@ -26,6 +27,8 @@ SingleClientWSServer::SingleClientWSServer(const std::string& serverName, uint16
server.set_validate_handler([&](connection_hdl hdl) { return this->validate(hdl); });
server.set_message_handler(
[&](connection_hdl hdl, message_t msg) { this->onMessage(hdl, msg); });
server.set_pong_handler(
[&](connection_hdl hdl, std::string payload) { this->onPong(hdl, payload); });
}

SingleClientWSServer::~SingleClientWSServer() {
Expand Down Expand Up @@ -80,6 +83,27 @@ bool SingleClientWSServer::addProtocol(std::unique_ptr<WebSocketProtocol> protoc
std::string path = protocol->getProtocolPath();
if (protocolMap.find(path) == protocolMap.end()) {
protocolMap.emplace(path, std::move(protocol));
auto& protocolData = protocolMap.at(path);
const auto& heartbeatInfo = protocolData.protocol->heartbeatInfo;
if (heartbeatInfo.has_value()) {
auto eventID =
pingScheduler.scheduleEvent(heartbeatInfo->first / 2, [this, path]() {
auto& pd = this->protocolMap.at(path);
std::lock_guard lock(pd.mutex);
if (pd.client.has_value()) {
log(LOG_DEBUG, "Ping!\n");
server.ping(pd.client.value(), path);
}
});
std::lock_guard l(protocolData.mutex);
// util::Watchdog is non-copyable and non-movable, so we must create in-place
// Since we want to create a member field of the pair in-place, it gets complicated
// so we have to use piecewise_construct to allow us to separately initialize all
// pair fields in-place
protocolData.heartbeatInfo.emplace(std::piecewise_construct,
std::tuple<decltype(eventID)>{eventID},
util::pairToTuple(heartbeatInfo.value()));
}
return true;
} else {
return false;
Expand Down Expand Up @@ -147,8 +171,15 @@ void SingleClientWSServer::onClose(connection_hdl hdl) {
serverName.c_str(), path.c_str(), client.c_str());

auto& protocolData = protocolMap.at(path);
protocolData.client.reset();
protocolData.protocol->clientDisconnected();
{
std::lock_guard lock(protocolData.mutex);
protocolData.client.reset();
if (protocolData.heartbeatInfo.has_value()) {
pingScheduler.removeEvent(protocolData.heartbeatInfo->first);
protocolData.heartbeatInfo.reset();
}
protocolData.protocol->clientDisconnected();
}
}

void SingleClientWSServer::onMessage(connection_hdl hdl, message_t message) {
Expand All @@ -162,5 +193,18 @@ void SingleClientWSServer::onMessage(connection_hdl hdl, message_t message) {
json obj = json::parse(jsonStr);
protocolMap.at(path).protocol->processMessage(obj);
}

void SingleClientWSServer::onPong(connection_hdl hdl, const std::string& payload) {
log(LOG_DEBUG, "Pong from %s\n", payload.c_str());
auto conn = server.get_con_from_hdl(hdl);

assert(protocolMap.find(payload) != protocolMap.end());

auto& pd = protocolMap.at(payload);
std::lock_guard lock(pd.mutex);
if (pd.heartbeatInfo.has_value()) {
pd.heartbeatInfo->second.feed();
}
}
} // namespace websocket
} // namespace net
8 changes: 7 additions & 1 deletion src/network/websocket/WebSocketServer.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "../../utils/scheduler.h"
#include "WebSocketProtocol.h"

#include <functional>
Expand Down Expand Up @@ -96,8 +97,11 @@ class SingleClientWSServer {
class ProtocolData {
public:
ProtocolData(std::unique_ptr<WebSocketProtocol> protocol);
std::unique_ptr<WebSocketProtocol> protocol;
const std::unique_ptr<WebSocketProtocol> protocol;
std::optional<connection_hdl> client;
std::optional<std::pair<util::PeriodicScheduler<>::eventid_t, util::Watchdog<>>>
heartbeatInfo;
std::mutex mutex;
};

std::string serverName;
Expand All @@ -106,11 +110,13 @@ class SingleClientWSServer {
bool isRunning;
std::map<std::string, ProtocolData> protocolMap;
std::thread serverThread;
util::PeriodicScheduler<> pingScheduler;

bool validate(connection_hdl hdl);
void onOpen(connection_hdl hdl);
void onClose(connection_hdl hdl);
void onMessage(connection_hdl hdl, message_t message);
void onPong(connection_hdl hdl, const std::string& payload);
void serverTask();
};
} // namespace websocket
Expand Down
14 changes: 14 additions & 0 deletions src/utils/core.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#pragma once

#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <utility>

#include <frozen/string.h>

Expand Down Expand Up @@ -53,4 +55,16 @@ template <typename T> std::string to_string(const T& val) {
*/
frozen::string freezeStr(const std::string& str);

/**
* @brief Converts a pair to a tuple. Elements are copied to the returned tuple.
*
* @tparam T The type of the first element.
* @tparam U The type of the second element.
* @param pair The pair to convert to a tuple.
* @return std::tuple<T, U> The converted tuple.
*/
template <typename T, typename U> std::tuple<T, U> pairToTuple(const std::pair<T, U>& pair) {
return std::tuple<T, U>(pair.first, pair.second);
}

} // namespace util
28 changes: 26 additions & 2 deletions src/utils/scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,30 @@ class PeriodicScheduler : private impl::Notifiable {
std::unordered_set<eventid_t> toRemove;
};

/**
* @brief Implements a thread-safe watchdog.
*
* A watchdog is a timer that is periodically reset (fed) by the client code. If the client
* fails to feed the watchdog for some duration, then the watchdog is "starved", and the
* callback is invoked. This is useful for implementing things such as heartbeats.
*
* @tparam Clock The clock to use for timing.
*
* @see https://en.wikipedia.org/wiki/Watchdog_timer
*/
template <typename Clock = std::chrono::steady_clock>
class Watchdog : private impl::Notifiable {
public:
/**
* @brief Construct a new Watchdog.
*
* @param duration The timeout duration. If not fed for at least this long, then the
* callback is invoked.
* @param callback The callback to invoke when the watchdog starves.
* @param keepCallingOnDeath If true, keep invoking @p callback every @p duration
* milliseconds until fed again. Otherwise, only call @p callback when starved, and do not
* call again until being reset and subsequently starved again.
*/
explicit Watchdog(std::chrono::milliseconds duration,
const std::function<void()>& callback, bool keepCallingOnDeath = false)
: duration(duration), callback(callback), keepCallingOnDeath(keepCallingOnDeath),
Expand All @@ -189,15 +210,18 @@ class Watchdog : private impl::Notifiable {
{
std::lock_guard lock(mutex);
quitting = true;
// wake up the thread with the cv so it can quit
fed = true;
}
cv.notify_all();
thread.join();
}

Watchdog& operator=(const Watchdog&) = delete;

/**
* @brief Feed the watchdog.
*
* Call at least once per period, or the watchdog starves.
*/
void feed() {
{
std::lock_guard lock(mutex);
Expand Down

0 comments on commit 5c5a406

Please sign in to comment.