From 8e513105561b7b1a03503d0cfa78b4ffae0b7e5a Mon Sep 17 00:00:00 2001 From: Giulio Eulisse <10544+ktf@users.noreply.github.com> Date: Sat, 24 Feb 2024 00:21:18 +0100 Subject: [PATCH] DPL: avoid TMessage usage TMessage does not allow for non owned buffers, so we end up having an extra buffer in private memory for (de)serializing. Using TBufferFile directly allows to avoid that, so this moves the whole ROOT serialization support in DPL to use it. --- .../Base/include/DetectorsBase/Detector.h | 15 +- Detectors/Base/src/Detector.cxx | 34 ++--- .../src/AODJAlienReaderHelpers.h | 2 + .../Core/include/Framework/DataAllocator.h | 1 + .../Core/include/Framework/DataRefUtils.h | 13 +- .../include/Framework/RootMessageContext.h | 3 + .../Framework/RootSerializationSupport.h | 3 +- .../include/Framework/TMessageSerializer.h | 130 ++++++++---------- Framework/Core/src/CommonDataProcessors.cxx | 4 +- Framework/Core/src/TMessageSerializer.cxx | 31 +++++ Framework/Core/test/test_DataRefUtils.cxx | 28 +++- .../Core/test/test_TMessageSerializer.cxx | 34 +++-- Framework/Utils/test/test_RootTreeWriter.cxx | 1 + Utilities/Mergers/src/ObjectStore.cxx | 11 +- Utilities/Mergers/test/benchmark_Types.cxx | 11 +- 15 files changed, 191 insertions(+), 130 deletions(-) diff --git a/Detectors/Base/include/DetectorsBase/Detector.h b/Detectors/Base/include/DetectorsBase/Detector.h index 6acfa4f5cc46c..4dd0452f2c059 100644 --- a/Detectors/Base/include/DetectorsBase/Detector.h +++ b/Detectors/Base/include/DetectorsBase/Detector.h @@ -29,7 +29,6 @@ #include #include #include -#include #include "CommonUtils/ShmManager.h" #include "CommonUtils/ShmAllocator.h" #include @@ -42,9 +41,7 @@ #include -namespace o2 -{ -namespace base +namespace o2::base { /// This is the basic class for any AliceO2 detector module, whether it is @@ -260,17 +257,12 @@ T decodeShmMessage(fair::mq::Parts& dataparts, int index, bool*& busy) } // this goes into the source -void attachMessageBufferToParts(fair::mq::Parts& parts, fair::mq::Channel& channel, - void* data, size_t size, void (*func_ptr)(void* data, void* hint), void* hint); +void attachMessageBufferToParts(fair::mq::Parts& parts, fair::mq::Channel& channel, void* data, TClass* cl); template void attachTMessage(Container const& hits, fair::mq::Channel& channel, fair::mq::Parts& parts) { - TMessage* tmsg = new TMessage(); - tmsg->WriteObjectAny((void*)&hits, TClass::GetClass(typeid(hits))); - attachMessageBufferToParts( - parts, channel, tmsg->Buffer(), tmsg->BufferSize(), - [](void* data, void* hint) { delete static_cast(hint); }, tmsg); + attachMessageBufferToParts(parts, channel, (void*)&hits, TClass::GetClass(typeid(hits))); } void* decodeTMessageCore(fair::mq::Parts& dataparts, int index); @@ -746,7 +738,6 @@ class DetImpl : public o2::base::Detector ClassDefOverride(DetImpl, 0); }; -} // namespace base } // namespace o2 #endif diff --git a/Detectors/Base/src/Detector.cxx b/Detectors/Base/src/Detector.cxx index 3168e0e84e1f2..3dccf732517b6 100644 --- a/Detectors/Base/src/Detector.cxx +++ b/Detectors/Base/src/Detector.cxx @@ -17,6 +17,7 @@ #include "DetectorsBase/MaterialManager.h" #include "DetectorsCommonDataFormats/DetID.h" #include "Field/MagneticField.h" +#include "Framework/TMessageSerializer.h" #include "TString.h" // for TString #include "TGeoManager.h" @@ -196,16 +197,18 @@ int Detector::registerSensitiveVolumeAndGetVolID(std::string const& name) #include #include #include -namespace o2 -{ -namespace base +namespace o2::base { // this goes into the source -void attachMessageBufferToParts(fair::mq::Parts& parts, fair::mq::Channel& channel, void* data, size_t size, - void (*free_func)(void* data, void* hint), void* hint) -{ - std::unique_ptr message(channel.NewMessage(data, size, free_func, hint)); - parts.AddPart(std::move(message)); +void attachMessageBufferToParts(fair::mq::Parts& parts, fair::mq::Channel& channel, void* data, TClass* cl) { + auto msg = channel.Transport()->CreateMessage(4096, fair::mq::Alignment{64}); + // This will serialize the data directly into the message buffer, without any further + // buffer or copying. Notice how the message will have 8 bytes of header and then + // the serialized data as TBufferFile. In principle one could construct a serialized TMessage payload + // however I did not manage to get it to work for every case. + o2::framework::FairOutputTBuffer buffer(*msg); + o2::framework::TMessageSerializer::serialize(buffer, data, cl); + parts.AddPart(std::move(msg)); } void attachDetIDHeaderMessage(int id, fair::mq::Channel& channel, fair::mq::Parts& parts) { @@ -246,17 +249,14 @@ void* decodeShmCore(fair::mq::Parts& dataparts, int index, bool*& busy) void* decodeTMessageCore(fair::mq::Parts& dataparts, int index) { - class TMessageWrapper : public TMessage - { - public: - TMessageWrapper(void* buf, Int_t len) : TMessage(buf, len) { ResetBit(kIsOwner); } - ~TMessageWrapper() override = default; - }; auto rawmessage = std::move(dataparts.At(index)); - auto message = std::make_unique(rawmessage->GetData(), rawmessage->GetSize()); - return message.get()->ReadObjectAny(message.get()->GetClass()); + o2::framework::FairInputTBuffer buffer((char*)rawmessage->GetData(), rawmessage->GetSize()); + buffer.InitMap(); + auto *cl = buffer.ReadClass(); + buffer.SetBufferOffset(0); + buffer.ResetMap(); + return buffer.ReadObjectAny(cl); } -} // namespace base } // namespace o2 ClassImp(o2::base::Detector); diff --git a/Framework/AnalysisSupport/src/AODJAlienReaderHelpers.h b/Framework/AnalysisSupport/src/AODJAlienReaderHelpers.h index 655e4b6c0b439..4b9fd710aca14 100644 --- a/Framework/AnalysisSupport/src/AODJAlienReaderHelpers.h +++ b/Framework/AnalysisSupport/src/AODJAlienReaderHelpers.h @@ -16,7 +16,9 @@ #include "Framework/AlgorithmSpec.h" #include "Framework/Logger.h" #include + #include +class TFile; namespace o2::framework::readers { diff --git a/Framework/Core/include/Framework/DataAllocator.h b/Framework/Core/include/Framework/DataAllocator.h index 8151d2f83c6c6..029e922aeb90b 100644 --- a/Framework/Core/include/Framework/DataAllocator.h +++ b/Framework/Core/include/Framework/DataAllocator.h @@ -359,6 +359,7 @@ class DataAllocator } else if constexpr (has_root_dictionary::value == true || is_specialization_v == true) { // Serialize a snapshot of an object with root dictionary payloadMessage = proxy.createOutputMessage(routeIndex); + payloadMessage->Rebuild(4096, {64}); if constexpr (is_specialization_v == true) { // Explicitely ROOT serialize a snapshot of object. // An object wrapped into type `ROOTSerialized` is explicitely marked to be ROOT serialized diff --git a/Framework/Core/include/Framework/DataRefUtils.h b/Framework/Core/include/Framework/DataRefUtils.h index defd10244bca5..e59f986f09250 100644 --- a/Framework/Core/include/Framework/DataRefUtils.h +++ b/Framework/Core/include/Framework/DataRefUtils.h @@ -71,12 +71,15 @@ struct DataRefUtils { throw runtime_error("Attempt to extract a TMessage from non-ROOT serialised message"); } - typename RSS::FairTMessage ftm(const_cast(ref.payload), payloadSize); - auto* storedClass = ftm.GetClass(); + typename RSS::FairInputTBuffer ftm(const_cast(ref.payload), payloadSize); auto* requestedClass = RSS::TClass::GetClass(typeid(T)); + ftm.InitMap(); + auto* storedClass = ftm.ReadClass(); // should always have the class description if has_root_dictionary is true assert(requestedClass != nullptr); + ftm.SetBufferOffset(0); + ftm.ResetMap(); auto* object = ftm.ReadObjectAny(storedClass); if (object == nullptr) { throw runtime_error_f("Failed to read object with name %s from message using ROOT serialization.", @@ -146,7 +149,11 @@ struct DataRefUtils { throw runtime_error("ROOT serialization not supported, dictionary not found for data type"); } - typename RSS::FairTMessage ftm(const_cast(ref.payload), payloadSize); + typename RSS::FairInputTBuffer ftm(const_cast(ref.payload), payloadSize); + ftm.InitMap(); + auto *classInfo = ftm.ReadClass(); + ftm.SetBufferOffset(0); + ftm.ResetMap(); result.reset(static_cast(ftm.ReadObjectAny(cl))); if (result.get() == nullptr) { throw runtime_error_f("Unable to extract class %s", cl == nullptr ? "" : cl->GetName()); diff --git a/Framework/Core/include/Framework/RootMessageContext.h b/Framework/Core/include/Framework/RootMessageContext.h index bef60ebbbf9f9..b1124880cf30f 100644 --- a/Framework/Core/include/Framework/RootMessageContext.h +++ b/Framework/Core/include/Framework/RootMessageContext.h @@ -72,6 +72,9 @@ class RootSerializedObject : public MessageContext::ContextObject fair::mq::Parts finalize() final { assert(mParts.Size() == 1); + if (mPayloadMsg->GetSize() < sizeof(char*)) { + mPayloadMsg->Rebuild(4096, {64}); + } TMessageSerializer::Serialize(*mPayloadMsg, mObject.get(), nullptr); mParts.AddPart(std::move(mPayloadMsg)); return ContextObject::finalize(); diff --git a/Framework/Core/include/Framework/RootSerializationSupport.h b/Framework/Core/include/Framework/RootSerializationSupport.h index cbf7408b13c7d..a44093f9c02bf 100644 --- a/Framework/Core/include/Framework/RootSerializationSupport.h +++ b/Framework/Core/include/Framework/RootSerializationSupport.h @@ -21,7 +21,8 @@ namespace o2::framework /// compiler. struct RootSerializationSupport { using TClass = ::TClass; - using FairTMessage = o2::framework::FairTMessage; + using FairInputTBuffer = o2::framework::FairInputTBuffer; + using FairOutputBuffer = o2::framework::FairOutputTBuffer; using TObject = ::TObject; }; diff --git a/Framework/Core/include/Framework/TMessageSerializer.h b/Framework/Core/include/Framework/TMessageSerializer.h index 1f08b456c0218..ca18eb21abfa1 100644 --- a/Framework/Core/include/Framework/TMessageSerializer.h +++ b/Framework/Core/include/Framework/TMessageSerializer.h @@ -16,9 +16,8 @@ #include "Framework/RuntimeError.h" #include -#include +#include #include -#include #include #include #include @@ -28,67 +27,76 @@ namespace o2::framework { -class FairTMessage; +class FairOutputTBuffer; +class FairInputTBuffer; // utilities to produce a span over a byte buffer held by various message types // this is to avoid littering code with casts and conversions (span has a signed index type(!)) -gsl::span as_span(const FairTMessage& msg); +gsl::span as_span(const FairInputTBuffer& msg); +gsl::span as_span(const FairOutputTBuffer& msg); gsl::span as_span(const fair::mq::Message& msg); -class FairTMessage : public TMessage +// A TBufferFile which we can use to serialise data to a FairMQ message. +class FairOutputTBuffer : public TBufferFile { public: - using TMessage::TMessage; - FairTMessage() : TMessage(kMESS_OBJECT) {} - FairTMessage(void* buf, Int_t len) : TMessage(buf, len) { ResetBit(kIsOwner); } - FairTMessage(gsl::span buf) : TMessage(buf.data(), buf.size()) { ResetBit(kIsOwner); } + // This is to serialise data to FairMQ. We embed the pointer to the message + // in the data itself, so that we can use it to reallocate the message if needed. + // The FairMQ message retains ownership of the data. + // When deserialising the root object, keep in mind one needs to skip the 8 bytes + // for the pointer. + FairOutputTBuffer(fair::mq::Message& msg) + : TBufferFile(TBuffer::kWrite, msg.GetSize() - sizeof(char*), embedInItself(msg), false, fairMQrealloc) + { + } + // Helper function to keep track of the FairMQ message that holds the data + // in the data itself. We can use this to make sure the message can be reallocated + // even if we simply have a pointer to the data. Hopefully ROOT will not play dirty + // with us. + void* embedInItself(fair::mq::Message& msg); // helper function to clean up the object holding the data after it is transported. - static void free(void* /*data*/, void* hint); + static char* fairMQrealloc(char* oldData, size_t newSize, size_t oldSize); }; -struct TMessageSerializer { - using StreamerList = std::vector; - using CompressionLevel = int; +class FairInputTBuffer : public TBufferFile +{ + public: + // This is to serialise data to FairMQ. The provided message is expeted to have 8 bytes + // of overhead, where the source embedded the pointer for the reallocation. + // Notice this will break if the sender and receiver are not using the same + // size for a pointer. + FairInputTBuffer(char * data, size_t size) + : TBufferFile(TBuffer::kRead, size-sizeof(char*), data + sizeof(char*), false, nullptr) + { + } +}; - static void Serialize(fair::mq::Message& msg, const TObject* input, - CompressionLevel compressionLevel = -1); +struct TMessageSerializer { + static void Serialize(fair::mq::Message& msg, const TObject* input); template - static void Serialize(fair::mq::Message& msg, const T* input, const TClass* cl, // - CompressionLevel compressionLevel = -1); + static void Serialize(fair::mq::Message& msg, const T* input, const TClass* cl); template static void Deserialize(const fair::mq::Message& msg, std::unique_ptr& output); - static void serialize(FairTMessage& msg, const TObject* input, - CompressionLevel compressionLevel = -1); + static void serialize(o2::framework::FairOutputTBuffer& msg, const TObject* input); template - static void serialize(FairTMessage& msg, const T* input, // - const TClass* cl, - CompressionLevel compressionLevel = -1); + static void serialize(o2::framework::FairOutputTBuffer& msg, const T* input, const TClass* cl); template - static std::unique_ptr deserialize(gsl::span buffer); - template - static inline std::unique_ptr deserialize(std::byte* buffer, size_t size); + static inline std::unique_ptr deserialize(FairInputTBuffer & buffer); }; -inline void TMessageSerializer::serialize(FairTMessage& tm, const TObject* input, - CompressionLevel compressionLevel) +inline void TMessageSerializer::serialize(FairOutputTBuffer& tm, const TObject* input) { - return serialize(tm, input, nullptr, compressionLevel); + return serialize(tm, input, nullptr); } template -inline void TMessageSerializer::serialize(FairTMessage& tm, const T* input, // - const TClass* cl, CompressionLevel compressionLevel) +inline void TMessageSerializer::serialize(FairOutputTBuffer& tm, const T* input, const TClass* cl) { - if (compressionLevel >= 0) { - // if negative, skip to use ROOT default - tm.SetCompressionLevel(compressionLevel); - } - // TODO: check what WriateObject and WriteObjectAny are doing if (cl == nullptr) { tm.WriteObject(input); @@ -98,7 +106,7 @@ inline void TMessageSerializer::serialize(FairTMessage& tm, const T* input, // } template -inline std::unique_ptr TMessageSerializer::deserialize(gsl::span buffer) +inline std::unique_ptr TMessageSerializer::deserialize(FairInputTBuffer & buffer) { TClass* tgtClass = TClass::GetClass(typeid(T)); if (tgtClass == nullptr) { @@ -107,53 +115,32 @@ inline std::unique_ptr TMessageSerializer::deserialize(gsl::span b // FIXME: we need to add consistency check for buffer data to be serialized // at the moment, TMessage might simply crash if an invalid or inconsistent // buffer is provided - FairTMessage tm(buffer); - TClass* serializedClass = tm.GetClass(); + buffer.InitMap(); + TClass* serializedClass = buffer.ReadClass(); + buffer.SetBufferOffset(0); + buffer.ResetMap(); if (serializedClass == nullptr) { throw runtime_error_f("can not read class info from buffer"); } if (tgtClass != serializedClass && serializedClass->GetBaseClass(tgtClass) == nullptr) { throw runtime_error_f("can not convert serialized class %s into target class %s", - tm.GetClass()->GetName(), + serializedClass->GetName(), tgtClass->GetName()); } - return std::unique_ptr(reinterpret_cast(tm.ReadObjectAny(serializedClass))); + return std::unique_ptr(reinterpret_cast(buffer.ReadObjectAny(serializedClass))); } -template -inline std::unique_ptr TMessageSerializer::deserialize(std::byte* buffer, size_t size) +inline void TMessageSerializer::Serialize(fair::mq::Message& msg, const TObject* input) { - return deserialize(gsl::span(buffer, gsl::narrow::size_type>(size))); -} - -inline void FairTMessage::free(void* /*data*/, void* hint) -{ - std::default_delete deleter; - deleter(static_cast(hint)); -} - -inline void TMessageSerializer::Serialize(fair::mq::Message& msg, const TObject* input, - TMessageSerializer::CompressionLevel compressionLevel) -{ - std::unique_ptr tm = std::make_unique(kMESS_OBJECT); - - serialize(*tm, input, input->Class(), compressionLevel); - - msg.Rebuild(tm->Buffer(), tm->BufferSize(), FairTMessage::free, tm.get()); - tm.release(); + FairOutputTBuffer output(msg); + serialize(output, input, input->Class()); } template -inline void TMessageSerializer::Serialize(fair::mq::Message& msg, const T* input, // - const TClass* cl, // - TMessageSerializer::CompressionLevel compressionLevel) +inline void TMessageSerializer::Serialize(fair::mq::Message& msg, const T* input, const TClass* cl) { - std::unique_ptr tm = std::make_unique(kMESS_OBJECT); - - serialize(*tm, input, cl, compressionLevel); - - msg.Rebuild(tm->Buffer(), tm->BufferSize(), FairTMessage::free, tm.get()); - tm.release(); + FairOutputTBuffer output(msg); + serialize(output, input, cl); } template @@ -161,7 +148,8 @@ inline void TMessageSerializer::Deserialize(const fair::mq::Message& msg, std::u { // we know the message will not be modified by this, // so const_cast should be OK here(IMHO). - output = deserialize(as_span(msg)); + FairInputTBuffer input(static_cast(msg.GetData()), static_cast(msg.GetSize())); + output = deserialize(input); } // gsl::narrow is used to do a runtime narrowing check, this might be a bit paranoid, @@ -171,7 +159,7 @@ inline gsl::span as_span(const fair::mq::Message& msg) return gsl::span{static_cast(msg.GetData()), gsl::narrow::size_type>(msg.GetSize())}; } -inline gsl::span as_span(const FairTMessage& msg) +inline gsl::span as_span(const FairInputTBuffer& msg) { return gsl::span{reinterpret_cast(msg.Buffer()), gsl::narrow::size_type>(msg.BufferSize())}; diff --git a/Framework/Core/src/CommonDataProcessors.cxx b/Framework/Core/src/CommonDataProcessors.cxx index 48a3eb1da95b9..02ef5c7bc5b3c 100644 --- a/Framework/Core/src/CommonDataProcessors.cxx +++ b/Framework/Core/src/CommonDataProcessors.cxx @@ -141,9 +141,9 @@ DataProcessorSpec CommonDataProcessors::getOutputObjHistSink(std::vector(ref.payload), static_cast(datah->payloadSize)); + FairInputTBuffer tm(const_cast(ref.payload), static_cast(datah->payloadSize)); InputObject obj; - obj.kind = tm.GetClass(); + obj.kind = tm.ReadClass(); if (obj.kind == nullptr) { LOG(error) << "Cannot read class info from buffer."; return; diff --git a/Framework/Core/src/TMessageSerializer.cxx b/Framework/Core/src/TMessageSerializer.cxx index 5388a6d716cda..9f09c3ade0089 100644 --- a/Framework/Core/src/TMessageSerializer.cxx +++ b/Framework/Core/src/TMessageSerializer.cxx @@ -9,7 +9,38 @@ // granted to it by virtue of its status as an Intergovernmental Organization // or submit itself to any jurisdiction. #include +#include #include #include using namespace o2::framework; + +void* FairOutputTBuffer::embedInItself(fair::mq::Message& msg) { + // The first bytes of the message are used to store the pointer to the message itself + // so that we can reallocate it if needed. + if (sizeof(char*) > msg.GetSize()) { + throw std::runtime_error("Message size too small to embed pointer"); + } + char* data = reinterpret_cast(msg.GetData()); + char* ptr = reinterpret_cast(&msg); + std::memcpy(data, ptr, sizeof(char*)); + return data + sizeof(char*); +} + +// Reallocation function. Get the message pointer from the data and call Rebuild. +char *FairOutputTBuffer::fairMQrealloc(char *oldData, size_t newSize, size_t oldSize) { + auto* msg = reinterpret_cast(oldData - sizeof(char*)); + if (newSize <= msg->GetSize()) { + // no need to reallocate, the message is already big enough + return oldData; + } + // Create a shallow copy of the message + fair::mq::MessagePtr oldMsg = msg->GetTransport()->CreateMessage(); + oldMsg->Copy(*msg); + // Copy the old data while rebuilding. Reference counting should make + // sure the old message is not deleted until the new one is ready. + msg->Rebuild(newSize, fair::mq::Alignment{64}); + memcpy(msg->GetData(), oldMsg->GetData(), oldSize); + + return reinterpret_cast(msg->GetData()) + sizeof(char*); +} diff --git a/Framework/Core/test/test_DataRefUtils.cxx b/Framework/Core/test/test_DataRefUtils.cxx index 37da7912bfe8b..081adc81ebf69 100644 --- a/Framework/Core/test/test_DataRefUtils.cxx +++ b/Framework/Core/test/test_DataRefUtils.cxx @@ -21,17 +21,37 @@ using namespace o2::framework; +TEST_CASE("PureRootTest") { + TBufferFile buffer(TBuffer::kWrite); + TObjString s("test"); + buffer.WriteObject(&s); + + TBufferFile buffer2(TBuffer::kRead, buffer.BufferSize(), buffer.Buffer(), false); + buffer2.SetReadMode(); + buffer2.InitMap(); + TClass *storedClass = buffer2.ReadClass(); + // ReadClass advances the buffer, so we need to reset it. + buffer2.SetBufferOffset(0); + buffer2.ResetMap(); + REQUIRE(storedClass != nullptr); + auto *outS = (TObjString*)buffer2.ReadObjectAny(storedClass); + REQUIRE(outS != nullptr); + REQUIRE(outS->GetString() == "test"); +} + // Simple test to do root deserialization. TEST_CASE("TestRootSerialization") { DataRef ref; - TMessage* tm = new TMessage(kMESS_OBJECT); + auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); + auto msg = transport->CreateMessage(4096); + FairOutputTBuffer tm(*msg); auto sOrig = std::make_unique("test"); - tm->WriteObject(sOrig.get()); + tm << sOrig.get(); o2::header::DataHeader dh; dh.payloadSerializationMethod = o2::header::gSerializationMethodROOT; - ref.payload = tm->Buffer(); - dh.payloadSize = tm->BufferSize(); + ref.payload = (char*)msg->GetData(); + dh.payloadSize = (size_t)msg->GetSize(); ref.header = reinterpret_cast(&dh); // Check by using the same type diff --git a/Framework/Core/test/test_TMessageSerializer.cxx b/Framework/Core/test/test_TMessageSerializer.cxx index bc5f817400a44..395b3779421a2 100644 --- a/Framework/Core/test/test_TMessageSerializer.cxx +++ b/Framework/Core/test/test_TMessageSerializer.cxx @@ -11,6 +11,7 @@ #include "Framework/TMessageSerializer.h" #include "Framework/RuntimeError.h" +#include #include "TestClasses.h" #include #include @@ -49,14 +50,14 @@ TEST_CASE("TestTMessageSerializer") array.SetOwner(); array.Add(new TNamed(testname, testtitle)); - FairTMessage msg; - TMessageSerializer::serialize(msg, &array); + auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); + auto msg = transport->CreateMessage(4096); + FairOutputTBuffer buffer(*msg); + TMessageSerializer::serialize(buffer, &array); - auto buf = as_span(msg); - REQUIRE(buf.size() == msg.BufferSize()); - REQUIRE(static_cast(buf.data()) == static_cast(msg.Buffer())); + FairInputTBuffer msg2((char*)msg->GetData(), msg->GetSize()); // test deserialization with TObject as target class (default) - auto out = TMessageSerializer::deserialize(buf); + auto out = TMessageSerializer::deserialize(msg2); auto* outarr = dynamic_cast(out.get()); REQUIRE(out.get() == outarr); @@ -66,9 +67,9 @@ TEST_CASE("TestTMessageSerializer") REQUIRE(named->GetTitle() == std::string(testtitle)); // test deserialization with a wrong target class and check the exception - REQUIRE_THROWS_AS(TMessageSerializer::deserialize(buf), o2::framework::RuntimeErrorRef); + REQUIRE_THROWS_AS(TMessageSerializer::deserialize(msg2), o2::framework::RuntimeErrorRef); - REQUIRE_THROWS_MATCHES(TMessageSerializer::deserialize(buf), o2::framework::RuntimeErrorRef, + REQUIRE_THROWS_MATCHES(TMessageSerializer::deserialize(msg2), o2::framework::RuntimeErrorRef, ExceptionMatcher("can not convert serialized class TObjArray into target class TNamed")); } @@ -87,23 +88,29 @@ TEST_CASE("TestTMessageSerializer_NonTObject") TClass* cl = TClass::GetClass("std::vector"); REQUIRE(cl != nullptr); - FairTMessage msg; + auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); + auto msg = transport->CreateMessage(4096); + FairOutputTBuffer buffer(*msg); char* in = reinterpret_cast(&data); - TMessageSerializer::serialize(msg, in, cl); + TMessageSerializer::serialize(buffer, in, cl); + FairInputTBuffer msg2((char*)msg->GetData(), msg->GetSize()); - auto out = TMessageSerializer::deserialize>(as_span(msg)); + auto out = TMessageSerializer::deserialize>(msg2); REQUIRE(out); REQUIRE((*out.get()).size() == 2); REQUIRE((*out.get())[0] == o2::test::Polymorphic(0xaffe)); REQUIRE((*out.get())[1] == o2::test::Polymorphic(0xd00f)); // test deserialization with a wrong target class and check the exception - REQUIRE_THROWS_AS(TMessageSerializer::deserialize(as_span(msg)), RuntimeErrorRef); + REQUIRE_THROWS_AS(TMessageSerializer::deserialize(msg2), RuntimeErrorRef); } TEST_CASE("TestTMessageSerializer_InvalidBuffer") { const char* buffer = "this is for sure not a serialized ROOT object"; + auto transport = fair::mq::TransportFactory::CreateTransportFactory("zeromq"); + auto msg = transport->CreateMessage(strlen(buffer) + 8); + memcpy((char*)msg->GetData() + 8, buffer, strlen(buffer)); // test deserialization of invalid buffer and check the exception // FIXME: at the moment, TMessage fails directly with a segfault, which it shouldn't do /* @@ -119,5 +126,6 @@ TEST_CASE("TestTMessageSerializer_InvalidBuffer") struct Dummy { }; auto matcher = ExceptionMatcher("class is not ROOT-serializable: ZL22CATCH2_INTERNAL_TEST_4vE5Dummy"); - REQUIRE_THROWS_MATCHES(TMessageSerializer::deserialize((std::byte*)buffer, strlen(buffer)), o2::framework::RuntimeErrorRef, matcher); + FairInputTBuffer msg2((char*)msg->GetData(), msg->GetSize()); + REQUIRE_THROWS_MATCHES(TMessageSerializer::deserialize(msg2), o2::framework::RuntimeErrorRef, matcher); } diff --git a/Framework/Utils/test/test_RootTreeWriter.cxx b/Framework/Utils/test/test_RootTreeWriter.cxx index 3194508f3d775..62e1eb62cb4f1 100644 --- a/Framework/Utils/test/test_RootTreeWriter.cxx +++ b/Framework/Utils/test/test_RootTreeWriter.cxx @@ -179,6 +179,7 @@ TEST_CASE("test_RootTreeWriter") auto createSerializedMessage = [&transport, &store](DataHeader&& dh, auto& data) { fair::mq::MessagePtr payload = transport->CreateMessage(); + payload->Rebuild(4096, {64}); auto* cl = TClass::GetClass(typeid(decltype(data))); TMessageSerializer().Serialize(*payload, &data, cl); dh.payloadSize = payload->GetSize(); diff --git a/Utilities/Mergers/src/ObjectStore.cxx b/Utilities/Mergers/src/ObjectStore.cxx index e88358507c31e..3bb49f1dfc9d8 100644 --- a/Utilities/Mergers/src/ObjectStore.cxx +++ b/Utilities/Mergers/src/ObjectStore.cxx @@ -38,7 +38,7 @@ static std::string concat(Args&&... arguments) return std::move(ss.str()); } -void* readObject(const TClass* type, o2::framework::FairTMessage& ftm) +void* readObject(const TClass* type, o2::framework::FairInputTBuffer& ftm) { using namespace std::string_view_literals; auto* object = ftm.ReadObjectAny(type); @@ -60,7 +60,7 @@ MergeInterface* castToMergeInterface(bool inheritsFromTObject, void* object, TCl return objectAsMergeInterface; } -std::optional extractVector(o2::framework::FairTMessage& ftm, const TClass* storedClass) +std::optional extractVector(o2::framework::FairInputTBuffer& ftm, const TClass* storedClass) { if (!storedClass->InheritsFrom(TClass::GetClass(typeid(VectorOfRawTObjects)))) { return std::nullopt; @@ -88,11 +88,14 @@ ObjectStore extractObjectFrom(const framework::DataRef& ref) throw std::runtime_error(concat(errorPrefix, "It is not ROOT-serialized"sv)); } - o2::framework::FairTMessage ftm(const_cast(ref.payload), o2::framework::DataRefUtils::getPayloadSize(ref)); - auto* storedClass = ftm.GetClass(); + o2::framework::FairInputTBuffer ftm(const_cast(ref.payload), o2::framework::DataRefUtils::getPayloadSize(ref)); + ftm.InitMap(); + auto* storedClass = ftm.ReadClass(); if (storedClass == nullptr) { throw std::runtime_error(concat(errorPrefix, "Unknown stored class"sv)); } + ftm.SetBufferOffset(0); + ftm.ResetMap(); if (const auto extractedVector = extractVector(ftm, storedClass)) { return extractedVector.value(); diff --git a/Utilities/Mergers/test/benchmark_Types.cxx b/Utilities/Mergers/test/benchmark_Types.cxx index 790fd329185ea..736685c5746b8 100644 --- a/Utilities/Mergers/test/benchmark_Types.cxx +++ b/Utilities/Mergers/test/benchmark_Types.cxx @@ -165,11 +165,16 @@ auto measure = [](Measurement m, auto* o, auto* i) -> double { tm->WriteObject(o); start = std::chrono::high_resolution_clock::now(); - o2::framework::FairTMessage ftm(const_cast(tm->Buffer()), tm->BufferSize()); - auto* storedClass = ftm.GetClass(); + // Needed to take into account that FairInputTBuffer expects the first 8 bytes to be the + // allocator pointer, which is not present in the TMessage buffer. + o2::framework::FairInputTBuffer ftm(const_cast(tm->Buffer() - 8), tm->BufferSize() + 8); + ftm.InitMap(); + auto* storedClass = ftm.ReadClass(); if (storedClass == nullptr) { throw std::runtime_error("Unknown stored class"); } + ftm.SetBufferOffset(0); + ftm.ResetMap(); auto* tObjectClass = TClass::GetClass(typeid(TObject)); if (!storedClass->InheritsFrom(tObjectClass)) { @@ -738,4 +743,4 @@ int main(int argc, const char* argv[]) file.close(); return 0; -} \ No newline at end of file +}