diff --git a/resources/array.npz b/resources/array.npz new file mode 100644 index 0000000..151b5c0 Binary files /dev/null and b/resources/array.npz differ diff --git a/src/lib.rs b/src/lib.rs index 3caf11a..3a7b75b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -52,7 +52,7 @@ mod npy; mod npz; pub use crate::npy::{ - read_npy, write_npy, write_zeroed_npy, ReadDataError, ReadNpyError, ReadNpyExt, + read_npy, write_npy, write_zeroed_npy, read_npz, write_npz, ReadDataError, ReadNpyError, ReadNpyExt, ReadableElement, ViewDataError, ViewElement, ViewMutElement, ViewMutNpyExt, ViewNpyError, ViewNpyExt, WritableElement, WriteDataError, WriteNpyError, WriteNpyExt, }; diff --git a/src/npy/mod.rs b/src/npy/mod.rs index 7e246d1..7a6d770 100644 --- a/src/npy/mod.rs +++ b/src/npy/mod.rs @@ -65,6 +65,78 @@ where array.write_npy(std::fs::File::create(path)?) } +/// Writes an array to an `.npz` file at the specified path. +/// +/// This function will create the file if it does not exist, or overwrite it if +/// it does. +/// +/// This is a convenience function for using `File::create` followed by +/// wrapping the file handle with [`ndarray_npy::NpzWriter`] and using its +/// `new_compressed` and `add_array` methods to write to the created file. +/// +/// The array will be labled with the name "arr_0.npy" following numpy's conventions +/// for labeling unnamed arrays in `savez_compressed`. +/// +/// # Example +/// +/// ```no_run +/// use ndarray::array; +/// use ndarray_npy::write_npz; +/// # use ndarray_npy::WriteNpzError; +/// +/// let arr = array![[1, 2, 3], [4, 5, 6]]; +/// write_npz("array.npz", &arr)?; +/// # Ok::<_, WriteNpzError>(()) +/// ``` +pub fn write_npz(path: P, array: &ArrayBase) -> Result<(), crate::WriteNpzError> +where + P: AsRef, + S::Elem: WritableElement, + S: Data, + D: Dimension +{ + let file = std::fs::File::create(path) + .map_err(|e| crate::WriteNpzError::Npy(WriteNpyError::Io(e)))?; + let mut wtr = crate::NpzWriter::new_compressed(file); + wtr.add_array("arr_0.npy", array)?; + Ok(()) +} + +/// Read an array from a `.npz` file located at the specified path and name. +/// +/// This is a convience function for opening a file and using `NpzReader` to +/// extract one array from it. +/// +/// The name of a single array written to an `.npz` file using `write_npz` +/// will be "arr_0.npy", following numpy's conventions for labeling unnamed +/// arrays in `savez_compressed`. +/// +/// # Example +/// +/// ``` +/// use ndarray::Array2; +/// use ndarray_npy::read_npz; +/// # use ndarray_npy::ReadNpzError; +/// let arr: Array2 = read_npz("resources/array.npz", "arr_0.npy")?; +/// # println!("arr = {}", arr); +/// # Ok::<_, ReadNpzError>(()) +/// ``` +pub fn read_npz(path: P, name: N) -> Result, crate::ReadNpzError> +where + P: AsRef, + N: Into, + S::Elem: ReadableElement, + S: DataOwned, + D: Dimension, +{ + let file = std::fs::File::open(path) + .map_err(|e| crate::ReadNpzError::Npy(ReadNpyError::Io(e)))?; + let mut rdr = crate::NpzReader::new(file)?; + let name: String = name.into(); + let arr = rdr.by_name(&name)?; + Ok(arr) +} + /// Writes an `.npy` file (sparse if possible) with bitwise-zero-filled data. /// /// The `.npy` file represents an array with element type `A` and shape diff --git a/tests/examples.rs b/tests/examples.rs index 5f62784..d672c59 100644 --- a/tests/examples.rs +++ b/tests/examples.rs @@ -302,3 +302,30 @@ fn zeroed() { assert_eq!(arr, Array3::::zeros(SHAPE)); assert!(arr.is_standard_layout()); } + +#[test] +fn convenience_functions_round_trip_f64_standard() { + let mut arr = Array3::::zeros((2, 3, 4)); + for (i, elem) in arr.iter_mut().enumerate() { + *elem = (i as f64).sin() * std::f64::consts::PI; + } + + let tmp = tempfile::tempdir().unwrap(); + + // npy round trip + let npy_path = tmp.path().join("f64-example.npy"); + ndarray_npy::write_npy(&npy_path, &arr).unwrap(); + assert!(npy_path.exists()); + let rt_arr: Array3 = ndarray_npy::read_npy(&npy_path).unwrap(); + assert_eq!(arr, rt_arr); + + // npz round trip + let npz_path = tmp.path().join("f64-example.npz"); + ndarray_npy::write_npz(&npz_path, &arr).unwrap(); + assert!(npz_path.exists()); + let rtz_arr: Array3 = ndarray_npy::read_npz(&npz_path, "arr_0.npy").unwrap(); + assert_eq!(arr, rtz_arr); + tmp.close().unwrap(); +} + +