-
Notifications
You must be signed in to change notification settings - Fork 3
add pointer related API tests #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -70,4 +70,37 @@ void FileManerger::setFileName(const std::string& value) { | |
| std::unique_lock<std::shared_mutex> lock(mutex_); | ||
| file_name_ = value; | ||
| } | ||
|
|
||
| void FileManerger::captureStdout(std::function<void()> func) { | ||
| std::unique_lock<std::shared_mutex> lock(mutex_); | ||
|
|
||
| if (!file_stream_.is_open()) { | ||
| throw std::runtime_error( | ||
| "File stream is not open. Call createFile() first."); | ||
| } | ||
|
|
||
| // 保存原来的 cout buffer | ||
| std::streambuf* old_cout_buf = std::cout.rdbuf(); | ||
|
|
||
| // 创建一个 stringstream 来捕获输出 | ||
| std::stringstream captured_output; | ||
|
|
||
| // 重定向 cout 到 stringstream | ||
| std::cout.rdbuf(captured_output.rdbuf()); | ||
|
|
||
|
Comment on lines
+82
to
+90
|
||
| try { | ||
| // 执行函数 | ||
| func(); | ||
|
|
||
| // 恢复 cout | ||
|
Comment on lines
+74
to
+95
|
||
| std::cout.rdbuf(old_cout_buf); | ||
|
|
||
| // 将捕获的输出写入文件 | ||
| file_stream_ << captured_output.str(); | ||
| } catch (...) { | ||
| // 确保恢复 cout | ||
| std::cout.rdbuf(old_cout_buf); | ||
| throw; | ||
| } | ||
| } | ||
| } // namespace paddle_api_test | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,132 @@ | ||
| #include <ATen/ATen.h> | ||
| #include <ATen/core/Tensor.h> | ||
| #include <ATen/ops/ones.h> | ||
| #include <gtest/gtest.h> | ||
| #include <torch/all.h> | ||
|
|
||
| #include <string> | ||
| #include <vector> | ||
|
|
||
| #include "../src/file_manager.h" | ||
|
|
||
| extern paddle_api_test::ThreadSafeParam g_custom_param; | ||
|
|
||
| namespace at { | ||
| namespace test { | ||
|
|
||
| using paddle_api_test::FileManerger; | ||
| using paddle_api_test::ThreadSafeParam; | ||
|
|
||
| class TensorUtilTest : public ::testing::Test { | ||
| protected: | ||
| void SetUp() override { | ||
| std::vector<int64_t> shape = {2, 3, 4}; | ||
| tensor = at::ones(shape, at::kFloat); | ||
| } | ||
|
|
||
| at::Tensor tensor; | ||
| }; | ||
|
|
||
| // 测试 toString | ||
| TEST_F(TensorUtilTest, ToString) { | ||
| auto file_name = g_custom_param.get(); | ||
| FileManerger file(file_name); | ||
| file.createFile(); | ||
| std::string tensor_str = tensor.toString(); | ||
| file << tensor_str << " "; | ||
| file.saveFile(); | ||
|
Comment on lines
+31
to
+37
|
||
| } | ||
|
|
||
| // 测试 is_contiguous_or_false | ||
| TEST_F(TensorUtilTest, IsContiguousOrFalse) { | ||
| auto file_name = g_custom_param.get(); | ||
| FileManerger file(file_name); | ||
| file.createFile(); | ||
| file << std::to_string(tensor.is_contiguous_or_false()) << " "; | ||
|
|
||
| // 测试非连续的tensor | ||
| at::Tensor transposed = tensor.transpose(0, 2); | ||
| file << std::to_string(transposed.is_contiguous_or_false()) << " "; | ||
| file.saveFile(); | ||
| } | ||
|
|
||
| // 测试 is_same | ||
| TEST_F(TensorUtilTest, IsSame) { | ||
| auto file_name = g_custom_param.get(); | ||
| FileManerger file(file_name); | ||
| file.createFile(); | ||
|
|
||
| // Test that tensor is same as itself | ||
| file << std::to_string(tensor.is_same(tensor)) << " "; | ||
|
|
||
| // Test that two different tensors are not the same | ||
| at::Tensor other_tensor = at::ones({2, 3, 4}, at::kFloat); | ||
| file << std::to_string(tensor.is_same(other_tensor)) << " "; | ||
|
|
||
| // Test that a shallow copy points to the same tensor | ||
| at::Tensor shallow_copy = tensor; | ||
| file << std::to_string(tensor.is_same(shallow_copy)) << " "; | ||
|
|
||
| // Test that a view of the tensor | ||
| at::Tensor view = tensor.view({24}); | ||
| file << std::to_string(tensor.is_same(view)) << " "; | ||
| file.saveFile(); | ||
| } | ||
|
|
||
| // 测试 use_count | ||
| TEST_F(TensorUtilTest, UseCount) { | ||
| auto file_name = g_custom_param.get(); | ||
| FileManerger file(file_name); | ||
| file.createFile(); | ||
|
|
||
| // Get initial use count | ||
| size_t initial_count = tensor.use_count(); | ||
| file << std::to_string(initial_count) << " "; | ||
|
|
||
| // Create a copy, should increase use count | ||
| { | ||
| at::Tensor copy = tensor; | ||
| size_t new_count = tensor.use_count(); | ||
| file << std::to_string(new_count) << " "; | ||
| file << std::to_string(new_count - initial_count) << " "; // 差值 | ||
| } | ||
|
|
||
| // After copy goes out of scope, use count should decrease | ||
| size_t final_count = tensor.use_count(); | ||
| file << std::to_string(final_count) << " "; | ||
| file.saveFile(); | ||
| } | ||
|
|
||
| // 测试 weak_use_count | ||
| TEST_F(TensorUtilTest, WeakUseCount) { | ||
| auto file_name = g_custom_param.get(); | ||
| FileManerger file(file_name); | ||
| file.createFile(); | ||
|
|
||
| // Get initial weak use count | ||
| size_t initial_weak_count = tensor.weak_use_count(); | ||
| file << std::to_string(initial_weak_count) << " "; | ||
| file.saveFile(); | ||
| } | ||
|
|
||
| // 测试 print | ||
| TEST_F(TensorUtilTest, Print) { | ||
| auto file_name = g_custom_param.get(); | ||
| FileManerger file(file_name); | ||
| file.createFile(); | ||
|
|
||
| // 创建一个小的tensor用于print测试 | ||
| at::Tensor small_tensor = at::ones({2, 2}, at::kFloat); | ||
|
|
||
| // 使用 captureStdout 捕获 print() 的输出 | ||
| file.captureStdout([&]() { | ||
| tensor.print(); | ||
| small_tensor.print(); | ||
| }); | ||
|
|
||
| file << std::to_string(1) << " "; // 如果执行到这里说明print()没有崩溃 | ||
| file.saveFile(); | ||
| } | ||
|
|
||
| } // namespace test | ||
| } // namespace at | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
captureStdoutusesstd::stringstreambutfile_manager.cppdoes not include<sstream>, which will fail to compile on standard-conforming toolchains. Add the missing header (and include<stdexcept>explicitly if relying onstd::runtime_error).