diff --git a/include/matx/core/sparse_tensor_format.h b/include/matx/core/sparse_tensor_format.h index 6de4396d..f5b733ef 100644 --- a/include/matx/core/sparse_tensor_format.h +++ b/include/matx/core/sparse_tensor_format.h @@ -252,59 +252,46 @@ template class SparseTensorFormat { return false; } - template - static CRD __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ * + template + __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ static void dim2lvl(const CRD *dims, CRD *lvls, bool asSize) { - // Lambda for dim2lvl translation. - auto loop_fun = [&dims, &lvls, &asSize](auto ic) { - constexpr int idx = decltype(ic)::value; - if constexpr (LVL >= (idx + 1)) { - using ftype = std::tuple_element_t; - if constexpr (ftype::expr::op == LvlOp::Id) { - lvls[idx] = dims[ftype::expr::di]; - } else if constexpr (ftype::expr::op == LvlOp::Div) { - lvls[idx] = dims[ftype::expr::di] / ftype::expr::cj; - } else if constexpr (ftype::expr::op == LvlOp::Mod) { - lvls[idx] = asSize ? ftype::expr::cj - : (dims[ftype::expr::di] % ftype::expr::cj); - } - } - }; - // Assumes LVL <= 5. - static_assert(LVL <= 5); - loop_fun(std::integral_constant{}); - loop_fun(std::integral_constant{}); - loop_fun(std::integral_constant{}); - loop_fun(std::integral_constant{}); - loop_fun(std::integral_constant{}); - return lvls; + using ftype = std::tuple_element_t; + if constexpr (ftype::expr::op == LvlOp::Id) { + lvls[L] = dims[ftype::expr::di]; + } else if constexpr (ftype::expr::op == LvlOp::Div) { + lvls[L] = dims[ftype::expr::di] / ftype::expr::cj; + } else if constexpr (ftype::expr::op == LvlOp::Mod) { + lvls[L] = + asSize ? ftype::expr::cj : (dims[ftype::expr::di] % ftype::expr::cj); + } + if constexpr (L + 1 < LVL) { + dim2lvl(dims, lvls, asSize); + } } - template - static CRD __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ * + template + __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ static void lvl2dim(const CRD *lvls, CRD *dims) { - // Lambda for lvl2dim translation. - auto loop_fun = [&lvls, &dims](auto ic) { - constexpr int idx = decltype(ic)::value; - if constexpr (LVL >= (idx + 1)) { - using ftype = std::tuple_element_t; - if constexpr (ftype::expr::op == LvlOp::Id) { - dims[ftype::expr::di] = lvls[idx]; - } else if constexpr (ftype::expr::op == LvlOp::Div) { - dims[ftype::expr::di] = lvls[idx] * ftype::expr::cj; - } else if constexpr (ftype::expr::op == LvlOp::Mod) { - dims[ftype::expr::di] += lvls[idx]; // update (seen second) - } - } - }; - // Assumes LVL <= 5. - static_assert(LVL <= 5); - loop_fun(std::integral_constant{}); - loop_fun(std::integral_constant{}); - loop_fun(std::integral_constant{}); - loop_fun(std::integral_constant{}); - loop_fun(std::integral_constant{}); - return dims; + using ftype = std::tuple_element_t; + if constexpr (ftype::expr::op == LvlOp::Id) { + dims[ftype::expr::di] = lvls[L]; + } else if constexpr (ftype::expr::op == LvlOp::Div) { + dims[ftype::expr::di] = lvls[L] * ftype::expr::cj; + } else if constexpr (ftype::expr::op == LvlOp::Mod) { + dims[ftype::expr::di] += lvls[L]; // update (seen second) + } + if constexpr (L + 1 < LVL) { + lvl2dim(lvls, dims); + } + } + + template static void printLevel() { + using ftype = std::tuple_element_t; + std::cout << " " << ftype::toString(); + if constexpr (L + 1 < LVL) { + std::cout << ","; + printLevel(); + } } static void print() { @@ -315,40 +302,7 @@ template class SparseTensorFormat { std::cout << ","; } std::cout << " ) -> ("; - // Assumes LVL <= 5. - static_assert(LVL <= 5); - if constexpr (LVL > 1) { - using ftype = std::tuple_element_t<0, LVLSPECS>; - std::cout << " " << ftype::toString(); - if constexpr (LVL != 1) { - std::cout << ","; - } - } - if constexpr (LVL >= 2) { - using ftype = std::tuple_element_t<1, LVLSPECS>; - std::cout << " " << ftype::toString(); - if constexpr (LVL > 2) { - std::cout << ","; - } - } - if constexpr (LVL >= 3) { - using ftype = std::tuple_element_t<2, LVLSPECS>; - std::cout << " " << ftype::toString(); - if constexpr (LVL > 3) { - std::cout << ","; - } - } - if constexpr (LVL >= 4) { - using ftype = std::tuple_element_t<3, LVLSPECS>; - std::cout << " " << ftype::toString(); - if constexpr (LVL > 4) { - std::cout << ","; - } - } - if constexpr (LVL >= 5) { - using ftype = std::tuple_element_t<4, LVLSPECS>; - std::cout << " " << ftype::toString(); - } + printLevel(); std::cout << " )" << std::endl; }; };