From 28f42622b4498508bb84caac08252f7f5b48cc0b Mon Sep 17 00:00:00 2001 From: youge325 Date: Fri, 6 Feb 2026 15:13:44 +0800 Subject: [PATCH] add select, detach, reciprocal and split related tests --- test/ops/DetachTest.cpp | 163 +++++++++++++++++++++ test/ops/ReciprocalTest.cpp | 130 +++++++++++++++++ test/ops/SelectTest.cpp | 221 ++++++++++++++++++++++++++++ test/ops/SplitTest.cpp | 281 ++++++++++++++++++++++++++++++++++++ 4 files changed, 795 insertions(+) create mode 100644 test/ops/DetachTest.cpp create mode 100644 test/ops/ReciprocalTest.cpp create mode 100644 test/ops/SelectTest.cpp create mode 100644 test/ops/SplitTest.cpp diff --git a/test/ops/DetachTest.cpp b/test/ops/DetachTest.cpp new file mode 100644 index 0000000..0d24669 --- /dev/null +++ b/test/ops/DetachTest.cpp @@ -0,0 +1,163 @@ +#include +#include +#include +#include + +#include +#include + +#include "../../src/file_manager.h" + +extern paddle_api_test::ThreadSafeParam g_custom_param; + +namespace at { +namespace test { + +using paddle_api_test::FileManerger; +using paddle_api_test::ThreadSafeParam; + +class DetachTest : public ::testing::Test { + protected: + void SetUp() override { + std::vector shape = {3, 4}; + test_tensor = at::zeros(shape, at::kFloat); + float* data = test_tensor.data_ptr(); + for (int64_t i = 0; i < 12; ++i) { + data[i] = static_cast(i + 1); + } + } + at::Tensor test_tensor; +}; + +static void write_detach_result_to_file(FileManerger* file, + const at::Tensor& result, + const at::Tensor& original) { + *file << std::to_string(result.dim()) << " "; + *file << std::to_string(result.numel()) << " "; + + // 写入形状信息 + for (int64_t i = 0; i < result.dim(); ++i) { + *file << std::to_string(result.sizes()[i]) << " "; + } + + // 写入数据内容 + float* result_data = result.data_ptr(); + for (int64_t i = 0; i < result.numel(); ++i) { + *file << std::to_string(result_data[i]) << " "; + } + + // 验证数据指针是否相同(共享存储) + *file << std::to_string(result.data_ptr() == + original.data_ptr()) + << " "; +} + +// 测试 detach() 方法 - 创建新的 tensor,不跟踪梯度 +TEST_F(DetachTest, BasicDetach) { + at::Tensor detached = test_tensor.detach(); + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + write_detach_result_to_file(&file, detached, test_tensor); + file.saveFile(); +} + +// 测试 detach_() in-place 方法 +TEST_F(DetachTest, InplaceDetach) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + + // 保存原始指针 + float* original_ptr = test_tensor.data_ptr(); + + // 调用 in-place 版本 + at::Tensor& result = test_tensor.detach_(); + + // 验证返回的是同一个 tensor + file << std::to_string(result.data_ptr() == original_ptr) << " "; + + // 写入数据 + float* data = result.data_ptr(); + for (int64_t i = 0; i < result.numel(); ++i) { + file << std::to_string(data[i]) << " "; + } + file.saveFile(); +} + +// 测试 detach 后修改数据 +TEST_F(DetachTest, DetachAndModify) { + at::Tensor detached = test_tensor.detach(); + + // 修改 detached tensor 的数据 + float* detached_data = detached.data_ptr(); + detached_data[0] = 99.0f; + detached_data[1] = 88.0f; + + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.openAppend(); + + // 验证原始 tensor 的数据也被修改了(因为共享存储) + float* original_data = test_tensor.data_ptr(); + file << std::to_string(original_data[0]) << " "; + file << std::to_string(original_data[1]) << " "; + file << std::to_string(detached_data[0]) << " "; + file << std::to_string(detached_data[1]) << " "; + file.saveFile(); +} + +// 测试不同类型 tensor 的 detach +TEST_F(DetachTest, DetachDifferentTensor) { + at::Tensor different_tensor = at::zeros({2, 2}, at::kFloat); + float* data = different_tensor.data_ptr(); + data[0] = 1.0f; + data[1] = 2.0f; + data[2] = 3.0f; + data[3] = 4.0f; + + at::Tensor detached = different_tensor.detach(); + + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.openAppend(); + + file << std::to_string(detached.numel()) << " "; + file << std::to_string(detached.dim()) << " "; + + float* detached_data = detached.data_ptr(); + for (int64_t i = 0; i < detached.numel(); ++i) { + file << std::to_string(detached_data[i]) << " "; + } + file.saveFile(); +} + +// 测试多维 tensor 的 detach +TEST_F(DetachTest, MultiDimensionalDetach) { + at::Tensor multi_tensor = at::zeros({2, 3, 4}, at::kFloat); + float* data = multi_tensor.data_ptr(); + for (int64_t i = 0; i < 24; ++i) { + data[i] = static_cast(i); + } + + at::Tensor detached = multi_tensor.detach(); + + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.openAppend(); + + file << std::to_string(detached.dim()) << " "; + file << std::to_string(detached.sizes()[0]) << " "; + file << std::to_string(detached.sizes()[1]) << " "; + file << std::to_string(detached.sizes()[2]) << " "; + file << std::to_string(detached.numel()) << " "; + + // 验证数据共享 + file << std::to_string(detached.data_ptr() == + multi_tensor.data_ptr()) + << " "; + file.saveFile(); +} + +} // namespace test +} // namespace at diff --git a/test/ops/ReciprocalTest.cpp b/test/ops/ReciprocalTest.cpp new file mode 100644 index 0000000..3116006 --- /dev/null +++ b/test/ops/ReciprocalTest.cpp @@ -0,0 +1,130 @@ +#include +#include +#include +#include +#include + +#include +#include + +#include "../../src/file_manager.h" + +extern paddle_api_test::ThreadSafeParam g_custom_param; + +namespace at { +namespace test { + +using paddle_api_test::FileManerger; +using paddle_api_test::ThreadSafeParam; + +class ReciprocalTest : public ::testing::Test { + protected: + void SetUp() override { + std::vector shape = {4}; + test_tensor = at::zeros(shape, at::kFloat); + float* data = test_tensor.data_ptr(); + data[0] = 1.0f; + data[1] = 2.0f; + data[2] = 0.5f; + data[3] = 4.0f; + } + at::Tensor test_tensor; +}; + +static void write_reciprocal_result_to_file(FileManerger* file, + const at::Tensor& result) { + *file << std::to_string(result.dim()) << " "; + *file << std::to_string(result.numel()) << " "; + float* result_data = result.data_ptr(); + for (int64_t i = 0; i < result.numel(); ++i) { + *file << std::to_string(result_data[i]) << " "; + } +} + +// 测试 reciprocal() 方法 +TEST_F(ReciprocalTest, BasicReciprocal) { + at::Tensor result = test_tensor.reciprocal(); + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + write_reciprocal_result_to_file(&file, result); + + // 验证原始 tensor 未被修改 + float* original_data = test_tensor.data_ptr(); + file << std::to_string(original_data[0]) << " "; + file << std::to_string(original_data[1]) << " "; + file.saveFile(); +} + +// 测试 reciprocal_() in-place 方法 +TEST_F(ReciprocalTest, InplaceReciprocal) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + + // 保存原始数据指针 + float* original_ptr = test_tensor.data_ptr(); + + // 调用 in-place 版本 + at::Tensor& result = test_tensor.reciprocal_(); + + // 验证返回的是同一个 tensor + file << std::to_string(result.data_ptr() == original_ptr) << " "; + + write_reciprocal_result_to_file(&file, result); + file.saveFile(); +} + +// 测试不同值的 reciprocal +TEST_F(ReciprocalTest, VariousValues) { + at::Tensor various_tensor = at::zeros({5}, at::kFloat); + float* data = various_tensor.data_ptr(); + data[0] = 10.0f; + data[1] = 0.1f; + data[2] = -2.0f; + data[3] = -0.5f; + data[4] = 100.0f; + + at::Tensor result = various_tensor.reciprocal(); + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.openAppend(); + write_reciprocal_result_to_file(&file, result); + file.saveFile(); +} + +// 测试多维 tensor 的 reciprocal +TEST_F(ReciprocalTest, MultiDimensionalTensor) { + at::Tensor multi_dim_tensor = at::zeros({2, 3}, at::kFloat); + float* data = multi_dim_tensor.data_ptr(); + data[0] = 1.0f; + data[1] = 2.0f; + data[2] = 4.0f; + data[3] = 0.25f; + data[4] = 0.5f; + data[5] = 8.0f; + + at::Tensor result = multi_dim_tensor.reciprocal(); + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.openAppend(); + + file << std::to_string(result.dim()) << " "; + file << std::to_string(result.sizes()[0]) << " "; + file << std::to_string(result.sizes()[1]) << " "; + write_reciprocal_result_to_file(&file, result); + file.saveFile(); +} + +// 测试使用 at::reciprocal 全局函数 +TEST_F(ReciprocalTest, GlobalReciprocal) { + at::Tensor result = at::reciprocal(test_tensor); + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.openAppend(); + write_reciprocal_result_to_file(&file, result); + file.saveFile(); +} + +} // namespace test +} // namespace at diff --git a/test/ops/SelectTest.cpp b/test/ops/SelectTest.cpp new file mode 100644 index 0000000..591797e --- /dev/null +++ b/test/ops/SelectTest.cpp @@ -0,0 +1,221 @@ +#include +#include +#include +#include +#include + +#include +#include + +#include "../../src/file_manager.h" + +extern paddle_api_test::ThreadSafeParam g_custom_param; + +namespace at { +namespace test { + +using paddle_api_test::FileManerger; +using paddle_api_test::ThreadSafeParam; + +class SelectTest : public ::testing::Test { + protected: + void SetUp() override { + // 创建一个 3x4x5 的三维 tensor + test_tensor = at::zeros({3, 4, 5}, at::kFloat); + float* data = test_tensor.data_ptr(); + for (int64_t i = 0; i < 60; ++i) { + data[i] = static_cast(i); + } + } + at::Tensor test_tensor; +}; + +static void write_select_result_to_file(FileManerger* file, + const at::Tensor& result) { + *file << std::to_string(result.dim()) << " "; + *file << std::to_string(result.numel()) << " "; + + // 写入形状信息 + for (int64_t i = 0; i < result.dim(); ++i) { + *file << std::to_string(result.sizes()[i]) << " "; + } + + // 写入前几个数据值 + float* result_data = result.data_ptr(); + int64_t max_elements = std::min(result.numel(), static_cast(10)); + for (int64_t i = 0; i < max_elements; ++i) { + *file << std::to_string(result_data[i]) << " "; + } +} + +// 测试 select 第 0 维 +TEST_F(SelectTest, SelectDim0) { + // 从第 0 维选择索引 1,结果应该是 4x5 + at::Tensor result = test_tensor.select(0, 1); + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + write_select_result_to_file(&file, result); + file.saveFile(); +} + +// 测试 select 第 1 维 +TEST_F(SelectTest, SelectDim1) { + // 从第 1 维选择索引 2,结果应该是 3x5 + at::Tensor result = test_tensor.select(1, 2); + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + write_select_result_to_file(&file, result); + file.saveFile(); +} + +// 测试 select 第 2 维 +TEST_F(SelectTest, SelectDim2) { + // 从第 2 维选择索引 3,结果应该是 3x4 + at::Tensor result = test_tensor.select(2, 3); + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.openAppend(); + write_select_result_to_file(&file, result); + file.saveFile(); +} + +// 测试 select 使用负数索引 +TEST_F(SelectTest, SelectNegativeIndex) { + // 从第 0 维选择索引 -1(最后一个),结果应该是 4x5 + at::Tensor result = test_tensor.select(0, -1); + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.openAppend(); + write_select_result_to_file(&file, result); + file.saveFile(); +} + +// 测试 select 链式调用 +TEST_F(SelectTest, SelectChain) { + // 先选择第 0 维的索引 1,再选择第 0 维的索引 2,最后选择第 0 维的索引 3 + at::Tensor result = test_tensor.select(0, 1).select(0, 2).select(0, 3); + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.openAppend(); + + file << std::to_string(result.dim()) << " "; + file << std::to_string(result.numel()) << " "; + + // 对于标量或一维 tensor,直接输出值 + if (result.numel() == 1) { + file << std::to_string(result.item()) << " "; + } else { + float* data = result.data_ptr(); + for (int64_t i = 0; i < std::min(result.numel(), static_cast(5)); + ++i) { + file << std::to_string(data[i]) << " "; + } + } + file.saveFile(); +} + +// 测试 select_symint 方法 +TEST_F(SelectTest, SelectSymInt) { + // 使用 SymInt 选择第 1 维的索引 1 + c10::SymInt sym_index(1); + at::Tensor result = test_tensor.select_symint(1, sym_index); + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.openAppend(); + write_select_result_to_file(&file, result); + file.saveFile(); +} + +// 测试 select_symint 使用不同的维度和索引 +TEST_F(SelectTest, SelectSymIntVariousIndices) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.openAppend(); + + // 选择第 0 维的索引 0 + c10::SymInt sym_index_0(0); + at::Tensor result1 = test_tensor.select_symint(0, sym_index_0); + file << std::to_string(result1.dim()) << " "; + file << std::to_string(result1.numel()) << " "; + + // 选择第 2 维的索引 4 + c10::SymInt sym_index_4(4); + at::Tensor result2 = test_tensor.select_symint(2, sym_index_4); + file << std::to_string(result2.dim()) << " "; + file << std::to_string(result2.numel()) << " "; + + file.saveFile(); +} + +// 测试二维 tensor 的 select +TEST_F(SelectTest, Select2DTensor) { + at::Tensor tensor_2d = at::zeros({4, 5}, at::kFloat); + float* data = tensor_2d.data_ptr(); + for (int64_t i = 0; i < 20; ++i) { + data[i] = static_cast(i * 2); + } + + // 选择第 0 维的索引 2,结果应该是一维 tensor,大小为 5 + at::Tensor result = tensor_2d.select(0, 2); + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.openAppend(); + + file << std::to_string(result.dim()) << " "; + file << std::to_string(result.numel()) << " "; + + float* result_data = result.data_ptr(); + for (int64_t i = 0; i < result.numel(); ++i) { + file << std::to_string(result_data[i]) << " "; + } + file.saveFile(); +} + +// 测试一维 tensor 的 select +TEST_F(SelectTest, Select1DTensor) { + at::Tensor tensor_1d = at::zeros({10}, at::kFloat); + float* data = tensor_1d.data_ptr(); + for (int64_t i = 0; i < 10; ++i) { + data[i] = static_cast(i * 10); + } + + // 选择第 0 维的索引 5,结果应该是标量 + at::Tensor result = tensor_1d.select(0, 5); + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.openAppend(); + + file << std::to_string(result.dim()) << " "; + file << std::to_string(result.numel()) << " "; + file << std::to_string(result.item()) << " "; + file.saveFile(); +} + +// 测试 select 返回的 view 与原始 tensor 共享存储 +TEST_F(SelectTest, SelectViewSharing) { + at::Tensor selected = test_tensor.select(0, 0); + + // 修改 selected 的数据 + float* selected_data = selected.data_ptr(); + float original_value = selected_data[0]; + selected_data[0] = 999.0f; + + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.openAppend(); + + // 验证原始 tensor 的对应位置也被修改 + float* original_data = test_tensor.data_ptr(); + file << std::to_string(original_value) << " "; + file << std::to_string(original_data[0]) << " "; + file << std::to_string(selected_data[0]) << " "; + + // 恢复数据 + selected_data[0] = original_value; + file.saveFile(); +} + +} // namespace test +} // namespace at diff --git a/test/ops/SplitTest.cpp b/test/ops/SplitTest.cpp new file mode 100644 index 0000000..7101356 --- /dev/null +++ b/test/ops/SplitTest.cpp @@ -0,0 +1,281 @@ +#include +#include +#include +#include +#include + +#include +#include + +#include "../../src/file_manager.h" + +extern paddle_api_test::ThreadSafeParam g_custom_param; + +namespace at { +namespace test { + +using paddle_api_test::FileManerger; +using paddle_api_test::ThreadSafeParam; + +class SplitTest : public ::testing::Test { + protected: + void SetUp() override { + // 创建一个 4x6x8 的 tensor 方便测试 + tensor = at::ones({4, 6, 8}, at::kFloat); + } + + at::Tensor tensor; +}; + +// 测试 split - 按大小分割 +TEST_F(SplitTest, SplitBySize) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + + // 在维度0上,每2个元素分割 + std::vector splits = tensor.split(2, 0); + file << std::to_string(splits.size()) << " "; + for (const auto& split : splits) { + file << std::to_string(split.sizes()[0]) << " "; + } + file.saveFile(); +} + +// 测试 split - 按大小数组分割 +TEST_F(SplitTest, SplitBySizes) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + + // 在维度1上,分割为大小 [2, 3, 1] + std::vector splits = tensor.split({2, 3, 1}, 1); + file << std::to_string(splits.size()) << " "; + for (const auto& split : splits) { + file << std::to_string(split.sizes()[1]) << " "; + } + file.saveFile(); +} + +// 测试 split_with_sizes +TEST_F(SplitTest, SplitWithSizes) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + + // 在维度2上,分割为大小 [3, 2, 3] + std::vector splits = tensor.split_with_sizes({3, 2, 3}, 2); + file << std::to_string(splits.size()) << " "; + for (const auto& split : splits) { + file << std::to_string(split.sizes()[2]) << " "; + } + file.saveFile(); +} + +// 测试 unsafe_split +TEST_F(SplitTest, UnsafeSplit) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + + // 在维度0上,每2个元素分割 + std::vector splits = tensor.unsafe_split(2, 0); + file << std::to_string(splits.size()) << " "; + for (const auto& split : splits) { + file << std::to_string(split.sizes()[0]) << " "; + } + file.saveFile(); +} + +// 测试 unsafe_split_with_sizes +TEST_F(SplitTest, UnsafeSplitWithSizes) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + + // 在维度1上,分割为大小 [2, 4] + std::vector splits = tensor.unsafe_split_with_sizes({2, 4}, 1); + file << std::to_string(splits.size()) << " "; + for (const auto& split : splits) { + file << std::to_string(split.sizes()[1]) << " "; + } + file.saveFile(); +} + +// 测试 tensor_split - 按节数分割 +TEST_F(SplitTest, TensorSplitBySections) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + + // 在维度0上,分割为2个部分(4能被2整除) + std::vector splits = tensor.tensor_split(2, 0); + file << std::to_string(splits.size()) << " "; + for (const auto& split : splits) { + file << std::to_string(split.sizes()[0]) << " "; + } + file.saveFile(); +} + +// 测试 tensor_split - 按索引分割 +TEST_F(SplitTest, TensorSplitByIndices) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + + // 在维度1上,在索引 [2, 4] 处分割 + std::vector splits = tensor.tensor_split({2, 4}, 1); + file << std::to_string(splits.size()) << " "; + for (const auto& split : splits) { + file << std::to_string(split.sizes()[1]) << " "; + } + file.saveFile(); +} + +// 测试 tensor_split - 使用 tensor 作为索引 +TEST_F(SplitTest, TensorSplitByTensorIndices) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + + // 创建大小数组 tensor(Paddle要求总和等于维度大小8) + std::vector indices_data = {2, 3, 3}; + at::Tensor indices = + at::from_blob(indices_data.data(), {3}, at::kLong).clone(); + // 在维度2上分割 + std::vector splits = tensor.tensor_split(indices, 2); + file << std::to_string(splits.size()) << " "; + for (const auto& split : splits) { + file << std::to_string(split.sizes()[2]) << " "; + } + file.saveFile(); +} + +// 测试 hsplit - 按节数水平分割 +TEST_F(SplitTest, HsplitBySections) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + + // 水平分割为3个部分 + std::vector splits = tensor.hsplit(3); + file << std::to_string(splits.size()) << " "; + for (const auto& split : splits) { + file << std::to_string(split.sizes()[1]) << " "; + } + file.saveFile(); +} + +// 测试 hsplit - 按索引水平分割 +TEST_F(SplitTest, HsplitByIndices) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + + // 在索引 [2, 4] 处水平分割 + std::vector splits = tensor.hsplit({2, 4}); + file << std::to_string(splits.size()) << " "; + for (const auto& split : splits) { + file << std::to_string(split.sizes()[1]) << " "; + } + file.saveFile(); +} + +// 测试 vsplit - 按节数垂直分割 +TEST_F(SplitTest, VsplitBySections) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + + // 垂直分割为2个部分 + std::vector splits = tensor.vsplit(2); + file << std::to_string(splits.size()) << " "; + for (const auto& split : splits) { + file << std::to_string(split.sizes()[0]) << " "; + } + file.saveFile(); +} + +// 测试 vsplit - 按索引垂直分割 +TEST_F(SplitTest, VsplitByIndices) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + + // 在索引 [1, 3] 处垂直分割 + std::vector splits = tensor.vsplit({1, 3}); + file << std::to_string(splits.size()) << " "; + for (const auto& split : splits) { + file << std::to_string(split.sizes()[0]) << " "; + } + file.saveFile(); +} + +// 测试 dsplit - 按节数深度分割 +TEST_F(SplitTest, DsplitBySections) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + + // 深度分割为4个部分 + std::vector splits = tensor.dsplit(4); + file << std::to_string(splits.size()) << " "; + for (const auto& split : splits) { + file << std::to_string(split.sizes()[2]) << " "; + } + file.saveFile(); +} + +// 测试 dsplit - 按索引深度分割 +TEST_F(SplitTest, DsplitByIndices) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + + // 使用大小数组 [3, 5](总和为8,等于维度2的大小) + std::vector splits = tensor.dsplit({3, 5}); + file << std::to_string(splits.size()) << " "; + for (const auto& split : splits) { + file << std::to_string(split.sizes()[2]) << " "; + } + file.saveFile(); +} + +// 测试 split 不同维度 +TEST_F(SplitTest, SplitDifferentDims) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + + // 在维度0上分割 + std::vector splits0 = tensor.split(1, 0); + file << std::to_string(splits0.size()) << " "; + + // 在维度1上分割 + std::vector splits1 = tensor.split(2, 1); + file << std::to_string(splits1.size()) << " "; + + // 在维度2上分割 + std::vector splits2 = tensor.split(4, 2); + file << std::to_string(splits2.size()) << " "; + + file.saveFile(); +} + +// 测试不均等分割 +TEST_F(SplitTest, UnevenSplit) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + + // 在维度1上,每5个元素分割(不能整除) + std::vector splits = tensor.split(5, 1); + file << std::to_string(splits.size()) << " "; + for (const auto& split : splits) { + file << std::to_string(split.sizes()[1]) << " "; + } + file.saveFile(); +} + +} // namespace test +} // namespace at