Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion tmva/sofie/inc/TMVA/RModel.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,13 @@ private:
std::vector<std::string> fDimShapeNames; // parameter names used to define the shapes
std::vector<std::string> fOutputTensorNames;
std::vector<std::string> fInputTensorNames; // input tensor names using ONNX order
std::vector<std::string> fDataMembers;
std::vector<std::string> fPointerMemberNames;


inline std::string AddTensorMember(std::string const &name) {
fPointerMemberNames.push_back(name);
return "tensor_" + name;
}

std::vector<std::unique_ptr<ROperator>> fOperators;

Expand Down Expand Up @@ -63,6 +68,11 @@ public:
std::vector<Dim> GetDimTensorShape(const std::string & name) const;
std::vector<Dim> GetDynamicTensorShape(const std::string & name) const ;

inline std::string AddDataMember(std::string const &name) {
fDataMembers.push_back(name);
return name;
}

// get the values for the tensor representing a shape
const std::vector<Dim> & GetShapeTensorValues(const std::string & tensor_name) const;

Expand Down
2 changes: 2 additions & 0 deletions tmva/sofie/inc/TMVA/ROperator_LSTM.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ template <typename T> class ROperator_LSTM final : public ROperator {

std::string fType; ///< Type of the tensors

int fCounter = 0;

public:
/*! Default constructor of ROperator_LSTM */
ROperator_LSTM() {}
Expand Down
42 changes: 38 additions & 4 deletions tmva/sofie/inc/TMVA/ROperator_LSTM.icc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ auto ROperator_LSTM<T>::ShapeInference(std::vector<std::vector<size_t>> input)
}
}

namespace Internal {

inline int &lstmCounter() {
static int counter = 0;
return counter;
}

}

template<typename T>
auto ROperator_LSTM<T>::Initialize(RModel& model)
-> void {
Expand Down Expand Up @@ -230,13 +239,37 @@ auto ROperator_LSTM<T>::Initialize(RModel& model)
fAttrActivations = {"Sigmoid", "Tanh", "Tanh"};
}
}

// Register session data members
fCounter = Internal::lstmCounter()++;
std::string opName = "op_lstm" + std::to_string(fCounter);
if (fAttrLayout != 0) {
model.AddDataMember("fVec_" + opName + "_input");
model.AddDataMember("fVec_" + opName + "_initial_hidden_state");
model.AddDataMember("fVec_" + opName + "_initial_cell_state");
}
model.AddDataMember("fVec_" + opName + "_ff_input_gate");
model.AddDataMember("fVec_" + opName + "_ff_output_gate");
model.AddDataMember("fVec_" + opName + "_ff_cell_gate");
if (fAttrInputForget == 0)
model.AddDataMember("fVec_" + opName + "_ff_forget_gate");
model.AddDataMember("fVec_" + opName + "_input_gate");
model.AddDataMember("fVec_" + opName + "_output_gate");
model.AddDataMember("fVec_" + opName + "_cell_gate");
if (fAttrInputForget == 0)
model.AddDataMember("fVec_" + opName + "_forget_gate");
model.AddDataMember("fVec_" + opName + "_cell_state");
model.AddDataMember("fVec_" + opName + "_new_cell_state");
if (fAttrLayout != 0 || fNY.empty()) {
model.AddDataMember("fVec_" + opName + "_hidden_state");
}
}

// generate code for Session data members (e.g. internal vectors)
template <typename T>
std::string ROperator_LSTM<T>::GenerateSessionMembersCode(std::string opName)
std::string ROperator_LSTM<T>::GenerateSessionMembersCode(std::string /*opName*/)
{
opName = "op_" + opName;
std::string opName = "op_lstm" + std::to_string(fCounter);
std::stringstream out;

size_t num_directions = fShapeW[0];
Expand Down Expand Up @@ -280,9 +313,10 @@ std::string ROperator_LSTM<T>::GenerateSessionMembersCode(std::string opName)
}

template<typename T>
auto ROperator_LSTM<T>::Generate(std::string OpName)
auto ROperator_LSTM<T>::Generate(std::string /*OpName*/)
-> std::string {
OpName = "op_" + OpName;
//OpName = "op_" + OpName;
std::string OpName = "op_lstm" + std::to_string(fCounter);
std::stringstream out;

size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
Expand Down
2 changes: 2 additions & 0 deletions tmva/sofie/inc/TMVA/ROperator_RNN.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ template <typename T> class ROperator_RNN final : public ROperator {

std::string fType; ///< Type of the tensors

int fCounter = 0;

public:
/*! Default constructor of ROperator_RNN */
ROperator_RNN() {}
Expand Down
31 changes: 27 additions & 4 deletions tmva/sofie/inc/TMVA/ROperator_RNN.icc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ auto ROperator_RNN<T>::ShapeInference(std::vector<std::vector<size_t>> input)
}
}

namespace Internal {

inline int &rnnCounter() {
static int counter = 0;
return counter;
}

}

template <typename T>
auto ROperator_RNN<T>::Initialize(RModel& model)
-> void {
Expand Down Expand Up @@ -183,13 +192,26 @@ auto ROperator_RNN<T>::Initialize(RModel& model)
}
// Add needed standard library headers
model.AddNeededStdLib("cmath");

// Register session data members
fCounter = Internal::rnnCounter()++;
std::string opName = "op_rnn" + std::to_string(fCounter);
if (fAttrLayout != 0) {
model.AddDataMember("fVec_" + opName + "_input");
model.AddDataMember("fVec_" + opName + "_initial_hidden_state");
}
model.AddDataMember("fVec_" + opName + "_feedforward");

if (fAttrLayout != 0 || fNY.empty()) {
model.AddDataMember("fVec_" + opName + "_hidden_state");
}
}

// generate code for Session data members (e.g. internal vectors)
template <typename T>
std::string ROperator_RNN<T>::GenerateSessionMembersCode(std::string opName)
std::string ROperator_RNN<T>::GenerateSessionMembersCode(std::string /*opName*/)
{
opName = "op_" + opName;
std::string opName = "op_rnn" + std::to_string(fCounter);
std::stringstream out;

size_t num_directions = fShapeW[0];
Expand Down Expand Up @@ -218,9 +240,10 @@ std::string ROperator_RNN<T>::GenerateSessionMembersCode(std::string opName)

//////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
auto ROperator_RNN<T>::Generate(std::string OpName)
auto ROperator_RNN<T>::Generate(std::string /*OpName*/)
-> std::string {
OpName = "op_" + OpName;
//OpName = "op_" + OpName;
std::string OpName = "op_rnn" + std::to_string(fCounter);
std::stringstream out;

size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
Expand Down
2 changes: 2 additions & 0 deletions tmva/sofie/inc/TMVA/ROperator_Random.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ public:
for (auto & p : fParams)
std::cout << p.first << " : " << p.second << std::endl;
}

model.AddDataMember("fRndmEngine");
}
// generate declaration code for random number generators
std::string GenerateDeclCode() override {
Expand Down
30 changes: 15 additions & 15 deletions tmva/sofie/inc/TMVA/SOFIE_common.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -679,23 +679,23 @@ void col2im(const Dtype* data_col, const int channels,
//std::cout << "finishing col2imp" << std::endl;
}

// Used at the end of infer() to fill the return object.
template <class T>
void FillOutput(T const *arr, std::vector<T> &out, std::size_t n)
{
out.resize(n);
for (std::size_t i = 0; i < n; ++i) {
out[i] = arr[i];
}
}

} // end namespace UTILITY

namespace BLAS{
extern "C" void sgemm_(const char * transa, const char * transb, const int * m, const int * n, const int * k,
const float * alpha, const float * A, const int * lda, const float * B, const int * ldb,
const float * beta, float * C, const int * ldc);
}//BLAS
namespace BLAS {

extern "C" void saxpy_(const int *n, const float *alpha, const float *x, const int *incx, float *y, const int *incy);

extern "C" void scopy_(const int *n, const float *x, const int *incx, float *y, const int *incy);

extern "C" void sgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k,
const float *alpha, const float *A, const int *lda, const float *B, const int *ldb,
const float *beta, float *C, const int *ldc);

extern "C" void sgemv_(const char *trans, const int *m, const int *n, const float *alpha, const float *A,
const int *lda, const float *X, const int *incx, const float *beta, const float *Y,
const int *incy);

} // namespace BLAS


struct GNN_Data {
Expand Down
Loading
Loading