Skip to content

Commit

Permalink
Adding CPU options struct
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Mar 13, 2024
1 parent 7d67d84 commit 9c792a7
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 0 deletions.
44 changes: 44 additions & 0 deletions lib/jax_finufft_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

#include "jax_finufft_cpu.h"

#include "jax_finufft_cpu_opts.h"
#include "pybind11_kernel_helpers.h"

using namespace jax_finufft;
namespace py = pybind11;

namespace {

Expand Down Expand Up @@ -92,6 +94,48 @@ PYBIND11_MODULE(jax_finufft_cpu, m) {
m.def("registrations", &Registrations);
m.def("build_descriptorf", &build_descriptor<float>);
m.def("build_descriptor", &build_descriptor<double>);

py::class_<jax_finufft_opts> opts(m, "Opts");

py::enum_<jax_finufft_opts::DebugLevel>(opts, "DebugLevel")
.value("Silent", jax_finufft_opts::DebugLevel::Silent)
.value("Vebose", jax_finufft_opts::DebugLevel::Verbose)
.value("Noisy", jax_finufft_opts::DebugLevel::Noisy)
.export_values();

py::enum_<jax_finufft_opts::FftwFlags>(opts, "FftwFlags")
.value("Estimate", jax_finufft_opts::FftwFlags::Estimate)
.value("Measure", jax_finufft_opts::FftwFlags::Measure)
.value("Patient", jax_finufft_opts::FftwFlags::Patient)
.value("Exhaustive", jax_finufft_opts::FftwFlags::Exhaustive)
.value("WisdomOnly", jax_finufft_opts::FftwFlags::WisdomOnly)
.export_values();

py::enum_<jax_finufft_opts::SpreadSort>(opts, "SpreadSort")
.value("No", jax_finufft_opts::SpreadSort::No)
.value("Yes", jax_finufft_opts::SpreadSort::Yes)
.value("Heuristic", jax_finufft_opts::SpreadSort::Heuristic)
.export_values();

py::enum_<jax_finufft_opts::SpreadThread>(opts, "SpreadThread")
.value("Auto", jax_finufft_opts::SpreadThread::Auto)
.value("Seq", jax_finufft_opts::SpreadThread::Seq)
.value("Parallel", jax_finufft_opts::SpreadThread::Parallel)
.export_values();

opts.def(
py::init<bool, bool, jax_finufft_opts::DebugLevel, jax_finufft_opts::DebugLevel, bool, int,
int, jax_finufft_opts::SpreadSort, bool, bool, double,
jax_finufft_opts::SpreadThread, int, int, int>(),
py::arg("modeord") = false, py::arg("chkbnds") = true,
py::arg("debug") = jax_finufft_opts::DebugLevel::Silent,
py::arg("spread_debug") = jax_finufft_opts::DebugLevel::Silent, py::arg("showwarn") = false,
py::arg("nthreads") = 0, py::arg("fftw") = int(FFTW_ESTIMATE),
py::arg("spread_sort") = jax_finufft_opts::SpreadSort::Heuristic,
py::arg("spread_kerevalmeth") = true, py::arg("spread_kerpad") = true,
py::arg("upsampfac") = 0.0, py::arg("spread_thread") = jax_finufft_opts::SpreadThread::Auto,
py::arg("maxbatchsize") = 0, py::arg("spread_nthr_atomic") = -1,
py::arg("spread_max_sp_size") = 0);
}

} // namespace
72 changes: 72 additions & 0 deletions lib/jax_finufft_cpu_opts.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#ifndef _JAX_FINUFFT_OPTS_H_
#define _JAX_FINUFFT_OPTS_H_

#include <fftw3.h>

#include "finufft.h"
#include "jax_finufft_cpu.h"

namespace jax_finufft {

struct jax_finufft_opts {
enum DebugLevel { Silent = 0, Verbose, Noisy };
enum FftwFlags {
Estimate = FFTW_ESTIMATE,
Measure = FFTW_MEASURE,
Patient = FFTW_PATIENT,
Exhaustive = FFTW_EXHAUSTIVE,
WisdomOnly = FFTW_WISDOM_ONLY
};
enum SpreadSort { No = 0, Yes, Heuristic };
enum SpreadThread { Auto = 0, Seq, Parallel };

finufft_opts opts;

jax_finufft_opts(
bool modeord, // (type 1,2 only): 0 CMCL-style increasing mode order
// 1 FFT-style mode order
bool chkbnds, // 0 don't check NU pts in [-3pi,3pi), 1 do (<few % slower)

// diagnostic opts...
DebugLevel debug,
DebugLevel spread_debug, // spreader: 0 silent, 1 some timing/debug, or 2 tonnes
bool showwarn, // 0 don't print warnings to stderr, 1 do

// algorithm performance opts...
int nthreads, // number of threads to use, or 0 uses all available
int fftw, // plan flags to FFTW (FFTW_ESTIMATE=64, FFTW_MEASURE=0,...)
SpreadSort spread_sort, // spreader: 0 don't sort, 1 do, or 2 heuristic choice
bool spread_kerevalmeth, // spreader: 0 exp(sqrt()), 1 Horner piecewise poly (faster)
bool spread_kerpad, // (exp(sqrt()) only): 0 don't pad kernel to 4n, 1 do
double upsampfac, // upsampling ratio sigma: 2.0 std, 1.25 small FFT, 0.0 auto
SpreadThread spread_thread, // (vectorized ntr>1 only): 0 auto, 1 seq multithreaded,
// 2 parallel single-thread spread
int maxbatchsize, // (vectorized ntr>1 only): max transform batch, 0 auto
int spread_nthr_atomic, // if >=0, threads above which spreader OMP critical goes atomic
int spread_max_sp_size // if >0, overrides spreader (dir=1) max subproblem size
) {
default_opts<double>(&opts);

opts.modeord = modeord;
opts.chkbnds = chkbnds;

opts.debug = int(debug);
opts.spread_debug = int(spread_debug);
opts.showwarn = int(showwarn);

opts.nthreads = nthreads;
opts.fftw = fftw;
opts.spread_sort = spread_sort;
opts.spread_kerevalmeth = int(spread_kerevalmeth);
opts.spread_kerpad = int(spread_kerpad);
opts.upsampfac = upsampfac;
opts.spread_thread = int(spread_thread);
opts.maxbatchsize = maxbatchsize;
opts.spread_nthr_atomic = spread_nthr_atomic;
opts.spread_max_sp_size = spread_max_sp_size;
}
};

} // namespace jax_finufft

#endif

0 comments on commit 9c792a7

Please sign in to comment.