Skip to content

Commit

Permalink
Merge pull request #16 from DifferentiableUniverseInitiative/format-c…
Browse files Browse the repository at this point in the history
…lang

Add a clang-format pre-commit and workflow
  • Loading branch information
EiffL authored Apr 30, 2024
2 parents b817ae8 + 743b663 commit e246d43
Show file tree
Hide file tree
Showing 13 changed files with 589 additions and 865 deletions.
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,10 @@ repos:
hooks:
- id: isort
name: isort (python)
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.4
hooks:
- id: clang-format
files: '\.(c|cc|cpp|h|hpp|cxx|hh|cu|cuh)$'
exclude: '^third_party/|/pybind11/'
name: clang-format
103 changes: 49 additions & 54 deletions include/checks.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,69 +26,64 @@ using namespace std;
#define E_NOTIMPL ((HRESULT)0x80004001L)

// Macro to check for CUDA errors
#define CHECK_CUDA_EXIT(call) \
do { \
cudaError_t err = call; \
if (err != cudaSuccess) { \
printf("CUDA error at %s %d: %s\n", __FILE__, __LINE__, \
cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
#define CHECK_CUDA_EXIT(call) \
do { \
cudaError_t err = call; \
if (err != cudaSuccess) { \
printf("CUDA error at %s %d: %s\n", __FILE__, __LINE__, cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
} while (0)

// Error checking macros
#define CHECK_CUDECOMP_EXIT(call) \
do { \
cudecompResult_t err = call; \
if (CUDECOMP_RESULT_SUCCESS != err) { \
fprintf(stderr, "%s:%d CUDECOMP error. (error code %d)\n", __FILE__, \
__LINE__, err); \
throw exception(); \
} \
#define CHECK_CUDECOMP_EXIT(call) \
do { \
cudecompResult_t err = call; \
if (CUDECOMP_RESULT_SUCCESS != err) { \
fprintf(stderr, "%s:%d CUDECOMP error. (error code %d)\n", __FILE__, __LINE__, err); \
throw exception(); \
} \
} while (false)

#define CHECK_CUFFT_EXIT(call) \
do { \
cufftResult_t err = call; \
if (CUFFT_SUCCESS != err) { \
fprintf(stderr, "%s:%d CUFFT error. (error code %d)\n", __FILE__, \
__LINE__, err); \
throw exception(); \
} \
#define CHECK_CUFFT_EXIT(call) \
do { \
cufftResult_t err = call; \
if (CUFFT_SUCCESS != err) { \
fprintf(stderr, "%s:%d CUFFT error. (error code %d)\n", __FILE__, __LINE__, err); \
throw exception(); \
} \
} while (false)

#define CHECK_MPI_EXIT(call) \
{ \
int err = call; \
if (0 != err) { \
char error_str[MPI_MAX_ERROR_STRING]; \
int len; \
MPI_Error_string(err, error_str, &len); \
if (error_str) { \
fprintf(stderr, "%s:%d MPI error. (%s)\n", __FILE__, __LINE__, \
error_str); \
} else { \
fprintf(stderr, "%s:%d MPI error. (error code %d)\n", __FILE__, \
__LINE__, err); \
} \
exit(EXIT_FAILURE); \
} \
} \
#define CHECK_MPI_EXIT(call) \
{ \
int err = call; \
if (0 != err) { \
char error_str[MPI_MAX_ERROR_STRING]; \
int len; \
MPI_Error_string(err, error_str, &len); \
if (error_str) { \
fprintf(stderr, "%s:%d MPI error. (%s)\n", __FILE__, __LINE__, error_str); \
} else { \
fprintf(stderr, "%s:%d MPI error. (error code %d)\n", __FILE__, __LINE__, err); \
} \
exit(EXIT_FAILURE); \
} \
} \
while (false)

#define HR2STR(hr) \
((hr == S_OK) ? "S_OK" \
: (hr == S_FALSE) ? "S_FALSE" \
: (hr == E_ABORT) ? "E_ABORT" \
: (hr == E_ACCESSDENIED) ? "E_ACCESSDENIED" \
: (hr == E_FAIL) ? "E_FAIL" \
: (hr == E_HANDLE) ? "E_HANDLE" \
: (hr == E_INVALIDARG) ? "E_INVALIDARG" \
: (hr == E_NOINTERFACE) ? "E_NOINTERFACE" \
: (hr == E_NOTIMPL) ? "E_NOTIMPL" \
: (hr == E_OUTOFMEMORY) ? "E_OUTOFMEMORY" \
: (hr == E_POINTER) ? "E_POINTER" \
: (hr == E_UNEXPECTED) ? "E_UNEXPECTED" \
#define HR2STR(hr) \
((hr == S_OK) ? "S_OK" \
: (hr == S_FALSE) ? "S_FALSE" \
: (hr == E_ABORT) ? "E_ABORT" \
: (hr == E_ACCESSDENIED) ? "E_ACCESSDENIED" \
: (hr == E_FAIL) ? "E_FAIL" \
: (hr == E_HANDLE) ? "E_HANDLE" \
: (hr == E_INVALIDARG) ? "E_INVALIDARG" \
: (hr == E_NOINTERFACE) ? "E_NOINTERFACE" \
: (hr == E_NOTIMPL) ? "E_NOTIMPL" \
: (hr == E_OUTOFMEMORY) ? "E_OUTOFMEMORY" \
: (hr == E_POINTER) ? "E_POINTER" \
: (hr == E_UNEXPECTED) ? "E_UNEXPECTED" \
: "Unknown HRESULT")

#endif // _JAX_DECOMP_CHECKS_H_
116 changes: 47 additions & 69 deletions include/fft.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#ifndef _JAX_DECOMP_FFT_H_
#define _JAX_DECOMP_FFT_H_

#include "logger.hpp"
#include "checks.h"
#include "logger.hpp"
#include <array>
#include <cmath> // has to be included before cuda/std/complex
#include <cstddef>
Expand All @@ -20,12 +20,8 @@ static cufftType get_cufft_type_c2r(double) { return CUFFT_Z2D; }
static cufftType get_cufft_type_c2r(float) { return CUFFT_C2R; }
static cufftType get_cufft_type_c2c(double) { return CUFFT_Z2Z; }
static cufftType get_cufft_type_c2c(float) { return CUFFT_C2C; }
static cudecompDataType_t get_cudecomp_datatype(cuda::std::complex<float>) {
return CUDECOMP_FLOAT_COMPLEX;
}
static cudecompDataType_t get_cudecomp_datatype(cuda::std::complex<double>) {
return CUDECOMP_DOUBLE_COMPLEX;
}
static cudecompDataType_t get_cudecomp_datatype(cuda::std::complex<float>) { return CUDECOMP_FLOAT_COMPLEX; }
static cudecompDataType_t get_cudecomp_datatype(cuda::std::complex<double>) { return CUDECOMP_DOUBLE_COMPLEX; }
namespace jaxdecomp {

enum Decomposition { slab_XY, slab_YZ, pencil, unknown };
Expand Down Expand Up @@ -61,13 +57,12 @@ class fftDescriptor {

// To make it trivially copyable
fftDescriptor() = default;
fftDescriptor(const fftDescriptor &other) = default;
fftDescriptor &operator=(const fftDescriptor &other) = default;
fftDescriptor(const fftDescriptor& other) = default;
fftDescriptor& operator=(const fftDescriptor& other) = default;
// Create a compare operator to be used in the unordered_map (a hash is also
// created in the bottom of the file)
bool operator==(const fftDescriptor &other) const {
if (double_precision != other.double_precision ||
gdims[0] != other.gdims[0] || gdims[1] != other.gdims[1] ||
bool operator==(const fftDescriptor& other) const {
if (double_precision != other.double_precision || gdims[0] != other.gdims[0] || gdims[1] != other.gdims[1] ||
gdims[2] != other.gdims[2] || decomposition != other.decomposition) {
return false;
}
Expand All @@ -78,8 +73,8 @@ class fftDescriptor {
// Initialize the descriptor from the grid configuration
// this is used for subsequent ffts to find the Executor that was already
// defined
fftDescriptor(cudecompGridDescConfig_t &config, const bool &double_precision,
const bool &iForward, const bool &iAdjoint)
fftDescriptor(cudecompGridDescConfig_t& config, const bool& double_precision, const bool& iForward,
const bool& iAdjoint)
: double_precision(double_precision), config(config) {
gdims[0] = config.gdims[0];
gdims[1] = config.gdims[1];
Expand All @@ -101,14 +96,12 @@ template <typename real_t> class FourierExecutor {
FourierExecutor() : m_Tracer("JAXDECOMP") {}
~FourierExecutor();

HRESULT Initialize(cudecompHandle_t handle, cudecompGridDescConfig_t config,
size_t &work_size, fftDescriptor &fft_descriptor);
HRESULT Initialize(cudecompHandle_t handle, cudecompGridDescConfig_t config, size_t& work_size,
fftDescriptor& fft_descriptor);

HRESULT forward(cudecompHandle_t handle, fftDescriptor desc,
cudaStream_t stream, void **buffers);
HRESULT forward(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, void** buffers);

HRESULT backward(cudecompHandle_t handle, fftDescriptor desc,
cudaStream_t stream, void **buffers);
HRESULT backward(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, void** buffers);

private:
AsyncLogger m_Tracer;
Expand Down Expand Up @@ -137,69 +130,54 @@ template <typename real_t> class FourierExecutor {
int64_t m_WorkSize;

// Internal functions
HRESULT InitializePencils(cudecompGridDescConfig_t &iGridConfig,
cudecompPencilInfo_t &x_pencil_info,
cudecompPencilInfo_t &y_pencil_info,
cudecompPencilInfo_t &z_pencil_info,
int64_t &work_size, const bool &is_contiguous);

HRESULT InitializeSlabXY(cudecompGridDescConfig_t &iGridConfig,
cudecompPencilInfo_t &x_pencil_info,
cudecompPencilInfo_t &y_pencil_info,
cudecompPencilInfo_t &z_pencil_info,
int64_t &work_size, const bool &is_contiguous);

HRESULT InitializeSlabYZ(cudecompGridDescConfig_t &iGridConfig,
cudecompPencilInfo_t &x_pencil_info,
cudecompPencilInfo_t &y_pencil_info,
cudecompPencilInfo_t &z_pencil_info,
int64_t &work_size, const bool &is_contiguous);

HRESULT forwardXY(cudecompHandle_t handle, fftDescriptor desc,
cudaStream_t stream, complex_t *input, complex_t *output,
complex_t *work_buffer);

HRESULT backwardXY(cudecompHandle_t handle, fftDescriptor desc,
cudaStream_t stream, complex_t *input, complex_t *output,
complex_t *work_buffer);

HRESULT forwardYZ(cudecompHandle_t handle, fftDescriptor desc,
cudaStream_t stream, complex_t *input, complex_t *output,
complex_t *work_buffer);

HRESULT backwardYZ(cudecompHandle_t handle, fftDescriptor desc,
cudaStream_t stream, complex_t *input, complex_t *output,
complex_t *work_buffer);

HRESULT forwardPencil(cudecompHandle_t handle, fftDescriptor desc,
cudaStream_t stream, complex_t *input,
complex_t *output, complex_t *work_buffer);

HRESULT backwardPencil(cudecompHandle_t handle, fftDescriptor desc,
cudaStream_t stream, complex_t *input,
complex_t *output, complex_t *work_buffer);
HRESULT InitializePencils(cudecompGridDescConfig_t& iGridConfig, cudecompPencilInfo_t& x_pencil_info,
cudecompPencilInfo_t& y_pencil_info, cudecompPencilInfo_t& z_pencil_info,
int64_t& work_size, const bool& is_contiguous);

HRESULT InitializeSlabXY(cudecompGridDescConfig_t& iGridConfig, cudecompPencilInfo_t& x_pencil_info,
cudecompPencilInfo_t& y_pencil_info, cudecompPencilInfo_t& z_pencil_info, int64_t& work_size,
const bool& is_contiguous);

HRESULT InitializeSlabYZ(cudecompGridDescConfig_t& iGridConfig, cudecompPencilInfo_t& x_pencil_info,
cudecompPencilInfo_t& y_pencil_info, cudecompPencilInfo_t& z_pencil_info, int64_t& work_size,
const bool& is_contiguous);

HRESULT forwardXY(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, complex_t* input,
complex_t* output, complex_t* work_buffer);

HRESULT backwardXY(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, complex_t* input,
complex_t* output, complex_t* work_buffer);

HRESULT forwardYZ(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, complex_t* input,
complex_t* output, complex_t* work_buffer);

HRESULT backwardYZ(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, complex_t* input,
complex_t* output, complex_t* work_buffer);

HRESULT forwardPencil(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, complex_t* input,
complex_t* output, complex_t* work_buffer);

HRESULT backwardPencil(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, complex_t* input,
complex_t* output, complex_t* work_buffer);

HRESULT clearPlans();

// DEBUG ONLY ... I WARN YOU
void inspect_device_array(complex_t *data, int size, cudaStream_t stream);
void inspect_device_array(complex_t* data, int size, cudaStream_t stream);
};

} // namespace jaxdecomp

namespace std {
template <> struct hash<jaxdecomp::fftDescriptor> {
std::size_t operator()(const jaxdecomp::fftDescriptor &descriptor) const {
std::size_t operator()(const jaxdecomp::fftDescriptor& descriptor) const {
// Only hash The double precision and the gdims and pdims
// If adjoint is changed then the plan should be the same
// adjoint is to be used to execute the backward plan
static const size_t xy_hash =
std::hash<int>()(jaxdecomp::Decomposition::slab_XY);
static const size_t xy_hash = std::hash<int>()(jaxdecomp::Decomposition::slab_XY);

size_t hash = std::hash<double>()(descriptor.double_precision) ^
std::hash<int>()(descriptor.gdims[0]) ^
std::hash<int>()(descriptor.gdims[1]) ^
std::hash<int>()(descriptor.gdims[2]) ^
size_t hash = std::hash<double>()(descriptor.double_precision) ^ std::hash<int>()(descriptor.gdims[0]) ^
std::hash<int>()(descriptor.gdims[1]) ^ std::hash<int>()(descriptor.gdims[2]) ^
std::hash<int>()(descriptor.decomposition);
return hash;
}
Expand Down
Loading

0 comments on commit e246d43

Please sign in to comment.