Skip to content

Commit

Permalink
remove LVL depth restriction with constexpr templating (#834)
Browse files Browse the repository at this point in the history
  • Loading branch information
aartbik authored Jan 22, 2025
1 parent 9b941ce commit 9dd538a
Showing 1 changed file with 37 additions and 83 deletions.
120 changes: 37 additions & 83 deletions include/matx/core/sparse_tensor_format.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,59 +252,46 @@ template <int D, typename... LvlSpecs> class SparseTensorFormat {
return false;
}

template <typename CRD>
static CRD __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ *
template <typename CRD, int L = 0>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ static void
dim2lvl(const CRD *dims, CRD *lvls, bool asSize) {
// Lambda for dim2lvl translation.
auto loop_fun = [&dims, &lvls, &asSize](auto ic) {
constexpr int idx = decltype(ic)::value;
if constexpr (LVL >= (idx + 1)) {
using ftype = std::tuple_element_t<idx, LVLSPECS>;
if constexpr (ftype::expr::op == LvlOp::Id) {
lvls[idx] = dims[ftype::expr::di];
} else if constexpr (ftype::expr::op == LvlOp::Div) {
lvls[idx] = dims[ftype::expr::di] / ftype::expr::cj;
} else if constexpr (ftype::expr::op == LvlOp::Mod) {
lvls[idx] = asSize ? ftype::expr::cj
: (dims[ftype::expr::di] % ftype::expr::cj);
}
}
};
// Assumes LVL <= 5.
static_assert(LVL <= 5);
loop_fun(std::integral_constant<int, 0>{});
loop_fun(std::integral_constant<int, 1>{});
loop_fun(std::integral_constant<int, 2>{});
loop_fun(std::integral_constant<int, 3>{});
loop_fun(std::integral_constant<int, 4>{});
return lvls;
using ftype = std::tuple_element_t<L, LVLSPECS>;
if constexpr (ftype::expr::op == LvlOp::Id) {
lvls[L] = dims[ftype::expr::di];
} else if constexpr (ftype::expr::op == LvlOp::Div) {
lvls[L] = dims[ftype::expr::di] / ftype::expr::cj;
} else if constexpr (ftype::expr::op == LvlOp::Mod) {
lvls[L] =
asSize ? ftype::expr::cj : (dims[ftype::expr::di] % ftype::expr::cj);
}
if constexpr (L + 1 < LVL) {
dim2lvl<CRD, L + 1>(dims, lvls, asSize);
}
}

template <typename CRD>
static CRD __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ *
template <typename CRD, int L = 0>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ static void
lvl2dim(const CRD *lvls, CRD *dims) {
// Lambda for lvl2dim translation.
auto loop_fun = [&lvls, &dims](auto ic) {
constexpr int idx = decltype(ic)::value;
if constexpr (LVL >= (idx + 1)) {
using ftype = std::tuple_element_t<idx, LVLSPECS>;
if constexpr (ftype::expr::op == LvlOp::Id) {
dims[ftype::expr::di] = lvls[idx];
} else if constexpr (ftype::expr::op == LvlOp::Div) {
dims[ftype::expr::di] = lvls[idx] * ftype::expr::cj;
} else if constexpr (ftype::expr::op == LvlOp::Mod) {
dims[ftype::expr::di] += lvls[idx]; // update (seen second)
}
}
};
// Assumes LVL <= 5.
static_assert(LVL <= 5);
loop_fun(std::integral_constant<int, 0>{});
loop_fun(std::integral_constant<int, 1>{});
loop_fun(std::integral_constant<int, 2>{});
loop_fun(std::integral_constant<int, 3>{});
loop_fun(std::integral_constant<int, 4>{});
return dims;
using ftype = std::tuple_element_t<L, LVLSPECS>;
if constexpr (ftype::expr::op == LvlOp::Id) {
dims[ftype::expr::di] = lvls[L];
} else if constexpr (ftype::expr::op == LvlOp::Div) {
dims[ftype::expr::di] = lvls[L] * ftype::expr::cj;
} else if constexpr (ftype::expr::op == LvlOp::Mod) {
dims[ftype::expr::di] += lvls[L]; // update (seen second)
}
if constexpr (L + 1 < LVL) {
lvl2dim<CRD, L + 1>(lvls, dims);
}
}

template <int L = 0> static void printLevel() {
using ftype = std::tuple_element_t<L, LVLSPECS>;
std::cout << " " << ftype::toString();
if constexpr (L + 1 < LVL) {
std::cout << ",";
printLevel<L + 1>();
}
}

static void print() {
Expand All @@ -315,40 +302,7 @@ template <int D, typename... LvlSpecs> class SparseTensorFormat {
std::cout << ",";
}
std::cout << " ) -> (";
// Assumes LVL <= 5.
static_assert(LVL <= 5);
if constexpr (LVL > 1) {
using ftype = std::tuple_element_t<0, LVLSPECS>;
std::cout << " " << ftype::toString();
if constexpr (LVL != 1) {
std::cout << ",";
}
}
if constexpr (LVL >= 2) {
using ftype = std::tuple_element_t<1, LVLSPECS>;
std::cout << " " << ftype::toString();
if constexpr (LVL > 2) {
std::cout << ",";
}
}
if constexpr (LVL >= 3) {
using ftype = std::tuple_element_t<2, LVLSPECS>;
std::cout << " " << ftype::toString();
if constexpr (LVL > 3) {
std::cout << ",";
}
}
if constexpr (LVL >= 4) {
using ftype = std::tuple_element_t<3, LVLSPECS>;
std::cout << " " << ftype::toString();
if constexpr (LVL > 4) {
std::cout << ",";
}
}
if constexpr (LVL >= 5) {
using ftype = std::tuple_element_t<4, LVLSPECS>;
std::cout << " " << ftype::toString();
}
printLevel();
std::cout << " )" << std::endl;
};
};
Expand Down

0 comments on commit 9dd538a

Please sign in to comment.