diff --git a/.gitignore b/.gitignore index baed48d16..61459d94c 100644 --- a/.gitignore +++ b/.gitignore @@ -37,3 +37,9 @@ git_version.* versioninfo.txt rtprcv-* _bwelogs* + +# Executables +smb +UnitTest2 +UnitTest +LoadTest diff --git a/CMakeLists.txt b/CMakeLists.txt index c8d3cd7e1..e3f3a672b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -577,6 +577,7 @@ set(TEST_FILES test/api/ParserTest.cpp test/memory/MapTest.cpp test/memory/PoolAllocatorTest.cpp + test/memory/PoolBufferTest.cpp test/memory/RingAllocatorTest.cpp test/utils/StringTokenizerTest.cpp test/utils/TrackerTest.cpp @@ -651,6 +652,7 @@ set(TEST_FILES test/memory/StackMapTest.cpp test/bridge/ActiveMediaListTestLevels.h test/bridge/MixerTest.cpp + test/bridge/DataChannelMessageSizeTest.cpp test/bridge/VideoNackReceiveJobTest.cpp test/utils/LogSpamTest.cpp test/utils/FunctionTest.cpp diff --git a/api/DataChannelMessage.h b/api/DataChannelMessage.h index 965eb21aa..af7475520 100644 --- a/api/DataChannelMessage.h +++ b/api/DataChannelMessage.h @@ -1,5 +1,6 @@ #pragma once +#include "memory/PoolBuffer.h" #include "utils/StringBuilder.h" #if ENABLE_LEGACY_API @@ -12,7 +13,8 @@ namespace api namespace DataChannelMessage { -inline void makeEndpointMessage(utils::StringBuilder<2048>& outMessage, +template +inline void makeEndpointMessage(utils::StringBuilder& outMessage, const std::string& toEndpointId, const std::string& fromEndpointId, const char* message) @@ -33,6 +35,69 @@ inline void makeEndpointMessage(utils::StringBuilder<2048>& outMessage, #endif } +inline memory::UniquePoolBuffer makeUniqueEndpointMessageBuffer( + const std::string& toEndpointId, + const std::string& fromEndpointId, + const memory::UniquePoolBuffer& payload) +{ +#if ENABLE_LEGACY_API + return legacyapi::DataChannelMessage::makeUniqueEndpointMessageBuffer(toEndpointId, fromEndpointId, payload); +#else + constexpr const char* TO_STRING = "{\"type\":\"EndpointMessage\",\"to\":\""; + constexpr const char* FROM_STRING = "\",\"from\":\""; + constexpr const char* MSG_STRING = "\",\"payload\":"; + constexpr const char* TAIL_STRING = "}"; + + constexpr std::size_t overhead_len = std::char_traits::length(TO_STRING) + + std::char_traits::length(FROM_STRING) + + std::char_traits::length(MSG_STRING) + + std::char_traits::length(TAIL_STRING); + + const std::size_t extraLen = toEndpointId.length() + fromEndpointId.length() + payload->getLength(); + auto buffer = memory::makeUniquePoolBuffer(payload->getAllocator(), overhead_len + extraLen); + if (!buffer) + { + return buffer; + } + + auto written = buffer->write(TO_STRING, std::char_traits::length(TO_STRING), 0); + written += buffer->write(toEndpointId.c_str(), toEndpointId.length(), written); + written += buffer->write(FROM_STRING, std::char_traits::length(FROM_STRING), written); + written += buffer->write(fromEndpointId.c_str(), fromEndpointId.length(), written); + written += buffer->write(MSG_STRING, std::char_traits::length(MSG_STRING), written); + written += buffer->write(*payload.get(), written); + written += buffer->write(TAIL_STRING, std::char_traits::length(TAIL_STRING), written); + + assert(written == buffer->getLength()); + return buffer; +#endif +} + +template +inline void makeLoggableStringFromBuffer(memory::Array& outArray, memory::UniquePoolBuffer& payload) +{ + if (!payload) + { + return; + } + outArray.clear(); + bool ellipsisNeeded = payload->getLength() > T - 1; + + const size_t maxCStrLength = std::min(payload->getLength(), T - 1); + outArray.resize(maxCStrLength + 1); + const auto read = payload->copyTo(const_cast(reinterpret_cast(outArray.data())), 0, maxCStrLength); + assert(read == maxCStrLength); + outArray[maxCStrLength] = '\0'; + + // Indicate that message was incompletely logged + if (ellipsisNeeded && T >= 4) + { + outArray[T - 2] = '.'; + outArray[T - 3] = '.'; + outArray[T - 4] = '.'; + } +} + inline void makeDominantSpeaker(utils::StringBuilder<256>& outMessage, const char* endpointId) { #if ENABLE_LEGACY_API diff --git a/api/Generator.cpp b/api/Generator.cpp index a7eeb2c2c..c4dff2822 100644 --- a/api/Generator.cpp +++ b/api/Generator.cpp @@ -210,6 +210,7 @@ nlohmann::json generateAllocateEndpointResponse(const EndpointDescription& chann const auto& data = channelsDescription.data.get(); nlohmann::json dataJson; dataJson["port"] = data.port; + dataJson["max-message-size"] = data.maxMessageSize; responseJson["data"] = dataJson; } diff --git a/api/RtcDescriptors.h b/api/RtcDescriptors.h index 746bbb812..03a09dc72 100644 --- a/api/RtcDescriptors.h +++ b/api/RtcDescriptors.h @@ -140,5 +140,6 @@ struct Video struct Data { uint32_t port; + uint32_t maxMessageSize; }; } // namespace api diff --git a/bridge/Mixer.cpp b/bridge/Mixer.cpp index c62af31cc..84cfa57b3 100644 --- a/bridge/Mixer.cpp +++ b/bridge/Mixer.cpp @@ -1835,14 +1835,17 @@ void Mixer::sendEndpointMessage(const std::string& toEndpointId, const utils::SimpleJson& message) { assert(fromEndpointIdHash); - if (message.size() >= memory::AudioPacket::maxLength()) + if (message.size() >= _config.sctp.maxMessageSize) { + logger::warn("Endpoint message too big, len %zu", + "MixerManager", + message.size() + ); return; } - auto& audioAllocator = _engineMixer->getAudioAllocator(); - auto packet = memory::makeUniquePacket(audioAllocator, message.jsonBegin(), message.size()); - reinterpret_cast(packet->get())[message.size()] = 0; // null terminated in packet + auto& packetAllocator = _engineMixer->getMainAllocator(); + auto buffer = memory::makeUniquePoolBuffer(packetAllocator, message.jsonBegin(), message.size()); std::lock_guard locker(_configurationLock); @@ -1857,7 +1860,7 @@ void Mixer::sendEndpointMessage(const std::string& toEndpointId, toEndpointIdHash = dataStreamItr->second->endpointIdHash; } - _engineMixer->asyncSendEndpointMessage(toEndpointIdHash, fromEndpointIdHash, packet); + _engineMixer->asyncSendEndpointMessage(toEndpointIdHash, fromEndpointIdHash, buffer); } RecordingStream* Mixer::findRecordingStream(const std::string& recordingId) diff --git a/bridge/MixerManager.cpp b/bridge/MixerManager.cpp index 83300bd0b..9fa6e719b 100644 --- a/bridge/MixerManager.cpp +++ b/bridge/MixerManager.cpp @@ -5,7 +5,7 @@ #include "bridge/Mixer.h" #include "bridge/MixerJobs.h" #include "bridge/VideoStream.h" -#include "bridge/engine/Engine.h" + #include "bridge/engine/EngineAudioStream.h" #include "bridge/engine/EngineBarbell.h" #include "bridge/engine/EngineDataStream.h" @@ -492,24 +492,34 @@ void MixerManager::freeVideoPacketCache(EngineMixer& mixer, uint32_t ssrc, size_ mixerItr->second->freeVideoPacketCache(ssrc, endpointIdHash); } -void MixerManager::sctpReceived(EngineMixer& mixer, memory::UniquePacket msgPacket, size_t endpointIdHash) +void MixerManager::sctpReceived(EngineMixer& mixer, memory::UniquePoolBuffer message, size_t endpointIdHash) { - auto& sctpHeader = webrtc::streamMessageHeader(*msgPacket); + // HEADER: SctpStreamMessageHeader prepended to payload + // Need to get full message instead of first chunk only to form JSON from it. + constexpr size_t MAX_BUFFER_SIZE = 8192; + if (message->size() > MAX_BUFFER_SIZE) { + logger::warn("Received large SCTP message(size = %zu, max allowed = %zu). Dropping.", "MixerManager", message->getLength(), MAX_BUFFER_SIZE); + return; + } + + char continousBuffer[message->size()]; + message->copyTo(continousBuffer, 0, message->size()); + + auto& sctpHeader = *reinterpret_cast(continousBuffer); if (sctpHeader.payloadProtocol == webrtc::DataChannelPpid::WEBRTC_ESTABLISH) { // create command with this packet to send the binary data -> engine -> WebRtcDataStream belonging to this // transport - mixer.asyncHandleSctpControl(endpointIdHash, msgPacket); + mixer.asyncHandleSctpControl(endpointIdHash, message); return; // do not free packet as we passed it on } else if (sctpHeader.payloadProtocol == webrtc::DataChannelPpid::WEBRTC_STRING) { - std::string body(reinterpret_cast(sctpHeader.data()), msgPacket->getLength() - sizeof(sctpHeader)); + std::string body(sctpHeader.getMessage(), message->getLength() - sizeof(sctpHeader)); try { - auto json = utils::SimpleJson::create(reinterpret_cast(sctpHeader.data()), - msgPacket->getLength() - sizeof(sctpHeader)); + auto json = utils::SimpleJson::create(sctpHeader.getMessage(), message->getLength() - sizeof(sctpHeader)); if (api::DataChannelMessageParser::isPinnedEndpointsChanged(json)) { @@ -573,7 +583,7 @@ void MixerManager::sctpReceived(EngineMixer& mixer, memory::UniquePacket msgPack logger::debug("received unexpected DataChannel payload protocol, %u, len %zu", "MixerManager", sctpHeader.payloadProtocol, - msgPacket->getLength()); + message->getLength()); } } diff --git a/bridge/MixerManager.h b/bridge/MixerManager.h index 7fb454206..b687ab1ea 100644 --- a/bridge/MixerManager.h +++ b/bridge/MixerManager.h @@ -5,6 +5,7 @@ #include "bridge/Stats.h" #include "bridge/engine/EngineMixer.h" #include "bridge/engine/EngineStats.h" +#include "bridge/engine/Engine.h" #include "concurrency/MpmcQueue.h" #include "memory/PacketPoolAllocator.h" #include "utils/Pacer.h" @@ -14,11 +15,6 @@ #include #include -namespace bridge -{ -class Engine; -} - namespace utils { class IdGenerator; @@ -122,7 +118,9 @@ class MixerManager : public MixerManagerAsync void allocateVideoPacketCache(EngineMixer& mixer, uint32_t ssrc, size_t endpointIdHash) override; void allocateRecordingRtpPacketCache(EngineMixer& mixer, uint32_t ssrc, size_t endpointIdHash) override; void videoStreamRemoved(EngineMixer& engineMixer, const EngineVideoStream& videoStream) override; - void sctpReceived(EngineMixer& mixer, memory::UniquePacket msgPacket, size_t endpointIdHash) override; + void sctpReceived(EngineMixer& mixer, + memory::UniquePoolBuffer message, + size_t endpointIdHash) override; void dataStreamRemoved(EngineMixer& mixer, const EngineDataStream& dataStream) override; void freeRecordingRtpPacketCache(EngineMixer& mixer, uint32_t ssrc, size_t endpointIdHash) override; void barbellRemoved(EngineMixer& mixer, const EngineBarbell& barbell) override; diff --git a/bridge/MixerManagerAsync.cpp b/bridge/MixerManagerAsync.cpp index 3a74894c6..b32858de8 100644 --- a/bridge/MixerManagerAsync.cpp +++ b/bridge/MixerManagerAsync.cpp @@ -52,12 +52,14 @@ bool MixerManagerAsync::asyncVideoStreamRemoved(EngineMixer& engineMixer, const utils::bind(&MixerManagerAsync::videoStreamRemoved, this, std::ref(engineMixer), std::cref(videoStream))); } -bool MixerManagerAsync::asyncSctpReceived(EngineMixer& mixer, memory::UniquePacket& msgPacket, size_t endpointIdHash) +bool MixerManagerAsync::asyncSctpReceived(EngineMixer& mixer, + memory::UniquePoolBuffer& message, + size_t endpointIdHash) { return post(utils::bind(&MixerManagerAsync::sctpReceived, this, std::ref(mixer), - utils::moveParam(msgPacket), + utils::moveParam(message), endpointIdHash)); } diff --git a/bridge/MixerManagerAsync.h b/bridge/MixerManagerAsync.h index 858b154c1..2cd30bd2e 100644 --- a/bridge/MixerManagerAsync.h +++ b/bridge/MixerManagerAsync.h @@ -1,6 +1,7 @@ #pragma once #include "bridge/engine/EndpointId.h" #include "memory/PacketPoolAllocator.h" +#include "memory/PoolBuffer.h" #include "utils/Function.h" #include @@ -44,7 +45,9 @@ class MixerManagerAsync virtual void allocateVideoPacketCache(EngineMixer& mixer, uint32_t ssrc, size_t endpointIdHash) = 0; virtual void allocateRecordingRtpPacketCache(EngineMixer& mixer, uint32_t ssrc, size_t endpointIdHash) = 0; virtual void videoStreamRemoved(EngineMixer& engineMixer, const EngineVideoStream& videoStream) = 0; - virtual void sctpReceived(EngineMixer& mixer, memory::UniquePacket msgPacket, size_t endpointIdHash) = 0; + virtual void sctpReceived(EngineMixer& mixer, + memory::UniquePoolBuffer message, + size_t endpointIdHash) = 0; virtual void dataStreamRemoved(EngineMixer& mixer, const EngineDataStream& dataStream) = 0; virtual void freeRecordingRtpPacketCache(EngineMixer& mixer, uint32_t ssrc, size_t endpointIdHash) = 0; virtual void barbellRemoved(EngineMixer& mixer, const EngineBarbell& barbell) = 0; @@ -61,7 +64,9 @@ class MixerManagerAsync bool asyncAllocateVideoPacketCache(EngineMixer& mixer, uint32_t ssrc, size_t endpointIdHash); bool asyncAllocateRecordingRtpPacketCache(EngineMixer& mixer, uint32_t ssrc, size_t endpointIdHash); bool asyncVideoStreamRemoved(EngineMixer& engineMixer, const EngineVideoStream& videoStream); - bool asyncSctpReceived(EngineMixer& mixer, memory::UniquePacket& msgPacket, size_t endpointIdHash); + bool asyncSctpReceived(EngineMixer& mixer, + memory::UniquePoolBuffer& msgBuffer, + size_t endpointIdHash); bool asyncDataStreamRemoved(EngineMixer& mixer, const EngineDataStream& dataStream); bool asyncFreeRecordingRtpPacketCache(EngineMixer& mixer, uint32_t ssrc, size_t endpointIdHash); bool asyncBarbellRemoved(EngineMixer& mixer, const EngineBarbell& barbell); diff --git a/bridge/endpointActions/BarbellActions.cpp b/bridge/endpointActions/BarbellActions.cpp index 4aa82246b..22a46601d 100644 --- a/bridge/endpointActions/BarbellActions.cpp +++ b/bridge/endpointActions/BarbellActions.cpp @@ -110,6 +110,7 @@ httpd::Response generateBarbellResponse(ActionContext* context, api::Data responseData; responseData.port = 5000; + responseData.maxMessageSize = 2048; channelsDescription.data = responseData; const auto responseBody = api::Generator::generateAllocateBarbellResponse(channelsDescription); diff --git a/bridge/endpointActions/ConferenceActions.cpp b/bridge/endpointActions/ConferenceActions.cpp index edb9c909d..916f2e4c2 100644 --- a/bridge/endpointActions/ConferenceActions.cpp +++ b/bridge/endpointActions/ConferenceActions.cpp @@ -272,6 +272,7 @@ httpd::Response generateAllocateEndpointResponse(ActionContext* context, } responseData.port = streamDescription.sctpPort.isSet() ? streamDescription.sctpPort.get() : 5000; + responseData.maxMessageSize = mixer.getConfig().sctp.maxMessageSize; channelsDescription.data.set(responseData); } diff --git a/bridge/engine/Engine.h b/bridge/engine/Engine.h index 387d8efda..0477035ba 100644 --- a/bridge/engine/Engine.h +++ b/bridge/engine/Engine.h @@ -28,14 +28,15 @@ class Engine public: Engine(jobmanager::JobManager& backgroundJobQueue); Engine(jobmanager::JobManager& backgroundJobQueue, std::thread&& externalThread); + virtual ~Engine() = default; - void setMessageListener(MixerManagerAsync* messageListener); + virtual void setMessageListener(MixerManagerAsync* messageListener); void stop(); void run(); - bool post(utils::Function&& task) { return _tasks.push(std::move(task)); } + virtual bool post(utils::Function&& task) { return _tasks.push(std::move(task)); } - concurrency::SynchronizationContext getSynchronizationContext() + virtual concurrency::SynchronizationContext getSynchronizationContext() { return concurrency::SynchronizationContext(_tasks); } @@ -62,8 +63,8 @@ class Engine void updateStats(uint64_t& statsPollTime, EngineStats::EngineStats& currentStatSample, uint64_t timestamp); public: - bool asyncAddMixer(EngineMixer* engineMixer); - bool asyncRemoveMixer(EngineMixer* engineMixer); + virtual bool asyncAddMixer(EngineMixer* engineMixer); + virtual bool asyncRemoveMixer(EngineMixer* engineMixer); private: void addMixer(EngineMixer* engineMixer); diff --git a/bridge/engine/EngineMixer.cpp b/bridge/engine/EngineMixer.cpp index f6eb3adbb..bf226429d 100644 --- a/bridge/engine/EngineMixer.cpp +++ b/bridge/engine/EngineMixer.cpp @@ -1,5 +1,5 @@ -#include "bridge/engine/EngineMixer.h" #include "api/DataChannelMessage.h" +#include "bridge/engine/EngineMixer.h" #include "bridge/MixerManagerAsync.h" #include "bridge/engine/ActiveMediaList.h" #include "bridge/engine/EngineAudioStream.h" @@ -10,9 +10,11 @@ #include "bridge/engine/VideoNackReceiveJob.h" #include "config/Config.h" #include "logger/Logger.h" +#include "memory/Array.h" #include "rtp/RtcpFeedback.h" #include "rtp/RtpHeader.h" #include "transport/Transport.h" +#include "webrtc/DataChannel.h" using namespace bridge; @@ -532,19 +534,15 @@ void EngineMixer::onConnected(transport::RtcTransport* sender) } } -void EngineMixer::handleSctpControl(const size_t endpointIdHash, memory::UniquePacket packet) +void EngineMixer::handleSctpControl(const size_t endpointIdHash, + memory::UniquePoolBuffer message) { - auto& header = webrtc::streamMessageHeader(*packet); + // HEADER: SctpStreamMessageHeader prepended to payload auto* dataStream = _engineDataStreams.getItem(endpointIdHash); if (dataStream) { const bool wasOpen = dataStream->stream.isOpen(); - dataStream->stream.onSctpMessage(&dataStream->transport, - header.id, - header.sequenceNumber, - header.payloadProtocol, - header.data(), - packet->getLength() - sizeof(header)); + dataStream->stream.onSctpMessageBuffer(&dataStream->transport, message); if (!wasOpen && dataStream->stream.isOpen()) { @@ -559,11 +557,48 @@ void EngineMixer::handleSctpControl(const size_t endpointIdHash, memory::UniqueP } } +void EngineMixer::sendEndpointMessageTo(EngineDataStream* toDataStream, + const EngineDataStream* fromDataStream, + const memory::UniquePoolBuffer& payload, + bool shouldLog) +{ + if (!toDataStream || !toDataStream->stream.isOpen()) + { + return; + } + + auto endpointMessageBuffer = + api::DataChannelMessage::makeUniqueEndpointMessageBuffer(toDataStream->endpointId, fromDataStream->endpointId, payload); + const int length = endpointMessageBuffer->getLength(); + + if (length > 0 && (size_t)length < _config.sctp.maxMessageSize) + { + if (shouldLog) + { + memory::Array loggableBuffer; + api::DataChannelMessage::makeLoggableStringFromBuffer(loggableBuffer, endpointMessageBuffer); + logger::debug("Endpoint message %s -> %s: %s", + _loggableId.c_str(), + fromDataStream->endpointId.c_str(), + toDataStream->endpointId.c_str(), + loggableBuffer.data()); + } + toDataStream->stream.sendMessage(webrtc::DataChannelPpid::WEBRTC_STRING, std::move(endpointMessageBuffer)); + } + else + { + logger::warn("Failed to format endpoint message or buffer too small for %s -> %s.", + _loggableId.c_str(), + fromDataStream->endpointId.c_str(), + toDataStream->endpointId.c_str()); + } +} + void EngineMixer::sendEndpointMessage(const size_t toEndpointIdHash, const size_t fromEndpointIdHash, - memory::UniqueAudioPacket packet) + memory::UniquePoolBuffer buffer) { - if (!fromEndpointIdHash || !packet) + if (!fromEndpointIdHash || !buffer) { assert(false); return; @@ -575,49 +610,27 @@ void EngineMixer::sendEndpointMessage(const size_t toEndpointIdHash, return; } - auto message = reinterpret_cast(packet->get()); - utils::StringBuilder<2048> endpointMessage; - if (toEndpointIdHash) { auto* toDataStream = _engineDataStreams.getItem(toEndpointIdHash); - if (!toDataStream || !toDataStream->stream.isOpen()) - { - return; - } - - api::DataChannelMessage::makeEndpointMessage(endpointMessage, - toDataStream->endpointId, - fromDataStream->endpointId, - message); - - toDataStream->stream.sendString(endpointMessage.get(), endpointMessage.getLength()); - logger::debug("Endpoint message %lu -> %lu: %s", - _loggableId.c_str(), - fromEndpointIdHash, - toEndpointIdHash, - endpointMessage.get()); + sendEndpointMessageTo(toDataStream, fromDataStream, buffer, true); } else { - logger::debug("Broadcast Endpoint message from %lu %s", + memory::Array loggableBuffer; + api::DataChannelMessage::makeLoggableStringFromBuffer(loggableBuffer, buffer); + logger::debug("Broadcast Endpoint message from %s: %s", _loggableId.c_str(), - fromEndpointIdHash, - endpointMessage.get()); + fromDataStream->endpointId.c_str(), + loggableBuffer.data()); + for (auto& dataStreamEntry : _engineDataStreams) { - if (dataStreamEntry.first == fromEndpointIdHash || !dataStreamEntry.second->stream.isOpen()) + if (dataStreamEntry.first == fromEndpointIdHash) { continue; } - - endpointMessage.clear(); - api::DataChannelMessage::makeEndpointMessage(endpointMessage, - dataStreamEntry.second->endpointId, - fromDataStream->endpointId, - message); - - dataStreamEntry.second->stream.sendString(endpointMessage.get(), endpointMessage.getLength()); + sendEndpointMessageTo(dataStreamEntry.second, fromDataStream, buffer, false); } } } @@ -650,21 +663,21 @@ void EngineMixer::onSctpMessage(transport::RtcTransport* sender, size_t length) { assert(sender); - if (EngineBarbell::isFromBarbell(sender->getTag())) + + auto buffer = webrtc::makeUniqueSctpMessage(streamId, payloadProtocol, data, length, _sendAllocator); + if (!buffer) { - auto packet = webrtc::makeUniquePacket(streamId, payloadProtocol, data, length, _sendAllocator); - _incomingBarbellSctp.push(IncomingPacketInfo(std::move(packet), sender)); + logger::error("Unable to allocate sctp message, sender %p, length %lu", _loggableId.c_str(), sender, length); return; } - auto packet = webrtc::makeUniquePacket(streamId, payloadProtocol, data, length, _sendAllocator); - if (!packet) + if (EngineBarbell::isFromBarbell(sender->getTag())) { - logger::error("Unable to allocate sctp message, sender %p, length %lu", _loggableId.c_str(), sender, length); + _incomingBarbellSctp.push(IncomingSctpMessageInfo(std::move(buffer), sender)); return; } - _messageListener.asyncSctpReceived(*this, packet, sender->getEndpointIdHash()); + _messageListener.asyncSctpReceived(*this, buffer, sender->getEndpointIdHash()); } /** @@ -1672,9 +1685,10 @@ bool EngineMixer::asyncAddDataSteam(EngineDataStream* engineDataStream) return post(utils::bind(&EngineMixer::addDataSteam, this, engineDataStream)); } -bool EngineMixer::asyncHandleSctpControl(const size_t endpointIdHash, memory::UniquePacket& packet) +bool EngineMixer::asyncHandleSctpControl(const size_t endpointIdHash, + memory::UniquePoolBuffer& message) { - return post(utils::bind(&EngineMixer::handleSctpControl, this, endpointIdHash, utils::moveParam(packet))); + return post(utils::bind(&EngineMixer::handleSctpControl, this, endpointIdHash, utils::moveParam(message))); } void EngineMixer::onIceReceived(transport::RtcTransport* transport, uint64_t timestamp) diff --git a/bridge/engine/EngineMixer.h b/bridge/engine/EngineMixer.h index 6aaa02436..987e87641 100644 --- a/bridge/engine/EngineMixer.h +++ b/bridge/engine/EngineMixer.h @@ -12,6 +12,7 @@ #include "memory/AudioPacketPoolAllocator.h" #include "memory/Map.h" #include "memory/PacketPoolAllocator.h" +#include "memory/PoolBuffer.h" #include "transport/RtcTransport.h" #include #include @@ -182,7 +183,7 @@ class EngineMixer : public transport::DataReceiver bool asyncPinEndpoint(const size_t endpointIdHash, const size_t targetEndpointIdHash); bool asyncSendEndpointMessage(const size_t toEndpointIdHash, const size_t fromEndpointIdHash, - memory::UniqueAudioPacket& packet); + memory::UniquePoolBuffer& buffer); bool asyncAddRecordingStream(EngineRecordingStream* engineRecordingStream); bool asyncAddTransportToRecordingStream(const size_t streamIdHash, transport::RecordingTransport& transport, @@ -198,7 +199,8 @@ class EngineMixer : public transport::DataReceiver bool asyncRemoveTransportFromRecordingStream(const size_t streamIdHash, const size_t endpointIdHash); bool asyncAddBarbell(EngineBarbell* barbell); bool asyncRemoveBarbell(size_t idHash); - bool asyncHandleSctpControl(const size_t endpointIdHash, memory::UniquePacket& packet); + bool asyncHandleSctpControl(const size_t endpointIdHash, + memory::UniquePoolBuffer& message); bool asyncRemoveRecordingStream(const EngineRecordingStream& engineRecordingStream); private: // impl async interface @@ -221,7 +223,11 @@ class EngineMixer : public transport::DataReceiver void pinEndpoint(const size_t endpointIdHash, const size_t targetEndpointIdHash); void sendEndpointMessage(const size_t toEndpointIdHash, const size_t fromEndpointIdHash, - memory::UniqueAudioPacket packet); + memory::UniquePoolBuffer buffer); + void sendEndpointMessageTo(EngineDataStream* toDataStream, + const EngineDataStream* fromDataStream, + const memory::UniquePoolBuffer& payload, + bool shouldLog); void recordingStart(EngineRecordingStream& stream, const RecordingDescription& desc); void stopRecording(EngineRecordingStream& stream, const RecordingDescription& desc); void updateRecordingStreamModalities(EngineRecordingStream& engineRecordingStream, @@ -235,7 +241,7 @@ class EngineMixer : public transport::DataReceiver void removeTransportFromRecordingStream(const size_t streamIdHash, const size_t endpointIdHash); void addBarbell(EngineBarbell* barbell); void removeBarbell(size_t idHash); - void handleSctpControl(const size_t endpointIdHash, memory::UniquePacket packet); + void handleSctpControl(const size_t endpointIdHash, memory::UniquePoolBuffer message); public: // private but called from helper method void removeStream(const EngineVideoStream* engineVideoStream); @@ -354,6 +360,7 @@ class EngineMixer : public transport::DataReceiver }; using IncomingPacketInfo = IncomingPacketAggregate; + using IncomingSctpMessageInfo = IncomingPacketAggregate>; std::string _id; logger::LoggableId _loggableId; @@ -362,7 +369,7 @@ class EngineMixer : public transport::DataReceiver concurrency::SynchronizationContext _engineSyncContext; MixerManagerAsync& _messageListener; - concurrency::MpmcQueue _incomingBarbellSctp; + concurrency::MpmcQueue _incomingBarbellSctp; concurrency::MpmcQueue _incomingForwarderAudioRtp; concurrency::MpmcQueue _incomingRtcp; concurrency::MpmcQueue _incomingForwarderVideoRtp; @@ -518,8 +525,7 @@ class EngineMixer : public transport::DataReceiver void onBarbellUserMediaMap(size_t barbellIdHash, const char* message); void onBarbellMinUplinkEstimate(size_t barbellIdHash, const char* message); void onBarbellDataChannelEstablish(size_t barbellIdHash, - webrtc::SctpStreamMessageHeader& header, - size_t packetSize); + memory::UniquePoolBuffer message); //// diff --git a/bridge/engine/EngineMixerBarbell.cpp b/bridge/engine/EngineMixerBarbell.cpp index cd72875f6..54f58dfb9 100644 --- a/bridge/engine/EngineMixerBarbell.cpp +++ b/bridge/engine/EngineMixerBarbell.cpp @@ -9,6 +9,7 @@ #include "bridge/engine/ProcessMissingVideoPacketsJob.h" #include "bridge/engine/VideoForwarderRewriteAndSendJob.h" #include "bridge/engine/VideoNackReceiveJob.h" +#include "memory/Array.h" #include "rtp/RtcpFeedback.h" #include "utils/SimpleJson.h" #include "utils/StringBuilder.h" @@ -412,9 +413,9 @@ void EngineMixer::onBarbellMinUplinkEstimate(size_t barbellIdHash, const char* m } void EngineMixer::onBarbellDataChannelEstablish(size_t barbellIdHash, - webrtc::SctpStreamMessageHeader& header, - size_t packetSize) + memory::UniquePoolBuffer message) { + // HEADER: SctpStreamMessageHeader prepended to payload auto barbell = _engineBarbells.getItem(barbellIdHash); if (!barbell) { @@ -422,12 +423,7 @@ void EngineMixer::onBarbellDataChannelEstablish(size_t barbellIdHash, } const auto state = barbell->dataChannel.getState(); - barbell->dataChannel.onSctpMessage(&barbell->transport, - header.id, - header.sequenceNumber, - header.payloadProtocol, - header.data(), - header.getMessageLength(packetSize)); + barbell->dataChannel.onSctpMessageBuffer(&barbell->transport, message); const auto newState = barbell->dataChannel.getState(); if (state != newState && newState == webrtc::WebRtcDataStream::State::OPEN) @@ -511,35 +507,55 @@ SsrcInboundContext* EngineMixer::emplaceBarbellInboundSsrcContext(const uint32_t void EngineMixer::processBarbellSctp(const uint64_t timestamp) { - for (IncomingPacketInfo packetInfo; _incomingBarbellSctp.pop(packetInfo);) + for (IncomingSctpMessageInfo messageInfo; _incomingBarbellSctp.pop(messageInfo);) { - auto header = reinterpret_cast(packetInfo.packet()->get()); + _lastReceiveTimeOnBarbellTransports = timestamp; + auto& buffer = messageInfo.packet(); + if (!buffer || buffer->empty()) + { + continue; + } + if (buffer->getLength() < sizeof(webrtc::SctpStreamMessageHeader)) + { + continue; + } - if (header->payloadProtocol == webrtc::DataChannelPpid::WEBRTC_STRING) + constexpr size_t MAX_BUFFER_SIZE = 8192; + if (buffer->size() > MAX_BUFFER_SIZE) { + logger::warn("Large barbell SCTP message (size = %zu, max allowed = %zu). Dropping.", _loggableId.c_str(), buffer->getLength(), MAX_BUFFER_SIZE); + continue; + } + + char continousBuffer[buffer->size()]; + buffer->copyTo(continousBuffer, 0, buffer->size()); + + auto* sctpHeader = const_cast( + reinterpret_cast(continousBuffer)); + + if (sctpHeader->payloadProtocol == webrtc::DataChannelPpid::WEBRTC_STRING) { - auto message = reinterpret_cast(header->data()); - const auto messageLength = header->getMessageLength(packetInfo.packet()->getLength()); + const char* message = reinterpret_cast(sctpHeader->data()); + const auto messageLength = buffer->getLength() - sizeof(*sctpHeader); if (messageLength == 0) { - return; + continue; } auto messageJson = utils::SimpleJson::create(message, messageLength); if (api::DataChannelMessageParser::isUserMediaMap(messageJson)) { - onBarbellUserMediaMap(packetInfo.transport()->getEndpointIdHash(), message); + onBarbellUserMediaMap(messageInfo.transport()->getEndpointIdHash(), message); } else if (api::DataChannelMessageParser::isMinUplinkBitrate(messageJson)) { - onBarbellMinUplinkEstimate(packetInfo.transport()->getEndpointIdHash(), message); + onBarbellMinUplinkEstimate(messageInfo.transport()->getEndpointIdHash(), message); } } - else if (header->payloadProtocol == webrtc::DataChannelPpid::WEBRTC_ESTABLISH) + else if (sctpHeader->payloadProtocol == webrtc::DataChannelPpid::WEBRTC_ESTABLISH) { - onBarbellDataChannelEstablish(packetInfo.transport()->getEndpointIdHash(), - *header, - packetInfo.packet()->getLength()); + onBarbellDataChannelEstablish(messageInfo.transport()->getEndpointIdHash(), + std::move(buffer)); } } } diff --git a/bridge/engine/EngineMixerVideo.cpp b/bridge/engine/EngineMixerVideo.cpp index a36b3bf0f..b9153341d 100644 --- a/bridge/engine/EngineMixerVideo.cpp +++ b/bridge/engine/EngineMixerVideo.cpp @@ -938,12 +938,12 @@ bool EngineMixer::asyncPinEndpoint(const size_t endpointIdHash, const size_t tar bool EngineMixer::asyncSendEndpointMessage(const size_t toEndpointIdHash, const size_t fromEndpointIdHash, - memory::UniqueAudioPacket& packet) + memory::UniquePoolBuffer& buffer) { return post(utils::bind(&EngineMixer::sendEndpointMessage, this, toEndpointIdHash, fromEndpointIdHash, - utils::moveParam(packet))); + utils::moveParam(buffer))); } } // namespace bridge diff --git a/config/Config.h b/config/Config.h index b907e4f87..9d74a3603 100644 --- a/config/Config.h +++ b/config/Config.h @@ -49,6 +49,7 @@ class Config : public ConfigReader // fix SCTP port to 5000 to support old CS CFG_PROP(bool, fixedPort, true); CFG_PROP(uint32_t, bufferSize, 50 * 1024); + CFG_PROP(uint32_t, maxMessageSize, 2048); CFG_GROUP_END(sctp); CFG_GROUP() diff --git a/doc/api/READMEapi.md b/doc/api/READMEapi.md index 0b41d951b..00a0b30a7 100644 --- a/doc/api/READMEapi.md +++ b/doc/api/READMEapi.md @@ -185,7 +185,8 @@ POST /conferences/{conferenceId}/{endpointId} ] }, "data": { - "port": 5000 + "port": 5000, + "max-message-size": 2048 } } ``` diff --git a/legacyapi/DataChannelMessage.h b/legacyapi/DataChannelMessage.h index 26d31290b..107ca3c95 100644 --- a/legacyapi/DataChannelMessage.h +++ b/legacyapi/DataChannelMessage.h @@ -1,5 +1,6 @@ #pragma once +#include "memory/PoolBuffer.h" #include "utils/StringBuilder.h" namespace legacyapi @@ -8,7 +9,8 @@ namespace legacyapi namespace DataChannelMessage { -inline void makeEndpointMessage(utils::StringBuilder<2048>& outMessage, +template +inline void makeEndpointMessage(utils::StringBuilder& outMessage, const std::string& toEndpointId, const std::string& fromEndpointId, const char* message) @@ -25,6 +27,40 @@ inline void makeEndpointMessage(utils::StringBuilder<2048>& outMessage, outMessage.append("}"); } +inline memory::UniquePoolBuffer makeUniqueEndpointMessageBuffer( + const std::string& toEndpointId, + const std::string& fromEndpointId, + const memory::UniquePoolBuffer& payload) +{ + constexpr const char* TO_STRING = "{\"colibriClass\":\"EndpointMessage\",\"to\":\""; + constexpr const char* FROM_STRING = "\",\"from\":\""; + constexpr const char* MSG_STRING = "\",\"msgPayload\":"; + constexpr const char* TAIL_STRING = "}"; + + constexpr std::size_t overhead_len = std::char_traits::length(TO_STRING) + + std::char_traits::length(FROM_STRING) + + std::char_traits::length(MSG_STRING) + + std::char_traits::length(TAIL_STRING); + + const std::size_t extraLen = toEndpointId.length() + fromEndpointId.length() + payload->getLength(); + auto buffer = memory::makeUniquePoolBuffer(payload->getAllocator(), overhead_len + extraLen); + if (!buffer) + { + return buffer; + } + + auto written = buffer->copyFrom(TO_STRING, std::char_traits::length(TO_STRING), 0); + written += buffer->copyFrom(toEndpointId.c_str(), toEndpointId.length(), written); + written += buffer->copyFrom(FROM_STRING, std::char_traits::length(FROM_STRING), written); + written += buffer->copyFrom(fromEndpointId.c_str(), fromEndpointId.length(), written); + written += buffer->copyFrom(MSG_STRING, std::char_traits::length(MSG_STRING), written); + written += buffer->copyFrom(*payload.get(), written); + written += buffer->copyFrom(TAIL_STRING, std::char_traits::length(TAIL_STRING), written); + + assert(written == buffer->getLength()); + return buffer; +} + inline void makeDominantSpeakerChange(utils::StringBuilder<256>& outMessage, const char* endpointId) { outMessage.append("{\"colibriClass\":\"DominantSpeakerEndpointChangeEvent\", \"dominantSpeakerEndpoint\":\""); diff --git a/memory/Array.h b/memory/Array.h index 27714bf67..155af8b5c 100644 --- a/memory/Array.h +++ b/memory/Array.h @@ -80,6 +80,7 @@ class Array size_t capacity() const { return _capacity; } size_t size() const { return _size; } + size_t resize(size_t size) { if (size > _capacity) { return _size; } _size = size; return _size; } bool empty() const { return _size == 0; } void clear() diff --git a/memory/PoolAllocator.h b/memory/PoolAllocator.h index 21424d153..67bd71120 100644 --- a/memory/PoolAllocator.h +++ b/memory/PoolAllocator.h @@ -112,6 +112,8 @@ class PoolAllocator size_t size() const { return _count.load(std::memory_order_relaxed); } size_t countAllocatedItems() const { return _originalElementCount - size(); } + size_t getElementSize() const { return ELEMENT_SIZE; } + void* allocate() { concurrency::StackItem* item = nullptr; diff --git a/memory/PoolBuffer.h b/memory/PoolBuffer.h new file mode 100644 index 000000000..08e022132 --- /dev/null +++ b/memory/PoolBuffer.h @@ -0,0 +1,493 @@ +#pragma once + +#include "memory/PoolAllocator.h" +#include "memory/PacketPoolAllocator.h" +#include "memory/Array.h" +#include +#include + +namespace memory +{ +//template +struct ReadonlyMemoryBuffer +{ + ReadonlyMemoryBuffer() : data(nullptr), length(0) {} + + const void* data; + size_t length; + char storage[8192]; +}; + +template +class PoolBuffer +{ +public: + struct Deleter + { + Deleter() : _allocator(nullptr) {} + explicit Deleter(TPoolAllocator& allocator) : _allocator(&allocator) {} + + void operator()(PoolBuffer* p) + { + if (p) + { + p->~PoolBuffer(); + if (_allocator) + { + _allocator->free(p); + } + } + } + + private: + TPoolAllocator* _allocator; + }; + + explicit PoolBuffer(TPoolAllocator& allocator) + : _allocator(allocator), + _masterChunk(nullptr), + _size(0), + _numChunks(0), + _externalMasterChunkSize(0), + _firstChunkSize(0), + _firstChunkIsInMaster(false) + {} + + explicit PoolBuffer(TPoolAllocator& allocator, void* preallocatedMasterChunk, size_t _masterChunkSize) + : _allocator(allocator), + _masterChunk(preallocatedMasterChunk), + _size(0), + _numChunks(0), + _externalMasterChunkSize(_masterChunkSize), // This buffer does NOT own masterChunk + _firstChunkSize(0), + _firstChunkIsInMaster(false) + {} + + PoolBuffer(PoolBuffer&& other) noexcept + : _allocator(other._allocator), + _masterChunk(other._masterChunk), + _size(other._size), + _numChunks(other._numChunks), + _externalMasterChunkSize(other._externalMasterChunkSize), + _firstChunkSize(other._firstChunkSize), + _firstChunkIsInMaster(other._firstChunkIsInMaster) + { + other._masterChunk = nullptr; + other._size = 0; + other._numChunks = 0; + other._firstChunkSize = 0; + other._firstChunkIsInMaster = false; + } + + PoolBuffer& operator=(PoolBuffer&& other) noexcept + { + if (this != &other) + { + clear(); + _masterChunk = other._masterChunk; + _externalMasterChunkSize = other._externalMasterChunkSize; + _size = other._size; + _numChunks = other._numChunks; + _firstChunkSize = other._firstChunkSize; + _firstChunkIsInMaster = other._firstChunkIsInMaster; + + other._masterChunk = nullptr; + other._externalMasterChunkSize = 0; + other._size = 0; + other._numChunks = 0; + other._firstChunkSize = 0; + other._firstChunkIsInMaster = false; + } + return *this; + } + + ~PoolBuffer() { clear(); } + + PoolBuffer(const PoolBuffer&) = delete; + PoolBuffer& operator=(const PoolBuffer&) = delete; + + bool allocate(size_t size) + { + if (size > capacity()) + { + clear(); + const auto elementSize = _allocator.getElementSize(); + auto masterChunkCapacity = _externalMasterChunkSize > 0 ? _externalMasterChunkSize : elementSize; + + if (size == 0) + { + _size = 0; + return true; + } + + if (0 == _externalMasterChunkSize || !_masterChunk) + { + _masterChunk = _allocator.allocate(); + if (!_masterChunk) + { + return false; + } + _externalMasterChunkSize = 0; + masterChunkCapacity = elementSize; + } + + size_t numChunks = (size + elementSize - 1) / elementSize; + if (size <= masterChunkCapacity) + { + numChunks = 1; + } + size_t pointersAreaSize = numChunks * sizeof(void*); + + if (masterChunkCapacity > pointersAreaSize) + { + const size_t firstChunkSize = masterChunkCapacity - pointersAreaSize; + if (size > firstChunkSize) + { + numChunks = 1 + (size - firstChunkSize + elementSize - 1) / elementSize; + } + else + { + numChunks = 1; + } + + pointersAreaSize = numChunks * sizeof(void*); + if (masterChunkCapacity > pointersAreaSize) + { + _firstChunkSize = masterChunkCapacity - pointersAreaSize; + _numChunks = numChunks; + _firstChunkIsInMaster = true; + } + else + { + _numChunks = (size + elementSize - 1) / elementSize; + } + } + else + { + _numChunks = numChunks; + } + + if (masterChunkCapacity < _numChunks * sizeof(void*)) + { + logger::error("master chunk is too small to hold chunk pointers. capacity %zu, required %zu", + "PoolBuffer", + masterChunkCapacity, + _numChunks * sizeof(void*)); + _numChunks = 0; + if (_masterChunk && !_externalMasterChunkSize) + { + _allocator.free(_masterChunk); + _masterChunk = nullptr; + } + return false; + } + + void** chunkPointers = reinterpret_cast(_masterChunk); + size_t i = 0; + if (_firstChunkIsInMaster) + { + chunkPointers[0] = reinterpret_cast(_masterChunk) + _numChunks * sizeof(void*); + i = 1; + } + + for (; i < _numChunks; ++i) + { + chunkPointers[i] = _allocator.allocate(); + if (!chunkPointers[i]) + { + for (size_t j = _firstChunkIsInMaster ? 1 : 0; j < i; ++j) + { + _allocator.free(chunkPointers[j]); + } + if (0 == _externalMasterChunkSize) + { + _allocator.free(_masterChunk); + } + _masterChunk = nullptr; + _numChunks = 0; + return false; + } + } + } + _size = size; + return true; + } + + void clear() + { + if (_masterChunk) + { + void** chunkPointers = reinterpret_cast(_masterChunk); + for (size_t i = _firstChunkIsInMaster ? 1 : 0; i < _numChunks; ++i) + { + if (chunkPointers[i]) + { + _allocator.free(chunkPointers[i]); + } + } + if (0 == _externalMasterChunkSize) + { + _allocator.free(_masterChunk); + } + } + if (0 == _externalMasterChunkSize) + { + _masterChunk = nullptr; + } + + _numChunks = 0; + _size = 0; + _firstChunkIsInMaster = false; + _firstChunkSize = 0; + } + + size_t size() const { return _size; } + size_t getLength() const { return _size; } + size_t capacity() const + { + if (_firstChunkIsInMaster) + { + return _firstChunkSize + (_numChunks - 1) * _allocator.getElementSize(); + } + return _numChunks * _allocator.getElementSize(); + } + bool empty() const { return _size == 0 || _numChunks == 0 || _masterChunk == nullptr; } + size_t getChunkCount() const { return _numChunks; } + bool isMultiChunk() const { return _numChunks > 1; } + + template + size_t copyFrom(const PoolBuffer& sourceBuffer, + size_t sourceOffset, + size_t len, + size_t destinationOffset = 0) + { + const auto destSize = size(); + const auto srcSize = sourceBuffer.size(); + if (destinationOffset >= destSize || sourceOffset >= srcSize) + { + return 0; + } + + const size_t bytesToCopy = std::min({len, srcSize - sourceOffset, destSize - destinationOffset}); + if (bytesToCopy == 0) + { + return 0; + } + + size_t totalBytesCopied = 0; + size_t currentSrcOffset = sourceOffset; + + auto callback = [&](uint8_t* block, size_t blockSize) { + const auto copied = sourceBuffer.copyTo(block, currentSrcOffset, blockSize); + totalBytesCopied += copied; + currentSrcOffset += copied; + return copied == blockSize; // Continue only if we copied the whole block + }; + + forEachBlock(destinationOffset, bytesToCopy, callback); + + return totalBytesCopied; + } + + size_t copyFrom(const PoolBuffer& src, size_t destinationOffset = 0) { + return copyFrom(src, 0, src.size(), destinationOffset); + } + + size_t copyFrom(const void* source, size_t len, size_t destinationOffset = 0) + { + const uint8_t* sourceData = static_cast(source); + size_t bytesCopied = 0; + const size_t remainingToCopy = std::min(len, _size > destinationOffset ? _size - destinationOffset : 0); + + auto callback = [&](uint8_t* block, size_t blockSize) { + std::memcpy(block, sourceData + bytesCopied, blockSize); + bytesCopied += blockSize; + return true; + }; + + forEachBlock(destinationOffset, remainingToCopy, callback); + return bytesCopied; + } + + size_t copyTo(void* destination, size_t sourceOffset, size_t count) const + { + if (!destination) + { + return 0; + } + + size_t bytesCopied = 0; + auto callback = [&](const uint8_t* block, size_t blockSize) { + std::memcpy(static_cast(destination) + bytesCopied, block, blockSize); + bytesCopied += blockSize; + return true; + }; + + forEachBlock(sourceOffset, count, callback); + return bytesCopied; + } + + TPoolAllocator& getAllocator() const { return _allocator; } + + bool isNullTerminated() const { + if (_size == 0) { + return false; + } + void** chunkPointers = reinterpret_cast(_masterChunk); + const auto chunkAndOffset = getChunkAndOffset(_size - 1); + const auto data = chunkPointers[chunkAndOffset.first]; + return reinterpret_cast(data)[chunkAndOffset.second] == '\0'; + } + +private: + size_t getChunkSize(size_t chunkIndex) const + { + if (_firstChunkIsInMaster && chunkIndex == 0) + { + return _firstChunkSize; + } + return _allocator.getElementSize(); + } + + std::pair getChunkAndOffset(size_t absoluteOffset) const + { + const auto elementSize = _allocator.getElementSize(); + if (_firstChunkIsInMaster && _numChunks > 0) + { + if (absoluteOffset < _firstChunkSize) + { + return {0, absoluteOffset}; + } + else + { + return {1 + (absoluteOffset - _firstChunkSize) / elementSize, + (absoluteOffset - _firstChunkSize) % elementSize}; + } + } + else + { + return {absoluteOffset / elementSize, absoluteOffset % elementSize}; + } + } + + template + void forEachBlockImpl(size_t offset, size_t count, F& callback) const + { + size_t processedBytes = 0; + const size_t remainingToProcess = std::min(count, _size > offset ? _size - offset : 0); + + if (!_masterChunk || remainingToProcess == 0) + { + return; + } + + auto chunkInfo = getChunkAndOffset(offset); + size_t chunkIndex = chunkInfo.first; + size_t offsetInChunk = chunkInfo.second; + + auto** chunkPointers = reinterpret_cast(_masterChunk); + + while (processedBytes < remainingToProcess && chunkIndex < _numChunks) + { + const auto currentChunkSize = getChunkSize(chunkIndex); + const size_t toProcessInChunk = + std::min(remainingToProcess - processedBytes, currentChunkSize - offsetInChunk); + + if (!callback(static_cast(static_cast(chunkPointers[chunkIndex]) + offsetInChunk), + toProcessInChunk)) + { + break; + } + + processedBytes += toProcessInChunk; + offsetInChunk = 0; + chunkIndex++; + } + } + + template + void forEachBlock(size_t offset, size_t count, F& callback) + { + forEachBlockImpl(offset, count, callback); + } + + template + void forEachBlock(size_t offset, size_t count, F& callback) const + { + forEachBlockImpl(offset, count, callback); + } + + TPoolAllocator& _allocator; + void* _masterChunk; + size_t _size; + size_t _numChunks; + size_t _externalMasterChunkSize; + size_t _firstChunkSize; + bool _firstChunkIsInMaster; +}; + +template +using UniquePoolBuffer = std::unique_ptr, typename PoolBuffer::Deleter>; + +template +inline UniquePoolBuffer makeUniquePoolBuffer(TPoolAllocator& allocator, + size_t length) +{ + auto pointer = allocator.allocate(); + assert(pointer); + if (!pointer) + { + logger::error("Unable to allocate pool buffer, no space left in pool %s", + "PoolBuffer", + allocator.getName().c_str()); + return UniquePoolBuffer(); + } + + constexpr auto poolBufferAlignment = alignof(PoolBuffer); + const auto alignedPoolBufferSize = + (sizeof(PoolBuffer) + poolBufferAlignment - 1) & ~(poolBufferAlignment - 1); + + const auto elementSize = allocator.getElementSize(); + const auto storageSize = elementSize > alignedPoolBufferSize ? elementSize - alignedPoolBufferSize : 0; + const auto maxChunkCount = elementSize > sizeof(void*) ? elementSize / sizeof(void*) : 0; + const auto optChunkCount = storageSize > sizeof(void*) ? storageSize / sizeof(void*) : 0; + + // Check that we can theoretically fit enough pointers to chunk into master chunk to accomodate all data. + if (maxChunkCount * elementSize < length) { + logger::error("Unable to allocate pool buffer, master chunk it too small %s", + "PoolBuffer", + allocator.getName().c_str()); + allocator.free(pointer); + return UniquePoolBuffer(); + } + + // Try to fit master chunk in already allocated pointer right after PoolBuffer*. + void* masterChunk = reinterpret_cast(pointer) + alignedPoolBufferSize; + auto buffer = (optChunkCount * elementSize >= length) + ? new (pointer) PoolBuffer(allocator, masterChunk, storageSize) + : new (pointer) PoolBuffer(allocator); + + auto smartBuffer = + UniquePoolBuffer(buffer, typename PoolBuffer::Deleter(allocator)); + + if (!smartBuffer->allocate(length)) + { + return UniquePoolBuffer(); + } + return smartBuffer; +} + +template +inline UniquePoolBuffer makeUniquePoolBuffer(TPoolAllocator& allocator, + const void* data, + size_t length) +{ + auto buffer = makeUniquePoolBuffer(allocator, length); + if (buffer && data && buffer->copyFrom(data, length) != length) + { + return nullptr; + } + + return buffer; +} +} \ No newline at end of file diff --git a/test/bridge/ActiveMediaListTest.cpp b/test/bridge/ActiveMediaListTest.cpp index 8380ca1df..f4a340cb1 100644 --- a/test/bridge/ActiveMediaListTest.cpp +++ b/test/bridge/ActiveMediaListTest.cpp @@ -88,7 +88,8 @@ class ActiveMediaListTest : public ::testing::Test _timers = std::make_unique(4096 * 8); _jobManager = std::make_unique(*_timers); _jobQueue = std::make_unique(*_jobManager); - _transport = std::make_unique(*_jobQueue); + _allocator = std::make_unique(16, "dummy"); + _transport = std::make_unique(*_jobQueue, *_allocator); _activeMediaList = std::make_unique(1, _audioSsrcs, _videoSsrcs, defaultLastN, audioLastN, 18); @@ -131,6 +132,7 @@ class ActiveMediaListTest : public ::testing::Test std::unique_ptr _jobManager; std::unique_ptr _jobQueue; std::unique_ptr _transport; + std::unique_ptr _allocator; concurrency::MpmcHashmap32 _engineAudioStreams; concurrency::MpmcHashmap32 _engineVideoStreams; diff --git a/test/bridge/ApiRequestHandlerTest.cpp b/test/bridge/ApiRequestHandlerTest.cpp index b4410cb9b..b09e9f062 100644 --- a/test/bridge/ApiRequestHandlerTest.cpp +++ b/test/bridge/ApiRequestHandlerTest.cpp @@ -330,13 +330,22 @@ TEST_F(ApiRequestHandlerTest, allocateEndpointWithVideoFieldWhenVideoIsDisabledS auto audioIt = responseJson.find("audio"); auto bundleTransportIt = responseJson.find("bundle-transport"); - auto data = responseJson.find("data"); + auto dataIt = responseJson.find("data"); // Check if answer contains the only 3 expected (audio, bundle-transport and data ) EXPECT_EQ(3, responseJson.size()); EXPECT_NE(responseJson.end(), audioIt); EXPECT_NE(responseJson.end(), bundleTransportIt); - EXPECT_NE(responseJson.end(), data); + EXPECT_NE(responseJson.end(), dataIt); + + const auto dataJson = *dataIt; + const auto portIt = dataJson.find("port"); + EXPECT_NE(dataJson.end(), portIt); + EXPECT_EQ(5000, portIt->get()); + + const auto maxMessageSizeIt = dataJson.find("max-message-size"); + EXPECT_NE(dataJson.end(), maxMessageSizeIt); + EXPECT_EQ(_mixerManagerSpyResources->config.sctp.maxMessageSize, maxMessageSizeIt->get()); // Because this test is explicit to test that the video is not present. Let's check it explicitly EXPECT_EQ(responseJson.end(), responseJson.find("video")); diff --git a/test/bridge/DataChannelMessageSizeTest.cpp b/test/bridge/DataChannelMessageSizeTest.cpp new file mode 100644 index 000000000..35aacaacd --- /dev/null +++ b/test/bridge/DataChannelMessageSizeTest.cpp @@ -0,0 +1,546 @@ +#include "api/DataChannelMessage.h" +#include "api/DataChannelMessageParser.h" +#include "memory/PacketPoolAllocator.h" +#include "memory/PoolBuffer.h" +#include "bridge/AudioStream.h" +#include "bridge/DataStream.h" +#include "bridge/Mixer.h" +#include "bridge/VideoStream.h" +#include "bridge/engine/EngineAudioStream.h" +#include "bridge/engine/EngineBarbell.h" +#include "bridge/engine/EngineDataStream.h" +#include "bridge/engine/EngineMixer.h" +#include "bridge/engine/EngineRecordingStream.h" +#include "bridge/engine/EngineVideoStream.h" +#include "config/Config.h" +#include "jobmanager/JobManager.h" +#include "jobmanager/TimerQueue.h" +#include "memory/AudioPacketPoolAllocator.h" +#include "memory/PoolBuffer.h" +#include "memory/PacketPoolAllocator.h" +#include "mocks/MixerManagerAsyncMock.h" +#include "mocks/RtcTransportMock.h" +#include "mocks/TransportFactoryMock.h" +#include "utils/Function.h" +#include "utils/IdGenerator.h" +#include "utils/Optional.h" +#include "utils/SsrcGenerator.h" +#include "utils/StdExtensions.h" +#include "utils/SimpleJson.h" +#include "webrtc/DataChannel.h" +#include +#include +#include +#include +#include + +#include "bridge/MixerManager.h" + +using namespace ::testing; +using namespace ::test; +using namespace ::bridge; + +namespace +{ +struct JobManagerProcessor +{ + JobManagerProcessor(jobmanager::TimerQueue& timeQueue) : jobManager(timeQueue, 512) {} + ~JobManagerProcessor() { dropAll(); } + + void process() + { + jobmanager::MultiStepJob* job; + while ((job = jobManager.pop()) != nullptr) + { + if (!job->runStep()) + { + jobManager.freeJob(job); + } + else + { + pendingJobs.push_back(job); + } + } + } + + void activatePendingJobs() { + for (auto job : pendingJobs) { + jobManager.addJobItem(job); + } + pendingJobs.clear(); + } + + void dropAll() + { + jobmanager::MultiStepJob* job; + while ((job = jobManager.pop()) != nullptr) + { + jobManager.freeJob(job); + } + + for (auto* pendingJob : pendingJobs) + { + jobManager.freeJob(pendingJob); + } + pendingJobs.clear(); + } + + jobmanager::JobManager& getJobManager() { return jobManager; } + +private: + jobmanager::JobManager jobManager; + std::vector pendingJobs; +}; + +struct MixerTestScope +{ + MixerTestScope() + : transportFactoryMock(std::make_shared>()), + engineTaskQueue(512), + engineSyncContext(engineTaskQueue), + timeQueue(64), + jobManagerProcessor(timeQueue), + backgroundJobManagerProcessor(timeQueue), + mainPacketAllocator(128 * 1024, "MainAllocator-test"), + sendPacketAllocator(32 * 1024, "SendAllocator-test"), + audioPacketAllocator(4 * 1024, "AudioAllocator-test") + { + } + + std::shared_ptr> transportFactoryMock; + StrictMock mixerManagerAsyncMock; + concurrency::MpmcQueue engineTaskQueue; + concurrency::SynchronizationContext engineSyncContext; + jobmanager::TimerQueue timeQueue; + JobManagerProcessor jobManagerProcessor; + JobManagerProcessor backgroundJobManagerProcessor; + memory::PacketPoolAllocator mainPacketAllocator; + memory::PacketPoolAllocator sendPacketAllocator; + memory::AudioPacketPoolAllocator audioPacketAllocator; +}; +} // namespace + +class DataChannelMessageSizeTest : public ::testing::Test +{ +public: + DataChannelMessageSizeTest() : _testScope(std::make_unique()) {} + +protected: + struct DataChannelEndpoints + { + std::shared_ptr> transport0; + std::shared_ptr> transport1; + const std::string endpointId0 = "endpoint-0"; + const std::string endpointId1 = "endpoint-1"; + size_t endpointId0Hash; + size_t endpointId1Hash; + }; + + DataChannelEndpoints createDataChannelEndpoints() + { + DataChannelEndpoints endpoints; + endpoints.endpointId0Hash = utils::hash{}(endpoints.endpointId0); + endpoints.endpointId1Hash = utils::hash{}(endpoints.endpointId1); + + endpoints.transport0 = std::make_shared>(); + endpoints.transport1 = std::make_shared>(); + + _testScope->transportFactoryMock->willReturnByDefaultForAll(nullptr); + EXPECT_CALL(*_testScope->transportFactoryMock, create(_, endpoints.endpointId0Hash, _, _, _, _, _, _)) + .WillOnce(Return(endpoints.transport0)); + EXPECT_CALL(*_testScope->transportFactoryMock, create(_, endpoints.endpointId1Hash, _, _, _, _, _, _)) + .WillOnce(Return(endpoints.transport1)); + + return endpoints; + } + + void connectDataStreams(Mixer& mixer, const DataChannelEndpoints& endpoints) + { + mixer.addBundleTransportIfNeeded(endpoints.endpointId0, ice::IceRole::CONTROLLING, false, false); + mixer.addBundleTransportIfNeeded(endpoints.endpointId1, ice::IceRole::CONTROLLING, false, false); + + std::string dataStreamId; + ASSERT_TRUE(mixer.addBundledDataStream(dataStreamId, endpoints.endpointId0)); + ASSERT_TRUE(mixer.addBundledDataStream(dataStreamId, endpoints.endpointId1)); + mixer.configureDataStream(endpoints.endpointId0, 5000); + mixer.configureDataStream(endpoints.endpointId1, 5000); + mixer.addDataStreamToEngine(endpoints.endpointId0); + mixer.addDataStreamToEngine(endpoints.endpointId1); + } + + void openDataChannel(Mixer& mixer, const std::string& endpointsId) + { + alignas(memory::Packet) const char webRtcOpen[] = + "\x03\x00\x00\x00\x00\x00\x00\x00\x00\x12\x00\x00\x77\x65\x62\x72" + "\x74\x63\x2d\x64\x61\x74\x61\x63\x68\x61\x6e\x6e\x65\x6c\x00\x00"; + + webrtc::SctpStreamMessageHeader header = {webrtc::DataChannelPpid::WEBRTC_ESTABLISH, 0, 0}; + auto buffer = memory::makeUniquePoolBuffer(_testScope->mainPacketAllocator, sizeof(webrtc::SctpStreamMessageHeader) + sizeof(webRtcOpen)); + buffer->copyFrom(&header, sizeof(webrtc::SctpStreamMessageHeader), 0); + buffer->copyFrom(webRtcOpen, sizeof(webRtcOpen) - 1, sizeof(webrtc::SctpStreamMessageHeader)); + + auto* dataStream = mixer.getEngineDataStream(endpointsId); + ASSERT_NE(nullptr, dataStream); + dataStream->stream.onSctpMessageBuffer(&dataStream->transport, buffer); + } + + void openDataChannels(Mixer& mixer, const DataChannelEndpoints& endpoints) + { + openDataChannel(mixer, endpoints.endpointId0); + openDataChannel(mixer, endpoints.endpointId1); + } + + void gracefullyTerminateMixerManager(bridge::Mixer* mixer, + bridge::MixerManager& mixerManager, + JobManagerProcessor& backgroundJobManager) + { + auto id = mixer->getId(); + mixerManager.remove(id); + processAllEngineQueue(); // will be removed on engine thread + backgroundJobManager.process(); // and event posted to background thread of MixerManager + + std::thread backgroundProcessor([&backgroundJobManager]() { + for (int i = 0; i < 300; ++i) + { + backgroundJobManager.process(); + if (i % 10 == 0) + { + backgroundJobManager.activatePendingJobs(); + } + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + }); + + mixerManager.stop(); + backgroundProcessor.join(); + } + + + void SetUp() override + { + _config.sctp.maxMessageSize = 4096; + std::vector audioSsrcs = {1, 2, 3}; + std::vector videoSsrcs; + std::vector videoPinSsrcs; + + auto engineMixer = std::make_unique("test-mixer", + _testScope->jobManagerProcessor.getJobManager(), + _testScope->engineSyncContext, + _testScope->backgroundJobManagerProcessor.getJobManager(), + _testScope->mixerManagerAsyncMock, + 0, + _config, + _testScope->sendPacketAllocator, + _testScope->audioPacketAllocator, + _testScope->mainPacketAllocator, + audioSsrcs, + videoSsrcs, + 0); + + _mixer = std::make_unique("test-mixer", + 1, + *_testScope->transportFactoryMock, + _testScope->backgroundJobManagerProcessor.getJobManager(), + std::move(engineMixer), + _idGenerator, + _ssrcGenerator, + _config, + audioSsrcs, + videoSsrcs, + videoPinSsrcs, + VideoCodecSpec::makeVp8(), + true); + } + + void TearDown() override + { + utils::Function func; + while (_testScope->engineTaskQueue.pop(func)) + { + } + _testScope->jobManagerProcessor.dropAll(); + _testScope->backgroundJobManagerProcessor.dropAll(); + + _mixer.reset(); + _testScope.reset(); + } + + void processAllEngineQueue() + { + utils::Function func; + while (_testScope->engineTaskQueue.pop(func)) + { + func(); + } + } + + config::Config _config; + std::unique_ptr _testScope; + std::unique_ptr _mixer; + utils::IdGenerator _idGenerator; + utils::SsrcGenerator _ssrcGenerator; +}; + +struct SendMessageSizeTestParam +{ + size_t payloadSize; + int expectedSendSctpCalls; +}; + +class DataChannelSendMessageSizeTest : public DataChannelMessageSizeTest, + public WithParamInterface +{ +}; + +TEST_P(DataChannelSendMessageSizeTest, sendLargeEndpointMessage) +{ + const auto param = GetParam(); + auto endpoints = createDataChannelEndpoints(); + + ON_CALL(*endpoints.transport0, getAllocator()).WillByDefault(ReturnRef(_testScope->mainPacketAllocator)); + ON_CALL(*endpoints.transport1, getAllocator()).WillByDefault(ReturnRef(_testScope->mainPacketAllocator)); + + // 1. ARRANGE: setup data channels and message size + connectDataStreams(*_mixer, endpoints); + processAllEngineQueue(); + openDataChannels(*_mixer, endpoints); + std::string largeMessage(param.payloadSize, 'a'); + + // 2. ACT: create a message of size close to or exceeding max and attempt to send it via 'sendEndpointMessage' + utils::StringBuilder<8192> builder; + std::string quotedPayload = "\"" + largeMessage + "\""; + api::DataChannelMessage::makeEndpointMessage(builder, endpoints.endpointId1, endpoints.endpointId0, quotedPayload.c_str()); + auto expectedJson = utils::SimpleJson::create(builder.get(), builder.getLength()); + + const auto payloadJson = api::DataChannelMessageParser::getEndpointMessagePayload(expectedJson); + ASSERT_FALSE(payloadJson.isNone()); + + // 3. ASSERT: if message size is smaller than _config.sctp.maxMessageSize = 4096 send should succeed, otherwise fail + auto& expect = EXPECT_CALL(*endpoints.transport1, sendSctp(_, _, _)) + .Times(param.expectedSendSctpCalls); + if (param.expectedSendSctpCalls > 0) + { + expect.WillOnce(Invoke( + [&](uint16_t streamId, uint32_t protocolId, memory::UniquePoolBuffer buffer) { + char continuousBuffer[buffer->getLength()]; + buffer->copyTo(continuousBuffer, 0, buffer->getLength()); + std::string sentData(continuousBuffer, buffer->getLength()); + + EXPECT_EQ(std::string(expectedJson.jsonBegin(), expectedJson.size()), sentData); + return true; + })); + } + + _mixer->sendEndpointMessage(endpoints.endpointId1, endpoints.endpointId0Hash, payloadJson); + processAllEngineQueue(); +} + +INSTANTIATE_TEST_SUITE_P(DataChannelMessageSize, + DataChannelSendMessageSizeTest, + Values(SendMessageSizeTestParam{4000, 1}, // size of endpoint message < 4096 - send should happen (1 time) + SendMessageSizeTestParam{4096, 0}), // size of endpoint message > 4096 - send should fail (happen 0 times) + [](const testing::TestParamInfo& info) { + if (info.param.expectedSendSctpCalls > 0) + { + return "Succeeds"; + } + return "Fails"; + }); + +struct ForwardMessageSizeTestParam +{ + size_t payloadSize; + int expectedSendSctpCalls; +}; + +class MockEngine : public bridge::Engine +{ +public: + MockEngine(jobmanager::JobManager& backgroundJobQueue) : bridge::Engine(backgroundJobQueue, {}) {} + + MOCK_METHOD(bool, post, (utils::Function && task), (override)); + MOCK_METHOD(concurrency::SynchronizationContext, getSynchronizationContext, (), (override)); +}; + +class DataChannelForwardMessageSizeTest : public DataChannelMessageSizeTest, + public WithParamInterface +{ +}; + +TEST_P(DataChannelForwardMessageSizeTest, forwardLargeEndpointMessage) +{ + const auto param = GetParam(); + + // 1. ARRANGE: setup mixer, engine, data channels and message size + jobmanager::TimerQueue timerQueue(1024); + JobManagerProcessor backgroundJobManager(timerQueue); + + concurrency::MpmcQueue& engineQueue = _testScope->engineTaskQueue; + NiceMock engine(backgroundJobManager.getJobManager()); + ON_CALL(engine, post(_)).WillByDefault(Invoke([&engineQueue](utils::Function&& task) { + return engineQueue.push(std::move(task)); + })); + ON_CALL(engine, getSynchronizationContext()).WillByDefault(Return(concurrency::SynchronizationContext(_testScope->engineTaskQueue))); + + bridge::MixerManager mixerManager(_idGenerator, + _ssrcGenerator, + _testScope->jobManagerProcessor.getJobManager(), + backgroundJobManager.getJobManager(), + *_testScope->transportFactoryMock, + engine, + _config, + _testScope->mainPacketAllocator, + _testScope->sendPacketAllocator, + _testScope->audioPacketAllocator); + auto mixer = mixerManager.create(utils::Optional(5), true, false); + + auto endpoints = createDataChannelEndpoints(); + + ON_CALL(*endpoints.transport0, getAllocator()).WillByDefault(ReturnRef(_testScope->mainPacketAllocator)); + ON_CALL(*endpoints.transport1, getAllocator()).WillByDefault(ReturnRef(_testScope->mainPacketAllocator)); + + ON_CALL(*endpoints.transport0, getTag()).WillByDefault(Return("tag-transport0")); + ON_CALL(*endpoints.transport1, getTag()).WillByDefault(Return("tag-transport1")); + ON_CALL(*endpoints.transport0, getEndpointIdHash()).WillByDefault(Return(endpoints.endpointId0Hash)); + ON_CALL(*endpoints.transport1, getEndpointIdHash()).WillByDefault(Return(endpoints.endpointId1Hash)); + + connectDataStreams(*mixer, endpoints); + + backgroundJobManager.process(); + + openDataChannels(*mixer, endpoints); + + std::string largeMessage(param.payloadSize, 'a'); + + // 2. ACT: create a message of size close to or exceeding max and attempt to FORWARD it via 'onSctpMessage' + // FORWARD: engine mixer receives the in 'onSctpMessage' and later transport sends/forwards it via 'sendSctp' + utils::StringBuilder<8192> builder; + std::string quotedPayload = "\"" + largeMessage + "\""; + api::DataChannelMessage::makeEndpointMessage(builder, endpoints.endpointId1, endpoints.endpointId0, quotedPayload.c_str()); + const auto message = builder.get(); + + // 3. ASSERT: if message size is smaller than _config.sctp.maxMessageSize = 4096 FORWARD should succeed, otherwise fail + auto& expect = EXPECT_CALL(*endpoints.transport1, sendSctp(_, _, _)) + .Times(param.expectedSendSctpCalls); + if (param.expectedSendSctpCalls > 0) + { + expect.WillOnce(Invoke([&](uint16_t streamId, uint32_t protocolId, memory::UniquePoolBuffer buffer) { + char continuousBuffer[buffer->getLength()]; + buffer->copyTo(continuousBuffer, 0, buffer->getLength()); + std::string sentData(continuousBuffer, buffer->getLength()); + + EXPECT_EQ(message, sentData); + return true; + })); + } + + mixer->getEngineMixer()->onSctpMessage(&mixer->getEngineDataStream(endpoints.endpointId0)->transport, + 0, + 0, + webrtc::DataChannelPpid::WEBRTC_STRING, + message, + builder.getLength()); + + backgroundJobManager.process(); + processAllEngineQueue(); + + gracefullyTerminateMixerManager(mixer, mixerManager, backgroundJobManager); +} + +INSTANTIATE_TEST_SUITE_P(DataChannelMessageSize, + DataChannelForwardMessageSizeTest, + Values(ForwardMessageSizeTestParam{4000, 1}, // size of endpoint message < 4096 - send should happen (1 time) + ForwardMessageSizeTestParam{4097, 0}), // size of endpoint message > 4096 - send should fail (happen 0 times) + [](const testing::TestParamInfo& info) { + if (info.param.expectedSendSctpCalls > 0) + { + return "Succeeds"; + } + return "Fails"; + }); + +TEST(DataChannelMessageTest, makeLoggableStringFromBuffer_smallBufferEllipsis) +{ + // Test case for T < 4 where ellipsis is needed + // This targets the potential buffer underflow before the fix. + + memory::PacketPoolAllocator testAllocator(1024, "TestAllocator"); + + // Test with T = 3, payload "abcde" + memory::Array outArray3; + auto payload = memory::makeUniquePoolBuffer(testAllocator, 5); + payload->copyFrom("abcde", 5, 0); + + // Call the function - it should not crash + api::DataChannelMessage::makeLoggableStringFromBuffer(outArray3, payload); + + // Verify it's null-terminated and no crash + ASSERT_EQ(outArray3.size(), 3); + ASSERT_EQ(std::string(outArray3.data()), "ab"); + + // Test with T = 2, payload "abcde" + memory::Array outArray2; + auto payload2 = memory::makeUniquePoolBuffer(testAllocator, 5); + payload2->copyFrom("abcde", 5, 0); + + api::DataChannelMessage::makeLoggableStringFromBuffer(outArray2, payload); + ASSERT_EQ(outArray2.size(), 2); + ASSERT_EQ(std::string(outArray2.data()), "a"); + + // Test with T = 1, payload "abcde" + memory::Array outArray1; + + api::DataChannelMessage::makeLoggableStringFromBuffer(outArray1, payload); + ASSERT_EQ(outArray1.size(), 1); + ASSERT_EQ(std::string(outArray1.data()), ""); +} + +TEST(DataChannelMessageTest, makeLoggableStringFromBuffer_bigBufferEllipsis) +{ + memory::PacketPoolAllocator testAllocator(1024, "TestAllocator"); + std::string payloadString = "abcdefghjklmnopqrstuvwxyz"; + + // Test with T = 10 + memory::Array outArray10; + auto payload = memory::makeUniquePoolBuffer(testAllocator, payloadString.length()); + payload->copyFrom(payloadString.c_str(), payloadString.length(), 0); + + // Call the function - it should not crash + api::DataChannelMessage::makeLoggableStringFromBuffer(outArray10, payload); + + // Verify it's null-terminated and no crash + ASSERT_EQ(outArray10.size(), 10); + + std::string resultStr10(outArray10.data()); + + ASSERT_EQ(resultStr10.length(), 9); + ASSERT_EQ(resultStr10, "abcdef..."); + + // Test with T = 5 + memory::Array outArray5; + + // Call the function - it should not crash + api::DataChannelMessage::makeLoggableStringFromBuffer(outArray5, payload); + + // Verify it's null-terminated and no crash + ASSERT_EQ(outArray5.size(), 5); + + std::string resultStr5(outArray5.data()); + + ASSERT_EQ(resultStr5.length(), 4); + ASSERT_EQ(resultStr5, "a..."); + + // Test with T = 4 + memory::Array outArray4; + + // Call the function - it should not crash + api::DataChannelMessage::makeLoggableStringFromBuffer(outArray4, payload); + + // Verify it's null-terminated and no crash + ASSERT_EQ(outArray4.size(), 4); + + std::string resultStr4(outArray4.data()); + + ASSERT_EQ(resultStr4.length(), 3); + ASSERT_EQ(resultStr4, "..."); +} \ No newline at end of file diff --git a/test/bridge/DummyRtcTransport.h b/test/bridge/DummyRtcTransport.h index e970b9be3..d3de2391f 100644 --- a/test/bridge/DummyRtcTransport.h +++ b/test/bridge/DummyRtcTransport.h @@ -5,7 +5,15 @@ class DummyRtcTransport : public transport::RtcTransport { public: - DummyRtcTransport(jobmanager::JobQueue& jobQueue) : _loggableId(""), _endpointIdHash(1), _jobQueue(jobQueue) {} + DummyRtcTransport(jobmanager::JobQueue& jobQueue, memory::PacketPoolAllocator& allocator) + : _loggableId(""), + _endpointIdHash(1), + _jobQueue(jobQueue), + _allocator(allocator) + { + } + + memory::PacketPoolAllocator& getAllocator() override { return _allocator; } bool isInitialized() const override { return true; } const logger::LoggableId& getLoggableId() const override { return _loggableId; } @@ -93,7 +101,12 @@ class DummyRtcTransport : public transport::RtcTransport void getReportSummary(std::unordered_map& outReportSummary) const override {} - bool sendSctp(uint16_t streamId, uint32_t protocolId, const void* data, uint16_t length) override { return true; } + bool sendSctp(uint16_t streamId, + uint32_t protocolId, + memory::UniquePoolBuffer buffer) override + { + return true; + } uint16_t allocateOutboundSctpStream() override { return 0; } const transport::SocketAddress& getRemotePeer() const override { return _socketAddress; } @@ -117,6 +130,7 @@ class DummyRtcTransport : public transport::RtcTransport logger::LoggableId _loggableId; size_t _endpointIdHash; jobmanager::JobQueue& _jobQueue; + memory::PacketPoolAllocator& _allocator; std::atomic_uint32_t _jobCounter; private: diff --git a/test/bridge/MixerTest.cpp b/test/bridge/MixerTest.cpp index 8c53a01e4..a434cd77c 100644 --- a/test/bridge/MixerTest.cpp +++ b/test/bridge/MixerTest.cpp @@ -167,7 +167,7 @@ struct MixerTestScope timeQueue(64), wtJobManagerProcessor(timeQueue), backgroundJobManagerProcessor(timeQueue), - packetAllocator(4096, "MixerTestPoolAllocator"), + mainPacketAllocator(4096, "MixerTestPoolAllocator"), audioPacketAllocator(4096, "MixerTestAudioPoolAllocator") { } @@ -179,7 +179,7 @@ struct MixerTestScope jobmanager::TimerQueue timeQueue; JobManagerProcessor wtJobManagerProcessor; JobManagerProcessor backgroundJobManagerProcessor; - memory::PacketPoolAllocator packetAllocator; + memory::PacketPoolAllocator mainPacketAllocator; memory::AudioPacketPoolAllocator audioPacketAllocator; }; @@ -207,9 +207,9 @@ class MixerTest : public ::testing::Test _testScope->mixerManagerAsyncMock, LOCAL_VIDEO_SRC, _config, - _testScope->packetAllocator, + _testScope->mainPacketAllocator, _testScope->audioPacketAllocator, - _testScope->packetAllocator, + _testScope->mainPacketAllocator, audioSsrc, videoSsrcs, LAST_N); diff --git a/test/bridge/VideoNackReceiveJobTest.cpp b/test/bridge/VideoNackReceiveJobTest.cpp index 96c0fc9a9..1cfe4c62e 100644 --- a/test/bridge/VideoNackReceiveJobTest.cpp +++ b/test/bridge/VideoNackReceiveJobTest.cpp @@ -26,9 +26,9 @@ class VideoNackReceiveJobTest : public ::testing::Test _timers = std::make_unique(4096 * 8); _jobManager = std::make_unique(*_timers); _jobQueue = std::make_unique(*_jobManager); - _transport = std::make_unique(*_jobQueue); - _allocator = std::make_unique(16, "VideoNackReceiveJobTest"); + _transport = std::make_unique(*_jobQueue, *_allocator); + _mainOutboundContext = std::make_unique(mediaSsrc, *_allocator, VP8_RTP_MAP, bridge::RtpMap::EMPTY); diff --git a/test/include/mocks/EngineMixerSpy.h b/test/include/mocks/EngineMixerSpy.h index afe082547..2b5513cb8 100644 --- a/test/include/mocks/EngineMixerSpy.h +++ b/test/include/mocks/EngineMixerSpy.h @@ -43,6 +43,8 @@ struct EngineMixerSpy : public bridge::EngineMixer // Make IncomingPacketInfo visible using IncomingPacketInfo = bridge::EngineMixer::IncomingPacketInfo; + using IncomingSctpMessagePacketInfo = bridge::EngineMixer::IncomingSctpMessageInfo; + // using parent constructor using bridge::EngineMixer::EngineMixer; @@ -76,7 +78,7 @@ struct EngineMixerSpy : public bridge::EngineMixer } public: - concurrency::MpmcQueue& spyIncomingBarbellSctp() { return _incomingBarbellSctp; }; + concurrency::MpmcQueue& spyIncomingBarbellSctp() { return _incomingBarbellSctp; }; concurrency::MpmcQueue& spyIncomingForwarderAudioRtp() { return _incomingForwarderAudioRtp; }; concurrency::MpmcQueue& spyIncomingRtcp() { return _incomingRtcp; }; concurrency::MpmcQueue& spyIncomingForwarderVideoRtp() { return _incomingForwarderVideoRtp; }; diff --git a/test/include/mocks/MixerManagerAsyncMock.h b/test/include/mocks/MixerManagerAsyncMock.h index 6be8660a1..f97909001 100644 --- a/test/include/mocks/MixerManagerAsyncMock.h +++ b/test/include/mocks/MixerManagerAsyncMock.h @@ -39,7 +39,7 @@ class MixerManagerAsyncMock : public bridge::MixerManagerAsync MOCK_METHOD(void, sctpReceived, - (bridge::EngineMixer & mixer, memory::UniquePacket msgPacket, size_t endpointIdHash), + (bridge::EngineMixer & mixer, memory::UniquePoolBuffer message, size_t endpointIdHash), (override)); MOCK_METHOD(void, diff --git a/test/include/mocks/RtcTransportMock.h b/test/include/mocks/RtcTransportMock.h index be5b4c236..781a1bddc 100644 --- a/test/include/mocks/RtcTransportMock.h +++ b/test/include/mocks/RtcTransportMock.h @@ -87,13 +87,13 @@ class RtcTransportMock : public TransportMock MOCK_METHOD(uint64_t, getLastReceivedPacketTimestamp, (), (const override)); - MOCK_METHOD(bool, - sendSctp, - (uint16_t streamId, uint32_t protocolId, const void* data, uint16_t length), - (override)); + using SctpBuffer = memory::UniquePoolBuffer; + MOCK_METHOD(bool, sendSctp, (uint16_t streamId, uint32_t protocolId, SctpBuffer buffer), (override)); MOCK_METHOD(uint16_t, allocateOutboundSctpStream, (), (override)); + MOCK_METHOD(memory::PacketPoolAllocator&, getAllocator, (), (override)); + MOCK_METHOD(void, getSdesKeys, (std::vector & sdesKeys), (const override)); MOCK_METHOD(void, asyncSetRemoteSdesKey, (const srtp::AesKey& key), (override)); }; diff --git a/test/integration/emulator/SfuClient.h b/test/integration/emulator/SfuClient.h index 53983e8db..5fe699977 100644 --- a/test/integration/emulator/SfuClient.h +++ b/test/integration/emulator/SfuClient.h @@ -657,7 +657,8 @@ class SfuClient : public transport::DataReceiver const void* data, size_t length) override { - _dataStream->onSctpMessage(sender, streamId, streamSequenceNumber, payloadProtocol, data, length); + auto buffer = webrtc::makeUniqueSctpMessage(streamId, payloadProtocol, data, length, _allocator); + _dataStream->onSctpMessageBuffer(sender, buffer); } void onRecControlReceived(transport::RecordingTransport* sender, diff --git a/test/memory/PoolBufferTest.cpp b/test/memory/PoolBufferTest.cpp new file mode 100644 index 000000000..8e2fa3f4b --- /dev/null +++ b/test/memory/PoolBufferTest.cpp @@ -0,0 +1,492 @@ +#include "memory/PoolBuffer.h" +#include "memory/PoolAllocator.h" +#include "memory/Array.h" +#include "test/macros.h" +#include +#include +#include + +TEST(PoolBuffer, create) +{ + memory::PoolAllocator<128> allocator(10, "test"); + memory::PoolBuffer buffer(allocator); + + EXPECT_TRUE(buffer.empty()); + EXPECT_EQ(buffer.size(), 0); + EXPECT_EQ(buffer.capacity(), 0); +} + +TEST(PoolBuffer, allocateSmall) +{ + memory::PoolAllocator<128> allocator(10, "test"); + memory::PoolBuffer buffer(allocator); + + EXPECT_TRUE(buffer.allocate(64)); + EXPECT_FALSE(buffer.empty()); + EXPECT_EQ(buffer.size(), 64); + + // We expect 64 bytes totall fit into master chunk + EXPECT_EQ(buffer.capacity(), 128 - sizeof(void*)); +#if ENABLE_ALLOCATOR_METRICS + EXPECT_EQ(allocator.countAllocatedItems(), 1); +#endif +} + +TEST(PoolBuffer, allocateExact) +{ + // Leave space for single pointer to chunk[0] in the master chunk - + // all data than should fit into the single buffer. + memory::PoolAllocator<128 + sizeof(void*)> allocator(10, "test"); + memory::PoolBuffer buffer(allocator); + + EXPECT_TRUE(buffer.allocate(128)); + EXPECT_EQ(buffer.size(), 128); + EXPECT_EQ(buffer.capacity(), 128); +#if ENABLE_ALLOCATOR_METRICS + EXPECT_EQ(allocator.countAllocatedItems(), 1); +#endif +} + +TEST(PoolBuffer, allocateMultipleChunks) +{ + memory::PoolAllocator<128> allocator(10, "test"); + memory::PoolBuffer buffer(allocator); + + EXPECT_TRUE(buffer.allocate(300)); + EXPECT_EQ(buffer.size(), 300); + // 3 chunks needed: 2 full + 1 in the master chunk, slightly smaller + EXPECT_EQ(buffer.capacity(), (3 - 1) * 128 + (128 - sizeof(void*) * 3)); +#if ENABLE_ALLOCATOR_METRICS + EXPECT_EQ(allocator.countAllocatedItems(), 3); +#endif +} + +TEST(PoolBuffer, allocateFail) +{ + memory::PoolAllocator<128> allocator(2, "test"); + const auto actualElementCount = allocator.size(); + const size_t sizeToRequest = actualElementCount * 128 + 1; + + memory::PoolBuffer buffer(allocator); + + EXPECT_FALSE(buffer.allocate(sizeToRequest)); + EXPECT_TRUE(buffer.empty()); + EXPECT_EQ(buffer.size(), 0); + EXPECT_EQ(buffer.capacity(), 0); +#if ENABLE_ALLOCATOR_METRICS + EXPECT_EQ(allocator.countAllocatedItems(), 0); +#endif +} + +TEST(PoolBuffer, writeAndRead) +{ + memory::PoolAllocator<128> allocator(10, "test"); + memory::PoolBuffer buffer(allocator); + + const size_t dataSize = 300; + EXPECT_TRUE(buffer.allocate(dataSize)); + + std::vector sourceData(dataSize); + for (size_t i = 0; i < dataSize; ++i) + { + sourceData[i] = static_cast(i); + } + + EXPECT_EQ(buffer.copyFrom(sourceData.data(), sourceData.size()), dataSize); + + char destinationData[dataSize]; + EXPECT_EQ(buffer.copyTo(destinationData, 0 , dataSize), dataSize); + + for (size_t i = 0; i < dataSize; ++i) + { + EXPECT_EQ(sourceData[i], static_cast(destinationData[i])); + } +} + +TEST(PoolBuffer, writeAndReadWithOffset) +{ + memory::PoolAllocator<128> allocator(10, "test"); + memory::PoolBuffer buffer(allocator); + + const size_t bufferSize = 400; + EXPECT_TRUE(buffer.allocate(bufferSize)); + + std::vector sourceData(150); + for (size_t i = 0; i < sourceData.size(); ++i) + { + sourceData[i] = static_cast(i); + } + + const size_t writeOffset = 130; // Cross chunk boundary + EXPECT_EQ(buffer.copyFrom(sourceData.data(), sourceData.size(), writeOffset), sourceData.size()); + + char readData[150]; + EXPECT_EQ(buffer.copyTo(readData, writeOffset, sizeof(readData)), sizeof(readData)); + + for (size_t i = 0; i < sourceData.size(); ++i) + { + EXPECT_EQ(sourceData[i], static_cast(readData[i])); + } +} + +TEST(PoolBuffer, move) +{ +// This test use allocator.countAllocatedItems extensively. +#if !ENABLE_ALLOCATOR_METRICS + GTEST_SKIP(); +#endif + memory::PoolAllocator<128> allocator(10, "test"); + memory::PoolBuffer buffer1(allocator); + EXPECT_TRUE(buffer1.allocate(300)); + EXPECT_EQ(allocator.countAllocatedItems(), 3); + + memory::PoolBuffer buffer2 = std::move(buffer1); + EXPECT_EQ(buffer2.size(), 300); + EXPECT_GE(buffer2.capacity(), 300); + EXPECT_EQ(allocator.countAllocatedItems(), 3); + EXPECT_EQ(buffer1.size(), 0); // NOLINT + + buffer2.clear(); + EXPECT_EQ(allocator.countAllocatedItems(), 0); +} + +TEST(PoolBuffer, isNullTerminated) +{ + memory::PoolAllocator<128 + sizeof(void*)> allocator(10, "test"); + memory::PoolBuffer buffer(allocator); + + // Empty buffer + EXPECT_TRUE(buffer.allocate(0)); + EXPECT_FALSE(buffer.isNullTerminated()); + buffer.clear(); + EXPECT_FALSE(buffer.isNullTerminated()); + + // Non-null terminated + const std::string s1 = "123456789a"; + EXPECT_TRUE(buffer.allocate(s1.length())); + buffer.copyFrom(s1.c_str(), s1.length()); + EXPECT_FALSE(buffer.isNullTerminated()); + + // Null terminated + const std::string s2 = "123456789"; + EXPECT_TRUE(buffer.allocate(s2.length() + 1)); + buffer.copyFrom(s2.c_str(), s2.length() + 1); + EXPECT_TRUE(buffer.isNullTerminated()); + + // Null at end of chunk + std::vector testData3(128, 'a'); + testData3[127] = '\0'; + EXPECT_TRUE(buffer.allocate(128)); + EXPECT_FALSE(buffer.isMultiChunk()); + buffer.copyFrom(testData3.data(), 128); + EXPECT_TRUE(buffer.isNullTerminated()); + + // Subview from larger buffer + char testData4[] = {'1', '2', '3', '4', '5', '\0', '6', '7', '8', '\0', 'A'}; + EXPECT_TRUE(buffer.allocate(sizeof(testData4))); + buffer.copyFrom(testData4, sizeof(testData4)); + EXPECT_FALSE(buffer.isNullTerminated()); +} + +TEST(PoolBuffer, deleter) +{ +// This test use allocator.countAllocatedItems extensively. +#if !ENABLE_ALLOCATOR_METRICS + GTEST_SKIP(); +#endif + memory::PoolAllocator<128> allocator(5, "test"); + EXPECT_EQ(allocator.countAllocatedItems(), 0); + + { + auto buffer = memory::makeUniquePoolBuffer(allocator, 20); + EXPECT_TRUE(buffer); + + // PoolBuffer aligned, (48 bytes), + // Master chunk contianing single pointer to data chunk 8 bytes, + // datachunk payload itself (20 bytes) = 76. + // All fits the single allocator's buffer of 128 bytes. + EXPECT_EQ(allocator.countAllocatedItems(), 1); + } + EXPECT_EQ(allocator.countAllocatedItems(), 0); + + { + auto buffer = memory::makeUniquePoolBuffer(allocator, 3 * 128); + EXPECT_TRUE(buffer); + EXPECT_EQ(allocator.countAllocatedItems(), 3 + 1); // 3 chunks of 128 bytes + 1 'master chunk' and PoolBuffer + } + + EXPECT_EQ(allocator.countAllocatedItems(), 0); + + auto buffer2 = memory::makeUniquePoolBuffer(allocator, 3 * 128); + EXPECT_TRUE(buffer2); + EXPECT_EQ(allocator.countAllocatedItems(), 3 + 1); // 3 chunks of 128 bytes + 1 'master chunk' and PoolBuffer +} + +TEST(PoolBuffer, copyToAndNullTermination) +{ + memory::PoolAllocator<128> allocator(10, "test"); + memory::PoolBuffer buffer(allocator); + + // Empty buffer + { + EXPECT_TRUE(buffer.allocate(0)); + EXPECT_FALSE(buffer.isNullTerminated()); + EXPECT_FALSE(buffer.isMultiChunk()); + EXPECT_EQ(buffer.getLength(), 0); + } + + // Single chunk, not null-terminated + { + const std::string testData = "single chunk test"; + EXPECT_TRUE(buffer.allocate(testData.length())); + buffer.copyFrom(testData.c_str(), testData.length()); + EXPECT_FALSE(buffer.isNullTerminated()); + + char readData[buffer.getLength()]; + EXPECT_EQ(buffer.getLength(), testData.length()); + EXPECT_EQ(buffer.copyTo(readData, 0, testData.length()), buffer.getLength()); + EXPECT_EQ(std::memcmp(readData, testData.c_str(), testData.length()), 0); + } + + // Single chunk, null-terminated + { + const std::string testData = "single chunk nullterm"; + EXPECT_TRUE(buffer.allocate(testData.length() + 1)); + buffer.copyFrom(testData.c_str(), testData.length() + 1); + EXPECT_TRUE(buffer.isNullTerminated()); + + char readData[buffer.getLength()]; + EXPECT_EQ(buffer.getLength(), testData.length() + 1); + EXPECT_EQ(buffer.copyTo(readData, 0, testData.length() + 1), buffer.getLength()); + EXPECT_EQ(std::string(readData), testData); + } + + // Multi-chunk, not null-terminated + { + std::vector testData(150, 'm'); + EXPECT_TRUE(buffer.allocate(testData.size())); + buffer.copyFrom(testData.data(), testData.size()); + EXPECT_FALSE(buffer.isNullTerminated()); + + char readData[buffer.getLength()]; + EXPECT_EQ(buffer.getLength(), testData.size()); + EXPECT_EQ(buffer.copyTo(readData, 0, testData.size()), buffer.getLength()); + EXPECT_EQ(std::memcmp(readData, testData.data(), testData.size()), 0); + } + + // Multi-chunk, null-terminated + { + std::vector testData(15, 'n'); + testData.back() = '\0'; + EXPECT_TRUE(buffer.allocate(testData.size())); + buffer.copyFrom(testData.data(), testData.size()); + EXPECT_TRUE(buffer.isNullTerminated()); + + char readData[buffer.getLength()]; + EXPECT_EQ(buffer.getLength(), testData.size()); + EXPECT_EQ(buffer.copyTo(readData, 0, testData.size()), buffer.getLength()); + EXPECT_EQ(std::memcmp(readData, testData.data(), testData.size()), 0); + } +} + +TEST(PoolBuffer, copy) +{ + memory::PoolAllocator<128> allocator(10, "test"); + memory::PoolBuffer buffer(allocator); + + const size_t dataSize = 300; + EXPECT_TRUE(buffer.allocate(dataSize)); + + std::vector sourceData(dataSize); + for (size_t i = 0; i < dataSize; ++i) + { + sourceData[i] = static_cast(i); + } + EXPECT_EQ(buffer.copyFrom(sourceData.data(), sourceData.size()), dataSize); + + // Test copying the full buffer + std::vector destData(dataSize); + EXPECT_EQ(buffer.copyTo(destData.data(), 0, dataSize), dataSize); + EXPECT_EQ(sourceData, destData); + + // Test copying a portion from the beginning + std::fill(destData.begin(), destData.end(), 0); + EXPECT_EQ(buffer.copyTo(destData.data(), 0, 100), 100); + EXPECT_TRUE(std::equal(sourceData.begin(), sourceData.begin() + 100, destData.begin())); + + // Test copying a portion from the middle, crossing a chunk boundary + std::fill(destData.begin(), destData.end(), 0); + const size_t copyOffset = 100; + const size_t copySize = 150; + EXPECT_EQ(buffer.copyTo(destData.data(), copyOffset, copySize), copySize); + EXPECT_TRUE(std::equal(sourceData.begin() + copyOffset, sourceData.begin() + copyOffset + copySize, destData.begin())); + + // Test copying with a count that goes over the end + std::fill(destData.begin(), destData.end(), 0); + EXPECT_EQ(buffer.copyTo(destData.data(), 200, 200), 100); + EXPECT_TRUE(std::equal(sourceData.begin() + 200, sourceData.end(), destData.begin())); + + // Test copying with an offset that is out of bounds + EXPECT_EQ(buffer.copyTo(destData.data(), dataSize, 1), 0); + EXPECT_EQ(buffer.copyTo(destData.data(), dataSize + 1, 1), 0); + + // Test copying to a nullptr destination + EXPECT_EQ(buffer.copyTo(nullptr, 0, 1), 0); + + // Test copying 0 bytes + EXPECT_EQ(buffer.copyTo(destData.data(), 0, 0), 0); + + // Test copying from an empty buffer + memory::PoolBuffer emptyBuffer(allocator); + EXPECT_TRUE(emptyBuffer.allocate(0)); + EXPECT_EQ(emptyBuffer.copyTo(destData.data(), 0, 1), 0); +} + +TEST(PoolBuffer, writeFromPoolBuffer) +{ + // Use different chunk sizes for src and dest to test more complex scenarios + memory::PoolAllocator<100> srcAllocator(20, "srcTest"); + memory::PoolAllocator<128> destAllocator(20, "destTest"); + + auto srcBuffer = memory::makeUniquePoolBuffer(srcAllocator, 1000); + ASSERT_TRUE(srcBuffer); + auto destBuffer = memory::makeUniquePoolBuffer(destAllocator, 1000); + ASSERT_TRUE(destBuffer); + + // Fill srcBuffer with some data + std::vector sourceData(1000); + for (size_t i = 0; i < sourceData.size(); ++i) + { + sourceData[i] = static_cast(i % 256); + } + srcBuffer->copyFrom(sourceData.data(), sourceData.size()); + + // Case 1: Simple full copy + { + SCOPED_TRACE("Case 1: Simple full copy"); + std::vector zeros(1000, 0); + destBuffer->copyFrom(zeros.data(), zeros.size()); + + EXPECT_EQ(destBuffer->copyFrom(*srcBuffer, 0, srcBuffer->size(), 0), srcBuffer->size()); + + std::vector destData(1000); + destBuffer->copyTo(destData.data(), 0, destData.size()); + EXPECT_EQ(sourceData, destData); + } + + // Case 2: Copy with source and destination offsets, crossing chunk boundaries + { + SCOPED_TRACE("Case 2: Offsets and boundary crossing"); + std::vector zeros(1000, 0); + destBuffer->copyFrom(zeros.data(), zeros.size()); + + const size_t srcOffset = 50; // start in 1st chunk of src + const size_t destOffset = 150; // start in 2nd chunk of dest + const size_t len = 300; // will cross boundaries for both + + size_t bytesWritten = destBuffer->copyFrom(*srcBuffer, srcOffset, len, destOffset); + EXPECT_EQ(bytesWritten, len); + + // Check area before write + std::vector beforeData(destOffset); + destBuffer->copyTo(beforeData.data(), 0, destOffset); + EXPECT_TRUE(std::all_of(beforeData.begin(), beforeData.end(), [](uint8_t i) { return i == 0; })); + + // Check written data + std::vector writtenData(len); + destBuffer->copyTo(writtenData.data(), destOffset, len); + EXPECT_TRUE(std::equal(sourceData.begin() + srcOffset, sourceData.begin() + srcOffset + len, writtenData.begin())); + + // Check area after write + const size_t afterOffset = destOffset + len; + if (destBuffer->size() > afterOffset) + { + const size_t afterLen = destBuffer->size() - afterOffset; + std::vector afterData(afterLen); + destBuffer->copyTo(afterData.data(), afterOffset, afterLen); + EXPECT_TRUE(std::all_of(afterData.begin(), afterData.end(), [](uint8_t i) { return i == 0; })); + } + } + + // Case 3: Partial copy, len is smaller than available data + { + SCOPED_TRACE("Case 3: Partial copy"); + std::vector zeros(1000, 0); + destBuffer->copyFrom(zeros.data(), zeros.size()); + const size_t srcOffset = 10; + const size_t destOffset = 20; + const size_t len = 50; + + size_t bytesWritten = destBuffer->copyFrom(*srcBuffer, srcOffset, len, destOffset); + EXPECT_EQ(bytesWritten, len); + + std::vector destData(len); + destBuffer->copyTo(destData.data(), destOffset, len); + EXPECT_TRUE(std::equal(sourceData.begin() + srcOffset, sourceData.begin() + srcOffset + len, destData.begin())); + } + + // Case 4: len is larger than available in src + { + SCOPED_TRACE("Case 4: Read past source boundary"); + const size_t srcOffset = 900; + const size_t destOffset = 0; + const size_t len = 200; // only 100 bytes available in src + const size_t expectedWrite = 100; + + size_t bytesWritten = destBuffer->copyFrom(*srcBuffer, srcOffset, len, destOffset); + EXPECT_EQ(bytesWritten, expectedWrite); + + std::vector destData(expectedWrite); + destBuffer->copyTo(destData.data(), destOffset, expectedWrite); + EXPECT_TRUE(std::equal(sourceData.begin() + srcOffset, sourceData.end(), destData.begin())); + } + + // Case 5: len is larger than available in dest + { + SCOPED_TRACE("Case 5: Write past destination boundary"); + const size_t srcOffset = 0; + const size_t destOffset = 950; + const size_t len = 100; // only 50 bytes available in dest + const size_t expectedWrite = 50; + + size_t bytesWritten = destBuffer->copyFrom(*srcBuffer, srcOffset, len, destOffset); + EXPECT_EQ(bytesWritten, expectedWrite); + + std::vector destData(expectedWrite); + destBuffer->copyTo(destData.data(), destOffset, expectedWrite); + EXPECT_TRUE( + std::equal(sourceData.begin() + srcOffset, sourceData.begin() + srcOffset + expectedWrite, destData.begin())); + } + + // Case 6: Zero-length copy + { + SCOPED_TRACE("Case 6: Zero-length copy"); + EXPECT_EQ(destBuffer->copyFrom(*srcBuffer, 10, 0, 10), 0); + } + + // Case 7: Out-of-bounds offsets + { + SCOPED_TRACE("Case 7: Out-of-bounds offsets"); + EXPECT_EQ(destBuffer->copyFrom(*srcBuffer, srcBuffer->size(), 1, 0), 0); + EXPECT_EQ(destBuffer->copyFrom(*srcBuffer, 0, 1, destBuffer->size()), 0); + EXPECT_EQ(destBuffer->copyFrom(*srcBuffer, srcBuffer->size() + 1, 1, 0), 0); + EXPECT_EQ(destBuffer->copyFrom(*srcBuffer, 0, 1, destBuffer->size() + 1), 0); + } + + // Case 8: Copy from empty source buffer + { + SCOPED_TRACE("Case 8: Empty source"); + auto emptySrc = memory::makeUniquePoolBuffer(srcAllocator, 0); + ASSERT_TRUE(emptySrc); + EXPECT_EQ(destBuffer->copyFrom(*emptySrc, 0, 1, 0), 0); + } + + // Case 9: Copy to empty dest buffer + { + SCOPED_TRACE("Case 9: Empty destination"); + auto emptyDest = memory::makeUniquePoolBuffer(destAllocator, 0); + ASSERT_TRUE(emptyDest); + EXPECT_EQ(emptyDest->copyFrom(*srcBuffer, 0, 1, 0), 0); + } +} + diff --git a/test/sctp/SctpTransferTests.cpp b/test/sctp/SctpTransferTests.cpp index 68a833b3b..bae273d5d 100644 --- a/test/sctp/SctpTransferTests.cpp +++ b/test/sctp/SctpTransferTests.cpp @@ -26,7 +26,7 @@ namespace struct SctpTransferTestFixture : public ::testing::Test { - SctpTransferTestFixture() : _timestamp(utils::Time::getAbsoluteTime()) {} + SctpTransferTestFixture() : _timestamp(utils::Time::getAbsoluteTime()), _mainPacketAllocator(500 * 1024, "main-allocator") {} void establishConnection(SctpEndpoint& A, SctpEndpoint& B) { @@ -59,9 +59,10 @@ struct SctpTransferTestFixture : public ::testing::Test } uint64_t _timestamp; sctp::SctpConfig _config; + memory::PacketPoolAllocator _mainPacketAllocator; }; -TEST_F(SctpTransferTestFixture, send500K) +TEST_F(SctpTransferTestFixture, send250K) { using namespace sctptest; SctpEndpoint A(5000, _config, _timestamp, 250); @@ -69,15 +70,14 @@ TEST_F(SctpTransferTestFixture, send500K) establishConnection(A, B); - const int DATA_SIZE = 450 * 1024; + const int DATA_SIZE = 250 * 1024; std::array data; std::memset(data.data(), 0xAA, DATA_SIZE); const auto startTime = _timestamp; - A._session->sendMessage(A.getStreamId(), - webrtc::DataChannelPpid::WEBRTC_BINARY, - data.data(), - DATA_SIZE, - _timestamp); + + auto buffer = memory::makeUniquePoolBuffer( _mainPacketAllocator, data.data(), DATA_SIZE); + + A._session->sendMessage(A.getStreamId(), webrtc::DataChannelPpid::WEBRTC_BINARY, buffer, _timestamp); for (int i = 0; i < 280000 && B.getReceivedSize() < DATA_SIZE; ++i) { _timestamp += 1 * utils::Time::ms; @@ -113,10 +113,12 @@ TEST_F(SctpTransferTestFixture, sendMany) auto toSend = std::max(15, rand() % data.size()); totalSent += toSend; ++totalSentCount; + + auto buffer = memory::makeUniquePoolBuffer(_mainPacketAllocator, data.data(), toSend); + A._session->sendMessage(A.getStreamId(), webrtc::DataChannelPpid::WEBRTC_BINARY, - data.data(), - toSend, + buffer, _timestamp); } _timestamp += 1 * utils::Time::ms; @@ -149,14 +151,16 @@ TEST_F(SctpTransferTestFixture, withLoss20) A._sendQueue.setLossRate(0.01); establishConnection(A, B); - const int DATA_SIZE = 450 * 1024; + const int DATA_SIZE = 250 * 1024; std::array data; std::memset(data.data(), 0xcd, DATA_SIZE); const auto startTime = _timestamp; + + auto buffer = memory::makeUniquePoolBuffer(_mainPacketAllocator, data.data(), DATA_SIZE); + A._session->sendMessage(A.getStreamId(), webrtc::DataChannelPpid::WEBRTC_BINARY, - data.data(), - DATA_SIZE, + buffer, _timestamp); const auto endTime = _timestamp + 4 * utils::Time::sec; while (static_cast(endTime - _timestamp) > 0 && B.getReceivedSize() == 0) @@ -207,14 +211,16 @@ TEST_F(SctpTransferTestFixture, lostSacks) establishConnection(A, B); - const int DATA_SIZE = 450 * 1024; + const int DATA_SIZE = 250 * 1024; std::array data; std::memset(data.data(), 0xcd, DATA_SIZE); // const auto startTime = _timestamp; + + auto buffer = memory::makeUniquePoolBuffer(_mainPacketAllocator, data.data(), DATA_SIZE); + A._session->sendMessage(A.getStreamId(), webrtc::DataChannelPpid::WEBRTC_BINARY, - data.data(), - DATA_SIZE, + buffer, _timestamp); for (int i = 0; i < 10000; ++i) { @@ -245,14 +251,15 @@ TEST_F(SctpTransferTestFixture, zeroRecvWindow) establishConnection(A, B); - const int DATA_SIZE = 450 * 1024; + const int DATA_SIZE = 250 * 1024; std::array data; std::memset(data.data(), 0xcd, DATA_SIZE); + auto buffer = memory::makeUniquePoolBuffer(_mainPacketAllocator, data.data(), DATA_SIZE); + A._session->sendMessage(A.getStreamId(), webrtc::DataChannelPpid::WEBRTC_BINARY, - data.data(), - DATA_SIZE, + buffer, _timestamp); for (int i = 0; i < 10000; ++i) { @@ -296,8 +303,11 @@ TEST_F(SctpTransferTestFixture, sendEmptyMessage) const auto startTime = _timestamp; size_t totalSent = 0; size_t totalSentCount = 0; + + auto buffer = memory::makeUniquePoolBuffer(_mainPacketAllocator, data.data(), 0); + const bool sendResult = - A._session->sendMessage(A.getStreamId(), webrtc::DataChannelPpid::WEBRTC_STRING, data.data(), 0, _timestamp); + A._session->sendMessage(A.getStreamId(), webrtc::DataChannelPpid::WEBRTC_STRING, buffer, _timestamp); EXPECT_TRUE(sendResult); for (int i = 0; i < 30000 && B.getReceivedSize() < DATA_SIZE; ++i) @@ -329,8 +339,10 @@ TEST_F(SctpTransferTestFixture, sendMtuMessage) std::array data; std::memset(data.data(), 0xdd, data.size()); + auto buffer0 = memory::makeUniquePoolBuffer(_mainPacketAllocator, data.data(), mtu); + const bool sendResult = - A._session->sendMessage(A.getStreamId(), webrtc::DataChannelPpid::WEBRTC_STRING, data.data(), mtu, _timestamp); + A._session->sendMessage(A.getStreamId(), webrtc::DataChannelPpid::WEBRTC_STRING, buffer0, _timestamp); EXPECT_TRUE(sendResult); for (int i = 0; i < 3000 && B.getReceivedSize() < mtu; ++i) @@ -346,10 +358,11 @@ TEST_F(SctpTransferTestFixture, sendMtuMessage) EXPECT_EQ(B.getReceivedSize(), mtu); EXPECT_EQ(B.getReceivedMessageCount(), 1); + auto buffer1 = memory::makeUniquePoolBuffer(_mainPacketAllocator, data.data(), mtu * 2); + EXPECT_TRUE(A._session->sendMessage(A.getStreamId(), webrtc::DataChannelPpid::WEBRTC_STRING, - data.data(), - mtu * 2, + buffer1, _timestamp)); for (int i = 0; i < 3000 && B.getReceivedSize() < DATA_SIZE; ++i) @@ -373,23 +386,26 @@ class DummySctpTransport : public webrtc::DataStreamTransport { uint16_t streamId = 0; uint32_t protocol = 0; - char data[256]; - uint16_t length = 0; + memory::UniquePoolBuffer buffer; }; - DummySctpTransport() : _loggableId("dummy") {} + DummySctpTransport(memory::PacketPoolAllocator& allocator) : _loggableId("dummy"), _allocator(allocator) {} - bool sendSctp(uint16_t streamId, uint32_t protocolId, const void* data, uint16_t length) override + bool sendSctp(uint16_t streamId, + uint32_t protocolId, + memory::UniquePoolBuffer buffer) override { - SctpInfo info{streamId, protocolId}; - info.length = length; - std::strncpy(info.data, reinterpret_cast(data), length); - _sendQueue.push(info); + SctpInfo info; + info.streamId = streamId; + info.protocol = protocolId; + info.buffer = std::move(buffer); + _sendQueue.push(std::move(info)); return true; } uint16_t allocateOutboundSctpStream() override { return 0; } + memory::PacketPoolAllocator& getAllocator() override { return _allocator; } logger::LoggableId _loggableId; - + memory::PacketPoolAllocator& _allocator; std::queue _sendQueue; }; @@ -400,7 +416,7 @@ TEST_F(SctpTransferTestFixture, sctpReorder) SctpEndpoint B(5001, _config, _timestamp, 250); establishConnection(A, B); - DummySctpTransport fakeTransport; + DummySctpTransport fakeTransport(_mainPacketAllocator); webrtc::WebRtcDataStream dataStream(2, fakeTransport); alignas(memory::Packet) const char webRtcOpen[] = "\x03\x00\x00\x00\x00\x00\x00\x00\x00\x12\x00\x00\x77\x65\x62\x72" @@ -422,20 +438,28 @@ TEST_F(SctpTransferTestFixture, sctpReorder) "\x34\x64\x32\x38\x2d\x62\x65\x36\x35\x2d\x39\x30\x33\x65\x39\x38" "\x63\x62\x37\x65\x66\x32\x22\x5d\x7d\x00\x00\x00"; + auto buffer0 = memory::makeUniquePoolBuffer(_mainPacketAllocator, sizeof(webRtcOpen)); + buffer0->copyFrom(webRtcOpen, sizeof(webRtcOpen) - 1, 0); + A._session->sendMessage(A.getStreamId(), webrtc::DataChannelPpid::WEBRTC_ESTABLISH, - webRtcOpen, - sizeof(webRtcOpen) - 1, + buffer0, _timestamp); + + auto buffer1 = memory::makeUniquePoolBuffer(_mainPacketAllocator, sizeof(msg1)); + buffer1->copyFrom(msg1, sizeof(msg1) - 1, 0); + A._session->sendMessage(A.getStreamId(), webrtc::DataChannelPpid::WEBRTC_STRING, - msg1, - sizeof(msg1) - 1, + buffer1, _timestamp); + + auto buffer2 = memory::makeUniquePoolBuffer(_mainPacketAllocator, sizeof(msg2)); + buffer2->copyFrom(msg2, sizeof(msg2) - 1, 0); + A._session->sendMessage(A.getStreamId(), webrtc::DataChannelPpid::WEBRTC_STRING, - msg2, - sizeof(msg2) - 1, + buffer2, _timestamp); A.process(); auto sctpOpen = A._sendQueue.pop(); @@ -461,10 +485,21 @@ TEST_F(SctpTransferTestFixture, sctpReorder) _timestamp += B.process(); auto sack3 = B._sendQueue.pop(); - dataStream.onSctpMessage(&fakeTransport, 0, 55, 0x32, sctpOpen->get() + 28, sctpOpen->getLength() - 28); - auto openAckMsg = fakeTransport._sendQueue.front(); + auto buffer = memory::makeUniquePoolBuffer(this->_mainPacketAllocator, sctpOpen->getLength() - 28); + webrtc::SctpStreamMessageHeader header = { + .payloadProtocol = 0x32, + .id = 0, + .sequenceNumber = 55, + }; + + buffer->copyFrom(&header, sizeof(webrtc::SctpStreamMessageHeader), 0); + buffer->copyFrom(sctpOpen->get() + 28, sctpOpen->getLength() - 28, sizeof(webrtc::SctpStreamMessageHeader)); + + dataStream.onSctpMessageBuffer(&fakeTransport, buffer); + auto openAckMsg = std::move(fakeTransport._sendQueue.front()); fakeTransport._sendQueue.pop(); - B._session->sendMessage(A.getStreamId(), openAckMsg.protocol, openAckMsg.data, openAckMsg.length, _timestamp); + + B._session->sendMessage(A.getStreamId(), openAckMsg.protocol, openAckMsg.buffer, _timestamp); B.process(); auto openAck = B._sendQueue.pop(); EXPECT_NE(sack3, nullptr); diff --git a/test/transport/SctpIntegrationTest.cpp b/test/transport/SctpIntegrationTest.cpp index 7f4dd05d9..6ea30cb82 100644 --- a/test/transport/SctpIntegrationTest.cpp +++ b/test/transport/SctpIntegrationTest.cpp @@ -69,13 +69,15 @@ struct ClientPair : public TransportClientPair { if (payloadProtocol != webrtc::DataChannelPpid::WEBRTC_STRING) { + auto buffer = memory::makeUniquePoolBuffer(_sendAllocator, length); + buffer->copyFrom(data, length, 0); if (sender == _transport1.get()) { - _stream1.onSctpMessage(sender, streamId, streamSequenceNumber, payloadProtocol, data, length); + _stream1.onSctpMessageBuffer(sender, buffer); } else { - _stream2.onSctpMessage(sender, streamId, streamSequenceNumber, payloadProtocol, data, length); + _stream2.onSctpMessageBuffer(sender, buffer); } } diff --git a/test/transport/SctpTest.cpp b/test/transport/SctpTest.cpp index c8e0db933..11cb6c8ca 100644 --- a/test/transport/SctpTest.cpp +++ b/test/transport/SctpTest.cpp @@ -15,6 +15,7 @@ #include "utils/Pacer.h" #include "utils/Time.h" #include "webrtc/DataChannel.h" +#include "webrtc/WebRtcDataStream.h" #include #include #include @@ -121,10 +122,9 @@ struct ClientPair : public transport::DataReceiver { ++_messagesSent; _messageBytesSent += std::strlen(theMessage); - _transport1->sendSctp(_streamId, - webrtc::DataChannelPpid::WEBRTC_STRING, - theMessage, - std::strlen(theMessage)); + + auto buffer = memory::makeUniquePoolBuffer(_sendAllocator, theMessage, std::strlen(theMessage)); + _transport1->sendSctp(_streamId,webrtc::DataChannelPpid::WEBRTC_STRING, std::move(buffer)); } for (int j = 0; j < 1; ++j) diff --git a/transport/TransportImpl.cpp b/transport/TransportImpl.cpp index 079de35d2..28749e294 100644 --- a/transport/TransportImpl.cpp +++ b/transport/TransportImpl.cpp @@ -1,4 +1,5 @@ #include "TransportImpl.h" +#include "transport/sctp/SctpConfig.h" #include "api/utils.h" #include "bwe/BandwidthEstimator.h" #include "config/Config.h" @@ -10,6 +11,7 @@ #include "logger/Logger.h" #include "logger/PacketLogger.h" #include "memory/AudioPacketPoolAllocator.h" +#include "memory/PoolBuffer.h" #include "rtp/RtcpFeedback.h" #include "rtp/RtpHeader.h" #include "sctp/SctpAssociation.h" @@ -164,64 +166,41 @@ class DtlsSetRemoteJob : public jobmanager::CountedJob class SctpSendJob : public jobmanager::CountedJob { - struct SctpDataChunk - { - uint32_t payloadProtocol; - uint16_t id; - - void* data() { return &id + 1; } - const void* data() const { return &id + 1; }; - }; - public: SctpSendJob(sctp::SctpAssociation& association, uint16_t streamId, uint32_t protocolId, - const void* data, - uint16_t length, - memory::PacketPoolAllocator& allocator, + memory::UniquePoolBuffer message, jobmanager::JobQueue& jobQueue, TransportImpl& transport) : CountedJob(transport.getJobCounter()), _jobQueue(jobQueue), _sctpAssociation(association), - _packet(memory::makeUniquePacket(allocator)), + _streamId(streamId), + _protocolId(protocolId), + _message(std::move(message)), _transport(transport) { - if (_packet) - { - if (sizeof(SctpDataChunk) + length > memory::Packet::size) - { - logger::error("sctp message too big %u", _transport.getLoggableId().c_str(), length); - return; - } - - auto* header = reinterpret_cast(_packet->get()); - header->id = streamId; - header->payloadProtocol = protocolId; - std::memcpy(header->data(), data, length); - _packet->setLength(sizeof(SctpDataChunk) + length); - } - else + if (!_message) { logger::error("failed to create packet for outbound sctp", transport.getLoggableId().c_str()); + return; } } void run() override { - if (!_packet) + if (!_message) { return; } auto timestamp = utils::Time::getAbsoluteTime(); auto currentTimeout = _sctpAssociation.nextTimeout(timestamp); - auto& header = *reinterpret_cast(_packet->get()); - if (!_sctpAssociation.sendMessage(header.id, - header.payloadProtocol, - header.data(), - _packet->getLength() - sizeof(header), + + if (!_sctpAssociation.sendMessage(_streamId, + _protocolId, + _message, timestamp)) { if (_transport.isConnected()) @@ -238,7 +217,7 @@ class SctpSendJob : public jobmanager::CountedJob logger::info("SCTP message sent too soon. sctp state %s, %zuB", _transport.getLoggableId().c_str(), toString(_sctpAssociation.getState()), - _packet->getLength()); + _message->size()); } } @@ -256,7 +235,9 @@ class SctpSendJob : public jobmanager::CountedJob private: jobmanager::JobQueue& _jobQueue; sctp::SctpAssociation& _sctpAssociation; - memory::UniquePacket _packet; + const uint16_t _streamId; + const uint32_t _protocolId; + memory::UniquePoolBuffer _message; TransportImpl& _transport; }; @@ -2231,25 +2212,23 @@ uint16_t TransportImpl::allocateOutboundSctpStream() return 0xFFFFu; } -bool TransportImpl::sendSctp(const uint16_t streamId, - const uint32_t protocolId, - const void* data, - const uint16_t length) +bool TransportImpl::sendSctp(uint16_t streamId, + uint32_t protocolId, + memory::UniquePoolBuffer message) { if (!_remoteSctpPort.isSet() || !_sctpAssociation) { logger::warn("SCTP not established yet.", _loggableId.c_str()); return false; } - - if (length == 0 || 2 * sizeof(uint32_t) + length > memory::Packet::size) + const auto maxSize = _sctpServerPort->getConfig().maxMessageSize; + if (message->getLength() == 0 || message->getLength() > maxSize) { - logger::error("sctp message invalid size %u", getLoggableId().c_str(), length); + logger::error("sctp message invalid size %zu", getLoggableId().c_str(), message->getLength()); return false; } - _jobQueue - .addJob(*_sctpAssociation, streamId, protocolId, data, length, _mainAllocator, _jobQueue, *this); + _jobQueue.addJob(*_sctpAssociation, streamId, protocolId, std::move(message), _jobQueue, *this); return true; } diff --git a/transport/TransportImpl.h b/transport/TransportImpl.h index 281375e02..ca099b398 100644 --- a/transport/TransportImpl.h +++ b/transport/TransportImpl.h @@ -156,6 +156,8 @@ class TransportImpl : public RtcTransport, jobmanager::JobQueue& getJobQueue() override { return _jobQueue; } + memory::PacketPoolAllocator& getAllocator() override { return _mainAllocator; } + uint32_t getSenderLossCount() const override; uint32_t getUplinkEstimateKbps() const override; uint32_t getDownlinkEstimateKbps() const override; @@ -177,7 +179,9 @@ class TransportImpl : public RtcTransport, uint32_t rtpFrequency) override; void setAbsSendTimeExtensionId(uint8_t extensionId) override; - bool sendSctp(uint16_t streamId, uint32_t protocolId, const void* data, uint16_t length) override; + bool sendSctp(uint16_t streamId, + uint32_t protocolId, + memory::UniquePoolBuffer message) override; uint16_t allocateOutboundSctpStream() override; void setSctp(uint16_t localPort, uint16_t remotePort) override; void connectSctp() override; diff --git a/transport/sctp/SctpAssociation.h b/transport/sctp/SctpAssociation.h index e300ec026..5b8fd1a80 100644 --- a/transport/sctp/SctpAssociation.h +++ b/transport/sctp/SctpAssociation.h @@ -2,6 +2,7 @@ #include #include #include +#include "memory/PoolBuffer.h" namespace sctp { struct SctpConfig; @@ -47,8 +48,7 @@ class SctpAssociation virtual uint16_t allocateStream() = 0; virtual bool sendMessage(uint16_t streamId, uint32_t payloadProtocol, - const void* payloadData, - size_t length, + memory::UniquePoolBuffer& payloadData, uint64_t timestamp) = 0; virtual size_t outboundPendingSize() const = 0; virtual int64_t nextTimeout(uint64_t timestamp) = 0; diff --git a/transport/sctp/SctpAssociationImpl.cpp b/transport/sctp/SctpAssociationImpl.cpp index 008593d85..67433fea1 100644 --- a/transport/sctp/SctpAssociationImpl.cpp +++ b/transport/sctp/SctpAssociationImpl.cpp @@ -6,6 +6,7 @@ #include "memory/Array.h" #include "utils/MersienneRandom.h" #include "utils/Time.h" +#include "webrtc/WebRtcDataStream.h" #define SCTP_LOG_ENABLE 0 @@ -128,7 +129,8 @@ SctpAssociationImpl::SentDataChunk::SentDataChunk(uint16_t streamId_, bool fragmentBegin_, bool fragmentEnd_, uint32_t transmissionSequenceNumber, - const void* payload, + memory::UniquePoolBuffer& buffer, + size_t offset, size_t size_) : transmitTime(0), size(size_), @@ -143,7 +145,8 @@ SctpAssociationImpl::SentDataChunk::SentDataChunk(uint16_t streamId_, reserved0(0), reserved1(0) { - std::memcpy(data(), payload, size_); + const auto copied = buffer->copyTo(data(), offset, size_); + assert(copied == size); } SctpAssociationImpl::ReceivedDataChunk::ReceivedDataChunk(const PayloadDataChunk& chunk, uint64_t timestamp) @@ -379,11 +382,11 @@ void SctpAssociationImpl::onCookieEcho(const SctpPacket& cookieEcho, const uint6 bool SctpAssociationImpl::sendMessage(uint16_t streamId, uint32_t payloadProtocol, - const void* payloadData, - size_t length, + memory::UniquePoolBuffer& sctpMessage, uint64_t timestamp) { auto streamIt = _streams.find(streamId); + size_t length = sctpMessage->size(); if (_state < State::ESTABLISHED || streamIt == _streams.cend()) { logger::warn("SCTP stream not open yet %u, count %zu", _loggableId.c_str(), streamId, _streams.size()); @@ -402,7 +405,6 @@ bool SctpAssociationImpl::sendMessage(uint16_t streamId, return false; } auto& streamState = streamIt->second; - auto* payloadBytes = reinterpret_cast(payloadData); size_t writtenBytes = 0; for (size_t i = 0; i < pktCount; ++i) { @@ -414,7 +416,8 @@ bool SctpAssociationImpl::sendMessage(uint16_t streamId, i == 0, i == (pktCount - 1), _local.tsn, - payloadBytes + writtenBytes, + sctpMessage, + writtenBytes, toWrite); assert(chunk); diff --git a/transport/sctp/SctpAssociationImpl.h b/transport/sctp/SctpAssociationImpl.h index be96ed2e0..34e576c1f 100644 --- a/transport/sctp/SctpAssociationImpl.h +++ b/transport/sctp/SctpAssociationImpl.h @@ -2,6 +2,7 @@ #include "SctpTimer.h" #include "Sctprotocol.h" #include "logger/Logger.h" +#include "memory/PoolBuffer.h" #include "memory/RingAllocator.h" #include "utils/MersienneRandom.h" #include @@ -44,7 +45,8 @@ class SctpAssociationImpl : public SctpAssociation bool fragmentBegin, bool fragmentEnd, uint32_t transmissionSequenceNumber, - const void* payload, + memory::UniquePoolBuffer& buffer, + size_t offset, size_t size); uint64_t transmitTime; @@ -119,8 +121,7 @@ class SctpAssociationImpl : public SctpAssociation uint16_t allocateStream() override; bool sendMessage(uint16_t streamId, uint32_t payloadProtocol, - const void* payloadData, - size_t length, + memory::UniquePoolBuffer& sctpMessage, uint64_t timestamp) override; size_t outboundPendingSize() const override; int64_t nextTimeout(uint64_t timestamp) override; diff --git a/transport/sctp/SctpConfig.h b/transport/sctp/SctpConfig.h index c907fd9a1..f3428eefd 100644 --- a/transport/sctp/SctpConfig.h +++ b/transport/sctp/SctpConfig.h @@ -50,6 +50,8 @@ struct SctpConfig uint32_t max = 4096; } mtu; + uint32_t maxMessageSize = 2048; + size_t transmitBufferSize = 512 * 1024; size_t receiveBufferSize = 512 * 1024; }; diff --git a/webrtc/DataStreamTransport.h b/webrtc/DataStreamTransport.h index 00a279d23..b8e186767 100644 --- a/webrtc/DataStreamTransport.h +++ b/webrtc/DataStreamTransport.h @@ -1,5 +1,6 @@ #pragma once +#include "memory/PoolBuffer.h" #include "memory/PacketPoolAllocator.h" namespace webrtc @@ -8,7 +9,10 @@ class DataStreamTransport { public: // Expects packet with SctpStreamMessageHeader - virtual bool sendSctp(uint16_t streamId, uint32_t protocolId, const void* data, uint16_t length) = 0; + virtual bool sendSctp(uint16_t streamId, + uint32_t protocolId, + memory::UniquePoolBuffer buffer) = 0; virtual uint16_t allocateOutboundSctpStream() = 0; + virtual memory::PacketPoolAllocator& getAllocator() = 0; }; } // namespace webrtc diff --git a/webrtc/WebRtcDataStream.cpp b/webrtc/WebRtcDataStream.cpp index 5f28fb14c..c0b175018 100644 --- a/webrtc/WebRtcDataStream.cpp +++ b/webrtc/WebRtcDataStream.cpp @@ -2,6 +2,7 @@ #include "logger/Logger.h" #include "webrtc/DataChannel.h" #include "webrtc/DataStreamTransport.h" +#include "memory/PoolBuffer.h" namespace webrtc { @@ -27,35 +28,61 @@ uint16_t WebRtcDataStream::open(const std::string& label) auto& message = DataChannelOpenMessage::create(data, label); _state = State::OPENING; - _transport.sendSctp(_streamId, DataChannelPpid::WEBRTC_ESTABLISH, data, message.size()); + auto buffer = memory::makeUniquePoolBuffer(_transport.getAllocator(), data, message.size()); + _transport.sendSctp(_streamId, DataChannelPpid::WEBRTC_ESTABLISH, std::move(buffer)); return _streamId; } void WebRtcDataStream::sendString(const char* string, const size_t length) { - _transport.sendSctp(_streamId, DataChannelPpid::WEBRTC_STRING, string, length); + auto buffer = memory::makeUniquePoolBuffer(_transport.getAllocator(), string, length); + _transport.sendSctp(_streamId, DataChannelPpid::WEBRTC_STRING, std::move(buffer)); } void WebRtcDataStream::sendData(const void* data, size_t length) { - _transport.sendSctp(_streamId, DataChannelPpid::WEBRTC_BINARY, data, length); + auto buffer = memory::makeUniquePoolBuffer(_transport.getAllocator(), data, length); + _transport.sendSctp(_streamId, DataChannelPpid::WEBRTC_BINARY, std::move(buffer)); } -void WebRtcDataStream::onSctpMessage(webrtc::DataStreamTransport* sender, - uint16_t streamId, - uint16_t streamSequenceNumber, - uint32_t payloadProtocol, - const void* data, - size_t length) +void WebRtcDataStream::sendMessage(uint32_t protocolId, memory::UniquePoolBuffer message) { + _transport.sendSctp(_streamId, protocolId, std::move(message)); +} + +void WebRtcDataStream::onSctpMessageBuffer(webrtc::DataStreamTransport* sender, + memory::UniquePoolBuffer& message + ) { + // HEADER: SctpStreamMessageHeader prepended to payload + assert(message->size() >= sizeof(SctpStreamMessageHeader)); + if (message->size() < sizeof(SctpStreamMessageHeader)) + { + return; + } + if (message->size() > 8192) { + logger::warn("SCTP message too big, len %zu", + _loggableId, + message->size() + ); + return; + } + + char continousBuffer[message->size()]; + message->copyTo(continousBuffer, 0, message->size()); + + const auto& header = *reinterpret_cast(continousBuffer); + const auto& payloadProtocol = header.payloadProtocol; + const auto& streamId = header.id; + const auto length = message->getLength() - sizeof(SctpStreamMessageHeader); + if (payloadProtocol == DataChannelPpid::WEBRTC_STRING) { - std::string command(reinterpret_cast(data), length); + std::string command(header.getMessage(), length); logger::debug("received on stream %u message %s", _loggableId, streamId, command.c_str()); if (_listener) { - _listener->onWebRtcDataString(reinterpret_cast(data), length); + _listener->onWebRtcDataString(header.getMessage(), length); } } if (payloadProtocol != webrtc::DataChannelPpid::WEBRTC_ESTABLISH) @@ -65,20 +92,21 @@ void WebRtcDataStream::onSctpMessage(webrtc::DataStreamTransport* sender, if (_state == State::CLOSED) { - auto msg = reinterpret_cast(data); + auto msg = reinterpret_cast(header.data()); if (msg->messageType == webrtc::DataChannelMessageType::DATA_CHANNEL_OPEN) { _state = State::OPEN; _streamId = streamId; _label = msg->getLabel(); uint8_t ack[] = {webrtc::DATA_CHANNEL_ACK}; - sender->sendSctp(streamId, webrtc::DataChannelPpid::WEBRTC_ESTABLISH, ack, 1); + auto buffer = memory::makeUniquePoolBuffer(_transport.getAllocator(), ack, 1); + sender->sendSctp(streamId, DataChannelPpid::WEBRTC_ESTABLISH, std::move(buffer)); logger::info("Data channel open. stream %u", _loggableId, streamId); } } else if (_state == State::OPENING) { - auto* message = reinterpret_cast(data); + auto* message = reinterpret_cast(header.data()); if (length > 0 && message[0] == DataChannelMessageType::DATA_CHANNEL_ACK) { _state = State::OPEN; @@ -87,39 +115,35 @@ void WebRtcDataStream::onSctpMessage(webrtc::DataStreamTransport* sender, } } -memory::UniquePacket makeUniquePacket(uint16_t streamId, +memory::UniquePoolBuffer makeUniqueSctpMessage(uint16_t streamId, uint32_t payloadProtocol, const void* message, size_t messageSize, memory::PacketPoolAllocator& allocator) { - assert(sizeof(SctpStreamMessageHeader) + messageSize <= memory::Packet::size); - if (sizeof(SctpStreamMessageHeader) + messageSize > memory::Packet::size) - { - return nullptr; + auto needNullTermination = payloadProtocol == WEBRTC_STRING && message && messageSize > 0 && ((char*)message)[messageSize - 1] == '\0'; + if (payloadProtocol == WEBRTC_STRING_EMPTY || payloadProtocol == WEBRTC_BINARY_EMPTY) { + needNullTermination = false; + messageSize = 0; } - - auto packet = memory::makeUniquePacket(allocator); - if (!packet) + auto buffer = memory::makeUniquePoolBuffer(allocator, sizeof(SctpStreamMessageHeader) + messageSize + (needNullTermination ? 1 : 0)); + if (!buffer) { - return packet; + return buffer; } - auto* header = reinterpret_cast(packet->get()); - header->id = streamId; - header->sequenceNumber = 0; - header->payloadProtocol = payloadProtocol; - std::memcpy(header->data(), message, messageSize); - auto* s = reinterpret_cast(header->data()); - if (messageSize > 0 && s[messageSize - 1] != 0) - { - s[messageSize] = 0; - ++messageSize; - } + SctpStreamMessageHeader header = { + .payloadProtocol = payloadProtocol, + .id = streamId, + .sequenceNumber = 0, + }; + buffer->copyFrom(&header, sizeof(SctpStreamMessageHeader), 0); + buffer->copyFrom(message, messageSize, sizeof(SctpStreamMessageHeader)); - packet->setLength(sizeof(SctpStreamMessageHeader) + messageSize); + if (needNullTermination) { + buffer->copyFrom("\0", 1, sizeof(SctpStreamMessageHeader) + messageSize); + } - return packet; + return buffer; } - } // namespace webrtc diff --git a/webrtc/WebRtcDataStream.h b/webrtc/WebRtcDataStream.h index 52d6aa600..5fa4ce45c 100644 --- a/webrtc/WebRtcDataStream.h +++ b/webrtc/WebRtcDataStream.h @@ -1,6 +1,7 @@ #pragma once #include "memory/PacketPoolAllocator.h" +#include "memory/PoolBuffer.h" #include #include @@ -35,16 +36,13 @@ class WebRtcDataStream bool isOpen() const { return _state == State::OPEN; } void sendString(const char* string, const size_t length); void sendData(const void* data, size_t length); + void sendMessage(uint32_t protocolId, memory::UniquePoolBuffer message); uint16_t getStreamId() const { return _streamId; }; std::string getLabel() const { return _label; } - void onSctpMessage(webrtc::DataStreamTransport* sender, - uint16_t streamId, - uint16_t streamSequenceNumber, - uint32_t payloadProtocol, - const void* data, - size_t length); + void onSctpMessageBuffer(webrtc::DataStreamTransport* sender, + memory::UniquePoolBuffer& message); State getState() const { return _state; } @@ -72,12 +70,8 @@ struct SctpStreamMessageHeader }; static_assert(sizeof(SctpStreamMessageHeader) == 8, "Misalignment of SctpStreamMessageHeader"); -inline const SctpStreamMessageHeader& streamMessageHeader(const memory::Packet& p) -{ - return reinterpret_cast(*p.get()); -} - -memory::UniquePacket makeUniquePacket(uint16_t streamId, +// Craeats sctp message with SctpStreamMessageHeader +memory::UniquePoolBuffer makeUniqueSctpMessage(uint16_t streamId, uint32_t payloadProtocol, const void* message, size_t messageSize,