diff --git a/Cargo.toml b/Cargo.toml index 35a85cf..78aa7e6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ rust-version = "1.84" [dependencies] byteorder = "1.3.2" -ndarray = "0.16" +ndarray = "0.17.1" num-complex-0_4 = { package = "num-complex", version = "0.4", optional = true } num-traits = "0.2" py_literal = "0.4" diff --git a/src/npy/mod.rs b/src/npy/mod.rs index 7f3f0ca..cb89692 100644 --- a/src/npy/mod.rs +++ b/src/npy/mod.rs @@ -60,7 +60,7 @@ where pub fn write_npy(path: P, array: &T) -> Result<(), WriteNpyError> where P: AsRef, - T: WriteNpyExt, + T: WriteNpyExt + ?Sized, { array.write_npy(BufWriter::new(File::create(path)?)) } @@ -323,10 +323,9 @@ pub trait WriteNpyExt { fn write_npy(&self, writer: W) -> Result<(), WriteNpyError>; } -impl WriteNpyExt for ArrayBase +impl WriteNpyExt for ArrayRef where A: WritableElement, - S: Data, D: Dimension, { fn write_npy(&self, mut writer: W) -> Result<(), WriteNpyError> { @@ -361,6 +360,18 @@ where } } +impl WriteNpyExt for ArrayBase +where + A: WritableElement, + S: Data, + D: Dimension, +{ + fn write_npy(&self, writer: W) -> Result<(), WriteNpyError> { + let arr: &ArrayRef = self; + arr.write_npy(writer) + } +} + /// An error reading array data. #[derive(Debug)] pub enum ReadDataError { diff --git a/src/npz.rs b/src/npz.rs index 5681ceb..c71fe61 100644 --- a/src/npz.rs +++ b/src/npz.rs @@ -1,8 +1,6 @@ -use crate::{ - ReadNpyError, ReadNpyExt, ReadableElement, WritableElement, WriteNpyError, WriteNpyExt, -}; +use crate::{ReadNpyError, ReadNpyExt, ReadableElement, WriteNpyError, WriteNpyExt}; use ndarray::prelude::*; -use ndarray::{Data, DataOwned}; +use ndarray::DataOwned; use std::error::Error; use std::fmt; use std::io::{BufWriter, Read, Seek, Write}; @@ -121,26 +119,28 @@ impl NpzWriter { /// /// To write a scalar value, create a zero-dimensional array using [`arr0`](ndarray::arr0) or /// [`aview0`](ndarray::aview0). - pub fn add_array( - &mut self, - name: N, - array: &ArrayBase, - ) -> Result<(), WriteNpzError> + pub fn add_array(&mut self, name: N, array: &T) -> Result<(), WriteNpzError> where N: Into, - S::Elem: WritableElement, - S: Data, - D: Dimension, + T: WriteNpyExt + ?Sized, { - self.zip.start_file(name.into() + ".npy", self.options)?; - // Buffering when writing individual arrays is beneficial even when the - // underlying writer is `Cursor>` instead of a real file. The - // only exception I saw in testing was the "compressed, in-memory - // writer, standard layout case". See - // https://github.com/jturner314/ndarray-npy/issues/50#issuecomment-812802481 - // for details. - array.write_npy(BufWriter::new(&mut self.zip))?; - Ok(()) + fn inner(npz: &mut NpzWriter, name: String, array: &T) -> Result<(), WriteNpzError> + where + W: Write + Seek, + T: WriteNpyExt + ?Sized, + { + npz.zip.start_file(name + ".npy", npz.options)?; + // Buffering when writing individual arrays is beneficial even when the + // underlying writer is `Cursor>` instead of a real file. The + // only exception I saw in testing was the "compressed, in-memory + // writer, standard layout case". See + // https://github.com/jturner314/ndarray-npy/issues/50#issuecomment-812802481 + // for details. + array.write_npy(BufWriter::new(&mut npz.zip))?; + Ok(()) + } + + inner(self, name.into(), array) } /// Calls [`.finish()`](ZipWriter::finish) on the zip file and diff --git a/tests/integration/examples.rs b/tests/integration/examples.rs index 2b1aee6..6c02fb7 100644 --- a/tests/integration/examples.rs +++ b/tests/integration/examples.rs @@ -13,6 +13,13 @@ use std::fs::{self, File}; use std::io::{Read, Seek, SeekFrom, Write}; use std::mem; +#[track_caller] +fn assert_written_is_correct(arr: &T, correct: &[u8]) { + let mut writer = Vec::::new(); + arr.write_npy(&mut writer).unwrap(); + assert_eq!(&correct, &writer); +} + #[test] fn write_f64_standard() { #[cfg(target_endian = "little")] @@ -21,13 +28,13 @@ fn write_f64_standard() { let path = "resources/example_f64_big_endian_standard.npy"; let correct = fs::read(path).unwrap(); - let mut writer = Vec::::new(); let mut arr = Array3::::zeros((2, 3, 4)); for (i, elem) in arr.iter_mut().enumerate() { *elem = i as f64; } - arr.write_npy(&mut writer).unwrap(); - assert_eq!(&correct, &writer); + assert_written_is_correct(&arr, &correct); + let arr_ref: &ArrayRef3 = &*arr; + assert_written_is_correct(arr_ref, &correct); } #[cfg(feature = "num-complex-0_4")] @@ -39,15 +46,15 @@ fn write_c64_standard() { let path = "resources/example_c64_big_endian_standard.npy"; let correct = fs::read(path).unwrap(); - let mut writer = Vec::::new(); let mut arr = Array3::>::zeros((2, 3, 4)); for (i, elem) in arr.iter_mut().enumerate() { // The `+ 0.` is necessary to get the same behavior as Python with // respect to signed zeros. *elem = Complex::new(i as f64, -(i as f64) + 0.); } - arr.write_npy(&mut writer).unwrap(); - assert_eq!(&correct, &writer); + assert_written_is_correct(&arr, &correct); + let arr_ref: &ArrayRef3> = &*arr; + assert_written_is_correct(arr_ref, &correct); } #[test] @@ -58,13 +65,13 @@ fn write_f64_fortran() { let path = "resources/example_f64_big_endian_fortran.npy"; let correct = fs::read(path).unwrap(); - let mut writer = Vec::::new(); let mut arr = Array3::::zeros((2, 3, 4).f()); for (i, elem) in arr.iter_mut().enumerate() { *elem = i as f64; } - arr.write_npy(&mut writer).unwrap(); - assert_eq!(&correct[..], &writer[..]); + assert_written_is_correct(&arr, &correct); + let arr_ref: &ArrayRef3 = &*arr; + assert_written_is_correct(arr_ref, &correct); } #[cfg(feature = "num-complex-0_4")] @@ -76,15 +83,15 @@ fn write_c64_fortran() { let path = "resources/example_c64_big_endian_fortran.npy"; let correct = fs::read(path).unwrap(); - let mut writer = Vec::::new(); let mut arr = Array3::>::zeros((2, 3, 4).f()); for (i, elem) in arr.iter_mut().enumerate() { // The `+ 0.` is necessary to get the same behavior as Python with // respect to signed zeros. *elem = Complex::new(i as f64, -(i as f64) + 0.); } - arr.write_npy(&mut writer).unwrap(); - assert_eq!(&correct[..], &writer[..]); + assert_written_is_correct(&arr, &correct); + let arr_ref: &ArrayRef3> = &*arr; + assert_written_is_correct(arr_ref, &correct); } #[test] @@ -95,15 +102,15 @@ fn write_f64_discontiguous() { let path = "resources/example_f64_big_endian_standard.npy"; let correct = fs::read(path).unwrap(); - let mut writer = Vec::::new(); let mut arr = Array3::::zeros((3, 4, 4)); arr.slice_axis_inplace(Axis(1), Slice::new(0, None, 2)); arr.swap_axes(0, 1); for (i, elem) in arr.iter_mut().enumerate() { *elem = i as f64; } - arr.write_npy(&mut writer).unwrap(); - assert_eq!(&correct, &writer); + assert_written_is_correct(&arr, &correct); + let arr_ref: &ArrayRef3 = &*arr; + assert_written_is_correct(arr_ref, &correct); } #[cfg(feature = "num-complex-0_4")] @@ -115,7 +122,6 @@ fn write_c64_discontiguous() { let path = "resources/example_c64_big_endian_standard.npy"; let correct = fs::read(path).unwrap(); - let mut writer = Vec::::new(); let mut arr = Array3::>::zeros((3, 4, 4)); arr.slice_axis_inplace(Axis(1), Slice::new(0, None, 2)); arr.swap_axes(0, 1); @@ -124,8 +130,9 @@ fn write_c64_discontiguous() { // respect to signed zeros. *elem = Complex::new(i as f64, -(i as f64) + 0.); } - arr.write_npy(&mut writer).unwrap(); - assert_eq!(&correct, &writer); + assert_written_is_correct(&arr, &correct); + let arr_ref: &ArrayRef3> = &*arr; + assert_written_is_correct(arr_ref, &correct); } #[test] diff --git a/tests/integration/npz.rs b/tests/integration/npz.rs index 9be26d3..ab042a4 100644 --- a/tests/integration/npz.rs +++ b/tests/integration/npz.rs @@ -1,6 +1,6 @@ //! .npz examples. -use ndarray::{array, Array2}; +use ndarray::{array, Array2, ArrayRef2}; use ndarray_npy::{NpzReader, NpzWriter}; use std::{error::Error, io::Cursor}; @@ -10,21 +10,23 @@ fn round_trip_npz() -> Result<(), Box> { let arr1 = array![[1i32, 3, 0], [4, 7, -1]]; let arr2 = array![[9i32, 6], [-5, 2], [3, -1]]; + let arr3: &ArrayRef2 = &arr1; { let mut writer = NpzWriter::new(Cursor::new(&mut buf)); writer.add_array("arr1", &arr1)?; writer.add_array("arr2", &arr2)?; + writer.add_array("arr3", arr3)?; writer.finish()?; } { let mut reader = NpzReader::new(Cursor::new(&buf))?; assert!(!reader.is_empty()); - assert_eq!(reader.len(), 2); + assert_eq!(reader.len(), 3); assert_eq!( reader.names()?, - vec!["arr1".to_string(), "arr2".to_string()], + vec!["arr1".to_string(), "arr2".to_string(), "arr3".to_string()], ); { let by_name: Array2 = reader.by_name("arr1")?; @@ -42,6 +44,14 @@ fn round_trip_npz() -> Result<(), Box> { let by_name: Array2 = reader.by_name("arr2.npy")?; assert_eq!(by_name, arr2); } + { + let by_name: Array2 = reader.by_name("arr3")?; + assert_eq!(*by_name, arr3); + } + { + let by_name: Array2 = reader.by_name("arr3.npy")?; + assert_eq!(*by_name, arr3); + } { let res: Result, _> = reader.by_name("arr1.npy.npy"); assert!(res.is_err());