diff --git a/src/brayns/core/jsonv2/JsonValue.cpp b/src/brayns/core/jsonv2/JsonValue.cpp index e4e534ab3..abb5566b7 100644 --- a/src/brayns/core/jsonv2/JsonValue.cpp +++ b/src/brayns/core/jsonv2/JsonValue.cpp @@ -56,7 +56,7 @@ const JsonArray &getArray(const JsonValue &json) } catch (const Poco::Exception &e) { - throw JsonException(e.displayText()); + throw JsonException(e.message()); } } @@ -68,7 +68,7 @@ const JsonObject &getObject(const JsonValue &json) } catch (const Poco::Exception &e) { - throw JsonException(e.displayText()); + throw JsonException(e.message()); } } @@ -88,7 +88,7 @@ JsonValue parseJson(const std::string &data) } catch (const Poco::Exception &e) { - throw JsonException(e.displayText()); + throw JsonException(e.message()); } } } diff --git a/src/brayns/core/utils/Binary.h b/src/brayns/core/utils/Binary.h index 189acff47..5c266f042 100644 --- a/src/brayns/core/utils/Binary.h +++ b/src/brayns/core/utils/Binary.h @@ -58,7 +58,7 @@ T swapBytes(T value) inline std::string_view extractBytes(std::string_view &bytes, std::size_t count) { auto extracted = bytes.substr(0, count); - bytes.remove_prefix(count); + bytes.remove_prefix(extracted.size()); return extracted; } diff --git a/src/brayns/core/utils/Log.cpp b/src/brayns/core/utils/Log.cpp index b57623ebb..5a6f7b16d 100644 --- a/src/brayns/core/utils/Log.cpp +++ b/src/brayns/core/utils/Log.cpp @@ -21,19 +21,6 @@ #include "Log.h" -#include - -namespace -{ -using namespace brayns; - -Logger consoleLogger() -{ - auto handler = [](const auto &record) { std::cout << toString(record) << '\n'; }; - return Logger("Brayns", LogLevel::Info, handler); -} -} - namespace brayns { void Log::setLevel(LogLevel level) @@ -46,5 +33,5 @@ void Log::disable() setLevel(LogLevel::Off); } -Logger Log::_logger = consoleLogger(); +Logger Log::_logger = createConsoleLogger("Brayns"); } // namespace brayns diff --git a/src/brayns/core/utils/Logger.cpp b/src/brayns/core/utils/Logger.cpp index e38bf2fc3..e44732e59 100644 --- a/src/brayns/core/utils/Logger.cpp +++ b/src/brayns/core/utils/Logger.cpp @@ -21,6 +21,7 @@ #include "Logger.h" +#include #include namespace brayns @@ -71,4 +72,10 @@ bool Logger::isEnabled(LogLevel level) const { return level >= _level; } + +Logger createConsoleLogger(std::string name) +{ + auto handler = [](const auto &record) { std::cout << toString(record) << '\n'; }; + return Logger(std::move(name), LogLevel::Info, handler); +} } diff --git a/src/brayns/core/utils/Logger.h b/src/brayns/core/utils/Logger.h index d86c712e5..9a862a422 100644 --- a/src/brayns/core/utils/Logger.h +++ b/src/brayns/core/utils/Logger.h @@ -122,4 +122,6 @@ class Logger LogLevel _level; LogHandler _handler; }; + +Logger createConsoleLogger(std::string name); } diff --git a/src/brayns/core/websocket/WebSocket.cpp b/src/brayns/core/websocket/WebSocket.cpp index 959aefee7..2dd3b78dd 100644 --- a/src/brayns/core/websocket/WebSocket.cpp +++ b/src/brayns/core/websocket/WebSocket.cpp @@ -40,6 +40,11 @@ WebSocketFrame receiveFrame(Poco::Net::WebSocket &websocket) websocket.receiveFrame(buffer, flags); + if (flags == 0 && buffer.size() == 0) + { + throw WebSocketClosed("Empty frame received"); + } + auto finalFrame = flags & Poco::Net::WebSocket::FRAME_FLAG_FIN; auto opcode = flags & Poco::Net::WebSocket::FRAME_OP_BITMASK; @@ -79,7 +84,7 @@ WebSocketStatus getStatus(int errorCode) WebSocketException websocketException(const Poco::Exception &e) { auto status = getStatus(e.code()); - auto message = e.displayText(); + const auto &message = e.message(); return WebSocketException(status, message); } } @@ -102,6 +107,12 @@ WebSocket::WebSocket(const Poco::Net::WebSocket &websocket): { } +std::size_t WebSocket::getMaxFrameSize() const +{ + auto size = _websocket.getMaxPayloadSize(); + return static_cast(size); +} + WebSocketFrame WebSocket::receive() { try @@ -147,9 +158,4 @@ void WebSocket::close(WebSocketStatus status, std::string_view message) { } } - -void WebSocket::close(const WebSocketException &e) -{ - close(e.getStatus(), e.what()); -} } diff --git a/src/brayns/core/websocket/WebSocket.h b/src/brayns/core/websocket/WebSocket.h index a44c08ff8..5f834dca0 100644 --- a/src/brayns/core/websocket/WebSocket.h +++ b/src/brayns/core/websocket/WebSocket.h @@ -48,6 +48,12 @@ class WebSocketException : public std::runtime_error WebSocketStatus _status; }; +class WebSocketClosed : public std::runtime_error +{ +public: + using runtime_error::runtime_error; +}; + enum class WebSocketOpcode { Continuation = Poco::Net::WebSocket::FRAME_OP_CONT, @@ -77,10 +83,10 @@ class WebSocket public: explicit WebSocket(const Poco::Net::WebSocket &websocket); + std::size_t getMaxFrameSize() const; WebSocketFrame receive(); void send(const WebSocketFrameView &frame); void close(WebSocketStatus status, std::string_view message = {}); - void close(const WebSocketException &e); private: Poco::Net::WebSocket _websocket; diff --git a/src/brayns/core/websocket/WebSocketManager.cpp b/src/brayns/core/websocket/WebSocketHandler.cpp similarity index 58% rename from src/brayns/core/websocket/WebSocketManager.cpp rename to src/brayns/core/websocket/WebSocketHandler.cpp index 82f2e426d..97502d49d 100644 --- a/src/brayns/core/websocket/WebSocketManager.cpp +++ b/src/brayns/core/websocket/WebSocketHandler.cpp @@ -19,15 +19,17 @@ * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. */ -#include "WebSocketManager.h" +#include "WebSocketHandler.h" #include #include +#include + namespace { using namespace brayns::experimental; -using namespace brayns; +using brayns::Logger; struct WebSocketBuffer { @@ -73,18 +75,18 @@ void onBinary(const WebSocketFrame &frame, WebSocketBuffer &buffer, Logger &logg buffer.binary = true; } -void onClose(const WebSocketConnection &websocket, Logger &logger) +void onClose(WebSocket &websocket, Logger &logger) { logger.info("Close frame received, sending normal close frame"); - websocket.closeOk(); + websocket.close(WebSocketStatus::NormalClose); } -void onPing(const WebSocketConnection &websocket, Logger &logger) +void onPing(const WebSocketFrame &frame, WebSocket &websocket, Logger &logger) { logger.info("Ping frame received, sending pong frame"); - websocket.send({.opcode = WebSocketOpcode::Pong}); + websocket.send({.opcode = WebSocketOpcode::Pong, .data = frame.data}); } void onPong(Logger &logger) @@ -92,19 +94,52 @@ void onPong(Logger &logger) logger.info("Pong frame received, ignoring"); } -void respond(const WebSocketConnection &websocket, const RawResponse &response) +void respond(ClientId clientId, WebSocket &websocket, Logger &logger, const RawResponse &response) { - websocket.send({ - .opcode = response.binary ? WebSocketOpcode::Binary : WebSocketOpcode::Text, - .data = response.data, - }); + auto data = response.data; + + logger.info("Sending response of {} bytes to client {}", data.size(), clientId); + + if (!response.binary) + { + logger.debug("Text response data: {}", data); + } + + auto opcode = response.binary ? WebSocketOpcode::Binary : WebSocketOpcode::Text; + + auto maxFrameSize = websocket.getMaxFrameSize(); + + while (true) + { + auto chunk = extractBytes(data, maxFrameSize); + + logger.info("Sending websocket frame of {} bytes", data.size()); + + try + { + websocket.send({opcode, chunk}); + } + catch (const WebSocketException &e) + { + logger.warn("Failed to send websocket frame: {}", e.what()); + websocket.close(e.getStatus(), e.what()); + return; + } + catch (...) + { + logger.error("Unexpected error while sending websocket frame"); + websocket.close(WebSocketStatus::UnexpectedCondition, "Internal error"); + return; + } + + if (data.empty()) + { + return; + } + } } -void runClientLoop( - ClientId clientId, - const WebSocketConnection &websocket, - const WebSocketListener &listener, - Logger &logger) +void runClientLoop(ClientId clientId, WebSocket &websocket, const WebSocketListener &listener, Logger &logger) { auto buffer = WebSocketBuffer(); @@ -127,12 +162,13 @@ void runClientLoop( onClose(websocket, logger); return; case WebSocketOpcode::Ping: - onPing(websocket, logger); + onPing(frame, websocket, logger); continue; case WebSocketOpcode::Pong: onPong(logger); continue; default: + logger.error("Unexpected invalid opcode: {}", static_cast(frame.opcode)); throw WebSocketException(WebSocketStatus::UnexpectedCondition, "Unexpected invalid opcode"); } @@ -141,37 +177,55 @@ void runClientLoop( continue; } - listener.onRequest({ + auto request = RawRequest{ .clientId = clientId, .data = std::exchange(buffer.data, {}), .binary = buffer.binary, - .respond = [=](const auto &response) { respond(websocket, response); }, - }); + .respond = [=, &logger](const auto &response) mutable { respond(clientId, websocket, logger, response); }, + }; + + logger.info("Received request of {} bytes from client {}", request.data.size(), clientId); + + if (!request.binary) + { + logger.debug("Text request data: {}", request.data); + } + + listener.onRequest(request); } } } namespace brayns::experimental { -WebSocketManager::WebSocketManager(WebSocketListener listener, Logger &logger): +WebSocketHandler::WebSocketHandler(WebSocketListener listener, Logger &logger): _listener(std::move(listener)), _logger(&logger) { } -void WebSocketManager::handle(const WebSocketConnection &websocket) +void WebSocketHandler::handle(WebSocket &websocket) { auto clientId = _clientIds.next(); + _listener.onConnect(clientId); + try { runClientLoop(clientId, websocket, _listener, *_logger); } - catch (...) + catch (const WebSocketClosed &e) + { + _logger->warn("WebSocket closed by peer: {}", e.what()); + } + catch (const WebSocketException &e) { - _logger->error("Unexpected error in websocket client loop"); + _logger->warn("Error while processing websocket: '{}'", e.what()); + websocket.close(e.getStatus(), e.what()); } + _listener.onDisconnect(clientId); + _clientIds.recycle(clientId); } } diff --git a/src/brayns/core/websocket/WebSocketManager.h b/src/brayns/core/websocket/WebSocketHandler.h similarity index 84% rename from src/brayns/core/websocket/WebSocketManager.h rename to src/brayns/core/websocket/WebSocketHandler.h index 13f6a7b26..e8866ceeb 100644 --- a/src/brayns/core/websocket/WebSocketManager.h +++ b/src/brayns/core/websocket/WebSocketHandler.h @@ -32,13 +32,6 @@ namespace brayns::experimental { -struct WebSocketConnection -{ - std::function receive; - std::function send; - std::function closeOk; -}; - using ClientId = std::uint32_t; struct RawResponse @@ -62,12 +55,12 @@ struct WebSocketListener std::function onRequest; }; -class WebSocketManager +class WebSocketHandler { public: - explicit WebSocketManager(WebSocketListener listener, Logger &logger); + explicit WebSocketHandler(WebSocketListener listener, Logger &logger); - void handle(const WebSocketConnection &websocket); + void handle(WebSocket &websocket); private: WebSocketListener _listener; diff --git a/src/brayns/core/websocket/WebSocketServer.cpp b/src/brayns/core/websocket/WebSocketServer.cpp index e98757c8e..d6a4b3dda 100644 --- a/src/brayns/core/websocket/WebSocketServer.cpp +++ b/src/brayns/core/websocket/WebSocketServer.cpp @@ -83,20 +83,15 @@ class WebSocketRequestHandler : public Poco::Net::HTTPRequestHandler return; } - _logger->info("Upgrade complete, host {} is now connected", request.getHost()); + _logger->info("Upgrade complete, {} is now connected", request.getHost()); try { - _handler(*websocket); - } - catch (const WebSocketException &e) - { - _logger->warn("Error while processing websocket: '{}'", e.what()); - websocket->close(e); + _handler.handle(*websocket); } catch (...) { - _logger->error("Internal error while processing websocket"); + _logger->error("Unexpected error in websocket request handler"); websocket->close(WebSocketStatus::UnexpectedCondition, "Internal error"); } } @@ -120,7 +115,7 @@ class WebSocketRequestHandler : public Poco::Net::HTTPRequestHandler } catch (const Poco::Net::WebSocketException &e) { - auto message = e.displayText(); + const auto &message = e.message(); _logger->warn("Upgrade failed: '{}'", message); @@ -161,7 +156,7 @@ class RequestHandlerFactory : public Poco::Net::HTTPRequestHandlerFactory return new HealthcheckRequestHandler(*_logger); } - if (uri == "") + if (uri == "/") { return new WebSocketRequestHandler(_handler, _maxFrameSize, *_logger); } @@ -240,7 +235,7 @@ WebSocketServer::~WebSocketServer() WebSocketServer startWebSocketServer(const WebSocketServerSettings &settings, WebSocketHandler handler, Logger &logger) { - if (settings.maxFrameSize > std::numeric_limits::max()) + if (settings.maxFrameSize > static_cast(std::numeric_limits::max())) { throw std::runtime_error("Max frame size cannot be above 2 ** 31"); } @@ -261,7 +256,7 @@ WebSocketServer startWebSocketServer(const WebSocketServerSettings &settings, We } catch (const Poco::Exception &e) { - throw std::runtime_error(fmt::format("Failed to start websocket server: {}", e.displayText())); + throw std::runtime_error(fmt::format("Failed to start websocket server: {}", e.message())); } } } diff --git a/src/brayns/core/websocket/WebSocketServer.h b/src/brayns/core/websocket/WebSocketServer.h index 49772b1eb..7b6f6d43d 100644 --- a/src/brayns/core/websocket/WebSocketServer.h +++ b/src/brayns/core/websocket/WebSocketServer.h @@ -22,6 +22,7 @@ #pragma once #include +#include #include #include @@ -30,29 +31,28 @@ #include #include "WebSocket.h" +#include "WebSocketHandler.h" namespace brayns::experimental { struct SslSettings { std::string privateKeyFile; - std::string privateKeyPassphrase; std::string certificateFile; - std::string caLocation; + std::string caLocation = {}; + std::string privateKeyPassphrase = {}; }; struct WebSocketServerSettings { - std::string host; - std::uint16_t port; - std::size_t maxThreadCount; - std::size_t queueSize; - std::size_t maxFrameSize; - std::optional ssl; + std::string host = "localhost"; + std::uint16_t port = 5000; + std::size_t maxThreadCount = 2; + std::size_t queueSize = 2; + std::size_t maxFrameSize = std::numeric_limits::max(); + std::optional ssl = std::nullopt; }; -using WebSocketHandler = std::function; - class WebSocketServer { public: diff --git a/tests/core/websocket/TestWebSocket.cpp b/tests/core/websocket/TestWebSocket.cpp new file mode 100644 index 000000000..266fda782 --- /dev/null +++ b/tests/core/websocket/TestWebSocket.cpp @@ -0,0 +1,71 @@ +/* Copyright (c) 2015-2024, EPFL/Blue Brain Project + * All rights reserved. Do not distribute without permission. + * Responsible author: Nadir Roman Guerrero + * + * This file is part of Brayns + * + * This library is free software; you can redistribute it and/or modify it under + * the terms of the GNU Lesser General Public License version 3.0 as published + * by the Free Software Foundation. + * + * This library is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS + * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more + * details. + * + * You should have received a copy of the GNU Lesser General Public License + * along with this library; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ + +#include + +#include +#include +#include + +#include +#include + +using namespace brayns::experimental; +using brayns::createConsoleLogger; +using brayns::Logger; +using brayns::LogLevel; + +TEST_CASE("WebSocket") +{ + SUBCASE("Server") + { + /*auto logger = createConsoleLogger("Test"); + logger.setLevel(LogLevel::Debug); + + auto mutex = std::mutex(); + auto condition = std::condition_variable(); + + auto onConnect = [&](auto clientId) { logger.info("Client {} connected", clientId); }; + + auto onDisconnect = [&](auto clientId) + { + logger.info("Client {} disconnected", clientId); + auto lock = std::lock_guard(mutex); + condition.notify_all(); + }; + + auto onRequest = [&](const auto &request) + { + logger.info("Request of {} bytes", request.data.size()); + // request.respond({request.data, request.binary}); + }; + + auto listener = WebSocketListener{onConnect, onDisconnect, onRequest}; + + auto handler = WebSocketHandler(listener, logger); + + auto settings = WebSocketServerSettings{}; + + auto server = startWebSocketServer(settings, handler, logger); + + auto lock = std::unique_lock(mutex); + condition.wait(lock);*/ + } +}