diff --git a/test/TensorTest.cpp b/test/TensorTest.cpp index cb50e62..3f2d1cf 100644 --- a/test/TensorTest.cpp +++ b/test/TensorTest.cpp @@ -215,76 +215,129 @@ TEST_F(TensorTest, Transpose) { // 测试 sym_size TEST_F(TensorTest, SymSize) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); // 获取符号化的单个维度大小 c10::SymInt sym_size_0 = tensor.sym_size(0); c10::SymInt sym_size_1 = tensor.sym_size(1); c10::SymInt sym_size_2 = tensor.sym_size(2); - - // 验证符号化大小与实际大小一致 - EXPECT_EQ(sym_size_0, 2); - EXPECT_EQ(sym_size_1, 3); - EXPECT_EQ(sym_size_2, 4); - +#if USE_PADDLE_API + file << std::to_string(sym_size_0) << " "; + file << std::to_string(sym_size_1) << " "; + file << std::to_string(sym_size_2) << " "; +#else + file << std::to_string(sym_size_0.guard_int(__FILE__, __LINE__)) << " "; + file << std::to_string(sym_size_1.guard_int(__FILE__, __LINE__)) << " "; + file << std::to_string(sym_size_2.guard_int(__FILE__, __LINE__)) << " "; +#endif // 测试负索引 c10::SymInt sym_size_neg1 = tensor.sym_size(-1); - EXPECT_EQ(sym_size_neg1, 4); +#if USE_PADDLE_API + file << std::to_string(sym_size_neg1) << " "; +#else + file << std::to_string(sym_size_neg1.guard_int(__FILE__, __LINE__)) << " "; +#endif + file.saveFile(); } // 测试 sym_stride TEST_F(TensorTest, SymStride) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); // 获取符号化的单个维度步长 c10::SymInt sym_stride_0 = tensor.sym_stride(0); c10::SymInt sym_stride_1 = tensor.sym_stride(1); c10::SymInt sym_stride_2 = tensor.sym_stride(2); - - // 验证符号化步长 - EXPECT_GT(sym_stride_0, 0); - EXPECT_GT(sym_stride_1, 0); - EXPECT_GT(sym_stride_2, 0); - +#if USE_PADDLE_API + file << std::to_string(sym_stride_0) << " "; + file << std::to_string(sym_stride_1) << " "; + file << std::to_string(sym_stride_2) << " "; +#else + file << std::to_string(sym_stride_0.guard_int(__FILE__, __LINE__)) << " "; + file << std::to_string(sym_stride_1.guard_int(__FILE__, __LINE__)) << " "; + file << std::to_string(sym_stride_2.guard_int(__FILE__, __LINE__)) << " "; +#endif // 测试负索引 c10::SymInt sym_stride_neg1 = tensor.sym_stride(-1); - EXPECT_EQ(sym_stride_neg1, 1); // 最后一维步长通常为1 +#if USE_PADDLE_API + file << std::to_string(sym_stride_neg1) << " "; +#else + file << std::to_string(sym_stride_neg1.guard_int(__FILE__, __LINE__)) << " "; +#endif + file.saveFile(); } // 测试 sym_sizes TEST_F(TensorTest, SymSizes) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); // 获取符号化的所有维度大小 c10::SymIntArrayRef sym_sizes = tensor.sym_sizes(); - - // 验证维度数量 - EXPECT_EQ(sym_sizes.size(), 3U); - - // 验证每个维度的大小 - EXPECT_EQ(sym_sizes[0], 2); - EXPECT_EQ(sym_sizes[1], 3); - EXPECT_EQ(sym_sizes[2], 4); + file << std::to_string(sym_sizes.size()) << " "; + for (size_t i = 0; i < sym_sizes.size(); ++i) { +#if USE_PADDLE_API + file << std::to_string(sym_sizes[i]) << " "; +#else + file << std::to_string(sym_sizes[i].guard_int(__FILE__, __LINE__)) << " "; +#endif + } + file.saveFile(); } // 测试 sym_strides TEST_F(TensorTest, SymStrides) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); // 获取符号化的所有维度步长 c10::SymIntArrayRef sym_strides = tensor.sym_strides(); - - // 验证维度数量 - EXPECT_EQ(sym_strides.size(), 3U); - - // 验证步长值都大于0 + file << std::to_string(sym_strides.size()) << " "; for (size_t i = 0; i < sym_strides.size(); ++i) { - EXPECT_GT(sym_strides[i], 0); +#if USE_PADDLE_API + file << std::to_string(sym_strides[i]) << " "; +#else + file << std::to_string(sym_strides[i].guard_int(__FILE__, __LINE__)) << " "; +#endif } + file.saveFile(); } // 测试 sym_numel TEST_F(TensorTest, SymNumel) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); // 获取符号化的元素总数 c10::SymInt sym_numel = tensor.sym_numel(); +#if USE_PADDLE_API + file << std::to_string(sym_numel) << " "; +#else + file << std::to_string(sym_numel.guard_int(__FILE__, __LINE__)) << " "; +#endif + file << std::to_string(tensor.numel()) << " "; + file.saveFile(); +} - // 验证符号化元素数与实际元素数一致 - EXPECT_EQ(sym_numel, 24); // 2*3*4 +// 测试 defined +TEST_F(TensorTest, Defined) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + file << std::to_string(tensor.defined()) << " "; + file.saveFile(); +} - // 验证与 numel() 结果一致 - EXPECT_EQ(sym_numel, tensor.numel()); +// 测试 reset +TEST_F(TensorTest, Reset) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + tensor.reset(); + file << std::to_string(tensor.defined()) << " "; + file.saveFile(); } } // namespace test