From 0dd2f8b4930f2e404e765d251d0fa39f02bceb5c Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Mon, 29 Apr 2024 18:18:30 +0200 Subject: [PATCH 1/3] Add a clang-format pre-commit and workflow --- .github/workflows/formatting.yml | 16 ++ .pre-commit-config.yaml | 5 + include/checks.h | 103 +++++------ include/fft.h | 116 +++++------- include/grid_descriptor_mgr.h | 38 ++-- include/halo.h | 24 +-- include/helpers.h | 24 +-- include/jaxdecomp.h | 115 ++++++------ include/logger.hpp | 109 +++++------ include/perfostep.hpp | 177 +++++++----------- src/fft.cu | 292 ++++++++++++----------------- src/grid_descriptor_mgr.cc | 79 +++----- src/halo.cu | 63 +++---- src/jaxdecomp.cc | 307 ++++++++++++------------------- 14 files changed, 603 insertions(+), 865 deletions(-) diff --git a/.github/workflows/formatting.yml b/.github/workflows/formatting.yml index 97cd358..684c9c7 100644 --- a/.github/workflows/formatting.yml +++ b/.github/workflows/formatting.yml @@ -19,3 +19,19 @@ jobs: python -m pip install pre-commit - name: Run pre-commit run: python -m pre_commit run --all-files + formatting-check: + name: Formatting Check + runs-on: ubuntu-latest + strategy: + matrix: + path: + - 'src' + - 'include' + steps: + - uses: actions/checkout@v3 + - name: Run clang-format style check for C/C++/Protobuf programs. + uses: jidicula/clang-format-action@v4.11.0 + with: + clang-format-version: '13' + check-path: ${{ matrix.path }} + fallback-style: 'LLVM' # optional diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f44eaca..75959b6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,3 +15,8 @@ repos: hooks: - id: isort name: isort (python) +- repo: https://github.com/pre-commit/mirrors-clang-format + rev: v18.1.4 + hooks: + - id: clang-format + name: clang-format diff --git a/include/checks.h b/include/checks.h index b97ac28..7832094 100644 --- a/include/checks.h +++ b/include/checks.h @@ -26,69 +26,64 @@ using namespace std; #define E_NOTIMPL ((HRESULT)0x80004001L) // Macro to check for CUDA errors -#define CHECK_CUDA_EXIT(call) \ - do { \ - cudaError_t err = call; \ - if (err != cudaSuccess) { \ - printf("CUDA error at %s %d: %s\n", __FILE__, __LINE__, \ - cudaGetErrorString(err)); \ - exit(EXIT_FAILURE); \ - } \ +#define CHECK_CUDA_EXIT(call) \ + do { \ + cudaError_t err = call; \ + if (err != cudaSuccess) { \ + printf("CUDA error at %s %d: %s\n", __FILE__, __LINE__, cudaGetErrorString(err)); \ + exit(EXIT_FAILURE); \ + } \ } while (0) // Error checking macros -#define CHECK_CUDECOMP_EXIT(call) \ - do { \ - cudecompResult_t err = call; \ - if (CUDECOMP_RESULT_SUCCESS != err) { \ - fprintf(stderr, "%s:%d CUDECOMP error. (error code %d)\n", __FILE__, \ - __LINE__, err); \ - throw exception(); \ - } \ +#define CHECK_CUDECOMP_EXIT(call) \ + do { \ + cudecompResult_t err = call; \ + if (CUDECOMP_RESULT_SUCCESS != err) { \ + fprintf(stderr, "%s:%d CUDECOMP error. (error code %d)\n", __FILE__, __LINE__, err); \ + throw exception(); \ + } \ } while (false) -#define CHECK_CUFFT_EXIT(call) \ - do { \ - cufftResult_t err = call; \ - if (CUFFT_SUCCESS != err) { \ - fprintf(stderr, "%s:%d CUFFT error. (error code %d)\n", __FILE__, \ - __LINE__, err); \ - throw exception(); \ - } \ +#define CHECK_CUFFT_EXIT(call) \ + do { \ + cufftResult_t err = call; \ + if (CUFFT_SUCCESS != err) { \ + fprintf(stderr, "%s:%d CUFFT error. (error code %d)\n", __FILE__, __LINE__, err); \ + throw exception(); \ + } \ } while (false) -#define CHECK_MPI_EXIT(call) \ - { \ - int err = call; \ - if (0 != err) { \ - char error_str[MPI_MAX_ERROR_STRING]; \ - int len; \ - MPI_Error_string(err, error_str, &len); \ - if (error_str) { \ - fprintf(stderr, "%s:%d MPI error. (%s)\n", __FILE__, __LINE__, \ - error_str); \ - } else { \ - fprintf(stderr, "%s:%d MPI error. (error code %d)\n", __FILE__, \ - __LINE__, err); \ - } \ - exit(EXIT_FAILURE); \ - } \ - } \ +#define CHECK_MPI_EXIT(call) \ + { \ + int err = call; \ + if (0 != err) { \ + char error_str[MPI_MAX_ERROR_STRING]; \ + int len; \ + MPI_Error_string(err, error_str, &len); \ + if (error_str) { \ + fprintf(stderr, "%s:%d MPI error. (%s)\n", __FILE__, __LINE__, error_str); \ + } else { \ + fprintf(stderr, "%s:%d MPI error. (error code %d)\n", __FILE__, __LINE__, err); \ + } \ + exit(EXIT_FAILURE); \ + } \ + } \ while (false) -#define HR2STR(hr) \ - ((hr == S_OK) ? "S_OK" \ - : (hr == S_FALSE) ? "S_FALSE" \ - : (hr == E_ABORT) ? "E_ABORT" \ - : (hr == E_ACCESSDENIED) ? "E_ACCESSDENIED" \ - : (hr == E_FAIL) ? "E_FAIL" \ - : (hr == E_HANDLE) ? "E_HANDLE" \ - : (hr == E_INVALIDARG) ? "E_INVALIDARG" \ - : (hr == E_NOINTERFACE) ? "E_NOINTERFACE" \ - : (hr == E_NOTIMPL) ? "E_NOTIMPL" \ - : (hr == E_OUTOFMEMORY) ? "E_OUTOFMEMORY" \ - : (hr == E_POINTER) ? "E_POINTER" \ - : (hr == E_UNEXPECTED) ? "E_UNEXPECTED" \ +#define HR2STR(hr) \ + ((hr == S_OK) ? "S_OK" \ + : (hr == S_FALSE) ? "S_FALSE" \ + : (hr == E_ABORT) ? "E_ABORT" \ + : (hr == E_ACCESSDENIED) ? "E_ACCESSDENIED" \ + : (hr == E_FAIL) ? "E_FAIL" \ + : (hr == E_HANDLE) ? "E_HANDLE" \ + : (hr == E_INVALIDARG) ? "E_INVALIDARG" \ + : (hr == E_NOINTERFACE) ? "E_NOINTERFACE" \ + : (hr == E_NOTIMPL) ? "E_NOTIMPL" \ + : (hr == E_OUTOFMEMORY) ? "E_OUTOFMEMORY" \ + : (hr == E_POINTER) ? "E_POINTER" \ + : (hr == E_UNEXPECTED) ? "E_UNEXPECTED" \ : "Unknown HRESULT") #endif // _JAX_DECOMP_CHECKS_H_ diff --git a/include/fft.h b/include/fft.h index edafe67..403d10d 100644 --- a/include/fft.h +++ b/include/fft.h @@ -1,8 +1,8 @@ #ifndef _JAX_DECOMP_FFT_H_ #define _JAX_DECOMP_FFT_H_ -#include "logger.hpp" #include "checks.h" +#include "logger.hpp" #include #include // has to be included before cuda/std/complex #include @@ -20,12 +20,8 @@ static cufftType get_cufft_type_c2r(double) { return CUFFT_Z2D; } static cufftType get_cufft_type_c2r(float) { return CUFFT_C2R; } static cufftType get_cufft_type_c2c(double) { return CUFFT_Z2Z; } static cufftType get_cufft_type_c2c(float) { return CUFFT_C2C; } -static cudecompDataType_t get_cudecomp_datatype(cuda::std::complex) { - return CUDECOMP_FLOAT_COMPLEX; -} -static cudecompDataType_t get_cudecomp_datatype(cuda::std::complex) { - return CUDECOMP_DOUBLE_COMPLEX; -} +static cudecompDataType_t get_cudecomp_datatype(cuda::std::complex) { return CUDECOMP_FLOAT_COMPLEX; } +static cudecompDataType_t get_cudecomp_datatype(cuda::std::complex) { return CUDECOMP_DOUBLE_COMPLEX; } namespace jaxdecomp { enum Decomposition { slab_XY, slab_YZ, pencil, unknown }; @@ -61,13 +57,12 @@ class fftDescriptor { // To make it trivially copyable fftDescriptor() = default; - fftDescriptor(const fftDescriptor &other) = default; - fftDescriptor &operator=(const fftDescriptor &other) = default; + fftDescriptor(const fftDescriptor& other) = default; + fftDescriptor& operator=(const fftDescriptor& other) = default; // Create a compare operator to be used in the unordered_map (a hash is also // created in the bottom of the file) - bool operator==(const fftDescriptor &other) const { - if (double_precision != other.double_precision || - gdims[0] != other.gdims[0] || gdims[1] != other.gdims[1] || + bool operator==(const fftDescriptor& other) const { + if (double_precision != other.double_precision || gdims[0] != other.gdims[0] || gdims[1] != other.gdims[1] || gdims[2] != other.gdims[2] || decomposition != other.decomposition) { return false; } @@ -78,8 +73,8 @@ class fftDescriptor { // Initialize the descriptor from the grid configuration // this is used for subsequent ffts to find the Executor that was already // defined - fftDescriptor(cudecompGridDescConfig_t &config, const bool &double_precision, - const bool &iForward, const bool &iAdjoint) + fftDescriptor(cudecompGridDescConfig_t& config, const bool& double_precision, const bool& iForward, + const bool& iAdjoint) : double_precision(double_precision), config(config) { gdims[0] = config.gdims[0]; gdims[1] = config.gdims[1]; @@ -101,14 +96,12 @@ template class FourierExecutor { FourierExecutor() : m_Tracer("JAXDECOMP") {} ~FourierExecutor(); - HRESULT Initialize(cudecompHandle_t handle, cudecompGridDescConfig_t config, - size_t &work_size, fftDescriptor &fft_descriptor); + HRESULT Initialize(cudecompHandle_t handle, cudecompGridDescConfig_t config, size_t& work_size, + fftDescriptor& fft_descriptor); - HRESULT forward(cudecompHandle_t handle, fftDescriptor desc, - cudaStream_t stream, void **buffers); + HRESULT forward(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, void** buffers); - HRESULT backward(cudecompHandle_t handle, fftDescriptor desc, - cudaStream_t stream, void **buffers); + HRESULT backward(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, void** buffers); private: AsyncLogger m_Tracer; @@ -137,69 +130,54 @@ template class FourierExecutor { int64_t m_WorkSize; // Internal functions - HRESULT InitializePencils(cudecompGridDescConfig_t &iGridConfig, - cudecompPencilInfo_t &x_pencil_info, - cudecompPencilInfo_t &y_pencil_info, - cudecompPencilInfo_t &z_pencil_info, - int64_t &work_size, const bool &is_contiguous); - - HRESULT InitializeSlabXY(cudecompGridDescConfig_t &iGridConfig, - cudecompPencilInfo_t &x_pencil_info, - cudecompPencilInfo_t &y_pencil_info, - cudecompPencilInfo_t &z_pencil_info, - int64_t &work_size, const bool &is_contiguous); - - HRESULT InitializeSlabYZ(cudecompGridDescConfig_t &iGridConfig, - cudecompPencilInfo_t &x_pencil_info, - cudecompPencilInfo_t &y_pencil_info, - cudecompPencilInfo_t &z_pencil_info, - int64_t &work_size, const bool &is_contiguous); - - HRESULT forwardXY(cudecompHandle_t handle, fftDescriptor desc, - cudaStream_t stream, complex_t *input, complex_t *output, - complex_t *work_buffer); - - HRESULT backwardXY(cudecompHandle_t handle, fftDescriptor desc, - cudaStream_t stream, complex_t *input, complex_t *output, - complex_t *work_buffer); - - HRESULT forwardYZ(cudecompHandle_t handle, fftDescriptor desc, - cudaStream_t stream, complex_t *input, complex_t *output, - complex_t *work_buffer); - - HRESULT backwardYZ(cudecompHandle_t handle, fftDescriptor desc, - cudaStream_t stream, complex_t *input, complex_t *output, - complex_t *work_buffer); - - HRESULT forwardPencil(cudecompHandle_t handle, fftDescriptor desc, - cudaStream_t stream, complex_t *input, - complex_t *output, complex_t *work_buffer); - - HRESULT backwardPencil(cudecompHandle_t handle, fftDescriptor desc, - cudaStream_t stream, complex_t *input, - complex_t *output, complex_t *work_buffer); + HRESULT InitializePencils(cudecompGridDescConfig_t& iGridConfig, cudecompPencilInfo_t& x_pencil_info, + cudecompPencilInfo_t& y_pencil_info, cudecompPencilInfo_t& z_pencil_info, + int64_t& work_size, const bool& is_contiguous); + + HRESULT InitializeSlabXY(cudecompGridDescConfig_t& iGridConfig, cudecompPencilInfo_t& x_pencil_info, + cudecompPencilInfo_t& y_pencil_info, cudecompPencilInfo_t& z_pencil_info, int64_t& work_size, + const bool& is_contiguous); + + HRESULT InitializeSlabYZ(cudecompGridDescConfig_t& iGridConfig, cudecompPencilInfo_t& x_pencil_info, + cudecompPencilInfo_t& y_pencil_info, cudecompPencilInfo_t& z_pencil_info, int64_t& work_size, + const bool& is_contiguous); + + HRESULT forwardXY(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, complex_t* input, + complex_t* output, complex_t* work_buffer); + + HRESULT backwardXY(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, complex_t* input, + complex_t* output, complex_t* work_buffer); + + HRESULT forwardYZ(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, complex_t* input, + complex_t* output, complex_t* work_buffer); + + HRESULT backwardYZ(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, complex_t* input, + complex_t* output, complex_t* work_buffer); + + HRESULT forwardPencil(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, complex_t* input, + complex_t* output, complex_t* work_buffer); + + HRESULT backwardPencil(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, complex_t* input, + complex_t* output, complex_t* work_buffer); HRESULT clearPlans(); // DEBUG ONLY ... I WARN YOU - void inspect_device_array(complex_t *data, int size, cudaStream_t stream); + void inspect_device_array(complex_t* data, int size, cudaStream_t stream); }; } // namespace jaxdecomp namespace std { template <> struct hash { - std::size_t operator()(const jaxdecomp::fftDescriptor &descriptor) const { + std::size_t operator()(const jaxdecomp::fftDescriptor& descriptor) const { // Only hash The double precision and the gdims and pdims // If adjoint is changed then the plan should be the same // adjoint is to be used to execute the backward plan - static const size_t xy_hash = - std::hash()(jaxdecomp::Decomposition::slab_XY); + static const size_t xy_hash = std::hash()(jaxdecomp::Decomposition::slab_XY); - size_t hash = std::hash()(descriptor.double_precision) ^ - std::hash()(descriptor.gdims[0]) ^ - std::hash()(descriptor.gdims[1]) ^ - std::hash()(descriptor.gdims[2]) ^ + size_t hash = std::hash()(descriptor.double_precision) ^ std::hash()(descriptor.gdims[0]) ^ + std::hash()(descriptor.gdims[1]) ^ std::hash()(descriptor.gdims[2]) ^ std::hash()(descriptor.decomposition); return hash; } diff --git a/include/grid_descriptor_mgr.h b/include/grid_descriptor_mgr.h index eb82974..e014111 100644 --- a/include/grid_descriptor_mgr.h +++ b/include/grid_descriptor_mgr.h @@ -2,37 +2,37 @@ #ifndef GRIDDESCRIPTORMANAGER_H #define GRIDDESCRIPTORMANAGER_H -#include "logger.hpp" #include "checks.h" #include "fft.h" #include "halo.h" +#include "logger.hpp" #include #include -#include #include +#include namespace jaxdecomp { class GridDescriptorManager { public: - static GridDescriptorManager &getInstance() { + static GridDescriptorManager& getInstance() { static GridDescriptorManager instance; // Guaranteed to be destroyed. // Instantiated on first use. return instance; } - HRESULT createFFTExecutor(fftDescriptor &descriptor, size_t &work_size, - std::shared_ptr> &executor); + HRESULT createFFTExecutor(fftDescriptor& descriptor, size_t& work_size, + std::shared_ptr>& executor); - HRESULT createFFTExecutor(fftDescriptor &descriptor, size_t &work_size, - std::shared_ptr> &executor); + HRESULT createFFTExecutor(fftDescriptor& descriptor, size_t& work_size, + std::shared_ptr>& executor); - HRESULT createHaloExecutor(haloDescriptor_t &descriptor, size_t &work_size, - std::shared_ptr> &executor); + HRESULT createHaloExecutor(haloDescriptor_t& descriptor, size_t& work_size, + std::shared_ptr>& executor); - HRESULT createHaloExecutor(haloDescriptor_t &descriptor, size_t &work_size, - std::shared_ptr> &executor); + HRESULT createHaloExecutor(haloDescriptor_t& descriptor, size_t& work_size, + std::shared_ptr>& executor); inline cudecompHandle_t getHandle() const { return m_Handle; } @@ -48,21 +48,17 @@ class GridDescriptorManager { cudecompHandle_t m_Handle; - std::unordered_map>, - std::hash, std::equal_to<>> + std::unordered_map>, std::hash, std::equal_to<>> m_Descriptors64; - std::unordered_map>, - std::hash, std::equal_to<>> + std::unordered_map>, std::hash, std::equal_to<>> m_Descriptors32; - std::unordered_map>> - m_HaloDescriptors64; - std::unordered_map>> - m_HaloDescriptors32; + std::unordered_map>> m_HaloDescriptors64; + std::unordered_map>> m_HaloDescriptors32; public: - GridDescriptorManager(GridDescriptorManager const &) = delete; - void operator=(GridDescriptorManager const &) = delete; + GridDescriptorManager(GridDescriptorManager const&) = delete; + void operator=(GridDescriptorManager const&) = delete; }; } // namespace jaxdecomp diff --git a/include/halo.h b/include/halo.h index 6eca529..70bd68f 100644 --- a/include/halo.h +++ b/include/halo.h @@ -18,18 +18,14 @@ class haloDescriptor_t { cudecompGridDescConfig_t config; // Descriptor for the grid haloDescriptor_t() = default; - haloDescriptor_t(const haloDescriptor_t &other) = default; + haloDescriptor_t(const haloDescriptor_t& other) = default; ~haloDescriptor_t() = default; - bool operator==(const haloDescriptor_t &other) const { - return (double_precision == other.double_precision && - halo_extents == other.halo_extents && - halo_periods == other.halo_periods && axis == other.axis && - config.gdims[0] == other.config.gdims[0] && - config.gdims[1] == other.config.gdims[1] && - config.gdims[2] == other.config.gdims[2] && - config.pdims[0] == other.config.pdims[0] && - config.pdims[1] == other.config.pdims[1]); + bool operator==(const haloDescriptor_t& other) const { + return (double_precision == other.double_precision && halo_extents == other.halo_extents && + halo_periods == other.halo_periods && axis == other.axis && config.gdims[0] == other.config.gdims[0] && + config.gdims[1] == other.config.gdims[1] && config.gdims[2] == other.config.gdims[2] && + config.pdims[0] == other.config.pdims[0] && config.pdims[1] == other.config.pdims[1]); } }; @@ -41,10 +37,8 @@ template class HaloExchange { // Grid descriptors are handled by the GridDescriptorManager ~HaloExchange() = default; - HRESULT get_halo_descriptor(cudecompHandle_t handle, size_t &work_size, - haloDescriptor_t &halo_desc); - HRESULT halo_exchange(cudecompHandle_t handle, haloDescriptor_t desc, - cudaStream_t stream, void **buffers); + HRESULT get_halo_descriptor(cudecompHandle_t handle, size_t& work_size, haloDescriptor_t& halo_desc); + HRESULT halo_exchange(cudecompHandle_t handle, haloDescriptor_t desc, cudaStream_t stream, void** buffers); private: cudecompGridDesc_t m_GridConfig; @@ -58,7 +52,7 @@ template class HaloExchange { namespace std { template <> struct hash { - std::size_t operator()(const jaxdecomp::haloDescriptor_t &descriptor) const { + std::size_t operator()(const jaxdecomp::haloDescriptor_t& descriptor) const { std::size_t h1 = std::hash{}(descriptor.double_precision); h1 ^= std::hash{}(descriptor.halo_extents[0]); h1 ^= std::hash{}(descriptor.halo_extents[1]); diff --git a/include/helpers.h b/include/helpers.h index 1211089..210ce57 100644 --- a/include/helpers.h +++ b/include/helpers.h @@ -7,11 +7,10 @@ namespace jaxdecomp { // https://en.cppreference.com/w/cpp/numeric/bit_cast template -typename std::enable_if::value && +typename std::enable_if::value && std::is_trivially_copyable::value, To>::type -bit_cast(const From &src) noexcept { +bit_cast(const From& src) noexcept { static_assert(std::is_trivially_constructible::value, "This implementation additionally requires destination type to " "be trivially constructible"); @@ -21,24 +20,21 @@ bit_cast(const From &src) noexcept { return dst; } -template std::string PackDescriptorAsString(const T &descriptor) { - return std::string(bit_cast(&descriptor), sizeof(T)); +template std::string PackDescriptorAsString(const T& descriptor) { + return std::string(bit_cast(&descriptor), sizeof(T)); } -template pybind11::bytes PackDescriptor(const T &descriptor) { +template pybind11::bytes PackDescriptor(const T& descriptor) { return pybind11::bytes(PackDescriptorAsString(descriptor)); } -template -const T *UnpackDescriptor(const char *opaque, std::size_t opaque_len) { - if (opaque_len != sizeof(T)) { - throw std::runtime_error("Invalid opaque object size"); - } - return bit_cast(opaque); +template const T* UnpackDescriptor(const char* opaque, std::size_t opaque_len) { + if (opaque_len != sizeof(T)) { throw std::runtime_error("Invalid opaque object size"); } + return bit_cast(opaque); } -template pybind11::capsule EncapsulateFunction(T *fn) { - return pybind11::capsule(bit_cast(fn), "xla._CUSTOM_CALL_TARGET"); +template pybind11::capsule EncapsulateFunction(T* fn) { + return pybind11::capsule(bit_cast(fn), "xla._CUSTOM_CALL_TARGET"); } } // namespace jaxdecomp diff --git a/include/jaxdecomp.h b/include/jaxdecomp.h index 395f2c0..ffbfb46 100644 --- a/include/jaxdecomp.h +++ b/include/jaxdecomp.h @@ -1,75 +1,64 @@ #ifndef _JAX_DECOMP_H_ #define _JAX_DECOMP_H_ +#include "checks.h" +#include #include #include -#include -#include "checks.h" -namespace jaxdecomp -{ - /** - * @brief A data structure defining configuration options for grid descriptor creation. - * Slightly adapted version of cudecompGridDescConfig_t which can be automatically translated by pybind11 - */ - typedef struct - { - // Grid information - std::array gdims; ///< dimensions of global data grid - // std::array gdims_dist; ///< dimensions of global data grid to use for distribution - std::array pdims; ///< dimensions of process grid +namespace jaxdecomp { +/** + * @brief A data structure defining configuration options for grid descriptor creation. + * Slightly adapted version of cudecompGridDescConfig_t which can be automatically translated by pybind11 + */ +typedef struct { + // Grid information + std::array gdims; ///< dimensions of global data grid + // std::array gdims_dist; ///< dimensions of global data grid to use for distribution + std::array pdims; ///< dimensions of process grid - // Transpose settings - cudecompTransposeCommBackend_t transpose_comm_backend; ///< communication backend to use for transpose communication - ///< (default: CUDECOMP_TRANSPOSE_COMM_MPI_P2P) - // bool transpose_axis_contiguous[3]; ///< flag (by axis) indicating if memory should be contiguous along pencil axis - // /< (default: [false, false, false]) + // Transpose settings + cudecompTransposeCommBackend_t transpose_comm_backend; ///< communication backend to use for transpose communication + ///< (default: CUDECOMP_TRANSPOSE_COMM_MPI_P2P) + // bool transpose_axis_contiguous[3]; ///< flag (by axis) indicating if memory should be contiguous along pencil axis + // /< (default: [false, false, false]) - // Halo settings - cudecompHaloCommBackend_t - halo_comm_backend; ///< communication backend to use for halo communication (default: CUDECOMP_HALO_COMM_MPI) + // Halo settings + cudecompHaloCommBackend_t + halo_comm_backend; ///< communication backend to use for halo communication (default: CUDECOMP_HALO_COMM_MPI) - } decompGridDescConfig_t; - void cudecompGridDescConfigSet(cudecompGridDescConfig_t *config, - const decompGridDescConfig_t *source) - { - // Initialize the config with the defaults - CHECK_CUDECOMP_EXIT(cudecompGridDescConfigSetDefaults(config)); - for (int i = 0; i < 3; i++) - config->gdims[i] = source->gdims[i]; - for (int i = 0; i < 2; i++) - config->pdims[i] = source->pdims[i]; - for (int i = 0; i < 3; i++) - config->transpose_axis_contiguous[i] = true; - config->halo_comm_backend = source->halo_comm_backend; - config->transpose_comm_backend = source->transpose_comm_backend; - }; +} decompGridDescConfig_t; +void cudecompGridDescConfigSet(cudecompGridDescConfig_t* config, const decompGridDescConfig_t* source) { + // Initialize the config with the defaults + CHECK_CUDECOMP_EXIT(cudecompGridDescConfigSetDefaults(config)); + for (int i = 0; i < 3; i++) config->gdims[i] = source->gdims[i]; + for (int i = 0; i < 2; i++) config->pdims[i] = source->pdims[i]; + for (int i = 0; i < 3; i++) config->transpose_axis_contiguous[i] = true; + config->halo_comm_backend = source->halo_comm_backend; + config->transpose_comm_backend = source->transpose_comm_backend; +}; - /** - * @brief A data structure containing geometry information about a pencil data buffer. - * Slightly adapted version of cudecompPencilInfo_t which can be automatically translated by pybind11 - */ - typedef struct - { - std::array shape; ///< pencil shape (in local order, including halo elements) - std::array lo; ///< lower bound coordinates (in local order, excluding halo elements) - std::array hi; ///< upper bound coordinates (in local order, excluding halo elements) - std::array order; ///< data layout order (e.g. 2,1,0 means memory is ordered Z,Y,X) - std::array halo_extents; ///< halo extents by dimension (in global order) - int64_t size; ///< number of elements in pencil (including halo elements) - } decompPencilInfo_t; - void decompPencilInfoSet(decompPencilInfo_t *info, - const cudecompPencilInfo_t *source) - { - for (int i = 0; i < 3; i++) - { - info->hi[i] = source->hi[i]; - info->lo[i] = source->lo[i]; - info->halo_extents[i] = source->halo_extents[i]; - info->order[i] = source->order[i]; - info->shape[i] = source->shape[i]; - } - info->size = source->size; - }; +/** + * @brief A data structure containing geometry information about a pencil data buffer. + * Slightly adapted version of cudecompPencilInfo_t which can be automatically translated by pybind11 + */ +typedef struct { + std::array shape; ///< pencil shape (in local order, including halo elements) + std::array lo; ///< lower bound coordinates (in local order, excluding halo elements) + std::array hi; ///< upper bound coordinates (in local order, excluding halo elements) + std::array order; ///< data layout order (e.g. 2,1,0 means memory is ordered Z,Y,X) + std::array halo_extents; ///< halo extents by dimension (in global order) + int64_t size; ///< number of elements in pencil (including halo elements) +} decompPencilInfo_t; +void decompPencilInfoSet(decompPencilInfo_t* info, const cudecompPencilInfo_t* source) { + for (int i = 0; i < 3; i++) { + info->hi[i] = source->hi[i]; + info->lo[i] = source->lo[i]; + info->halo_extents[i] = source->halo_extents[i]; + info->order[i] = source->order[i]; + info->shape[i] = source->shape[i]; + } + info->size = source->size; }; +}; // namespace jaxdecomp #endif diff --git a/include/logger.hpp b/include/logger.hpp index 46e17c1..19eda1b 100644 --- a/include/logger.hpp +++ b/include/logger.hpp @@ -39,10 +39,10 @@ #include #include #include +#include #include #include #include -#include // #ifdef MPI_VERSION // #include @@ -50,19 +50,17 @@ class AsyncLogger { public: - AsyncLogger(const std::string &name) - : name(name), bufferSize(10 * 1024 * 1024), buffer(""), traceInfo(false), - traceVerbose(false), traceToConsole(true) { - static const char *traceEnv = std::getenv("ASYNC_TRACE"); + AsyncLogger(const std::string& name) + : name(name), bufferSize(10 * 1024 * 1024), buffer(""), traceInfo(false), traceVerbose(false), + traceToConsole(true) { + static const char* traceEnv = std::getenv("ASYNC_TRACE"); if (traceEnv != nullptr) { std::string traceString = traceEnv; size_t pos = traceString.find(name); - if (pos != std::string::npos) { - traceInfo = true; - } + if (pos != std::string::npos) { traceInfo = true; } } - static const char *traceEnvVerb = std::getenv("ASYNC_TRACE_VERBOSE"); + static const char* traceEnvVerb = std::getenv("ASYNC_TRACE_VERBOSE"); if (traceEnvVerb != nullptr) { std::string traceString = traceEnvVerb; size_t pos = traceString.find(name); @@ -72,12 +70,10 @@ class AsyncLogger { } } - static const char *bufferSizeEnv = std::getenv("ASYNC_TRACE_MAX_BUFFER"); - if (bufferSizeEnv != nullptr) { - bufferSize = std::atoi(bufferSizeEnv); - } + static const char* bufferSizeEnv = std::getenv("ASYNC_TRACE_MAX_BUFFER"); + if (bufferSizeEnv != nullptr) { bufferSize = std::atoi(bufferSizeEnv); } - static const char *outputDirEnv = std::getenv("ASYNC_TRACE_OUTPUT_DIR"); + static const char* outputDirEnv = std::getenv("ASYNC_TRACE_OUTPUT_DIR"); if (outputDirEnv != nullptr) { outputDir = outputDirEnv; // Ensure the output directory exists @@ -85,24 +81,24 @@ class AsyncLogger { traceToConsole = false; } - static const char *traceToConsoleEnv = std::getenv("ASYNC_TRACE_CONSOLE"); + static const char* traceToConsoleEnv = std::getenv("ASYNC_TRACE_CONSOLE"); if (traceToConsoleEnv != nullptr) { traceToConsole = std::atoi(traceToConsoleEnv) != 0; traceToConsole = true; } - static const char *nobufferEnv = std::getenv("ASYNC_TRACE_NOBUFFER"); + static const char* nobufferEnv = std::getenv("ASYNC_TRACE_NOBUFFER"); if (nobufferEnv != nullptr) { nobuffer = std::atoi(nobufferEnv) != 0; nobuffer = true; } -// This requires MPI to be already initialized, which happens only later -// #ifdef MPI_VERSION -// MPI_Comm_rank(MPI_COMM_WORLD, &rank); -// #endif + // This requires MPI to be already initialized, which happens only later + // #ifdef MPI_VERSION + // MPI_Comm_rank(MPI_COMM_WORLD, &rank); + // #endif } - AsyncLogger &startTraceInfo() { + AsyncLogger& startTraceInfo() { if (traceInfo || traceVerbose) { std::ostringstream ss; addTimestamp(ss); @@ -113,7 +109,7 @@ class AsyncLogger { return *this; } - AsyncLogger &startTraceVerbose() { + AsyncLogger& startTraceVerbose() { if (traceInfo || traceVerbose) { std::ostringstream ss; addTimestamp(ss); @@ -124,48 +120,40 @@ class AsyncLogger { return *this; } - template AsyncLogger &operator<<(const T &value) { + template AsyncLogger& operator<<(const T& value) { if (traceInfo || traceVerbose) { std::ostringstream ss; ss << value; buffer += ss.str(); - if (buffer.size() >= bufferSize || nobuffer) { - flush(); - } + if (buffer.size() >= bufferSize || nobuffer) { flush(); } } return *this; } // Specialization for bool - AsyncLogger &operator<<(bool value) { + AsyncLogger& operator<<(bool value) { if (traceInfo || traceVerbose) { std::ostringstream ss; ss << std::boolalpha << value; buffer += ss.str(); - if (buffer.size() >= bufferSize || nobuffer) { - flush(); - } + if (buffer.size() >= bufferSize || nobuffer) { flush(); } } return *this; } // Specialization for std::endl - AsyncLogger &operator<<(std::ostream &(*manipulator)(std::ostream &)) { + AsyncLogger& operator<<(std::ostream& (*manipulator)(std::ostream&)) { if (traceInfo || traceVerbose) { std::ostringstream ss; ss << manipulator; buffer += ss.str(); - if (buffer.size() >= bufferSize || nobuffer) { - flush(); - } + if (buffer.size() >= bufferSize || nobuffer) { flush(); } } return *this; } ~AsyncLogger() { - if (traceInfo || traceVerbose) { - flush(); - } + if (traceInfo || traceVerbose) { flush(); } } bool getTraceInfo() const { return traceInfo; } @@ -196,9 +184,9 @@ class AsyncLogger { ss << "Call stack:" << std::endl; const int max_frames = 64; - void *frame_ptrs[max_frames]; + void* frame_ptrs[max_frames]; int num_frames = backtrace(frame_ptrs, max_frames); - char **symbols = backtrace_symbols(frame_ptrs, num_frames); + char** symbols = backtrace_symbols(frame_ptrs, num_frames); if (symbols == nullptr) { buffer += "Error retrieving backtrace symbols." + std::string("\n"); @@ -209,8 +197,7 @@ class AsyncLogger { // Demangle the C++ function name size_t size; int status; - char *demangled = - abi::__cxa_demangle(symbols[i], nullptr, &size, &status); + char* demangled = abi::__cxa_demangle(symbols[i], nullptr, &size, &status); if (status == 0) { ss << demangled << std::endl; @@ -225,19 +212,15 @@ class AsyncLogger { buffer += ss.str(); - if (buffer.size() >= bufferSize || nobuffer) { - flush(); - } + if (buffer.size() >= bufferSize || nobuffer) { flush(); } } } private: - void addTimestamp(std::ostringstream &stream) { + void addTimestamp(std::ostringstream& stream) { auto now = std::chrono::system_clock::now(); auto timePoint = std::chrono::system_clock::to_time_t(now); - auto milliseconds = std::chrono::duration_cast( - now.time_since_epoch()) % - 1000; + auto milliseconds = std::chrono::duration_cast(now.time_since_epoch()) % 1000; std::tm tm; #ifdef _WIN32 @@ -246,9 +229,8 @@ class AsyncLogger { localtime_r(&timePoint, &tm); #endif - stream << "[" << tm.tm_year + 1900 << "/" << tm.tm_mon + 1 << "/" - << tm.tm_mday << " " << tm.tm_hour << ":" << tm.tm_min << ":" - << tm.tm_sec << ":" << milliseconds.count() << "] "; + stream << "[" << tm.tm_year + 1900 << "/" << tm.tm_mon + 1 << "/" << tm.tm_mday << " " << tm.tm_hour << ":" + << tm.tm_min << ":" << tm.tm_sec << ":" << milliseconds.count() << "] "; } std::string name; @@ -262,25 +244,20 @@ class AsyncLogger { int rank = -1; }; -#define StartTraceInfo(logger) \ - if (logger.getTraceInfo()) \ - logger.startTraceInfo() +#define StartTraceInfo(logger) \ + if (logger.getTraceInfo()) logger.startTraceInfo() -#define TraceInfo(logger) \ - if (logger.getTraceInfo()) \ - logger +#define TraceInfo(logger) \ + if (logger.getTraceInfo()) logger -#define PrintStack(logger) \ - if (logger.getTraceInfo()) \ - logger.addStackTrace() +#define PrintStack(logger) \ + if (logger.getTraceInfo()) logger.addStackTrace() -#define StartTraceVerbose(logger) \ - if (logger.getTraceVerbose()) \ - logger.startTraceVerbose() +#define StartTraceVerbose(logger) \ + if (logger.getTraceVerbose()) logger.startTraceVerbose() -#define TraceVerbose(logger) \ - if (logger.getTraceVerbose()) \ - logger +#define TraceVerbose(logger) \ + if (logger.getTraceVerbose()) logger #endif // ASYNC_LOGGER_HPP diff --git a/include/perfostep.hpp b/include/perfostep.hpp index 16c1861..f8219e8 100644 --- a/include/perfostep.hpp +++ b/include/perfostep.hpp @@ -73,15 +73,13 @@ typedef std::map Reports; class AbstractPerfostep { public: - virtual void Start(const std::string &iReport, const ColumnNames &iCol) = 0; + virtual void Start(const std::string& iReport, const ColumnNames& iCol) = 0; virtual double Stop() = 0; - virtual void Report(const bool &iPrintTotal = false) const = 0; - virtual void PrintToMarkdown(const char *ifilename, - const bool &iPrintTotal = false) const = 0; - virtual void PrintToCSV(const char *ifilename, - const bool &iPrintTotal = false) const = 0; - virtual void Switch(const std::string &iReport, const ColumnNames &iCol) = 0; + virtual void Report(const bool& iPrintTotal = false) const = 0; + virtual void PrintToMarkdown(const char* ifilename, const bool& iPrintTotal = false) const = 0; + virtual void PrintToCSV(const char* ifilename, const bool& iPrintTotal = false) const = 0; + virtual void Switch(const std::string& iReport, const ColumnNames& iCol) = 0; virtual ~AbstractPerfostep() {} protected: @@ -91,61 +89,46 @@ class AbstractPerfostep { class BasePerfostep : public AbstractPerfostep { public: - void Report(const bool &iPrintTotal = false) const override { - if (m_Reports.size() == 0) - return; + void Report(const bool& iPrintTotal = false) const override { + if (m_Reports.size() == 0) return; std::cout << "Reporting : " << std::endl; std::cout << "For parameters: " << std::endl; - for (const auto &entry : m_ColNames) { - std::cout << std::get<0>(entry) << " : " << std::get<1>(entry) - << std::endl; + for (const auto& entry : m_ColNames) { + std::cout << std::get<0>(entry) << " : " << std::get<1>(entry) << std::endl; } - for (const auto &entry : m_Reports) { - std::cout << std::get<0>(entry) << " : " << std::get<1>(entry) << "ms " - << std::endl; + for (const auto& entry : m_Reports) { + std::cout << std::get<0>(entry) << " : " << std::get<1>(entry) << "ms " << std::endl; } } - void PrintToMarkdown(const char *filename, - const bool &iPrintTotal = false) const override { - if (m_Reports.size() == 0) - return; + void PrintToMarkdown(const char* filename, const bool& iPrintTotal = false) const override { + if (m_Reports.size() == 0) return; std::ofstream file(filename, std::ios::app); - if (!file.is_open()) { - throw std::runtime_error("Failed to open file: " + std::string(filename)); - } + if (!file.is_open()) { throw std::runtime_error("Failed to open file: " + std::string(filename)); } if (file.tellp() == 0) { // Check if file is empty file << "| Task | "; - for (const auto &entry : m_ColNames) { - file << std::get<0>(entry) << " | "; - } + for (const auto& entry : m_ColNames) { file << std::get<0>(entry) << " | "; } file << "Elapsed Time (ms) |" << std::endl; // Header names for columns // For the Task column file << "| ---- | "; // For the other columns - for (const auto &entry : m_ColNames) { - file << std::string(entry.first.length(), '-') << " | "; - } + for (const auto& entry : m_ColNames) { file << std::string(entry.first.length(), '-') << " | "; } // For the elapsed time column file << " ---------------- |" << std::endl; } std::string colvalues; - for (const auto &col : m_ColNames) { - colvalues += std::get<1>(col) + " | "; - } + for (const auto& col : m_ColNames) { colvalues += std::get<1>(col) + " | "; } - for (const auto &entry : m_Reports) { - file << "| " << std::get<0>(entry) << " | " << colvalues - << std::get<1>(entry) << " |" << std::endl; + for (const auto& entry : m_Reports) { + file << "| " << std::get<0>(entry) << " | " << colvalues << std::get<1>(entry) << " |" << std::endl; } - if (iPrintTotal) - file << "| Total | " << colvalues << GetTotal() << " |" << std::endl; + if (iPrintTotal) file << "| Total | " << colvalues << GetTotal() << " |" << std::endl; file.close(); } /** @@ -153,41 +136,31 @@ class BasePerfostep : public AbstractPerfostep { * format to a file. * @param filename The name of the file to write the CSV data to. */ - void PrintToCSV(const char *filename, - const bool &iPrintTotal) const override { - if (m_Reports.size() == 0) - return; + void PrintToCSV(const char* filename, const bool& iPrintTotal) const override { + if (m_Reports.size() == 0) return; std::ofstream file(filename, std::ios::app); // Open file in append mode - if (!file.is_open()) { - throw std::runtime_error("Failed to open file: " + std::string(filename)); - } + if (!file.is_open()) { throw std::runtime_error("Failed to open file: " + std::string(filename)); } if (file.tellp() == 0) { // Check if file is empty file << "Task,"; - for (const auto &entry : m_ColNames) { - file << std::get<0>(entry) << ","; - } + for (const auto& entry : m_ColNames) { file << std::get<0>(entry) << ","; } file << "Elapsed Time (ms)" << std::endl; // Header names for columns } std::string colvalues; - for (const auto &col : m_ColNames) { - colvalues += std::get<1>(col) + ","; - } + for (const auto& col : m_ColNames) { colvalues += std::get<1>(col) + ","; } - for (const auto &entry : m_Reports) { - file << std::get<0>(entry) << "," << colvalues << std::get<1>(entry) - << std::endl; + for (const auto& entry : m_Reports) { + file << std::get<0>(entry) << "," << colvalues << std::get<1>(entry) << std::endl; } - if (iPrintTotal) - file << "Total," << colvalues << GetTotal() << std::endl; + if (iPrintTotal) file << "Total," << colvalues << GetTotal() << std::endl; file.close(); } - void Switch(const std::string &iReport, const ColumnNames &iCol) override { + void Switch(const std::string& iReport, const ColumnNames& iCol) override { Stop(); Start(iReport, iCol); } @@ -196,22 +169,17 @@ class BasePerfostep : public AbstractPerfostep { double GetTotal() const { double total = std::accumulate( m_Reports.begin(), m_Reports.end(), 0.0, - [](double sum, const std::tuple &entry) { - return sum + std::get<1>(entry); - }); + [](double sum, const std::tuple& entry) { return sum + std::get<1>(entry); }); return total; } }; -typedef std::vector< - std::tuple>> - StartTimes; +typedef std::vector>> StartTimes; class PerfostepChrono : public BasePerfostep { public: - void Start(const std::string &iReport, const ColumnNames &iCol) override { - m_StartTimes.push_back( - std::make_tuple(iReport, high_resolution_clock::now())); + void Start(const std::string& iReport, const ColumnNames& iCol) override { + m_StartTimes.push_back(std::make_tuple(iReport, high_resolution_clock::now())); m_ColNames = iCol; } @@ -229,20 +197,17 @@ class PerfostepChrono : public BasePerfostep { ~PerfostepChrono() { if (m_StartTimes.size() > 0) { - std::cerr << "Warning: There are still start times not stopped" - << std::endl; + std::cerr << "Warning: There are still start times not stopped" << std::endl; // print message for each start time - for (const auto &entry : m_StartTimes) { - std::cerr << "Start time for " << std::get<0>(entry) - << " is not stopped" << std::endl; + for (const auto& entry : m_StartTimes) { + std::cerr << "Start time for " << std::get<0>(entry) << " is not stopped" << std::endl; } } } private: - StartTimes m_StartTimes; /**< The start time of the measurement. */ - time_point - m_EndTime; /**< The end time of the measurement. */ + StartTimes m_StartTimes; /**< The start time of the measurement. */ + time_point m_EndTime; /**< The end time of the measurement. */ }; #ifdef ENABLE_NVTX @@ -251,11 +216,10 @@ class PerfostepChrono : public BasePerfostep { class PerfostepNVTX : public BasePerfostep { public: // ColumnNames are not used in NVTX - void Start(const std::string &iReport, const ColumnNames &iCol) override { + void Start(const std::string& iReport, const ColumnNames& iCol) override { static constexpr int ncolors_ = 8; - static constexpr int colors_[ncolors_] = {0x3366CC, 0xDC3912, 0xFF9900, - 0x109618, 0x990099, 0x3B3EAC, - 0x0099C6, 0xDD4477}; + static constexpr int colors_[ncolors_] = {0x3366CC, 0xDC3912, 0xFF9900, 0x109618, + 0x990099, 0x3B3EAC, 0x0099C6, 0xDD4477}; std::string range_name(iReport); std::hash hash_fn; int color = colors_[hash_fn(range_name) % ncolors_]; @@ -278,10 +242,8 @@ class PerfostepNVTX : public BasePerfostep { } ~PerfostepNVTX() { if (nvtx_ranges > 0) { - std::cerr << "Warning: There are still start times not stopped" - << std::endl; - for (int i = 0; i < nvtx_ranges; i++) - nvtxRangePop(); + std::cerr << "Warning: There are still start times not stopped" << std::endl; + for (int i = 0; i < nvtx_ranges; i++) nvtxRangePop(); } } @@ -299,7 +261,7 @@ class PerfostepCUDA : public BasePerfostep { public: PerfostepCUDA() { cudaEventCreate(&m_EndEvent); } - void Start(const std::string &iReport, const ColumnNames &iCol) override { + void Start(const std::string& iReport, const ColumnNames& iCol) override { cudaEvent_t m_StartEvent; cudaEventCreate(&m_StartEvent); cudaEventRecord(m_StartEvent); @@ -311,8 +273,7 @@ class PerfostepCUDA : public BasePerfostep { cudaEventRecord(m_EndEvent); cudaEventSynchronize(m_EndEvent); float elapsed; - cudaEventElapsedTime(&elapsed, std::get<1>(m_StartEvents.back()), - m_EndEvent); + cudaEventElapsedTime(&elapsed, std::get<1>(m_StartEvents.back()), m_EndEvent); double m_ElapsedTime = static_cast(elapsed); cudaEventDestroy(std::get<1>(m_StartEvents.back())); m_Reports[std::get<0>(m_StartEvents.back())] = m_ElapsedTime; @@ -323,13 +284,10 @@ class PerfostepCUDA : public BasePerfostep { ~PerfostepCUDA() { if (m_StartEvents.size() > 0) { - std::cerr << "Warning: There are still start events not stopped" - << std::endl; + std::cerr << "Warning: There are still start events not stopped" << std::endl; std::for_each( m_StartEvents.cbegin(), m_StartEvents.cend(), - [](const std::tuple &entry) { - cudaEventDestroy(std::get<1>(entry)); - }); + [](const std::tuple& entry) { cudaEventDestroy(std::get<1>(entry)); }); } cudaEventDestroy(m_EndEvent); } @@ -344,7 +302,7 @@ class PerfostepCUDA : public BasePerfostep { class Perfostep { public: Perfostep() { - static const char *env = std::getenv("ENABLE_PERFO_STEP"); + static const char* env = std::getenv("ENABLE_PERFO_STEP"); if (env != nullptr) { std::string envStr(env); if (envStr == "TIMER") { @@ -355,8 +313,7 @@ class Perfostep { m_Perfostep = std::make_unique(); m_EnablePerfoStep = true; #else - throw std::runtime_error( - "NVTX is not available. Please install NVTX to use it."); + throw std::runtime_error("NVTX is not available. Please install NVTX to use it."); #endif } else if (envStr == "CUDA") { #ifdef ENABLE_CUDA @@ -367,39 +324,31 @@ class Perfostep { "or compile using nvcc to use it."); #endif } else { - throw std::runtime_error( - "Invalid value for ENABLE_PERFO_STEP: " + envStr + - ". Possible values are TIMER, NVTX, or " - "CUDA."); + throw std::runtime_error("Invalid value for ENABLE_PERFO_STEP: " + envStr + + ". Possible values are TIMER, NVTX, or " + "CUDA."); } } } - void Start(const std::string &iReport, const ColumnNames &iCol = {}) { - if (m_EnablePerfoStep) - m_Perfostep->Start(iReport, iCol); + void Start(const std::string& iReport, const ColumnNames& iCol = {}) { + if (m_EnablePerfoStep) m_Perfostep->Start(iReport, iCol); } double Stop() { - if (m_EnablePerfoStep) - return m_Perfostep->Stop(); + if (m_EnablePerfoStep) return m_Perfostep->Stop(); return 0.0; } - void Report(const bool &iPrintTotal = false) const { - if (m_EnablePerfoStep) - m_Perfostep->Report(iPrintTotal); + void Report(const bool& iPrintTotal = false) const { + if (m_EnablePerfoStep) m_Perfostep->Report(iPrintTotal); } - void PrintToMarkdown(const char *filename, - const bool &iPrintTotal = false) const { - if (m_EnablePerfoStep) - m_Perfostep->PrintToMarkdown(filename, iPrintTotal); + void PrintToMarkdown(const char* filename, const bool& iPrintTotal = false) const { + if (m_EnablePerfoStep) m_Perfostep->PrintToMarkdown(filename, iPrintTotal); } - void PrintToCSV(const char *filename, const bool &iPrintTotal = false) const { - if (m_EnablePerfoStep) - m_Perfostep->PrintToCSV(filename, iPrintTotal); + void PrintToCSV(const char* filename, const bool& iPrintTotal = false) const { + if (m_EnablePerfoStep) m_Perfostep->PrintToCSV(filename, iPrintTotal); } - void Switch(const std::string &iReport, const ColumnNames &iCol = {}) { - if (m_EnablePerfoStep) - m_Perfostep->Switch(iReport, iCol); + void Switch(const std::string& iReport, const ColumnNames& iCol = {}) { + if (m_EnablePerfoStep) m_Perfostep->Switch(iReport, iCol); } private: diff --git a/src/fft.cu b/src/fft.cu index 87dd2b1..b74947e 100644 --- a/src/fft.cu +++ b/src/fft.cu @@ -1,7 +1,7 @@ -#include "logger.hpp" -#include "perfostep.hpp" #include "checks.h" #include "fft.h" +#include "logger.hpp" +#include "perfostep.hpp" #include #include #include @@ -12,33 +12,26 @@ namespace jaxdecomp { template -HRESULT FourierExecutor::Initialize(cudecompHandle_t handle, - cudecompGridDescConfig_t config, - size_t &work_size, - fftDescriptor &fft_descriptor) { +HRESULT FourierExecutor::Initialize(cudecompHandle_t handle, cudecompGridDescConfig_t config, size_t& work_size, + fftDescriptor& fft_descriptor) { Perfostep profiler; profiler.Start("CreateGridDesc"); m_GridDescConfig = config; - CHECK_CUDECOMP_EXIT( - cudecompGridDescCreate(handle, &m_GridConfig, &config, nullptr)); + CHECK_CUDECOMP_EXIT(cudecompGridDescCreate(handle, &m_GridConfig, &config, nullptr)); // Get x-pencil information (complex) cudecompPencilInfo_t pinfo_x_c; - CHECK_CUDECOMP_EXIT( - cudecompGetPencilInfo(handle, m_GridConfig, &pinfo_x_c, 0, nullptr)); + CHECK_CUDECOMP_EXIT(cudecompGetPencilInfo(handle, m_GridConfig, &pinfo_x_c, 0, nullptr)); // Get y-pencil information (complex) cudecompPencilInfo_t pinfo_y_c; - CHECK_CUDECOMP_EXIT( - cudecompGetPencilInfo(handle, m_GridConfig, &pinfo_y_c, 1, nullptr)); + CHECK_CUDECOMP_EXIT(cudecompGetPencilInfo(handle, m_GridConfig, &pinfo_y_c, 1, nullptr)); // Get z-pencil information (complex) cudecompPencilInfo_t pinfo_z_c; - CHECK_CUDECOMP_EXIT( - cudecompGetPencilInfo(handle, m_GridConfig, &pinfo_z_c, 2, nullptr)); + CHECK_CUDECOMP_EXIT(cudecompGetPencilInfo(handle, m_GridConfig, &pinfo_z_c, 2, nullptr)); // Get workspace size int64_t num_elements_work_c; - CHECK_CUDECOMP_EXIT(cudecompGetTransposeWorkspaceSize(handle, m_GridConfig, - &num_elements_work_c)); + CHECK_CUDECOMP_EXIT(cudecompGetTransposeWorkspaceSize(handle, m_GridConfig, &num_elements_work_c)); profiler.Stop(); // Set up the FFT plan @@ -48,9 +41,8 @@ HRESULT FourierExecutor::Initialize(cudecompHandle_t handle, // Simple code is better // No need to handle mixed contiguous and non-contiguous cases - bool is_contiguous = config.transpose_axis_contiguous[0] || - config.transpose_axis_contiguous[1] || - config.transpose_axis_contiguous[2]; + bool is_contiguous = + config.transpose_axis_contiguous[0] || config.transpose_axis_contiguous[1] || config.transpose_axis_contiguous[2]; is_contiguous = true; // Force only contiguous case for now because I need to // review the non-contiguous case @@ -62,22 +54,17 @@ HRESULT FourierExecutor::Initialize(cudecompHandle_t handle, switch (GetDecomposition(config.pdims)) { case Decomposition::slab_XY: profiler.Start("InitializeSlabXY"); - hr = InitializeSlabXY(config, pinfo_x_c, pinfo_y_c, pinfo_z_c, - work_sz_cufft, is_contiguous); + hr = InitializeSlabXY(config, pinfo_x_c, pinfo_y_c, pinfo_z_c, work_sz_cufft, is_contiguous); break; case Decomposition::slab_YZ: profiler.Start("InitializeSlabYZ"); - hr = InitializeSlabYZ(config, pinfo_x_c, pinfo_y_c, pinfo_z_c, - work_sz_cufft, is_contiguous); + hr = InitializeSlabYZ(config, pinfo_x_c, pinfo_y_c, pinfo_z_c, work_sz_cufft, is_contiguous); break; case Decomposition::pencil: profiler.Start("InitializePencils"); - hr = InitializePencils(config, pinfo_x_c, pinfo_y_c, pinfo_z_c, - work_sz_cufft, is_contiguous); - break; - case Decomposition::unknown: - hr = E_FAIL; + hr = InitializePencils(config, pinfo_x_c, pinfo_y_c, pinfo_z_c, work_sz_cufft, is_contiguous); break; + case Decomposition::unknown: hr = E_FAIL; break; } profiler.Stop(); @@ -107,14 +94,14 @@ HRESULT FourierExecutor::Initialize(cudecompHandle_t handle, } template -HRESULT FourierExecutor::InitializePencils( - cudecompGridDescConfig_t &iGridConfig, cudecompPencilInfo_t &x_pencil_info, - cudecompPencilInfo_t &y_pencil_info, cudecompPencilInfo_t &z_pencil_info, - int64_t &work_size, const bool &is_contiguous) { - - int &gx = iGridConfig.gdims[0]; // take reference to avoid copying - int &gy = iGridConfig.gdims[1]; - int &gz = iGridConfig.gdims[2]; +HRESULT +FourierExecutor::InitializePencils(cudecompGridDescConfig_t& iGridConfig, cudecompPencilInfo_t& x_pencil_info, + cudecompPencilInfo_t& y_pencil_info, cudecompPencilInfo_t& z_pencil_info, + int64_t& work_size, const bool& is_contiguous) { + + int& gx = iGridConfig.gdims[0]; // take reference to avoid copying + int& gy = iGridConfig.gdims[1]; + int& gz = iGridConfig.gdims[2]; // Create the plans CHECK_CUFFT_EXIT(cufftCreate(&m_Plan_c2c_x)); CHECK_CUFFT_EXIT(cufftCreate(&m_Plan_c2c_y)); @@ -128,32 +115,29 @@ HRESULT FourierExecutor::InitializePencils( // The work size size_t work_sz_c2c_x, work_sz_c2c_y, work_sz_c2c_z; // The X plan - CHECK_CUFFT_EXIT(cufftMakePlan1d( - m_Plan_c2c_x, gx, get_cufft_type_c2c(real_t(0)), - x_pencil_info.shape[1] * x_pencil_info.shape[2], &work_sz_c2c_x)); + CHECK_CUFFT_EXIT(cufftMakePlan1d(m_Plan_c2c_x, gx, get_cufft_type_c2c(real_t(0)), + x_pencil_info.shape[1] * x_pencil_info.shape[2], &work_sz_c2c_x)); // The Y plan - CHECK_CUFFT_EXIT(cufftMakePlan1d( - m_Plan_c2c_y, gy, get_cufft_type_c2c(real_t(0)), - y_pencil_info.shape[1] * y_pencil_info.shape[2], &work_sz_c2c_y)); + CHECK_CUFFT_EXIT(cufftMakePlan1d(m_Plan_c2c_y, gy, get_cufft_type_c2c(real_t(0)), + y_pencil_info.shape[1] * y_pencil_info.shape[2], &work_sz_c2c_y)); // The Z plan - CHECK_CUFFT_EXIT(cufftMakePlan1d( - m_Plan_c2c_z, gz, get_cufft_type_c2c(real_t(0)), - z_pencil_info.shape[1] * z_pencil_info.shape[2], &work_sz_c2c_z)); + CHECK_CUFFT_EXIT(cufftMakePlan1d(m_Plan_c2c_z, gz, get_cufft_type_c2c(real_t(0)), + z_pencil_info.shape[1] * z_pencil_info.shape[2], &work_sz_c2c_z)); work_size = std::max(work_sz_c2c_x, std::max(work_sz_c2c_y, work_sz_c2c_z)); return work_size > 0 ? S_OK : E_FAIL; } template -HRESULT FourierExecutor::InitializeSlabXY( - cudecompGridDescConfig_t &iGridConfig, cudecompPencilInfo_t &x_pencil_info, - cudecompPencilInfo_t &y_pencil_info, cudecompPencilInfo_t &z_pencil_info, - int64_t &work_size, const bool &is_contiguous) { - int &gx = iGridConfig.gdims[0]; // take reference to avoid copying - int &gy = iGridConfig.gdims[1]; - int &gz = iGridConfig.gdims[2]; +HRESULT +FourierExecutor::InitializeSlabXY(cudecompGridDescConfig_t& iGridConfig, cudecompPencilInfo_t& x_pencil_info, + cudecompPencilInfo_t& y_pencil_info, cudecompPencilInfo_t& z_pencil_info, + int64_t& work_size, const bool& is_contiguous) { + int& gx = iGridConfig.gdims[0]; // take reference to avoid copying + int& gy = iGridConfig.gdims[1]; + int& gz = iGridConfig.gdims[2]; // The XY plan CHECK_CUFFT_EXIT(cufftCreate(&m_Plan_c2c_xy)); CHECK_CUFFT_EXIT(cufftSetAutoAllocation(m_Plan_c2c_xy, 0)); @@ -169,30 +153,26 @@ HRESULT FourierExecutor::InitializeSlabXY( // (Side note: the first axis is always contiguous in cuDecomp) std::array y_x{gy, gx}; - CHECK_CUFFT_EXIT(cufftMakePlanMany( - m_Plan_c2c_xy, 2, y_x.data(), nullptr, 1, - x_pencil_info.shape[0] * x_pencil_info.shape[1], nullptr, 1, - x_pencil_info.shape[0] * x_pencil_info.shape[1], - get_cufft_type_c2c(real_t(0)), x_pencil_info.shape[2], &work_size_xy)); + CHECK_CUFFT_EXIT(cufftMakePlanMany(m_Plan_c2c_xy, 2, y_x.data(), nullptr, 1, + x_pencil_info.shape[0] * x_pencil_info.shape[1], nullptr, 1, + x_pencil_info.shape[0] * x_pencil_info.shape[1], get_cufft_type_c2c(real_t(0)), + x_pencil_info.shape[2], &work_size_xy)); if (is_contiguous) { // make the second plan - CHECK_CUFFT_EXIT(cufftMakePlan1d( - m_Plan_c2c_z, gz, get_cufft_type_c2c(real_t(0)), - z_pencil_info.shape[1] * z_pencil_info.shape[2], &work_size_z)); + CHECK_CUFFT_EXIT(cufftMakePlan1d(m_Plan_c2c_z, gz, get_cufft_type_c2c(real_t(0)), + z_pencil_info.shape[1] * z_pencil_info.shape[2], &work_size_z)); } else { // TODO(wassim) : I did not understand this yet // Making the second non contiguous plans first for Z Y slab Y is not // contiguous here - CHECK_CUFFT_EXIT(cufftMakePlanMany( - m_Plan_c2c_z, 1, &gz /* unused */, &gz, - z_pencil_info.shape[0] * z_pencil_info.shape[1], 1, &gz, - z_pencil_info.shape[0] * z_pencil_info.shape[1], 1, - get_cufft_type_c2c(real_t(0)), - z_pencil_info.shape[0] * z_pencil_info.shape[1], &work_size_z)); + CHECK_CUFFT_EXIT( + cufftMakePlanMany(m_Plan_c2c_z, 1, &gz /* unused */, &gz, z_pencil_info.shape[0] * z_pencil_info.shape[1], 1, + &gz, z_pencil_info.shape[0] * z_pencil_info.shape[1], 1, get_cufft_type_c2c(real_t(0)), + z_pencil_info.shape[0] * z_pencil_info.shape[1], &work_size_z)); // Another Batched many plan should be made here } @@ -202,14 +182,14 @@ HRESULT FourierExecutor::InitializeSlabXY( } template -HRESULT FourierExecutor::InitializeSlabYZ( - cudecompGridDescConfig_t &iGridConfig, cudecompPencilInfo_t &x_pencil_info, - cudecompPencilInfo_t &y_pencil_info, cudecompPencilInfo_t &z_pencil_info, - int64_t &work_size, const bool &is_contiguous) { - - int &gx = iGridConfig.gdims[0]; // take reference to avoid copying - int &gy = iGridConfig.gdims[1]; - int &gz = iGridConfig.gdims[2]; +HRESULT +FourierExecutor::InitializeSlabYZ(cudecompGridDescConfig_t& iGridConfig, cudecompPencilInfo_t& x_pencil_info, + cudecompPencilInfo_t& y_pencil_info, cudecompPencilInfo_t& z_pencil_info, + int64_t& work_size, const bool& is_contiguous) { + + int& gx = iGridConfig.gdims[0]; // take reference to avoid copying + int& gy = iGridConfig.gdims[1]; + int& gz = iGridConfig.gdims[2]; // The XY plan CHECK_CUFFT_EXIT(cufftCreate(&m_Plan_c2c_x)); CHECK_CUFFT_EXIT(cufftSetAutoAllocation(m_Plan_c2c_x, 0)); @@ -219,29 +199,26 @@ HRESULT FourierExecutor::InitializeSlabYZ( // Get the plan sizes size_t work_size_x, work_size_yz; - CHECK_CUFFT_EXIT(cufftMakePlan1d( - m_Plan_c2c_x, gx, get_cufft_type_c2c(real_t(0)), - x_pencil_info.shape[1] * x_pencil_info.shape[2], &work_size_x)); + CHECK_CUFFT_EXIT(cufftMakePlan1d(m_Plan_c2c_x, gx, get_cufft_type_c2c(real_t(0)), + x_pencil_info.shape[1] * x_pencil_info.shape[2], &work_size_x)); if (is_contiguous) { // make the second plan YZ std::array n{gz, gy}; - CHECK_CUFFT_EXIT(cufftMakePlanMany( - m_Plan_c2c_yz, 2, n.data(), nullptr, 1, - y_pencil_info.shape[0] * y_pencil_info.shape[1], nullptr, 1, - y_pencil_info.shape[0] * y_pencil_info.shape[1], - get_cufft_type_c2c(real_t(0)), y_pencil_info.shape[2], &work_size_yz)); + CHECK_CUFFT_EXIT(cufftMakePlanMany(m_Plan_c2c_yz, 2, n.data(), nullptr, 1, + y_pencil_info.shape[0] * y_pencil_info.shape[1], nullptr, 1, + y_pencil_info.shape[0] * y_pencil_info.shape[1], get_cufft_type_c2c(real_t(0)), + y_pencil_info.shape[2], &work_size_yz)); } else { // TODO(wassim) : I did not understand this yet // Making the second non contiguous plans first for Z Y slab Y is not // contiguous here - CHECK_CUFFT_EXIT(cufftMakePlanMany( - m_Plan_c2c_y, 1, &gy /* unused */, &gy, y_pencil_info.shape[0], 1, &gy, - y_pencil_info.shape[0], 1, get_cufft_type_c2c(real_t(0)), - y_pencil_info.shape[0], &work_size_yz)); + CHECK_CUFFT_EXIT(cufftMakePlanMany(m_Plan_c2c_y, 1, &gy /* unused */, &gy, y_pencil_info.shape[0], 1, &gy, + y_pencil_info.shape[0], 1, get_cufft_type_c2c(real_t(0)), y_pencil_info.shape[0], + &work_size_yz)); // Another Batched many plan should be made here } @@ -251,21 +228,20 @@ HRESULT FourierExecutor::InitializeSlabYZ( } template -HRESULT FourierExecutor::forward(cudecompHandle_t handle, - fftDescriptor desc, - cudaStream_t stream, void **buffers) { +HRESULT FourierExecutor::forward(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, + void** buffers) { Perfostep profiler; profiler.Start("forward"); HRESULT hr(E_FAIL); - void *data_d = buffers[0]; - void *work_d = buffers[1]; - complex_t *data_c_d = static_cast(data_d); - complex_t *input = data_c_d; - complex_t *output = data_c_d; + void* data_d = buffers[0]; + void* work_d = buffers[1]; + complex_t* data_c_d = static_cast(data_d); + complex_t* input = data_c_d; + complex_t* output = data_c_d; // Assign cuFFT work area and current XLA stream - complex_t *work_c_d = static_cast(work_d); + complex_t* work_c_d = static_cast(work_d); switch (desc.decomposition) { case Decomposition::slab_XY: profiler.Start("forwardXY"); @@ -279,28 +255,26 @@ HRESULT FourierExecutor::forward(cudecompHandle_t handle, profiler.Start("forwardPencil"); hr = forwardPencil(handle, desc, stream, input, output, work_c_d); break; - case Decomposition::unknown: - hr = E_FAIL; + case Decomposition::unknown: hr = E_FAIL; } profiler.Stop(); return hr; } template -HRESULT FourierExecutor::backward(cudecompHandle_t handle, - fftDescriptor desc, - cudaStream_t stream, void **buffers) { +HRESULT FourierExecutor::backward(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, + void** buffers) { Perfostep profiler; profiler.Start("backward"); HRESULT hr(E_FAIL); - void *data_d = buffers[0]; - void *work_d = buffers[1]; - complex_t *data_c_d = static_cast(data_d); - complex_t *input = data_c_d; - complex_t *output = data_c_d; + void* data_d = buffers[0]; + void* work_d = buffers[1]; + complex_t* data_c_d = static_cast(data_d); + complex_t* input = data_c_d; + complex_t* output = data_c_d; // Assign cuFFT work area and current XLA stream - complex_t *work_c_d = static_cast(work_d); + complex_t* work_c_d = static_cast(work_d); switch (desc.decomposition) { case Decomposition::slab_XY: profiler.Start("backwardXY"); @@ -314,18 +288,15 @@ HRESULT FourierExecutor::backward(cudecompHandle_t handle, profiler.Start("backwardPencil"); hr = backwardPencil(handle, desc, stream, input, output, work_c_d); break; - case Decomposition::unknown: - hr = E_FAIL; + case Decomposition::unknown: hr = E_FAIL; } profiler.Stop(); return hr; } template -HRESULT -FourierExecutor::forwardXY(cudecompHandle_t handle, fftDescriptor desc, - cudaStream_t stream, complex_t *input, - complex_t *output, complex_t *work_d) { +HRESULT FourierExecutor::forwardXY(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, + complex_t* input, complex_t* output, complex_t* work_d) { const int DIRECTION = desc.adjoint ? CUFFT_INVERSE : CUFFT_FORWARD; @@ -336,13 +307,11 @@ FourierExecutor::forwardXY(cudecompHandle_t handle, fftDescriptor desc, // FFT on the first slab CHECK_CUFFT_EXIT(cufftXtExec(m_Plan_c2c_xy, input, output, DIRECTION)); // Tranpose X to Y - CHECK_CUDECOMP_EXIT(cudecompTransposeXToY( - handle, m_GridConfig, output, output, work_d, - get_cudecomp_datatype(complex_t(0)), nullptr, nullptr, stream)); + CHECK_CUDECOMP_EXIT(cudecompTransposeXToY(handle, m_GridConfig, output, output, work_d, + get_cudecomp_datatype(complex_t(0)), nullptr, nullptr, stream)); // Tranpose Y to Z - CHECK_CUDECOMP_EXIT(cudecompTransposeYToZ( - handle, m_GridConfig, output, output, work_d, - get_cudecomp_datatype(complex_t(0)), nullptr, nullptr, stream)); + CHECK_CUDECOMP_EXIT(cudecompTransposeYToZ(handle, m_GridConfig, output, output, work_d, + get_cudecomp_datatype(complex_t(0)), nullptr, nullptr, stream)); // FFT on the second slab CHECK_CUFFT_EXIT(cufftXtExec(m_Plan_c2c_z, output, output, DIRECTION)); @@ -350,10 +319,8 @@ FourierExecutor::forwardXY(cudecompHandle_t handle, fftDescriptor desc, } template -HRESULT -FourierExecutor::backwardXY(cudecompHandle_t handle, fftDescriptor desc, - cudaStream_t stream, complex_t *input, - complex_t *output, complex_t *work_d) { +HRESULT FourierExecutor::backwardXY(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, + complex_t* input, complex_t* output, complex_t* work_d) { const int DIRECTION = desc.adjoint ? CUFFT_FORWARD : CUFFT_INVERSE; @@ -365,13 +332,11 @@ FourierExecutor::backwardXY(cudecompHandle_t handle, fftDescriptor desc, // FFT on the first slab CHECK_CUFFT_EXIT(cufftXtExec(m_Plan_c2c_z, input, output, DIRECTION)); // Tranpose Z to Y - CHECK_CUDECOMP_EXIT(cudecompTransposeZToY( - handle, m_GridConfig, output, output, work_d, - get_cudecomp_datatype(complex_t(0)), nullptr, nullptr, stream)); + CHECK_CUDECOMP_EXIT(cudecompTransposeZToY(handle, m_GridConfig, output, output, work_d, + get_cudecomp_datatype(complex_t(0)), nullptr, nullptr, stream)); // Tranpose Y to X - CHECK_CUDECOMP_EXIT(cudecompTransposeYToX( - handle, m_GridConfig, output, output, work_d, - get_cudecomp_datatype(complex_t(0)), nullptr, nullptr, stream)); + CHECK_CUDECOMP_EXIT(cudecompTransposeYToX(handle, m_GridConfig, output, output, work_d, + get_cudecomp_datatype(complex_t(0)), nullptr, nullptr, stream)); // IFFT on the second slab CHECK_CUFFT_EXIT(cufftXtExec(m_Plan_c2c_xy, output, output, DIRECTION)); @@ -379,10 +344,8 @@ FourierExecutor::backwardXY(cudecompHandle_t handle, fftDescriptor desc, } template -HRESULT -FourierExecutor::forwardYZ(cudecompHandle_t handle, fftDescriptor desc, - cudaStream_t stream, complex_t *input, - complex_t *output, complex_t *work_d) { +HRESULT FourierExecutor::forwardYZ(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, + complex_t* input, complex_t* output, complex_t* work_d) { const int DIRECTION = desc.adjoint ? CUFFT_INVERSE : CUFFT_FORWARD; @@ -393,9 +356,8 @@ FourierExecutor::forwardYZ(cudecompHandle_t handle, fftDescriptor desc, // FFT on the first slab CHECK_CUFFT_EXIT(cufftXtExec(m_Plan_c2c_x, input, output, DIRECTION)); // Tranpose X to Y - CHECK_CUDECOMP_EXIT(cudecompTransposeXToY( - handle, m_GridConfig, output, output, work_d, - get_cudecomp_datatype(complex_t(0)), nullptr, nullptr, stream)); + CHECK_CUDECOMP_EXIT(cudecompTransposeXToY(handle, m_GridConfig, output, output, work_d, + get_cudecomp_datatype(complex_t(0)), nullptr, nullptr, stream)); // FFT on the second slab CHECK_CUFFT_EXIT(cufftXtExec(m_Plan_c2c_yz, output, output, DIRECTION)); @@ -403,10 +365,8 @@ FourierExecutor::forwardYZ(cudecompHandle_t handle, fftDescriptor desc, } template -HRESULT -FourierExecutor::backwardYZ(cudecompHandle_t handle, fftDescriptor desc, - cudaStream_t stream, complex_t *input, - complex_t *output, complex_t *work_d) { +HRESULT FourierExecutor::backwardYZ(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, + complex_t* input, complex_t* output, complex_t* work_d) { const int DIRECTION = desc.adjoint ? CUFFT_FORWARD : CUFFT_INVERSE; @@ -418,9 +378,8 @@ FourierExecutor::backwardYZ(cudecompHandle_t handle, fftDescriptor desc, // FFT on the first slab CHECK_CUFFT_EXIT(cufftXtExec(m_Plan_c2c_yz, input, output, DIRECTION)); // Tranpose Y to X - CHECK_CUDECOMP_EXIT(cudecompTransposeYToX( - handle, m_GridConfig, output, output, work_d, - get_cudecomp_datatype(complex_t(0)), nullptr, nullptr, stream)); + CHECK_CUDECOMP_EXIT(cudecompTransposeYToX(handle, m_GridConfig, output, output, work_d, + get_cudecomp_datatype(complex_t(0)), nullptr, nullptr, stream)); // IFFT on the second slab CHECK_CUFFT_EXIT(cufftXtExec(m_Plan_c2c_x, output, output, DIRECTION)); @@ -429,16 +388,14 @@ FourierExecutor::backwardYZ(cudecompHandle_t handle, fftDescriptor desc, // DEBUG ONLY ... I WARN YOU template -void FourierExecutor::inspect_device_array(complex_t *data, int size, - cudaStream_t stream) { +void FourierExecutor::inspect_device_array(complex_t* data, int size, cudaStream_t stream) { int rank; CHECK_MPI_EXIT(MPI_Comm_rank(MPI_COMM_WORLD, &rank)); // Copy input to host const int local_size = 4 * 4 * size; - complex_t *host = new complex_t[local_size]; + complex_t* host = new complex_t[local_size]; - cudaMemcpyAsync(host, data, sizeof(complex_t) * local_size, - cudaMemcpyDeviceToHost, stream); + cudaMemcpyAsync(host, data, sizeof(complex_t) * local_size, cudaMemcpyDeviceToHost, stream); cudaStreamSynchronize(stream); MPI_Barrier(MPI_COMM_WORLD); @@ -449,8 +406,7 @@ void FourierExecutor::inspect_device_array(complex_t *data, int size, if (rank == r) // to force printing in order so I have less headache for (int i = 0; i < size * size * size; i++) { - std::cout << "Rank[" << rank << "] Element [" << i - << "] : " << host[i].real() << " + " << host[i].imag() << "i" + std::cout << "Rank[" << rank << "] Element [" << i << "] : " << host[i].real() << " + " << host[i].imag() << "i" << std::endl; } MPI_Barrier(MPI_COMM_WORLD); @@ -467,9 +423,8 @@ void FourierExecutor::inspect_device_array(complex_t *data, int size, for (int x = 0; x < size; x++) { if (rank == r) { int indx = x + y * size + z * size * size; - std::cout << "Rank[" << rank << "] Element (" << x << "," << y - << "," << z << ") : " << host[indx].real() << " + " - << host[indx].imag() << "i" << std::endl; + std::cout << "Rank[" << rank << "] Element (" << x << "," << y << "," << z << ") : " << host[indx].real() + << " + " << host[indx].imag() << "i" << std::endl; } MPI_Barrier(MPI_COMM_WORLD); } @@ -479,9 +434,8 @@ void FourierExecutor::inspect_device_array(complex_t *data, int size, } template -HRESULT FourierExecutor::forwardPencil( - cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, - complex_t *input, complex_t *output, complex_t *work_d) { +HRESULT FourierExecutor::forwardPencil(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, + complex_t* input, complex_t* output, complex_t* work_d) { const int DIRECTION = desc.adjoint ? CUFFT_INVERSE : CUFFT_FORWARD; @@ -495,24 +449,21 @@ HRESULT FourierExecutor::forwardPencil( // FFT on the first pencil CHECK_CUFFT_EXIT(cufftXtExec(m_Plan_c2c_x, input, output, DIRECTION)); // Tranpose X to Y - CHECK_CUDECOMP_EXIT(cudecompTransposeXToY( - handle, m_GridConfig, output, output, work_d, - get_cudecomp_datatype(complex_t(0)), nullptr, nullptr, stream)); + CHECK_CUDECOMP_EXIT(cudecompTransposeXToY(handle, m_GridConfig, output, output, work_d, + get_cudecomp_datatype(complex_t(0)), nullptr, nullptr, stream)); // FFT on the second pencil CHECK_CUFFT_EXIT(cufftXtExec(m_Plan_c2c_y, output, output, DIRECTION)); // Tranpose Y to Z - CHECK_CUDECOMP_EXIT(cudecompTransposeYToZ( - handle, m_GridConfig, output, output, work_d, - get_cudecomp_datatype(complex_t(0)), nullptr, nullptr, stream)); + CHECK_CUDECOMP_EXIT(cudecompTransposeYToZ(handle, m_GridConfig, output, output, work_d, + get_cudecomp_datatype(complex_t(0)), nullptr, nullptr, stream)); // FFT on the third pencil CHECK_CUFFT_EXIT(cufftXtExec(m_Plan_c2c_z, output, output, DIRECTION)); return S_OK; } template -HRESULT FourierExecutor::backwardPencil( - cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, - complex_t *input, complex_t *output, complex_t *work_d) { +HRESULT FourierExecutor::backwardPencil(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, + complex_t* input, complex_t* output, complex_t* work_d) { const int DIRECTION = desc.adjoint ? CUFFT_FORWARD : CUFFT_INVERSE; @@ -526,15 +477,13 @@ HRESULT FourierExecutor::backwardPencil( CHECK_CUFFT_EXIT(cufftXtExec(m_Plan_c2c_z, input, output, DIRECTION)); // Tranpose Z to Y - CHECK_CUDECOMP_EXIT(cudecompTransposeZToY( - handle, m_GridConfig, output, output, work_d, - get_cudecomp_datatype(complex_t(0)), nullptr, nullptr, stream)); + CHECK_CUDECOMP_EXIT(cudecompTransposeZToY(handle, m_GridConfig, output, output, work_d, + get_cudecomp_datatype(complex_t(0)), nullptr, nullptr, stream)); // FFT on the second pencil CHECK_CUFFT_EXIT(cufftXtExec(m_Plan_c2c_y, output, output, DIRECTION)); // Tranpose Y to X - CHECK_CUDECOMP_EXIT(cudecompTransposeYToX( - handle, m_GridConfig, output, output, work_d, - get_cudecomp_datatype(complex_t(0)), nullptr, nullptr, stream)); + CHECK_CUDECOMP_EXIT(cudecompTransposeYToX(handle, m_GridConfig, output, output, work_d, + get_cudecomp_datatype(complex_t(0)), nullptr, nullptr, stream)); // FFT on the third pencil CHECK_CUFFT_EXIT(cufftXtExec(m_Plan_c2c_x, output, output, DIRECTION)); @@ -557,8 +506,7 @@ template HRESULT FourierExecutor::clearPlans() { cufftDestroy(m_Plan_c2c_y); cufftDestroy(m_Plan_c2c_z); break; - case Decomposition::unknown: - break; + case Decomposition::unknown: break; } return S_OK; diff --git a/src/grid_descriptor_mgr.cc b/src/grid_descriptor_mgr.cc index 300cc40..87986d7 100644 --- a/src/grid_descriptor_mgr.cc +++ b/src/grid_descriptor_mgr.cc @@ -1,8 +1,8 @@ #include "grid_descriptor_mgr.h" -#include "logger.hpp" -#include "fft.h" #include "checks.h" +#include "fft.h" +#include "logger.hpp" #include #include #include @@ -21,17 +21,14 @@ GridDescriptorManager::GridDescriptorManager() : m_Tracer("JAXDECOMP") { // Check if MPI has already been initialized by the user (maybe with mpi4py) int is_initialized; CHECK_MPI_EXIT(MPI_Initialized(&is_initialized)); - if (!is_initialized) { - CHECK_MPI_EXIT(MPI_Init(nullptr, nullptr)); - } + if (!is_initialized) { CHECK_MPI_EXIT(MPI_Init(nullptr, nullptr)); } // Initialize cuDecomp CHECK_CUDECOMP_EXIT(cudecompInit(&m_Handle, mpi_comm)); isInitialized = true; } -HRESULT GridDescriptorManager::createFFTExecutor( - fftDescriptor &descriptor, size_t &work_size, - std::shared_ptr> &executor) { +HRESULT GridDescriptorManager::createFFTExecutor(fftDescriptor& descriptor, size_t& work_size, + std::shared_ptr>& executor) { HRESULT hr(E_FAIL); @@ -44,18 +41,14 @@ HRESULT GridDescriptorManager::createFFTExecutor( } if (hr == E_FAIL) { - hr = executor->Initialize(m_Handle, descriptor.config, work_size, - descriptor); - if (SUCCEEDED(hr)) { - m_Descriptors64[descriptor] = executor; - } + hr = executor->Initialize(m_Handle, descriptor.config, work_size, descriptor); + if (SUCCEEDED(hr)) { m_Descriptors64[descriptor] = executor; } } return hr; } -HRESULT GridDescriptorManager::createFFTExecutor( - fftDescriptor &descriptor, size_t &work_size, - std::shared_ptr> &executor) { +HRESULT GridDescriptorManager::createFFTExecutor(fftDescriptor& descriptor, size_t& work_size, + std::shared_ptr>& executor) { HRESULT hr(E_FAIL); @@ -67,18 +60,14 @@ HRESULT GridDescriptorManager::createFFTExecutor( } if (hr == E_FAIL) { - hr = executor->Initialize(m_Handle, descriptor.config, work_size, - descriptor); - if (SUCCEEDED(hr)) { - m_Descriptors32[descriptor] = executor; - } + hr = executor->Initialize(m_Handle, descriptor.config, work_size, descriptor); + if (SUCCEEDED(hr)) { m_Descriptors32[descriptor] = executor; } } return hr; } -HRESULT GridDescriptorManager::createHaloExecutor( - haloDescriptor_t &descriptor, size_t &work_size, - std::shared_ptr> &executor) { +HRESULT GridDescriptorManager::createHaloExecutor(haloDescriptor_t& descriptor, size_t& work_size, + std::shared_ptr>& executor) { HRESULT hr(E_FAIL); @@ -92,16 +81,13 @@ HRESULT GridDescriptorManager::createHaloExecutor( if (hr == E_FAIL) { executor = std::make_shared>(); hr = executor->get_halo_descriptor(m_Handle, work_size, descriptor); - if (SUCCEEDED(hr)) { - m_HaloDescriptors32[descriptor] = executor; - } + if (SUCCEEDED(hr)) { m_HaloDescriptors32[descriptor] = executor; } } return hr; } -HRESULT GridDescriptorManager::createHaloExecutor( - haloDescriptor_t &descriptor, size_t &work_size, - std::shared_ptr> &executor) { +HRESULT GridDescriptorManager::createHaloExecutor(haloDescriptor_t& descriptor, size_t& work_size, + std::shared_ptr>& executor) { HRESULT hr(E_FAIL); @@ -115,45 +101,38 @@ HRESULT GridDescriptorManager::createHaloExecutor( if (hr == E_FAIL) { executor = std::make_shared>(); hr = executor->get_halo_descriptor(m_Handle, work_size, descriptor); - if (SUCCEEDED(hr)) { - m_HaloDescriptors64[descriptor] = executor; - } + if (SUCCEEDED(hr)) { m_HaloDescriptors64[descriptor] = executor; } } return hr; } void GridDescriptorManager::finalize() { - if (!isInitialized) - return; + if (!isInitialized) return; StartTraceInfo(m_Tracer) << "JaxDecomp shut down" << std::endl; // Destroy grid descriptors - for (auto &descriptor : m_Descriptors64) { - auto &executor = descriptor.second; + for (auto& descriptor : m_Descriptors64) { + auto& executor = descriptor.second; // TODO(wassim) : Cleanup cudecomp resources // CHECK_CUDECOMP_EXIT(cudecompFree(handle, grid_desc_c, work)); This can // be used instead of requesting XLA to allocate the memory - cudecompResult_t err = - cudecompGridDescDestroy(m_Handle, executor->m_GridConfig); + cudecompResult_t err = cudecompGridDescDestroy(m_Handle, executor->m_GridConfig); // Do not throw exceptioin here, this called when the library is being // unloaded we should not throw exceptions here if (CUDECOMP_RESULT_SUCCESS != err) { - StartTraceInfo(m_Tracer) - << "CUDECOMP error.at exit " << err << ")" << std::endl; + StartTraceInfo(m_Tracer) << "CUDECOMP error.at exit " << err << ")" << std::endl; } executor->clearPlans(); } - for (auto &descriptor : m_Descriptors32) { - auto &executor = descriptor.second; + for (auto& descriptor : m_Descriptors32) { + auto& executor = descriptor.second; // Cleanup cudecomp resources // CHECK_CUDECOMP_EXIT(cudecompFree(handle, grid_desc_c, work)); This can // be used instead of requesting XLA to allocate the memory - cudecompResult_t err = - cudecompGridDescDestroy(m_Handle, executor->m_GridConfig); + cudecompResult_t err = cudecompGridDescDestroy(m_Handle, executor->m_GridConfig); if (CUDECOMP_RESULT_SUCCESS != err) { - StartTraceInfo(m_Tracer) - << "CUDECOMP error.at exit " << err << ")" << std::endl; + StartTraceInfo(m_Tracer) << "CUDECOMP error.at exit " << err << ")" << std::endl; } executor->clearPlans(); } @@ -164,13 +143,11 @@ void GridDescriptorManager::finalize() { // Clean finish CHECK_CUDA_EXIT(cudaDeviceSynchronize()); // MPI is finalized by the mpi4py runtime (I wish it wasn't) - //CHECK_MPI_EXIT(MPI_Finalize()); + // CHECK_MPI_EXIT(MPI_Finalize()); isInitialized = false; } GridDescriptorManager::~GridDescriptorManager() { - if (isInitialized) { - finalize(); - } + if (isInitialized) { finalize(); } } } // namespace jaxdecomp diff --git a/src/halo.cu b/src/halo.cu index c8f6f92..27093c5 100644 --- a/src/halo.cu +++ b/src/halo.cu @@ -9,40 +9,30 @@ namespace jaxdecomp { -static inline cudecompDataType_t get_cudecomp_datatype(float) { - return CUDECOMP_FLOAT; -} -static inline cudecompDataType_t get_cudecomp_datatype(double) { - return CUDECOMP_DOUBLE; -} +static inline cudecompDataType_t get_cudecomp_datatype(float) { return CUDECOMP_FLOAT; } +static inline cudecompDataType_t get_cudecomp_datatype(double) { return CUDECOMP_DOUBLE; } template -HRESULT HaloExchange::get_halo_descriptor(cudecompHandle_t handle, - size_t &work_size, - haloDescriptor_t &halo_desc) { +HRESULT HaloExchange::get_halo_descriptor(cudecompHandle_t handle, size_t& work_size, + haloDescriptor_t& halo_desc) { - cudecompGridDescConfig_t &config = halo_desc.config; + cudecompGridDescConfig_t& config = halo_desc.config; - CHECK_CUDECOMP_EXIT( - cudecompGridDescCreate(handle, &m_GridConfig, &config, nullptr)); + CHECK_CUDECOMP_EXIT(cudecompGridDescCreate(handle, &m_GridConfig, &config, nullptr)); // Get pencil information for the specified axis - CHECK_CUDECOMP_EXIT(cudecompGetPencilInfo(handle, m_GridConfig, &m_PencilInfo, - halo_desc.axis, - halo_desc.halo_extents.data())); + CHECK_CUDECOMP_EXIT( + cudecompGetPencilInfo(handle, m_GridConfig, &m_PencilInfo, halo_desc.axis, halo_desc.halo_extents.data())); cudecompPencilInfo_t no_halo; // Get pencil information for the specified axis - CHECK_CUDECOMP_EXIT(cudecompGetPencilInfo(handle, m_GridConfig, &no_halo, - halo_desc.axis, nullptr)); - + CHECK_CUDECOMP_EXIT(cudecompGetPencilInfo(handle, m_GridConfig, &no_halo, halo_desc.axis, nullptr)); // Get workspace size int64_t workspace_num_elements; - CHECK_CUDECOMP_EXIT(cudecompGetHaloWorkspaceSize( - handle, m_GridConfig, halo_desc.axis, m_PencilInfo.halo_extents, - &workspace_num_elements)); + CHECK_CUDECOMP_EXIT(cudecompGetHaloWorkspaceSize(handle, m_GridConfig, halo_desc.axis, m_PencilInfo.halo_extents, + &workspace_num_elements)); int64_t dtype_size; if (halo_desc.double_precision) @@ -56,31 +46,26 @@ HRESULT HaloExchange::get_halo_descriptor(cudecompHandle_t handle, } template -HRESULT HaloExchange::halo_exchange(cudecompHandle_t handle, - haloDescriptor_t desc, - cudaStream_t stream, - void **buffers) { - void *data_d = buffers[0]; - void *work_d = buffers[1]; - - //desc.axis = 2; - // Perform halo exchange along the three dimensions +HRESULT HaloExchange::halo_exchange(cudecompHandle_t handle, haloDescriptor_t desc, cudaStream_t stream, + void** buffers) { + void* data_d = buffers[0]; + void* work_d = buffers[1]; + + // desc.axis = 2; + // Perform halo exchange along the three dimensions for (int i = 0; i < 3; ++i) { switch (desc.axis) { case 0: - CHECK_CUDECOMP_EXIT(cudecompUpdateHalosX( - handle, m_GridConfig, data_d, work_d, get_cudecomp_datatype(real_t(0)), - m_PencilInfo.halo_extents, desc.halo_periods.data(), i, stream)); + CHECK_CUDECOMP_EXIT(cudecompUpdateHalosX(handle, m_GridConfig, data_d, work_d, get_cudecomp_datatype(real_t(0)), + m_PencilInfo.halo_extents, desc.halo_periods.data(), i, stream)); break; case 1: - CHECK_CUDECOMP_EXIT(cudecompUpdateHalosY( - handle, m_GridConfig, data_d, work_d, get_cudecomp_datatype(real_t(0)), - m_PencilInfo.halo_extents, desc.halo_periods.data(), i, stream)); + CHECK_CUDECOMP_EXIT(cudecompUpdateHalosY(handle, m_GridConfig, data_d, work_d, get_cudecomp_datatype(real_t(0)), + m_PencilInfo.halo_extents, desc.halo_periods.data(), i, stream)); break; case 2: - CHECK_CUDECOMP_EXIT(cudecompUpdateHalosZ( - handle, m_GridConfig, data_d, work_d, get_cudecomp_datatype(real_t(0)), - m_PencilInfo.halo_extents, desc.halo_periods.data(), i, stream)); + CHECK_CUDECOMP_EXIT(cudecompUpdateHalosZ(handle, m_GridConfig, data_d, work_d, get_cudecomp_datatype(real_t(0)), + m_PencilInfo.halo_extents, desc.halo_periods.data(), i, stream)); break; } } diff --git a/src/jaxdecomp.cc b/src/jaxdecomp.cc index c597ea1..901be19 100644 --- a/src/jaxdecomp.cc +++ b/src/jaxdecomp.cc @@ -1,15 +1,14 @@ -#include -#include -#include -#include -#include "checks.h" -#include "helpers.h" #include "jaxdecomp.h" -#include "logger.hpp" -#include "grid_descriptor_mgr.h" #include "checks.h" #include "fft.h" +#include "grid_descriptor_mgr.h" #include "halo.h" +#include "helpers.h" +#include "logger.hpp" +#include +#include +#include +#include namespace py = pybind11; namespace jd = jaxdecomp; @@ -23,24 +22,21 @@ namespace jaxdecomp { /** * @brief Finalizes the cuDecomp library */ -void finalize(){jd::GridDescriptorManager::getInstance().finalize();}; +void finalize() { jd::GridDescriptorManager::getInstance().finalize(); }; /** * @brief Returns Pencil information for a given grid */ -decompPencilInfo_t getPencilInfo(decompGridDescConfig_t grid_config, - int32_t axis) { +decompPencilInfo_t getPencilInfo(decompGridDescConfig_t grid_config, int32_t axis) { cudecompHandle_t handle(jd::GridDescriptorManager::getInstance().getHandle()); // Create cuDecomp grid descriptor cudecompGridDescConfig_t config; cudecompGridDescConfigSet(&config, &grid_config); // Create the grid description cudecompGridDesc_t grid_desc; - CHECK_CUDECOMP_EXIT( - cudecompGridDescCreate(handle, &grid_desc, &config, nullptr)); + CHECK_CUDECOMP_EXIT(cudecompGridDescCreate(handle, &grid_desc, &config, nullptr)); cudecompPencilInfo_t pencil_info; - CHECK_CUDECOMP_EXIT( - cudecompGetPencilInfo(handle, grid_desc, &pencil_info, axis, nullptr)); + CHECK_CUDECOMP_EXIT(cudecompGetPencilInfo(handle, grid_desc, &pencil_info, axis, nullptr)); decompPencilInfo_t result; decompPencilInfoSet(&result, &pencil_info); @@ -61,12 +57,10 @@ decompPencilInfo_t getPencilInfo(decompGridDescConfig_t grid_config, * @param halo_periods * @return decompGridDescConfig_t */ -decompGridDescConfig_t -getAutotunedGridConfig(decompGridDescConfig_t grid_config, - bool double_precision, bool disable_nccl_backends, - bool disable_nvshmem_backends, bool tune_with_transpose, - std::array halo_extents, - std::array halo_periods) { +decompGridDescConfig_t getAutotunedGridConfig(decompGridDescConfig_t grid_config, bool double_precision, + bool disable_nccl_backends, bool disable_nvshmem_backends, + bool tune_with_transpose, std::array halo_extents, + std::array halo_periods) { // Create cuDecomp grid descriptor cudecompHandle_t handle(jd::GridDescriptorManager::getInstance().getHandle()); cudecompGridDescConfig_t config; @@ -82,8 +76,7 @@ getAutotunedGridConfig(decompGridDescConfig_t grid_config, options.disable_nvshmem_backends = disable_nvshmem_backends; // Process grid autotuning options - options.grid_mode = tune_with_transpose ? CUDECOMP_AUTOTUNE_GRID_TRANSPOSE - : CUDECOMP_AUTOTUNE_GRID_HALO; + options.grid_mode = tune_with_transpose ? CUDECOMP_AUTOTUNE_GRID_TRANSPOSE : CUDECOMP_AUTOTUNE_GRID_HALO; options.allow_uneven_decompositions = false; // Transpose communication backend autotuning options @@ -108,16 +101,13 @@ getAutotunedGridConfig(decompGridDescConfig_t grid_config, options.halo_periods[2] = halo_periods[2]; cudecompGridDesc_t grid_desc; - CHECK_CUDECOMP_EXIT( - cudecompGridDescCreate(handle, &grid_desc, &config, &options)); + CHECK_CUDECOMP_EXIT(cudecompGridDescCreate(handle, &grid_desc, &config, &options)); decompGridDescConfig_t output_config; output_config.halo_comm_backend = config.halo_comm_backend; output_config.transpose_comm_backend = config.transpose_comm_backend; - for (int i = 0; i < 3; i++) - output_config.gdims[i] = config.gdims[i]; - for (int i = 0; i < 2; i++) - output_config.pdims[i] = config.pdims[i]; + for (int i = 0; i < 3; i++) output_config.gdims[i] = config.gdims[i]; + for (int i = 0; i < 2; i++) output_config.pdims[i] = config.pdims[i]; CHECK_CUDECOMP_EXIT(cudecompGridDescDestroy(handle, grid_desc)); @@ -125,36 +115,30 @@ getAutotunedGridConfig(decompGridDescConfig_t grid_config, }; /// XLA interface ops -void transposeXtoY(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len) { +void transposeXtoY(cudaStream_t stream, void** buffers, const char* opaque, size_t opaque_len) { cudecompHandle_t handle(jd::GridDescriptorManager::getInstance().getHandle()); - void *data_d = buffers[0]; // In place operations, so only one buffer + void* data_d = buffers[0]; // In place operations, so only one buffer // Create cuDecomp grid descriptor - cudecompGridDescConfig_t config = - *UnpackDescriptor(opaque, opaque_len); + cudecompGridDescConfig_t config = *UnpackDescriptor(opaque, opaque_len); // Create the grid description cudecompGridDesc_t grid_desc; - CHECK_CUDECOMP_EXIT( - cudecompGridDescCreate(handle, &grid_desc, &config, nullptr)); + CHECK_CUDECOMP_EXIT(cudecompGridDescCreate(handle, &grid_desc, &config, nullptr)); // Get workspace sizes int64_t transpose_work_num_elements; - CHECK_CUDECOMP_EXIT(cudecompGetTransposeWorkspaceSize( - handle, grid_desc, &transpose_work_num_elements)); + CHECK_CUDECOMP_EXIT(cudecompGetTransposeWorkspaceSize(handle, grid_desc, &transpose_work_num_elements)); int64_t dtype_size; CHECK_CUDECOMP_EXIT(cudecompGetDataTypeSize(CUDECOMP_FLOAT, &dtype_size)); - double *transpose_work_d; - CHECK_CUDECOMP_EXIT(cudecompMalloc( - handle, grid_desc, reinterpret_cast(&transpose_work_d), - transpose_work_num_elements * dtype_size)); + double* transpose_work_d; + CHECK_CUDECOMP_EXIT(cudecompMalloc(handle, grid_desc, reinterpret_cast(&transpose_work_d), + transpose_work_num_elements * dtype_size)); - CHECK_CUDECOMP_EXIT(cudecompTransposeXToY(handle, grid_desc, data_d, data_d, - transpose_work_d, CUDECOMP_FLOAT, + CHECK_CUDECOMP_EXIT(cudecompTransposeXToY(handle, grid_desc, data_d, data_d, transpose_work_d, CUDECOMP_FLOAT, nullptr, nullptr, stream)); CHECK_CUDECOMP_EXIT(cudecompFree(handle, grid_desc, transpose_work_d)); @@ -162,36 +146,30 @@ void transposeXtoY(cudaStream_t stream, void **buffers, const char *opaque, CHECK_CUDECOMP_EXIT(cudecompGridDescDestroy(handle, grid_desc)); } -void transposeYtoZ(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len) { +void transposeYtoZ(cudaStream_t stream, void** buffers, const char* opaque, size_t opaque_len) { cudecompHandle_t handle(jd::GridDescriptorManager::getInstance().getHandle()); - void *data_d = buffers[0]; // In place operations, so only one buffer + void* data_d = buffers[0]; // In place operations, so only one buffer // Create cuDecomp grid descriptor - cudecompGridDescConfig_t config = - *UnpackDescriptor(opaque, opaque_len); + cudecompGridDescConfig_t config = *UnpackDescriptor(opaque, opaque_len); // Create the grid description cudecompGridDesc_t grid_desc; - CHECK_CUDECOMP_EXIT( - cudecompGridDescCreate(handle, &grid_desc, &config, nullptr)); + CHECK_CUDECOMP_EXIT(cudecompGridDescCreate(handle, &grid_desc, &config, nullptr)); // Get workspace sizes int64_t transpose_work_num_elements; - CHECK_CUDECOMP_EXIT(cudecompGetTransposeWorkspaceSize( - handle, grid_desc, &transpose_work_num_elements)); + CHECK_CUDECOMP_EXIT(cudecompGetTransposeWorkspaceSize(handle, grid_desc, &transpose_work_num_elements)); int64_t dtype_size; CHECK_CUDECOMP_EXIT(cudecompGetDataTypeSize(CUDECOMP_FLOAT, &dtype_size)); - double *transpose_work_d; - CHECK_CUDECOMP_EXIT(cudecompMalloc( - handle, grid_desc, reinterpret_cast(&transpose_work_d), - transpose_work_num_elements * dtype_size)); + double* transpose_work_d; + CHECK_CUDECOMP_EXIT(cudecompMalloc(handle, grid_desc, reinterpret_cast(&transpose_work_d), + transpose_work_num_elements * dtype_size)); - CHECK_CUDECOMP_EXIT(cudecompTransposeYToZ(handle, grid_desc, data_d, data_d, - transpose_work_d, CUDECOMP_FLOAT, + CHECK_CUDECOMP_EXIT(cudecompTransposeYToZ(handle, grid_desc, data_d, data_d, transpose_work_d, CUDECOMP_FLOAT, nullptr, nullptr, stream)); CHECK_CUDECOMP_EXIT(cudecompFree(handle, grid_desc, transpose_work_d)); @@ -199,36 +177,30 @@ void transposeYtoZ(cudaStream_t stream, void **buffers, const char *opaque, CHECK_CUDECOMP_EXIT(cudecompGridDescDestroy(handle, grid_desc)); } -void transposeZtoY(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len) { +void transposeZtoY(cudaStream_t stream, void** buffers, const char* opaque, size_t opaque_len) { cudecompHandle_t handle(jd::GridDescriptorManager::getInstance().getHandle()); - void *data_d = buffers[0]; // In place operations, so only one buffer + void* data_d = buffers[0]; // In place operations, so only one buffer // Create cuDecomp grid descriptor - cudecompGridDescConfig_t config = - *UnpackDescriptor(opaque, opaque_len); + cudecompGridDescConfig_t config = *UnpackDescriptor(opaque, opaque_len); // Create the grid description cudecompGridDesc_t grid_desc; - CHECK_CUDECOMP_EXIT( - cudecompGridDescCreate(handle, &grid_desc, &config, nullptr)); + CHECK_CUDECOMP_EXIT(cudecompGridDescCreate(handle, &grid_desc, &config, nullptr)); // Get workspace sizes int64_t transpose_work_num_elements; - CHECK_CUDECOMP_EXIT(cudecompGetTransposeWorkspaceSize( - handle, grid_desc, &transpose_work_num_elements)); + CHECK_CUDECOMP_EXIT(cudecompGetTransposeWorkspaceSize(handle, grid_desc, &transpose_work_num_elements)); int64_t dtype_size; CHECK_CUDECOMP_EXIT(cudecompGetDataTypeSize(CUDECOMP_FLOAT, &dtype_size)); - double *transpose_work_d; - CHECK_CUDECOMP_EXIT(cudecompMalloc( - handle, grid_desc, reinterpret_cast(&transpose_work_d), - transpose_work_num_elements * dtype_size)); + double* transpose_work_d; + CHECK_CUDECOMP_EXIT(cudecompMalloc(handle, grid_desc, reinterpret_cast(&transpose_work_d), + transpose_work_num_elements * dtype_size)); - CHECK_CUDECOMP_EXIT(cudecompTransposeZToY(handle, grid_desc, data_d, data_d, - transpose_work_d, CUDECOMP_FLOAT, + CHECK_CUDECOMP_EXIT(cudecompTransposeZToY(handle, grid_desc, data_d, data_d, transpose_work_d, CUDECOMP_FLOAT, nullptr, nullptr, stream)); CHECK_CUDECOMP_EXIT(cudecompFree(handle, grid_desc, transpose_work_d)); @@ -236,36 +208,30 @@ void transposeZtoY(cudaStream_t stream, void **buffers, const char *opaque, CHECK_CUDECOMP_EXIT(cudecompGridDescDestroy(handle, grid_desc)); } -void transposeYtoX(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len) { +void transposeYtoX(cudaStream_t stream, void** buffers, const char* opaque, size_t opaque_len) { cudecompHandle_t handle(jd::GridDescriptorManager::getInstance().getHandle()); - void *data_d = buffers[0]; // In place operations, so only one buffer + void* data_d = buffers[0]; // In place operations, so only one buffer // Create cuDecomp grid descriptor - cudecompGridDescConfig_t config = - *UnpackDescriptor(opaque, opaque_len); + cudecompGridDescConfig_t config = *UnpackDescriptor(opaque, opaque_len); // Create the grid description cudecompGridDesc_t grid_desc; - CHECK_CUDECOMP_EXIT( - cudecompGridDescCreate(handle, &grid_desc, &config, nullptr)); + CHECK_CUDECOMP_EXIT(cudecompGridDescCreate(handle, &grid_desc, &config, nullptr)); // Get workspace sizes int64_t transpose_work_num_elements; - CHECK_CUDECOMP_EXIT(cudecompGetTransposeWorkspaceSize( - handle, grid_desc, &transpose_work_num_elements)); + CHECK_CUDECOMP_EXIT(cudecompGetTransposeWorkspaceSize(handle, grid_desc, &transpose_work_num_elements)); int64_t dtype_size; CHECK_CUDECOMP_EXIT(cudecompGetDataTypeSize(CUDECOMP_FLOAT, &dtype_size)); - double *transpose_work_d; - CHECK_CUDECOMP_EXIT(cudecompMalloc( - handle, grid_desc, reinterpret_cast(&transpose_work_d), - transpose_work_num_elements * dtype_size)); + double* transpose_work_d; + CHECK_CUDECOMP_EXIT(cudecompMalloc(handle, grid_desc, reinterpret_cast(&transpose_work_d), + transpose_work_num_elements * dtype_size)); - CHECK_CUDECOMP_EXIT(cudecompTransposeYToX(handle, grid_desc, data_d, data_d, - transpose_work_d, CUDECOMP_FLOAT, + CHECK_CUDECOMP_EXIT(cudecompTransposeYToX(handle, grid_desc, data_d, data_d, transpose_work_d, CUDECOMP_FLOAT, nullptr, nullptr, stream)); CHECK_CUDECOMP_EXIT(cudecompFree(handle, grid_desc, transpose_work_d)); @@ -276,21 +242,17 @@ void transposeYtoX(cudaStream_t stream, void **buffers, const char *opaque, /** * @brief Wrapper to cuDecomp-based FFTs */ -void pfft3d(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len) { +void pfft3d(cudaStream_t stream, void** buffers, const char* opaque, size_t opaque_len) { - fftDescriptor descriptor = - *UnpackDescriptor(opaque, opaque_len); + fftDescriptor descriptor = *UnpackDescriptor(opaque, opaque_len); size_t work_size; - cudecompHandle_t my_handle( - jd::GridDescriptorManager::getInstance().getHandle()); + cudecompHandle_t my_handle(jd::GridDescriptorManager::getInstance().getHandle()); // Execute the correct version of the FFT if (descriptor.double_precision) { auto executor = std::make_shared>(); - jd::GridDescriptorManager::getInstance().createFFTExecutor( - descriptor, work_size, executor); + jd::GridDescriptorManager::getInstance().createFFTExecutor(descriptor, work_size, executor); if (descriptor.forward) executor->forward(my_handle, descriptor, stream, buffers); @@ -299,8 +261,7 @@ void pfft3d(cudaStream_t stream, void **buffers, const char *opaque, } else { auto executor = std::make_shared>(); - jd::GridDescriptorManager::getInstance().createFFTExecutor( - descriptor, work_size, executor); + jd::GridDescriptorManager::getInstance().createFFTExecutor(descriptor, work_size, executor); if (descriptor.forward) executor->forward(my_handle, descriptor, stream, buffers); @@ -313,27 +274,23 @@ void pfft3d(cudaStream_t stream, void **buffers, const char *opaque, * @brief Perfom a halo exchange along the 3 dimensions * */ -void halo(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len) { +void halo(cudaStream_t stream, void** buffers, const char* opaque, size_t opaque_len) { cudecompHandle_t handle(jd::GridDescriptorManager::getInstance().getHandle()); - haloDescriptor_t descriptor = - *UnpackDescriptor(opaque, opaque_len); + haloDescriptor_t descriptor = *UnpackDescriptor(opaque, opaque_len); size_t work_size; // Execute the correct version of the halo exchange if (descriptor.double_precision) { auto executor = std::make_shared>(); - jd::GridDescriptorManager::getInstance().createHaloExecutor( - descriptor, work_size, executor); + jd::GridDescriptorManager::getInstance().createHaloExecutor(descriptor, work_size, executor); executor->halo_exchange(handle, descriptor, stream, buffers); } else { auto executor = std::make_shared>(); - jd::GridDescriptorManager::getInstance().createHaloExecutor( - descriptor, work_size, executor); + jd::GridDescriptorManager::getInstance().createHaloExecutor(descriptor, work_size, executor); executor->halo_exchange(handle, descriptor, stream, buffers); } @@ -368,97 +325,75 @@ PYBIND11_MODULE(_jaxdecomp, m) { return jd::PackDescriptor(cuconfig); }); - m.def("build_fft_descriptor", [](jd::decompGridDescConfig_t config, - bool forward, bool double_precision, - bool adjoint) { - // Create a real cuDecomp grid descriptor - cudecompGridDescConfig_t cuconfig; - cudecompGridDescConfigSet(&cuconfig, &config); + m.def("build_fft_descriptor", + [](jd::decompGridDescConfig_t config, bool forward, bool double_precision, bool adjoint) { + // Create a real cuDecomp grid descriptor + cudecompGridDescConfig_t cuconfig; + cudecompGridDescConfigSet(&cuconfig, &config); - size_t work_size; - jd::fftDescriptor fftdesc(cuconfig, double_precision, forward, adjoint); - if (double_precision) { + size_t work_size; + jd::fftDescriptor fftdesc(cuconfig, double_precision, forward, adjoint); + if (double_precision) { - auto executor = std::make_shared>(); + auto executor = std::make_shared>(); - HRESULT hr = jd::GridDescriptorManager::getInstance().createFFTExecutor( - fftdesc, work_size, executor); + HRESULT hr = jd::GridDescriptorManager::getInstance().createFFTExecutor(fftdesc, work_size, executor); - return std::pair(work_size, - PackDescriptor(fftdesc)); + return std::pair(work_size, PackDescriptor(fftdesc)); - } else { - auto executor = std::make_shared>(); + } else { + auto executor = std::make_shared>(); - HRESULT hr = jd::GridDescriptorManager::getInstance().createFFTExecutor( - fftdesc, work_size, executor); + HRESULT hr = jd::GridDescriptorManager::getInstance().createFFTExecutor(fftdesc, work_size, executor); - return std::pair(work_size, - PackDescriptor(fftdesc)); - } - }); + return std::pair(work_size, PackDescriptor(fftdesc)); + } + }); - m.def("build_halo_descriptor", [](jd::decompGridDescConfig_t config, - bool double_precision, - std::array halo_extents, - std::array halo_periods, - int axis = 0) { - // Create a real cuDecomp grid descriptor - cudecompGridDescConfig_t cuconfig; - cudecompGridDescConfigSet(&cuconfig, &config); - cudecompHandle_t handle( - jd::GridDescriptorManager::getInstance().getHandle()); - - size_t work_size; - jd::haloDescriptor_t halo_desc; - halo_desc.double_precision = double_precision; - halo_desc.halo_extents = halo_extents; - halo_desc.halo_periods = halo_periods; - halo_desc.axis = axis; - halo_desc.config = cuconfig; - - if (double_precision) { - auto executor = std::make_shared>(); - HRESULT hr = jd::GridDescriptorManager::getInstance().createHaloExecutor( - halo_desc, work_size, executor); - } else { - auto executor = std::make_shared>(); - HRESULT hr = jd::GridDescriptorManager::getInstance().createHaloExecutor( - halo_desc, work_size, executor); - } - - return std::pair(work_size, - PackDescriptor(halo_desc)); - }); + m.def("build_halo_descriptor", + [](jd::decompGridDescConfig_t config, bool double_precision, std::array halo_extents, + std::array halo_periods, int axis = 0) { + // Create a real cuDecomp grid descriptor + cudecompGridDescConfig_t cuconfig; + cudecompGridDescConfigSet(&cuconfig, &config); + cudecompHandle_t handle(jd::GridDescriptorManager::getInstance().getHandle()); + + size_t work_size; + jd::haloDescriptor_t halo_desc; + halo_desc.double_precision = double_precision; + halo_desc.halo_extents = halo_extents; + halo_desc.halo_periods = halo_periods; + halo_desc.axis = axis; + halo_desc.config = cuconfig; + + if (double_precision) { + auto executor = std::make_shared>(); + HRESULT hr = jd::GridDescriptorManager::getInstance().createHaloExecutor(halo_desc, work_size, executor); + } else { + auto executor = std::make_shared>(); + HRESULT hr = jd::GridDescriptorManager::getInstance().createHaloExecutor(halo_desc, work_size, executor); + } + + return std::pair(work_size, PackDescriptor(halo_desc)); + }); // Exported types py::enum_(m, "TransposeCommBackend") - .value("TRANSPOSE_COMM_MPI_P2P", - cudecompTransposeCommBackend_t::CUDECOMP_TRANSPOSE_COMM_MPI_P2P) - .value("TRANSPOSE_COMM_MPI_P2P_PL", - cudecompTransposeCommBackend_t::CUDECOMP_TRANSPOSE_COMM_MPI_P2P_PL) - .value("TRANSPOSE_COMM_MPI_A2A", - cudecompTransposeCommBackend_t::CUDECOMP_TRANSPOSE_COMM_MPI_A2A) - .value("TRANSPOSE_COMM_NCCL", - cudecompTransposeCommBackend_t::CUDECOMP_TRANSPOSE_COMM_NCCL) - .value("TRANSPOSE_COMM_NCCL_PL", - cudecompTransposeCommBackend_t::CUDECOMP_TRANSPOSE_COMM_NCCL_PL) - .value("TRANSPOSE_COMM_NVSHMEM", - cudecompTransposeCommBackend_t::CUDECOMP_TRANSPOSE_COMM_NVSHMEM) - .value("TRANSPOSE_COMM_NVSHMEM_PL", - cudecompTransposeCommBackend_t::CUDECOMP_TRANSPOSE_COMM_NVSHMEM_PL) + .value("TRANSPOSE_COMM_MPI_P2P", cudecompTransposeCommBackend_t::CUDECOMP_TRANSPOSE_COMM_MPI_P2P) + .value("TRANSPOSE_COMM_MPI_P2P_PL", cudecompTransposeCommBackend_t::CUDECOMP_TRANSPOSE_COMM_MPI_P2P_PL) + .value("TRANSPOSE_COMM_MPI_A2A", cudecompTransposeCommBackend_t::CUDECOMP_TRANSPOSE_COMM_MPI_A2A) + .value("TRANSPOSE_COMM_NCCL", cudecompTransposeCommBackend_t::CUDECOMP_TRANSPOSE_COMM_NCCL) + .value("TRANSPOSE_COMM_NCCL_PL", cudecompTransposeCommBackend_t::CUDECOMP_TRANSPOSE_COMM_NCCL_PL) + .value("TRANSPOSE_COMM_NVSHMEM", cudecompTransposeCommBackend_t::CUDECOMP_TRANSPOSE_COMM_NVSHMEM) + .value("TRANSPOSE_COMM_NVSHMEM_PL", cudecompTransposeCommBackend_t::CUDECOMP_TRANSPOSE_COMM_NVSHMEM_PL) .export_values(); py::enum_(m, "HaloCommBackend") .value("HALO_COMM_MPI", cudecompHaloCommBackend_t::CUDECOMP_HALO_COMM_MPI) - .value("HALO_COMM_MPI_BLOCKING", - cudecompHaloCommBackend_t::CUDECOMP_HALO_COMM_MPI_BLOCKING) - .value("HALO_COMM_NCCL", - cudecompHaloCommBackend_t::CUDECOMP_HALO_COMM_NCCL) - .value("HALO_COMM_NVSHMEM", - cudecompHaloCommBackend_t::CUDECOMP_HALO_COMM_NVSHMEM) - .value("HALO_COMM_NVSHMEM_BLOCKING", - cudecompHaloCommBackend_t::CUDECOMP_HALO_COMM_NVSHMEM_BLOCKING) + .value("HALO_COMM_MPI_BLOCKING", cudecompHaloCommBackend_t::CUDECOMP_HALO_COMM_MPI_BLOCKING) + .value("HALO_COMM_NCCL", cudecompHaloCommBackend_t::CUDECOMP_HALO_COMM_NCCL) + .value("HALO_COMM_NVSHMEM", cudecompHaloCommBackend_t::CUDECOMP_HALO_COMM_NVSHMEM) + .value("HALO_COMM_NVSHMEM_BLOCKING", cudecompHaloCommBackend_t::CUDECOMP_HALO_COMM_NVSHMEM_BLOCKING) .export_values(); py::class_ pencil_info(m, "PencilInfo"); @@ -474,8 +409,6 @@ PYBIND11_MODULE(_jaxdecomp, m) { config.def(py::init<>()) .def_readwrite("gdims", &jd::decompGridDescConfig_t::gdims) .def_readwrite("pdims", &jd::decompGridDescConfig_t::pdims) - .def_readwrite("transpose_comm_backend", - &jd::decompGridDescConfig_t::transpose_comm_backend) - .def_readwrite("halo_comm_backend", - &jd::decompGridDescConfig_t::halo_comm_backend); + .def_readwrite("transpose_comm_backend", &jd::decompGridDescConfig_t::transpose_comm_backend) + .def_readwrite("halo_comm_backend", &jd::decompGridDescConfig_t::halo_comm_backend); } From 3f705308b68a1d78eb02a1bb48a6bc7d7ce18bd4 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Mon, 29 Apr 2024 23:23:00 +0200 Subject: [PATCH 2/3] remove clang-format from the CI workflow --- .github/workflows/formatting.yml | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/.github/workflows/formatting.yml b/.github/workflows/formatting.yml index 684c9c7..97cd358 100644 --- a/.github/workflows/formatting.yml +++ b/.github/workflows/formatting.yml @@ -19,19 +19,3 @@ jobs: python -m pip install pre-commit - name: Run pre-commit run: python -m pre_commit run --all-files - formatting-check: - name: Formatting Check - runs-on: ubuntu-latest - strategy: - matrix: - path: - - 'src' - - 'include' - steps: - - uses: actions/checkout@v3 - - name: Run clang-format style check for C/C++/Protobuf programs. - uses: jidicula/clang-format-action@v4.11.0 - with: - clang-format-version: '13' - check-path: ${{ matrix.path }} - fallback-style: 'LLVM' # optional From 743b663c3a59be18305a0eb06f3cfb5b629e1512 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Tue, 30 Apr 2024 22:58:51 +0200 Subject: [PATCH 3/3] Exclude cudecomp and pybind from clang-format pre-commit --- .pre-commit-config.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 75959b6..c7141cd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,4 +19,6 @@ repos: rev: v18.1.4 hooks: - id: clang-format + files: '\.(c|cc|cpp|h|hpp|cxx|hh|cu|cuh)$' + exclude: '^third_party/|/pybind11/' name: clang-format