diff --git a/src/test/java/org/flag4j/sparse_csr_complex_matrix/ComplexCsrMatMultTests.java b/src/test/java/org/flag4j/sparse_csr_complex_matrix/ComplexCsrMatMultTests.java index 747b26f29..125c38e05 100644 --- a/src/test/java/org/flag4j/sparse_csr_complex_matrix/ComplexCsrMatMultTests.java +++ b/src/test/java/org/flag4j/sparse_csr_complex_matrix/ComplexCsrMatMultTests.java @@ -185,6 +185,8 @@ void standardAsSparseTests() { expCsrColIndices = new int[]{1, 4, 10, 12, 13, 1, 6, 10, 12, 14, 7, 7, 2, 6, 14, 1, 2, 3, 7, 8, 12, 1, 2, 3, 8, 9, 6, 8, 9, 12}; expCsr = new CsrCMatrix(expCsrShape, expCsrEntries, expCsrRowPointers, expCsrColIndices); + assertEquals(expCsr, A.mult2CSR(B)); + // ---------------------- Sub-case 2 ---------------------- aShape = new Shape(1201, 502); aEntries = new CNumber[]{new CNumber(0.80192, 0.00457), new CNumber(0.75879, 0.96137), new CNumber(0.7197, 0.79965), new CNumber(0.27629, 0.72594), new CNumber(0.49392, 0.66413), new CNumber(0.4067, 0.26811), new CNumber(0.10677, 0.6162), new CNumber(0.07077, 0.41851), new CNumber(0.50209, 0.29478), new CNumber(0.40524, 0.70039), new CNumber(0.30549, 0.65839), new CNumber(0.60399, 0.91417), new CNumber(0.3789, 0.59518), new CNumber(0.28813, 0.22652), new CNumber(0.95052, 0.01256), new CNumber(0.30087, 0.23646), new CNumber(0.92451, 0.7899), new CNumber(0.7624, 0.63808), new CNumber(0.79608, 0.04811), new CNumber(0.1743, 0.94944), new CNumber(0.00145, 0.56106), new CNumber(0.25729, 0.03609), new CNumber(0.72895, 0.64931), new CNumber(0.47464, 0.41308), new CNumber(0.71833, 0.15304), new CNumber(0.98708, 0.98231), new CNumber(0.64243, 0.95304), new CNumber(0.90844, 0.90654), new CNumber(0.11218, 0.33391), new CNumber(0.4905, 0.98171), new CNumber(0.81564, 0.68582), new CNumber(0.1198, 0.38116), new CNumber(0.33929, 0.50226), new CNumber(0.54688, 0.69092), new CNumber(0.94487, 0.4582), new CNumber(0.4023, 0.02026), new CNumber(0.2841, 0.29161), new CNumber(0.40666, 0.46315), new CNumber(0.47243, 0.5369), new CNumber(0.46895, 0.87651), new CNumber(0.17839, 0.12726), new CNumber(0.3982, 0.43276), new CNumber(0.94624, 0.59931), new CNumber(0.55679, 0.95208), new CNumber(0.30977, 0.8028), new CNumber(0.31346, 0.2119), new CNumber(0.56704, 0.88278), new CNumber(0.94851, 0.51192), new CNumber(0.13748, 0.13437), new CNumber(0.03488, 0.6052), new CNumber(0.15132, 0.81863), new CNumber(0.09701, 0.64024), new CNumber(0.63921, 0.09402), new CNumber(0.14725, 0.86802), new CNumber(0.74567, 0.92625), new CNumber(0.71516, 0.17888), new CNumber(0.76541, 0.02061), new CNumber(0.90628, 0.80481), new CNumber(0.40745, 0.41716), new CNumber(0.97685, 0.29665)}; @@ -204,7 +206,31 @@ void standardAsSparseTests() { expCsrColIndices = new int[]{624, 498, 343, 1514, 230, 1081, 821, 716, 651, 1138, 1485}; expCsr = new CsrCMatrix(expCsrShape, expCsrEntries, expCsrRowPointers, expCsrColIndices); + assertEquals(expCsr, A.mult2CSR(B)); + // ---------------------- Sub-case 3 ---------------------- + aShape = new Shape(12, 15); + aEntries = new CNumber[]{new CNumber(0.22316, 0.35562), new CNumber(0.94242, 0.37764), new CNumber(0.6349, 0.4183), new CNumber(0.98095, 0.96738), new CNumber(0.79842, 0.03641), new CNumber(0.49714, 0.20468), new CNumber(0.44936, 0.37731), new CNumber(0.82674, 0.72264), new CNumber(0.29454, 0.8007), new CNumber(0.64223, 0.46247), new CNumber(0.24694, 0.87524), new CNumber(0.0005, 0.42085), new CNumber(0.2366, 0.44245), new CNumber(0.9721, 0.0875), new CNumber(0.55155, 0.62224), new CNumber(0.76339, 0.05573), new CNumber(0.57728, 0.45146), new CNumber(0.87833, 0.75748)}; + aRowPointers = new int[]{0, 2, 3, 4, 5, 8, 9, 10, 12, 14, 16, 17, 18}; + aColIndices = new int[]{5, 11, 11, 7, 4, 2, 7, 13, 7, 2, 3, 7, 4, 7, 3, 10, 7, 14}; + A = new CsrCMatrix(aShape, aEntries, aRowPointers, aColIndices); + + bShape = new Shape(15, 3); + bRealEntries = new double[]{0.8897510865291638, 0.02884886126710373, 0.9073269737206647, 0.9197691430889076, 0.6633399541563916 + , 0.7645310847724797, 0.6949989054896308, 0.04413656628380547, 0.13411817488702715}; + bRowPointers = new int[]{0, 1, 1, 1, 3, 3, 3, 4, 6, 8, 8, 8, 8, 8, 8, 9}; + bColIndices = new int[]{0, 0, 2, 2, 1, 2, 0, 2, 0}; + BReal = new CsrMatrix(bShape, bRealEntries, bRowPointers, bColIndices); + + expCsrShape = new Shape(12, 3); + expCsrEntries = new CNumber[]{new CNumber(0.6507033280297123, 0.64170180485181), new CNumber(0.7499667676075639, 0.7395920807872014), new CNumber(0.2980784417997161, 0.2502847981027481), new CNumber(0.34354968825336146, 0.2884652235955043), new CNumber(0.1953801500972236, 0.5311363012930227), new CNumber(0.22518498570888618, 0.6121600395773245), new CNumber(0.0071239378012985955, 0.02524967733541987), new CNumber(0.0003316699770781958, 0.2791666197067174), new CNumber(0.22443758843296718, 1.1158817675057726), new CNumber(0.6448327694354282, 0.05804224598868426), new CNumber(0.7432006675073275, 0.06689646991759196), new CNumber(0.015911589431871064, 0.017950915434842625), new CNumber(0.5004361923556326, 0.5645751361279464), new CNumber(0.38293288873540177, 0.2994714557034445), new CNumber(0.44134850461745706, 0.3451552035313836), new CNumber(0.11780001654852257, 0.10159183511342533)}; + expCsrRowPointers = new int[]{0, 0, 0, 2, 2, 4, 6, 6, 9, 11, 13, 15, 16}; + expCsrColIndices = new int[]{1, 2, 1, 2, 1, 2, 0, 1, 2, 1, 2, 0, 2, 1, 2, 0}; + expCsr = new CsrCMatrix(expCsrShape, expCsrEntries, expCsrRowPointers, expCsrColIndices); + + assertEquals(expCsr, A.mult2CSR(BReal)); + + // ---------------------- Sub-case 4 ---------------------- A = new CsrCMatrix(10, 15); B = new CsrCMatrix(11, 124); assertThrows(LinearAlgebraException.class, ()->A.mult2CSR(B)); @@ -213,7 +239,7 @@ void standardAsSparseTests() { B = new CsrCMatrix(11, 156); assertThrows(LinearAlgebraException.class, ()->A.mult2CSR(B)); - // ---------------------- Sub-case 4 ---------------------- + // ---------------------- Sub-case 5 ---------------------- A = new CsrCMatrix(10, 15); BReal = new CsrMatrix(11, 124); assertThrows(LinearAlgebraException.class, ()->A.mult2CSR(BReal));