From 3ade2bbcaab57a0fee3cd74faf53604a593fdf49 Mon Sep 17 00:00:00 2001 From: Simon25772 Date: Wed, 23 Apr 2025 15:50:20 +0800 Subject: [PATCH 01/21] test: add tests --- README.md | 53 +++++++- src/fmt.rs | 68 ++++++++++- src/lib.rs | 240 +++++++++++++++++++++++++++++++++---- src/transform/broadcast.rs | 49 +++++++- src/transform/index.rs | 93 ++++++++++++-- src/transform/merge.rs | 197 ++++++++++++++++++++++++++---- src/transform/slice.rs | 105 ++++++++++++++-- src/transform/split.rs | 69 ++++++++++- src/transform/tile.rs | 116 ++++++++++++++++-- src/transform/transpose.rs | 57 ++++++++- 10 files changed, 969 insertions(+), 78 deletions(-) diff --git a/README.md b/README.md index a9c80a9..ea7e407 100644 --- a/README.md +++ b/README.md @@ -12,4 +12,55 @@ ![GitHub contributors](https://img.shields.io/github/contributors/InfiniTensor/ndarray-layout) ![GitHub commit activity](https://img.shields.io/github/commit-activity/m/InfiniTensor/ndarray-layout) -This crate provides definitions and transformations for multi-dimensional array data layouts. + +ndarray-layout 是一个用于处理多维数组布局的 Rust 库,它提供了 ArrayLayout 结构体,用于高效管理和操作多维数组的元信息,如形状、步长和偏移量等。这个库在处理多维数组时,提供了灵活且高效的布局管理方式,能够满足不同场景下对数组布局的操作需求。 +## 主要功能特点 +### 多维数组布局管理: + +* ArrayLayout 结构体支持指定任意维度的数组布局,通过 new 方法可以创建具有指定形状、步长和偏移量的布局。 +* 提供 new_contiguous 方法,用于创建连续的数组布局,支持大端序(BigEndian)和小端序(LittleEndian)两种存储顺序。 + +### 元信息访问: + +* 提供便捷的方法来访问数组布局的元信息,如 ndim、offset、shape 和 strides 等。 +* 支持计算数组元素的偏移量和数据范围,方便进行内存访问和数据处理。 + +### 布局操作功能: + +* 提供多种布局变换方法,如 index、tile、transpose、merge 和 slice 等,方便对数组布局进行各种变换操作。 + +## 使用示例 + +```rust +use ndarray_layout::{ArrayLayout, BroadcastArg}; + +// 创建一个新的 ArrayLayout 实例 +// 形状为 [1, 2, 3],步长为 [12, 4, 1],偏移量为 0 +let layout = ArrayLayout::<3>::new(&[1, 2, 3], &[12, 4, 1], 0); + +// 验证初始的形状和步长 +assert_eq!(layout.shape(), &[1, 2, 3]); +assert_eq!(layout.strides(), &[12, 4, 1]); +assert_eq!(layout.offset(), 0); + +// 对第 0 维进行广播变换,广播次数为 4 +let broadcasted_layout = layout.broadcast(0, 4); + +// 验证广播变换后的形状和步长 +assert_eq!(broadcasted_layout.shape(), &[4, 2, 3]); +assert_eq!(broadcasted_layout.strides(), &[0, 4, 1]); +assert_eq!(broadcasted_layout.offset(), 0); + +// 一次对多个阶进行广播变换 +let args = [ + BroadcastArg { axis: 0, times: 4 }, + BroadcastArg { axis: 1, times: 3 } +]; +let multi_broadcasted_layout = layout.broadcast_many(&args); + +// 验证多次广播变换后的形状和步长 +assert_eq!(multi_broadcasted_layout.shape(), &[4, 3, 3]); +assert_eq!(multi_broadcasted_layout.strides(), &[0, 0, 1]); +assert_eq!(multi_broadcasted_layout.offset(), 0); +``` + diff --git a/src/fmt.rs b/src/fmt.rs index 501c319..026225c 100644 --- a/src/fmt.rs +++ b/src/fmt.rs @@ -1,51 +1,91 @@ -use crate::ArrayLayout; +// 引入 crate 中的 ArrayLayout 结构体 +use crate::ArrayLayout; +// 引入标准库中的 fmt 模块,用于格式化输出 use std::fmt; +/// 为 ArrayLayout 结构体实现格式化相关方法 impl ArrayLayout { /// 高维数组格式化。 /// + /// 该函数根据数组布局信息,将数组元素按照不同维度进行格式化输出。 + /// /// # Safety /// /// 这个函数从对裸指针解引用以获得要格式化的数组元素。 + /// 调用者需要确保指针 `ptr` 有效,并且指向的内存区域包含足够的元素, + /// 同时偏移量计算不会导致越界访问。 + /// + /// # 参数 + /// - `f`: 用于格式化输出的 `fmt::Formatter` 引用。 + /// - `ptr`: 指向要格式化的数组元素的裸指针。 + /// + /// # 返回值 + /// 如果格式化成功,返回 `Ok(())`;否则返回 `fmt::Error`。 pub unsafe fn write_array( &self, f: &mut fmt::Formatter, ptr: *const T, ) -> fmt::Result { + // 根据数组的维度数量进行不同的格式化处理 match self.ndim() { + // 处理 0 维数组 0 => { + // 写入格式化后的数组信息,从指针偏移处读取元素 write!(f, "array<> = [{}]", unsafe { ptr.byte_offset(self.offset()).read_unaligned() }) } + // 处理 1 维数组 1 => { + // 解构出数组的形状和步长 let &[n] = self.shape() else { unreachable!() }; let &[s] = self.strides() else { unreachable!() }; + // 写入数组标题 writeln!(f, "array<{n}>[")?; + // 计算指针偏移 let ptr = unsafe { ptr.byte_offset(self.offset()) }; + // 遍历数组元素并写入格式化后的信息 for i in 0..n as isize { writeln!(f, " {}", unsafe { ptr.byte_offset(i * s).read_unaligned() })? } + // 写入数组结束符 writeln!(f, "]")?; Ok(()) } + // 处理多维数组 _ => { + // 生成数组标题 let mut title = "array<".to_string(); for d in self.shape() { title.push_str(&format!("{d}x")) } + // 移除标题末尾多余的 'x' assert_eq!(title.pop(), Some('x')); title.push('>'); + // 创建一个栈用于存储索引信息 let mut stack = Vec::with_capacity(self.ndim() - 2); + // 递归调用 write_recursive 方法进行格式化 self.write_recursive(f, ptr, &title, &mut stack) } } } + /// 递归地格式化多维数组。 + /// + /// 该函数通过递归的方式处理多维数组的不同维度,将数组元素格式化输出。 + /// + /// # 参数 + /// - `f`: 用于格式化输出的 `fmt::Formatter` 引用。 + /// - `ptr`: 指向要格式化的数组元素的裸指针。 + /// - `title`: 数组的标题字符串。 + /// - `indices`: 用于存储当前维度索引的可变向量。 + /// + /// # 返回值 + /// 如果格式化成功,返回 `Ok(())`;否则返回 `fmt::Error`。 fn write_recursive( &self, f: &mut fmt::Formatter, @@ -53,20 +93,27 @@ impl ArrayLayout { title: &str, indices: &mut Vec, ) -> fmt::Result { + // 根据数组的形状进行不同的格式化处理 match *self.shape() { + // 空形状或单元素形状不应该出现,触发 unreachable! 宏 [] | [_] => unreachable!(), + // 处理 2 维数组 [rows, cols] => { + // 写入数组标题和索引信息 write!(f, "{title}[")?; for i in indices { write!(f, "{i}, ")? } writeln!(f, "..]")?; + // 解构出数组的行步长和列步长 let &[rs, cs] = self.strides() else { unreachable!() }; + // 计算指针偏移 let ptr = unsafe { ptr.byte_offset(self.offset()) }; + // 遍历二维数组的行和列,写入格式化后的元素信息 for r in 0..rows as isize { for c in 0..cols as isize { write!(f, "{} ", unsafe { @@ -76,10 +123,15 @@ impl ArrayLayout { writeln!(f)? } } + // 处理多维数组的批量维度 [batch, ..] => { + // 遍历批量维度 for i in 0..batch { + // 将当前索引压入栈中 indices.push(i); + // 递归调用 write_recursive 方法处理下一个维度 self.index(0, i).write_recursive(f, ptr, title, indices)?; + // 从栈中弹出当前索引 indices.pop(); } } @@ -88,18 +140,24 @@ impl ArrayLayout { } } +/// 测试格式化功能的测试用例 #[test] fn test() { + // 定义测试数据 const DATA: &[u8] = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 0]; + /// 定义一个包装结构体 Tensor,包含 ArrayLayout struct Tensor(ArrayLayout<4>); + /// 为 Tensor 结构体实现 fmt::Display trait,用于格式化输出 impl fmt::Display for Tensor { + /// 实现 fmt 方法,调用 ArrayLayout 的 write_array 方法进行格式化 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { unsafe { self.0.write_array(f, DATA.as_ptr()) } } } + // 创建一个 1 维数组布局的 Tensor 实例并打印 let tensor = Tensor(ArrayLayout::<4>::new_contiguous( &[DATA.len()], crate::Endian::BigEndian, @@ -107,9 +165,15 @@ fn test() { )); println!("{}", tensor); + // 对数组布局进行平铺和广播操作后创建新的 Tensor 实例并打印 let tensor = Tensor(tensor.0.tile_be(0, &[1, DATA.len()]).broadcast(0, 6)); println!("{}", tensor); + // 对数组布局进行多次平铺操作后创建新的 Tensor 实例并打印 let tensor = Tensor(tensor.0.tile_be(0, &[2, 3]).tile_be(2, &[5, 2])); println!("{}", tensor); -} + + // 创建一个 0 维数组布局的 Tensor 实例并打印 + let tensor = Tensor(ArrayLayout::<4>::with_ndim(0)); + println!("{}", tensor); +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 83e7e5c..20f10a6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,39 +1,61 @@ +// 将项目的 README 文件内容作为文档注释 #![doc = include_str!("../README.md")] +// 开启对警告和缺失文档注释的检查 #![deny(warnings, missing_docs)] -/// An array layout allow N dimensions inlined. +/// 允许内联存储 N 维信息的数组布局结构体。 pub struct ArrayLayout { + // 数组的维度数量 ndim: usize, + // 存储布局内容的联合体 content: Union, } +/// 声明 ArrayLayout 实现 Send trait,表明该类型可以安全地在线程间发送。 +/// 由于使用了 unsafe 关键字,需要确保实现的正确性。 unsafe impl Send for ArrayLayout {} +/// 声明 ArrayLayout 实现 Sync trait,表明该类型可以安全地在线程间共享引用。 +/// 由于使用了 unsafe 关键字,需要确保实现的正确性。 unsafe impl Sync for ArrayLayout {} +/// 用于存储布局内容的联合体,根据维度数量选择不同的存储方式。 union Union { + // 当维度数量超过 N 时,使用指针进行动态分配存储 ptr: NonNull, + // 当维度数量不超过 N 时,内联存储偏移量、形状和步长信息 _inlined: (isize, [usize; N], [isize; N]), } +/// 为 ArrayLayout 实现 Clone trait,允许克隆数组布局。 impl Clone for ArrayLayout { + /// 内联函数,克隆当前数组布局。 #[inline] fn clone(&self) -> Self { + // 调用 new 方法创建一个新的布局,使用当前布局的形状、步长和偏移量 Self::new(self.shape(), self.strides(), self.offset()) } } +/// 为 ArrayLayout 实现 PartialEq trait,允许比较两个数组布局是否相等。 impl PartialEq for ArrayLayout { + /// 内联函数,比较两个数组布局是否相等。 #[inline] fn eq(&self, other: &Self) -> bool { + // 比较维度数量和内容切片是否相等 self.ndim == other.ndim && self.content().as_slice() == other.content().as_slice() } } +/// 为 ArrayLayout 实现 Eq trait,表明该类型支持相等比较。 impl Eq for ArrayLayout {} +/// 为 ArrayLayout 实现 Drop trait,当布局实例被丢弃时执行清理操作。 impl Drop for ArrayLayout { + /// 当布局实例被丢弃时,释放动态分配的内存(如果有)。 fn drop(&mut self) { + // 检查是否有动态分配的指针 if let Some(ptr) = self.ptr_allocated() { + // 不安全代码块,释放动态分配的内存 unsafe { dealloc(ptr.cast().as_ptr(), layout(self.ndim)) } } } @@ -48,9 +70,11 @@ pub enum Endian { LittleEndian, } +/// 为 ArrayLayout 实现关联方法。 impl ArrayLayout { - /// Creates a new Layout with the given shape, strides, and offset. + /// 创建一个具有指定形状、步长和偏移量的新布局。 /// + /// # 示例 /// ```rust /// # use ndarray_layout::ArrayLayout; /// let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20); @@ -59,23 +83,25 @@ impl ArrayLayout { /// assert_eq!(layout.strides(), &[12, -4, 1]); /// ``` pub fn new(shape: &[usize], strides: &[isize], offset: isize) -> Self { - // check - assert_eq!( - shape.len(), - strides.len(), - "shape and strides must have the same length" - ); + // 检查形状和步长的长度是否相等 + assert_eq!(shape.len(),strides.len(),"shape and strides must have the same length"); + // 创建一个具有指定维度数量的新布局 let mut ans = Self::with_ndim(shape.len()); + // 获取布局内容的可变引用 let mut content = ans.content_mut(); + // 设置偏移量 content.set_offset(offset); + // 复制形状信息 content.copy_shape(shape); + // 复制步长信息 content.copy_strides(strides); ans } - /// Creates a new contiguous Layout with the given shape. + /// 创建一个具有指定形状的连续布局。 /// + /// # 示例 /// ```rust /// # use ndarray_layout::{Endian, ArrayLayout}; /// let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::LittleEndian, 4); @@ -84,15 +110,22 @@ impl ArrayLayout { /// assert_eq!(layout.strides(), &[4, 8, 24]); /// ``` pub fn new_contiguous(shape: &[usize], endian: Endian, element_size: usize) -> Self { + // 创建一个具有指定维度数量的新布局 let mut ans = Self::with_ndim(shape.len()); + // 获取布局内容的可变引用 let mut content = ans.content_mut(); + // 设置偏移量为 0 content.set_offset(0); + // 复制形状信息 content.copy_shape(shape); + // 初始化元素大小的倍数 let mut mul = element_size as isize; + // 定义一个闭包,用于设置步长并更新倍数 let push = |i| { content.set_stride(i, mul); mul *= shape[i] as isize; }; + // 根据大端序或小端序决定遍历顺序 match endian { Endian::BigEndian => (0..shape.len()).rev().for_each(push), Endian::LittleEndian => (0..shape.len()).for_each(push), @@ -100,32 +133,33 @@ impl ArrayLayout { ans } - /// Gets offset. + /// 获取数组的维度数量。 #[inline] pub const fn ndim(&self) -> usize { self.ndim } - /// Gets offset. + /// 获取数组的偏移量。 #[inline] pub fn offset(&self) -> isize { self.content().offset() } - /// Gets shape. + /// 获取数组的形状。 #[inline] pub fn shape(&self) -> &[usize] { self.content().shape() } - /// Gets strides. + /// 获取数组的步长。 #[inline] pub fn strides(&self) -> &[isize] { self.content().strides() } - /// Copy data to another `ArrayLayout` with inline size `M`. + /// 将当前布局复制到内联大小为 `M` 的另一个 `ArrayLayout` 中。 /// + /// # 示例 /// ```rust /// # use ndarray_layout::{Endian::BigEndian, ArrayLayout}; /// let layout = ArrayLayout::<4>::new_contiguous(&[3, 4], BigEndian, 0); @@ -135,11 +169,13 @@ impl ArrayLayout { /// assert_eq!(size_of_val(&layout), (2 * 2 + 2) * size_of::()); /// ``` pub fn to_inline_size(&self) -> ArrayLayout { + // 调用 new 方法创建一个新的布局,使用当前布局的形状、步长和偏移量 ArrayLayout::new(self.shape(), self.strides(), self.offset()) } - /// Calculates the number of elements in the array. + /// 计算数组中的元素数量。 /// + /// # 示例 /// ```rust /// # use ndarray_layout::{Endian::BigEndian, ArrayLayout}; /// let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], BigEndian, 20); @@ -147,23 +183,27 @@ impl ArrayLayout { /// ``` #[inline] pub fn num_elements(&self) -> usize { + // 对形状中的元素进行累乘 self.shape().iter().product() } - /// Calculates the offset of element at the given `index`. + /// 计算给定索引处元素的偏移量。 /// + /// # 示例 /// ```rust /// # use ndarray_layout::{Endian::BigEndian, ArrayLayout}; /// let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], BigEndian, 4); /// assert_eq!(layout.element_offset(22, BigEndian), 88); // 88 <- (22 % 4 * 4) + (22 / 4 % 3 * 16) + (22 / 4 / 3 % 2 * 48) /// ``` pub fn element_offset(&self, index: usize, endian: Endian) -> isize { + /// 正向计算元素偏移量的辅助函数。 fn offset_forwards( mut rem: usize, shape: impl IntoIterator, strides: impl IntoIterator, ) -> isize { let mut ans = 0; + // 遍历形状和步长,计算偏移量 for (d, s) in zip(shape, strides) { ans += s * (rem % d) as isize; rem /= d @@ -171,8 +211,10 @@ impl ArrayLayout { ans } + // 获取形状和步长的迭代器 let shape = self.shape().iter().cloned(); let strides = self.strides().iter().cloned(); + // 加上布局的偏移量,并根据大端序或小端序调用辅助函数 self.offset() + match endian { Endian::BigEndian => offset_forwards(index, shape.rev(), strides.rev()), @@ -180,16 +222,27 @@ impl ArrayLayout { } } - /// Calculates the range of data in bytes to determine the location of the memory area that the array needs to access. + /// 计算数据的字节范围,以确定数组需要访问的内存区域位置。 + /// + /// # 示例 + /// ```rust + /// # use ndarray_layout::ArrayLayout; + /// let layout = ArrayLayout::<4>::new(&[2, 3, 4],&[12, -4, 1], 20); + /// let range = layout.data_range(); + /// assert_eq!(range, 12..=35); + /// ``` pub fn data_range(&self) -> RangeInclusive { + // 获取布局内容 let content = self.content(); + // 初始化起始和结束偏移量为布局的偏移量 let mut start = content.offset(); let mut end = content.offset(); + // 遍历形状和步长,更新起始和结束偏移量 for (&d, s) in zip(content.shape(), content.strides()) { use std::cmp::Ordering::{Equal, Greater, Less}; let i = d as isize - 1; match s.cmp(&0) { - Equal => {} + Equal => {}, Less => start += s * i, Greater => end += s * i, } @@ -198,10 +251,14 @@ impl ArrayLayout { } } +// 引入格式化模块 mod fmt; +// 引入变换模块 mod transform; +// 重新导出变换模块中的类型 pub use transform::{BroadcastArg, IndexArg, MergeArg, SliceArg, Split, TileArg}; +// 引入标准库中的相关类型和函数 use std::{ alloc::{Layout, alloc, dealloc}, iter::zip, @@ -210,10 +267,14 @@ use std::{ slice::from_raw_parts, }; +/// 为 ArrayLayout 实现私有方法。 impl ArrayLayout { + /// 内联函数,检查是否有动态分配的指针。 #[inline] fn ptr_allocated(&self) -> Option> { - const { assert!(N > 0) } + // 编译时断言 N 大于 0 + const { assert!(N > 0)} + // ndim 大于 N 则 content 是 ptr,否则是元组 if self.ndim > N { Some(unsafe { self.content.ptr }) } else { @@ -221,6 +282,7 @@ impl ArrayLayout { } } + /// 内联函数,获取布局内容的不可变引用。 #[inline] fn content(&self) -> Content { Content { @@ -231,6 +293,7 @@ impl ArrayLayout { } } + /// 内联函数,获取布局内容的可变引用。 #[inline] fn content_mut(&mut self) -> Content { Content { @@ -241,16 +304,18 @@ impl ArrayLayout { } } - /// Create a new ArrayLayout with the given dimensions. + /// 创建一个具有指定维度数量的新 ArrayLayout。 #[inline] fn with_ndim(ndim: usize) -> Self { Self { ndim, content: if ndim <= N { + // 维度数量不超过 N 时,使用内联存储 Union { _inlined: (0, [0; N], [0; N]), } } else { + // 维度数量超过 N 时,使用动态分配存储 Union { ptr: unsafe { NonNull::new_unchecked(alloc(layout(ndim)).cast()) }, } @@ -259,60 +324,85 @@ impl ArrayLayout { } } +/// 表示布局内容的结构体,根据 MUT 标记决定是否可变。 struct Content { ptr: NonNull, ndim: usize, } +/// 为 Content 实现方法。 impl Content { + /// 内联函数,将内容转换为切片。 #[inline] fn as_slice(&self) -> &[usize] { + // 不安全代码块,从指针创建切片 unsafe { from_raw_parts(self.ptr.as_ptr(), 1 + self.ndim * 2) } } - #[inline] + /// 内联函数,获取偏移量。 + #[inline] fn offset(&self) -> isize { + // 不安全代码块,从指针读取偏移量 unsafe { self.ptr.cast().read() } } + /// 内联函数,获取形状信息。 #[inline] fn shape<'a>(&self) -> &'a [usize] { + // 不安全代码块,从指针创建形状切片 unsafe { from_raw_parts(self.ptr.add(1).as_ptr(), self.ndim) } } + /// 内联函数,获取步长信息。 #[inline] fn strides<'a>(&self) -> &'a [isize] { + // 不安全代码块,从指针创建步长切片 unsafe { from_raw_parts(self.ptr.add(1 + self.ndim).cast().as_ptr(), self.ndim) } } } +/// 为可变的 Content 实现方法。 impl Content { + /// 内联函数,设置偏移量。 #[inline] fn set_offset(&mut self, val: isize) { + // 不安全代码块,向指针写入偏移量 unsafe { self.ptr.cast().write(val) } } + /// 内联函数,设置指定索引处的形状值。 #[inline] fn set_shape(&mut self, idx: usize, val: usize) { + // 检查索引是否越界 assert!(idx < self.ndim); + // 不安全代码块,向指针写入形状值 unsafe { self.ptr.add(1 + idx).write(val) } } + /// 内联函数,设置指定索引处的步长值。 #[inline] fn set_stride(&mut self, idx: usize, val: isize) { + // 检查索引是否越界 assert!(idx < self.ndim); + // 不安全代码块,向指针写入步长值 unsafe { self.ptr.add(1 + idx + self.ndim).cast().write(val) } } + /// 内联函数,复制形状信息。 #[inline] fn copy_shape(&mut self, val: &[usize]) { + // 检查形状长度是否匹配 assert!(val.len() == self.ndim); + // 不安全代码块,复制形状信息到指针 unsafe { copy_nonoverlapping(val.as_ptr(), self.ptr.add(1).as_ptr(), self.ndim) } } + /// 内联函数,复制步长信息。 #[inline] fn copy_strides(&mut self, val: &[isize]) { + // 检查步长长度是否匹配 assert!(val.len() == self.ndim); + // 不安全代码块,复制步长信息到指针 unsafe { copy_nonoverlapping( val.as_ptr(), @@ -323,7 +413,115 @@ impl Content { } } +/// 内联函数,根据维度数量计算内存布局。 #[inline] fn layout(ndim: usize) -> Layout { + // 创建一个包含指定数量 usize 元素的内存布局 Layout::array::(1 + ndim * 2).unwrap() } + + +/// 测试 new 方法是否正确创建布局。 +#[test] +fn test_new() { + let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20); + assert_eq!(layout.offset(), 20); + assert_eq!(layout.shape(), &[2, 3, 4]); + assert_eq!(layout.strides(), &[12, -4, 1]); + assert_eq!(layout.ndim(), 3); +} + +/// 测试 new 方法在形状和步长长度不同时的行为。 +#[test] +fn test_new_different_length(){ + +} + +/// 测试 new_contiguous 方法在小端序下是否正确创建布局。 +#[test] +fn test_new_contiguous_little_endian() { + let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::LittleEndian, 4); + assert_eq!(layout.offset(), 0); + assert_eq!(layout.shape(), &[2, 3, 4]); + assert_eq!(layout.strides(), &[4, 8, 24]); +} + +/// 测试 new_contiguous 方法在大端序下是否正确创建布局。 +#[test] +fn test_new_contiguous_big_endian() { + let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::LittleEndian, 4); + assert_eq!(layout.offset(), 0); + assert_eq!(layout.shape(), &[2, 3, 4]); + assert_eq!(layout.strides(), &[4, 8, 24]); +} + +/// 测试 new 方法在形状和步长长度不匹配时是否会 panic。 +#[test] +#[should_panic(expected = "shape and strides must have the same length")] +fn test_new_invalid_shape_strides_length() { + ArrayLayout::<4>::new(&[2, 3], &[12, -4, 1], 20); +} + +/// 测试 to_inline_size 方法是否正确转换内联大小。 +#[test] +fn test_to_inline_size() { + let layout = ArrayLayout::<4>::new_contiguous(&[3, 4], Endian::BigEndian, 0); + assert_eq!(size_of_val(&layout), (2 * 4 + 2) * size_of::()); + let layout = layout.to_inline_size::<2>(); + assert_eq!(size_of_val(&layout), (2 * 2 + 2) * size_of::()); +} + +/// 测试 num_elements 方法是否正确计算元素数量。 +#[test] +fn test_num_elements() { + let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::BigEndian, 20); + assert_eq!(layout.num_elements(), 24); +} + +/// 测试 element_offset 方法在小端序下是否正确计算元素偏移量。 +#[test] +fn test_element_offset_little_endian() { + let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::LittleEndian, 4); + assert_eq!(layout.element_offset(22, Endian::LittleEndian), 88); +} + +/// 测试 element_offset 方法在大端序下是否正确计算元素偏移量。 +#[test] +fn test_element_offset_big_endian() { + let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::BigEndian, 4); + assert_eq!(layout.element_offset(22, Endian::BigEndian), 88); +} + +/// 测试 data_range 方法在步长为正数时是否正确计算数据范围。 +#[test] +fn test_data_range_positive_strides() { + let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::LittleEndian, 4); + let range = layout.data_range(); + assert_eq!(range, 0..=92); // 0 + 2*4 + 3*8 + 4*24 = 92 +} + +/// 测试 data_range 方法在步长混合时是否正确计算数据范围。 +#[test] +fn test_data_range_mixed_strides() { + let layout = ArrayLayout::<4>::new(&[2, 3, 4],&[12, -4, 0], 20); + let range = layout.data_range(); + assert_eq!(range, 12..=32); +} + +/// 测试 clone 和 eq 方法是否正确工作。 +#[test] +fn test_clone_and_eq() { + let layout1 = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20); + let layout2 = layout1.clone(); + assert!(layout1.eq(&layout2)); +} + +/// 测试 drop 方法是否正确释放内存。 +#[test] +fn test_drop() { + let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20); + // let ptr = layout.ptr_allocated().unwrap(); + drop(layout); + // 丢弃后,内存应该被释放。 + // 由于无法直接测试,依赖 Rust 的安全保证。 +} \ No newline at end of file diff --git a/src/transform/broadcast.rs b/src/transform/broadcast.rs index f2cca39..06870c0 100644 --- a/src/transform/broadcast.rs +++ b/src/transform/broadcast.rs @@ -1,17 +1,21 @@ -use crate::ArrayLayout; +// 引入 crate 中的 ArrayLayout 结构体,用于后续的广播变换操作 +use crate::ArrayLayout; -/// 索引变换参数。 +/// 广播变换参数。该结构体用于存储广播操作所需的信息,包括广播的轴和广播的次数。 #[derive(Clone, PartialEq, Eq, Debug)] pub struct BroadcastArg { - /// 广播的轴。 + /// 广播的轴,指定在哪个维度上进行广播操作。 pub axis: usize, - /// 广播次数。 + /// 广播次数,即指定轴上的长度要扩增的倍数。 pub times: usize, } +/// 为 ArrayLayout 结构体实现广播相关方法 impl ArrayLayout { /// 广播变换将指定的长度为 1 的阶扩增指定的倍数,并将其步长固定为 0。 + /// 广播操作允许在不复制数据的情况下,将一个较小的数组在某个维度上扩展成一个较大的数组。 /// + /// # 示例 /// ```rust /// # use ndarray_layout::ArrayLayout; /// let layout = ArrayLayout::<3>::new(&[1, 5, 2], &[10, 2, 1], 0).broadcast(0, 10); @@ -19,19 +23,56 @@ impl ArrayLayout { /// assert_eq!(layout.strides(), &[0, 2, 1]); /// assert_eq!(layout.offset(), 0); /// ``` + /// + /// # 参数 + /// - `axis`: 要进行广播操作的轴的索引。 + /// - `times`: 在指定轴上进行广播的次数,即该轴的新长度。 + /// + /// # 返回值 + /// 返回一个新的 `ArrayLayout` 实例,其形状和步长已根据广播操作进行更新。 pub fn broadcast(&self, axis: usize, times: usize) -> Self { + // 调用 broadcast_many 方法,传入单个广播参数 self.broadcast_many(&[BroadcastArg { axis, times }]) } /// 一次对多个阶进行广播变换。 + /// 该方法可以同时在多个轴上进行广播操作,提高操作效率。 + /// + /// # 参数 + /// - `args`: 包含多个 `BroadcastArg` 结构体的切片,每个结构体表示一个轴的广播参数。 + /// + /// # 返回值 + /// 返回一个新的 `ArrayLayout` 实例,其形状和步长已根据所有广播参数进行更新。 pub fn broadcast_many(&self, args: &[BroadcastArg]) -> Self { + // 克隆当前的 ArrayLayout 实例,作为初始的新布局 let mut ans = self.clone(); + // 获取新布局内容的可变引用,以便修改形状和步长 let mut content = ans.content_mut(); + // 遍历所有的广播参数 for &BroadcastArg { axis, times } in args { + // 断言要广播的轴的原始长度为 1 或者该轴的步长为 0,确保广播操作的合法性 assert!(content.shape()[axis] == 1 || content.strides()[axis] == 0); + // 设置指定轴的新形状为广播次数 content.set_shape(axis, times); + // 设置指定轴的步长为 0,表示在该轴上广播时不移动数据位置 content.set_stride(axis, 0); } + // 返回更新后的新布局 ans } } + +/// 测试 broadcast 方法的正确性 +#[test] +fn test_broadcast() { + // 创建一个初始的 ArrayLayout 实例 + let layout = ArrayLayout::<3>::new(&[1, 5, 2], &[10, 2, 1], 0); + // 对轴 0 进行广播操作,广播次数为 10 + let layout = layout.broadcast(0, 10); + // 断言广播操作后的形状是否符合预期 + assert_eq!(layout.shape(), &[10, 5, 2]); + // 断言广播操作后的步长是否符合预期 + assert_eq!(layout.strides(), &[0, 2, 1]); + // 断言广播操作后的偏移量是否符合预期 + assert_eq!(layout.offset(), 0); +} \ No newline at end of file diff --git a/src/transform/index.rs b/src/transform/index.rs index b959595..7ac88ea 100644 --- a/src/transform/index.rs +++ b/src/transform/index.rs @@ -1,19 +1,23 @@ -use crate::ArrayLayout; +// 引入 crate 中的 ArrayLayout 结构体,用于后续的索引变换操作 +use crate::ArrayLayout; +// 引入标准库中的 zip 函数,用于同时迭代多个迭代器 use std::iter::zip; -/// 索引变换参数。 +/// 索引变换参数。该结构体用于存储索引变换所需的信息,包括索引的轴和选择的元素索引。 #[derive(Clone, PartialEq, Eq, Debug)] pub struct IndexArg { - /// 索引的轴。 + /// 索引的轴,指定在哪个维度上进行索引操作。 pub axis: usize, - /// 选择指定轴的第几个元素。 + /// 选择指定轴的第几个元素,索引从 0 开始。 pub index: usize, } +/// 为 ArrayLayout 结构体实现索引相关方法 impl ArrayLayout { /// 索引变换是选择张量指定阶上一项数据的变换,例如指定向量中的一个数、指定矩阵的一行或一列。 /// 索引变换导致张量降阶,确定索引的阶从张量表示移除。 /// + /// # 示例 /// ```rust /// # use ndarray_layout::ArrayLayout; /// let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).index(1, 2); @@ -21,38 +25,69 @@ impl ArrayLayout { /// assert_eq!(layout.strides(), &[12, 1]); /// assert_eq!(layout.offset(), 8); /// ``` + /// + /// # 参数 + /// - `axis`: 要进行索引操作的轴的索引。 + /// - `index`: 在指定轴上选择的元素的索引。 + /// + /// # 返回值 + /// 返回一个新的 `ArrayLayout` 实例,其形状、步长和偏移量已根据索引操作进行更新,维度数减少。 pub fn index(&self, axis: usize, index: usize) -> Self { + // 调用 index_many 方法,传入单个索引参数 self.index_many(&[IndexArg { axis, index }]) } /// 一次对多个阶进行索引变换。 + /// + /// # 参数 + /// - `args`: 包含多个 `IndexArg` 结构体的切片,每个结构体表示一个轴的索引参数。 + /// + /// # 返回值 + /// 返回一个新的 `ArrayLayout` 实例,其形状、步长和偏移量已根据所有索引参数进行更新,维度数减少。 pub fn index_many(&self, mut args: &[IndexArg]) -> Self { + // 获取当前布局的内容 let content = self.content(); + // 初始化偏移量为当前布局的偏移量 let mut offset = content.offset(); + // 获取当前布局的形状 let shape = content.shape(); + // 同时迭代当前布局的形状和步长,并获取索引 let iter = zip(shape, content.strides()).enumerate(); + // 定义一个闭包,用于检查索引参数是否有效 let check = |&IndexArg { axis, index }| shape.get(axis).filter(|&&d| index < d).is_some(); + // 检查第一个索引参数是否有效,如果无效则触发断言失败 if let [first, ..] = args { assert!(check(first), "Invalid index arg: {first:?}"); } else { + // 如果没有索引参数,直接克隆当前布局并返回 return self.clone(); } + // 创建一个新的 ArrayLayout 实例,其维度数量为当前布局的维度数量减去索引参数的数量 let mut ans = Self::with_ndim(self.ndim - args.len()); + // 获取新布局内容的可变引用 let mut content = ans.content_mut(); + // 初始化新布局的索引 let mut j = 0; + // 遍历当前布局的形状和步长 for (i, (&d, &s)) in iter { match *args { + // 如果当前轴与索引参数的轴匹配 [IndexArg { axis, index }, ref tail @ ..] if axis == i => { + // 根据索引更新偏移量 offset += index as isize * s; + // 检查下一个索引参数是否有效 if let [first, ..] = tail { assert!(check(first), "Invalid index arg: {first:?}"); + // 确保索引参数的轴按升序排列 assert!(first.axis > axis, "Index args must be in ascending order"); } + // 更新剩余的索引参数 args = tail; } + // 如果当前轴没有对应的索引参数,将形状和步长设置到新布局中 [..] => { content.set_shape(j, d); content.set_stride(j, s); @@ -60,22 +95,66 @@ impl ArrayLayout { } } } + // 设置新布局的偏移量 content.set_offset(offset as _); + // 返回新的布局 ans } } +/// 测试 index 和 index_many 方法的正确性 #[test] fn test() { - let layout = ArrayLayout::<1>::new(&[2, 3, 4], &[12, 4, 1], 0); + // 错误:这里应该是 ArrayLayout::<3>,修正后创建一个三维数组布局 + let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0); + // 对轴 1 进行索引操作,选择第 2 个元素 let layout = layout.index(1, 2); + // 断言索引操作后的形状是否符合预期 assert_eq!(layout.shape(), &[2, 4]); + // 断言索引操作后的步长是否符合预期 assert_eq!(layout.strides(), &[12, 1]); + // 断言索引操作后的偏移量是否符合预期 assert_eq!(layout.offset(), 8); - let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20); + // 错误:这里应该是 ArrayLayout::<3>,修正后创建一个三维数组布局 + let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, -4, 1], 20); + // 对轴 1 进行索引操作,选择第 2 个元素 let layout = layout.index(1, 2); + // 断言索引操作后的形状是否符合预期 assert_eq!(layout.shape(), &[2, 4]); + // 断言索引操作后的步长是否符合预期 assert_eq!(layout.strides(), &[12, 1]); + // 断言索引操作后的偏移量是否符合预期 assert_eq!(layout.offset(), 12); -} + + // 错误:这里应该是 ArrayLayout::<3>,修正后创建一个三维数组布局 + let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, -4, 1], 20); + // 调用 index_many 方法,传入空的索引参数切片 + let layout = layout.index_many(&[]); + // 断言索引操作后的形状是否符合预期 + assert_eq!(layout.shape(), &[2, 3, 4]); + // 断言索引操作后的步长是否符合预期 + assert_eq!(layout.strides(), &[12, -4, 1]); + // 断言索引操作后的偏移量是否符合预期 + assert_eq!(layout.offset(), 20); + + // 错误:这里应该是 ArrayLayout::<3>,修正后创建一个三维数组布局 + let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, -4, 1], 20); + // 调用 index_many 方法,传入多个索引参数 + let layout = layout.index_many(&[ + IndexArg { + axis: 0, + index: 1, + }, + IndexArg { + axis: 1, + index: 2, + }, + ]); + // 断言索引操作后的形状是否符合预期 + assert_eq!(layout.shape(), &[4]); + // 断言索引操作后的步长是否符合预期 + assert_eq!(layout.strides(), &[1]); + // 断言索引操作后的偏移量是否符合预期 + assert_eq!(layout.offset(), 24); +} \ No newline at end of file diff --git a/src/transform/merge.rs b/src/transform/merge.rs index d149ad7..d727a05 100644 --- a/src/transform/merge.rs +++ b/src/transform/merge.rs @@ -1,21 +1,25 @@ -use crate::{ArrayLayout, Endian}; +// 引入 crate 中的 ArrayLayout 结构体和 Endian 枚举 +use crate::{ArrayLayout, Endian}; +// 引入标准库中的 zip 函数,用于同时迭代多个迭代器 use std::iter::zip; -/// 合并变换参数。 +/// 合并变换参数。该结构体用于存储合并操作所需的信息,包括合并的起始位置、合并的维度数量以及分块顺序。 #[derive(Clone, PartialEq, Eq, Debug)] pub struct MergeArg { - /// 合并的起点。 + /// 合并的起点,即从哪个维度开始进行合并操作。 pub start: usize, - /// 合并的宽度 + /// 合并的宽度,即要合并的连续维度的数量。 pub len: usize, - /// 分块的顺序。 + /// 分块的顺序。`Some(Endian::BigEndian)` 表示大端合并,`Some(Endian::LittleEndian)` 表示小端合并,`None` 表示任意合并。 pub endian: Option, } +/// 为 ArrayLayout 结构体实现合并相关方法 impl ArrayLayout { /// 合并变换是将多个连续维度划分合并的变换。 /// 大端合并对维度从后到前依次合并。 /// + /// # 示例 /// ```rust /// # use ndarray_layout::ArrayLayout; /// let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).merge_be(0, 3).unwrap(); @@ -23,8 +27,16 @@ impl ArrayLayout { /// assert_eq!(layout.strides(), &[1]); /// assert_eq!(layout.offset(), 0); /// ``` + /// + /// # 参数 + /// - `start`: 合并操作的起始维度索引。 + /// - `len`: 要合并的连续维度的数量。 + /// + /// # 返回值 + /// 如果合并成功,返回 `Some(ArrayLayout)`;否则返回 `None`。 #[inline] pub fn merge_be(&self, start: usize, len: usize) -> Option { + // 调用 merge_many 方法,传入大端合并的参数 self.merge_many(&[MergeArg { start, len, @@ -35,6 +47,7 @@ impl ArrayLayout { /// 合并变换是将多个连续维度划分合并的变换。 /// 小端合并对维度从前到后依次合并。 /// + /// # 示例 /// ```rust /// # use ndarray_layout::ArrayLayout; /// let layout = ArrayLayout::<3>::new(&[4, 3, 2], &[1, 4, 12], 0).merge_le(0, 3).unwrap(); @@ -42,8 +55,16 @@ impl ArrayLayout { /// assert_eq!(layout.strides(), &[1]); /// assert_eq!(layout.offset(), 0); /// ``` + /// + /// # 参数 + /// - `start`: 合并操作的起始维度索引。 + /// - `len`: 要合并的连续维度的数量。 + /// + /// # 返回值 + /// 如果合并成功,返回 `Some(ArrayLayout)`;否则返回 `None`。 #[inline] pub fn merge_le(&self, start: usize, len: usize) -> Option { + // 调用 merge_many 方法,传入小端合并的参数 self.merge_many(&[MergeArg { start, len, @@ -54,6 +75,7 @@ impl ArrayLayout { /// 合并变换是将多个连续维度划分合并的变换。 /// 任意合并只考虑维度的存储连续性。 /// + /// # 示例 /// ```rust /// # use ndarray_layout::ArrayLayout; /// let layout = ArrayLayout::<3>::new(&[3, 2, 4], &[4, 12, 1], 0).merge_free(0, 3).unwrap(); @@ -61,8 +83,16 @@ impl ArrayLayout { /// assert_eq!(layout.strides(), &[1]); /// assert_eq!(layout.offset(), 0); /// ``` + /// + /// # 参数 + /// - `start`: 合并操作的起始维度索引。 + /// - `len`: 要合并的连续维度的数量。 + /// + /// # 返回值 + /// 如果合并成功,返回 `Some(ArrayLayout)`;否则返回 `None`。 #[inline] pub fn merge_free(&self, start: usize, len: usize) -> Option { + // 调用 merge_many 方法,传入任意合并的参数 self.merge_many(&[MergeArg { start, len, @@ -71,82 +101,205 @@ impl ArrayLayout { } /// 一次对多个阶进行合并变换。 + /// + /// # 参数 + /// - `args`: 包含多个 `MergeArg` 结构体的切片,每个结构体表示一组合并操作的参数。 + /// + /// # 返回值 + /// 如果所有合并操作都成功,返回 `Some(ArrayLayout)`;否则返回 `None`。 pub fn merge_many(&self, args: &[MergeArg]) -> Option { + // 获取当前布局的内容 let content = self.content(); + // 获取当前布局的形状 let shape = content.shape(); + // 获取当前布局的步长 let strides = content.strides(); - let merged = args.iter().map(|arg| arg.len).sum::(); + // 修改 BUG:计算合并后的维度数量,确保每个合并操作至少合并 1 个维度 + let merged = args.iter().map(|arg| arg.len.max(1)).sum::(); + // 创建一个新的 ArrayLayout 实例,计算新的维度数量 let mut ans = Self::with_ndim(self.ndim + args.len() - merged); + // 获取新布局内容的可变引用 let mut content = ans.content_mut(); + // 将新布局的偏移量设置为当前布局的偏移量 content.set_offset(self.offset()); + // 初始化新布局的索引 let mut i = 0; + // 定义一个闭包,用于设置新布局的形状和步长 let mut push = |d, s| { content.set_shape(i, d); content.set_stride(i, s); i += 1; }; + // 记录上一次合并操作的结束位置 let mut last_end = 0; + // 遍历所有合并操作的参数 for arg in args { + // 解构合并操作的参数 let &MergeArg { start, len, endian } = arg; + // 计算本次合并操作的结束位置 let end = start + len; + // 如果合并的宽度为 0,跳过本次合并操作 if len == 0 { continue; } + // 将上一次合并操作结束位置到本次合并操作起始位置之间的维度添加到新布局中 for j in last_end..arg.start { push(shape[j], strides[j]); } + // 创建一个向量,用于存储要合并的维度的形状和步长对 let mut pairs = Vec::with_capacity(len); + // 遍历要合并的维度,将非 0 和非 1 的维度添加到向量中 for (&d, &s) in zip(&shape[start..end], &strides[start..end]) { match d { - 0 => todo!(), - 1 => {} - _ => pairs.push((d, s)), + 0 => todo!(), // 处理维度大小为 0 的情况,目前待实现 + 1 => {} // 忽略维度大小为 1 的情况 + _ => pairs.push((d, s)), // 将非 0 和非 1 的维度添加到向量中 } } + + // 修改 BUG:更新上一次合并操作的结束位置 + last_end = end; + + // 如果向量为空,说明要合并的维度都是 0 或 1,添加一个形状为 1,步长为 0 的维度 if pairs.is_empty() { push(1, 0); continue; } + + // 根据合并的顺序对向量进行排序或反转 match endian { - Some(Endian::BigEndian) => pairs.reverse(), - Some(Endian::LittleEndian) => {} - None => pairs.sort_unstable_by_key(|(_, s)| s.unsigned_abs()), + Some(Endian::BigEndian) => pairs.reverse(), // 大端合并,反转向量 + Some(Endian::LittleEndian) => {} // 小端合并,不做处理 + None => pairs.sort_unstable_by_key(|(_, s)| s.unsigned_abs()), // 任意合并,按步长的绝对值排序 } + // 取出向量的第一个元素 let ((d, s), pairs) = pairs.split_first().unwrap(); + // 初始化合并后的维度大小 let mut d = *d; + // 遍历剩余的元素,检查步长是否符合合并条件 for &(d_, s_) in pairs { if s_ == s * d as isize { - d *= d_ + d *= d_ // 如果符合条件,更新合并后的维度大小 } else { - return None; + return None; // 不符合条件,合并失败,返回 None } } + // 将合并后的维度添加到新布局中 push(d, *s); - last_end = end; } + + // 将最后一次合并操作结束位置到原布局末尾的维度添加到新布局中 for j in last_end..shape.len() { push(shape[j], strides[j]); } + // 返回合并后的新布局 Some(ans) } } +/// 测试 merge_be 方法在合并失败时返回 None #[test] -fn test_merge() { - let layout = ArrayLayout::<3>::new(&[16, 1, 4], &[16, 768, 4], 0) - .merge_be(0, 2) - .unwrap(); - assert_eq!(layout.shape(), &[16, 4]); - assert_eq!(layout.strides(), &[16, 4]); - assert_eq!(layout.offset(), 0); +fn test_merge_return_none() { + // 创建一个三维数组布局 + let layout = ArrayLayout::<3>::new(&[16, 4, 2], &[8, 4, 1], 0); + // 尝试从第 0 个维度开始合并 3 个维度 + let merged_layout = layout.merge_be(0, 3); + // 断言合并操作失败,返回 None + assert!(merged_layout.is_none()); } + +/// 测试当要合并的维度对为空时的合并操作 +#[test] +fn test_merge_pairs_empyt(){ + // 创建一个三维数组布局 + let layout = ArrayLayout::<3>::new(&[1, 1, 1], &[1, 1, 1], 0); + // 尝试从第 0 个维度开始合并 2 个维度 + let merged_layout = layout.merge_be(0, 2).unwrap(); + // 断言合并后的形状符合预期 + assert_eq!(merged_layout.shape(), &[1, 1]); + // 断言合并后的步长符合预期 + assert_eq!(merged_layout.strides(), &[0, 1]); + // 断言合并后的偏移量符合预期 + assert_eq!(merged_layout.offset(), 0); +} + +/// 测试 merge_be 方法的示例用法 +#[test] +fn test_merge_be_example() { + // 创建一个三维数组布局 + let layout = ArrayLayout::<3>::new(&[16, 1, 4], &[16, 768, 4], 0); + // 尝试从第 0 个维度开始合并 2 个维度 + let merged_layout = layout.merge_be(0, 2).unwrap(); + // 断言合并后的形状符合预期 + assert_eq!(merged_layout.shape(), &[16, 4]); + // 断言合并后的步长符合预期 + assert_eq!(merged_layout.strides(), &[16, 4]); + // 断言合并后的偏移量符合预期 + assert_eq!(merged_layout.offset(), 0); +} + +/// 测试 merge_le 方法的示例用法 +#[test] +fn test_merge_le_example() { + // 创建一个三维数组布局 + let layout = ArrayLayout::<3>::new(&[4, 3, 2], &[1, 4, 12], 0); + // 从第 0 个维度开始,合并 3 个维度 + let merged_layout = layout.merge_le(0, 3).unwrap(); + + // 验证合并后的形状、步长和偏移量 + assert_eq!(merged_layout.shape(), &[24]); + assert_eq!(merged_layout.strides(), &[1]); + assert_eq!(merged_layout.offset(), 0); +} + +/// 测试合并宽度为 0 时的合并操作 +#[test] +fn test_merge_len_zero(){ + // 创建一个三维数组布局 + let layout = ArrayLayout::<3>::new(&[4, 3, 2], &[1, 4, 12], 0); + // 从第 0 个维度开始,合并 0 个维度 + let merged_layout = layout.merge_le(0, 0).unwrap(); + + // 验证合并后的形状、步长和偏移量 + assert_eq!(merged_layout.shape(), &[4, 3, 2]); + assert_eq!(merged_layout.strides(), &[1, 4, 12]); + assert_eq!(merged_layout.offset(), 0); +} + +/// 测试部分合并操作 +#[test] +fn test_partial_merge() { + // 创建一个四维数组布局 + let layout = ArrayLayout::<4>::new(&[2, 3, 4, 5], &[60, 20, 5, 1], 0); + // 从第 1 个维度开始,合并 2 个维度 + let merged_layout = layout.merge_be(1, 2).unwrap(); + + // 验证合并后的形状、步长和偏移量 + assert_eq!(merged_layout.shape(), &[2, 12, 5]); + assert_eq!(merged_layout.strides(), &[60, 5, 1]); + assert_eq!(merged_layout.offset(), 0); +} + +/// 测试 merge_free 方法的示例用法 +#[test] +fn test_merge_free_example() { + // 创建一个三维数组布局 + let layout = ArrayLayout::<3>::new(&[3, 2, 4], &[4, 12, 1], 0); + // 从第 0 个维度开始,合并 3 个维度 + let merged_layout = layout.merge_free(0, 3).unwrap(); + + // 验证合并后的形状、步长和偏移量 + assert_eq!(merged_layout.shape(), &[24]); + assert_eq!(merged_layout.strides(), &[1]); + assert_eq!(merged_layout.offset(), 0); +} \ No newline at end of file diff --git a/src/transform/slice.rs b/src/transform/slice.rs index 815ca07..45277f5 100644 --- a/src/transform/slice.rs +++ b/src/transform/slice.rs @@ -1,31 +1,47 @@ -use crate::ArrayLayout; +// 引入 crate 中的 ArrayLayout 结构体,用于后续的切片操作 +use crate::ArrayLayout; +// 引入标准库中的 zip 函数,用于同时迭代多个迭代器 use std::iter::zip; -/// 切片变换参数。 +/// 切片变换参数。该结构体用于存储切片操作所需的信息,包括切片的轴、起始位置、步长和长度。 #[derive(Clone, PartialEq, Eq, Debug)] pub struct SliceArg { - /// 切片的轴。 + /// 切片的轴,指定在哪个维度上进行切片操作。 pub axis: usize, - /// 切片的起始位置。 + /// 切片的起始位置,即从该轴的哪个位置开始切片。 pub start: usize, - /// 切片的步长。 + /// 切片的步长,决定了切片时元素之间的间隔。正数表示正向切片,负数表示反向切片,0 表示在该位置重复元素。 pub step: isize, - /// 切片的长度。 + /// 切片的长度,即切片操作最终选取的元素数量。 pub len: usize, } +/// 为 ArrayLayout 结构体实现切片相关方法 impl ArrayLayout { /// 切片变换是裁剪张量指定阶上一组连续数据的变换。 /// + /// 该方法用于在指定的轴上进行切片操作,是 `slice_many` 方法的简化版本,仅对单个轴进行切片。 + /// + /// # 示例 /// ```rust /// # use ndarray_layout::ArrayLayout; - /// // axis = 1, start = 1, step = -1, len = 2 + /// // 在轴 1 上,从位置 2 开始,步长为 -1,切片长度为 2 /// let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).slice(1, 2, -1, 2); /// assert_eq!(layout.shape(), &[2, 2, 4]); /// assert_eq!(layout.strides(), &[12, -4, 1]); /// assert_eq!(layout.offset(), 8); /// ``` + /// + /// # 参数 + /// - `axis`: 要进行切片的轴的索引。 + /// - `start`: 切片的起始位置。 + /// - `step`: 切片的步长。 + /// - `len`: 切片的长度。 + /// + /// # 返回值 + /// 返回一个新的 `ArrayLayout` 实例,其形状、步长和偏移量已根据切片操作进行更新。 pub fn slice(&self, axis: usize, start: usize, step: isize, len: usize) -> Self { + // 调用 slice_many 方法,传入单个切片参数 self.slice_many(&[SliceArg { axis, start, @@ -35,43 +51,76 @@ impl ArrayLayout { } /// 一次对多个阶进行切片变换。 + /// + /// 该方法允许同时在多个轴上进行切片操作,根据传入的 `SliceArg` 切片参数更新布局的形状、步长和偏移量。 + /// + /// # 参数 + /// - `args`: 包含多个 `SliceArg` 结构体的切片,每个结构体表示一个轴的切片参数。 + /// + /// # 返回值 + /// 返回一个新的 `ArrayLayout` 实例,其形状、步长和偏移量已根据所有切片参数进行更新。 pub fn slice_many(&self, mut args: &[SliceArg]) -> Self { + // 获取当前布局的内容 let content = self.content(); + // 初始化偏移量为当前布局的偏移量 let mut offset = content.offset(); + // 同时迭代当前布局的形状和步长,并获取索引 let iter = zip(content.shape(), content.strides()).enumerate(); + // 创建一个新的 ArrayLayout 实例,其维度数量与当前布局相同 let mut ans = Self::with_ndim(self.ndim); + // 获取新布局内容的可变引用 let mut content = ans.content_mut(); + // 遍历当前布局的形状和步长 for (i, (&d, &s)) in iter { match args { + // 如果当前轴与切片参数的轴匹配 [arg, tail @ ..] if arg.axis == i => { + // 解构切片参数 let &SliceArg { axis, start, step, len, } = arg; + // 引入标准库中的 Ordering 枚举,用于比较步长的正负 use std::cmp::Ordering::*; + // 根据步长的正负计算实际的切片长度 let len = match step.cmp(&0) { + // 步长为正数的情况 Greater => { + // 断言起始位置小于该轴的维度大小 assert!(start < d); + // 更新偏移量 offset += start as isize * s; + // 计算实际的切片长度 (d - start).div_ceil(step as _).min(len) } + // 步长为 0 的情况 Equal => { + // 断言起始位置小于该轴的维度大小 assert!(start < d); + // 更新偏移量 offset += start as isize * s; + // 切片长度保持不变 len } + // 步长为负数的情况 Less => { + // 确保起始位置不超过该轴的维度大小减 1 let start = start.min(d - 1); + // 更新偏移量 offset += start as isize * s; + // 计算实际的切片长度 (start + 1).div_ceil((-step) as _).min(len) } }; + // 设置新布局指定轴的形状为实际的切片长度 content.set_shape(i, len); + // 设置新布局指定轴的步长为原步长乘以切片步长 content.set_stride(i, s * step); + // 检查下一个切片参数的轴是否合法 if let [next, ..] = tail { assert!( axis < next.axis && next.axis < self.ndim, @@ -81,15 +130,57 @@ impl ArrayLayout { self.ndim, ); } + // 更新剩余的切片参数 args = tail; } + // 如果当前轴没有对应的切片参数,保持形状和步长不变 [..] => { content.set_shape(i, d); content.set_stride(i, s); } } } + // 设置新布局的偏移量 content.set_offset(offset as _); + // 返回新的布局 ans } } + +/// 测试 slice 和 slice_many 方法的正确性 +#[test] +fn test_slice() { + // 测试在轴 1 上,从位置 2 开始,步长为 -1,切片长度为 2 的情况 + let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).slice(1, 2, -1, 2); + assert_eq!(layout.shape(), &[2, 2, 4]); + assert_eq!(layout.strides(), &[12, -4, 1]); + assert_eq!(layout.offset(), 8); + + // 测试在轴 1 上,从位置 2 开始,步长为 0,切片长度为 2 的情况 + let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).slice(1, 2, 0, 2); + assert_eq!(layout.shape(), &[2, 2, 4]); + assert_eq!(layout.strides(), &[12, 0, 1]); + assert_eq!(layout.offset(), 8); + + // 测试在轴 1 上,从位置 0 开始,步长为 1,切片长度为 2 的情况 + let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).slice(1, 0, 1, 2); + assert_eq!(layout.shape(), &[2, 2, 4]); + assert_eq!(layout.strides(), &[12, 4, 1]); + assert_eq!(layout.offset(), 0); + + // 测试同时在轴 1 和轴 2 上进行切片的情况 + let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).slice_many(&[SliceArg{ + axis: 1, + start: 0, + step: 1, + len: 2, + },SliceArg{ + axis: 2, + start: 0, + step: 1, + len: 4, + }]); + assert_eq!(layout.shape(), &[2, 2, 4]); + assert_eq!(layout.strides(), &[12, 4, 1]); + assert_eq!(layout.offset(), 0); +} \ No newline at end of file diff --git a/src/transform/split.rs b/src/transform/split.rs index eb8a320..e06124f 100644 --- a/src/transform/split.rs +++ b/src/transform/split.rs @@ -1,16 +1,30 @@ -use crate::ArrayLayout; +// 引入 crate 中的 ArrayLayout 结构体 +use crate::ArrayLayout; -/// 切分变换参数。 +/// 切分变换参数。该结构体用于存储切分操作所需的信息,以便将一个 `ArrayLayout` 沿指定维度切分成多个部分。 +/// +/// - `src`: 指向要进行切分操作的原始 `ArrayLayout` 的引用。 +/// - `axis`: 指定要进行切分的维度的索引。 +/// - `start`: 当前切分部分在指定维度上的起始位置。 +/// - `parts`: 一个切片,包含每个切分部分在指定维度上的大小。 pub struct Split<'a, const N: usize> { + // 要进行切分的原始 ArrayLayout 的引用 src: &'a ArrayLayout, + // 进行切分的维度 axis: usize, + // 当前切分的起始位置 start: usize, + // 每个切分部分的大小 parts: &'a [usize], } +/// 为 ArrayLayout 结构体实现切分相关方法 impl ArrayLayout { - /// 切分变换讲单个张量沿某个维度切分成多个张量,因此可以支持不均匀的切分。 + /// 切分变换将单个张量沿某个维度切分成多个张量,因此可以支持不均匀的切分。 /// + /// 该方法返回一个 `Split` 迭代器,用于逐个获取切分后的 `ArrayLayout` 实例。 + /// + /// # 示例 /// ```rust /// # use ndarray_layout::ArrayLayout; /// let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0); @@ -26,9 +40,18 @@ impl ArrayLayout { /// assert_eq!(layout.strides(), &[12, 4, 1]); /// assert_eq!(layout.offset(), 1); /// ``` + /// + /// # 参数 + /// - `axis`: 要进行切分的维度的索引。 + /// - `parts`: 一个切片,包含每个切分部分在指定维度上的大小。所有部分大小之和必须等于指定维度的原始大小。 + /// + /// # 返回值 + /// 返回一个 `Split` 迭代器,用于遍历切分后的 `ArrayLayout` 实例。 #[inline] pub fn split<'a>(&'a self, axis: usize, parts: &'a [usize]) -> Split<'a, N> { + // 断言指定维度的原始大小等于所有切分部分大小之和 assert_eq!(self.shape()[axis], parts.iter().sum()); + // 创建并返回 Split 结构体实例 Split { src: self, axis, @@ -38,16 +61,56 @@ impl ArrayLayout { } } +/// 为 Split 结构体实现 Iterator trait,使其可以作为迭代器使用 impl Iterator for Split<'_, N> { + // 迭代器返回的元素类型为 ArrayLayout type Item = ArrayLayout; + /// 获取迭代器的下一个元素。 + /// + /// 该方法会从 `parts` 中取出第一个元素作为当前切分部分的大小, + /// 然后根据当前的起始位置和切分大小生成一个新的 `ArrayLayout` 实例。 + /// + /// # 返回值 + /// - 如果 `parts` 不为空,返回 `Some(ArrayLayout)`,表示下一个切分后的 `ArrayLayout` 实例。 + /// - 如果 `parts` 为空,返回 `None`,表示迭代结束。 #[inline] fn next(&mut self) -> Option { + // 尝试从 parts 中取出第一个元素和剩余部分 self.parts.split_first().map(|(&head, tail)| { + // 记录当前的起始位置 let start = self.start; + // 更新起始位置为当前起始位置加上当前切分部分的大小 self.start += head; + // 更新 parts 为剩余部分 self.parts = tail; + // 调用 src 的 slice 方法生成切分后的 ArrayLayout 实例 self.src.slice(self.axis, start, 1, head) }) } } + +/// 测试 split 方法的正确性 +#[test] +fn test_split() { + // 创建一个 ArrayLayout 实例 + let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0); + // 调用 split 方法进行切分,得到一个 Split 迭代器 + let mut splits = layout.split(2, &[1, 3]); + // 获取第一个切分后的 ArrayLayout 实例 + let layout = splits.next().unwrap(); + // 断言第一个切分后的形状是否符合预期 + assert_eq!(layout.shape(), &[2, 3, 1]); + // 断言第一个切分后的步长是否符合预期 + assert_eq!(layout.strides(), &[12, 4, 1]); + // 断言第一个切分后的偏移量是否符合预期 + assert_eq!(layout.offset(), 0); + // 获取第二个切分后的 ArrayLayout 实例 + let layout = splits.next().unwrap(); + // 断言第二个切分后的形状是否符合预期 + assert_eq!(layout.shape(), &[2, 3, 3]); + // 断言第二个切分后的步长是否符合预期 + assert_eq!(layout.strides(), &[12, 4, 1]); + // 断言第二个切分后的偏移量是否符合预期 + assert_eq!(layout.offset(), 1); +} \ No newline at end of file diff --git a/src/transform/tile.rs b/src/transform/tile.rs index d0165c0..2f6e5b3 100644 --- a/src/transform/tile.rs +++ b/src/transform/tile.rs @@ -1,21 +1,24 @@ +// 引入 crate 中的 ArrayLayout 结构体和 Endian 枚举 use crate::{ArrayLayout, Endian}; +// 引入标准库中的 zip 函数,用于同时迭代多个迭代器 use std::iter::zip; -/// 分块变换参数。 +/// 分块变换参数。该结构体用于存储分块变换所需的参数。 #[derive(Clone, PartialEq, Eq, Debug)] pub struct TileArg<'a> { - /// 分块的轴。 + /// 分块操作要应用的轴。轴的索引从 0 开始。 pub axis: usize, - /// 分块的顺序。 + /// 分块的顺序。大端序或小端序决定了分块后维度在形状中的排列顺序。 pub endian: Endian, - /// 分块的大小。 + /// 分块的大小数组。每个元素表示对应分块的大小。 pub tiles: &'a [usize], } +/// 为 ArrayLayout 结构体实现分块变换相关方法 impl ArrayLayout { - /// 分块变换是将单个维度划分为多个分块的变换。 - /// 大端分块使得分块后范围更大的维度在形状中更靠前的位置。 + /// 大端分块变换。将单个维度划分为多个分块,大端分块使得分块后范围更大的维度在形状中更靠前的位置。 /// + /// # 示例 /// ```rust /// # use ndarray_layout::ArrayLayout; /// let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_be(2, &[2, 3]); @@ -23,8 +26,16 @@ impl ArrayLayout { /// assert_eq!(layout.strides(), &[18, 6, 3, 1]); /// assert_eq!(layout.offset(), 0); /// ``` + /// + /// # 参数 + /// - `axis`: 要进行分块的轴的索引。 + /// - `tiles`: 分块大小的数组。 + /// + /// # 返回值 + /// 返回一个新的 `ArrayLayout` 实例,其维度已根据大端分块规则进行变换。 #[inline] pub fn tile_be(&self, axis: usize, tiles: &[usize]) -> Self { + // 调用 tile_many 方法,传入大端序的分块参数 self.tile_many(&[TileArg { axis, endian: Endian::BigEndian, @@ -32,9 +43,9 @@ impl ArrayLayout { }]) } - /// 分块变换是将单个维度划分为多个分块的变换。 - /// 小端分块使得分块后范围更小的维度在形状中更靠前的位置。 + /// 小端分块变换。将单个维度划分为多个分块,小端分块使得分块后范围更小的维度在形状中更靠前的位置。 /// + /// # 示例 /// ```rust /// # use ndarray_layout::ArrayLayout; /// let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_le(2, &[2, 3]); @@ -42,8 +53,16 @@ impl ArrayLayout { /// assert_eq!(layout.strides(), &[18, 6, 1, 2]); /// assert_eq!(layout.offset(), 0); /// ``` + /// + /// # 参数 + /// - `axis`: 要进行分块的轴的索引。 + /// - `tiles`: 分块大小的数组。 + /// + /// # 返回值 + /// 返回一个新的 `ArrayLayout` 实例,其维度已根据小端分块规则进行变换。 #[inline] pub fn tile_le(&self, axis: usize, tiles: &[usize]) -> Self { + // 调用 tile_many 方法,传入小端序的分块参数 self.tile_many(&[TileArg { axis, endian: Endian::LittleEndian, @@ -52,43 +71,68 @@ impl ArrayLayout { } /// 一次对多个阶进行分块变换。 + /// + /// # 参数 + /// - `args`: 包含多个 `TileArg` 结构体的切片,每个结构体表示一个轴的分块参数。 + /// + /// # 返回值 + /// 返回一个新的 `ArrayLayout` 实例,其维度已根据所有分块参数进行变换。 pub fn tile_many(&self, mut args: &[TileArg]) -> Self { + // 获取当前布局的内容 let content = self.content(); + // 获取当前布局的形状 let shape = content.shape(); + // 同时迭代形状和步长,并获取索引 let iter = zip(shape, content.strides()).enumerate(); + // 定义一个闭包,用于检查分块参数是否有效 let check = |&TileArg { axis, tiles, .. }| { + // 检查指定轴的维度大小是否等于分块大小的乘积 shape .get(axis) .filter(|&&d| d == tiles.iter().product()) .is_some() }; + // 初始化新布局的维度数量和上一个处理的轴的索引 let (mut new, mut last_axis) = match args { [first, ..] => { + // 检查第一个分块参数是否有效 assert!(check(first)); + // 新布局的维度数量初始化为第一个分块参数的分块数量 (first.tiles.len(), first.axis) } - [..] => return self.clone(), + [..] => return self.clone(), // 如果没有分块参数,直接克隆当前布局 }; + // 遍历剩余的分块参数 for arg in &args[1..] { + // 检查分块参数是否有效 assert!(check(arg)); + // 确保当前轴的索引大于上一个轴的索引 assert!(arg.axis > last_axis); + // 累加新布局的维度数量 new += arg.tiles.len(); + // 更新上一个处理的轴的索引 last_axis = arg.axis; } + // 创建一个新的 ArrayLayout 实例,其维度数量为当前布局的维度数量加上新维度数量减去分块参数的数量 let mut ans = Self::with_ndim(self.ndim + new - args.len()); + // 获取新布局内容的可变引用 let mut content = ans.content_mut(); + // 将新布局的偏移量设置为当前布局的偏移量 content.set_offset(self.offset()); + // 初始化新布局的索引 let mut j = 0; + // 定义一个闭包,用于设置新布局的形状和步长 let mut push = |t, s| { content.set_shape(j, t); content.set_stride(j, s); j += 1; }; + // 遍历当前布局的形状和步长 for (i, (&d, &s)) in iter { match *args { [ @@ -99,8 +143,10 @@ impl ArrayLayout { }, ref tail @ .., ] if axis == i => { + // 如果当前轴与分块参数的轴匹配 match endian { Endian::BigEndian => { + // 大端分块规则 // tile : [a, b , c] // strides: [s * c * b, s * c, s] let mut s = s * d as isize; @@ -110,6 +156,7 @@ impl ArrayLayout { } } Endian::LittleEndian => { + // 小端分块规则 // tile : [a, b , c ] // strides: [s, s * a, s * a * b] let mut s = s; @@ -119,11 +166,60 @@ impl ArrayLayout { } } } + // 处理完当前分块参数后,更新剩余的分块参数 args = tail; } - [..] => push(d, s), + [..] => push(d, s), // 如果当前轴没有分块参数,直接设置形状和步长 } } + // 返回新的布局 ans } } + +/// 测试大端分块变换的正确性 +#[test] +fn test_tile_be() { + let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_be(2, &[2, 3]); + assert_eq!(layout.shape(), &[2, 3, 2, 3]); + assert_eq!(layout.strides(), &[18, 6, 3, 1]); + assert_eq!(layout.offset(), 0); +} + +/// 测试小端分块变换的正确性 +#[test] +fn test_tile_le() { + let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_le(2, &[2, 3]); + assert_eq!(layout.shape(), &[2, 3, 2, 3]); + assert_eq!(layout.strides(), &[18, 6, 1, 2]); + assert_eq!(layout.offset(), 0); +} + +/// 测试无分块参数时的行为 +#[test] +fn test_empty_tile() { + let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_many(&[]); + assert_eq!(layout.shape(), &[2, 3, 6]); + assert_eq!(layout.strides(), &[18, 6, 1]); + assert_eq!(layout.offset(), 0); +} + +/// 测试多个分块参数时的行为 +#[test] +fn test_multiple_tiles(){ + let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_many(&[ + TileArg { + axis: 0, + endian: Endian::BigEndian, + tiles: &[2, 1], + }, + TileArg { + axis: 2, + endian: Endian::BigEndian, + tiles: &[2, 3], + } + ]); + assert_eq!(layout.shape(), &[2, 1, 3, 2, 3]); + assert_eq!(layout.strides(), &[18, 18, 6, 3, 1]); + assert_eq!(layout.offset(), 0); +} \ No newline at end of file diff --git a/src/transform/transpose.rs b/src/transform/transpose.rs index 1000739..dee97fc 100644 --- a/src/transform/transpose.rs +++ b/src/transform/transpose.rs @@ -1,9 +1,16 @@ -use crate::ArrayLayout; +// 引入 crate 中的 ArrayLayout 结构体 +use crate::ArrayLayout; +// 引入标准库中的 BTreeSet 用于存储唯一且有序的元素,以及 zip 函数用于迭代多个迭代器 use std::{collections::BTreeSet, iter::zip}; +/// 为 ArrayLayout 结构体实现方法 impl ArrayLayout { /// 转置变换允许调换张量的维度顺序,但不改变元素的存储顺序。 /// + /// 该方法接收一个排列数组 `perm`,根据该数组重新排列原布局的维度。 + /// 未在 `perm` 中指定的维度将保持不变。 + /// + /// # 示例 /// ```rust /// # use ndarray_layout::ArrayLayout; /// let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).transpose(&[1, 0]); @@ -11,33 +18,81 @@ impl ArrayLayout { /// assert_eq!(layout.strides(), &[4, 12, 1]); /// assert_eq!(layout.offset(), 0); /// ``` + /// + /// # 参数 + /// - `perm`: 一个切片,包含要交换的维度的索引。索引必须唯一。 + /// + /// # 返回值 + /// 返回一个新的 `ArrayLayout` 实例,其维度顺序已根据 `perm` 进行转置。 pub fn transpose(&self, perm: &[usize]) -> Self { + // 将 perm 中的元素收集到 BTreeSet 中,确保元素唯一且有序 let perm_ = perm.iter().collect::>(); + // 断言 perm 中的元素都是唯一的 assert_eq!(perm_.len(), perm.len()); + // 获取当前布局的内容 let content = self.content(); + // 获取当前布局的形状 let shape = content.shape(); + // 获取当前布局的步长 let strides = content.strides(); + // 创建一个新的 ArrayLayout 实例,其维度数量与当前布局相同 let mut ans = Self::with_ndim(self.ndim); + // 获取新布局内容的可变引用 let mut content = ans.content_mut(); + // 将新布局的偏移量设置为当前布局的偏移量 content.set_offset(self.offset()); + + // 定义一个闭包,用于设置新布局指定索引处的形状和步长 let mut set = |i, j| { + // 设置新布局索引 i 处的形状为原布局索引 j 处的形状 content.set_shape(i, shape[j]); + // 设置新布局索引 i 处的步长为原布局索引 j 处的步长 content.set_stride(i, strides[j]); }; + // 记录上一次处理的维度索引,初始化为 0 let mut last = 0; + // 同时遍历有序的 perm_ 和原始的 perm for (&i, &j) in zip(perm_, perm) { + // 处理 last 到 i 之间未在 perm 中指定的维度,保持这些维度不变 for i in last..i { set(i, i); } + // 根据 perm 中的映射关系设置新布局的形状和步长 set(i, j); + // 更新 last 为当前处理的维度索引加 1 last = i + 1; } + // 处理 perm 中未涉及的剩余维度,保持这些维度不变 for i in last..shape.len() { set(i, i); } + + // 返回转置后的新布局 ans } } + +/// 测试 transpose 方法的正确性 +#[test] +fn test_transpose() { + // 创建一个初始布局并进行转置操作 + let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).transpose(&[1, 0]); + // 断言转置后的形状是否符合预期 + assert_eq!(layout.shape(), &[3, 2, 4]); + // 断言转置后的步长是否符合预期 + assert_eq!(layout.strides(), &[4, 12, 1]); + // 断言转置后的偏移量是否符合预期 + assert_eq!(layout.offset(), 0); + + // 创建另一个初始布局并进行转置操作 + let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).transpose(&[2, 0]); + // 断言转置后的形状是否符合预期 + assert_eq!(layout.shape(), &[4, 3, 2]); + // 断言转置后的步长是否符合预期 + assert_eq!(layout.strides(), &[1, 4, 12]); + // 断言转置后的偏移量是否符合预期 + assert_eq!(layout.offset(), 0); +} \ No newline at end of file From eda1229766e04be1a9074ff7f0c1a5c017dbd5bf Mon Sep 17 00:00:00 2001 From: Simon25772 Date: Thu, 24 Apr 2025 16:42:40 +0800 Subject: [PATCH 02/21] =?UTF-8?q?style(README):=20=E8=A7=84=E8=8C=83?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index ea7e407..f3ddb4d 100644 --- a/README.md +++ b/README.md @@ -12,20 +12,21 @@ ![GitHub contributors](https://img.shields.io/github/contributors/InfiniTensor/ndarray-layout) ![GitHub commit activity](https://img.shields.io/github/commit-activity/m/InfiniTensor/ndarray-layout) - ndarray-layout 是一个用于处理多维数组布局的 Rust 库,它提供了 ArrayLayout 结构体,用于高效管理和操作多维数组的元信息,如形状、步长和偏移量等。这个库在处理多维数组时,提供了灵活且高效的布局管理方式,能够满足不同场景下对数组布局的操作需求。 + ## 主要功能特点 -### 多维数组布局管理: + +### 多维数组布局管理 * ArrayLayout 结构体支持指定任意维度的数组布局,通过 new 方法可以创建具有指定形状、步长和偏移量的布局。 * 提供 new_contiguous 方法,用于创建连续的数组布局,支持大端序(BigEndian)和小端序(LittleEndian)两种存储顺序。 -### 元信息访问: +### 元信息访问 * 提供便捷的方法来访问数组布局的元信息,如 ndim、offset、shape 和 strides 等。 * 支持计算数组元素的偏移量和数据范围,方便进行内存访问和数据处理。 -### 布局操作功能: +### 布局操作功能 * 提供多种布局变换方法,如 index、tile、transpose、merge 和 slice 等,方便对数组布局进行各种变换操作。 @@ -63,4 +64,3 @@ assert_eq!(multi_broadcasted_layout.shape(), &[4, 3, 3]); assert_eq!(multi_broadcasted_layout.strides(), &[0, 0, 1]); assert_eq!(multi_broadcasted_layout.offset(), 0); ``` - From 0d9b9364d69d1b2b7f315f722bf4ae7c1d9aac26 Mon Sep 17 00:00:00 2001 From: Simon25772 Date: Thu, 24 Apr 2025 17:10:59 +0800 Subject: [PATCH 03/21] =?UTF-8?q?style:=20=E8=A7=A3=E5=86=B3=E4=B9=8B?= =?UTF-8?q?=E5=89=8D=E6=B3=A8=E9=87=8A=E4=B8=8D=E8=A7=84=E8=8C=83=E7=9A=84?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/fmt.rs | 79 +++---------------- src/lib.rs | 135 ++++---------------------------- src/transform/broadcast.rs | 45 ++--------- src/transform/index.rs | 90 ++++----------------- src/transform/merge.rs | 155 ++++++++----------------------------- src/transform/slice.rs | 74 ++---------------- src/transform/split.rs | 55 +------------ src/transform/tile.rs | 73 +++-------------- src/transform/transpose.rs | 44 +---------- 9 files changed, 98 insertions(+), 652 deletions(-) diff --git a/src/fmt.rs b/src/fmt.rs index 026225c..051bba1 100644 --- a/src/fmt.rs +++ b/src/fmt.rs @@ -1,91 +1,48 @@ -// 引入 crate 中的 ArrayLayout 结构体 -use crate::ArrayLayout; -// 引入标准库中的 fmt 模块,用于格式化输出 +use crate::ArrayLayout; use std::fmt; -/// 为 ArrayLayout 结构体实现格式化相关方法 impl ArrayLayout { /// 高维数组格式化。 /// - /// 该函数根据数组布局信息,将数组元素按照不同维度进行格式化输出。 - /// /// # Safety /// /// 这个函数从对裸指针解引用以获得要格式化的数组元素。 - /// 调用者需要确保指针 `ptr` 有效,并且指向的内存区域包含足够的元素, - /// 同时偏移量计算不会导致越界访问。 - /// - /// # 参数 - /// - `f`: 用于格式化输出的 `fmt::Formatter` 引用。 - /// - `ptr`: 指向要格式化的数组元素的裸指针。 - /// - /// # 返回值 - /// 如果格式化成功,返回 `Ok(())`;否则返回 `fmt::Error`。 pub unsafe fn write_array( &self, f: &mut fmt::Formatter, ptr: *const T, ) -> fmt::Result { - // 根据数组的维度数量进行不同的格式化处理 match self.ndim() { - // 处理 0 维数组 - 0 => { - // 写入格式化后的数组信息,从指针偏移处读取元素 - write!(f, "array<> = [{}]", unsafe { + 0 => {write!(f, "array<> = [{}]", unsafe { ptr.byte_offset(self.offset()).read_unaligned() }) } - // 处理 1 维数组 - 1 => { - // 解构出数组的形状和步长 - let &[n] = self.shape() else { unreachable!() }; + 1 => {let &[n] = self.shape() else { unreachable!() }; let &[s] = self.strides() else { unreachable!() }; - // 写入数组标题 writeln!(f, "array<{n}>[")?; - // 计算指针偏移 let ptr = unsafe { ptr.byte_offset(self.offset()) }; - // 遍历数组元素并写入格式化后的信息 for i in 0..n as isize { writeln!(f, " {}", unsafe { ptr.byte_offset(i * s).read_unaligned() })? } - // 写入数组结束符 writeln!(f, "]")?; Ok(()) } - // 处理多维数组 - _ => { - // 生成数组标题 - let mut title = "array<".to_string(); + _ => {let mut title = "array<".to_string(); for d in self.shape() { title.push_str(&format!("{d}x")) } - // 移除标题末尾多余的 'x' assert_eq!(title.pop(), Some('x')); title.push('>'); - // 创建一个栈用于存储索引信息 let mut stack = Vec::with_capacity(self.ndim() - 2); - // 递归调用 write_recursive 方法进行格式化 self.write_recursive(f, ptr, &title, &mut stack) } } } - /// 递归地格式化多维数组。 - /// - /// 该函数通过递归的方式处理多维数组的不同维度,将数组元素格式化输出。 - /// - /// # 参数 - /// - `f`: 用于格式化输出的 `fmt::Formatter` 引用。 - /// - `ptr`: 指向要格式化的数组元素的裸指针。 - /// - `title`: 数组的标题字符串。 - /// - `indices`: 用于存储当前维度索引的可变向量。 - /// - /// # 返回值 - /// 如果格式化成功,返回 `Ok(())`;否则返回 `fmt::Error`。 fn write_recursive( &self, f: &mut fmt::Formatter, @@ -93,27 +50,20 @@ impl ArrayLayout { title: &str, indices: &mut Vec, ) -> fmt::Result { - // 根据数组的形状进行不同的格式化处理 match *self.shape() { - // 空形状或单元素形状不应该出现,触发 unreachable! 宏 [] | [_] => unreachable!(), - // 处理 2 维数组 [rows, cols] => { - // 写入数组标题和索引信息 write!(f, "{title}[")?; for i in indices { write!(f, "{i}, ")? } writeln!(f, "..]")?; - // 解构出数组的行步长和列步长 let &[rs, cs] = self.strides() else { unreachable!() }; - // 计算指针偏移 let ptr = unsafe { ptr.byte_offset(self.offset()) }; - // 遍历二维数组的行和列,写入格式化后的元素信息 for r in 0..rows as isize { for c in 0..cols as isize { write!(f, "{} ", unsafe { @@ -123,15 +73,10 @@ impl ArrayLayout { writeln!(f)? } } - // 处理多维数组的批量维度 [batch, ..] => { - // 遍历批量维度 for i in 0..batch { - // 将当前索引压入栈中 indices.push(i); - // 递归调用 write_recursive 方法处理下一个维度 self.index(0, i).write_recursive(f, ptr, title, indices)?; - // 从栈中弹出当前索引 indices.pop(); } } @@ -140,24 +85,20 @@ impl ArrayLayout { } } -/// 测试格式化功能的测试用例 + + #[test] fn test() { - // 定义测试数据 const DATA: &[u8] = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 0]; - /// 定义一个包装结构体 Tensor,包含 ArrayLayout struct Tensor(ArrayLayout<4>); - /// 为 Tensor 结构体实现 fmt::Display trait,用于格式化输出 impl fmt::Display for Tensor { - /// 实现 fmt 方法,调用 ArrayLayout 的 write_array 方法进行格式化 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { unsafe { self.0.write_array(f, DATA.as_ptr()) } } } - // 创建一个 1 维数组布局的 Tensor 实例并打印 let tensor = Tensor(ArrayLayout::<4>::new_contiguous( &[DATA.len()], crate::Endian::BigEndian, @@ -165,15 +106,13 @@ fn test() { )); println!("{}", tensor); - // 对数组布局进行平铺和广播操作后创建新的 Tensor 实例并打印 let tensor = Tensor(tensor.0.tile_be(0, &[1, DATA.len()]).broadcast(0, 6)); println!("{}", tensor); - // 对数组布局进行多次平铺操作后创建新的 Tensor 实例并打印 let tensor = Tensor(tensor.0.tile_be(0, &[2, 3]).tile_be(2, &[5, 2])); println!("{}", tensor); - // 创建一个 0 维数组布局的 Tensor 实例并打印 - let tensor = Tensor(ArrayLayout::<4>::with_ndim(0)); + let tensor=Tensor(ArrayLayout::<4>::with_ndim(0)); println!("{}", tensor); -} \ No newline at end of file + +} diff --git a/src/lib.rs b/src/lib.rs index 20f10a6..7ffb82e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,61 +1,39 @@ -// 将项目的 README 文件内容作为文档注释 #![doc = include_str!("../README.md")] -// 开启对警告和缺失文档注释的检查 #![deny(warnings, missing_docs)] -/// 允许内联存储 N 维信息的数组布局结构体。 +/// An array layout allow N dimensions inlined. pub struct ArrayLayout { - // 数组的维度数量 ndim: usize, - // 存储布局内容的联合体 content: Union, } -/// 声明 ArrayLayout 实现 Send trait,表明该类型可以安全地在线程间发送。 -/// 由于使用了 unsafe 关键字,需要确保实现的正确性。 unsafe impl Send for ArrayLayout {} -/// 声明 ArrayLayout 实现 Sync trait,表明该类型可以安全地在线程间共享引用。 -/// 由于使用了 unsafe 关键字,需要确保实现的正确性。 unsafe impl Sync for ArrayLayout {} -/// 用于存储布局内容的联合体,根据维度数量选择不同的存储方式。 union Union { - // 当维度数量超过 N 时,使用指针进行动态分配存储 ptr: NonNull, - // 当维度数量不超过 N 时,内联存储偏移量、形状和步长信息 _inlined: (isize, [usize; N], [isize; N]), } -/// 为 ArrayLayout 实现 Clone trait,允许克隆数组布局。 impl Clone for ArrayLayout { - /// 内联函数,克隆当前数组布局。 #[inline] fn clone(&self) -> Self { - // 调用 new 方法创建一个新的布局,使用当前布局的形状、步长和偏移量 Self::new(self.shape(), self.strides(), self.offset()) } } -/// 为 ArrayLayout 实现 PartialEq trait,允许比较两个数组布局是否相等。 impl PartialEq for ArrayLayout { - /// 内联函数,比较两个数组布局是否相等。 #[inline] fn eq(&self, other: &Self) -> bool { - // 比较维度数量和内容切片是否相等 self.ndim == other.ndim && self.content().as_slice() == other.content().as_slice() } } -/// 为 ArrayLayout 实现 Eq trait,表明该类型支持相等比较。 impl Eq for ArrayLayout {} -/// 为 ArrayLayout 实现 Drop trait,当布局实例被丢弃时执行清理操作。 impl Drop for ArrayLayout { - /// 当布局实例被丢弃时,释放动态分配的内存(如果有)。 fn drop(&mut self) { - // 检查是否有动态分配的指针 if let Some(ptr) = self.ptr_allocated() { - // 不安全代码块,释放动态分配的内存 unsafe { dealloc(ptr.cast().as_ptr(), layout(self.ndim)) } } } @@ -70,11 +48,9 @@ pub enum Endian { LittleEndian, } -/// 为 ArrayLayout 实现关联方法。 impl ArrayLayout { - /// 创建一个具有指定形状、步长和偏移量的新布局。 + /// Creates a new Layout with the given shape, strides, and offset. /// - /// # 示例 /// ```rust /// # use ndarray_layout::ArrayLayout; /// let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20); @@ -83,25 +59,19 @@ impl ArrayLayout { /// assert_eq!(layout.strides(), &[12, -4, 1]); /// ``` pub fn new(shape: &[usize], strides: &[isize], offset: isize) -> Self { - // 检查形状和步长的长度是否相等 + // check assert_eq!(shape.len(),strides.len(),"shape and strides must have the same length"); - // 创建一个具有指定维度数量的新布局 let mut ans = Self::with_ndim(shape.len()); - // 获取布局内容的可变引用 let mut content = ans.content_mut(); - // 设置偏移量 content.set_offset(offset); - // 复制形状信息 content.copy_shape(shape); - // 复制步长信息 content.copy_strides(strides); ans } - /// 创建一个具有指定形状的连续布局。 + /// Creates a new contiguous Layout with the given shape. /// - /// # 示例 /// ```rust /// # use ndarray_layout::{Endian, ArrayLayout}; /// let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::LittleEndian, 4); @@ -110,22 +80,16 @@ impl ArrayLayout { /// assert_eq!(layout.strides(), &[4, 8, 24]); /// ``` pub fn new_contiguous(shape: &[usize], endian: Endian, element_size: usize) -> Self { - // 创建一个具有指定维度数量的新布局 let mut ans = Self::with_ndim(shape.len()); - // 获取布局内容的可变引用 let mut content = ans.content_mut(); - // 设置偏移量为 0 content.set_offset(0); - // 复制形状信息 content.copy_shape(shape); - // 初始化元素大小的倍数 let mut mul = element_size as isize; - // 定义一个闭包,用于设置步长并更新倍数 let push = |i| { content.set_stride(i, mul); mul *= shape[i] as isize; }; - // 根据大端序或小端序决定遍历顺序 + // 大端小端区别在于是否反转 match endian { Endian::BigEndian => (0..shape.len()).rev().for_each(push), Endian::LittleEndian => (0..shape.len()).for_each(push), @@ -133,33 +97,32 @@ impl ArrayLayout { ans } - /// 获取数组的维度数量。 + /// Gets offset. #[inline] pub const fn ndim(&self) -> usize { self.ndim } - /// 获取数组的偏移量。 + /// Gets offset. #[inline] pub fn offset(&self) -> isize { self.content().offset() } - /// 获取数组的形状。 + /// Gets shape. #[inline] pub fn shape(&self) -> &[usize] { self.content().shape() } - /// 获取数组的步长。 + /// Gets strides. #[inline] pub fn strides(&self) -> &[isize] { self.content().strides() } - /// 将当前布局复制到内联大小为 `M` 的另一个 `ArrayLayout` 中。 + /// Copy data to another `ArrayLayout` with inline size `M`. /// - /// # 示例 /// ```rust /// # use ndarray_layout::{Endian::BigEndian, ArrayLayout}; /// let layout = ArrayLayout::<4>::new_contiguous(&[3, 4], BigEndian, 0); @@ -169,13 +132,11 @@ impl ArrayLayout { /// assert_eq!(size_of_val(&layout), (2 * 2 + 2) * size_of::()); /// ``` pub fn to_inline_size(&self) -> ArrayLayout { - // 调用 new 方法创建一个新的布局,使用当前布局的形状、步长和偏移量 ArrayLayout::new(self.shape(), self.strides(), self.offset()) } - /// 计算数组中的元素数量。 + /// Calculates the number of elements in the array. /// - /// # 示例 /// ```rust /// # use ndarray_layout::{Endian::BigEndian, ArrayLayout}; /// let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], BigEndian, 20); @@ -183,27 +144,23 @@ impl ArrayLayout { /// ``` #[inline] pub fn num_elements(&self) -> usize { - // 对形状中的元素进行累乘 self.shape().iter().product() } - /// 计算给定索引处元素的偏移量。 + /// Calculates the offset of element at the given `index`. /// - /// # 示例 /// ```rust /// # use ndarray_layout::{Endian::BigEndian, ArrayLayout}; /// let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], BigEndian, 4); /// assert_eq!(layout.element_offset(22, BigEndian), 88); // 88 <- (22 % 4 * 4) + (22 / 4 % 3 * 16) + (22 / 4 / 3 % 2 * 48) /// ``` pub fn element_offset(&self, index: usize, endian: Endian) -> isize { - /// 正向计算元素偏移量的辅助函数。 fn offset_forwards( mut rem: usize, shape: impl IntoIterator, strides: impl IntoIterator, ) -> isize { let mut ans = 0; - // 遍历形状和步长,计算偏移量 for (d, s) in zip(shape, strides) { ans += s * (rem % d) as isize; rem /= d @@ -211,10 +168,8 @@ impl ArrayLayout { ans } - // 获取形状和步长的迭代器 let shape = self.shape().iter().cloned(); let strides = self.strides().iter().cloned(); - // 加上布局的偏移量,并根据大端序或小端序调用辅助函数 self.offset() + match endian { Endian::BigEndian => offset_forwards(index, shape.rev(), strides.rev()), @@ -222,9 +177,8 @@ impl ArrayLayout { } } - /// 计算数据的字节范围,以确定数组需要访问的内存区域位置。 + /// Calculates the range of data in bytes to determine the location of the memory area that the array needs to access. /// - /// # 示例 /// ```rust /// # use ndarray_layout::ArrayLayout; /// let layout = ArrayLayout::<4>::new(&[2, 3, 4],&[12, -4, 1], 20); @@ -232,12 +186,9 @@ impl ArrayLayout { /// assert_eq!(range, 12..=35); /// ``` pub fn data_range(&self) -> RangeInclusive { - // 获取布局内容 let content = self.content(); - // 初始化起始和结束偏移量为布局的偏移量 let mut start = content.offset(); let mut end = content.offset(); - // 遍历形状和步长,更新起始和结束偏移量 for (&d, s) in zip(content.shape(), content.strides()) { use std::cmp::Ordering::{Equal, Greater, Less}; let i = d as isize - 1; @@ -251,14 +202,10 @@ impl ArrayLayout { } } -// 引入格式化模块 mod fmt; -// 引入变换模块 mod transform; -// 重新导出变换模块中的类型 pub use transform::{BroadcastArg, IndexArg, MergeArg, SliceArg, Split, TileArg}; -// 引入标准库中的相关类型和函数 use std::{ alloc::{Layout, alloc, dealloc}, iter::zip, @@ -267,14 +214,11 @@ use std::{ slice::from_raw_parts, }; -/// 为 ArrayLayout 实现私有方法。 impl ArrayLayout { - /// 内联函数,检查是否有动态分配的指针。 #[inline] fn ptr_allocated(&self) -> Option> { - // 编译时断言 N 大于 0 const { assert!(N > 0)} - // ndim 大于 N 则 content 是 ptr,否则是元组 + // ndim>N则content是ptr,否则是元组 if self.ndim > N { Some(unsafe { self.content.ptr }) } else { @@ -282,7 +226,6 @@ impl ArrayLayout { } } - /// 内联函数,获取布局内容的不可变引用。 #[inline] fn content(&self) -> Content { Content { @@ -293,7 +236,6 @@ impl ArrayLayout { } } - /// 内联函数,获取布局内容的可变引用。 #[inline] fn content_mut(&mut self) -> Content { Content { @@ -304,18 +246,16 @@ impl ArrayLayout { } } - /// 创建一个具有指定维度数量的新 ArrayLayout。 + /// Create a new ArrayLayout with the given dimensions. #[inline] fn with_ndim(ndim: usize) -> Self { Self { ndim, content: if ndim <= N { - // 维度数量不超过 N 时,使用内联存储 Union { _inlined: (0, [0; N], [0; N]), } } else { - // 维度数量超过 N 时,使用动态分配存储 Union { ptr: unsafe { NonNull::new_unchecked(alloc(layout(ndim)).cast()) }, } @@ -324,85 +264,60 @@ impl ArrayLayout { } } -/// 表示布局内容的结构体,根据 MUT 标记决定是否可变。 struct Content { ptr: NonNull, ndim: usize, } -/// 为 Content 实现方法。 impl Content { - /// 内联函数,将内容转换为切片。 #[inline] fn as_slice(&self) -> &[usize] { - // 不安全代码块,从指针创建切片 unsafe { from_raw_parts(self.ptr.as_ptr(), 1 + self.ndim * 2) } } - /// 内联函数,获取偏移量。 #[inline] fn offset(&self) -> isize { - // 不安全代码块,从指针读取偏移量 unsafe { self.ptr.cast().read() } } - /// 内联函数,获取形状信息。 #[inline] fn shape<'a>(&self) -> &'a [usize] { - // 不安全代码块,从指针创建形状切片 unsafe { from_raw_parts(self.ptr.add(1).as_ptr(), self.ndim) } } - /// 内联函数,获取步长信息。 #[inline] fn strides<'a>(&self) -> &'a [isize] { - // 不安全代码块,从指针创建步长切片 unsafe { from_raw_parts(self.ptr.add(1 + self.ndim).cast().as_ptr(), self.ndim) } } } -/// 为可变的 Content 实现方法。 impl Content { - /// 内联函数,设置偏移量。 #[inline] fn set_offset(&mut self, val: isize) { - // 不安全代码块,向指针写入偏移量 unsafe { self.ptr.cast().write(val) } } - /// 内联函数,设置指定索引处的形状值。 #[inline] fn set_shape(&mut self, idx: usize, val: usize) { - // 检查索引是否越界 assert!(idx < self.ndim); - // 不安全代码块,向指针写入形状值 unsafe { self.ptr.add(1 + idx).write(val) } } - /// 内联函数,设置指定索引处的步长值。 #[inline] fn set_stride(&mut self, idx: usize, val: isize) { - // 检查索引是否越界 assert!(idx < self.ndim); - // 不安全代码块,向指针写入步长值 unsafe { self.ptr.add(1 + idx + self.ndim).cast().write(val) } } - /// 内联函数,复制形状信息。 #[inline] fn copy_shape(&mut self, val: &[usize]) { - // 检查形状长度是否匹配 assert!(val.len() == self.ndim); - // 不安全代码块,复制形状信息到指针 unsafe { copy_nonoverlapping(val.as_ptr(), self.ptr.add(1).as_ptr(), self.ndim) } } - /// 内联函数,复制步长信息。 #[inline] fn copy_strides(&mut self, val: &[isize]) { - // 检查步长长度是否匹配 assert!(val.len() == self.ndim); - // 不安全代码块,复制步长信息到指针 unsafe { copy_nonoverlapping( val.as_ptr(), @@ -413,15 +328,12 @@ impl Content { } } -/// 内联函数,根据维度数量计算内存布局。 #[inline] fn layout(ndim: usize) -> Layout { - // 创建一个包含指定数量 usize 元素的内存布局 Layout::array::(1 + ndim * 2).unwrap() } -/// 测试 new 方法是否正确创建布局。 #[test] fn test_new() { let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20); @@ -431,13 +343,11 @@ fn test_new() { assert_eq!(layout.ndim(), 3); } -/// 测试 new 方法在形状和步长长度不同时的行为。 #[test] fn test_new_different_length(){ } -/// 测试 new_contiguous 方法在小端序下是否正确创建布局。 #[test] fn test_new_contiguous_little_endian() { let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::LittleEndian, 4); @@ -446,7 +356,6 @@ fn test_new_contiguous_little_endian() { assert_eq!(layout.strides(), &[4, 8, 24]); } -/// 测试 new_contiguous 方法在大端序下是否正确创建布局。 #[test] fn test_new_contiguous_big_endian() { let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::LittleEndian, 4); @@ -455,14 +364,12 @@ fn test_new_contiguous_big_endian() { assert_eq!(layout.strides(), &[4, 8, 24]); } -/// 测试 new 方法在形状和步长长度不匹配时是否会 panic。 #[test] #[should_panic(expected = "shape and strides must have the same length")] fn test_new_invalid_shape_strides_length() { ArrayLayout::<4>::new(&[2, 3], &[12, -4, 1], 20); } -/// 测试 to_inline_size 方法是否正确转换内联大小。 #[test] fn test_to_inline_size() { let layout = ArrayLayout::<4>::new_contiguous(&[3, 4], Endian::BigEndian, 0); @@ -471,28 +378,24 @@ fn test_to_inline_size() { assert_eq!(size_of_val(&layout), (2 * 2 + 2) * size_of::()); } -/// 测试 num_elements 方法是否正确计算元素数量。 #[test] fn test_num_elements() { let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::BigEndian, 20); assert_eq!(layout.num_elements(), 24); } -/// 测试 element_offset 方法在小端序下是否正确计算元素偏移量。 #[test] fn test_element_offset_little_endian() { let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::LittleEndian, 4); assert_eq!(layout.element_offset(22, Endian::LittleEndian), 88); } -/// 测试 element_offset 方法在大端序下是否正确计算元素偏移量。 #[test] fn test_element_offset_big_endian() { let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::BigEndian, 4); assert_eq!(layout.element_offset(22, Endian::BigEndian), 88); } -/// 测试 data_range 方法在步长为正数时是否正确计算数据范围。 #[test] fn test_data_range_positive_strides() { let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::LittleEndian, 4); @@ -500,7 +403,6 @@ fn test_data_range_positive_strides() { assert_eq!(range, 0..=92); // 0 + 2*4 + 3*8 + 4*24 = 92 } -/// 测试 data_range 方法在步长混合时是否正确计算数据范围。 #[test] fn test_data_range_mixed_strides() { let layout = ArrayLayout::<4>::new(&[2, 3, 4],&[12, -4, 0], 20); @@ -508,7 +410,6 @@ fn test_data_range_mixed_strides() { assert_eq!(range, 12..=32); } -/// 测试 clone 和 eq 方法是否正确工作。 #[test] fn test_clone_and_eq() { let layout1 = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20); @@ -516,12 +417,8 @@ fn test_clone_and_eq() { assert!(layout1.eq(&layout2)); } -/// 测试 drop 方法是否正确释放内存。 #[test] fn test_drop() { let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20); - // let ptr = layout.ptr_allocated().unwrap(); drop(layout); - // 丢弃后,内存应该被释放。 - // 由于无法直接测试,依赖 Rust 的安全保证。 -} \ No newline at end of file +} diff --git a/src/transform/broadcast.rs b/src/transform/broadcast.rs index 06870c0..7998e0f 100644 --- a/src/transform/broadcast.rs +++ b/src/transform/broadcast.rs @@ -1,21 +1,17 @@ -// 引入 crate 中的 ArrayLayout 结构体,用于后续的广播变换操作 -use crate::ArrayLayout; +use crate::ArrayLayout; -/// 广播变换参数。该结构体用于存储广播操作所需的信息,包括广播的轴和广播的次数。 +/// 索引变换参数。 #[derive(Clone, PartialEq, Eq, Debug)] pub struct BroadcastArg { - /// 广播的轴,指定在哪个维度上进行广播操作。 + /// 广播的轴。 pub axis: usize, - /// 广播次数,即指定轴上的长度要扩增的倍数。 + /// 广播次数。 pub times: usize, } -/// 为 ArrayLayout 结构体实现广播相关方法 impl ArrayLayout { /// 广播变换将指定的长度为 1 的阶扩增指定的倍数,并将其步长固定为 0。 - /// 广播操作允许在不复制数据的情况下,将一个较小的数组在某个维度上扩展成一个较大的数组。 /// - /// # 示例 /// ```rust /// # use ndarray_layout::ArrayLayout; /// let layout = ArrayLayout::<3>::new(&[1, 5, 2], &[10, 2, 1], 0).broadcast(0, 10); @@ -23,56 +19,27 @@ impl ArrayLayout { /// assert_eq!(layout.strides(), &[0, 2, 1]); /// assert_eq!(layout.offset(), 0); /// ``` - /// - /// # 参数 - /// - `axis`: 要进行广播操作的轴的索引。 - /// - `times`: 在指定轴上进行广播的次数,即该轴的新长度。 - /// - /// # 返回值 - /// 返回一个新的 `ArrayLayout` 实例,其形状和步长已根据广播操作进行更新。 pub fn broadcast(&self, axis: usize, times: usize) -> Self { - // 调用 broadcast_many 方法,传入单个广播参数 self.broadcast_many(&[BroadcastArg { axis, times }]) } /// 一次对多个阶进行广播变换。 - /// 该方法可以同时在多个轴上进行广播操作,提高操作效率。 - /// - /// # 参数 - /// - `args`: 包含多个 `BroadcastArg` 结构体的切片,每个结构体表示一个轴的广播参数。 - /// - /// # 返回值 - /// 返回一个新的 `ArrayLayout` 实例,其形状和步长已根据所有广播参数进行更新。 pub fn broadcast_many(&self, args: &[BroadcastArg]) -> Self { - // 克隆当前的 ArrayLayout 实例,作为初始的新布局 let mut ans = self.clone(); - // 获取新布局内容的可变引用,以便修改形状和步长 let mut content = ans.content_mut(); - // 遍历所有的广播参数 for &BroadcastArg { axis, times } in args { - // 断言要广播的轴的原始长度为 1 或者该轴的步长为 0,确保广播操作的合法性 assert!(content.shape()[axis] == 1 || content.strides()[axis] == 0); - // 设置指定轴的新形状为广播次数 content.set_shape(axis, times); - // 设置指定轴的步长为 0,表示在该轴上广播时不移动数据位置 content.set_stride(axis, 0); } - // 返回更新后的新布局 ans } } -/// 测试 broadcast 方法的正确性 #[test] fn test_broadcast() { - // 创建一个初始的 ArrayLayout 实例 - let layout = ArrayLayout::<3>::new(&[1, 5, 2], &[10, 2, 1], 0); - // 对轴 0 进行广播操作,广播次数为 10 - let layout = layout.broadcast(0, 10); - // 断言广播操作后的形状是否符合预期 + let layout = ArrayLayout::<3>::new(&[1, 5, 2], &[10, 2, 1], 0).broadcast(0, 10); assert_eq!(layout.shape(), &[10, 5, 2]); - // 断言广播操作后的步长是否符合预期 assert_eq!(layout.strides(), &[0, 2, 1]); - // 断言广播操作后的偏移量是否符合预期 assert_eq!(layout.offset(), 0); -} \ No newline at end of file +} diff --git a/src/transform/index.rs b/src/transform/index.rs index 7ac88ea..dba9ed1 100644 --- a/src/transform/index.rs +++ b/src/transform/index.rs @@ -1,23 +1,19 @@ -// 引入 crate 中的 ArrayLayout 结构体,用于后续的索引变换操作 -use crate::ArrayLayout; -// 引入标准库中的 zip 函数,用于同时迭代多个迭代器 +use crate::ArrayLayout; use std::iter::zip; -/// 索引变换参数。该结构体用于存储索引变换所需的信息,包括索引的轴和选择的元素索引。 +/// 索引变换参数。 #[derive(Clone, PartialEq, Eq, Debug)] pub struct IndexArg { - /// 索引的轴,指定在哪个维度上进行索引操作。 + /// 索引的轴。 pub axis: usize, - /// 选择指定轴的第几个元素,索引从 0 开始。 + /// 选择指定轴的第几个元素。 pub index: usize, } -/// 为 ArrayLayout 结构体实现索引相关方法 impl ArrayLayout { /// 索引变换是选择张量指定阶上一项数据的变换,例如指定向量中的一个数、指定矩阵的一行或一列。 /// 索引变换导致张量降阶,确定索引的阶从张量表示移除。 /// - /// # 示例 /// ```rust /// # use ndarray_layout::ArrayLayout; /// let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).index(1, 2); @@ -25,69 +21,38 @@ impl ArrayLayout { /// assert_eq!(layout.strides(), &[12, 1]); /// assert_eq!(layout.offset(), 8); /// ``` - /// - /// # 参数 - /// - `axis`: 要进行索引操作的轴的索引。 - /// - `index`: 在指定轴上选择的元素的索引。 - /// - /// # 返回值 - /// 返回一个新的 `ArrayLayout` 实例,其形状、步长和偏移量已根据索引操作进行更新,维度数减少。 pub fn index(&self, axis: usize, index: usize) -> Self { - // 调用 index_many 方法,传入单个索引参数 self.index_many(&[IndexArg { axis, index }]) } /// 一次对多个阶进行索引变换。 - /// - /// # 参数 - /// - `args`: 包含多个 `IndexArg` 结构体的切片,每个结构体表示一个轴的索引参数。 - /// - /// # 返回值 - /// 返回一个新的 `ArrayLayout` 实例,其形状、步长和偏移量已根据所有索引参数进行更新,维度数减少。 pub fn index_many(&self, mut args: &[IndexArg]) -> Self { - // 获取当前布局的内容 let content = self.content(); - // 初始化偏移量为当前布局的偏移量 let mut offset = content.offset(); - // 获取当前布局的形状 let shape = content.shape(); - // 同时迭代当前布局的形状和步长,并获取索引 let iter = zip(shape, content.strides()).enumerate(); - // 定义一个闭包,用于检查索引参数是否有效 let check = |&IndexArg { axis, index }| shape.get(axis).filter(|&&d| index < d).is_some(); - // 检查第一个索引参数是否有效,如果无效则触发断言失败 if let [first, ..] = args { assert!(check(first), "Invalid index arg: {first:?}"); } else { - // 如果没有索引参数,直接克隆当前布局并返回 return self.clone(); } - // 创建一个新的 ArrayLayout 实例,其维度数量为当前布局的维度数量减去索引参数的数量 let mut ans = Self::with_ndim(self.ndim - args.len()); - // 获取新布局内容的可变引用 let mut content = ans.content_mut(); - // 初始化新布局的索引 let mut j = 0; - // 遍历当前布局的形状和步长 for (i, (&d, &s)) in iter { match *args { - // 如果当前轴与索引参数的轴匹配 [IndexArg { axis, index }, ref tail @ ..] if axis == i => { - // 根据索引更新偏移量 offset += index as isize * s; - // 检查下一个索引参数是否有效 if let [first, ..] = tail { assert!(check(first), "Invalid index arg: {first:?}"); - // 确保索引参数的轴按升序排列 assert!(first.axis > axis, "Index args must be in ascending order"); } - // 更新剩余的索引参数 args = tail; } - // 如果当前轴没有对应的索引参数,将形状和步长设置到新布局中 [..] => { content.set_shape(j, d); content.set_stride(j, s); @@ -95,66 +60,37 @@ impl ArrayLayout { } } } - // 设置新布局的偏移量 content.set_offset(offset as _); - // 返回新的布局 ans } } -/// 测试 index 和 index_many 方法的正确性 #[test] fn test() { - // 错误:这里应该是 ArrayLayout::<3>,修正后创建一个三维数组布局 - let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0); - // 对轴 1 进行索引操作,选择第 2 个元素 + let layout = ArrayLayout::<1>::new(&[2, 3, 4], &[12, 4, 1], 0); let layout = layout.index(1, 2); - // 断言索引操作后的形状是否符合预期 assert_eq!(layout.shape(), &[2, 4]); - // 断言索引操作后的步长是否符合预期 assert_eq!(layout.strides(), &[12, 1]); - // 断言索引操作后的偏移量是否符合预期 assert_eq!(layout.offset(), 8); - // 错误:这里应该是 ArrayLayout::<3>,修正后创建一个三维数组布局 - let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, -4, 1], 20); - // 对轴 1 进行索引操作,选择第 2 个元素 + let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20); let layout = layout.index(1, 2); - // 断言索引操作后的形状是否符合预期 assert_eq!(layout.shape(), &[2, 4]); - // 断言索引操作后的步长是否符合预期 assert_eq!(layout.strides(), &[12, 1]); - // 断言索引操作后的偏移量是否符合预期 assert_eq!(layout.offset(), 12); - // 错误:这里应该是 ArrayLayout::<3>,修正后创建一个三维数组布局 - let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, -4, 1], 20); - // 调用 index_many 方法,传入空的索引参数切片 + let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20); let layout = layout.index_many(&[]); - // 断言索引操作后的形状是否符合预期 assert_eq!(layout.shape(), &[2, 3, 4]); - // 断言索引操作后的步长是否符合预期 assert_eq!(layout.strides(), &[12, -4, 1]); - // 断言索引操作后的偏移量是否符合预期 assert_eq!(layout.offset(), 20); - // 错误:这里应该是 ArrayLayout::<3>,修正后创建一个三维数组布局 - let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, -4, 1], 20); - // 调用 index_many 方法,传入多个索引参数 - let layout = layout.index_many(&[ - IndexArg { - axis: 0, - index: 1, - }, - IndexArg { - axis: 1, - index: 2, - }, - ]); - // 断言索引操作后的形状是否符合预期 + let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20); + let layout = layout.index_many(&[IndexArg{axis:0, + index:1}, + IndexArg{axis:1, + index:2}]); assert_eq!(layout.shape(), &[4]); - // 断言索引操作后的步长是否符合预期 assert_eq!(layout.strides(), &[1]); - // 断言索引操作后的偏移量是否符合预期 assert_eq!(layout.offset(), 24); -} \ No newline at end of file +} diff --git a/src/transform/merge.rs b/src/transform/merge.rs index d727a05..50208f1 100644 --- a/src/transform/merge.rs +++ b/src/transform/merge.rs @@ -1,25 +1,21 @@ -// 引入 crate 中的 ArrayLayout 结构体和 Endian 枚举 -use crate::{ArrayLayout, Endian}; -// 引入标准库中的 zip 函数,用于同时迭代多个迭代器 +use crate::{ArrayLayout, Endian}; use std::iter::zip; -/// 合并变换参数。该结构体用于存储合并操作所需的信息,包括合并的起始位置、合并的维度数量以及分块顺序。 +/// 合并变换参数。 #[derive(Clone, PartialEq, Eq, Debug)] pub struct MergeArg { - /// 合并的起点,即从哪个维度开始进行合并操作。 + /// 合并的起点。 pub start: usize, - /// 合并的宽度,即要合并的连续维度的数量。 + /// 合并的宽度 pub len: usize, - /// 分块的顺序。`Some(Endian::BigEndian)` 表示大端合并,`Some(Endian::LittleEndian)` 表示小端合并,`None` 表示任意合并。 + /// 分块的顺序。 pub endian: Option, } -/// 为 ArrayLayout 结构体实现合并相关方法 impl ArrayLayout { /// 合并变换是将多个连续维度划分合并的变换。 /// 大端合并对维度从后到前依次合并。 /// - /// # 示例 /// ```rust /// # use ndarray_layout::ArrayLayout; /// let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).merge_be(0, 3).unwrap(); @@ -27,16 +23,8 @@ impl ArrayLayout { /// assert_eq!(layout.strides(), &[1]); /// assert_eq!(layout.offset(), 0); /// ``` - /// - /// # 参数 - /// - `start`: 合并操作的起始维度索引。 - /// - `len`: 要合并的连续维度的数量。 - /// - /// # 返回值 - /// 如果合并成功,返回 `Some(ArrayLayout)`;否则返回 `None`。 #[inline] pub fn merge_be(&self, start: usize, len: usize) -> Option { - // 调用 merge_many 方法,传入大端合并的参数 self.merge_many(&[MergeArg { start, len, @@ -47,7 +35,6 @@ impl ArrayLayout { /// 合并变换是将多个连续维度划分合并的变换。 /// 小端合并对维度从前到后依次合并。 /// - /// # 示例 /// ```rust /// # use ndarray_layout::ArrayLayout; /// let layout = ArrayLayout::<3>::new(&[4, 3, 2], &[1, 4, 12], 0).merge_le(0, 3).unwrap(); @@ -55,16 +42,8 @@ impl ArrayLayout { /// assert_eq!(layout.strides(), &[1]); /// assert_eq!(layout.offset(), 0); /// ``` - /// - /// # 参数 - /// - `start`: 合并操作的起始维度索引。 - /// - `len`: 要合并的连续维度的数量。 - /// - /// # 返回值 - /// 如果合并成功,返回 `Some(ArrayLayout)`;否则返回 `None`。 #[inline] pub fn merge_le(&self, start: usize, len: usize) -> Option { - // 调用 merge_many 方法,传入小端合并的参数 self.merge_many(&[MergeArg { start, len, @@ -75,7 +54,6 @@ impl ArrayLayout { /// 合并变换是将多个连续维度划分合并的变换。 /// 任意合并只考虑维度的存储连续性。 /// - /// # 示例 /// ```rust /// # use ndarray_layout::ArrayLayout; /// let layout = ArrayLayout::<3>::new(&[3, 2, 4], &[4, 12, 1], 0).merge_free(0, 3).unwrap(); @@ -83,16 +61,8 @@ impl ArrayLayout { /// assert_eq!(layout.strides(), &[1]); /// assert_eq!(layout.offset(), 0); /// ``` - /// - /// # 参数 - /// - `start`: 合并操作的起始维度索引。 - /// - `len`: 要合并的连续维度的数量。 - /// - /// # 返回值 - /// 如果合并成功,返回 `Some(ArrayLayout)`;否则返回 `None`。 #[inline] pub fn merge_free(&self, start: usize, len: usize) -> Option { - // 调用 merge_many 方法,传入任意合并的参数 self.merge_many(&[MergeArg { start, len, @@ -101,205 +71,144 @@ impl ArrayLayout { } /// 一次对多个阶进行合并变换。 - /// - /// # 参数 - /// - `args`: 包含多个 `MergeArg` 结构体的切片,每个结构体表示一组合并操作的参数。 - /// - /// # 返回值 - /// 如果所有合并操作都成功,返回 `Some(ArrayLayout)`;否则返回 `None`。 pub fn merge_many(&self, args: &[MergeArg]) -> Option { - // 获取当前布局的内容 let content = self.content(); - // 获取当前布局的形状 let shape = content.shape(); - // 获取当前布局的步长 let strides = content.strides(); - // 修改 BUG:计算合并后的维度数量,确保每个合并操作至少合并 1 个维度 + // 修改BUG let merged = args.iter().map(|arg| arg.len.max(1)).sum::(); - // 创建一个新的 ArrayLayout 实例,计算新的维度数量 let mut ans = Self::with_ndim(self.ndim + args.len() - merged); - // 获取新布局内容的可变引用 let mut content = ans.content_mut(); - // 将新布局的偏移量设置为当前布局的偏移量 content.set_offset(self.offset()); - // 初始化新布局的索引 let mut i = 0; - // 定义一个闭包,用于设置新布局的形状和步长 let mut push = |d, s| { content.set_shape(i, d); content.set_stride(i, s); i += 1; }; - // 记录上一次合并操作的结束位置 let mut last_end = 0; - // 遍历所有合并操作的参数 for arg in args { - // 解构合并操作的参数 let &MergeArg { start, len, endian } = arg; - // 计算本次合并操作的结束位置 let end = start + len; - // 如果合并的宽度为 0,跳过本次合并操作 if len == 0 { continue; } - // 将上一次合并操作结束位置到本次合并操作起始位置之间的维度添加到新布局中 for j in last_end..arg.start { push(shape[j], strides[j]); } - // 创建一个向量,用于存储要合并的维度的形状和步长对 let mut pairs = Vec::with_capacity(len); - // 遍历要合并的维度,将非 0 和非 1 的维度添加到向量中 for (&d, &s) in zip(&shape[start..end], &strides[start..end]) { match d { - 0 => todo!(), // 处理维度大小为 0 的情况,目前待实现 - 1 => {} // 忽略维度大小为 1 的情况 - _ => pairs.push((d, s)), // 将非 0 和非 1 的维度添加到向量中 + 0 => todo!(), + 1 => {} + _ => pairs.push((d, s)), } } - // 修改 BUG:更新上一次合并操作的结束位置 + // 修改BUG last_end = end; - // 如果向量为空,说明要合并的维度都是 0 或 1,添加一个形状为 1,步长为 0 的维度 if pairs.is_empty() { push(1, 0); continue; } - - // 根据合并的顺序对向量进行排序或反转 match endian { - Some(Endian::BigEndian) => pairs.reverse(), // 大端合并,反转向量 - Some(Endian::LittleEndian) => {} // 小端合并,不做处理 - None => pairs.sort_unstable_by_key(|(_, s)| s.unsigned_abs()), // 任意合并,按步长的绝对值排序 + Some(Endian::BigEndian) => pairs.reverse(), + Some(Endian::LittleEndian) => {} + None => pairs.sort_unstable_by_key(|(_, s)| s.unsigned_abs()), } - // 取出向量的第一个元素 let ((d, s), pairs) = pairs.split_first().unwrap(); - // 初始化合并后的维度大小 let mut d = *d; - // 遍历剩余的元素,检查步长是否符合合并条件 for &(d_, s_) in pairs { if s_ == s * d as isize { - d *= d_ // 如果符合条件,更新合并后的维度大小 + d *= d_ } else { - return None; // 不符合条件,合并失败,返回 None + return None; } } - // 将合并后的维度添加到新布局中 push(d, *s); + } - - // 将最后一次合并操作结束位置到原布局末尾的维度添加到新布局中 for j in last_end..shape.len() { push(shape[j], strides[j]); } - // 返回合并后的新布局 Some(ans) } } -/// 测试 merge_be 方法在合并失败时返回 None #[test] fn test_merge_return_none() { - // 创建一个三维数组布局 - let layout = ArrayLayout::<3>::new(&[16, 4, 2], &[8, 4, 1], 0); - // 尝试从第 0 个维度开始合并 3 个维度 - let merged_layout = layout.merge_be(0, 3); - // 断言合并操作失败,返回 None - assert!(merged_layout.is_none()); + let layout = ArrayLayout::<3>::new(&[16, 4, 2], &[8, 4, 1], 0) + .merge_be(0, 3); + assert!(layout.is_none()); } -/// 测试当要合并的维度对为空时的合并操作 #[test] fn test_merge_pairs_empyt(){ - // 创建一个三维数组布局 - let layout = ArrayLayout::<3>::new(&[1, 1, 1], &[1, 1, 1], 0); - // 尝试从第 0 个维度开始合并 2 个维度 - let merged_layout = layout.merge_be(0, 2).unwrap(); - // 断言合并后的形状符合预期 - assert_eq!(merged_layout.shape(), &[1, 1]); - // 断言合并后的步长符合预期 - assert_eq!(merged_layout.strides(), &[0, 1]); - // 断言合并后的偏移量符合预期 - assert_eq!(merged_layout.offset(), 0); + let layout = ArrayLayout::<3>::new(&[1, 1, 1], &[1, 1, 1], 0) + .merge_be(0, 2) + .unwrap(); + assert_eq!(layout.shape(), &[1, 1]); + assert_eq!(layout.strides(), &[0, 1]); + assert_eq!(layout.offset(), 0); } -/// 测试 merge_be 方法的示例用法 #[test] fn test_merge_be_example() { - // 创建一个三维数组布局 - let layout = ArrayLayout::<3>::new(&[16, 1, 4], &[16, 768, 4], 0); - // 尝试从第 0 个维度开始合并 2 个维度 - let merged_layout = layout.merge_be(0, 2).unwrap(); - // 断言合并后的形状符合预期 - assert_eq!(merged_layout.shape(), &[16, 4]); - // 断言合并后的步长符合预期 - assert_eq!(merged_layout.strides(), &[16, 4]); - // 断言合并后的偏移量符合预期 - assert_eq!(merged_layout.offset(), 0); + let layout = ArrayLayout::<3>::new(&[16, 1, 4], &[16, 768, 4], 0) + .merge_be(0, 2) + .unwrap(); + assert_eq!(layout.shape(), &[16, 4]); + assert_eq!(layout.strides(), &[16, 4]); + assert_eq!(layout.offset(), 0); } -/// 测试 merge_le 方法的示例用法 #[test] fn test_merge_le_example() { - // 创建一个三维数组布局 let layout = ArrayLayout::<3>::new(&[4, 3, 2], &[1, 4, 12], 0); - // 从第 0 个维度开始,合并 3 个维度 let merged_layout = layout.merge_le(0, 3).unwrap(); - // 验证合并后的形状、步长和偏移量 assert_eq!(merged_layout.shape(), &[24]); assert_eq!(merged_layout.strides(), &[1]); assert_eq!(merged_layout.offset(), 0); } -/// 测试合并宽度为 0 时的合并操作 #[test] fn test_merge_len_zero(){ - // 创建一个三维数组布局 let layout = ArrayLayout::<3>::new(&[4, 3, 2], &[1, 4, 12], 0); - // 从第 0 个维度开始,合并 0 个维度 let merged_layout = layout.merge_le(0, 0).unwrap(); - // 验证合并后的形状、步长和偏移量 assert_eq!(merged_layout.shape(), &[4, 3, 2]); assert_eq!(merged_layout.strides(), &[1, 4, 12]); assert_eq!(merged_layout.offset(), 0); } -/// 测试部分合并操作 #[test] fn test_partial_merge() { - // 创建一个四维数组布局 let layout = ArrayLayout::<4>::new(&[2, 3, 4, 5], &[60, 20, 5, 1], 0); - // 从第 1 个维度开始,合并 2 个维度 let merged_layout = layout.merge_be(1, 2).unwrap(); - // 验证合并后的形状、步长和偏移量 assert_eq!(merged_layout.shape(), &[2, 12, 5]); assert_eq!(merged_layout.strides(), &[60, 5, 1]); assert_eq!(merged_layout.offset(), 0); } -/// 测试 merge_free 方法的示例用法 #[test] fn test_merge_free_example() { - // 创建一个三维数组布局 let layout = ArrayLayout::<3>::new(&[3, 2, 4], &[4, 12, 1], 0); - // 从第 0 个维度开始,合并 3 个维度 let merged_layout = layout.merge_free(0, 3).unwrap(); - // 验证合并后的形状、步长和偏移量 assert_eq!(merged_layout.shape(), &[24]); assert_eq!(merged_layout.strides(), &[1]); assert_eq!(merged_layout.offset(), 0); -} \ No newline at end of file +} diff --git a/src/transform/slice.rs b/src/transform/slice.rs index 45277f5..3108fc2 100644 --- a/src/transform/slice.rs +++ b/src/transform/slice.rs @@ -1,47 +1,31 @@ -// 引入 crate 中的 ArrayLayout 结构体,用于后续的切片操作 -use crate::ArrayLayout; -// 引入标准库中的 zip 函数,用于同时迭代多个迭代器 +use crate::ArrayLayout; use std::iter::zip; -/// 切片变换参数。该结构体用于存储切片操作所需的信息,包括切片的轴、起始位置、步长和长度。 +/// 切片变换参数。 #[derive(Clone, PartialEq, Eq, Debug)] pub struct SliceArg { - /// 切片的轴,指定在哪个维度上进行切片操作。 + /// 切片的轴。 pub axis: usize, - /// 切片的起始位置,即从该轴的哪个位置开始切片。 + /// 切片的起始位置。 pub start: usize, - /// 切片的步长,决定了切片时元素之间的间隔。正数表示正向切片,负数表示反向切片,0 表示在该位置重复元素。 + /// 切片的步长。 pub step: isize, - /// 切片的长度,即切片操作最终选取的元素数量。 + /// 切片的长度。 pub len: usize, } -/// 为 ArrayLayout 结构体实现切片相关方法 impl ArrayLayout { /// 切片变换是裁剪张量指定阶上一组连续数据的变换。 /// - /// 该方法用于在指定的轴上进行切片操作,是 `slice_many` 方法的简化版本,仅对单个轴进行切片。 - /// - /// # 示例 /// ```rust /// # use ndarray_layout::ArrayLayout; - /// // 在轴 1 上,从位置 2 开始,步长为 -1,切片长度为 2 + /// // axis = 1, start = 1, step = -1, len = 2 /// let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).slice(1, 2, -1, 2); /// assert_eq!(layout.shape(), &[2, 2, 4]); /// assert_eq!(layout.strides(), &[12, -4, 1]); /// assert_eq!(layout.offset(), 8); /// ``` - /// - /// # 参数 - /// - `axis`: 要进行切片的轴的索引。 - /// - `start`: 切片的起始位置。 - /// - `step`: 切片的步长。 - /// - `len`: 切片的长度。 - /// - /// # 返回值 - /// 返回一个新的 `ArrayLayout` 实例,其形状、步长和偏移量已根据切片操作进行更新。 pub fn slice(&self, axis: usize, start: usize, step: isize, len: usize) -> Self { - // 调用 slice_many 方法,传入单个切片参数 self.slice_many(&[SliceArg { axis, start, @@ -51,76 +35,43 @@ impl ArrayLayout { } /// 一次对多个阶进行切片变换。 - /// - /// 该方法允许同时在多个轴上进行切片操作,根据传入的 `SliceArg` 切片参数更新布局的形状、步长和偏移量。 - /// - /// # 参数 - /// - `args`: 包含多个 `SliceArg` 结构体的切片,每个结构体表示一个轴的切片参数。 - /// - /// # 返回值 - /// 返回一个新的 `ArrayLayout` 实例,其形状、步长和偏移量已根据所有切片参数进行更新。 pub fn slice_many(&self, mut args: &[SliceArg]) -> Self { - // 获取当前布局的内容 let content = self.content(); - // 初始化偏移量为当前布局的偏移量 let mut offset = content.offset(); - // 同时迭代当前布局的形状和步长,并获取索引 let iter = zip(content.shape(), content.strides()).enumerate(); - // 创建一个新的 ArrayLayout 实例,其维度数量与当前布局相同 let mut ans = Self::with_ndim(self.ndim); - // 获取新布局内容的可变引用 let mut content = ans.content_mut(); - // 遍历当前布局的形状和步长 for (i, (&d, &s)) in iter { match args { - // 如果当前轴与切片参数的轴匹配 [arg, tail @ ..] if arg.axis == i => { - // 解构切片参数 let &SliceArg { axis, start, step, len, } = arg; - // 引入标准库中的 Ordering 枚举,用于比较步长的正负 use std::cmp::Ordering::*; - // 根据步长的正负计算实际的切片长度 let len = match step.cmp(&0) { - // 步长为正数的情况 Greater => { - // 断言起始位置小于该轴的维度大小 assert!(start < d); - // 更新偏移量 offset += start as isize * s; - // 计算实际的切片长度 (d - start).div_ceil(step as _).min(len) } - // 步长为 0 的情况 Equal => { - // 断言起始位置小于该轴的维度大小 assert!(start < d); - // 更新偏移量 offset += start as isize * s; - // 切片长度保持不变 len } - // 步长为负数的情况 Less => { - // 确保起始位置不超过该轴的维度大小减 1 let start = start.min(d - 1); - // 更新偏移量 offset += start as isize * s; - // 计算实际的切片长度 (start + 1).div_ceil((-step) as _).min(len) } }; - // 设置新布局指定轴的形状为实际的切片长度 content.set_shape(i, len); - // 设置新布局指定轴的步长为原步长乘以切片步长 content.set_stride(i, s * step); - // 检查下一个切片参数的轴是否合法 if let [next, ..] = tail { assert!( axis < next.axis && next.axis < self.ndim, @@ -130,45 +81,36 @@ impl ArrayLayout { self.ndim, ); } - // 更新剩余的切片参数 args = tail; } - // 如果当前轴没有对应的切片参数,保持形状和步长不变 [..] => { content.set_shape(i, d); content.set_stride(i, s); } } } - // 设置新布局的偏移量 content.set_offset(offset as _); - // 返回新的布局 ans } } -/// 测试 slice 和 slice_many 方法的正确性 #[test] fn test_slice() { - // 测试在轴 1 上,从位置 2 开始,步长为 -1,切片长度为 2 的情况 let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).slice(1, 2, -1, 2); assert_eq!(layout.shape(), &[2, 2, 4]); assert_eq!(layout.strides(), &[12, -4, 1]); assert_eq!(layout.offset(), 8); - // 测试在轴 1 上,从位置 2 开始,步长为 0,切片长度为 2 的情况 let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).slice(1, 2, 0, 2); assert_eq!(layout.shape(), &[2, 2, 4]); assert_eq!(layout.strides(), &[12, 0, 1]); assert_eq!(layout.offset(), 8); - // 测试在轴 1 上,从位置 0 开始,步长为 1,切片长度为 2 的情况 let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).slice(1, 0, 1, 2); assert_eq!(layout.shape(), &[2, 2, 4]); assert_eq!(layout.strides(), &[12, 4, 1]); assert_eq!(layout.offset(), 0); - // 测试同时在轴 1 和轴 2 上进行切片的情况 let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).slice_many(&[SliceArg{ axis: 1, start: 0, @@ -183,4 +125,4 @@ fn test_slice() { assert_eq!(layout.shape(), &[2, 2, 4]); assert_eq!(layout.strides(), &[12, 4, 1]); assert_eq!(layout.offset(), 0); -} \ No newline at end of file +} diff --git a/src/transform/split.rs b/src/transform/split.rs index e06124f..c08722a 100644 --- a/src/transform/split.rs +++ b/src/transform/split.rs @@ -1,30 +1,16 @@ -// 引入 crate 中的 ArrayLayout 结构体 -use crate::ArrayLayout; +use crate::ArrayLayout; -/// 切分变换参数。该结构体用于存储切分操作所需的信息,以便将一个 `ArrayLayout` 沿指定维度切分成多个部分。 -/// -/// - `src`: 指向要进行切分操作的原始 `ArrayLayout` 的引用。 -/// - `axis`: 指定要进行切分的维度的索引。 -/// - `start`: 当前切分部分在指定维度上的起始位置。 -/// - `parts`: 一个切片,包含每个切分部分在指定维度上的大小。 +/// 切分变换参数。 pub struct Split<'a, const N: usize> { - // 要进行切分的原始 ArrayLayout 的引用 src: &'a ArrayLayout, - // 进行切分的维度 axis: usize, - // 当前切分的起始位置 start: usize, - // 每个切分部分的大小 parts: &'a [usize], } -/// 为 ArrayLayout 结构体实现切分相关方法 impl ArrayLayout { - /// 切分变换将单个张量沿某个维度切分成多个张量,因此可以支持不均匀的切分。 + /// 切分变换讲单个张量沿某个维度切分成多个张量,因此可以支持不均匀的切分。 /// - /// 该方法返回一个 `Split` 迭代器,用于逐个获取切分后的 `ArrayLayout` 实例。 - /// - /// # 示例 /// ```rust /// # use ndarray_layout::ArrayLayout; /// let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0); @@ -40,18 +26,9 @@ impl ArrayLayout { /// assert_eq!(layout.strides(), &[12, 4, 1]); /// assert_eq!(layout.offset(), 1); /// ``` - /// - /// # 参数 - /// - `axis`: 要进行切分的维度的索引。 - /// - `parts`: 一个切片,包含每个切分部分在指定维度上的大小。所有部分大小之和必须等于指定维度的原始大小。 - /// - /// # 返回值 - /// 返回一个 `Split` 迭代器,用于遍历切分后的 `ArrayLayout` 实例。 #[inline] pub fn split<'a>(&'a self, axis: usize, parts: &'a [usize]) -> Split<'a, N> { - // 断言指定维度的原始大小等于所有切分部分大小之和 assert_eq!(self.shape()[axis], parts.iter().sum()); - // 创建并返回 Split 结构体实例 Split { src: self, axis, @@ -61,56 +38,30 @@ impl ArrayLayout { } } -/// 为 Split 结构体实现 Iterator trait,使其可以作为迭代器使用 impl Iterator for Split<'_, N> { - // 迭代器返回的元素类型为 ArrayLayout type Item = ArrayLayout; - /// 获取迭代器的下一个元素。 - /// - /// 该方法会从 `parts` 中取出第一个元素作为当前切分部分的大小, - /// 然后根据当前的起始位置和切分大小生成一个新的 `ArrayLayout` 实例。 - /// - /// # 返回值 - /// - 如果 `parts` 不为空,返回 `Some(ArrayLayout)`,表示下一个切分后的 `ArrayLayout` 实例。 - /// - 如果 `parts` 为空,返回 `None`,表示迭代结束。 #[inline] fn next(&mut self) -> Option { - // 尝试从 parts 中取出第一个元素和剩余部分 self.parts.split_first().map(|(&head, tail)| { - // 记录当前的起始位置 let start = self.start; - // 更新起始位置为当前起始位置加上当前切分部分的大小 self.start += head; - // 更新 parts 为剩余部分 self.parts = tail; - // 调用 src 的 slice 方法生成切分后的 ArrayLayout 实例 self.src.slice(self.axis, start, 1, head) }) } } -/// 测试 split 方法的正确性 #[test] fn test_split() { - // 创建一个 ArrayLayout 实例 let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0); - // 调用 split 方法进行切分,得到一个 Split 迭代器 let mut splits = layout.split(2, &[1, 3]); - // 获取第一个切分后的 ArrayLayout 实例 let layout = splits.next().unwrap(); - // 断言第一个切分后的形状是否符合预期 assert_eq!(layout.shape(), &[2, 3, 1]); - // 断言第一个切分后的步长是否符合预期 assert_eq!(layout.strides(), &[12, 4, 1]); - // 断言第一个切分后的偏移量是否符合预期 assert_eq!(layout.offset(), 0); - // 获取第二个切分后的 ArrayLayout 实例 let layout = splits.next().unwrap(); - // 断言第二个切分后的形状是否符合预期 assert_eq!(layout.shape(), &[2, 3, 3]); - // 断言第二个切分后的步长是否符合预期 assert_eq!(layout.strides(), &[12, 4, 1]); - // 断言第二个切分后的偏移量是否符合预期 assert_eq!(layout.offset(), 1); } \ No newline at end of file diff --git a/src/transform/tile.rs b/src/transform/tile.rs index 2f6e5b3..18f7bf2 100644 --- a/src/transform/tile.rs +++ b/src/transform/tile.rs @@ -1,24 +1,21 @@ -// 引入 crate 中的 ArrayLayout 结构体和 Endian 枚举 use crate::{ArrayLayout, Endian}; -// 引入标准库中的 zip 函数,用于同时迭代多个迭代器 use std::iter::zip; -/// 分块变换参数。该结构体用于存储分块变换所需的参数。 +/// 分块变换参数。 #[derive(Clone, PartialEq, Eq, Debug)] pub struct TileArg<'a> { - /// 分块操作要应用的轴。轴的索引从 0 开始。 + /// 分块的轴。 pub axis: usize, - /// 分块的顺序。大端序或小端序决定了分块后维度在形状中的排列顺序。 + /// 分块的顺序。 pub endian: Endian, - /// 分块的大小数组。每个元素表示对应分块的大小。 + /// 分块的大小。 pub tiles: &'a [usize], } -/// 为 ArrayLayout 结构体实现分块变换相关方法 impl ArrayLayout { - /// 大端分块变换。将单个维度划分为多个分块,大端分块使得分块后范围更大的维度在形状中更靠前的位置。 + /// 分块变换是将单个维度划分为多个分块的变换。 + /// 大端分块使得分块后范围更大的维度在形状中更靠前的位置。 /// - /// # 示例 /// ```rust /// # use ndarray_layout::ArrayLayout; /// let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_be(2, &[2, 3]); @@ -26,16 +23,8 @@ impl ArrayLayout { /// assert_eq!(layout.strides(), &[18, 6, 3, 1]); /// assert_eq!(layout.offset(), 0); /// ``` - /// - /// # 参数 - /// - `axis`: 要进行分块的轴的索引。 - /// - `tiles`: 分块大小的数组。 - /// - /// # 返回值 - /// 返回一个新的 `ArrayLayout` 实例,其维度已根据大端分块规则进行变换。 #[inline] pub fn tile_be(&self, axis: usize, tiles: &[usize]) -> Self { - // 调用 tile_many 方法,传入大端序的分块参数 self.tile_many(&[TileArg { axis, endian: Endian::BigEndian, @@ -43,9 +32,9 @@ impl ArrayLayout { }]) } - /// 小端分块变换。将单个维度划分为多个分块,小端分块使得分块后范围更小的维度在形状中更靠前的位置。 + /// 分块变换是将单个维度划分为多个分块的变换。 + /// 小端分块使得分块后范围更小的维度在形状中更靠前的位置。 /// - /// # 示例 /// ```rust /// # use ndarray_layout::ArrayLayout; /// let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_le(2, &[2, 3]); @@ -53,16 +42,8 @@ impl ArrayLayout { /// assert_eq!(layout.strides(), &[18, 6, 1, 2]); /// assert_eq!(layout.offset(), 0); /// ``` - /// - /// # 参数 - /// - `axis`: 要进行分块的轴的索引。 - /// - `tiles`: 分块大小的数组。 - /// - /// # 返回值 - /// 返回一个新的 `ArrayLayout` 实例,其维度已根据小端分块规则进行变换。 #[inline] pub fn tile_le(&self, axis: usize, tiles: &[usize]) -> Self { - // 调用 tile_many 方法,传入小端序的分块参数 self.tile_many(&[TileArg { axis, endian: Endian::LittleEndian, @@ -71,68 +52,43 @@ impl ArrayLayout { } /// 一次对多个阶进行分块变换。 - /// - /// # 参数 - /// - `args`: 包含多个 `TileArg` 结构体的切片,每个结构体表示一个轴的分块参数。 - /// - /// # 返回值 - /// 返回一个新的 `ArrayLayout` 实例,其维度已根据所有分块参数进行变换。 pub fn tile_many(&self, mut args: &[TileArg]) -> Self { - // 获取当前布局的内容 let content = self.content(); - // 获取当前布局的形状 let shape = content.shape(); - // 同时迭代形状和步长,并获取索引 let iter = zip(shape, content.strides()).enumerate(); - // 定义一个闭包,用于检查分块参数是否有效 let check = |&TileArg { axis, tiles, .. }| { - // 检查指定轴的维度大小是否等于分块大小的乘积 shape .get(axis) .filter(|&&d| d == tiles.iter().product()) .is_some() }; - // 初始化新布局的维度数量和上一个处理的轴的索引 let (mut new, mut last_axis) = match args { [first, ..] => { - // 检查第一个分块参数是否有效 assert!(check(first)); - // 新布局的维度数量初始化为第一个分块参数的分块数量 (first.tiles.len(), first.axis) } - [..] => return self.clone(), // 如果没有分块参数,直接克隆当前布局 + [..] => return self.clone(), }; - // 遍历剩余的分块参数 for arg in &args[1..] { - // 检查分块参数是否有效 assert!(check(arg)); - // 确保当前轴的索引大于上一个轴的索引 assert!(arg.axis > last_axis); - // 累加新布局的维度数量 new += arg.tiles.len(); - // 更新上一个处理的轴的索引 last_axis = arg.axis; } - // 创建一个新的 ArrayLayout 实例,其维度数量为当前布局的维度数量加上新维度数量减去分块参数的数量 let mut ans = Self::with_ndim(self.ndim + new - args.len()); - // 获取新布局内容的可变引用 let mut content = ans.content_mut(); - // 将新布局的偏移量设置为当前布局的偏移量 content.set_offset(self.offset()); - // 初始化新布局的索引 let mut j = 0; - // 定义一个闭包,用于设置新布局的形状和步长 let mut push = |t, s| { content.set_shape(j, t); content.set_stride(j, s); j += 1; }; - // 遍历当前布局的形状和步长 for (i, (&d, &s)) in iter { match *args { [ @@ -143,10 +99,8 @@ impl ArrayLayout { }, ref tail @ .., ] if axis == i => { - // 如果当前轴与分块参数的轴匹配 match endian { Endian::BigEndian => { - // 大端分块规则 // tile : [a, b , c] // strides: [s * c * b, s * c, s] let mut s = s * d as isize; @@ -156,7 +110,6 @@ impl ArrayLayout { } } Endian::LittleEndian => { - // 小端分块规则 // tile : [a, b , c ] // strides: [s, s * a, s * a * b] let mut s = s; @@ -166,18 +119,15 @@ impl ArrayLayout { } } } - // 处理完当前分块参数后,更新剩余的分块参数 args = tail; } - [..] => push(d, s), // 如果当前轴没有分块参数,直接设置形状和步长 + [..] => push(d, s), } } - // 返回新的布局 ans } } -/// 测试大端分块变换的正确性 #[test] fn test_tile_be() { let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_be(2, &[2, 3]); @@ -186,7 +136,6 @@ fn test_tile_be() { assert_eq!(layout.offset(), 0); } -/// 测试小端分块变换的正确性 #[test] fn test_tile_le() { let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_le(2, &[2, 3]); @@ -195,7 +144,6 @@ fn test_tile_le() { assert_eq!(layout.offset(), 0); } -/// 测试无分块参数时的行为 #[test] fn test_empty_tile() { let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_many(&[]); @@ -204,7 +152,6 @@ fn test_empty_tile() { assert_eq!(layout.offset(), 0); } -/// 测试多个分块参数时的行为 #[test] fn test_multiple_tiles(){ let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_many(&[ diff --git a/src/transform/transpose.rs b/src/transform/transpose.rs index dee97fc..a54ceae 100644 --- a/src/transform/transpose.rs +++ b/src/transform/transpose.rs @@ -1,16 +1,9 @@ -// 引入 crate 中的 ArrayLayout 结构体 -use crate::ArrayLayout; -// 引入标准库中的 BTreeSet 用于存储唯一且有序的元素,以及 zip 函数用于迭代多个迭代器 +use crate::ArrayLayout; use std::{collections::BTreeSet, iter::zip}; -/// 为 ArrayLayout 结构体实现方法 impl ArrayLayout { /// 转置变换允许调换张量的维度顺序,但不改变元素的存储顺序。 /// - /// 该方法接收一个排列数组 `perm`,根据该数组重新排列原布局的维度。 - /// 未在 `perm` 中指定的维度将保持不变。 - /// - /// # 示例 /// ```rust /// # use ndarray_layout::ArrayLayout; /// let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).transpose(&[1, 0]); @@ -18,81 +11,46 @@ impl ArrayLayout { /// assert_eq!(layout.strides(), &[4, 12, 1]); /// assert_eq!(layout.offset(), 0); /// ``` - /// - /// # 参数 - /// - `perm`: 一个切片,包含要交换的维度的索引。索引必须唯一。 - /// - /// # 返回值 - /// 返回一个新的 `ArrayLayout` 实例,其维度顺序已根据 `perm` 进行转置。 pub fn transpose(&self, perm: &[usize]) -> Self { - // 将 perm 中的元素收集到 BTreeSet 中,确保元素唯一且有序 let perm_ = perm.iter().collect::>(); - // 断言 perm 中的元素都是唯一的 assert_eq!(perm_.len(), perm.len()); - // 获取当前布局的内容 let content = self.content(); - // 获取当前布局的形状 let shape = content.shape(); - // 获取当前布局的步长 let strides = content.strides(); - // 创建一个新的 ArrayLayout 实例,其维度数量与当前布局相同 let mut ans = Self::with_ndim(self.ndim); - // 获取新布局内容的可变引用 let mut content = ans.content_mut(); - // 将新布局的偏移量设置为当前布局的偏移量 content.set_offset(self.offset()); - - // 定义一个闭包,用于设置新布局指定索引处的形状和步长 let mut set = |i, j| { - // 设置新布局索引 i 处的形状为原布局索引 j 处的形状 content.set_shape(i, shape[j]); - // 设置新布局索引 i 处的步长为原布局索引 j 处的步长 content.set_stride(i, strides[j]); }; - // 记录上一次处理的维度索引,初始化为 0 let mut last = 0; - // 同时遍历有序的 perm_ 和原始的 perm for (&i, &j) in zip(perm_, perm) { - // 处理 last 到 i 之间未在 perm 中指定的维度,保持这些维度不变 for i in last..i { set(i, i); } - // 根据 perm 中的映射关系设置新布局的形状和步长 set(i, j); - // 更新 last 为当前处理的维度索引加 1 last = i + 1; } - // 处理 perm 中未涉及的剩余维度,保持这些维度不变 for i in last..shape.len() { set(i, i); } - - // 返回转置后的新布局 ans } } -/// 测试 transpose 方法的正确性 #[test] fn test_transpose() { - // 创建一个初始布局并进行转置操作 let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).transpose(&[1, 0]); - // 断言转置后的形状是否符合预期 assert_eq!(layout.shape(), &[3, 2, 4]); - // 断言转置后的步长是否符合预期 assert_eq!(layout.strides(), &[4, 12, 1]); - // 断言转置后的偏移量是否符合预期 assert_eq!(layout.offset(), 0); - // 创建另一个初始布局并进行转置操作 let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).transpose(&[2, 0]); - // 断言转置后的形状是否符合预期 assert_eq!(layout.shape(), &[4, 3, 2]); - // 断言转置后的步长是否符合预期 assert_eq!(layout.strides(), &[1, 4, 12]); - // 断言转置后的偏移量是否符合预期 assert_eq!(layout.offset(), 0); } \ No newline at end of file From 0b0cda41ce05f0a340e1cf87fe4f55005cd3e94f Mon Sep 17 00:00:00 2001 From: Simon25772 Date: Thu, 24 Apr 2025 17:30:57 +0800 Subject: [PATCH 04/21] =?UTF-8?q?style:=20=E6=A0=BC=E5=BC=8F=E5=8C=96?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/transform/index.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transform/index.rs b/src/transform/index.rs index dba9ed1..5b8f1fc 100644 --- a/src/transform/index.rs +++ b/src/transform/index.rs @@ -86,10 +86,10 @@ fn test() { assert_eq!(layout.offset(), 20); let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20); - let layout = layout.index_many(&[IndexArg{axis:0, - index:1}, - IndexArg{axis:1, - index:2}]); + let layout = layout.index_many(&[ + IndexArg { axis: 0, index: 1 }, + IndexArg { axis: 1, index: 2 }, + ]); assert_eq!(layout.shape(), &[4]); assert_eq!(layout.strides(), &[1]); assert_eq!(layout.offset(), 24); From c27509910486424e7f54017315f40815dca2ee3a Mon Sep 17 00:00:00 2001 From: Simon25772 Date: Thu, 24 Apr 2025 17:31:25 +0800 Subject: [PATCH 05/21] =?UTF-8?q?style:=20=E4=BD=BF=E7=94=A8vscode?= =?UTF-8?q?=E8=87=AA=E5=8A=A8=E6=A0=BC=E5=BC=8F=E5=8C=96=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/fmt.rs | 14 +++++++------- src/lib.rs | 21 +++++++++++---------- src/transform/merge.rs | 8 +++----- src/transform/slice.rs | 25 ++++++++++++++----------- src/transform/split.rs | 2 +- src/transform/tile.rs | 6 +++--- src/transform/transpose.rs | 2 +- 7 files changed, 40 insertions(+), 38 deletions(-) diff --git a/src/fmt.rs b/src/fmt.rs index 051bba1..63381f0 100644 --- a/src/fmt.rs +++ b/src/fmt.rs @@ -13,11 +13,13 @@ impl ArrayLayout { ptr: *const T, ) -> fmt::Result { match self.ndim() { - 0 => {write!(f, "array<> = [{}]", unsafe { + 0 => { + write!(f, "array<> = [{}]", unsafe { ptr.byte_offset(self.offset()).read_unaligned() }) } - 1 => {let &[n] = self.shape() else { unreachable!() }; + 1 => { + let &[n] = self.shape() else { unreachable!() }; let &[s] = self.strides() else { unreachable!() }; writeln!(f, "array<{n}>[")?; @@ -30,7 +32,8 @@ impl ArrayLayout { writeln!(f, "]")?; Ok(()) } - _ => {let mut title = "array<".to_string(); + _ => { + let mut title = "array<".to_string(); for d in self.shape() { title.push_str(&format!("{d}x")) } @@ -85,8 +88,6 @@ impl ArrayLayout { } } - - #[test] fn test() { const DATA: &[u8] = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 0]; @@ -112,7 +113,6 @@ fn test() { let tensor = Tensor(tensor.0.tile_be(0, &[2, 3]).tile_be(2, &[5, 2])); println!("{}", tensor); - let tensor=Tensor(ArrayLayout::<4>::with_ndim(0)); + let tensor = Tensor(ArrayLayout::<4>::with_ndim(0)); println!("{}", tensor); - } diff --git a/src/lib.rs b/src/lib.rs index 7ffb82e..892f640 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -60,7 +60,11 @@ impl ArrayLayout { /// ``` pub fn new(shape: &[usize], strides: &[isize], offset: isize) -> Self { // check - assert_eq!(shape.len(),strides.len(),"shape and strides must have the same length"); + assert_eq!( + shape.len(), + strides.len(), + "shape and strides must have the same length" + ); let mut ans = Self::with_ndim(shape.len()); let mut content = ans.content_mut(); @@ -178,7 +182,7 @@ impl ArrayLayout { } /// Calculates the range of data in bytes to determine the location of the memory area that the array needs to access. - /// + /// /// ```rust /// # use ndarray_layout::ArrayLayout; /// let layout = ArrayLayout::<4>::new(&[2, 3, 4],&[12, -4, 1], 20); @@ -193,7 +197,7 @@ impl ArrayLayout { use std::cmp::Ordering::{Equal, Greater, Less}; let i = d as isize - 1; match s.cmp(&0) { - Equal => {}, + Equal => {} Less => start += s * i, Greater => end += s * i, } @@ -217,7 +221,7 @@ use std::{ impl ArrayLayout { #[inline] fn ptr_allocated(&self) -> Option> { - const { assert!(N > 0)} + const { assert!(N > 0) } // ndim>N则content是ptr,否则是元组 if self.ndim > N { Some(unsafe { self.content.ptr }) @@ -275,7 +279,7 @@ impl Content { unsafe { from_raw_parts(self.ptr.as_ptr(), 1 + self.ndim * 2) } } - #[inline] + #[inline] fn offset(&self) -> isize { unsafe { self.ptr.cast().read() } } @@ -333,7 +337,6 @@ fn layout(ndim: usize) -> Layout { Layout::array::(1 + ndim * 2).unwrap() } - #[test] fn test_new() { let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20); @@ -344,9 +347,7 @@ fn test_new() { } #[test] -fn test_new_different_length(){ - -} +fn test_new_different_length() {} #[test] fn test_new_contiguous_little_endian() { @@ -405,7 +406,7 @@ fn test_data_range_positive_strides() { #[test] fn test_data_range_mixed_strides() { - let layout = ArrayLayout::<4>::new(&[2, 3, 4],&[12, -4, 0], 20); + let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 0], 20); let range = layout.data_range(); assert_eq!(range, 12..=32); } diff --git a/src/transform/merge.rs b/src/transform/merge.rs index 50208f1..93f280f 100644 --- a/src/transform/merge.rs +++ b/src/transform/merge.rs @@ -136,7 +136,6 @@ impl ArrayLayout { } push(d, *s); - } for j in last_end..shape.len() { push(shape[j], strides[j]); @@ -148,13 +147,12 @@ impl ArrayLayout { #[test] fn test_merge_return_none() { - let layout = ArrayLayout::<3>::new(&[16, 4, 2], &[8, 4, 1], 0) - .merge_be(0, 3); + let layout = ArrayLayout::<3>::new(&[16, 4, 2], &[8, 4, 1], 0).merge_be(0, 3); assert!(layout.is_none()); } #[test] -fn test_merge_pairs_empyt(){ +fn test_merge_pairs_empyt() { let layout = ArrayLayout::<3>::new(&[1, 1, 1], &[1, 1, 1], 0) .merge_be(0, 2) .unwrap(); @@ -184,7 +182,7 @@ fn test_merge_le_example() { } #[test] -fn test_merge_len_zero(){ +fn test_merge_len_zero() { let layout = ArrayLayout::<3>::new(&[4, 3, 2], &[1, 4, 12], 0); let merged_layout = layout.merge_le(0, 0).unwrap(); diff --git a/src/transform/slice.rs b/src/transform/slice.rs index 3108fc2..9a3695e 100644 --- a/src/transform/slice.rs +++ b/src/transform/slice.rs @@ -111,17 +111,20 @@ fn test_slice() { assert_eq!(layout.strides(), &[12, 4, 1]); assert_eq!(layout.offset(), 0); - let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).slice_many(&[SliceArg{ - axis: 1, - start: 0, - step: 1, - len: 2, - },SliceArg{ - axis: 2, - start: 0, - step: 1, - len: 4, - }]); + let layout = ArrayLayout::<3>::new(&[2, 3, 4], &[12, 4, 1], 0).slice_many(&[ + SliceArg { + axis: 1, + start: 0, + step: 1, + len: 2, + }, + SliceArg { + axis: 2, + start: 0, + step: 1, + len: 4, + }, + ]); assert_eq!(layout.shape(), &[2, 2, 4]); assert_eq!(layout.strides(), &[12, 4, 1]); assert_eq!(layout.offset(), 0); diff --git a/src/transform/split.rs b/src/transform/split.rs index c08722a..9e28219 100644 --- a/src/transform/split.rs +++ b/src/transform/split.rs @@ -64,4 +64,4 @@ fn test_split() { assert_eq!(layout.shape(), &[2, 3, 3]); assert_eq!(layout.strides(), &[12, 4, 1]); assert_eq!(layout.offset(), 1); -} \ No newline at end of file +} diff --git a/src/transform/tile.rs b/src/transform/tile.rs index 18f7bf2..d0b655b 100644 --- a/src/transform/tile.rs +++ b/src/transform/tile.rs @@ -153,7 +153,7 @@ fn test_empty_tile() { } #[test] -fn test_multiple_tiles(){ +fn test_multiple_tiles() { let layout = ArrayLayout::<3>::new(&[2, 3, 6], &[18, 6, 1], 0).tile_many(&[ TileArg { axis: 0, @@ -164,9 +164,9 @@ fn test_multiple_tiles(){ axis: 2, endian: Endian::BigEndian, tiles: &[2, 3], - } + }, ]); assert_eq!(layout.shape(), &[2, 1, 3, 2, 3]); assert_eq!(layout.strides(), &[18, 18, 6, 3, 1]); assert_eq!(layout.offset(), 0); -} \ No newline at end of file +} diff --git a/src/transform/transpose.rs b/src/transform/transpose.rs index a54ceae..30009d6 100644 --- a/src/transform/transpose.rs +++ b/src/transform/transpose.rs @@ -53,4 +53,4 @@ fn test_transpose() { assert_eq!(layout.shape(), &[4, 3, 2]); assert_eq!(layout.strides(), &[1, 4, 12]); assert_eq!(layout.offset(), 0); -} \ No newline at end of file +} From f6cea4fd99d6234d687b2684574b25d88e0581ed Mon Sep 17 00:00:00 2001 From: Simon25772 Date: Thu, 24 Apr 2025 17:57:21 +0800 Subject: [PATCH 06/21] =?UTF-8?q?style:=20=E6=A0=BC=E5=BC=8F=E5=8C=96Rust?= =?UTF-8?q?=20Document?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/fmt.rs | 3 ++- src/lib.rs | 20 ++++++++++++++++++-- src/transform/broadcast.rs | 3 +++ src/transform/index.rs | 3 +++ 4 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/fmt.rs b/src/fmt.rs index 63381f0..ab3123d 100644 --- a/src/fmt.rs +++ b/src/fmt.rs @@ -88,7 +88,8 @@ impl ArrayLayout { } } -#[test] +# [test] + fn test() { const DATA: &[u8] = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 0]; diff --git a/src/lib.rs b/src/lib.rs index 892f640..50dfd61 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ -#![doc = include_str!("../README.md")] +# ![doc = include_str!("../README.md")] #![deny(warnings, missing_docs)] /// An array layout allow N dimensions inlined. @@ -40,7 +40,9 @@ impl Drop for ArrayLayout { } /// 元信息存储顺序。 + #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] + pub enum Endian { /// 大端序,范围更大的维度在元信息中更靠前的位置。 BigEndian, @@ -57,7 +59,7 @@ impl ArrayLayout { /// assert_eq!(layout.offset(), 20); /// assert_eq!(layout.shape(), &[2, 3, 4]); /// assert_eq!(layout.strides(), &[12, -4, 1]); - /// ``` + ///``` pub fn new(shape: &[usize], strides: &[isize], offset: isize) -> Self { // check assert_eq!( @@ -333,11 +335,13 @@ impl Content { } #[inline] + fn layout(ndim: usize) -> Layout { Layout::array::(1 + ndim * 2).unwrap() } #[test] + fn test_new() { let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20); assert_eq!(layout.offset(), 20); @@ -347,9 +351,11 @@ fn test_new() { } #[test] + fn test_new_different_length() {} #[test] + fn test_new_contiguous_little_endian() { let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::LittleEndian, 4); assert_eq!(layout.offset(), 0); @@ -358,6 +364,7 @@ fn test_new_contiguous_little_endian() { } #[test] + fn test_new_contiguous_big_endian() { let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::LittleEndian, 4); assert_eq!(layout.offset(), 0); @@ -367,11 +374,13 @@ fn test_new_contiguous_big_endian() { #[test] #[should_panic(expected = "shape and strides must have the same length")] + fn test_new_invalid_shape_strides_length() { ArrayLayout::<4>::new(&[2, 3], &[12, -4, 1], 20); } #[test] + fn test_to_inline_size() { let layout = ArrayLayout::<4>::new_contiguous(&[3, 4], Endian::BigEndian, 0); assert_eq!(size_of_val(&layout), (2 * 4 + 2) * size_of::()); @@ -380,24 +389,28 @@ fn test_to_inline_size() { } #[test] + fn test_num_elements() { let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::BigEndian, 20); assert_eq!(layout.num_elements(), 24); } #[test] + fn test_element_offset_little_endian() { let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::LittleEndian, 4); assert_eq!(layout.element_offset(22, Endian::LittleEndian), 88); } #[test] + fn test_element_offset_big_endian() { let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::BigEndian, 4); assert_eq!(layout.element_offset(22, Endian::BigEndian), 88); } #[test] + fn test_data_range_positive_strides() { let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::LittleEndian, 4); let range = layout.data_range(); @@ -405,6 +418,7 @@ fn test_data_range_positive_strides() { } #[test] + fn test_data_range_mixed_strides() { let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 0], 20); let range = layout.data_range(); @@ -412,6 +426,7 @@ fn test_data_range_mixed_strides() { } #[test] + fn test_clone_and_eq() { let layout1 = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20); let layout2 = layout1.clone(); @@ -419,6 +434,7 @@ fn test_clone_and_eq() { } #[test] + fn test_drop() { let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20); drop(layout); diff --git a/src/transform/broadcast.rs b/src/transform/broadcast.rs index 7998e0f..175a2bd 100644 --- a/src/transform/broadcast.rs +++ b/src/transform/broadcast.rs @@ -1,7 +1,9 @@ use crate::ArrayLayout; /// 索引变换参数。 + #[derive(Clone, PartialEq, Eq, Debug)] + pub struct BroadcastArg { /// 广播的轴。 pub axis: usize, @@ -37,6 +39,7 @@ impl ArrayLayout { } #[test] + fn test_broadcast() { let layout = ArrayLayout::<3>::new(&[1, 5, 2], &[10, 2, 1], 0).broadcast(0, 10); assert_eq!(layout.shape(), &[10, 5, 2]); diff --git a/src/transform/index.rs b/src/transform/index.rs index 5b8f1fc..3b040aa 100644 --- a/src/transform/index.rs +++ b/src/transform/index.rs @@ -2,7 +2,9 @@ use std::iter::zip; /// 索引变换参数。 + #[derive(Clone, PartialEq, Eq, Debug)] + pub struct IndexArg { /// 索引的轴。 pub axis: usize, @@ -66,6 +68,7 @@ impl ArrayLayout { } #[test] + fn test() { let layout = ArrayLayout::<1>::new(&[2, 3, 4], &[12, 4, 1], 0); let layout = layout.index(1, 2); From 09e323391694cc6d27ff8561600994d3a93c42f8 Mon Sep 17 00:00:00 2001 From: Simon25772 Date: Mon, 28 Apr 2025 11:35:16 +0800 Subject: [PATCH 07/21] =?UTF-8?q?style:=20=E6=9B=B4=E6=94=B9=E6=A0=BC?= =?UTF-8?q?=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/fmt.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fmt.rs b/src/fmt.rs index ab3123d..e75e89b 100644 --- a/src/fmt.rs +++ b/src/fmt.rs @@ -88,7 +88,7 @@ impl ArrayLayout { } } -# [test] +#[test] fn test() { const DATA: &[u8] = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 0]; From 0afe8076a981b3f17c641f5f371354deb377a4e5 Mon Sep 17 00:00:00 2001 From: Shenghu Su <63157630+Simon25772@users.noreply.github.com> Date: Mon, 28 Apr 2025 11:31:21 +0800 Subject: [PATCH 08/21] =?UTF-8?q?ci:=20=E6=9B=B4=E6=96=B0action=E8=84=9A?= =?UTF-8?q?=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/build.yml | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index dd1430a..149e5ac 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,4 +1,4 @@ -# This workflow uses actions that are not certified by GitHub. +# This workflow uses actions that are not certified by GitHub. # They are provided by a third-party and are governed by # separate terms of service, privacy policy, and support # documentation. @@ -51,3 +51,18 @@ jobs: with: sarif_file: rust-clippy-results.sarif wait-for-processing: true + + - name: Install required cargo + run: cargo install cargo-tarpaulin + + - name: Generate code coverage + run: + cargo tarpaulin + --all-features + --workspace --timeout 120 --out xml + + - name: Upload to codecov.io + uses: codecov/codecov-action@v5 + with: + token: ${{secrets.CODECOV_TOKEN}} + fail_ci_if_error: true From a0854ff1f2aaef124369e8897cd78ad5b229476c Mon Sep 17 00:00:00 2001 From: Simon25772 Date: Mon, 28 Apr 2025 12:05:42 +0800 Subject: [PATCH 09/21] =?UTF-8?q?docs(README):=20=E8=B0=83=E6=95=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/README.md b/README.md index f3ddb4d..4cf6496 100644 --- a/README.md +++ b/README.md @@ -51,16 +51,4 @@ let broadcasted_layout = layout.broadcast(0, 4); assert_eq!(broadcasted_layout.shape(), &[4, 2, 3]); assert_eq!(broadcasted_layout.strides(), &[0, 4, 1]); assert_eq!(broadcasted_layout.offset(), 0); - -// 一次对多个阶进行广播变换 -let args = [ - BroadcastArg { axis: 0, times: 4 }, - BroadcastArg { axis: 1, times: 3 } -]; -let multi_broadcasted_layout = layout.broadcast_many(&args); - -// 验证多次广播变换后的形状和步长 -assert_eq!(multi_broadcasted_layout.shape(), &[4, 3, 3]); -assert_eq!(multi_broadcasted_layout.strides(), &[0, 0, 1]); -assert_eq!(multi_broadcasted_layout.offset(), 0); ``` From 5091f7e5c6c3bf117006cd9abfbe3f67d4b816bc Mon Sep 17 00:00:00 2001 From: Simon25772 Date: Mon, 28 Apr 2025 14:20:25 +0800 Subject: [PATCH 10/21] =?UTF-8?q?style(README,lib,merge):=20=E8=B0=83?= =?UTF-8?q?=E6=95=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 10 +++++----- src/lib.rs | 5 ++--- src/transform/merge.rs | 18 ++++++++++++++---- 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 4cf6496..d668b3b 100644 --- a/README.md +++ b/README.md @@ -18,17 +18,17 @@ ndarray-layout 是一个用于处理多维数组布局的 Rust 库,它提供 ### 多维数组布局管理 -* ArrayLayout 结构体支持指定任意维度的数组布局,通过 new 方法可以创建具有指定形状、步长和偏移量的布局。 -* 提供 new_contiguous 方法,用于创建连续的数组布局,支持大端序(BigEndian)和小端序(LittleEndian)两种存储顺序。 +- ArrayLayout 结构体支持指定任意维度的数组布局,通过 new 方法可以创建具有指定形状、步长和偏移量的布局; +- 提供 new_contiguous 方法,用于创建连续的数组布局,支持大端序(BigEndian)和小端序(LittleEndian)两种存储顺序; ### 元信息访问 -* 提供便捷的方法来访问数组布局的元信息,如 ndim、offset、shape 和 strides 等。 -* 支持计算数组元素的偏移量和数据范围,方便进行内存访问和数据处理。 +- 提供便捷的方法来访问数组布局的元信息,如 ndim、offset、shape 和 strides 等; +- 支持计算数组元素的偏移量和数据范围,方便进行内存访问和数据处理; ### 布局操作功能 -* 提供多种布局变换方法,如 index、tile、transpose、merge 和 slice 等,方便对数组布局进行各种变换操作。 +- 提供多种布局变换方法,如 index、tile、transpose、merge 和 slice 等,方便对数组布局进行各种变换操作; ## 使用示例 diff --git a/src/lib.rs b/src/lib.rs index 50dfd61..464dd8e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ -# ![doc = include_str!("../README.md")] +#![doc = include_str!("../README.md")] #![deny(warnings, missing_docs)] /// An array layout allow N dimensions inlined. @@ -40,7 +40,6 @@ impl Drop for ArrayLayout { } /// 元信息存储顺序。 - #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] pub enum Endian { @@ -59,7 +58,7 @@ impl ArrayLayout { /// assert_eq!(layout.offset(), 20); /// assert_eq!(layout.shape(), &[2, 3, 4]); /// assert_eq!(layout.strides(), &[12, -4, 1]); - ///``` + /// ``` pub fn new(shape: &[usize], strides: &[isize], offset: isize) -> Self { // check assert_eq!( diff --git a/src/transform/merge.rs b/src/transform/merge.rs index 93f280f..3d304db 100644 --- a/src/transform/merge.rs +++ b/src/transform/merge.rs @@ -76,8 +76,19 @@ impl ArrayLayout { let shape = content.shape(); let strides = content.strides(); - // 修改BUG - let merged = args.iter().map(|arg| arg.len.max(1)).sum::(); + let (merged, flag) = args.iter().fold((0, true), |(acc, _f), arg| { + ( + acc + arg.len.max(1), + match arg.len { + x if x >= 2 => false, + _ => true, + }, + ) + }); + // 如果所有arg.len都是0或者1,直接返回原布局 + if flag { + return Some(self.clone()); + } let mut ans = Self::with_ndim(self.ndim + args.len() - merged); let mut content = ans.content_mut(); @@ -94,7 +105,7 @@ impl ArrayLayout { let &MergeArg { start, len, endian } = arg; let end = start + len; - if len == 0 { + if len < 2 { continue; } @@ -111,7 +122,6 @@ impl ArrayLayout { } } - // 修改BUG last_end = end; if pairs.is_empty() { From 0ce18bae88477e308c48a201e7708a6fea71a93b Mon Sep 17 00:00:00 2001 From: Simon25772 Date: Mon, 28 Apr 2025 14:28:06 +0800 Subject: [PATCH 11/21] =?UTF-8?q?style(lib,merge):=20=E8=B0=83=E6=95=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/lib.rs | 17 ----------------- src/transform/merge.rs | 8 +------- 2 files changed, 1 insertion(+), 24 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 464dd8e..59b18dc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -334,13 +334,11 @@ impl Content { } #[inline] - fn layout(ndim: usize) -> Layout { Layout::array::(1 + ndim * 2).unwrap() } #[test] - fn test_new() { let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20); assert_eq!(layout.offset(), 20); @@ -350,11 +348,6 @@ fn test_new() { } #[test] - -fn test_new_different_length() {} - -#[test] - fn test_new_contiguous_little_endian() { let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::LittleEndian, 4); assert_eq!(layout.offset(), 0); @@ -363,7 +356,6 @@ fn test_new_contiguous_little_endian() { } #[test] - fn test_new_contiguous_big_endian() { let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::LittleEndian, 4); assert_eq!(layout.offset(), 0); @@ -373,13 +365,11 @@ fn test_new_contiguous_big_endian() { #[test] #[should_panic(expected = "shape and strides must have the same length")] - fn test_new_invalid_shape_strides_length() { ArrayLayout::<4>::new(&[2, 3], &[12, -4, 1], 20); } #[test] - fn test_to_inline_size() { let layout = ArrayLayout::<4>::new_contiguous(&[3, 4], Endian::BigEndian, 0); assert_eq!(size_of_val(&layout), (2 * 4 + 2) * size_of::()); @@ -388,28 +378,24 @@ fn test_to_inline_size() { } #[test] - fn test_num_elements() { let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::BigEndian, 20); assert_eq!(layout.num_elements(), 24); } #[test] - fn test_element_offset_little_endian() { let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::LittleEndian, 4); assert_eq!(layout.element_offset(22, Endian::LittleEndian), 88); } #[test] - fn test_element_offset_big_endian() { let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::BigEndian, 4); assert_eq!(layout.element_offset(22, Endian::BigEndian), 88); } #[test] - fn test_data_range_positive_strides() { let layout = ArrayLayout::<4>::new_contiguous(&[2, 3, 4], Endian::LittleEndian, 4); let range = layout.data_range(); @@ -417,7 +403,6 @@ fn test_data_range_positive_strides() { } #[test] - fn test_data_range_mixed_strides() { let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 0], 20); let range = layout.data_range(); @@ -425,7 +410,6 @@ fn test_data_range_mixed_strides() { } #[test] - fn test_clone_and_eq() { let layout1 = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20); let layout2 = layout1.clone(); @@ -433,7 +417,6 @@ fn test_clone_and_eq() { } #[test] - fn test_drop() { let layout = ArrayLayout::<4>::new(&[2, 3, 4], &[12, -4, 1], 20); drop(layout); diff --git a/src/transform/merge.rs b/src/transform/merge.rs index 3d304db..21c5e24 100644 --- a/src/transform/merge.rs +++ b/src/transform/merge.rs @@ -77,13 +77,7 @@ impl ArrayLayout { let strides = content.strides(); let (merged, flag) = args.iter().fold((0, true), |(acc, _f), arg| { - ( - acc + arg.len.max(1), - match arg.len { - x if x >= 2 => false, - _ => true, - }, - ) + (acc + arg.len.max(1), !matches!(arg.len, x if x >= 2)) }); // 如果所有arg.len都是0或者1,直接返回原布局 if flag { From 64ba42a46e9333f825768dcd24942c995eff8403 Mon Sep 17 00:00:00 2001 From: Simon25772 Date: Mon, 28 Apr 2025 21:37:24 +0800 Subject: [PATCH 12/21] docs(README): add the coverage tag, but the link uses my own codecov --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index d668b3b..817364c 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@ [![Latest version](https://img.shields.io/crates/v/ndarray-layout.svg)](https://crates.io/crates/ndarray-layout) [![Documentation](https://docs.rs/ndarray-layout/badge.svg)](https://docs.rs/ndarray-layout) [![license](https://img.shields.io/github/license/InfiniTensor/ndarray-layout)](https://mit-license.org/) +[![codecov](https://codecov.io/github/Simon25772/ndarray-layout/branch/ShenghuSu/graph/badge.svg)](https://codecov.io/github/Simon25772/ndarray-layout/tree/Shenghu) ![GitHub repo size](https://img.shields.io/github/repo-size/InfiniTensor/ndarray-layout) ![GitHub code size in bytes](https://img.shields.io/github/languages/code-size/InfiniTensor/ndarray-layout) From 44086cf49894e4f659acfee9c2114b7e7ba32993 Mon Sep 17 00:00:00 2001 From: Simon25772 Date: Tue, 29 Apr 2025 15:03:11 +0800 Subject: [PATCH 13/21] style: remove redundant blank lines --- src/fmt.rs | 1 - src/lib.rs | 1 - src/transform/broadcast.rs | 2 -- src/transform/index.rs | 2 -- 4 files changed, 6 deletions(-) diff --git a/src/fmt.rs b/src/fmt.rs index e75e89b..63381f0 100644 --- a/src/fmt.rs +++ b/src/fmt.rs @@ -89,7 +89,6 @@ impl ArrayLayout { } #[test] - fn test() { const DATA: &[u8] = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 0]; diff --git a/src/lib.rs b/src/lib.rs index 59b18dc..978dd01 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,7 +41,6 @@ impl Drop for ArrayLayout { /// 元信息存储顺序。 #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] - pub enum Endian { /// 大端序,范围更大的维度在元信息中更靠前的位置。 BigEndian, diff --git a/src/transform/broadcast.rs b/src/transform/broadcast.rs index 175a2bd..159136e 100644 --- a/src/transform/broadcast.rs +++ b/src/transform/broadcast.rs @@ -3,7 +3,6 @@ /// 索引变换参数。 #[derive(Clone, PartialEq, Eq, Debug)] - pub struct BroadcastArg { /// 广播的轴。 pub axis: usize, @@ -39,7 +38,6 @@ impl ArrayLayout { } #[test] - fn test_broadcast() { let layout = ArrayLayout::<3>::new(&[1, 5, 2], &[10, 2, 1], 0).broadcast(0, 10); assert_eq!(layout.shape(), &[10, 5, 2]); diff --git a/src/transform/index.rs b/src/transform/index.rs index 3b040aa..b261300 100644 --- a/src/transform/index.rs +++ b/src/transform/index.rs @@ -4,7 +4,6 @@ use std::iter::zip; /// 索引变换参数。 #[derive(Clone, PartialEq, Eq, Debug)] - pub struct IndexArg { /// 索引的轴。 pub axis: usize, @@ -68,7 +67,6 @@ impl ArrayLayout { } #[test] - fn test() { let layout = ArrayLayout::<1>::new(&[2, 3, 4], &[12, 4, 1], 0); let layout = layout.index(1, 2); From e3cf510aadc130aed0a955a982471675a3b2d548 Mon Sep 17 00:00:00 2001 From: Simon25772 Date: Tue, 29 Apr 2025 15:26:41 +0800 Subject: [PATCH 14/21] docs: improve the readability of README.md --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 817364c..2f7d34e 100644 --- a/README.md +++ b/README.md @@ -13,23 +13,23 @@ ![GitHub contributors](https://img.shields.io/github/contributors/InfiniTensor/ndarray-layout) ![GitHub commit activity](https://img.shields.io/github/commit-activity/m/InfiniTensor/ndarray-layout) -ndarray-layout 是一个用于处理多维数组布局的 Rust 库,它提供了 ArrayLayout 结构体,用于高效管理和操作多维数组的元信息,如形状、步长和偏移量等。这个库在处理多维数组时,提供了灵活且高效的布局管理方式,能够满足不同场景下对数组布局的操作需求。 +*ndarray-layout* 是一个用于处理多维数组布局的 *Rust* 库,它提供了 `ArrayLayout` 结构体,用于高效管理和操作多维数组的元信息,如形状、步长和偏移量等。这个库在处理多维数组时,提供了灵活且高效的布局管理方式,能够满足不同场景下对数组布局的操作需求。 ## 主要功能特点 ### 多维数组布局管理 -- ArrayLayout 结构体支持指定任意维度的数组布局,通过 new 方法可以创建具有指定形状、步长和偏移量的布局; -- 提供 new_contiguous 方法,用于创建连续的数组布局,支持大端序(BigEndian)和小端序(LittleEndian)两种存储顺序; +- `ArrayLayout` 结构体支持指定任意维度的数组布局,通过 `new` 方法可以创建具有指定形状、步长和偏移量的布局; +- 提供 `new_contiguous` 方法,用于创建连续的数组布局,支持大端序(`BigEndian`)和小端序(`LittleEndian`)两种存储顺序; ### 元信息访问 -- 提供便捷的方法来访问数组布局的元信息,如 ndim、offset、shape 和 strides 等; +- 提供便捷的方法来访问数组布局的元信息,如 `ndim`、`offset`、`shape` 和 `strides` 等; - 支持计算数组元素的偏移量和数据范围,方便进行内存访问和数据处理; ### 布局操作功能 -- 提供多种布局变换方法,如 index、tile、transpose、merge 和 slice 等,方便对数组布局进行各种变换操作; +- 提供多种布局变换方法,如 `index`、`tile`、`transpose`、`merge` 和 `slice` 等,方便对数组布局进行各种变换操作; ## 使用示例 From 73be6b98a3f56c70984fd0c8186629b43188f8cb Mon Sep 17 00:00:00 2001 From: Simon25772 Date: Tue, 29 Apr 2025 15:45:04 +0800 Subject: [PATCH 15/21] style: standardized annotation format --- README.md | 10 +++++----- src/lib.rs | 2 +- src/transform/merge.rs | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 2f7d34e..2f0278f 100644 --- a/README.md +++ b/README.md @@ -36,19 +36,19 @@ ```rust use ndarray_layout::{ArrayLayout, BroadcastArg}; -// 创建一个新的 ArrayLayout 实例 -// 形状为 [1, 2, 3],步长为 [12, 4, 1],偏移量为 0 +// 创建一个新的 `ArrayLayout` 实例。 +// 形状为 [1, 2, 3],步长为 [12, 4, 1],偏移量为 0。 let layout = ArrayLayout::<3>::new(&[1, 2, 3], &[12, 4, 1], 0); -// 验证初始的形状和步长 +// 验证初始的形状和步长。 assert_eq!(layout.shape(), &[1, 2, 3]); assert_eq!(layout.strides(), &[12, 4, 1]); assert_eq!(layout.offset(), 0); -// 对第 0 维进行广播变换,广播次数为 4 +// 对第 0 维进行广播变换,广播次数为 4。 let broadcasted_layout = layout.broadcast(0, 4); -// 验证广播变换后的形状和步长 +// 验证广播变换后的形状和步长。 assert_eq!(broadcasted_layout.shape(), &[4, 2, 3]); assert_eq!(broadcasted_layout.strides(), &[0, 4, 1]); assert_eq!(broadcasted_layout.offset(), 0); diff --git a/src/lib.rs b/src/lib.rs index 978dd01..3097f88 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -222,7 +222,7 @@ impl ArrayLayout { #[inline] fn ptr_allocated(&self) -> Option> { const { assert!(N > 0) } - // ndim>N则content是ptr,否则是元组 + // ndim > N 则 content 是 ptr,否则是元组。 if self.ndim > N { Some(unsafe { self.content.ptr }) } else { diff --git a/src/transform/merge.rs b/src/transform/merge.rs index 21c5e24..6b9ef8e 100644 --- a/src/transform/merge.rs +++ b/src/transform/merge.rs @@ -79,7 +79,7 @@ impl ArrayLayout { let (merged, flag) = args.iter().fold((0, true), |(acc, _f), arg| { (acc + arg.len.max(1), !matches!(arg.len, x if x >= 2)) }); - // 如果所有arg.len都是0或者1,直接返回原布局 + // 如果所有 arg.len 都是 0 或者 1,直接返回原布局。 if flag { return Some(self.clone()); } From 65dcf00e15f1352cfee918276058bfd79a6d9c09 Mon Sep 17 00:00:00 2001 From: Simon25772 Date: Tue, 29 Apr 2025 15:53:39 +0800 Subject: [PATCH 16/21] refactor(transform/merge): simplify the merging judgment conditions The original judgment logic for accumulating flags in fold was changed to compare merged with `args.len()` to determine whether to merge. This judgment is equivalent but more concise. --- src/transform/merge.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/transform/merge.rs b/src/transform/merge.rs index 6b9ef8e..ebae853 100644 --- a/src/transform/merge.rs +++ b/src/transform/merge.rs @@ -76,13 +76,11 @@ impl ArrayLayout { let shape = content.shape(); let strides = content.strides(); - let (merged, flag) = args.iter().fold((0, true), |(acc, _f), arg| { - (acc + arg.len.max(1), !matches!(arg.len, x if x >= 2)) - }); - // 如果所有 arg.len 都是 0 或者 1,直接返回原布局。 - if flag { + let merged = args.iter().map(|arg| arg.len.max(1)).sum::(); + if merged == args.len() { return Some(self.clone()); } + let mut ans = Self::with_ndim(self.ndim + args.len() - merged); let mut content = ans.content_mut(); From 3163b7ed4433e35b6bbad0810b4adf602f8ab680 Mon Sep 17 00:00:00 2001 From: Simon25772 Date: Tue, 29 Apr 2025 16:05:53 +0800 Subject: [PATCH 17/21] docs: append Simon25772 to authors list --- Cargo.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 685f8ba..e4ac656 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,10 @@ name = "ndarray-layout" description = "This crate provides definitions and transformations for multi-dimensional array data layouts." version = "0.2.1" edition = "2024" -authors = ["YdrMaster "] +authors = [ + "YdrMaster ", + "Simon25772 ", +] repository = "https://github.com/InfiniTensor/ndarray-layout.git" documentation = "https://docs.rs/ndarray-layout" license = "MIT" From b23873d05d74d4f162e4ecf401b46b035a5ffc56 Mon Sep 17 00:00:00 2001 From: Simon25772 Date: Tue, 29 Apr 2025 16:19:52 +0800 Subject: [PATCH 18/21] docs(CHANGELOG): update CHANGELOG --- CHANGELOG.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e575dcc..78062ac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,9 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Changed + +- Compatible with meaningless input: Merging 0 or 1 dimensions will not change the layout; + ### Added -- Add `to_inline_size` function, to copy data from `ArrayLayout` into `ArrayLayout`. +- Add `to_inline_size` function, to copy data from `ArrayLayout` into `ArrayLayout`; ## [0.2.1] - 2025-03-28 From f39c127cb3eab8bdad35fa6961879b96fef5926b Mon Sep 17 00:00:00 2001 From: Simon25772 Date: Tue, 29 Apr 2025 16:23:57 +0800 Subject: [PATCH 19/21] style(transform/broadcast): remove redundant blank lines --- src/transform/broadcast.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transform/broadcast.rs b/src/transform/broadcast.rs index 159136e..7998e0f 100644 --- a/src/transform/broadcast.rs +++ b/src/transform/broadcast.rs @@ -1,7 +1,6 @@ use crate::ArrayLayout; /// 索引变换参数。 - #[derive(Clone, PartialEq, Eq, Debug)] pub struct BroadcastArg { /// 广播的轴。 From 4d063b3916cc8aa3980ec8053c754a875f2d682f Mon Sep 17 00:00:00 2001 From: Simon25772 Date: Tue, 29 Apr 2025 16:24:30 +0800 Subject: [PATCH 20/21] style(transform/index): remove redundant blank lines --- src/transform/index.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transform/index.rs b/src/transform/index.rs index b261300..5b8f1fc 100644 --- a/src/transform/index.rs +++ b/src/transform/index.rs @@ -2,7 +2,6 @@ use std::iter::zip; /// 索引变换参数。 - #[derive(Clone, PartialEq, Eq, Debug)] pub struct IndexArg { /// 索引的轴。 From 466853faa714c61a35f533241f9d53e147ac7cd9 Mon Sep 17 00:00:00 2001 From: Simon25772 Date: Tue, 29 Apr 2025 16:29:00 +0800 Subject: [PATCH 21/21] =?UTF-8?q?docs(transform/broadcast):=20=E6=9B=B4?= =?UTF-8?q?=E6=AD=A3=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/transform/broadcast.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transform/broadcast.rs b/src/transform/broadcast.rs index 7998e0f..87a2c96 100644 --- a/src/transform/broadcast.rs +++ b/src/transform/broadcast.rs @@ -1,6 +1,6 @@ use crate::ArrayLayout; -/// 索引变换参数。 +/// 广播变换参数。 #[derive(Clone, PartialEq, Eq, Debug)] pub struct BroadcastArg { /// 广播的轴。