From 6a7a8f3a42d42aad5c85191f062f7168a8010b5d Mon Sep 17 00:00:00 2001 From: youge325 Date: Thu, 22 Jan 2026 10:42:41 +0800 Subject: [PATCH] add sparse related API tests --- test/TensorTest.cpp | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/test/TensorTest.cpp b/test/TensorTest.cpp index 5868349..aa0d450 100644 --- a/test/TensorTest.cpp +++ b/test/TensorTest.cpp @@ -151,6 +151,32 @@ TEST_F(TensorTest, IsCuda) { EXPECT_FALSE(tensor.is_cuda()); } +// 测试 is_sparse +TEST_F(TensorTest, IsSparse) { + // 密集张量应该返回 false + EXPECT_FALSE(tensor.is_sparse()); + + // 创建稀疏 COO 张量 - 先创建模板,再使用 zeros_like + at::TensorOptions sparse_options = + at::TensorOptions().dtype(at::kFloat).layout(at::kSparse); + at::Tensor sparse_template = at::empty({2, 3}, sparse_options); + at::Tensor sparse_tensor = at::zeros_like(sparse_template); + EXPECT_TRUE(sparse_tensor.is_sparse()); +} + +// 测试 is_sparse_csr +TEST_F(TensorTest, IsSparseCsr) { + // 密集张量应该返回 false + EXPECT_FALSE(tensor.is_sparse_csr()); + + // 创建稀疏 CSR 张量 - 先创建模板,再使用 zeros_like + at::TensorOptions sparse_csr_options = + at::TensorOptions().dtype(at::kFloat).layout(at::kSparseCsr); + at::Tensor sparse_csr_template = at::empty({2, 3}, sparse_csr_options); + at::Tensor sparse_csr_tensor = at::zeros_like(sparse_csr_template); + EXPECT_TRUE(sparse_csr_tensor.is_sparse_csr()); +} + // 测试 reshape TEST_F(TensorTest, Reshape) { // Tensor tensor(paddle_tensor_);