Skip to content

Commit

Permalink
DPL: avoid TMessage usage
Browse files Browse the repository at this point in the history
TMessage does not allow for non owned buffers, so we end up having
an extra buffer in private memory for (de)serializing. Using TBufferFile
directly allows to avoid that, so this moves the whole ROOT serialization
support in DPL to use it.
  • Loading branch information
ktf committed Feb 24, 2024
1 parent fc7baac commit 8e51310
Show file tree
Hide file tree
Showing 15 changed files with 191 additions and 130 deletions.
15 changes: 3 additions & 12 deletions Detectors/Base/include/DetectorsBase/Detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
#include <typeinfo>
#include <type_traits>
#include <string>
#include <TMessage.h>
#include "CommonUtils/ShmManager.h"
#include "CommonUtils/ShmAllocator.h"
#include <sys/shm.h>
Expand All @@ -42,9 +41,7 @@

#include <fairmq/FwdDecls.h>

namespace o2
{
namespace base
namespace o2::base
{

/// This is the basic class for any AliceO2 detector module, whether it is
Expand Down Expand Up @@ -260,17 +257,12 @@ T decodeShmMessage(fair::mq::Parts& dataparts, int index, bool*& busy)
}

// this goes into the source
void attachMessageBufferToParts(fair::mq::Parts& parts, fair::mq::Channel& channel,
void* data, size_t size, void (*func_ptr)(void* data, void* hint), void* hint);
void attachMessageBufferToParts(fair::mq::Parts& parts, fair::mq::Channel& channel, void* data, TClass* cl);

template <typename Container>
void attachTMessage(Container const& hits, fair::mq::Channel& channel, fair::mq::Parts& parts)
{
TMessage* tmsg = new TMessage();
tmsg->WriteObjectAny((void*)&hits, TClass::GetClass(typeid(hits)));
attachMessageBufferToParts(
parts, channel, tmsg->Buffer(), tmsg->BufferSize(),
[](void* data, void* hint) { delete static_cast<TMessage*>(hint); }, tmsg);
attachMessageBufferToParts(parts, channel, (void*)&hits, TClass::GetClass(typeid(hits)));
}

void* decodeTMessageCore(fair::mq::Parts& dataparts, int index);
Expand Down Expand Up @@ -746,7 +738,6 @@ class DetImpl : public o2::base::Detector

ClassDefOverride(DetImpl, 0);
};
} // namespace base
} // namespace o2

#endif
34 changes: 17 additions & 17 deletions Detectors/Base/src/Detector.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "DetectorsBase/MaterialManager.h"
#include "DetectorsCommonDataFormats/DetID.h"
#include "Field/MagneticField.h"
#include "Framework/TMessageSerializer.h"
#include "TString.h" // for TString
#include "TGeoManager.h"

Expand Down Expand Up @@ -196,16 +197,18 @@ int Detector::registerSensitiveVolumeAndGetVolID(std::string const& name)
#include <fairmq/Message.h>
#include <fairmq/Parts.h>
#include <fairmq/Channel.h>
namespace o2
{
namespace base
namespace o2::base
{
// this goes into the source
void attachMessageBufferToParts(fair::mq::Parts& parts, fair::mq::Channel& channel, void* data, size_t size,
void (*free_func)(void* data, void* hint), void* hint)
{
std::unique_ptr<fair::mq::Message> message(channel.NewMessage(data, size, free_func, hint));
parts.AddPart(std::move(message));
void attachMessageBufferToParts(fair::mq::Parts& parts, fair::mq::Channel& channel, void* data, TClass* cl) {
auto msg = channel.Transport()->CreateMessage(4096, fair::mq::Alignment{64});
// This will serialize the data directly into the message buffer, without any further
// buffer or copying. Notice how the message will have 8 bytes of header and then
// the serialized data as TBufferFile. In principle one could construct a serialized TMessage payload
// however I did not manage to get it to work for every case.
o2::framework::FairOutputTBuffer buffer(*msg);
o2::framework::TMessageSerializer::serialize(buffer, data, cl);
parts.AddPart(std::move(msg));
}
void attachDetIDHeaderMessage(int id, fair::mq::Channel& channel, fair::mq::Parts& parts)
{
Expand Down Expand Up @@ -246,17 +249,14 @@ void* decodeShmCore(fair::mq::Parts& dataparts, int index, bool*& busy)

void* decodeTMessageCore(fair::mq::Parts& dataparts, int index)
{
class TMessageWrapper : public TMessage
{
public:
TMessageWrapper(void* buf, Int_t len) : TMessage(buf, len) { ResetBit(kIsOwner); }
~TMessageWrapper() override = default;
};
auto rawmessage = std::move(dataparts.At(index));
auto message = std::make_unique<TMessageWrapper>(rawmessage->GetData(), rawmessage->GetSize());
return message.get()->ReadObjectAny(message.get()->GetClass());
o2::framework::FairInputTBuffer buffer((char*)rawmessage->GetData(), rawmessage->GetSize());
buffer.InitMap();
auto *cl = buffer.ReadClass();
buffer.SetBufferOffset(0);
buffer.ResetMap();
return buffer.ReadObjectAny(cl);
}

} // namespace base
} // namespace o2
ClassImp(o2::base::Detector);
2 changes: 2 additions & 0 deletions Framework/AnalysisSupport/src/AODJAlienReaderHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
#include "Framework/AlgorithmSpec.h"
#include "Framework/Logger.h"
#include <Monitoring/Monitoring.h>

#include <uv.h>
class TFile;

namespace o2::framework::readers
{
Expand Down
1 change: 1 addition & 0 deletions Framework/Core/include/Framework/DataAllocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ class DataAllocator
} else if constexpr (has_root_dictionary<T>::value == true || is_specialization_v<T, ROOTSerialized> == true) {
// Serialize a snapshot of an object with root dictionary
payloadMessage = proxy.createOutputMessage(routeIndex);
payloadMessage->Rebuild(4096, {64});
if constexpr (is_specialization_v<T, ROOTSerialized> == true) {
// Explicitely ROOT serialize a snapshot of object.
// An object wrapped into type `ROOTSerialized` is explicitely marked to be ROOT serialized
Expand Down
13 changes: 10 additions & 3 deletions Framework/Core/include/Framework/DataRefUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,15 @@ struct DataRefUtils {
throw runtime_error("Attempt to extract a TMessage from non-ROOT serialised message");
}

typename RSS::FairTMessage ftm(const_cast<char*>(ref.payload), payloadSize);
auto* storedClass = ftm.GetClass();
typename RSS::FairInputTBuffer ftm(const_cast<char*>(ref.payload), payloadSize);
auto* requestedClass = RSS::TClass::GetClass(typeid(T));
ftm.InitMap();
auto* storedClass = ftm.ReadClass();
// should always have the class description if has_root_dictionary is true
assert(requestedClass != nullptr);

ftm.SetBufferOffset(0);
ftm.ResetMap();
auto* object = ftm.ReadObjectAny(storedClass);
if (object == nullptr) {
throw runtime_error_f("Failed to read object with name %s from message using ROOT serialization.",
Expand Down Expand Up @@ -146,7 +149,11 @@ struct DataRefUtils {
throw runtime_error("ROOT serialization not supported, dictionary not found for data type");
}

typename RSS::FairTMessage ftm(const_cast<char*>(ref.payload), payloadSize);
typename RSS::FairInputTBuffer ftm(const_cast<char*>(ref.payload), payloadSize);
ftm.InitMap();
auto *classInfo = ftm.ReadClass();
ftm.SetBufferOffset(0);
ftm.ResetMap();
result.reset(static_cast<wrapped*>(ftm.ReadObjectAny(cl)));
if (result.get() == nullptr) {
throw runtime_error_f("Unable to extract class %s", cl == nullptr ? "<name not available>" : cl->GetName());
Expand Down
3 changes: 3 additions & 0 deletions Framework/Core/include/Framework/RootMessageContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ class RootSerializedObject : public MessageContext::ContextObject
fair::mq::Parts finalize() final
{
assert(mParts.Size() == 1);
if (mPayloadMsg->GetSize() < sizeof(char*)) {
mPayloadMsg->Rebuild(4096, {64});
}
TMessageSerializer::Serialize(*mPayloadMsg, mObject.get(), nullptr);
mParts.AddPart(std::move(mPayloadMsg));
return ContextObject::finalize();
Expand Down
3 changes: 2 additions & 1 deletion Framework/Core/include/Framework/RootSerializationSupport.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ namespace o2::framework
/// compiler.
struct RootSerializationSupport {
using TClass = ::TClass;
using FairTMessage = o2::framework::FairTMessage;
using FairInputTBuffer = o2::framework::FairInputTBuffer;
using FairOutputBuffer = o2::framework::FairOutputTBuffer;
using TObject = ::TObject;
};

Expand Down
130 changes: 59 additions & 71 deletions Framework/Core/include/Framework/TMessageSerializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@
#include "Framework/RuntimeError.h"

#include <TList.h>
#include <TMessage.h>
#include <TBufferFile.h>
#include <TObjArray.h>
#include <TStreamerInfo.h>
#include <gsl/util>
#include <gsl/span>
#include <gsl/narrow>
Expand All @@ -28,67 +27,76 @@

namespace o2::framework
{
class FairTMessage;
class FairOutputTBuffer;
class FairInputTBuffer;

// utilities to produce a span over a byte buffer held by various message types
// this is to avoid littering code with casts and conversions (span has a signed index type(!))
gsl::span<std::byte> as_span(const FairTMessage& msg);
gsl::span<std::byte> as_span(const FairInputTBuffer& msg);
gsl::span<std::byte> as_span(const FairOutputTBuffer& msg);
gsl::span<std::byte> as_span(const fair::mq::Message& msg);

class FairTMessage : public TMessage
// A TBufferFile which we can use to serialise data to a FairMQ message.
class FairOutputTBuffer : public TBufferFile
{
public:
using TMessage::TMessage;
FairTMessage() : TMessage(kMESS_OBJECT) {}
FairTMessage(void* buf, Int_t len) : TMessage(buf, len) { ResetBit(kIsOwner); }
FairTMessage(gsl::span<std::byte> buf) : TMessage(buf.data(), buf.size()) { ResetBit(kIsOwner); }
// This is to serialise data to FairMQ. We embed the pointer to the message
// in the data itself, so that we can use it to reallocate the message if needed.
// The FairMQ message retains ownership of the data.
// When deserialising the root object, keep in mind one needs to skip the 8 bytes
// for the pointer.
FairOutputTBuffer(fair::mq::Message& msg)
: TBufferFile(TBuffer::kWrite, msg.GetSize() - sizeof(char*), embedInItself(msg), false, fairMQrealloc)
{
}
// Helper function to keep track of the FairMQ message that holds the data
// in the data itself. We can use this to make sure the message can be reallocated
// even if we simply have a pointer to the data. Hopefully ROOT will not play dirty
// with us.
void* embedInItself(fair::mq::Message& msg);
// helper function to clean up the object holding the data after it is transported.
static void free(void* /*data*/, void* hint);
static char* fairMQrealloc(char* oldData, size_t newSize, size_t oldSize);
};

struct TMessageSerializer {
using StreamerList = std::vector<TVirtualStreamerInfo*>;
using CompressionLevel = int;
class FairInputTBuffer : public TBufferFile
{
public:
// This is to serialise data to FairMQ. The provided message is expeted to have 8 bytes
// of overhead, where the source embedded the pointer for the reallocation.
// Notice this will break if the sender and receiver are not using the same
// size for a pointer.
FairInputTBuffer(char * data, size_t size)
: TBufferFile(TBuffer::kRead, size-sizeof(char*), data + sizeof(char*), false, nullptr)
{
}
};

static void Serialize(fair::mq::Message& msg, const TObject* input,
CompressionLevel compressionLevel = -1);
struct TMessageSerializer {
static void Serialize(fair::mq::Message& msg, const TObject* input);

template <typename T>
static void Serialize(fair::mq::Message& msg, const T* input, const TClass* cl, //
CompressionLevel compressionLevel = -1);
static void Serialize(fair::mq::Message& msg, const T* input, const TClass* cl);

template <typename T = TObject>
static void Deserialize(const fair::mq::Message& msg, std::unique_ptr<T>& output);

static void serialize(FairTMessage& msg, const TObject* input,
CompressionLevel compressionLevel = -1);
static void serialize(o2::framework::FairOutputTBuffer& msg, const TObject* input);

template <typename T>
static void serialize(FairTMessage& msg, const T* input, //
const TClass* cl,
CompressionLevel compressionLevel = -1);
static void serialize(o2::framework::FairOutputTBuffer& msg, const T* input, const TClass* cl);

template <typename T = TObject>
static std::unique_ptr<T> deserialize(gsl::span<std::byte> buffer);
template <typename T = TObject>
static inline std::unique_ptr<T> deserialize(std::byte* buffer, size_t size);
static inline std::unique_ptr<T> deserialize(FairInputTBuffer & buffer);
};

inline void TMessageSerializer::serialize(FairTMessage& tm, const TObject* input,
CompressionLevel compressionLevel)
inline void TMessageSerializer::serialize(FairOutputTBuffer& tm, const TObject* input)
{
return serialize(tm, input, nullptr, compressionLevel);
return serialize(tm, input, nullptr);
}

template <typename T>
inline void TMessageSerializer::serialize(FairTMessage& tm, const T* input, //
const TClass* cl, CompressionLevel compressionLevel)
inline void TMessageSerializer::serialize(FairOutputTBuffer& tm, const T* input, const TClass* cl)
{
if (compressionLevel >= 0) {
// if negative, skip to use ROOT default
tm.SetCompressionLevel(compressionLevel);
}

// TODO: check what WriateObject and WriteObjectAny are doing
if (cl == nullptr) {
tm.WriteObject(input);
Expand All @@ -98,7 +106,7 @@ inline void TMessageSerializer::serialize(FairTMessage& tm, const T* input, //
}

template <typename T>
inline std::unique_ptr<T> TMessageSerializer::deserialize(gsl::span<std::byte> buffer)
inline std::unique_ptr<T> TMessageSerializer::deserialize(FairInputTBuffer & buffer)
{
TClass* tgtClass = TClass::GetClass(typeid(T));
if (tgtClass == nullptr) {
Expand All @@ -107,61 +115,41 @@ inline std::unique_ptr<T> TMessageSerializer::deserialize(gsl::span<std::byte> b
// FIXME: we need to add consistency check for buffer data to be serialized
// at the moment, TMessage might simply crash if an invalid or inconsistent
// buffer is provided
FairTMessage tm(buffer);
TClass* serializedClass = tm.GetClass();
buffer.InitMap();
TClass* serializedClass = buffer.ReadClass();
buffer.SetBufferOffset(0);
buffer.ResetMap();
if (serializedClass == nullptr) {
throw runtime_error_f("can not read class info from buffer");
}
if (tgtClass != serializedClass && serializedClass->GetBaseClass(tgtClass) == nullptr) {
throw runtime_error_f("can not convert serialized class %s into target class %s",
tm.GetClass()->GetName(),
serializedClass->GetName(),
tgtClass->GetName());
}
return std::unique_ptr<T>(reinterpret_cast<T*>(tm.ReadObjectAny(serializedClass)));
return std::unique_ptr<T>(reinterpret_cast<T*>(buffer.ReadObjectAny(serializedClass)));
}

template <typename T>
inline std::unique_ptr<T> TMessageSerializer::deserialize(std::byte* buffer, size_t size)
inline void TMessageSerializer::Serialize(fair::mq::Message& msg, const TObject* input)
{
return deserialize<T>(gsl::span<std::byte>(buffer, gsl::narrow<gsl::span<std::byte>::size_type>(size)));
}

inline void FairTMessage::free(void* /*data*/, void* hint)
{
std::default_delete<FairTMessage> deleter;
deleter(static_cast<FairTMessage*>(hint));
}

inline void TMessageSerializer::Serialize(fair::mq::Message& msg, const TObject* input,
TMessageSerializer::CompressionLevel compressionLevel)
{
std::unique_ptr<FairTMessage> tm = std::make_unique<FairTMessage>(kMESS_OBJECT);

serialize(*tm, input, input->Class(), compressionLevel);

msg.Rebuild(tm->Buffer(), tm->BufferSize(), FairTMessage::free, tm.get());
tm.release();
FairOutputTBuffer output(msg);
serialize(output, input, input->Class());
}

template <typename T>
inline void TMessageSerializer::Serialize(fair::mq::Message& msg, const T* input, //
const TClass* cl, //
TMessageSerializer::CompressionLevel compressionLevel)
inline void TMessageSerializer::Serialize(fair::mq::Message& msg, const T* input, const TClass* cl)
{
std::unique_ptr<FairTMessage> tm = std::make_unique<FairTMessage>(kMESS_OBJECT);

serialize(*tm, input, cl, compressionLevel);

msg.Rebuild(tm->Buffer(), tm->BufferSize(), FairTMessage::free, tm.get());
tm.release();
FairOutputTBuffer output(msg);
serialize(output, input, cl);
}

template <typename T>
inline void TMessageSerializer::Deserialize(const fair::mq::Message& msg, std::unique_ptr<T>& output)
{
// we know the message will not be modified by this,
// so const_cast should be OK here(IMHO).
output = deserialize(as_span(msg));
FairInputTBuffer input(static_cast<char*>(msg.GetData()), static_cast<int>(msg.GetSize()));
output = deserialize(input);
}

// gsl::narrow is used to do a runtime narrowing check, this might be a bit paranoid,
Expand All @@ -171,7 +159,7 @@ inline gsl::span<std::byte> as_span(const fair::mq::Message& msg)
return gsl::span<std::byte>{static_cast<std::byte*>(msg.GetData()), gsl::narrow<gsl::span<std::byte>::size_type>(msg.GetSize())};
}

inline gsl::span<std::byte> as_span(const FairTMessage& msg)
inline gsl::span<std::byte> as_span(const FairInputTBuffer& msg)
{
return gsl::span<std::byte>{reinterpret_cast<std::byte*>(msg.Buffer()),
gsl::narrow<gsl::span<std::byte>::size_type>(msg.BufferSize())};
Expand Down
Loading

0 comments on commit 8e51310

Please sign in to comment.