diff --git a/src/TcpRelayStatisticsInfo.cpp b/src/TcpRelayStatisticsInfo.cpp index e50124a..d66164c 100644 --- a/src/TcpRelayStatisticsInfo.cpp +++ b/src/TcpRelayStatisticsInfo.cpp @@ -92,8 +92,8 @@ TcpRelayStatisticsInfo::SessionInfo::SessionInfo(const std::shared_ptr &s) : SessionInfo(s.lock()) { -} +TcpRelayStatisticsInfo::SessionInfo::SessionInfo(const std::weak_ptr &s) + : SessionInfo(s.lock()) {} void TcpRelayStatisticsInfo::SessionInfo::updateTargetInfo(const std::shared_ptr &s) { if (const auto &p = s) { @@ -109,6 +109,7 @@ void TcpRelayStatisticsInfo::SessionInfo::updateTargetInfo(const std::shared_ptr void TcpRelayStatisticsInfo::addSession(size_t index, const std::shared_ptr &s) { BOOST_ASSERT(s); + std::lock_guard lgG{mtx}; if (auto ptr = s) { if (upstreamIndex.find(index) == upstreamIndex.end()) { upstreamIndex.try_emplace(index, std::make_shared()); @@ -136,6 +137,7 @@ void TcpRelayStatisticsInfo::addSession(size_t index, const std::shared_ptr &s) { BOOST_ASSERT(s); + std::lock_guard lgG{mtx}; if (auto ptr = s) { const std::string &addr = ptr->getClientEndpointAddrString(); if (clientIndex.find(addr) == clientIndex.end()) { @@ -164,6 +166,7 @@ void TcpRelayStatisticsInfo::addSessionClient(const std::shared_ptr &s) { BOOST_ASSERT(s); + std::lock_guard lgG{mtx}; if (auto ptr = s) { const std::string &addr = ptr->getListenEndpointAddrString(); if (listenIndex.find(addr) == listenIndex.end()) { @@ -192,6 +195,7 @@ void TcpRelayStatisticsInfo::addSessionListen(const std::shared_ptr &s) { BOOST_ASSERT(s); + std::lock_guard lgG{mtx}; if (auto ptr = s) { BOOST_ASSERT(ptr->authUser); size_t id = ptr->authUser->id; @@ -223,6 +227,7 @@ void TcpRelayStatisticsInfo::addSessionAuthUser(const std::shared_ptr s) { BOOST_ASSERT(s); + std::lock_guard lgG{mtx}; if (s) { { auto ui = upstreamIndex.find(s->getNowServer()->index); @@ -276,6 +281,7 @@ void TcpRelayStatisticsInfo::updateSessionInfo(std::shared_ptr } std::shared_ptr TcpRelayStatisticsInfo::getInfo(size_t index) { + std::lock_guard lg{mtx}; auto n = upstreamIndex.find(index); if (n != upstreamIndex.end()) { return n->second; @@ -285,6 +291,7 @@ std::shared_ptr TcpRelayStatisticsInfo::getInfo(si } std::shared_ptr TcpRelayStatisticsInfo::getInfoClient(const std::string &addr) { + std::lock_guard lg{mtx}; auto n = clientIndex.find(addr); if (n != clientIndex.end()) { return n->second; @@ -294,6 +301,7 @@ std::shared_ptr TcpRelayStatisticsInfo::getInfoCli } std::shared_ptr TcpRelayStatisticsInfo::getInfoListen(const std::string &addr) { + std::lock_guard lg{mtx}; auto n = listenIndex.find(addr); if (n != listenIndex.end()) { return n->second; @@ -303,6 +311,7 @@ std::shared_ptr TcpRelayStatisticsInfo::getInfoLis } std::shared_ptr TcpRelayStatisticsInfo::getInfoAuthUser(size_t id) { + std::lock_guard lg{mtx}; auto n = authUserIndex.find(id); BOOST_ASSERT(n != authUserIndex.end()); if (n != authUserIndex.end()) { @@ -313,6 +322,7 @@ std::shared_ptr TcpRelayStatisticsInfo::getInfoAut } void TcpRelayStatisticsInfo::removeExpiredSession(size_t index) { + std::lock_guard lg{mtx}; auto p = getInfo(index); if (p) { p->removeExpiredSession(); @@ -320,6 +330,7 @@ void TcpRelayStatisticsInfo::removeExpiredSession(size_t index) { } void TcpRelayStatisticsInfo::removeExpiredSessionClient(const std::string &addr) { + std::lock_guard lg{mtx}; auto p = getInfoClient(addr); if (p) { p->removeExpiredSession(); @@ -327,6 +338,7 @@ void TcpRelayStatisticsInfo::removeExpiredSessionClient(const std::string &addr) } void TcpRelayStatisticsInfo::removeExpiredSessionListen(const std::string &addr) { + std::lock_guard lg{mtx}; auto p = getInfoListen(addr); if (p) { p->removeExpiredSession(); @@ -334,6 +346,7 @@ void TcpRelayStatisticsInfo::removeExpiredSessionListen(const std::string &addr) } void TcpRelayStatisticsInfo::removeExpiredSessionAuthUser(size_t id) { + std::lock_guard lg{mtx}; auto p = getInfoAuthUser(id); if (p) { p->removeExpiredSession(); @@ -341,6 +354,7 @@ void TcpRelayStatisticsInfo::removeExpiredSessionAuthUser(size_t id) { } void TcpRelayStatisticsInfo::addByteUp(size_t index, size_t b) { + std::lock_guard lg{mtx}; auto p = getInfo(index); if (p) { p->byteUp += b; @@ -348,6 +362,7 @@ void TcpRelayStatisticsInfo::addByteUp(size_t index, size_t b) { } void TcpRelayStatisticsInfo::addByteUpClient(const std::string &addr, size_t b) { + std::lock_guard lg{mtx}; auto p = getInfoClient(addr); if (p) { p->byteUp += b; @@ -355,6 +370,7 @@ void TcpRelayStatisticsInfo::addByteUpClient(const std::string &addr, size_t b) } void TcpRelayStatisticsInfo::addByteUpListen(const std::string &addr, size_t b) { + std::lock_guard lg{mtx}; auto p = getInfoListen(addr); if (p) { p->byteUp += b; @@ -362,6 +378,7 @@ void TcpRelayStatisticsInfo::addByteUpListen(const std::string &addr, size_t b) } void TcpRelayStatisticsInfo::addByteUpAuthUser(size_t id, size_t b) { + std::lock_guard lg{mtx}; auto p = getInfoAuthUser(id); if (p) { p->byteUp += b; @@ -369,6 +386,7 @@ void TcpRelayStatisticsInfo::addByteUpAuthUser(size_t id, size_t b) { } void TcpRelayStatisticsInfo::addByteDown(size_t index, size_t b) { + std::lock_guard lg{mtx}; auto p = getInfo(index); if (p) { p->byteDown += b; @@ -376,6 +394,7 @@ void TcpRelayStatisticsInfo::addByteDown(size_t index, size_t b) { } void TcpRelayStatisticsInfo::addByteDownClient(const std::string &addr, size_t b) { + std::lock_guard lg{mtx}; auto p = getInfoClient(addr); if (p) { p->byteDown += b; @@ -383,6 +402,7 @@ void TcpRelayStatisticsInfo::addByteDownClient(const std::string &addr, size_t b } void TcpRelayStatisticsInfo::addByteDownListen(const std::string &addr, size_t b) { + std::lock_guard lg{mtx}; auto p = getInfoListen(addr); if (p) { p->byteDown += b; @@ -390,6 +410,7 @@ void TcpRelayStatisticsInfo::addByteDownListen(const std::string &addr, size_t b } void TcpRelayStatisticsInfo::addByteDownAuthUser(size_t id, size_t b) { + std::lock_guard lg{mtx}; auto p = getInfoAuthUser(id); if (p) { p->byteDown += b; @@ -397,6 +418,7 @@ void TcpRelayStatisticsInfo::addByteDownAuthUser(size_t id, size_t b) { } void TcpRelayStatisticsInfo::calcByteAll() { + std::lock_guard lg{mtx}; for (auto &a: upstreamIndex) { a.second->calcByte(); } @@ -412,6 +434,7 @@ void TcpRelayStatisticsInfo::calcByteAll() { } void TcpRelayStatisticsInfo::removeExpiredSessionAll() { + std::lock_guard lg{mtx}; for (auto &a: upstreamIndex) { a.second->removeExpiredSession(); } @@ -427,6 +450,7 @@ void TcpRelayStatisticsInfo::removeExpiredSessionAll() { } void TcpRelayStatisticsInfo::closeAllSession(size_t index) { + std::lock_guard lg{mtx}; auto p = getInfo(index); if (p) { p->closeAllSession(); @@ -434,6 +458,7 @@ void TcpRelayStatisticsInfo::closeAllSession(size_t index) { } void TcpRelayStatisticsInfo::closeAllSessionClient(const std::string &addr) { + std::lock_guard lg{mtx}; auto p = getInfoClient(addr); if (p) { p->closeAllSession(); @@ -441,6 +466,7 @@ void TcpRelayStatisticsInfo::closeAllSessionClient(const std::string &addr) { } void TcpRelayStatisticsInfo::closeAllSessionListen(const std::string &addr) { + std::lock_guard lg{mtx}; auto p = getInfoListen(addr); if (p) { p->closeAllSession(); @@ -448,6 +474,7 @@ void TcpRelayStatisticsInfo::closeAllSessionListen(const std::string &addr) { } void TcpRelayStatisticsInfo::closeAllSessionAuthUser(size_t id) { + std::lock_guard lg{mtx}; auto p = getInfoAuthUser(id); if (p) { p->closeAllSession(); @@ -455,6 +482,7 @@ void TcpRelayStatisticsInfo::closeAllSessionAuthUser(size_t id) { } void TcpRelayStatisticsInfo::connectCountAdd(size_t index) { + std::lock_guard lg{mtx}; auto p = getInfo(index); if (p) { p->connectCountAdd(); @@ -462,6 +490,7 @@ void TcpRelayStatisticsInfo::connectCountAdd(size_t index) { } void TcpRelayStatisticsInfo::connectCountAddClient(const std::string &addr) { + std::lock_guard lg{mtx}; auto p = getInfoClient(addr); if (p) { p->connectCountAdd(); @@ -469,6 +498,7 @@ void TcpRelayStatisticsInfo::connectCountAddClient(const std::string &addr) { } void TcpRelayStatisticsInfo::connectCountAddListen(const std::string &addr) { + std::lock_guard lg{mtx}; auto p = getInfoListen(addr); if (p) { p->connectCountAdd(); @@ -476,6 +506,7 @@ void TcpRelayStatisticsInfo::connectCountAddListen(const std::string &addr) { } void TcpRelayStatisticsInfo::connectCountAddAuthUser(size_t id) { + std::lock_guard lg{mtx}; auto p = getInfoAuthUser(id); if (p) { p->connectCountAdd(); @@ -483,6 +514,7 @@ void TcpRelayStatisticsInfo::connectCountAddAuthUser(size_t id) { } void TcpRelayStatisticsInfo::connectCountSub(size_t index) { + std::lock_guard lg{mtx}; auto p = getInfo(index); if (p) { p->connectCountSub(); @@ -490,6 +522,7 @@ void TcpRelayStatisticsInfo::connectCountSub(size_t index) { } void TcpRelayStatisticsInfo::connectCountSubClient(const std::string &addr) { + std::lock_guard lg{mtx}; auto p = getInfoClient(addr); if (p) { p->connectCountSub(); @@ -497,6 +530,7 @@ void TcpRelayStatisticsInfo::connectCountSubClient(const std::string &addr) { } void TcpRelayStatisticsInfo::connectCountSubListen(const std::string &addr) { + std::lock_guard lg{mtx}; auto p = getInfoListen(addr); if (p) { p->connectCountSub(); @@ -504,24 +538,29 @@ void TcpRelayStatisticsInfo::connectCountSubListen(const std::string &addr) { } void TcpRelayStatisticsInfo::connectCountSubAuthUser(size_t id) { + std::lock_guard lg{mtx}; auto p = getInfoAuthUser(id); if (p) { p->connectCountSub(); } } -std::map> &TcpRelayStatisticsInfo::getUpstreamIndex() { - return upstreamIndex; +std::map> TcpRelayStatisticsInfo::getUpstreamIndex() { + std::lock_guard lg{mtx}; + return decltype(upstreamIndex){upstreamIndex.begin(), upstreamIndex.end()}; } -std::map> &TcpRelayStatisticsInfo::getClientIndex() { - return clientIndex; +std::map> TcpRelayStatisticsInfo::getClientIndex() { + std::lock_guard lg{mtx}; + return decltype(clientIndex){clientIndex.begin(), clientIndex.end()}; } -std::map> &TcpRelayStatisticsInfo::getListenIndex() { - return listenIndex; +std::map> TcpRelayStatisticsInfo::getListenIndex() { + std::lock_guard lg{mtx}; + return decltype(listenIndex){listenIndex.begin(), listenIndex.end()}; } -std::map> &TcpRelayStatisticsInfo::getAuthUserIndex() { - return authUserIndex; +std::map> TcpRelayStatisticsInfo::getAuthUserIndex() { + std::lock_guard lg{mtx}; + return decltype(authUserIndex){authUserIndex.begin(), authUserIndex.end()}; } diff --git a/src/TcpRelayStatisticsInfo.h b/src/TcpRelayStatisticsInfo.h index e981b4b..2d6628d 100644 --- a/src/TcpRelayStatisticsInfo.h +++ b/src/TcpRelayStatisticsInfo.h @@ -97,6 +97,9 @@ class TcpRelayStatisticsInfo : public std::enable_shared_from_this &s); }; @@ -156,9 +159,11 @@ class TcpRelayStatisticsInfo : public std::enable_shared_from_this> upstreamIndex; // clientEndpointAddrString "ip" @@ -184,15 +195,15 @@ class TcpRelayStatisticsInfo : public std::enable_shared_from_thisindex - std::map> &getUpstreamIndex(); + std::map> getUpstreamIndex(); // ClientEndpointAddrString : (127.0.0.1) - std::map> &getClientIndex(); + std::map> getClientIndex(); // ListenEndpointAddrString : (127.0.0.1:661133) - std::map> &getListenIndex(); + std::map> getListenIndex(); - std::map> &getAuthUserIndex(); + std::map> getAuthUserIndex(); public: void addSession(size_t index, const std::shared_ptr &s); diff --git a/src/UpstreamPool.cpp b/src/UpstreamPool.cpp index 27bb82b..109ca7f 100644 --- a/src/UpstreamPool.cpp +++ b/src/UpstreamPool.cpp @@ -94,6 +94,7 @@ void UpstreamPool::setConfig(std::shared_ptr configLoader) { } void UpstreamPool::forceSetLastUseUpstreamIndex(size_t i) { + std::lock_guard lg{mtx}; if (i >= 0 && i < _pool.size()) { lastUseUpstreamIndex = i; } @@ -174,6 +175,7 @@ auto UpstreamPool::getServerByHint( size_t &_lastUseUpstreamIndex, const size_t &relayId, bool dontFallbackToGlobal) -> UpstreamServerRef { + std::lock_guard lg{mtx}; RuleEnum __upstreamSelectRule = _upstreamSelectRule; @@ -231,6 +233,7 @@ auto UpstreamPool::getServerByHint( } auto UpstreamPool::getServerGlobal(const size_t &relayId) -> UpstreamServerRef { + std::lock_guard lg{mtx}; const auto &_upstreamSelectRule = _configLoader->config.upstreamSelectRule; auto &_lastUseUpstreamIndex = lastUseUpstreamIndex; return getServerByHint(_upstreamSelectRule, _lastUseUpstreamIndex, relayId, true); @@ -274,6 +277,7 @@ void UpstreamPool::startAdditionTimer() { } void UpstreamPool::startCheckTimer() { + std::lock_guard lg{mtx}; if (tcpCheckerTimer && connectCheckerTimer) { return; } @@ -320,6 +324,7 @@ std::string UpstreamPool::print() { } void UpstreamPool::stop() { + std::lock_guard lg{mtx}; endAdditionTimer(); endCheckTimer(); if (auto ptr = forceCheckerTimer.lock()) { @@ -346,6 +351,7 @@ void UpstreamPool::do_tcpCheckerTimer_impl() { auto p = std::to_string(a->port); auto t = tcpTest->createTest(a->host, p, maxDelayTime); t->run([t, a](std::chrono::milliseconds ping) { + std::lock_guard lg{a->mtx}; // on ok if (a->isOffline) { a->lastConnectFailed = false; @@ -357,6 +363,7 @@ void UpstreamPool::do_tcpCheckerTimer_impl() { }, [t, a](std::string reason) { boost::ignore_unused(reason); + std::lock_guard lg{a->mtx}; // ok error a->isOffline = true; a->lastOnlinePing = std::chrono::milliseconds{-1}; @@ -373,6 +380,7 @@ void UpstreamPool::do_tcpCheckerOne_impl(UpstreamServerRef a) { auto p = std::to_string(a->port); auto t = tcpTest->createTest(a->host, p); t->run([t, a](std::chrono::milliseconds ping) { + std::lock_guard lg{a->mtx}; // on ok if (a->isOffline) { a->lastConnectFailed = false; @@ -384,6 +392,7 @@ void UpstreamPool::do_tcpCheckerOne_impl(UpstreamServerRef a) { }, [t, a](std::string reason) { boost::ignore_unused(reason); + std::lock_guard lg{a->mtx}; // ok error a->isOffline = true; a->lastOnlinePing = std::chrono::milliseconds{-1}; @@ -487,6 +496,7 @@ void UpstreamPool::do_connectCheckerTimer_impl() { maxDelayTime ); t->run([t, a](std::chrono::milliseconds ping, ConnectTestHttpsSession::SuccessfulInfo info) { + std::lock_guard lg{a->mtx}; // on ok // BOOST_LOG_S5B(trace) << "SuccessfulInfo:" << info; a->lastConnectTime = UpstreamTimePointNow(); @@ -500,6 +510,7 @@ void UpstreamPool::do_connectCheckerTimer_impl() { }, [t, a](std::string reason) { boost::ignore_unused(reason); + std::lock_guard lg{a->mtx}; // ok error a->lastConnectFailed = true; a->lastConnectPing = std::chrono::milliseconds{-1}; @@ -524,6 +535,7 @@ void UpstreamPool::do_connectCheckerOne_impl(UpstreamServerRef a) { R"(\)" ); t->run([t, a](std::chrono::milliseconds ping, ConnectTestHttpsSession::SuccessfulInfo info) { + std::lock_guard lg{a->mtx}; // on ok // BOOST_LOG_S5B(trace) << "SuccessfulInfo:" << info; a->lastConnectTime = UpstreamTimePointNow(); @@ -537,6 +549,7 @@ void UpstreamPool::do_connectCheckerOne_impl(UpstreamServerRef a) { }, [t, a](std::string reason) { boost::ignore_unused(reason); + std::lock_guard lg{a->mtx}; // ok error a->lastConnectFailed = true; a->lastConnectPing = std::chrono::milliseconds{-1}; @@ -565,6 +578,7 @@ void UpstreamPool::do_connectCheckerTimer() { } void UpstreamPool::forceCheckNow() { + std::lock_guard lg{mtx}; if (_configLoader->config.disableConnectTest) { return; } @@ -578,6 +592,7 @@ void UpstreamPool::forceCheckNow() { } void UpstreamPool::forceCheckOne(size_t index) { + std::lock_guard lg{mtx}; if (_configLoader->config.disableConnectTest) { return; } @@ -609,9 +624,11 @@ void UpstreamPool::do_forceCheckNow(std::shared_ptr _forceChec } void UpstreamPool::updateLastConnectComeTime() { + std::lock_guard lg{mtx}; this->lastConnectComeTime = UpstreamTimePointNow(); } UpstreamTimePoint UpstreamPool::getLastConnectComeTime() { + std::lock_guard lg{mtx}; return this->lastConnectComeTime; } diff --git a/src/UpstreamPool.h b/src/UpstreamPool.h index 060335c..c97956c 100644 --- a/src/UpstreamPool.h +++ b/src/UpstreamPool.h @@ -50,6 +50,8 @@ UpstreamTimePoint UpstreamTimePointNow(); std::string printUpstreamTimePoint(UpstreamTimePoint p); struct UpstreamServer : public std::enable_shared_from_this { + std::recursive_mutex mtx; + std::string host; uint16_t port; std::string name; @@ -105,6 +107,7 @@ struct UpstreamServer : public std::enable_shared_from_this { using UpstreamServerRef = std::shared_ptr; class UpstreamPool : public std::enable_shared_from_this { + std::recursive_mutex mtx; boost::asio::any_io_executor ex; std::deque _pool; @@ -166,13 +169,14 @@ class UpstreamPool : public std::enable_shared_from_this { std::weak_ptr forceCheckerTimer; std::shared_ptr additionTimer; -public: +protected: void endCheckTimer(); - void startCheckTimer(); - std::string print(); +public: + void startCheckTimer(); + void stop(); void forceCheckNow(); diff --git a/src/main.cpp b/src/main.cpp index b8316e1..bd06539 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -93,21 +93,20 @@ int main(int argc, const char *argv[]) { try { boost::asio::io_context ioc; - boost::asio::any_io_executor ex = boost::asio::make_strand(ioc); auto configLoader = std::make_shared(); configLoader->load(config_file); configLoader->print(); - auto tcpTest = std::make_shared(ex); - auto connectTestHttps = std::make_shared(ex); + auto tcpTest = std::make_shared(boost::asio::make_strand(ioc)); + auto connectTestHttps = std::make_shared(boost::asio::make_strand(ioc)); auto authClientManager = std::make_shared(configLoader->shared_from_this()); - auto upstreamPool = std::make_shared(ex, tcpTest, connectTestHttps); + auto upstreamPool = std::make_shared(boost::asio::make_strand(ioc), tcpTest, connectTestHttps); upstreamPool->setConfig(configLoader); - auto tcpRelay = std::make_shared(ex, configLoader, upstreamPool, authClientManager); + auto tcpRelay = std::make_shared(boost::asio::make_strand(ioc), configLoader, upstreamPool, authClientManager); auto stateMonitor = std::make_shared( boost::asio::make_strand(ioc), configLoader, upstreamPool, tcpRelay);