-
-
Notifications
You must be signed in to change notification settings - Fork 191
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1824 from bstatcomp/cl_kernel_generator_unary_ops
Add unary operations minus and logical negation to kernel generator
- Loading branch information
Showing
3 changed files
with
214 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
170 changes: 170 additions & 0 deletions
170
stan/math/opencl/kernel_generator/unary_operation_cl.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_UNARY_OPERATION_CL_HPP | ||
#define STAN_MATH_OPENCL_KERNEL_GENERATOR_UNARY_OPERATION_CL_HPP | ||
#ifdef STAN_OPENCL | ||
|
||
#include <stan/math/prim/meta.hpp> | ||
#include <stan/math/opencl/err.hpp> | ||
#include <stan/math/opencl/matrix_cl_view.hpp> | ||
#include <stan/math/opencl/kernel_generator/operation_cl.hpp> | ||
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp> | ||
#include <stan/math/opencl/kernel_generator/is_valid_expression.hpp> | ||
#include <string> | ||
#include <type_traits> | ||
#include <set> | ||
#include <utility> | ||
|
||
namespace stan { | ||
namespace math { | ||
|
||
/** | ||
* Represents a unary operation in kernel generator expressions. | ||
* @tparam Derived derived type | ||
* @tparam T type of argument | ||
*/ | ||
template <typename Derived, typename T, typename Scal> | ||
class unary_operation_cl | ||
: public operation_cl<Derived, typename std::remove_reference_t<T>::Scalar, | ||
T> { | ||
public: | ||
using Scalar = Scal; | ||
using base = operation_cl<Derived, Scalar, T>; | ||
using base::var_name; | ||
|
||
/** | ||
* Constructor | ||
* @param a argument expression | ||
* @param fun function | ||
*/ | ||
unary_operation_cl(T&& a, const std::string& op) | ||
: base(std::forward<T>(a)), op_(op) {} | ||
|
||
/** | ||
* generates kernel code for this expression. | ||
* @param i row index variable name | ||
* @param j column index variable name | ||
* @param var_name_arg variable name of the nested expression | ||
* @return part of kernel with code for this expression | ||
*/ | ||
inline kernel_parts generate(const std::string& i, const std::string& j, | ||
const std::string& var_name_arg) const { | ||
kernel_parts res{}; | ||
res.body = type_str<Scalar>() + " " + var_name + " = " + op_ + var_name_arg | ||
+ ";\n"; | ||
return res; | ||
} | ||
|
||
protected: | ||
std::string op_; | ||
}; | ||
|
||
/** | ||
* Represents a logical negation in kernel generator expressions. | ||
* @tparam Derived derived type | ||
* @tparam T type of argument | ||
*/ | ||
template <typename T> | ||
class logical_negation_ | ||
: public unary_operation_cl<logical_negation_<T>, T, bool> { | ||
static_assert( | ||
std::is_integral<typename std::remove_reference_t<T>::Scalar>::value, | ||
"logical_negation: argument must be expression with integral " | ||
"or boolean return type!"); | ||
using base = unary_operation_cl<logical_negation_<T>, T, bool>; | ||
using base::arguments_; | ||
|
||
public: | ||
/** | ||
* Constructor | ||
* @param a argument expression | ||
*/ | ||
explicit logical_negation_(T&& a) : base(std::forward<T>(a), "!") {} | ||
|
||
/** | ||
* Creates a deep copy of this expression. | ||
* @return copy of \c *this | ||
*/ | ||
inline auto deep_copy() const { | ||
auto&& arg_copy = this->template get_arg<0>().deep_copy(); | ||
return logical_negation_<std::remove_reference_t<decltype(arg_copy)>>{ | ||
std::move(arg_copy)}; | ||
} | ||
|
||
/** | ||
* View of a matrix that would be the result of evaluating this expression. | ||
* @return view | ||
*/ | ||
inline matrix_cl_view view() const { return matrix_cl_view::Entire; } | ||
}; | ||
|
||
/** | ||
* Logical negation of a kernel generator expression. | ||
* | ||
* @tparam T type of the argument | ||
* @param a argument expression | ||
* @return logical negation of given expression | ||
*/ | ||
template <typename T, | ||
require_all_valid_expressions_and_none_scalar_t<T>* = nullptr> | ||
inline logical_negation_<as_operation_cl_t<T>> operator!(T&& a) { | ||
return logical_negation_<as_operation_cl_t<T>>( | ||
as_operation_cl(std::forward<T>(a))); | ||
} | ||
|
||
/** | ||
* Represents an unary minus operation in kernel generator expressions. | ||
* @tparam Derived derived type | ||
* @tparam T type of argument | ||
*/ | ||
template <typename T> | ||
class unary_minus_ | ||
: public unary_operation_cl<unary_minus_<T>, T, | ||
typename std::remove_reference_t<T>::Scalar> { | ||
using base = unary_operation_cl<unary_minus_<T>, T, | ||
typename std::remove_reference_t<T>::Scalar>; | ||
using base::arguments_; | ||
|
||
public: | ||
/** | ||
* Constructor | ||
* @param a argument expression | ||
*/ | ||
explicit unary_minus_(T&& a) : base(std::forward<T>(a), "-") {} | ||
|
||
/** | ||
* Creates a deep copy of this expression. | ||
* @return copy of \c *this | ||
*/ | ||
inline auto deep_copy() const { | ||
auto&& arg_copy = this->template get_arg<0>().deep_copy(); | ||
return unary_minus_<std::remove_reference_t<decltype(arg_copy)>>{ | ||
std::move(arg_copy)}; | ||
} | ||
|
||
/** | ||
* View of a matrix that would be the result of evaluating this expression. | ||
* @return view | ||
*/ | ||
inline matrix_cl_view view() const { | ||
return this->template get_arg<0>().view(); | ||
} | ||
}; | ||
|
||
/** | ||
* Unary minus of a kernel generator expression. | ||
* | ||
* @tparam T type of the argument | ||
* @param a argument expression | ||
* @return unary minus of given expression | ||
*/ | ||
template <typename T, | ||
require_all_valid_expressions_and_none_scalar_t<T>* = nullptr> | ||
inline unary_minus_<as_operation_cl_t<T>> operator-(T&& a) { | ||
return unary_minus_<as_operation_cl_t<T>>( | ||
as_operation_cl(std::forward<T>(a))); | ||
} | ||
|
||
} // namespace math | ||
} // namespace stan | ||
|
||
#endif | ||
#endif |
43 changes: 43 additions & 0 deletions
43
test/unit/math/opencl/kernel_generator/unary_operation_cl_test.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
#ifdef STAN_OPENCL | ||
|
||
#include <stan/math/prim.hpp> | ||
#include <stan/math/opencl/kernel_generator.hpp> | ||
#include <stan/math/opencl/matrix_cl.hpp> | ||
#include <stan/math/opencl/copy.hpp> | ||
#include <test/unit/math/opencl/kernel_generator/reference_kernel.hpp> | ||
#include <stan/math.hpp> | ||
#include <gtest/gtest.h> | ||
#include <string> | ||
|
||
#define EXPECT_MATRIX_NEAR(A, B, DELTA) \ | ||
for (int i = 0; i < A.size(); i++) \ | ||
EXPECT_NEAR(A(i), B(i), DELTA); | ||
|
||
TEST(KernelGenerator, logical_negation_test) { | ||
using stan::math::matrix_cl; | ||
Eigen::Matrix<bool, Eigen::Dynamic, Eigen::Dynamic> m1(3, 3); | ||
m1 << true, false, true, true, false, false, true, false, false; | ||
|
||
matrix_cl<bool> m1_cl(m1); | ||
matrix_cl<bool> res_cl = !m1_cl; | ||
|
||
Eigen::Matrix<bool, Eigen::Dynamic, Eigen::Dynamic> res | ||
= stan::math::from_matrix_cl(res_cl); | ||
Eigen::Matrix<bool, Eigen::Dynamic, Eigen::Dynamic> correct = !m1.array(); | ||
EXPECT_MATRIX_NEAR(correct, res, 1e-9); | ||
} | ||
|
||
TEST(KernelGenerator, unary_minus_test) { | ||
using stan::math::matrix_cl; | ||
Eigen::MatrixXd m1(3, 3); | ||
m1 << 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9; | ||
|
||
matrix_cl<double> m1_cl(m1); | ||
matrix_cl<double> res_cl = -m1_cl; | ||
|
||
Eigen::MatrixXd res = stan::math::from_matrix_cl(res_cl); | ||
Eigen::MatrixXd correct = -m1; | ||
EXPECT_MATRIX_NEAR(correct, res, 1e-9); | ||
} | ||
|
||
#endif |