1111#include " TMVA/RModel.hxx"
1212#include " TMVA/SOFIE_common.hxx"
1313
14- namespace TMVA {
15- namespace Experimental {
16- namespace SOFIE {
14+ namespace TMVA ::Experimental::SOFIE {
1715
1816namespace {
1917const std::string SP = " " ;
@@ -356,7 +354,7 @@ std::string RModel::AllocateIntermediateMemory(std::span<const std::string_view>
356354 std::string typeName = ConvertTypeToString (GetTensorType (name));
357355 code << " \n // Allocating memory for intermediate tensor " << name << " with size " << size << " bytes" ;
358356 code << " \n "
359- << typeName << " * tensor_ " << name << " = reinterpret_cast<" << typeName
357+ << typeName << " * " << AddTensorMember ( name) << " = reinterpret_cast<" << typeName
360358 << " *>(fIntermediateMemoryPool.data() + " << location << " );\n " ;
361359 };
362360
@@ -546,6 +544,8 @@ void RModel::Initialize(const std::map<std::string, size_t> & inputParams, bool
546544 }
547545 fIntermediateTensorInfos .clear ();
548546 fDynamicTensorInfos .clear ();
547+ fDataMembers .clear ();
548+ fPointerMemberNames .clear ();
549549
550550
551551 // loop on inputs and see if shape can be full specified
@@ -692,7 +692,8 @@ void RModel::InitializeSubGraph(std::shared_ptr<RModel> graph) {
692692// Function to generate the code for declaring and initializing constant tensors
693693// This is for tensors which are not part of weight files and can be created from the Constant operator
694694template <typename T>
695- std::string GenerateConstantTensorCode (const std::pair<std::string, InitializedTensor> &t)
695+ std::string GenerateConstantTensorCode (const std::pair<std::string, InitializedTensor> &t,
696+ std::function<std::string(std::string const &)> addTensorMember)
696697{
697698 std::stringstream strs;
698699 std::string type = ConvertTypeToString (t.second .type ());
@@ -714,15 +715,15 @@ std::string GenerateConstantTensorCode(const std::pair<std::string, InitializedT
714715 } while (sameData && idx < length);
715716 }
716717 if (allocateOnStack) {
717- strs << type << " tensor_ " << t.first << " [" << length << " ] = " << ConvertValuesToString (length, data) << " ;\n " ;
718+ strs << type << " " << addTensorMember ( t.first ) << " [" << length << " ] = " << ConvertValuesToString (length, data) << " ;\n " ;
718719 } else {
719720 strs << " std::vector<" << type << " > fTensor_" << t.first << " = " ;
720721 if (sameData)
721722 strs << " std::vector<" << type << " >(" << length << " , " << ConvertValToString (data[0 ]) << " );\n " ;
722723 else {
723724 strs << ConvertValuesToString (length, data) << " ;\n " ;
724725 }
725- strs << type << " * tensor_ " + t.first + " = fTensor_" + t.first + " .data();\n " ;
726+ strs << type << " * " + addTensorMember ( t.first ) + " = fTensor_" + t.first + " .data();\n " ;
726727 }
727728 return strs.str ();
728729}
@@ -736,11 +737,12 @@ void RModel::GenerateInitializedTensorInfo()
736737 for (auto &i : fInitializedTensors ) {
737738 if (i.second .IsNotWritable ()) continue ;
738739 if (!fUseWeightFile || i.second .IsConstantTensor () || !i.second .IsWeightTensor () ) {
740+ auto addTensorMember = [this ](std::string const &name) -> std::string { return this ->AddTensorMember (name); };
739741 if (i.second .type () == ETensorType::FLOAT) {
740- fGC += GenerateConstantTensorCode<float >(i);
742+ fGC += GenerateConstantTensorCode<float >(i, addTensorMember );
741743 fConstantTensorSize += ConvertShapeToLength (i.second .shape ()) * 4 ;
742744 } else if (i.second .type () == ETensorType::INT64) {
743- fGC += GenerateConstantTensorCode<int64_t >(i);
745+ fGC += GenerateConstantTensorCode<int64_t >(i, addTensorMember );
744746 fConstantTensorSize += ConvertShapeToLength (i.second .shape ()) * 8 ;
745747 }
746748
@@ -749,7 +751,7 @@ void RModel::GenerateInitializedTensorInfo()
749751 size_t length = ConvertShapeToLength (i.second .shape ());
750752 if (i.second .type () == ETensorType::FLOAT) {
751753 fGC += " std::vector<float> fTensor_" + i.first + " = std::vector<float>(" + std::to_string (length) + " );\n " ;
752- fGC += " float * tensor_ " + i.first + " = fTensor_" + i.first + " .data();\n " ;
754+ fGC += " float * " + AddTensorMember ( i.first ) + " = fTensor_" + i.first + " .data();\n " ;
753755 fWeightsTensorSize += ConvertShapeToLength (i.second .shape ()) * 4 ;
754756 }
755757 }
@@ -774,7 +776,7 @@ void RModel::GenerateIntermediateTensorInfo() {
774776 bool is_alias = (IsAliasTensor (i.first ));
775777 if (i.second .type == ETensorType::BOOL && !is_alias) {
776778 tensor_declaration_block += " std::vector<std::uint8_t> fTensor_" + i.first + " = std::vector<std::uint8_t>(" + std::to_string (ConvertShapeToLength (i.second .shape )) + " );\n " ;
777- tensor_declaration_block += " std::uint8_t * tensor_ " + i.first + " = fTensor_" + i.first + " .data();\n " ;
779+ tensor_declaration_block += " std::uint8_t * " + AddTensorMember ( i.first ) + " = fTensor_" + i.first + " .data();\n " ;
778780 continue ;
779781 }
780782 bool is_extended = (fOptimizationLevel == OptimizationLevel::kExtended );
@@ -788,22 +790,22 @@ void RModel::GenerateIntermediateTensorInfo() {
788790
789791 if (i.second .type == ETensorType::FLOAT) {
790792 tensor_declaration_block += " std::vector<float> fTensor_" + i.first + " = std::vector<float>(" + std::to_string (length) + " );\n " ;
791- tensor_declaration_block += " float * tensor_ " + i.first + " = fTensor_" + i.first + " .data();\n " ;
793+ tensor_declaration_block += " float * " + AddTensorMember ( i.first ) + " = fTensor_" + i.first + " .data();\n " ;
792794 fOtherTensorSize += 4 * length;
793795 }
794796 else if (i.second .type == ETensorType::DOUBLE) {
795797 tensor_declaration_block += " std::vector<double> fTensor_" + i.first + " = std::vector<double>(" + std::to_string (length) + " );\n " ;
796- tensor_declaration_block += " double * tensor_ " + i.first + " = fTensor_" + i.first + " .data();\n " ;
798+ tensor_declaration_block += " double * " + AddTensorMember ( i.first ) + " = fTensor_" + i.first + " .data();\n " ;
797799 fOtherTensorSize += 8 * length;
798800 }
799801 else if (i.second .type == ETensorType::INT64) {
800802 tensor_declaration_block += " std::vector<int64_t> fTensor_" + i.first + " = std::vector<int64_t>(" + std::to_string (length) + " );\n " ;
801- tensor_declaration_block += " int64_t * tensor_ " + i.first + " = fTensor_" + i.first + " .data();\n " ;
803+ tensor_declaration_block += " int64_t * " + AddTensorMember ( i.first ) + " = fTensor_" + i.first + " .data();\n " ;
802804 fOtherTensorSize += 8 * length;
803805 }
804806 }
805807 if (is_alias) {
806- tensor_declaration_block += ConvertTypeToString (i.second .type ) + " * tensor_ " + i.first + " = nullptr;\n " ;
808+ tensor_declaration_block += ConvertTypeToString (i.second .type ) + " * " + AddTensorMember ( i.first ) + " = nullptr;\n " ;
807809 }
808810
809811 }
@@ -816,7 +818,7 @@ void RModel::GenerateIntermediateTensorInfo() {
816818 if (!fDynamicTensorInfos .empty ()) {
817819 fGC += " //--- declare the dynamic tensors\n " ;
818820 for (auto &i : fDynamicTensorInfos ) {
819- fGC += ConvertTypeToString (i.second .type ) + " * tensor_ " + i.first + " = nullptr;\n " ;
821+ fGC += ConvertTypeToString (i.second .type ) + " * " + AddTensorMember ( i.first ) + " = nullptr;\n " ;
820822 }
821823 fGC += " //--- dynamic tensors pool\n " ;
822824 fGC += " std::vector<char> fDynamicMemoryPool;\n " ;
@@ -995,9 +997,9 @@ void RModel::GenerateOutput()
995997 if (!doInferArgs.empty ())
996998 doInferArgs += " ," ;
997999 for (std::string const &name : fOutputTensorNames ) {
998- bool isIntermediate = fIntermediateTensorInfos .count (name) > 0 ;
999- std::string n = isIntermediate ? std::to_string (ConvertShapeToLength (GetTensorShape (name)))
1000- : ConvertDimShapeToLength (GetDynamicTensorShape (name));
1000+ bool isDynamic = fDynamicTensorInfos .count (name) > 0 ;
1001+ std::string n = !isDynamic ? std::to_string (ConvertShapeToLength (GetTensorShape (name)))
1002+ : ConvertDimShapeToLength (GetDynamicTensorShape (name));
10011003 fGC += SP + " std::vector<" + typeForOutput (GetTensorType (name)) + " > output_tensor_" + name + " (" + n + " );\n " ;
10021004 doInferArgs += " output_tensor_" + name + " .data()," ;
10031005 }
@@ -1060,7 +1062,7 @@ void RModel::GenerateSessionCode()
10601062 doInferSignature.back () = ' ' ;
10611063
10621064 if (fUseSession && !fIsGNNComponent ) {
1063- doInferSignature = sessionName + " const * session, " + doInferSignature;
1065+ doInferSignature = sessionName + " * session, " + doInferSignature;
10641066 }
10651067
10661068 doInferSignature = " void doInfer(" + doInferSignature + " )" ;
@@ -1205,22 +1207,14 @@ void RModel::GenerateSessionCode()
12051207 fGC += " \n " ;
12061208
12071209 if (fUseSession && !fIsGNNComponent ) {
1208- fGC += " auto const& sess = session[0];\n " ;
1209- std::vector<std::string> names;
1210- for (auto const & it: fInitializedTensors ) {
1211- names.push_back (it.first );
1210+ fGC += " auto & sess = session[0];\n " ;
1211+ for (auto const & name: fDataMembers ) {
1212+ fGC += " auto & " + name + " = sess." + name + " ;\n " ;
12121213 }
1213- for (auto const & it: fIntermediateTensorInfos ) {
1214- names.push_back (it.first );
1215- }
1216- std::vector<std::string> added;
1217- for (auto const & name : names) {
1214+ for (auto const & name: fPointerMemberNames ) {
12181215 auto found = std::find (fOutputTensorNames .begin (), fOutputTensorNames .end (), name);
1219- auto found2 = std::find (added.begin (), added.end (), name);
1220- // Output tensors are passed directly via the function call
1221- if (found == fOutputTensorNames .end () && found2 == added.end ()) {
1216+ if (found == fOutputTensorNames .end ()) {
12221217 fGC += " auto & tensor_" + name + " = sess.tensor_" + name + " ;\n " ;
1223- added.push_back (name);
12241218 }
12251219 }
12261220 fGC += " \n " ;
@@ -1238,6 +1232,17 @@ void RModel::GenerateSessionCode()
12381232 fGC += (fOperators [op_idx]->Generate (std::to_string (op_idx)));
12391233 }
12401234
1235+ if (fUseSession && !fIsGNNComponent ) {
1236+ for (auto const & name: fPointerMemberNames ) {
1237+ auto found = std::find (fOutputTensorNames .begin (), fOutputTensorNames .end (), name);
1238+ if (IsConstantTensor (name) && found != fOutputTensorNames .end ()) {
1239+ std::string t = " sess.tensor_" + name;
1240+ fGC += " std::copy(std::begin(" + t + " ), std::end(" + t + " ), tensor_" + name + " );\n " ;
1241+ }
1242+ }
1243+ fGC += " \n " ;
1244+ }
1245+
12411246 fGC += " }\n " ;
12421247}
12431248
@@ -1657,6 +1662,4 @@ void RModel::Streamer(TBuffer &R__b) {
16571662 }
16581663}
16591664
1660- }// SOFIE
1661- }// Experimental
1662- }// TMVA
1665+ } // namespace SOFIE::Experimental::TMVA
0 commit comments