diff --git a/test/ops/SqueezeTest.cpp b/test/ops/SqueezeTest.cpp new file mode 100644 index 0000000..0db4134 --- /dev/null +++ b/test/ops/SqueezeTest.cpp @@ -0,0 +1,96 @@ +#include +#include +#include +#include + +#include +#include + +#include "../../src/file_manager.h" + +extern paddle_api_test::ThreadSafeParam g_custom_param; + +namespace at { +namespace test { + +using paddle_api_test::FileManerger; +using paddle_api_test::ThreadSafeParam; + +class SqueezeTest : public ::testing::Test { + protected: + void SetUp() override { + // 创建一个包含大小为1的维度的tensor: shape = {2, 1, 3, 1, 4} + tensor_with_ones = at::ones({2, 1, 3, 1, 4}, at::kFloat); + } + at::Tensor tensor_with_ones; +}; + +// 测试 squeeze - 移除所有大小为1的维度 +TEST_F(SqueezeTest, SqueezeAll) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + at::Tensor squeezed = tensor_with_ones.squeeze(); + file << std::to_string(squeezed.dim()) << " "; + file << std::to_string(squeezed.numel()) << " "; + for (int64_t i = 0; i < squeezed.dim(); ++i) { + file << std::to_string(squeezed.sizes()[i]) << " "; + } + file.saveFile(); +} + +// 测试 squeeze - 移除指定维度 +TEST_F(SqueezeTest, SqueezeDim) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + // 移除维度1(大小为1) + at::Tensor squeezed_dim1 = tensor_with_ones.squeeze(1); + file << std::to_string(squeezed_dim1.dim()) << " "; + file << std::to_string(squeezed_dim1.numel()) << " "; + for (int64_t i = 0; i < squeezed_dim1.dim(); ++i) { + file << std::to_string(squeezed_dim1.sizes()[i]) << " "; + } + file.saveFile(); +} + +// 测试 squeeze_ - 原位移除所有大小为1的维度 +TEST_F(SqueezeTest, SqueezeInplaceAll) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + // 记录原始数据指针 + void* original_ptr = tensor_with_ones.data_ptr(); + // 原位移除所有大小为1的维度 + tensor_with_ones.squeeze_(); + file << std::to_string(tensor_with_ones.dim()) << " "; + file << std::to_string(tensor_with_ones.numel()) << " "; + for (int64_t i = 0; i < tensor_with_ones.dim(); ++i) { + file << std::to_string(tensor_with_ones.sizes()[i]) << " "; + } + // 验证是原位操作(数据指针未改变) + file << std::to_string(tensor_with_ones.data_ptr() == original_ptr) << " "; + file.saveFile(); +} + +// 测试 squeeze_ - 原位移除指定维度 +TEST_F(SqueezeTest, SqueezeInplaceDim) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + // 记录原始数据指针 + void* original_ptr = tensor_with_ones.data_ptr(); + // 原位移除维度1 + tensor_with_ones.squeeze_(1); + file << std::to_string(tensor_with_ones.dim()) << " "; + file << std::to_string(tensor_with_ones.numel()) << " "; + for (int64_t i = 0; i < tensor_with_ones.dim(); ++i) { + file << std::to_string(tensor_with_ones.sizes()[i]) << " "; + } + // 验证是原位操作(数据指针未改变) + file << std::to_string(tensor_with_ones.data_ptr() == original_ptr) << " "; + file.saveFile(); +} + +} // namespace test +} // namespace at diff --git a/test/ops/UnsqueezeTest.cpp b/test/ops/UnsqueezeTest.cpp new file mode 100644 index 0000000..ecd4d85 --- /dev/null +++ b/test/ops/UnsqueezeTest.cpp @@ -0,0 +1,109 @@ +#include +#include +#include +#include + +#include +#include + +#include "../../src/file_manager.h" + +extern paddle_api_test::ThreadSafeParam g_custom_param; + +namespace at { +namespace test { + +using paddle_api_test::FileManerger; +using paddle_api_test::ThreadSafeParam; + +class UnsqueezeTest : public ::testing::Test { + protected: + void SetUp() override { + // 创建一个基础tensor: shape = {2, 3, 4} + tensor = at::ones({2, 3, 4}, at::kFloat); + } + at::Tensor tensor; +}; + +// 测试 unsqueeze - 在维度0之前添加维度 +TEST_F(UnsqueezeTest, UnsqueezeDim0) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + at::Tensor unsqueezed0 = tensor.unsqueeze(0); + file << std::to_string(unsqueezed0.dim()) << " "; + file << std::to_string(unsqueezed0.numel()) << " "; + for (int64_t i = 0; i < unsqueezed0.dim(); ++i) { + file << std::to_string(unsqueezed0.sizes()[i]) << " "; + } + file.saveFile(); +} + +// 测试 unsqueeze - 在维度2之前添加维度 +TEST_F(UnsqueezeTest, UnsqueezeDim2) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + at::Tensor unsqueezed2 = tensor.unsqueeze(2); + file << std::to_string(unsqueezed2.dim()) << " "; + file << std::to_string(unsqueezed2.numel()) << " "; + for (int64_t i = 0; i < unsqueezed2.dim(); ++i) { + file << std::to_string(unsqueezed2.sizes()[i]) << " "; + } + file.saveFile(); +} + +// 测试 unsqueeze - 使用负索引在最后添加维度 +TEST_F(UnsqueezeTest, UnsqueezeNegativeDim) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + at::Tensor unsqueezed_last = tensor.unsqueeze(-1); + file << std::to_string(unsqueezed_last.dim()) << " "; + file << std::to_string(unsqueezed_last.numel()) << " "; + for (int64_t i = 0; i < unsqueezed_last.dim(); ++i) { + file << std::to_string(unsqueezed_last.sizes()[i]) << " "; + } + file.saveFile(); +} + +// 测试 unsqueeze_ - 原位在维度0之前添加维度 +TEST_F(UnsqueezeTest, UnsqueezeInplaceDim0) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + // 记录原始数据指针 + void* original_ptr = tensor.data_ptr(); + // 原位在维度0之前添加维度 + tensor.unsqueeze_(0); + file << std::to_string(tensor.dim()) << " "; + file << std::to_string(tensor.numel()) << " "; + for (int64_t i = 0; i < tensor.dim(); ++i) { + file << std::to_string(tensor.sizes()[i]) << " "; + } + // 验证是原位操作(数据指针未改变) + file << std::to_string(tensor.data_ptr() == original_ptr) << " "; + file.saveFile(); +} + +// 测试 unsqueeze_ - 原位使用负索引添加维度 +TEST_F(UnsqueezeTest, UnsqueezeInplaceNegativeDim) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + // 记录原始数据指针 + void* original_ptr = tensor.data_ptr(); + // 原位在最后添加维度 + tensor.unsqueeze_(-1); + file << std::to_string(tensor.dim()) << " "; + file << std::to_string(tensor.numel()) << " "; + for (int64_t i = 0; i < tensor.dim(); ++i) { + file << std::to_string(tensor.sizes()[i]) << " "; + } + // 验证是原位操作(数据指针未改变) + file << std::to_string(tensor.data_ptr() == original_ptr) << " "; + file.saveFile(); +} + +} // namespace test +} // namespace at