diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a62923..1be8242 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,25 @@ # Changelog +## 0.10.0 2026-01-10 + +### Added + +* Rust + * Add top-level `interpn`, `interpn_alloc` and `interpn_serial` methods along with supporting enums for selecting methods + * Add `par` feature, enabled by default, that enables parallelism with rayon in `interpn` function + * Add lazy static for getting number of physical cores to advise thread chunk sizes + +### Changed + +* Rust + * Update deps +* Python + * !Refactor `interpn` function + * Combine `check_bounds: bool` and `bounds_check_atol: float` to `check_bounds_with_atol: float | None` + * Replace `assume_regular` input with optional `grid_kind` to allow assuming either regular or rectilinear, or making no assumption + * Add `max_threads: int | None` input to allow manually limiting parallelism + * Use rust top-level `interpn` function as backend for method selection, bounds checks, and parallelism + ## 0.9.1 2025-12-31 ### Changed diff --git a/Cargo.lock b/Cargo.lock index 461a64f..c1d4854 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,6 +11,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "alloca" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5a7d05ea6aea7e9e64d25b9156ba2fee3fdd659e34e41063cd2fc7cd020d7f4" +dependencies = [ + "cc", +] + [[package]] name = "anes" version = "0.1.6" @@ -111,10 +120,11 @@ checksum = "b94f61472cee1439c0b966b47e3aca9ae07e45d070759512cd390ea2bebc6675" [[package]] name = "criterion" -version = "0.7.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1c047a62b0cc3e145fa84415a3191f628e980b194c2755aa12300a4e6cbd928" +checksum = "4d883447757bb0ee46f233e9dc22eb84d93a9508c9b868687b274fc431d886bf" dependencies = [ + "alloca", "anes", "cast", "ciborium", @@ -123,6 +133,7 @@ dependencies = [ "itertools 0.13.0", "num-traits", "oorandom", + "page_size", "plotters", "rayon", "regex", @@ -134,9 +145,9 @@ dependencies = [ [[package]] name = "criterion-plot" -version = "0.6.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b1bcc0dc7dfae599d84ad0b1a55f80cde8af3725da8313b528da95ef783e338" +checksum = "ed943f81ea2faa8dcecbbfa50164acf95d555afec96a27871663b300e387b2e4" dependencies = [ "cast", "itertools 0.13.0", @@ -213,6 +224,12 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hermit-abi" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" + [[package]] name = "indoc" version = "2.0.5" @@ -221,16 +238,18 @@ checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" [[package]] name = "interpn" -version = "0.9.1" +version = "0.10.0" dependencies = [ "criterion", "crunchy", "itertools 0.14.0", "ndarray", "num-traits", + "num_cpus", "numpy", "pyo3", "rand", + "rayon", ] [[package]] @@ -353,6 +372,16 @@ dependencies = [ "libm", ] +[[package]] +name = "num_cpus" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "numpy" version = "0.27.1" @@ -381,6 +410,16 @@ version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" +[[package]] +name = "page_size" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30d5b2194ed13191c1999ae0704b7839fb18384fa22e49b57eeaa97d79ce40da" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "plotters" version = "0.3.7" @@ -824,6 +863,22 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + [[package]] name = "winapi-util" version = "0.1.11" @@ -833,6 +888,12 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows-link" version = "0.2.0" diff --git a/Cargo.toml b/Cargo.toml index 5ddb9a4..847926c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "interpn" -version = "0.9.1" +version = "0.10.0" edition = "2024" rust-version = "1.87" # 2025-05-15 authors = ["James Logan "] @@ -17,8 +17,12 @@ crate-type = ["cdylib", "rlib"] [dependencies] # Rust lib deps -num-traits = { version = "0.2.19", default-features = false, features = ["libm"] } -crunchy = { version = "0.2.4", default-features = false } +num-traits = { version = "^0.2.19", default-features = false, features = ["libm"] } +crunchy = { version = "^0.2.4", default-features = false } + +# Parallelism +rayon = { version = "^1.11.0", optional = true } +num_cpus = { version = "^1.17.0", optional = true } # Python bindings pyo3 = { version = "0.27.2", features = ["extension-module", "abi3-py310", "generate-import-lib"], optional = true } @@ -27,17 +31,19 @@ numpy = { version = "0.27.1", optional = true } # Test-only utils itertools = { version = "0.14.0", optional = true } + [dev-dependencies] rand = "0.9.2" -criterion = "0.7.0" +criterion = "0.8.1" ndarray = "0.17.1" [features] -default = ["std", "crunchy/limit_64"] +default = ["std", "crunchy/limit_64", "par"] python = ["numpy", "pyo3", "std"] std = ["itertools"] deep-unroll = ["crunchy/limit_256"] fma = [] +par = ["std", "rayon", "num_cpus"] [profile.release] opt-level = 3 diff --git a/benches/bench_cpu.py b/benches/bench_cpu.py index ab06eab..b5fcdc3 100644 --- a/benches/bench_cpu.py +++ b/benches/bench_cpu.py @@ -18,6 +18,7 @@ MulticubicRectilinear, NearestRegular, NearestRectilinear, + interpn as interpn_fn, ) # Toggle SciPy/NumPy baselines via environment for PGO workloads. @@ -50,6 +51,11 @@ def average_call_time( DASH_STYLES = ["solid", "dash", "dot", "dashdot", "longdash", "longdashdot"] +THREAD_SPEEDUP_COLORS = { + "linear": "#1f77b4", + "cubic": "#ff7f0e", + "nearest": "#2ca02c", +} def _normalized_line_style(index: int) -> str: @@ -466,6 +472,196 @@ def _plot_speedup_vs_dims( fig.show() +def _thread_counts() -> list[int]: + max_threads = os.cpu_count() or 1 + max_threads = max(int(max_threads / 2), 1) # Real threads, not hyperthreads + counts = [] + threads = 1 + while threads < max_threads: + counts.append(threads) + threads *= 2 + counts.append(max_threads) + return sorted(set(counts)) + + +def _plot_speedup_vs_threads( + *, + thread_counts: list[int], + speedups: dict[str, dict[str, list[float]]], + nobs: int, + output_path: Path, +) -> None: + fig = make_subplots(rows=1, cols=1) + dash_styles = [ + _normalized_line_style(i) + for i in range(len(["linear", "cubic", "nearest"]) * 2) + ] + all_values = [] + thread_arr = np.array(thread_counts) + series: list[tuple[str, np.ndarray]] = [] + for grid_kind in ["regular", "rectilinear"]: + for method in ["linear", "cubic", "nearest"]: + values = speedups[grid_kind].get(method) + if not values: + continue + values_arr = np.array(values, dtype=float) + all_values.append(values_arr) + series.append((f"{method.title()} {grid_kind}", values_arr)) + for _, values_arr in series: + ones = np.ones_like(values_arr) + fill_between( + fig, + x=thread_arr, + upper=np.maximum(values_arr, ones), + lower=np.minimum(values_arr, ones), + row=1, + col=1, + fillcolor="rgba(139, 196, 59, 0.25)", + ) + for idx, (label, values_arr) in enumerate(series): + fig.add_trace( + go.Scatter( + x=thread_arr, + y=values_arr, + mode="lines+markers", + name=label, + line=dict( + color="black", + width=2, + dash=dash_styles[idx], + ), + marker=dict(size=7, color="black"), + showlegend=False, + ), + row=1, + col=1, + ) + fig.add_hline( + y=1.0, + line=dict(color="black", dash="dot", width=1), + row=1, + col=1, + ) + y_min = 1.0 + if all_values: + y_min = min(1.0, min(values.min() for values in all_values)) + y_max = max(thread_counts) if thread_counts else 1 + fig.update_xaxes( + title_text="Threads", + row=1, + col=1, + range=[min(thread_counts), max(thread_counts)] if thread_counts else None, + showline=True, + linecolor="black", + linewidth=1, + mirror=True, + ticks="outside", + tickcolor="black", + showgrid=False, + zeroline=False, + ) + fig.update_yaxes( + title_text="Speedup vs. 1 Thread", + row=1, + col=1, + range=[y_min, y_max], + showline=True, + linecolor="black", + linewidth=1, + mirror=True, + ticks="outside", + tickcolor="black", + showgrid=False, + zeroline=False, + ) + + fig.update_layout( + title=dict( + text=f"InterpN Thread Speedup ({nobs} Observation Points)", + y=0.98, + yanchor="top", + ), + height=430, + margin=dict(t=70, l=60, r=40, b=80), + showlegend=False, + plot_bgcolor="rgba(0,0,0,0)", + paper_bgcolor="rgba(0,0,0,0)", + font=dict(color="black"), + ) + fig.write_image(str(output_path)) + fig.write_html( + str(output_path.with_suffix(".html")), + include_plotlyjs="cdn", + full_html=False, + ) + fig.show() + + +def bench_thread_speedup_vs_threads(): + nobs = 10_000_000 + ndims = 2 + ngrid = int(10e6**0.5) + rng = np.random.default_rng(17) + + thread_counts = _thread_counts() + speedups: dict[str, dict[str, list[float]]] = { + "regular": {}, + "rectilinear": {}, + } + + def make_grids(kind: str) -> list[NDArray]: + base = np.linspace(-1.0, 1.0, ngrid) + if kind == "regular": + return [base for _ in range(ndims)] + warped = base**3 + return [warped for _ in range(ndims)] + + for grid_kind in ["regular", "rectilinear"]: + grids = make_grids(grid_kind) + mesh = np.meshgrid(*grids, indexing="ij") + vals = (mesh[0] + 2.0 * mesh[1]).astype(np.float64) + vals_flat = vals.ravel() + obs = [] + for grid in grids: + lo = grid[0] + hi = grid[-1] + span = hi - lo + obs.append( + rng.uniform(lo + 0.05 * span, hi - 0.05 * span, size=nobs).astype( + np.float64 + ) + ) + out = np.zeros_like(obs[0]) + + for method in ["linear", "cubic", "nearest"]: + timings = [] + for threads in thread_counts: + timed = average_call_time( + lambda points, threads=threads: interpn_fn( + obs=points, + grids=grids, + vals=vals_flat, + method=method, + out=out, + linearize_extrapolation=False, + grid_kind=grid_kind, + max_threads=threads, + ), + obs, + ) + timings.append(timed) + baseline = timings[0] + speedups[grid_kind][method] = [baseline / t if t else 0.0 for t in timings] + + output_path = Path(__file__).parent / f"../docs/speedup_vs_threads_{nobs}_obs.svg" + _plot_speedup_vs_threads( + thread_counts=thread_counts, + speedups=speedups, + nobs=nobs, + output_path=output_path, + ) + + def bench_4_dims_1_obs(): nbench = 30 # Bench iterations preallocate = False # Whether to preallocate output array for InterpN @@ -650,7 +846,7 @@ def bench_4_dims_1_obs(): def bench_3_dims_n_obs_unordered(): - for preallocate in [False, True]: + for preallocate in [True, False]: ndims = 3 # Number of grid dimensions ngrid = 20 # Size of grid on each dimension @@ -790,7 +986,7 @@ def bench_3_dims_n_obs_unordered(): def bench_4_dims_n_obs_unordered(): - for preallocate in [False, True]: + for preallocate in [True, False]: ndims = 4 # Number of grid dimensions ngrid = 20 # Size of grid on each dimension @@ -1077,6 +1273,7 @@ def bench_throughput_vs_dims(): def main(): bench_throughput_vs_dims() + bench_thread_speedup_vs_threads() bench_4_dims_1_obs() bench_4_dims_n_obs_unordered() bench_3_dims_n_obs_unordered() diff --git a/docs/speedup_vs_threads_10000000_obs.html b/docs/speedup_vs_threads_10000000_obs.html new file mode 100644 index 0000000..1c55871 --- /dev/null +++ b/docs/speedup_vs_threads_10000000_obs.html @@ -0,0 +1,2 @@ +
+
\ No newline at end of file diff --git a/scripts/distr_pgo_profile.sh b/scripts/distr_pgo_profile.sh index 46323e3..68755fa 100644 --- a/scripts/distr_pgo_profile.sh +++ b/scripts/distr_pgo_profile.sh @@ -1,23 +1,26 @@ #!/bin/bash -# Get llvm-profdata from `apt install llvm-20` -# Must match rust's llvm version or it will crash +# Get llvm-profdata from `apt install llvm-20`. +# Must match rust's llvm version or it will crash. -# Build instrumented wheel +# Build instrumented wheel. cargo clean uv cache clean uv pip install maturin rm -rf dist/ UV_NO_BUILD_CACHE=1 uv run --no-sync maturin build --compatibility pypi --out dist --verbose -- "-Cprofile-generate=${PWD}/scripts/pgo-profiles/pgo.profraw" -# Install instrumented wheel +# Install instrumented wheel. uv pip install $(find dist/ -name '*.whl')[pydantic] --group test --group bench --reinstall -# Clear existing profiles +# Clear existing profiles. rm -rf ./scripts/pgo-profiles; mkdir ./scripts/pgo-profiles; mkdir ./scripts/pgo-profiles/pgo.profraw -# Run reference workload to generate profile -uv run --no-sync ./scripts/profile_workload.py +# Run reference workload to generate profile. +# Name each profile file per-binary (%m), per-process (%p) +# so that they don't overwrite each other's results. +# Unfortunately, per-thread (%t) is not always available. +LLVM_PROFILE_FILE="${PWD}/scripts/pgo-profiles/pgo.profraw/%m-%p.profraw" uv run --no-sync ./scripts/profile_workload.py -# Merge profiles +# Merge profiles. /usr/lib/llvm-21/bin/llvm-profdata merge -o scripts/pgo-profiles/pgo.profdata $(find scripts/pgo-profiles/pgo.profraw -name '*.profraw') diff --git a/scripts/profile_workload.py b/scripts/profile_workload.py index 87debc1..267789a 100644 --- a/scripts/profile_workload.py +++ b/scripts/profile_workload.py @@ -3,19 +3,16 @@ from __future__ import annotations +import subprocess +import sys +from pathlib import Path + import numpy as np -from interpn import ( - MulticubicRectilinear, - MulticubicRegular, - MultilinearRectilinear, - MultilinearRegular, - NearestRectilinear, - NearestRegular, -) +from interpn import interpn as interpn_fn _TARGET_COUNT = int(1e4) -_OBSERVATION_COUNTS = (1, 3, 571, 2017) +_OBSERVATION_COUNTS = (1, 3, 571, 2017, int(1e4)) _MAX_DIMS = 4 _GRID_SIZE = 30 @@ -34,12 +31,35 @@ def _observation_points( return points -def _evaluate(interpolator, points: list[np.ndarray]) -> None: - # Without preallocated output - interpolator.eval(points) - # With preallocated output +def _evaluate( + *, + grids: list[np.ndarray], + vals: np.ndarray, + points: list[np.ndarray], + method: str, + grid_kind: str, + max_threads: int | None, +) -> None: + interpn_fn( + obs=points, + grids=grids, + vals=vals, + method=method, + grid_kind=grid_kind, + linearize_extrapolation=True, + max_threads=max_threads, + ) out = np.empty_like(points[0]) - interpolator.eval(points, out) + interpn_fn( + obs=points, + grids=grids, + vals=vals, + method=method, + grid_kind=grid_kind, + linearize_extrapolation=True, + out=out, + max_threads=max_threads, + ) def main() -> None: @@ -55,47 +75,40 @@ def main() -> None: ] mesh = np.meshgrid(*grids, indexing="ij") zgrid = rng.uniform(-1.0, 1.0, mesh[0].size).astype(dtype) - dims = [grid.size for grid in grids] - starts = np.array([grid[0] for grid in grids], dtype=dtype) - steps = np.array([grid[1] - grid[0] for grid in grids], dtype=dtype) - - linear_regular = MultilinearRegular.new(dims, starts, steps, zgrid) - linear_rect = MultilinearRectilinear.new(grids_rect, zgrid) - cubic_regular = MulticubicRegular.new( - dims, - starts, - steps, - zgrid, - linearize_extrapolation=True, - ) - cubic_rect = MulticubicRectilinear.new( - grids_rect, - zgrid, - linearize_extrapolation=True, + cases = ( + ("linear", "regular", grids), + ("linear", "rectilinear", grids_rect), + ("cubic", "regular", grids), + ("cubic", "rectilinear", grids_rect), + ("nearest", "regular", grids), + ("nearest", "rectilinear", grids_rect), ) - nearest_regular = NearestRegular.new(dims, starts, steps, zgrid) - nearest_rect = NearestRectilinear.new(grids_rect, zgrid) for nobs in _OBSERVATION_COUNTS: nreps = max(int(_TARGET_COUNT / nobs), 1) - for interpolator in ( - linear_regular, - linear_rect, - cubic_regular, - cubic_rect, - nearest_regular, - nearest_rect, - ): - for _ in range(nreps): - points = _observation_points(rng, ndims, nobs, dtype) - _evaluate(interpolator, points) - - print( - f"Completed {type(interpolator).__name__} " - f"dtype={np.dtype(dtype).name} ndims={ndims} nobs={nobs}" - ) + for max_threads in (None, 1): + for method, grid_kind, grids_in in cases: + for _ in range(nreps): + points = _observation_points(rng, ndims, nobs, dtype) + _evaluate( + grids=grids_in, + vals=zgrid, + points=points, + method=method, + grid_kind=grid_kind, + max_threads=max_threads, + ) + + mode = "parallel" if max_threads is None else "serial" + print( + f"Completed interpn method={method} grid={grid_kind} " + f"dtype={np.dtype(dtype).name} ndims={ndims} nobs={nobs} " + f"mode={mode}" + ) if __name__ == "__main__": main() + script = Path(__file__).with_name("profile_workload_ser.py") + subprocess.run([sys.executable, str(script)], check=True) diff --git a/scripts/profile_workload_ser.py b/scripts/profile_workload_ser.py new file mode 100644 index 0000000..87debc1 --- /dev/null +++ b/scripts/profile_workload_ser.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +"""Lightweight workload used to gather PGO profiles for interpn.""" + +from __future__ import annotations + +import numpy as np + +from interpn import ( + MulticubicRectilinear, + MulticubicRegular, + MultilinearRectilinear, + MultilinearRegular, + NearestRectilinear, + NearestRegular, +) + +_TARGET_COUNT = int(1e4) +_OBSERVATION_COUNTS = (1, 3, 571, 2017) +_MAX_DIMS = 4 +_GRID_SIZE = 30 + + +def _observation_points( + rng: np.random.Generator, ndims: int, nobs: int, dtype: np.dtype +) -> tuple[list[np.ndarray], list[np.ndarray]]: + """Generate observation points inside and outside the grid domain. + The fraction of points outside the domain here will set the relative weight of + extrapolation branches. + """ + m = max(int(float(nobs) ** (1.0 / ndims) + 2.0), 2) + axes = [rng.uniform(-1.05, 1.05, m).astype(dtype) for _ in range(ndims)] + mesh = np.meshgrid(*axes, indexing="ij") + points = [axis.flatten()[:nobs].copy() for axis in mesh] + return points + + +def _evaluate(interpolator, points: list[np.ndarray]) -> None: + # Without preallocated output + interpolator.eval(points) + # With preallocated output + out = np.empty_like(points[0]) + interpolator.eval(points, out) + + +def main() -> None: + rng = np.random.default_rng(2394587) + + for dtype in (np.float64, np.float32): + for ndims in range(1, _MAX_DIMS + 1): + ngrid = _GRID_SIZE if ndims < 5 else 6 + grids = [np.linspace(-1.0, 1.0, ngrid, dtype=dtype) for _ in range(ndims)] + grids_rect = [ + np.array(sorted(np.random.uniform(-1.0, 1.0, ngrid).astype(dtype))) + for _ in range(ndims) + ] + mesh = np.meshgrid(*grids, indexing="ij") + zgrid = rng.uniform(-1.0, 1.0, mesh[0].size).astype(dtype) + dims = [grid.size for grid in grids] + starts = np.array([grid[0] for grid in grids], dtype=dtype) + steps = np.array([grid[1] - grid[0] for grid in grids], dtype=dtype) + + linear_regular = MultilinearRegular.new(dims, starts, steps, zgrid) + linear_rect = MultilinearRectilinear.new(grids_rect, zgrid) + cubic_regular = MulticubicRegular.new( + dims, + starts, + steps, + zgrid, + linearize_extrapolation=True, + ) + cubic_rect = MulticubicRectilinear.new( + grids_rect, + zgrid, + linearize_extrapolation=True, + ) + nearest_regular = NearestRegular.new(dims, starts, steps, zgrid) + nearest_rect = NearestRectilinear.new(grids_rect, zgrid) + + for nobs in _OBSERVATION_COUNTS: + nreps = max(int(_TARGET_COUNT / nobs), 1) + + for interpolator in ( + linear_regular, + linear_rect, + cubic_regular, + cubic_rect, + nearest_regular, + nearest_rect, + ): + for _ in range(nreps): + points = _observation_points(rng, ndims, nobs, dtype) + _evaluate(interpolator, points) + + print( + f"Completed {type(interpolator).__name__} " + f"dtype={np.dtype(dtype).name} ndims={ndims} nobs={nobs}" + ) + + +if __name__ == "__main__": + main() diff --git a/src/interpn/__init__.py b/src/interpn/__init__.py index fd2db38..64cc2ef 100644 --- a/src/interpn/__init__.py +++ b/src/interpn/__init__.py @@ -56,37 +56,41 @@ def interpn( method: Literal["linear", "cubic", "nearest"] = "linear", out: NDArray | None = None, linearize_extrapolation: bool = True, - assume_regular: bool = False, - check_bounds: bool = False, - bounds_atol: float = 1e-8, + grid_kind: Literal["regular", "rectilinear"] | None = None, + check_bounds_with_atol: float | None = None, + max_threads: int | None = None, ) -> NDArray: """ Evaluate an N-dimensional grid at the supplied observation points. - Performs some small allocations to prepare the inputs and - performs O(gridsize) checks to determine grid regularity - unless `assume_regular` is set. To avoid this overhead entirely, - use the persistent wrapper classes or raw bindings instead. - Reallocates input arrays if and only if they are not contiguous yet. + Note: values must be defined in C-order, like made by + `numpy.meshgrid(*grids, indexing="ij")`. Values on meshgrids defined + in graphics-order without `indexing="ij"` will not have the desired effect. + + If a pre-allocated output array is provided, the returned array is a + reference to that array. + Args: obs: Observation coordinates, one array per dimension. grids: Grid axis coordinates, one array per dimension. - vals: Values defined on the full tensor-product grid. + vals: Values defined on the full cartesian-product grid. method: Interpolation kind, one of ``"linear"``, ``"cubic"``, or ``"nearest"``. out: Optional preallocated array that receives the result. linearize_extrapolation: Whether cubic extrapolation should fall back to linear behaviour outside the grid bounds. - assume_regular: Treat the grid as regular without checking spacing. - check_bounds: When True, raise if any observation lies outside the grid. - bounds_atol: Absolute tolerance for bounds checks, to avoid spurious errors + grid_kind: Optional ``"regular"`` or ``"rectilinear"`` to skip + grid-shape autodetection. + check_bounds_with_atol: When set, raise if any observation lies outside + the grid by more than this absolute tolerance. + max_threads: Optional upper bound for parallel execution threads. Returns: Interpolated values """ # Allocate for the output if it is not supplied - out = out or np.zeros_like(obs[0]) + out = out if out is not None else np.zeros_like(obs[0]) outshape = out.shape out = out.ravel() # Flat view without reallocating @@ -101,106 +105,35 @@ def interpn( "`interpn` defined only for float32 and float64 data" ) - # Check regularity - is_regular = assume_regular or _check_regular(grids) - - if is_regular: - dims = np.array([len(grid) for grid in grids], dtype=int) - starts = np.array([grid[0] for grid in grids], dtype=dtype) - steps = np.array([grid[1] - grid[0] for grid in grids], dtype=dtype) - else: - # Pyright doesn't understand match-case - dims = np.empty((0,), dtype=int) - starts = np.empty((0,), dtype=dtype) - steps = starts - - # Check bounds - if check_bounds: - outb = np.zeros_like(out.shape, dtype=bool) - match (dtype, is_regular): - case (np.float32, True): - raw.check_bounds_regular_f32( - dims, starts, steps, obs, atol=bounds_atol, out=outb - ) - case (np.float64, True): - raw.check_bounds_regular_f64( - dims, starts, steps, obs, atol=bounds_atol, out=outb - ) - case (np.float32, False): - raw.check_bounds_rectilinear_f32(grids, obs, atol=bounds_atol, out=outb) - case (np.float64, False): - raw.check_bounds_rectilinear_f64(grids, obs, atol=bounds_atol, out=outb) - - if any(outb): - raise ValueError("Observation points violate interpolator bounds") - # Do interpolation - match (dtype, is_regular, method): - case (np.float32, True, "linear"): - raw.interpn_linear_regular_f32(dims, starts, steps, vals, obs, out) - case (np.float64, True, "linear"): - raw.interpn_linear_regular_f64(dims, starts, steps, vals, obs, out) - case (np.float32, False, "linear"): - raw.interpn_linear_rectilinear_f32(grids, vals, obs, out) - case (np.float64, False, "linear"): - raw.interpn_linear_rectilinear_f64(grids, vals, obs, out) - case (np.float32, True, "nearest"): - raw.interpn_nearest_regular_f32(dims, starts, steps, vals, obs, out) - case (np.float64, True, "nearest"): - raw.interpn_nearest_regular_f64(dims, starts, steps, vals, obs, out) - case (np.float32, False, "nearest"): - raw.interpn_nearest_rectilinear_f32(grids, vals, obs, out) - case (np.float64, False, "nearest"): - raw.interpn_nearest_rectilinear_f64(grids, vals, obs, out) - case (np.float32, True, "cubic"): - raw.interpn_cubic_regular_f32( - dims, - starts, - steps, - vals, - linearize_extrapolation, - obs, - out, - ) - case (np.float64, True, "cubic"): - raw.interpn_cubic_regular_f64( - dims, - starts, - steps, - vals, - linearize_extrapolation, - obs, - out, - ) - case (np.float32, False, "cubic"): - raw.interpn_cubic_rectilinear_f32( + match dtype: + case np.float32: + raw.interpn_f32( grids, vals, - linearize_extrapolation, obs, out, + method=method, + grid_kind=grid_kind, + linearize_extrapolation=linearize_extrapolation, + check_bounds_with_atol=check_bounds_with_atol, + max_threads=max_threads, ) - case (np.float64, False, "cubic"): - raw.interpn_cubic_rectilinear_f64( + case np.float64: + raw.interpn_f64( grids, vals, - linearize_extrapolation, obs, out, + method=method, + grid_kind=grid_kind, + linearize_extrapolation=linearize_extrapolation, + check_bounds_with_atol=check_bounds_with_atol, + max_threads=max_threads, ) case _: raise ValueError( - "Unsupported interpolation configuration:" - f" {dtype}, {is_regular}, {method}" + f"Unsupported interpolation configuration: {dtype}, {method}" ) return out.reshape(outshape) - - -def _check_regular(grids: Sequence[NDArray]) -> bool: - """Check if grids are all regularly spaced""" - is_regular = True - for grid in grids: - dgrid = np.diff(grid) - is_regular = is_regular and np.all(dgrid == dgrid[0]) - return bool(is_regular) diff --git a/src/interpn/raw.py b/src/interpn/raw.py index 1d72221..3b80dac 100644 --- a/src/interpn/raw.py +++ b/src/interpn/raw.py @@ -4,6 +4,8 @@ """ from .interpn import ( + interpn_f64, + interpn_f32, interpn_linear_regular_f64, interpn_linear_regular_f32, interpn_linear_rectilinear_f64, @@ -23,6 +25,8 @@ ) __all__ = [ + "interpn_f64", + "interpn_f32", "interpn_linear_regular_f64", "interpn_linear_regular_f32", "interpn_linear_rectilinear_f64", diff --git a/src/interpn/raw.pyi b/src/interpn/raw.pyi index af7f424..e585c4c 100644 --- a/src/interpn/raw.pyi +++ b/src/interpn/raw.pyi @@ -11,6 +11,8 @@ BoolArray = NDArray[np.bool_] IntArray = NDArray[np.intp] __all__ = [ + "interpn_f64", + "interpn_f32", "interpn_linear_regular_f64", "interpn_linear_regular_f32", "interpn_linear_rectilinear_f64", @@ -29,6 +31,28 @@ __all__ = [ "check_bounds_rectilinear_f32", ] +def interpn_f64( + grids: Sequence[NDArrayF64], + vals: NDArrayF64, + obs: Sequence[NDArrayF64], + out: NDArrayF64, + method: str = "linear", + grid_kind: str | None = None, + linearize_extrapolation: bool = True, + check_bounds_with_atol: float | None = None, + max_threads: int | None = None, +) -> None: ... +def interpn_f32( + grids: Sequence[NDArrayF32], + vals: NDArrayF32, + obs: Sequence[NDArrayF32], + out: NDArrayF32, + method: str = "linear", + grid_kind: str | None = None, + linearize_extrapolation: bool = True, + check_bounds_with_atol: float | None = None, + max_threads: int | None = None, +) -> None: ... def interpn_linear_regular_f64( dims: IntArray, starts: NDArrayF64, diff --git a/src/lib.rs b/src/lib.rs index 88d8cab..a407d68 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -88,12 +88,24 @@ // expanded code that is entirely in const. #![allow(clippy::absurd_extreme_comparisons)] +use num_traits::Float; + pub mod multilinear; pub use multilinear::{MultilinearRectilinear, MultilinearRegular}; pub mod multicubic; pub use multicubic::{MulticubicRectilinear, MulticubicRegular}; +pub mod linear { + pub use crate::multilinear::rectilinear; + pub use crate::multilinear::regular; +} + +pub mod cubic { + pub use crate::multicubic::rectilinear; + pub use crate::multicubic::regular; +} + pub mod nearest; pub use nearest::{NearestRectilinear, NearestRegular}; @@ -103,6 +115,15 @@ pub use one_dim::{ linear::Linear1D, linear::LinearHoldLast1D, }; +#[cfg(feature = "par")] +use rayon::{ + iter::{IndexedParallelIterator, ParallelIterator}, + slice::ParallelSliceMut, +}; + +#[cfg(feature = "par")] +use std::sync::{LazyLock, Mutex}; + #[cfg(feature = "std")] pub mod utils; @@ -112,6 +133,395 @@ pub(crate) mod testing; #[cfg(feature = "python")] pub mod python; +/// Interpolant function for multi-dimensional methods. +#[derive(Clone, Copy)] +pub enum GridInterpMethod { + /// Multi-linear interpolation. + Linear, + /// Cubic Hermite spline interpolation. + Cubic, + /// Nearest-neighbor interpolation. + Nearest, +} + +/// Grid spacing category for multi-dimensional methods. +#[derive(Clone, Copy)] +pub enum GridKind { + /// Evenly-spaced points along each axis. + Regular, + /// Un-evenly spaced points along each axis. + Rectilinear, +} + +const MAXDIMS: usize = 8; +const MAXDIMS_ERR: &str = + "Dimension exceeds maximum (8). Use interpolator struct directly for higher dimensions."; +const MIN_CHUNK_SIZE: usize = 1024; + +/// The number of physical cores present on the machine; +/// initialized once, then never again, because each call involves some file I/O +/// and allocations that can be slower than the function call that they support. +/// +/// On subsequent accesses, each access is an atomic load without any waiting paths. +/// +/// This lock can only be contended if multiple threads attempt access +/// before it is initialized; in that case, the waiting threads may park +/// until initialization is complete, which can cause ~20us delays +/// on first access only. +#[cfg(feature = "par")] +static PHYSICAL_CORES: LazyLock = LazyLock::new(num_cpus::get_physical); + +/// Evaluate multidimensional interpolation on a regular grid in up to 8 dimensions. +/// Assumes C-style ordering of vals (z(x0, y0), z(x0, y1), ..., z(x0, yn), z(x1, y0), ...). +/// +/// For lower dimensions, a fast flattened method is used. For higher dimensions, where that flattening +/// becomes impractical due to compile times and instruction size, evaluation defers to a bounded +/// recursion. +/// The linear method uses the flattening for 1-6 dimensions, while +/// flattened cubic methods are available up to 3 dimensions by default and up to 4 dimensions +/// with the `deep_unroll` feature enabled. +/// +/// This is a convenience function; best performance will be achieved by using the exact right +/// number for the N parameter, as this will slightly reduce compute and storage overhead, +/// and the underlying method can be extended to more than this function's limit of 8 dimensions. +/// The limit of 8 dimensions was chosen for no more specific reason than to reduce unit test times. +/// +/// While this method initializes the interpolator struct on every call, the overhead of doing this +/// is minimal even when using it to evaluate one observation point at a time. +/// +/// Like most grid search algorithms (including in the standard library), the uniqueness and +/// monotonicity of the grid is the responsibility of the user, because checking it is often much +/// more expensive than the algorithm that we will perform on it. Behavior with ill-posed grids +/// is undefined. +/// +/// #### Args: +/// +/// * `grids`: `N` slices of each axis' grid coordinates. Must be unique and monotonically increasing. +/// * `vals`: Flattened `N`-dimensional array of data values at each grid point in C-style order. +/// Must be the same length as the cartesian product of the grids, (n_x * n_y * ...). +/// * `obs`: `N` slices of Observation points where the interpolant should be evaluated. +/// Must be of equal length. +/// * `out`: Pre-allocated output buffer to place the resulting values. +/// Must be the same length as each of the `obs` slices. +/// * `method`: Choice of interpolant function. +/// * `assume_grid_kind`: Whether to assume the grid is regular (evenly-spaced), +/// rectilinear (un-evenly spaced), or make no assumption. +/// If an assumption is provided, this bypasses a check of each +/// grid, which can be a major speedup in some cases. +/// * `linearize_extrapolation`: Whether cubic methods should extrapolate linearly instead of the default +/// quadratic extrapolation. Linearization is recommended to prevent +/// the interpolant from diverging to extremes outside the grid. +/// * `check_bounds_with_atol`: If provided, return an error if any observation points are outside the grid +/// by an amount exceeding the provided tolerance. +/// * `max_threads`: If provided, limit number of threads used to at most this number. Otherwise, +/// use a heuristic to choose the number that will provide the best throughput. +#[cfg(feature = "par")] +pub fn interpn( + grids: &[&[T]], + vals: &[T], + obs: &[&[T]], + out: &mut [T], + method: GridInterpMethod, + assume_grid_kind: Option, + linearize_extrapolation: bool, + check_bounds_with_atol: Option, + max_threads: Option, +) -> Result<(), &'static str> { + let ndims = grids.len(); + if ndims > MAXDIMS { + return Err(MAXDIMS_ERR); + } + let n = out.len(); + + // Resolve grid kind, checking the grid if the kind is not provided by the user. + // We do this once at the top level so that the work is not repeated by each thread. + let kind = resolve_grid_kind(assume_grid_kind, grids)?; + + // If there are enough points to justify it, run parallel + if 2 * MIN_CHUNK_SIZE <= n { + // Chunk for parallelism. + // + // By default, use only physical cores, because on most machines as of + // 2026, only half the available cores represent real compute capability due to + // the widespread adoption of hyperthreading. If a larger number is requested for + // max_threads, that value is clamped to the total available threads so that we don't + // queue chunks unnecessarily. + // + // We also use a minimum chunk size of 1024 as a heuristic, because below that limit, + // single-threaded performance is usually faster due to a combination of thread spawning overhead, + // memory page sizing, and improved vectorization over larger inputs. + let num_cores_physical = *PHYSICAL_CORES; // Real cores, populated on first access + let num_cores_pool = rayon::current_num_threads(); // Available cores from rayon thread pool + let num_cores_available = num_cores_physical.min(num_cores_pool).max(1); // Real max + let num_cores = match max_threads { + Some(num_cores_requested) => num_cores_requested.min(num_cores_available), + None => num_cores_available, + }; + let chunk = MIN_CHUNK_SIZE.max(n / num_cores); + + // Make a shared error indicator + let result: Mutex> = Mutex::new(None); + let write_err = |msg: &'static str| { + let mut guard = result.lock().unwrap(); + if guard.is_none() { + *guard = Some(msg); + } + }; + + // Run threaded + out.par_chunks_mut(chunk).enumerate().for_each(|(i, outc)| { + // Calculate the start and end of observation point chunks + let start = chunk * i; + let end = start + outc.len(); + + // Chunk observation points + let mut obs_slices: [&[T]; 8] = [&[]; 8]; + for (j, o) in obs.iter().enumerate() { + let s = &o.get(start..end); + match s { + Some(s) => obs_slices[j] = s, + None => { + write_err("Dimension mismatch"); + return; + } + }; + } + + // Do interpolations + let res_inner = interpn_serial( + grids, + vals, + &obs_slices[..ndims], + outc, + method, + Some(kind), + linearize_extrapolation, + check_bounds_with_atol, + ); + + match res_inner { + Ok(()) => {} + Err(msg) => write_err(msg), + } + }); + + // Handle errors from threads + match *result.lock().unwrap() { + Some(msg) => Err(msg), + None => Ok(()), + } + } else { + // If there are not enough points to justify parallelism, run serial + interpn_serial( + grids, + vals, + obs, + out, + method, + Some(kind), + linearize_extrapolation, + check_bounds_with_atol, + ) + } +} + +/// Allocating variant of [interpn]. +/// It is recommended to pre-allocate outputs and use the non-allocating variant +/// whenever possible. +#[cfg(feature = "par")] +pub fn interpn_alloc( + grids: &[&[T]], + vals: &[T], + obs: &[&[T]], + out: Option>, + method: GridInterpMethod, + assume_grid_kind: Option, + linearize_extrapolation: bool, + check_bounds_with_atol: Option, + max_threads: Option, +) -> Result, &'static str> { + // Empty input -> empty output + if obs.len() == 0 { + return Ok(Vec::with_capacity(0)); + } + + // If output storage was not provided, build it now + let mut out = out.unwrap_or_else(|| vec![T::zero(); obs[0].len()]); + + interpn( + grids, + vals, + obs, + &mut out, + method, + assume_grid_kind, + linearize_extrapolation, + check_bounds_with_atol, + max_threads, + )?; + + Ok(out) +} + +/// Single-threaded, non-allocating variant of [interpn] available without `par` feature. +pub fn interpn_serial( + grids: &[&[T]], + vals: &[T], + obs: &[&[T]], + out: &mut [T], + method: GridInterpMethod, + assume_grid_kind: Option, + linearize_extrapolation: bool, + check_bounds_with_atol: Option, +) -> Result<(), &'static str> { + let ndims = grids.len(); + if ndims > MAXDIMS { + return Err(MAXDIMS_ERR); + } + + // Resolve grid kind, checking the grid if the kind is not provided by the user. + let kind = resolve_grid_kind(assume_grid_kind, grids)?; + + // Extract regular grid params + let get_regular_grid = || { + let mut dims = [0_usize; MAXDIMS]; + let mut starts = [T::zero(); MAXDIMS]; + let mut steps = [T::zero(); MAXDIMS]; + + for (i, grid) in grids.iter().enumerate() { + if grid.len() < 2 { + return Err("All grids must have at least two entries"); + } + dims[i] = grid.len(); + starts[i] = grid[0]; + steps[i] = grid[1] - grid[0]; + } + + Ok((dims, starts, steps)) + }; + + // Bounds checks for regular grid, if requested + let maybe_check_bounds_regular = |dims: &[usize], starts: &[T], steps: &[T], obs: &[&[T]]| { + if let Some(atol) = check_bounds_with_atol { + let mut bounds = [false; MAXDIMS]; + let out = &mut bounds[..ndims]; + multilinear::regular::check_bounds( + &dims[..ndims], + &starts[..ndims], + &steps[..ndims], + obs, + atol, + out, + )?; + if bounds.iter().any(|x| *x) { + return Err("At least one observation point is outside the grid."); + } + } + Ok(()) + }; + + // Bounds checks for rectilinear grid, if requested + let maybe_check_bounds_rectilinear = |grids, obs| { + if let Some(atol) = check_bounds_with_atol { + let mut bounds = [false; MAXDIMS]; + let out = &mut bounds[..ndims]; + multilinear::rectilinear::check_bounds(grids, obs, atol, out)?; + if bounds.iter().any(|x| *x) { + return Err("At least one observation point is outside the grid."); + } + } + Ok(()) + }; + + // Select lower-level method + match (method, kind) { + (GridInterpMethod::Linear, GridKind::Regular) => { + let (dims, starts, steps) = get_regular_grid()?; + maybe_check_bounds_regular(&dims, &starts, &steps, obs)?; + linear::regular::interpn( + &dims[..ndims], + &starts[..ndims], + &steps[..ndims], + vals, + obs, + out, + ) + } + (GridInterpMethod::Linear, GridKind::Rectilinear) => { + maybe_check_bounds_rectilinear(grids, obs)?; + linear::rectilinear::interpn(grids, vals, obs, out) + } + (GridInterpMethod::Cubic, GridKind::Regular) => { + let (dims, starts, steps) = get_regular_grid()?; + maybe_check_bounds_regular(&dims, &starts, &steps, obs)?; + cubic::regular::interpn( + &dims[..ndims], + &starts[..ndims], + &steps[..ndims], + vals, + linearize_extrapolation, + obs, + out, + ) + } + (GridInterpMethod::Cubic, GridKind::Rectilinear) => { + maybe_check_bounds_rectilinear(grids, obs)?; + cubic::rectilinear::interpn(grids, vals, linearize_extrapolation, obs, out) + } + (GridInterpMethod::Nearest, GridKind::Regular) => { + let (dims, starts, steps) = get_regular_grid()?; + maybe_check_bounds_regular(&dims, &starts, &steps, obs)?; + nearest::regular::interpn( + &dims[..ndims], + &starts[..ndims], + &steps[..ndims], + vals, + obs, + out, + ) + } + (GridInterpMethod::Nearest, GridKind::Rectilinear) => { + maybe_check_bounds_rectilinear(grids, obs)?; + nearest::rectilinear::interpn(grids, vals, obs, out) + } + } +} + +/// Figure out whether a grid is regular or rectilinear. +fn resolve_grid_kind( + assume_grid_kind: Option, + grids: &[&[T]], +) -> Result { + let kind = match assume_grid_kind { + Some(GridKind::Regular) => GridKind::Regular, + Some(GridKind::Rectilinear) => GridKind::Rectilinear, + None => { + // Check whether grid is regular + let mut is_regular = true; + + for grid in grids.iter() { + if grid.len() < 2 { + return Err("All grids must have at least two entries"); + } + let step = grid[1] - grid[0]; + + if !grid.windows(2).all(|pair| pair[1] - pair[0] == step) { + is_regular = false; + break; + } + } + + if is_regular { + GridKind::Regular + } else { + GridKind::Rectilinear + } + } + }; + + Ok(kind) +} + /// Index a single value from an array #[inline] pub(crate) fn index_arr(loc: &[usize], dimprod: &[usize], data: &[T]) -> T { @@ -132,11 +542,8 @@ pub(crate) fn index_arr_fixed_dims( ) -> T { let mut i = 0; - // unroll! { - // for j < 7 in 0..N { for j in 0..N { i += loc[j] * dimprod[j]; - // } } data[i] diff --git a/src/python.rs b/src/python.rs index 165d684..2a3d3c9 100644 --- a/src/python.rs +++ b/src/python.rs @@ -5,6 +5,7 @@ use pyo3::prelude::*; use crate::multicubic; use crate::multilinear; use crate::nearest; +use crate::{GridInterpMethod, GridKind}; /// Maximum number of dimensions for linear interpn convenience methods const MAXDIMS: usize = 8; @@ -35,6 +36,9 @@ fn interpn<'py>(_py: Python, m: &Bound<'py, PyModule>) -> PyResult<()> { // Multicubic with rectilinear grid m.add_function(wrap_pyfunction!(interpn_cubic_rectilinear_f64, m)?)?; m.add_function(wrap_pyfunction!(interpn_cubic_rectilinear_f32, m)?)?; + // Top-level interpn dispatch + m.add_function(wrap_pyfunction!(interpn_f64, m)?)?; + m.add_function(wrap_pyfunction!(interpn_f32, m)?)?; Ok(()) } @@ -290,3 +294,69 @@ macro_rules! interpn_cubic_rectilinear_impl { interpn_cubic_rectilinear_impl!(interpn_cubic_rectilinear_f64, f64); interpn_cubic_rectilinear_impl!(interpn_cubic_rectilinear_f32, f32); + +fn parse_grid_interp_method(method: &str) -> Result { + match method.to_ascii_lowercase().as_str() { + "linear" => Ok(GridInterpMethod::Linear), + "cubic" => Ok(GridInterpMethod::Cubic), + "nearest" => Ok(GridInterpMethod::Nearest), + _ => Err(exceptions::PyValueError::new_err( + "`method` must be 'linear', 'cubic', or 'nearest'", + )), + } +} + +fn parse_grid_kind(grid_kind: Option<&str>) -> Result, PyErr> { + match grid_kind.map(|kind| kind.to_ascii_lowercase()) { + None => Ok(None), + Some(kind) => match kind.as_str() { + "regular" => Ok(Some(GridKind::Regular)), + "rectilinear" => Ok(Some(GridKind::Rectilinear)), + _ => Err(exceptions::PyValueError::new_err( + "`grid_kind` must be 'regular', 'rectilinear', or None", + )), + }, + } +} + +macro_rules! interpn_top_impl { + ($funcname:ident, $T:ty) => { + #[pyfunction] + #[pyo3(signature = (grids, vals, obs, out, method="linear", grid_kind=None, linearize_extrapolation=true, check_bounds_with_atol=None, max_threads=None))] + fn $funcname( + grids: Vec>, + vals: PyReadonlyArray1<$T>, + obs: Vec>, + mut out: PyReadwriteArray1<$T>, + method: &str, + grid_kind: Option<&str>, + linearize_extrapolation: bool, + check_bounds_with_atol: Option<$T>, + max_threads: Option, + ) -> PyResult<()> { + unpack_vec_of_arr!(grids, grids, $T); + unpack_vec_of_arr!(obs, obs, $T); + + let method = parse_grid_interp_method(method)?; + let grid_kind = parse_grid_kind(grid_kind)?; + + match crate::interpn( + grids, + vals.as_slice()?, + obs, + out.as_slice_mut()?, + method, + grid_kind, + linearize_extrapolation, + check_bounds_with_atol, + max_threads, + ) { + Ok(()) => Ok(()), + Err(msg) => Err(exceptions::PyAssertionError::new_err(msg)), + } + } + }; +} + +interpn_top_impl!(interpn_f64, f64); +interpn_top_impl!(interpn_f32, f32); diff --git a/test/test_interpn.py b/test/test_interpn.py index 989731b..0f185f8 100644 --- a/test/test_interpn.py +++ b/test/test_interpn.py @@ -5,7 +5,8 @@ @pytest.mark.parametrize("dtype", [np.float64, np.float32]) -def test_interpn_check_bounds_regular(dtype): +@pytest.mark.parametrize("max_threads", [None, 1], ids=["parallel", "serial"]) +def test_interpn_check_bounds_regular(dtype, max_threads): grid = np.linspace(-1.0, 1.0, 5).astype(dtype) vals = np.linspace(0.0, 10.0, grid.size).astype(dtype) @@ -17,22 +18,25 @@ def test_interpn_check_bounds_regular(dtype): grids=[grid], vals=vals, method="linear", - check_bounds=True, + check_bounds_with_atol=1e-8, + max_threads=max_threads, ) assert inside.shape == obs_inside[0].shape - with pytest.raises(ValueError): + with pytest.raises(AssertionError): interpn( obs=obs_outside, grids=[grid], vals=vals, method="linear", - check_bounds=True, + check_bounds_with_atol=1e-8, + max_threads=max_threads, ) @pytest.mark.parametrize("dtype", [np.float64, np.float32]) -def test_interpn_check_bounds_rectilinear(dtype): +@pytest.mark.parametrize("max_threads", [None, 1], ids=["parallel", "serial"]) +def test_interpn_check_bounds_rectilinear(dtype, max_threads): grid = np.array([-1.0, -0.25, 0.5, 2.0], dtype=dtype) vals = np.linspace(0.0, 10.0, grid.size).astype(dtype) @@ -44,15 +48,17 @@ def test_interpn_check_bounds_rectilinear(dtype): grids=[grid], vals=vals, method="linear", - check_bounds=True, + check_bounds_with_atol=1e-8, + max_threads=max_threads, ) assert inside.shape == obs_inside[0].shape - with pytest.raises(ValueError): + with pytest.raises(AssertionError): interpn( obs=obs_outside, grids=[grid], vals=vals, method="linear", - check_bounds=True, + check_bounds_with_atol=1e-8, + max_threads=max_threads, ) diff --git a/test/test_multicubic_rectilinear.py b/test/test_multicubic_rectilinear.py index 65d5b7f..61313dc 100644 --- a/test/test_multicubic_rectilinear.py +++ b/test/test_multicubic_rectilinear.py @@ -1,8 +1,10 @@ import numpy as np +import pytest import interpn -def test_multilinear_rectilinear(): +@pytest.mark.parametrize("max_threads", [None, 1], ids=["parallel", "serial"]) +def test_multilinear_rectilinear(max_threads): for dtype in [np.float64, np.float32]: x = np.linspace(0.0, 10.0, 5).astype(dtype) y = np.linspace(20.0, 30.0, 4).astype(dtype) @@ -48,6 +50,7 @@ def test_multilinear_rectilinear(): vals=zgrid.flatten(), method="cubic", linearize_extrapolation=False, + max_threads=max_threads, ) for i in range(out_helper.size): assert out_helper[i] == zf[i] diff --git a/test/test_multicubic_regular.py b/test/test_multicubic_regular.py index 54cfa4a..8989f42 100644 --- a/test/test_multicubic_regular.py +++ b/test/test_multicubic_regular.py @@ -1,8 +1,10 @@ import numpy as np +import pytest import interpn -def test_multicubic_regular(): +@pytest.mark.parametrize("max_threads", [None, 1], ids=["parallel", "serial"]) +def test_multicubic_regular(max_threads): for dtype, tol in [(np.float64, 1e-12), (np.float32, 1e-6)]: x = np.linspace(0.0, 10.0, 7).astype(dtype) y = np.linspace(20.0, 30.0, 5).astype(dtype) @@ -55,6 +57,7 @@ def test_multicubic_regular(): vals=zgrid.flatten(), method="cubic", linearize_extrapolation=False, + max_threads=max_threads, ) for i in range(out_helper.size): assert approx(out_helper[i], zf[i], dtype(tol)) diff --git a/test/test_multilinear_rectilinear.py b/test/test_multilinear_rectilinear.py index e6b9528..aa1fbe0 100644 --- a/test/test_multilinear_rectilinear.py +++ b/test/test_multilinear_rectilinear.py @@ -1,8 +1,10 @@ import numpy as np +import pytest import interpn -def test_multilinear_rectilinear(): +@pytest.mark.parametrize("max_threads", [None, 1], ids=["parallel", "serial"]) +def test_multilinear_rectilinear(max_threads): for dtype in [np.float64, np.float32]: x = np.linspace(0.0, 10.0, 5).astype(dtype) y = np.linspace(20.0, 30.0, 3).astype(dtype) @@ -46,6 +48,7 @@ def test_multilinear_rectilinear(): grids=grids, vals=zgrid.flatten(), method="linear", + max_threads=max_threads, ) for i in range(out_helper.size): diff --git a/test/test_multilinear_regular.py b/test/test_multilinear_regular.py index 4280869..b39d7df 100644 --- a/test/test_multilinear_regular.py +++ b/test/test_multilinear_regular.py @@ -1,8 +1,10 @@ import numpy as np +import pytest import interpn -def test_multilinear_regular(): +@pytest.mark.parametrize("max_threads", [None, 1], ids=["parallel", "serial"]) +def test_multilinear_regular(max_threads): for dtype in [np.float64, np.float32]: x = np.linspace(0.0, 10.0, 5).astype(dtype) y = np.linspace(20.0, 30.0, 3).astype(dtype) @@ -53,6 +55,7 @@ def test_multilinear_regular(): grids=grids, vals=zgrid.flatten(), method="linear", + max_threads=max_threads, ) for i in range(out_helper.size): diff --git a/test/test_nearest_rectilinear.py b/test/test_nearest_rectilinear.py index f47182a..9ece6f4 100644 --- a/test/test_nearest_rectilinear.py +++ b/test/test_nearest_rectilinear.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import interpn @@ -11,7 +12,8 @@ def _nearest_rectilinear_index(value: float, grid: np.ndarray) -> int: return idx if dt <= 0.5 else idx + 1 -def test_nearest_rectilinear(): +@pytest.mark.parametrize("max_threads", [None, 1], ids=["parallel", "serial"]) +def test_nearest_rectilinear(max_threads): for dtype in [np.float64, np.float32]: x = np.array([0.0, 1.0, 3.5, 4.0], dtype=dtype) y = np.array([-2.0, -0.5, 0.1], dtype=dtype) @@ -54,6 +56,7 @@ def test_nearest_rectilinear(): grids=grids, vals=zgrid.flatten(), method="nearest", + max_threads=max_threads, ) np.testing.assert_array_equal(out_helper, expected) diff --git a/test/test_nearest_regular.py b/test/test_nearest_regular.py index f1d64ca..9f22a8d 100644 --- a/test/test_nearest_regular.py +++ b/test/test_nearest_regular.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import interpn @@ -10,7 +11,8 @@ def _nearest_regular_index(value: float, start: float, step: float, size: int) - return min(loc if dt <= 0.5 else loc + 1, size - 1) -def test_nearest_regular(): +@pytest.mark.parametrize("max_threads", [None, 1], ids=["parallel", "serial"]) +def test_nearest_regular(max_threads): for dtype in [np.float64, np.float32]: x = np.linspace(0.0, 6.0, 4).astype(dtype) y = np.linspace(-3.0, 3.0, 3).astype(dtype) @@ -65,6 +67,7 @@ def test_nearest_regular(): grids=grids, vals=zgrid.flatten(), method="nearest", + max_threads=max_threads, ) np.testing.assert_array_equal(out_helper, expected)