diff --git a/test/TensorTest.cpp b/test/TensorTest.cpp index 5868349..aa0d450 100644 --- a/test/TensorTest.cpp +++ b/test/TensorTest.cpp @@ -151,6 +151,32 @@ TEST_F(TensorTest, IsCuda) { EXPECT_FALSE(tensor.is_cuda()); } +// 测试 is_sparse +TEST_F(TensorTest, IsSparse) { + // 密集张量应该返回 false + EXPECT_FALSE(tensor.is_sparse()); + + // 创建稀疏 COO 张量 - 先创建模板,再使用 zeros_like + at::TensorOptions sparse_options = + at::TensorOptions().dtype(at::kFloat).layout(at::kSparse); + at::Tensor sparse_template = at::empty({2, 3}, sparse_options); + at::Tensor sparse_tensor = at::zeros_like(sparse_template); + EXPECT_TRUE(sparse_tensor.is_sparse()); +} + +// 测试 is_sparse_csr +TEST_F(TensorTest, IsSparseCsr) { + // 密集张量应该返回 false + EXPECT_FALSE(tensor.is_sparse_csr()); + + // 创建稀疏 CSR 张量 - 先创建模板,再使用 zeros_like + at::TensorOptions sparse_csr_options = + at::TensorOptions().dtype(at::kFloat).layout(at::kSparseCsr); + at::Tensor sparse_csr_template = at::empty({2, 3}, sparse_csr_options); + at::Tensor sparse_csr_tensor = at::zeros_like(sparse_csr_template); + EXPECT_TRUE(sparse_csr_tensor.is_sparse_csr()); +} + // 测试 reshape TEST_F(TensorTest, Reshape) { // Tensor tensor(paddle_tensor_);