Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions test/TensorTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,23 @@
#include <ATen/core/Tensor.h>
#include <ATen/ops/ones.h>
#include <gtest/gtest.h>
#if !USE_PADDLE_API
#include <torch/all.h>
#endif

#include <string>
#include <vector>
#if USE_PADDLE_API
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compat api的目的就是和torch头文件完全兼容,这里不需要增加额外的宏

#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/memory/malloc.h"
namespace phi {
inline std::ostream& operator<<(std::ostream& os, AllocationType type) {
return os << static_cast<int>(type);
}
} // namespace phi
#endif

#include "../src/file_manager.h"

Expand Down Expand Up @@ -213,5 +226,66 @@ TEST_F(TensorTest, Transpose) {
file.saveFile();
}

// 测试 var(bool unbiased)
TEST_F(TensorTest, VarUnbiased) {
auto file_name = g_custom_param.get();
FileManerger file(file_name);
file.createFile();
std::vector<int64_t> shape = {2, 3};
at::Tensor test_tensor = at::ones(shape, at::kFloat);
test_tensor.data_ptr<float>()[0] = 1.0f;
test_tensor.data_ptr<float>()[1] = 2.0f;
test_tensor.data_ptr<float>()[2] = 3.0f;
test_tensor.data_ptr<float>()[3] = 4.0f;
test_tensor.data_ptr<float>()[4] = 5.0f;
test_tensor.data_ptr<float>()[5] = 6.0f;
at::Tensor var_result = test_tensor.var(true);
file << std::to_string(var_result.dim()) << " ";
file << std::to_string(var_result.data_ptr<float>()[0]) << " ";
at::Tensor var_result_biased = test_tensor.var(false);
file << std::to_string(var_result_biased.dim()) << " ";
file << std::to_string(var_result_biased.data_ptr<float>()[0]) << " ";
file.saveFile();
}

// 测试 var(OptionalIntArrayRef dim, bool unbiased, bool keepdim)
TEST_F(TensorTest, VarDim) {
auto file_name = g_custom_param.get();
FileManerger file(file_name);
file.createFile();
std::vector<int64_t> shape = {2, 3};
at::Tensor test_tensor = at::ones(shape, at::kFloat);
for (int i = 0; i < 6; ++i) {
test_tensor.data_ptr<float>()[i] = static_cast<float>(i + 1);
}
at::Tensor var_result = test_tensor.var({0}, true, false);
file << std::to_string(var_result.dim()) << " ";
file << std::to_string(var_result.size(0)) << " ";
at::Tensor var_result_keepdim = test_tensor.var({1}, true, true);
file << std::to_string(var_result_keepdim.dim()) << " ";
file << std::to_string(var_result_keepdim.size(0)) << " ";
file << std::to_string(var_result_keepdim.size(1)) << " ";
file.saveFile();
}

// 测试 var(OptionalIntArrayRef dim, optional<Scalar> correction, bool keepdim)
TEST_F(TensorTest, VarCorrection) {
auto file_name = g_custom_param.get();
FileManerger file(file_name);
file.createFile();
std::vector<int64_t> shape = {2, 3};
at::Tensor test_tensor = at::ones(shape, at::kFloat);
for (int i = 0; i < 6; ++i) {
test_tensor.data_ptr<float>()[i] = static_cast<float>(i + 1);
}
at::Tensor var_result = test_tensor.var({0}, at::Scalar(1.0), false);
file << std::to_string(var_result.dim()) << " ";
file << std::to_string(var_result.size(0)) << " ";
at::Tensor var_result_pop = test_tensor.var({0}, at::Scalar(0.0), false);
file << std::to_string(var_result_pop.dim()) << " ";
file << std::to_string(var_result_pop.size(0)) << " ";
file.saveFile();
}

} // namespace test
} // namespace at