From 4beb2688b09077a0751de8c46474827aa15e8bea Mon Sep 17 00:00:00 2001 From: Le-soleile <3516093767@qq.com> Date: Mon, 2 Feb 2026 18:22:39 +0800 Subject: [PATCH] add all test --- test/TensorTest.cpp | 62 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/test/TensorTest.cpp b/test/TensorTest.cpp index acfcdaf..4e2fdb3 100644 --- a/test/TensorTest.cpp +++ b/test/TensorTest.cpp @@ -213,5 +213,67 @@ TEST_F(TensorTest, Transpose) { file.saveFile(); } +// 测试 all() +TEST_F(TensorTest, AllNoDim) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + std::vector shape = {2, 3}; + at::Tensor t = at::ones(shape, at::kFloat); + at::Tensor out = t.all(); + file << std::to_string(out.dim()) << " "; + file << std::to_string(static_cast(out.scalar_type())) << " "; + file << std::to_string(static_cast(out.data_ptr()[0])) << " "; + file.saveFile(); +} + +// 测试 all(dim, keepdim) +TEST_F(TensorTest, AllDim) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + std::vector shape = {2, 3}; + at::Tensor t = at::ones(shape, at::kFloat); + at::Tensor out = t.all(0, false); + file << std::to_string(out.dim()) << " "; + file << std::to_string(out.size(0)) << " "; + bool* p = out.data_ptr(); + for (int64_t i = 0; i < out.numel(); ++i) { + file << std::to_string(static_cast(p[i])) << " "; + } + file.saveFile(); +} + +// 测试 all(dim, keepdim=true) +TEST_F(TensorTest, AllOptDim) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + std::vector shape = {2, 3}; + at::Tensor t = at::ones(shape, at::kFloat); + at::Tensor out = t.all(1, true); + file << std::to_string(out.dim()) << " "; + file << std::to_string(out.size(0)) << " "; + file << std::to_string(out.size(1)) << " "; + file.saveFile(); +} + +// 测试 allclose +TEST_F(TensorTest, Allclose) { + auto file_name = g_custom_param.get(); + FileManerger file(file_name); + file.createFile(); + std::vector shape = {2, 3}; + at::Tensor a = at::ones(shape, at::kFloat); + at::Tensor b = at::ones(shape, at::kFloat); + bool result = a.allclose(b, 1e-5, 1e-8, false); + file << std::to_string(static_cast(result)) << " "; + at::Tensor c = at::ones(shape, at::kFloat); + c.data_ptr()[0] = 100.0f; + bool result2 = a.allclose(c, 1e-5, 1e-8, false); + file << std::to_string(static_cast(result2)) << " "; + file.saveFile(); +} + } // namespace test } // namespace at