Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ rust-version = "1.84"

[dependencies]
byteorder = "1.3.2"
ndarray = "0.16"
ndarray = "0.17.1"
num-complex-0_4 = { package = "num-complex", version = "0.4", optional = true }
num-traits = "0.2"
py_literal = "0.4"
Expand Down
17 changes: 14 additions & 3 deletions src/npy/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ where
pub fn write_npy<P, T>(path: P, array: &T) -> Result<(), WriteNpyError>
where
P: AsRef<std::path::Path>,
T: WriteNpyExt,
T: WriteNpyExt + ?Sized,
{
array.write_npy(BufWriter::new(File::create(path)?))
}
Expand Down Expand Up @@ -323,10 +323,9 @@ pub trait WriteNpyExt {
fn write_npy<W: io::Write>(&self, writer: W) -> Result<(), WriteNpyError>;
}

impl<A, S, D> WriteNpyExt for ArrayBase<S, D>
impl<A, D> WriteNpyExt for ArrayRef<A, D>
where
A: WritableElement,
S: Data<Elem = A>,
D: Dimension,
{
fn write_npy<W: io::Write>(&self, mut writer: W) -> Result<(), WriteNpyError> {
Expand Down Expand Up @@ -361,6 +360,18 @@ where
}
}

impl<A, S, D> WriteNpyExt for ArrayBase<S, D>
where
A: WritableElement,
S: Data<Elem = A>,
D: Dimension,
{
fn write_npy<W: io::Write>(&self, writer: W) -> Result<(), WriteNpyError> {
let arr: &ArrayRef<A, D> = self;
arr.write_npy(writer)
}
}

/// An error reading array data.
#[derive(Debug)]
pub enum ReadDataError {
Expand Down
42 changes: 21 additions & 21 deletions src/npz.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use crate::{
ReadNpyError, ReadNpyExt, ReadableElement, WritableElement, WriteNpyError, WriteNpyExt,
};
use crate::{ReadNpyError, ReadNpyExt, ReadableElement, WriteNpyError, WriteNpyExt};
use ndarray::prelude::*;
use ndarray::{Data, DataOwned};
use ndarray::DataOwned;
use std::error::Error;
use std::fmt;
use std::io::{BufWriter, Read, Seek, Write};
Expand Down Expand Up @@ -121,26 +119,28 @@ impl<W: Write + Seek> NpzWriter<W> {
///
/// To write a scalar value, create a zero-dimensional array using [`arr0`](ndarray::arr0) or
/// [`aview0`](ndarray::aview0).
pub fn add_array<N, S, D>(
&mut self,
name: N,
array: &ArrayBase<S, D>,
) -> Result<(), WriteNpzError>
pub fn add_array<N, T>(&mut self, name: N, array: &T) -> Result<(), WriteNpzError>
where
N: Into<String>,
S::Elem: WritableElement,
S: Data,
D: Dimension,
T: WriteNpyExt + ?Sized,
{
self.zip.start_file(name.into() + ".npy", self.options)?;
// Buffering when writing individual arrays is beneficial even when the
// underlying writer is `Cursor<Vec<u8>>` instead of a real file. The
// only exception I saw in testing was the "compressed, in-memory
// writer, standard layout case". See
// https://github.com/jturner314/ndarray-npy/issues/50#issuecomment-812802481
// for details.
array.write_npy(BufWriter::new(&mut self.zip))?;
Ok(())
fn inner<W, T>(npz: &mut NpzWriter<W>, name: String, array: &T) -> Result<(), WriteNpzError>
where
W: Write + Seek,
T: WriteNpyExt + ?Sized,
{
npz.zip.start_file(name + ".npy", npz.options)?;
// Buffering when writing individual arrays is beneficial even when the
// underlying writer is `Cursor<Vec<u8>>` instead of a real file. The
// only exception I saw in testing was the "compressed, in-memory
// writer, standard layout case". See
// https://github.com/jturner314/ndarray-npy/issues/50#issuecomment-812802481
// for details.
array.write_npy(BufWriter::new(&mut npz.zip))?;
Ok(())
}

inner(self, name.into(), array)
}

/// Calls [`.finish()`](ZipWriter::finish) on the zip file and
Expand Down
43 changes: 25 additions & 18 deletions tests/integration/examples.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ use std::fs::{self, File};
use std::io::{Read, Seek, SeekFrom, Write};
use std::mem;

#[track_caller]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this track_caller?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That way, if the assertion fails, the error message is more useful since the line number in the error message points to the specific line in the test that failed, not the internals of the assert_written_is_correct function. (Regardless, you can see the full backtrace with RUST_BACKTRACE=1 cargo test; it's just more convenient if the output of cargo test without RUST_BACKTRACE is also informative.)

fn assert_written_is_correct<T: WriteNpyExt + ?Sized>(arr: &T, correct: &[u8]) {
let mut writer = Vec::<u8>::new();
arr.write_npy(&mut writer).unwrap();
assert_eq!(&correct, &writer);
}

#[test]
fn write_f64_standard() {
#[cfg(target_endian = "little")]
Expand All @@ -21,13 +28,13 @@ fn write_f64_standard() {
let path = "resources/example_f64_big_endian_standard.npy";

let correct = fs::read(path).unwrap();
let mut writer = Vec::<u8>::new();
let mut arr = Array3::<f64>::zeros((2, 3, 4));
for (i, elem) in arr.iter_mut().enumerate() {
*elem = i as f64;
}
arr.write_npy(&mut writer).unwrap();
assert_eq!(&correct, &writer);
assert_written_is_correct(&arr, &correct);
let arr_ref: &ArrayRef3<f64> = &*arr;
assert_written_is_correct(arr_ref, &correct);
}

#[cfg(feature = "num-complex-0_4")]
Expand All @@ -39,15 +46,15 @@ fn write_c64_standard() {
let path = "resources/example_c64_big_endian_standard.npy";

let correct = fs::read(path).unwrap();
let mut writer = Vec::<u8>::new();
let mut arr = Array3::<Complex<f64>>::zeros((2, 3, 4));
for (i, elem) in arr.iter_mut().enumerate() {
// The `+ 0.` is necessary to get the same behavior as Python with
// respect to signed zeros.
*elem = Complex::new(i as f64, -(i as f64) + 0.);
}
arr.write_npy(&mut writer).unwrap();
assert_eq!(&correct, &writer);
assert_written_is_correct(&arr, &correct);
let arr_ref: &ArrayRef3<Complex<f64>> = &*arr;
assert_written_is_correct(arr_ref, &correct);
}

#[test]
Expand All @@ -58,13 +65,13 @@ fn write_f64_fortran() {
let path = "resources/example_f64_big_endian_fortran.npy";

let correct = fs::read(path).unwrap();
let mut writer = Vec::<u8>::new();
let mut arr = Array3::<f64>::zeros((2, 3, 4).f());
for (i, elem) in arr.iter_mut().enumerate() {
*elem = i as f64;
}
arr.write_npy(&mut writer).unwrap();
assert_eq!(&correct[..], &writer[..]);
assert_written_is_correct(&arr, &correct);
let arr_ref: &ArrayRef3<f64> = &*arr;
assert_written_is_correct(arr_ref, &correct);
}

#[cfg(feature = "num-complex-0_4")]
Expand All @@ -76,15 +83,15 @@ fn write_c64_fortran() {
let path = "resources/example_c64_big_endian_fortran.npy";

let correct = fs::read(path).unwrap();
let mut writer = Vec::<u8>::new();
let mut arr = Array3::<Complex<f64>>::zeros((2, 3, 4).f());
for (i, elem) in arr.iter_mut().enumerate() {
// The `+ 0.` is necessary to get the same behavior as Python with
// respect to signed zeros.
*elem = Complex::new(i as f64, -(i as f64) + 0.);
}
arr.write_npy(&mut writer).unwrap();
assert_eq!(&correct[..], &writer[..]);
assert_written_is_correct(&arr, &correct);
let arr_ref: &ArrayRef3<Complex<f64>> = &*arr;
assert_written_is_correct(arr_ref, &correct);
}

#[test]
Expand All @@ -95,15 +102,15 @@ fn write_f64_discontiguous() {
let path = "resources/example_f64_big_endian_standard.npy";

let correct = fs::read(path).unwrap();
let mut writer = Vec::<u8>::new();
let mut arr = Array3::<f64>::zeros((3, 4, 4));
arr.slice_axis_inplace(Axis(1), Slice::new(0, None, 2));
arr.swap_axes(0, 1);
for (i, elem) in arr.iter_mut().enumerate() {
*elem = i as f64;
}
arr.write_npy(&mut writer).unwrap();
assert_eq!(&correct, &writer);
assert_written_is_correct(&arr, &correct);
let arr_ref: &ArrayRef3<f64> = &*arr;
assert_written_is_correct(arr_ref, &correct);
}

#[cfg(feature = "num-complex-0_4")]
Expand All @@ -115,7 +122,6 @@ fn write_c64_discontiguous() {
let path = "resources/example_c64_big_endian_standard.npy";

let correct = fs::read(path).unwrap();
let mut writer = Vec::<u8>::new();
let mut arr = Array3::<Complex<f64>>::zeros((3, 4, 4));
arr.slice_axis_inplace(Axis(1), Slice::new(0, None, 2));
arr.swap_axes(0, 1);
Expand All @@ -124,8 +130,9 @@ fn write_c64_discontiguous() {
// respect to signed zeros.
*elem = Complex::new(i as f64, -(i as f64) + 0.);
}
arr.write_npy(&mut writer).unwrap();
assert_eq!(&correct, &writer);
assert_written_is_correct(&arr, &correct);
let arr_ref: &ArrayRef3<Complex<f64>> = &*arr;
assert_written_is_correct(arr_ref, &correct);
}

#[test]
Expand Down
16 changes: 13 additions & 3 deletions tests/integration/npz.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! .npz examples.

use ndarray::{array, Array2};
use ndarray::{array, Array2, ArrayRef2};
use ndarray_npy::{NpzReader, NpzWriter};
use std::{error::Error, io::Cursor};

Expand All @@ -10,21 +10,23 @@ fn round_trip_npz() -> Result<(), Box<dyn Error>> {

let arr1 = array![[1i32, 3, 0], [4, 7, -1]];
let arr2 = array![[9i32, 6], [-5, 2], [3, -1]];
let arr3: &ArrayRef2<i32> = &arr1;

{
let mut writer = NpzWriter::new(Cursor::new(&mut buf));
writer.add_array("arr1", &arr1)?;
writer.add_array("arr2", &arr2)?;
writer.add_array("arr3", arr3)?;
writer.finish()?;
}

{
let mut reader = NpzReader::new(Cursor::new(&buf))?;
assert!(!reader.is_empty());
assert_eq!(reader.len(), 2);
assert_eq!(reader.len(), 3);
assert_eq!(
reader.names()?,
vec!["arr1".to_string(), "arr2".to_string()],
vec!["arr1".to_string(), "arr2".to_string(), "arr3".to_string()],
);
{
let by_name: Array2<i32> = reader.by_name("arr1")?;
Expand All @@ -42,6 +44,14 @@ fn round_trip_npz() -> Result<(), Box<dyn Error>> {
let by_name: Array2<i32> = reader.by_name("arr2.npy")?;
assert_eq!(by_name, arr2);
}
{
let by_name: Array2<i32> = reader.by_name("arr3")?;
assert_eq!(*by_name, arr3);
}
{
let by_name: Array2<i32> = reader.by_name("arr3.npy")?;
assert_eq!(*by_name, arr3);
}
{
let res: Result<Array2<i32>, _> = reader.by_name("arr1.npy.npy");
assert!(res.is_err());
Expand Down