Skip to content

Commit

Permalink
add scal to tile_extension
Browse files Browse the repository at this point in the history
  • Loading branch information
albestro committed Nov 16, 2023
1 parent 9b381e3 commit 93c5f5a
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions include/dlaf/blas/tile_extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,23 @@ auto add(const dlaf::internal::Policy<B>& p, Sender&& s);
template <Backend B>
auto add(const dlaf::internal::Policy<B>& p);

/// Computes A = beta * A
///
/// This overload blocks until completion of the algorithm.
template <Backend B, class T, Device D>
void scal(T beta, const matrix::Tile<T, D>& tile);

/// This overload takes a policy argument and a sender which must send all required arguments for the
/// algorithm. Returns a sender which signals a connected receiver when the algorithm is done.
template <Backend B, typename Sender,
typename = std::enable_if_t<pika::execution::experimental::is_sender_v<Sender>>>
auto scal(const dlaf::internal::Policy<B>& p, Sender&& s);

/// This overload partially applies the algorithm with a policy for later use with operator| with a
/// sender on the left-hand side.
template <Backend B>
auto scal(const dlaf::internal::Policy<B>& p);

#else

namespace internal {
Expand All @@ -78,9 +95,33 @@ void add(T alpha, const matrix::Tile<const T, Device::GPU>& tile_b,
#endif

DLAF_MAKE_CALLABLE_OBJECT(add);

template <class T>
void scal(T beta, const matrix::Tile<T, Device::CPU>& tile) {
common::internal::SingleThreadedBlasScope single;
blas::gemm(blas::Layout::ColMajor, blas::Op::NoTrans, blas::Op::NoTrans, tile.size().rows(),
tile.size().cols(), 0, T(0), nullptr, tile.ld(), nullptr, 1, beta, tile.ptr(), tile.ld());
}

#ifdef DLAF_WITH_GPU
template <class T>
void scal(cublasHandle_t handle, T beta, const matrix::Tile<T, Device::GPU>& tile) {
using util::blasToCublasCast;

const T alpha = 0;
gpublas::internal::Gemm<T>::call(handle, CUBLAS_OP_N, CUBLAS_OP_N, to_int(tile.size().rows()),
to_int(tile.size().cols()), 0, blasToCublasCast(&alpha),
blasToCublasCast<T*>(nullptr), to_int(tile.ld()),
blasToCublasCast<T*>(nullptr), to_int(1), blasToCublasCast(&beta),
blasToCublasCast(tile.ptr()), to_int(tile.ld()));
}
#endif

DLAF_MAKE_CALLABLE_OBJECT(scal);
}

DLAF_MAKE_SENDER_ALGORITHM_OVERLOADS(dlaf::internal::TransformDispatchType::Plain, add, internal::add_o)
DLAF_MAKE_SENDER_ALGORITHM_OVERLOADS(dlaf::internal::TransformDispatchType::Blas, scal, internal::scal_o)

#endif
}
Expand Down

0 comments on commit 93c5f5a

Please sign in to comment.