diff --git a/Cargo.lock b/Cargo.lock index 54c475f..3751de2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -642,6 +642,7 @@ dependencies = [ "libdeflater", "log", "lru", + "ndarray 0.15.6", "ndarray 0.16.0", "netcdf", "numpy", @@ -998,6 +999,7 @@ dependencies = [ "num-integer", "num-traits", "rawpointer", + "rayon", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 942f5a4..e1bf2a3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,8 @@ hdf5-sys = { package = "hdf5-metno-sys", version = "0.9.1" } log = "0.4" rayon = "1.10" ndarray = { version = "0.16", features = [ "rayon" ] } +# Remove when https://github.com/PyO3/rust-numpy/pull/439 is addressed +ndarray_0_15 = { package = "ndarray", version = "0.15", features = ["rayon"] } pyo3 = { version = "0.21", optional = true, features = ["anyhow", "auto-initialize", "abi3-py39"] } numpy = { version = "0.21.0", optional = true } netcdf = { version = "0.10.4", optional = true } diff --git a/src/python/mod.rs b/src/python/mod.rs index c5847f4..430cb1b 100644 --- a/src/python/mod.rs +++ b/src/python/mod.rs @@ -164,7 +164,7 @@ impl Dataset { let arr = arr.downcast::>().unwrap(); let mut v = unsafe { arr.as_array_mut() }; - v.mapv_inplace(|v| if v == cond { fv } else { v }); + v.par_mapv_inplace(|v| if v == cond { fv } else { v }); } }