diff --git a/src/file_manager.cpp b/src/file_manager.cpp index 6ee0ba0..1fb36ea 100644 --- a/src/file_manager.cpp +++ b/src/file_manager.cpp @@ -70,4 +70,37 @@ void FileManerger::setFileName(const std::string& value) { std::unique_lock lock(mutex_); file_name_ = value; } + +void FileManerger::captureStdout(std::function func) { + std::unique_lock lock(mutex_); + + if (!file_stream_.is_open()) { + throw std::runtime_error( + "File stream is not open. Call createFile() first."); + } + + // 保存原来的 cout buffer + std::streambuf* old_cout_buf = std::cout.rdbuf(); + + // 创建一个 stringstream 来捕获输出 + std::stringstream captured_output; + + // 重定向 cout 到 stringstream + std::cout.rdbuf(captured_output.rdbuf()); + + try { + // 执行函数 + func(); + + // 恢复 cout + std::cout.rdbuf(old_cout_buf); + + // 将捕获的输出写入文件 + file_stream_ << captured_output.str(); + } catch (...) { + // 确保恢复 cout + std::cout.rdbuf(old_cout_buf); + throw; + } +} } // namespace paddle_api_test diff --git a/src/file_manager.h b/src/file_manager.h index 1d7a5da..2672a63 100644 --- a/src/file_manager.h +++ b/src/file_manager.h @@ -1,5 +1,6 @@ #pragma once #include +#include #include #include #include @@ -17,6 +18,9 @@ class FileManerger { FileManerger& operator<<(const std::string& str); void saveFile(); + // 捕获标准输出到文件 + void captureStdout(std::function func); + private: mutable std::shared_mutex mutex_; std::string basic_path_ = "/tmp/paddle_cpp_api_test/"; diff --git a/test/TensorTest.cpp b/test/TensorTest.cpp index 49c9294..8bf0fd6 100644 --- a/test/TensorTest.cpp +++ b/test/TensorTest.cpp @@ -276,6 +276,23 @@ TEST_F(TensorTest, PinMemoryResult) { file.saveFile(); } +// 测试 sym_size +TEST_F(TensorTest, SymSize) { + // 获取符号化的单个维度大小 + c10::SymInt sym_size_0 = tensor.sym_size(0); + c10::SymInt sym_size_1 = tensor.sym_size(1); + c10::SymInt sym_size_2 = tensor.sym_size(2); + + // 验证符号化大小与实际大小一致 + EXPECT_EQ(sym_size_0, 2); + EXPECT_EQ(sym_size_1, 3); + EXPECT_EQ(sym_size_2, 4); + + // 测试负索引 + c10::SymInt sym_size_neg1 = tensor.sym_size(-1); + EXPECT_EQ(sym_size_neg1, 4); +} + // 测试 sym_stride TEST_F(TensorTest, SymStride) { // 获取符号化的单个维度步长 diff --git a/test/TensorUtilTest.cpp b/test/TensorUtilTest.cpp new file mode 100644 index 0000000..ba734b5 --- /dev/null +++ b/test/TensorUtilTest.cpp @@ -0,0 +1,132 @@ +#include +#include +#include +#include +#include + +#include +#include + +#include "../src/file_manager.h" + +extern paddle_api_test::ThreadSafeParam g_custom_param; + +namespace at { +namespace test { + +using paddle_api_test::FileManerger; +using paddle_api_test::ThreadSafeParam; + +class TensorUtilTest : public ::testing::Test { + protected: + void SetUp() override { + std::vector shape = {2, 3, 4}; + tensor = at::ones(shape, at::kFloat); + } + + at::Tensor tensor; +}; + +// 测试 toString +TEST_F(TensorUtilTest, ToString) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + std::string tensor_str = tensor.toString(); + file << tensor_str << " "; + file.saveFile(); +} + +// 测试 is_contiguous_or_false +TEST_F(TensorUtilTest, IsContiguousOrFalse) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + file << std::to_string(tensor.is_contiguous_or_false()) << " "; + + // 测试非连续的tensor + at::Tensor transposed = tensor.transpose(0, 2); + file << std::to_string(transposed.is_contiguous_or_false()) << " "; + file.saveFile(); +} + +// 测试 is_same +TEST_F(TensorUtilTest, IsSame) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + + // Test that tensor is same as itself + file << std::to_string(tensor.is_same(tensor)) << " "; + + // Test that two different tensors are not the same + at::Tensor other_tensor = at::ones({2, 3, 4}, at::kFloat); + file << std::to_string(tensor.is_same(other_tensor)) << " "; + + // Test that a shallow copy points to the same tensor + at::Tensor shallow_copy = tensor; + file << std::to_string(tensor.is_same(shallow_copy)) << " "; + + // Test that a view of the tensor + at::Tensor view = tensor.view({24}); + file << std::to_string(tensor.is_same(view)) << " "; + file.saveFile(); +} + +// 测试 use_count +TEST_F(TensorUtilTest, UseCount) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + + // Get initial use count + size_t initial_count = tensor.use_count(); + file << std::to_string(initial_count) << " "; + + // Create a copy, should increase use count + { + at::Tensor copy = tensor; + size_t new_count = tensor.use_count(); + file << std::to_string(new_count) << " "; + file << std::to_string(new_count - initial_count) << " "; // 差值 + } + + // After copy goes out of scope, use count should decrease + size_t final_count = tensor.use_count(); + file << std::to_string(final_count) << " "; + file.saveFile(); +} + +// 测试 weak_use_count +TEST_F(TensorUtilTest, WeakUseCount) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + + // Get initial weak use count + size_t initial_weak_count = tensor.weak_use_count(); + file << std::to_string(initial_weak_count) << " "; + file.saveFile(); +} + +// 测试 print +TEST_F(TensorUtilTest, Print) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + + // 创建一个小的tensor用于print测试 + at::Tensor small_tensor = at::ones({2, 2}, at::kFloat); + + // 使用 captureStdout 捕获 print() 的输出 + file.captureStdout([&]() { + tensor.print(); + small_tensor.print(); + }); + + file << std::to_string(1) << " "; // 如果执行到这里说明print()没有崩溃 + file.saveFile(); +} + +} // namespace test +} // namespace at