Skip to content

Commit

Permalink
Merge pull request #32 from DifferentiableUniverseInitiative/push-to-…
Browse files Browse the repository at this point in the history
…pypi-the-sequel

Push to pypi the sequel
  • Loading branch information
ASKabalan authored Oct 24, 2024
2 parents f17bae2 + ba352b9 commit 773ca5b
Show file tree
Hide file tree
Showing 35 changed files with 137 additions and 71 deletions.
68 changes: 68 additions & 0 deletions .github/workflows/github-deploy.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
name: Build and upload to PyPI

on:
workflow_dispatch:
pull_request:
push:
branches:
- main
# release:
# types:
# - published

jobs:
build_wheels:
name: Build wheels on ${{ matrix.os }}
runs-on: ${{ matrix.os }}
strategy:
matrix:
# macos-13 is an intel runner, macos-14 is apple silicon
os: [ubuntu-latest]

steps:
- uses: actions/checkout@v4

- name: Build wheels
uses: pypa/cibuildwheel@v2.21.3
env:
CIBW_BUILD: "cp310-* cp311-* cp312-*"
CIBW_BUILD_VERBOSITY: 2
- uses: actions/upload-artifact@v4
with:
name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }}
path: ./wheelhouse/*.whl

build_sdist:
name: Build source distribution
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Build sdist
run: pipx run build --sdist

- uses: actions/upload-artifact@v4
with:
name: cibw-sdist
path: dist/*.tar.gz

upload_pypi:
needs: [build_wheels, build_sdist]
runs-on: ubuntu-latest
environment: pypi
permissions:
id-token: write
# if: github.event_name == 'release' && github.event.action == 'published'
# or, alternatively, upload to PyPI on every tag starting with 'v' (remove on: release above to use this)
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v')
steps:
- uses: actions/download-artifact@v4
with:
# unpacks all CIBW artifacts into dist/
pattern: cibw-*
path: dist
merge-multiple: true

- uses: pypa/gh-action-pypi-publish@release/v1
#with:
# repository-url: https://test.pypi.org/legacy/
16 changes: 8 additions & 8 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,18 @@ if(CMAKE_CUDA_COMPILER AND JD_CUDECOMP_BACKEND)

# Add _jaxdecomp modulei
pybind11_add_module(_jaxdecomp
src/halo.cu
src/jaxdecomp.cc
src/grid_descriptor_mgr.cc
src/fft.cu
src/transpose.cu
src/csrc/halo.cu
src/csrc/jaxdecomp.cc
src/csrc/grid_descriptor_mgr.cc
src/csrc/fft.cu
src/csrc/transpose.cu
)

set_target_properties(_jaxdecomp PROPERTIES CUDA_ARCHITECTURES "${CUDECOMP_CUDA_CC_LIST}")

target_include_directories(_jaxdecomp
PRIVATE
${CMAKE_CURRENT_LIST_DIR}/include
${CMAKE_CURRENT_LIST_DIR}/src/csrc/include
${CMAKE_CURRENT_SOURCE_DIR}/third_party/cuDecomp/include
${NVHPC_CUDA_INCLUDE_DIR}
${MPI_CXX_INCLUDE_DIRS}
Expand All @@ -88,8 +88,8 @@ if(CMAKE_CUDA_COMPILER AND JD_CUDECOMP_BACKEND)
set_target_properties(_jaxdecomp PROPERTIES LINKER_LANGUAGE CXX)
target_compile_definitions(_jaxdecomp PRIVATE JD_CUDECOMP_BACKEND)
else()
pybind11_add_module(_jaxdecomp src/jaxdecomp.cc)
target_include_directories(_jaxdecomp PRIVATE ${CMAKE_CURRENT_LIST_DIR}/include)
pybind11_add_module(_jaxdecomp src/csrc/jaxdecomp.cc)
target_include_directories(_jaxdecomp PRIVATE ${CMAKE_CURRENT_LIST_DIR}/src/csrc/include)
target_compile_definitions(_jaxdecomp PRIVATE JD_JAX_BACKEND)
endif()

Expand Down
25 changes: 17 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
[build-system]
requires = [ "scikit-build-core","pybind11"]
requires = ["scikit-build-core>=0.4.0", "pybind11>=2.9.0"]
build-backend = "scikit_build_core.build"

[project]
name = "jaxdecomp"
version = "0.1.0"
version = "0.2.0"
description = "JAX bindings for the cuDecomp library"
authors = [
{ name = "Wassim Kabalan" },
Expand All @@ -15,23 +14,33 @@ readme = "README.md"
license = { file = "LICENSE" }
classifiers = [
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent"
"Operating System :: OS Independent",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
]
dependencies = [
"jaxtyping>=0.2.33",
"jax>=0.4.30",
]
dependencies = ["jaxtyping"]

[project.optional-dependencies]
test = ["pytest"]
test = ["pytest>=8.0.0" , "jax[cpu]>=0.4.30"]

[tool.scikit-build]
minimum-version = "0.8"
cmake.version = ">=3.25"
build-dir = "build/{wheel_tag}"
wheel.py-api = "py3"
cmake.build-type = "Release"
# Add any additional configurations for scikit-build if necessary
wheel.install-dir = "jaxdecomp/_src"

[tool.scikit-build.cmake.define]
CMAKE_LIBRARY_OUTPUT_DIRECTORY = ""
CMAKE_EXPORT_COMPILE_COMMANDS = "ON"

#[tool.cibuildwheel]
#test-extras = "test"
#test-command = "pytest {project}/tests"
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
54 changes: 31 additions & 23 deletions include/fft.h → src/csrc/include/fft.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ static cudecompDataType_t get_cudecomp_datatype(cuda::std::complex<float>) { ret
static cudecompDataType_t get_cudecomp_datatype(cuda::std::complex<double>) { return CUDECOMP_DOUBLE_COMPLEX; }
namespace jaxdecomp {

enum Decomposition { slab_XY = 0, slab_YZ = 1, pencil = 2, no_decomp = 3 };
enum Decomposition { slab_XY, slab_YZ, pencil, unknown };

static Decomposition GetDecomposition(const int pdims[2]) {
if (pdims[0] == 1 && pdims[1] > 1) {
Expand All @@ -34,23 +34,24 @@ static Decomposition GetDecomposition(const int pdims[2]) {
} else if (pdims[0] > 1 && pdims[1] > 1) {
return Decomposition::pencil;
}
return Decomposition::no_decomp;
// Return pencils on one devices for testing
return Decomposition::pencil;
// return Decomposition::unknown;
}

// fftDescriptor hash should be triavially computable
// because it contains only bools and integers
class fftDescriptor {
public:
bool adjoint = false;
bool contiguous = true;
bool forward = true; ///< forward or backward pass
// fft_is_forward_pass and forwad are used for the Execution but not for the
// hash This way IFFT and FFT have the same plans when operating with the same
// grid and pdims
int32_t gdims[3]; ///< dimensions of global data grid
// Decomposition type is used in order to allow reusing plans
// from the XY and XZ forward pass for ZY and YZ backward pass respectively
Decomposition decomposition = Decomposition::no_decomp; ///< decomposition type
Decomposition decomposition = Decomposition::unknown; ///< decomposition type
bool double_precision = false;
cudecompGridDescConfig_t config; // Descriptor for the grid

Expand All @@ -62,7 +63,7 @@ class fftDescriptor {
// created in the bottom of the file)
bool operator==(const fftDescriptor& other) const {
if (double_precision != other.double_precision || gdims[0] != other.gdims[0] || gdims[1] != other.gdims[1] ||
gdims[2] != other.gdims[2] || decomposition != other.decomposition || contiguous != other.contiguous) {
gdims[2] != other.gdims[2] || decomposition != other.decomposition) {
return false;
}
return true;
Expand All @@ -73,15 +74,14 @@ class fftDescriptor {
// this is used for subsequent ffts to find the Executor that was already
// defined
fftDescriptor(cudecompGridDescConfig_t& config, const bool& double_precision, const bool& iForward,
const bool& iAdjoint, const bool& iContiguous, const Decomposition& iDecomposition)
: double_precision(double_precision), config(config), forward(iForward), contiguous(iContiguous),
adjoint(iAdjoint), decomposition(iDecomposition) {
const bool& iAdjoint)
: double_precision(double_precision), config(config) {
gdims[0] = config.gdims[0];
gdims[1] = config.gdims[1];
gdims[2] = config.gdims[2];
this->config.transpose_axis_contiguous[0] = iContiguous;
this->config.transpose_axis_contiguous[1] = iContiguous;
this->config.transpose_axis_contiguous[2] = iContiguous;
forward = iForward;
adjoint = iAdjoint;
decomposition = GetDecomposition(config.pdims);
}
};

Expand All @@ -96,7 +96,8 @@ template <typename real_t> class FourierExecutor {
FourierExecutor() : m_Tracer("JAXDECOMP") {}
~FourierExecutor();

HRESULT Initialize(cudecompHandle_t handle, size_t& work_size, fftDescriptor& fft_descriptor);
HRESULT Initialize(cudecompHandle_t handle, cudecompGridDescConfig_t config, size_t& work_size,
fftDescriptor& fft_descriptor);

HRESULT forward(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, void** buffers);

Expand All @@ -113,9 +114,6 @@ template <typename real_t> class FourierExecutor {

cudecompGridDesc_t m_GridConfig;
cudecompGridDescConfig_t m_GridDescConfig;
cudecompPencilInfo_t m_XPencilInfo;
cudecompPencilInfo_t m_YPencilInfo;
cudecompPencilInfo_t m_ZPencilInfo;
// For the sake of expressive code, plans have the name of their corresponding
// goal Instead of reusing pencils plans for slabs, or even ZY to YZ we store
// properly named plans
Expand All @@ -124,20 +122,25 @@ template <typename real_t> class FourierExecutor {
cufftHandle m_Plan_c2c_x;
cufftHandle m_Plan_c2c_y;
cufftHandle m_Plan_c2c_z;
// For Slabs XY FFT (Y) FFT(XZ) but JAX redifines the axes to YZX as X pencil for cudecomp
// so in the end it is FFT(X) FFT(YZ)
// For Slabs XZ FFT (X) FFT(YZ)
cufftHandle m_Plan_c2c_yz;
// For Slabs XY
cufftHandle m_Plan_c2c_xy;
// For Slabs XZ
cufftHandle m_Plan_c2c_yz;
// work size
int64_t m_WorkSize;

// Internal functions
HRESULT InitializePencils(int64_t& work_size, fftDescriptor& fft_descriptor);
HRESULT InitializePencils(cudecompGridDescConfig_t& iGridConfig, cudecompPencilInfo_t& x_pencil_info,
cudecompPencilInfo_t& y_pencil_info, cudecompPencilInfo_t& z_pencil_info,
int64_t& work_size, const bool& is_contiguous);

HRESULT InitializeSlabXY(int64_t& work_size, fftDescriptor& fft_descriptor);
HRESULT InitializeSlabXY(cudecompGridDescConfig_t& iGridConfig, cudecompPencilInfo_t& x_pencil_info,
cudecompPencilInfo_t& y_pencil_info, cudecompPencilInfo_t& z_pencil_info, int64_t& work_size,
const bool& is_contiguous);

HRESULT InitializeSlabYZ(int64_t& work_size, fftDescriptor& fft_descriptor);
HRESULT InitializeSlabYZ(cudecompGridDescConfig_t& iGridConfig, cudecompPencilInfo_t& x_pencil_info,
cudecompPencilInfo_t& y_pencil_info, cudecompPencilInfo_t& z_pencil_info, int64_t& work_size,
const bool& is_contiguous);

HRESULT forwardXY(cudecompHandle_t handle, fftDescriptor desc, cudaStream_t stream, complex_t* input,
complex_t* output, complex_t* work_buffer);
Expand All @@ -158,6 +161,9 @@ template <typename real_t> class FourierExecutor {
complex_t* output, complex_t* work_buffer);

HRESULT clearPlans();

// DEBUG ONLY ... I WARN YOU
void inspect_device_array(complex_t* data, int size, cudaStream_t stream);
};

} // namespace jaxdecomp
Expand All @@ -168,9 +174,11 @@ template <> struct hash<jaxdecomp::fftDescriptor> {
// Only hash The double precision and the gdims and pdims
// If adjoint is changed then the plan should be the same
// adjoint is to be used to execute the backward plan
static const size_t xy_hash = std::hash<int>()(jaxdecomp::Decomposition::slab_XY);

size_t hash = std::hash<double>()(descriptor.double_precision) ^ std::hash<int>()(descriptor.gdims[0]) ^
std::hash<int>()(descriptor.gdims[1]) ^ std::hash<int>()(descriptor.gdims[2]) ^
std::hash<int>()(descriptor.decomposition) ^ std::hash<bool>()(descriptor.contiguous);
std::hash<int>()(descriptor.decomposition);
return hash;
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class GridDescriptorManager {

AsyncLogger m_Tracer;
bool isInitialized = false;
int isMPIalreadyInitialized = false;

cudecompHandle_t m_Handle;

std::unordered_map<fftDescriptor, std::shared_ptr<FourierExecutor<double>>, std::hash<fftDescriptor>, std::equal_to<>>
Expand Down
17 changes: 5 additions & 12 deletions include/halo.h → src/csrc/include/halo.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,37 +22,30 @@ class haloDescriptor_t {
~haloDescriptor_t() = default;

bool operator==(const haloDescriptor_t& other) const {
return (double_precision == other.double_precision && halo_extents[0] == other.halo_extents[0] &&
halo_extents[1] == other.halo_extents[1] && halo_extents[2] == other.halo_extents[2] &&
halo_periods[0] == other.halo_periods[0] && halo_periods[1] == other.halo_periods[1] &&
halo_periods[2] == other.halo_periods[2] && axis == other.axis &&
config.gdims[0] == other.config.gdims[0] && config.gdims[1] == other.config.gdims[1] &&
config.gdims[2] == other.config.gdims[2] && config.pdims[0] == other.config.pdims[0] &&
config.pdims[1] == other.config.pdims[1]);
return (double_precision == other.double_precision && halo_extents == other.halo_extents &&
halo_periods == other.halo_periods && axis == other.axis && config.gdims[0] == other.config.gdims[0] &&
config.gdims[1] == other.config.gdims[1] && config.gdims[2] == other.config.gdims[2] &&
config.pdims[0] == other.config.pdims[0] && config.pdims[1] == other.config.pdims[1]);
}
};

template <typename real_t> class HaloExchange {
friend class GridDescriptorManager;

public:
HaloExchange() : m_Tracer("JAXDECOMP") {}
HaloExchange() = default;
// Grid descriptors are handled by the GridDescriptorManager
// No memory should be cleaned up here everything is handled by the GridDescriptorManager
~HaloExchange() = default;

HRESULT get_halo_descriptor(cudecompHandle_t handle, size_t& work_size, haloDescriptor_t& halo_desc);
HRESULT halo_exchange(cudecompHandle_t handle, haloDescriptor_t desc, cudaStream_t stream, void** buffers);

private:
AsyncLogger m_Tracer;

cudecompGridDesc_t m_GridConfig;
cudecompGridDescConfig_t m_GridDescConfig;
cudecompPencilInfo_t m_PencilInfo;

int64_t m_WorkSize;
HRESULT cleanUp(cudecompHandle_t handle);
};

} // namespace jaxdecomp
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 773ca5b

Please sign in to comment.