Skip to content

Commit

Permalink
Merge pull request #1824 from bstatcomp/cl_kernel_generator_unary_ops
Browse files Browse the repository at this point in the history
Add unary operations minus and logical negation to kernel generator
  • Loading branch information
t4c1 authored Apr 8, 2020
2 parents f02c6df + 2e9a962 commit 11742e2
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 0 deletions.
1 change: 1 addition & 0 deletions stan/math/opencl/kernel_generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
#include <stan/math/opencl/kernel_generator/scalar.hpp>
#include <stan/math/opencl/kernel_generator/binary_operation.hpp>
#include <stan/math/opencl/kernel_generator/unary_function_cl.hpp>
#include <stan/math/opencl/kernel_generator/unary_operation_cl.hpp>
#include <stan/math/opencl/kernel_generator/block.hpp>
#include <stan/math/opencl/kernel_generator/select.hpp>
#include <stan/math/opencl/kernel_generator/rowwise_reduction.hpp>
Expand Down
170 changes: 170 additions & 0 deletions stan/math/opencl/kernel_generator/unary_operation_cl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#ifndef STAN_MATH_OPENCL_KERNEL_GENERATOR_UNARY_OPERATION_CL_HPP
#define STAN_MATH_OPENCL_KERNEL_GENERATOR_UNARY_OPERATION_CL_HPP
#ifdef STAN_OPENCL

#include <stan/math/prim/meta.hpp>
#include <stan/math/opencl/err.hpp>
#include <stan/math/opencl/matrix_cl_view.hpp>
#include <stan/math/opencl/kernel_generator/operation_cl.hpp>
#include <stan/math/opencl/kernel_generator/as_operation_cl.hpp>
#include <stan/math/opencl/kernel_generator/is_valid_expression.hpp>
#include <string>
#include <type_traits>
#include <set>
#include <utility>

namespace stan {
namespace math {

/**
* Represents a unary operation in kernel generator expressions.
* @tparam Derived derived type
* @tparam T type of argument
*/
template <typename Derived, typename T, typename Scal>
class unary_operation_cl
: public operation_cl<Derived, typename std::remove_reference_t<T>::Scalar,
T> {
public:
using Scalar = Scal;
using base = operation_cl<Derived, Scalar, T>;
using base::var_name;

/**
* Constructor
* @param a argument expression
* @param fun function
*/
unary_operation_cl(T&& a, const std::string& op)
: base(std::forward<T>(a)), op_(op) {}

/**
* generates kernel code for this expression.
* @param i row index variable name
* @param j column index variable name
* @param var_name_arg variable name of the nested expression
* @return part of kernel with code for this expression
*/
inline kernel_parts generate(const std::string& i, const std::string& j,
const std::string& var_name_arg) const {
kernel_parts res{};
res.body = type_str<Scalar>() + " " + var_name + " = " + op_ + var_name_arg
+ ";\n";
return res;
}

protected:
std::string op_;
};

/**
* Represents a logical negation in kernel generator expressions.
* @tparam Derived derived type
* @tparam T type of argument
*/
template <typename T>
class logical_negation_
: public unary_operation_cl<logical_negation_<T>, T, bool> {
static_assert(
std::is_integral<typename std::remove_reference_t<T>::Scalar>::value,
"logical_negation: argument must be expression with integral "
"or boolean return type!");
using base = unary_operation_cl<logical_negation_<T>, T, bool>;
using base::arguments_;

public:
/**
* Constructor
* @param a argument expression
*/
explicit logical_negation_(T&& a) : base(std::forward<T>(a), "!") {}

/**
* Creates a deep copy of this expression.
* @return copy of \c *this
*/
inline auto deep_copy() const {
auto&& arg_copy = this->template get_arg<0>().deep_copy();
return logical_negation_<std::remove_reference_t<decltype(arg_copy)>>{
std::move(arg_copy)};
}

/**
* View of a matrix that would be the result of evaluating this expression.
* @return view
*/
inline matrix_cl_view view() const { return matrix_cl_view::Entire; }
};

/**
* Logical negation of a kernel generator expression.
*
* @tparam T type of the argument
* @param a argument expression
* @return logical negation of given expression
*/
template <typename T,
require_all_valid_expressions_and_none_scalar_t<T>* = nullptr>
inline logical_negation_<as_operation_cl_t<T>> operator!(T&& a) {
return logical_negation_<as_operation_cl_t<T>>(
as_operation_cl(std::forward<T>(a)));
}

/**
* Represents an unary minus operation in kernel generator expressions.
* @tparam Derived derived type
* @tparam T type of argument
*/
template <typename T>
class unary_minus_
: public unary_operation_cl<unary_minus_<T>, T,
typename std::remove_reference_t<T>::Scalar> {
using base = unary_operation_cl<unary_minus_<T>, T,
typename std::remove_reference_t<T>::Scalar>;
using base::arguments_;

public:
/**
* Constructor
* @param a argument expression
*/
explicit unary_minus_(T&& a) : base(std::forward<T>(a), "-") {}

/**
* Creates a deep copy of this expression.
* @return copy of \c *this
*/
inline auto deep_copy() const {
auto&& arg_copy = this->template get_arg<0>().deep_copy();
return unary_minus_<std::remove_reference_t<decltype(arg_copy)>>{
std::move(arg_copy)};
}

/**
* View of a matrix that would be the result of evaluating this expression.
* @return view
*/
inline matrix_cl_view view() const {
return this->template get_arg<0>().view();
}
};

/**
* Unary minus of a kernel generator expression.
*
* @tparam T type of the argument
* @param a argument expression
* @return unary minus of given expression
*/
template <typename T,
require_all_valid_expressions_and_none_scalar_t<T>* = nullptr>
inline unary_minus_<as_operation_cl_t<T>> operator-(T&& a) {
return unary_minus_<as_operation_cl_t<T>>(
as_operation_cl(std::forward<T>(a)));
}

} // namespace math
} // namespace stan

#endif
#endif
43 changes: 43 additions & 0 deletions test/unit/math/opencl/kernel_generator/unary_operation_cl_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#ifdef STAN_OPENCL

#include <stan/math/prim.hpp>
#include <stan/math/opencl/kernel_generator.hpp>
#include <stan/math/opencl/matrix_cl.hpp>
#include <stan/math/opencl/copy.hpp>
#include <test/unit/math/opencl/kernel_generator/reference_kernel.hpp>
#include <stan/math.hpp>
#include <gtest/gtest.h>
#include <string>

#define EXPECT_MATRIX_NEAR(A, B, DELTA) \
for (int i = 0; i < A.size(); i++) \
EXPECT_NEAR(A(i), B(i), DELTA);

TEST(KernelGenerator, logical_negation_test) {
using stan::math::matrix_cl;
Eigen::Matrix<bool, Eigen::Dynamic, Eigen::Dynamic> m1(3, 3);
m1 << true, false, true, true, false, false, true, false, false;

matrix_cl<bool> m1_cl(m1);
matrix_cl<bool> res_cl = !m1_cl;

Eigen::Matrix<bool, Eigen::Dynamic, Eigen::Dynamic> res
= stan::math::from_matrix_cl(res_cl);
Eigen::Matrix<bool, Eigen::Dynamic, Eigen::Dynamic> correct = !m1.array();
EXPECT_MATRIX_NEAR(correct, res, 1e-9);
}

TEST(KernelGenerator, unary_minus_test) {
using stan::math::matrix_cl;
Eigen::MatrixXd m1(3, 3);
m1 << 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9;

matrix_cl<double> m1_cl(m1);
matrix_cl<double> res_cl = -m1_cl;

Eigen::MatrixXd res = stan::math::from_matrix_cl(res_cl);
Eigen::MatrixXd correct = -m1;
EXPECT_MATRIX_NEAR(correct, res, 1e-9);
}

#endif

0 comments on commit 11742e2

Please sign in to comment.