Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions test/TensorTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,32 @@ TEST_F(TensorTest, IsCuda) {
EXPECT_FALSE(tensor.is_cuda());
}

// 测试 is_sparse
TEST_F(TensorTest, IsSparse) {
// 密集张量应该返回 false
EXPECT_FALSE(tensor.is_sparse());

// 创建稀疏 COO 张量 - 先创建模板,再使用 zeros_like
at::TensorOptions sparse_options =
at::TensorOptions().dtype(at::kFloat).layout(at::kSparse);
at::Tensor sparse_template = at::empty({2, 3}, sparse_options);
at::Tensor sparse_tensor = at::zeros_like(sparse_template);
EXPECT_TRUE(sparse_tensor.is_sparse());
}

// 测试 is_sparse_csr
TEST_F(TensorTest, IsSparseCsr) {
// 密集张量应该返回 false
EXPECT_FALSE(tensor.is_sparse_csr());

// 创建稀疏 CSR 张量 - 先创建模板,再使用 zeros_like
at::TensorOptions sparse_csr_options =
at::TensorOptions().dtype(at::kFloat).layout(at::kSparseCsr);
at::Tensor sparse_csr_template = at::empty({2, 3}, sparse_csr_options);
at::Tensor sparse_csr_tensor = at::zeros_like(sparse_csr_template);
EXPECT_TRUE(sparse_csr_tensor.is_sparse_csr());
}

// 测试 reshape
TEST_F(TensorTest, Reshape) {
// Tensor tensor(paddle_tensor_);
Expand Down