diff --git a/include/operators/Where.h b/include/operators/Where.h index 727ac4d..28721c3 100644 --- a/include/operators/Where.h +++ b/include/operators/Where.h @@ -35,9 +35,19 @@ template class Where : public baseOperator { // bool getAttribute(OPATTR attrName, int& obj) ; - void compute(void) { - // CHANGE return-type and args - // AND ADD YOUR FUNCTIONAL CODE HERE + tensor compute(tensor &B, tensor &X, tensor &Y) { + if (typeid(B) != typeid(tensor)) + throw std::invalid_argument( + "tensor types not appropriate for Where operator."); + if (X.shape() != Y.shape() || X.shape() != B.shape() || Y.shape() != B.shape()) + throw std::invalid_argument( + "tensor dimenions not appropriate for Where operator."); + + tensor result(X.shape(), X.name()); + for (size_t i = 0; i < X.length(); i++) + result[i] = B[i] ? X[i] : Y[i]; + + return result; } }; } // namespace dnnc \ No newline at end of file diff --git a/src/operators/Add.cpp b/src/operators/Add.cpp index d32c8be..2aa98e3 100644 --- a/src/operators/Add.cpp +++ b/src/operators/Add.cpp @@ -25,7 +25,7 @@ using namespace dnnc; using namespace Eigen; -//#define DNNC_ADD_TEST 1 +#define DNNC_ADD_TEST 1 #ifdef DNNC_ADD_TEST #include @@ -37,7 +37,7 @@ int main() { tensor b(2, 3); b.load(d2); - Add m("localOpName", 0x0); + Add m("localOpName"); auto result = m.compute(a, b); std::cout << result; diff --git a/src/operators/Where.cpp b/src/operators/Where.cpp index b74cd15..97602ee 100644 --- a/src/operators/Where.cpp +++ b/src/operators/Where.cpp @@ -26,9 +26,30 @@ using namespace dnnc; using namespace Eigen; +#define DNNC_WHERE_TEST 1 #ifdef DNNC_WHERE_TEST #include int main() { - // ADD YOUR TEST CODE HERE + + int d1[8] = {1, 2, 3, 4, 5, 6}; + int d2[8] = {9, 8, 7, 6, 5, 4}; + bool d3[8] = {1, 0, 0, 1, 0, 1}; + + tensor a(2, 4); + a.load(d1); + + tensor b(2, 4); + b.load(d2); + + tensor c(2, 4); + c.load(d3); + + Where w("localOpName"); + auto result = w.compute(c, a, b); + + std::cout << result; + std::cout << "\n"; + + return 0; } #endif diff --git a/src/operators/a.out b/src/operators/a.out old mode 100644 new mode 100755 index 575f4eb..348ea72 Binary files a/src/operators/a.out and b/src/operators/a.out differ