Skip to content

Commit 55f9c81

Browse files
committed
add some very basic testing of the tensor library
1 parent 9dccd64 commit 55f9c81

File tree

3 files changed

+91
-0
lines changed

3 files changed

+91
-0
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
22
project(nuTens)
3+
enable_testing()
34

45
find_package(Torch REQUIRED)
56
find_package(Protobuf REQUIRED)

tests/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
add_executable(basic basic.cpp)
3+
target_link_libraries(basic PUBLIC tensor)
4+
target_include_directories(basic PUBLIC "${CMAKE_SOURCE_DIR}")
5+
6+
add_test(NAME basicTest COMMAND basic)

tests/basic.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#include <nuTens/tensors/tensor.hpp>
2+
#include <nuTens/tensors/dtypes.hpp>
3+
#include <complex.h>
4+
5+
/*
6+
Do some very basic tests of tensor functionality
7+
e.g. test that complex matrices work as expected, 1+1 == 2 etc.
8+
*/
9+
10+
int main(){
11+
std::cout << "Tensor library: " << Tensor::getTensorLibrary() << std::endl;
12+
13+
std::cout << "########################################" << std::endl;
14+
std::cout << "Float: " << std::endl;
15+
Tensor tensorFloat;
16+
tensorFloat.zeros({3, 3}, NTdtypes::kDouble).dType(NTdtypes::kFloat).device(NTdtypes::kCPU);
17+
tensorFloat.setValue({0,0}, 0.0);
18+
tensorFloat.setValue({0,1}, 1.0);
19+
tensorFloat.setValue({0,2}, 2.0);
20+
21+
tensorFloat.setValue({1,0}, 3.0);
22+
tensorFloat.setValue({1,1}, 4.0);
23+
tensorFloat.setValue({1,2}, 5.0);
24+
25+
tensorFloat.setValue({2,0}, 6.0);
26+
tensorFloat.setValue({2,1}, 7.0);
27+
tensorFloat.setValue({2,2}, 8.0);
28+
std::cout << "real: " << std::endl << tensorFloat.real() << std::endl;
29+
std::cout << "Middle value: " << tensorFloat.getValue<float>({1,1}) << std::endl;
30+
31+
Tensor realSquared = Tensor::matmul(tensorFloat, tensorFloat);
32+
std::cout << "Squared: " << std::endl;
33+
std::cout << realSquared << std::endl;
34+
std::cout << "########################################" << std::endl << std::endl;
35+
36+
37+
std::cout << "########################################" << std::endl;
38+
std::cout << "Complex float: " << std::endl;
39+
Tensor tensorComplex;
40+
tensorComplex.zeros({3, 3}, NTdtypes::kComplexFloat);
41+
tensorComplex.setValue({0,0}, std::complex<float>(0.0j));
42+
tensorComplex.setValue({0,1}, std::complex<float>(1.0j));
43+
tensorComplex.setValue({0,2}, std::complex<float>(2.0j));
44+
45+
tensorComplex.setValue({1,0}, std::complex<float>(3.0j));
46+
tensorComplex.setValue({1,1}, std::complex<float>(4.0j));
47+
tensorComplex.setValue({1,2}, std::complex<float>(5.0j));
48+
49+
tensorComplex.setValue({2,0}, std::complex<float>(6.0j));
50+
tensorComplex.setValue({2,1}, std::complex<float>(7.0j));
51+
tensorComplex.setValue({2,2}, std::complex<float>(8.0j));
52+
53+
std::cout << "real: " << std::endl << tensorComplex.real() << std::endl;
54+
std::cout << "imag: " << std::endl << tensorComplex.imag() << std::endl << std::endl;
55+
56+
Tensor imagSquared = Tensor::matmul(tensorComplex, tensorComplex);
57+
std::cout << "Squared: " << std::endl;
58+
std::cout << imagSquared << std::endl;
59+
std::cout << "########################################" << std::endl << std::endl;
60+
61+
// check if the real matrix squared is equal to the -ve of the imaginary one squared
62+
if( realSquared != - imagSquared.real()){
63+
std::cerr << std::endl;
64+
std::cerr << "real**2 != -imaginary**2" << std::endl;
65+
std::cerr << std::endl;
66+
return 1;
67+
}
68+
69+
Tensor ones;
70+
ones.ones({3,3}, NTdtypes::kFloat);
71+
Tensor twos = ones + ones;
72+
73+
std::cout << "ones + ones: " << std::endl;
74+
std::cout << twos << std::endl;
75+
76+
// check that adding works
77+
if( twos.getValue<float>({1, 1}) != 2.0 ){
78+
std::cerr << std::endl;
79+
std::cerr << "ERROR: 1 + 1 != 2 !!!!" << std::endl;
80+
std::cerr << std::endl;
81+
return 1;
82+
}
83+
84+
}

0 commit comments

Comments
 (0)