From b56ed8eeddc5794f3981832a38b6bcc195eb20f8 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Thu, 22 Aug 2024 05:22:39 -0700 Subject: [PATCH] Port GPU kernel for Householder transformation to FFI. PiperOrigin-RevId: 666305682 --- jaxlib/gpu/gpu_kernels.cc | 2 + jaxlib/gpu/solver.cc | 1 + jaxlib/gpu/solver_kernels_ffi.cc | 171 ++++++++++++++++++++++++------- jaxlib/gpu/solver_kernels_ffi.h | 1 + 4 files changed, 140 insertions(+), 35 deletions(-) diff --git a/jaxlib/gpu/gpu_kernels.cc b/jaxlib/gpu/gpu_kernels.cc index 1814641bb4fb..3841393654a8 100644 --- a/jaxlib/gpu/gpu_kernels.cc +++ b/jaxlib/gpu/gpu_kernels.cc @@ -51,6 +51,8 @@ XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_geqrf_ffi", "CUDA", GeqrfFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_csrlsvqr", Csrlsvqr, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_orgqr", Orgqr, "CUDA"); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_orgqr_ffi", "CUDA", + OrgqrFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevd", Syevd, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevj", Syevj, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_sytrd", Sytrd, "CUDA"); diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index 4ee7a9f1dbf7..fee1c1014c75 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -477,6 +477,7 @@ nb::dict Registrations() { dict[JAX_GPU_PREFIX "solver_getrf_ffi"] = EncapsulateFfiHandler(GetrfFfi); dict[JAX_GPU_PREFIX "solver_geqrf_ffi"] = EncapsulateFfiHandler(GeqrfFfi); + dict[JAX_GPU_PREFIX "solver_orgqr_ffi"] = EncapsulateFfiHandler(OrgqrFfi); return dict; } diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc index 91124a847121..2b1f5552977f 100644 --- a/jaxlib/gpu/solver_kernels_ffi.cc +++ b/jaxlib/gpu/solver_kernels_ffi.cc @@ -51,13 +51,13 @@ inline absl::StatusOr AllocateWorkspace(ffi::ScratchAllocator& scratch, } // namespace #define SOLVER_DISPATCH_IMPL(impl, ...) \ - if (dataType == ffi::DataType::F32) { \ + if (dataType == ffi::F32) { \ return impl(__VA_ARGS__); \ - } else if (dataType == ffi::DataType::F64) { \ + } else if (dataType == ffi::F64) { \ return impl(__VA_ARGS__); \ - } else if (dataType == ffi::DataType::C64) { \ + } else if (dataType == ffi::C64) { \ return impl(__VA_ARGS__); \ - } else if (dataType == ffi::DataType::C128) { \ + } else if (dataType == ffi::C128) { \ return impl(__VA_ARGS__); \ } @@ -94,8 +94,8 @@ template ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols, gpuStream_t stream, ffi::ScratchAllocator& scratch, ffi::AnyBuffer a, ffi::Result out, - ffi::Result> ipiv, - ffi::Result> info) { + ffi::Result> ipiv, + ffi::Result> info) { FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow(rows)); FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); @@ -110,13 +110,12 @@ ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols, auto ipiv_data = ipiv->typed_data(); auto info_data = info->typed_data(); if (a_data != out_data) { - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS( - gpuMemcpyAsync(out_data, a_data, sizeof(T) * batch * rows * cols, - gpuMemcpyDeviceToDevice, stream))); + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream))); } int ipiv_step = std::min(m, n); - for (int i = 0; i < batch; ++i) { + for (auto i = 0; i < batch; ++i) { FFI_RETURN_IF_ERROR_STATUS(GetrfKernel::Run( handle.get(), m, n, out_data, workspace, lwork, ipiv_data, info_data)); out_data += m * n; @@ -147,8 +146,8 @@ template ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream, ffi::ScratchAllocator& scratch, ffi::AnyBuffer a, ffi::Result out, - ffi::Result> ipiv, - ffi::Result> info) { + ffi::Result> ipiv, + ffi::Result> info) { FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream)); FFI_ASSIGN_OR_RETURN(auto batch_ptrs, @@ -159,9 +158,8 @@ ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream, auto ipiv_data = ipiv->typed_data(); auto info_data = info->typed_data(); if (a_data != out_data) { - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS( - gpuMemcpyAsync(out_data, a_data, sizeof(T) * batch * cols * cols, - gpuMemcpyDeviceToDevice, stream))); + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream))); } MakeBatchPointersAsync(stream, out_data, batch_ptrs, batch, @@ -176,8 +174,8 @@ ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream, ffi::Error GetrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, ffi::AnyBuffer a, ffi::Result out, - ffi::Result> ipiv, - ffi::Result> info) { + ffi::Result> ipiv, + ffi::Result> info) { auto dataType = a.element_type(); if (dataType != out->element_type()) { return ffi::Error::InvalidArgument( @@ -201,15 +199,14 @@ ffi::Error GetrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, } } // namespace -XLA_FFI_DEFINE_HANDLER_SYMBOL( - GetrfFfi, GetrfDispatch, - ffi::Ffi::Bind() - .Ctx>() - .Ctx() - .Arg() // a - .Ret() // out - .Ret>() // ipiv - .Ret>() // info +XLA_FFI_DEFINE_HANDLER_SYMBOL(GetrfFfi, GetrfDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Ctx() + .Arg() // a + .Ret() // out + .Ret>() // ipiv + .Ret>() // info ); // QR decomposition: geqrf @@ -264,14 +261,13 @@ ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols, auto out_data = static_cast(out->untyped_data()); auto tau_data = static_cast(tau->untyped_data()); if (a_data != out_data) { - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS( - gpuMemcpyAsync(out_data, a_data, sizeof(T) * batch * rows * cols, - gpuMemcpyDeviceToDevice, stream))); + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream))); } int out_step = m * n; int tau_step = std::min(m, n); - for (int i = 0; i < batch; ++i) { + for (auto i = 0; i < batch; ++i) { FFI_RETURN_IF_ERROR_STATUS(GeqrfKernel::Run( handle.get(), m, n, out_data, tau_data, workspace, lwork, info)); out_data += out_step; @@ -284,8 +280,8 @@ ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols, template <> \ struct GeqrfBatchedKernel { \ static absl::Status Run(gpublasHandle_t handle, int m, int n, type** a, \ - type** tau, int* info, int batch) { \ - return JAX_AS_STATUS(name(handle, m, n, a, m, tau, info, batch)); \ + type** tau, int* info, int batch) { \ + return JAX_AS_STATUS(name(handle, m, n, a, m, tau, info, batch)); \ } \ } @@ -314,9 +310,8 @@ ffi::Error GeqrfBatchedImpl(int64_t batch, int64_t rows, int64_t cols, auto out_data = out->untyped_data(); auto tau_data = tau->untyped_data(); if (a_data != out_data) { - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS( - gpuMemcpyAsync(out_data, a_data, sizeof(T) * batch * rows * cols, - gpuMemcpyDeviceToDevice, stream))); + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream))); } MakeBatchPointersAsync(stream, out_data, out_batch_ptrs, batch, @@ -369,6 +364,112 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GeqrfFfi, GeqrfDispatch, .Ret() // tau ); +// Householder transformations: orgqr + +namespace { +#define ORGQR_KERNEL_IMPL(type, name) \ + template <> \ + struct OrgqrKernel { \ + static absl::StatusOr BufferSize(gpusolverDnHandle_t handle, int m, \ + int n, int k) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ + name##_bufferSize(handle, m, n, k, /*A=*/nullptr, /*lda=*/m, \ + /*tau=*/nullptr, &lwork))); \ + return lwork; \ + } \ + static absl::Status Run(gpusolverDnHandle_t handle, int m, int n, int k, \ + type* a, type* tau, type* workspace, int lwork, \ + int* info) { \ + return JAX_AS_STATUS( \ + name(handle, m, n, k, a, m, tau, workspace, lwork, info)); \ + } \ + } + +template +struct OrgqrKernel; +ORGQR_KERNEL_IMPL(float, gpusolverDnSorgqr); +ORGQR_KERNEL_IMPL(double, gpusolverDnDorgqr); +ORGQR_KERNEL_IMPL(gpuComplex, gpusolverDnCungqr); +ORGQR_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZungqr); +#undef ORGQR_KERNEL_IMPL + +template +ffi::Error OrgqrImpl(int64_t batch, int64_t rows, int64_t cols, int64_t size, + gpuStream_t stream, ffi::ScratchAllocator& scratch, + ffi::AnyBuffer a, ffi::AnyBuffer tau, + ffi::Result out) { + FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow(rows)); + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); + FFI_ASSIGN_OR_RETURN(auto k, MaybeCastNoOverflow(size)); + + FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); + FFI_ASSIGN_OR_RETURN(int lwork, + OrgqrKernel::BufferSize(handle.get(), m, n, k)); + + FFI_ASSIGN_OR_RETURN(auto workspace, + AllocateWorkspace(scratch, lwork, "orgqr")); + // Note: We ignore the returned value of info because it is only used for + // shape checking (which we already do ourselves), but it is expected to be + // in device memory, so we need to allocate it. + FFI_ASSIGN_OR_RETURN(auto info, AllocateWorkspace(scratch, 1, "orgqr")); + + auto a_data = static_cast(a.untyped_data()); + auto tau_data = static_cast(tau.untyped_data()); + auto out_data = static_cast(out->untyped_data()); + if (a_data != out_data) { + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream))); + } + + int out_step = m * n; + for (auto i = 0; i < batch; ++i) { + FFI_RETURN_IF_ERROR_STATUS(OrgqrKernel::Run( + handle.get(), m, n, k, out_data, tau_data, workspace, lwork, info)); + out_data += out_step; + tau_data += k; + } + return ffi::Error::Success(); +} + +ffi::Error OrgqrDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, + ffi::AnyBuffer a, ffi::AnyBuffer tau, + ffi::Result out) { + auto dataType = a.element_type(); + if (dataType != tau.element_type() || dataType != out->element_type()) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to orgqr must have the same element type"); + } + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(a.dimensions())); + FFI_ASSIGN_OR_RETURN((auto [tau_batch, size]), + SplitBatch1D(tau.dimensions())); + if (tau_batch != batch) { + return ffi::Error::InvalidArgument( + "The batch dimensions of the inputs to orgqr must match"); + } + if (size > cols) { + return ffi::Error::InvalidArgument( + "The trailing dimension of the tau input to orgqr must be less than or " + "equal to the number of columns of the input matrix"); + } + FFI_RETURN_IF_ERROR( + CheckShape(out->dimensions(), {batch, rows, cols}, "out", "orgqr")); + SOLVER_DISPATCH_IMPL(OrgqrImpl, batch, rows, cols, size, stream, scratch, a, + tau, out); + return ffi::Error::InvalidArgument("Unsupported element type for orgqr"); +} +} // namespace + +XLA_FFI_DEFINE_HANDLER_SYMBOL(OrgqrFfi, OrgqrDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Ctx() + .Arg() // a + .Arg() // tau + .Ret() // out +); + #undef SOLVER_DISPATCH_IMPL } // namespace JAX_GPU_NAMESPACE diff --git a/jaxlib/gpu/solver_kernels_ffi.h b/jaxlib/gpu/solver_kernels_ffi.h index d9c3da47655a..7dbc7454c2e6 100644 --- a/jaxlib/gpu/solver_kernels_ffi.h +++ b/jaxlib/gpu/solver_kernels_ffi.h @@ -24,6 +24,7 @@ namespace JAX_GPU_NAMESPACE { XLA_FFI_DECLARE_HANDLER_SYMBOL(GetrfFfi); XLA_FFI_DECLARE_HANDLER_SYMBOL(GeqrfFfi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(OrgqrFfi); } // namespace JAX_GPU_NAMESPACE } // namespace jax