Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/file_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,37 @@ void FileManerger::setFileName(const std::string& value) {
std::unique_lock<std::shared_mutex> lock(mutex_);
file_name_ = value;
}

void FileManerger::captureStdout(std::function<void()> func) {
std::unique_lock<std::shared_mutex> lock(mutex_);

if (!file_stream_.is_open()) {
throw std::runtime_error(
"File stream is not open. Call createFile() first.");
}

// 保存原来的 cout buffer
std::streambuf* old_cout_buf = std::cout.rdbuf();

// 创建一个 stringstream 来捕获输出
std::stringstream captured_output;

// 重定向 cout 到 stringstream
std::cout.rdbuf(captured_output.rdbuf());

Comment on lines +82 to +90
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

captureStdout uses std::stringstream but file_manager.cpp does not include <sstream>, which will fail to compile on standard-conforming toolchains. Add the missing header (and include <stdexcept> explicitly if relying on std::runtime_error).

Copilot uses AI. Check for mistakes.
Comment on lines +82 to +90
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redirecting std::cout via std::cout.rdbuf(...) is a process-global side effect and is not made safe by the per-instance mutex_. If tests (or other code) run concurrently, output from other threads/tests may be captured or disrupted. Consider guarding the redirection with a single global/static mutex (and keep the critical section as small as possible) or using a dedicated logging/capture mechanism that avoids global std::cout redirection.

Copilot uses AI. Check for mistakes.
try {
// 执行函数
func();

// 恢复 cout
Comment on lines +74 to +95
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

captureStdout holds mutex_ with a unique_lock while executing the provided callback. If the callback calls any FileManerger method (e.g., operator<<, saveFile()), it will attempt to lock mutex_ again and deadlock. Consider releasing the lock before invoking func() and only locking around access to file_stream_ (or document/enforce that the callback must not call back into FileManerger).

Copilot uses AI. Check for mistakes.
std::cout.rdbuf(old_cout_buf);

// 将捕获的输出写入文件
file_stream_ << captured_output.str();
} catch (...) {
// 确保恢复 cout
std::cout.rdbuf(old_cout_buf);
throw;
}
}
} // namespace paddle_api_test
4 changes: 4 additions & 0 deletions src/file_manager.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include <fstream>
#include <functional>
#include <mutex>
#include <shared_mutex>
#include <string>
Expand All @@ -17,6 +18,9 @@ class FileManerger {
FileManerger& operator<<(const std::string& str);
void saveFile();

// 捕获标准输出到文件
void captureStdout(std::function<void()> func);

private:
mutable std::shared_mutex mutex_;
std::string basic_path_ = "/tmp/paddle_cpp_api_test/";
Expand Down
17 changes: 17 additions & 0 deletions test/TensorTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,23 @@ TEST_F(TensorTest, PinMemoryResult) {
file.saveFile();
}

// 测试 sym_size
TEST_F(TensorTest, SymSize) {
// 获取符号化的单个维度大小
c10::SymInt sym_size_0 = tensor.sym_size(0);
c10::SymInt sym_size_1 = tensor.sym_size(1);
c10::SymInt sym_size_2 = tensor.sym_size(2);

// 验证符号化大小与实际大小一致
EXPECT_EQ(sym_size_0, 2);
EXPECT_EQ(sym_size_1, 3);
EXPECT_EQ(sym_size_2, 4);

// 测试负索引
c10::SymInt sym_size_neg1 = tensor.sym_size(-1);
EXPECT_EQ(sym_size_neg1, 4);
}

// 测试 sym_stride
TEST_F(TensorTest, SymStride) {
// 获取符号化的单个维度步长
Expand Down
132 changes: 132 additions & 0 deletions test/TensorUtilTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
#include <ATen/ops/ones.h>
#include <gtest/gtest.h>
#include <torch/all.h>

#include <string>
#include <vector>

#include "../src/file_manager.h"

extern paddle_api_test::ThreadSafeParam g_custom_param;

namespace at {
namespace test {

using paddle_api_test::FileManerger;
using paddle_api_test::ThreadSafeParam;

class TensorUtilTest : public ::testing::Test {
protected:
void SetUp() override {
std::vector<int64_t> shape = {2, 3, 4};
tensor = at::ones(shape, at::kFloat);
}

at::Tensor tensor;
};

// 测试 toString
TEST_F(TensorUtilTest, ToString) {
auto file_name = g_custom_param.get();
FileManerger file(file_name);
file.createFile();
std::string tensor_str = tensor.toString();
file << tensor_str << " ";
file.saveFile();
Comment on lines +31 to +37
Copy link

Copilot AI Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All tests in this file write to the same result filename (g_custom_param.get() == executable name) and each test calls createFile(), which truncates/removes the file. When the test binary runs with RUN_ALL_TESTS(), only the last test's output will remain, so the earlier API checks (e.g., toString, is_same, use_count) won’t be reflected in the compared result file. Consider consolidating these checks into a single TEST_F, or switch to a per-test output filename scheme and update the result comparison workflow accordingly.

Copilot uses AI. Check for mistakes.
}

// 测试 is_contiguous_or_false
TEST_F(TensorUtilTest, IsContiguousOrFalse) {
auto file_name = g_custom_param.get();
FileManerger file(file_name);
file.createFile();
file << std::to_string(tensor.is_contiguous_or_false()) << " ";

// 测试非连续的tensor
at::Tensor transposed = tensor.transpose(0, 2);
file << std::to_string(transposed.is_contiguous_or_false()) << " ";
file.saveFile();
}

// 测试 is_same
TEST_F(TensorUtilTest, IsSame) {
auto file_name = g_custom_param.get();
FileManerger file(file_name);
file.createFile();

// Test that tensor is same as itself
file << std::to_string(tensor.is_same(tensor)) << " ";

// Test that two different tensors are not the same
at::Tensor other_tensor = at::ones({2, 3, 4}, at::kFloat);
file << std::to_string(tensor.is_same(other_tensor)) << " ";

// Test that a shallow copy points to the same tensor
at::Tensor shallow_copy = tensor;
file << std::to_string(tensor.is_same(shallow_copy)) << " ";

// Test that a view of the tensor
at::Tensor view = tensor.view({24});
file << std::to_string(tensor.is_same(view)) << " ";
file.saveFile();
}

// 测试 use_count
TEST_F(TensorUtilTest, UseCount) {
auto file_name = g_custom_param.get();
FileManerger file(file_name);
file.createFile();

// Get initial use count
size_t initial_count = tensor.use_count();
file << std::to_string(initial_count) << " ";

// Create a copy, should increase use count
{
at::Tensor copy = tensor;
size_t new_count = tensor.use_count();
file << std::to_string(new_count) << " ";
file << std::to_string(new_count - initial_count) << " "; // 差值
}

// After copy goes out of scope, use count should decrease
size_t final_count = tensor.use_count();
file << std::to_string(final_count) << " ";
file.saveFile();
}

// 测试 weak_use_count
TEST_F(TensorUtilTest, WeakUseCount) {
auto file_name = g_custom_param.get();
FileManerger file(file_name);
file.createFile();

// Get initial weak use count
size_t initial_weak_count = tensor.weak_use_count();
file << std::to_string(initial_weak_count) << " ";
file.saveFile();
}

// 测试 print
TEST_F(TensorUtilTest, Print) {
auto file_name = g_custom_param.get();
FileManerger file(file_name);
file.createFile();

// 创建一个小的tensor用于print测试
at::Tensor small_tensor = at::ones({2, 2}, at::kFloat);

// 使用 captureStdout 捕获 print() 的输出
file.captureStdout([&]() {
tensor.print();
small_tensor.print();
});

file << std::to_string(1) << " "; // 如果执行到这里说明print()没有崩溃
file.saveFile();
}

} // namespace test
} // namespace at