Skip to content

Commit 9ea0f01

Browse files
committed
Continue
1 parent 65b184a commit 9ea0f01

File tree

7 files changed

+121
-45
lines changed

7 files changed

+121
-45
lines changed

tmva/sofie/inc/TMVA/RModel.hxx

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,13 @@ private:
3434
std::vector<std::string> fDimShapeNames; // parameter names used to define the shapes
3535
std::vector<std::string> fOutputTensorNames;
3636
std::vector<std::string> fInputTensorNames; // input tensor names using ONNX order
37+
std::vector<std::string> fDataMembers;
38+
std::vector<std::string> fPointerMemberNames;
3739

38-
40+
inline std::string AddTensorMember(std::string const &name) {
41+
fPointerMemberNames.push_back(name);
42+
return "tensor_" + name;
43+
}
3944

4045
std::vector<std::unique_ptr<ROperator>> fOperators;
4146

@@ -63,6 +68,11 @@ public:
6368
std::vector<Dim> GetDimTensorShape(const std::string & name) const;
6469
std::vector<Dim> GetDynamicTensorShape(const std::string & name) const ;
6570

71+
inline std::string AddDataMember(std::string const &name) {
72+
fDataMembers.push_back(name);
73+
return name;
74+
}
75+
6676
// get the values for the tensor representing a shape
6777
const std::vector<Dim> & GetShapeTensorValues(const std::string & tensor_name) const;
6878

tmva/sofie/inc/TMVA/ROperator_LSTM.hxx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ template <typename T> class ROperator_LSTM final : public ROperator {
5757

5858
std::string fType; ///< Type of the tensors
5959

60+
int fCounter = 0;
61+
6062
public:
6163
/*! Default constructor of ROperator_LSTM */
6264
ROperator_LSTM() {}

tmva/sofie/inc/TMVA/ROperator_LSTM.icc

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,15 @@ auto ROperator_LSTM<T>::ShapeInference(std::vector<std::vector<size_t>> input)
3636
}
3737
}
3838

39+
namespace Internal {
40+
41+
inline int &lstmCounter() {
42+
static int counter = 0;
43+
return counter;
44+
}
45+
46+
}
47+
3948
template<typename T>
4049
auto ROperator_LSTM<T>::Initialize(RModel& model)
4150
-> void {
@@ -230,13 +239,37 @@ auto ROperator_LSTM<T>::Initialize(RModel& model)
230239
fAttrActivations = {"Sigmoid", "Tanh", "Tanh"};
231240
}
232241
}
242+
243+
// Register session data members
244+
fCounter = Internal::lstmCounter()++;
245+
std::string opName = "op_lstm" + std::to_string(fCounter);
246+
if (fAttrLayout != 0) {
247+
model.AddDataMember("fVec_" + opName + "_input");
248+
model.AddDataMember("fVec_" + opName + "_initial_hidden_state");
249+
model.AddDataMember("fVec_" + opName + "_initial_cell_state");
250+
}
251+
model.AddDataMember("fVec_" + opName + "_ff_input_gate");
252+
model.AddDataMember("fVec_" + opName + "_ff_output_gate");
253+
model.AddDataMember("fVec_" + opName + "_ff_cell_gate");
254+
if (fAttrInputForget == 0)
255+
model.AddDataMember("fVec_" + opName + "_ff_forget_gate");
256+
model.AddDataMember("fVec_" + opName + "_input_gate");
257+
model.AddDataMember("fVec_" + opName + "_output_gate");
258+
model.AddDataMember("fVec_" + opName + "_cell_gate");
259+
if (fAttrInputForget == 0)
260+
model.AddDataMember("fVec_" + opName + "_forget_gate");
261+
model.AddDataMember("fVec_" + opName + "_cell_state");
262+
model.AddDataMember("fVec_" + opName + "_new_cell_state");
263+
if (fAttrLayout != 0 || fNY.empty()) {
264+
model.AddDataMember("fVec_" + opName + "_hidden_state");
265+
}
233266
}
234267

235268
// generate code for Session data members (e.g. internal vectors)
236269
template <typename T>
237-
std::string ROperator_LSTM<T>::GenerateSessionMembersCode(std::string opName)
270+
std::string ROperator_LSTM<T>::GenerateSessionMembersCode(std::string /*opName*/)
238271
{
239-
opName = "op_" + opName;
272+
std::string opName = "op_lstm" + std::to_string(fCounter);
240273
std::stringstream out;
241274

242275
size_t num_directions = fShapeW[0];
@@ -280,9 +313,10 @@ std::string ROperator_LSTM<T>::GenerateSessionMembersCode(std::string opName)
280313
}
281314

282315
template<typename T>
283-
auto ROperator_LSTM<T>::Generate(std::string OpName)
316+
auto ROperator_LSTM<T>::Generate(std::string /*OpName*/)
284317
-> std::string {
285-
OpName = "op_" + OpName;
318+
//OpName = "op_" + OpName;
319+
std::string OpName = "op_lstm" + std::to_string(fCounter);
286320
std::stringstream out;
287321

288322
size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];

tmva/sofie/inc/TMVA/ROperator_RNN.hxx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ template <typename T> class ROperator_RNN final : public ROperator {
4949

5050
std::string fType; ///< Type of the tensors
5151

52+
int fCounter = 0;
53+
5254
public:
5355
/*! Default constructor of ROperator_RNN */
5456
ROperator_RNN() {}

tmva/sofie/inc/TMVA/ROperator_RNN.icc

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,15 @@ auto ROperator_RNN<T>::ShapeInference(std::vector<std::vector<size_t>> input)
3434
}
3535
}
3636

37+
namespace Internal {
38+
39+
inline int &rnnCounter() {
40+
static int counter = 0;
41+
return counter;
42+
}
43+
44+
}
45+
3746
template <typename T>
3847
auto ROperator_RNN<T>::Initialize(RModel& model)
3948
-> void {
@@ -183,13 +192,26 @@ auto ROperator_RNN<T>::Initialize(RModel& model)
183192
}
184193
// Add needed standard library headers
185194
model.AddNeededStdLib("cmath");
195+
196+
// Register session data members
197+
fCounter = Internal::rnnCounter()++;
198+
std::string opName = "op_rnn" + std::to_string(fCounter);
199+
if (fAttrLayout != 0) {
200+
model.AddDataMember("fVec_" + opName + "_input");
201+
model.AddDataMember("fVec_" + opName + "_initial_hidden_state");
202+
}
203+
model.AddDataMember("fVec_" + opName + "_feedforward");
204+
205+
if (fAttrLayout != 0 || fNY.empty()) {
206+
model.AddDataMember("fVec_" + opName + "_hidden_state");
207+
}
186208
}
187209

188210
// generate code for Session data members (e.g. internal vectors)
189211
template <typename T>
190-
std::string ROperator_RNN<T>::GenerateSessionMembersCode(std::string opName)
212+
std::string ROperator_RNN<T>::GenerateSessionMembersCode(std::string /*opName*/)
191213
{
192-
opName = "op_" + opName;
214+
std::string opName = "op_rnn" + std::to_string(fCounter);
193215
std::stringstream out;
194216

195217
size_t num_directions = fShapeW[0];
@@ -218,9 +240,10 @@ std::string ROperator_RNN<T>::GenerateSessionMembersCode(std::string opName)
218240

219241
//////////////////////////////////////////////////////////////////////////////////////////////////
220242
template<typename T>
221-
auto ROperator_RNN<T>::Generate(std::string OpName)
243+
auto ROperator_RNN<T>::Generate(std::string /*OpName*/)
222244
-> std::string {
223-
OpName = "op_" + OpName;
245+
//OpName = "op_" + OpName;
246+
std::string OpName = "op_rnn" + std::to_string(fCounter);
224247
std::stringstream out;
225248

226249
size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];

tmva/sofie/inc/TMVA/ROperator_Random.hxx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ public:
8888
for (auto & p : fParams)
8989
std::cout << p.first << " : " << p.second << std::endl;
9090
}
91+
92+
model.AddDataMember("fRndmEngine");
9193
}
9294
// generate declaration code for random number generators
9395
std::string GenerateDeclCode() override {

tmva/sofie/src/RModel.cxx

Lines changed: 39 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@
1111
#include "TMVA/RModel.hxx"
1212
#include "TMVA/SOFIE_common.hxx"
1313

14-
namespace TMVA {
15-
namespace Experimental {
16-
namespace SOFIE {
14+
namespace TMVA::Experimental::SOFIE {
1715

1816
namespace {
1917
const std::string SP = " ";
@@ -356,7 +354,7 @@ std::string RModel::AllocateIntermediateMemory(std::span<const std::string_view>
356354
std::string typeName = ConvertTypeToString(GetTensorType(name));
357355
code << "\n // Allocating memory for intermediate tensor " << name << " with size " << size << " bytes";
358356
code << "\n"
359-
<< typeName << "* tensor_" << name << " = reinterpret_cast<" << typeName
357+
<< typeName << "* " << AddTensorMember(name) << " = reinterpret_cast<" << typeName
360358
<< "*>(fIntermediateMemoryPool.data() + " << location << ");\n";
361359
};
362360

@@ -546,6 +544,8 @@ void RModel::Initialize(const std::map<std::string, size_t> & inputParams, bool
546544
}
547545
fIntermediateTensorInfos.clear();
548546
fDynamicTensorInfos.clear();
547+
fDataMembers.clear();
548+
fPointerMemberNames.clear();
549549

550550

551551
// loop on inputs and see if shape can be full specified
@@ -692,7 +692,8 @@ void RModel::InitializeSubGraph(std::shared_ptr<RModel> graph) {
692692
// Function to generate the code for declaring and initializing constant tensors
693693
// This is for tensors which are not part of weight files and can be created from the Constant operator
694694
template <typename T>
695-
std::string GenerateConstantTensorCode(const std::pair<std::string, InitializedTensor> &t)
695+
std::string GenerateConstantTensorCode(const std::pair<std::string, InitializedTensor> &t,
696+
std::function<std::string(std::string const &)> addTensorMember)
696697
{
697698
std::stringstream strs;
698699
std::string type = ConvertTypeToString(t.second.type());
@@ -714,15 +715,15 @@ std::string GenerateConstantTensorCode(const std::pair<std::string, InitializedT
714715
} while (sameData && idx < length);
715716
}
716717
if (allocateOnStack) {
717-
strs << type << " tensor_" << t.first << "[" << length << "] = " << ConvertValuesToString(length, data) << ";\n";
718+
strs << type << " " << addTensorMember(t.first) << "[" << length << "] = " << ConvertValuesToString(length, data) << ";\n";
718719
} else {
719720
strs << "std::vector<" << type << "> fTensor_" << t.first << " = ";
720721
if (sameData)
721722
strs << "std::vector<" << type << ">(" << length << ", " << ConvertValToString(data[0]) << ");\n";
722723
else {
723724
strs << ConvertValuesToString(length, data) << ";\n";
724725
}
725-
strs << type << " * tensor_" + t.first + " = fTensor_" + t.first + ".data();\n";
726+
strs << type << " * " + addTensorMember(t.first) + " = fTensor_" + t.first + ".data();\n";
726727
}
727728
return strs.str();
728729
}
@@ -736,11 +737,12 @@ void RModel::GenerateInitializedTensorInfo()
736737
for (auto &i : fInitializedTensors) {
737738
if (i.second.IsNotWritable()) continue;
738739
if (!fUseWeightFile || i.second.IsConstantTensor() || !i.second.IsWeightTensor() ) {
740+
auto addTensorMember = [this](std::string const &name) -> std::string { return this->AddTensorMember(name); };
739741
if (i.second.type() == ETensorType::FLOAT) {
740-
fGC += GenerateConstantTensorCode<float>(i);
742+
fGC += GenerateConstantTensorCode<float>(i, addTensorMember);
741743
fConstantTensorSize += ConvertShapeToLength(i.second.shape()) * 4;
742744
} else if (i.second.type() == ETensorType::INT64) {
743-
fGC += GenerateConstantTensorCode<int64_t>(i);
745+
fGC += GenerateConstantTensorCode<int64_t>(i, addTensorMember);
744746
fConstantTensorSize += ConvertShapeToLength(i.second.shape()) * 8;
745747
}
746748

@@ -749,7 +751,7 @@ void RModel::GenerateInitializedTensorInfo()
749751
size_t length = ConvertShapeToLength(i.second.shape());
750752
if (i.second.type() == ETensorType::FLOAT) {
751753
fGC += "std::vector<float> fTensor_" + i.first + " = std::vector<float>(" + std::to_string(length) + ");\n";
752-
fGC += "float * tensor_" + i.first + " = fTensor_" + i.first + ".data();\n";
754+
fGC += "float * " + AddTensorMember(i.first) + " = fTensor_" + i.first + ".data();\n";
753755
fWeightsTensorSize += ConvertShapeToLength(i.second.shape()) * 4;
754756
}
755757
}
@@ -774,7 +776,7 @@ void RModel::GenerateIntermediateTensorInfo() {
774776
bool is_alias = (IsAliasTensor(i.first));
775777
if (i.second.type == ETensorType::BOOL && !is_alias) {
776778
tensor_declaration_block += "std::vector<std::uint8_t> fTensor_" + i.first + " = std::vector<std::uint8_t>(" + std::to_string(ConvertShapeToLength(i.second.shape)) + ");\n";
777-
tensor_declaration_block += "std::uint8_t * tensor_" + i.first + " = fTensor_" + i.first + ".data();\n";
779+
tensor_declaration_block += "std::uint8_t * " + AddTensorMember(i.first) + " = fTensor_" + i.first + ".data();\n";
778780
continue;
779781
}
780782
bool is_extended = (fOptimizationLevel == OptimizationLevel::kExtended);
@@ -788,22 +790,22 @@ void RModel::GenerateIntermediateTensorInfo() {
788790

789791
if (i.second.type == ETensorType::FLOAT) {
790792
tensor_declaration_block += "std::vector<float> fTensor_" + i.first + " = std::vector<float>(" + std::to_string(length) + ");\n";
791-
tensor_declaration_block += "float * tensor_" + i.first + " = fTensor_" + i.first + ".data();\n";
793+
tensor_declaration_block += "float * " + AddTensorMember(i.first) + " = fTensor_" + i.first + ".data();\n";
792794
fOtherTensorSize += 4 * length;
793795
}
794796
else if (i.second.type == ETensorType::DOUBLE) {
795797
tensor_declaration_block += "std::vector<double> fTensor_" + i.first + " = std::vector<double>(" + std::to_string(length) + ");\n";
796-
tensor_declaration_block += "double * tensor_" + i.first + " = fTensor_" + i.first + ".data();\n";
798+
tensor_declaration_block += "double * " + AddTensorMember(i.first) + " = fTensor_" + i.first + ".data();\n";
797799
fOtherTensorSize += 8 * length;
798800
}
799801
else if (i.second.type == ETensorType::INT64) {
800802
tensor_declaration_block += "std::vector<int64_t> fTensor_" + i.first + " = std::vector<int64_t>(" + std::to_string(length) + ");\n";
801-
tensor_declaration_block += "int64_t * tensor_" + i.first + " = fTensor_" + i.first + ".data();\n";
803+
tensor_declaration_block += "int64_t * " + AddTensorMember(i.first) + " = fTensor_" + i.first + ".data();\n";
802804
fOtherTensorSize += 8 * length;
803805
}
804806
}
805807
if (is_alias) {
806-
tensor_declaration_block += ConvertTypeToString(i.second.type) + " * tensor_" + i.first + " = nullptr;\n";
808+
tensor_declaration_block += ConvertTypeToString(i.second.type) + " * " + AddTensorMember(i.first) + " = nullptr;\n";
807809
}
808810

809811
}
@@ -816,7 +818,7 @@ void RModel::GenerateIntermediateTensorInfo() {
816818
if (!fDynamicTensorInfos.empty()) {
817819
fGC += "//--- declare the dynamic tensors\n";
818820
for (auto &i : fDynamicTensorInfos) {
819-
fGC += ConvertTypeToString(i.second.type) + " * tensor_" + i.first + " = nullptr;\n";
821+
fGC += ConvertTypeToString(i.second.type) + " * " + AddTensorMember(i.first) + " = nullptr;\n";
820822
}
821823
fGC += "//--- dynamic tensors pool\n";
822824
fGC += "std::vector<char> fDynamicMemoryPool;\n";
@@ -995,9 +997,9 @@ void RModel::GenerateOutput()
995997
if (!doInferArgs.empty())
996998
doInferArgs += ",";
997999
for (std::string const &name : fOutputTensorNames) {
998-
bool isIntermediate = fIntermediateTensorInfos.count(name) > 0;
999-
std::string n = isIntermediate ? std::to_string(ConvertShapeToLength(GetTensorShape(name)))
1000-
: ConvertDimShapeToLength(GetDynamicTensorShape(name));
1000+
bool isDynamic = fDynamicTensorInfos.count(name) > 0;
1001+
std::string n = !isDynamic ? std::to_string(ConvertShapeToLength(GetTensorShape(name)))
1002+
: ConvertDimShapeToLength(GetDynamicTensorShape(name));
10011003
fGC += SP + "std::vector<" + typeForOutput(GetTensorType(name)) + " > output_tensor_" + name + "(" + n + ");\n";
10021004
doInferArgs += " output_tensor_" + name + ".data(),";
10031005
}
@@ -1060,7 +1062,7 @@ void RModel::GenerateSessionCode()
10601062
doInferSignature.back() = ' ';
10611063

10621064
if (fUseSession && !fIsGNNComponent) {
1063-
doInferSignature = sessionName + " const* session, " + doInferSignature;
1065+
doInferSignature = sessionName + " * session, " + doInferSignature;
10641066
}
10651067

10661068
doInferSignature = "void doInfer(" + doInferSignature + ")";
@@ -1205,22 +1207,14 @@ void RModel::GenerateSessionCode()
12051207
fGC += "\n";
12061208

12071209
if (fUseSession && !fIsGNNComponent) {
1208-
fGC += " auto const& sess = session[0];\n";
1209-
std::vector<std::string> names;
1210-
for (auto const& it: fInitializedTensors) {
1211-
names.push_back(it.first);
1210+
fGC += " auto & sess = session[0];\n";
1211+
for (auto const& name: fDataMembers) {
1212+
fGC += " auto & " + name + " = sess." + name + ";\n";
12121213
}
1213-
for (auto const& it: fIntermediateTensorInfos) {
1214-
names.push_back(it.first);
1215-
}
1216-
std::vector<std::string> added;
1217-
for (auto const& name : names) {
1214+
for (auto const& name: fPointerMemberNames) {
12181215
auto found = std::find(fOutputTensorNames.begin(), fOutputTensorNames.end(), name);
1219-
auto found2 = std::find(added.begin(), added.end(), name);
1220-
// Output tensors are passed directly via the function call
1221-
if(found == fOutputTensorNames.end() && found2 == added.end()) {
1216+
if(found == fOutputTensorNames.end()) {
12221217
fGC += " auto & tensor_" + name + " = sess.tensor_" + name + ";\n";
1223-
added.push_back(name);
12241218
}
12251219
}
12261220
fGC += "\n";
@@ -1238,6 +1232,17 @@ void RModel::GenerateSessionCode()
12381232
fGC += (fOperators[op_idx]->Generate(std::to_string(op_idx)));
12391233
}
12401234

1235+
if (fUseSession && !fIsGNNComponent) {
1236+
for (auto const& name: fPointerMemberNames) {
1237+
auto found = std::find(fOutputTensorNames.begin(), fOutputTensorNames.end(), name);
1238+
if(IsConstantTensor(name) && found != fOutputTensorNames.end()) {
1239+
std::string t = "sess.tensor_" + name;
1240+
fGC += " std::copy(std::begin(" + t + "), std::end(" + t + "), tensor_" + name + ");\n";
1241+
}
1242+
}
1243+
fGC += "\n";
1244+
}
1245+
12411246
fGC += "}\n";
12421247
}
12431248

@@ -1657,6 +1662,4 @@ void RModel::Streamer(TBuffer &R__b) {
16571662
}
16581663
}
16591664

1660-
}//SOFIE
1661-
}//Experimental
1662-
}//TMVA
1665+
} // namespace SOFIE::Experimental::TMVA

0 commit comments

Comments
 (0)