diff --git a/.github/workflows/build-linux.yml b/.github/workflows/build-linux.yml index 82e3fe1..2f9c5c8 100644 --- a/.github/workflows/build-linux.yml +++ b/.github/workflows/build-linux.yml @@ -39,4 +39,4 @@ jobs: - name: Test working-directory: ${{ runner.workspace }}/build - run: ctest --timeout 30 -C ${{ matrix.build-type }} \ No newline at end of file + run: ctest --timeout 30 -j 4 -C ${{ matrix.build-type }} \ No newline at end of file diff --git a/.github/workflows/build-windows.yml b/.github/workflows/build-windows.yml index 38e5568..0e4f2e4 100644 --- a/.github/workflows/build-windows.yml +++ b/.github/workflows/build-windows.yml @@ -38,4 +38,4 @@ jobs: - name: Test working-directory: ${{ runner.workspace }}/build - run: ctest --timeout 30 -C ${{ matrix.build-type }} \ No newline at end of file + run: ctest --timeout 30 -j 4 -C ${{ matrix.build-type }} \ No newline at end of file diff --git a/CMakeSettings.json b/CMakeSettings.json index 856034c..1e83a26 100644 --- a/CMakeSettings.json +++ b/CMakeSettings.json @@ -10,7 +10,7 @@ "installRoot": "${projectDir}\\out\\install\\${name}", "cmakeCommandArgs": "", "buildCommandArgs": "", - "ctestCommandArgs": "--timeout 30" + "ctestCommandArgs": "--timeout 30 -j 4" }, { "name": "2. Windows x86 Release", @@ -22,7 +22,7 @@ "installRoot": "${projectDir}\\out\\install\\${name}", "cmakeCommandArgs": "", "buildCommandArgs": "", - "ctestCommandArgs": "--timeout 30" + "ctestCommandArgs": "--timeout 30 -j 4" }, { "name": "3. Windows x64 Debug", @@ -34,7 +34,7 @@ "installRoot": "${projectDir}\\out\\install\\${name}", "cmakeCommandArgs": "", "buildCommandArgs": "", - "ctestCommandArgs": "--timeout 30" + "ctestCommandArgs": "--timeout 30 -j 4" }, { "name": "4. Windows x64 Release", @@ -46,7 +46,7 @@ "installRoot": "${projectDir}\\out\\install\\${name}", "cmakeCommandArgs": "", "buildCommandArgs": "", - "ctestCommandArgs": "--timeout 30" + "ctestCommandArgs": "--timeout 30 -j 4" }, { "name": "5. Linux x86 Debug", @@ -56,7 +56,7 @@ "remoteCopySourcesExclusionList": [ ".vs", ".git", "out" ], "cmakeCommandArgs": "-DLIBNETWRK_LINUX_ARCHITECTURE:STRING=x86", "buildCommandArgs": "", - "ctestCommandArgs": "--timeout 30", + "ctestCommandArgs": "--timeout 30 -j 4", "inheritEnvironments": [ "linux_x86" ], "intelliSenseMode": "linux-gcc-x86", "remoteMachineName": "${defaultRemoteMachineName}", @@ -76,7 +76,7 @@ "remoteCopySourcesExclusionList": [ ".vs", ".git", "out" ], "cmakeCommandArgs": "-DLIBNETWRK_LINUX_ARCHITECTURE:STRING=x86", "buildCommandArgs": "", - "ctestCommandArgs": "--timeout 30", + "ctestCommandArgs": "--timeout 30 -j 4", "inheritEnvironments": [ "linux_x86" ], "intelliSenseMode": "linux-gcc-x86", "remoteMachineName": "${defaultRemoteMachineName}", @@ -96,7 +96,7 @@ "remoteCopySourcesExclusionList": [ ".vs", ".git", "out" ], "cmakeCommandArgs": "", "buildCommandArgs": "", - "ctestCommandArgs": "--timeout 30", + "ctestCommandArgs": "--timeout 30 -j 4", "inheritEnvironments": [ "linux_x64" ], "intelliSenseMode": "linux-gcc-x64", "remoteMachineName": "${defaultRemoteMachineName}", @@ -116,7 +116,7 @@ "remoteCopySourcesExclusionList": [ ".vs", ".git", "out" ], "cmakeCommandArgs": "", "buildCommandArgs": "", - "ctestCommandArgs": "--timeout 30", + "ctestCommandArgs": "--timeout 30 -j 4", "inheritEnvironments": [ "linux_x64" ], "intelliSenseMode": "linux-gcc-x64", "remoteMachineName": "${defaultRemoteMachineName}", diff --git a/examples/tcp_echo/tcp_echo_service.cpp b/examples/tcp_echo/tcp_echo_service.cpp index a68f9de..895b12d 100644 --- a/examples/tcp_echo/tcp_echo_service.cpp +++ b/examples/tcp_echo/tcp_echo_service.cpp @@ -7,22 +7,26 @@ using namespace libnetwrk; class tcp_echo_service : public tcp_service { public: - tcp_echo_service() : tcp_service() {} + tcp_echo_service() : tcp_service() { + set_message_callback([this](auto command, auto message) { + ev_message(command, message); + }); + } - void ev_message(owned_message_t& msg) override { + void ev_message(command_t command, owned_message_t* msg) { message_t response; - switch (msg.msg.command()) { + switch (command) { case commands::c2s_echo: { std::string text; - msg.msg >> text; + msg->msg >> text; - LIBNETWRK_INFO(this->name, "{}:{}\t{}", - msg.sender->get_ip().c_str(), msg.sender->get_port(), text); + LIBNETWRK_INFO(get_name(), "{}:{}\t{}", + msg->sender->get_ip().c_str(), msg->sender->get_port(), text); response.set_command(commands::s2c_echo); response << text; - msg.sender->send(response); + msg->sender->send(response); break; } default: break; diff --git a/include/libnetwrk/net/core/base_service.hpp b/include/libnetwrk/net/core/base_service.hpp deleted file mode 100644 index d033362..0000000 --- a/include/libnetwrk/net/core/base_service.hpp +++ /dev/null @@ -1,402 +0,0 @@ -#pragma once - -#include "libnetwrk/net/core/context.hpp" -#include "libnetwrk/net/core/base_client_connection.hpp" - -#include -#include -#include - -namespace libnetwrk { - template - class base_service : public context> { - public: - // This service type - using service_t = base_service; - - // Connection type for this service - using connection_t = base_client_connection; - - // Context type for this service - using context_t = context; - - // Message type - using message_t = connection_t::message_t; - - // Owned message type for this service - using owned_message_t = connection_t::owned_message_t; - - // Send predicate - using send_predicate = std::function)>; - - private: - // Timer type - using timer_t = asio::steady_timer; - - public: - uint8_t gc_freq_sec = 15U; - - public: - base_service() = delete; - base_service(const service_t&) = delete; - base_service(service_t&&) = default; - - service_t& operator=(const service_t&) = delete; - service_t& operator=(service_t&&) = default; - - base_service(const std::string& name) - : context_t(name) {} - - public: - /* - Get service status. - */ - bool running() { - return this->m_status == service_status::started; - } - - /* - Start service. - */ - bool start(const char* host, const unsigned short port) { - if (this->m_status != service_status::stopped) - return false; - - this->m_status = service_status::starting; - - bool started = impl_start(host, port); - - if (started) { - ev_service_started(); - this->m_status = service_status::started; - } - else { - this->m_status = service_status::stopped; - } - - return started; - } - - /* - Stop service. - */ - virtual void stop() = 0; - - /* - Queue up a function to run. - */ - void queue_async_job(std::function const& lambda) { - asio::post(*(this->io_context), lambda); - } - - /* - Send a message to client. - */ - void send(std::shared_ptr client, message_t& message, libnetwrk::send_flags flags = libnetwrk::send_flags::none) { - if (!client || client->is_connected()) return; - - client->send(message, flags); - } - - /* - Send a message to clients. - Predicate can be used to filter clients. - */ - void send_all(message_t& message, libnetwrk::send_flags flags = libnetwrk::send_flags::none, send_predicate predicate = nullptr) { - std::shared_ptr> outgoing_message; - - if (LIBNETWRK_FLAG_SET(flags, libnetwrk::send_flags::keep_message)) { - outgoing_message = std::make_shared>(message); - } - else { - outgoing_message = std::make_shared>(std::move(message)); - } - - std::lock_guard guard(m_connections_mutex); - for (auto& client : m_connections) { - if (!client || !client->is_connected()) continue; - if (predicate && !predicate(client)) continue; - - client->send(outgoing_message); - } - } - - protected: - uint64_t m_ids = 0U; - - std::list> m_connections; - std::mutex m_connections_mutex; - - protected: - virtual ~base_service() { - if (this->m_status == service_status::stopped) - return; - - teardown(); - this->m_status = service_status::stopped; - }; - - protected: - /* - Called when the service was successfully started. - */ - virtual void ev_service_started() {}; - - /* - Called when service stopped. - */ - virtual void ev_service_stopped() {}; - - /* - Called when processing messages. - */ - virtual void ev_message(owned_message_t& msg) override {}; - - /* - Called before client is fully accepted. - Allows performing checks on client before accepting (blacklist, whitelist). - */ - virtual bool ev_before_client_connected(std::shared_ptr client) { return true; }; - - /* - Called when a client has connected. - */ - virtual void ev_client_connected(std::shared_ptr client) {}; - - /* - Called when a client has disconnected. - */ - virtual void ev_client_disconnected(std::shared_ptr client) {}; - - protected: - /* - Service start implementation. - */ - virtual bool impl_start(const char* host, const unsigned short port) = 0; - - /* - Pre process message data before writing. - */ - virtual void pre_process_message(message_t::buffer_t& buffer) override {} - - /* - Post process message data after reading. - */ - virtual void post_process_message(message_t::buffer_t& buffer) override {} - - protected: - void teardown() { - if (m_gc_timer) - m_gc_timer->cancel(); - - if (m_gc_future.valid()) - m_gc_future.wait(); - - /* - Close all connections and signal coroutines to stop - */ - stop_all_connections(); - - /* - Wait for all coroutines to stop - */ - wait_for_coroutines_to_stop(); - - if (this->io_context && !this->io_context->stopped()) - this->io_context->stop(); - - if (m_context_thread.joinable()) - m_context_thread.join(); - - { - std::lock_guard guard(this->m_incoming_mutex); - - this->m_incoming_messages = {}; - this->m_incoming_system_messages = {}; - } - - this->m_cv.notify_all(); - - if (this->m_process_messages_thread.joinable()) - this->m_process_messages_thread.join(); - - LIBNETWRK_INFO(this->name, "Stopped."); - }; - - void start_context() { - m_gc_future = asio::co_spawn(*this->io_context, co_gc(), asio::use_future); - m_context_thread = std::thread([this] { this->io_context->run(); }); - } - - private: - std::thread m_context_thread; - std::unique_ptr m_gc_timer; - std::future m_gc_future; - - private: - bool internal_process_message() override final { - try { - owned_message_t message; - - { - std::lock_guard guard(this->m_incoming_mutex); - - if (!this->m_incoming_system_messages.empty()) { - message = this->m_incoming_system_messages.front(); - this->m_incoming_system_messages.pop(); - } - else { - message = this->m_incoming_messages.front(); - this->m_incoming_messages.pop(); - } - } - - if (message.msg.head.type == message_type::system) { - ev_system_message(message); - } - else { - ev_message(message); - } - } - catch (const std::exception& e) { - (void)e; - - LIBNETWRK_ERROR(this->name, "Failed to process message. | {}", e.what()); - return false; - } - catch (...) { - LIBNETWRK_ERROR(this->name, "Failed to process message. | Critical fail."); - return false; - } - - return true; - } - - /* - Client disconnected callback from connection. - */ - void internal_ev_client_disconnected(std::shared_ptr client) override final { - LIBNETWRK_INFO(this->name, "{}: Client disconnected.", client->id()); - } - - void ev_system_message(owned_message_t& msg) override final { - system_command command = static_cast(msg.msg.command()); - - switch (command) { - case system_command::c2s_verify: return on_system_verify_message(msg); - default: return; - } - } - - void on_system_verify_message(owned_message_t& msg) { - LIBNETWRK_DEBUG(this->name, "Received verify response."); - - auth::answer_t answer{}; - msg.msg >> answer; - - if (!auth::is_correct(msg.sender->auth_question, answer)) - return msg.sender->stop(); - - msg.sender->is_authenticated.store(true); - - message_t response{}; - response.head.type = message_type::system; - response.head.command = static_cast(system_command::s2c_verify_ok); - - msg.sender->send(response); - } - - private: - void impl_send(std::shared_ptr& client, std::shared_ptr message) { - if (client && client->is_connected()) - client->send(message); - } - - void stop_all_connections() { - std::lock_guard guard(m_connections_mutex); - - for (auto& client : m_connections) { - if (!client) continue; - - client->stop(); - } - } - - void wait_for_coroutines_to_stop() { - bool running; - - while (true) { - { - std::lock_guard guard(m_connections_mutex); - - running = std::any_of(m_connections.begin(), m_connections.end(), - [](auto& client) { - return client && client->active_operations != 0; - } - ); - } - - if (!running) - break; - - std::this_thread::sleep_for(std::chrono::milliseconds(5)); - } - } - - asio::awaitable co_gc() { - auto current_executor = co_await asio::this_coro::executor; - m_gc_timer = std::make_unique(current_executor, std::chrono::seconds(gc_freq_sec)); - - size_t count_before = 0U; - size_t count_after = 0U; - - while (true) { - auto [ec] = co_await m_gc_timer->async_wait(asio::as_tuple(asio::use_awaitable)); - - if (ec) { - if (ec != asio::error::operation_aborted) { - LIBNETWRK_ERROR(this->name, "Failed to run GC. | {}", ec.message()); - } - - break; - } - - { - std::lock_guard guard(m_connections_mutex); - - count_before = m_connections.size(); - - m_connections.remove_if([this](auto& client) { - if (!client) - return true; - - if (!client->is_connected()) { - ev_client_disconnected(client); - return true; - } - - if (!client->is_authenticated.load()) { - uint64_t timestamp = std::chrono::duration_cast( - std::chrono::system_clock::now().time_since_epoch()).count(); - - if (timestamp > client->auth_timeout_timestamp) { - client->stop(); - LIBNETWRK_VERBOSE(this->name, "{}: Auth timeout. Client disconnected.", client->id()); - ev_client_disconnected(client); - return true; - } - } - - return false; - }); - - count_after = m_connections.size(); - } - - LIBNETWRK_VERBOSE(this->name, "GC tc: {} rc: {}", count_before, count_before - count_after); - - m_gc_timer->expires_after(std::chrono::seconds(gc_freq_sec)); - } - } - }; -} diff --git a/include/libnetwrk/net/core/client/client.hpp b/include/libnetwrk/net/core/client/client.hpp index 252954a..4a7af7d 100644 --- a/include/libnetwrk/net/core/client/client.hpp +++ b/include/libnetwrk/net/core/client/client.hpp @@ -157,12 +157,10 @@ namespace libnetwrk { private: void wait_for_coroutines_to_stop() { - while (true) { - if (!m_comp_connection.connection) break; - if (m_comp_connection.connection->active_operations == 0) break; + if (!m_comp_connection.connection) return; + if (!m_comp_connection.connection->cancel_cv.has_active_operations()) return; - std::this_thread::sleep_for(std::chrono::milliseconds(5)); - } + m_comp_connection.connection->cancel_cv.wait_for_end(); } }; } diff --git a/include/libnetwrk/net/core/client/client_comp_connection.hpp b/include/libnetwrk/net/core/client/client_comp_connection.hpp index 71da7b4..9907ad1 100644 --- a/include/libnetwrk/net/core/client/client_comp_connection.hpp +++ b/include/libnetwrk/net/core/client/client_comp_connection.hpp @@ -25,7 +25,7 @@ namespace libnetwrk { public: void create_connection() { - connection = std::make_shared(*m_context.io_context); + connection = std::make_shared(m_context.io_context); } void establish_connection(const endpoint_t& endpoint) { diff --git a/include/libnetwrk/net/core/client/client_connection_internal.hpp b/include/libnetwrk/net/core/client/client_connection_internal.hpp index b2e16b2..48d81fa 100644 --- a/include/libnetwrk/net/core/client/client_connection_internal.hpp +++ b/include/libnetwrk/net/core/client/client_connection_internal.hpp @@ -2,6 +2,7 @@ #include "libnetwrk/net/core/client/client_connection.hpp" #include "libnetwrk/net/core/misc/coroutine_cv.hpp" +#include "libnetwrk/net/core/enums.hpp" #include #include @@ -32,15 +33,13 @@ namespace libnetwrk { : base_t(context), write_cv(context), cancel_cv(context) { is_authenticated = false; - active_operations = 0U; } connection_t& operator=(const connection_t&) = delete; connection_t& operator=(connection_t&&) = default; public: - std::atomic_bool is_authenticated; - std::atomic_ushort active_operations; + std::atomic_bool is_authenticated; coroutine_cv write_cv; coroutine_cv cancel_cv; diff --git a/include/libnetwrk/net/core/client/client_context.hpp b/include/libnetwrk/net/core/client/client_context.hpp index f04f1ad..a6ccfc2 100644 --- a/include/libnetwrk/net/core/client/client_context.hpp +++ b/include/libnetwrk/net/core/client/client_context.hpp @@ -9,6 +9,10 @@ namespace libnetwrk { using cb_disconnect_t = std::function; public: - cb_disconnect_t cb_disconnect; + cb_disconnect_t cb_disconnect; + + public: + client_context() + : libnetwrk_context() {} }; } diff --git a/include/libnetwrk/net/core/enums.hpp b/include/libnetwrk/net/core/enums.hpp index c3d4710..6dc137f 100644 --- a/include/libnetwrk/net/core/enums.hpp +++ b/include/libnetwrk/net/core/enums.hpp @@ -17,4 +17,9 @@ namespace libnetwrk { none = 0, keep_message = 1 << 0 // Make message reusable (copy instead of move upon sending) }; + + enum disconnect_code : uint8_t { + unspecified = 0, + authentication_failed = 1 + }; } diff --git a/include/libnetwrk/net/core/libnetwrk_context.hpp b/include/libnetwrk/net/core/libnetwrk_context.hpp index 6c86077..51ac282 100644 --- a/include/libnetwrk/net/core/libnetwrk_context.hpp +++ b/include/libnetwrk/net/core/libnetwrk_context.hpp @@ -18,25 +18,29 @@ namespace libnetwrk { using io_context_t = asio::io_context; using connection_internal_t = tn_connection; - using connection_t = typename tn_connection::base_t; - using command_t = typename connection_t::command_t; - using endpoint_t = typename connection_internal_t::endpoint_t; - using message_t = typename connection_t::message_t; - using owned_message_t = typename connection_t::owned_message_t; - using outgoing_message_t = typename connection_t::outgoing_message_t; - using buffer_t = typename message_t::buffer_t; + using connection_t = connection_internal_t::base_t; + using command_t = connection_t::command_t; + using endpoint_t = connection_internal_t::endpoint_t; + using message_t = connection_t::message_t; + using owned_message_t = connection_t::owned_message_t; + using outgoing_message_t = connection_t::outgoing_message_t; + using buffer_t = message_t::buffer_t; using cb_message_t = std::function; using cb_system_message_t = std::function; - using cb_connect_t = std::function)>; - using cb_internal_disconnect_t = std::function)>; + using cb_connect_t = std::function)>; + using cb_internal_disconnect_t = std::function)>; using cb_pre_process_message_t = std::function; using cb_post_process_message_t = std::function; public: - std::string name = ""; - std::atomic_uint8_t status = libnetwrk::service_status::stopped; - std::unique_ptr io_context; + libnetwrk_context() + : io_context(1) {} + + public: + std::string name = ""; + std::atomic_uint8_t status = libnetwrk::service_status::stopped; + io_context_t io_context; cb_message_t cb_message; cb_system_message_t cb_system_message; @@ -50,20 +54,18 @@ namespace libnetwrk { return status == service_status::started; } - void create_io_context() { - io_context = std::make_unique(1); - } - void start_io_context() { - m_io_context_thread = std::thread([this] { this->io_context->run(); }); + m_io_context_thread = std::thread([this] { + this->io_context.run(); + }); } void stop_io_context() { - if (io_context && !io_context->stopped()) - io_context->stop(); + if (!io_context.stopped()) + io_context.stop(); - if (m_io_context_thread.joinable()) - m_io_context_thread.join(); + if(m_io_context_thread.joinable()) + m_io_context_thread.join(); } private: diff --git a/include/libnetwrk/net/core/misc/coroutine_cv.hpp b/include/libnetwrk/net/core/misc/coroutine_cv.hpp index f67749a..8ac1ef5 100644 --- a/include/libnetwrk/net/core/misc/coroutine_cv.hpp +++ b/include/libnetwrk/net/core/misc/coroutine_cv.hpp @@ -3,6 +3,8 @@ #include "asio.hpp" #include "libnetwrk/net/core/context.hpp" +#include + namespace libnetwrk { /* Condition variable for coroutines. @@ -13,22 +15,40 @@ namespace libnetwrk { coroutine_cv(const coroutine_cv&) = delete; coroutine_cv(coroutine_cv&&) = default; - coroutine_cv(work_context& context) - : m_timer(*context.io_context, asio::steady_timer::time_point::max()) - {} - coroutine_cv(asio::io_context& context) - : m_timer(context, asio::steady_timer::time_point::max()) {} + : m_io_context(context), m_timer(context, asio::steady_timer::duration::max()) + {} coroutine_cv& operator=(const coroutine_cv&) = delete; coroutine_cv& operator=(coroutine_cv&&) = default; public: + /* + Wait for notify or expire. + */ asio::awaitable wait() { + m_operations++; co_await m_timer.async_wait(asio::as_tuple(asio::use_awaitable)); + m_operations--; co_return; } + /* + Wait for all operations to finish. + */ + void wait_for_end() { + while (true) { + if (m_operations == 0 || m_io_context.stopped()) + break; + + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + } + + bool has_active_operations() { + return m_operations != 0; + } + void notify_one() { m_timer.cancel_one(); } @@ -38,6 +58,23 @@ namespace libnetwrk { } private: - asio::steady_timer m_timer; + asio::io_context& m_io_context; + asio::steady_timer m_timer; + std::atomic_uint16_t m_operations = 0U; + + private: + asio::awaitable co_wait_for_end() { + asio::system_timer timer{ co_await asio::this_coro::executor }; + + while (true) { + timer.expires_after(std::chrono::milliseconds{ 5 }); + auto [ec] = co_await timer.async_wait(asio::as_tuple(asio::use_awaitable)); + + if (ec || m_operations == 0) + break; + } + + co_return; + } }; } diff --git a/include/libnetwrk/net/core/service/service.hpp b/include/libnetwrk/net/core/service/service.hpp new file mode 100644 index 0000000..3dbd10c --- /dev/null +++ b/include/libnetwrk/net/core/service/service.hpp @@ -0,0 +1,183 @@ +#pragma once + +#include "libnetwrk/net/core/service/service_context.hpp" +#include "libnetwrk/net/core/service/service_comp_connection.hpp" +#include "libnetwrk/net/core/service/service_comp_message.hpp" +#include "libnetwrk/net/core/service/service_comp_system_message.hpp" +#include "libnetwrk/net/core/service/service_connection_internal.hpp" + +#include + +namespace libnetwrk { + template + class service { + public: + using context_t = service_context>; + using comp_connection_t = service_comp_connection; + using comp_message_t = service_comp_message; + using comp_system_message_t = service_comp_system_message; + + using service_t = service; + using connection_t = context_t::connection_t; + using message_t = context_t::message_t; + using owned_message_t = context_t::owned_message_t; + + public: + service() = delete; + service(const service_t&) = delete; + service(service_t&&) = default; + + service(const std::string& name) + : m_comp_connection(m_context), + m_comp_message(m_context, m_comp_connection), + m_comp_system_message(m_context) + { + m_context.name = name; + + m_context.cb_internal_disconnect = [this](auto connection) { + connection->stop(); + + if (connection->disconnect_code == libnetwrk::disconnect_code::authentication_failed) { + LIBNETWRK_VERBOSE(m_context.name, "{}: Auth timeout. Disconnecting client.", connection->get_id()); + } + else { + LIBNETWRK_VERBOSE(m_context.name, "{}: Client disconnected.", connection->get_id()); + } + }; + } + + service_t& operator=(const service_t&) = delete; + service_t& operator=(service_t&&) = default; + + public: + bool is_running() { + return m_context.is_running(); + } + + bool start(const std::string& host, const uint16_t port) { + if (m_context.status != service_status::stopped) + return false; + + m_context.status = service_status::starting; + + bool started = start_impl(host, port); + + if (started) { + if (m_context.cb_start) + m_context.cb_start(); + + m_context.status = service_status::started; + } + else { + m_context.status = service_status::stopped; + } + + return started; + } + + void stop() { + if (m_context.status != service_status::started) + return; + + m_context.status = service_status::stopping; + + this->teardown(); + + m_context.status = service_status::stopped; + + if (m_context.cb_stop) + m_context.cb_stop(); + } + + void send(std::shared_ptr client, message_t& message, libnetwrk::send_flags flags = libnetwrk::send_flags::none) { + m_comp_message.send(client, message, flags); + } + + void send_all(message_t& message, libnetwrk::send_flags flags = libnetwrk::send_flags::none, + comp_message_t::send_predicate_t predicate = nullptr) + { + m_comp_message.send_all(message, flags, predicate); + } + + bool process_message() { + return m_comp_message.process_message(); + } + + bool process_messages() { + return m_comp_message.process_messages(); + } + + bool process_messages_async() { + return m_comp_message.process_messages_async(); + } + + public: + const std::string& get_name() const { + return m_context.name; + } + + service_settings& get_settings() { + return m_context.settings; + } + + void set_message_callback(context_t::cb_message_t cb) { + if (!m_context.cb_message) + m_context.cb_message = cb; + } + + void set_pre_process_message_callback(context_t::cb_pre_process_message_t cb) { + if (!m_context.cb_pre_process_message) + m_context.cb_pre_process_message = cb; + } + + void set_post_process_message_callback(context_t::cb_post_process_message_t cb) { + if (!m_context.cb_post_process_message) + m_context.cb_post_process_message = cb; + } + + void set_connect_callback(context_t::cb_connect_t cb) { + if (!m_context.cb_connect) + m_context.cb_connect = cb; + } + + void set_disconnect_callback(context_t::cb_disconnect_t cb) { + if (!m_context.cb_disconnect) + m_context.cb_disconnect = cb; + } + + protected: + context_t m_context; + comp_connection_t m_comp_connection; + comp_message_t m_comp_message; + comp_system_message_t m_comp_system_message; + + protected: + virtual void teardown() { + m_context.cancel_cv.notify_all(); + m_context.cancel_cv.wait_for_end(); + + stop_all_connections(); + + m_context.stop_io_context(); + m_comp_message.stop_processing_messages(); + } + + virtual bool start_impl(const std::string& host, const uint16_t port) { + return false; + } + + private: + void stop_all_connections() { + std::lock_guard guard(m_comp_connection.connections_mutex); + + for (auto& client : m_comp_connection.connections) { + if (!client) continue; + + client->stop(); + + if (client->cancel_cv.has_active_operations()) + client->cancel_cv.wait_for_end(); + } + } + }; +} \ No newline at end of file diff --git a/include/libnetwrk/net/core/service/service_comp_connection.hpp b/include/libnetwrk/net/core/service/service_comp_connection.hpp new file mode 100644 index 0000000..39f440a --- /dev/null +++ b/include/libnetwrk/net/core/service/service_comp_connection.hpp @@ -0,0 +1,99 @@ +#pragma once + +#include "asio.hpp" +#include "asio/experimental/awaitable_operators.hpp" +#include "libnetwrk/net/core/service/service_context.hpp" +#include "libnetwrk/net/core/misc/coroutine_cv.hpp" + +#include +#include + +namespace libnetwrk { + template + class service_comp_connection { + public: + using context_t = tn_context; + using connection_t = typename tn_context::connection_internal_t; + + public: + std::list> connections; + std::mutex connections_mutex; + uint64_t id_counter = 0U; + + public: + service_comp_connection(context_t& context) + : m_context(context) {} + + public: + std::shared_ptr create_connection() { + return std::make_shared(m_context.io_context); + } + + void accept_connection(std::shared_ptr connection) { + std::lock_guard guard(connections_mutex); + connections.push_back(connection); + } + + void start_gc() { + using namespace asio::experimental::awaitable_operators; + + asio::co_spawn(m_context.io_context, co_gc() || m_context.cancel_cv.wait(), [this](auto, auto) { + LIBNETWRK_VERBOSE(m_context.name, "Stopped GC."); + }); + } + + private: + context_t& m_context; + + private: + asio::awaitable co_gc() { + asio::steady_timer timer(m_context.io_context, std::chrono::seconds(m_context.settings.gc_freq_sec)); + + size_t count_before = 0U; + size_t count_after = 0U; + + while (true) { + auto [ec] = co_await timer.async_wait(asio::as_tuple(asio::use_awaitable)); + + if (ec) + break; + + { + std::lock_guard guard(connections_mutex); + + count_before = connections.size(); + + connections.remove_if([this](auto& client) { + if (!client) + return true; + + if (!client->is_connected() && !client->cancel_cv.has_active_operations()) { + if (m_context.cb_disconnect) + m_context.cb_disconnect(client, client->disconnect_code); + + return true; + } + + if (!client->is_authenticated) { + uint64_t timestamp = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()); + + if (timestamp > client->auth_request_deadline) { + client->disconnect_code = libnetwrk::disconnect_code::authentication_failed; + + if (m_context.cb_internal_disconnect) + m_context.cb_internal_disconnect(client); + } + } + + return false; + }); + + count_after = connections.size(); + } + + LIBNETWRK_VERBOSE(m_context.name, "GC tc: {} rc: {}", count_after, count_before - count_after); + timer.expires_after(std::chrono::seconds(m_context.settings.gc_freq_sec)); + } + } + }; +} \ No newline at end of file diff --git a/include/libnetwrk/net/core/service/service_comp_message.hpp b/include/libnetwrk/net/core/service/service_comp_message.hpp new file mode 100644 index 0000000..2d530cd --- /dev/null +++ b/include/libnetwrk/net/core/service/service_comp_message.hpp @@ -0,0 +1,57 @@ +#pragma once + +#include "libnetwrk/net/core/service/service_context.hpp" +#include "libnetwrk/net/core/service/service_comp_connection.hpp" +#include "libnetwrk/net/core/shared/shared_comp_message.hpp" +#include "libnetwrk/net/core/enums.hpp" + +#include + +namespace libnetwrk { + template + class service_comp_message : public shared_comp_message { + public: + using context_t = tn_context; + using comp_connection_t = service_comp_connection; + using connection_t = context_t::connection_t; + using message_t = context_t::message_t; + using outgoing_message_t = context_t::outgoing_message_t; + + using send_predicate_t = std::function)>; + + public: + service_comp_message(context_t& context, comp_connection_t& comp_connection) + : shared_comp_message(context), + m_comp_connection(comp_connection) + {} + + public: + void send(std::shared_ptr client, message_t& message, libnetwrk::send_flags flags) { + if (!client || client->is_connected()) return; + + client->send(message, flags); + } + + void send_all(message_t& message, libnetwrk::send_flags flags, send_predicate_t predicate) { + std::shared_ptr outgoing_message; + + if (LIBNETWRK_FLAG_SET(flags, libnetwrk::send_flags::keep_message)) { + outgoing_message = std::make_shared(message); + } + else { + outgoing_message = std::make_shared(std::move(message)); + } + + std::lock_guard guard(m_comp_connection.connections_mutex); + for (auto& client : m_comp_connection.connections) { + if (!client || !client->is_connected()) continue; + if (predicate && !predicate(client)) continue; + + client->direct_send(outgoing_message); + } + } + + private: + comp_connection_t& m_comp_connection; + }; +} \ No newline at end of file diff --git a/include/libnetwrk/net/core/service/service_comp_system_message.hpp b/include/libnetwrk/net/core/service/service_comp_system_message.hpp new file mode 100644 index 0000000..c1393b7 --- /dev/null +++ b/include/libnetwrk/net/core/service/service_comp_system_message.hpp @@ -0,0 +1,69 @@ +#pragma once + +#include "libnetwrk/net/core/service/service_context.hpp" +#include "libnetwrk/net/core/auth.hpp" + +namespace libnetwrk { + template + class service_comp_system_message { + public: + using context_t = tn_context; + using connection_internal_t = context_t::connection_internal_t; + using message_t = context_t::message_t; + using owned_message_t = context_t::owned_message_t; + + public: + service_comp_system_message(context_t& context) + : m_context(context) + { + m_context.cb_system_message = [this](auto command, auto message) { + ev_system_message(command, message); + }; + } + + void send_auth_message(std::shared_ptr connection) { + connection->auth_request = auth::generate_auth_question(); + + message_t request{}; + request.head.type = message_type::system; + request.head.command = static_cast(system_command::s2c_verify); + request << connection->auth_request; + + connection->send(request); + + auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(m_context.settings.auth_deadline_sec); + connection->auth_request_deadline = std::chrono::system_clock::to_time_t(deadline); + } + + private: + context_t& m_context; + + private: + void ev_system_message(system_command command, owned_message_t* message) { + switch (command) { + case system_command::c2s_verify: return on_system_verify_message(message); + default: return; + } + } + + void on_system_verify_message(owned_message_t* message) { + LIBNETWRK_DEBUG(m_context.name, "Received verify response."); + + auto sender = std::static_pointer_cast(message->sender); + + auth::answer_t answer{}; + message->msg >> answer; + + if (!auth::is_correct(sender->auth_request, answer)) + return sender->stop(); + + sender->is_authenticated = true; + + message_t response{}; + response.head.type = message_type::system; + response.head.command = static_cast(system_command::s2c_verify_ok); + + sender->send(response); + } + }; +} \ No newline at end of file diff --git a/include/libnetwrk/net/core/service/service_connection.hpp b/include/libnetwrk/net/core/service/service_connection.hpp new file mode 100644 index 0000000..2baed82 --- /dev/null +++ b/include/libnetwrk/net/core/service/service_connection.hpp @@ -0,0 +1,41 @@ +#pragma once + +#include "libnetwrk/net/core/shared/shared_connection.hpp" + +namespace libnetwrk { + template + class service_connection : public shared_connection { + public: + using base_t = shared_connection; + using io_context_t = base_t::io_context_t; + using command_t = base_t::command_t; + using serialize_t = base_t::serialize_t; + using connection_t = service_connection; + using message_t = base_t::message_t; + using owned_message_t = owned_message; + using outgoing_message_t = base_t::outgoing_message_t; + + public: + service_connection() = delete; + service_connection(const connection_t&) = delete; + service_connection(connection_t&&) = default; + + service_connection(io_context_t& context) + : base_t(context) {} + + connection_t& operator=(const connection_t&) = delete; + connection_t& operator=(connection_t&&) = default; + + public: + virtual void stop() override { + base_t::stop(); + } + + protected: + virtual void notify() override {} + + virtual void direct_send(const std::shared_ptr outgoing_message) override { + base_t::direct_send(outgoing_message); + } + }; +} \ No newline at end of file diff --git a/include/libnetwrk/net/core/service/service_connection_internal.hpp b/include/libnetwrk/net/core/service/service_connection_internal.hpp new file mode 100644 index 0000000..503d7e8 --- /dev/null +++ b/include/libnetwrk/net/core/service/service_connection_internal.hpp @@ -0,0 +1,103 @@ +#pragma once + +#include "libnetwrk/net/core/service/service_connection.hpp" +#include "libnetwrk/net/core/auth.hpp" + +#include + +namespace libnetwrk { + template + class service_connection_internal : public service_connection { + public: + using base_t = service_connection; + using io_context_t = base_t::io_context_t; + using command_t = base_t::command_t; + using serialize_t = base_t::serialize_t; + using connection_t = service_connection_internal; + using endpoint_t = typename tn_socket::endpoint_t; + using message_t = base_t::message_t; + using owned_message_t = base_t::owned_message_t; + using outgoing_message_t = base_t::outgoing_message_t; + + public: + service_connection_internal() = delete; + service_connection_internal(const connection_t&) = delete; + service_connection_internal(connection_t&&) = default; + + service_connection_internal(io_context_t& context) + : base_t(context), write_cv(context), cancel_cv(context) + { + is_authenticated = false; + auth_request = {}; + auth_request_deadline = 0U; + disconnect_code = libnetwrk::disconnect_code::unspecified; + } + + connection_t& operator=(const connection_t&) = delete; + connection_t& operator=(connection_t&&) = default; + + public: + std::atomic_bool is_authenticated; + auth::question_t auth_request; + uint64_t auth_request_deadline; + libnetwrk::disconnect_code disconnect_code; + + coroutine_cv write_cv; + coroutine_cv cancel_cv; + + public: + bool wait_for_messages() { + std::unique_lock lock(this->m_outgoing_mutex); + return this->m_outgoing_system_messages.empty() && this->m_outgoing_messages.empty(); + } + + bool has_user_messages() { + return !this->m_outgoing_messages.empty(); + } + + bool has_system_messages() { + return !this->m_outgoing_system_messages.empty(); + } + + std::queue>& get_user_messages() { return this->m_outgoing_messages; } + std::queue>& get_system_messages() { return this->m_outgoing_system_messages; } + std::mutex& get_outgoing_mutex() { return this->m_outgoing_mutex; } + + public: + void stop() override final { + base_t::stop(); + write_cv.notify_all(); + cancel_cv.notify_all(); + } + + void direct_send(const std::shared_ptr outgoing_message) override final { + base_t::direct_send(outgoing_message); + } + + public: + void connect(const endpoint_t& endpoint) { + this->m_socket.connect(endpoint); + } + + tn_socket& get_socket() { + return this->m_socket; + } + + void set_id(uint64_t id) { + this->m_id = id; + } + + asio::awaitable co_read_message(message_t& recv_message, std::error_code& ec) { + return base_t::base_t::co_read_message(recv_message, ec); + } + + asio::awaitable co_write_message(std::shared_ptr message, std::error_code& ec) { + return base_t::base_t::co_write_message(message, ec); + } + + protected: + void notify() override final { + write_cv.notify_one(); + } + }; +} \ No newline at end of file diff --git a/include/libnetwrk/net/core/service/service_context.hpp b/include/libnetwrk/net/core/service/service_context.hpp new file mode 100644 index 0000000..9be28f8 --- /dev/null +++ b/include/libnetwrk/net/core/service/service_context.hpp @@ -0,0 +1,42 @@ +#pragma once + +#include "libnetwrk/net/core/libnetwrk_context.hpp" +#include "libnetwrk/net/core/service/service_connection.hpp" +#include "libnetwrk/net/core/misc/coroutine_cv.hpp" +#include "libnetwrk/net/core/system_commands.hpp" +#include "libnetwrk/net/core/enums.hpp" + +#include + +namespace libnetwrk { + struct service_settings { + uint8_t gc_freq_sec = 15U; + uint8_t auth_deadline_sec = 10U; + }; + + template + class service_context : public libnetwrk_context { + public: + using connection_t = typename tn_connection::base_t; + + using cb_start_t = std::function; + using cb_stop_t = std::function; + using cb_before_connect_t = std::function)>; + using cb_disconnect_t = std::function, libnetwrk::disconnect_code)>; + + public: + service_settings settings; + coroutine_cv cancel_cv; + + cb_start_t cb_start; + cb_stop_t cb_stop; + cb_before_connect_t cb_before_connect; + cb_disconnect_t cb_disconnect; + + public: + service_context() + : libnetwrk_context(), + cancel_cv(this->io_context) + {} + }; +} diff --git a/include/libnetwrk/net/core/shared/shared_comp_message.hpp b/include/libnetwrk/net/core/shared/shared_comp_message.hpp index 9130e45..62b458c 100644 --- a/include/libnetwrk/net/core/shared/shared_comp_message.hpp +++ b/include/libnetwrk/net/core/shared/shared_comp_message.hpp @@ -51,16 +51,14 @@ namespace libnetwrk { void start_connection_read_and_write(std::shared_ptr connection) { using namespace asio::experimental::awaitable_operators; - asio::co_spawn(*m_context.io_context, this->co_read(connection) || connection->cancel_cv.wait(), + asio::co_spawn(m_context.io_context, this->co_read(connection) || connection->cancel_cv.wait(), [this, connection](auto, auto) { - connection->active_operations--; LIBNETWRK_DEBUG(m_context.name, "{}: Stopped reading messages.", connection->get_id()); } ); - asio::co_spawn(*m_context.io_context, this->co_write(connection) || connection->cancel_cv.wait(), + asio::co_spawn(m_context.io_context, this->co_write(connection) || connection->cancel_cv.wait(), [this, connection](auto, auto) { - connection->active_operations--; LIBNETWRK_DEBUG(m_context.name, "{}: Stopped writing messages.", connection->get_id()); } ); @@ -95,7 +93,6 @@ namespace libnetwrk { asio::awaitable co_read(std::shared_ptr connection) { std::error_code ec = {}; - connection->active_operations++; LIBNETWRK_DEBUG(m_context.name, "Started reading messages."); while (true) { @@ -111,7 +108,7 @@ namespace libnetwrk { LIBNETWRK_ERROR(m_context.name, "Failed during read. | {}", ec.message()); } - connection->stop(); + //connection->stop(); if (m_context.cb_internal_disconnect) m_context.cb_internal_disconnect(connection); @@ -148,7 +145,6 @@ namespace libnetwrk { asio::awaitable co_write(std::shared_ptr connection) { std::error_code ec = {}; - connection->active_operations++; LIBNETWRK_DEBUG(m_context.name, "Started writing messages."); while (true) { @@ -208,7 +204,7 @@ namespace libnetwrk { LIBNETWRK_ERROR(m_context.name, "Failed during write. | {}", ec.message()); } - connection->stop(); + //connection->stop(); if (m_context.cb_internal_disconnect) m_context.cb_internal_disconnect(connection); diff --git a/include/libnetwrk/net/tcp/socket.hpp b/include/libnetwrk/net/tcp/socket.hpp index 7cb477c..0f44584 100644 --- a/include/libnetwrk/net/tcp/socket.hpp +++ b/include/libnetwrk/net/tcp/socket.hpp @@ -56,8 +56,13 @@ namespace libnetwrk::tcp { } public: + native_socket_t& native() { + return m_socket; + } + void close() { - m_socket.close(); + if (m_socket.is_open()) + m_socket.close(); } template diff --git a/include/libnetwrk/net/tcp/tcp_client.hpp b/include/libnetwrk/net/tcp/tcp_client.hpp index d048b69..1f2a87f 100644 --- a/include/libnetwrk/net/tcp/tcp_client.hpp +++ b/include/libnetwrk/net/tcp/tcp_client.hpp @@ -36,11 +36,8 @@ namespace libnetwrk::tcp { private: bool connect_impl(const std::string& host, const uint16_t port) override final { try { - // Create context - this->m_context.create_io_context(); - // Create resolver - tcp_resolver resolver(*this->m_context.io_context); + tcp_resolver resolver(this->m_context.io_context); // Resolve hostname asio::ip::tcp::endpoint ep; diff --git a/include/libnetwrk/net/tcp/tcp_service.hpp b/include/libnetwrk/net/tcp/tcp_service.hpp index 99db63c..07e5b45 100644 --- a/include/libnetwrk/net/tcp/tcp_service.hpp +++ b/include/libnetwrk/net/tcp/tcp_service.hpp @@ -3,156 +3,90 @@ #include "libnetwrk/net/default_service_desc.hpp" #include "libnetwrk/net/tcp/socket.hpp" #include "libnetwrk/net/tcp/tcp_resolver.hpp" -#include "libnetwrk/net/core/base_service.hpp" +#include "libnetwrk/net/core/service/service.hpp" #include "libnetwrk/net/core/serialization/bin_serialize.hpp" #include #include namespace libnetwrk::tcp { - template - requires is_libnetwrk_service_desc - class tcp_service : public libnetwrk::base_service { + template + requires is_libnetwrk_service_desc + class tcp_service : public libnetwrk::service { public: - // This service type - using service_t = tcp_service; - - // Base service type - using base_service_t = libnetwrk::base_service; - - // Connection type for this service - using connection_t = base_service_t::connection_t; - - // Message type - using message_t = base_service_t::message_t; - - // Owned message type for this service - using owned_message_t = base_service_t::owned_message_t; - - // Command type for this service - using command_t = typename Desc::command_t; - - private: - // Client acceptor type - using acceptor_t = asio::ip::tcp::acceptor; + using base_t = libnetwrk::service; + using connection_t = base_t::context_t::connection_t; + using connection_internal_t = base_t::context_t::connection_internal_t; + using command_t = typename tn_desc::command_t; + using message_t = base_t::message_t; + using owned_message_t = base_t::owned_message_t; public: tcp_service(const std::string& name = "TCP service") - : base_service_t(name) {}; + : base_t(name), m_acceptor(this->m_context.io_context) {}; virtual ~tcp_service() { - auto status = this->m_status.load(); - - if (status == service_status::stopped || status == service_status::stopping) - return; - - this->m_status = service_status::stopping; - teardown(); + this->m_context.status = service_status::stopping; + this->teardown(); } - public: - /* - Stop service. - */ - void stop() override final { - if (this->m_status != service_status::started) - return; - - this->m_status = service_status::stopping; - - teardown(); - base_service_t::teardown(); - - this->m_status = service_status::stopped; - - ev_service_stopped(); + uint16_t get_port() { + return m_acceptor.local_endpoint().port(); } protected: - std::unique_ptr m_acceptor; - - protected: - /* - Called when the service was successfully started. - */ - virtual void ev_service_started() override {}; - - /* - Called when service stopped. - */ - virtual void ev_service_stopped() override {}; - - /* - Called when processing messages. - */ - virtual void ev_message(owned_message_t& msg) override {}; - - /* - Called before client is fully accepted. - Allows performing checks on client before accepting (blacklist, whitelist). - */ - virtual bool ev_before_client_connected(std::shared_ptr client) override { return true; }; - - /* - Called when a client has connected. - */ - virtual void ev_client_connected(std::shared_ptr client) override {}; - - /* - Called when a client has disconnected. - */ - virtual void ev_client_disconnected(std::shared_ptr client) override {}; + using acceptor_t = asio::ip::tcp::acceptor; protected: - /* - Pre process message data before writing. - */ - virtual void pre_process_message(message_t::buffer_t& buffer) override {} - - /* - Post process message data after reading. - */ - virtual void post_process_message(message_t::buffer_t& buffer) override {} + acceptor_t m_acceptor; private: - // Native socket type for this service - using native_socket_t = libnetwrk::tcp::socket::native_socket_t; + void teardown() override final { + if (m_acceptor.is_open()) + m_acceptor.close(); - private: - std::future m_listening_future; + base_t::teardown(); + }; - private: - bool impl_start(const char* host, const unsigned short port) override final { - try { - // Create ASIO context - this->io_context = std::make_unique(1); + bool start_impl(const std::string& host, const uint16_t port) override final { + using namespace asio::experimental::awaitable_operators; + try { // Create resolver - tcp_resolver resolver(*this->io_context); + tcp_resolver resolver(this->m_context.io_context); // Resolve hostname asio::ip::tcp::endpoint ep; if (!resolver.get_endpoint(host, port, ep)) throw libnetwrk_exception("Failed to resolve hostname."); - // Create ASIO acceptor - m_acceptor = std::make_unique(*(this->io_context), ep); + // Open acceptor + m_acceptor.open(asio::ip::tcp::v4()); + m_acceptor.set_option(acceptor_t::reuse_address(true)); + m_acceptor.bind(ep); + m_acceptor.listen(); + + // Start listening + asio::co_spawn(this->m_context.io_context, co_listen() || this->m_context.cancel_cv.wait(), [this](auto, auto) { + LIBNETWRK_INFO(this->m_context.name, "Stopped listening."); + }); - m_listening_future = asio::co_spawn(*this->io_context, internal_listen(), asio::use_future); + // Start GC + this->m_comp_connection.start_gc(); - // Start ASIO context - this->start_context(); + // Start context + this->m_context.start_io_context(); } catch (const std::exception& e) { (void)e; - LIBNETWRK_ERROR(this->name, "Failed to start listening. | {}", e.what()); - stop(); + LIBNETWRK_ERROR(this->m_context.name, "Failed to start listening. | {}", e.what()); + this->teardown(); return false; } catch (...) { - LIBNETWRK_ERROR(this->name, "Failed to start listening. | Critical fail."); - stop(); + LIBNETWRK_ERROR(this->m_context.name, "Failed to start listening. | Critical fail."); + this->teardown(); return false; } @@ -160,63 +94,49 @@ namespace libnetwrk::tcp { } private: - void teardown() { - if (m_acceptor && m_acceptor->is_open()) - m_acceptor->close(); - - if (m_listening_future.valid()) - m_listening_future.wait(); - - m_acceptor.reset(); - }; - - asio::awaitable internal_listen() { + asio::awaitable co_listen() { auto current_executor = co_await asio::this_coro::executor; - LIBNETWRK_INFO(this->name, "Listening for connections on {}:{}.", - m_acceptor->local_endpoint().address().to_string(), - m_acceptor->local_endpoint().port()); + LIBNETWRK_INFO(this->m_context.name, "Listening for connections on {}:{}.", + m_acceptor.local_endpoint().address().to_string(), + m_acceptor.local_endpoint().port()); while (true) { - auto [ec, socket] = co_await m_acceptor->async_accept(asio::as_tuple(asio::use_awaitable)); + auto connection = std::make_shared(this->m_context.io_context); + + auto [ec] = co_await m_acceptor.async_accept(connection->get_socket().native(), + asio::as_tuple(asio::use_awaitable)); if (ec) { if (ec != asio::error::operation_aborted) { - LIBNETWRK_ERROR(this->name, "Failed to accept connection. | {}: {}", ec.value(), ec.message()); + LIBNETWRK_ERROR(this->m_context.name, "Failed to accept connection. | {}: {}", ec.value(), ec.message()); } break; } - asio::co_spawn(current_executor, internal_accept(std::move(socket)), asio::detached); + asio::co_spawn(current_executor, co_accept(connection), asio::detached); } - - LIBNETWRK_INFO(this->name, "Stopped listening."); } - asio::awaitable internal_accept(native_socket_t socket) { - LIBNETWRK_VERBOSE(this->name, "Attempted connection from {}:{}.", - socket.remote_endpoint().address().to_string(), - socket.remote_endpoint().port()); + asio::awaitable co_accept(std::shared_ptr connection) { + LIBNETWRK_VERBOSE(this->m_context.name, "Attempted connection from {}:{}.", + connection->get_ip(), connection->get_port()); - auto new_connection = std::make_shared(*this, std::move(socket)); - - if (ev_before_client_connected(new_connection)) { - { - std::lock_guard guard(this->m_connections_mutex); - this->m_connections.push_back(new_connection); - this->m_connections.back()->id() = ++this->m_ids; - this->m_connections.back()->start(); - } + if (!this->m_context.cb_before_connect || this->m_context.cb_before_connect(std::static_pointer_cast(connection))) { + this->m_comp_connection.accept_connection(connection); + connection->set_id(++this->m_comp_connection.id_counter); + this->m_comp_message.start_connection_read_and_write(connection); + this->m_comp_system_message.send_auth_message(connection); - ev_client_connected(new_connection); + if (this->m_context.cb_connect) + this->m_context.cb_connect(std::static_pointer_cast(connection)); - LIBNETWRK_INFO(this->name, "Connection success from {}:{}.", - new_connection->get_ip(), - new_connection->get_port()); + LIBNETWRK_INFO(this->m_context.name, "Connection success from {}:{}.", + connection->get_ip(), connection->get_port()); } else { - LIBNETWRK_WARNING(this->name, "Connection denied."); + LIBNETWRK_WARNING(this->m_context.name, "Connection denied."); } co_return; diff --git a/test/test_service_client.cpp b/test/test_service_client.cpp index 788e181..161dc5d 100644 --- a/test/test_service_client.cpp +++ b/test/test_service_client.cpp @@ -17,28 +17,42 @@ struct service_desc { class test_service : public tcp_service { public: - test_service() : tcp_service() {} + bool client_connected = false; + bool client_disconnected = false; + libnetwrk::disconnect_code dc_code = libnetwrk::disconnect_code::unspecified; + + test_service() : tcp_service() { + set_connect_callback([this](auto) { + client_connected = true; + }); + + set_disconnect_callback([this](auto, auto code) { + client_disconnected = true; + dc_code = code; + }); + } size_t connections() { - return m_connections.size(); + return m_comp_connection.connections.size(); } }; TEST(service_client, auth_timeout) { test_service service; - service.gc_freq_sec = 1; - service.start("127.0.0.1", 21205); + service.get_settings().gc_freq_sec = 1; + service.get_settings().auth_deadline_sec = 4; + service.start("127.0.0.1", 0); tcp_client client; - client.connect("127.0.0.1", 21205); + client.connect("127.0.0.1", service.get_port()); service.process_messages_async(); - std::this_thread::sleep_for(std::chrono::milliseconds(7500)); - - EXPECT_TRUE(service.connections() == 1); - - std::this_thread::sleep_for(std::chrono::milliseconds(7500)); + while (service.connections() != 0 || client.is_connected()) { + std::this_thread::sleep_for(std::chrono::milliseconds(20)); + } - EXPECT_TRUE(service.connections() == 0); + EXPECT_TRUE(service.client_connected); + EXPECT_TRUE(service.client_disconnected); + EXPECT_TRUE(service.dc_code == libnetwrk::disconnect_code::authentication_failed); } \ No newline at end of file diff --git a/test/test_tcp_command_type.cpp b/test/test_tcp_command_type.cpp index d898c98..05d68e9 100644 --- a/test/test_tcp_command_type.cpp +++ b/test/test_tcp_command_type.cpp @@ -20,18 +20,22 @@ class basic_service : public tcp_service> { public: using base_t = tcp_service>; - basic_service() : tcp_service>() {} + basic_service() : tcp_service>() { + this->set_message_callback([this](auto command, auto message) { + ev_message(command, message); + }); + } std::string ping = ""; - void ev_message(base_t::owned_message_t& message) override { - switch (message.msg.command()) { + void ev_message(base_t::command_t command, base_t::owned_message_t* msg) { + switch (command) { case base_t::command_t::c2s_ping: { - message.msg >> ping; + msg->msg >> ping; libnetwrk::message> response(base_t::command_t::s2c_pong); response << std::string("pOnG"); - message.sender->send(response); + msg->sender->send(response); break; } default: break; @@ -74,11 +78,11 @@ enum class commands_uint8 : unsigned char { TEST(commands, uint8) { ASSERT_NO_THROW(basic_service service); basic_service service; - service.start("127.0.0.1", 21205); + service.start("127.0.0.1", 0); ASSERT_NO_THROW(basic_client client); basic_client client; - client.connect("127.0.0.1", 21205); + client.connect("127.0.0.1", service.get_port()); basic_client::message_t message(basic_client::command_t::c2s_ping); message << std::string("PiNg"); @@ -105,11 +109,11 @@ enum class commands_uint16 : unsigned short { TEST(commands, uint16) { ASSERT_NO_THROW(basic_service service); basic_service service; - service.start("127.0.0.1", 21205); + service.start("127.0.0.1", 0); ASSERT_NO_THROW(basic_client client); basic_client client; - client.connect("127.0.0.1", 21205); + client.connect("127.0.0.1", service.get_port()); basic_client::message_t message(basic_client::command_t::c2s_ping); message << std::string("PiNg"); @@ -136,11 +140,11 @@ enum class commands_uint32 : unsigned int { TEST(commands, uint32) { ASSERT_NO_THROW(basic_service service); basic_service service; - service.start("127.0.0.1", 21205); + service.start("127.0.0.1", 0); ASSERT_NO_THROW(basic_client client); basic_client client; - client.connect("127.0.0.1", 21205); + client.connect("127.0.0.1", service.get_port()); basic_client::message_t message(basic_client::command_t::c2s_ping); message << std::string("PiNg"); @@ -167,11 +171,11 @@ enum class commands_uint64 : unsigned long long { TEST(commands, uint64) { ASSERT_NO_THROW(basic_service service); basic_service service; - service.start("127.0.0.1", 21205); + service.start("127.0.0.1", 0); ASSERT_NO_THROW(basic_client client); basic_client client; - client.connect("127.0.0.1", 21205); + client.connect("127.0.0.1", service.get_port()); basic_client::message_t message(basic_client::command_t::c2s_ping); message << std::string("PiNg"); @@ -198,11 +202,11 @@ enum class commands_int8 : char { TEST(commands, int8) { ASSERT_NO_THROW(basic_service service); basic_service service; - service.start("127.0.0.1", 21205); + service.start("127.0.0.1", 0); ASSERT_NO_THROW(basic_client client); basic_client client; - client.connect("127.0.0.1", 21205); + client.connect("127.0.0.1", service.get_port()); basic_client::message_t message(basic_client::command_t::c2s_ping); message << std::string("PiNg"); @@ -229,11 +233,11 @@ enum class commands_int16 : short { TEST(commands, int16) { ASSERT_NO_THROW(basic_service service); basic_service service; - service.start("127.0.0.1", 21205); + service.start("127.0.0.1", 0); ASSERT_NO_THROW(basic_client client); basic_client client; - client.connect("127.0.0.1", 21205); + client.connect("127.0.0.1", service.get_port()); basic_client::message_t message(basic_client::command_t::c2s_ping); message << std::string("PiNg"); @@ -260,11 +264,11 @@ enum class commands_int32 : int { TEST(commands, int32) { ASSERT_NO_THROW(basic_service service); basic_service service; - service.start("127.0.0.1", 21205); + service.start("127.0.0.1", 0); ASSERT_NO_THROW(basic_client client); basic_client client; - client.connect("127.0.0.1", 21205); + client.connect("127.0.0.1", service.get_port()); basic_client::message_t message(basic_client::command_t::c2s_ping); message << std::string("PiNg"); @@ -291,11 +295,11 @@ enum class commands_int64 : long long { TEST(commands, int64) { ASSERT_NO_THROW(basic_service service); basic_service service; - service.start("127.0.0.1", 21205); + service.start("127.0.0.1", 0); ASSERT_NO_THROW(basic_client client); basic_client client; - client.connect("127.0.0.1", 21205); + client.connect("127.0.0.1", service.get_port()); basic_client::message_t message(basic_client::command_t::c2s_ping); message << std::string("PiNg"); diff --git a/test/test_tcp_service.cpp b/test/test_tcp_service.cpp index d18d561..05ef1fa 100644 --- a/test/test_tcp_service.cpp +++ b/test/test_tcp_service.cpp @@ -16,26 +16,26 @@ TEST(tcp_service, create) { ASSERT_NO_THROW(libnetwrk::tcp::tcp_service service); libnetwrk::tcp::tcp_service service; - EXPECT_FALSE(service.running()); + EXPECT_FALSE(service.is_running()); } TEST(tcp_service, start_ip) { libnetwrk::tcp::tcp_service service; - EXPECT_TRUE(service.start("127.0.0.1", 21205)); - EXPECT_TRUE(service.running()); + EXPECT_TRUE(service.start("127.0.0.1", 0)); + EXPECT_TRUE(service.is_running()); } TEST(tcp_service, start_twice) { libnetwrk::tcp::tcp_service service; - EXPECT_TRUE(service.start("127.0.0.1", 21205)); - EXPECT_FALSE(service.start("127.0.0.1", 21205)); - EXPECT_TRUE(service.running()); + EXPECT_TRUE(service.start("127.0.0.1", 0)); + EXPECT_FALSE(service.start("127.0.0.1", 0)); + EXPECT_TRUE(service.is_running()); } TEST(tcp_service, stop) { libnetwrk::tcp::tcp_service service; - service.start("127.0.0.1", 21205); - EXPECT_TRUE(service.running()); + service.start("127.0.0.1", 0); + EXPECT_TRUE(service.is_running()); service.stop(); - EXPECT_FALSE(service.running()); + EXPECT_FALSE(service.is_running()); } diff --git a/test/test_tcp_service_client.cpp b/test/test_tcp_service_client.cpp index b26f92d..501e54b 100644 --- a/test/test_tcp_service_client.cpp +++ b/test/test_tcp_service_client.cpp @@ -30,29 +30,33 @@ struct service_desc { class test_service : public tcp_service { public: - test_service() : tcp_service() {} + test_service() : tcp_service() { + set_message_callback([this](auto command, auto message) { + ev_message(command, message); + }); + } bool client_said_hello = false; bool client_said_echo = false; bool client_said_broadcast = false; std::string ping = ""; - void ev_message(owned_message_t& msg) override { + void ev_message(command_t command, owned_message_t* msg) { message_t response; - switch (msg.msg.command()) { + switch (command) { case commands::c2s_hello: client_said_hello = true; break; case commands::c2s_echo: client_said_echo = true; response.set_command(commands::s2c_echo); - msg.sender->send(response); + msg->sender->send(response); break; case commands::c2s_ping: - msg.msg >> ping; + msg->msg >> ping; response.set_command(commands::s2c_pong); response << std::string("pOnG"); - msg.sender->send(response); + msg->sender->send(response); break; case commands::c2s_broadcast: client_said_broadcast = true; @@ -62,13 +66,13 @@ class test_service : public tcp_service { case commands::c2s_send_sync_success: response.set_command(commands::s2c_send_sync_success); response << std::string("success"); - msg.sender->send(response); + msg->sender->send(response); break; case commands::c2s_send_sync_fail: response.set_command(commands::s2c_send_sync_fail); response << std::string("fail"); std::this_thread::sleep_for(std::chrono::milliseconds(5500)); - msg.sender->send(response); + msg->sender->send(response); break; default: break; @@ -76,11 +80,11 @@ class test_service : public tcp_service { } bool is_correct_id(uint32_t index, uint64_t id) { - if (index > m_connections.size() - 1) return false; + if (index > m_comp_connection.connections.size() - 1) return false; - auto front = m_connections.begin(); + auto front = m_comp_connection.connections.begin(); std::advance(front, index); - return (*front)->id() == id; + return (*front)->get_id() == id; } }; @@ -116,18 +120,30 @@ class test_client : public tcp_client { class test_service_pp : public tcp_service { public: - test_service_pp() : tcp_service() {} + test_service_pp() : tcp_service() { + set_message_callback([this](auto command, auto message) { + ev_message(command, message); + }); + + set_pre_process_message_callback([this](auto buffer) { + pre_process_message(buffer); + }); + + set_post_process_message_callback([this](auto buffer) { + post_process_message(buffer); + }); + } std::string ping = ""; - void ev_message(owned_message_t& msg) override { + void ev_message(command_t command, owned_message_t* msg) { message_t response; - switch (msg.msg.command()) { + switch (command) { case commands::c2s_ping: - msg.msg >> ping; + msg->msg >> ping; response.set_command(commands::s2c_pong); response << std::string("pOnG"); - msg.sender->send(response); + msg->sender->send(response); break; default: break; @@ -135,18 +151,18 @@ class test_service_pp : public tcp_service { } protected: - void pre_process_message(message_t::buffer_t& buffer) override final { - for (uint8_t& byte : buffer.underlying()) { + void pre_process_message(message_t::buffer_t* buffer) { + for (uint8_t& byte : buffer->underlying()) { byte ^= 69; } - buffer.underlying().push_back(155); + buffer->underlying().push_back(155); } - void post_process_message(message_t::buffer_t& buffer) override final { - buffer.underlying().resize(buffer.size() - 1); + void post_process_message(message_t::buffer_t* buffer) { + buffer->underlying().resize(buffer->size() - 1); - for (uint8_t& byte : buffer.underlying()) { + for (uint8_t& byte : buffer->underlying()) { byte ^= 69; } } @@ -201,27 +217,27 @@ class test_client_pp : public tcp_client { TEST(tcp_service_client, connect) { { test_service service; - EXPECT_TRUE(service.start("127.0.0.1", 21205)); + EXPECT_TRUE(service.start("127.0.0.1", 0)); test_client client; - EXPECT_TRUE(client.connect("127.0.0.1", 21205)); + EXPECT_TRUE(client.connect("127.0.0.1", service.get_port())); } { test_service service; - EXPECT_TRUE(service.start("localhost", 21205)); + EXPECT_TRUE(service.start("localhost", 0)); test_client client; - EXPECT_TRUE(client.connect("localhost", 21205)); + EXPECT_TRUE(client.connect("localhost", service.get_port())); } } TEST(tcp_service_client, hello) { test_service service; - service.start("127.0.0.1", 21205); + service.start("127.0.0.1", 0); test_client client; - client.connect("127.0.0.1", 21205); + client.connect("127.0.0.1", service.get_port()); test_client::message_t msg(commands::c2s_hello); client.send(msg); @@ -236,10 +252,10 @@ TEST(tcp_service_client, hello) { TEST(tcp_service_client, echo) { test_service service; - service.start("127.0.0.1", 21205); + service.start("127.0.0.1", 0); test_client client; - client.connect("127.0.0.1", 21205); + client.connect("127.0.0.1", service.get_port()); test_client::message_t msg(commands::c2s_echo); client.send(msg); @@ -255,10 +271,10 @@ TEST(tcp_service_client, echo) { TEST(tcp_service_client, ping_pong) { test_service service; - service.start("127.0.0.1", 21205); + service.start("127.0.0.1", 0); test_client client; - client.connect("127.0.0.1", 21205); + client.connect("127.0.0.1", service.get_port()); test_client::message_t msg(commands::c2s_ping); msg << std::string("PiNg"); @@ -275,13 +291,13 @@ TEST(tcp_service_client, ping_pong) { TEST(tcp_service_client, broadcast) { test_service service; - service.start("127.0.0.1", 21205); + service.start("127.0.0.1", 0); test_client client1; - EXPECT_TRUE(client1.connect("127.0.0.1", 21205) == true); + EXPECT_TRUE(client1.connect("127.0.0.1", service.get_port()) == true); test_client client2; - EXPECT_TRUE(client2.connect("127.0.0.1", 21205) == true); + EXPECT_TRUE(client2.connect("127.0.0.1", service.get_port()) == true); std::this_thread::sleep_for(std::chrono::milliseconds(2500)); @@ -304,10 +320,10 @@ TEST(tcp_service_client, broadcast) { TEST(tcp_service_client, ping_pong_pre_post_process) { test_service_pp service; - service.start("127.0.0.1", 21205); + service.start("127.0.0.1", 0); test_client_pp client; - client.connect("127.0.0.1", 21205); + client.connect("127.0.0.1", service.get_port()); test_client_pp::message_t msg(commands::c2s_ping); msg << std::string("PiNg"); @@ -324,15 +340,15 @@ TEST(tcp_service_client, ping_pong_pre_post_process) { TEST(tcp_service_client, two_clients_out_of_order) { test_service service; - service.start("127.0.0.1", 21205); + service.start("127.0.0.1", 0); service.process_messages_async(); test_client client1; - EXPECT_TRUE(client1.connect("127.0.0.1", 21205) == true); + EXPECT_TRUE(client1.connect("127.0.0.1", service.get_port()) == true); client1.process_messages_async(); test_client client2; - EXPECT_TRUE(client2.connect("127.0.0.1", 21205) == true); + EXPECT_TRUE(client2.connect("127.0.0.1", service.get_port()) == true); client2.process_messages_async(); std::this_thread::sleep_for(std::chrono::milliseconds(2500)); @@ -359,11 +375,11 @@ TEST(tcp_service_client, two_clients_out_of_order) { TEST(tcp_service_client, send_keep_message) { test_service service; - service.start("127.0.0.1", 21205); + service.start("127.0.0.1", 0); service.process_messages_async(); test_client client; - EXPECT_TRUE(client.connect("127.0.0.1", 21205) == true); + EXPECT_TRUE(client.connect("127.0.0.1", service.get_port()) == true); client.process_messages_async(); std::this_thread::sleep_for(std::chrono::milliseconds(2500)); diff --git a/test/test_tcp_talk.cpp b/test/test_tcp_talk.cpp index ae8a5b3..e4563e5 100644 --- a/test/test_tcp_talk.cpp +++ b/test/test_tcp_talk.cpp @@ -25,21 +25,25 @@ struct service_desc { class test_service : public tcp_service { public: - test_service() : tcp_service() {} + test_service() : tcp_service() { + set_message_callback([this](auto command, auto message) { + ev_message(command, message); + }); + } - void ev_message(owned_message_t& msg) override { + void ev_message(command_t command, owned_message_t* msg) { message_t response; std::string received; - msg.msg >> received; + msg->msg >> received; - switch (msg.msg.command()) { + switch (command) { case commands::c2s_msg1: { EXPECT_TRUE(received == "request_1"); response.set_command(commands::s2c_msg1); response << "response_1"; - msg.sender->send(response); + msg->sender->send(response); break; } case commands::c2s_msg2: { @@ -47,7 +51,7 @@ class test_service : public tcp_service { response.set_command(commands::s2c_msg2); response << "response_2"; - msg.sender->send(response); + msg->sender->send(response); break; } case commands::c2s_msg3: { @@ -55,7 +59,7 @@ class test_service : public tcp_service { response.set_command(commands::s2c_msg3); response << "response_3"; - msg.sender->send(response); + msg->sender->send(response); break; } case commands::c2s_msg4: { @@ -63,7 +67,7 @@ class test_service : public tcp_service { response.set_command(commands::s2c_msg4); response << "response_4"; - msg.sender->send(response); + msg->sender->send(response); break; } @@ -122,12 +126,14 @@ class test_client : public tcp_client { } }; +#include + TEST(tcp_talk, talking) { test_service service; - service.start("127.0.0.1", 21205); + service.start("127.0.0.1", 0); test_client client; - client.connect("127.0.0.1", 21205); + client.connect("127.0.0.1", service.get_port()); test_client::message_t msg(test_client::command_t::c2s_msg1); msg << "request_1";