Skip to content

Commit

Permalink
Clean up sphericart-jax-cuda even more
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Oct 12, 2024
1 parent c2be239 commit 46718ef
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 120 deletions.
47 changes: 0 additions & 47 deletions sphericart-jax/include/sphericart/sphericart_jax_cuda.hpp

This file was deleted.

1 change: 0 additions & 1 deletion sphericart-jax/src/sphericart_jax_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ void cpu_sph_with_hessians(void* out_tuple, const void** in) {
}

// Registration of the custom calls with pybind11

pybind11::dict Registrations() {
pybind11::dict dict;
dict["cpu_spherical_f32"] = EncapsulateFunction(cpu_sph<sphericart::SphericalHarmonics, float>);
Expand Down
99 changes: 27 additions & 72 deletions sphericart-jax/src/sphericart_jax_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
// devices. It is exposed as a standard pybind11 module defining "capsule"
// objects containing our methods. For simplicity, we export a separate capsule
// for each supported dtype.
// This file is separated from `sphericart_jax_cuda.cu` because pybind11 does
// not accept cuda files.

#include <cstdlib>
#include <map>
Expand All @@ -12,7 +10,11 @@

#include "sphericart_cuda.hpp"
#include "sphericart/pybind11_kernel_helpers.hpp"
#include "sphericart/sphericart_jax_cuda.hpp"

struct SphDescriptor {
std::int64_t n_samples;
std::int64_t lmax;
};

namespace sphericart_jax {
namespace cuda {
Expand Down Expand Up @@ -85,78 +87,31 @@ inline void cuda_sph_with_hessians(
calculator->compute_with_hessians(xyz, n_samples, sph, dsph, ddsph, stream);
}

void cuda_spherical_f32(void* stream, void** in, const char* opaque, std::size_t opaque_len) {
cuda_sph<sphericart::cuda::SphericalHarmonics, float>(stream, in, opaque, opaque_len);
}

void cuda_spherical_f64(void* stream, void** in, const char* opaque, std::size_t opaque_len) {
cuda_sph<sphericart::cuda::SphericalHarmonics, double>(stream, in, opaque, opaque_len);
}

void cuda_dspherical_f32(void* stream, void** in, const char* opaque, std::size_t opaque_len) {
cuda_sph_with_gradients<sphericart::cuda::SphericalHarmonics, float>(
stream, in, opaque, opaque_len
);
}

void cuda_dspherical_f64(void* stream, void** in, const char* opaque, std::size_t opaque_len) {
cuda_sph_with_gradients<sphericart::cuda::SphericalHarmonics, double>(
stream, in, opaque, opaque_len
);
}

void cuda_ddspherical_f32(void* stream, void** in, const char* opaque, std::size_t opaque_len) {
cuda_sph_with_hessians<sphericart::cuda::SphericalHarmonics, float>(
stream, in, opaque, opaque_len
);
}

void cuda_ddspherical_f64(void* stream, void** in, const char* opaque, std::size_t opaque_len) {
cuda_sph_with_hessians<sphericart::cuda::SphericalHarmonics, double>(
stream, in, opaque, opaque_len
);
}

void cuda_solid_f32(void* stream, void** in, const char* opaque, std::size_t opaque_len) {
cuda_sph<sphericart::cuda::SolidHarmonics, float>(stream, in, opaque, opaque_len);
}

void cuda_solid_f64(void* stream, void** in, const char* opaque, std::size_t opaque_len) {
cuda_sph<sphericart::cuda::SolidHarmonics, double>(stream, in, opaque, opaque_len);
}

void cuda_dsolid_f32(void* stream, void** in, const char* opaque, std::size_t opaque_len) {
cuda_sph_with_gradients<sphericart::cuda::SolidHarmonics, float>(stream, in, opaque, opaque_len);
}

void cuda_dsolid_f64(void* stream, void** in, const char* opaque, std::size_t opaque_len) {
cuda_sph_with_gradients<sphericart::cuda::SolidHarmonics, double>(stream, in, opaque, opaque_len);
}

void cuda_ddsolid_f32(void* stream, void** in, const char* opaque, std::size_t opaque_len) {
cuda_sph_with_hessians<sphericart::cuda::SolidHarmonics, float>(stream, in, opaque, opaque_len);
}

void cuda_ddsolid_f64(void* stream, void** in, const char* opaque, std::size_t opaque_len) {
cuda_sph_with_hessians<sphericart::cuda::SolidHarmonics, double>(stream, in, opaque, opaque_len);
}

// Registration of the custom calls with pybind11

pybind11::dict Registrations() {
pybind11::dict dict;
dict["cuda_spherical_f32"] = EncapsulateFunction(cuda_spherical_f32);
dict["cuda_spherical_f64"] = EncapsulateFunction(cuda_spherical_f64);
dict["cuda_dspherical_f32"] = EncapsulateFunction(cuda_dspherical_f32);
dict["cuda_dspherical_f64"] = EncapsulateFunction(cuda_dspherical_f64);
dict["cuda_ddspherical_f32"] = EncapsulateFunction(cuda_ddspherical_f32);
dict["cuda_ddspherical_f64"] = EncapsulateFunction(cuda_ddspherical_f64);
dict["cuda_solid_f32"] = EncapsulateFunction(cuda_solid_f32);
dict["cuda_solid_f64"] = EncapsulateFunction(cuda_solid_f64);
dict["cuda_dsolid_f32"] = EncapsulateFunction(cuda_dsolid_f32);
dict["cuda_dsolid_f64"] = EncapsulateFunction(cuda_dsolid_f64);
dict["cuda_ddsolid_f32"] = EncapsulateFunction(cuda_ddsolid_f32);
dict["cuda_ddsolid_f64"] = EncapsulateFunction(cuda_ddsolid_f64);
dict["cuda_spherical_f32"] =
EncapsulateFunction(cuda_sph<sphericart::cuda::SphericalHarmonics, float>);
dict["cuda_spherical_f64"] =
EncapsulateFunction(cuda_sph<sphericart::cuda::SphericalHarmonics, double>);
dict["cuda_dspherical_f32"] =
EncapsulateFunction(cuda_sph_with_gradients<sphericart::cuda::SphericalHarmonics, float>);
dict["cuda_dspherical_f64"] =
EncapsulateFunction(cuda_sph_with_gradients<sphericart::cuda::SphericalHarmonics, double>);
dict["cuda_ddspherical_f32"] =
EncapsulateFunction(cuda_sph_with_hessians<sphericart::cuda::SphericalHarmonics, float>);
dict["cuda_ddspherical_f64"] =
EncapsulateFunction(cuda_sph_with_hessians<sphericart::cuda::SphericalHarmonics, double>);
dict["cuda_solid_f32"] = EncapsulateFunction(cuda_sph<sphericart::cuda::SolidHarmonics, float>);
dict["cuda_solid_f64"] = EncapsulateFunction(cuda_sph<sphericart::cuda::SolidHarmonics, double>);
dict["cuda_dsolid_f32"] =
EncapsulateFunction(cuda_sph_with_gradients<sphericart::cuda::SolidHarmonics, float>);
dict["cuda_dsolid_f64"] =
EncapsulateFunction(cuda_sph_with_gradients<sphericart::cuda::SolidHarmonics, double>);
dict["cuda_ddsolid_f32"] =
EncapsulateFunction(cuda_sph_with_hessians<sphericart::cuda::SolidHarmonics, float>);
dict["cuda_ddsolid_f64"] =
EncapsulateFunction(cuda_sph_with_hessians<sphericart::cuda::SolidHarmonics, double>);
return dict;
}

Expand Down

0 comments on commit 46718ef

Please sign in to comment.