diff --git a/lib/jax_finufft_cpu.cc b/lib/jax_finufft_cpu.cc index ddd4563..10cf204 100644 --- a/lib/jax_finufft_cpu.cc +++ b/lib/jax_finufft_cpu.cc @@ -4,9 +4,11 @@ #include "jax_finufft_cpu.h" +#include "jax_finufft_cpu_opts.h" #include "pybind11_kernel_helpers.h" using namespace jax_finufft; +namespace py = pybind11; namespace { @@ -92,6 +94,48 @@ PYBIND11_MODULE(jax_finufft_cpu, m) { m.def("registrations", &Registrations); m.def("build_descriptorf", &build_descriptor); m.def("build_descriptor", &build_descriptor); + + py::class_ opts(m, "Opts"); + + py::enum_(opts, "DebugLevel") + .value("Silent", jax_finufft_opts::DebugLevel::Silent) + .value("Vebose", jax_finufft_opts::DebugLevel::Verbose) + .value("Noisy", jax_finufft_opts::DebugLevel::Noisy) + .export_values(); + + py::enum_(opts, "FftwFlags") + .value("Estimate", jax_finufft_opts::FftwFlags::Estimate) + .value("Measure", jax_finufft_opts::FftwFlags::Measure) + .value("Patient", jax_finufft_opts::FftwFlags::Patient) + .value("Exhaustive", jax_finufft_opts::FftwFlags::Exhaustive) + .value("WisdomOnly", jax_finufft_opts::FftwFlags::WisdomOnly) + .export_values(); + + py::enum_(opts, "SpreadSort") + .value("No", jax_finufft_opts::SpreadSort::No) + .value("Yes", jax_finufft_opts::SpreadSort::Yes) + .value("Heuristic", jax_finufft_opts::SpreadSort::Heuristic) + .export_values(); + + py::enum_(opts, "SpreadThread") + .value("Auto", jax_finufft_opts::SpreadThread::Auto) + .value("Seq", jax_finufft_opts::SpreadThread::Seq) + .value("Parallel", jax_finufft_opts::SpreadThread::Parallel) + .export_values(); + + opts.def( + py::init(), + py::arg("modeord") = false, py::arg("chkbnds") = true, + py::arg("debug") = jax_finufft_opts::DebugLevel::Silent, + py::arg("spread_debug") = jax_finufft_opts::DebugLevel::Silent, py::arg("showwarn") = false, + py::arg("nthreads") = 0, py::arg("fftw") = int(FFTW_ESTIMATE), + py::arg("spread_sort") = jax_finufft_opts::SpreadSort::Heuristic, + py::arg("spread_kerevalmeth") = true, py::arg("spread_kerpad") = true, + py::arg("upsampfac") = 0.0, py::arg("spread_thread") = jax_finufft_opts::SpreadThread::Auto, + py::arg("maxbatchsize") = 0, py::arg("spread_nthr_atomic") = -1, + py::arg("spread_max_sp_size") = 0); } } // namespace diff --git a/lib/jax_finufft_cpu_opts.h b/lib/jax_finufft_cpu_opts.h new file mode 100644 index 0000000..7078732 --- /dev/null +++ b/lib/jax_finufft_cpu_opts.h @@ -0,0 +1,72 @@ +#ifndef _JAX_FINUFFT_OPTS_H_ +#define _JAX_FINUFFT_OPTS_H_ + +#include + +#include "finufft.h" +#include "jax_finufft_cpu.h" + +namespace jax_finufft { + +struct jax_finufft_opts { + enum DebugLevel { Silent = 0, Verbose, Noisy }; + enum FftwFlags { + Estimate = FFTW_ESTIMATE, + Measure = FFTW_MEASURE, + Patient = FFTW_PATIENT, + Exhaustive = FFTW_EXHAUSTIVE, + WisdomOnly = FFTW_WISDOM_ONLY + }; + enum SpreadSort { No = 0, Yes, Heuristic }; + enum SpreadThread { Auto = 0, Seq, Parallel }; + + finufft_opts opts; + + jax_finufft_opts( + bool modeord, // (type 1,2 only): 0 CMCL-style increasing mode order + // 1 FFT-style mode order + bool chkbnds, // 0 don't check NU pts in [-3pi,3pi), 1 do (1 only): 0 auto, 1 seq multithreaded, + // 2 parallel single-thread spread + int maxbatchsize, // (vectorized ntr>1 only): max transform batch, 0 auto + int spread_nthr_atomic, // if >=0, threads above which spreader OMP critical goes atomic + int spread_max_sp_size // if >0, overrides spreader (dir=1) max subproblem size + ) { + default_opts(&opts); + + opts.modeord = modeord; + opts.chkbnds = chkbnds; + + opts.debug = int(debug); + opts.spread_debug = int(spread_debug); + opts.showwarn = int(showwarn); + + opts.nthreads = nthreads; + opts.fftw = fftw; + opts.spread_sort = spread_sort; + opts.spread_kerevalmeth = int(spread_kerevalmeth); + opts.spread_kerpad = int(spread_kerpad); + opts.upsampfac = upsampfac; + opts.spread_thread = int(spread_thread); + opts.maxbatchsize = maxbatchsize; + opts.spread_nthr_atomic = spread_nthr_atomic; + opts.spread_max_sp_size = spread_max_sp_size; + } +}; + +} // namespace jax_finufft + +#endif