Skip to content

Commit

Permalink
Port GPU kernel for Householder transformation to FFI.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 666305682
  • Loading branch information
dfm authored and jax authors committed Aug 22, 2024
1 parent 0b4f64e commit b56ed8e
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 35 deletions.
2 changes: 2 additions & 0 deletions jaxlib/gpu/gpu_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_geqrf_ffi", "CUDA",
GeqrfFfi);
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_csrlsvqr", Csrlsvqr, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_orgqr", Orgqr, "CUDA");
XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_orgqr_ffi", "CUDA",
OrgqrFfi);
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevd", Syevd, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevj", Syevj, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_sytrd", Sytrd, "CUDA");
Expand Down
1 change: 1 addition & 0 deletions jaxlib/gpu/solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ nb::dict Registrations() {

dict[JAX_GPU_PREFIX "solver_getrf_ffi"] = EncapsulateFfiHandler(GetrfFfi);
dict[JAX_GPU_PREFIX "solver_geqrf_ffi"] = EncapsulateFfiHandler(GeqrfFfi);
dict[JAX_GPU_PREFIX "solver_orgqr_ffi"] = EncapsulateFfiHandler(OrgqrFfi);

return dict;
}
Expand Down
171 changes: 136 additions & 35 deletions jaxlib/gpu/solver_kernels_ffi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ inline absl::StatusOr<T*> AllocateWorkspace(ffi::ScratchAllocator& scratch,
} // namespace

#define SOLVER_DISPATCH_IMPL(impl, ...) \
if (dataType == ffi::DataType::F32) { \
if (dataType == ffi::F32) { \
return impl<float>(__VA_ARGS__); \
} else if (dataType == ffi::DataType::F64) { \
} else if (dataType == ffi::F64) { \
return impl<double>(__VA_ARGS__); \
} else if (dataType == ffi::DataType::C64) { \
} else if (dataType == ffi::C64) { \
return impl<gpuComplex>(__VA_ARGS__); \
} else if (dataType == ffi::DataType::C128) { \
} else if (dataType == ffi::C128) { \
return impl<gpuDoubleComplex>(__VA_ARGS__); \
}

Expand Down Expand Up @@ -94,8 +94,8 @@ template <typename T>
ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols,
gpuStream_t stream, ffi::ScratchAllocator& scratch,
ffi::AnyBuffer a, ffi::Result<ffi::AnyBuffer> out,
ffi::Result<ffi::Buffer<ffi::DataType::S32>> ipiv,
ffi::Result<ffi::Buffer<ffi::DataType::S32>> info) {
ffi::Result<ffi::Buffer<ffi::S32>> ipiv,
ffi::Result<ffi::Buffer<ffi::S32>> info) {
FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow<int>(rows));
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(cols));

Expand All @@ -110,13 +110,12 @@ ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols,
auto ipiv_data = ipiv->typed_data();
auto info_data = info->typed_data();
if (a_data != out_data) {
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(
gpuMemcpyAsync(out_data, a_data, sizeof(T) * batch * rows * cols,
gpuMemcpyDeviceToDevice, stream)));
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
}

int ipiv_step = std::min(m, n);
for (int i = 0; i < batch; ++i) {
for (auto i = 0; i < batch; ++i) {
FFI_RETURN_IF_ERROR_STATUS(GetrfKernel<T>::Run(
handle.get(), m, n, out_data, workspace, lwork, ipiv_data, info_data));
out_data += m * n;
Expand Down Expand Up @@ -147,8 +146,8 @@ template <typename T>
ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream,
ffi::ScratchAllocator& scratch, ffi::AnyBuffer a,
ffi::Result<ffi::AnyBuffer> out,
ffi::Result<ffi::Buffer<ffi::DataType::S32>> ipiv,
ffi::Result<ffi::Buffer<ffi::DataType::S32>> info) {
ffi::Result<ffi::Buffer<ffi::S32>> ipiv,
ffi::Result<ffi::Buffer<ffi::S32>> info) {
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(cols));
FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream));
FFI_ASSIGN_OR_RETURN(auto batch_ptrs,
Expand All @@ -159,9 +158,8 @@ ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream,
auto ipiv_data = ipiv->typed_data();
auto info_data = info->typed_data();
if (a_data != out_data) {
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(
gpuMemcpyAsync(out_data, a_data, sizeof(T) * batch * cols * cols,
gpuMemcpyDeviceToDevice, stream)));
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
}

MakeBatchPointersAsync(stream, out_data, batch_ptrs, batch,
Expand All @@ -176,8 +174,8 @@ ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream,

ffi::Error GetrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
ffi::AnyBuffer a, ffi::Result<ffi::AnyBuffer> out,
ffi::Result<ffi::Buffer<ffi::DataType::S32>> ipiv,
ffi::Result<ffi::Buffer<ffi::DataType::S32>> info) {
ffi::Result<ffi::Buffer<ffi::S32>> ipiv,
ffi::Result<ffi::Buffer<ffi::S32>> info) {
auto dataType = a.element_type();
if (dataType != out->element_type()) {
return ffi::Error::InvalidArgument(
Expand All @@ -201,15 +199,14 @@ ffi::Error GetrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
}
} // namespace

XLA_FFI_DEFINE_HANDLER_SYMBOL(
GetrfFfi, GetrfDispatch,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<gpuStream_t>>()
.Ctx<ffi::ScratchAllocator>()
.Arg<ffi::AnyBuffer>() // a
.Ret<ffi::AnyBuffer>() // out
.Ret<ffi::Buffer<ffi::DataType::S32>>() // ipiv
.Ret<ffi::Buffer<ffi::DataType::S32>>() // info
XLA_FFI_DEFINE_HANDLER_SYMBOL(GetrfFfi, GetrfDispatch,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<gpuStream_t>>()
.Ctx<ffi::ScratchAllocator>()
.Arg<ffi::AnyBuffer>() // a
.Ret<ffi::AnyBuffer>() // out
.Ret<ffi::Buffer<ffi::S32>>() // ipiv
.Ret<ffi::Buffer<ffi::S32>>() // info
);

// QR decomposition: geqrf
Expand Down Expand Up @@ -264,14 +261,13 @@ ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols,
auto out_data = static_cast<T*>(out->untyped_data());
auto tau_data = static_cast<T*>(tau->untyped_data());
if (a_data != out_data) {
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(
gpuMemcpyAsync(out_data, a_data, sizeof(T) * batch * rows * cols,
gpuMemcpyDeviceToDevice, stream)));
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
}

int out_step = m * n;
int tau_step = std::min(m, n);
for (int i = 0; i < batch; ++i) {
for (auto i = 0; i < batch; ++i) {
FFI_RETURN_IF_ERROR_STATUS(GeqrfKernel<T>::Run(
handle.get(), m, n, out_data, tau_data, workspace, lwork, info));
out_data += out_step;
Expand All @@ -284,8 +280,8 @@ ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols,
template <> \
struct GeqrfBatchedKernel<type> { \
static absl::Status Run(gpublasHandle_t handle, int m, int n, type** a, \
type** tau, int* info, int batch) { \
return JAX_AS_STATUS(name(handle, m, n, a, m, tau, info, batch)); \
type** tau, int* info, int batch) { \
return JAX_AS_STATUS(name(handle, m, n, a, m, tau, info, batch)); \
} \
}

Expand Down Expand Up @@ -314,9 +310,8 @@ ffi::Error GeqrfBatchedImpl(int64_t batch, int64_t rows, int64_t cols,
auto out_data = out->untyped_data();
auto tau_data = tau->untyped_data();
if (a_data != out_data) {
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(
gpuMemcpyAsync(out_data, a_data, sizeof(T) * batch * rows * cols,
gpuMemcpyDeviceToDevice, stream)));
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
}

MakeBatchPointersAsync(stream, out_data, out_batch_ptrs, batch,
Expand Down Expand Up @@ -369,6 +364,112 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GeqrfFfi, GeqrfDispatch,
.Ret<ffi::AnyBuffer>() // tau
);

// Householder transformations: orgqr

namespace {
#define ORGQR_KERNEL_IMPL(type, name) \
template <> \
struct OrgqrKernel<type> { \
static absl::StatusOr<int> BufferSize(gpusolverDnHandle_t handle, int m, \
int n, int k) { \
int lwork; \
JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \
name##_bufferSize(handle, m, n, k, /*A=*/nullptr, /*lda=*/m, \
/*tau=*/nullptr, &lwork))); \
return lwork; \
} \
static absl::Status Run(gpusolverDnHandle_t handle, int m, int n, int k, \
type* a, type* tau, type* workspace, int lwork, \
int* info) { \
return JAX_AS_STATUS( \
name(handle, m, n, k, a, m, tau, workspace, lwork, info)); \
} \
}

template <typename T>
struct OrgqrKernel;
ORGQR_KERNEL_IMPL(float, gpusolverDnSorgqr);
ORGQR_KERNEL_IMPL(double, gpusolverDnDorgqr);
ORGQR_KERNEL_IMPL(gpuComplex, gpusolverDnCungqr);
ORGQR_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZungqr);
#undef ORGQR_KERNEL_IMPL

template <typename T>
ffi::Error OrgqrImpl(int64_t batch, int64_t rows, int64_t cols, int64_t size,
gpuStream_t stream, ffi::ScratchAllocator& scratch,
ffi::AnyBuffer a, ffi::AnyBuffer tau,
ffi::Result<ffi::AnyBuffer> out) {
FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow<int>(rows));
FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow<int>(cols));
FFI_ASSIGN_OR_RETURN(auto k, MaybeCastNoOverflow<int>(size));

FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream));
FFI_ASSIGN_OR_RETURN(int lwork,
OrgqrKernel<T>::BufferSize(handle.get(), m, n, k));

FFI_ASSIGN_OR_RETURN(auto workspace,
AllocateWorkspace<T>(scratch, lwork, "orgqr"));
// Note: We ignore the returned value of info because it is only used for
// shape checking (which we already do ourselves), but it is expected to be
// in device memory, so we need to allocate it.
FFI_ASSIGN_OR_RETURN(auto info, AllocateWorkspace<int>(scratch, 1, "orgqr"));

auto a_data = static_cast<T*>(a.untyped_data());
auto tau_data = static_cast<T*>(tau.untyped_data());
auto out_data = static_cast<T*>(out->untyped_data());
if (a_data != out_data) {
FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync(
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)));
}

int out_step = m * n;
for (auto i = 0; i < batch; ++i) {
FFI_RETURN_IF_ERROR_STATUS(OrgqrKernel<T>::Run(
handle.get(), m, n, k, out_data, tau_data, workspace, lwork, info));
out_data += out_step;
tau_data += k;
}
return ffi::Error::Success();
}

ffi::Error OrgqrDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
ffi::AnyBuffer a, ffi::AnyBuffer tau,
ffi::Result<ffi::AnyBuffer> out) {
auto dataType = a.element_type();
if (dataType != tau.element_type() || dataType != out->element_type()) {
return ffi::Error::InvalidArgument(
"The inputs and outputs to orgqr must have the same element type");
}
FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]),
SplitBatch2D(a.dimensions()));
FFI_ASSIGN_OR_RETURN((auto [tau_batch, size]),
SplitBatch1D(tau.dimensions()));
if (tau_batch != batch) {
return ffi::Error::InvalidArgument(
"The batch dimensions of the inputs to orgqr must match");
}
if (size > cols) {
return ffi::Error::InvalidArgument(
"The trailing dimension of the tau input to orgqr must be less than or "
"equal to the number of columns of the input matrix");
}
FFI_RETURN_IF_ERROR(
CheckShape(out->dimensions(), {batch, rows, cols}, "out", "orgqr"));
SOLVER_DISPATCH_IMPL(OrgqrImpl, batch, rows, cols, size, stream, scratch, a,
tau, out);
return ffi::Error::InvalidArgument("Unsupported element type for orgqr");
}
} // namespace

XLA_FFI_DEFINE_HANDLER_SYMBOL(OrgqrFfi, OrgqrDispatch,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<gpuStream_t>>()
.Ctx<ffi::ScratchAllocator>()
.Arg<ffi::AnyBuffer>() // a
.Arg<ffi::AnyBuffer>() // tau
.Ret<ffi::AnyBuffer>() // out
);

#undef SOLVER_DISPATCH_IMPL

} // namespace JAX_GPU_NAMESPACE
Expand Down
1 change: 1 addition & 0 deletions jaxlib/gpu/solver_kernels_ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace JAX_GPU_NAMESPACE {

XLA_FFI_DECLARE_HANDLER_SYMBOL(GetrfFfi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(GeqrfFfi);
XLA_FFI_DECLARE_HANDLER_SYMBOL(OrgqrFfi);

} // namespace JAX_GPU_NAMESPACE
} // namespace jax
Expand Down

0 comments on commit b56ed8e

Please sign in to comment.