From 0eac6d5501487b44e5b0dcf47a53b54d2ab46950 Mon Sep 17 00:00:00 2001 From: Jim Turner Date: Sun, 16 Nov 2025 23:00:33 -0500 Subject: [PATCH] Introduce new Layout type for clarity --- src/npy/header.rs | 43 +++++++++++++++++++++++++++++++++---------- src/npy/mod.rs | 20 ++++++++++---------- 2 files changed, 43 insertions(+), 20 deletions(-) diff --git a/src/npy/header.rs b/src/npy/header.rs index bf3a3f3..2a1e5f4 100644 --- a/src/npy/header.rs +++ b/src/npy/header.rs @@ -316,10 +316,26 @@ impl From for WriteHeaderError { } } +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub enum Layout { + /// Standard layout (C order). + Standard, + /// Fortran layout. + Fortran, +} + +impl Layout { + /// Returns `true` if the layout is [`Fortran`](Self::Fortran). + #[inline] + pub fn is_fortran(&self) -> bool { + matches!(*self, Layout::Fortran) + } +} + #[derive(Clone, Debug)] pub struct Header { pub type_descriptor: PyValue, - pub fortran_order: bool, + pub layout: Layout, pub shape: Vec, } @@ -333,7 +349,7 @@ impl Header { fn from_py_value(value: PyValue) -> Result { if let PyValue::Dict(dict) = value { let mut type_descriptor: Option = None; - let mut fortran_order: Option = None; + let mut is_fortran: Option = None; let mut shape: Option> = None; for (key, value) in dict { match key { @@ -342,7 +358,7 @@ impl Header { } PyValue::String(ref k) if k == "fortran_order" => { if let PyValue::Boolean(b) = value { - fortran_order = Some(b); + is_fortran = Some(b); } else { return Err(ParseHeaderError::IllegalValue { key: "fortran_order".to_owned(), @@ -370,12 +386,19 @@ impl Header { k => return Err(ParseHeaderError::UnknownKey(k)), } } - match (type_descriptor, fortran_order, shape) { - (Some(type_descriptor), Some(fortran_order), Some(shape)) => Ok(Header { - type_descriptor, - fortran_order, - shape, - }), + match (type_descriptor, is_fortran, shape) { + (Some(type_descriptor), Some(is_fortran), Some(shape)) => { + let layout = if is_fortran { + Layout::Fortran + } else { + Layout::Standard + }; + Ok(Header { + type_descriptor, + layout, + shape, + }) + } (None, _, _) => Err(ParseHeaderError::MissingKey("descr".to_owned())), (_, None, _) => Err(ParseHeaderError::MissingKey("fortran_order".to_owned())), (_, _, None) => Err(ParseHeaderError::MissingKey("shaper".to_owned())), @@ -433,7 +456,7 @@ impl Header { ), ( PyValue::String("fortran_order".into()), - PyValue::Boolean(self.fortran_order), + PyValue::Boolean(self.layout.is_fortran()), ), ( PyValue::String("shape".into()), diff --git a/src/npy/mod.rs b/src/npy/mod.rs index cb89692..e8b30f8 100644 --- a/src/npy/mod.rs +++ b/src/npy/mod.rs @@ -2,7 +2,7 @@ mod elements; pub mod header; use self::header::{ - FormatHeaderError, Header, ParseHeaderError, ReadHeaderError, WriteHeaderError, + FormatHeaderError, Header, Layout, ParseHeaderError, ReadHeaderError, WriteHeaderError, }; use ndarray::prelude::*; use ndarray::{Data, DataOwned, IntoDimension}; @@ -170,7 +170,7 @@ where .expect("overflow converting length of data to u64"); Header { type_descriptor: A::type_descriptor(), - fortran_order: false, + layout: Layout::Standard, shape: dim.as_array_view().to_vec(), } .write(file)?; @@ -329,10 +329,10 @@ where D: Dimension, { fn write_npy(&self, mut writer: W) -> Result<(), WriteNpyError> { - let write_contiguous = |mut writer: W, fortran_order: bool| { + let write_contiguous = |mut writer: W, layout: Layout| { Header { type_descriptor: A::type_descriptor(), - fortran_order, + layout, shape: self.shape().to_owned(), } .write(&mut writer)?; @@ -341,13 +341,13 @@ where Ok(()) }; if self.is_standard_layout() { - write_contiguous(writer, false) + write_contiguous(writer, Layout::Standard) } else if self.view().reversed_axes().is_standard_layout() { - write_contiguous(writer, true) + write_contiguous(writer, Layout::Fortran) } else { Header { type_descriptor: A::type_descriptor(), - fortran_order: false, + layout: Layout::Standard, shape: self.shape().to_owned(), } .write(&mut writer)?; @@ -577,7 +577,7 @@ where let ndim = shape.ndim(); let len = shape_length_checked::(&shape).ok_or(ReadNpyError::LengthOverflow)?; let data = A::read_to_end_exact_vec(&mut reader, &header.type_descriptor, len)?; - ArrayBase::from_shape_vec(shape.set_f(header.fortran_order), data) + ArrayBase::from_shape_vec(shape.set_f(header.layout.is_fortran()), data) .unwrap() .into_dimensionality() .map_err(|_| ReadNpyError::WrongNdim(D::NDIM, ndim)) @@ -821,7 +821,7 @@ where let ndim = shape.ndim(); let len = shape_length_checked::(&shape).ok_or(ViewNpyError::LengthOverflow)?; let data = A::bytes_as_slice(reader, &header.type_descriptor, len)?; - ArrayView::from_shape(shape.set_f(header.fortran_order), data) + ArrayView::from_shape(shape.set_f(header.layout.is_fortran()), data) .unwrap() .into_dimensionality() .map_err(|_| ViewNpyError::WrongNdim(D::NDIM, ndim)) @@ -841,7 +841,7 @@ where let len = shape_length_checked::(&shape).ok_or(ViewNpyError::LengthOverflow)?; let mid = buf.len() - reader.len(); let data = A::bytes_as_mut_slice(&mut buf[mid..], &header.type_descriptor, len)?; - ArrayViewMut::from_shape(shape.set_f(header.fortran_order), data) + ArrayViewMut::from_shape(shape.set_f(header.layout.is_fortran()), data) .unwrap() .into_dimensionality() .map_err(|_| ViewNpyError::WrongNdim(D::NDIM, ndim))