diff --git a/resources/array.npz b/resources/array.npz
new file mode 100644
index 0000000..151b5c0
Binary files /dev/null and b/resources/array.npz differ
diff --git a/src/lib.rs b/src/lib.rs
index 3caf11a..3a7b75b 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -52,7 +52,7 @@ mod npy;
mod npz;
pub use crate::npy::{
- read_npy, write_npy, write_zeroed_npy, ReadDataError, ReadNpyError, ReadNpyExt,
+ read_npy, write_npy, write_zeroed_npy, read_npz, write_npz, ReadDataError, ReadNpyError, ReadNpyExt,
ReadableElement, ViewDataError, ViewElement, ViewMutElement, ViewMutNpyExt, ViewNpyError,
ViewNpyExt, WritableElement, WriteDataError, WriteNpyError, WriteNpyExt,
};
diff --git a/src/npy/mod.rs b/src/npy/mod.rs
index 7e246d1..7a6d770 100644
--- a/src/npy/mod.rs
+++ b/src/npy/mod.rs
@@ -65,6 +65,78 @@ where
array.write_npy(std::fs::File::create(path)?)
}
+/// Writes an array to an `.npz` file at the specified path.
+///
+/// This function will create the file if it does not exist, or overwrite it if
+/// it does.
+///
+/// This is a convenience function for using `File::create` followed by
+/// wrapping the file handle with [`ndarray_npy::NpzWriter`] and using its
+/// `new_compressed` and `add_array` methods to write to the created file.
+///
+/// The array will be labled with the name "arr_0.npy" following numpy's conventions
+/// for labeling unnamed arrays in `savez_compressed`.
+///
+/// # Example
+///
+/// ```no_run
+/// use ndarray::array;
+/// use ndarray_npy::write_npz;
+/// # use ndarray_npy::WriteNpzError;
+///
+/// let arr = array![[1, 2, 3], [4, 5, 6]];
+/// write_npz("array.npz", &arr)?;
+/// # Ok::<_, WriteNpzError>(())
+/// ```
+pub fn write_npz
(path: P, array: &ArrayBase) -> Result<(), crate::WriteNpzError>
+where
+ P: AsRef,
+ S::Elem: WritableElement,
+ S: Data,
+ D: Dimension
+{
+ let file = std::fs::File::create(path)
+ .map_err(|e| crate::WriteNpzError::Npy(WriteNpyError::Io(e)))?;
+ let mut wtr = crate::NpzWriter::new_compressed(file);
+ wtr.add_array("arr_0.npy", array)?;
+ Ok(())
+}
+
+/// Read an array from a `.npz` file located at the specified path and name.
+///
+/// This is a convience function for opening a file and using `NpzReader` to
+/// extract one array from it.
+///
+/// The name of a single array written to an `.npz` file using `write_npz`
+/// will be "arr_0.npy", following numpy's conventions for labeling unnamed
+/// arrays in `savez_compressed`.
+///
+/// # Example
+///
+/// ```
+/// use ndarray::Array2;
+/// use ndarray_npy::read_npz;
+/// # use ndarray_npy::ReadNpzError;
+/// let arr: Array2 = read_npz("resources/array.npz", "arr_0.npy")?;
+/// # println!("arr = {}", arr);
+/// # Ok::<_, ReadNpzError>(())
+/// ```
+pub fn read_npz(path: P, name: N) -> Result, crate::ReadNpzError>
+where
+ P: AsRef,
+ N: Into,
+ S::Elem: ReadableElement,
+ S: DataOwned,
+ D: Dimension,
+{
+ let file = std::fs::File::open(path)
+ .map_err(|e| crate::ReadNpzError::Npy(ReadNpyError::Io(e)))?;
+ let mut rdr = crate::NpzReader::new(file)?;
+ let name: String = name.into();
+ let arr = rdr.by_name(&name)?;
+ Ok(arr)
+}
+
/// Writes an `.npy` file (sparse if possible) with bitwise-zero-filled data.
///
/// The `.npy` file represents an array with element type `A` and shape
diff --git a/tests/examples.rs b/tests/examples.rs
index 5f62784..d672c59 100644
--- a/tests/examples.rs
+++ b/tests/examples.rs
@@ -302,3 +302,30 @@ fn zeroed() {
assert_eq!(arr, Array3::::zeros(SHAPE));
assert!(arr.is_standard_layout());
}
+
+#[test]
+fn convenience_functions_round_trip_f64_standard() {
+ let mut arr = Array3::::zeros((2, 3, 4));
+ for (i, elem) in arr.iter_mut().enumerate() {
+ *elem = (i as f64).sin() * std::f64::consts::PI;
+ }
+
+ let tmp = tempfile::tempdir().unwrap();
+
+ // npy round trip
+ let npy_path = tmp.path().join("f64-example.npy");
+ ndarray_npy::write_npy(&npy_path, &arr).unwrap();
+ assert!(npy_path.exists());
+ let rt_arr: Array3 = ndarray_npy::read_npy(&npy_path).unwrap();
+ assert_eq!(arr, rt_arr);
+
+ // npz round trip
+ let npz_path = tmp.path().join("f64-example.npz");
+ ndarray_npy::write_npz(&npz_path, &arr).unwrap();
+ assert!(npz_path.exists());
+ let rtz_arr: Array3 = ndarray_npy::read_npz(&npz_path, "arr_0.npy").unwrap();
+ assert_eq!(arr, rtz_arr);
+ tmp.close().unwrap();
+}
+
+