Skip to content

Commit

Permalink
re-add old mean-API and add deprecated tags
Browse files Browse the repository at this point in the history
  • Loading branch information
mfoerste4 committed Jan 16, 2025
1 parent 5dfce6a commit e580ecf
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
19 changes: 19 additions & 0 deletions cpp/include/raft/stats/detail/mean.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,25 @@ void mean(Type* mu, const Type* data, IdxType D, IdxType N, bool rowMajor, cudaS
raft::mul_const_op<Type>(ratio));
}

template <typename Type, typename IdxType = int>
[[deprecated]] void mean(
Type* mu, const Type* data, IdxType D, IdxType N, bool sample, bool rowMajor, cudaStream_t stream)
{
Type ratio = Type(1) / ((sample) ? Type(N - 1) : Type(N));
raft::linalg::reduce(mu,
data,
D,
N,
Type(0),
rowMajor,
false,
stream,
false,
raft::identity_op(),
raft::add_op(),
raft::mul_const_op<Type>(ratio));
}

} // namespace detail
} // namespace stats
} // namespace raft
62 changes: 62 additions & 0 deletions cpp/include/raft/stats/mean.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,31 @@ void mean(Type* mu, const Type* data, IdxType D, IdxType N, bool rowMajor, cudaS
detail::mean(mu, data, D, N, rowMajor, stream);
}

/**
* @brief Compute mean of the input matrix
*
* Mean operation is assumed to be performed on a given column.
* Note: This call is deprecated, please use `mean` call without `sample` parameter.
*
* @tparam Type: the data type
* @tparam IdxType Integer type used to for addressing
* @param mu: the output mean vector
* @param data: the input matrix
* @param D: number of columns of data
* @param N: number of rows of data
* @param sample: whether to evaluate sample mean or not. In other words,
* whether
* to normalize the output using N-1 or N, for true or false, respectively
* @param rowMajor: whether the input data is row or col major
* @param stream: cuda stream
*/
template <typename Type, typename IdxType = int>
[[deprecated("'sample' parameter deprecated")]] void mean(
Type* mu, const Type* data, IdxType D, IdxType N, bool sample, bool rowMajor, cudaStream_t stream)
{
detail::mean(mu, data, D, N, sample, rowMajor, stream);
}

/**
* @defgroup stats_mean Mean
* @{
Expand Down Expand Up @@ -83,6 +108,43 @@ void mean(raft::resources const& handle,
resource::get_cuda_stream(handle));
}

/**
* @brief Compute mean of the input matrix
*
* Mean operation is assumed to be performed on a given column.
* Note: This call is deprecated, please use `mean` call without `sample` parameter.
*
* @tparam value_t the data type
* @tparam idx_t index type
* @tparam layout_t Layout type of the input matrix.
* @param[in] handle the raft handle
* @param[in] data: the input matrix
* @param[out] mu: the output mean vector
* @param[in] sample: whether to evaluate sample mean or not. In other words, whether
* to normalize the output using N-1 or N, for true or false, respectively
*/
template <typename value_t, typename idx_t, typename layout_t>
[[deprecated("'sample' parameter deprecated")]] void mean(
raft::resources const& handle,
raft::device_matrix_view<const value_t, idx_t, layout_t> data,
raft::device_vector_view<value_t, idx_t> mu,
bool sample)
{
static_assert(
std::is_same_v<layout_t, raft::row_major> || std::is_same_v<layout_t, raft::col_major>,
"Data layout not supported");
RAFT_EXPECTS(data.extent(1) == mu.extent(0), "Size mismatch between data and mu");
RAFT_EXPECTS(mu.is_exhaustive(), "mu must be contiguous");
RAFT_EXPECTS(data.is_exhaustive(), "data must be contiguous");
detail::mean(mu.data_handle(),
data.data_handle(),
data.extent(1),
data.extent(0),
sample,
std::is_same_v<layout_t, raft::row_major>,
resource::get_cuda_stream(handle));
}

/** @} */ // end group stats_mean

}; // namespace stats
Expand Down

0 comments on commit e580ecf

Please sign in to comment.