1
+ #include < nuTens/tensors/tensor.hpp>
2
+ #include < nuTens/tensors/dtypes.hpp>
3
+ #include < complex.h>
4
+
5
+ /*
6
+ Do some very basic tests of tensor functionality
7
+ e.g. test that complex matrices work as expected, 1+1 == 2 etc.
8
+ */
9
+
10
+ int main (){
11
+ std::cout << " Tensor library: " << Tensor::getTensorLibrary () << std::endl;
12
+
13
+ std::cout << " ########################################" << std::endl;
14
+ std::cout << " Float: " << std::endl;
15
+ Tensor tensorFloat;
16
+ tensorFloat.zeros ({3 , 3 }, NTdtypes::kDouble ).dType (NTdtypes::kFloat ).device (NTdtypes::kCPU );
17
+ tensorFloat.setValue ({0 ,0 }, 0.0 );
18
+ tensorFloat.setValue ({0 ,1 }, 1.0 );
19
+ tensorFloat.setValue ({0 ,2 }, 2.0 );
20
+
21
+ tensorFloat.setValue ({1 ,0 }, 3.0 );
22
+ tensorFloat.setValue ({1 ,1 }, 4.0 );
23
+ tensorFloat.setValue ({1 ,2 }, 5.0 );
24
+
25
+ tensorFloat.setValue ({2 ,0 }, 6.0 );
26
+ tensorFloat.setValue ({2 ,1 }, 7.0 );
27
+ tensorFloat.setValue ({2 ,2 }, 8.0 );
28
+ std::cout << " real: " << std::endl << tensorFloat.real () << std::endl;
29
+ std::cout << " Middle value: " << tensorFloat.getValue <float >({1 ,1 }) << std::endl;
30
+
31
+ Tensor realSquared = Tensor::matmul (tensorFloat, tensorFloat);
32
+ std::cout << " Squared: " << std::endl;
33
+ std::cout << realSquared << std::endl;
34
+ std::cout << " ########################################" << std::endl << std::endl;
35
+
36
+
37
+ std::cout << " ########################################" << std::endl;
38
+ std::cout << " Complex float: " << std::endl;
39
+ Tensor tensorComplex;
40
+ tensorComplex.zeros ({3 , 3 }, NTdtypes::kComplexFloat );
41
+ tensorComplex.setValue ({0 ,0 }, std::complex<float >(0 .0j));
42
+ tensorComplex.setValue ({0 ,1 }, std::complex<float >(1 .0j));
43
+ tensorComplex.setValue ({0 ,2 }, std::complex<float >(2 .0j));
44
+
45
+ tensorComplex.setValue ({1 ,0 }, std::complex<float >(3 .0j));
46
+ tensorComplex.setValue ({1 ,1 }, std::complex<float >(4 .0j));
47
+ tensorComplex.setValue ({1 ,2 }, std::complex<float >(5 .0j));
48
+
49
+ tensorComplex.setValue ({2 ,0 }, std::complex<float >(6 .0j));
50
+ tensorComplex.setValue ({2 ,1 }, std::complex<float >(7 .0j));
51
+ tensorComplex.setValue ({2 ,2 }, std::complex<float >(8 .0j));
52
+
53
+ std::cout << " real: " << std::endl << tensorComplex.real () << std::endl;
54
+ std::cout << " imag: " << std::endl << tensorComplex.imag () << std::endl << std::endl;
55
+
56
+ Tensor imagSquared = Tensor::matmul (tensorComplex, tensorComplex);
57
+ std::cout << " Squared: " << std::endl;
58
+ std::cout << imagSquared << std::endl;
59
+ std::cout << " ########################################" << std::endl << std::endl;
60
+
61
+ // check if the real matrix squared is equal to the -ve of the imaginary one squared
62
+ if ( realSquared != - imagSquared.real ()){
63
+ std::cerr << std::endl;
64
+ std::cerr << " real**2 != -imaginary**2" << std::endl;
65
+ std::cerr << std::endl;
66
+ return 1 ;
67
+ }
68
+
69
+ Tensor ones;
70
+ ones.ones ({3 ,3 }, NTdtypes::kFloat );
71
+ Tensor twos = ones + ones;
72
+
73
+ std::cout << " ones + ones: " << std::endl;
74
+ std::cout << twos << std::endl;
75
+
76
+ // check that adding works
77
+ if ( twos.getValue <float >({1 , 1 }) != 2.0 ){
78
+ std::cerr << std::endl;
79
+ std::cerr << " ERROR: 1 + 1 != 2 !!!!" << std::endl;
80
+ std::cerr << std::endl;
81
+ return 1 ;
82
+ }
83
+
84
+ }
0 commit comments