diff --git a/docs/flag4j/org/flag4j/linalg/operations/dense/real/RealDenseMatrixMultTranspose.html b/docs/flag4j/org/flag4j/linalg/operations/dense/real/RealDenseMatrixMultTranspose.html index 3b64e931f..51a016244 100644 --- a/docs/flag4j/org/flag4j/linalg/operations/dense/real/RealDenseMatrixMultTranspose.html +++ b/docs/flag4j/org/flag4j/linalg/operations/dense/real/RealDenseMatrixMultTranspose.html @@ -86,7 +86,7 @@

Class RealDenseMatrixMultTranspose

java.lang.Object -
org.flag4j.linalg.ops.dense.real.RealDenseMatrixMultTranspose
+
org.flag4j.linalg.ops.dense.real.RealDenseMatMultTranspose
diff --git a/docs/flag4j/org/flag4j/linalg/operations/dense/real/RealDenseMatrixMultiplication.html b/docs/flag4j/org/flag4j/linalg/operations/dense/real/RealDenseMatrixMultiplication.html index fba3f1650..f542e0c95 100644 --- a/docs/flag4j/org/flag4j/linalg/operations/dense/real/RealDenseMatrixMultiplication.html +++ b/docs/flag4j/org/flag4j/linalg/operations/dense/real/RealDenseMatrixMultiplication.html @@ -86,7 +86,7 @@

Class RealDenseMatrixMultiplication

java.lang.Object -
org.flag4j.linalg.ops.dense.real.RealDenseMatrixMultiplication
+
org.flag4j.linalg.ops.dense.real.RealDenseMatMult
diff --git a/docs/flag4j/org/flag4j/linalg/operations/dense_sparse/csr/field_ops/DenseCsrFieldMatMult.html b/docs/flag4j/org/flag4j/linalg/operations/dense_sparse/csr/field_ops/DenseCsrFieldMatMult.html index 2d69b5d84..21317cda0 100644 --- a/docs/flag4j/org/flag4j/linalg/operations/dense_sparse/csr/field_ops/DenseCsrFieldMatMult.html +++ b/docs/flag4j/org/flag4j/linalg/operations/dense_sparse/csr/field_ops/DenseCsrFieldMatMult.html @@ -86,7 +86,7 @@

Class DenseCsrFieldMatMult

java.lang.Object -
org.flag4j.linalg.ops.dense_sparse.csr.field_ops.DenseCsrFieldMatMult
+
org.flag4j.linalg.ops.dense_sparse.csr.semiring_ops.DenseCsrSemiringMatMult
diff --git a/docs/flag4j/org/flag4j/linalg/operations/sparse/coo/field_ops/CooFieldNorms.html b/docs/flag4j/org/flag4j/linalg/operations/sparse/coo/field_ops/CooFieldNorms.html index c209d29ab..6663a0117 100644 --- a/docs/flag4j/org/flag4j/linalg/operations/sparse/coo/field_ops/CooFieldNorms.html +++ b/docs/flag4j/org/flag4j/linalg/operations/sparse/coo/field_ops/CooFieldNorms.html @@ -86,7 +86,7 @@

Class CooFieldNorms

java.lang.Object -
org.flag4j.linalg.ops.sparse.coo.field_ops.CooFieldNorms
+
org.flag4j.linalg.ops.sparse.coo.ring_ops.CooRingNorms
diff --git a/docs/flag4j/org/flag4j/linalg/operations/sparse/coo/real/RealSparseMatrixGetSet.html b/docs/flag4j/org/flag4j/linalg/operations/sparse/coo/real/RealSparseMatrixGetSet.html index 2a7ea5937..d80604f55 100644 --- a/docs/flag4j/org/flag4j/linalg/operations/sparse/coo/real/RealSparseMatrixGetSet.html +++ b/docs/flag4j/org/flag4j/linalg/operations/sparse/coo/real/RealSparseMatrixGetSet.html @@ -86,7 +86,7 @@

Class RealSparseMatrixGetSet

java.lang.Object -
org.flag4j.linalg.ops.sparse.coo.real.RealSparseMatrixGetSet
+
org.flag4j.linalg.ops.sparse.coo.real.RealCooMatrixGetSet
diff --git a/docs/flag4j/org/flag4j/linalg/operations/sparse/csr/field_ops/CsrFieldEquals.html b/docs/flag4j/org/flag4j/linalg/operations/sparse/csr/field_ops/CsrFieldEquals.html index 9761fe286..64bc69ee2 100644 --- a/docs/flag4j/org/flag4j/linalg/operations/sparse/csr/field_ops/CsrFieldEquals.html +++ b/docs/flag4j/org/flag4j/linalg/operations/sparse/csr/field_ops/CsrFieldEquals.html @@ -86,7 +86,7 @@

Class CsrFieldEquals

java.lang.Object -
org.flag4j.linalg.ops.sparse.csr.field_ops.CsrFieldEquals
+
org.flag4j.linalg.ops.sparse.csr.ring_ops.CsrRingProperties
diff --git a/docs/flag4j/org/flag4j/linalg/ops/dense/real/RealDenseMatrixMultTranspose.html b/docs/flag4j/org/flag4j/linalg/ops/dense/real/RealDenseMatrixMultTranspose.html index 14f83582a..dad610e4b 100644 --- a/docs/flag4j/org/flag4j/linalg/ops/dense/real/RealDenseMatrixMultTranspose.html +++ b/docs/flag4j/org/flag4j/linalg/ops/dense/real/RealDenseMatrixMultTranspose.html @@ -86,7 +86,7 @@

Class RealDenseMatrixMultTranspose

java.lang.Object -
org.flag4j.linalg.ops.dense.real.RealDenseMatrixMultTranspose
+
org.flag4j.linalg.ops.dense.real.RealDenseMatMultTranspose
diff --git a/docs/flag4j/org/flag4j/linalg/ops/dense/real/RealDenseMatrixMultiplication.html b/docs/flag4j/org/flag4j/linalg/ops/dense/real/RealDenseMatrixMultiplication.html index 51c23bc12..7386d4c4b 100644 --- a/docs/flag4j/org/flag4j/linalg/ops/dense/real/RealDenseMatrixMultiplication.html +++ b/docs/flag4j/org/flag4j/linalg/ops/dense/real/RealDenseMatrixMultiplication.html @@ -94,7 +94,7 @@

Class RealDenseMatrixMultiplication

java.lang.Object -
org.flag4j.linalg.ops.dense.real.RealDenseMatrixMultiplication
+
org.flag4j.linalg.ops.dense.real.RealDenseMatMult
diff --git a/docs/flag4j/org/flag4j/linalg/ops/dense_sparse/csr/field_ops/DenseCsrFieldMatMult.html b/docs/flag4j/org/flag4j/linalg/ops/dense_sparse/csr/field_ops/DenseCsrFieldMatMult.html index 1384563fd..5037b3c8a 100644 --- a/docs/flag4j/org/flag4j/linalg/ops/dense_sparse/csr/field_ops/DenseCsrFieldMatMult.html +++ b/docs/flag4j/org/flag4j/linalg/ops/dense_sparse/csr/field_ops/DenseCsrFieldMatMult.html @@ -85,7 +85,7 @@

Class DenseCsrFieldMatMult

java.lang.Object -
org.flag4j.linalg.ops.dense_sparse.csr.field_ops.DenseCsrFieldMatMult
+
org.flag4j.linalg.ops.dense_sparse.csr.semiring_ops.DenseCsrSemiringMatMult
diff --git a/docs/flag4j/org/flag4j/linalg/ops/sparse/coo/field_ops/CooFieldNorms.html b/docs/flag4j/org/flag4j/linalg/ops/sparse/coo/field_ops/CooFieldNorms.html index 5b8fc60f6..097c773b1 100644 --- a/docs/flag4j/org/flag4j/linalg/ops/sparse/coo/field_ops/CooFieldNorms.html +++ b/docs/flag4j/org/flag4j/linalg/ops/sparse/coo/field_ops/CooFieldNorms.html @@ -85,7 +85,7 @@

Class CooFieldNorms

java.lang.Object -
org.flag4j.linalg.ops.sparse.coo.field_ops.CooFieldNorms
+
org.flag4j.linalg.ops.sparse.coo.ring_ops.CooRingNorms
diff --git a/docs/flag4j/org/flag4j/linalg/ops/sparse/coo/real/RealSparseMatrixGetSet.html b/docs/flag4j/org/flag4j/linalg/ops/sparse/coo/real/RealSparseMatrixGetSet.html index de7e2325c..a0c3ece64 100644 --- a/docs/flag4j/org/flag4j/linalg/ops/sparse/coo/real/RealSparseMatrixGetSet.html +++ b/docs/flag4j/org/flag4j/linalg/ops/sparse/coo/real/RealSparseMatrixGetSet.html @@ -104,7 +104,7 @@

Class RealSparseMatrixGetSet

java.lang.Object -
org.flag4j.linalg.ops.sparse.coo.real.RealSparseMatrixGetSet
+
org.flag4j.linalg.ops.sparse.coo.real.RealCooMatrixGetSet
diff --git a/docs/flag4j/org/flag4j/linalg/ops/sparse/csr/field_ops/CsrFieldEquals.html b/docs/flag4j/org/flag4j/linalg/ops/sparse/csr/field_ops/CsrFieldEquals.html index 6f3c15e47..73d4e3c5f 100644 --- a/docs/flag4j/org/flag4j/linalg/ops/sparse/csr/field_ops/CsrFieldEquals.html +++ b/docs/flag4j/org/flag4j/linalg/ops/sparse/csr/field_ops/CsrFieldEquals.html @@ -84,7 +84,7 @@

Class CsrFieldEquals

java.lang.Object -
org.flag4j.linalg.ops.sparse.csr.field_ops.CsrFieldEquals
+
org.flag4j.linalg.ops.sparse.csr.ring_ops.CsrRingProperties
diff --git a/src/main/java/module-info.java b/src/main/java/module-info.java index a8cdeef66..7dd801dcd 100644 --- a/src/main/java/module-info.java +++ b/src/main/java/module-info.java @@ -3,6 +3,7 @@ */ module flag4j { requires java.logging; + requires java.desktop; // Abstract algebra stuff. exports org.flag4j.algebraic_structures; @@ -67,14 +68,12 @@ exports org.flag4j.linalg.ops.sparse; exports org.flag4j.linalg.ops.sparse.coo; - exports org.flag4j.linalg.ops.sparse.coo.field_ops; exports org.flag4j.linalg.ops.sparse.coo.real; exports org.flag4j.linalg.ops.sparse.coo.real_complex; exports org.flag4j.linalg.ops.sparse.coo.ring_ops; exports org.flag4j.linalg.ops.sparse.coo.semiring_ops; exports org.flag4j.linalg.ops.sparse.csr; - exports org.flag4j.linalg.ops.sparse.csr.field_ops; exports org.flag4j.linalg.ops.sparse.csr.real; exports org.flag4j.linalg.ops.sparse.csr.real_complex; // ------------------------------------------------------------------------ @@ -85,4 +84,7 @@ // Utilities exports org.flag4j.util; exports org.flag4j.util.exceptions; + exports org.flag4j.linalg.ops.sparse.csr.ring_ops; + exports org.flag4j.linalg.ops.dense_sparse.csr.semiring_ops; + exports org.flag4j.linalg.decompositions.balance; } \ No newline at end of file diff --git a/src/main/java/org/flag4j/algebraic_structures/Complex128.java b/src/main/java/org/flag4j/algebraic_structures/Complex128.java index bbc21da15..17375ec8b 100644 --- a/src/main/java/org/flag4j/algebraic_structures/Complex128.java +++ b/src/main/java/org/flag4j/algebraic_structures/Complex128.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -350,7 +350,7 @@ public Complex128 mult(Complex128 b) { */ @Override public boolean isZero() { - return equals(ZERO); + return re == 0.0 && im == 0.0; } @@ -363,7 +363,7 @@ public boolean isZero() { */ @Override public boolean isOne() { - return equals(ONE); + return re == 1.0 && im == 0.0; } @@ -398,6 +398,7 @@ public Complex128 getOne() { * @param b Second element in product. * @return The product of this field element and {@code b}. */ + @Override public Complex128 mult(double b) { return new Complex128(re*b, im*b); } diff --git a/src/main/java/org/flag4j/algebraic_structures/Complex64.java b/src/main/java/org/flag4j/algebraic_structures/Complex64.java index 3ba5ff626..e786e41aa 100644 --- a/src/main/java/org/flag4j/algebraic_structures/Complex64.java +++ b/src/main/java/org/flag4j/algebraic_structures/Complex64.java @@ -180,6 +180,15 @@ public Complex64(String num) { } + /** + * Checks if this complex has zero imaginary part and real part equal to a double. + * @return True if {@code this.re == b && this.im == 0}. False otherwise. + */ + public boolean equals(float b) { + return this.re == b && this.im == 0; + } + + /** * Checks if an object is equal to this Field element. * @param b Object to compare to this Field element. @@ -502,18 +511,6 @@ public Complex64 conj() { } - /** - * Compute a raised to the power of b. This method wraps {@link Math#pow(double, double)} - * and returns a {@link Complex64}. - * @param a The base. - * @param b The exponent. - * @return a raised to the power of b. - */ - public static Complex64 pow(float a, float b) { - return new Complex64((float) Math.pow(a, b)); - } - - /** * Compute a raised to the power of b. * and returns a {@link Complex64}. diff --git a/src/main/java/org/flag4j/algebraic_structures/Field.java b/src/main/java/org/flag4j/algebraic_structures/Field.java index 8f069cd03..f7c51feb9 100644 --- a/src/main/java/org/flag4j/algebraic_structures/Field.java +++ b/src/main/java/org/flag4j/algebraic_structures/Field.java @@ -169,7 +169,7 @@ default T mult(double b) { */ default T div(double b) { throw new UnsupportedOperationException("Division with primitive doubles is not supported for this field: " - + getClass() + "."); + + getClass().getName() + "."); } diff --git a/src/main/java/org/flag4j/algebraic_structures/Ring.java b/src/main/java/org/flag4j/algebraic_structures/Ring.java index ad508b683..2c47cca1b 100644 --- a/src/main/java/org/flag4j/algebraic_structures/Ring.java +++ b/src/main/java/org/flag4j/algebraic_structures/Ring.java @@ -105,7 +105,7 @@ public interface Ring> extends Semiring { * @param b Second ring element in difference. * @return The difference of this ring element and {@code b}. */ - public T sub(T b); + T sub(T b); /** @@ -115,7 +115,7 @@ public interface Ring> extends Semiring { * * @return The additive inverse for this ring element. */ - public T addInv(); + T addInv(); /** @@ -133,8 +133,9 @@ default double abs() { * Computes the magnitude of this ring element. * @return The magnitude of this ring element. */ - default public double mag() { - throw new UnsupportedOperationException("Magnitude/absolute value is not defined for this ring: " + getClass() + "."); + default double mag() { + throw new UnsupportedOperationException("Magnitude/absolute value is not defined for this algebraic object: " + + getClass().getName() + "."); } @@ -143,7 +144,8 @@ default public double mag() { * @return The conjugation of this ring's element. * @implNote The default implementation of this method simply returns this rings element. */ - public default T conj() { - throw new UnsupportedOperationException("Magnitude/absolute value is not defined for this ring: " + getClass() + "."); + default T conj() { + throw new UnsupportedOperationException("Conjugation is not defined for this algebraic object: " + + getClass().getName() + "."); } } diff --git a/src/main/java/org/flag4j/arrays/SmartMatrix.java b/src/main/java/org/flag4j/arrays/SmartMatrix.java index d6d30b66b..16d3289ed 100644 --- a/src/main/java/org/flag4j/arrays/SmartMatrix.java +++ b/src/main/java/org/flag4j/arrays/SmartMatrix.java @@ -26,7 +26,6 @@ import org.flag4j.algebraic_structures.Complex128; import org.flag4j.algebraic_structures.Field; -import org.flag4j.algebraic_structures.RealFloat64; import org.flag4j.algebraic_structures.Semiring; import org.flag4j.arrays.backend.MatrixMixin; import org.flag4j.arrays.backend.smart_visitors.*; @@ -394,45 +393,4 @@ public int hashCode() { public String toString() { return "SmartMatrix type: " + this.matrix.getClass().getSimpleName() + "\n" + matrix.toString(); } - - - // TODO: TESTING - public static void main(String[] args) { - var realDense = new Matrix(new double[][]{ - {1, 2, 3}, - {4, 5, 6}, - {7, 8, 9} - }); - - var complexDense = new CMatrix(new Complex128[][]{ - {new Complex128(1.46, 12.6), new Complex128(-2, 1),new Complex128(0, 4)}, - {new Complex128(2), new Complex128(2.1, 3),new Complex128(-1, 3)}, - {new Complex128(1, 1), new Complex128(12, 5),new Complex128(0, -2)} - }); - - var fieldDense = new FieldMatrix<>(new Complex128[][]{ - {new Complex128(1.46, 12.6), new Complex128(-2, 1),new Complex128(0, 4)}, - {new Complex128(2), new Complex128(2.1, 3),new Complex128(-1, 3)}, - {new Complex128(1, 1), new Complex128(12, 5),new Complex128(0, -2)} - }); - - var fieldDense2 = new FieldMatrix<>(new RealFloat64[][]{ - {new RealFloat64(1.5566), new RealFloat64(-9.3), new RealFloat64(0)}, - {new RealFloat64(1), new RealFloat64(-2), new RealFloat64(400.1)}, - {new RealFloat64(2), new RealFloat64(0.84), new RealFloat64(8e9)} - }); - - SmartMatrix a = new SmartMatrix(realDense); - SmartMatrix b = new SmartMatrix(complexDense); - SmartMatrix c = new SmartMatrix(fieldDense); - SmartMatrix d = new SmartMatrix(fieldDense2); - - double trace = a.tr(Double.class); - - System.out.println("a:\n" + a + "\n"); - System.out.println("b:\n" + b + "\n"); - System.out.println("a + b:\n" + a.add(b) + "\n"); - System.out.println("c + c:\n" + d.add(d) + "\n"); - System.out.println(); - } } diff --git a/src/main/java/org/flag4j/arrays/backend/field_arrays/AbstractCooFieldMatrix.java b/src/main/java/org/flag4j/arrays/backend/field_arrays/AbstractCooFieldMatrix.java index 4b90027da..27b3c8e88 100644 --- a/src/main/java/org/flag4j/arrays/backend/field_arrays/AbstractCooFieldMatrix.java +++ b/src/main/java/org/flag4j/arrays/backend/field_arrays/AbstractCooFieldMatrix.java @@ -25,1087 +25,98 @@ package org.flag4j.arrays.backend.field_arrays; import org.flag4j.algebraic_structures.Field; -import org.flag4j.algebraic_structures.Semiring; import org.flag4j.arrays.Shape; -import org.flag4j.arrays.SparseMatrixData; -import org.flag4j.arrays.SparseVectorData; -import org.flag4j.arrays.backend.AbstractTensor; import org.flag4j.arrays.backend.MatrixMixin; +import org.flag4j.arrays.backend.ring_arrays.AbstractCooRingMatrix; import org.flag4j.arrays.sparse.CooMatrix; import org.flag4j.linalg.ops.common.field_ops.FieldOps; import org.flag4j.linalg.ops.common.ring_ops.RingOps; -import org.flag4j.linalg.ops.common.semiring_ops.CompareSemiring; -import org.flag4j.linalg.ops.sparse.SparseElementSearch; -import org.flag4j.linalg.ops.sparse.SparseUtils; -import org.flag4j.linalg.ops.sparse.coo.*; -import org.flag4j.linalg.ops.sparse.coo.field_ops.CooFieldMatrixProperties; -import org.flag4j.linalg.ops.sparse.coo.ring_ops.CooRingMatrixOps; -import org.flag4j.linalg.ops.sparse.coo.semiring_ops.CooSemiringMatMult; -import org.flag4j.linalg.ops.sparse.coo.semiring_ops.CooSemiringMatrixOps; -import org.flag4j.linalg.ops.sparse.coo.semiring_ops.CooSemiringMatrixProperties; -import org.flag4j.util.ArrayUtils; +import org.flag4j.linalg.ops.sparse.coo.CooConversions; import org.flag4j.util.ValidateParameters; -import org.flag4j.util.exceptions.LinearAlgebraException; -import org.flag4j.util.exceptions.TensorShapeException; -import java.math.BigDecimal; -import java.math.RoundingMode; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.function.BinaryOperator; -import static org.flag4j.linalg.ops.sparse.SparseUtils.copyRanges; - - -/** - *

A sparse matrix stored in coordinate list (COO) format. The {@link #data} of this COO vector are - * elements of a {@link Field}. - * - *

The {@link #data non-zero data} and non-zero indices of a COO matrix are mutable but the {@link #shape} - * and total number of non-zero data is fixed. - * - *

Sparse matrices allow for the efficient storage of and ops on matrices that contain many zero values. - * - *

COO matrices are optimized for hyper-sparse matrices (i.e. matrices which contain almost all zeros relative to the size of the - * matrix). - * - *

COO Representation:

- * A sparse COO matrix is stored as: - *
    - *
  • The full {@link #shape shape} of the matrix.
  • - *
  • The non-zero {@link #data} of the matrix. All other data in the matrix are - * assumed to be zero. Zero values can also explicitly be stored in {@link #data}.
  • - *
  • The {@link #rowIndices row indices} of the non-zero values in the sparse matrix.
  • - *
  • The {@link #colIndices column indices} of the non-zero values in the sparse matrix.
  • - *
- * - *

Note: many ops assume that the data of the COO matrix are sorted lexicographically by the row and column indices. - * (i.e.) by row indices first then column indices. However, this is not explicitly verified but any ops implemented in this - * class will preserve the lexicographical sorting. - * - *

If indices need to be sorted, call {@link #sortIndices()}. - * - * @param Type of this sparse COO matrix. - * @param Type of dense matrix which is similar to {@code T}. - * @param Type of sparse COO vector which is similar to {@code T}. - * @param Type of the field element in this matrix. - */ -public abstract class AbstractCooFieldMatrix, - U extends AbstractDenseFieldMatrix, - V extends AbstractCooFieldVector, - W extends Field> - extends AbstractTensor - implements FieldTensorMixin, MatrixMixin { - - /** - * The zero element for the field that this tensor's elements belong to. - */ - private W zeroElement; - /** - * Row indices for non-zero value of this sparse COO matrix. - */ - public final int[] rowIndices; - /** - * column indices for non-zero value of this sparse COO matrix. - */ - public final int[] colIndices; - /** - * Number of non-zero data in this COO matrix. - */ - public final int nnz; - /** - * The number of rows in this matrix. - */ - public final int numRows; - /** - * The number of columns in this matrix. - */ - public final int numCols; - /** - * The sparsity of this matrix. - */ - public final double sparsity; - - - /** - * Creates a sparse coo matrix with the specified non-zero data, non-zero indices, and shape. - * - * @param shape Shape of this tensor. - * @param entries Non-zero data of this sparse matrix. - * @param rowIndices Non-zero row indices of this sparse matrix. - * @param colIndices Non-zero column indies of this sparse matrix. - */ - protected AbstractCooFieldMatrix(Shape shape, W[] entries, int[] rowIndices, int[] colIndices) { - super(shape, entries); - ValidateParameters.ensureRank(shape, 2); - ValidateParameters.ensureIndicesInBounds(shape.get(0), rowIndices); - ValidateParameters.ensureIndicesInBounds(shape.get(1), colIndices); - ValidateParameters.ensureArrayLengthsEq(entries.length, rowIndices.length, colIndices.length); - - this.rowIndices = rowIndices; - this.colIndices = colIndices; - nnz = entries.length; - numRows = shape.get(0); - numCols = shape.get(1); - sparsity = BigDecimal.valueOf(nnz).divide(new BigDecimal(shape.totalEntries()), RoundingMode.HALF_UP).doubleValue(); - - // Attempt to set the zero element for the field. - this.zeroElement = (entries.length > 0) ? entries[0].getZero() : null; - } - - - /** - * Constructs a sparse COO tensor of the same type as this tensor with the specified non-zero data and indices. - * @param shape Shape of the matrix. - * @param entries Non-zero data of the matrix. - * @param rowIndices Non-zero row indices of the matrix. - * @param colIndices Non-zero column indices of the matrix. - * @return A sparse COO tensor of the same type as this tensor with the specified non-zero data and indices. - */ - public abstract T makeLikeTensor(Shape shape, W[] entries, int[] rowIndices, int[] colIndices); - - - /** - * Constructs a COO matrix with the specified shape, non-zero data, and non-zero indices. - * @param shape Shape of the matrix. - * @param entries Non-zero values of the matrix. - * @param rowIndices Non-zero row indices of the matrix. - * @param colIndices Non-zero column indices of the matrix. - * @return A COO matrix with the specified shape, non-zero data, and non-zero indices. - */ - public abstract T makeLikeTensor(Shape shape, List entries, List rowIndices, List colIndices); - - - /** - * Constructs a sparse COO vector of a similar type to this COO matrix. - * @param shape Shape of the vector. Must be rank 1. - * @param entries Non-zero data of the COO vector. - * @param indices Non-zero indices of the COO vector. - * @return A sparse COO vector of a similar type to this COO matrix. - */ - public abstract V makeLikeVector(Shape shape, W[] entries, int[] indices); - - - /** - * Constructs a dense tensor with the specified {@code shape} and {@code data} which is a similar type to this sparse tensor. - * @param shape Shape of the dense tensor. - * @param entries Entries of the dense tensor. - * @return A dense tensor with the specified {@code shape} and {@code data} which is a similar type to this sparse tensor. - */ - public abstract U makeLikeDenseTensor(Shape shape, W[] entries); - - - /** - * Constructs a sparse CSR matrix of a similar type to this sparse COO matrix. - * @param shape Shape of the CSR matrix to construct. - * @param entries Non-zero data of the CSR matrix. - * @param rowPointers Non-zero row pointers of the CSR matrix. - * @param colIndices Non-zero column indices of the CSR matrix. - * @return A CSR matrix of a similar type to this sparse COO matrix. - */ - public abstract AbstractCsrFieldMatrix makeLikeCsrMatrix( - Shape shape, W[] entries, int[] rowPointers, int[] colIndices); - - - /** - * Gets the zero element for the field of this tensor. - * @return The zero element for the field of this tensor. If it could not be determined during construction of this object - * and has not been set explicitly by {@link #setZeroElement(Field)} then {@code null} will be returned. - * - * @see #setZeroElement(Field) - */ - public W getZeroElement() { - return zeroElement; - } - - - /** - * Gets the sparsity of this matrix as a decimal percentage. - * That is, the percentage of data in this matrix that are zero. - * @return The sparsity of this matrix as a decimal percentage. - * @see #density() - */ - public double sparsity() { - return sparsity; - } - - - /** - * Gets the density of this matrix as a decimal percentage. - * That is, the percentage of data in this matrix that are non-zero. - * @return The density of this matrix as a decimal percentage. - * @see #sparsity - */ - public double density() { - return 1.0 - sparsity; - } - - - /** - * Gets the length of the data array which backs this matrix. - * - * @return The length of the data array which backs this matrix. - */ - @Override - public int dataLength() { - return data.length; - } - - - /** - * Sets the zero element for the field of this tensor. - * @param zeroElement The zero element of this tensor. - * @throws IllegalArgumentException If {@code zeroElement} is not an additive identity for the field. - * - * @see #getZeroElement() - */ - public void setZeroElement(W zeroElement) { - if (zeroElement.isZero()) { - this.zeroElement = zeroElement; - } else { - throw new IllegalArgumentException("The provided zeroElement is not an additive identity."); - } - } - - - /** - * Gets the element of this tensor at the specified index. - * - * @param index Indices of the element to get. - * - * @return The element of this tensor at the specified index. If there is a non-zero value with the specified index, that value - * will be returned. If there is no non-zero value at the specified index than the zero element will attempt to be - * returned (i.e. the additive identity of the field). However, if the zero element could not be determined during - * construction or if it was not set with {@link #setZeroElement(Field)} then - * {@code null} will be returned. - * - * @throws ArrayIndexOutOfBoundsException If any index are not within this tensor. - */ - @Override - public W get(int... index) { - ValidateParameters.validateTensorIndex(shape, index); - W value = CooGetSet.getCoo(data, rowIndices, colIndices, index[0], index[1]); - return (value == null) ? getZeroElement() : value; - } - - - /** - * Sets the element of this tensor at the specified indices. - * - * @param value New value to set the specified index of this tensor to. - * @param indices Indices of the element to set. - * - * @return If this tensor is dense, a reference to this tensor is returned. If this tensor is sparse, a copy of this tensor with - * the updated value is returned. - * - * @throws IndexOutOfBoundsException If {@code indices} is not within the bounds of this tensor. - */ - @Override - public T set(W value, int... indices) { - ValidateParameters.validateTensorIndex(shape, indices); - return set(value, indices[0], indices[1]); - } - - - /** - * Sets an index of this matrix to the specified value. - * - * @param value Value to set. - * @param row Row index to set. - * @param col Column index to set. - * - * @return A reference to this matrix. - */ - @Override - public T set(W value, int row, int col) { - // Find position of row index within the row indices if it exits. - int idx = SparseElementSearch.matrixBinarySearch(rowIndices, colIndices, row, col); - W[] destEntries; - int[] destRowIndices; - int[] destColIndices; - - if(idx < 0) { - idx = -idx - 1; - - // No non-zero element with these indices exists. Insert new value. - destEntries = makeEmptyDataArray(data.length + 1); - destRowIndices = new int[data.length + 1]; - destColIndices = new int[data.length + 1]; - - CooGetSet.cooInsertNewValue( - value, row, col, - data, rowIndices, colIndices, idx, - destEntries, destRowIndices, destColIndices); - } else { - // Value with these indices exists. Simply update value. - destEntries = Arrays.copyOf(data, data.length); - destEntries[idx] = value; - destRowIndices = rowIndices.clone(); - destColIndices = colIndices.clone(); - } - - return makeLikeTensor(shape, destEntries, destRowIndices, destColIndices); - } - - - /** - * Flattens this matrix to a single row. - * - * @return The flattened matrix. - * - * @see #flatten(int) - */ - @Override - public T flatten() { - return flatten(0); - } - - - /** - * Flattens a tensor along the specified axis. - * - * @param axis Axis along which to flatten tensor. - * - * @throws ArrayIndexOutOfBoundsException If the axis is not positive or larger than {@code this.{@link #getRank()}-1}. - * @see #flatten() - */ - @Override - public T flatten(int axis) { - ValidateParameters.ensureValidAxes(shape, axis); - int[] dims = {1, 1}; - dims[1-axis] = shape.totalEntriesIntValueExact(); - Shape flatShape = new Shape(dims); - - int[] destIndices = new int[data.length]; - - for(int i = 0; i < data.length; i++) - destIndices[i] = shape.getFlatIndex(rowIndices[i], colIndices[i]); - - return (axis == 0) - ? makeLikeTensor(flatShape, data.clone(), new int[data.length], destIndices) - : makeLikeTensor(flatShape, data.clone(), destIndices, new int[data.length]); - } - - - /** - * Copies and reshapes this tensor. - * - * @param newShape New shape for the tensor. - * - * @return A copy of this tensor with the new shape. - * - * @throws TensorShapeException If {@code newShape} is not broadcastable to {@link #shape this.shape}. - */ - @Override - public T reshape(Shape newShape) { - ValidateParameters.ensureBroadcastable(shape, newShape); - int oldColCount = shape.get(1); - int newColCount = newShape.get(1); - - // Initialize new COO structures with the same size as the original. - int[] newRowIndices = new int[rowIndices.length]; - int[] newColIndices = new int[colIndices.length]; - - for (int i = 0; i < rowIndices.length; i++) { - int flatIndex = rowIndices[i]*oldColCount + colIndices[i]; - newRowIndices[i] = flatIndex / newColCount; - newColIndices[i] = flatIndex % newColCount; - } - - return makeLikeTensor(newShape, data.clone(), newRowIndices, newColIndices); - } - - - /** - * Computes the transpose of a tensor by exchanging the first and last axes of this tensor. - * - * @return The transpose of this tensor. - * - * @see #T(int, int) - * @see #T(int...) - */ - @Override - public T T() { - T transpose = makeLikeTensor(shape.swapAxes(0, 1), data.clone(), colIndices.clone(), rowIndices.clone()); - transpose.sortIndices(); // Ensure the indices are sorted correctly. - return transpose; - } - - - /** - * Computes the transpose of a tensor by exchanging {@code axis1} and {@code axis2}. - * - * @param axis1 First axis to exchange. - * @param axis2 Second axis to exchange. - * - * @return The transpose of this tensor according to the specified axes. - * - * @throws IndexOutOfBoundsException If either {@code axis1} or {@code axis2} are out of bounds for the rank of this tensor. - * @see #T() - * @see #T(int...) - */ - @Override - public T T(int axis1, int axis2) { - ValidateParameters.ensureValidAxes(shape, axis1, axis2); - if(axis1 == axis2) return copy(); - else return T(); - } - - - /** - * Computes the transpose of this tensor. That is, permutes the axes of this tensor so that it matches - * the permutation specified by {@code axes}. - * - * @param axes Permutation of tensor axis. If the tensor has rank {@code N}, then this must be an array of length - * {@code N} which is a permutation of {@code {0, 1, 2, ..., N-1}}. - * - * @return The transpose of this tensor with its axes permuted by the {@code axes} array. - * - * @throws IndexOutOfBoundsException If any element of {@code axes} is out of bounds for the rank of this tensor. - * @throws IllegalArgumentException If {@code axes} is not a permutation of {@code {1, 2, 3, ... N-1}}. - * @see #T(int, int) - * @see #T() - */ - @Override - public T T(int... axes) { - if(axes.length != 2) - throw new IllegalArgumentException("Expecting two axes in transpose but got " + axes.length + "."); - return T(axes[0], axes[1]); - } - - - /** - * Gets the number of rows in this matrix. - * - * @return The number of rows in this matrix. - */ - @Override - public int numRows() { - return numRows; - } - - - /** - * Gets the number of columns in this matrix. - * - * @return The number of columns in this matrix. - */ - @Override - public int numCols() { - return numCols; - } - - - /** - * Gets the element of this matrix at this specified {@code row} and {@code col}. - * - * @param row Row index of the item to get from this matrix. - * @param col Column index of the item to get from this matrix. - * - * @return The element of this matrix at the specified index. - */ - @Override - public W get(int row, int col) { - return CooGetSet.getCoo(data, rowIndices, colIndices, row, col); - } - - - /** - *

Computes the trace of this matrix. That is, the sum of elements along the principle diagonal of this matrix. - * - *

Same as {@link #trace()}. - * - * @return The trace of this matrix. - * - * @throws IllegalArgumentException If this matrix is not square. - */ - @Override - public W tr() { - W trace = getZeroElement(); - - for(int i = 0; i< data.length; i++) - if(rowIndices[i]==colIndices[i]) trace = trace.add(data[i]); // Then entry is on the diagonal. - - return trace; - } - - - /** - * Checks if this matrix is upper triangular. - * - * @return {@code true} is this matrix is upper triangular; {@code false} otherwise. - * - * @see #isTri() - * @see #isTriL() - * @see #isDiag() - */ - @Override - public boolean isTriU() { - for(int i = 0; i< data.length; i++) - if(rowIndices[i] > colIndices[i] && !data[i].isZero()) return false; // Then non-zero entry is not in upper triangle. - - return true; - } - - - /** - * Checks if this matrix is lower triangular. - * - * @return {@code true} is this matrix is lower triangular; {@code false} otherwise. - * - * @see #isTri() - * @see #isTriU() - * @see #isDiag() - */ - @Override - public boolean isTriL() { - for(int i = 0; i< data.length; i++) - if(rowIndices[i] < colIndices[i] && !data[i].isZero()) return false; // Then non-zero entry is not in lower triangle. - - return true; - } - - - /** - * Checks if this matrix is the identity matrix. That is, checks if this matrix is square and contains - * only ones along the principle diagonal and zeros everywhere else. - * - * @return {@code true} if this matrix is the identity matrix; {@code false} otherwise. - */ - @Override - public boolean isI() { - return CooSemiringMatrixProperties.isIdentity(shape, data, rowIndices, colIndices); - } - - - /** - * Computes the matrix multiplication between two matrices. - * - * @param b Second matrix in the matrix multiplication. - * - * @return The result of matrix multiplying this matrix with matrix {@code b}. - * - * @throws LinearAlgebraException If the number of columns in this matrix do not equal the number - * of rows in matrix {@code b}. - */ - @Override - public U mult(T b) { - ValidateParameters.ensureMatMultShapes(shape, b.shape); - W[] dest = makeEmptyDataArray(numRows*b.numCols); - CooSemiringMatMult.standard( - data, rowIndices, colIndices, shape, - b.data, b.rowIndices, b.colIndices, b.shape, dest); - - return makeLikeDenseTensor(new Shape(numRows, b.numCols), dest); - } - - // TODO: Add mult2Sparse methods for all COO matrices. - - /** - * Multiplies this matrix with the transpose of the {@code b} tensor as if by - * {@code this.mult(b.H())}. - * For large matrices, this method may - * be significantly faster than directly computing the Hermitian transpose followed by the multiplication as - * {@code this.mult(b.H())}. - * - * @param b The second matrix in the multiplication and the matrix to transpose. - * - * @return The result of multiplying this matrix with the Hermitian transpose of {@code b}. - */ - @Override - public U multTranspose(T b) { - // TODO: MAke sure all complex and field matrices use the hermitian transpose for this method. - ValidateParameters.ensureEquals(numCols, b.numCols); - return mult(b.H()); - } - - - /** - * Stacks matrices along columns.
- * - * @param b Matrix to stack to this matrix. - * - * @return The result of stacking this matrix on top of the matrix {@code b}. - * - * @throws IllegalArgumentException If this matrix and matrix {@code b} have a different number of columns. - * @see #stack(MatrixMixin, int) - * @see #augment(T) - */ - @Override - public T stack(T b) { - ValidateParameters.ensureEquals(numCols, b.numCols); - - Shape destShape = new Shape(numRows+b.numRows, numCols); - W[] destEntries = makeEmptyDataArray(data.length + b.data.length); - int[] destRowIndices = new int[destEntries.length]; - int[] destColIndices = new int[destEntries.length]; - CooConcat.stack(data, rowIndices, colIndices, numRows, - b.data, b.rowIndices, b.colIndices, - destEntries, destRowIndices, destColIndices); - - return makeLikeTensor(destShape, destEntries, destRowIndices, destColIndices); - } - - - /** - * Stacks matrices along rows. - * - * @param b Matrix to stack to this matrix. - * - * @return The result of stacking {@code b} to the right of this matrix. - * - * @throws IllegalArgumentException If this matrix and matrix {@code b} have a different number of rows. - * @see #stack(T) - * @see #stack(MatrixMixin, int) - */ - @Override - public T augment(T b) { - ValidateParameters.ensureEquals(numRows, b.numRows); - - Shape destShape = new Shape(numRows, numCols + b.numCols); - W[] destEntries = makeEmptyDataArray(data.length + b.data.length); - int[] destRowIndices = new int[destEntries.length]; - int[] destColIndices = new int[destEntries.length]; - CooConcat.augment(data, rowIndices, colIndices, numCols, - b.data, b.rowIndices, b.colIndices, - destEntries, destRowIndices, destColIndices); - - return makeLikeTensor(destShape, destEntries, destRowIndices, destColIndices); - } - - - /** - * Augments a vector to this matrix. - * - * @param b The vector to augment to this matrix. - * - * @return The result of augmenting {@code b} to this matrix. - */ - @Override - public T augment(V b) { - ValidateParameters.ensureEquals(numRows, b.size); - - Shape destShape = new Shape(numRows, numCols + 1); - W[] destEntries = makeEmptyDataArray(nnz + b.data.length); - int[] destRowIndices = new int[destEntries.length]; - int[] destColIndices = new int[destEntries.length]; - CooConcat.augmentVector( - data, rowIndices, colIndices, numCols, - b.data, b.indices, - destEntries, destRowIndices, destColIndices); - - return makeLikeTensor(destShape, destEntries, destRowIndices, destColIndices); - } - - - /** - * Swaps specified rows in the matrix. This is done in place. - * - * @param rowIndex1 Index of the first row to swap. - * @param rowIndex2 Index of the second row to swap. - * - * @return A reference to this matrix. - * - * @throws ArrayIndexOutOfBoundsException If either index is outside the matrix bounds. - */ - @Override - public T swapRows(int rowIndex1, int rowIndex2) { - CooManipulations.swapRows(shape, data, rowIndices, colIndices, rowIndex1, rowIndex2); - return (T) this; - } - - - /** - * Swaps specified columns in the matrix. This is done in place. - * - * @param colIndex1 Index of the first column to swap. - * @param colIndex2 Index of the second column to swap. - * - * @return A reference to this matrix. - * - * @throws ArrayIndexOutOfBoundsException If either index is outside the matrix bounds. - */ - @Override - public T swapCols(int colIndex1, int colIndex2) { - CooManipulations.swapCols(shape, data, rowIndices, colIndices, colIndex1, colIndex2); - return (T) this; - } - - - /** - * Checks if a matrix is symmetric. That is, if the matrix is square and equal to its transpose. - * - * @return {@code true} if this matrix is symmetric; {@code false} otherwise. - */ - @Override - public boolean isSymmetric() { - return CooSemiringMatrixProperties.isSymmetric(shape, data, rowIndices, colIndices); - } - - - /** - * Checks if a matrix is Hermitian. That is, if the matrix is square and equal to its conjugate transpose. - * - * @return {@code true} if this matrix is Hermitian; {@code false} otherwise. - */ - @Override - public boolean isHermitian() { - return CooFieldMatrixProperties.isHermitian(shape, data, rowIndices, colIndices); - } - - - /** - * Checks if this matrix is orthogonal. That is, if the inverse of this matrix is equal to its transpose. - * - * @return {@code true} if this matrix it is orthogonal; {@code false} otherwise. - */ - @Override - public boolean isOrthogonal() { - if(isSquare()) return mult(T()).isI(); - else return false; - } - - - /** - * Gets a range of a row of this matrix. - * - * @param rowIdx The index of the row to get. - * @param start The staring column of the row range to get (inclusive). - * @param stop The ending column of the row range to get (exclusive). - * - * @return A vector containing the elements of the specified row over the range [start, stop). - * - * @throws IllegalArgumentException If {@code rowIdx < 0 || rowIdx >= this.numRows()} or {@code start < 0 || start >= numCols} or - * {@code stop < start || stop > numCols}. - */ - @Override - public V getRow(int rowIdx, int start, int stop) { - SparseVectorData data = CooGetSet.getRow(shape, - this.data, rowIndices, colIndices, rowIdx, start, stop); - return makeLikeVector(data.shape(), - data.data().toArray(makeEmptyDataArray(data.data().size())), - data.indicesToArray()); - } - - - /** - * Gets a range of a column of this matrix. - * - * @param colIdx The index of the column to get. - * @param start The staring row of the column range to get (inclusive). - * @param stop The ending row of the column range to get (exclusive). - * - * @return A vector containing the elements of the specified column over the range [start, stop). - * - * @throws IllegalArgumentException If {@code colIdx < 0 || colIdx >= this.numCols()} or {@code start < 0 || start >= numRows} or - * {@code stop < start || stop > numRows}. - */ - @Override - public V getCol(int colIdx, int start, int stop) { - SparseVectorData data = CooGetSet.getCol(shape, this.data, rowIndices, colIndices, colIdx, start, stop); - return makeLikeVector(data.shape(), - data.data().toArray(makeEmptyDataArray(data.data().size())), - data.indicesToArray()); - } - - - /** - * Gets the elements of this matrix along the specified diagonal. - * - * @param diagOffset The diagonal to get within this matrix. - *

    - *
  • If {@code diagOffset == 0}: Then the elements of the principle diagonal are collected.
  • - *
  • If {@code diagOffset < 0}: Then the elements of the sub-diagonal {@code diagOffset} below the principle diagonal - * are collected.
  • - *
  • If {@code diagOffset > 0}: Then the elements of the super-diagonal {@code diagOffset} above the principle diagonal - * are collected.
  • - *
- * - * @return The elements of the specified diagonal as a vector. - */ - @Override - public V getDiag(int diagOffset) { - SparseVectorData data = CooGetSet.getDiag(shape, this.data, rowIndices, colIndices, diagOffset); - return makeLikeVector(data.shape(), - data.data().toArray(makeEmptyDataArray(data.data().size())), - data.indicesToArray()); - } - - - /** - * Sets a specified row of this matrix to a vector. - * - * @param row Vector to replace specified row in this matrix. - * @param rowIdx Index of the row to set. - * - * @return If this matrix is dense, the row set operation is done in place and a reference to this matrix is returned. - * If this matrix is sparse a copy will be created with the new row and returned. - */ - @Override - public T setRow(V row, int rowIdx) { - SparseMatrixData data = CooGetSet.setRow( - shape, this.data, rowIndices, colIndices, - rowIdx, row.size, row.data, row.indices); - return makeLikeTensor(data.shape(), data.data(), data.rowData(), data.colData()); - } - - - /** - * Sets a column of this matrix at the given index to the specified vector. - * - * @param col Vector containing new column data. - * @param colIndex The index of the column which is to be set. - * - * @return A copy of this matrix with the specified column set to {@code col}. - * - * @throws IllegalArgumentException If the {@code col} vector has a different length than the number of rows of this matrix. - * @throws IndexOutOfBoundsException If {@code colIndex < 0 || colIndex >= this.numCols}. - */ - public T setCol(V col, int colIndex) { - SparseMatrixData data = CooGetSet.setCol( - shape, this.data, rowIndices, colIndices, - colIndex, col.size, col.data, col.indices); - CooDataSorter sorter = new CooDataSorter(data.data(), data.rowData(), data.colData()).sparseSort(); - return makeLikeTensor(data.shape(), data.data(), data.rowData(), data.colData()); - } - - - /** - * Removes a specified row from this matrix. - * - * @param rowIndex Index of the row to remove from this matrix. - * - * @return A copy of this matrix with the specified row removed. - */ - @Override - public T removeRow(int rowIndex) { - Shape shape = new Shape(numRows-1, numCols); - - // Find the start and end index within the data array which have the given row index. - int[] startEnd = SparseElementSearch.matrixFindRowStartEnd(rowIndices, rowIndex); - int size = data.length - (startEnd[1]-startEnd[0]); - - // Initialize arrays. - W[] entries = makeEmptyDataArray(size); - int[] rowIndices = new int[size]; - int[] colIndices = new int[size]; - - copyRanges(this.data, this.rowIndices, this.colIndices, entries, rowIndices, colIndices, startEnd); - - return makeLikeTensor(shape, entries, rowIndices, colIndices); - } - - - /** - * Removes a specified set of rows from this matrix. - * - * @param rowIdxs The indices of the rows to remove from this matrix. Assumed to contain unique values. - * - * @return A copy of this matrix with the specified column removed. - */ - @Override - public T removeRows(int... rowIdxs) { - // TODO: This should be doable for a general COO matrix. Return SparseMatrixData object. - Shape shape = new Shape(numRows-rowIdxs.length, numCols); - int initSize = Math.max(0, nnz - (nnz / numRows) * rowIdxs.length); // Estimate the number of non-zeros in the result. - List entries = new ArrayList<>(initSize); - List rowIndices = new ArrayList<>(initSize); - List colIndices = new ArrayList<>(initSize); - - for(int i=0; i destEntries = new ArrayList<>(data.length); - List destRowIndices = new ArrayList<>(data.length); - List destColIndices = new ArrayList<>(data.length); - - for(int i = 0; i< data.length; i++) { - if(colIndices[i] != colIndex) { - // Then entry is not in the specified column, so remove it. - destEntries.add(data[i]); - destRowIndices.add(rowIndices[i]); - - if(colIndices[i] < colIndex) destColIndices.add(colIndices[i]); - else destColIndices.add(colIndices[i]-1); - } - } - - return makeLikeTensor(shape, destEntries, destRowIndices, destColIndices); - } - - - /** - * Removes a specified set of columns from this matrix. - * - * @param colIdxs Indices of the columns to remove from this matrix. Assumed to contain unique values. - * - * @return A copy of this matrix with the specified column removed. - */ - @Override - public T removeCols(int... colIdxs) { - Shape shape = new Shape(numRows, numCols-1); - List destEntries = new ArrayList<>(data.length); - List destRowIndices = new ArrayList<>(data.length); - List destColIndices = new ArrayList<>(data.length); - - for(int i = 0; i< data.length; i++) { - int idx = Arrays.binarySearch(colIdxs, colIndices[i]); - - if(idx < 0) { - // Then entry is not in the specified column, so copy it with the appropriate column index shift. - destEntries.add(data[i]); - destRowIndices.add(rowIndices[i]); - destColIndices.add(colIndices[i] + (idx+1)); - } - } - - return makeLikeTensor(shape, destEntries, destRowIndices, destColIndices); - } - - - /** - * Creates a copy of this matrix and sets a slice of the copy to the specified values. The rowStart and colStart parameters specify the upper - * left index location of the slice to set. - * - * @param values New values for the specified slice. - * @param rowStart Starting row index for the slice (inclusive). - * @param colStart Starting column index for the slice (inclusive). - * - * @return A copy of this matrix with the given slice set to the specified values. - * - * @throws IndexOutOfBoundsException If rowStart or colStart are not within the matrix. - * @throws IllegalArgumentException If the values slice, with upper left corner at the specified location, does not - * fit completely within this matrix. - */ - @Override - public T setSliceCopy(T values, int rowStart, int colStart) { - SparseMatrixData sliceData = CooGetSet.setSlice( - shape, data, rowIndices, colIndices, - values.shape, values.data, values.rowIndices, values.colIndices, - rowStart, colStart); - return makeLikeTensor(sliceData.shape(), sliceData.data(), sliceData.rowData(), sliceData.colData()); - } - - - /** - * Gets a specified slice of this matrix. - * - * @param rowStart Starting row index of slice (inclusive). - * @param rowEnd Ending row index of slice (exclusive). - * @param colStart Starting column index of slice (inclusive). - * @param colEnd Ending row index of slice (exclusive). - * - * @return The specified slice of this matrix. This is a completely new matrix and NOT a view into the matrix. - * - * @throws ArrayIndexOutOfBoundsException If any of the indices are out of bounds of this matrix. - * @throws IllegalArgumentException If {@code rowEnd} is not greater than {@code rowStart} or if {@code colEnd} is not greater than {@code colStart}. - */ - @Override - public T getSlice(int rowStart, int rowEnd, int colStart, int colEnd) { - SparseMatrixData sliceData = CooGetSet.getSlice( - shape, data, rowIndices, colIndices, - rowStart, rowEnd, colStart, colEnd); - return makeLikeTensor(sliceData.shape(), sliceData.data(), sliceData.rowData(), sliceData.colData()); - } - - - /** - * Extracts the upper-triangular portion of this matrix with a specified diagonal offset. All other data of the resulting - * matrix will be zero. - * - * @param diagOffset Diagonal offset for upper-triangular portion to extract: - *
    - *
  • If zero, then all data at and above the principle diagonal of this matrix are extracted.
  • - *
  • If positive, then all data at and above the equivalent super-diagonal are extracted.
  • - *
  • If negative, then all data at and above the equivalent sub-diagonal are extracted.
  • - *
- * - * @return The upper-triangular portion of this matrix with a specified diagonal offset. All other data of the returned - * matrix will be zero. - * - * @throws IllegalArgumentException If {@code diagOffset} is not in the range (-numRows, numCols). - */ - @Override - public T getTriU(int diagOffset) { - SparseMatrixData data = CooGetSet.getTriU(diagOffset, shape, this.data, rowIndices, colIndices); - return makeLikeTensor(data.shape(), data.data(), data.rowData(), data.colData()); - } +/** + *

A sparse matrix stored in coordinate list (COO) format. The {@link #data} of this COO vector are + * elements of a {@link Field}. + * + *

The {@link #data non-zero data} and non-zero indices of a COO matrix are mutable but the {@link #shape} + * and total number of non-zero data is fixed. + * + *

Sparse matrices allow for the efficient storage of and ops on matrices that contain many zero values. + * + *

COO matrices are optimized for hyper-sparse matrices (i.e. matrices which contain almost all zeros relative to the size of the + * matrix). + * + *

COO Representation:

+ * A sparse COO matrix is stored as: + *
    + *
  • The full {@link #shape shape} of the matrix.
  • + *
  • The non-zero {@link #data} of the matrix. All other data in the matrix are + * assumed to be zero. Zero values can also explicitly be stored in {@link #data}.
  • + *
  • The {@link #rowIndices row indices} of the non-zero values in the sparse matrix.
  • + *
  • The {@link #colIndices column indices} of the non-zero values in the sparse matrix.
  • + *
+ * + *

Note: many ops assume that the data of the COO matrix are sorted lexicographically by the row and column indices. + * (i.e.) by row indices first then column indices. However, this is not explicitly verified but any ops implemented in this + * class will preserve the lexicographical sorting. + * + *

If indices need to be sorted, call {@link #sortIndices()}. + * + * @param Type of this sparse COO matrix. + * @param Type of dense matrix which is similar to {@code T}. + * @param Type of sparse COO vector which is similar to {@code T}. + * @param Type of the field element in this matrix. + */ +public abstract class AbstractCooFieldMatrix, + U extends AbstractDenseFieldMatrix, + V extends AbstractCooFieldVector, + W extends Field> + extends AbstractCooRingMatrix + implements FieldTensorMixin, MatrixMixin { /** - * Extracts the lower-triangular portion of this matrix with a specified diagonal offset. All other data of the resulting - * matrix will be zero. - * - * @param diagOffset Diagonal offset for lower-triangular portion to extract: - *

    - *
  • If zero, then all data at and above the principle diagonal of this matrix are extracted.
  • - *
  • If positive, then all data at and above the equivalent super-diagonal are extracted.
  • - *
  • If negative, then all data at and above the equivalent sub-diagonal are extracted.
  • - *
- * - * @return The lower-triangular portion of this matrix with a specified diagonal offset. All other data of the returned - * matrix will be zero. + * Creates a sparse coo matrix with the specified non-zero data, non-zero indices, and shape. * - * @throws IllegalArgumentException If {@code diagOffset} is not in the range (-numRows, numCols). + * @param shape Shape of this tensor. + * @param entries Non-zero data of this sparse matrix. + * @param rowIndices Non-zero row indices of this sparse matrix. + * @param colIndices Non-zero column indies of this sparse matrix. */ - @Override - public T getTriL(int diagOffset) { - SparseMatrixData data = CooGetSet.getTriL(diagOffset, shape, this.data, rowIndices, colIndices); - return makeLikeTensor(data.shape(), data.data(), data.rowData(), data.colData()); + protected AbstractCooFieldMatrix(Shape shape, W[] entries, int[] rowIndices, int[] colIndices) { + super(shape, entries, rowIndices, colIndices); } /** - * Creates a deep copy of this tensor. - * - * @return A deep copy of this tensor. + * Constructs a sparse CSR matrix of a similar type to this sparse COO matrix. + * @param shape Shape of the CSR matrix to construct. + * @param entries Non-zero data of the CSR matrix. + * @param rowPointers Non-zero row pointers of the CSR matrix. + * @param colIndices Non-zero column indices of the CSR matrix. + * @return A CSR matrix of a similar type to this sparse COO matrix. */ - @Override - public T copy() { - return makeLikeTensor(shape, data); - } + public abstract AbstractCsrFieldMatrix makeLikeCsrMatrix( + Shape shape, W[] entries, int[] rowPointers, int[] colIndices); /** - * Computes the element-wise difference between two tensors of the same shape. - * - * @param b Second tensor in the element-wise difference. + * Multiplies this matrix with the transpose of the {@code b} tensor as if by + * {@code this.mult(b.H())}. + * For large matrices, this method may + * be significantly faster than directly computing the Hermitian transpose followed by the multiplication as + * {@code this.mult(b.H())}. * - * @return The difference of this tensor with {@code b}. + * @param b The second matrix in the multiplication and the matrix to transpose. * - * @throws TensorShapeException If this tensor and {@code b} do not have the same shape. + * @return The result of multiplying this matrix with the Hermitian transpose of {@code b}. */ @Override - public T sub(T b) { - SparseMatrixData data = CooRingMatrixOps.sub( - shape, this.data, rowIndices, colIndices, - b.shape, b.data, b.rowIndices, b.colIndices); - - return makeLikeTensor(data.shape(), - data.data().toArray(makeEmptyDataArray(data.data().size())), - data.rowIndicesToArray(), - data.colIndicesToArray()); + public U multTranspose(T b) { + // TODO: Ensure all complex and field matrices use the hermitian transpose for this method. + ValidateParameters.ensureEquals(numCols, b.numCols); + return mult(b.H()); } @@ -1175,146 +186,6 @@ public T H(int... axes) { } - /** - * Finds the minimum value in this tensor. If this tensor is complex, then this method finds the smallest value in magnitude. - * - * @return The minimum value (smallest in magnitude for a complex valued tensor) in this tensor. - */ - @Override - public W min() { - return CompareSemiring.min(data); - } - - - /** - * Finds the maximum value in this tensor. If this tensor is complex, then this method finds the largest value in magnitude. - * - * @return The maximum value (largest in magnitude for a complex valued tensor) in this tensor. - */ - @Override - public W max() { - return (W) CompareSemiring.max(data); - } - - - /** - * Finds the indices of the minimum value in this tensor. - * - * @return The indices of the minimum value in this tensor. If this value occurs multiple times, the indices of the first - * entry (in row-major ordering) are returned. - */ - @Override - public int[] argmin() { - int idx = CompareSemiring.argmin(data); - return new int[]{rowIndices[idx], colIndices[idx]}; - } - - - /** - * Finds the indices of the maximum value in this tensor. - * - * @return The indices of the maximum value in this tensor. If this value occurs multiple times, the indices of the first - * entry (in row-major ordering) are returned. - */ - @Override - public int[] argmax() { - int idx = CompareSemiring.argmax(data); - return new int[]{rowIndices[idx], colIndices[idx]}; - } - - - /** - * Computes the element-wise sum between two tensors of the same shape. - * - * @param b Second tensor in the element-wise sum. - * - * @return The sum of this tensor with {@code b}. - * - * @throws TensorShapeException If this tensor and {@code b} do not have the same shape. - */ - @Override - public T add(T b) { - SparseMatrixData data = CooSemiringMatrixOps.add( - shape, this.data, rowIndices, colIndices, - b.shape, b.data, b.rowIndices, b.colIndices); - - return makeLikeTensor(data.shape(), - data.data().toArray(makeEmptyDataArray(data.data().size())), - data.rowIndicesToArray(), - data.colIndicesToArray()); - } - - - /** - * Computes the element-wise multiplication of two tensors of the same shape. - * - * @param b Second tensor in the element-wise product. - * - * @return The element-wise product between this tensor and {@code b}. - * - * @throws IllegalArgumentException If this tensor and {@code b} do not have the same shape. - */ - @Override - public T elemMult(T b) { - SparseMatrixData data = CooSemiringMatrixOps.elemMult( - shape, this.data, rowIndices, colIndices, - b.shape, b.data, b.rowIndices, b.colIndices); - - return makeLikeTensor(data.shape(), - data.data().toArray(makeEmptyDataArray(data.data().size())), - data.rowIndicesToArray(), - data.colIndicesToArray()); - } - - - /** - *

Computes the generalized trace of this tensor along the specified axes. - * - *

The generalized tensor trace is the sum along the diagonal values of the 2D sub-arrays of this tensor specified by - * {@code axis1} and {@code axis2}. The shape of the resulting tensor is equal to this tensor with the - * {@code axis1} and {@code axis2} removed. - * - * @param axis1 First axis for 2D sub-array. - * @param axis2 Second axis for 2D sub-array. - * - * @return The generalized trace of this tensor along {@code axis1} and {@code axis2}. - * - * @throws IndexOutOfBoundsException If the two axes are not both larger than zero and less than this tensors rank. - * @throws IllegalArgumentException If {@code axis1 == axis2} or {@code this.shape.get(axis1) != this.shape.get(axis1)} - * (i.e. the axes are equal or the tensor does not have the same length along the two axes.) - */ - @Override - public T tensorTr(int axis1, int axis2) { - ValidateParameters.ensureNotEquals(axis1, axis2); - ValidateParameters.ensureValidAxes(shape, axis1, axis2); - // TODO: Investigate if this cast (W[]) is safe for example for Complex128[]. - return makeLikeTensor(new Shape(1, 1), (W[]) new Field[]{tr()}, new int[]{0}, new int[]{0}); - } - - - /** - * Sorts the indices of this tensor in lexicographical order while maintaining the associated value for each index. - */ - public void sortIndices() { - CooDataSorter.wrap(data, rowIndices, colIndices).sparseSort().unwrap(data, rowIndices, colIndices); - } - - - /** - * Converts this sparse COO matrix to an equivalent dense matrix. - * @return A dense matrix equivalent to this sparse COO matrix. - */ - public U toDense() { - W[] dense = makeEmptyDataArray(shape.totalEntriesIntValueExact()); - Arrays.fill(dense, zeroElement); - - for(int i = 0; i toCsr() { } - /** - * Converts this matrix to an equivalent tensor. - * @return A tensor which is equivalent to this matrix. - */ - public abstract AbstractCooFieldTensor toTensor(); - - - /** - * Converts this matrix to an equivalent tensor with the specified shape. - * @param newShape New shape for the tensor. Can be any rank but must be broadcastable to {@link #shape this.shape}. - * @return A tensor equivalent to this matrix which has been reshaped to {@code newShape} - */ - public abstract AbstractCooFieldTensor toTensor(Shape newShape); - - - /** - * Converts this sparse CSR matrix to an equivalent vector. If this matrix is not a row or column vector it will be flattened - * before conversion. - * @return A vector equivalent to this CSR matrix. - */ - public V toVector() { - int[] destIndices = new int[data.length]; - for(int i = 0; i< data.length; i++) - destIndices[i] = rowIndices[i]*numCols + colIndices[i]; - - return makeLikeVector(new Shape(numRows*numCols), data.clone(), destIndices); - } - - /** *

Computes the element-wise quotient between two tensors. *

WARNING: This method is not supported for sparse tensors. If called on a sparse tensor, @@ -1427,41 +269,4 @@ public boolean isInfinite() { public boolean isNaN() { return FieldOps.isNaN(data); } - - - /** - * Coalesces this sparse COO matrix. An uncoalesced matrix is a sparse matrix with multiple data for a single index. This - * method will ensure that each index only has one non-zero value by summing duplicated data. If another form of aggregation other - * than summing is desired, use {@link #coalesce(BinaryOperator)}. - * @return A new coalesced sparse COO matrix which is equivalent to this COO matrix. - * @see #coalesce(BinaryOperator) - */ - public T coalesce() { - SparseMatrixData mat = SparseUtils.coalesce(Semiring::add, shape, data, rowIndices, colIndices); - return makeLikeTensor(mat.shape(), mat.data(), mat.rowData(), mat.colData()); - } - - - /** - * Coalesces this sparse COO matrix. An uncoalesced matrix is a sparse matrix with multiple data for a single index. This - * method will ensure that each index only has one non-zero value by aggregating duplicated data using {@code aggregator}. - * @param aggregator Custom aggregation function to combine multiple. - * @return A new coalesced sparse COO matrix which is equivalent to this COO matrix. - * @see #coalesce() - */ - public T coalesce(BinaryOperator aggregator) { - SparseMatrixData mat = SparseUtils.coalesce(aggregator, shape, data, rowIndices, colIndices); - return makeLikeTensor(mat.shape(), mat.data(), mat.rowData(), mat.colData()); - } - - - /** - * Drops any explicit zeros in this sparse COO matrix. - * @return A copy of this COO matrix with any explicitly stored zeros removed. - */ - public T dropZeros() { - SparseMatrixData mat = SparseUtils.dropZeros(shape, data, rowIndices, colIndices); - return makeLikeTensor(mat.shape(), mat.data(), mat.rowData(), mat.colData()); - } - } diff --git a/src/main/java/org/flag4j/arrays/backend/field_arrays/AbstractCooFieldTensor.java b/src/main/java/org/flag4j/arrays/backend/field_arrays/AbstractCooFieldTensor.java index 1684e6ab5..969e8710f 100644 --- a/src/main/java/org/flag4j/arrays/backend/field_arrays/AbstractCooFieldTensor.java +++ b/src/main/java/org/flag4j/arrays/backend/field_arrays/AbstractCooFieldTensor.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -25,29 +25,14 @@ package org.flag4j.arrays.backend.field_arrays; import org.flag4j.algebraic_structures.Field; -import org.flag4j.algebraic_structures.Semiring; import org.flag4j.arrays.Shape; -import org.flag4j.arrays.SparseTensorData; -import org.flag4j.arrays.backend.AbstractTensor; +import org.flag4j.arrays.backend.ring_arrays.AbstractCooRingTensor; +import org.flag4j.arrays.backend.semiring_arrays.AbstractCooSemiringTensor; import org.flag4j.arrays.sparse.CooTensor; import org.flag4j.linalg.ops.common.field_ops.FieldOps; import org.flag4j.linalg.ops.common.ring_ops.RingOps; -import org.flag4j.linalg.ops.common.semiring_ops.CompareSemiring; -import org.flag4j.linalg.ops.sparse.SparseElementSearch; -import org.flag4j.linalg.ops.sparse.SparseUtils; -import org.flag4j.linalg.ops.sparse.coo.*; -import org.flag4j.linalg.ops.sparse.coo.field_ops.CooFieldHermTranspose; -import org.flag4j.linalg.ops.sparse.coo.ring_ops.CooRingTensorOps; -import org.flag4j.linalg.ops.sparse.coo.semiring_ops.CooSemiringTensorOps; +import org.flag4j.linalg.ops.sparse.coo.ring_ops.CooRingHermTranspose; import org.flag4j.util.ArrayUtils; -import org.flag4j.util.ValidateParameters; -import org.flag4j.util.exceptions.TensorShapeException; - -import java.math.BigDecimal; -import java.math.RoundingMode; -import java.util.Arrays; -import java.util.List; -import java.util.function.BinaryOperator; /** @@ -77,35 +62,15 @@ * * * @param Type of this sparse COO tensor. - * @param Type of dense tensor equivalent to {@code T}. This type parameter is required because some ops (e.g. - * {@link #tensorDot(AbstractCooFieldTensor, int[], int[])} between two sparse tensors results in a dense tensor. + * @param Type of dense tensor equivalent to {@code T}. This type parameter is required because some operations (e.g. + * {@link #tensorDot(AbstractCooSemiringTensor, int[], int[])} between two sparse tensors results in a dense tensor. * @param Type of the {@link Field} which the data of this tensor belong to. */ public abstract class AbstractCooFieldTensor, U extends AbstractDenseFieldTensor, V extends Field> - extends AbstractTensor + extends AbstractCooRingTensor implements FieldTensorMixin { - /** - * The zero element for the field that this tensor's elements belong to. - */ - private V zeroElement; - /** - *

The non-zero indices of this sparse tensor. - * - *

Has shape {@code (nnz, rank)} where {@code nnz} is the number of non-zero data in this sparse tensor. - */ - public final int[][] indices; - /** - * The number of non-zero data in this sparse tensor. - */ - public final int nnz; - /** - * Stores the sparsity of this matrix. - */ - public final double sparsity; - - /** * Creates a tensor with the specified data and shape. * @@ -114,451 +79,7 @@ public abstract class AbstractCooFieldTensor 0 && entries[0] != null) ? entries[0].getZero() : null; - } - - - /** - * Constructs a tensor of the same type as this tensor with the specified shape and non-zero data. - * @param shape Shape of the tensor to construct. - * @param entries Non-zero data of the tensor to construct. - * @param indices Indices of the non-zero data of the tensor. - * @return A tensor of the same type as this tensor with the specified shape and non-zero data. - */ - public abstract T makeLikeTensor(Shape shape, V[] entries, int[][] indices); - - - /** - * Constructs a tensor of the same type as this tensor with the specified shape and non-zero data. - * @param shape Shape of the tensor to construct. - * @param entries Non-zero data of the tensor to construct. - * @param indices Indices of the non-zero data of the tensor. - * @return A tensor of the same type as this tensor with the specified shape and non-zero data. - */ - public abstract T makeLikeTensor(Shape shape, List entries, List indices); - - - /** - * Constructs a dense tensor that is a similar type as this sparse COO tensor. - * @param shape Shape of the tensor to construct. - * @param entries The data of the dense tensor to construct. - * @return A dense tensor that is a similar type as this sparse COO tensor. - */ - public abstract U makeLikeDenseTensor(Shape shape, V[] entries); - - - /** - * Gets the zero element for the field of this tensor. - * @return The zero element for the field of this tensor. If it could not be determined during construction of this object - * and has not been set explicitly by {@link #setZeroElement(Field)} then {@code null} will be returned. - */ - public V getZeroElement() { - return zeroElement; - } - - - /** - * Sets the zero element for the field of this tensor. - * @param zeroElement The zero element of this tensor. - * @throws IllegalArgumentException If {@code zeroElement} is not an additive identity for the field. - */ - public void setZeroElement(V zeroElement) { - if (zeroElement.isZero()) { - this.zeroElement = zeroElement; - } else { - throw new IllegalArgumentException("The provided zeroElement is not an additive identity."); - } - } - - /** - * Gets the sparsity of this tensor as a decimal percentage. - * That is, the percentage of data in this tensor that are zero. - * @return The sparsity of this tensor as a decimal percentage. - * @see #density() - */ - public double sparsity() { - return sparsity; - } - - - /** - * Gets the density of this tensor as a decimal percentage. - * That is, the percentage of data in this tensor that are non-zero. - * @return The density of this tensor as a decimal percentage. - * @see #sparsity() - */ - public double density() { - return 1.0 - sparsity; - } - - - /** - * Computes the element-wise sum between two tensors of the same shape. - * - * @param b Second tensor in the element-wise sum. - * - * @return The sum of this tensor with {@code b}. - * - * @throws TensorShapeException If this tensor and {@code b} do not have the same shape. - */ - @Override - public T add(T b) { - SparseTensorData sum = CooSemiringTensorOps.add( - shape, data, indices, - b.shape, b.data, b.indices - ); - - return makeLikeTensor(sum.shape(), - sum.data().toArray(makeEmptyDataArray(sum.data().size())), - sum.indices().toArray(new int[sum.indices().size()][])); - } - - - /** - * Computes the element-wise multiplication of two tensors of the same shape. - * - * @param b Second tensor in the element-wise product. - * - * @return The element-wise product between this tensor and {@code b}. - * - * @throws IllegalArgumentException If this tensor and {@code b} do not have the same shape. - */ - @Override - public T elemMult(T b) { - SparseTensorData prod = CooSemiringTensorOps.elemMult( - shape, data, indices, b.shape, b.data, b.indices); - return makeLikeTensor(prod.shape(), - prod.data().toArray(makeEmptyDataArray(prod.data().size())), - prod.indices().toArray(new int[prod.indices().size()][])); - } - - - /** - * Computes the tensor contraction of this tensor with a specified tensor over the specified set of axes. That is, - * computes the sum of products between the two tensors along the specified set of axes. - * - * @param src2 Tensor to contract with this tensor. - * @param aAxes Axes along which to compute products for this tensor. - * @param bAxes Axes along which to compute products for {@code src2} tensor. - * - * @return The tensor dot product over the specified axes. - * - * @throws IllegalArgumentException If the two tensors shapes do not match along the specified axes pairwise in - * {@code aAxes} and {@code bAxes}. - * @throws IllegalArgumentException If {@code aAxes} and {@code bAxes} do not match in length, or if any of the axes - * are out of bounds for the corresponding tensor. - */ - @Override - public U tensorDot(T src2, int[] aAxes, int[] bAxes) { - CooTensorDot problem = new CooTensorDot<>(shape, data, indices, - src2.shape, src2.data, src2.indices, - aAxes, bAxes); - V[] dest = makeEmptyDataArray(problem.getOutputSize()); - problem.compute(dest); - return makeLikeDenseTensor(problem.getOutputShape(), dest); - } - - - /** - *

Computes the generalized trace of this tensor along the specified axes. - * - *

The generalized tensor trace is the sum along the diagonal values of the 2D sub-arrays of this tensor specified by - * {@code axis1} and {@code axis2}. The shape of the resulting tensor is equal to this tensor with the - * {@code axis1} and {@code axis2} removed. - * - * @param axis1 First axis for 2D sub-array. - * @param axis2 Second axis for 2D sub-array. - * - * @return The generalized trace of this tensor along {@code axis1} and {@code axis2}. - * - * @throws IndexOutOfBoundsException If the two axes are not both larger than zero and less than this tensors rank. - * @throws IllegalArgumentException If {@code axis1 == axis2} or {@code this.shape.get(axis1) != this.shape.get(axis1)} - * (i.e. the axes are equal or the tensor does not have the same length along the two axes.) - */ - @Override - public T tensorTr(int axis1, int axis2) { - SparseTensorData tr = CooSemiringTensorOps.tensorTr( - shape, data, indices, axis1, axis2); - return makeLikeTensor(tr.shape(), - tr.data().toArray(makeEmptyDataArray(tr.data().size())), - tr.indices().toArray(new int[tr.indices().size()][])); - } - - - /** - * Computes the transpose of a tensor by exchanging the first and last axes of this tensor. - * - * @return The transpose of this tensor. - * - * @see #T(int, int) - * @see #T(int...) - */ - @Override - public T T() { - V[] destEntries = makeEmptyDataArray(nnz); - int[][] destIndices = new int[nnz][rank]; - CooTranspose.tensorTranspose(shape, data, indices,0, shape.getRank()-1, destEntries, destIndices); - return makeLikeTensor(shape.swapAxes(0, rank-1), destEntries, destIndices); - } - - - /** - * Computes the transpose of a tensor by exchanging {@code axis1} and {@code axis2}. - * - * @param axis1 First axis to exchange. - * @param axis2 Second axis to exchange. - * - * @return The transpose of this tensor according to the specified axes. - * - * @throws IndexOutOfBoundsException If either {@code axis1} or {@code axis2} are out of bounds for the rank of this tensor. - * @see #T() - * @see #T(int...) - */ - @Override - public T T(int axis1, int axis2) { - V[] destEntries = makeEmptyDataArray(nnz); - int[][] destIndices = new int[nnz][rank]; - CooTranspose.tensorTranspose(shape, data, indices, axis1, axis2, destEntries, destIndices); - return makeLikeTensor(shape.swapAxes(axis1, axis2), destEntries, destIndices); - } - - - /** - * Computes the transpose of this tensor. That is, permutes the axes of this tensor so that it matches - * the permutation specified by {@code axes}. - * - * @param axes Permutation of tensor axis. If the tensor has rank {@code N}, then this must be an array of length - * {@code N} which is a permutation of {@code {0, 1, 2, ..., N-1}}. - * - * @return The transpose of this tensor with its axes permuted by the {@code axes} array. - * - * @throws IndexOutOfBoundsException If any element of {@code axes} is out of bounds for the rank of this tensor. - * @throws IllegalArgumentException If {@code axes} is not a permutation of {@code {1, 2, 3, ... N-1}}. - * @see #T(int, int) - * @see #T() - */ - @Override - public T T(int... axes) { - V[] destEntries = makeEmptyDataArray(nnz); - int[][] destIndices = new int[nnz][rank]; - CooTranspose.tensorTranspose(shape, data, indices, axes, destEntries, destIndices); - return makeLikeTensor(shape.permuteAxes(axes), destEntries, destIndices); - } - - - /** - * Creates a deep copy of this tensor. - * - * @return A deep copy of this tensor. - */ - @Override - public T copy() { - return makeLikeTensor(shape, data.clone()); - } - - - /** - * Finds the minimum (non-zero) value in this tensor. If this tensor is complex, then this method finds the smallest value in - * magnitude. - * - * @return The minimum (non-zero) value in this tensor. - */ - @Override - public V min() { - return CompareSemiring.min(data); - } - - - /** - * Finds the maximum (non-zero) value in this tensor. - * - * @return The maximum (non-zero) value in this tensor. - */ - @Override - public V max() { - return CompareSemiring.max(data); - } - - - /** - * Finds the indices of the minimum (non-zero) value in this tensor. - * - * @return The indices of the minimum (non-zero) value in this tensor. - */ - @Override - public int[] argmin() { - return indices[CompareSemiring.argmin(data)]; - } - - - /** - * Finds the indices of the maximum (non-zero) value in this tensor. - * - * @return The indices of the maximum (non-zero) value in this tensor. - */ - @Override - public int[] argmax() { - return indices[CompareSemiring.argmin(data)]; - } - - - /** - * Gets the element of this tensor at the specified target. - * - * @param target Index of the element to get. - * - * @return The element of this tensor at the specified index. If there is a non-zero value with the specified index, that value - * will be returned. If there is no non-zero value at the specified index than the zero element will attempt to be - * returned (i.e. the additive identity of the field). However, if the zero element could not be determined during - * construction or if it was not set with {@link #setZeroElement(Field)} then - * {@code null} will be returned. - * - * @throws ArrayIndexOutOfBoundsException If any target are not within this tensor. - */ - @Override - public V get(int... target) { - ValidateParameters.validateTensorIndex(shape, target); - V value = CooGetSet.getCoo(data, indices, target); - return (value == null) ? getZeroElement() : value; - } - - - /** - * Sets the element of this tensor at the specified target. - * - * @param value New value to set the specified index of this tensor to. - * @param target Index of the element to set. - * - * @return If this tensor is dense, a reference to this tensor is returned. If this tensor is sparse, a copy of this tensor with - * the updated value is returned. - * - * @throws IndexOutOfBoundsException If {@code target} is not within the bounds of this tensor. - */ - @Override - public T set(V value, int... target) { - ValidateParameters.validateTensorIndex(shape, target); - int idx = SparseElementSearch.binarySearchCoo(indices, target); - - V[] destEntries; - int[][] destIndices; - - if (idx >= 0) { - // Target index found. - destEntries = data.clone(); - destIndices = ArrayUtils.deepCopy(indices, null); - destEntries[idx] = value; - destIndices[idx] = target; - } else { - // Target not found, insert new value and index. - destEntries = makeEmptyDataArray(nnz + 1); - destIndices = new int[nnz + 1][rank]; - int insertionPoint = - (idx + 1); - CooGetSet.cooInsertNewValue(value, target, data, indices, insertionPoint, destEntries, destIndices); - } - - return makeLikeTensor(shape, destEntries, destIndices); - } - - - /** - * Flattens tensor to single dimension while preserving order of data. - * - * @return The flattened tensor. - * - * @see #flatten(int) - */ - @Override - public T flatten() { - return makeLikeTensor( - shape.flatten(), - data.clone(), - SparseUtils.cooFlattenIndices(shape, indices)); - } - - - /** - * Flattens a tensor along the specified axis. Unlike {@link #flatten()} - * - * @param axis Axis along which to flatten tensor. - * - * @throws ArrayIndexOutOfBoundsException If the axis is not positive or larger than {@code this.{@link #getRank()}-1}. - * @see #flatten() - */ - @Override - public T flatten(int axis) { - int[] destShape = new int[indices[0].length]; - Arrays.fill(destShape, 1); - destShape[axis] = shape.totalEntries().intValueExact(); - - return makeLikeTensor( - new Shape(destShape), - data.clone(), - SparseUtils.cooFlattenIndices(shape, indices, axis)); - } - - - /** - * Copies and reshapes this tensor. - * - * @param newShape New shape for the tensor. - * - * @return A copy of this tensor with the new shape. - * - * @throws TensorShapeException If {@code newShape} is not broadcastable to {@link #shape this.shape}. - */ - @Override - public T reshape(Shape newShape) { - return makeLikeTensor(newShape, data.clone(), SparseUtils.cooReshape(shape, newShape, indices)); - } - - - /** - * Sorts the indices of this tensor in lexicographical order while maintaining the associated value for each index. - */ - public void sortIndices() { - CooDataSorter.wrap(data, indices).sparseSort().unwrap(data, indices); - } - - - /** - * Converts this COO tensor to an equivalent dense tensor. - * @return A dense tensor which is equivalent to this COO tensor. - * @throws ArithmeticException If the number of data in the dense tensor exceeds 2,147,483,647. - */ - public U toDense() { - V[] denseEntries = makeEmptyDataArray(shape.totalEntriesIntValueExact()); - CooConversions.toDense(shape, data, indices, denseEntries); - return makeLikeDenseTensor(shape, denseEntries); - } - - - /** - * Computes the element-wise difference between two tensors of the same shape. - * - * @param b Second tensor in the element-wise difference. - * - * @return The difference of this tensor with {@code b}. - * - * @throws TensorShapeException If this tensor and {@code b} do not have the same shape. - */ - @Override - public T sub(T b) { - SparseTensorData diff = CooRingTensorOps.sub( - shape, data, indices, - b.shape, b.data, b.indices); - - return makeLikeTensor(diff.shape(), - diff.data().toArray(makeEmptyDataArray(diff.data().size())), - diff.indicesToArray()); + super(shape, entries, indices); } @@ -591,7 +112,7 @@ public CooTensor abs() { public T H(int axis1, int axis2) { V[] destData = makeEmptyDataArray(data.length); int[][] destIndices = new int[nnz][rank]; - CooFieldHermTranspose.tensorHermTranspose(shape, data, indices, axis1, axis2, destData, destIndices); + CooRingHermTranspose.tensorHermTranspose(shape, data, indices, axis1, axis2, destData, destIndices); return makeLikeTensor(shape.swapAxes(axis1, axis2), destData, destIndices); } @@ -614,7 +135,7 @@ public T H(int axis1, int axis2) { public T H(int... axes) { V[] destData = makeEmptyDataArray(data.length); int[][] destIndices = new int[nnz][rank]; - CooFieldHermTranspose.tensorHermTranspose(shape, data, indices, axes, destData, destIndices); + CooRingHermTranspose.tensorHermTranspose(shape, data, indices, axes, destData, destIndices); return makeLikeTensor(shape.permuteAxes(axes), destData, destIndices); } @@ -689,40 +210,4 @@ public boolean isInfinite() { public boolean isNaN() { return FieldOps.isInfinite(data); } - - - /** - * Coalesces this sparse COO tensor. An uncoalesced tensor is a sparse tensor with multiple data for a single index. This - * method will ensure that each index only has one non-zero value by summing duplicated data. If another form of aggregation other - * than summing is desired, use {@link #coalesce(BinaryOperator)}. - * @return A new coalesced sparse COO tensor which is equivalent to this COO tensor. - * @see #coalesce(BinaryOperator) - */ - public T coalesce() { - SparseTensorData tensor = SparseUtils.coalesce(Semiring::add, shape, data, indices); - return makeLikeTensor(tensor.shape(), tensor.data(), tensor.indices()); - } - - - /** - * Coalesces this sparse COO tensor. An uncoalesced tensor is a sparse tensor with multiple data for a single index. This - * method will ensure that each index only has one non-zero value by aggregating duplicated data using {@code aggregator}. - * @param aggregator Custom aggregation function to combine multiple. - * @return A new coalesced sparse COO tensor which is equivalent to this COO tensor. - * @see #coalesce() - */ - public T coalesce(BinaryOperator aggregator) { - SparseTensorData tensor = SparseUtils.coalesce(aggregator, shape, data, indices); - return makeLikeTensor(tensor.shape(), tensor.data(), tensor.indices()); - } - - - /** - * Drops any explicit zeros in this sparse COO tensor. - * @return A copy of this COO tensor with any explicitly stored zeros removed. - */ - public T dropZeros() { - SparseTensorData tensor = SparseUtils.dropZeros(shape, data, indices); - return makeLikeTensor(tensor.shape(), tensor.data(), tensor.indices()); - } } diff --git a/src/main/java/org/flag4j/arrays/backend/field_arrays/AbstractCooFieldVector.java b/src/main/java/org/flag4j/arrays/backend/field_arrays/AbstractCooFieldVector.java index 03a749d10..c91fe4128 100644 --- a/src/main/java/org/flag4j/arrays/backend/field_arrays/AbstractCooFieldVector.java +++ b/src/main/java/org/flag4j/arrays/backend/field_arrays/AbstractCooFieldVector.java @@ -25,31 +25,13 @@ package org.flag4j.arrays.backend.field_arrays; import org.flag4j.algebraic_structures.Field; -import org.flag4j.algebraic_structures.Semiring; import org.flag4j.arrays.Shape; -import org.flag4j.arrays.SparseVectorData; -import org.flag4j.arrays.backend.AbstractTensor; import org.flag4j.arrays.backend.VectorMixin; +import org.flag4j.arrays.backend.ring_arrays.AbstractCooRingVector; import org.flag4j.arrays.sparse.CooVector; import org.flag4j.linalg.ops.common.field_ops.FieldOps; import org.flag4j.linalg.ops.common.ring_ops.RingOps; -import org.flag4j.linalg.ops.sparse.SparseUtils; -import org.flag4j.linalg.ops.sparse.coo.CooConcat; -import org.flag4j.linalg.ops.sparse.coo.CooDataSorter; -import org.flag4j.linalg.ops.sparse.coo.CooGetSet; -import org.flag4j.linalg.ops.sparse.coo.field_ops.CooFieldVectorOps; -import org.flag4j.linalg.ops.sparse.coo.ring_ops.CooRingVectorOps; -import org.flag4j.linalg.ops.sparse.coo.semiring_ops.CooSemiringVectorOps; -import org.flag4j.util.ArrayUtils; import org.flag4j.util.ValidateParameters; -import org.flag4j.util.exceptions.LinearAlgebraException; -import org.flag4j.util.exceptions.TensorShapeException; - -import java.math.BigDecimal; -import java.math.RoundingMode; -import java.util.Arrays; -import java.util.List; -import java.util.function.BinaryOperator; /** @@ -89,30 +71,9 @@ public abstract class AbstractCooFieldVector< V extends AbstractCooFieldMatrix, W extends AbstractDenseFieldMatrix, Y extends Field> - extends AbstractTensor + extends AbstractCooRingVector implements FieldTensorMixin, VectorMixin { - /** - * The zero element for the arrays that this tensor's elements belong to. - */ - private Y zeroElement; - /** - * Indices of the non-zero values of this sparse COO vector. - */ - public final int[] indices; - /** - * The number of non-zero data in this sparse COO vector. - */ - public final int nnz; - /** - * The total size of this sparse COO vector (including zero values). - */ - public final int size; - /** - * The sparsity of this matrix. - */ - public final double sparsity; - /** * Creates a tensor with the specified data and shape. @@ -122,600 +83,7 @@ public abstract class AbstractCooFieldVector< * If this tensor is sparse, this specifies only the non-zero data of the tensor. */ protected AbstractCooFieldVector(Shape shape, Y[] data, int[] indices) { - super(shape, data); - ValidateParameters.ensureRank(shape, 1); - ValidateParameters.ensureIndicesInBounds(shape.get(0), indices); - this.size = shape.totalEntriesIntValueExact(); - - if(data.length != indices.length) { - throw new IllegalArgumentException("data and indices arrays of a COO vector must have the same length but got " + - "lengths " + data.length + " and " + indices.length + "."); - } - if(data.length > size) { - throw new IllegalArgumentException("The number of data cannot be greater than the size of the vector but but got " + - "data.length=" + data.length + " and size=" + size + "."); - } - - this.indices = indices; - nnz = data.length; - sparsity = BigDecimal.valueOf(nnz).divide(new BigDecimal(shape.totalEntries()), RoundingMode.HALF_UP).doubleValue(); - - // Attempt to set the zero element for the arrays. - zeroElement = (data.length > 0 && data[0] != null) ? data[0].getZero() : null; - } - - - /** - * Constructs a sparse COO vector of the same type as this vector with the specified non-zero data and indices. - * @param shape Shape of the vector to construct. - * @param entries Non-zero data of the vector to construct. - * @param indices Non-zero row indices of the vector to construct. - * @return A sparse COO vector of the same type as this vector with the specified non-zero data and indices. - */ - public abstract T makeLikeTensor(Shape shape, Y[] entries, int[] indices); - - - /** - * Constructs a dense vector of a similar type as this vector with the specified shape and data. - * @param shape Shape of the vector to construct. - * @param entries Entries of the vector to construct. - * @return A dense vector of a similar type as this vector with the specified data. - */ - public abstract U makeLikeDenseTensor(Shape shape, Y... entries); - - - /** - * Constructs a dense matrix of a similar type as this vector with the specified shape and data. - * @param shape Shape of the matrix to construct. - * @param entries Entries of the matrix to construct. - * @return A dense matrix of a similar type as this vector with the specified data. - */ - public abstract W makeLikeDenseMatrix(Shape shape, Y... entries); - - - /** - * Constructs a COO vector with the specified shape, non-zero data, and non-zero indices. - * @param shape Shape of the vector. - * @param entries Non-zero values of the vector. - * @param indices Indices of the non-zero values in the vector. - * @return A COO vector of the same type as this vector with the specified shape, non-zero data, and non-zero indices. - */ - public abstract T makeLikeTensor(Shape shape, List entries, List indices); - - - /** - * Constructs a COO matrix with the specified shape, non-zero data, and row and column indices. - * @param shape Shape of the matrix to construct. - * @param entries Non-zero data of the matrix. - * @param rowIndices Row indices of the matrix. - * @param colIndices Column indices of the matrix. - * @return A COO matrix of similar type as this vector with the specified shape, non-zero data, and non-zero row/col indices. - */ - public abstract V makeLikeMatrix(Shape shape, Y[] entries, int[] rowIndices, int[] colIndices); - - - /** - * Gets the sparsity of this matrix as a decimal percentage. - * That is, the percentage of data in this matrix that are zero. - * @return The sparsity of this matrix as a decimal percentage. - * @see #density() - */ - public double sparsity() { - return sparsity; - } - - - /** - * Gets the density of this matrix as a decimal percentage. - * That is, the percentage of data in this matrix that are non-zero. - * @return The density of this matrix as a decimal percentage. - * @see #sparsity - */ - public double density() { - return 1.0 - sparsity; - } - - - /** - * Sorts the indices of this tensor in lexicographical order while maintaining the associated value for each index. - */ - public void sortIndices() { - CooDataSorter.wrap(data, indices).sparseSort().unwrap(data, indices); - } - - - /** - * Gets the element of this tensor at the specified indices. - * - * @param target Indices of the element to get. - * - * @return The element of this tensor at the specified indices. - * - * @throws IndexOutOfBoundsException If any {target} are not within this tensor. - */ - @Override - public Y get(int... target) { - ValidateParameters.ensureArrayLengthsEq(1, target.length); - return get(target[0]); - } - - - /** - * Computes the transpose of a tensor by exchanging {@code axis1} and {@code axis2}. - * - * @param axis1 First axis to exchange. - * @param axis2 Second axis to exchange. - * - * @return The transpose of this tensor according to the specified axes. - * - * @throws IndexOutOfBoundsException If either {@code axis1} or {@code axis2} are out of bounds for the rank of this tensor. - * @see #T() - * @see #T(int...) - */ - @Override - public T T(int axis1, int axis2) { - ValidateParameters.ensureValidAxes(shape, axis1, axis2); - return copy(); - } - - - /** - * Computes the transpose of this tensor. That is, permutes the axes of this tensor so that it matches - * the permutation specified by {@code axes}. - * - * @param axes Permutation of tensor axis. If the tensor has rank {@code N}, then this must be an array of length - * {@code N} which is a permutation of {@code {0, 1, 2, ..., N-1}}. - * - * @return The transpose of this tensor with its axes permuted by the {@code axes} array. - * - * @throws IllegalArgumentException If {@code axes} is not a permutation of {@code {1, 2, 3, ... N-1}}. - * @see #T(int, int) - * @see #T() - */ - @Override - public T T(int... axes) { - if(axes.length != 1) - throw new IllegalArgumentException("Axes for tensor of rank 1 must be permutation of {1}."); - ValidateParameters.ensurePermutation(axes); - return copy(); - } - - - /** - * Creates a deep copy of this tensor. - * - * @return A deep copy of this tensor. - */ - @Override - public T copy() { - return makeLikeTensor(shape, data); - } - - - /** - * Sets the element of this tensor at the specified indices. - * - * @param value New value to set the specified index of this tensor to. - * @param target Indices of the element to set. - * - * @return If this tensor is dense, a reference to this tensor is returned. If this tensor is sparse, a copy of this tensor with - * the updated value is returned. - * - * @throws IndexOutOfBoundsException If {@code indices} is not within the bounds of this tensor. - */ - @Override - public T set(Y value, int... target) { - ValidateParameters.validateTensorIndex(shape, target); - int idx = Arrays.binarySearch(indices, target[0]); - - Y[] destEntries; - int[] destIndices; - - if (idx >= 0) { - // Target index found. - destEntries = data.clone(); - destIndices = indices.clone(); - destEntries[idx] = value; - destIndices[idx] = target[0]; - } else { - // Target not found, insert new value and index. - destEntries = makeEmptyDataArray(nnz + 1); - destIndices = new int[nnz + 1]; - int insertionPoint = - (idx + 1); - CooGetSet.cooInsertNewValue(value, target[0], data, indices, insertionPoint, destEntries, destIndices); - } - - return makeLikeTensor(shape, destEntries, destIndices); - } - - - /** - * Flattens tensor to single dimension while preserving order of data. - * - * @return The flattened tensor. - * - * @see #flatten(int) - */ - @Override - public T flatten() { - return copy(); - } - - - /** - * Flattens a tensor along the specified axis. Unlike {@link #flatten()} - * - * @param axis Axis along which to flatten tensor. - * - * @throws ArrayIndexOutOfBoundsException If the axis is not positive or larger than {@code this.{@link #getRank()}-1}. - * @see #flatten() - */ - @Override - public T flatten(int axis) { - ValidateParameters.ensureValidAxes(shape, axis); - return copy(); - } - - - /** - * Copies and reshapes this tensor. - * - * @param newShape New shape for the tensor. - * - * @return A copy of this tensor with the new shape. - * - * @throws TensorShapeException If {@code newShape} is not broadcastable to {@link #shape this.shape}. - */ - @Override - public T reshape(Shape newShape) { - ValidateParameters.ensureRank(newShape, 1); - ValidateParameters.ensureBroadcastable(shape, newShape); - return copy(); - } - - - /** - * Joints specified vector with this vector. That is, creates a vector of length {@code this.length() + b.length()} containing - * first the elements of this vector followed by the elements of {@code b}. - * - * @param b Vector to join with this vector. - * - * @return A vector resulting from joining the specified vector with this vector. - */ - @Override - public T join(T b) { - Y[] destEntries = makeEmptyDataArray(this.data.length + b.data.length); - int[] destIndices = new int[this.indices.length + b.indices.length]; - CooConcat.join(data, indices, size, b.data, b.indices, destEntries, destIndices); - return makeLikeTensor(new Shape(shape.get(0) + b.shape.get(0)), destEntries, destIndices); - } - - - /** - *

Computes the inner product between two vectors. - * - *

Note: this method is distinct from {@link #dot(AbstractCooFieldVector)}. The inner product is equivalent to the dot product - * of this tensor with the conjugation of {@code b}. - * - * @param b Second vector in the inner product. - * - * @return The inner product between this vector and the vector {@code b}. - * - * @throws IllegalArgumentException If this vector and vector {@code b} do not have the same number of data. - * @see #dot(AbstractCooFieldVector) - */ - @Override - public Y inner(T b) { - return CooFieldVectorOps.inner(this, b); - } - - - /** - *

Computes the dot product between two vectors. - * - *

Note: this method is distinct from {@link #inner(AbstractCooFieldVector)}. - * The inner product is equivalent to the dot product of this tensor with the conjugation of {@code b}. - * - * @param b Second vector in the dot product. - * - * @return The dot product between this vector and the vector {@code b}. - * - * @throws IllegalArgumentException If this vector and vector {@code b} do not have the same number of data. - * @see #inner(AbstractCooFieldVector) - */ - @Override - public Y dot(T b) { - return (Y) CooSemiringVectorOps.dot(shape, data, indices, b.shape, b.data, b.indices); - } - - - /** - *

Gets the length of a vector. Same as {@link #size()}. - *

WARNING: This method will throw a {@link ArithmeticException} if the - * total number of data in this vector is greater than the maximum integer. In this case, the true size of this vector can - * still be found by calling {@code shape.totalEntries()} on this vector. - * - * @return The length, i.e. the number of data, in this vector. - * @throws ArithmeticException If the total number of data in this vector is greater than the maximum integer. - */ - @Override - public int length() { - return shape.totalEntriesIntValueExact(); - } - - - /** - * Repeats a vector {@code n} times along a certain axis to create a matrix. - * - * @param n Number of times to repeat vector. - * @param axis Axis along which to repeat vector: - *

    - *
  • If {@code axis=0}, then the vector will be treated as a row vector and stacked vertically {@code n} times.
  • - *
  • If {@code axis=1} then the vector will be treated as a column vector and stacked horizontally {@code n} times.
  • - *
- * - * @return A matrix whose rows/columns are this vector repeated. - */ - @Override - public V repeat(int n, int axis) { - Y[] tiledEntries = makeEmptyDataArray(n*data.length); - int[] tiledRows = new int[tiledEntries.length]; - int[] tiledCols = new int[tiledEntries.length]; - Shape tiledShape = CooConcat.repeat(data, indices, size, n, axis, tiledEntries, tiledRows, tiledCols); - return makeLikeMatrix(tiledShape, tiledEntries, tiledRows, tiledCols); - } - - - /** - *

- * Stacks two vectors along specified axis. - * - * - *

- * Stacking two vectors of length {@code n} along axis 0 stacks the vectors - * as if they were row vectors resulting in a {@code 2-by-n} matrix. - * - * - *

- * Stacking two vectors of length {@code n} along axis 1 stacks the vectors - * as if they were column vectors resulting in a {@code n-by-2} matrix. - * - * - * @param b Vector to stack with this vector. - * @param axis Axis along which to stack vectors. If {@code axis=0}, then vectors are stacked as if they are row - * vectors. If {@code axis=1}, then vectors are stacked as if they are column vectors. - * - * @return The result of stacking this vector and the vector {@code b}. - * - * @throws IllegalArgumentException If the number of data in this vector is different from the number of - * data in the vector {@code b}. - * @throws IllegalArgumentException If axis is not either 0 or 1. - */ - @Override - public V stack(T b, int axis) { - ValidateParameters.ensureEquals(size, b.size); - Y[] destEntries = makeEmptyDataArray(data.length + b.data.length); - int[][] destIndices = new int[2][indices.length + b.indices.length]; // Row and column indices. - - CooConcat.stack(data, indices, b.data, b.indices, destEntries, destIndices[0], destIndices[1]); - V mat = makeLikeMatrix(new Shape(2, size), destEntries, destIndices[0], destIndices[1]); - - return (axis == 0) ? mat : mat.T(); - } - - - /** - * Computes the outer product of two vectors. - * - * @param b Second vector in the outer product. - * - * @return The result of the vector outer product between this vector and {@code b}. - * - * @throws IllegalArgumentException If the two vectors do not have the same number of data. - */ - @Override - public W outer(T b) { - // TODO: This should almost certainly return a sparse tensor. It seems unlikely that if a vectors are worth storing as - // as COO vectors that the outer product would be dense. Further, this would almost never be useful as the dense matrix - // would take up so much more memory than the two sparse COO vectors (assuming they are 'very' sparse). - Shape destShape = new Shape(size, b.size); - Y[] dest = makeEmptyDataArray(size*b.size); - CooSemiringVectorOps.outerProduct(data, indices, size, b.data, b.indices, dest); - return makeLikeDenseMatrix(shape, dest); - } - - - /** - * Converts a vector to an equivalent matrix representing either a row or column vector. - * - * @param columVector Flag indicating whether to convert this vector to a matrix representing a row or column vector: - *

If {@code true}, the vector will be converted to a matrix representing a column vector. - *

If {@code false}, The vector will be converted to a matrix representing a row vector. - * - * @return A matrix equivalent to this vector. - */ - @Override - public V toMatrix(boolean columVector) { - if(columVector) { - // Convert to column vector - int[] rowIndices = indices.clone(); - int[] colIndices = new int[data.length]; - Shape matShape = new Shape(size, 1); - - return makeLikeMatrix(matShape, data.clone(), rowIndices, colIndices); - } else { - // Convert to row vector. - int[] rowIndices = new int[data.length]; - int[] colIndices = indices.clone(); - Shape matShape = new Shape(1, size); - - return makeLikeMatrix(matShape, data.clone(), rowIndices, colIndices); - } - } - - - /** - * Computes the element-wise sum between two tensors of the same shape. - * - * @param b Second tensor in the element-wise sum. - * - * @return The sum of this tensor with {@code b}. - * - * @throws TensorShapeException If this tensor and {@code b} do not have the same shape. - */ - @Override - public T add(T b) { - SparseVectorData result = CooSemiringVectorOps.add( - shape, data, indices, b.shape, b.data, b.indices); - return makeLikeTensor(shape, - result.data().toArray(makeEmptyDataArray(result.data().size())), - ArrayUtils.fromIntegerList(result.indices())); - } - - - /** - * Computes the element-wise multiplication of two tensors of the same shape. - * - * @param b Second tensor in the element-wise product. - * - * @return The element-wise product between this tensor and {@code b}. - * - * @throws IllegalArgumentException If this tensor and {@code b} do not have the same shape. - */ - @Override - public T elemMult(T b) { - SparseVectorData prod = CooSemiringVectorOps.elemMult( - shape, data, indices, - b.shape, b.data, b.indices); - return makeLikeTensor(shape, - prod.data().toArray(makeEmptyDataArray(prod.data().size())), - ArrayUtils.fromIntegerList(prod.indices())); - } - - - /** - * Computes the tensor contraction of this tensor with a specified tensor over the specified set of axes. That is, - * computes the sum of products between the two tensors along the specified set of axes. - * - * @param src2 Tensor to contract with this tensor. - * @param aAxes Axes along which to compute products for this tensor. - * @param bAxes Axes along which to compute products for {@code src2} tensor. - * - * @return The tensor dot product over the specified axes. - * - * @throws IllegalArgumentException If the two tensors shapes do not match along the specified axes pairwise in - * {@code aAxes} and {@code bAxes}. - * @throws IllegalArgumentException If {@code aAxes} and {@code bAxes} do not match in length, or if any of the axes - * are out of bounds for the corresponding tensor. - */ - @Override - public U tensorDot(T src2, int[] aAxes, int[] bAxes) { - if(aAxes.length != 1 || bAxes.length != 1) { - throw new LinearAlgebraException("Vector dot product requires exactly one dimension for each vector but got " - + aAxes.length + " and " + bAxes.length + "."); - } - if(aAxes[0] != 0 || bAxes[0] != 0) { - throw new LinearAlgebraException("Both axes must be 0 for vector dot product but got " - + aAxes[0] + " and " + bAxes[0] + "."); - } - - return makeLikeDenseTensor(shape, dot(src2)); - } - - - /** - *

Computes the generalized trace of this tensor along the specified axes. - * - *

The generalized tensor trace is the sum along the diagonal values of the 2D sub-arrays of this tensor specified by - * {@code axis1} and {@code axis2}. The shape of the resulting tensor is equal to this tensor with the - * {@code axis1} and {@code axis2} removed. - * - * @param axis1 First axis for 2D sub-array. - * @param axis2 Second axis for 2D sub-array. - * - * @return The generalized trace of this tensor along {@code axis1} and {@code axis2}. - * - * @throws IndexOutOfBoundsException If the two axes are not both larger than zero and less than this tensors rank. - * @throws IllegalArgumentException If {@code axis1 == axis2} or {@code this.shape.get(axis1) != this.shape.get(axis1)} - * (i.e. the axes are equal or the tensor does not have the same length along the two axes.) - */ - @Override - public T tensorTr(int axis1, int axis2) { - throw new LinearAlgebraException("Tensor trace cannot be computed for a rank 1 tensor " + - "(must be rank 2 or " + "greater)."); - } - - - /** - * Gets the zero element for the field of this vector. - * @return The zero element for the field of this vector. If it could not be determined during construction of this object - * and has not been set explicitly by {@link #setZeroElement(Field)} then {@code null} will be returned. - */ - public Y getZeroElement() { - return (Y) zeroElement; - } - - - /** - * Sets the zero element for the field of this tensor. - * @param zeroElement The zero element of this tensor. - * @throws IllegalArgumentException If {@code zeroElement} is not an additive identity for the ring. - */ - public void setZeroElement(Y zeroElement) { - if (zeroElement.isZero()) { - this.zeroElement = zeroElement; - } else { - throw new IllegalArgumentException("The provided zeroElement is not an additive identity."); - } - } - - - /** - * Converts this sparse COO matrix to an equivalent dense matrix. - * @return A dense matrix equivalent to this sparse COO matrix. - */ - public U toDense() { - Y[] entries = makeEmptyDataArray(shape.totalEntriesIntValueExact()); - Arrays.fill(entries, zeroElement); - - for(int i = 0; i< nnz; i++) - entries[indices[i]] = data[i]; - - return makeLikeDenseTensor(shape, entries); - } - - - /** - * Converts this matrix to an equivalent rank 1 tensor. - * @return A tensor which is equivalent to this matrix. - */ - public abstract AbstractTensor toTensor(); - - - /** - * Converts this vector to an equivalent tensor with the specified shape. - * @param newShape New shape for the tensor. Can be any rank but must be broadcastable to {@link #shape this.shape}. - * @return A tensor equivalent to this matrix which has been reshaped to {@code newShape} - */ - public abstract AbstractTensor toTensor(Shape newShape); - - - /** - * Computes the element-wise difference between two tensors of the same shape. - * - * @param b Second tensor in the element-wise difference. - * - * @return The difference of this tensor with {@code b}. - * - * @throws TensorShapeException If this tensor and {@code b} do not have the same shape. - */ - @Override - public T sub(T b) { - SparseVectorData result = CooRingVectorOps.sub( - shape, data, indices, b.shape, b.data, b.indices); - return makeLikeTensor(shape, - result.data().toArray(makeEmptyDataArray(result.data().size())), - result.indicesToArray()); + super(shape, data, indices); } @@ -868,55 +236,4 @@ public Y mag() { return mag.sqrt(); } - - - /** - * Gets the element of this vector at the specified index. - * - * @param idx Index of the element to get within this vector. - * - * @return The element of this vector at index {@code idx}. - */ - @Override - public Y get(int idx) { - ValidateParameters.validateTensorIndex(shape, idx); - Y value = CooGetSet.getCoo(data, indices, idx); - return (value == null) ? getZeroElement() : value; - } - - - /** - * Coalesces this sparse COO vector. An uncoalesced vector is a sparse vector with multiple data for a single index. This - * method will ensure that each index only has one non-zero value by summing duplicated data. If another form of aggregation other - * than summing is desired, use {@link #coalesce(BinaryOperator)}. - * @return A new coalesced sparse COO vector which is equivalent to this COO vector. - * @see #coalesce(BinaryOperator) - */ - public T coalesce() { - SparseVectorData vec = SparseUtils.coalesce(Semiring::add, shape, data, indices); - return makeLikeTensor(vec.shape(), vec.data(), vec.indices()); - } - - - /** - * Coalesces this sparse COO vector. An uncoalesced vector is a sparse vector with multiple data for a single index. This - * method will ensure that each index only has one non-zero value by aggregating duplicated data using {@code aggregator}. - * @param aggregator Custom aggregation function to combine multiple. - * @return A new coalesced sparse COO vector which is equivalent to this COO vector. - * @see #coalesce() - */ - public T coalesce(BinaryOperator aggregator) { - SparseVectorData vec = SparseUtils.coalesce(aggregator, shape, data, indices); - return makeLikeTensor(vec.shape(), vec.data(), vec.indices()); - } - - - /** - * Drops any explicit zeros in this sparse COO vector. - * @return A copy of this COO vector with any explicitly stored zeros removed. - */ - public T dropZeros() { - SparseVectorData vec = SparseUtils.dropZeros(shape, data, indices); - return makeLikeTensor(vec.shape(), vec.data(), vec.indices()); - } } diff --git a/src/main/java/org/flag4j/arrays/backend/field_arrays/AbstractCsrFieldMatrix.java b/src/main/java/org/flag4j/arrays/backend/field_arrays/AbstractCsrFieldMatrix.java index e1795dd24..f7f1598b3 100644 --- a/src/main/java/org/flag4j/arrays/backend/field_arrays/AbstractCsrFieldMatrix.java +++ b/src/main/java/org/flag4j/arrays/backend/field_arrays/AbstractCsrFieldMatrix.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -27,71 +27,22 @@ import org.flag4j.algebraic_structures.Field; import org.flag4j.arrays.Shape; import org.flag4j.arrays.SparseMatrixData; -import org.flag4j.arrays.backend.AbstractTensor; import org.flag4j.arrays.backend.MatrixMixin; +import org.flag4j.arrays.backend.ring_arrays.AbstractCsrRingMatrix; +import org.flag4j.arrays.backend.semiring_arrays.AbstractCsrSemiringMatrix; import org.flag4j.arrays.sparse.CsrMatrix; import org.flag4j.linalg.ops.common.field_ops.FieldOps; import org.flag4j.linalg.ops.common.ring_ops.RingOps; -import org.flag4j.linalg.ops.sparse.csr.CsrConversions; import org.flag4j.linalg.ops.sparse.csr.CsrOps; -import org.flag4j.linalg.ops.sparse.csr.CsrProperties; -import org.flag4j.linalg.ops.sparse.csr.field_ops.CsrFieldMatrixProperties; +import org.flag4j.linalg.ops.sparse.csr.ring_ops.CsrRingProperties; import org.flag4j.linalg.ops.sparse.csr.semiring_ops.SemiringCsrMatMult; -import org.flag4j.linalg.ops.sparse.csr.semiring_ops.SemiringCsrOps; -import org.flag4j.linalg.ops.sparse.csr.semiring_ops.SemiringCsrProperties; -import org.flag4j.util.ValidateParameters; -import org.flag4j.util.exceptions.LinearAlgebraException; -import org.flag4j.util.exceptions.TensorShapeException; - -import java.math.BigDecimal; -import java.math.RoundingMode; -import java.util.Arrays; -import java.util.List; - -import static org.flag4j.linalg.ops.sparse.SparseUtils.sortCsrMatrix; public abstract class AbstractCsrFieldMatrix, U extends AbstractDenseFieldMatrix, V extends AbstractCooFieldVector, W extends Field> - extends AbstractTensor + extends AbstractCsrRingMatrix implements FieldTensorMixin, MatrixMixin { - /** - * The zero element for the field that this tensor's elements belong to. - */ - private W zeroElement; - /** - *

Pointers indicating starting index of each row within the {@link #colIndices} and {@link #data} arrays. - * Has length {@link #numRows numRows + 1}. - * - *

The range [{@code data[rowPointers[i]], data[rowPointers[i+1]]}) contains all {@link #data non-zero data} within - * row {@code i}. - * - *

Similarly, [{@code colData[rowPointers[i]], colData[rowPointers[i+1]]}) contains all {@link #colIndices column indices} - * for the data in row {@code i}. - * - */ - public final int[] rowPointers; - /** - * Column indices for non-zero values of this sparse CSR matrix. - */ - public final int[] colIndices; - /** - * Number of non-zero data in this CSR matrix. - */ - public final int nnz; - /** - * The number of rows in this matrix. - */ - public final int numRows; - /** - * The number of columns in this matrix. - */ - public final int numCols; - /** - * The sparsity of this matrix. - */ - private final double sparsity; /** @@ -106,498 +57,16 @@ public abstract class AbstractCsrFieldMatrix 0 && entries[0] != null) ? entries[0].getZero() : null; - } - - - /** - * Constructs a sparse CSR tensor of the same type as this tensor with the specified non-zero data and indices. - * @param shape Shape of the matrix. - * @param entries Non-zero data of the CSR matrix. - * @param rowPointers Row pointers for the non-zero values in the CSR matrix. - * @param colIndices Non-zero column indices of the CSR matrix. - * @return A sparse CSR tensor of the same type as this tensor with the specified non-zero data and indices. - */ - public abstract T makeLikeTensor(Shape shape, W[] entries, int[] rowPointers, int[] colIndices); - - - /** - * Constructs a CSR matrix with the specified shape, non-zero data, and non-zero indices. - * @param shape Shape of the matrix. - * @param entries Non-zero values of the CSR matrix. - * @param rowPointers Row pointers for the non-zero values in the CSR matrix. - * @param colIndices Non-zero column indices of the CSR matrix. - * @return A CSR matrix with the specified shape, non-zero data, and non-zero indices. - */ - public abstract T makeLikeTensor(Shape shape, List entries, List rowPointers, List colIndices); - - - /** - * Constructs a dense matrix which is of a similar type to this sparse CSR matrix. - * @param shape Shape of the dense matrix. - * @param entries Entries of the dense matrix. - * @return A dense matrix which is of a similar type to this sparse CSR matrix with the specified {@code shape} - * and {@code data}. - */ - public abstract U makeLikeDenseTensor(Shape shape, W[] entries); - - - /** - *

Constructs a sparse COO matrix of a similar type to this sparse CSR matrix. - *

Note: this method constructs a new COO matrix with the specified data and indices. It does not convert this matrix - * to a CSR matrix. To convert this matrix to a sparse COO matrix use {@link #toCoo()}. - * @param shape Shape of the COO matrix. - * @param entries Non-zero data of the COO matrix. - * @param rowIndices Non-zero row indices of the sparse COO matrix. - * @param colIndices Non-zero column indices of the Sparse COO matrix. - * @return A sparse COO matrix of a similar type to this sparse CSR matrix. - */ - public abstract AbstractCooFieldMatrix makeLikeCooMatrix( - Shape shape, W[] entries, int[] rowIndices, int[] colIndices); - - - /** - * Gets the sparsity of this matrix as a decimal percentage. - * That is, the percentage of data in this matrix that are zero. - * @return The sparsity of this matrix as a decimal percentage. - * @see #density() - */ - public double sparsity() { - return sparsity; - } - - - /** - * Gets the density of this matrix as a decimal percentage. - * That is, the percentage of data in this matrix that are non-zero. - * @return The density of this matrix as a decimal percentage. - * @see #sparsity - */ - public double density() { - return 1.0 - sparsity; - } - - - /** - * Gets the length of the data array which backs this matrix. - * - * @return The length of the data array which backs this matrix. - */ - @Override - public int dataLength() { - return data.length; - } - - - /** - * Gets the zero element for the field of this tensor. - * @return The zero element for the field of this tensor. If it could not be determined during construction of this object - * and has not been set explicitly by {@link #setZeroElement(Field)} then {@code null} will be returned. - * - * @see #setZeroElement(Field) - */ - public W getZeroElement() { - return zeroElement; - } - - - /** - * Sets the zero element for the field of this tensor. - * @param zeroElement The zero element of this tensor. - * @throws IllegalArgumentException If {@code zeroElement} is not an additive identity for the field. - * - * @see #getZeroElement() - */ - public void setZeroElement(W zeroElement) { - if (zeroElement.isZero()) { - this.zeroElement = zeroElement; - } else { - throw new IllegalArgumentException("The provided zeroElement is not an additive identity."); - } - } - - - - /** - * Gets the element of this tensor at the specified indices. - * - * @param indices Indices of the element to get. - * - * @return The element of this tensor at the specified indices. - * - * @throws ArrayIndexOutOfBoundsException If any indices are not within this tensor. - */ - @Override - public W get(int... indices) { - ValidateParameters.validateTensorIndex(shape, indices); - int row = indices[0]; - int col = indices[1]; - return get(row, col); - } - - - /** - * Sets the element of this tensor at the specified indices. - * - * @param value New value to set the specified index of this tensor to. - * @param indices Indices of the element to set. - * - * @return If this tensor is dense, a reference to this tensor is returned. If this tensor is sparse, a copy of this tensor with - * the updated value is returned. - * - * @throws IndexOutOfBoundsException If {@code indices} is not within the bounds of this tensor. - */ - @Override - public T set(W value, int... indices) { - ValidateParameters.validateTensorIndex(shape, indices); - return set(value, indices[0], indices[1]); - } - - - /** - * Flattens tensor to single dimension while preserving order of data. - * - * @return The flattened tensor. - * - * @see #flatten(int) - */ - @Override - public T flatten() { - int[] newRowPointers = new int[2]; - newRowPointers[1] = nnz; - return makeLikeTensor( - new Shape(1, shape.totalEntriesIntValueExact()), - data.clone(), - newRowPointers, - colIndices.clone()); - } - - - /** - * Flattens a tensor along the specified axis. Unlike {@link #flatten()} - * - * @param axis Axis along which to flatten tensor. - * - * @throws ArrayIndexOutOfBoundsException If the axis is not positive or larger than {@code this.{@link #getRank()}-1}. - * @see #flatten() - */ - @Override - public T flatten(int axis) { - ValidateParameters.ensureValidAxes(shape, axis); - int[] newRowPointers; - int[] newColIndices; - - if (axis == 0) { - // Flatten to a single row. - newRowPointers = new int[2]; - newRowPointers[1] = nnz; - newColIndices = new int[nnz]; - } else { - // Flatten to a single column. - int flatSize = shape.totalEntriesIntValueExact(); - newColIndices = new int[nnz]; // Set all column indices to 0. - newRowPointers = new int[flatSize + 1]; - } - - Shape newShape = CsrConversions.flatten(shape, data, rowPointers, colIndices, axis, newRowPointers, newColIndices); - - return makeLikeTensor( - new Shape(shape.totalEntriesIntValueExact(), 1), - data.clone(), - newRowPointers, - newColIndices); - } - - - /** - * Copies and reshapes this tensor. - * - * @param newShape New shape for the tensor. - * - * @return A copy of this tensor with the new shape. - * - * @throws TensorShapeException If {@code newShape} is not broadcastable to {@link #shape this.shape}. - */ - @Override - public T reshape(Shape newShape) { - return (T) toCoo().reshape(newShape).toCsr(); - } - - - /** - * Computes the transpose of a tensor by exchanging the first and last axes of this tensor. - * - * @return The transpose of this tensor. - * - * @see #T(int, int) - * @see #T(int...) - */ - @Override - public T T() { - W[] dest = makeEmptyDataArray(data.length); - int[] destRowPointers = new int[numCols+1]; - int[] destColIndices = new int[data.length]; - CsrOps.transpose(data, rowPointers, colIndices, dest, destRowPointers, destColIndices); - - return makeLikeTensor(shape.swapAxes(0, 1), dest, destRowPointers, destColIndices); - } - - - /** - * Computes the element-wise sum between two tensors of the same shape. - * - * @param b Second tensor in the element-wise sum. - * - * @return The sum of this tensor with {@code b}. - * - * @throws TensorShapeException If this tensor and {@code b} do not have the same shape. - */ - @Override - public T add(T b) { - SparseMatrixData destData = CsrOps.applyBinOpp( - shape, data, rowPointers, colIndices, - b.shape, b.data, b.rowPointers, b.colIndices, - Field::add, null); - - return makeLikeTensor(shape, destData.data(), destData.rowData(), destData.colData()); - } - - - /** - * Computes the element-wise multiplication of two tensors of the same shape. - * - * @param b Second tensor in the element-wise product. - * - * @return The element-wise product between this tensor and {@code b}. - * - * @throws IllegalArgumentException If this tensor and {@code b} do not have the same shape. - */ - @Override - public T elemMult(T b) { - SparseMatrixData destData = CsrOps.applyBinOpp( - shape, data, rowPointers, colIndices, - b.shape, b.data, b.rowPointers, b.colIndices, - Field::mult, null); - - return makeLikeTensor(shape, destData.data(), destData.rowData(), destData.colData()); - } - - - /** - *

Computes the generalized trace of this tensor along the specified axes. - * - *

The generalized tensor trace is the sum along the diagonal values of the 2D sub-arrays of this tensor specified by - * {@code axis1} and {@code axis2}. The shape of the resulting tensor is equal to this tensor with the - * {@code axis1} and {@code axis2} removed. - * - * @param axis1 First axis for 2D sub-array. - * @param axis2 Second axis for 2D sub-array. - * - * @return The generalized trace of this tensor along {@code axis1} and {@code axis2}. - * - * @throws IndexOutOfBoundsException If the two axes are not both larger than zero and less than this tensors rank. - * @throws IllegalArgumentException If {@code axis1 == axis2} or {@code this.shape.get(axis1) != this.shape.get(axis1)} - * (i.e. the axes are equal or the tensor does not have the same length along the two axes.) - */ - @Override - public T tensorTr(int axis1, int axis2) { - // TODO: Needs to return a tensor and probably be abstract and implemented in concrete children classes. - ValidateParameters.ensureNotEquals(axis1, axis2); - ValidateParameters.ensureValidAxes(shape, axis1, axis2); - - // TODO: Investigate the (W[]) cast for array of specific field implementation (e.g. complex128). - return (T) makeLikeTensor(new Shape(1, 1), (W[]) new Field[]{tr()}, new int[]{0}, new int[]{0}); - } - - - /** - * Computes the transpose of a tensor by exchanging {@code axis1} and {@code axis2}. - * - * @param axis1 First axis to exchange. - * @param axis2 Second axis to exchange. - * - * @return The transpose of this tensor according to the specified axes. - * - * @throws IndexOutOfBoundsException If either {@code axis1} or {@code axis2} are out of bounds for the rank of this tensor. - * @see #T() - * @see #T(int...) - */ - @Override - public T T(int axis1, int axis2) { - ValidateParameters.ensureValidAxes(shape, axis1, axis2); - if(axis1 == axis2) return copy(); - return T(); - } - - - /** - * Computes the transpose of this tensor. That is, permutes the axes of this tensor so that it matches - * the permutation specified by {@code axes}. - * - * @param axes Permutation of tensor axis. If the tensor has rank {@code N}, then this must be an array of length - * {@code N} which is a permutation of {@code {0, 1, 2, ..., N-1}}. - * - * @return The transpose of this tensor with its axes permuted by the {@code axes} array. - * - * @throws IndexOutOfBoundsException If any element of {@code axes} is out of bounds for the rank of this tensor. - * @throws IllegalArgumentException If {@code axes} is not a permutation of {@code {1, 2, 3, ... N-1}}. - * @see #T(int, int) - * @see #T() - */ - @Override - public T T(int... axes) { - if(axes.length != 2) { - throw new IllegalArgumentException("Cannot transpose axes " - + Arrays.toString(axes) + " for a tensor of rank " + rank); - } - - return T(axes[0], axes[1]); - } - - - /** - * Gets the number of rows in this matrix. - * - * @return The number of rows in this matrix. - */ - @Override - public int numRows() { - return numRows; - } - - - /** - * Gets the number of columns in this matrix. - * - * @return The number of columns in this matrix. - */ - @Override - public int numCols() { - return numCols; - } - - - /** - * Gets the element of this matrix at this specified {@code row} and {@code col}. - * - * @param row Row index of the item to get from this matrix. - * @param col Column index of the item to get from this matrix. - * - * @return The element of this matrix at the specified index. - */ - @Override - public W get(int row, int col) { - ValidateParameters.validateTensorIndex(shape, row, col); - int loc = Arrays.binarySearch(colIndices, rowPointers[row], rowPointers[row+1], col); - - if(loc >= 0) return data[loc]; - else return zeroElement; - } - - - /** - *

Computes the trace of this matrix. That is, the sum of elements along the principle diagonal of this matrix. - * - *

Same as {@link #trace()}. - * - * @return The trace of this matrix. - * - * @throws IllegalArgumentException If this matrix is not square. - */ - @Override - public W tr() { - ValidateParameters.ensureSquare(shape); - W tr = SemiringCsrOps.trace(data, rowPointers, colIndices); - return (tr == null) ? zeroElement : tr; - } - - - /** - * Checks if this matrix is upper triangular. - * - * @return {@code true} is this matrix is upper triangular; {@code false} otherwise. - * - * @see #isTri() - * @see #isTriL() - * @see #isDiag() - */ - @Override - public boolean isTriU() { - return SemiringCsrProperties.isTriU(shape, data, rowPointers, colIndices); - } - - - /** - * Checks if this matrix is lower triangular. - * - * @return {@code true} is this matrix is lower triangular; {@code false} otherwise. - * - * @see #isTri() - * @see #isTriU() - * @see #isDiag() - */ - @Override - public boolean isTriL() { - return SemiringCsrProperties.isTriL(shape, data, rowPointers, colIndices); + super(shape, entries, rowPointers, colIndices); } - - /** - * Checks if this matrix is the identity matrix. That is, checks if this matrix is square and contains - * only ones along the principle diagonal and zeros everywhere else. - * - * @return {@code true} if this matrix is the identity matrix; {@code false} otherwise. - */ - @Override - public boolean isI() { - return SemiringCsrProperties.isIdentity(shape, data, rowPointers, colIndices); - } - - - /** - * Computes the matrix multiplication between two matrices. - * - * @param b Second matrix in the matrix multiplication. - * - * @return The result of matrix multiplying this matrix with matrix {@code b}. - * - * @throws LinearAlgebraException If the number of columns in this matrix do not equal the number - * of rows in matrix {@code b}. - * @see #mult2Csr(AbstractCsrFieldMatrix) - */ - @Override - public U mult(T b) { - Shape destShape = new Shape(numRows, b.numCols); - W[] destArray = makeEmptyDataArray(numRows*b.numCols); - - SemiringCsrMatMult.standard( - shape, data, rowPointers, colIndices, b.shape, - b.data, b.rowPointers, b.colIndices, - destArray, zeroElement); - - return makeLikeDenseTensor(destShape, destArray); - } - - /** *

Computes the matrix multiplication between two sparse CSR matrices and stores the result in a sparse matrix. *

Warning: this method should be used with caution as sparse-sparse matrix multiplication may result in a dense matrix. - * In such a case, this method will likely be significantly slower than {@link #mult(AbstractCsrFieldMatrix)}. + * In such a case, this method will likely be significantly slower than {@link #mult(AbstractCsrSemiringMatrix)}. * @param b Second matrix in the matrix multiplication. * @return The result of matrix multiplying this matrix with matrix {@code b} as a sparse CSR matrix. - * @see #mult(AbstractCsrFieldMatrix) + * @see #mult(AbstractCsrSemiringMatrix) */ public T mult2Csr(T b) { SparseMatrixData data = SemiringCsrMatMult.standardToSparse( @@ -608,116 +77,6 @@ public T mult2Csr(T b) { } - /** - * Multiplies this matrix with the transpose of the {@code b} tensor as if by - * {@code this.mult(b.T())}. - * For large matrices, this method may - * be significantly faster than directly computing the transpose followed by the multiplication as - * {@code this.mult(b.T())}. - * - * @param b The second matrix in the multiplication and the matrix to transpose. - * - * @return The result of multiplying this matrix with the transpose of {@code b}. - */ - @Override - public U multTranspose(T b) { - ValidateParameters.ensureEquals(numCols, b.numCols); - return mult(b.T()); - } - - - /** - * Stacks matrices along columns.
- * - * @param b Matrix to stack to this matrix. - * - * @return The result of stacking this matrix on top of the matrix {@code b}. - * - * @throws IllegalArgumentException If this matrix and matrix {@code b} have a different number of columns. - * @see #stack(MatrixMixin, int) - * @see #augment(T) - */ - @Override - public T stack(T b) { - return (T) toCoo().stack(b.toCoo()).toCsr(); - } - - - /** - * Stacks matrices along rows. - * - * @param b Matrix to stack to this matrix. - * - * @return The result of stacking {@code b} to the right of this matrix. - * - * @throws IllegalArgumentException If this matrix and matrix {@code b} have a different number of rows. - * @see #stack(T) - * @see #stack(MatrixMixin, int) - */ - @Override - public T augment(T b) { - return (T) toCoo().augment(b.toCoo()).toCsr(); - } - - - /** - * Augments a vector to this matrix. - * - * @param b The vector to augment to this matrix. - * - * @return The result of augmenting {@code b} to this matrix. - */ - @Override - public T augment(V b) { - return (T) toCoo().augment(b).toCsr(); - } - - - /** - * Swaps specified rows in the matrix. This is done in place. - * - * @param rowIndex1 Index of the first row to swap. - * @param rowIndex2 Index of the second row to swap. - * - * @return A reference to this matrix. - * - * @throws ArrayIndexOutOfBoundsException If either index is outside the matrix bounds. - */ - @Override - public T swapRows(int rowIndex1, int rowIndex2) { - CsrOps.swapRows(data, rowPointers, colIndices, rowIndex1, rowIndex2); - return (T) this; - } - - - /** - * Swaps specified columns in the matrix. This is done in place. - * - * @param colIndex1 Index of the first column to swap. - * @param colIndex2 Index of the second column to swap. - * - * @return A reference to this matrix. - * - * @throws ArrayIndexOutOfBoundsException If either index is outside the matrix bounds. - */ - @Override - public T swapCols(int colIndex1, int colIndex2) { - CsrOps.swapCols(data, rowPointers, colIndices, colIndex1, colIndex2); - return (T) this; - } - - - /** - * Checks if a matrix is symmetric. That is, if the matrix is square and equal to its transpose. - * - * @return {@code true} if this matrix is symmetric; {@code false} otherwise. - */ - @Override - public boolean isSymmetric() { - return CsrProperties.isSymmetric(shape, data, rowPointers, colIndices); - } - - /** * Checks if a matrix is Hermitian. That is, if the matrix is square and equal to its conjugate transpose. * @@ -726,7 +85,7 @@ public boolean isSymmetric() { @Override public boolean isHermitian() { // For a field matrix, same as isSymmetric. - return CsrFieldMatrixProperties.isHermitian(this); + return CsrRingProperties.isHermitian(shape, data, rowPointers, colIndices); } @@ -735,260 +94,12 @@ public boolean isHermitian() { * * @return {@code true} if this matrix it is orthogonal; {@code false} otherwise. */ - @Override - public boolean isOrthogonal() { - if(isSquare()) return mult(T()).isI(); + public boolean isUnitary() { + if(isSquare()) return mult(H()).isI(); else return false; } - /** - * Sets a specified row of this matrix to a vector. - * - * @param row Vector to replace specified row in this matrix. - * @param rowIdx Index of the row to set. - * - * @return If this matrix is dense, the row set operation is done in place and a reference to this matrix is returned. - * If this matrix is sparse a copy will be created with the new row and returned. - */ - @Override - public T setRow(V row, int rowIdx) { - return (T) toCoo().setRow(row, rowIdx).toCsr(); - } - - - /** - * Sets a specified column of this matrix to a vector. - * - * @param col Vector to replace specified column in this matrix. - * @param colIdx Index of the column to set. - * - * @return If this matrix is dense, the column set operation is done in place and a reference to this matrix is returned. - * If this matrix is sparse a copy will be created with the new column and returned. - */ - @Override - public T setCol(V col, int colIdx) { - return (T) toCoo().setCol(col, colIdx).toCsr(); - } - - - /** - * Removes a specified row from this matrix. - * - * @param rowIndex Index of the row to remove from this matrix. - * - * @return A copy of this matrix with the specified row removed. - */ - @Override - public T removeRow(int rowIndex) { - return (T) toCoo().removeRow(rowIndex).toCsr(); - } - - - /** - * Removes a specified set of rows from this matrix. - * - * @param rowIndices The indices of the rows to remove from this matrix. Assumed to contain unique values. - * - * @return A copy of this matrix with the specified column removed. - */ - @Override - public T removeRows(int... rowIndices) { - return (T) toCoo().removeRows(rowIndices).toCsr(); - } - - - /** - * Removes a specified column from this matrix. - * - * @param colIndex Index of the column to remove from this matrix. - * - * @return A copy of this matrix with the specified column removed. - */ - @Override - public T removeCol(int colIndex) { - return (T) toCoo().removeCol(colIndex).toCsr(); - } - - - /** - * Removes a specified set of columns from this matrix. - * - * @param colIndices Indices of the columns to remove from this matrix. Assumed to contain unique values. - * - * @return A copy of this matrix with the specified column removed. - */ - @Override - public T removeCols(int... colIndices) { - return (T) toCoo().removeCols(colIndices).toCsr(); - } - - - /** - * Creates a copy of this matrix and sets a slice of the copy to the specified values. The rowStart and colStart parameters specify the upper - * left index location of the slice to set. - * - * @param values New values for the specified slice. - * @param rowStart Starting row index for the slice (inclusive). - * @param colStart Starting column index for the slice (inclusive). - * - * @return A copy of this matrix with the given slice set to the specified values. - * - * @throws IndexOutOfBoundsException If rowStart or colStart are not within the matrix. - * @throws IllegalArgumentException If the values slice, with upper left corner at the specified location, does not - * fit completely within this matrix. - */ - @Override - public T setSliceCopy(T values, int rowStart, int colStart) { - return (T) toCoo().setSliceCopy(values.toCoo(), rowStart, colStart).toCsr(); - } - - - /** - * Gets a specified slice of this matrix. - * - * @param rowStart Starting row index of slice (inclusive). - * @param rowEnd Ending row index of slice (exclusive). - * @param colStart Starting column index of slice (inclusive). - * @param colEnd Ending row index of slice (exclusive). - * - * @return The specified slice of this matrix. This is a completely new matrix and NOT a view into the matrix. - * - * @throws ArrayIndexOutOfBoundsException If any of the indices are out of bounds of this matrix. - * @throws IllegalArgumentException If {@code rowEnd} is not greater than {@code rowStart} or if {@code colEnd} is not greater than {@code colStart}. - */ - @Override - public T getSlice(int rowStart, int rowEnd, int colStart, int colEnd) { - SparseMatrixData sliceData = CsrOps.getSlice( - data, rowPointers, colIndices, - rowStart, rowEnd, colStart, colEnd); - return makeLikeTensor(sliceData.shape(), sliceData.data(), - sliceData.rowData(), sliceData.colData()); - } - - - /** - * Sets an index of this matrix to the specified value. - * - * @param value Value to set. - * @param row Row index to set. - * @param col Column index to set. - * - * @return A reference to this matrix. - */ - @Override - public T set(W value, int row, int col) { - // Ensure indices are in bounds. - ValidateParameters.validateTensorIndex(shape, row, col); - W[] newEntries; - int[] newRowPointers = rowPointers.clone(); - int[] newColIndices; - boolean found = false; // Flag indicating an element already exists in this matrix at the specified row and col. - int loc = -1; - - if(rowPointers[row] < rowPointers[row+1]) { - int start = rowPointers[row]; - int stop = rowPointers[row+1]; - - loc = Arrays.binarySearch(colIndices, start, stop, col); - found = loc >= 0; - } - - if(found) { - newEntries = data.clone(); - newEntries[loc] = value; - newRowPointers = rowPointers.clone(); - newColIndices = colIndices.clone(); - } else { - loc = -loc - 1; // Compute insertion index as specified by Arrays.binarySearch. - newEntries = makeEmptyDataArray(data.length + 1); - newColIndices = new int[data.length + 1]; - - CsrOps.insertNewValue( - data, rowPointers, colIndices, - newEntries, newRowPointers, newColIndices, - row, col, loc, value); - } - - return makeLikeTensor(shape, newEntries, newRowPointers, newColIndices); - } - - - /** - * Extracts the upper-triangular portion of this matrix with a specified diagonal offset. All other data of the resulting - * matrix will be zero. - * - * @param diagOffset Diagonal offset for upper-triangular portion to extract: - *

    - *
  • If zero, then all data at and above the principle diagonal of this matrix are extracted.
  • - *
  • If positive, then all data at and above the equivalent super-diagonal are extracted.
  • - *
  • If negative, then all data at and above the equivalent sub-diagonal are extracted.
  • - *
- * - * @return The upper-triangular portion of this matrix with a specified diagonal offset. All other data of the returned - * matrix will be zero. - * - * @throws IllegalArgumentException If {@code diagOffset} is not in the range (-numRows, numCols). - */ - @Override - public T getTriU(int diagOffset) { - return (T) toCoo().getTriU(diagOffset).toCsr(); - } - - - /** - * Extracts the lower-triangular portion of this matrix with a specified diagonal offset. All other data of the resulting - * matrix will be zero. - * - * @param diagOffset Diagonal offset for lower-triangular portion to extract: - *
    - *
  • If zero, then all data at and above the principle diagonal of this matrix are extracted.
  • - *
  • If positive, then all data at and above the equivalent super-diagonal are extracted.
  • - *
  • If negative, then all data at and above the equivalent sub-diagonal are extracted.
  • - *
- * - * @return The lower-triangular portion of this matrix with a specified diagonal offset. All other data of the returned - * matrix will be zero. - * - * @throws IllegalArgumentException If {@code diagOffset} is not in the range (-numRows, numCols). - */ - @Override - public T getTriL(int diagOffset) { - return (T) toCoo().getTriL(diagOffset).toCsr(); - } - - - /** - * Creates a deep copy of this tensor. - * - * @return A deep copy of this tensor. - */ - @Override - public T copy() { - return makeLikeTensor(shape, data.clone()); - } - - - /** - * Computes the element-wise difference between two tensors of the same shape. - * - * @param b Second tensor in the element-wise difference. - * - * @return The difference of this tensor with {@code b}. - * - * @throws TensorShapeException If this tensor and {@code b} do not have the same shape. - */ - @Override - public T sub(T b) { - SparseMatrixData destData = CsrOps.applyBinOpp( - shape, data, rowPointers, colIndices, - b.shape, b.data, b.rowPointers, b.colIndices, - Field::add, Field::addInv); - - return makeLikeTensor(shape, destData.data(), destData.rowData(), destData.colData()); - } - - /** * Computes the element-wise absolute value of this tensor. * @@ -1056,63 +167,6 @@ public T H(int... axes) { } - /** - * Sorts the indices of this tensor in lexicographical order while maintaining the associated value for each index. - */ - public void sortIndices() { - sortCsrMatrix(data, rowPointers, colIndices); - } - - - /** - *

Converts this sparse CSR matrix to an equivalent dense matrix. - * - *

The zero data of this CSR matrix will be attempted to be filled with a zero value if it could be determined during - * construction of this sparse CSR matrix. If the zero value could not be determined the zero data will be filled with - * {@code null} (this only happens when {@code nnz==0}). To avoid this, the zero element of the field for this - * matrix can be set explicitly using {@link #setZeroElement(Field)}. - * - * @return A dense matrix which is equivalent to this sparse CSR matrix. - */ - public U toDense() { - W[] dest = makeEmptyDataArray(shape.totalEntriesIntValueExact()); - CsrConversions.toDense(shape, data, rowPointers, colIndices, dest, zeroElement); - return makeLikeDenseTensor(shape, dest); - } - - - /** - * Converts this sparse CSR matrix to an equivalent sparse COO matrix. - * @return A sparse COO matrix equivalent to this sparse CSR matrix. - */ - public abstract AbstractCooFieldMatrix toCoo(); - - - /** - * Converts this CSR matrix to an equivalent sparse COO tensor. - * @return An sparse COO tensor equivalent to this CSR matrix. - */ - public abstract AbstractCooFieldTensor toTensor(); - - - /** - * Converts this CSR matrix to an equivalent COO tensor with the specified shape. - * @param newShape New shape for the COO tensor. Can be any rank but must be broadcastable to {@link #shape this.shape}. - * @return A COO tensor equivalent to this CSR matrix which has been reshaped to {@code newShape} - */ - public abstract AbstractCooFieldTensor toTensor(Shape shape); - - - /** - * Converts this sparse CSR matrix to an equivalent vector. If this matrix is not a row or column vector it will be flattened - * before conversion. - * @return A vector equivalent to this CSR matrix. - */ - public V toVector() { - return (V) toCoo().toVector(); - } - - /** *

Computes the element-wise quotient between two tensors. *

WARNING: This method is not supported for sparse tensors. If called on a sparse tensor, diff --git a/src/main/java/org/flag4j/arrays/backend/field_arrays/AbstractDenseFieldMatrix.java b/src/main/java/org/flag4j/arrays/backend/field_arrays/AbstractDenseFieldMatrix.java index 1bafb8efd..9c697c842 100644 --- a/src/main/java/org/flag4j/arrays/backend/field_arrays/AbstractDenseFieldMatrix.java +++ b/src/main/java/org/flag4j/arrays/backend/field_arrays/AbstractDenseFieldMatrix.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -26,32 +26,20 @@ import org.flag4j.algebraic_structures.Field; import org.flag4j.arrays.Shape; -import org.flag4j.arrays.SparseMatrixData; import org.flag4j.arrays.backend.MatrixMixin; +import org.flag4j.arrays.backend.ring_arrays.AbstractDenseRingMatrix; import org.flag4j.arrays.dense.Matrix; import org.flag4j.linalg.ops.TransposeDispatcher; +import org.flag4j.linalg.ops.common.field_ops.FieldOps; import org.flag4j.linalg.ops.common.ring_ops.RingOps; -import org.flag4j.linalg.ops.dense.field_ops.DenseFieldProperties; -import org.flag4j.linalg.ops.dense.semiring_ops.DenseSemiringConversions; -import org.flag4j.linalg.ops.dense.semiring_ops.DenseSemiringMatMultDispatcher; -import org.flag4j.util.ArrayUtils; -import org.flag4j.util.ValidateParameters; -import org.flag4j.util.exceptions.LinearAlgebraException; - -import java.util.Arrays; +import org.flag4j.linalg.ops.common.ring_ops.RingProperties; +import org.flag4j.linalg.ops.dense.field_ops.DenseFieldElemDiv; +// TODO: Javadoc. public abstract class AbstractDenseFieldMatrix, U extends AbstractDenseFieldVector, V extends Field> - extends AbstractDenseFieldTensor implements MatrixMixin { - - /** - * The number of rows in this matrix. - */ - public final int numRows; - /** - * The number of columns in this matrix. - */ - public final int numCols; + extends AbstractDenseRingMatrix + implements MatrixMixin, FieldTensorMixin { /** @@ -63,32 +51,9 @@ public abstract class AbstractDenseFieldMatrix makeLikeCooMatrix( - Shape shape, V[] entries, int[] rowIndices, int[] colIndices); - - /** * Constructs a sparse CSR matrix which is of a similar type as this dense matrix. * @param shape Shape of the CSR matrix. @@ -102,877 +67,12 @@ protected AbstractDenseFieldMatrix(Shape shape, V[] data) { /** - * Gets the length of the data array which backs this matrix. - * - * @return The length of the data array which backs this matrix. - */ - @Override - public int dataLength() { - return data.length; - } - - - /** - * Computes the transpose of a tensor by exchanging the first and last axes of this tensor. - * - * @return The transpose of this tensor. - * - * @see #T(int, int) - * @see #T(int...) - */ - @Override - public T T() { - V[] dest = makeEmptyDataArray(data.length); - TransposeDispatcher.dispatch(data, shape, dest); - return makeLikeTensor(shape.swapAxes(0, 1), dest); - } - - - /** - * Gets the number of rows in this matrix. - * - * @return The number of rows in this matrix. - */ - @Override - public int numRows() { - return numRows; - } - - - /** - * Gets the number of columns in this matrix. - * - * @return The number of columns in this matrix. - */ - @Override - public int numCols() { - return numCols; - } - - - /** - * Gets the element of this matrix at this specified {@code row} and {@code col}. - * - * @param row Row index of the item to get from this matrix. - * @param col Column index of the item to get from this matrix. + * Checks if this matrix is unitary. That is, if the inverse of this matrix is approximately equal to its transpose. * - * @return The element of this matrix at the specified index. + * @return {@code true} if this matrix it is unitary; {@code false} otherwise. */ - @Override - public V get(int row, int col) { - return data[row*numCols + col]; - } - - - /** - *

Computes the trace of this matrix. That is, the sum of elements along the principle diagonal of this matrix. - * - *

Same as {@link #trace()}. - * - * @return The trace of this matrix. - * - * @throws IllegalArgumentException If this matrix is not square. - */ - @Override - public V tr() { - ValidateParameters.ensureSquareMatrix(shape); - V sum = data[0]; - int colsOffset = this.numCols + 1; - - for(int i=1; imay, be noticeably faster than directly computing the transpose followed by the - * multiplication as {@code this.mult(b.T())}. - * - * @param b The second matrix in the multiplication and the matrix to transpose. - * - * @return The result of multiplying this matrix with the transpose of {@code b}. - */ - @Override - public T multTranspose(T b) { - V[] dest = makeEmptyDataArray(numRows*b.numRows); - DenseSemiringMatMultDispatcher.dispatchTranspose(data, shape, b.data, b.shape, dest); - return makeLikeTensor(new Shape(numRows, b.numRows), dest); - } - - - /** - * Stacks matrices along columns.
- * - * @param b Matrix to stack to this matrix. - * - * @return The result of stacking this matrix on top of the matrix {@code b}. - * - * @throws IllegalArgumentException If this matrix and matrix {@code b} have a different number of columns. - * @see #stack(MatrixMixin, int) - * @see #augment(T) - */ - @Override - public T stack(T b) { - ValidateParameters.ensureArrayLengthsEq(this.numCols, b.numCols); - Shape stackedShape = new Shape(this.numRows + b.numRows, this.numCols); - V[] stackedEntries = makeEmptyDataArray(stackedShape.totalEntries().intValueExact()); - - System.arraycopy(this.data, 0, stackedEntries, 0, this.data.length); - System.arraycopy(b.data, 0, stackedEntries, this.data.length, b.data.length); - - return makeLikeTensor(stackedShape, stackedEntries); - } - - - /** - * Stacks matrices along rows. - * - * @param b Matrix to stack to this matrix. - * - * @return The result of stacking {@code b} to the right of this matrix. - * - * @throws IllegalArgumentException If this matrix and matrix {@code b} have a different number of rows. - * @see #stack(T) - * @see #stack(MatrixMixin, int) - */ - @Override - public T augment(T b) { - ValidateParameters.ensureArrayLengthsEq(numRows, b.numRows); - - int augNumCols = numCols + b.numCols; - Shape augShape = new Shape(numRows, augNumCols); - V[] augEntries = makeEmptyDataArray(numRows*augNumCols); - - // Copy data from this matrix. - for(int i=0; iNOT a view into the matrix. - * - * @throws ArrayIndexOutOfBoundsException If any of the indices are out of bounds of this matrix. - * @throws IllegalArgumentException If {@code rowEnd} is not greater than {@code rowStart} or if {@code colEnd} is not greater than {@code colStart}. - */ - @Override - public T getSlice(int rowStart, int rowEnd, int colStart, int colEnd) { - ValidateParameters.ensureValidArrayIndices(numRows, rowStart, rowEnd); - ValidateParameters.ensureValidArrayIndices(numCols, colStart, colEnd); - - int sliceRows = rowEnd-rowStart; - int sliceCols = colEnd-colStart; - int destPos = 0; - V[] slice = makeEmptyDataArray(sliceRows*sliceCols); - - for(int i=rowStart; i - *

  • If zero, then all data at and above the principle diagonal of this matrix are extracted.
  • - *
  • If positive, then all data at and above the equivalent super-diagonal are extracted.
  • - *
  • If negative, then all data at and above the equivalent sub-diagonal are extracted.
  • - * - * - * @return The upper-triangular portion of this matrix with a specified diagonal offset. All other data of the returned - * matrix will be zero. - * - * @throws IllegalArgumentException If {@code diagOffset} is not in the range (-numRows, numCols). - */ - @Override - public T getTriU(int diagOffset) { - ValidateParameters.ensureInRange(diagOffset, -numRows+1, numCols-1, "diagOffset"); - V[] copyEntries = makeEmptyDataArray(data.length); - Arrays.fill(copyEntries, (data.length > 0) ? data[0].getZero() : null); - T result = makeLikeTensor(shape, copyEntries); - - // Extract the upper triangular portion - for(int i=0; i= i + diagOffset) - result.data[rowOffset + j] = data[rowOffset + j]; - } - } - - return result; - } - - - /** - * Extracts the lower-triangular portion of this matrix with a specified diagonal offset. All other data of the resulting - * matrix will be zero. - * - * @param diagOffset Diagonal offset for lower-triangular portion to extract: - *
      - *
    • If zero, then all data at and above the principle diagonal of this matrix are extracted.
    • - *
    • If positive, then all data at and above the equivalent super-diagonal are extracted.
    • - *
    • If negative, then all data at and above the equivalent sub-diagonal are extracted.
    • - *
    - * - * @return The lower-triangular portion of this matrix with a specified diagonal offset. All other data of the returned - * matrix will be zero. - * - * @throws IllegalArgumentException If {@code diagOffset} is not in the range (-numRows, numCols). - */ - @Override - public T getTriL(int diagOffset) { - ValidateParameters.ensureInRange(diagOffset, -numRows+1, numCols-1, "diagOffset"); - V[] copyEntries = makeEmptyDataArray(data.length); - Arrays.fill(copyEntries, (data.length > 0) ? data[0].getZero() : null); - T result = makeLikeTensor(shape, copyEntries); - - // Extract the lower triangular portion - for(int i=0; i - *
  • If {@code diagOffset == 0}: Then the elements of the principle diagonal are collected.
  • - *
  • If {@code diagOffset < 0}: Then the elements of the sub-diagonal {@code diagOffset} below the principle diagonal - * are collected.
  • - *
  • If {@code diagOffset > 0}: Then the elements of the super-diagonal {@code diagOffset} above the principle diagonal - * are collected.
  • - * - * - * @return The elements of the specified diagonal as a vector. - */ - @Override - public U getDiag(int diagOffset) { - ValidateParameters.ensureInRange(diagOffset, -(numRows-1), numCols-1, "diagOffset"); - - // Check for some quick returns. - if(numRows == 1 && diagOffset > 0) return makeLikeVector((V[]) new Field[]{data[diagOffset]}); - if(numCols == 1 && diagOffset < 0) return makeLikeVector((V[]) new Field[]{data[-diagOffset]}); - - // Compute the length of the diagonal. - int newSize = Math.min(numRows, numCols); - int idx = 0; - - if(diagOffset > 0) { - newSize = Math.min(newSize, numCols - diagOffset); - idx = diagOffset; - } - else if(diagOffset < 0) { - newSize = Math.min(newSize, numRows + diagOffset); - idx = -diagOffset*numCols; - } - - V[] diag = makeEmptyDataArray(newSize); - - for(int i=0; i - *
  • If {@code axis == 0} a matrix with the shape {@code (this.numRows*this.numCols, 1)} is returned.
  • - *
  • If {@code axis == 1} a matrix with the shape {@code (1, this.numRows*this.numCols)} is returned.
  • - * + * Computes the element-wise square root of this tensor. * - * @throws ArrayIndexOutOfBoundsException If the axis is negative or larger than {@code this.{@link #getRank()}-1}. - * @see #flatten() + * @return The element-wise square root of this tensor. */ @Override - public T flatten(int axis) { - ValidateParameters.ensureValidAxes(shape, axis); - return (axis == 0) - ? makeLikeTensor(new Shape(data.length, 1), data.clone()) - : makeLikeTensor(new Shape(1, data.length), data.clone()); - } - - - /** - * Converts this matrix to an equivalent sparse COO matrix. - * @return A sparse COO matrix that is equivalent to this dense matrix. - * @see #toCoo(double) - */ - public AbstractCooFieldMatrix toCoo() { - return toCoo(0.01); - } - - - /** - * Converts this matrix to an equivalent sparse COO matrix. - * @param estimatedSparsity Estimated sparsity of the matrix. Must be between 0 and 1 inclusive. If this is an accurate estimation - * it may provide a slight speedup and can reduce unneeded memory consumption. If memory is a concern, it is better to - * over-estimate the sparsity. If speed is the concern it is better to under-estimate the sparsity. - * @return A sparse COO matrix that is equivalent to this dense matrix. - * @see #toCoo() - */ - public AbstractCooFieldMatrix toCoo(double estimatedSparsity) { - SparseMatrixData data = DenseSemiringConversions.toCoo(shape, this.data, 0.1); - V[] cooEntries = data.data().toArray(makeEmptyDataArray(data.data().size())); - int[] rowIndices = ArrayUtils.fromIntegerList(data.rowData()); - int[] colIndices = ArrayUtils.fromIntegerList(data.colData()); - - return makeLikeCooMatrix(data.shape(), cooEntries, rowIndices, colIndices); + public T sqrt() { + V[] dest = makeEmptyDataArray(data.length); + FieldOps.sqrt(data, dest); + return makeLikeTensor(shape, dest); } /** - * Converts this matrix to an equivalent sparse CSR matrix. - * @return A sparse CSR matrix that is equivalent to this dense matrix. - * @see #toCsr(double) + * Checks if this tensor only contains finite values. + * + * @return {@code true} if this tensor only contains finite values; {@code false} otherwise. + * + * @see #isInfinite() + * @see #isNaN() */ - public AbstractCsrFieldMatrix toCsr() { - return toCoo(0.01).toCsr(); + @Override + public boolean isFinite() { + return FieldOps.isFinite(data); } /** - * Converts this matrix to an equivalent sparse CSR matrix. - * @param estimatedSparsity Estimated sparsity of the matrix. Must be between 0 and 1 inclusive. If this is an accurate estimation - * it may provide a slight speedup and can reduce unneeded memory consumption. If memory is a concern, it is better to - * over-estimate the sparsity. If speed is the concern it is better to under-estimate the sparsity. - * @return A sparse CSR matrix that is equivalent to this dense matrix. - * @see #toCsr() + * Checks if this tensor contains at least one infinite value. + * + * @return {@code true} if this tensor contains at least one infinite value; {@code false} otherwise. + * + * @see #isFinite() + * @see #isNaN() */ - public AbstractCsrFieldMatrix toCsr(double estimatedSparsity) { - return toCoo(estimatedSparsity).toCsr(); + @Override + public boolean isInfinite() { + return FieldOps.isInfinite(data); } /** - * Converts this matrix to an equivalent vector. If this matrix is not a row or column vector it will first be flattened then - * converted to a vector. + * Checks if this tensor contains at least one NaN value. * - * @return A vector which contains the same data as this matrix. + * @return {@code true} if this tensor contains at least one NaN value; {@code false} otherwise. + * + * @see #isFinite() + * @see #isInfinite() */ @Override - public U toVector() { - return makeLikeVector(data.clone()); + public boolean isNaN() { + return FieldOps.isInfinite(data); } /** - * Converts this matrix to an equivalent tensor. - * @return A tensor with the same shape and data as this matrix. + * Checks if all data of this matrix are 'close' as defined below. Custom tolerances may be specified using + * {@link #allClose(AbstractDenseFieldMatrix, double, double)}. + * @param b Second tensor in the comparison. + * @return True if both tensors have the same shape and all data are 'close' element-wise, i.e. + * elements {@code x} and {@code y} at the same positions in the two tensors respectively and satisfy + * {@code |x-y| <= (1E-08 + 1E-05*|y|)}. Otherwise, returns false. + * @see #allClose(AbstractDenseFieldMatrix, double, double) (AbstractDenseFieldTensor, double, double) */ - public abstract AbstractDenseFieldTensor toTensor(); + public boolean allClose(T b) { + return sameShape(b) && RingProperties.allClose(data, b.data); + } /** - * Converts this matrix to an equivalent tensor with the specified {@code newShape}. - * @param newShape Shape of the tensor. Can be any rank but must be broadcastable to the shape of this matrix. - * @return A tensor with the specified {@code newShape} and the same data as this matrix. + * Checks if all data of this matrix are 'close' as defined below. + * @param b Second tensor in the comparison. + * @return True if both tensors have the same length and all data are 'close' element-wise, i.e. + * elements {@code x} and {@code y} at the same positions in the two tensors respectively and satisfy + * {@code |x-y| <= (absTol + relTol*|y|)}. Otherwise, returns false. + * @see #allClose(AbstractDenseFieldMatrix) */ - public abstract AbstractDenseFieldTensor toTensor(Shape newShape); + public boolean allClose(T b, double relTol, double absTol) { + return sameShape(b) && RingProperties.allClose(data, b.data, relTol, absTol); + } } diff --git a/src/main/java/org/flag4j/arrays/backend/field_arrays/AbstractDenseFieldTensor.java b/src/main/java/org/flag4j/arrays/backend/field_arrays/AbstractDenseFieldTensor.java index e9e3f95c5..cff9974f6 100644 --- a/src/main/java/org/flag4j/arrays/backend/field_arrays/AbstractDenseFieldTensor.java +++ b/src/main/java/org/flag4j/arrays/backend/field_arrays/AbstractDenseFieldTensor.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -26,23 +26,12 @@ import org.flag4j.algebraic_structures.Field; import org.flag4j.arrays.Shape; -import org.flag4j.arrays.SparseTensorData; -import org.flag4j.arrays.backend.AbstractTensor; -import org.flag4j.arrays.backend.VectorMixin; +import org.flag4j.arrays.backend.ring_arrays.AbstractDenseRingTensor; import org.flag4j.linalg.ops.TransposeDispatcher; import org.flag4j.linalg.ops.common.field_ops.FieldOps; -import org.flag4j.linalg.ops.common.field_ops.FieldProperties; -import org.flag4j.linalg.ops.dense.DenseSemiringTensorDot; +import org.flag4j.linalg.ops.common.ring_ops.RingProperties; import org.flag4j.linalg.ops.dense.field_ops.DenseFieldElemDiv; -import org.flag4j.linalg.ops.dense.real.RealDenseTranspose; -import org.flag4j.linalg.ops.dense.ring_ops.DenseRingTensorOps; -import org.flag4j.linalg.ops.dense.semiring_ops.DenseSemiringConversions; -import org.flag4j.linalg.ops.dense.semiring_ops.DenseSemiringElemMult; -import org.flag4j.linalg.ops.dense.semiring_ops.DenseSemiringOps; import org.flag4j.util.ValidateParameters; -import org.flag4j.util.exceptions.TensorShapeException; - -import java.util.Arrays; /** *

    The base class for all dense {@link Field} tensors. @@ -54,14 +43,9 @@ * @param The type of the {@link Field} which this tensor's data belong to. */ public abstract class AbstractDenseFieldTensor, V extends Field> - extends AbstractTensor + extends AbstractDenseRingTensor implements FieldTensorMixin { - /** - * The zero element for the field that this tensor's elements belong to. - */ - private V zeroElement; - /** * Creates a tensor with the specified data and shape. * @@ -76,232 +60,6 @@ protected AbstractDenseFieldTensor(Shape shape, V[] entries) { } - /** - * Sets the zero element for the field of this tensor. - * @param zeroElement The zero element of this tensor. - * @throws IllegalArgumentException If {@code zeroElement} is not an additive identity for the field. - * - * @see #getZeroElement() - */ - public void setZeroElement(V zeroElement) { - if (zeroElement.isZero()) { - this.zeroElement = zeroElement; - } else { - throw new IllegalArgumentException("The provided zeroElement is not an additive identity."); - } - } - - - /** - * Gets the zero element for the field of this tensor. - * @return The zero element for the field of this tensor. If it could not be determined during construction of this object - * and has not been set explicitly by {@link #setZeroElement(Field)} then {@code null} will be returned. - * - * @see #setZeroElement(Field) - */ - public V getZeroElement() { - return zeroElement; - } - - - /** - * Constructs a sparse COO tensor which is of a similar type as this dense tensor. - * @param shape Shape of the COO tensor. - * @param entries Non-zero data of the COO tensor. - * @param rowIndices Non-zero row indices of the COO tensor. - * @param colIndices Non-zero column indices of the COO tensor. - * @return A sparse COO tensor which is of a similar type as this dense tensor. - */ - protected abstract AbstractTensor makeLikeCooTensor( - Shape shape, V[] entries, int[][] indices); - - - /** - * Gets the element of this tensor at the specified indices. - * - * @param indices Indices of the element to get. - * - * @return The element of this tensor at the specified indices. - * - * @throws ArrayIndexOutOfBoundsException If any indices are not within this tensor. - */ - @Override - public V get(int... indices) { - return data[shape.getFlatIndex(indices)]; - } - - - /** - * Computes the transpose of a tensor by exchanging {@code axis1} and {@code axis2}. - * - * @param axis1 First axis to exchange. - * @param axis2 Second axis to exchange. - * - * @return The transpose of this tensor according to the specified axes. - * - * @throws IndexOutOfBoundsException If either {@code axis1} or {@code axis2} are out of bounds for the rank of this tensor. - * @see #T() - * @see #T(int...) - */ - @Override - public T T(int axis1, int axis2) { - ValidateParameters.ensureValidAxes(shape, axis1, axis2); - V[] dest = makeEmptyDataArray(data.length); - TransposeDispatcher.dispatchTensor(data, shape, axis1, axis2, dest); - return makeLikeTensor(shape.swapAxes(axis1, axis2), dest); - } - - - /** - * Computes the transpose of this tensor. That is, permutes the axes of this tensor so that it matches - * the permutation specified by {@code axes}. - * - * @param axes Permutation of tensor axis. If the tensor has rank {@code N}, then this must be an array of length - * {@code N} which is a permutation of {@code {0, 1, 2, ..., N-1}}. - * - * @return The transpose of this tensor with its axes permuted by the {@code axes} array. - * - * @throws IndexOutOfBoundsException If any element of {@code axes} is out of bounds for the rank of this tensor. - * @throws IllegalArgumentException If {@code axes} is not a permutation of {@code {1, 2, 3, ... N-1}}. - * @see #T(int, int) - * @see #T() - */ - @Override - public T T(int... axes) { - V[] dest = makeEmptyDataArray(data.length); - TransposeDispatcher.dispatchTensor(data, shape, axes, dest); - return makeLikeTensor(shape.permuteAxes(axes), dest); - } - - - /** - * Creates a deep copy of this tensor. - * - * @return A deep copy of this tensor. - */ - @Override - public T copy() { - return makeLikeTensor(shape, data.clone()); - } - - - /** - * Sets the element of this tensor at the specified indices. - * - * @param value New value to set the specified index of this tensor to. - * @param indices Indices of the element to set. - * - * @return If this tensor is dense, a reference to this tensor is returned. If this tensor is sparse, a copy of this tensor with - * the updated value is returned. - * - * @throws IndexOutOfBoundsException If {@code indices} is not within the bounds of this tensor. - */ - @Override - public T set(V value, int... indices) { - data[shape.getFlatIndex(indices)] = value; - return (T) this; - } - - - /** - * Flattens tensor to single dimension while preserving order of data. - * - * @return The flattened tensor. - * - * @see #flatten(int) - */ - @Override - public T flatten() { - return makeLikeTensor(shape.flatten(), data.clone()); - } - - - /** - * Flattens a tensor along the specified axis. Unlike {@link #flatten()} - * - * @param axis Axis along which to flatten tensor. - * - * @throws ArrayIndexOutOfBoundsException If the axis is not positive or larger than {@code this.{@link #getRank()}-1}. - * @see #flatten() - */ - @Override - public T flatten(int axis) { - ValidateParameters.ensureValidAxes(shape, axis); - int[] dims = new int[this.getRank()]; - Arrays.fill(dims, 1); - dims[axis] = shape.totalEntries().intValueExact(); - Shape flatShape = new Shape(dims); - - return makeLikeTensor(flatShape, data.clone()); - } - - - /** - * Copies and reshapes this tensor. - * - * @param newShape New shape for the tensor. - * - * @return A copy of this tensor with the new shape. - * - * @throws TensorShapeException If {@code newShape} is not broadcastable to {@link #shape this.shape}. - */ - @Override - public T reshape(Shape newShape) { - // No need to make explicit broadcastable check as the constructor will verify that the number of data in the shape matches - // the number of data in the array. - return makeLikeTensor(newShape, data.clone()); - } - - - /** - * Computes the element-wise difference between two tensors of the same shape. - * - * @param b Second tensor in the element-wise difference. - * - * @return The difference of this tensor with {@code b}. - * - * @throws TensorShapeException If this tensor and {@code b} do not have the same shape. - */ - @Override - public T sub(T b) { - V[] diff = makeEmptyDataArray(data.length); - DenseRingTensorOps.sub(shape, data, b.shape, b.data, diff); - return makeLikeTensor(shape, diff); - } - - - /** - * Computes the element-wise difference between two tensors of the same shape and stores the result in this tensor. - * - * @param b Second tensor in the element-wise difference. - * - * @throws TensorShapeException If this tensor and {@code b} do not have the same shape. - */ - public void subEq(T b) { - DenseRingTensorOps.sub(shape, data, b.shape, b.data, data); - } - - - /** - * Computes the conjugate transpose of a tensor by conjugating and exchanging {@code axis1} and {@code axis2}. - * - * @param axis1 First axis to exchange and conjugate. - * @param axis2 Second axis to exchange and conjugate. - * - * @return The conjugate transpose of this tensor according to the specified axes. - * - * @throws IndexOutOfBoundsException If either {@code axis1} or {@code axis2} are out of bounds for the rank of this tensor. - * @see #H() - * @see #H(int...) - */ - @Override - public T H(int axis1, int axis2) { - V[] dest = makeEmptyDataArray(data.length); - TransposeDispatcher.dispatchTensorHermitian(shape, data, axis1, axis2, dest); - return makeLikeTensor(shape.swapAxes(axis1, axis2), dest); - } - - /** * Computes the conjugate transpose of this tensor. That is, conjugates and permutes the axes of this tensor so that it matches * the permutation specified by {@code axes}. @@ -324,98 +82,6 @@ public T H(int... axes) { } - /** - * Computes the element-wise sum between two tensors of the same shape. - * - * @param b Second tensor in the element-wise sum. - * - * @return The sum of this tensor with {@code b}. - * - * @throws TensorShapeException If this tensor and {@code b} do not have the same shape. - */ - @Override - public T add(T b) { - V[] sum = makeEmptyDataArray(data.length); - DenseSemiringOps.add(data, shape, b.data, b.shape, sum); - return makeLikeTensor(shape, sum); - } - - - /** - * Computes the element-wise sum between two tensors of the same shape and stores the result in this tensor. - * - * @param b Second tensor in the element-wise sum. - */ - public void addEq(T b) { - DenseSemiringOps.add(data, shape, b.data, b.shape, data); - } - - - /** - * Computes the element-wise multiplication of two tensors of the same shape. - * - * @param b Second tensor in the element-wise product. - * - * @return The element-wise product between this tensor and {@code b}. - * - * @throws IllegalArgumentException If this tensor and {@code b} do not have the same shape. - */ - @Override - public T elemMult(T b) { - V[] prod = makeEmptyDataArray(data.length); - DenseSemiringElemMult.dispatch(data, shape, b.data, b.shape, prod); - return makeLikeTensor(shape, prod); - } - - - /** - * Computes the tensor contraction of this tensor with a specified tensor over the specified set of axes. That is, - * computes the sum of products between the two tensors along the specified set of axes. - * - * @param src2 Tensor to contract with this tensor. - * @param aAxes Axes along which to compute products for this tensor. - * @param bAxes Axes along which to compute products for {@code src2} tensor. - * - * @return The tensor dot product over the specified axes. - * - * @throws IllegalArgumentException If the two tensors shapes do not match along the specified axes pairwise in - * {@code aAxes} and {@code bAxes}. - * @throws IllegalArgumentException If {@code aAxes} and {@code bAxes} do not match in length, or if any of the axes - * are out of bounds for the corresponding tensor. - */ - @Override - public T tensorDot(T src2, int[] aAxes, int[] bAxes) { - DenseSemiringTensorDot dot = new DenseSemiringTensorDot(shape, data, src2.shape, src2.data, aAxes, bAxes); - V[] dest = makeEmptyDataArray(dot.getOutputSize()); - dot.compute(dest); - return makeLikeTensor(dot.getOutputShape(), dest); - } - - - /** - *

    Computes the generalized trace of this tensor along the specified axes. - * - *

    The generalized tensor trace is the sum along the diagonal values of the 2D sub-arrays of this tensor specified by - * {@code axis1} and {@code axis2}. The shape of the resulting tensor is equal to this tensor with the - * {@code axis1} and {@code axis2} removed. - * - * @param axis1 First axis for 2D sub-array. - * @param axis2 Second axis for 2D sub-array. - * - * @return The generalized trace of this tensor along {@code axis1} and {@code axis2}. - * - * @throws IndexOutOfBoundsException If the two axes are not both larger than zero and less than this tensors rank. - * @throws IllegalArgumentException If {@code axis1 == axis2} or {@code this.shape.get(axis1) != this.shape.get(axis1)} - * (i.e. the axes are equal or the tensor does not have the same length along the two axes.) - */ - @Override - public T tensorTr(int axis1, int axis2) { - Shape destShape = DenseSemiringOps.getTrShape(shape, axis1, axis2); - V[] dest = makeEmptyDataArray(destShape.totalEntriesIntValueExact()); - return makeLikeTensor(destShape, dest); - } - - /** * Computes the element-wise quotient between two tensors. * @@ -486,51 +152,17 @@ public boolean isNaN() { } - /** - * Converts this tensor to an equivalent sparse COO tensor. - * @return A sparse COO tensor that is equivalent to this dense tensor. - * @see #toCoo(double) - */ - public AbstractTensor toCoo() { - return toCoo(0.9); - } - - - /** - * Converts this tensor to an equivalent sparse COO tensor. - * @param estimatedSparsity Estimated sparsity of the tensor. Must be between 0 and 1 inclusive. If this is an accurate estimation - * it may provide a slight speedup and can reduce unneeded memory consumption. If memory is a concern, it is better to - * over-estimate the sparsity. If speed is the concern it is better to under-estimate the sparsity. - * @return A sparse COO tensor that is equivalent to this dense tensor. - * @see #toCoo(double) - */ - public AbstractTensor toCoo(double estimatedSparsity) { - SparseTensorData data = DenseSemiringConversions.toCooTensor(shape, this.data, estimatedSparsity); - V[] cooEntries = data.data().toArray(makeEmptyDataArray(data.data().size())); - - // TODO: First check if this tensor is a vector then delegate to specialized toCooVector - // or toCooTensor methods. - if(this instanceof VectorMixin) { - return makeLikeCooTensor( - data.shape(), cooEntries, - RealDenseTranspose.standardIntMatrix(data.indicesToArray())); - } else { - return makeLikeCooTensor(data.shape(), cooEntries, data.indicesToArray()); - } - } - - /** * Checks if all data of this matrix are 'close' as defined below. Custom tolerances may be specified using * {@link #allClose(AbstractDenseFieldTensor, double, double)}. * @param b Second tensor in the comparison. * @return True if both tensors have the same shape and all data are 'close' element-wise, i.e. * elements {@code x} and {@code y} at the same positions in the two tensors respectively and satisfy - * {@code |x-y| <= (1E-05 + 1E-08*|y|)}. Otherwise, returns false. + * {@code |x-y| <= (1E-08 + 1E-05*|y|)}. Otherwise, returns false. * @see #allClose(AbstractDenseFieldTensor, double, double) */ public boolean allClose(T b) { - return sameShape(b) && FieldProperties.allClose(data, b.data); + return sameShape(b) && RingProperties.allClose(data, b.data); } @@ -543,6 +175,6 @@ public boolean allClose(T b) { * @see #allClose(AbstractDenseFieldTensor) */ public boolean allClose(T b, double relTol, double absTol) { - return sameShape(b) && FieldProperties.allClose(data, b.data, relTol, absTol); + return sameShape(b) && RingProperties.allClose(data, b.data, relTol, absTol); } } diff --git a/src/main/java/org/flag4j/arrays/backend/field_arrays/AbstractDenseFieldVector.java b/src/main/java/org/flag4j/arrays/backend/field_arrays/AbstractDenseFieldVector.java index b19f715ff..39fc9c503 100644 --- a/src/main/java/org/flag4j/arrays/backend/field_arrays/AbstractDenseFieldVector.java +++ b/src/main/java/org/flag4j/arrays/backend/field_arrays/AbstractDenseFieldVector.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -27,12 +27,14 @@ import org.flag4j.algebraic_structures.Field; import org.flag4j.arrays.Shape; import org.flag4j.arrays.backend.VectorMixin; +import org.flag4j.arrays.backend.ring_arrays.AbstractDenseRingVector; import org.flag4j.arrays.dense.Vector; import org.flag4j.linalg.VectorNorms; +import org.flag4j.linalg.ops.common.field_ops.FieldOps; import org.flag4j.linalg.ops.common.ring_ops.RingOps; -import org.flag4j.linalg.ops.dense.DenseConcat; +import org.flag4j.linalg.ops.common.ring_ops.RingProperties; +import org.flag4j.linalg.ops.dense.field_ops.DenseFieldElemDiv; import org.flag4j.linalg.ops.dense.field_ops.DenseFieldVectorOps; -import org.flag4j.linalg.ops.dense.semiring_ops.DenseSemiringVectorOps; import org.flag4j.util.ValidateParameters; @@ -49,13 +51,8 @@ */ public abstract class AbstractDenseFieldVector, U extends AbstractDenseFieldMatrix, V extends Field> - extends AbstractDenseFieldTensor - implements VectorMixin { - - /** - * The size of this vector. This is the total number of data stored in this vector. - */ - public final int size; + extends AbstractDenseRingVector + implements VectorMixin, FieldTensorMixin { /** @@ -67,8 +64,6 @@ public abstract class AbstractDenseFieldVectorComputes the inner product between two vectors. * @@ -140,197 +103,179 @@ public V inner(T b) { /** - *

    Computes the dot product between two vectors. - * - *

    Note: this method is distinct from {@link #inner(AbstractDenseFieldVector)}. The inner product is equivalent to the dot product - * of this tensor with the conjugation of {@code b}. + * Computes the outer product of two vectors. * - * @param b Second vector in the dot product. + * @param b Second vector in the outer product. * - * @return The dot product between this vector and the vector {@code b}. + * @return The result of the vector outer product between this vector and {@code b}. * - * @throws IllegalArgumentException If this vector and vector {@code b} do not have the same number of data. - * @see #inner(AbstractDenseFieldVector) + * @throws IllegalArgumentException If the two vectors do not have the same number of data. */ @Override - public V dot(T b) { - return DenseSemiringVectorOps.dotProduct(data, b.data); + public U outer(T b) { + V[] dest = makeEmptyDataArray(size*b.size); + DenseFieldVectorOps.outerProduct(data, b.data, dest); + return makeLikeMatrix(new Shape(size, b.size), dest); } /** - * Gets the length of a vector. Same as {@link #size()}. + * Computes the element-wise absolute value of this tensor. * - * @return The length, i.e. the number of data, in this vector. + * @return The element-wise absolute value of this tensor. */ @Override - public int length() { - return size; + public Vector abs() { + double[] abs = new double[data.length]; + RingOps.abs(data, abs); + return new Vector(shape, abs); } /** - * Repeats a vector {@code n} times along a certain axis to create a matrix. - * - * @param n Number of times to repeat vector. Must be positive. - * @param axis Axis along which to repeat vector. Must be either 1 or 0. - *

      - *
    • If {@code axis=0}, then the vector will be treated as a row vector and stacked vertically {@code n} times.
    • - *
    • If {@code axis=1} then the vector will be treated as a column vector and stacked horizontally {@code n} times.
    • - *
    + * Normalizes this vector to a unit length vector. * - * @return A matrix whose rows/columns are this vector repeated. + * @return This vector normalized to a unit length. */ @Override - public U repeat(int n, int axis) { - V[] dest = makeEmptyDataArray(size*n); - DenseConcat.repeat(data, n, axis, dest); // axis is verified to be 1 or 0 here. - Shape shape = (axis==0) ? new Shape(n, size) : new Shape(size, n); - - return makeLikeMatrix(shape, dest); + public T normalize() { + V[] dest = makeEmptyDataArray(size); + FieldOps.div(data, mag(), dest); + return makeLikeTensor(shape, dest); } /** - *

    Stacks two vectors along specified axis. - * - *

    Stacking two vectors of length {@code n} along axis 0 stacks the vectors - * as if they were row vectors resulting in a {@code 2-by-n} matrix. - * - *

    Stacking two vectors of length {@code n} along axis 1 stacks the vectors - * as if they were column vectors resulting in a {@code n-by-2} matrix. - * - * @param b Vector to stack with this vector. - * @param axis Axis along which to stack vectors. If {@code axis=0}, then vectors are stacked as if they are row - * vectors. If {@code axis=1}, then vectors are stacked as if they are column vectors. - * - * @return The result of stacking this vector and the vector {@code b}. + * Computes the magnitude of this vector. * - * @throws IllegalArgumentException If the number of data in this vector is different from the number of - * data in the vector {@code b}. - * @throws IllegalArgumentException If axis is not either 0 or 1. + * @return The magnitude of this vector. */ @Override - public U stack(T b, int axis) { - V[] dest = makeEmptyDataArray(2*size); - DenseConcat.stack(data, b.data, axis, dest); - Shape shape = (axis==0) ? new Shape(2, size) : new Shape(size, 2); - return makeLikeMatrix(shape, dest); + public V mag() { + V mag = getZeroElement(); + + for(int i=0; i - *

  • If {@code true}, the vector will be converted to a matrix representing a column vector.
  • - *
  • If {@code false}, The vector will be converted to a matrix representing a row vector.
  • - * + * @param b Second matrix in the element-wise quotient. * - * @return A matrix equivalent to this vector. + * @return The element-wise quotient of this matrix and {@code b}. */ @Override - public U toMatrix(boolean columVector) { - if(columVector) { - // Convert to column vector. - return makeLikeMatrix(new Shape(data.length, 1), data.clone()); - } else { - // Convert to row vector. - return makeLikeMatrix(new Shape(1, data.length), data.clone()); - } + public T div(T b) { + V[] dest = makeEmptyDataArray(data.length); + DenseFieldElemDiv.dispatch(data, shape, b.data, b.shape, dest); + return makeLikeTensor(shape, dest); } /** - * Computes the element-wise absolute value of this tensor. + * Computes the element-wise square root of this tensor. * - * @return The element-wise absolute value of this tensor. + * @return The element-wise square root of this tensor. */ @Override - public Vector abs() { - double[] abs = new double[data.length]; - RingOps.abs(data, abs); - return new Vector(shape, abs); + public T sqrt() { + V[] dest = makeEmptyDataArray(data.length); + FieldOps.sqrt(data, dest); + return makeLikeTensor(shape, dest); } /** - * Normalizes this vector to a unit length vector. + * Checks if this tensor only contains finite values. * - * @return This vector normalized to a unit length. + * @return {@code true} if this tensor only contains finite values; {@code false} otherwise. + * + * @see #isInfinite() + * @see #isNaN() */ @Override - public T normalize() { - return div(mag()); + public boolean isFinite() { + return FieldOps.isFinite(data); } /** - * Computes the magnitude of this vector. + * Checks if this tensor contains at least one infinite value. * - * @return The magnitude of this vector. + * @return {@code true} if this tensor contains at least one infinite value; {@code false} otherwise. + * + * @see #isFinite() + * @see #isNaN() */ @Override - public V mag() { - V mag = getZeroElement(); - - for(int i=0; i, U extends FieldTensorMixin, V extends Field> - extends TensorOverField { + extends TensorOverField, RingTensorMixin { /** * Creates an empty array of the same type as the data array of this tensor. @@ -432,6 +433,7 @@ default T div(double b) { return makeLikeTensor(getShape(), dest); } + /** * Divides each element of this tensor by a primitive scalar value and stores the result in this tensor. * @@ -448,6 +450,7 @@ default void divEq(double b) { } + // TODO: Remove norms. They need only be defined for vectors and matrices. /** * Computes the Euclidean norm of this vector. * @@ -458,6 +461,7 @@ default double norm() { } + // TODO: Remove norms. They need only be defined for vectors and matrices. /** * Computes the p-norm of this vector. * @@ -465,7 +469,7 @@ default double norm() { * * @return The Euclidean norm of this vector. */ - default double norm(int p) { + default double norm(double p) { return VectorNorms.norm(getData(), p); } } diff --git a/src/main/java/org/flag4j/arrays/backend/primitive_arrays/AbstractDenseDoubleTensor.java b/src/main/java/org/flag4j/arrays/backend/primitive_arrays/AbstractDenseDoubleTensor.java index bcd8b7905..65c1b7a6a 100644 --- a/src/main/java/org/flag4j/arrays/backend/primitive_arrays/AbstractDenseDoubleTensor.java +++ b/src/main/java/org/flag4j/arrays/backend/primitive_arrays/AbstractDenseDoubleTensor.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -395,7 +395,7 @@ public T roundToZero(double tolerance) { * @param b Second tensor in the comparison. * @return True if both tensors have the same shape and all data are 'close' element-wise, i.e. * elements {@code x} and {@code y} at the same positions in the two tensors respectively and satisfy - * {@code |x-y| <= (1E-05 + 1E-08*|y|)}. Otherwise, returns false. + * {@code |x-y| <= (1E-08 + 1E-05*|y|)}. Otherwise, returns false. * @see #allClose(AbstractDoubleTensor, double, double) */ public boolean allClose(T b) { diff --git a/src/main/java/org/flag4j/arrays/backend/ring_arrays/AbstractCooRingMatrix.java b/src/main/java/org/flag4j/arrays/backend/ring_arrays/AbstractCooRingMatrix.java index da28ee35d..a704d9cd2 100644 --- a/src/main/java/org/flag4j/arrays/backend/ring_arrays/AbstractCooRingMatrix.java +++ b/src/main/java/org/flag4j/arrays/backend/ring_arrays/AbstractCooRingMatrix.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -25,33 +25,14 @@ package org.flag4j.arrays.backend.ring_arrays; import org.flag4j.algebraic_structures.Ring; -import org.flag4j.algebraic_structures.Semiring; import org.flag4j.arrays.Shape; import org.flag4j.arrays.SparseMatrixData; -import org.flag4j.arrays.SparseVectorData; -import org.flag4j.arrays.backend.AbstractTensor; import org.flag4j.arrays.backend.MatrixMixin; -import org.flag4j.linalg.ops.common.semiring_ops.CompareSemiring; -import org.flag4j.linalg.ops.sparse.SparseElementSearch; -import org.flag4j.linalg.ops.sparse.SparseUtils; -import org.flag4j.linalg.ops.sparse.coo.*; +import org.flag4j.arrays.backend.semiring_arrays.AbstractCooSemiringMatrix; +import org.flag4j.linalg.ops.sparse.coo.CooConversions; import org.flag4j.linalg.ops.sparse.coo.ring_ops.CooRingMatrixOps; -import org.flag4j.linalg.ops.sparse.coo.semiring_ops.CooSemiringMatMult; -import org.flag4j.linalg.ops.sparse.coo.semiring_ops.CooSemiringMatrixOps; -import org.flag4j.linalg.ops.sparse.coo.semiring_ops.CooSemiringMatrixProperties; -import org.flag4j.util.ArrayUtils; -import org.flag4j.util.ValidateParameters; -import org.flag4j.util.exceptions.LinearAlgebraException; import org.flag4j.util.exceptions.TensorShapeException; -import java.math.BigDecimal; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.function.BinaryOperator; - -import static org.flag4j.linalg.ops.sparse.SparseUtils.copyRanges; - /** *

    A sparse matrix stored in coordinate list (COO) format. The {@link #data} of this COO matrix are * elements of a {@link Ring}. @@ -89,40 +70,10 @@ public abstract class AbstractCooRingMatrix, V extends AbstractCooRingVector, W extends Ring> - extends AbstractTensor + extends AbstractCooSemiringMatrix implements RingTensorMixin, MatrixMixin { - /** - * The zero element for the arrays that this tensor's elements belong to. - */ - private W zeroElement; - /** - * Row indices for non-zero value of this sparse COO matrix. - */ - public final int[] rowIndices; - /** - * column indices for non-zero value of this sparse COO matrix. - */ - public final int[] colIndices; - /** - * Number of non-zero data in this COO matrix. - */ - public final int nnz; - /** - * The number of rows in this matrix. - */ - public final int numRows; - /** - * The number of columns in this matrix. - */ - public final int numCols; - /** - * The sparsity of this matrix. - */ - public final double sparsity; - - /** * Creates a sparse coo matrix with the specified non-zero data, non-zero indices, and shape. * @@ -132,65 +83,10 @@ public abstract class AbstractCooRingMatrix 0) ? entries[0].getZero() : null; + super(shape, entries, rowIndices, colIndices); } - /** - * Constructs a sparse COO tensor of the same type as this tensor with the specified non-zero data and indices. - * @param shape Shape of the matrix. - * @param entries Non-zero data of the matrix. - * @param rowIndices Non-zero row indices of the matrix. - * @param colIndices Non-zero column indices of the matrix. - * @return A sparse COO tensor of the same type as this tensor with the specified non-zero data and indices. - */ - public abstract T makeLikeTensor(Shape shape, W[] entries, int[] rowIndices, int[] colIndices); - - - /** - * Constructs a COO matrix with the specified shape, non-zero data, and non-zero indices. - * @param shape Shape of the matrix. - * @param entries Non-zero values of the matrix. - * @param rowIndices Non-zero row indices of the matrix. - * @param colIndices Non-zero column indices of the matrix. - * @return A COO matrix with the specified shape, non-zero data, and non-zero indices. - */ - public abstract T makeLikeTensor(Shape shape, List entries, List rowIndices, List colIndices); - - - /** - * Constructs a sparse COO vector of a similar type to this COO matrix. - * @param shape Shape of the vector. Must be rank 1. - * @param entries Non-zero data of the COO vector. - * @param indices Non-zero indices of the COO vector. - * @return A sparse COO vector of a similar type to this COO matrix. - */ - public abstract V makeLikeVector(Shape shape, W[] entries, int[] indices); - - - /** - * Constructs a dense tensor with the specified {@code shape} and {@code data} which is a similar type to this sparse tensor. - * @param shape Shape of the dense tensor. - * @param entries Entries of the dense tensor. - * @return A dense tensor with the specified {@code shape} and {@code data} which is a similar type to this sparse tensor. - */ - public abstract U makeLikeDenseTensor(Shape shape, W[] entries); - - /** * Constructs a sparse CSR matrix of a similar type to this sparse COO matrix. * @param shape Shape of the CSR matrix to construct. @@ -204,1162 +100,111 @@ public abstract AbstractCsrRingMatrix makeLikeCsrMatrix( /** - * Gets the zero element for the arrays of this tensor. - * @return The zero element for the ring of this tensor. If it could not be determined during construction of this object - * and has not been set explicitly by {@link #setZeroElement(Ring)} then {@code null} will be returned. - * - * @see #setZeroElement(Ring) - */ - public W getZeroElement() { - return zeroElement; - } - - - /** - * Gets the sparsity of this matrix as a decimal percentage. - * That is, the percentage of data in this matrix that are zero. - * @return The sparsity of this matrix as a decimal percentage. - * @see #density() - */ - public double sparsity() { - return sparsity; - } - - - /** - * Gets the density of this matrix as a decimal percentage. - * That is, the percentage of data in this matrix that are non-zero. - * @return The density of this matrix as a decimal percentage. - * @see #sparsity - */ - public double density() { - return 1.0 - sparsity; - } - - - /** - * Gets the length of the data array which backs this matrix. - * - * @return The length of the data array which backs this matrix. - */ - @Override - public int dataLength() { - return data.length; - } - - - /** - * Sets the zero element for the ring of this tensor. - * @param zeroElement The zero element of this tensor. - * @throws IllegalArgumentException If {@code zeroElement} is not an additive identity for the ring. - * - * @see #getZeroElement() - */ - public void setZeroElement(W zeroElement) { - if (zeroElement.isZero()) { - this.zeroElement = zeroElement; - } else { - throw new IllegalArgumentException("The provided zeroElement is not an additive identity."); - } - } - - - /** - * Gets the element of this tensor at the specified index. - * - * @param index Indices of the element to get. - * - * @return The element of this tensor at the specified index. If there is a non-zero value with the specified index, that value - * will be returned. If there is no non-zero value at the specified index than the zero element will attempt to be - * returned (i.e. the additive identity of the ring). However, if the zero element could not be determined during - * construction or if it was not set with {@link #setZeroElement(Ring)} then - * {@code null} will be returned. - * - * @throws ArrayIndexOutOfBoundsException If any index are not within this tensor. - */ - @Override - public W get(int... index) { - ValidateParameters.validateTensorIndex(shape, index); - W value = CooGetSet.getCoo(data, rowIndices, colIndices, index[0], index[1]); - return (value == null) ? getZeroElement() : value; - } - - - /** - * Sets the element of this tensor at the specified indices. - * - * @param value New value to set the specified index of this tensor to. - * @param indices Indices of the element to set. - * - * @return If this tensor is dense, a reference to this tensor is returned. If this tensor is sparse, a copy of this tensor with - * the updated value is returned. - * - * @throws IndexOutOfBoundsException If {@code indices} is not within the bounds of this tensor. - */ - @Override - public T set(W value, int... indices) { - ValidateParameters.validateTensorIndex(shape, indices); - return set(value, indices[0], indices[1]); - } - - - /** - * Sets an index of this matrix to the specified value. - * - * @param value Value to set. - * @param row Row index to set. - * @param col Column index to set. - * - * @return A reference to this matrix. - */ - @Override - public T set(W value, int row, int col) { - // Find position of row index within the row indices if it exits. - int idx = SparseElementSearch.matrixBinarySearch(rowIndices, colIndices, row, col); - W[] destEntries; - int[] destRowIndices; - int[] destColIndices; - - if(idx < 0) { - idx = -idx - 1; - - // No non-zero element with these indices exists. Insert new value. - destEntries = (W[]) new Ring[data.length + 1]; - destRowIndices = new int[data.length + 1]; - destColIndices = new int[data.length + 1]; - - CooGetSet.cooInsertNewValue( - value, row, col, - data, rowIndices, colIndices, idx, - destEntries, destRowIndices, destColIndices); - } else { - // Value with these indices exists. Simply update value. - destEntries = Arrays.copyOf(data, data.length); - destEntries[idx] = value; - destRowIndices = rowIndices.clone(); - destColIndices = colIndices.clone(); - } - - return makeLikeTensor(shape, (W[]) destEntries, destRowIndices, destColIndices); - } - - - /** - * Sets a specified row of this matrix to a vector. - * - * @param row Vector to replace specified row in this matrix. - * @param rowIdx Index of the row to set. - * - * @return If this matrix is dense, the row set operation is done in place and a reference to this matrix is returned. - * If this matrix is sparse a copy will be created with the new row and returned. - */ - @Override - public T setRow(V row, int rowIdx) { - SparseMatrixData dest = CooGetSet.setRow( - shape, data, rowIndices, colIndices, - rowIdx, row.size, row.data, row.indices); - return makeLikeTensor(dest.shape(), dest.data(), dest.rowData(), dest.colData()); - } - - - /** - * Sets a column of this matrix at the given index to the specified vector. - * - * @param col Vector containing new column data. - * @param colIndex The index of the column which is to be set. - * - * @return A copy of this matrix with the specified column set to {@code col}. - * - * @throws IllegalArgumentException If the {@code col} vector has a different length than the number of rows of this matrix. - * @throws IndexOutOfBoundsException If {@code colIndex < 0 || colIndex >= this.numCols}. + * Converts this sparse COO matrix to an equivalent sparse CSR matrix. + * @return A sparse CSR matrix equivalent to this sparse COO matrix. */ - public T setCol(V col, int colIndex) { - SparseMatrixData dest = CooGetSet.setCol( - shape, data, rowIndices, colIndices, - colIndex, col.size, col.data, col.indices); - CooDataSorter sorter = new CooDataSorter(dest.data(), dest.rowData(), dest.colData()).sparseSort(); - return makeLikeTensor(dest.shape(), dest.data(), dest.rowData(), dest.colData()); + public AbstractCsrRingMatrix toCsr() { + W[] csrEntries = (W[]) new Ring[data.length]; + int[] csrRowPointers = new int[numRows + 1]; + int[] csrColPointers = new int[colIndices.length]; + CooConversions.toCsr(shape, data, rowIndices, colIndices, csrEntries, csrRowPointers, csrColPointers); + return makeLikeCsrMatrix(shape, csrEntries, csrRowPointers, csrColPointers); } /** - * Flattens this matrix to a single row. - * - * @return The flattened matrix. - * - * @see #flatten(int) + * Converts this matrix to an equivalent tensor. + * @return A tensor which is equivalent to this matrix. */ - @Override - public T flatten() { - return flatten(0); - } + public abstract AbstractCooRingTensor toTensor(); /** - * Flattens a tensor along the specified axis. - * - * @param axis Axis along which to flatten tensor. - * - * @throws ArrayIndexOutOfBoundsException If the axis is not positive or larger than {@code }. - * @see #flatten() + * Converts this matrix to an equivalent tensor with the specified shape. + * @param newShape New shape for the tensor. Can be any rank but must be broadcastable to {@link #shape this.shape}. + * @return A tensor equivalent to this matrix which has been reshaped to {@code newShape} */ - @Override - public T flatten(int axis) { - ValidateParameters.ensureValidAxes(shape, axis); - int[] dims = {1, 1}; - dims[1-axis] = shape.totalEntriesIntValueExact(); - Shape flatShape = new Shape(dims); - - int[] destIndices = new int[data.length]; - - for(int i = 0; i < data.length; i++) - destIndices[i] = shape.getFlatIndex(rowIndices[i], colIndices[i]); - - return (axis == 0) - ? makeLikeTensor(flatShape, (W[]) data.clone(), new int[data.length], destIndices) - : makeLikeTensor(flatShape, (W[]) data.clone(), destIndices, new int[data.length]); - } + public abstract AbstractCooRingTensor toTensor(Shape newShape); /** - * Copies and reshapes this tensor. + * Computes the element-wise difference between two tensors of the same shape. * - * @param newShape New shape for the tensor. + * @param b Second tensor in the element-wise difference. * - * @return A copy of this tensor with the new shape. + * @return The difference of this tensor with {@code b}. * - * @throws TensorShapeException If {@code newShape} is not broadcastable to {@link #shape this.shape}. + * @throws TensorShapeException If this tensor and {@code b} do not have the same shape. */ @Override - public T reshape(Shape newShape) { - ValidateParameters.ensureBroadcastable(shape, newShape); - int oldColCount = shape.get(1); - int newColCount = newShape.get(1); - - // Initialize new COO structures with the same size as the original. - int[] newRowIndices = new int[rowIndices.length]; - int[] newColIndices = new int[colIndices.length]; - - for (int i = 0; i < rowIndices.length; i++) { - int flatIndex = rowIndices[i]*oldColCount + colIndices[i]; - newRowIndices[i] = flatIndex / newColCount; - newColIndices[i] = flatIndex % newColCount; - } + public T sub(T b) { + SparseMatrixData data = CooRingMatrixOps.sub( + shape, this.data, rowIndices, colIndices, + b.shape, b.data, b.rowIndices, b.colIndices); - return makeLikeTensor(newShape, (W[]) data.clone(), newRowIndices, newColIndices); + return makeLikeTensor(data.shape(), data.data(), data.rowData(), data.colData()); } /** - * Computes the transpose of a tensor by exchanging the first and last axes of this tensor. - * - * @return The transpose of this tensor. + * Computes the Hermitian transpose of this matrix. * - * @see #T(int, int) - * @see #T(int...) + * @return The Hermitian transpose of this matrix. */ @Override - public T T() { - T transpose = makeLikeTensor(shape.swapAxes(0, 1), (W[]) data.clone(), colIndices.clone(), rowIndices.clone()); - transpose.sortIndices(); // Ensure the indices are sorted correctly. - - return transpose; + public T H() { + return T(); } /** - * Computes the transpose of a tensor by exchanging {@code axis1} and {@code axis2}. + * Computes the conjugate transpose of a tensor by conjugating and exchanging {@code axis1} and {@code axis2}. * - * @param axis1 First axis to exchange. - * @param axis2 Second axis to exchange. + * @param axis1 First axis to exchange and conjugate. + * @param axis2 Second axis to exchange and conjugate. * - * @return The transpose of this tensor according to the specified axes. + * @return The conjugate transpose of this tensor according to the specified axes. * * @throws IndexOutOfBoundsException If either {@code axis1} or {@code axis2} are out of bounds for the rank of this tensor. - * @see #T() - * @see #T(int...) + * @see #H() + * @see #H(int...) */ @Override - public T T(int axis1, int axis2) { - ValidateParameters.ensureValidAxes(shape, axis1, axis2); - if(axis1 == axis2) return copy(); - else return T(); + public T H(int axis1, int axis2) { + return T(axis1, axis2); } /** - * Computes the transpose of this tensor. That is, permutes the axes of this tensor so that it matches + * Computes the conjugate transpose of this tensor. That is, conjugates and permutes the axes of this tensor so that it matches * the permutation specified by {@code axes}. * * @param axes Permutation of tensor axis. If the tensor has rank {@code N}, then this must be an array of length * {@code N} which is a permutation of {@code {0, 1, 2, ..., N-1}}. * - * @return The transpose of this tensor with its axes permuted by the {@code axes} array. + * @return The conjugate transpose of this tensor with its axes permuted by the {@code axes} array. * * @throws IndexOutOfBoundsException If any element of {@code axes} is out of bounds for the rank of this tensor. * @throws IllegalArgumentException If {@code axes} is not a permutation of {@code {1, 2, 3, ... N-1}}. - * @see #T(int, int) - * @see #T() - */ - @Override - public T T(int... axes) { - if(axes.length != 2) - throw new IllegalArgumentException("Expecting two axes in transpose but got " + axes.length + "."); - return T(axes[0], axes[1]); - } - - - /** - * Gets the number of rows in this matrix. - * - * @return The number of rows in this matrix. - */ - @Override - public int numRows() { - return numRows; - } - - - /** - * Gets the number of columns in this matrix. - * - * @return The number of columns in this matrix. - */ - @Override - public int numCols() { - return numCols; - } - - - /** - * Gets the element of this matrix at this specified {@code row} and {@code col}. - * - * @param row Row index of the item to get from this matrix. - * @param col Column index of the item to get from this matrix. - * - * @return The element of this matrix at the specified index. - */ - @Override - public W get(int row, int col) { - return (W) CooGetSet.getCoo(data, rowIndices, colIndices, row, col); - } - - - /** - *

    Computes the trace of this matrix. That is, the sum of elements along the principle diagonal of this matrix. - * - *

    Same as {@link #trace()}. - * - * @return The trace of this matrix. - * - * @throws IllegalArgumentException If this matrix is not square. - */ - @Override - public W tr() { - W trace = getZeroElement(); - - for(int i = 0; i< data.length; i++) - if(rowIndices[i]==colIndices[i]) trace = trace.add((W) data[i]); // Then entry is on the diagonal. - - return trace; - } - - - /** - * Checks if this matrix is upper triangular. - * - * @return {@code true} is this matrix is upper triangular; {@code false} otherwise. - * - * @see #isTri() - * @see #isTriL() - * @see #isDiag() - */ - @Override - public boolean isTriU() { - for(int i = 0; i< data.length; i++) - if(rowIndices[i] > colIndices[i] && !data[i].isZero()) return false; // Then non-zero entry is not in upper triangle. - - return true; - } - - - /** - * Checks if this matrix is lower triangular. - * - * @return {@code true} is this matrix is lower triangular; {@code false} otherwise. - * - * @see #isTri() - * @see #isTriU() - * @see #isDiag() - */ - @Override - public boolean isTriL() { - for(int i = 0; i< data.length; i++) - if(rowIndices[i] < colIndices[i] && !data[i].isZero()) return false; // Then non-zero entry is not in lower triangle. - - return true; - } - - - /** - * Checks if this matrix is the identity matrix. That is, checks if this matrix is square and contains - * only ones along the principle diagonal and zeros everywhere else. - * - * @return {@code true} if this matrix is the identity matrix; {@code false} otherwise. - */ - @Override - public boolean isI() { - return CooSemiringMatrixProperties.isIdentity(shape, data, rowIndices, colIndices); - } - - - /** - * Computes the matrix multiplication between two matrices. - * - * @param b Second matrix in the matrix multiplication. - * - * @return The result of matrix multiplying this matrix with matrix {@code b}. - * - * @throws LinearAlgebraException If the number of columns in this matrix do not equal the number - * of rows in matrix {@code b}. - */ - @Override - public U mult(T b) { - ValidateParameters.ensureMatMultShapes(shape, b.shape); - W[] dest = (W[]) new Ring[numRows*b.numCols]; - CooSemiringMatMult.standard( - data, rowIndices, colIndices, shape, - b.data, b.rowIndices, b.colIndices, b.shape, dest); - - return makeLikeDenseTensor(new Shape(numRows, b.numCols), dest); - } - - - /** - * Multiplies this matrix with the transpose of the {@code b} tensor as if by - * {@code this.mult(b.T())}. - * For large matrices, this method may - * be significantly faster than directly computing the transpose followed by the multiplication as - * {@code this.mult(b.T())}. - * - * @param b The second matrix in the multiplication and the matrix to transpose. - * - * @return The result of multiplying this matrix with the transpose of {@code b}. - */ - @Override - public U multTranspose(T b) { - ValidateParameters.ensureEquals(numCols, b.numCols); - return mult(b.T()); - } - - - /** - * Stacks matrices along columns.
    - * - * @param b Matrix to stack to this matrix. - * - * @return The result of stacking this matrix on top of the matrix {@code b}. - * - * @throws IllegalArgumentException If this matrix and matrix {@code b} have a different number of columns. - * @see #stack(MatrixMixin, int) - * @see #augment(T) + * @see #H(int, int) + * @see #H() */ @Override - public T stack(T b) { - ValidateParameters.ensureEquals(numCols, b.numCols); - - Shape destShape = new Shape(numRows+b.numRows, numCols); - W[] destEntries = (W[]) new Ring[data.length + b.data.length]; - int[] destRowIndices = new int[destEntries.length]; - int[] destColIndices = new int[destEntries.length]; - CooConcat.stack(data, rowIndices, colIndices, numRows, - b.data, b.rowIndices, b.colIndices, - destEntries, destRowIndices, destColIndices); - - return makeLikeTensor(destShape, destEntries, destRowIndices, destColIndices); + public T H(int... axes) { + return T(axes); } /** - * Stacks matrices along rows. + * Checks if the matrix is "close" to an identity matrix. Two entries {@code x} and {@code y} are considered + * "close" if they satisfy the following: + *

    {@code
    +     *      |x-y| <= (1E-08 + 1E-05*|y|)
    +     * }
    * - * @param b Matrix to stack to this matrix. - * - * @return The result of stacking {@code b} to the right of this matrix. - * - * @throws IllegalArgumentException If this matrix and matrix {@code b} have a different number of rows. - * @see #stack(T) - * @see #stack(MatrixMixin, int) - */ - @Override - public T augment(T b) { - ValidateParameters.ensureEquals(numRows, b.numRows); - - Shape destShape = new Shape(numRows, numCols + b.numCols); - W[] destEntries = (W[]) new Ring[data.length + b.data.length]; - int[] destRowIndices = new int[destEntries.length]; - int[] destColIndices = new int[destEntries.length]; - CooConcat.augment(data, rowIndices, colIndices, numCols, - b.data, b.rowIndices, b.colIndices, - destEntries, destRowIndices, destColIndices); - - return makeLikeTensor(destShape, destEntries, destRowIndices, destColIndices); - } - - - /** - * Augments a vector to this matrix. - * - * @param b The vector to augment to this matrix. - * - * @return The result of augmenting {@code b} to this matrix. - */ - @Override - public T augment(V b) { - ValidateParameters.ensureEquals(numRows, b.size); - - Shape destShape = new Shape(numRows, numCols + 1); - W[] destEntries = (W[]) new Ring[nnz + b.data.length]; - int[] destRowIndices = new int[destEntries.length]; - int[] destColIndices = new int[destEntries.length]; - CooConcat.augmentVector( - data, rowIndices, colIndices, numCols, - b.data, b.indices, - destEntries, destRowIndices, destColIndices); - - return makeLikeTensor(destShape, destEntries, destRowIndices, destColIndices); - } - - - /** - * Swaps specified rows in the matrix. This is done in place. - * - * @param rowIndex1 Index of the first row to swap. - * @param rowIndex2 Index of the second row to swap. - * - * @return A reference to this matrix. - * - * @throws ArrayIndexOutOfBoundsException If either index is outside the matrix bounds. - */ - @Override - public T swapRows(int rowIndex1, int rowIndex2) { - CooManipulations.swapRows(shape, data, rowIndices, colIndices, rowIndex1, rowIndex2); - return (T) this; - } - - - /** - * Swaps specified columns in the matrix. This is done in place. - * - * @param colIndex1 Index of the first column to swap. - * @param colIndex2 Index of the second column to swap. - * - * @return A reference to this matrix. - * - * @throws ArrayIndexOutOfBoundsException If either index is outside the matrix bounds. - */ - @Override - public T swapCols(int colIndex1, int colIndex2) { - CooManipulations.swapCols(shape, data, rowIndices, colIndices, colIndex1, colIndex2); - return (T) this; - } - - - /** - * Checks if a matrix is symmetric. That is, if the matrix is square and equal to its transpose. - * - * @return {@code true} if this matrix is symmetric; {@code false} otherwise. - */ - @Override - public boolean isSymmetric() { - return CooSemiringMatrixProperties.isSymmetric(shape, data, rowIndices, colIndices); - } - - - /** - * Checks if a matrix is Hermitian. That is, if the matrix is square and equal to its conjugate transpose. - * - * @return {@code true} if this matrix is Hermitian; {@code false} otherwise. - */ - @Override - public boolean isHermitian() { - return isSymmetric(); - } - - - /** - * Checks if this matrix is orthogonal. That is, if the inverse of this matrix is equal to its transpose. - * - * @return {@code true} if this matrix it is orthogonal; {@code false} otherwise. - */ - @Override - public boolean isOrthogonal() { - if(isSquare()) return mult(T()).isI(); - else return false; - } - - - /** - * Gets a range of a row of this matrix. - * - * @param rowIdx The index of the row to get. - * @param start The staring column of the row range to get (inclusive). - * @param stop The ending column of the row range to get (exclusive). - * - * @return A vector containing the elements of the specified row over the range [start, stop). - * - * @throws IllegalArgumentException If {@code rowIdx < 0 || rowIdx >= this.numRows()} or {@code start < 0 || start >= numCols} or - * {@code stop < start || stop > numCols}. - */ - @Override - public V getRow(int rowIdx, int start, int stop) { - SparseVectorData data = CooGetSet.getRow(shape, this.data, rowIndices, colIndices, rowIdx, start, stop); - return makeLikeVector(data.shape(), - (W[]) data.data().toArray(new Ring[data.data().size()]), - data.indicesToArray()); - } - - - /** - * Gets a range of a column of this matrix. - * - * @param colIdx The index of the column to get. - * @param start The staring row of the column range to get (inclusive). - * @param stop The ending row of the column range to get (exclusive). - * - * @return A vector containing the elements of the specified column over the range [start, stop). - * - * @throws IllegalArgumentException If {@code colIdx < 0 || colIdx >= this.numCols()} or {@code start < 0 || start >= numRows} or - * {@code stop < start || stop > numRows}. - */ - @Override - public V getCol(int colIdx, int start, int stop) { - SparseVectorData data = CooGetSet.getCol(shape, this.data, rowIndices, colIndices, colIdx, start, stop); - return makeLikeVector(data.shape(), - (W[]) data.data().toArray(new Ring[data.data().size()]), - data.indicesToArray()); - } - - - /** - * Gets the elements of this matrix along the specified diagonal. - * - * @param diagOffset The diagonal to get within this matrix. - *
      - *
    • If {@code diagOffset == 0}: Then the elements of the principle diagonal are collected.
    • - *
    • If {@code diagOffset < 0}: Then the elements of the sub-diagonal {@code diagOffset} below the principle diagonal - * are collected.
    • - *
    • If {@code diagOffset > 0}: Then the elements of the super-diagonal {@code diagOffset} above the principle diagonal - * are collected.
    • - *
    - * - * @return The elements of the specified diagonal as a vector. - */ - @Override - public V getDiag(int diagOffset) { - SparseVectorData data = CooGetSet.getDiag(shape, this.data, rowIndices, colIndices, diagOffset); - return makeLikeVector(data.shape(), - (W[]) data.data().toArray(new Ring[data.data().size()]), - data.indicesToArray()); - } - - - /** - * Removes a specified row from this matrix. - * - * @param rowIndex Index of the row to remove from this matrix. - * - * @return A copy of this matrix with the specified row removed. - */ - @Override - public T removeRow(int rowIndex) { - Shape shape = new Shape(numRows-1, numCols); - - // Find the start and end index within the data array which have the given row index. - int[] startEnd = SparseElementSearch.matrixFindRowStartEnd(rowIndices, rowIndex); - int size = data.length - (startEnd[1]-startEnd[0]); - - // Initialize arrays. - W[] entries = (W[]) new Ring[size]; - int[] rowIndices = new int[size]; - int[] colIndices = new int[size]; - - copyRanges(this.data, this.rowIndices, this.colIndices, entries, rowIndices, colIndices, startEnd); - - return makeLikeTensor(shape, entries, rowIndices, colIndices); - } - - - /** - * Removes a specified set of rows from this matrix. - * - * @param rowIdxs The indices of the rows to remove from this matrix. Assumed to contain unique values. - * - * @return A copy of this matrix with the specified column removed. - */ - @Override - public T removeRows(int... rowIdxs) { - // TODO: This should be doable for a general COO matrix. Return SparseMatrixData object. - Shape shape = new Shape(numRows-rowIdxs.length, numCols); - List entries = new ArrayList<>(nnz); - List rowIndices = new ArrayList<>(nnz); - List colIndices = new ArrayList<>(nnz); - - for(int i=0; i destEntries = new ArrayList<>(data.length); - List destRowIndices = new ArrayList<>(data.length); - List destColIndices = new ArrayList<>(data.length); - - for(int i = 0; i< data.length; i++) { - if(colIndices[i] != colIndex) { - // Then entry is not in the specified column, so remove it. - destEntries.add(data[i]); - destRowIndices.add(rowIndices[i]); - - if(colIndices[i] < colIndex) destColIndices.add(colIndices[i]); - else destColIndices.add(colIndices[i]-1); - } - } - - return makeLikeTensor(shape, destEntries, destRowIndices, destColIndices); - } - - - /** - * Removes a specified set of columns from this matrix. - * - * @param colIdxs Indices of the columns to remove from this matrix. Assumed to contain unique values. - * - * @return A copy of this matrix with the specified column removed. - */ - @Override - public T removeCols(int... colIdxs) { - Shape shape = new Shape(numRows, numCols-1); - List destEntries = new ArrayList<>(data.length); - List destRowIndices = new ArrayList<>(data.length); - List destColIndices = new ArrayList<>(data.length); - - for(int i = 0; i< data.length; i++) { - int idx = Arrays.binarySearch(colIdxs, colIndices[i]); - - if(idx < 0) { - // Then entry is not in the specified column, so copy it with the appropriate column index shift. - destEntries.add(data[i]); - destRowIndices.add(rowIndices[i]); - destColIndices.add(colIndices[i] + (idx+1)); - } - } - - return makeLikeTensor(shape, destEntries, destRowIndices, destColIndices); - } - - - /** - * Creates a copy of this matrix and sets a slice of the copy to the specified values. The rowStart and colStart parameters specify the upper - * left index location of the slice to set. - * - * @param values New values for the specified slice. - * @param rowStart Starting row index for the slice (inclusive). - * @param colStart Starting column index for the slice (inclusive). - * - * @return A copy of this matrix with the given slice set to the specified values. - * - * @throws IndexOutOfBoundsException If rowStart or colStart are not within the matrix. - * @throws IllegalArgumentException If the values slice, with upper left corner at the specified location, does not - * fit completely within this matrix. - */ - @Override - public T setSliceCopy(T values, int rowStart, int colStart) { - SparseMatrixData sliceData = CooGetSet.setSlice( - shape, data, rowIndices, colIndices, - values.shape, values.data, values.rowIndices, values.colIndices, - rowStart, colStart); - return makeLikeTensor(sliceData.shape(), sliceData.data(), sliceData.rowData(), sliceData.colData()); - } - - - /** - * Gets a specified slice of this matrix. - * - * @param rowStart Starting row index of slice (inclusive). - * @param rowEnd Ending row index of slice (exclusive). - * @param colStart Starting column index of slice (inclusive). - * @param colEnd Ending row index of slice (exclusive). - * - * @return The specified slice of this matrix. This is a completely new matrix and NOT a view into the matrix. - * - * @throws ArrayIndexOutOfBoundsException If any of the indices are out of bounds of this matrix. - * @throws IllegalArgumentException If {@code rowEnd} is not greater than {@code rowStart} or if {@code colEnd} is not greater than {@code colStart}. - */ - @Override - public T getSlice(int rowStart, int rowEnd, int colStart, int colEnd) { - SparseMatrixData sliceData = CooGetSet.getSlice( - shape, data, rowIndices, colIndices, - rowStart, rowEnd, colStart, colEnd); - return makeLikeTensor(sliceData.shape(), sliceData.data(), sliceData.rowData(), sliceData.colData()); - } - - - /** - * Extracts the upper-triangular portion of this matrix with a specified diagonal offset. All other data of the resulting - * matrix will be zero. - * - * @param diagOffset Diagonal offset for upper-triangular portion to extract: - *
      - *
    • If zero, then all data at and above the principle diagonal of this matrix are extracted.
    • - *
    • If positive, then all data at and above the equivalent super-diagonal are extracted.
    • - *
    • If negative, then all data at and above the equivalent sub-diagonal are extracted.
    • - *
    - * - * @return The upper-triangular portion of this matrix with a specified diagonal offset. All other data of the returned - * matrix will be zero. - * - * @throws IllegalArgumentException If {@code diagOffset} is not in the range (-numRows, numCols). - */ - @Override - public T getTriU(int diagOffset) { - SparseMatrixData data = CooGetSet.getTriU(diagOffset, shape, this.data, rowIndices, colIndices); - return makeLikeTensor(data.shape(), data.data(), data.rowData(), data.colData()); - } - - - /** - * Extracts the lower-triangular portion of this matrix with a specified diagonal offset. All other data of the resulting - * matrix will be zero. - * - * @param diagOffset Diagonal offset for lower-triangular portion to extract: - *
      - *
    • If zero, then all data at and above the principle diagonal of this matrix are extracted.
    • - *
    • If positive, then all data at and above the equivalent super-diagonal are extracted.
    • - *
    • If negative, then all data at and above the equivalent sub-diagonal are extracted.
    • - *
    - * - * @return The lower-triangular portion of this matrix with a specified diagonal offset. All other data of the returned - * matrix will be zero. - * - * @throws IllegalArgumentException If {@code diagOffset} is not in the range (-numRows, numCols). - */ - @Override - public T getTriL(int diagOffset) { - SparseMatrixData data = CooGetSet.getTriL(diagOffset, shape, this.data, rowIndices, colIndices); - return makeLikeTensor(data.shape(), data.data(), data.rowData(), data.colData()); - } - - - /** - * Creates a deep copy of this tensor. - * - * @return A deep copy of this tensor. - */ - @Override - public T copy() { - return makeLikeTensor(shape, data); - } - - - /** - * Computes the element-wise difference between two tensors of the same shape. - * - * @param b Second tensor in the element-wise difference. - * - * @return The difference of this tensor with {@code b}. - * - * @throws TensorShapeException If this tensor and {@code b} do not have the same shape. - */ - @Override - public T sub(T b) { - SparseMatrixData data = CooRingMatrixOps.sub( - shape, this.data, rowIndices, colIndices, - b.shape, b.data, b.rowIndices, b.colIndices); - - return makeLikeTensor(data.shape(), data.data(), data.rowData(), data.colData()); - } - - - /** - * Computes the Hermitian transpose of this matrix. - * - * @return The Hermitian transpose of this matrix. - */ - @Override - public T H() { - return T(); - } - - - /** - * Computes the conjugate transpose of a tensor by conjugating and exchanging {@code axis1} and {@code axis2}. - * - * @param axis1 First axis to exchange and conjugate. - * @param axis2 Second axis to exchange and conjugate. - * - * @return The conjugate transpose of this tensor according to the specified axes. - * - * @throws IndexOutOfBoundsException If either {@code axis1} or {@code axis2} are out of bounds for the rank of this tensor. - * @see #H() - * @see #H(int...) - */ - @Override - public T H(int axis1, int axis2) { - return T(axis1, axis2); - } - - - /** - * Computes the conjugate transpose of this tensor. That is, conjugates and permutes the axes of this tensor so that it matches - * the permutation specified by {@code axes}. - * - * @param axes Permutation of tensor axis. If the tensor has rank {@code N}, then this must be an array of length - * {@code N} which is a permutation of {@code {0, 1, 2, ..., N-1}}. - * - * @return The conjugate transpose of this tensor with its axes permuted by the {@code axes} array. - * - * @throws IndexOutOfBoundsException If any element of {@code axes} is out of bounds for the rank of this tensor. - * @throws IllegalArgumentException If {@code axes} is not a permutation of {@code {1, 2, 3, ... N-1}}. - * @see #H(int, int) - * @see #H() - */ - @Override - public T H(int... axes) { - return T(axes); - } - - - /** - * Finds the minimum value in this tensor. If this tensor is complex, then this method finds the smallest value in magnitude. - * - * @return The minimum value (smallest in magnitude for a complex valued tensor) in this tensor. - */ - @Override - public W min() { - return CompareSemiring.min(data); - } - - - /** - * Finds the maximum value in this tensor. If this tensor is complex, then this method finds the largest value in magnitude. - * - * @return The maximum value (largest in magnitude for a complex valued tensor) in this tensor. - */ - @Override - public W max() { - return CompareSemiring.max(data); - } - - - /** - * Finds the indices of the minimum value in this tensor. - * - * @return The indices of the minimum value in this tensor. If this value occurs multiple times, the indices of the first - * entry (in row-major ordering) are returned. - */ - @Override - public int[] argmin() { - int idx = CompareSemiring.argmin(data); - return new int[]{rowIndices[idx], colIndices[idx]}; - } - - - /** - * Finds the indices of the maximum value in this tensor. - * - * @return The indices of the maximum value in this tensor. If this value occurs multiple times, the indices of the first - * entry (in row-major ordering) are returned. - */ - @Override - public int[] argmax() { - int idx = CompareSemiring.argmax(data); - return new int[]{rowIndices[idx], colIndices[idx]}; - } - - - /** - * Computes the element-wise sum between two tensors of the same shape. - * - * @param b Second tensor in the element-wise sum. - * - * @return The sum of this tensor with {@code b}. - * - * @throws TensorShapeException If this tensor and {@code b} do not have the same shape. - */ - @Override - public T add(T b) { - SparseMatrixData data = CooSemiringMatrixOps.add( - shape, this.data, rowIndices, colIndices, - b.shape, b.data, b.rowIndices, b.colIndices); - - return makeLikeTensor(data.shape(), - (W[]) data.data().toArray(new Ring[data.data().size()]), - data.rowIndicesToArray(), - data.colIndicesToArray()); - } - - - /** - * Computes the element-wise multiplication of two tensors of the same shape. - * - * @param b Second tensor in the element-wise product. - * - * @return The element-wise product between this tensor and {@code b}. - * - * @throws IllegalArgumentException If this tensor and {@code b} do not have the same shape. - */ - @Override - public T elemMult(T b) { - SparseMatrixData data = CooSemiringMatrixOps.elemMult( - shape, this.data, rowIndices, colIndices, - b.shape, b.data, b.rowIndices, b.colIndices); - - return makeLikeTensor(data.shape(), - (W[]) data.data().toArray(new Ring[data.data().size()]), - data.rowIndicesToArray(), - data.colIndicesToArray()); - } - - - /** - *

    Computes the generalized trace of this tensor along the specified axes. - * - *

    The generalized tensor trace is the sum along the diagonal values of the 2D sub-arrays of this tensor specified by - * {@code axis1} and {@code axis2}. The shape of the resulting tensor is equal to this tensor with the - * {@code axis1} and {@code axis2} removed. - * - * @param axis1 First axis for 2D sub-array. - * @param axis2 Second axis for 2D sub-array. - * - * @return The generalized trace of this tensor along {@code axis1} and {@code axis2}. - * - * @throws IndexOutOfBoundsException If the two axes are not both larger than zero and less than this tensors rank. - * @throws IllegalArgumentException If {@code axis1 == axis2} or {@code this.shape.get(axis1) != this.shape.get(axis1)} - * (i.e. the axes are equal or the tensor does not have the same length along the two axes.) - */ - @Override - public T tensorTr(int axis1, int axis2) { - ValidateParameters.ensureNotEquals(axis1, axis2); - ValidateParameters.ensureValidAxes(shape, axis1, axis2); - - return makeLikeTensor(new Shape(1, 1), (W[]) new Ring[]{tr()}, new int[]{0}, new int[]{0}); - } - - - /** - * Sorts the indices of this tensor in lexicographical order while maintaining the associated value for each index. - */ - public void sortIndices() { - CooDataSorter.wrap(data, rowIndices, colIndices).sparseSort().unwrap(data, rowIndices, colIndices); - } - - - /** - * Converts this sparse COO matrix to an equivalent dense matrix. - * @return A dense matrix equivalent to this sparse COO matrix. - */ - public U toDense() { - W[] data = (W[]) new Ring[totalEntries().intValueExact()]; - - for(int i = 0; i< nnz; i++) - data[rowIndices[i]*numCols + colIndices[i]] = this.data[i]; - - return makeLikeDenseTensor(shape, data); - } - - - /** - * Converts this sparse COO matrix to an equivalent sparse CSR matrix. - * @return A sparse CSR matrix equivalent to this sparse COO matrix. - */ - public AbstractCsrRingMatrix toCsr() { - W[] csrEntries = (W[]) new Ring[data.length]; - int[] csrRowPointers = new int[numRows + 1]; - int[] csrColPointers = new int[colIndices.length]; - CooConversions.toCsr(shape, data, rowIndices, colIndices, csrEntries, csrRowPointers, csrColPointers); - return makeLikeCsrMatrix(shape, csrEntries, csrRowPointers, csrColPointers); - } - - - /** - * Converts this matrix to an equivalent tensor. - * @return A tensor which is equivalent to this matrix. - */ - public abstract AbstractTensor toTensor(); - - - /** - * Converts this matrix to an equivalent tensor with the specified shape. - * @param newShape New shape for the tensor. Can be any rank but must be broadcastable to {@link #shape this.shape}. - * @return A tensor equivalent to this matrix which has been reshaped to {@code newShape} - */ - public abstract AbstractTensor toTensor(Shape newShape); - - - /** - * Converts this sparse CSR matrix to an equivalent vector. If this matrix is not a row or column vector it will be flattened - * before conversion. - * @return A vector equivalent to this CSR matrix. - */ - public V toVector() { - int[] destIndices = new int[data.length]; - for(int i = 0; i< data.length; i++) - destIndices[i] = rowIndices[i]*colIndices[i]; - - return makeLikeVector(new Shape(numRows*numCols), data.clone(), destIndices); - } - - - /** - * Coalesces this sparse COO matrix. An uncoalesced matrix is a sparse matrix with multiple data for a single index. This - * method will ensure that each index only has one non-zero value by summing duplicated data. If another form of aggregation other - * than summing is desired, use {@link #coalesce(BinaryOperator)}. - * @return A new coalesced sparse COO matrix which is equivalent to this COO matrix. - * @see #coalesce(BinaryOperator) - */ - public T coalesce() { - SparseMatrixData mat = SparseUtils.coalesce(Semiring::add, shape, data, rowIndices, colIndices); - return makeLikeTensor(mat.shape(), mat.data(), mat.rowData(), mat.colData()); - } - - - /** - * Coalesces this sparse COO matrix. An uncoalesced matrix is a sparse matrix with multiple data for a single index. This - * method will ensure that each index only has one non-zero value by aggregating duplicated data using {@code aggregator}. - * @param aggregator Custom aggregation function to combine multiple. - * @return A new coalesced sparse COO matrix which is equivalent to this COO matrix. - * @see #coalesce() - */ - public T coalesce(BinaryOperator aggregator) { - SparseMatrixData mat = SparseUtils.coalesce(aggregator, shape, data, rowIndices, colIndices); - return makeLikeTensor(mat.shape(), mat.data(), mat.rowData(), mat.colData()); - } - - - /** - * Drops any explicit zeros in this sparse COO matrix. - * @return A copy of this COO matrix with any explicitly stored zeros removed. + * @return {@code true} if the matrix is approximately an identity matrix, otherwise {@code false}. */ - public T dropZeros() { - SparseMatrixData mat = SparseUtils.dropZeros(shape, data, rowIndices, colIndices); - return makeLikeTensor(mat.shape(), mat.data(), mat.rowData(), mat.colData()); + public boolean isCloseToI() { + return CooRingMatrixOps.isCloseToIdentity(this); } } diff --git a/src/main/java/org/flag4j/arrays/backend/ring_arrays/AbstractCooRingTensor.java b/src/main/java/org/flag4j/arrays/backend/ring_arrays/AbstractCooRingTensor.java index a867b9d82..30f50c69c 100644 --- a/src/main/java/org/flag4j/arrays/backend/ring_arrays/AbstractCooRingTensor.java +++ b/src/main/java/org/flag4j/arrays/backend/ring_arrays/AbstractCooRingTensor.java @@ -25,25 +25,12 @@ package org.flag4j.arrays.backend.ring_arrays; import org.flag4j.algebraic_structures.Ring; -import org.flag4j.algebraic_structures.Semiring; import org.flag4j.arrays.Shape; import org.flag4j.arrays.SparseTensorData; -import org.flag4j.arrays.backend.AbstractTensor; -import org.flag4j.linalg.ops.common.semiring_ops.CompareSemiring; -import org.flag4j.linalg.ops.sparse.SparseElementSearch; -import org.flag4j.linalg.ops.sparse.SparseUtils; -import org.flag4j.linalg.ops.sparse.coo.*; +import org.flag4j.arrays.backend.semiring_arrays.AbstractCooSemiringTensor; import org.flag4j.linalg.ops.sparse.coo.ring_ops.CooRingTensorOps; -import org.flag4j.linalg.ops.sparse.coo.semiring_ops.CooSemiringTensorOps; -import org.flag4j.util.ArrayUtils; -import org.flag4j.util.ValidateParameters; import org.flag4j.util.exceptions.TensorShapeException; -import java.math.BigDecimal; -import java.util.Arrays; -import java.util.List; -import java.util.function.BinaryOperator; - /** *

    Base class for all sparse {@link Ring} tensors stored in coordinate list (COO) format. The data of this COO tensor are * elements of a {@link Ring}. @@ -77,28 +64,9 @@ */ public abstract class AbstractCooRingTensor, U extends AbstractDenseRingTensor, V extends Ring> - extends AbstractTensor + extends AbstractCooSemiringTensor implements RingTensorMixin { - /** - * The zero element for the arrays that this tensor's elements belong to. - */ - private V zeroElement; - /** - *

    The non-zero indices of this sparse tensor. - * - *

    Has shape {@code (nnz, rank)} where {@code nnz} is the number of non-zero data in this sparse tensor. - */ - public final int[][] indices; - /** - * The number of non-zero data in this sparse tensor. - */ - public final int nnz; - /** - * Stores the sparsity of this matrix. - */ - public final double sparsity; - /** * Creates a tensor with the specified data and shape. @@ -108,430 +76,7 @@ public abstract class AbstractCooRingTensor 0 && data[0] != null) ? data[0].getZero() : null; - } - - - /** - * Constructs a tensor of the same type as this tensor with the specified shape and non-zero data. - * @param shape Shape of the tensor to construct. - * @param entries Non-zero data of the tensor to construct. - * @param indices Indices of the non-zero data of the tensor. - * @return A tensor of the same type as this tensor with the specified shape and non-zero data. - */ - public abstract T makeLikeTensor(Shape shape, V[] entries, int[][] indices); - - - /** - * Constructs a tensor of the same type as this tensor with the specified shape and non-zero data. - * @param shape Shape of the tensor to construct. - * @param entries Non-zero data of the tensor to construct. - * @param indices Indices of the non-zero data of the tensor. - * @return A tensor of the same type as this tensor with the specified shape and non-zero data. - */ - public abstract T makeLikeTensor(Shape shape, List entries, List indices); - - - /** - * Constructs a dense tensor that is a similar type as this sparse COO tensor. - * @param shape Shape of the tensor to construct. - * @param entries The data of the dense tensor to construct. - * @return A dense tensor that is a similar type as this sparse COO tensor. - */ - public abstract U makeLikeDenseTensor(Shape shape, V[] entries); - - - /** - * Gets the zero element for the field of this tensor. - * @return The zero element for the field of this tensor. If it could not be determined during construction of this object - * and has not been set explicitly by {@link #setZeroElement(Ring)} then {@code null} will be returned. - */ - public V getZeroElement() { - return zeroElement; - } - - - /** - * Sets the zero element for the field of this tensor. - * @param zeroElement The zero element of this tensor. - * @throws IllegalArgumentException If {@code zeroElement} is not an additive identity for the arrays. - */ - public void setZeroElement(V zeroElement) { - if (zeroElement.isZero()) { - this.zeroElement = zeroElement; - } else { - throw new IllegalArgumentException("The provided zeroElement is not an additive identity."); - } - } - - /** - * Gets the sparsity of this tensor as a decimal percentage. - * That is, the percentage of data in this tensor that are zero. - * @return The sparsity of this tensor as a decimal percentage. - * @see #density() - */ - public double sparsity() { - return sparsity; - } - - - /** - * Gets the density of this tensor as a decimal percentage. - * That is, the percentage of data in this tensor that are non-zero. - * @return The density of this tensor as a decimal percentage. - * @see #sparsity() - */ - public double density() { - return 1.0 - sparsity; - } - - - /** - * Computes the element-wise sum between two tensors of the same shape. - * - * @param b Second tensor in the element-wise sum. - * - * @return The sum of this tensor with {@code b}. - * - * @throws TensorShapeException If this tensor and {@code b} do not have the same shape. - */ - @Override - public T add(T b) { - SparseTensorData data = CooSemiringTensorOps.add( - shape, this.data, indices, - b.shape, b.data, b.indices - ); - - return makeLikeTensor(data.shape(), - (V[]) data.data().toArray(new Ring[data.data().size()]), - data.indicesToArray()); - } - - - /** - * Computes the element-wise multiplication of two tensors of the same shape. - * - * @param b Second tensor in the element-wise product. - * - * @return The element-wise product between this tensor and {@code b}. - * - * @throws IllegalArgumentException If this tensor and {@code b} do not have the same shape. - */ - @Override - public T elemMult(T b) { - SparseTensorData data = CooSemiringTensorOps.elemMult( - shape, this.data, indices, b.shape, b.data, b.indices); - return makeLikeTensor(data.shape(), - (V[]) data.data().toArray(new Ring[data.data().size()]), - data.indicesToArray()); - } - - - /** - * Computes the tensor contraction of this tensor with a specified tensor over the specified set of axes. That is, - * computes the sum of products between the two tensors along the specified set of axes. - * - * @param src2 Tensor to contract with this tensor. - * @param aAxes Axes along which to compute products for this tensor. - * @param bAxes Axes along which to compute products for {@code src2} tensor. - * - * @return The tensor dot product over the specified axes. - * - * @throws IllegalArgumentException If the two tensors shapes do not match along the specified axes pairwise in - * {@code aAxes} and {@code bAxes}. - * @throws IllegalArgumentException If {@code aAxes} and {@code bAxes} do not match in length, or if any of the axes - * are out of bounds for the corresponding tensor. - */ - @Override - public U tensorDot(T src2, int[] aAxes, int[] bAxes) { - CooTensorDot problem = new CooTensorDot<>(shape, data, indices, - src2.shape, src2.data, src2.indices, - aAxes, bAxes); - V[] dest = (V[]) new Ring[problem.getOutputSize()]; - problem.compute(dest); - return makeLikeDenseTensor(problem.getOutputShape(), dest); - } - - - /** - *

    Computes the generalized trace of this tensor along the specified axes. - * - *

    The generalized tensor trace is the sum along the diagonal values of the 2D sub-arrays of this tensor specified by - * {@code axis1} and {@code axis2}. The shape of the resulting tensor is equal to this tensor with the - * {@code axis1} and {@code axis2} removed. - * - * @param axis1 First axis for 2D sub-array. - * @param axis2 Second axis for 2D sub-array. - * - * @return The generalized trace of this tensor along {@code axis1} and {@code axis2}. - * - * @throws IndexOutOfBoundsException If the two axes are not both larger than zero and less than this tensors rank. - * @throws IllegalArgumentException If {@code axis1 == axis2} or {@code this.shape.get(axis1) != this.shape.get(axis1)} - * (i.e. the axes are equal or the tensor does not have the same length along the two axes.) - */ - @Override - public T tensorTr(int axis1, int axis2) { - SparseTensorData data = CooSemiringTensorOps.tensorTr( - shape, this.data, indices, axis1, axis2); - return makeLikeTensor(data.shape(), - (V[]) data.data().toArray(new Ring[data.data().size()]), - data.indices().toArray(new int[data.indices().size()][])); - } - - - /** - * Computes the transpose of a tensor by exchanging the first and last axes of this tensor. - * - * @return The transpose of this tensor. - * - * @see #T(int, int) - * @see #T(int...) - */ - @Override - public T T() { - V[] destEntries = (V[]) new Ring[nnz]; - int[][] destIndices = new int[nnz][rank]; - CooTranspose.tensorTranspose(shape, data, indices,0, shape.getRank()-1, destEntries, destIndices); - return makeLikeTensor(shape.swapAxes(0, rank-1), (V[]) destEntries, destIndices); - } - - - /** - * Computes the transpose of a tensor by exchanging {@code axis1} and {@code axis2}. - * - * @param axis1 First axis to exchange. - * @param axis2 Second axis to exchange. - * - * @return The transpose of this tensor according to the specified axes. - * - * @throws IndexOutOfBoundsException If either {@code axis1} or {@code axis2} are out of bounds for the rank of this tensor. - * @see #T() - * @see #T(int...) - */ - @Override - public T T(int axis1, int axis2) { - V[] destEntries = (V[]) new Ring[nnz]; - int[][] destIndices = new int[nnz][rank]; - CooTranspose.tensorTranspose(shape, data, indices, axis1, axis2, destEntries, destIndices); - return makeLikeTensor(shape.swapAxes(axis1, axis2), (V[]) destEntries, destIndices); - } - - - /** - * Computes the transpose of this tensor. That is, permutes the axes of this tensor so that it matches - * the permutation specified by {@code axes}. - * - * @param axes Permutation of tensor axis. If the tensor has rank {@code N}, then this must be an array of length - * {@code N} which is a permutation of {@code {0, 1, 2, ..., N-1}}. - * - * @return The transpose of this tensor with its axes permuted by the {@code axes} array. - * - * @throws IndexOutOfBoundsException If any element of {@code axes} is out of bounds for the rank of this tensor. - * @throws IllegalArgumentException If {@code axes} is not a permutation of {@code {1, 2, 3, ... N-1}}. - * @see #T(int, int) - * @see #T() - */ - @Override - public T T(int... axes) { - V[] destEntries = (V[]) new Ring[nnz]; - int[][] destIndices = new int[nnz][rank]; - CooTranspose.tensorTranspose(shape, data, indices, axes, destEntries, destIndices); - return makeLikeTensor(shape.permuteAxes(axes), destEntries, destIndices); - } - - - /** - * Creates a deep copy of this tensor. - * - * @return A deep copy of this tensor. - */ - @Override - public T copy() { - return makeLikeTensor(shape, data.clone()); - } - - - /** - * Finds the minimum (non-zero) value in this tensor. If this tensor is complex, then this method finds the smallest value in - * magnitude. - * - * @return The minimum (non-zero) value in this tensor. - */ - @Override - public V min() { - return CompareSemiring.min(data); - } - - - /** - * Finds the maximum (non-zero) value in this tensor. - * - * @return The maximum (non-zero) value in this tensor. - */ - @Override - public V max() { - return CompareSemiring.max(data); - } - - - /** - * Finds the indices of the minimum (non-zero) value in this tensor. - * - * @return The indices of the minimum (non-zero) value in this tensor. - */ - @Override - public int[] argmin() { - return indices[CompareSemiring.argmin(data)]; - } - - - /** - * Finds the indices of the maximum (non-zero) value in this tensor. - * - * @return The indices of the maximum (non-zero) value in this tensor. - */ - @Override - public int[] argmax() { - return indices[CompareSemiring.argmin(data)]; - } - - - /** - * Gets the element of this tensor at the specified target. - * - * @param target Index of the element to get. - * - * @return The element of this tensor at the specified index. If there is a non-zero value with the specified index, that value - * will be returned. If there is no non-zero value at the specified index than the zero element will attempt to be - * returned (i.e. the additive identity of the arrays). However, if the zero element could not be determined during - * construction or if it was not set with {@link #setZeroElement(Ring)} then - * {@code null} will be returned. - * - * @throws ArrayIndexOutOfBoundsException If any target are not within this tensor. - */ - @Override - public V get(int... target) { - ValidateParameters.validateTensorIndex(shape, target); - V value = CooGetSet.getCoo(data, indices, target); - return (value == null) ? getZeroElement() : value; - } - - - /** - * Sets the element of this tensor at the specified target. - * - * @param value New value to set the specified index of this tensor to. - * @param target Index of the element to set. - * - * @return If this tensor is dense, a reference to this tensor is returned. If this tensor is sparse, a copy of this tensor with - * the updated value is returned. - * - * @throws IndexOutOfBoundsException If {@code target} is not within the bounds of this tensor. - */ - @Override - public T set(V value, int... target) { - ValidateParameters.validateTensorIndex(shape, target); - int idx = SparseElementSearch.binarySearchCoo(indices, target); - - V[] destEntries; - int[][] destIndices; - - if (idx >= 0) { - // Target index found. - destEntries = data.clone(); - destIndices = ArrayUtils.deepCopy(indices, null); - destEntries[idx] = value; - destIndices[idx] = target; - } else { - // Target not found, insert new value and index. - destEntries = (V[]) new Ring[nnz + 1]; - destIndices = new int[nnz + 1][rank]; - int insertionPoint = - (idx + 1); - CooGetSet.cooInsertNewValue(value, target, data, indices, insertionPoint, destEntries, destIndices); - } - - return makeLikeTensor(shape, destEntries, destIndices); - } - - - /** - * Flattens tensor to single dimension while preserving order of data. - * - * @return The flattened tensor. - * - * @see #flatten(int) - */ - @Override - public T flatten() { - return makeLikeTensor( - shape.flatten(), - data.clone(), - SparseUtils.cooFlattenIndices(shape, indices)); - } - - - /** - * Flattens a tensor along the specified axis. Unlike {@link #flatten()} - * - * @param axis Axis along which to flatten tensor. - * - * @throws ArrayIndexOutOfBoundsException If the axis is not positive or larger than {@code this.{@link #getRank()}-1}. - * @see #flatten() - */ - @Override - public T flatten(int axis) { - int[] destShape = new int[indices[0].length]; - Arrays.fill(destShape, 1); - destShape[axis] = shape.totalEntries().intValueExact(); - - return makeLikeTensor( - new Shape(destShape), - data.clone(), - SparseUtils.cooFlattenIndices(shape, indices, axis)); - } - - - /** - * Copies and reshapes this tensor. - * - * @param newShape New shape for the tensor. - * - * @return A copy of this tensor with the new shape. - * - * @throws TensorShapeException If {@code newShape} is not broadcastable to {@link #shape this.shape}. - */ - @Override - public T reshape(Shape newShape) { - return makeLikeTensor(newShape, data.clone(), SparseUtils.cooReshape(shape, newShape, indices)); - } - - - /** - * Sorts the indices of this tensor in lexicographical order while maintaining the associated value for each index. - */ - public void sortIndices() { - CooDataSorter.wrap(data, indices).sparseSort().unwrap(data, indices); - } - - - /** - * Converts this COO tensor to an equivalent dense tensor. - * @return A dense tensor which is equivalent to this COO tensor. - * @throws ArithmeticException If the number of data in the dense tensor exceeds 2,147,483,647. - */ - public U toDense() { - V[] denseEntries = (V[]) new Ring[shape.totalEntriesIntValueExact()]; - CooConversions.toDense(shape, data, indices, denseEntries); - return makeLikeDenseTensor(shape, denseEntries); + super(shape, data, indices); } @@ -546,13 +91,14 @@ public U toDense() { */ @Override public T sub(T b) { - SparseTensorData data = CooRingTensorOps.sub( - shape, this.data, indices, + SparseTensorData diff = CooRingTensorOps.sub( + shape, data, indices, b.shape, b.data, b.indices); - - return makeLikeTensor(data.shape(), - (V[]) data.data().toArray(new Ring[data.data().size()]), - data.indices().toArray(new int[data.indices().size()][])); + V[] dest = makeEmptyDataArray(diff.data().size()); + diff.data().toArray(dest); + return makeLikeTensor(diff.shape(), + dest, + diff.indices().toArray(new int[diff.indices().size()][])); } @@ -592,41 +138,4 @@ public T H(int axis1, int axis2) { public T H(int... axes) { return T(axes); } - - - - /** - * Coalesces this sparse COO tensor. An uncoalesced tensor is a sparse tensor with multiple data for a single index. This - * method will ensure that each index only has one non-zero value by summing duplicated data. If another form of aggregation other - * than summing is desired, use {@link #coalesce(BinaryOperator)}. - * @return A new coalesced sparse COO tensor which is equivalent to this COO tensor. - * @see #coalesce(BinaryOperator) - */ - public T coalesce() { - SparseTensorData tensor = SparseUtils.coalesce(Semiring::add, shape, data, indices); - return makeLikeTensor(tensor.shape(), tensor.data(), tensor.indices()); - } - - - /** - * Coalesces this sparse COO tensor. An uncoalesced tensor is a sparse tensor with multiple data for a single index. This - * method will ensure that each index only has one non-zero value by aggregating duplicated data using {@code aggregator}. - * @param aggregator Custom aggregation function to combine multiple. - * @return A new coalesced sparse COO tensor which is equivalent to this COO tensor. - * @see #coalesce() - */ - public T coalesce(BinaryOperator aggregator) { - SparseTensorData tensor = SparseUtils.coalesce(aggregator, shape, data, indices); - return makeLikeTensor(tensor.shape(), tensor.data(), tensor.indices()); - } - - - /** - * Drops any explicit zeros in this sparse COO tensor. - * @return A copy of this COO tensor with any explicitly stored zeros removed. - */ - public T dropZeros() { - SparseTensorData tensor = SparseUtils.dropZeros(shape, data, indices); - return makeLikeTensor(tensor.shape(), tensor.data(), tensor.indices()); - } } diff --git a/src/main/java/org/flag4j/arrays/backend/ring_arrays/AbstractCooRingVector.java b/src/main/java/org/flag4j/arrays/backend/ring_arrays/AbstractCooRingVector.java index 233a9d40d..2fb3e1bec 100644 --- a/src/main/java/org/flag4j/arrays/backend/ring_arrays/AbstractCooRingVector.java +++ b/src/main/java/org/flag4j/arrays/backend/ring_arrays/AbstractCooRingVector.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -24,30 +24,14 @@ package org.flag4j.arrays.backend.ring_arrays; -import org.flag4j.algebraic_structures.Field; import org.flag4j.algebraic_structures.Ring; -import org.flag4j.algebraic_structures.Semiring; import org.flag4j.arrays.Shape; import org.flag4j.arrays.SparseVectorData; -import org.flag4j.arrays.backend.AbstractTensor; import org.flag4j.arrays.backend.VectorMixin; -import org.flag4j.linalg.ops.common.semiring_ops.AggregateSemiring; -import org.flag4j.linalg.ops.sparse.SparseUtils; -import org.flag4j.linalg.ops.sparse.coo.CooConcat; -import org.flag4j.linalg.ops.sparse.coo.CooDataSorter; -import org.flag4j.linalg.ops.sparse.coo.CooGetSet; +import org.flag4j.arrays.backend.semiring_arrays.AbstractCooSemiringVector; import org.flag4j.linalg.ops.sparse.coo.ring_ops.CooRingVectorOps; -import org.flag4j.linalg.ops.sparse.coo.semiring_ops.CooSemiringVectorOps; -import org.flag4j.util.ArrayUtils; -import org.flag4j.util.ValidateParameters; -import org.flag4j.util.exceptions.LinearAlgebraException; import org.flag4j.util.exceptions.TensorShapeException; -import java.math.BigDecimal; -import java.util.Arrays; -import java.util.List; -import java.util.function.BinaryOperator; - /** *

    A sparse vector stored in coordinate list (COO) format. The {@link #data} of this COO vector are * elements of a {@link Ring}. @@ -85,32 +69,10 @@ public abstract class AbstractCooRingVector< V extends AbstractCooRingMatrix, W extends AbstractDenseRingMatrix, Y extends Ring> - extends AbstractTensor + extends AbstractCooSemiringVector implements RingTensorMixin, VectorMixin { - /** - * The zero element for the arrays that this tensor's elements belong to. - */ - private Y zeroElement; - /** - * Indices of the non-zero values of this sparse COO vector. - */ - public final int[] indices; - /** - * The number of non-zero data in this sparse COO vector. - */ - public final int nnz; - /** - * The total size of this sparse COO vector (including zero values). - */ - public final int size; - /** - * The sparsity of this matrix. - */ - public final double sparsity; - - /** * Creates a COO vector with the specified data and shape. * @@ -120,278 +82,14 @@ public abstract class AbstractCooRingVector< * If this tensor is sparse, this specifies only the non-zero data of the tensor. */ protected AbstractCooRingVector(Shape shape, Y[] entries, int[] indices) { - super(shape, entries); - ValidateParameters.ensureRank(shape, 1); - ValidateParameters.ensureIndicesInBounds(shape.get(0), indices); - if(entries.length != indices.length) { - throw new IllegalArgumentException("data and indices arrays of a COO vector must have the same length but got " + - "lengths" + entries.length + " and " + indices.length + "."); - } - this.size = shape.totalEntriesIntValueExact(); - - if(entries.length > size) { - throw new IllegalArgumentException("The number of data cannot be greater than the size of the vector but but got " + - "data.length=" + entries.length + " and size=" + size + "."); - } - - this.indices = indices; - this.nnz = entries.length; - sparsity = BigDecimal.valueOf(nnz).divide(new BigDecimal(shape.totalEntries())).doubleValue(); - - // Attempt to set the zero element for the arrays. - this.zeroElement = (entries.length > 0) ? entries[0].getZero() : null; - } - - - /** - * Constructs a sparse COO vector of the same type as this vector with the specified non-zero data and indices. - * @param shape Shape of the vector to construct. - * @param entries Non-zero data of the vector to construct. - * @param indices Non-zero row indices of the vector to construct. - * @return A sparse COO vector of the same type as this vector with the specified non-zero data and indices. - */ - public abstract T makeLikeTensor(Shape shape, Y[] entries, int[] indices); - - - /** - * Constructs a dense vector of a similar type as this vector with the specified shape and data. - * @param shape Shape of the vector to construct. - * @param entries Entries of the vector to construct. - * @return A dense vector of a similar type as this vector with the specified data. - */ - public abstract U makeLikeDenseTensor(Shape shape, Y... entries); - - - /** - * Constructs a dense matrix of a similar type as this vector with the specified shape and data. - * @param shape Shape of the matrix to construct. - * @param entries Entries of the matrix to construct. - * @return A dense matrix of a similar type as this vector with the specified data. - */ - public abstract W makeLikeDenseMatrix(Shape shape, Y... entries); - - - /** - * Constructs a COO vector with the specified shape, non-zero data, and non-zero indices. - * @param shape Shape of the vector. - * @param entries Non-zero values of the vector. - * @param indices Indices of the non-zero values in the vector. - * @return A COO vector of the same type as this vector with the specified shape, non-zero data, and non-zero indices. - */ - public abstract T makeLikeTensor(Shape shape, List entries, List indices); - - - /** - * Constructs a COO matrix with the specified shape, non-zero data, and row and column indices. - * @param shape Shape of the matrix to construct. - * @param entries Non-zero data of the matrix. - * @param rowIndices Row indices of the matrix. - * @param colIndices Column indices of the matrix. - * @return A COO matrix of similar type as this vector with the specified shape, non-zero data, and non-zero row/col indices. - */ - public abstract V makeLikeMatrix(Shape shape, Y[] entries, int[] rowIndices, int[] colIndices); - - - /** - * Gets the sparsity of this matrix as a decimal percentage. - * That is, the percentage of data in this matrix that are zero. - * @return The sparsity of this matrix as a decimal percentage. - * @see #density() - */ - public double sparsity() { - return sparsity; - } - - - /** - * Gets the density of this matrix as a decimal percentage. - * That is, the percentage of data in this matrix that are non-zero. - * @return The density of this matrix as a decimal percentage. - * @see #sparsity - */ - public double density() { - return 1.0 - sparsity; - } - - - /** - * Sorts the indices of this tensor in lexicographical order while maintaining the associated value for each index. - */ - public void sortIndices() { - CooDataSorter.wrap(data, indices).sparseSort().unwrap(data, indices); - } - - - /** - * Gets the element of this tensor at the specified indices. - * - * @param target Indices of the element to get. - * - * @return The element of this tensor at the specified indices. - * - * @throws IndexOutOfBoundsException If any {target} are not within this tensor. - */ - @Override - public Y get(int... target) { - ValidateParameters.ensureArrayLengthsEq(1, target.length); - return get(target[0]); - } - - - /** - * Computes the transpose of a tensor by exchanging {@code axis1} and {@code axis2}. - * - * @param axis1 First axis to exchange. - * @param axis2 Second axis to exchange. - * - * @return The transpose of this tensor according to the specified axes. - * - * @throws IndexOutOfBoundsException If either {@code axis1} or {@code axis2} are out of bounds for the rank of this tensor. - * @see #T() - * @see #T(int...) - */ - @Override - public T T(int axis1, int axis2) { - ValidateParameters.ensureValidAxes(shape, axis1, axis2); - return copy(); - } - - - /** - * Computes the transpose of this tensor. That is, permutes the axes of this tensor so that it matches - * the permutation specified by {@code axes}. - * - * @param axes Permutation of tensor axis. If the tensor has rank {@code N}, then this must be an array of length - * {@code N} which is a permutation of {@code {0, 1, 2, ..., N-1}}. - * - * @return The transpose of this tensor with its axes permuted by the {@code axes} array. - * - * @throws IllegalArgumentException If {@code axes} is not a permutation of {@code {1, 2, 3, ... N-1}}. - * @see #T(int, int) - * @see #T() - */ - @Override - public T T(int... axes) { - if(axes.length != 1) - throw new IllegalArgumentException("Axes for tensor of rank 1 must be permutation of {1}."); - ValidateParameters.ensurePermutation(axes); - return copy(); - } - - - /** - * Creates a deep copy of this tensor. - * - * @return A deep copy of this tensor. - */ - @Override - public T copy() { - return makeLikeTensor(shape, data); - } - - - /** - * Sets the element of this tensor at the specified indices. - * - * @param value New value to set the specified index of this tensor to. - * @param target Indices of the element to set. - * - * @return If this tensor is dense, a reference to this tensor is returned. If this tensor is sparse, a copy of this tensor with - * the updated value is returned. - * - * @throws IndexOutOfBoundsException If {@code indices} is not within the bounds of this tensor. - */ - @Override - public T set(Y value, int... target) { - ValidateParameters.validateTensorIndex(shape, target); - int idx = Arrays.binarySearch(indices, target[0]); - - Y[] destEntries; - int[] destIndices; - - if (idx >= 0) { - // Target index found. - destEntries = data.clone(); - destIndices = indices.clone(); - destEntries[idx] = value; - destIndices[idx] = target[0]; - } else { - // Target not found, insert new value and index. - destEntries = (Y[]) new Ring[nnz + 1]; - destIndices = new int[nnz + 1]; - int insertionPoint = - (idx + 1); - CooGetSet.cooInsertNewValue(value, target[0], data, indices, insertionPoint, destEntries, destIndices); - } - - return makeLikeTensor(shape, destEntries, destIndices); - } - - - /** - * Flattens tensor to single dimension while preserving order of data. - * - * @return The flattened tensor. - * - * @see #flatten(int) - */ - @Override - public T flatten() { - return copy(); - } - - - /** - * Flattens a tensor along the specified axis. Unlike {@link #flatten()} - * - * @param axis Axis along which to flatten tensor. - * - * @throws ArrayIndexOutOfBoundsException If the axis is not positive or larger than {@code this.{@link #getRank()}-1}. - * @see #flatten() - */ - @Override - public T flatten(int axis) { - ValidateParameters.ensureValidAxes(shape, axis); - return copy(); - } - - - /** - * Copies and reshapes this tensor. - * - * @param newShape New shape for the tensor. - * - * @return A copy of this tensor with the new shape. - * - * @throws TensorShapeException If {@code newShape} is not broadcastable to {@link #shape this.shape}. - */ - @Override - public T reshape(Shape newShape) { - ValidateParameters.ensureRank(newShape, 1); - return copy(); - } - - - /** - * Joints specified vector with this vector. That is, creates a vector of length {@code this.length() + b.length()} containing - * first the elements of this vector followed by the elements of {@code b}. - * - * @param b Vector to join with this vector. - * - * @return A vector resulting from joining the specified vector with this vector. - */ - @Override - public T join(T b) { - Y[] destEntries = (Y[]) new Ring[this.data.length + b.data.length]; - int[] destIndices = new int[this.indices.length + b.indices.length]; - CooConcat.join(data, indices, size, b.data, b.indices, destEntries, destIndices); - return makeLikeTensor(new Shape(shape.get(0) + b.shape.get(0)), destEntries, destIndices); + super(shape, entries, indices); } /** *

    Computes the inner product between two vectors. * - *

    Note: this method is distinct from {@link #dot(AbstractCooRingVector)}. The inner product is equivalent to the dot product + *

    Note: this method is distinct from {@link #dot(AbstractCooSemiringVector)}. The inner product is equivalent to the dot product * of this tensor with the conjugation of {@code b}. * * @param b Second vector in the inner product. @@ -399,337 +97,14 @@ public T join(T b) { * @return The inner product between this vector and the vector {@code b}. * * @throws IllegalArgumentException If this vector and vector {@code b} do not have the same number of data. - * @see #dot(AbstractCooRingVector) + * @see #dot(AbstractCooSemiringVector) */ @Override public Y inner(T b) { - return dot(b); // For rings, this will be the same. - } - - - /** - *

    Computes the dot product between two vectors. - * - *

    Note: this method is distinct from {@link #inner(AbstractCooRingVector)}. - * The inner product is equivalent to the dot product of this tensor with the conjugation of {@code b}. - * - * @param b Second vector in the dot product. - * - * @return The dot product between this vector and the vector {@code b}. - * - * @throws IllegalArgumentException If this vector and vector {@code b} do not have the same number of data. - * @see #inner(AbstractCooRingVector) - */ - @Override - public Y dot(T b) { - return CooSemiringVectorOps.dot(shape, data, indices, b.shape, b.data, b.indices); - } - - - /** - *

    Gets the length of a vector. Same as {@link #size()}. - *

    WARNING: This method will throw a {@link ArithmeticException} if the - * total number of data in this vector is greater than the maximum integer. In this case, the true size of this vector can - * still be found by calling {@code shape.totalEntries()} on this vector. - * - * @return The length, i.e. the number of data, in this vector. - * @throws ArithmeticException If the total number of data in this vector is greater than the maximum integer. - */ - @Override - public int length() { - return shape.totalEntriesIntValueExact(); - } - - - /** - * Repeats a vector {@code n} times along a certain axis to create a matrix. - * - * @param n Number of times to repeat vector. - * @param axis Axis along which to repeat vector: - *

      - *
    • If {@code axis=0}, then the vector will be treated as a row vector and stacked vertically {@code n} times.
    • - *
    • If {@code axis=1} then the vector will be treated as a column vector and stacked horizontally {@code n} times.
    • - *
    - * - * @return A matrix whose rows/columns are this vector repeated. - */ - @Override - public V repeat(int n, int axis) { - Y[] tiledEntries = (Y[]) new Field[n*data.length]; - int[] tiledRows = new int[tiledEntries.length]; - int[] tiledCols = new int[tiledEntries.length]; - Shape tiledShape = CooConcat.repeat(data, indices, size, n, axis, tiledEntries, tiledRows, tiledCols); - return makeLikeMatrix(tiledShape, data, tiledRows, tiledCols); - } - - - /** - *

    - * Stacks two vectors along specified axis. - * - * - *

    - * Stacking two vectors of length {@code n} along axis 0 stacks the vectors - * as if they were row vectors resulting in a {@code 2-by-n} matrix. - * - * - *

    - * Stacking two vectors of length {@code n} along axis 1 stacks the vectors - * as if they were column vectors resulting in a {@code n-by-2} matrix. - * - * - * @param b Vector to stack with this vector. - * @param axis Axis along which to stack vectors. If {@code axis=0}, then vectors are stacked as if they are row - * vectors. If {@code axis=1}, then vectors are stacked as if they are column vectors. - * - * @return The result of stacking this vector and the vector {@code b}. - * - * @throws IllegalArgumentException If the number of data in this vector is different from the number of - * data in the vector {@code b}. - * @throws IllegalArgumentException If axis is not either 0 or 1. - */ - @Override - public V stack(T b, int axis) { - ValidateParameters.ensureEquals(size, b.size); - Y[] destEntries = (Y[]) new Ring[data.length + b.data.length]; - int[][] destIndices = new int[2][indices.length + indices.length]; // Row and column indices. - - CooConcat.stack(data, indices, b.data, b.indices, destEntries, destIndices[0], destIndices[1]); - V mat = makeLikeMatrix(new Shape(2, size), destEntries, destIndices[0], destIndices[1]); - - return (axis == 0) ? mat : mat.T(); - } - - - /** - * Computes the outer product of two vectors. - * - * @param b Second vector in the outer product. - * - * @return The result of the vector outer product between this vector and {@code b}. - * - * @throws IllegalArgumentException If the two vectors do not have the same number of data. - */ - @Override - public W outer(T b) { - Shape destShape = new Shape(size, b.size); - Y[] dest = (Y[]) new Ring[size*b.size]; - CooSemiringVectorOps.outerProduct(data, indices, size, b.data, b.indices, dest); - return makeLikeDenseMatrix(shape, dest); - } - - - /** - * Converts a vector to an equivalent matrix representing either a row or column vector. - * - * @param columVector Flag indicating whether to convert this vector to a matrix representing a row or column vector: - *

    If {@code true}, the vector will be converted to a matrix representing a column vector. - *

    If {@code false}, The vector will be converted to a matrix representing a row vector. - * - * @return A matrix equivalent to this vector. - */ - @Override - public V toMatrix(boolean columVector) { - if(columVector) { - // Convert to column vector - int[] rowIndices = indices.clone(); - int[] colIndices = new int[data.length]; - Shape matShape = new Shape(size, 1); - - return makeLikeMatrix(matShape, data.clone(), rowIndices, colIndices); - } else { - // Convert to row vector. - int[] rowIndices = new int[data.length]; - int[] colIndices = indices.clone(); - Shape matShape = new Shape(1, size); - - return makeLikeMatrix(matShape, data.clone(), rowIndices, colIndices); - } - } - - - /** - * Normalizes this vector to a unit length vector. - * - * @return This vector normalized to a unit length. - */ - @Override - public T normalize() { - throw new UnsupportedOperationException("Normalization not supported for arrays vectors."); - } - - - /** - * Computes the magnitude of this vector. - * - * @return The magnitude of this vector. - */ - @Override - public Y mag() { - return AggregateSemiring.sum(data); - } - - - /** - * Gets the element of this vector at the specified index. - * - * @param idx Index of the element to get within this vector. - * - * @return The element of this vector at index {@code idx}. - */ - @Override - public Y get(int idx) { - ValidateParameters.validateTensorIndex(shape, idx); - Y value = CooGetSet.getCoo(data, indices, idx); - return (value == null) ? getZeroElement() : value; - } - - - /** - * Computes the element-wise sum between two tensors of the same shape. - * - * @param b Second tensor in the element-wise sum. - * - * @return The sum of this tensor with {@code b}. - * - * @throws TensorShapeException If this tensor and {@code b} do not have the same shape. - */ - @Override - public T add(T b) { - SparseVectorData result = CooSemiringVectorOps.add( - shape, data, indices, b.shape, b.data, b.indices); - return makeLikeTensor(shape, - (Y[]) result.data().toArray(new Ring[result.data().size()]), - ArrayUtils.fromIntegerList(result.indices())); - } - - - /** - * Computes the element-wise multiplication of two tensors of the same shape. - * - * @param b Second tensor in the element-wise product. - * - * @return The element-wise product between this tensor and {@code b}. - * - * @throws IllegalArgumentException If this tensor and {@code b} do not have the same shape. - */ - @Override - public T elemMult(T b) { - SparseVectorData prod = CooSemiringVectorOps.elemMult( - shape, data, indices, - b.shape, b.data, b.indices); - return makeLikeTensor(shape, - (Y[]) prod.data().toArray(new Ring[prod.data().size()]), - ArrayUtils.fromIntegerList(prod.indices())); - } - - - /** - * Computes the tensor contraction of this tensor with a specified tensor over the specified set of axes. That is, - * computes the sum of products between the two tensors along the specified set of axes. - * - * @param src2 Tensor to contract with this tensor. - * @param aAxes Axes along which to compute products for this tensor. - * @param bAxes Axes along which to compute products for {@code src2} tensor. - * - * @return The tensor dot product over the specified axes. - * - * @throws IllegalArgumentException If the two tensors shapes do not match along the specified axes pairwise in - * {@code aAxes} and {@code bAxes}. - * @throws IllegalArgumentException If {@code aAxes} and {@code bAxes} do not match in length, or if any of the axes - * are out of bounds for the corresponding tensor. - */ - @Override - public U tensorDot(T src2, int[] aAxes, int[] bAxes) { - if(aAxes.length != 1 || bAxes.length != 1) { - throw new LinearAlgebraException("Vector dot product requires exactly one dimension for each vector but got " - + aAxes.length + " and " + bAxes.length + "."); - } - if(aAxes[0] != 0 || bAxes[0] != 0) { - throw new LinearAlgebraException("Both axes must be 0 for vector dot product but got " - + aAxes[0] + " and " + bAxes[0] + "."); - } - - return makeLikeDenseTensor(shape, dot(src2)); - } - - - /** - *

    Computes the generalized trace of this tensor along the specified axes. - * - *

    The generalized tensor trace is the sum along the diagonal values of the 2D sub-arrays of this tensor specified by - * {@code axis1} and {@code axis2}. The shape of the resulting tensor is equal to this tensor with the - * {@code axis1} and {@code axis2} removed. - * - * @param axis1 First axis for 2D sub-array. - * @param axis2 Second axis for 2D sub-array. - * - * @return The generalized trace of this tensor along {@code axis1} and {@code axis2}. - * - * @throws IndexOutOfBoundsException If the two axes are not both larger than zero and less than this tensors rank. - * @throws IllegalArgumentException If {@code axis1 == axis2} or {@code this.shape.get(axis1) != this.shape.get(axis1)} - * (i.e. the axes are equal or the tensor does not have the same length along the two axes.) - */ - @Override - public T tensorTr(int axis1, int axis2) { - throw new LinearAlgebraException("Tensor trace cannot be computed for a rank 1 tensor " + - "(must be rank 2 or " + "greater)."); - } - - - /** - * Gets the zero element for the field of this vector. - * @return The zero element for the field of this vector. If it could not be determined during construction of this object - * and has not been set explicitly by {@link #setZeroElement(Ring)} then {@code null} will be returned. - */ - public Y getZeroElement() { - return zeroElement; - } - - - /** - * Sets the zero element for the field of this tensor. - * @param zeroElement The zero element of this tensor. - * @throws IllegalArgumentException If {@code zeroElement} is not an additive identity for the ring. - */ - public void setZeroElement(Y zeroElement) { - if (zeroElement.isZero()) { - this.zeroElement = zeroElement; - } else { - throw new IllegalArgumentException("The provided zeroElement is not an additive identity."); - } + return CooRingVectorOps.inner(this, b); } - /** - * Converts this sparse COO matrix to an equivalent dense matrix. - * @return A dense matrix equivalent to this sparse COO matrix. - */ - public U toDense() { - Y[] entries = (Y[]) new Ring[shape.totalEntriesIntValueExact()]; - - for(int i = 0; i< nnz; i++) - entries[indices[i]] = this.data[i]; - - return makeLikeDenseTensor(shape, entries); - } - - - /** - * Converts this matrix to an equivalent rank 1 tensor. - * @return A tensor which is equivalent to this matrix. - */ - public abstract AbstractTensor toTensor(); - - - /** - * Converts this vector to an equivalent tensor with the specified shape. - * @param newShape New shape for the tensor. Can be any rank but must be broadcastable to {@link #shape this.shape}. - * @return A tensor equivalent to this matrix which has been reshaped to {@code newShape} - */ - public abstract AbstractTensor toTensor(Shape newShape); - - /** * Computes the element-wise difference between two tensors of the same shape. * @@ -783,40 +158,4 @@ public T H(int axis1, int axis2) { public T H(int... axes) { return T(axes); } - - - /** - * Coalesces this sparse COO vector. An uncoalesced vector is a sparse vector with multiple data for a single index. This - * method will ensure that each index only has one non-zero value by summing duplicated data. If another form of aggregation other - * than summing is desired, use {@link #coalesce(BinaryOperator)}. - * @return A new coalesced sparse COO vector which is equivalent to this COO vector. - * @see #coalesce(BinaryOperator) - */ - public T coalesce() { - SparseVectorData vec = SparseUtils.coalesce(Semiring::add, shape, data, indices); - return makeLikeTensor(vec.shape(), vec.data(), vec.indices()); - } - - - /** - * Coalesces this sparse COO vector. An uncoalesced vector is a sparse vector with multiple data for a single index. This - * method will ensure that each index only has one non-zero value by aggregating duplicated data using {@code aggregator}. - * @param aggregator Custom aggregation function to combine multiple. - * @return A new coalesced sparse COO vector which is equivalent to this COO vector. - * @see #coalesce() - */ - public T coalesce(BinaryOperator aggregator) { - SparseVectorData vec = SparseUtils.coalesce(aggregator, shape, data, indices); - return makeLikeTensor(vec.shape(), vec.data(), vec.indices()); - } - - - /** - * Drops any explicit zeros in this sparse COO vector. - * @return A copy of this COO vector with any explicitly stored zeros removed. - */ - public T dropZeros() { - SparseVectorData vec = SparseUtils.dropZeros(shape, data, indices); - return makeLikeTensor(vec.shape(), vec.data(), vec.indices()); - } } diff --git a/src/main/java/org/flag4j/arrays/backend/ring_arrays/AbstractCsrRingMatrix.java b/src/main/java/org/flag4j/arrays/backend/ring_arrays/AbstractCsrRingMatrix.java index 64b719669..3c1459262 100644 --- a/src/main/java/org/flag4j/arrays/backend/ring_arrays/AbstractCsrRingMatrix.java +++ b/src/main/java/org/flag4j/arrays/backend/ring_arrays/AbstractCsrRingMatrix.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -25,28 +25,15 @@ package org.flag4j.arrays.backend.ring_arrays; -import org.flag4j.algebraic_structures.Field; import org.flag4j.algebraic_structures.Ring; import org.flag4j.arrays.Shape; import org.flag4j.arrays.SparseMatrixData; -import org.flag4j.arrays.backend.AbstractTensor; import org.flag4j.arrays.backend.MatrixMixin; -import org.flag4j.linalg.ops.sparse.csr.CsrConversions; +import org.flag4j.arrays.backend.semiring_arrays.AbstractCsrSemiringMatrix; import org.flag4j.linalg.ops.sparse.csr.CsrOps; -import org.flag4j.linalg.ops.sparse.csr.CsrProperties; -import org.flag4j.linalg.ops.sparse.csr.semiring_ops.SemiringCsrMatMult; -import org.flag4j.linalg.ops.sparse.csr.semiring_ops.SemiringCsrOps; -import org.flag4j.linalg.ops.sparse.csr.semiring_ops.SemiringCsrProperties; -import org.flag4j.util.ValidateParameters; -import org.flag4j.util.exceptions.LinearAlgebraException; +import org.flag4j.linalg.ops.sparse.csr.ring_ops.CsrRingProperties; import org.flag4j.util.exceptions.TensorShapeException; -import java.math.BigDecimal; -import java.util.Arrays; -import java.util.List; - -import static org.flag4j.linalg.ops.sparse.SparseUtils.sortCsrMatrix; - /** *

    A sparse matrix stored in compressed sparse row (CSR) format. The {@link #data} of this CSR matrix are @@ -83,47 +70,9 @@ public abstract class AbstractCsrRingMatrix, V extends AbstractCooRingVector, W extends Ring> - extends AbstractTensor + extends AbstractCsrSemiringMatrix implements RingTensorMixin, MatrixMixin { - /** - * The zero element for the arrays that this tensor's elements belong to. - */ - private W zeroElement; - /** - *

    Pointers indicating starting index of each row within the {@link #colIndices} and {@link #data} arrays. - * Has length {@link #numRows numRows + 1}. - * - *

    The range [{@code data[rowPointers[i]], data[rowPointers[i+1]]}) contains all {@link #data non-zero data} within - * row {@code i}. - * - *

    Similarly, [{@code colData[rowPointers[i]], colData[rowPointers[i+1]]}) contains all {@link #colIndices column indices} - * for the data in row {@code i}. - * - */ - public final int[] rowPointers; - /** - * Column indices for non-zero values of this sparse CSR matrix. - */ - public final int[] colIndices; - /** - * Number of non-zero data in this CSR matrix. - */ - public final int nnz; - /** - * The number of rows in this matrix. - */ - public final int numRows; - /** - * The number of columns in this matrix. - */ - public final int numCols; - /** - * The sparsity of this matrix. - */ - private final double sparsity; - - /** * Creates a sparse CSR matrix with the specified {@code shape}, non-zero data, row pointers, and non-zero column indices. * @@ -136,884 +85,7 @@ public abstract class AbstractCsrRingMatrix 0) ? entries[0].getZero() : null; - } - - - /** - * Constructs a sparse CSR tensor of the same type as this tensor with the specified non-zero data and indices. - * @param shape Shape of the matrix. - * @param entries Non-zero data of the CSR matrix. - * @param rowPointers Row pointers for the non-zero values in the CSR matrix. - * @param colIndices Non-zero column indices of the CSR matrix. - * @return A sparse CSR tensor of the same type as this tensor with the specified non-zero data and indices. - */ - public abstract T makeLikeTensor(Shape shape, W[] entries, int[] rowPointers, int[] colIndices); - - - /** - * Constructs a CSR matrix with the specified shape, non-zero data, and non-zero indices. - * @param shape Shape of the matrix. - * @param entries Non-zero values of the CSR matrix. - * @param rowPointers Row pointers for the non-zero values in the CSR matrix. - * @param colIndices Non-zero column indices of the CSR matrix. - * @return A CSR matrix with the specified shape, non-zero data, and non-zero indices. - */ - public abstract T makeLikeTensor(Shape shape, List entries, List rowPointers, List colIndices); - - - /** - * Constructs a dense matrix which is of a similar type to this sparse CSR matrix. - * @param shape Shape of the dense matrix. - * @param entries Entries of the dense matrix. - * @return A dense matrix which is of a similar type to this sparse CSR matrix with the specified {@code shape} - * and {@code data}. - */ - public abstract U makeLikeDenseTensor(Shape shape, W[] entries); - - - /** - *

    Constructs a sparse COO matrix of a similar type to this sparse CSR matrix. - *

    Note: this method constructs a new COO matrix with the specified data and indices. It does not convert this matrix - * to a CSR matrix. To convert this matrix to a sparse COO matrix use {@link #toCoo()}. - * @param shape Shape of the COO matrix. - * @param entries Non-zero data of the COO matrix. - * @param rowIndices Non-zero row indices of the sparse COO matrix. - * @param colIndices Non-zero column indices of the Sparse COO matrix. - * @return A sparse COO matrix of a similar type to this sparse CSR matrix. - */ - public abstract AbstractCooRingMatrix makeLikeCooMatrix( - Shape shape, W[] entries, int[] rowIndices, int[] colIndices); - - - /** - * Gets the sparsity of this matrix as a decimal percentage. - * That is, the percentage of data in this matrix that are zero. - * @return The sparsity of this matrix as a decimal percentage. - * @see #density() - */ - public double sparsity() { - return sparsity; - } - - - /** - * Gets the density of this matrix as a decimal percentage. - * That is, the percentage of data in this matrix that are non-zero. - * @return The density of this matrix as a decimal percentage. - * @see #sparsity - */ - public double density() { - return 1.0 - sparsity; - } - - - /** - * Gets the length of the data array which backs this matrix. - * - * @return The length of the data array which backs this matrix. - */ - @Override - public int dataLength() { - return data.length; - } - - - /** - * Gets the zero element for the arrays of this tensor. - * @return The zero element for the arrays of this tensor. If it could not be determined during construction of this object - * and has not been set explicitly by {@link #setZeroElement(Ring)} then {@code null} will be returned. - * - * @see #setZeroElement(Ring) - */ - public W getZeroElement() { - return zeroElement; - } - - - /** - * Sets the zero element for the arrays of this tensor. - * @param zeroElement The zero element of this tensor. - * @throws IllegalArgumentException If {@code zeroElement} is not an additive identity for the arrays. - * - * @see #getZeroElement() - */ - public void setZeroElement(W zeroElement) { - if (zeroElement.isZero()) { - this.zeroElement = zeroElement; - } else { - throw new IllegalArgumentException("The provided zeroElement is not an additive identity."); - } - } - - - - /** - * Gets the element of this tensor at the specified indices. - * - * @param indices Indices of the element to get. - * - * @return The element of this tensor at the specified indices. - * - * @throws ArrayIndexOutOfBoundsException If any indices are not within this tensor. - */ - @Override - public W get(int... indices) { - ValidateParameters.validateTensorIndex(shape, indices); - int row = indices[0]; - int col = indices[1]; - return get(row, col); - } - - - /** - * Sets the element of this tensor at the specified indices. - * - * @param value New value to set the specified index of this tensor to. - * @param indices Indices of the element to set. - * - * @return If this tensor is dense, a reference to this tensor is returned. If this tensor is sparse, a copy of this tensor with - * the updated value is returned. - * - * @throws IndexOutOfBoundsException If {@code indices} is not within the bounds of this tensor. - */ - @Override - public T set(W value, int... indices) { - ValidateParameters.validateTensorIndex(shape, indices); - return set(value, indices[0], indices[1]); - } - - - /** - * Sets a specified row of this matrix to a vector. - * - * @param row Vector to replace specified row in this matrix. - * @param rowIdx Index of the row to set. - * - * @return If this matrix is dense, the row set operation is done in place and a reference to this matrix is returned. - * If this matrix is sparse a copy will be created with the new row and returned. - */ - @Override - public T setRow(V row, int rowIdx) { - return (T) toCoo().setRow(row, rowIdx).toCsr(); - } - - - /** - * Sets a specified column of this matrix to a vector. - * - * @param col Vector to replace specified column in this matrix. - * @param colIdx Index of the column to set. - * - * @return If this matrix is dense, the column set operation is done in place and a reference to this matrix is returned. - * If this matrix is sparse a copy will be created with the new column and returned. - */ - @Override - public T setCol(V col, int colIdx) { - return (T) toCoo().setCol(col, colIdx).toCsr(); - } - - - /** - * Flattens tensor to single dimension while preserving order of data. - * - * @return The flattened tensor. - * - * @see #flatten(int) - */ - @Override - public T flatten() { - int[] newRowPointers = new int[2]; - newRowPointers[1] = nnz; - return makeLikeTensor( - new Shape(1, shape.totalEntriesIntValueExact()), - data.clone(), - newRowPointers, - colIndices.clone()); - } - - - /** - * Flattens a tensor along the specified axis. Unlike {@link #flatten()} - * - * @param axis Axis along which to flatten tensor. - * - * @throws ArrayIndexOutOfBoundsException If the axis is not positive or larger than {@code this.{@link #getRank()}-1}. - * @see #flatten() - */ - @Override - public T flatten(int axis) { - ValidateParameters.ensureValidAxes(shape, axis); - int[] newRowPointers; - int[] newColIndices; - - if (axis == 0) { - // Flatten to a single row. - newRowPointers = new int[2]; - newRowPointers[1] = nnz; - newColIndices = new int[nnz]; - } else { - // Flatten to a single column. - int flatSize = shape.totalEntriesIntValueExact(); - newColIndices = new int[nnz]; // Set all column indices to 0. - newRowPointers = new int[flatSize + 1]; - } - - Shape newShape = CsrConversions.flatten(shape, data, rowPointers, colIndices, axis, newRowPointers, newColIndices); - - return makeLikeTensor( - new Shape(shape.totalEntriesIntValueExact(), 1), - data.clone(), - newRowPointers, - newColIndices); - } - - - /** - * Copies and reshapes this tensor. - * - * @param newShape New shape for the tensor. - * - * @return A copy of this tensor with the new shape. - * - * @throws TensorShapeException If {@code newShape} is not broadcastable to {@link #shape this.shape}. - */ - @Override - public T reshape(Shape newShape) { - return (T) toCoo().reshape(newShape).toCsr(); - } - - - /** - * Computes the transpose of a tensor by exchanging the first and last axes of this tensor. - * - * @return The transpose of this tensor. - * - * @see #T(int, int) - * @see #T(int...) - */ - @Override - public T T() { - W[] dest = (W[]) new Ring[data.length]; - int[] destRowPointers = new int[numCols+1]; - int[] destColIndices = new int[data.length]; - CsrOps.transpose(data, rowPointers, colIndices, dest, destRowPointers, destColIndices); - - return makeLikeTensor(shape.swapAxes(0, 1), dest, destRowPointers, destColIndices); - } - - - /** - * Computes the element-wise sum between two tensors of the same shape. - * - * @param b Second tensor in the element-wise sum. - * - * @return The sum of this tensor with {@code b}. - * - * @throws TensorShapeException If this tensor and {@code b} do not have the same shape. - */ - @Override - public T add(T b) { - SparseMatrixData destData = CsrOps.applyBinOpp( - shape, data, rowPointers, colIndices, - b.shape, b.data, b.rowPointers, b.colIndices, - Ring::add, null); - - return makeLikeTensor(shape, destData.data(), destData.rowData(), destData.colData()); - } - - - /** - * Computes the element-wise multiplication of two tensors of the same shape. - * - * @param b Second tensor in the element-wise product. - * - * @return The element-wise product between this tensor and {@code b}. - * - * @throws IllegalArgumentException If this tensor and {@code b} do not have the same shape. - */ - @Override - public T elemMult(T b) { - SparseMatrixData destData = CsrOps.applyBinOpp( - shape, data, rowPointers, colIndices, - b.shape, b.data, b.rowPointers, b.colIndices, - Ring::mult, null); - - return makeLikeTensor(shape, destData.data(), destData.rowData(), destData.colData()); - } - - - /** - * Computes the tensor contraction of this tensor with a specified tensor over the specified set of axes. That is, - * computes the sum of products between the two tensors along the specified set of axes. - * - * @param src2 Tensor to contract with this tensor. - * @param aAxes Axes along which to compute products for this tensor. - * @param bAxes Axes along which to compute products for {@code src2} tensor. - * - * @return The tensor dot product over the specified axes. - * - * @throws IllegalArgumentException If the two tensors shapes do not match along the specified axes pairwise in - * {@code aAxes} and {@code bAxes}. - * @throws IllegalArgumentException If {@code aAxes} and {@code bAxes} do not match in length, or if any of the axes - * are out of bounds for the corresponding tensor. - */ - @Override - public AbstractDenseRingTensor tensorDot(T src2, int[] aAxes, int[] bAxes) { - // TODO: Implement this method. Need to wait for a concrete implementation of AbstractDenseSemiringTensor - return null; - } - - - /** - *

    Computes the generalized trace of this tensor along the specified axes. - * - *

    The generalized tensor trace is the sum along the diagonal values of the 2D sub-arrays of this tensor specified by - * {@code axis1} and {@code axis2}. The shape of the resulting tensor is equal to this tensor with the - * {@code axis1} and {@code axis2} removed. - * - * @param axis1 First axis for 2D sub-array. - * @param axis2 Second axis for 2D sub-array. - * - * @return The generalized trace of this tensor along {@code axis1} and {@code axis2}. - * - * @throws IndexOutOfBoundsException If the two axes are not both larger than zero and less than this tensors rank. - * @throws IllegalArgumentException If {@code axis1 == axis2} or {@code this.shape.get(axis1) != this.shape.get(axis1)} - * (i.e. the axes are equal or the tensor does not have the same length along the two axes.) - */ - @Override - public T tensorTr(int axis1, int axis2) { - ValidateParameters.ensureNotEquals(axis1, axis2); - ValidateParameters.ensureValidAxes(shape, axis1, axis2); - - return makeLikeTensor(new Shape(1, 1), (W[]) new Ring[]{tr()}, new int[]{0}, new int[]{0}); - } - - - /** - * Computes the transpose of a tensor by exchanging {@code axis1} and {@code axis2}. - * - * @param axis1 First axis to exchange. - * @param axis2 Second axis to exchange. - * - * @return The transpose of this tensor according to the specified axes. - * - * @throws IndexOutOfBoundsException If either {@code axis1} or {@code axis2} are out of bounds for the rank of this tensor. - * @see #T() - * @see #T(int...) - */ - @Override - public T T(int axis1, int axis2) { - ValidateParameters.ensureValidAxes(shape, axis1, axis2); - if(axis1 == axis2) return copy(); - return T(); - } - - - /** - * Computes the transpose of this tensor. That is, permutes the axes of this tensor so that it matches - * the permutation specified by {@code axes}. - * - * @param axes Permutation of tensor axis. If the tensor has rank {@code N}, then this must be an array of length - * {@code N} which is a permutation of {@code {0, 1, 2, ..., N-1}}. - * - * @return The transpose of this tensor with its axes permuted by the {@code axes} array. - * - * @throws IndexOutOfBoundsException If any element of {@code axes} is out of bounds for the rank of this tensor. - * @throws IllegalArgumentException If {@code axes} is not a permutation of {@code {1, 2, 3, ... N-1}}. - * @see #T(int, int) - * @see #T() - */ - @Override - public T T(int... axes) { - if(axes.length != 2) { - throw new IllegalArgumentException("Cannot transpose axes " - + Arrays.toString(axes) + " for a tensor of rank " + rank); - } - - return T(axes[0], axes[1]); - } - - - /** - * Gets the number of rows in this matrix. - * - * @return The number of rows in this matrix. - */ - @Override - public int numRows() { - return numRows; - } - - - /** - * Gets the number of columns in this matrix. - * - * @return The number of columns in this matrix. - */ - @Override - public int numCols() { - return numCols; - } - - - /** - * Gets the element of this matrix at this specified {@code row} and {@code col}. - * - * @param row Row index of the item to get from this matrix. - * @param col Column index of the item to get from this matrix. - * - * @return The element of this matrix at the specified index. - */ - @Override - public W get(int row, int col) { - int loc = Arrays.binarySearch(colIndices, rowPointers[row], rowPointers[row+1], col); - - if(loc >= 0) return data[loc]; - else return zeroElement; - } - - - /** - *

    Computes the trace of this matrix. That is, the sum of elements along the principle diagonal of this matrix. - * - *

    Same as {@link #trace()}. - * - * @return The trace of this matrix. - * - * @throws IllegalArgumentException If this matrix is not square. - */ - @Override - public W tr() { - W tr = SemiringCsrOps.trace(data, rowPointers, colIndices); - - return (tr == null) ? zeroElement : tr; - } - - - /** - * Checks if this matrix is upper triangular. - * - * @return {@code true} is this matrix is upper triangular; {@code false} otherwise. - * - * @see #isTri() - * @see #isTriL() - * @see #isDiag() - */ - @Override - public boolean isTriU() { - return SemiringCsrProperties.isTriU(shape, data, rowPointers, colIndices); - } - - - /** - * Checks if this matrix is lower triangular. - * - * @return {@code true} is this matrix is lower triangular; {@code false} otherwise. - * - * @see #isTri() - * @see #isTriU() - * @see #isDiag() - */ - @Override - public boolean isTriL() { - return SemiringCsrProperties.isTriL(shape, data, rowPointers, colIndices); - } - - - /** - * Checks if this matrix is the identity matrix. That is, checks if this matrix is square and contains - * only ones along the principle diagonal and zeros everywhere else. - * - * @return {@code true} if this matrix is the identity matrix; {@code false} otherwise. - */ - @Override - public boolean isI() { - return SemiringCsrProperties.isIdentity(shape, data, rowPointers, colIndices); - } - - - /** - * Computes the matrix multiplication between two matrices. - * - * @param b Second matrix in the matrix multiplication. - * - * @return The result of matrix multiplying this matrix with matrix {@code b}. - * - * @throws LinearAlgebraException If the number of columns in this matrix do not equal the number - * of rows in matrix {@code b}. - * @see #multToSparse(AbstractCsrRingMatrix) - */ - @Override - public U mult(T b) { - Shape destShape = new Shape(numRows, b.numCols); - W[] destArray = (W[]) new Ring[numRows*b.numCols]; - - SemiringCsrMatMult.standard( - shape, data, rowPointers, colIndices, b.shape, - b.data, b.rowPointers, b.colIndices, - destArray, zeroElement); - - return makeLikeDenseTensor(shape, destArray); - } - - - /** - *

    Computes the matrix multiplication between two sparse CSR matrices and stores the result in a sparse matrix. - *

    Warning: this method should be used with caution as sparse-sparse matrix multiplication may result in a dense matrix. - * In such a case, this method will likely be significantly slower than {@link #mult(AbstractCsrRingMatrix)}. - * @param b - * @return - */ - public T multToSparse(T b) { - SparseMatrixData data = SemiringCsrMatMult.standardToSparse( - shape, this.data, rowPointers, colIndices, b.shape, - b.data, b.rowPointers, b.colIndices); - - return makeLikeTensor(data.shape(), data.data(), data.rowData(), data.colData()); - } - - - /** - * Multiplies this matrix with the transpose of the {@code b} tensor as if by - * {@code this.mult(b.T())}. - * For large matrices, this method may - * be significantly faster than directly computing the transpose followed by the multiplication as - * {@code this.mult(b.T())}. - * - * @param b The second matrix in the multiplication and the matrix to transpose. - * - * @return The result of multiplying this matrix with the transpose of {@code b}. - */ - @Override - public U multTranspose(T b) { - ValidateParameters.ensureEquals(numCols, b.numCols); - return mult(b.T()); - } - - - /** - * Stacks matrices along columns.
    - * - * @param b Matrix to stack to this matrix. - * - * @return The result of stacking this matrix on top of the matrix {@code b}. - * - * @throws IllegalArgumentException If this matrix and matrix {@code b} have a different number of columns. - * @see #stack(MatrixMixin, int) - * @see #augment(T) - */ - @Override - public T stack(T b) { - return (T) toCoo().stack(b.toCoo()).toCsr(); - } - - - /** - * Stacks matrices along rows. - * - * @param b Matrix to stack to this matrix. - * - * @return The result of stacking {@code b} to the right of this matrix. - * - * @throws IllegalArgumentException If this matrix and matrix {@code b} have a different number of rows. - * @see #stack(T) - * @see #stack(MatrixMixin, int) - */ - @Override - public T augment(T b) { - return (T) toCoo().augment(b.toCoo()).toCsr(); - } - - - /** - * Augments a vector to this matrix. - * - * @param b The vector to augment to this matrix. - * - * @return The result of augmenting {@code b} to this matrix. - */ - @Override - public T augment(V b) { - return (T) toCoo().augment(b).toCsr(); - } - - - /** - * Swaps specified rows in the matrix. This is done in place. - * - * @param rowIndex1 Index of the first row to swap. - * @param rowIndex2 Index of the second row to swap. - * - * @return A reference to this matrix. - * - * @throws ArrayIndexOutOfBoundsException If either index is outside the matrix bounds. - */ - @Override - public T swapRows(int rowIndex1, int rowIndex2) { - CsrOps.swapRows(data, rowPointers, colIndices, rowIndex1, rowIndex2); - return (T) this; - } - - - /** - * Swaps specified columns in the matrix. This is done in place. - * - * @param colIndex1 Index of the first column to swap. - * @param colIndex2 Index of the second column to swap. - * - * @return A reference to this matrix. - * - * @throws ArrayIndexOutOfBoundsException If either index is outside the matrix bounds. - */ - @Override - public T swapCols(int colIndex1, int colIndex2) { - CsrOps.swapCols(data, rowPointers, colIndices, colIndex1, colIndex2); - return (T) this; - } - - - /** - * Checks if a matrix is symmetric. That is, if the matrix is square and equal to its transpose. - * - * @return {@code true} if this matrix is symmetric; {@code false} otherwise. - */ - @Override - public boolean isSymmetric() { - return CsrProperties.isSymmetric(shape, data, rowPointers, colIndices); - } - - - /** - * Checks if a matrix is Hermitian. That is, if the matrix is square and equal to its conjugate transpose. - * - * @return {@code true} if this matrix is Hermitian; {@code false} otherwise. - */ - @Override - public boolean isHermitian() { - // For a arrays matrix, same as isSymmetric. - return isSymmetric(); - } - - - /** - * Checks if this matrix is orthogonal. That is, if the inverse of this matrix is equal to its transpose. - * - * @return {@code true} if this matrix it is orthogonal; {@code false} otherwise. - */ - @Override - public boolean isOrthogonal() { - if(isSquare()) return mult(T()).isI(); - else return false; - } - - - /** - * Removes a specified row from this matrix. - * - * @param rowIndex Index of the row to remove from this matrix. - * - * @return A copy of this matrix with the specified row removed. - */ - @Override - public T removeRow(int rowIndex) { - return (T) toCoo().removeRow(rowIndex).toCsr(); - } - - - /** - * Removes a specified set of rows from this matrix. - * - * @param rowIndices The indices of the rows to remove from this matrix. Assumed to contain unique values. - * - * @return A copy of this matrix with the specified column removed. - */ - @Override - public T removeRows(int... rowIndices) { - return (T) toCoo().removeRows(rowIndices).toCsr(); - } - - - /** - * Removes a specified column from this matrix. - * - * @param colIndex Index of the column to remove from this matrix. - * - * @return A copy of this matrix with the specified column removed. - */ - @Override - public T removeCol(int colIndex) { - return (T) toCoo().removeCol(colIndex).toCsr(); - } - - - /** - * Removes a specified set of columns from this matrix. - * - * @param colIndices Indices of the columns to remove from this matrix. Assumed to contain unique values. - * - * @return A copy of this matrix with the specified column removed. - */ - @Override - public T removeCols(int... colIndices) { - return (T) toCoo().removeCols(colIndices).toCsr(); - } - - - /** - * Creates a copy of this matrix and sets a slice of the copy to the specified values. The rowStart and colStart parameters specify the upper - * left index location of the slice to set. - * - * @param values New values for the specified slice. - * @param rowStart Starting row index for the slice (inclusive). - * @param colStart Starting column index for the slice (inclusive). - * - * @return A copy of this matrix with the given slice set to the specified values. - * - * @throws IndexOutOfBoundsException If rowStart or colStart are not within the matrix. - * @throws IllegalArgumentException If the values slice, with upper left corner at the specified location, does not - * fit completely within this matrix. - */ - @Override - public T setSliceCopy(T values, int rowStart, int colStart) { - return (T) toCoo().setSliceCopy(values.toCoo(), rowStart, colStart).toCsr(); - } - - - /** - * Gets a specified slice of this matrix. - * - * @param rowStart Starting row index of slice (inclusive). - * @param rowEnd Ending row index of slice (exclusive). - * @param colStart Starting column index of slice (inclusive). - * @param colEnd Ending row index of slice (exclusive). - * - * @return The specified slice of this matrix. This is a completely new matrix and NOT a view into the matrix. - * - * @throws ArrayIndexOutOfBoundsException If any of the indices are out of bounds of this matrix. - * @throws IllegalArgumentException If {@code rowEnd} is not greater than {@code rowStart} or if {@code colEnd} is not greater than {@code colStart}. - */ - @Override - public T getSlice(int rowStart, int rowEnd, int colStart, int colEnd) { - SparseMatrixData sliceData = CsrOps.getSlice( - data, rowPointers, colIndices, - rowStart, rowEnd, colStart, colEnd); - return makeLikeTensor(sliceData.shape(), (List) sliceData.data(), - sliceData.rowData(), sliceData.colData()); - } - - - /** - * Sets an index of this matrix to the specified value. - * - * @param value Value to set. - * @param row Row index to set. - * @param col Column index to set. - * - * @return A reference to this matrix. - */ - @Override - public T set(W value, int row, int col) { - // Ensure indices are in bounds. - ValidateParameters.validateTensorIndex(shape, row, col); - W[] newEntries; - int[] newRowPointers = rowPointers.clone(); - int[] newColIndices; - boolean found = false; // Flag indicating an element already exists in this matrix at the specified row and col. - int loc = -1; - - if(rowPointers[row] < rowPointers[row+1]) { - int start = rowPointers[row]; - int stop = rowPointers[row+1]; - - loc = Arrays.binarySearch(colIndices, start, stop, col); - found = loc >= 0; - } - - if(found) { - newEntries = data.clone(); - newEntries[loc] = value; - newRowPointers = rowPointers.clone(); - newColIndices = colIndices.clone(); - } else { - loc = -loc - 1; // Compute insertion index as specified by Arrays.binarySearch. - newEntries = (W[]) new Field[data.length + 1]; - newColIndices = new int[data.length + 1]; - - CsrOps.insertNewValue( - data, rowPointers, colIndices, - newEntries, newRowPointers, newColIndices, - row, col, loc, value); - } - - return makeLikeTensor(shape, newEntries, newRowPointers, newColIndices); - } - - - /** - * Extracts the upper-triangular portion of this matrix with a specified diagonal offset. All other data of the resulting - * matrix will be zero. - * - * @param diagOffset Diagonal offset for upper-triangular portion to extract: - *

      - *
    • If zero, then all data at and above the principle diagonal of this matrix are extracted.
    • - *
    • If positive, then all data at and above the equivalent super-diagonal are extracted.
    • - *
    • If negative, then all data at and above the equivalent sub-diagonal are extracted.
    • - *
    - * - * @return The upper-triangular portion of this matrix with a specified diagonal offset. All other data of the returned - * matrix will be zero. - * - * @throws IllegalArgumentException If {@code diagOffset} is not in the range (-numRows, numCols). - */ - @Override - public T getTriU(int diagOffset) { - return (T) toCoo().getTriU(diagOffset).toCsr(); - } - - - /** - * Extracts the lower-triangular portion of this matrix with a specified diagonal offset. All other data of the resulting - * matrix will be zero. - * - * @param diagOffset Diagonal offset for lower-triangular portion to extract: - *
      - *
    • If zero, then all data at and above the principle diagonal of this matrix are extracted.
    • - *
    • If positive, then all data at and above the equivalent super-diagonal are extracted.
    • - *
    • If negative, then all data at and above the equivalent sub-diagonal are extracted.
    • - *
    - * - * @return The lower-triangular portion of this matrix with a specified diagonal offset. All other data of the returned - * matrix will be zero. - * - * @throws IllegalArgumentException If {@code diagOffset} is not in the range (-numRows, numCols). - */ - @Override - public T getTriL(int diagOffset) { - return (T) toCoo().getTriL(diagOffset).toCsr(); - } - - - /** - * Creates a deep copy of this tensor. - * - * @return A deep copy of this tensor. - */ - @Override - public T copy() { - return makeLikeTensor(shape, data.clone()); + super(shape, entries, rowPointers, colIndices); } @@ -1098,68 +170,37 @@ public T H(int... axes) { /** - * Sorts the indices of this tensor in lexicographical order while maintaining the associated value for each index. - */ - public void sortIndices() { - sortCsrMatrix(data, rowPointers, colIndices); - } - - - /** - *

    Converts this sparse CSR matrix to an equivalent dense matrix. + * Checks if two sparse CSR ring matrices are element-wise equal within the following tolerance for two entries {@code x} + * and {@code y}: + *

    {@code
    +     *  |x-y| <= (1e-08 + 1e-05*|y|)
    +     * }
    * - *

    The zero data of this CSR matrix will be attempted to be filled with a zero value if it could be determined during - * construction of this sparse CSR matrix. If the zero value could not be determined the zero data will be filled with - * {@code null} (this only happens when {@code nnz==0}). To avoid this, the zero element of the arrays for this - * matrix can be set explicitly using {@link #setZeroElement(Ring)}. + * To specify the relative and absolute tolerances use {@link #allClose(AbstractCsrRingMatrix, double, double)} * - * @return A dense matrix which is equivalent to this sparse CSR matrix. - */ - public U toDense() { - W[] dest = (W[]) new Ring[shape.totalEntriesIntValueExact()]; - CsrConversions.toDense(shape, data, rowPointers, colIndices, dest, zeroElement); - return makeLikeDenseTensor(shape, dest); - } - - - /** - * Converts this sparse CSR matrix to an equivalent sparse COO matrix. - * @return A sparse COO matrix equivalent to this sparse CSR matrix. - */ - public AbstractCooRingMatrix toCoo() { - W[] cooEntries = (W[]) new Ring[nnz]; - int[] cooRowIndices = new int[nnz]; - int[] cooColIndices = new int[nnz]; - CsrConversions.toCoo(shape, data, rowPointers, colIndices, cooEntries, cooRowIndices, cooColIndices); - return makeLikeCooMatrix(shape, cooEntries, cooRowIndices, cooColIndices); - } - - - /** - * Converts this CSR matrix to an equivalent sparse COO tensor. - * @return An sparse COO tensor equivalent to this CSR matrix. - */ - public AbstractTensor toTensor() { - return toCoo().toTensor(); - } - - - /** - * Converts this CSR matrix to an equivalent COO tensor with the specified shape. - * @param newShape New shape for the COO tensor. Can be any rank but must be broadcastable to {@link #shape this.shape}. - * @return A COO tensor equivalent to this CSR matrix which has been reshaped to {@code newShape} + * @return {@code true} if this matrix and {@code b} element-wise equal within the tolerance {@code |x-y| <= (1e-08 + 1e-05*|y|)}. + * @see #allClose(AbstractCsrRingMatrix, double, double) */ - public AbstractTensor toTensor(Shape shape) { - return toCoo().toTensor(); + public boolean allClose(T b) { + return allClose(b, 1e-05, 1e-08); } /** - * Converts this sparse CSR matrix to an equivalent vector. If this matrix is not a row or column vector it will be flattened - * before conversion. - * @return A vector equivalent to this CSR matrix. + * Checks if two matrices are element-wise equal within the tolerance specified by {@code relTol} and {@code absTol}. Two elements + * {@code x} and {@code y} are considered "close" if they satisfy the following: + *

    {@code
    +     *  |x-y| <= (absTol + relTol*|y|)
    +     * }
    + * @param b Matrix to compare to this matrix. + * @param relTol Relative tolerance. + * @param absTol Absolute tolerance. + * @return {@code true} if the {@code src1} matrix is the same shape as the {@code src2} matrix and all data + * are 'close', i.e. elements {@code a} and {@code b} at the same positions in the two matrices respectively + * satisfy {@code |a-b| <= (absTol + relTol*|b|)}. Otherwise, returns {@code false}. + * @see #allClose(AbstractCsrRingMatrix) */ - public V toVector() { - return (V) toCoo().toVector(); + public boolean allClose(T b, double relRol, double absTol) { + return CsrRingProperties.allClose(this, b, relRol, absTol); } } diff --git a/src/main/java/org/flag4j/arrays/backend/ring_arrays/AbstractDenseRingMatrix.java b/src/main/java/org/flag4j/arrays/backend/ring_arrays/AbstractDenseRingMatrix.java index 238d166fa..7bc56ab53 100644 --- a/src/main/java/org/flag4j/arrays/backend/ring_arrays/AbstractDenseRingMatrix.java +++ b/src/main/java/org/flag4j/arrays/backend/ring_arrays/AbstractDenseRingMatrix.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -27,16 +27,12 @@ import org.flag4j.algebraic_structures.Ring; import org.flag4j.arrays.Shape; -import org.flag4j.arrays.SparseMatrixData; import org.flag4j.arrays.backend.MatrixMixin; -import org.flag4j.linalg.ops.TransposeDispatcher; -import org.flag4j.linalg.ops.dense.semiring_ops.DenseSemiringConversions; -import org.flag4j.linalg.ops.dense.semiring_ops.DenseSemiringMatMultDispatcher; -import org.flag4j.util.ArrayUtils; -import org.flag4j.util.ValidateParameters; -import org.flag4j.util.exceptions.LinearAlgebraException; - -import java.util.Arrays; +import org.flag4j.arrays.backend.semiring_arrays.AbstractDenseSemiringMatrix; +import org.flag4j.linalg.MatrixNorms; +import org.flag4j.linalg.ops.common.ring_ops.RingProperties; +import org.flag4j.linalg.ops.dense.ring_ops.DenseRingTensorOps; +import org.flag4j.util.exceptions.TensorShapeException; /** * The base class for all dense matrices whose elements are members of a {@link Ring}. @@ -46,16 +42,8 @@ */ public abstract class AbstractDenseRingMatrix, U extends AbstractDenseRingVector, V extends Ring> - extends AbstractDenseRingTensor implements MatrixMixin { - - /** - * The number of rows in this matrix. - */ - public final int numRows; - /** - * The number of columns in this matrix. - */ - public final int numCols; + extends AbstractDenseSemiringMatrix + implements RingTensorMixin, MatrixMixin { /** @@ -67,413 +55,71 @@ public abstract class AbstractDenseRingMatrix makeLikeCooMatrix( - Shape shape, V[] data, int[] rowIndices, int[] colIndices); - - - /** - * Constructs a sparse CSR matrix which is of a similar type as this dense matrix. - * @param shape Shape of the CSR matrix. - * @param data Non-zero data of the CSR matrix. - * @param rowPointers Non-zero row pointers of the CSR matrix. - * @param colIndices Non-zero column indices of the CSR matrix. - * @return A sparse CSR matrix which is of a similar type as this dense matrix. - */ - protected abstract AbstractCsrRingMatrix makeLikeCsrMatrix( - Shape shape, V[] data, int[] rowPointers, int[] colIndices); - - - /** - * Gets the length of the data array which backs this matrix. - * - * @return The length of the data array which backs this matrix. - */ - @Override - public int dataLength() { - return data.length; - } - - - /** - * Computes the transpose of a tensor by exchanging the first and last axes of this tensor. - * - * @return The transpose of this tensor. - * - * @see #T(int, int) - * @see #T(int...) - */ - @Override - public T T() { - V[] dest = (V[]) new Ring[data.length]; - TransposeDispatcher.dispatch(data, shape, dest); - return makeLikeTensor(shape.swapAxes(0, 1), dest); - } - - - /** - * Gets the number of rows in this matrix. - * - * @return The number of rows in this matrix. - */ - @Override - public int numRows() { - return numRows; - } - - - /** - * Gets the number of columns in this matrix. - * - * @return The number of columns in this matrix. - */ - @Override - public int numCols() { - return numCols; - } - - - /** - * Gets the element of this matrix at this specified {@code row} and {@code col}. - * - * @param row Row index of the item to get from this matrix. - * @param col Column index of the item to get from this matrix. - * - * @return The element of this matrix at the specified index. - */ - @Override - public V get(int row, int col) { - return data[row*numCols + col]; - } - - - /** - *

    Computes the trace of this matrix. That is, the sum of elements along the principle diagonal of this matrix. - * - *

    Same as {@link #trace()}. - * - * @return The trace of this matrix. - * - * @throws IllegalArgumentException If this matrix is not square. - */ - @Override - public V tr() { - ValidateParameters.ensureSquareMatrix(shape); - V sum = data[0]; - int colsOffset = this.numCols + 1; - - for(int i=1; imay, be noticeably faster than directly computing the transpose followed by the - * multiplication as {@code this.mult(b.T())}. + * Computes the conjugate transpose of a tensor by conjugating and exchanging {@code axis1} and {@code axis2}. * - * @param b The second matrix in the multiplication and the matrix to transpose. + * @param axis1 First axis to exchange and conjugate. + * @param axis2 Second axis to exchange and conjugate. * - * @return The result of multiplying this matrix with the transpose of {@code b}. - */ - @Override - public T multTranspose(T b) { - V[] dest = makeEmptyDataArray(numRows*b.numRows); - DenseSemiringMatMultDispatcher.dispatchTranspose(data, shape, b.data, b.shape, dest); - return makeLikeTensor(new Shape(numRows, b.numRows), dest); - } - - - /** - * Stacks matrices along columns.
    + * @return The conjugate transpose of this tensor according to the specified axes. * - * @param b Matrix to stack to this matrix. - * - * @return The result of stacking this matrix on top of the matrix {@code b}. - * - * @throws IllegalArgumentException If this matrix and matrix {@code b} have a different number of columns. - * @see #stack(MatrixMixin, int) - * @see #augment(T) + * @throws IndexOutOfBoundsException If either {@code axis1} or {@code axis2} are out of bounds for the rank of this tensor. + * @see #H() + * @see #H(int...) */ @Override - public T stack(T b) { - ValidateParameters.ensureArrayLengthsEq(this.numCols, b.numCols); - Shape stackedShape = new Shape(this.numRows + b.numRows, this.numCols); - V[] stackedEntries = makeEmptyDataArray(stackedShape.totalEntries().intValueExact()); - - System.arraycopy(this.data, 0, stackedEntries, 0, this.data.length); - System.arraycopy(b.data, 0, stackedEntries, this.data.length, b.data.length); - - return makeLikeTensor(stackedShape, stackedEntries); + public T H(int axis1, int axis2) { + return T(); } /** - * Stacks matrices along rows. + * Computes the conjugate transpose of this tensor. That is, conjugates and permutes the axes of this tensor so that it matches + * the permutation specified by {@code axes}. * - * @param b Matrix to stack to this matrix. + * @param axes Permutation of tensor axis. If the tensor has rank {@code N}, then this must be an array of length + * {@code N} which is a permutation of {@code {0, 1, 2, ..., N-1}}. * - * @return The result of stacking {@code b} to the right of this matrix. + * @return The conjugate transpose of this tensor with its axes permuted by the {@code axes} array. * - * @throws IllegalArgumentException If this matrix and matrix {@code b} have a different number of rows. - * @see #stack(T) - * @see #stack(MatrixMixin, int) + * @throws IndexOutOfBoundsException If any element of {@code axes} is out of bounds for the rank of this tensor. + * @throws IllegalArgumentException If {@code axes} is not a permutation of {@code {1, 2, 3, ... N-1}}. + * @see #H(int, int) + * @see #H() */ @Override - public T augment(T b) { - ValidateParameters.ensureArrayLengthsEq(numRows, b.numRows); - - int augNumCols = numCols + b.numCols; - Shape augShape = new Shape(numRows, augNumCols); - V[] augEntries = makeEmptyDataArray(numRows*augNumCols); - - // Copy data from this matrix. - for(int i=0; iNOT a view into the matrix. - * - * @throws ArrayIndexOutOfBoundsException If any of the indices are out of bounds of this matrix. - * @throws IllegalArgumentException If {@code rowEnd} is not greater than {@code rowStart} or if {@code colEnd} is not greater than {@code colStart}. - */ - @Override - public T getSlice(int rowStart, int rowEnd, int colStart, int colEnd) { - ValidateParameters.ensureValidArrayIndices(numRows, rowStart, rowEnd); - ValidateParameters.ensureValidArrayIndices(numCols, colStart, colEnd); - - int sliceRows = rowEnd-rowStart; - int sliceCols = colEnd-colStart; - int destPos = 0; - V[] slice = makeEmptyDataArray(sliceRows*sliceCols); - - for(int i=rowStart; i - *

  • If zero, then all data at and above the principle diagonal of this matrix are extracted.
  • - *
  • If positive, then all data at and above the equivalent super-diagonal are extracted.
  • - *
  • If negative, then all data at and above the equivalent sub-diagonal are extracted.
  • - * - * - * @return The upper-triangular portion of this matrix with a specified diagonal offset. All other data of the returned - * matrix will be zero. - * - * @throws IllegalArgumentException If {@code diagOffset} is not in the range (-numRows, numCols). - */ - @Override - public T getTriU(int diagOffset) { - ValidateParameters.ensureInRange(diagOffset, -numRows+1, numCols-1, "diagOffset"); - V[] copyEntries = makeEmptyDataArray(data.length); - Arrays.fill(copyEntries, (data.length > 0) ? data[0].getZero() : null); - T result = makeLikeTensor(shape, copyEntries); - - // Extract the upper triangular portion - for(int i=0; i= i + diagOffset) - result.data[rowOffset + j] = data[rowOffset + j]; - } - } - - return result; - } - - - /** - * Extracts the lower-triangular portion of this matrix with a specified diagonal offset. All other data of the resulting - * matrix will be zero. - * - * @param diagOffset Diagonal offset for lower-triangular portion to extract: - *
      - *
    • If zero, then all data at and above the principle diagonal of this matrix are extracted.
    • - *
    • If positive, then all data at and above the equivalent super-diagonal are extracted.
    • - *
    • If negative, then all data at and above the equivalent sub-diagonal are extracted.
    • - *
    - * - * @return The lower-triangular portion of this matrix with a specified diagonal offset. All other data of the returned - * matrix will be zero. - * - * @throws IllegalArgumentException If {@code diagOffset} is not in the range (-numRows, numCols). - */ - @Override - public T getTriL(int diagOffset) { - ValidateParameters.ensureInRange(diagOffset, -numRows+1, numCols-1, "diagOffset"); - V[] copyEntries = makeEmptyDataArray(data.length); - Arrays.fill(copyEntries, (data.length > 0) ? data[0].getZero() : null); - T result = makeLikeTensor(shape, copyEntries); - - // Extract the lower triangular portion - for(int i=0; i - *
  • If {@code diagOffset == 0}: Then the elements of the principle diagonal are collected.
  • - *
  • If {@code diagOffset < 0}: Then the elements of the sub-diagonal {@code diagOffset} below the principle diagonal - * are collected.
  • - *
  • If {@code diagOffset > 0}: Then the elements of the super-diagonal {@code diagOffset} above the principle diagonal - * are collected.
  • - * - * - * @return The elements of the specified diagonal as a vector. - */ - @Override - public U getDiag(int diagOffset) { - ValidateParameters.ensureInRange(diagOffset, -(numRows-1), numCols-1, "diagOffset"); - - // Check for some quick returns. - if(numRows == 1 && diagOffset > 0) return (U) makeLikeVector(shape, (V[]) new Ring[]{data[diagOffset]}); - if(numCols == 1 && diagOffset < 0) return (U) makeLikeVector(shape, (V[]) new Ring[]{data[-diagOffset]}); - - // Compute the length of the diagonal. - int newSize = Math.min(numRows, numCols); - int idx = 0; - - if(diagOffset > 0) { - newSize = Math.min(newSize, numCols - diagOffset); - idx = diagOffset; - } - else if(diagOffset < 0) { - newSize = Math.min(newSize, numRows + diagOffset); - idx = -diagOffset*numCols; - } - - V[] diag = makeEmptyDataArray(newSize); - - for(int i=0; i{@code + * |x-y| <= (1E-08 + 1E-05*|y|) + * } * - * @return If this matrix is dense, the row set operation is done in place and a reference to this matrix is returned. - * If this matrix is sparse a copy will be created with the new row and returned. + * @return {@code true} if the matrix is approximately an identity matrix, otherwise {@code false}. */ - @Override - public T setRow(U row, int rowIdx) { - return setRow((V[]) row.data, rowIdx); + public boolean isCloseToIdentity() { + return DenseRingTensorOps.isCloseToIdentity(shape, data); } /** - * Sets a specified row of this matrix to an array. + * Checks if two sparse CSR ring matrices are element-wise equal within the following tolerance for two entries {@code x} + * and {@code y}: + *
    {@code
    +     *  |x-y| <= (1e-08 + 1e-05*|y|)
    +     * }
    * - * @param row Array containing values to replace specified row in this matrix. - * @param rowIdx Index of the row to set. + * To specify the relative and absolute tolerances use {@link #allClose(AbstractDenseRingMatrix, double, double)} * - * @return If this matrix is dense, the row set operation is done in place and a reference to this matrix is returned. - * If this matrix is sparse a copy will be created with the new row and returned. + * @return {@code true} if this matrix and {@code b} element-wise equal within the tolerance {@code |x-y| <= (1e-08 + 1e-05*|y|)}. + * @see #allClose(AbstractDenseRingMatrix, double, double) */ - public T setRow(V[] row, int rowIdx) { - ValidateParameters.ensureArrayLengthsEq(row.length, this.numCols); - - for(int i=0, size=row.length, rowOffset=rowIdx*numCols; i{@code + * |x-y| <= (absTol + relTol*|y|) + * } + * @param b Matrix to compare to this matrix. + * @param relTol Relative tolerance. + * @param absTol Absolute tolerance. + * @return {@code true} if the {@code src1} matrix is the same shape as the {@code src2} matrix and all data + * are 'close', i.e. elements {@code a} and {@code b} at the same positions in the two matrices respectively + * satisfy {@code |a-b| <= (absTol + relTol*|b|)}. Otherwise, returns {@code false}. + * @see #allClose(AbstractDenseRingMatrix) */ - @Override - public T setCol(U col, int colIdx) { - return setRow((V[]) col.data, colIdx); + public boolean allClose(T b, double relRol, double absTol) { + return RingProperties.allClose(data, b.data, relRol, absTol); } /** - * Sets a specified column of this matrix to an array. + * Computes the p-norm of this vector. * - * @param col Vector to replace specified column in this matrix. - * @param colIdx Index of the column to set. + * @param p {@code p} value in the p-norm. * - * @return If this matrix is dense, the column set operation is done in place and a reference to this matrix is returned. - * If this matrix is sparse a copy will be created with the new column and returned. + * @return The Euclidean norm of this vector. */ - public T setCol(V[] col, int colIdx) { - ValidateParameters.ensureArrayLengthsEq(col.length, this.numRows); - - int rowOffset = 0; - for(int i=0, size=col.length; i toCoo() { - return toCoo(0.01); - } - - - /** - * Converts this matrix to an equivalent sparse COO matrix. - * @param estimatedSparsity Estimated sparsity of the matrix. Must be between 0 and 1 inclusive. If this is an accurate estimation - * it may provide a slight speedup and can reduce unneeded memory consumption. If memory is a concern, it is better to - * over-estimate the sparsity. If speed is the concern it is better to under-estimate the sparsity. - * @return A sparse COO matrix that is equivalent to this dense matrix. - * @see #toCoo() - */ - public AbstractCooRingMatrix toCoo(double estimatedSparsity) { - SparseMatrixData data = DenseSemiringConversions.toCoo(shape, this.data, 0.1); - V[] cooEntries = (V[]) data.data().toArray(new Ring[data.data().size()]); - int[] rowIndices = ArrayUtils.fromIntegerList(data.rowData()); - int[] colIndices = ArrayUtils.fromIntegerList(data.colData()); - - return makeLikeCooMatrix(data.shape(), cooEntries, rowIndices, colIndices); - } - - - /** - * Converts this matrix to an equivalent sparse CSR matrix. - * @return A sparse CSR matrix that is equivalent to this dense matrix. - * @see #toCsr(double) - */ - public AbstractCsrRingMatrix toCsr() { - return toCoo(0.01).toCsr(); - } - - - /** - * Converts this matrix to an equivalent sparse CSR matrix. - * @param estimatedSparsity Estimated sparsity of the matrix. Must be between 0 and 1 inclusive. If this is an accurate estimation - * it may provide a slight speedup and can reduce unneeded memory consumption. If memory is a concern, it is better to - * over-estimate the sparsity. If speed is the concern it is better to under-estimate the sparsity. - * @return A sparse CSR matrix that is equivalent to this dense matrix. - * @see #toCsr() - */ - public AbstractCsrRingMatrix toCsr(double estimatedSparsity) { - return toCoo(estimatedSparsity).toCsr(); - } - - - /** - * Converts this matrix to an equivalent vector. If this matrix is not a row or column vector it will first be flattened then - * converted to a vector. - * - * @return A vector which contains the same data as this matrix. - */ - @Override - public U toVector() { - return makeLikeVector(new Shape(numRows*numCols), data.clone()); - } - - - /** - * Converts this matrix to an equivalent tensor. - * @return A tensor with the same shape and data as this matrix. - */ - public abstract AbstractDenseRingTensor toTensor(); - - - /** - * Converts this matrix to an equivalent tensor with the specified {@code newShape}. - * @param newShape Shape of the tensor. Can be any rank but must be broadcastable to the shape of this matrix. - * @return A tensor with the specified {@code newShape} and the same data as this matrix. - */ - public abstract AbstractDenseRingTensor toTensor(Shape newShape); } diff --git a/src/main/java/org/flag4j/arrays/backend/ring_arrays/AbstractDenseRingTensor.java b/src/main/java/org/flag4j/arrays/backend/ring_arrays/AbstractDenseRingTensor.java index 2b4533160..c9afdc7a4 100644 --- a/src/main/java/org/flag4j/arrays/backend/ring_arrays/AbstractDenseRingTensor.java +++ b/src/main/java/org/flag4j/arrays/backend/ring_arrays/AbstractDenseRingTensor.java @@ -26,20 +26,12 @@ import org.flag4j.algebraic_structures.Ring; import org.flag4j.arrays.Shape; -import org.flag4j.arrays.SparseTensorData; -import org.flag4j.arrays.backend.AbstractTensor; +import org.flag4j.arrays.backend.semiring_arrays.AbstractDenseSemiringTensor; import org.flag4j.linalg.ops.TransposeDispatcher; import org.flag4j.linalg.ops.common.ring_ops.CompareRing; -import org.flag4j.linalg.ops.dense.DenseSemiringTensorDot; import org.flag4j.linalg.ops.dense.ring_ops.DenseRingTensorOps; -import org.flag4j.linalg.ops.dense.semiring_ops.DenseSemiringConversions; -import org.flag4j.linalg.ops.dense.semiring_ops.DenseSemiringElemMult; -import org.flag4j.linalg.ops.dense.semiring_ops.DenseSemiringOps; -import org.flag4j.util.ValidateParameters; import org.flag4j.util.exceptions.TensorShapeException; -import java.util.Arrays; - /** *

    The base class for all dense {@link Ring} tensors. *

    The {@link #data} of an AbstractDenseRingTensor are mutable but the {@link #shape} is fixed. @@ -50,14 +42,9 @@ * @param The type of the {@link Ring} which this tensor's data belong to. */ public abstract class AbstractDenseRingTensor, V extends Ring> - extends AbstractTensor + extends AbstractDenseSemiringTensor implements RingTensorMixin { - /** - * The zero element for the arrays that this tensor's elements belong to. - */ - protected V zeroElement; - /** * Creates a tensor with the specified data and shape. * @@ -67,259 +54,6 @@ public abstract class AbstractDenseRingTensor 0 && data[0] != null) ? data[0].getZero() : null; - } - - - /** - * Constructs a sparse COO tensor which is of a similar type as this dense tensor. - * @param shape Shape of the COO tensor. - * @param entries Non-zero data of the COO tensor. - * @param rowIndices Non-zero row indices of the COO tensor. - * @param colIndices Non-zero column indices of the COO tensor. - * @return A sparse COO tensor which is of a similar type as this dense tensor. - */ - protected abstract AbstractTensor makeLikeCooTensor( - Shape shape, V[] entries, int[][] indices); - - - /** - * Gets the shape of this tensor. - * - * @return The shape of this tensor. - */ - @Override - public Shape getShape() { - return super.getShape(); - } - - - /** - * Gets the element of this tensor at the specified indices. - * - * @param indices Indices of the element to get. - * - * @return The element of this tensor at the specified indices. - * - * @throws ArrayIndexOutOfBoundsException If any indices are not within this tensor. - */ - @Override - public V get(int... indices) { - return data[shape.getFlatIndex(indices)]; - } - - - /** - * Sets the element of this tensor at the specified indices. - * - * @param value New value to set the specified index of this tensor to. - * @param indices Indices of the element to set. - * - * @return If this tensor is dense, a reference to this tensor is returned. If this tensor is sparse, a copy of this tensor with - * the updated value is returned. - * - * @throws IndexOutOfBoundsException If {@code indices} is not within the bounds of this tensor. - */ - @Override - public T set(V value, int... indices) { - data[shape.getFlatIndex(indices)] = value; - return (T) this; - } - - - /** - * Flattens tensor to single dimension while preserving order of data. - * - * @return The flattened tensor. - * - * @see #flatten(int) - */ - @Override - public T flatten() { - return makeLikeTensor(shape.flatten(), data.clone()); - } - - - /** - * Flattens a tensor along the specified axis. Unlike {@link #flatten()} - * - * @param axis Axis along which to flatten tensor. - * - * @throws ArrayIndexOutOfBoundsException If the axis is not positive or larger than {@code this.{@link #getRank()}-1}. - * @see #flatten() - */ - @Override - public T flatten(int axis) { - ValidateParameters.ensureValidAxes(shape, axis); - int[] dims = new int[this.getRank()]; - Arrays.fill(dims, 1); - dims[axis] = shape.totalEntries().intValueExact(); - Shape flatShape = new Shape(dims); - - return makeLikeTensor(flatShape, data.clone()); - } - - - /** - * Copies and reshapes this tensor. - * - * @param newShape New shape for the tensor. - * - * @return A copy of this tensor with the new shape. - * - * @throws TensorShapeException If {@code newShape} is not broadcastable to {@link #shape this.shape}. - */ - @Override - public T reshape(Shape newShape) { - // No need to make explicit broadcastable check as the constructor will verify that the number of data in the shape matches - // the number of data in the array. - return makeLikeTensor(newShape, data.clone()); - } - - - /** - * Computes the transpose of a tensor by exchanging {@code axis1} and {@code axis2}. - * - * @param axis1 First axis to exchange. - * @param axis2 Second axis to exchange. - * - * @return The transpose of this tensor according to the specified axes. - * - * @throws IndexOutOfBoundsException If either {@code axis1} or {@code axis2} are out of bounds for the rank of this tensor. - * @see #T() - * @see #T(int...) - */ - @Override - public T T(int axis1, int axis2) { - V[] dest = (V[]) new Ring[data.length]; - TransposeDispatcher.dispatchTensor(data, shape, axis1, axis2, dest); - return makeLikeTensor(shape, dest); - } - - - /** - * Computes the transpose of this tensor. That is, permutes the axes of this tensor so that it matches - * the permutation specified by {@code axes}. - * - * @param axes Permutation of tensor axis. If the tensor has rank {@code N}, then this must be an array of length - * {@code N} which is a permutation of {@code {0, 1, 2, ..., N-1}}. - * - * @return The transpose of this tensor with its axes permuted by the {@code axes} array. - * - * @throws IndexOutOfBoundsException If any element of {@code axes} is out of bounds for the rank of this tensor. - * @throws IllegalArgumentException If {@code axes} is not a permutation of {@code {1, 2, 3, ... N-1}}. - * @see #T(int, int) - * @see #T() - */ - @Override - public T T(int... axes) { - V[] dest = (V[]) new Ring[data.length]; - TransposeDispatcher.dispatchTensor(data, shape, axes, dest); - return makeLikeTensor(shape.permuteAxes(axes), dest); - } - - - /** - * Creates a deep copy of this tensor. - * - * @return A deep copy of this tensor. - */ - @Override - public T copy() { - return makeLikeTensor(shape, data.clone()); - } - - - /** - * Computes the element-wise sum between two tensors of the same shape. - * - * @param b Second tensor in the element-wise sum. - * - * @return The sum of this tensor with {@code b}. - * - * @throws TensorShapeException If this tensor and {@code b} do not have the same shape. - */ - @Override - public T add(T b) { - V[] sum = (V[]) new Ring[data.length]; - DenseSemiringOps.add(data, shape, b.data, b.shape, sum); - return makeLikeTensor(shape, sum); - } - - - /** - * Computes the element-wise sum between two tensors of the same shape and stores the result in this tensor. - * - * @param b Second tensor in the element-wise sum. - */ - public void addEq(T b) { - DenseSemiringOps.add(data, shape, b.data, b.shape, data); - } - - - /** - * Computes the element-wise multiplication of two tensors of the same shape. - * - * @param b Second tensor in the element-wise product. - * - * @return The element-wise product between this tensor and {@code b}. - * - * @throws IllegalArgumentException If this tensor and {@code b} do not have the same shape. - */ - @Override - public T elemMult(T b) { - V[] prod = (V[]) new Ring[data.length]; - DenseSemiringElemMult.dispatch(data, shape, b.data, b.shape, prod); - return makeLikeTensor(shape, prod); - } - - - /** - * Computes the tensor contraction of this tensor with a specified tensor over the specified set of axes. That is, - * computes the sum of products between the two tensors along the specified set of axes. - * - * @param src2 Tensor to contract with this tensor. - * @param aAxes Axes along which to compute products for this tensor. - * @param bAxes Axes along which to compute products for {@code src2} tensor. - * - * @return The tensor dot product over the specified axes. - * - * @throws IllegalArgumentException If the two tensors shapes do not match along the specified axes pairwise in - * {@code aAxes} and {@code bAxes}. - * @throws IllegalArgumentException If {@code aAxes} and {@code bAxes} do not match in length, or if any of the axes - * are out of bounds for the corresponding tensor. - */ - @Override - public T tensorDot(T src2, int[] aAxes, int[] bAxes) { - DenseSemiringTensorDot dot = new DenseSemiringTensorDot(shape, data, src2.shape, src2.data, aAxes, bAxes); - V[] dest = (V[]) new Ring[dot.getOutputSize()]; - dot.compute(dest); - return makeLikeTensor(dot.getOutputShape(), dest); - } - - - /** - *

    Computes the generalized trace of this tensor along the specified axes. - * - *

    The generalized tensor trace is the sum along the diagonal values of the 2D sub-arrays of this tensor specified by - * {@code axis1} and {@code axis2}. The shape of the resulting tensor is equal to this tensor with the - * {@code axis1} and {@code axis2} removed. - * - * @param axis1 First axis for 2D sub-array. - * @param axis2 Second axis for 2D sub-array. - * - * @return The generalized trace of this tensor along {@code axis1} and {@code axis2}. - * - * @throws IndexOutOfBoundsException If the two axes are not both larger than zero and less than this tensors rank. - * @throws IllegalArgumentException If {@code axis1 == axis2} or {@code this.shape.get(axis1) != this.shape.get(axis1)} - * (i.e. the axes are equal or the tensor does not have the same length along the two axes.) - */ - @Override - public T tensorTr(int axis1, int axis2) { - Shape destShape = DenseSemiringOps.getTrShape(shape, axis1, axis2); - V[] destEntries = (V[]) new Ring[destShape.totalEntriesIntValueExact()]; - return makeLikeTensor(destShape, destEntries); } @@ -334,7 +68,7 @@ public T tensorTr(int axis1, int axis2) { */ @Override public T sub(T b) { - V[] diff = (V[]) new Ring[data.length]; + V[] diff = makeEmptyDataArray(data.length); DenseRingTensorOps.sub(shape, data, b.shape, b.data, diff); return makeLikeTensor(shape, diff); } @@ -366,9 +100,9 @@ public void subEq(T b) { */ @Override public T H(int axis1, int axis2) { - V[] dest = (V[]) new Ring[data.length]; + V[] dest = makeEmptyDataArray(data.length); TransposeDispatcher.dispatchTensorHermitian(shape, data, axis1, axis2, dest); - return makeLikeTensor(shape, dest); + return makeLikeTensor(shape.swapAxes(axis1, axis2), dest); } @@ -388,7 +122,7 @@ public T H(int axis1, int axis2) { */ @Override public T H(int... axes) { - V[] dest = (V[]) new Ring[data.length]; + V[] dest = makeEmptyDataArray(data.length); TransposeDispatcher.dispatchTensorHermitian(shape, data, axes, dest); return makeLikeTensor(shape, dest); } @@ -438,29 +172,4 @@ public double minAbs() { public double maxAbs() { return CompareRing.maxAbs(data); } - - - /** - * Converts this tensor to an equivalent sparse COO tensor. - * @return A sparse COO tensor that is equivalent to this dense tensor. - * @see #toCoo(double) - */ - public AbstractTensor toCoo() { - return toCoo(0.9); - } - - - /** - * Converts this tensor to an equivalent sparse COO tensor. - * @param estimatedSparsity Estimated sparsity of the tensor. Must be between 0 and 1 inclusive. If this is an accurate estimation - * it may provide a slight speedup and can reduce unneeded memory consumption. If memory is a concern, it is better to - * over-estimate the sparsity. If speed is the concern it is better to under-estimate the sparsity. - * @return A sparse COO tensor that is equivalent to this dense tensor. - * @see #toCoo(double) - */ - public AbstractTensor toCoo(double estimatedSparsity) { - SparseTensorData data = DenseSemiringConversions.toCooTensor(shape, this.data, estimatedSparsity); - V[] cooEntries = (V[]) data.data().toArray(new Ring[data.data().size()]); - return makeLikeCooTensor(data.shape(), cooEntries, data.indicesToArray()); - } } diff --git a/src/main/java/org/flag4j/arrays/backend/ring_arrays/AbstractDenseRingVector.java b/src/main/java/org/flag4j/arrays/backend/ring_arrays/AbstractDenseRingVector.java index 9d2d92c3d..0302b0302 100644 --- a/src/main/java/org/flag4j/arrays/backend/ring_arrays/AbstractDenseRingVector.java +++ b/src/main/java/org/flag4j/arrays/backend/ring_arrays/AbstractDenseRingVector.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -27,11 +27,11 @@ import org.flag4j.algebraic_structures.Ring; import org.flag4j.arrays.Shape; import org.flag4j.arrays.backend.VectorMixin; +import org.flag4j.arrays.backend.semiring_arrays.AbstractDenseSemiringVector; import org.flag4j.linalg.VectorNorms; -import org.flag4j.linalg.ops.common.semiring_ops.AggregateSemiring; -import org.flag4j.linalg.ops.dense.DenseConcat; -import org.flag4j.linalg.ops.dense.semiring_ops.DenseSemiringVectorOps; +import org.flag4j.linalg.ops.dense.ring_ops.DenseRingTensorOps; import org.flag4j.util.ValidateParameters; +import org.flag4j.util.exceptions.TensorShapeException; /** *

    The base class for all dense vectors whose data are {@link Ring} elements. @@ -46,13 +46,8 @@ */ public abstract class AbstractDenseRingVector, U extends AbstractDenseRingMatrix, V extends Ring> - extends AbstractDenseRingTensor - implements VectorMixin { - - /** - * The size of this vector. This is the total number of data stored in this vector. - */ - public final int size; + extends AbstractDenseSemiringVector + implements RingTensorMixin, VectorMixin { /** @@ -64,233 +59,107 @@ public abstract class AbstractDenseRingVectorComputes the inner product between two vectors. - * - * @param b Second vector in the inner product. - * - * @return The inner product between this vector and the vector {@code b}. - * - * @throws IllegalArgumentException If this vector and vector {@code b} do not have the same number of data. - * @see #dot(AbstractDenseRingVector) - */ - @Override - public V inner(T b) { - return dot(b); // For a semiarrays, simply delegate to dot product since semirings do not define conjugates. } - /** - *

    Computes the dot product between two vectors. - * - *

    Note: this method is distinct from {@link #inner(AbstractDenseRingVector)}. The inner product is equivalent to the dot product - * of this tensor with the conjugation of {@code b}. - * - * @param b Second vector in the dot product. - * - * @return The dot product between this vector and the vector {@code b}. + * Normalizes this vector to a unit length vector. * - * @throws IllegalArgumentException If this vector and vector {@code b} do not have the same number of data. - * @see #inner(AbstractDenseRingVector) + * @return This vector normalized to a unit length. */ @Override - public V dot(T b) { - return DenseSemiringVectorOps.dotProduct(data, b.data); + public T normalize() { + throw new UnsupportedOperationException("Normalization not supported for arrays vectors."); } /** - * Gets the length of a vector. Same as {@link #size()}. + * Computes the Euclidean norm of this vector. * - * @return The length, i.e. the number of data, in this vector. + * @return The Euclidean norm of this vector. */ - @Override - public int length() { - return size; + public double norm() { + return VectorNorms.norm(data); } /** - * Repeats a vector {@code n} times along a certain axis to create a matrix. + * Computes the p-norm of this vector. * - * @param n Number of times to repeat vector. Must be positive. - * @param axis Axis along which to repeat vector. Must be either 1 or 0. - *

      - *
    • If {@code axis=0}, then the vector will be treated as a row vector and stacked vertically {@code n} times.
    • - *
    • If {@code axis=1} then the vector will be treated as a column vector and stacked horizontally {@code n} times.
    • - *
    + * @param p {@code p} value in the p-norm. * - * @return A matrix whose rows/columns are this vector repeated. + * @return The Euclidean norm of this vector. */ - @Override - public U repeat(int n, int axis) { - V[] dest = (V[]) new Ring[size*n]; - DenseConcat.repeat(data, n, axis, dest); // n is verified to be 1 or 0 here. - Shape shape = (n==0) ? new Shape(n, size) : new Shape(size, n); - return makeLikeMatrix(shape, dest); + public double norm(double p) { + return VectorNorms.norm(data, p); } /** - *

    Stacks two vectors along specified axis. - * - *

    Stacking two vectors of length {@code n} along axis 0 stacks the vectors - * as if they were row vectors resulting in a {@code 2-by-n} matrix. + * Computes the element-wise difference between two tensors of the same shape. * - *

    Stacking two vectors of length {@code n} along axis 1 stacks the vectors - * as if they were column vectors resulting in a {@code n-by-2} matrix. + * @param b Second tensor in the element-wise difference. * - * @param b Vector to stack with this vector. - * @param axis Axis along which to stack vectors. If {@code axis=0}, then vectors are stacked as if they are row - * vectors. If {@code axis=1}, then vectors are stacked as if they are column vectors. + * @return The difference of this tensor with {@code b}. * - * @return The result of stacking this vector and the vector {@code b}. - * - * @throws IllegalArgumentException If the number of data in this vector is different from the number of - * data in the vector {@code b}. - * @throws IllegalArgumentException If axis is not either 0 or 1. + * @throws TensorShapeException If this tensor and {@code b} do not have the same shape. */ @Override - public U stack(T b, int axis) { - V[] dest = (V[]) new Ring[2*size]; - DenseConcat.stack(data, b.data, axis, dest); - Shape shape = (axis==0) ? new Shape(2, size) : new Shape(size, 2); - return makeLikeMatrix(shape, dest); + public T sub(T b) { + V[] diff = makeEmptyDataArray(data.length); + DenseRingTensorOps.sub(shape, data, b.shape, b.data, diff); + return makeLikeTensor(shape, diff); } /** - * Computes the outer product of two vectors. - * - * @param b Second vector in the outer product. + * Computes the element-wise difference between two vectors of the same shape and stores the result in this vectors. * - * @return The result of the vector outer product between this vector and {@code b}. + * @param b Second vectors in the element-wise difference. * - * @throws IllegalArgumentException If the two vectors do not have the same number of data. + * @throws TensorShapeException If this vectors and {@code b} do not have the same shape. */ - @Override - public U outer(T b) { - V[] dest = (V[]) new Ring[size*b.size]; - DenseSemiringVectorOps.outerProduct(data, b.data, dest); - return makeLikeMatrix(new Shape(size, size), dest); + public void subEq(T b) { + DenseRingTensorOps.sub(shape, data, b.shape, b.data, data); } /** - * Converts a vector to an equivalent matrix representing either a row or column vector. + * Computes the conjugate transpose of a tensor by conjugating and exchanging {@code axis1} and {@code axis2}. * - * @param columVector Flag indicating whether to convert this vector to a matrix representing a row or column vector: - *

      - *
    • If {@code true}, the vector will be converted to a matrix representing a column vector.
    • - *
    • If {@code false}, The vector will be converted to a matrix representing a row vector.
    • - *
    + * @param axis1 First axis to exchange and conjugate. + * @param axis2 Second axis to exchange and conjugate. * - * @return A matrix equivalent to this vector. - */ - @Override - public U toMatrix(boolean columVector) { - if(columVector) { - // Convert to column vector. - return makeLikeMatrix(new Shape(data.length, 1), data.clone()); - } else { - // Convert to row vector. - return makeLikeMatrix(new Shape(1, data.length), data.clone()); - } - } - - - /** - * Normalizes this vector to a unit length vector. + * @return The conjugate transpose of this tensor according to the specified axes. * - * @return This vector normalized to a unit length. + * @throws IndexOutOfBoundsException If either {@code axis1} or {@code axis2} are out of bounds for the rank of this tensor. + * @see #H() + * @see #H(int...) */ @Override - public T normalize() { - throw new UnsupportedOperationException("Normalization not supported for arrays vectors."); + public T H(int axis1, int axis2) { + ValidateParameters.ensureValidAxes(shape, axis1, axis2); + return conj(); } /** - * Computes the magnitude of this vector. + * Computes the conjugate transpose of this tensor. That is, conjugates and permutes the axes of this tensor so that it matches + * the permutation specified by {@code axes}. * - * @return The magnitude of this vector. - */ - @Override - public V mag() { - return AggregateSemiring.sum(data); - } - - - /** - * Gets the element of this vector at the specified index. + * @param axes Permutation of tensor axis. If the tensor has rank {@code N}, then this must be an array of length + * {@code N} which is a permutation of {@code {0, 1, 2, ..., N-1}}. * - * @param idx Index of the element to get within this vector. + * @return The conjugate transpose of this tensor with its axes permuted by the {@code axes} array. * - * @return The element of this vector at index {@code idx}. + * @throws IndexOutOfBoundsException If any element of {@code axes} is out of bounds for the rank of this tensor. + * @throws IllegalArgumentException If {@code axes} is not a permutation of {@code {1, 2, 3, ... N-1}}. + * @see #H(int, int) + * @see #H() */ @Override - public V get(int idx) { - ValidateParameters.validateTensorIndex(shape, idx); - return data[idx]; - } - - - /** - * Computes the Euclidean norm of this vector. - * - * @return The Euclidean norm of this vector. - */ - public double norm() { - return VectorNorms.norm(data); - } - - - /** - * Computes the p-norm of this vector. - * - * @param p {@code p} value in the p-norm. - * - * @return The Euclidean norm of this vector. - */ - public double norm(int p) { - return VectorNorms.norm(data, p); + public T H(int... axes) { + ValidateParameters.ensureValidAxes(shape, axes); + return conj(); } } diff --git a/src/main/java/org/flag4j/arrays/backend/ring_arrays/RingTensorMixin.java b/src/main/java/org/flag4j/arrays/backend/ring_arrays/RingTensorMixin.java index 0258f62e7..e91ac0aef 100644 --- a/src/main/java/org/flag4j/arrays/backend/ring_arrays/RingTensorMixin.java +++ b/src/main/java/org/flag4j/arrays/backend/ring_arrays/RingTensorMixin.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -25,7 +25,7 @@ package org.flag4j.arrays.backend.ring_arrays; import org.flag4j.algebraic_structures.Ring; -import org.flag4j.arrays.dense.Tensor; +import org.flag4j.arrays.backend.semiring_arrays.SemiringTensorMixin; import org.flag4j.linalg.VectorNorms; import org.flag4j.linalg.ops.common.ring_ops.CompareRing; import org.flag4j.linalg.ops.common.ring_ops.RingOps; @@ -45,7 +45,7 @@ */ public interface RingTensorMixin, U extends RingTensorMixin, V extends Ring> - extends TensorOverRing { + extends TensorOverRing, SemiringTensorMixin { /** @@ -86,20 +86,6 @@ default void subEq(V b) { } - /** - * Computes the element-wise absolute value of this tensor. - * - * @return The element-wise absolute value of this tensor. - */ - @Override - default Tensor abs() { - V[] data = getData(); - double[] abs = new double[data.length]; - RingOps.abs(data, abs); - return new Tensor(getShape(), abs); - } - - /** * Computes the element-wise conjugation of this tensor. * @@ -319,7 +305,7 @@ default double norm() { * * @return The Euclidean norm of this vector. */ - default double norm(int p) { + default double norm(double p) { return VectorNorms.norm(getData(), p); } } diff --git a/src/main/java/org/flag4j/arrays/backend/semiring_arrays/AbstractCooSemiringMatrix.java b/src/main/java/org/flag4j/arrays/backend/semiring_arrays/AbstractCooSemiringMatrix.java index d861791df..90903b6d5 100644 --- a/src/main/java/org/flag4j/arrays/backend/semiring_arrays/AbstractCooSemiringMatrix.java +++ b/src/main/java/org/flag4j/arrays/backend/semiring_arrays/AbstractCooSemiringMatrix.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -38,12 +38,12 @@ import org.flag4j.linalg.ops.sparse.coo.semiring_ops.CooSemiringMatMult; import org.flag4j.linalg.ops.sparse.coo.semiring_ops.CooSemiringMatrixOps; import org.flag4j.linalg.ops.sparse.coo.semiring_ops.CooSemiringMatrixProperties; -import org.flag4j.util.ArrayUtils; import org.flag4j.util.ValidateParameters; import org.flag4j.util.exceptions.LinearAlgebraException; import org.flag4j.util.exceptions.TensorShapeException; import java.math.BigDecimal; +import java.math.RoundingMode; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -81,7 +81,7 @@ * * @param Type of this sparse COO matrix. * @param Type of dense matrix which is similar to {@code T}. - * @param Type of sparse COO vector which is similar to {@code T}. + * @param Type of sparse COO vector which is similar to {@code T}. * @param Type of the semiring element in this matrix. */ public abstract class AbstractCooSemiringMatrix, @@ -94,7 +94,7 @@ public abstract class AbstractCooSemiringMatrix 0) ? entries[0].getZero() : null; + this.zeroElement = (entries.length > 0 && entries[0] != null) ? entries[0].getZero() : null; } @@ -321,7 +323,7 @@ public T set(W value, int row, int col) { idx = -idx - 1; // No non-zero element with these indices exists. Insert new value. - destEntries = (W[]) new Semiring[data.length + 1]; + destEntries = makeEmptyDataArray(data.length + 1); destRowIndices = new int[data.length + 1]; destColIndices = new int[data.length + 1]; @@ -388,7 +390,7 @@ public T setCol(V col, int colIndex) { */ @Override public T flatten() { - return flatten(0); + return flatten(1); } @@ -404,7 +406,7 @@ public T flatten() { public T flatten(int axis) { ValidateParameters.ensureValidAxes(shape, axis); int[] dims = {1, 1}; - dims[1-axis] = shape.totalEntriesIntValueExact(); + dims[axis] = shape.totalEntriesIntValueExact(); Shape flatShape = new Shape(dims); int[] destIndices = new int[data.length]; @@ -413,8 +415,8 @@ public T flatten(int axis) { destIndices[i] = shape.getFlatIndex(rowIndices[i], colIndices[i]); return (axis == 0) - ? makeLikeTensor(flatShape, data.clone(), new int[data.length], destIndices) - : makeLikeTensor(flatShape, data.clone(), destIndices, new int[data.length]); + ? makeLikeTensor(flatShape, data.clone(), destIndices, new int[data.length]) + : makeLikeTensor(flatShape, data.clone(), new int[data.length], destIndices); } @@ -624,7 +626,7 @@ public boolean isI() { @Override public U mult(T b) { ValidateParameters.ensureMatMultShapes(shape, b.shape); - W[] dest = (W[]) new Semiring[numRows*b.numCols]; + W[] dest = makeEmptyDataArray(numRows*b.numCols); CooSemiringMatMult.standard( data, rowIndices, colIndices, shape, b.data, b.rowIndices, b.colIndices, b.shape, dest); @@ -647,7 +649,7 @@ public U mult(T b) { @Override public U multTranspose(T b) { ValidateParameters.ensureEquals(numCols, b.numCols); - return mult(b.T()); + return mult(b.H()); } @@ -667,14 +669,14 @@ public T stack(T b) { ValidateParameters.ensureEquals(numCols, b.numCols); Shape destShape = new Shape(numRows+b.numRows, numCols); - W[] destEntries = (W[]) new Semiring[data.length + b.data.length]; + W[] destEntries = makeEmptyDataArray(data.length + b.data.length); int[] destRowIndices = new int[destEntries.length]; int[] destColIndices = new int[destEntries.length]; CooConcat.stack(data, rowIndices, colIndices, numRows, b.data, b.rowIndices, b.colIndices, destEntries, destRowIndices, destColIndices); - return makeLikeTensor(destShape, (W[]) destEntries, destRowIndices, destColIndices); + return makeLikeTensor(destShape, destEntries, destRowIndices, destColIndices); } @@ -694,7 +696,7 @@ public T augment(T b) { ValidateParameters.ensureEquals(numRows, b.numRows); Shape destShape = new Shape(numRows, numCols + b.numCols); - W[] destEntries = (W[]) new Semiring[data.length + b.data.length]; + W[] destEntries = makeEmptyDataArray(data.length + b.data.length); int[] destRowIndices = new int[destEntries.length]; int[] destColIndices = new int[destEntries.length]; CooConcat.augment(data, rowIndices, colIndices, numCols, @@ -717,7 +719,7 @@ public T augment(V b) { ValidateParameters.ensureEquals(numRows, b.size); Shape destShape = new Shape(numRows, numCols + 1); - W[] destEntries = (W[]) new Semiring[nnz + b.data.length]; + W[] destEntries = makeEmptyDataArray(nnz + b.data.length); int[] destRowIndices = new int[destEntries.length]; int[] destColIndices = new int[destEntries.length]; CooConcat.augmentVector( @@ -770,7 +772,7 @@ public T swapCols(int colIndex1, int colIndex2) { */ @Override public boolean isSymmetric() { - return CooSemiringMatrixProperties.isSymmetric(shape, data, rowIndices, colIndices); + return CooProperties.isSymmetric(shape, data, rowIndices, colIndices, zeroElement); } @@ -812,9 +814,10 @@ public boolean isOrthogonal() { @Override public V getRow(int rowIdx, int start, int stop) { SparseVectorData data = CooGetSet.getRow(shape, this.data, rowIndices, colIndices, rowIdx, start, stop); - return makeLikeVector(data.shape(), - (W[]) data.data().toArray(new Semiring[data.data().size()]), - data.indicesToArray()); + W[] dest = makeEmptyDataArray(data.data().size()); + data.data().toArray(dest); + + return makeLikeVector(data.shape(), dest, data.indicesToArray()); } @@ -833,9 +836,9 @@ public V getRow(int rowIdx, int start, int stop) { @Override public V getCol(int colIdx, int start, int stop) { SparseVectorData data = CooGetSet.getCol(shape, this.data, rowIndices, colIndices, colIdx, start, stop); - return makeLikeVector(data.shape(), - (W[]) data.data().toArray(new Semiring[data.data().size()]), - data.indicesToArray()); + W[] dest = makeEmptyDataArray(data.data().size()); + data.data().toArray(dest); + return makeLikeVector(data.shape(), dest, data.indicesToArray()); } @@ -856,8 +859,10 @@ public V getCol(int colIdx, int start, int stop) { @Override public V getDiag(int diagOffset) { SparseVectorData data = CooGetSet.getDiag(shape, this.data, rowIndices, colIndices, diagOffset); + W[] dest = makeEmptyDataArray(data.data().size()); + data.data().toArray(dest); return makeLikeVector(data.shape(), - (W[]) data.data().toArray(new Semiring[data.data().size()]), + dest, data.indicesToArray()); } @@ -871,6 +876,7 @@ public V getDiag(int diagOffset) { */ @Override public T removeRow(int rowIndex) { + ValidateParameters.ensureValidArrayIndices(numRows, rowIndex); Shape shape = new Shape(numRows-1, numCols); // Find the start and end index within the data array which have the given row index. @@ -878,12 +884,22 @@ public T removeRow(int rowIndex) { int size = data.length - (startEnd[1]-startEnd[0]); // Initialize arrays. - W[] entries = (W[]) new Semiring[size]; + W[] entries = makeEmptyDataArray(size); int[] rowIndices = new int[size]; int[] colIndices = new int[size]; - copyRanges(this.data, this.rowIndices, this.colIndices, entries, rowIndices, colIndices, startEnd); + // Shift all row indices occurring after removed row. + if (startEnd[0] > 0) { + for(int i=startEnd[0], length=rowIndices.length; i rowIndex) + rowIndices[i]--; + } + } + return makeLikeTensor(shape, entries, rowIndices, colIndices); } @@ -897,22 +913,41 @@ public T removeRow(int rowIndex) { */ @Override public T removeRows(int... rowIdxs) { - // TODO: This should be doable for a general COO matrix. Return SparseMatrixData object. - Shape shape = new Shape(numRows-rowIdxs.length, numCols); + ValidateParameters.ensureValidArrayIndices(numRows, rowIdxs); + // Ensure the indices are sorted. + Arrays.sort(rowIdxs); + + Shape shape = new Shape(numRows - rowIdxs.length, numCols); List entries = new ArrayList<>(nnz); - List rowIndices = new ArrayList<>(nnz); - List colIndices = new ArrayList<>(nnz); - - for(int i=0; i newRowIndices = new ArrayList<>(nnz); + List newColIndices = new ArrayList<>(nnz); + + int j = 0; // Points into the rowIdxs array + int removeCount = 0; // Tracks number of removed rows. + + for (int i = 0; i < nnz; i++) { + int oldRow = rowIndices[i]; + + // Advance j while rowIdxs[j] < oldRow, updating removeCount + while (j < rowIdxs.length && rowIdxs[j] < oldRow) { + removeCount++; + j++; } + + // If oldRow is one of the removed rows, skip this entry. + if (j < rowIdxs.length && rowIdxs[j] == oldRow) + continue; + + // Otherwise, shift oldRow by however many removed rows lie below it. + int newRow = oldRow - removeCount; + + // Keep the entry + entries.add(data[i]); + newRowIndices.add(newRow); + newColIndices.add(colIndices[i]); } - return makeLikeTensor(shape, entries, rowIndices, colIndices); + return makeLikeTensor(shape, entries, newRowIndices, newColIndices); } @@ -925,6 +960,8 @@ public T removeRows(int... rowIdxs) { */ @Override public T removeCol(int colIndex) { + ValidateParameters.ensureValidArrayIndices(numRows, colIndex); + Shape shape = new Shape(numRows, numCols-1); List destEntries = new ArrayList<>(data.length); List destRowIndices = new ArrayList<>(data.length); @@ -954,23 +991,35 @@ public T removeCol(int colIndex) { */ @Override public T removeCols(int... colIdxs) { - Shape shape = new Shape(numRows, numCols-1); + ValidateParameters.ensureValidArrayIndices(numRows, colIdxs); + + // Ensure the indices are sorted. + Arrays.sort(colIdxs); + + Shape shape = new Shape(numRows, numCols - colIdxs.length); List destEntries = new ArrayList<>(data.length); - List destRowIndices = new ArrayList<>(data.length); - List destColIndices = new ArrayList<>(data.length); + List destRowIdx = new ArrayList<>(data.length); + List destColIdx = new ArrayList<>(data.length); - for(int i = 0; i< data.length; i++) { - int idx = Arrays.binarySearch(colIdxs, colIndices[i]); + for (int i = 0; i < data.length; i++) { + int oldCol = colIndices[i]; - if(idx < 0) { - // Then entry is not in the specified column, so copy it with the appropriate column index shift. - destEntries.add(data[i]); - destRowIndices.add(rowIndices[i]); - destColIndices.add(colIndices[i] + (idx+1)); - } + // Check if oldCol is being removed. + int idx = Arrays.binarySearch(colIdxs, oldCol); + + // If idx >= 0, oldCol is in colIdxs then skip this entry. + if (idx >= 0) continue; + + // Otherwise, shift column index. + int insertionPoint = -idx - 1; + int newCol = oldCol - insertionPoint; + + destEntries.add(data[i]); + destRowIdx.add(rowIndices[i]); + destColIdx.add(newCol); } - return makeLikeTensor(shape, destEntries, destRowIndices, destColIndices); + return makeLikeTensor(shape, destEntries, destRowIdx, destColIdx); } @@ -1073,7 +1122,7 @@ public T getTriL(int diagOffset) { */ @Override public T copy() { - return makeLikeTensor(shape, data); + return makeLikeTensor(shape, data.clone()); } @@ -1212,10 +1261,11 @@ public void sortIndices() { * @return A dense matrix equivalent to this sparse COO matrix. */ public U toDense() { - W[] entries = (W[]) new Semiring[shape.totalEntriesIntValueExact()]; + W[] entries = makeEmptyDataArray(shape.totalEntriesIntValueExact()); + Arrays.fill(entries, getZeroElement()); for(int i = 0; i< nnz; i++) - entries[rowIndices[i]*numCols + colIndices[i]] = this.data[i]; + entries[rowIndices[i]*numCols + colIndices[i]] = data[i]; return makeLikeDenseTensor(shape, entries); } @@ -1226,7 +1276,7 @@ public U toDense() { * @return A sparse CSR matrix equivalent to this sparse COO matrix. */ public AbstractCsrSemiringMatrix toCsr() { - W[] csrEntries = (W[]) new Semiring[data.length]; + W[] csrEntries = makeEmptyDataArray(data.length); int[] csrRowPointers = new int[numRows + 1]; int[] csrColPointers = new int[colIndices.length]; CooConversions.toCsr(shape, data, rowIndices, colIndices, csrEntries, csrRowPointers, csrColPointers); @@ -1257,7 +1307,7 @@ public AbstractCsrSemiringMatrix toCsr() { public V toVector() { int[] destIndices = new int[data.length]; for(int i = 0; i< data.length; i++) - destIndices[i] = rowIndices[i]*colIndices[i]; + destIndices[i] = rowIndices[i]*numCols + colIndices[i]; return makeLikeVector(new Shape(numRows*numCols), data.clone(), destIndices); } @@ -1271,8 +1321,7 @@ public V toVector() { * @see #coalesce(BinaryOperator) */ public T coalesce() { - SparseMatrixData mat = SparseUtils.coalesce(Semiring::add, shape, data, rowIndices, colIndices); - return makeLikeTensor(mat.shape(), mat.data(), mat.rowData(), mat.colData()); + return coalesce(Semiring::add); } diff --git a/src/main/java/org/flag4j/arrays/backend/semiring_arrays/AbstractCooSemiringTensor.java b/src/main/java/org/flag4j/arrays/backend/semiring_arrays/AbstractCooSemiringTensor.java index cc5e8bbc7..08b0e681e 100644 --- a/src/main/java/org/flag4j/arrays/backend/semiring_arrays/AbstractCooSemiringTensor.java +++ b/src/main/java/org/flag4j/arrays/backend/semiring_arrays/AbstractCooSemiringTensor.java @@ -38,6 +38,7 @@ import org.flag4j.util.exceptions.TensorShapeException; import java.math.BigDecimal; +import java.math.RoundingMode; import java.util.Arrays; import java.util.List; import java.util.function.BinaryOperator; @@ -81,7 +82,7 @@ public abstract class AbstractCooSemiringTensorThe non-zero indices of this sparse tensor. * @@ -113,7 +114,7 @@ protected AbstractCooSemiringTensor(Shape shape, V[] data, int[][] indices) { ValidateParameters.validateTensorIndices(shape, indices); this.indices = indices; this.nnz = data.length; - sparsity = BigDecimal.valueOf(nnz).divide(new BigDecimal(shape.totalEntries())).doubleValue(); + sparsity = BigDecimal.valueOf(nnz).divide(new BigDecimal(shape.totalEntries()), RoundingMode.HALF_UP).doubleValue(); // Attempt to set the zero element for the semiring. this.zeroElement = (data.length > 0 && data[0] != null) ? data[0].getZero() : null; @@ -251,7 +252,7 @@ public U tensorDot(T src2, int[] aAxes, int[] bAxes) { CooTensorDot problem = new CooTensorDot<>(shape, data, indices, src2.shape, src2.data, src2.indices, aAxes, bAxes); - V[] dest = (V[]) new Semiring[problem.getOutputSize()]; + V[] dest= makeEmptyDataArray(problem.getOutputSize()); problem.compute(dest); return makeLikeDenseTensor(problem.getOutputShape(), dest); } @@ -291,7 +292,7 @@ public T tensorTr(int axis1, int axis2) { */ @Override public T T() { - V[] destEntries = (V[]) new Semiring[nnz]; + V[] destEntries= makeEmptyDataArray(nnz); int[][] destIndices = new int[nnz][rank]; CooTranspose.tensorTranspose(shape, data, indices,0, shape.getRank()-1, destEntries, destIndices); return makeLikeTensor(shape.swapAxes(0, rank-1), destEntries, destIndices); @@ -312,7 +313,7 @@ public T T() { */ @Override public T T(int axis1, int axis2) { - V[] destEntries = (V[]) new Semiring[nnz]; + V[] destEntries= makeEmptyDataArray(nnz); int[][] destIndices = new int[nnz][rank]; CooTranspose.tensorTranspose(shape, data, indices, axis1, axis2, destEntries, destIndices); return makeLikeTensor(shape.swapAxes(axis1, axis2), destEntries, destIndices); @@ -335,7 +336,7 @@ public T T(int axis1, int axis2) { */ @Override public T T(int... axes) { - V[] destEntries = (V[]) new Semiring[nnz]; + V[] destEntries= makeEmptyDataArray(nnz); int[][] destIndices = new int[nnz][rank]; CooTranspose.tensorTranspose(shape, data, indices, axes, destEntries, destIndices); return makeLikeTensor(shape.permuteAxes(axes), destEntries, destIndices); @@ -446,7 +447,7 @@ public T set(V value, int... target) { destIndices[idx] = target; } else { // Target not found, insert new value and index. - destEntries = (V[]) new Semiring[nnz + 1]; + destEntries= makeEmptyDataArray(nnz + 1); destIndices = new int[nnz + 1][rank]; int insertionPoint = - (idx + 1); CooGetSet.cooInsertNewValue(value, target, data, indices, insertionPoint, destEntries, destIndices); @@ -522,7 +523,7 @@ public void sortIndices() { * @throws ArithmeticException If the number of data in the dense tensor exceeds 2,147,483,647. */ public U toDense() { - V[] denseEntries = (V[]) new Semiring[shape.totalEntriesIntValueExact()]; + V[] denseEntries= makeEmptyDataArray(shape.totalEntriesIntValueExact()); CooConversions.toDense(shape, data, indices, denseEntries); return makeLikeDenseTensor(shape, denseEntries); } diff --git a/src/main/java/org/flag4j/arrays/backend/semiring_arrays/AbstractCooSemiringVector.java b/src/main/java/org/flag4j/arrays/backend/semiring_arrays/AbstractCooSemiringVector.java index 69b6b9e18..1e52e7bef 100644 --- a/src/main/java/org/flag4j/arrays/backend/semiring_arrays/AbstractCooSemiringVector.java +++ b/src/main/java/org/flag4j/arrays/backend/semiring_arrays/AbstractCooSemiringVector.java @@ -25,7 +25,6 @@ package org.flag4j.arrays.backend.semiring_arrays; -import org.flag4j.algebraic_structures.Field; import org.flag4j.algebraic_structures.Semiring; import org.flag4j.arrays.Shape; import org.flag4j.arrays.SparseVectorData; @@ -42,6 +41,7 @@ import org.flag4j.util.exceptions.TensorShapeException; import java.math.BigDecimal; +import java.math.RoundingMode; import java.util.Arrays; import java.util.List; import java.util.function.BinaryOperator; @@ -133,10 +133,10 @@ protected AbstractCooSemiringVector(Shape shape, Y[] entries, int[] indices) { this.indices = indices; this.nnz = entries.length; - sparsity = BigDecimal.valueOf(nnz).divide(new BigDecimal(shape.totalEntries())).doubleValue(); + sparsity = BigDecimal.valueOf(nnz).divide(new BigDecimal(shape.totalEntries()), RoundingMode.HALF_UP).doubleValue(); // Attempt to set the zero element for the semiring. - this.zeroElement = (entries.length > 0) ? entries[0].getZero() : null; + this.zeroElement = (entries.length > 0 && entries[0] != null) ? entries[0].getZero() : null; } @@ -314,7 +314,7 @@ public T set(Y value, int... target) { destIndices[idx] = target[0]; } else { // Target not found, insert new value and index. - destEntries = (Y[]) new Semiring[nnz + 1]; + destEntries = makeEmptyDataArray(nnz + 1); destIndices = new int[nnz + 1]; int insertionPoint = - (idx + 1); CooGetSet.cooInsertNewValue(value, target[0], data, indices, insertionPoint, destEntries, destIndices); @@ -364,6 +364,7 @@ public T flatten(int axis) { @Override public T reshape(Shape newShape) { ValidateParameters.ensureRank(newShape, 1); + ValidateParameters.ensureBroadcastable(shape, newShape); return copy(); } @@ -378,7 +379,7 @@ public T reshape(Shape newShape) { */ @Override public T join(T b) { - Y[] destEntries = (Y[]) new Semiring[this.data.length + b.data.length]; + Y[] destEntries = makeEmptyDataArray(this.data.length + b.data.length); int[] destIndices = new int[this.indices.length + b.indices.length]; CooConcat.join(data, indices, size, b.data, b.indices, destEntries, destIndices); return makeLikeTensor(new Shape(shape.get(0) + b.shape.get(0)), destEntries, destIndices); @@ -452,11 +453,11 @@ public int length() { */ @Override public V repeat(int n, int axis) { - Y[] tiledEntries = (Y[]) new Field[n*data.length]; + Y[] tiledEntries = makeEmptyDataArray(n*data.length); int[] tiledRows = new int[tiledEntries.length]; int[] tiledCols = new int[tiledEntries.length]; Shape tiledShape = CooConcat.repeat(data, indices, size, n, axis, tiledEntries, tiledRows, tiledCols); - return makeLikeMatrix(tiledShape, data, tiledRows, tiledCols); + return makeLikeMatrix(tiledShape, tiledEntries, tiledRows, tiledCols); } @@ -488,8 +489,8 @@ public V repeat(int n, int axis) { @Override public V stack(T b, int axis) { ValidateParameters.ensureEquals(size, b.size); - Y[] destEntries = (Y[]) new Semiring[data.length + b.data.length]; - int[][] destIndices = new int[2][indices.length + indices.length]; // Row and column indices. + Y[] destEntries = makeEmptyDataArray(data.length + b.data.length); + int[][] destIndices = new int[2][indices.length + b.indices.length]; // Row and column indices. CooConcat.stack(data, indices, b.data, b.indices, destEntries, destIndices[0], destIndices[1]); V mat = makeLikeMatrix(new Shape(2, size), destEntries, destIndices[0], destIndices[1]); @@ -510,7 +511,7 @@ public V stack(T b, int axis) { @Override public W outer(T b) { Shape destShape = new Shape(size, b.size); - Y[] dest = (Y[]) new Semiring[size*b.size]; + Y[] dest = makeEmptyDataArray(size*b.size); CooSemiringVectorOps.outerProduct(data, indices, size, b.data, b.indices, dest); return makeLikeDenseMatrix(shape, dest); } @@ -662,7 +663,8 @@ public void setZeroElement(Y zeroElement) { * @return A dense matrix equivalent to this sparse COO matrix. */ public U toDense() { - Y[] entries = (Y[]) new Semiring[shape.totalEntriesIntValueExact()]; + Y[] entries = makeEmptyDataArray(shape.totalEntriesIntValueExact()); + Arrays.fill(entries, zeroElement); for(int i = 0; i< nnz; i++) entries[indices[i]] = this.data[i]; diff --git a/src/main/java/org/flag4j/arrays/backend/semiring_arrays/AbstractCsrSemiringMatrix.java b/src/main/java/org/flag4j/arrays/backend/semiring_arrays/AbstractCsrSemiringMatrix.java index 1b0be5162..2decc53c9 100644 --- a/src/main/java/org/flag4j/arrays/backend/semiring_arrays/AbstractCsrSemiringMatrix.java +++ b/src/main/java/org/flag4j/arrays/backend/semiring_arrays/AbstractCsrSemiringMatrix.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -24,7 +24,6 @@ package org.flag4j.arrays.backend.semiring_arrays; -import org.flag4j.algebraic_structures.Field; import org.flag4j.algebraic_structures.Semiring; import org.flag4j.arrays.Shape; import org.flag4j.arrays.SparseMatrixData; @@ -41,6 +40,7 @@ import org.flag4j.util.exceptions.TensorShapeException; import java.math.BigDecimal; +import java.math.RoundingMode; import java.util.Arrays; import java.util.List; @@ -88,7 +88,7 @@ public abstract class AbstractCsrSemiringMatrixPointers indicating starting index of each row within the {@link #colIndices} and {@link #data} arrays. * Has length {@link #numRows numRows + 1}. @@ -120,7 +120,7 @@ public abstract class AbstractCsrSemiringMatrix 0) ? entries[0].getZero() : null; + this.zeroElement = (entries.length > 0 && entries[0] != null) ? entries[0].getZero() : null; } @@ -375,7 +376,7 @@ public T reshape(Shape newShape) { */ @Override public T T() { - W[] dest = (W[]) new Semiring[data.length]; + W[] dest = makeEmptyDataArray(data.length); int[] destRowPointers = new int[numCols+1]; int[] destColIndices = new int[data.length]; CsrOps.transpose(data, rowPointers, colIndices, dest, destRowPointers, destColIndices); @@ -526,6 +527,7 @@ public int numCols() { */ @Override public W get(int row, int col) { + ValidateParameters.validateTensorIndex(shape, row, col); int loc = Arrays.binarySearch(colIndices, rowPointers[row], rowPointers[row+1], col); if(loc >= 0) return (W) data[loc]; @@ -544,8 +546,8 @@ public W get(int row, int col) { */ @Override public W tr() { + ValidateParameters.ensureSquare(shape); W tr = (W) SemiringCsrOps.trace(data, rowPointers, colIndices); - return (tr == null) ? (W) zeroElement : tr; } @@ -606,14 +608,14 @@ public boolean isI() { @Override public U mult(T b) { Shape destShape = new Shape(numRows, b.numCols); - W[] destArray = (W[]) new Semiring[numRows*b.numCols]; + W[] destArray = makeEmptyDataArray(numRows*b.numCols); SemiringCsrMatMult.standard( shape, data, rowPointers, colIndices, b.shape, b.data, b.rowPointers, b.colIndices, destArray, zeroElement); - return makeLikeDenseTensor(shape, destArray); + return makeLikeDenseTensor(destShape, destArray); } @@ -769,7 +771,7 @@ public T swapCols(int colIndex1, int colIndex2) { */ @Override public boolean isSymmetric() { - return CsrProperties.isSymmetric(shape, data, rowPointers, colIndices); + return CsrProperties.isSymmetric(shape, data, rowPointers, colIndices, zeroElement); } @@ -780,7 +782,7 @@ public boolean isSymmetric() { */ @Override public boolean isHermitian() { - // For a semiring matrix, same as isSymmetric. + // For a general semiring matrix, same as isSymmetric. return isSymmetric(); } @@ -885,7 +887,7 @@ public T setSliceCopy(T values, int rowStart, int colStart) { @Override public T getSlice(int rowStart, int rowEnd, int colStart, int colEnd) { SparseMatrixData> sliceData = CsrOps.getSlice( - data, rowPointers, colIndices, + shape, data, rowPointers, colIndices, rowStart, rowEnd, colStart, colEnd); return makeLikeTensor(sliceData.shape(), (List) sliceData.data(), sliceData.rowData(), sliceData.colData()); @@ -926,7 +928,7 @@ public T set(W value, int row, int col) { newColIndices = colIndices.clone(); } else { loc = -loc - 1; // Compute insertion index as specified by Arrays.binarySearch. - newEntries = (W[]) new Field[data.length + 1]; + newEntries = makeEmptyDataArray(data.length + 1); newColIndices = new int[data.length + 1]; CsrOps.insertNewValue( @@ -1024,7 +1026,7 @@ public void sortIndices() { * @return A dense matrix which is equivalent to this sparse CSR matrix. */ public U toDense() { - W[] dest = (W[]) new Semiring[shape.totalEntriesIntValueExact()]; + W[] dest = makeEmptyDataArray(shape.totalEntriesIntValueExact()); CsrConversions.toDense(shape, data, rowPointers, colIndices, dest, zeroElement); return makeLikeDenseTensor(shape, dest); } diff --git a/src/main/java/org/flag4j/arrays/backend/semiring_arrays/AbstractDenseSemiringMatrix.java b/src/main/java/org/flag4j/arrays/backend/semiring_arrays/AbstractDenseSemiringMatrix.java index 2341060ef..1e3cc24bf 100644 --- a/src/main/java/org/flag4j/arrays/backend/semiring_arrays/AbstractDenseSemiringMatrix.java +++ b/src/main/java/org/flag4j/arrays/backend/semiring_arrays/AbstractDenseSemiringMatrix.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -30,6 +30,7 @@ import org.flag4j.arrays.SparseMatrixData; import org.flag4j.arrays.backend.MatrixMixin; import org.flag4j.linalg.ops.TransposeDispatcher; +import org.flag4j.linalg.ops.dense.DenseOps; import org.flag4j.linalg.ops.dense.semiring_ops.DenseSemiringConversions; import org.flag4j.linalg.ops.dense.semiring_ops.DenseSemiringMatMultDispatcher; import org.flag4j.util.ArrayUtils; @@ -83,6 +84,14 @@ protected AbstractDenseSemiringMatrix(Shape shape, V[] entries) { protected abstract U makeLikeVector(Shape shape, V[] entries); + /** + * Constructs a vector of a similar type as this matrix. + * @param entries Entries of the vector. + * @return A vector of a similar type as this matrix. + */ + protected abstract U makeLikeVector(V[] entries); + + /** * Constructs a sparse COO matrix which is of a similar type as this dense matrix. * @param shape Shape of the COO matrix. @@ -103,7 +112,7 @@ protected AbstractDenseSemiringMatrix(Shape shape, V[] entries) { * @param colIndices Non-zero column indices of the CSR matrix. * @return A sparse CSR matrix which is of a similar type as this dense matrix. */ - protected abstract AbstractCsrSemiringMatrix makeLikeCsrMatrix( + public abstract AbstractCsrSemiringMatrix makeLikeCsrMatrix( Shape shape, V[] entries, int[] rowPointers, int[] colIndices); @@ -117,7 +126,7 @@ protected AbstractDenseSemiringMatrix(Shape shape, V[] entries) { */ @Override public T T() { - V[] dest = (V[]) new Semiring[data.length]; + V[] dest = makeEmptyDataArray(data.length); TransposeDispatcher.dispatch(data, shape, dest); return makeLikeTensor(shape.swapAxes(0, 1), dest); } @@ -301,8 +310,7 @@ public U mult(U b) { */ @Override public T mult(T b) { - V[] dest = (V[]) new Semiring[numRows*b.numCols]; - System.out.printf("Shapes: %s, %s.\n", shape, b.shape); + V[] dest = makeEmptyDataArray(numRows*b.numCols); DenseSemiringMatMultDispatcher.dispatch(data, shape, b.data, b.shape, dest); return makeLikeTensor(new Shape(numRows, b.numCols), dest); } @@ -320,7 +328,7 @@ public T mult(T b) { */ @Override public T multTranspose(T b) { - V[] dest = (V[]) new Semiring[numRows*b.numRows]; + V[] dest = makeEmptyDataArray(numRows*b.numRows); DenseSemiringMatMultDispatcher.dispatchTranspose(data, shape, b.data, b.shape, dest); return makeLikeTensor(new Shape(numRows, b.numRows), dest); } @@ -341,7 +349,7 @@ public T multTranspose(T b) { public T stack(T b) { ValidateParameters.ensureArrayLengthsEq(this.numCols, b.numCols); Shape stackedShape = new Shape(this.numRows + b.numRows, this.numCols); - V[] stackedEntries = (V[]) new Semiring[stackedShape.totalEntries().intValueExact()]; + V[] stackedEntries = makeEmptyDataArray(stackedShape.totalEntries().intValueExact()); System.arraycopy(this.data, 0, stackedEntries, 0, this.data.length); System.arraycopy(b.data, 0, stackedEntries, this.data.length, b.data.length); @@ -367,7 +375,7 @@ public T augment(T b) { int augNumCols = numCols + b.numCols; Shape augShape = new Shape(numRows, augNumCols); - V[] augEntries = (V[]) new Semiring[numRows*augNumCols]; + V[] augEntries = makeEmptyDataArray(numRows*augNumCols); // Copy data from this matrix. for(int i=0; i 0) ? data[0].getZero() : null); T result = makeLikeTensor(shape, copyEntries); @@ -793,7 +783,7 @@ public T getTriU(int diagOffset) { @Override public T getTriL(int diagOffset) { ValidateParameters.ensureInRange(diagOffset, -numRows+1, numCols-1, "diagOffset"); - V[] copyEntries = (V[]) new Semiring[data.length]; + V[] copyEntries = makeEmptyDataArray(data.length); Arrays.fill(copyEntries, (data.length > 0) ? data[0].getZero() : null); T result = makeLikeTensor(shape, copyEntries); @@ -830,9 +820,18 @@ public T getTriL(int diagOffset) { public U getDiag(int diagOffset) { ValidateParameters.ensureInRange(diagOffset, -(numRows-1), numCols-1, "diagOffset"); + // Check for some quick returns. - if(numRows == 1 && diagOffset > 0) return makeLikeVector(shape, (V[]) new Semiring[]{data[diagOffset]}); - if(numCols == 1 && diagOffset < 0) return makeLikeVector(shape, (V[]) new Semiring[]{data[-diagOffset]}); + if(numRows == 1 && diagOffset > 0) { + V[] dest = makeEmptyDataArray(1); + dest[0] = data[diagOffset]; + return makeLikeVector(shape, dest); + } + if(numCols == 1 && diagOffset < 0) { + V[] dest = makeEmptyDataArray(1); + dest[0] = data[-diagOffset]; + return makeLikeVector(shape, dest); + } // Compute the length of the diagonal. int newSize = Math.min(numRows, numCols); @@ -847,7 +846,7 @@ else if(diagOffset < 0) { idx = -diagOffset*numCols; } - V[] diag = (V[]) new Semiring[newSize]; + V[] diag = makeEmptyDataArray(newSize); for(int i=0; i toCoo(double estimatedSparsity) { SparseMatrixData data = DenseSemiringConversions.toCoo(shape, this.data, estimatedSparsity); - V[] cooEntries = (V[]) data.data().toArray(new Semiring[data.data().size()]); + V[] cooEntries = makeEmptyDataArray(data.data().size()); + data.data().toArray(cooEntries); int[] rowIndices = ArrayUtils.fromIntegerList(data.rowData()); int[] colIndices = ArrayUtils.fromIntegerList(data.colData()); diff --git a/src/main/java/org/flag4j/arrays/backend/semiring_arrays/AbstractDenseSemiringTensor.java b/src/main/java/org/flag4j/arrays/backend/semiring_arrays/AbstractDenseSemiringTensor.java index 3c4058653..754b1b52e 100644 --- a/src/main/java/org/flag4j/arrays/backend/semiring_arrays/AbstractDenseSemiringTensor.java +++ b/src/main/java/org/flag4j/arrays/backend/semiring_arrays/AbstractDenseSemiringTensor.java @@ -28,9 +28,11 @@ import org.flag4j.arrays.Shape; import org.flag4j.arrays.SparseTensorData; import org.flag4j.arrays.backend.AbstractTensor; +import org.flag4j.arrays.backend.VectorMixin; import org.flag4j.linalg.ops.TransposeDispatcher; import org.flag4j.linalg.ops.common.semiring_ops.CompareSemiring; import org.flag4j.linalg.ops.dense.DenseSemiringTensorDot; +import org.flag4j.linalg.ops.dense.real.RealDenseTranspose; import org.flag4j.linalg.ops.dense.semiring_ops.DenseSemiringConversions; import org.flag4j.linalg.ops.dense.semiring_ops.DenseSemiringElemMult; import org.flag4j.linalg.ops.dense.semiring_ops.DenseSemiringOps; @@ -65,6 +67,7 @@ public abstract class AbstractDenseSemiringTensor toCoo() { public AbstractTensor toCoo(double estimatedSparsity) { SparseTensorData data = DenseSemiringConversions.toCooTensor(shape, this.data, estimatedSparsity); V[] cooEntries = data.data().toArray(makeEmptyDataArray(data.data().size())); - return makeLikeCooTensor(data.shape(), cooEntries, data.indicesToArray()); + + // TODO: First check if this tensor is a vector then delegate to specialized toCooVector + // or toCooTensor methods. + if(this instanceof VectorMixin) { + return makeLikeCooTensor( + data.shape(), cooEntries, + RealDenseTranspose.standardIntMatrix(data.indicesToArray())); + } else { + return makeLikeCooTensor(data.shape(), cooEntries, data.indicesToArray()); + } } } diff --git a/src/main/java/org/flag4j/arrays/backend/semiring_arrays/AbstractDenseSemiringVector.java b/src/main/java/org/flag4j/arrays/backend/semiring_arrays/AbstractDenseSemiringVector.java index 43c230ee1..b50322d39 100644 --- a/src/main/java/org/flag4j/arrays/backend/semiring_arrays/AbstractDenseSemiringVector.java +++ b/src/main/java/org/flag4j/arrays/backend/semiring_arrays/AbstractDenseSemiringVector.java @@ -56,11 +56,13 @@ public abstract class AbstractDenseSemiringVector makeLikeCooTensor(Shape shape, T[] entries, int[][] } + /** + * Constructs a vector of a similar type as this matrix. + * + * @param shape Shape of the vector to construct. Must be rank 1. + * @param entries Entries of the vector. + * + * @return A vector of a similar type as this matrix. + */ + @Override + protected FieldVector makeLikeVector(Shape shape, T[] entries) { + return new FieldVector<>(shape, entries); + } + + /** * Constructs a tensor of the same type as this tensor with the given the shape and data. * diff --git a/src/main/java/org/flag4j/arrays/dense/FieldVector.java b/src/main/java/org/flag4j/arrays/dense/FieldVector.java index 48a984c04..d98eba950 100644 --- a/src/main/java/org/flag4j/arrays/dense/FieldVector.java +++ b/src/main/java/org/flag4j/arrays/dense/FieldVector.java @@ -32,6 +32,7 @@ import org.flag4j.io.PrettyPrint; import org.flag4j.io.PrintOptions; import org.flag4j.util.ValidateParameters; +import org.flag4j.util.exceptions.LinearAlgebraException; import java.util.Arrays; @@ -98,6 +99,18 @@ public FieldVector(int size, T fillValue) { } + /** + * Constructs a dense complex vector with the given shape and entries. + * @param shape The shape of the vector. Must be rank-1 and satisfy {@code shape.totalEntriesIntValueExact() == data.length}. + * @param data The entries of the vector. + * @throws LinearAlgebraException If {@code shape.getRank() != 1} + * @throws IllegalArgumentException If {@code shape.totalEntriesIntValueExact() != data.length} + */ + public FieldVector(Shape shape, T[] entries) { + super(shape, entries); + } + + /** * Creates a vector with the specified {@code data}. * diff --git a/src/main/java/org/flag4j/arrays/dense/Matrix.java b/src/main/java/org/flag4j/arrays/dense/Matrix.java index 41c94c55c..2ea9601e9 100644 --- a/src/main/java/org/flag4j/arrays/dense/Matrix.java +++ b/src/main/java/org/flag4j/arrays/dense/Matrix.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -113,6 +113,8 @@ public class Matrix extends AbstractDenseDoubleTensor implements MatrixMixin { private static final long serialVersionUID = 1L; + // TODO: Add norm methods. + /** * The number of rows in this matrix. */ @@ -1171,6 +1173,7 @@ public Matrix setValues(Matrix values) { */ @Override public Matrix setCol(Vector values, int colIndex) { + System.out.println("values: " + values); return setCol(values.data, colIndex); } @@ -1501,8 +1504,7 @@ public Vector getDiag(int diagOffset) { if(diagOffset > 0) { newSize = Math.min(newSize, numCols - diagOffset); idx = diagOffset; - } - else if(diagOffset < 0) { + } else if(diagOffset < 0) { newSize = Math.min(newSize, numRows + diagOffset); idx = -diagOffset*numCols; } diff --git a/src/main/java/org/flag4j/arrays/dense/RingMatrix.java b/src/main/java/org/flag4j/arrays/dense/RingMatrix.java index 132cc1904..30860770f 100644 --- a/src/main/java/org/flag4j/arrays/dense/RingMatrix.java +++ b/src/main/java/org/flag4j/arrays/dense/RingMatrix.java @@ -34,6 +34,7 @@ import org.flag4j.arrays.sparse.CsrRingMatrix; import org.flag4j.io.PrettyPrint; import org.flag4j.io.PrintOptions; +import org.flag4j.linalg.ops.common.ring_ops.RingOps; import org.flag4j.util.ArrayUtils; import org.flag4j.util.StringUtils; import org.flag4j.util.ValidateParameters; @@ -189,6 +190,19 @@ protected RingVector makeLikeVector(Shape shape, T[] entries) { } + /** + * Constructs a vector of a similar type as this matrix. + * + * @param entries Entries of the vector. + * + * @return A vector of a similar type as this matrix. + */ + @Override + protected RingVector makeLikeVector(T[] entries) { + return new RingVector<>(entries); + } + + /** * Constructs a sparse COO matrix which is of a similar type as this dense matrix. * @@ -216,7 +230,7 @@ protected CooRingMatrix makeLikeCooMatrix(Shape shape, T[] entries, int[] row * @return A sparse CSR matrix which is of a similar type as this dense matrix. */ @Override - protected CsrRingMatrix makeLikeCsrMatrix( + public CsrRingMatrix makeLikeCsrMatrix( Shape shape, T[] entries, int[] rowPointers, int[] colIndices) { return new CsrRingMatrix(shape, entries, rowPointers, colIndices); } @@ -455,6 +469,19 @@ public RingMatrix sub(RingMatrix b) { } + /** + * Computes the element-wise absolute value of this tensor. + * + * @return The element-wise absolute value of this tensor. + */ + @Override + public Matrix abs() { + double[] dest = new double[data.length]; + RingOps.abs(data, dest); + return new Matrix(shape, dest); + } + + /** *

    {@inheritDoc} *

    This method will throw an {@code UnsupportedOperationException} as division is not defined for a general ring. diff --git a/src/main/java/org/flag4j/arrays/dense/RingTensor.java b/src/main/java/org/flag4j/arrays/dense/RingTensor.java index 8a6c2451a..40f2b5e2a 100644 --- a/src/main/java/org/flag4j/arrays/dense/RingTensor.java +++ b/src/main/java/org/flag4j/arrays/dense/RingTensor.java @@ -29,6 +29,7 @@ import org.flag4j.arrays.backend.ring_arrays.AbstractDenseRingTensor; import org.flag4j.arrays.sparse.CooRingTensor; import org.flag4j.io.PrintOptions; +import org.flag4j.linalg.ops.common.ring_ops.RingOps; import org.flag4j.linalg.ops.dense.DenseEquals; import org.flag4j.util.ArrayUtils; import org.flag4j.util.StringUtils; @@ -209,6 +210,19 @@ public RingMatrix toMatrix(Shape matShape) { } + /** + * Computes the element-wise absolute value of this tensor. + * + * @return The element-wise absolute value of this tensor. + */ + @Override + public Tensor abs() { + double[] dest = new double[data.length]; + RingOps.abs(data, dest); + return new Tensor(shape, dest); + } + + /** * Checks if an object is equal to this tensor object. * @param object Object to check equality with this tensor. diff --git a/src/main/java/org/flag4j/arrays/dense/RingVector.java b/src/main/java/org/flag4j/arrays/dense/RingVector.java index 7217e4de6..aeb876b04 100644 --- a/src/main/java/org/flag4j/arrays/dense/RingVector.java +++ b/src/main/java/org/flag4j/arrays/dense/RingVector.java @@ -30,6 +30,7 @@ import org.flag4j.arrays.sparse.CooRingVector; import org.flag4j.io.PrettyPrint; import org.flag4j.io.PrintOptions; +import org.flag4j.linalg.ops.common.ring_ops.RingOps; import java.util.Arrays; @@ -103,7 +104,7 @@ public RingVector(Shape shape, T[] data) { * @param entries Entries of the dense vector to construct. */ @Override - protected RingVector makeLikeTensor(T[] entries) { + public RingVector makeLikeTensor(T[] entries) { return new RingVector<>(entries); } @@ -154,6 +155,19 @@ protected CooRingVector makeLikeCooTensor(Shape shape, T[] data, int[][] indi } + /** + * Computes the element-wise absolute value of this tensor. + * + * @return The element-wise absolute value of this tensor. + */ + @Override + public Vector abs() { + double[] dest = new double[data.length]; + RingOps.abs(data, dest); + return new Vector(shape, dest); + } + + /** * Checks if an object is equal to this vector object. * @param object Object to check equality with this vector. diff --git a/src/main/java/org/flag4j/arrays/dense/SemiringMatrix.java b/src/main/java/org/flag4j/arrays/dense/SemiringMatrix.java index da369a445..3bea826c1 100644 --- a/src/main/java/org/flag4j/arrays/dense/SemiringMatrix.java +++ b/src/main/java/org/flag4j/arrays/dense/SemiringMatrix.java @@ -196,6 +196,19 @@ protected SemiringVector makeLikeVector(Shape shape, T[] entries) { } + /** + * Constructs a vector of a similar type as this matrix. + * + * @param entries Entries of the vector. + * + * @return A vector of a similar type as this matrix. + */ + @Override + protected SemiringVector makeLikeVector(T[] entries) { + return new SemiringVector<>(entries); + } + + /** * Constructs a sparse COO matrix which is of a similar type as this dense matrix. * @@ -223,7 +236,7 @@ protected CooSemiringMatrix makeLikeCooMatrix(Shape shape, T[] entries, int[] * @return A sparse CSR matrix which is of a similar type as this dense matrix. */ @Override - protected CsrSemiringMatrix makeLikeCsrMatrix( + public CsrSemiringMatrix makeLikeCsrMatrix( Shape shape, T[] entries, int[] rowPointers, int[] colIndices) { return new CsrSemiringMatrix(shape, entries, rowPointers, colIndices); } diff --git a/src/main/java/org/flag4j/arrays/dense/SemiringVector.java b/src/main/java/org/flag4j/arrays/dense/SemiringVector.java index f8a7eb322..ee510925a 100644 --- a/src/main/java/org/flag4j/arrays/dense/SemiringVector.java +++ b/src/main/java/org/flag4j/arrays/dense/SemiringVector.java @@ -103,7 +103,7 @@ public SemiringVector(Shape shape, T[] data) { * @param entries Entries of the dense vector to construct. */ @Override - protected SemiringVector makeLikeTensor(T[] entries) { + public SemiringVector makeLikeTensor(T[] entries) { return new SemiringVector<>(entries); } diff --git a/src/main/java/org/flag4j/arrays/dense/package-info.java b/src/main/java/org/flag4j/arrays/dense/package-info.java index df88ac295..d431d50df 100644 --- a/src/main/java/org/flag4j/arrays/dense/package-info.java +++ b/src/main/java/org/flag4j/arrays/dense/package-info.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -63,6 +63,18 @@ * {@code T extends Field}. *

  • {@link org.flag4j.arrays.dense.FieldTensor FieldTensor<T>} - Dense tensor parameterized over a field element * {@code T extends Field}.
  • + *
  • {@link org.flag4j.arrays.dense.RingVector RingVector<T>} - Dense vector parameterized over a ring element + * {@code T extends Ring}.
  • + *
  • {@link org.flag4j.arrays.dense.RingMatrix RingMatrix<T>} - Dense matrix parameterized over a ring element + * {@code T extends Ring}.
  • + *
  • {@link org.flag4j.arrays.dense.RingTensor RingTensor<T>} - Dense tensor parameterized over a ring element + * {@code T extends Ring}.
  • + *
  • {@link org.flag4j.arrays.dense.SemiringVector SemiringVector<T>} - Dense vector parameterized over a semiring element + * {@code T extends semiring}.
  • + *
  • {@link org.flag4j.arrays.dense.SemiringMatrix SemiringMatrix<T>} - Dense matrix parameterized over a semiring element + * {@code T extends semiring}.
  • + *
  • {@link org.flag4j.arrays.dense.SemiringTensor SemiringTensor<T>} - Dense tensor parameterized over a semiring element + * {@code T extends semiring}.
  • * * * @see org.flag4j.arrays.sparse diff --git a/src/main/java/org/flag4j/arrays/sparse/CooCMatrix.java b/src/main/java/org/flag4j/arrays/sparse/CooCMatrix.java index a792023af..3eb874a81 100644 --- a/src/main/java/org/flag4j/arrays/sparse/CooCMatrix.java +++ b/src/main/java/org/flag4j/arrays/sparse/CooCMatrix.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -38,9 +38,10 @@ import org.flag4j.linalg.ops.dense.real.RealDenseTranspose; import org.flag4j.linalg.ops.dense_sparse.coo.field_ops.DenseCooFieldMatrixOps; import org.flag4j.linalg.ops.dense_sparse.coo.real_complex.RealComplexDenseCooMatOps; -import org.flag4j.linalg.ops.sparse.coo.field_ops.CooFieldEquals; import org.flag4j.linalg.ops.sparse.coo.real_complex.RealComplexCooConcats; import org.flag4j.linalg.ops.sparse.coo.real_complex.RealComplexSparseMatOps; +import org.flag4j.linalg.ops.sparse.coo.ring_ops.CooRingMatrixOps; +import org.flag4j.linalg.ops.sparse.coo.semiring_ops.CooSemiringEquals; import org.flag4j.linalg.ops.sparse.coo.semiring_ops.CooSemiringMatMult; import org.flag4j.util.ArrayUtils; import org.flag4j.util.StringUtils; @@ -339,6 +340,17 @@ public CsrCMatrix makeLikeCsrMatrix(Shape shape, Complex128[] entries, int[] row } + /** + * Checks if a matrix is Hermitian. That is, if the matrix is square and equal to its conjugate transpose. + * + * @return {@code true} if this matrix is Hermitian; {@code false} otherwise. + */ + @Override + public boolean isHermitian() { + return CooRingMatrixOps.isHermitian(shape, data, rowIndices, colIndices); + } + + /** * Converts this sparse COO matrix to an equivalent sparse CSR matrix. * @@ -567,9 +579,7 @@ public R accept(MatrixVisitor visitor) { public boolean equals(Object object) { if(this == object) return true; if(object == null || object.getClass() != getClass()) return false; - - return CooFieldEquals.cooMatrixEquals(this.dropZeros(), - ((CooCMatrix) object).dropZeros()); + return CooSemiringEquals.cooMatrixEquals(this, ((CooCMatrix) object)); } diff --git a/src/main/java/org/flag4j/arrays/sparse/CooCTensor.java b/src/main/java/org/flag4j/arrays/sparse/CooCTensor.java index 938e38e32..7034100b8 100644 --- a/src/main/java/org/flag4j/arrays/sparse/CooCTensor.java +++ b/src/main/java/org/flag4j/arrays/sparse/CooCTensor.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -33,7 +33,7 @@ import org.flag4j.io.PrintOptions; import org.flag4j.linalg.ops.common.complex.Complex128Ops; import org.flag4j.linalg.ops.dense.real.RealDenseTranspose; -import org.flag4j.linalg.ops.sparse.coo.field_ops.CooFieldEquals; +import org.flag4j.linalg.ops.sparse.coo.semiring_ops.CooSemiringEquals; import org.flag4j.util.ArrayUtils; import org.flag4j.util.ValidateParameters; @@ -315,7 +315,7 @@ public boolean equals(Object object) { if(this == object) return true; if(object == null || object.getClass() != getClass()) return false; - return CooFieldEquals.cooTensorEquals(this, (CooCTensor) object); + return CooSemiringEquals.cooTensorEquals(this, (CooCTensor) object); } diff --git a/src/main/java/org/flag4j/arrays/sparse/CooCVector.java b/src/main/java/org/flag4j/arrays/sparse/CooCVector.java index a475688a6..aea9061f4 100644 --- a/src/main/java/org/flag4j/arrays/sparse/CooCVector.java +++ b/src/main/java/org/flag4j/arrays/sparse/CooCVector.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -33,8 +33,8 @@ import org.flag4j.io.PrintOptions; import org.flag4j.linalg.ops.common.complex.Complex128Ops; import org.flag4j.linalg.ops.dense.real.RealDenseTranspose; -import org.flag4j.linalg.ops.sparse.coo.field_ops.CooFieldEquals; import org.flag4j.linalg.ops.sparse.coo.real_complex.RealComplexSparseVectorOps; +import org.flag4j.linalg.ops.sparse.coo.semiring_ops.CooSemiringEquals; import org.flag4j.util.ArrayUtils; import org.flag4j.util.StringUtils; import org.flag4j.util.ValidateParameters; @@ -394,7 +394,7 @@ public boolean equals(Object object) { if(this == object) return true; if(object == null || object.getClass() != getClass()) return false; - return CooFieldEquals.cooVectorEquals(this, (CooCVector) object); + return CooSemiringEquals.cooVectorEquals(this, (CooCVector) object); } diff --git a/src/main/java/org/flag4j/arrays/sparse/CooFieldMatrix.java b/src/main/java/org/flag4j/arrays/sparse/CooFieldMatrix.java index e5b83d722..39a2ff797 100644 --- a/src/main/java/org/flag4j/arrays/sparse/CooFieldMatrix.java +++ b/src/main/java/org/flag4j/arrays/sparse/CooFieldMatrix.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -34,7 +34,7 @@ import org.flag4j.io.PrettyPrint; import org.flag4j.io.PrintOptions; import org.flag4j.linalg.ops.dense.real.RealDenseTranspose; -import org.flag4j.linalg.ops.sparse.coo.field_ops.CooFieldEquals; +import org.flag4j.linalg.ops.sparse.coo.semiring_ops.CooSemiringEquals; import org.flag4j.linalg.ops.sparse.coo.semiring_ops.CooSemiringMatMult; import org.flag4j.util.ArrayUtils; import org.flag4j.util.StringUtils; @@ -385,7 +385,7 @@ public boolean equals(Object object) { CooFieldMatrix src2 = (CooFieldMatrix) object; - return CooFieldEquals.cooMatrixEquals(this, src2); + return CooSemiringEquals.cooMatrixEquals(this, src2); } diff --git a/src/main/java/org/flag4j/arrays/sparse/CooFieldTensor.java b/src/main/java/org/flag4j/arrays/sparse/CooFieldTensor.java index 3e3b94a14..8704b04c9 100644 --- a/src/main/java/org/flag4j/arrays/sparse/CooFieldTensor.java +++ b/src/main/java/org/flag4j/arrays/sparse/CooFieldTensor.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -33,7 +33,7 @@ import org.flag4j.io.PrettyPrint; import org.flag4j.io.PrintOptions; import org.flag4j.linalg.ops.dense.real.RealDenseTranspose; -import org.flag4j.linalg.ops.sparse.coo.field_ops.CooFieldEquals; +import org.flag4j.linalg.ops.sparse.coo.semiring_ops.CooSemiringEquals; import org.flag4j.util.ArrayUtils; import org.flag4j.util.ValidateParameters; @@ -325,7 +325,7 @@ public boolean equals(Object object) { CooFieldTensor src2 = (CooFieldTensor) object; - return CooFieldEquals.cooTensorEquals(this, src2); + return CooSemiringEquals.cooTensorEquals(this, src2); } diff --git a/src/main/java/org/flag4j/arrays/sparse/CooFieldVector.java b/src/main/java/org/flag4j/arrays/sparse/CooFieldVector.java index 520f1ccf2..e1cd352c7 100644 --- a/src/main/java/org/flag4j/arrays/sparse/CooFieldVector.java +++ b/src/main/java/org/flag4j/arrays/sparse/CooFieldVector.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -32,7 +32,7 @@ import org.flag4j.io.PrettyPrint; import org.flag4j.io.PrintOptions; import org.flag4j.linalg.ops.dense.real.RealDenseTranspose; -import org.flag4j.linalg.ops.sparse.coo.field_ops.CooFieldEquals; +import org.flag4j.linalg.ops.sparse.coo.semiring_ops.CooSemiringEquals; import org.flag4j.util.ArrayUtils; import org.flag4j.util.StringUtils; import org.flag4j.util.ValidateParameters; @@ -257,7 +257,7 @@ public boolean equals(Object object) { CooFieldVector src2 = (CooFieldVector) object; - return CooFieldEquals.cooVectorEquals(this, src2); + return CooSemiringEquals.cooVectorEquals(this, src2); } diff --git a/src/main/java/org/flag4j/arrays/sparse/CooMatrix.java b/src/main/java/org/flag4j/arrays/sparse/CooMatrix.java index cf15feb6a..dead3aeab 100644 --- a/src/main/java/org/flag4j/arrays/sparse/CooMatrix.java +++ b/src/main/java/org/flag4j/arrays/sparse/CooMatrix.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -446,7 +446,7 @@ public Matrix toDense() { public CsrMatrix toCsr() { int[] csrRowPointers = new int[numRows + 1]; - // Copy the non-zero data anc column indices. Count number of data per row. + // Copy the non-zero data and column indices. Count number of data per row. for(int i = 0, size = data.length; i= this.numCols}. */ public CooMatrix setCol(CooVector values, int colIndex) { - return RealSparseMatrixGetSet.setCol(this, colIndex, values); + return RealCooMatrixGetSet.setCol(this, colIndex, values); } @@ -1637,7 +1637,7 @@ public CooMatrix setCol(CooVector values, int colIndex) { */ @Override public CooMatrix setRow(CooVector values, int rowIndex) { - return RealSparseMatrixGetSet.setRow(this, rowIndex, values); + return RealCooMatrixGetSet.setRow(this, rowIndex, values); } @@ -1651,7 +1651,7 @@ public CooMatrix setRow(CooVector values, int rowIndex) { * If this matrix is sparse a copy will be created with the new row and returned. */ public CooMatrix setRow(double[] row, int rowIdx) { - return RealSparseMatrixGetSet.setRow(this, rowIdx, row); + return RealCooMatrixGetSet.setRow(this, rowIdx, row); } diff --git a/src/main/java/org/flag4j/arrays/sparse/CooRingMatrix.java b/src/main/java/org/flag4j/arrays/sparse/CooRingMatrix.java index d10120b5c..f0a4a43ef 100644 --- a/src/main/java/org/flag4j/arrays/sparse/CooRingMatrix.java +++ b/src/main/java/org/flag4j/arrays/sparse/CooRingMatrix.java @@ -27,11 +27,11 @@ import org.flag4j.algebraic_structures.Ring; import org.flag4j.arrays.Shape; import org.flag4j.arrays.backend.ring_arrays.AbstractCooRingMatrix; -import org.flag4j.arrays.backend.ring_arrays.AbstractCsrRingMatrix; import org.flag4j.arrays.backend.smart_visitors.MatrixVisitor; import org.flag4j.arrays.dense.RingMatrix; import org.flag4j.arrays.dense.RingTensor; import org.flag4j.arrays.dense.RingVector; +import org.flag4j.linalg.ops.common.ring_ops.RingOps; import org.flag4j.linalg.ops.dense.real.RealDenseTranspose; import org.flag4j.linalg.ops.sparse.coo.semiring_ops.CooSemiringMatMult; import org.flag4j.util.ArrayUtils; @@ -242,7 +242,7 @@ public RingMatrix makeLikeDenseTensor(Shape shape, T[] entries) { * @return A CSR matrix of a similar type to this sparse COO matrix. */ @Override - public AbstractCsrRingMatrix, CooRingVector, T> makeLikeCsrMatrix( + public CsrRingMatrix makeLikeCsrMatrix( Shape shape, T[] entries, int[] rowPointers, int[] colIndices) { return new CsrRingMatrix<>(shape, entries, rowPointers, colIndices); } @@ -370,14 +370,15 @@ public R accept(MatrixVisitor visitor) { /** - * {@inheritDoc} + * Computes the element-wise absolute value of this tensor. * - *

    Warning: This method will throw a {@link UnsupportedOperationException} as subtraction is not supported for - * ring tensors. + * @return The element-wise absolute value of this tensor. */ @Override - public CooRingMatrix sub(CooRingMatrix b) { - throw new UnsupportedOperationException("Subtraction not supported for matrix type: " + getClass().getName()); + public CooMatrix abs() { + double[] dest = new double[data.length]; + RingOps.abs(data, dest); + return new CooMatrix(shape, dest, rowIndices.clone(), colIndices.clone()); } diff --git a/src/main/java/org/flag4j/arrays/sparse/CooRingTensor.java b/src/main/java/org/flag4j/arrays/sparse/CooRingTensor.java index 4606ef8e7..372625615 100644 --- a/src/main/java/org/flag4j/arrays/sparse/CooRingTensor.java +++ b/src/main/java/org/flag4j/arrays/sparse/CooRingTensor.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -31,8 +31,9 @@ import org.flag4j.arrays.dense.RingVector; import org.flag4j.io.PrettyPrint; import org.flag4j.io.PrintOptions; +import org.flag4j.linalg.ops.common.ring_ops.RingOps; import org.flag4j.linalg.ops.dense.real.RealDenseTranspose; -import org.flag4j.linalg.ops.sparse.coo.ring_ops.CooRingEquals; +import org.flag4j.linalg.ops.sparse.coo.semiring_ops.CooSemiringEquals; import org.flag4j.util.ArrayUtils; import org.flag4j.util.ValidateParameters; @@ -263,6 +264,20 @@ public CooRingMatrix toMatrix() { return mat; } + + /** + * Computes the element-wise absolute value of this tensor. + * + * @return The element-wise absolute value of this tensor. + */ + @Override + public CooTensor abs() { + double[] dest = new double[data.length]; + RingOps.abs(data, dest); + return new CooTensor(shape, dest, ArrayUtils.deepCopy(indices, null)); + } + + /** * Checks if an object is equal to this tensor object. * @param object Object to check equality with this tensor. @@ -276,7 +291,7 @@ public boolean equals(Object object) { CooRingTensor src2 = (CooRingTensor) object; - return CooRingEquals.cooTensorEquals(this, src2); + return CooSemiringEquals.cooTensorEquals(this, src2); } diff --git a/src/main/java/org/flag4j/arrays/sparse/CooRingVector.java b/src/main/java/org/flag4j/arrays/sparse/CooRingVector.java index 5bb0fc619..c0f3e4ff3 100644 --- a/src/main/java/org/flag4j/arrays/sparse/CooRingVector.java +++ b/src/main/java/org/flag4j/arrays/sparse/CooRingVector.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -31,8 +31,9 @@ import org.flag4j.arrays.dense.RingVector; import org.flag4j.io.PrettyPrint; import org.flag4j.io.PrintOptions; +import org.flag4j.linalg.ops.common.ring_ops.RingOps; import org.flag4j.linalg.ops.dense.real.RealDenseTranspose; -import org.flag4j.linalg.ops.sparse.coo.ring_ops.CooRingEquals; +import org.flag4j.linalg.ops.sparse.coo.semiring_ops.CooSemiringEquals; import org.flag4j.util.ArrayUtils; import org.flag4j.util.StringUtils; @@ -271,7 +272,7 @@ public boolean equals(Object object) { CooRingVector src2 = (CooRingVector) object; - return CooRingEquals.cooVectorEquals(this, src2); + return CooSemiringEquals.cooVectorEquals(this, src2); } @@ -344,6 +345,20 @@ public CooRingTensor toTensor(Shape newShape) { } + + /** + * Computes the element-wise absolute value of this tensor. + * + * @return The element-wise absolute value of this tensor. + */ + @Override + public CooVector abs() { + double[] dest = new double[data.length]; + RingOps.abs(data, dest); + return new CooVector(shape, dest, indices.clone()); + } + + /** * Formats this tensor as a human-readable string. Specifically, a string containing the * shape and flatten data of this tensor. @@ -361,19 +376,19 @@ public String toString() { int precision = PrintOptions.getPrecision(); if(size > 0) { - int stopIndex = Math.min(maxCols -1, size-1); + int stopIndex = Math.min(maxCols - 1, size - 1); int width; String value; // Get data up until the stopping point. - for(int i = 0; i src2 = (CooSemiringVector) object; - return CooFieldEquals.cooVectorEquals(this, src2); + return CooSemiringEquals.cooVectorEquals(this, src2); } diff --git a/src/main/java/org/flag4j/arrays/sparse/CsrCMatrix.java b/src/main/java/org/flag4j/arrays/sparse/CsrCMatrix.java index 5df486ed4..0a6cef74f 100644 --- a/src/main/java/org/flag4j/arrays/sparse/CsrCMatrix.java +++ b/src/main/java/org/flag4j/arrays/sparse/CsrCMatrix.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -36,15 +36,14 @@ import org.flag4j.io.PrettyPrint; import org.flag4j.io.PrintOptions; import org.flag4j.linalg.ops.common.complex.Complex128Ops; -import org.flag4j.linalg.ops.dense_sparse.csr.field_ops.DenseCsrFieldMatMult; import org.flag4j.linalg.ops.dense_sparse.csr.real_field_ops.RealFieldDenseCsrMatMult; +import org.flag4j.linalg.ops.dense_sparse.csr.semiring_ops.DenseCsrSemiringMatMult; import org.flag4j.linalg.ops.sparse.SparseUtils; import org.flag4j.linalg.ops.sparse.csr.CsrConversions; import org.flag4j.linalg.ops.sparse.csr.real_complex.RealComplexCsrMatMult; import org.flag4j.linalg.ops.sparse.csr.semiring_ops.SemiringCsrMatMult; import org.flag4j.util.ArrayUtils; import org.flag4j.util.StringUtils; -import org.flag4j.util.ValidateParameters; import org.flag4j.util.exceptions.LinearAlgebraException; import java.util.List; @@ -78,6 +77,7 @@ public class CsrCMatrix extends AbstractCsrFieldMatrix { private static final long serialVersionUID = 1L; + // TODO: Implement coalesce and and drop zero methods for all CSR classes. /** * Creates a complex sparse CSR matrix with the specified {@code shape}, non-zero data, row pointers, and non-zero column @@ -93,7 +93,6 @@ public class CsrCMatrix extends AbstractCsrFieldMatrix entries, List rowPointe super(shape, entries.toArray(new Complex128[0]), ArrayUtils.fromIntegerList(rowPointers), ArrayUtils.fromIntegerList(colIndices)); - ValidateParameters.ensureRank(shape, 2); setZeroElement(Complex128.ZERO); } @@ -125,7 +123,6 @@ public CsrCMatrix(Shape shape, List entries, List rowPointe */ public CsrCMatrix(Shape shape) { super(shape, new Complex128[0], new int[0], new int[0]); - ValidateParameters.ensureRank(shape, 2); setZeroElement(Complex128.ZERO); } @@ -145,7 +142,6 @@ public CsrCMatrix(Shape shape) { */ public CsrCMatrix(int rows, int cols, Complex128[] entries, int[] rowPointers, int[] colIndices) { super(new Shape(rows, cols), entries, rowPointers, colIndices); - ValidateParameters.ensureRank(shape, 2); setZeroElement(Complex128.ZERO); } @@ -167,7 +163,6 @@ public CsrCMatrix(int rows, int cols, List entries, List ro super(new Shape(rows, cols), entries.toArray(new Complex128[0]), ArrayUtils.fromIntegerList(rowPointers), ArrayUtils.fromIntegerList(colIndices)); - ValidateParameters.ensureRank(shape, 2); setZeroElement(Complex128.ZERO); } @@ -179,7 +174,6 @@ public CsrCMatrix(int rows, int cols, List entries, List ro */ public CsrCMatrix(int rows, int cols) { super(new Shape(rows, cols), new Complex128[0], new int[0], new int[0]); - ValidateParameters.ensureRank(shape, 2); setZeroElement(Complex128.ZERO); } @@ -374,7 +368,7 @@ public CMatrix mult(Matrix b) { * @return The result of multiplying this matrix with the matrix {@code b}. */ public CMatrix mult(CMatrix b) { - return (CMatrix) DenseCsrFieldMatMult.standard(this, b); + return (CMatrix) DenseCsrSemiringMatMult.standard(this, b); } @@ -625,4 +619,18 @@ public String toString() { return result.toString(); } + + + /** + * Computes the matrix-vector multiplication of a vector with this matrix. + * + * @param b Vector in the matrix-vector multiplication. + * + * @return The result of multiplying this matrix with {@code b}. + * + * @throws LinearAlgebraException If the number of columns in this matrix do not equal the size of {@code b}. + */ + public CVector mult(CVector b) { + return (CVector) DenseCsrSemiringMatMult.standardVector(this, b); + } } diff --git a/src/main/java/org/flag4j/arrays/sparse/CsrMatrix.java b/src/main/java/org/flag4j/arrays/sparse/CsrMatrix.java index e03f2d380..b147f51ef 100644 --- a/src/main/java/org/flag4j/arrays/sparse/CsrMatrix.java +++ b/src/main/java/org/flag4j/arrays/sparse/CsrMatrix.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -169,6 +169,7 @@ public class CsrMatrix extends AbstractDoubleTensor * values in row {@code i}. * @param colIndices Column indices for each non-zero value in this sparse CSR matrix. Must satisfy * {@code data.length == colData.length}. + * @throws TensorShapeException If {@code shape.getRank() != 2}. */ public CsrMatrix(Shape shape, double[] entries, int[] rowPointers, int[] colIndices) { super(shape, entries); @@ -212,6 +213,7 @@ public CsrMatrix(int numRows, int numCols, double[] entries, int[] rowPointers, */ public CsrMatrix(int numRows, int numCols) { super(new Shape(numRows, numCols), new double[0]); + this.rowPointers = new int[0]; this.colIndices = new int[0]; this.nnz = 0; @@ -220,6 +222,23 @@ public CsrMatrix(int numRows, int numCols) { } + /** + * Constructs zero matrix with the specified {@code shape}. + * @param shape Shape of the zero matrix to construct. Must be rank 2. + * @throws TensorShapeException If {@code shape.getRank() != 2}. + */ + public CsrMatrix(Shape shape) { + super(shape, new double[0]); + ValidateParameters.ensureRank(shape, 2); + + this.rowPointers = new int[0]; + this.colIndices = new int[0]; + this.nnz = 0; + this.numRows = shape.get(0); + this.numCols = shape.get(1); + } + + /** * Gets the length of the data array which backs this matrix. * diff --git a/src/main/java/org/flag4j/arrays/sparse/CsrRingMatrix.java b/src/main/java/org/flag4j/arrays/sparse/CsrRingMatrix.java index 5ecf7a466..dc62862e8 100644 --- a/src/main/java/org/flag4j/arrays/sparse/CsrRingMatrix.java +++ b/src/main/java/org/flag4j/arrays/sparse/CsrRingMatrix.java @@ -35,6 +35,7 @@ import org.flag4j.arrays.dense.RingVector; import org.flag4j.io.PrettyPrint; import org.flag4j.io.PrintOptions; +import org.flag4j.linalg.ops.common.ring_ops.RingOps; import org.flag4j.linalg.ops.sparse.SparseUtils; import org.flag4j.linalg.ops.sparse.csr.semiring_ops.SemiringCsrMatMult; import org.flag4j.util.ArrayUtils; @@ -385,6 +386,19 @@ public CooRingTensor toTensor(Shape shape) { } + /** + * Computes the element-wise absolute value of this tensor. + * + * @return The element-wise absolute value of this tensor. + */ + @Override + public CsrMatrix abs() { + double[] dest = new double[data.length]; + RingOps.abs(data, dest); + return new CsrMatrix(shape, dest, rowPointers.clone(), colIndices.clone()); + } + + /** * Drops any explicit zeros in this sparse COO matrix. * @return A copy of this Csr matrix with any explicitly stored zeros removed. diff --git a/src/main/java/org/flag4j/arrays/sparse/PermutationMatrix.java b/src/main/java/org/flag4j/arrays/sparse/PermutationMatrix.java index 7f0e1f6db..6f22a5a1e 100644 --- a/src/main/java/org/flag4j/arrays/sparse/PermutationMatrix.java +++ b/src/main/java/org/flag4j/arrays/sparse/PermutationMatrix.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -33,38 +33,70 @@ import org.flag4j.util.ArrayUtils; import org.flag4j.util.ValidateParameters; import org.flag4j.util.exceptions.LinearAlgebraException; +import org.flag4j.util.exceptions.TensorShapeException; import java.io.Serializable; import java.util.Arrays; /** - * A permutation matrix is a square matrix containing only zeros and ones such that each row and column have exactly a single - * one. The identity matrix is a special case of a permutation matrix. Permutation matrices are commonly used to - * track or apply row/column swaps in a matrix.

    + *

    Represents a square permutation matrix with rows and columns equal to {@link #size}, where each row and column contains exactly + * one entry of {@code 1}, and all other entries are {@code 0}. Internally, this class stores a permutation array so that + * {@code permutation[i] = j} indicates that there is a {@code 1} at row {@code i}, column {@code j}. * - * All permutation matrices are {@link Matrix#isOrthogonal() orthogonal}/{@link CMatrix#isUnitary() unitary} meaning - * their inverse is equal to their transpose.

    + *

    Permutation matrices are orthogonal (and unitary in the complex case), so their inverse is equal to their + * transpose. * - * When a permutation matrix is left multiplied to a second matrix, it has the result of swapping rows - * in the second matrix.

    + *

    Permutation matrices are useful for permuting rows or columns of another matrix. + *

      + *
    • Left multiplying another matrix by a permutation matrix permutes the rows of that matrix.
    • + *
    • Right multiplying another matrix by a permutation matrix permutes the columns of that matrix.
    • + *
    + *

    * - * Similarly, when a permutation matrix is right multiplied to another matrix, it has the result of swapping columns in - * the other matrix. + *

    The determinant of any permutation matrix is always {@code +1} or + * {@code -1}, depending on the parity of the permutation (i.e. the number of swaps in the matrix). + * + *

    The identity matrix is a special case of a permutation matrix, corresponding to the identity permutation + * {@code {0, 1, ..., n-1}}. + * + *

    + *

    Example usage:

    + *
    {@code
    + *         // Construct matrices to permute.
    + *         Matrix a = new Vector(ArrayUtils.range(0, 5)).repeat(5, 1);
    + *         Matrix b = a.T();
    + *
    + *         // Create matrix to permute rows according to (4, 2, 3, 0, 1)
    + *         PermutationMatrix p1 = new PermutationMatrix(4, 2, 3, 0, 1);
    + *
    + *         // Permute rows of a according to (0, 1, 2, 3, 4) -> (4, 2, 3, 0, 1)
    + *         Matrix aPerm = p1.leftMult(a);
    + *
    + *         // Permute columns of b according to (0, 1, 2, 3, 4) -> (4, 2, 3, 0, 1)
    + *         Matrix bPerm = p1.T().rightMult(b);
    + *
    + *         // Display original matrices are their permuted counterparts.
    + *         System.out.println("a:\n" + a + "\naPerm:\n" + aPerm + "\n");
    + *         System.out.println("b:\n" + b + "\nbPerm:\n" + bPerm);
    + * }
    */ public class PermutationMatrix implements Serializable { private static final long serialVersionUID = 1L; /** - * Tracks row/column swaps within the permutation matrix. For an {@code n-by-n} permutation matrix, this array will - * have size {@code n}. Each entry of the array represents a 1 in the permutation matrix. The index of an entry - * corresponds to the row index of the 1, and the value of this array corresponds to the column index of the 1. + * Describes the permutation represented by this permutation matrix. + * {@code permutation[i] = j} indicates that there is a {@code 1} at row {@code i}, column {@code j} with in permutation matrix. */ - public final int[] swapPointers; + protected final int[] permutation; /** * Size of this permutation matrix. */ public final int size; + /** + * Shape of this permutation matrix. + */ + public final Shape shape; /** @@ -73,7 +105,8 @@ public class PermutationMatrix implements Serializable { */ public PermutationMatrix(int size) { this.size = size; - swapPointers = ArrayUtils.intRange(0, size); + shape = new Shape(size, size); + permutation = ArrayUtils.intRange(0, size); } @@ -84,52 +117,92 @@ public PermutationMatrix(int size) { */ public PermutationMatrix(Shape shape) { ValidateParameters.ensureSquareMatrix(shape); + this.shape = shape; this.size = shape.get(0); - swapPointers = ArrayUtils.intRange(0, size); + permutation = ArrayUtils.intRange(0, size); } /** - * Copy constructor which creates a copy of the {@code src} permutation matrix. + * Copy constructor which creates a deep copy of the {@code src} permutation matrix. * @param src The permutation matrix to copy. */ public PermutationMatrix(PermutationMatrix src) { this.size = src.size; - this.swapPointers = src.swapPointers.clone(); + this.shape = src.shape; + this.permutation = src.permutation.clone(); } /** - * Creates a permutation matrix where the position of its ones are specified by a {@link #swapPointers swap pointer} - * array. - * @param swapPointers An array which defines row/column swaps within the permutation matrix. - * For an {@code n-by-n} permutation matrix, this array will have size {@code n}. - * Each entry of the array represents a 1 in the permutation matrix. The index of an entry - * corresponds to the row index of the 1, and the value of this array corresponds to - * the column index of the 1. This must be a permutation matrix. However, the validity of this - * is not enforced by this constructor. + *

    Constructs a permutation matrix from the specified {@code permutation}. + * + *

    This constructor will explicitly verify that {@code permutation} is a valid permutation. It is highly recommended + * to do this. However, there is a + * + * @param permutation Array specifying the permutation. Must contain a permutation of {@code {0, 1, ..., permutation.length-1}}. + * {@code permutation[i] = j} indicates that there is a {@code 1} at row {@code i}, column {@code j}. + * @throws IllegalArgumentException {@code permutation} is not a permutation of + * {@code {0, 1, ..., permutation.length-1}. */ - public PermutationMatrix(int[] swapPointers) { - this.size = swapPointers.length; - this.swapPointers = swapPointers.clone(); + public PermutationMatrix(int... permutation) { + this(permutation, true); } /** - * Creates a permutation matrix with the specified column swaps. - * @param colSwaps Array specifying column swaps. The entry {@code x} at index {@code i} indicates that column - * {@code i} has been swapped with column {@code x}. Must be a - * {@link ValidateParameters#ensurePermutation(int...) permutation array}. - * @return A permutation matrix with the specified column swaps. - * @throws IllegalArgumentException If {@code colSwaps} is not a + *

    Constructs a permutation matrix from the specified {@code permutation}. This constructor also accepts a flag indicating if an + * explicit check should be made to enforce that the {@code permutation} array is a valid permutation. + * + *

    It is highly recommended to use {@link #PermutationMatrix(int[])} or set {@code ensurePermutation = true}. However, + * if there is absolute confidence in the validity of the {@code permutation} array, then setting + * {@code ensurePermutation = false} may yield very slight performance benefits. + * + * @param permutation Array specifying the permutation. Must contain a permutation of {@code {0, 1, ..., permutation.length-1}}. + * {@code permutation[i] = j} indicates that there is a {@code 1} at row {@code i}, column {@code j}. + * @param ensurePermutation Flag indicating if an explicit check should be made to verify that {@code permutation} is a valid + * permutation of {@code {0, 1, ..., permutation.length-1}}. + *

      + *
    • If {@code true}: an explicit check will be made that {@code permutation} is a valid permutation.
    • + *
    • If {@code false}: NO check will be made to ensure {@code permutation} is a valid permutation.
    • + *
    + * + * @throws IllegalArgumentException If {@code ensurePermutation == true} and {@code permutation} is not a permutation of + * {@code {0, 1, ..., permutation.length-1}. + */ + public PermutationMatrix(int[] permutation, boolean ensurePermutation) { + if(ensurePermutation) ValidateParameters.ensurePermutation(permutation); + this.size = permutation.length; + this.shape = new Shape(size, size); + this.permutation = permutation.clone(); + } + + + /** + * Returns the permutation represented by this permutation matrix. + * @return The permutation represented by this permutation matrix. + */ + public int[] getPermutation() { + return permutation; + } + + + /** + * Creates a permutation matrix with the specified column permutation. That is, a permutation matrix such that + * right multiplying it with another matrix results in permuting th columns of that matrix according to {@code colPermutation}. + * + * @param colPermutation Array specifying column permutation. The entry {@code x} at index {@code i} indicates that column + * {@code i} has been swapped with column {@code x}. * {@link ValidateParameters#ensurePermutation(int...) permutation array}. + * @return A permutation matrix that when right multiplied to a matrix results in permuting the columns of that matrix according + * to {@code colPermutation}. + * @throws IllegalArgumentException If {@code colPermutation} is not a valid permutation. */ - public static PermutationMatrix fromColSwaps(int[] colSwaps) { - int[] rowPerm = new int[colSwaps.length]; + public static PermutationMatrix fromColSwaps(int[] colPermutation) { + int[] rowPerm = new int[colPermutation.length]; - for (int i=0; i 0) { + value = String.valueOf(arr[arr.length-1]); + width = PrintOptions.getPadding() + value.length(); + value = PrintOptions.useCentering() ? StringUtils.center(value, width) : value; + result.append(String.format("%-" + width + "s", value)); + } return result.append("]").toString(); } diff --git a/src/main/java/org/flag4j/linalg/Condition.java b/src/main/java/org/flag4j/linalg/Condition.java index 260275b5f..ac2b7cdff 100644 --- a/src/main/java/org/flag4j/linalg/Condition.java +++ b/src/main/java/org/flag4j/linalg/Condition.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -30,6 +30,8 @@ import org.flag4j.linalg.decompositions.svd.ComplexSVD; import org.flag4j.linalg.decompositions.svd.RealSVD; +import java.util.function.BiFunction; + /** *

    Utility class for computing the condition number of a matrix. @@ -51,96 +53,183 @@ *

    When the L2-norm is used to compute the condition number then, *

      *     cond(A) = σmax(A) / σmin(A)
    - * where σmax(A) and σmin(A) are the maximum and minimum singular values of A. + * where σmax(A) and σmin(A) are the maximum and minimum singular values of the matrix A. + * + *

    This class supports the computation of the condition number of a real or complex matrix using the following norms. + *

      + *
    • Induced (operator) norm: {@link #cond(Matrix, double)} and {@link #cond(CMatrix, double)}.
    • + *
    • Schatten norm: {@link #condSchatten(Matrix, double)} and {@link #condSchatten(CMatrix, double)}.
    • + *
    • Frobenius norm: {@link #condFro(Matrix)} and {@link #condFro(CMatrix)}.
    • + *
    • Entry-wise norm: {@link #condEntryWise(Matrix, double)} and {@link #condEntryWise(CMatrix, double)}
    • + *
    */ public final class Condition { private Condition() { // Hide default constructor for utility class + } + + /** + *

    Computes the condition number of a matrix. + * + *

    This method computes the condition number using the matrix operator norm induced by the + * vector p-norm ({@link MatrixNorms#inducedNorm(Matrix, double)}). {@code p} must be one of the following: + *

      + *
    • {@code p=1}: Maximum absolute column sum of the matrix.
    • + *
    • {@code p=-1}: Minimum absolute column sum of the matrix.
    • + *
    • {@code p=2}: Spectral norm. Equivalent to the maximum singular value of the matrix.
    • + *
    • {@code p=-2}: Equivalent to the minimum singular value of the matrix.
    • + *
    • {@code p=Double.POSITIVE_INFINITY}: Maximum absolute row sum of the matrix.
    • + *
    • {@code p=Double.NEGATIVE_INFINITY}: Minimum absolute row sum of the matrix.
    • + *
    + * When {@code p < 1}, the "norm" is not a true mathematical norm but may still serve useful numerical purposes. + * + *

    To compute the condition number using other norms see one the below methods: + *

      + *
    • Schatten norm: {@link #condSchatten(Matrix, double)}.
    • + *
    • Frobenius norm: {@link #condFro(Matrix)}
    • + *
    • Entry-wise norm: {@link #condEntryWise(Matrix, double)}
    • + *
    + * + * @param src The matrix to compute the condition number of. + * @param p The p-value to use in the induced norm during condition number computation. + * Must be one of the following: {@code 1}, {@code -1}, {@code 2}, {@code -2}, {@link Double#POSITIVE_INFINITY} or + * {@link Double#NEGATIVE_INFINITY}. + * @return The condition number of {@code src} as computed using the matrix operator norm induced by vector p-norm. + */ + public static double cond(Matrix src, double p) { + if(p == 2.0 || p == -2.0) { + // Special case for spectral norm. No need to invert matrix explicitly. + Vector s = new RealSVD(false).decompose(src).getSingularValues(); + return (p==2) ? s.max()/s.min() : s.min()/s.max(); + } + + return cond(src, p, MatrixNorms::inducedNorm); } /** - * Computes the condition number of this matrix using the 2-norm. - * Specifically, the condition number is computed as the norm of this matrix multiplied by the norm - * of the inverse of this matrix. + *

    Computes the condition number of a matrix. + * + *

    This method computes the condition number using the matrix operator norm induced by the + * vector p-norm ({@link MatrixNorms#inducedNorm(CMatrix, double)}). {@code p} must be one of the following: + *

      + *
    • {@code p=1}: Maximum absolute column sum of the matrix.
    • + *
    • {@code p=-1}: Minimum absolute column sum of the matrix.
    • + *
    • {@code p=2}: Spectral norm. Equivalent to the maximum singular value of the matrix.
    • + *
    • {@code p=-2}: Equivalent to the minimum singular value of the matrix.
    • + *
    • {@code p=Double.POSITIVE_INFINITY}: Maximum absolute row sum of the matrix.
    • + *
    • {@code p=Double.NEGATIVE_INFINITY}: Minimum absolute row sum of the matrix.
    • + *
    + * When {@code p < 1}, the "norm" is not a true mathematical norm but may still serve useful numerical purposes. * + *

    To compute the condition number using other norms see one the below methods: + *

      + *
    • Schatten norm: {@link #condSchatten(CMatrix, double)}.
    • + *
    • Frobenius norm: {@link #condFro(CMatrix)}
    • + *
    • Entry-wise norm: {@link #condEntryWise(CMatrix, double)}
    • + *
    + * + * @param src The matrix to compute the condition number of. + * @param p The p-value to use in the induced norm during condition number computation. + * Must be one of the following: {@code 1}, {@code -1}, {@code 2}, {@code -2}, {@link Double#POSITIVE_INFINITY} or + * {@link Double#NEGATIVE_INFINITY}. + * @return The condition number of {@code src} as computed using the matrix operator norm induced by vector p-norm. + */ + public static double cond(CMatrix src, double p) { + if(p == 2.0 || p == -2.0) { + // Special case for spectral norm. No need to invert matrix explicitly. + Vector s = new ComplexSVD(false).decompose(src).getSingularValues(); + return (p==2) ? s.max()/s.min() : s.min()/s.max(); + } + + return cond(src, p, MatrixNorms::inducedNorm); + } + + + /** + * Computes the condition number of a matrix using the {@link MatrixNorms#schattenNorm(Matrix, double) Schatten norm}. * @param src Matrix to compute the condition number of. - * @return The condition number of this matrix (Assuming 2-norm). This value may be - * {@link Double#POSITIVE_INFINITY infinite}. + * @param p The p value in the Schatten norm. + * @return The condition number of {@code src}. */ - public static double cond(Matrix src) { - return cond(src, 2); + public static double condSchatten(Matrix src, double p) { + return cond(src, p, MatrixNorms::schattenNorm); } /** - * Computes the condition number of this matrix using a specified norm. The condition number of a matrix is defined - * as the norm of a matrix multiplied by the norm of the inverse of the matrix. + * Computes the condition number of a matrix using the {@link MatrixNorms#schattenNorm(CMatrix, double) Schatten norm}. * @param src Matrix to compute the condition number of. - * @param p Specifies the order of the norm to be used when computing the condition number. - * Common {@code p} values include:
    - * - {@code p} = {@link Double#POSITIVE_INFINITY}, {@link MatrixNorms#infNorm(Matrix)}.
    - * - {@code p} = 2, The standard matrix 2-norm (the largest singular value).
    - * - {@code p} = -2, The Smallest singular value.
    - * - {@code p} = 1, Maximum absolute row sum.
    - * @return The condition number of this matrix using the specified norm. This value may be - * {@link Double#POSITIVE_INFINITY infinite}. + * @param p The p value in the Schatten norm. + * @return The condition number of {@code src}. */ - public static double cond(Matrix src, double p) { - double cond; - - if(p==2 || p==-2) { - // Compute the singular value decomposition of the matrix. - Vector s = new RealSVD(false).decompose(src).getS().getDiag(); - cond = p==2 ? s.max()/s.min() : s.min()/s.max(); - } else { - cond = MatrixNorms.norm(src, p)*MatrixNorms.norm(Invert.inv(src), p); - } + public static double condSchatten(CMatrix src, double p) { + return cond(src, p, MatrixNorms::schattenNorm); + } - return cond; + + /** + * Computes the condition number of a matrix using the Frobenius norm. + * @param src Matrix to compute the condition number of. + * @return The condition number of {@code src}. + */ + public static double condFro(Matrix src) { + return cond(src, 2, MatrixNorms::schattenNorm); } /** - * Computes the condition number of this matrix using the 2-norm. - * Specifically, the condition number is computed as the norm of this matrix multiplied by the norm - * of the inverse of this matrix. - * + * Computes the condition number of a matrix using the Frobenius norm. * @param src Matrix to compute the condition number of. - * @return The condition number of this matrix (Assuming 2-norm). This value may be - * {@link Double#POSITIVE_INFINITY infinite}. + * @return The condition number of {@code src}. */ - public static double cond(CMatrix src) { - return cond(src,2); + public static double condFro(CMatrix src) { + return cond(src, 2, MatrixNorms::schattenNorm); } /** - * Computes the condition number of this matrix using a specified norm. The condition number of a matrix is defined - * as the norm of a matrix multiplied by the norm of the inverse of the matrix. + * Computes the condition number of a matrix using an {@link MatrixNorms#entryWiseNorm(Matrix, double) entry-wise norm}. * @param src Matrix to compute the condition number of. - * @param p Specifies the order of the norm to be used when computing the condition number. - * Common {@code p} values include:
    - * - {@code p} = {@link Double#POSITIVE_INFINITY}, {@link MatrixNorms#infNorm(CMatrix)}.
    - * - {@code p} = 2, The standard matrix 2-norm (the largest singular value).
    - * - {@code p} = -2, The Smallest singular value.
    - * - {@code p} = 1, Maximum absolute row sum.
    - * @return The condition number of this matrix using the specified norm. This value may be - * {@link Double#POSITIVE_INFINITY infinite}. + * @return The condition number of {@code src}. */ - public static double cond(CMatrix src, double p) { - double cond; - - if(p==2 || p==-2) { - // Compute the singular value decomposition of the matrix. - Vector s = new ComplexSVD(false).decompose(src).getS().getDiag(); - cond = p==2 ? s.max()/s.min() : s.min()/s.max(); - } else { - cond = MatrixNorms.norm(src, p) * MatrixNorms.norm(Invert.inv(src), p); - } + public static double condEntryWise(Matrix src, double p) { + return cond(src, p, MatrixNorms::entryWiseNorm); + } + + + /** + * Computes the condition number of a matrix using an {@link MatrixNorms#entryWiseNorm(CMatrix, double) entry-wise norm}. + * @param src Matrix to compute the condition number of. + * @return The condition number of {@code src}. + */ + public static double condEntryWise(CMatrix src, double p) { + return cond(src, p, MatrixNorms::entryWiseNorm); + } + + + /** + * Computes the condition number of a matrix using the specified norm. + * @param src Matrix to compute the condition number of. + * @param p The p-value in the norm. + * @param norm The norm to apply when computing the condition number. + * @return + */ + private static double cond(Matrix src, double p, BiFunction norm) { + return norm.apply(src, p) * norm.apply(Invert.inv(src), p); + } + - return cond; + /** + * Computes the condition number of a matrix using the specified norm. + * @param src Matrix to compute the condition number of. + * @param p The p-value in the norm. + * @param norm The norm to apply when computing the condition number. + * @return + */ + private static double cond(CMatrix src, double p, BiFunction norm) { + return norm.apply(src, p) * norm.apply(Invert.inv(src), p); } } diff --git a/src/main/java/org/flag4j/linalg/MatrixNorms.java b/src/main/java/org/flag4j/linalg/MatrixNorms.java index b662316a1..3191276dd 100644 --- a/src/main/java/org/flag4j/linalg/MatrixNorms.java +++ b/src/main/java/org/flag4j/linalg/MatrixNorms.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -24,66 +24,258 @@ package org.flag4j.linalg; -import org.flag4j.algebraic_structures.Complex128; +import org.flag4j.algebraic_structures.Ring; import org.flag4j.arrays.Shape; +import org.flag4j.arrays.backend.ring_arrays.AbstractDenseRingMatrix; import org.flag4j.arrays.dense.CMatrix; import org.flag4j.arrays.dense.Matrix; +import org.flag4j.arrays.dense.Vector; import org.flag4j.arrays.sparse.CooCMatrix; import org.flag4j.arrays.sparse.CooMatrix; +import org.flag4j.arrays.sparse.CsrCMatrix; import org.flag4j.arrays.sparse.CsrMatrix; +import org.flag4j.linalg.decompositions.svd.ComplexSVD; +import org.flag4j.linalg.decompositions.svd.RealSVD; import org.flag4j.linalg.ops.common.real.RealProperties; import org.flag4j.linalg.ops.common.ring_ops.CompareRing; -import org.flag4j.linalg.ops.sparse.coo.field_ops.CooFieldNorms; import org.flag4j.linalg.ops.sparse.coo.real.RealSparseNorms; +import org.flag4j.linalg.ops.sparse.coo.ring_ops.CooRingNorms; import org.flag4j.util.ValidateParameters; +import org.flag4j.util.exceptions.LinearAlgebraException; + +import java.util.function.Function; + /** - * Utility class containing static methods for computing norms of matrices. + *

    A utility class providing a range of matrix norm computations for both real and complex matrices.

    + * + *

    Overview

    + *

    This class includes static methods to compute:

    + *
      + *
    • Schatten norms (p-norms of the singular values), including: + *
        + *
      • Nuclear norm (p=1)
      • + *
      • Frobenius norm (p=2)
      • + *
      • Spectral norm (p = {@link Double#POSITIVE_INFINITY})
      • + *
      + *
    • + *
    • Induced norms (operator norms) for specific values of p: + *
        + *
      • p = 1 or -1 (maximum/minimum absolute column sums)
      • + *
      • p = 2 or -2 (largest/smallest singular value)
      • + *
      • p = {@link Double#POSITIVE_INFINITY} or {@link Double#NEGATIVE_INFINITY} + * (maximum/minimum absolute row sum)
      • + *
      + *
    • + *
    • Lp,q norms for both dense and sparse (COO/CSR) matrices.
    • + *
    • Common norms like the Frobenius norm, maximum absolute value (max norm), + * and infinite norm (maximum row sum) for real and complex matrices.
    • + *
    • Entry-wise p-norms, computed by flattening the matrix and computing the vector p-norm.
    • + *
    + * + *

    Example Usage:

    + *
    {@code
    + * Matrix A = ...; // some real matrix
    + * double fro = MatrixNorms.norm(A); // Frobenius norm
    + * double nuc = MatrixNorms.shattenNorm(A, 1.0); // nuclear norm
    + * double spec = MatrixNorms.inducedNorm(A, Double.POSITIVE_INFINITY); // spectral norm
    + * }
    + * + *

    For complex matrices, use the corresponding overloads that accept a {@code CMatrix} or other complex matrix type.

    */ public final class MatrixNorms { private MatrixNorms() { // Hide default constructor for utility class + } + + /** + *

    Computes the Schatten p-norm of a real dense matrix. This is equivalent to the p-norm of the vector of singular values of the + * matrix. + * + *

    This method accepts values of p which are negative. When {@code p < 0} the result is not a true norm but may still have + * numerical uses. + * + * @param src The matrix to compute the norm of. + * @param p The p value in the Schatten p-norm. Some common cases include: + *

      + *
    • {@code p=1}: The nuclear (or trace) norm. Equivalent to the sum of singular values.
    • + *
    • {@code p=2}: Frobenius (or L2, 2) norm. Equivalent to the square root of the sum of the absolute squares + * of all entries in the matrix.
    • + *
    • {@code p=Double.POSITIVE_INFINITY}: The spectral norm. Equivalent to the maximum singular value.
    • + *
    • {@code p=Double.NEGATIVE_INFINITY}: The minimum singular value.
    • + *
    + * @return The Schatten p-norm of {@code src}. + */ + public static double schattenNorm(Matrix src, double p) { + if(p == 1.0) { + return nuclearNorm(src); // Nuclear norm. + } else if(p == 2.0) { + return VectorNorms.norm(src.data); // Frobenius norm. + } else if(p == Double.POSITIVE_INFINITY) { + return svdBasedNorm(src, RealProperties::max); // Spectral norm. + } else { + Vector sigmas = new RealSVD(false, true).decompose(src).getSingularValues(); + return VectorNorms.norm(sigmas.data, p); + } } /** - * Computes the 2-norm of this tensor. This is equivalent to {@link #norm(Matrix, double) norm(2)}. - * This will be equal to the largest singular value of the matrix. + *

    Computes the Schatten p-norm of a complex dense matrix. This is equivalent to the p-norm of + * the vector of singular values of the matrix. * - * @param src Matrix to compute norm of. + * @param src The matrix to compute the norm of. + * @param p The p value in the Schatten p-norm. Must be greater than or equal to 1. Some common cases include: + *

      + *
    • {@code p=1}: The nuclear (or trace) norm. Equivalent to the sum of singular values.
    • + *
    • {@code p=2}: Frobenius (or L2, 2) norm. Equivalent to the square root of the sum of the absolute squares + * of all entries in the matrix.
    • + *
    • {@code p=Double.POSITIVE_INFINITY}: The spectral norm. Equivalent to the largest singular value.
    • + *
    + * @return The Schatten p-norm of {@code src}. + * @throws IllegalArgumentException If {@code p < 1}. + */ + public static double schattenNorm(CMatrix src, double p) { + ValidateParameters.ensureGreaterEq(1, p, "p"); + + if(p == 1.0) { + return nuclearNorm(src); // Nuclear norm. + } else if(p == 2.0) { + return VectorNorms.norm(src.data); // Frobenius norm. + } else if(p == Double.POSITIVE_INFINITY) { + return svdBasedNorm(src, RealProperties::max); // Spectral norm. + } else { + Vector sigmas = new ComplexSVD(false, true).decompose(src).getSingularValues(); + return VectorNorms.norm(sigmas.data, p); + } + } + + + /** + *

    Computes the matrix operator norm of a real dense matrix "induced" by the vector p-norm. + * Specifically, this method computes the operator norm of the matrix as: + *

    +     *     ||A||p = supx≠0(||Ax||p / ||x||p).
    * - * @return the 2-norm of this tensor. + *

    This method supports a limited set of {@code p} values which yield simple formulas. When {@code p < 1}, the result this method + * returns is not a true mathematical norm. However, these values may still be useful for numerical purposes. + *

      + *
    • {@code p=1}: The maximum absolute column sum.
    • + *
    • {@code p=-1}: The minimum absolute column sum.
    • + *
    • {@code p=2}: The spectral norm. Equivalent to the largest singular value of the matrix.
    • + *
    • {@code p=-2}: The smallest singular value of the matrix.
    • + *
    • {@code p=Double.POSITIVE_INFINITY}: The maximum absolute row sum.
    • + *
    • {@code p=Double.NEGATIVE_INFINITY}: The minimum absolute row sum.
    • + *
    + * + * @param src Matrix to compute the norm of. + * @param p The p value in the "induced" p-norm. Must be one of the following: {@code 1}, {@code -1}, {@code 2}, {@code -2}, + * {@link Double#POSITIVE_INFINITY} or {@link Double#NEGATIVE_INFINITY}. + * @return Norm of the matrix. + * @throws LinearAlgebraException If {@code p} is not one of the following: {@code 1}, {@code -1}, {@code 2}, {@code -2}, + * {@link Double#POSITIVE_INFINITY} or {@link Double#NEGATIVE_INFINITY}. */ - public static double norm(Matrix src) { - return TensorNorms.tensorNormL2(src.data); + public static double inducedNorm(Matrix src, double p) { + if(p == 1.0) { + return colBasedNorm(src.shape, src.data, RealProperties::max); + } else if(p == -1.0) { + return colBasedNorm(src.shape, src.data, RealProperties::min); + } else if(p == 2.0) { + return svdBasedNorm(src, RealProperties::max); + } else if(p == -2.0) { + return svdBasedNorm(src, RealProperties::min); + } else if(p == Double.POSITIVE_INFINITY) { + return rowBasedNorm(src.shape, src.data, RealProperties::max); + } else if(p == Double.NEGATIVE_INFINITY) { + return rowBasedNorm(src.shape, src.data, RealProperties::min); + } else { + throw new LinearAlgebraException("Unsupported norm type: p = " + p + ".\n" + + "Supported values are: 1, -1, 2, -2, Double.POSITIVE_INFINITY, and Double.NEGATIVE_INFINITY."); + } } /** - * Computes the p-norm of this tensor. Equivalent to calling {@link #norm(Matrix, double, double) norm(p, p)} + *

    Computes the matrix operator norm of a complex dense matrix "induced" by the vector p-norm. + * Specifically, this method computes the operator norm of the matrix as: + *

    +     *     ||A||p = supx≠0(||Ax||p / ||x||p).
    * - * @param src Matrix to compute norm of. - * @param p The p value in the p-norm.
    - * - If p is inf, then this method computes the maximum/infinite norm. - * @return The p-norm of this tensor. - * @throws IllegalArgumentException If p is less than 1. - */ - public static double norm(Matrix src, double p) { - double norm; - - if(Double.isInfinite(p)) { - if(p > 0) { - norm = maxNorm(src); - } else { - norm = src.minAbs(); - } + *

    This method supports a limited set of {@code p} values which yield simple formulas. When {@code p < 1}, the result this method + * returns is not a true mathematical norm. However, these values may still be useful for numerical purposes. + *

      + *
    • {@code p=1}: The maximum absolute column sum.
    • + *
    • {@code p=-1}: The minimum absolute column sum.
    • + *
    • {@code p=2}: The spectral norm. Equivalent to the largest singular value of the matrix.
    • + *
    • {@code p=-2}: The smallest singular value of the matrix.
    • + *
    • {@code p=Double.POSITIVE_INFINITY}: The maximum absolute row sum.
    • + *
    • {@code p=Double.NEGATIVE_INFINITY}: The minimum absolute row sum.
    • + *
    + * + * @param src Matrix to compute the norm of. + * @param p The p value in the "induced" p-norm. Must be one of the following: {@code 1}, {@code -1}, {@code 2}, {@code -2}, + * {@link Double#POSITIVE_INFINITY} or {@link Double#NEGATIVE_INFINITY}. + * @return Norm of the matrix. + * @throws LinearAlgebraException If {@code p} is not one of the following: {@code 1}, {@code -1}, {@code 2}, {@code -2}, + * {@link Double#POSITIVE_INFINITY} or {@link Double#NEGATIVE_INFINITY}. + */ + public static double inducedNorm(CMatrix src, double p) { + if(p == 1.0) { + return colBasedNorm(src.shape, src.data, RealProperties::max); + } else if(p == -1.0) { + return colBasedNorm(src.shape, src.data, RealProperties::min); + } else if(p == 2.0) { + return svdBasedNorm(src, RealProperties::max); + } else if(p == -2.0) { + return svdBasedNorm(src, RealProperties::min); + } else if(p == Double.POSITIVE_INFINITY) { + return rowBasedNorm(src.shape, src.data, RealProperties::max); + } else if(p == Double.NEGATIVE_INFINITY) { + return rowBasedNorm(src.shape, src.data, RealProperties::min); } else { - norm = TensorNorms.tensorNormLp(src.data, p); + throw new LinearAlgebraException("Unsupported norm type: p = " + p + ".\n" + + "Supported values are: 1, -1, 2, -2, Double.POSITIVE_INFINITY, and Double.NEGATIVE_INFINITY."); } + } - return norm; + + /** + *

    Computes the Frobenius (or L2, 2) norm of a real dense matrix. + * + *

    The Frobenius norm is defined as the square root of the sum of absolute squares of all entries in the matrix. + * + *

    This method is equivalent to {@link #norm(Matrix, double, double) norm(src, 2, 2)}. + * However, this method should generally be preferred over + * {@link #norm(Matrix, double, double)} as it may be slightly more efficient. + * + * @param src Matrix to compute theFrobenius norm of. + * + * @return the Frobenius of this tensor. + * @see #norm(Matrix, double, double) + */ + public static double norm(Matrix src) { + return VectorNorms.norm(src.data); + } + + + /** + *

    Computes the Frobenius (or L2, 2) norm of a real dense matrix. + * + *

    The Frobenius norm is defined as the square root of the sum of absolute squares of all entries in the matrix. + * + *

    This method is equivalent to {@link #norm(AbstractDenseRingMatrix, double, double) norm(src, 2, 2)}. + * However, this method should generally be preferred over + * {@link #norm(AbstractDenseRingMatrix, double, double)} as it may be slightly more efficient. + * + * @param src Matrix to compute theFrobenius norm of. + * + * @return the Frobenius of this tensor. + * @see #norm(AbstractDenseRingMatrix, double, double) + */ + public static double norm(AbstractDenseRingMatrix src) { + return VectorNorms.norm(src.data); } @@ -95,114 +287,155 @@ public static double norm(Matrix src, double p) { * @see #infNorm(Matrix) */ public static double maxNorm(Matrix src) { - return matrixMaxNorm(src.data); + return RealProperties.maxAbs(src.data); } /** - * Computes the infinite norm of this matrix. that is the maximum row sum in the matrix. + * Computes the maximum norm of this matrix. That is, the maximum value in the matrix. * * @param src Matrix to compute norm of. - * @return The infinite norm of this matrix. - * @see #maxNorm(Matrix) + * @return The maximum norm of this matrix. + * @see #infNorm(AbstractDenseRingMatrix) */ - public static double infNorm(Matrix src) { - return matrixInfNorm(src.data, src.shape); + public static double maxNorm(AbstractDenseRingMatrix src) { + return CompareRing.maxAbs(src.data); } /** - * Computes the Lp, q norm of this matrix. + * Computes the infinite norm of this matrix. That is the maximum row sum in the matrix. * - * @param p P value in the Lp, q norm. - * @param q Q value in the Lp, q norm. - * @return The Lp, q norm of this matrix. + * @param src Matrix to compute norm of. + * @return The infinite norm of this matrix. + * @see #maxNorm(Matrix) */ - public static double norm(Matrix src, double p, double q) { - return matrixNormLpq(src.data, src.shape, p, q); + public static double infNorm(Matrix src) { + return rowBasedNorm(src.shape, src.data, RealProperties::max); } /** - * Computes the Lp, q norm of this matrix. + * Computes the infinite norm of this matrix. That is the maximum row sum in the matrix. * * @param src Matrix to compute norm of. - * @param p P value in the Lp, q norm. - * @param q Q value in the Lp, q norm. - * @return The Lp, q norm of this matrix. + * @return The infinite norm of this matrix. + * @see #maxNorm(Matrix) */ - public static double norm(CMatrix src, double p, double q) { - return matrixNormLpq(src.data, src.shape, p, q); + public static double infNorm(AbstractDenseRingMatrix src) { + return rowBasedNorm(src.shape, src.data, RealProperties::max); } /** - * Computes the max norm of a matrix. + *

    Computes the Lp, q norm of a real dense matrix. + *

    Some common special cases are: + *

      + *
    • {@code p=2}, {@code q=1}: The sum of Euclidean norms of the column vectors of the matrix.
    • + *
    • {@code p=2}, {@code q=2}: The Frobenius norm. Equivalent to the Euclidean norm of the vector of singular values of + * the matrix.
    • + *
    * - * @param src Matrix to compute norm of. - * @return The max norm of this matrix. + *

    The Lp, q norm is computed as if by: + *

    {@code
    +     *      double norm = 0;
    +     *      for(int j=0; j
    +     *
    +     * @param p p value in the Lp, q norm.
    +     * @param q q value in the Lp, q norm.
    +     * @return The Lp, q norm of {@code src}.
          */
    -    public static double maxNorm(CMatrix src) {
    -        return matrixMaxNorm(src.data);
    +    public static double norm(Matrix src, double p, double q) {
    +        if(p == q) return VectorNorms.norm(src.data, p);
    +        return matrixNormLpq(src.data, src.shape, p, q);
         }
     
     
         /**
    -     * Computes the 2-norm of this tensor. This is equivalent to {@link #norm(CMatrix, double) norm(2)}.
    +     * 

    Computes the Lp, q norm of a real dense matrix. + *

    Some common special cases are: + *

      + *
    • {@code p=2}, {@code q=1}: The sum of Euclidean norms of the column vectors of the matrix.
    • + *
    • {@code p=2}, {@code q=2}: The Frobenius norm. Equivalent to the Euclidean norm of the vector of singular values of + * the matrix.
    • + *
    * - * @param src Matrix to compute norm of. - * @return the 2-norm of this tensor. + *

    The Lp, q norm is computed as if by: + *

    {@code
    +     *      double norm = 0;
    +     *      for(int j=0; j
    +     *
    +     * @param p p value in the Lp, q norm.
    +     * @param q q value in the Lp, q norm.
    +     * @return The Lp, q norm of {@code src}.
          */
    -    public static double norm(CMatrix src) {
    -        return matrixNormL2(src.data, src.shape);
    +    public static double norm(AbstractDenseRingMatrix src, double p, double q) {
    +        if(p == q) return VectorNorms.norm(src.data, p);
    +        return matrixNormLpq(src.data, src.shape, p, q);
         }
     
     
         /**
    -     * Computes the p-norm of this tensor.
    +     * 

    Computes the entry-wise p-norm of a real dense matrix. * - * @param src Matrix to compute norm of. - * @param p The p value in the p-norm.
    - * - If p is inf, then this method computes the maximum/infinite norm. - * @return The p-norm of this tensor. - * @throws IllegalArgumentException If p is less than 1. - */ - public static double norm(CMatrix src, double p) { - double norm; - - if(Double.isInfinite(p)) { - if(p > 0) { - norm = maxNorm(src); - } else { - norm = src.minAbs(); - } - } else { - norm = matrixNormLp(src.data, src.shape, p); - } - - return norm; + *

    The entry-wise p-norm of a matrix is equivalent to the + * vector ℓp norm computed on the flattened matrix as if by {@code src.toVector().norm(p);}. + * + * @param src The matrix to compute the entry-wise norm of. + * @param p The p value in the ℓp vector norm. + * @return The entry-wise norm of {@code src}. + */ + public static double entryWiseNorm(Matrix src, double p) { + return VectorNorms.norm(src.data, p); } /** - * Computes the maximum/infinite norm of this tensor. + *

    Computes the entry-wise p-norm of a complex dense matrix. * - * @param src Matrix to compute norm of. - * @return The maximum/infinite norm of this tensor. + *

    The entry-wise p-norm of a matrix is equivalent to the + * vector ℓp norm computed on the flattened matrix as if by {@code src.toVector().norm(p);}. + * + * @param src The matrix to compute the entry-wise norm of. + * @param p The p value in the ℓp vector norm. + * @return The entry-wise norm of {@code src}. */ - public static double infNorm(CMatrix src) { - return matrixInfNorm(src.data, src.shape); + public static double entryWiseNorm(CMatrix src, double p) { + return VectorNorms.norm(src.data, p); } // ------------------------------ Sparse COO Matrices ------------------------------ /** - * Computes the Lp, q norm of this matrix. + *

    Computes the Lp, q norm of a real COO matrix. + *

    Some common special cases are: + *

      + *
    • {@code p=2}, {@code q=1}: The sum of Euclidean norms of the column vectors of the matrix.
    • + *
    • {@code p=2}, {@code q=2}: The Frobenius norm. Equivalent to the Euclidean norm of the vector of singular values of + * the matrix.
    • + *
    * - * @param src Matrix to compute norm of. - * @param p P value in the Lp, q norm. - * @param q Q value in the Lp, q norm. - * @return The Lp, q norm of this matrix. + * @param p p value in the Lp, q norm. + * @param q q value in the Lp, q norm. + * @return The Lp, q norm of {@code src}. */ public static double norm(CooMatrix src, double p, double q) { // Sparse implementation is usually only faster for very sparse matrices. @@ -212,21 +445,31 @@ public static double norm(CooMatrix src, double p, double q) { /** - * Computes the max norm of a matrix. + *

    Computes the Lp, q norm of a complex COO matrix. + *

    Some common special cases are: + *

      + *
    • {@code p=2}, {@code q=1}: The sum of Euclidean norms of the column vectors of the matrix.
    • + *
    • {@code p=2}, {@code q=2}: The Frobenius norm. Equivalent to the Euclidean norm of the vector of singular values of + * the matrix.
    • + *
    * - * @param src Matrix to compute norm of. - * @return The max norm of this matrix. + * @param p p value in the Lp, q norm. + * @param q q value in the Lp, q norm. + * @return The Lp, q norm of {@code src}. */ - public static double maxNorm(CooMatrix src) { - return matrixMaxNorm(src.data); + public static double norm(CooCMatrix src, double p, double q) { + // Sparse implementation is usually only faster for very sparse matrices. + return src.sparsity()>=0.95 ? CooRingNorms.matrixNormLpq(src, p, q) : + norm(src.toDense(), p, q); } /** - * Computes the 2-norm of this tensor. This is equivalent to {@link #norm(CooMatrix, double) norm(2)}. + * Computes the Frobenius (L2, 2) norm of this complex COO matrix. This is equivalent to + * {@link #norm(CooMatrix, double, double) norm(src, 2, 2)}. * - * @param src Matrix to compute norm of. - * @return the 2-norm of this tensor. + * @param src Matrix to compute the L2, 2 norm of. + * @return the Frobenius (L2, 2) norm of this tensor. */ public static double norm(CooMatrix src) { // Sparse implementation is usually only faster for very sparse matrices. @@ -236,33 +479,27 @@ public static double norm(CooMatrix src) { /** - * Computes the p-norm of this tensor. + * Computes the Frobenius (L2, 2) norm of this complex COO matrix. This is equivalent to + * {@link #norm(CooCMatrix, double, double) norm(src, 2, 2)}. * - * @param src Matrix to compute norm of. - * @param p The p value in the p-norm.
    - * - If p is inf, then this method computes the maximum/infinite norm. - * @return The p-norm of this tensor. - * @throws IllegalArgumentException If p is less than 1. + * @param src Matrix to compute the L2, 2 norm of. + * @return the Frobenius (L2, 2) norm of this tensor. */ - public static double norm(CooMatrix src, double p) { + public static double norm(CooCMatrix src) { // Sparse implementation is usually only faster for very sparse matrices. - return src.sparsity()>=0.95 ? RealSparseNorms.matrixNormLp(src, p) : - norm(src.toDense(), p); + return src.sparsity()>=0.95 ? CooRingNorms.matrixNormL22(src) : + norm(src.toDense()); } /** - * Computes the Lp, q norm of this matrix. + * Computes the max norm of a matrix. * * @param src Matrix to compute norm of. - * @param p P value in the Lp, q norm. - * @param q Q value in the Lp, q norm. - * @return The Lp, q norm of this matrix. + * @return The max norm of this matrix. */ - public static double norm(CooCMatrix src, double p, double q) { - // Sparse implementation is usually only faster for very sparse matrices. - return src.sparsity()>=0.95 ? CooFieldNorms.matrixNormLpq(src, p, q) : - norm(src.toDense(), p, q); + public static double maxNorm(CooCMatrix src) { + return CompareRing.maxAbs(src.data); } @@ -272,52 +509,97 @@ public static double norm(CooCMatrix src, double p, double q) { * @param src Matrix to compute norm of. * @return The max norm of this matrix. */ - public static double maxNorm(CooCMatrix src) { - return matrixMaxNorm(src.data); + public static double maxNorm(CooMatrix src) { + return RealProperties.maxAbs(src.data); } + // ------------------------------ Sparse CSR Matrices ------------------------------ /** - * Computes the 2-norm of this tensor. This is equivalent to {@link #norm(CooCMatrix, double) norm(2)}. + *

    Computes the Lp, q norm of a real CSR matrix. + *

    Some common special cases are: + *

      + *
    • {@code p=2}, {@code q=1}: The sum of Euclidean norms of the column vectors of the matrix.
    • + *
    • {@code p=2}, {@code q=2}: The Frobenius norm. Equivalent to the Euclidean norm of the vector of singular values of + * the matrix.
    • + *
    * - * @param src Matrix to compute the norm. - * @return the 2-norm of this tensor. + * @param p p value in the Lp, q norm. + * @param q q value in the Lp, q norm. + * @return The Lp, q norm of {@code src}. */ - public static double norm(CooCMatrix src) { - // Sparse implementation is usually only faster for very sparse matrices. - return src.sparsity()>=0.95 ? CooFieldNorms.matrixNormL2(src) : - norm(src.toDense()); + public static double norm(CsrMatrix src, double p, double q) { + if(p == 0 || q == 0) + throw new IllegalArgumentException("p and q must be non-zero for norm."); + + double norm = 0; + double qOverP = q / p; + + // stores intermediate column norms. + double[] colNorms = new double[src.numCols]; + + // Accumulate column-wise norms. + for (int row = 0; row < src.numRows; row++) { + int start = src.rowPointers[row]; + int end = src.rowPointers[row + 1]; + + for (int idx = start; idx < end; idx++) { + int col = src.colIndices[idx]; + double value = src.data[idx]; + + colNorms[col] += Math.pow(Math.abs(value), p); + } + } + + // Compute the q-norm of the column norms. + for (double colNorm : colNorms) + if (colNorm > 0) norm += Math.pow(colNorm, qOverP); + + return Math.pow(norm, 1.0 / q); } /** - * Computes the p-norm of this tensor. + *

    Computes the Lp, q norm of a complex CSR matrix. + *

    Some common special cases are: + *

      + *
    • {@code p=2}, {@code q=1}: The sum of Euclidean norms of the column vectors of the matrix.
    • + *
    • {@code p=2}, {@code q=2}: The Frobenius norm. Equivalent to the Euclidean norm of the vector of singular values of + * the matrix.
    • + *
    * - * @param src Matrix to compute the norm. - * @param p The p value in the p-norm.
    - * - If p is inf, then this method computes the maximum/infinite norm. - * @return The p-norm of this tensor. - * @throws IllegalArgumentException If p is less than 1. + * @param p p value in the Lp, q norm. + * @param q q value in the Lp, q norm. + * @return The Lp, q norm of {@code src}. */ - public static double norm(CooCMatrix src, double p) { - // Sparse implementation is usually only faster for very sparse matrices. - return src.sparsity()>=0.95 ? CooFieldNorms.matrixNormLp(src, p) : - norm(src.toDense(), p); - } + public static double norm(CsrCMatrix src, double p, double q) { + if(p == 0 || q == 0) + throw new IllegalArgumentException("p and q must be non-zero for norm."); + double norm = 0; + double qOverP = q / p; - // CSR Matrices + // stores intermediate column norms. + double[] colNorms = new double[src.numCols]; - /** - * Computes the Lp, q norm of this matrix. - * - * @param src Matrix to compute norm of. - * @param p P value in the Lp, q norm. - * @param q Q value in the Lp, q norm. - * @return The Lp, q norm of this matrix. - */ - public static double norm(CsrMatrix src, double p, double q) { - return matrixNormLpq(src, p, q); + // Accumulate column-wise norms. + for (int row = 0; row < src.numRows; row++) { + int start = src.rowPointers[row]; + int end = src.rowPointers[row + 1]; + + for (int idx = start; idx < end; idx++) { + int col = src.colIndices[idx]; + double value = src.data[idx].mag(); + + colNorms[col] += Math.pow(value, p); + } + } + + // Compute the q-norm of the column norms. + for (double colNorm : colNorms) + if (colNorm > 0) norm += Math.pow(colNorm, qOverP); + + return Math.pow(norm, 1.0 / q); } @@ -328,66 +610,73 @@ public static double norm(CsrMatrix src, double p, double q) { * @return The max norm of this matrix. */ public static double maxNorm(CsrMatrix src) { - return matrixMaxNorm(src.data); + return RealProperties.maxAbs(src.data); } /** - * Computes the 2-norm of this tensor. This is equivalent to {@link #norm(CsrMatrix, double) norm(2)}. + * Computes the max norm of a matrix. * - * @param src Matrix to compute the norm of. - * @return the 2-norm of this tensor. + * @param src Matrix to compute norm of. + * @return The max norm of this matrix. */ - public double norm(CsrMatrix src) { - return TensorNorms.tensorNormL2(src.data); // Zeros do not contribute to this norm. + public static double maxNorm(CsrCMatrix src) { + return CompareRing.maxAbs(src.data); } /** - * Computes the p-norm of this tensor. + * Computes the Frobenius (L2, 2) of this matrix. This is equivalent to {@link #norm(CsrMatrix, double, double) norm + * (src, 2, 2)}. * - * @param src Matrix to compute norm of. - * @param p The p value in the p-norm.
    - * - If p is inf, then this method computes the maximum/infinite norm. - * @return The p-norm of this tensor. - * @throws IllegalArgumentException If p is less than 1. + * @param src Matrix to compute the norm of. + * @return the Frobenius of this matrix. */ - public double norm(CsrMatrix src, double p) { - return TensorNorms.tensorNormLp(src.data, p); // Zeros do not contribute to this norm. + public static double norm(CsrMatrix src) { + return VectorNorms.norm(src.data); // Zeros do not contribute to this norm. } - // -------------------------------------------------- Low-level implementations -------------------------------------------------- /** - * Computes the infinity/maximum norm of a matrix. That is, the maximum value in this matrix. - * @param src Entries of the matrix. - * @return The infinity norm of the matrix. + * Computes the Frobenius of this matrix. This is equivalent to {@link #norm(CsrCMatrix, double, double) norm(src, 2, 2)}. + * + * @param src Matrix to compute the norm of. + * @return the Frobenius of this matrix. */ - private static double matrixMaxNorm(double[] src) { - return RealProperties.maxAbs(src); + public static double norm(CsrCMatrix src) { + return VectorNorms.norm(src.data); // Zeros do not contribute to this norm. } + // -------------------------------------------------- Low-level implementations -------------------------------------------------- /** - * Computes the infinity/maximum norm of a matrix. That is, the maximum value in this matrix. + * Compute the Lp, q norm of a matrix. * @param src Entries of the matrix. * @param shape Shape of the matrix. - * @return The infinity norm of the matrix. + * @param p First parameter in Lp, q norm. + * @param q Second parameter in Lp, q norm. + * @return The Lp, q norm of the matrix. */ - private static double matrixInfNorm(double[] src, Shape shape) { + private static double matrixNormLpq(double[] src, Shape shape, double p, double q) { + if(p == 0 || q == 0) + throw new LinearAlgebraException("p and q must be non-zero for norm but got p=" + p + " and q=" + q + "."); + + double norm = 0; + double colSum; int rows = shape.get(0); int cols = shape.get(1); - double[] rowSums = new double[rows]; - for(int i=0; ip, q norm of a matrix. * @param src Entries of the matrix. @@ -395,10 +684,10 @@ private static double matrixInfNorm(double[] src, Shape shape) { * @param p First parameter in Lp, q norm. * @param q Second parameter in Lp, q norm. * @return The Lp, q norm of the matrix. - * @throws IllegalArgumentException If {@code p} or {@code q} is less than 1. */ - private static double matrixNormLpq(Complex128[] src, Shape shape, double p, double q) { - ValidateParameters.ensureGreaterEq(1, p, q); + private static > double matrixNormLpq(T[] src, Shape shape, double p, double q) { + if(p == 0 || q == 0) + throw new IllegalArgumentException("p and q must be non-zero for norm but got p=" + p + " and q=" + q + "."); double norm = 0; double colSum; @@ -406,154 +695,162 @@ private static double matrixNormLpq(Complex128[] src, Shape shape, double p, dou int cols = shape.get(1); for(int j=0; jp norm of a matrix. This is equivalent to passing {@code q=1} to - * {@link #matrixNormLpq(Complex128[], Shape, double, double)} - * @param src Entries of the matrix. + * Helper method for computing a matrix norm which is based on the absolute row sums. + * * @param shape Shape of the matrix. - * @param p Parameter in Lp norm. - * @return The Lp norm of the matrix. - * @throws IllegalArgumentException If {@code p} is less than 1. + * @param src Entries of the matrix. + * @param aggregator Operation to apply to absolute row sums. + * + * @return The row-based matrix norm. */ - private static double matrixNormLp(Complex128[] src, Shape shape, double p) { - ValidateParameters.ensureGreaterEq(1, p); - - double norm = 0; - double colSum; + private static double rowBasedNorm(Shape shape, double[] src, Function aggregator) { int rows = shape.get(0); int cols = shape.get(1); + double[] rowSums = new double[rows]; - for(int j=0; j2 norm of a matrix. This is equivalent to passing {@code q=1} to - * {@link #matrixNormLpq(Complex128[], Shape, double, double)} - * @param src Entries of the matrix. + * Helper method for computing a matrix norm which is based on the absolute row sums. + * * @param shape Shape of the matrix. - * @return The L2 norm of the matrix. + * @param src Entries of the matrix. + * @param aggregator Operation to apply to absolute row sums. + * + * @return The row-based matrix norm. */ - private static double matrixNormL2(Complex128[] src, Shape shape) { - double norm = 0; + private static > double rowBasedNorm(Shape shape, T[] src, Function aggregator) { int rows = shape.get(0); int cols = shape.get(1); + double[] rowSums = new double[rows]; - double colSum; - - for(int j=0; j aggregator) { + Vector sigmas = new RealSVD(false, true).decompose(src).getSingularValues(); + return aggregator.apply(sigmas.data); } /** - * Computes the infinity/maximum norm of a matrix. That is, the maximum absolute value in this matrix. + * Helper for computing an SVD based norm of a complex dense matrix. + * + * @param src Matrix for which to compute SVD based norm. + * @param aggregator Operation to apply to the vector of singular values. + * + * @return The result of applying the {@code aggregator} function to the singular values of {@code src}. + */ + private static double svdBasedNorm(CMatrix src, Function aggregator) { + Vector sigmas = new ComplexSVD(false, true).decompose(src).getSingularValues(); + return aggregator.apply(sigmas.data); + } + + + /** + * Helper method for computing a matrix norm which is based on the absolute column sums. + * + * @param shape Shape of the matrix. * @param src Entries of the matrix. - * @return The infinity norm of the matrix. + * @param aggregator Operation to apply to absolute column sums. + * + * @return The column-based matrix norm. */ - private static double matrixInfNorm(Complex128[] src, Shape shape) { + private static double colBasedNorm(Shape shape, double[] src, Function aggregator) { int rows = shape.get(0); int cols = shape.get(1); - double[] rowSums = new double[rows]; + double[] colSums = new double[cols]; for(int i=0; ip,q norm of a sparse CSR matrix. - * @param src Sparse CSR matrix to compute norm of. - * @return The Lp,q norm of the matrix. + * Helper method for computing a matrix norm which is based on the absolute column sums. + * + * @param shape Shape of the matrix. + * @param src Entries of the matrix. + * @param aggregator Operation to apply to absolute column sums. + * + * @return The column-based matrix norm. */ - public static double matrixNormLpq(CsrMatrix src, double p, double q) { - CsrMatrix tSrc = src.T(); - double norm = 0; - double pOverQ = p/q; - - for(int i=0; i> double colBasedNorm(Shape shape, T[] src, Function aggregator) { + int rows = shape.get(0); + int cols = shape.get(1); + double[] colSums = new double[cols]; - norm += Math.pow(colNorm, pOverQ); + for(int i=0; ip, q norm of a matrix. - * @param src Entries of the matrix. - * @param shape Shape of the matrix. - * @param p First parameter in Lp, q norm. - * @param q Second parameter in Lp, q norm. - * @return The Lp, q norm of the matrix. - * @throws IllegalArgumentException If {@code p} or {@code q} is less than 1. + * Computes the nuclear norm of a real dense matrix. Equivalent to the sum of singular values of the matrix. + * @param src The matrix to compute the norm of. + * @return The nuclear norm of {@code src}. */ - public static double matrixNormLpq(double[] src, Shape shape, double p, double q) { - ValidateParameters.ensureGreaterEq(1, p, q); - - double norm = 0; - double colSum; - int rows = shape.get(0); - int cols = shape.get(1); + private static double nuclearNorm(Matrix src) { + return new RealSVD(false, true) + .decompose(src) + .getSingularValues() + .sum(); + } - for(int j=0; j tol) { // Then the least squares solution does not provide an "exact" solution. // Hence, the column of src2 cannot be expressed as a linear combination of the columns of src1 diff --git a/src/main/java/org/flag4j/linalg/TensorNorms.java b/src/main/java/org/flag4j/linalg/TensorNorms.java deleted file mode 100644 index 20c45a9ec..000000000 --- a/src/main/java/org/flag4j/linalg/TensorNorms.java +++ /dev/null @@ -1,240 +0,0 @@ -/* - * MIT License - * - * Copyright (c) 2024. Jacob Watters - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -package org.flag4j.linalg; - -import org.flag4j.algebraic_structures.Complex128; -import org.flag4j.arrays.dense.CTensor; -import org.flag4j.arrays.dense.Tensor; -import org.flag4j.arrays.sparse.CooCTensor; -import org.flag4j.arrays.sparse.CooTensor; -import org.flag4j.linalg.ops.common.ring_ops.CompareRing; -import org.flag4j.util.ValidateParameters; - -/** - * This utility class provides static methods useful for computing norms of a tensor. - */ -public final class TensorNorms { - - private TensorNorms() { - // Hide default constructor for utility class - - } - - // TODO: Ensure the below infNorm methods correct? These seems to be a max norm. Are the same? -// /** -// * Computes the infinity norm of a tensor, matrix, or vector. That is, the largest absolute value. -// * @param src The tensor, matrix, or vector to compute the norm of. -// * @return The infinity norm of the source tensor, matrix, or vector. -// */ -// public static double infNorm(DoubleTensorBase src) { -// return src.maxAbs(); -// } - - -// /** -// * Computes the infinity norm of a tensor, matrix, or vector. That is, the largest value by magnitude. -// * @param src The tensor, matrix, or vector to compute the norm of. -// * @return The infinity norm of the source tensor, matrix, or vector. -// */ -// public static double infNorm(FieldTensorBase src) { -// return src.maxAbs(); -// } - - - /** - * Computes the 2-norm of this tensor as if the tensor was a vector (i.e. as if by {@code VectorNorm(Tensor.toVector())}). - * This is equivalent to {@link #norm(Tensor, double) norm(src, 2)}. - * - * @param src Tensor to compute norm of. - * @return the 2-norm of this tensor. - */ - public static double norm(Tensor src) { - return tensorNormL2(src.data); - } - - - /** - * Computes the p-norm of this tensor. - * - * @param src Tensor to compute norm of. - * @param p The {@code p} value in the p-norm.
    - * - If {@code p} is inf, then this method computes the maximum/infinite norm. - * @return The p-norm of this tensor. - * @throws IllegalArgumentException If p is less than 1. - */ - public static double norm(Tensor src, double p) { - return tensorNormLp(src.data, p); - } - - - /** - * Computes the 2-norm of this tensor. This is equivalent to {@link #norm(CTensor, double) norm(src, 2)}. - * - * @return the 2-norm of this tensor. - */ - public double norm(CTensor src) { - return tensorNormL2(src.data); - } - - - /** - * Computes the p-norm of this tensor. - * - * @param src Tensor to compute norm of. - * @param p The p value in the p-norm.
    - * - If p is inf, then this method computes the maximum/infinite norm. - * @return The p-norm of this tensor. - * @throws IllegalArgumentException If p is less than 1. - */ - public double norm(CTensor src, double p) { - return tensorNormLp(src.data, p); - } - - - /** - * Computes the maximum/infinite norm of this tensor. - * - * @param src Tensor to compute norm of. - * @return The maximum/infinite norm of this tensor. - */ - public double infNorm(CTensor src) { - return CompareRing.maxAbs(src.data); - } - - - /** - * Computes the 2-norm of this tensor. This is equivalent to {@link #norm(CooTensor, double) norm(src, 2)}. - * - * @param src Tensor to compute norm of. - * @return the 2-norm of this tensor. - */ - public static double norm(CooTensor src) { - return tensorNormL2(src.data); - } - - - /** - * Computes the p-norm of this tensor. - * - * @param src Tensor to compute norm of. - * @param p The p value in the p-norm.
    - * - If p is inf, then this method computes the maximum/infinite norm. - * @return The p-norm of this tensor. - * @throws IllegalArgumentException If p is less than 1. - */ - public double norm(CooTensor src, double p) { - return tensorNormLp(src.data, p); - } - - - /** - * Computes the 2-norm of this tensor. This is equivalent to {@link #norm(CooTensor, double) norm(src, 2)}. - * - * @param src Tensor to compute norm of. - * @return the 2-norm of this tensor. - */ - public static double norm(CooCTensor src) { - return tensorNormL2(src.data); - } - - - /** - * Computes the p-norm of this tensor. - * - * @param src Tensor to compute norm of. - * @param p The p value in the p-norm.
    - * - If p is inf, then this method computes the maximum/infinite norm. - * @return The p-norm of this tensor. - * @throws IllegalArgumentException If p is less than 1. - */ - public double norm(CooCTensor src, double p) { - return tensorNormLp(src.data, p); - } - - - // ---------------------------- Low-level implementations ---------------------------- - - /** - * Computes the L2 norm of a tensor. - * @param src Entries of the tensor. - * @return The L2 norm of the tensor. - */ - public static double tensorNormL2(double[] src) { - double norm = 0; - - for(double value : src) - norm += Math.pow(Math.abs(value), 2); - - return Math.sqrt(norm); - } - - - /** - * Computes the Lp norm of a tensor. - * @param src Entries of the tensor. - * @param p The {@code p} parameter of the Lp norm. - * @return The Lp norm of the tensor. - */ - public static double tensorNormLp(double[] src, double p) { - ValidateParameters.ensureNotEquals(0, p); - double norm = 0; - - for(double value : src) - norm += Math.pow(Math.abs(value), p); - - return Math.pow(norm, 1.0/p); - } - - - /** - * Computes the L2 norm of a tensor (i.e. the Frobenius norm). - * @param src Entries of the tensor. - * @return The L2 norm of the tensor. - */ - public static double tensorNormL2(Complex128[] src) { - double norm = 0; - - for(Complex128 value : src) - norm += Complex128.pow((Complex128) value, 2).mag(); - - return Math.sqrt(norm); - } - - - /** - * Computes the Lp norm of a tensor (i.e. the Frobenius norm). - * @param src Entries of the tensor. - * @param p The {@code p} parameter of the Lp norm. - * @return The Lp norm of the tensor. - */ - public static double tensorNormLp(Complex128[] src, double p) { - double norm = 0; - - for(Complex128 value : src) - norm += Complex128.pow((Complex128) value, p).mag(); - - return Math.pow(norm, 1.0/p); - } -} diff --git a/src/main/java/org/flag4j/linalg/VectorNorms.java b/src/main/java/org/flag4j/linalg/VectorNorms.java index 0f66b9ba5..f221adef9 100644 --- a/src/main/java/org/flag4j/linalg/VectorNorms.java +++ b/src/main/java/org/flag4j/linalg/VectorNorms.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -24,296 +24,369 @@ package org.flag4j.linalg; -import org.flag4j.algebraic_structures.Field; import org.flag4j.algebraic_structures.Ring; -import org.flag4j.arrays.dense.CVector; -import org.flag4j.arrays.dense.FieldVector; -import org.flag4j.arrays.dense.Vector; -import org.flag4j.arrays.sparse.CooCVector; -import org.flag4j.arrays.sparse.CooVector; import org.flag4j.linalg.ops.common.real.RealProperties; import org.flag4j.linalg.ops.common.ring_ops.CompareRing; /** - * Utility class for computing norms of vectors. + * A utility class for computing vector norms, including various types of ℓp norms, + * with support for both dense and sparse vectors. This class provides methods to compute norms + * for vectors with real entries as well as vectors with entries that belong to a {@link Ring}. + * + *

    The methods in this class utilize scaling internally when computing the ℓp norm to protect against + * overflow and underflow for very large or very small values of {@code p} (in absolute value). + * + *

    Note: When {@code p < 1}, the results of the ℓp norm methods are not + * technically true mathematical norms but may still be useful for numerical tasks. However, {@code p = 0} + * will result in {@link Double#NaN}. + * + *

    This class is designed to be stateless and is not intended to be instantiated. */ public final class VectorNorms { + // TODO: This class currently uses scaling to avoid potential over/underflow issues. This works for most + // reasonable inputs. However, for very large (or small) values over/underflow may still occur. + // To better combat this, multiple accumulators can be used with different scaling as is done by LAPACK/BLAS drnm2. + // Specifically, a "big" accumulator which scales values down to ovoid overflow, a "small" accumulator which scales values up + // to avoid underflow, and a "medium" accumulator which applies no scaling. + private VectorNorms() { // Hide default constructor for utility class - } /** - * Computes the 2-norm of this vector. This is equivalent to {@link #norm(Vector, double) norm(src, 2)}. + *

    Computes the Euclidean (ℓ2) norm of a real dense or sparse vector. + *

    Zeros do not contribute to this norm so this function may be called on the entries of a dense vector or the non-zero entries + * of a sparse vector. * - * @param src Vector to compute norm of. - * @return the 2-norm of this vector. + * @param src Entries of the vector (or non-zero data if vector is sparse) to compute norm of. + * @return Euclidean (ℓ2) norm */ - public static double norm(Vector src) { - return norm(src.data); + public static double norm(double... src) { + return scaledL2Norm(src); } /** - * Computes the p-norm of this vector. + *

    Computes the Euclidean (ℓ2) norm of a dense or sparse vector whose entries are members of a + * {@link Ring}. + *

    Zeros do not contribute to this norm so this function may be called on the entries of a dense vector or the non-zero entries + * of a sparse vector. * - * @param src Vector to compute norm of. - * @param p The p value in the p-norm.
    - * - If p is inf, then this method computes the maximum/infinite norm. - * @return The p-norm of this vector. - * @throws IllegalArgumentException If p is less than 1. + * @param src Entries of the vector (or non-zero data if vector is sparse) to compute norm of. + * @return Euclidean (ℓ2) norm */ - public static double norm(Vector src, double p) { - return norm(src.data, p); + public static > double norm(T... src) { + return scaledL2Norm(src); } /** - * Computes the 2-norm of this vector. This is equivalent to {@link #norm(Vector, double) norm(2)}. + *

    Computes the ℓp norm (or p-norm) of a real dense or sparse vector. + *

    Some common norms: + *

      + *
    • {@code p=1}: The taxicab, city block, or Manhattan norm.
    • + *
    • {@code p=2}: The Euclidean or ℓ2 norm.
    • + *
    * - * @param src Vector to compute norm of. - * @return the 2-norm of this vector. - */ - public static double norm(CooVector src) { - return norm(src.data); - } - - - /** - * Computes the p-norm of this vector. + *

    Zeros do not contribute to this norm so this function may be called on the entries of a dense vector or the non-zero entries + * of a sparse vector. * - * @param src Vector to compute norm of. - * @param p The p value in the p-norm.
    - * - If p is inf, then this method computes the maximum/infinite norm. - * @return The p-norm of this vector. - * @throws IllegalArgumentException If p is less than 1. - */ - public static double norm(CooVector src, double p) { - return norm(src.data, p); - } - - - /** - * Computes the 2-norm of this vector. This is equivalent to {@link #norm(CVector, double) norm(2)}. + * @param src Entries of the vector (or non-zero data if vector is sparse). + * @param p The {@code p} value in the {@code p}-norm. When {@code p < 1}, the result of this method is not technically a + * true mathematical norm. However, it may be useful for various numerical tasks. + *

      + *
    • If {@code p} is finite, then the norm is computed as if by: + *
      {@code
      +     *     int norm = 0;
            *
      -     * @return the 2-norm of this vector.
      -     */
      -    @Deprecated
      -    public static double norm(CVector src) {
      -        return VectorNorms.norm(src.data);
      -    }
      -
      -
      -    /**
      -     * Computes the p-norm of this vector. Warning, if p is large in absolute value, overflow issues may occur.
      +     *     for(double v : src)
      +     *         norm += Math.pow(Math.abs(v), p);
            *
      -     * @param p The p value in the p-norm. 
      - * - If p is {@link Double#POSITIVE_INFINITY}, then this method computes the maximum/infinite norm.
      - * - If p is {@link Double#NEGATIVE_INFINITY}, then this method computes the minimum norm. - * @return The p-norm of this vector. + * return Math.pow(norm, 1.0/p); + * }
      + *
    • + *
    • If {@code p} is {@link Double#POSITIVE_INFINITY}, then this method computes the maximum/infinite norm.
    • + *
    • If {@code p} is {@link Double#NEGATIVE_INFINITY}, then this method computes the minimum norm.
    • + *
    + * + *

    Warning, if {@code p} is very large in absolute value, overflow errors may occur. + * @return The {@code p}-norm of the vector. */ - @Deprecated - public static double norm(CVector src, double p) { - return VectorNorms.norm(src.data, p); - } + public static double norm(double[] src, double p) { + if (src.length == 0) return 0; + if(p == Double.POSITIVE_INFINITY) { + return RealProperties.maxAbs(src); // Maximum norm. + } else if(p == Double.NEGATIVE_INFINITY) { + return RealProperties.minAbs(src); // Minimum "norm". + } else if(p == 1) { + double norm = 0; + for(double v : src) + norm += Math.abs(v); - /** - * Computes the 2-norm of this vector. This is equivalent to {@link #norm(CVector, double) norm(2)}. - * - * @return the 2-norm of this vector. - */ - public static > double norm(FieldVector src) { - return VectorNorms.norm(src.data); + return norm; + } else if(p == 2) { + return scaledL2Norm(src); + } else { + return scaledLpNorm(src, p); + } } /** - * Computes the p-norm of this vector. Warning, if p is large in absolute value, overflow issues may occur. + *

    Computes the ℓp norm (or p-norm) of a dense or sparse vector whose entries are members of a {@link Ring}. + *

    Some common norms: + *

      + *
    • {@code p=1}: The taxicab, city block, or Manhattan norm.
    • + *
    • {@code p=2}: The Euclidean or ℓ2 norm.
    • + *
    * - * @param p The {@code p} value in the p-norm: + *

    Zeros do not contribute to this norm so this function may be called on the entries of a dense vector or the non-zero entries + * of a sparse vector. + * + * @param src Entries of the vector (or non-zero data if vector is sparse). + * @param p The {@code p} value in the {@code p}-norm. When {@code p < 1}, the result of this method is not technically a + * true mathematical norm. However, it may be useful for various numerical tasks. *

      + *
    • If {@code p} is finite, then the norm is computed as if by: + *
      {@code
      +     *     int norm = 0;
      +     *
      +     *     for(double v : src)
      +     *         norm += Math.pow(Math.abs(v), p);
      +     *
      +     *     return Math.pow(norm, 1.0/p);
      +     *     }
      + *
    • *
    • If {@code p} is {@link Double#POSITIVE_INFINITY}, then this method computes the maximum/infinite norm.
    • *
    • If {@code p} is {@link Double#NEGATIVE_INFINITY}, then this method computes the minimum norm.
    • *
    * - * @return The p-norm of this vector. + *

    Warning, if {@code p} is very large in absolute value, overflow errors may occur. + * @return The {@code p}-norm of the vector. */ - public static > double norm(FieldVector src, double p) { - return VectorNorms.norm(src.data, p); - } + public static > double norm(T[] src, double p) { + if (src.length == 0) return 0; + if(p == Double.POSITIVE_INFINITY) { + return CompareRing.maxAbs(src); // Maximum norm. + } else if(p == Double.NEGATIVE_INFINITY) { + return CompareRing.minAbs(src); // Minimum "norm". + } else if(p == 1) { + double norm = 0; + for(T v : src) + norm += v.abs(); - /** - * Computes the 2-norm of this vector. This is equivalent to {@link #norm(CooCVector, double) norm(src, 2)}. - * - * @param src Vector to compute norm of. - * @return the 2-norm of this vector. - */ - public static double norm(CooCVector src) { - return VectorNorms.norm(src.data); + return norm; + } else if(p == 2) { + return scaledL2Norm(src); + } else { + return scaledLpNorm(src, p); + } } /** - * Computes the p-norm of this vector. + * Computes the scaled ℓp norm of a vector. + * This method uses scaling to protect against numerical instability such as overflow or underflow + * when computing the ℓp norm for large or small values of {@code p}. * - * @param src Vector to compute norm of. - * @param p The p value in the p-norm.
    - * - If p is inf, then this method computes the maximum/infinite norm. - * @return The p-norm of this vector. - * @throws IllegalArgumentException If p is less than 1. + * @param src The input vector (or non-zero values if vector is sparse) whose ℓp norm is to be computed. + * @param p The value of {@code p} for the ℓp norm. + * @return The scaled ℓp norm of the input vector. */ - public static double norm(CooCVector src, double p) { - return VectorNorms.norm(src.data, p); - } + private static double scaledLpNorm(double[] src, double p) { + // Find the maximum absolute value in the vector. + double maxAbs = RealProperties.maxAbs(src); + if (maxAbs == 0.0) return 0.0; // Quick return for zero vector. - /** - * Computes the infinity norm of a tensor, matrix, or vector. That is, the largest absolute value. - * @param src The vector to compute the norm of. - * @return The infinity norm of the source vector. - */ - public static double infNorm(CooVector src) { - return src.maxAbs(); - } + double maxInv = 1.0 / maxAbs; + // Compute the p-norm using scaled values. + double sum = 0; + for (double v : src) + sum += Math.pow(Math.abs(v) * maxInv, p); - /** - * Computes the infinity norm of a vector. That is, the largest absolute value. - * @param src The vector to compute the norm of. - * @return The infinity norm of the source vector. - */ - public static double infNorm(CooCVector src) { - return src.maxAbs(); + // Ensure result is properly scaled back up. + return maxAbs * Math.pow(sum, 1.0 / p); } /** - * Computes the infinity norm of a tensor, matrix, or vector. That is, the largest absolute value. - * @param src The vector to compute the norm of. - * @return The infinity norm of the source vector. + * Computes the scaled ℓ2 norm (Euclidean norm) of a vector. + * This method uses scaling to protect against numerical instability such as overflow or underflow + * when computing the ℓ2 norm for vectors with very large or very small values. + * + * @param src The input vector (or non-zero entries if the vector is sparse) whose ℓ2 norm is to be computed. + * @return The scaled ℓ2 norm of the input vector. */ - public static double infNorm(Vector src) { - return src.maxAbs(); - } + private static double scaledL2Norm(double[] src) { + // Find the maximum absolute value in the vector. + double maxAbs = RealProperties.maxAbs(src); + if (maxAbs == 0.0) return 0.0; // Quick return for zero vector. - /** - * Computes the infinity norm of a vector. That is, the largest absolute value. - * @param src The vector to compute the norm of. - * @return The infinity norm of the source vector. - */ - @Deprecated - public static double infNorm(CVector src) { - return src.maxAbs(); - } + double maxInv = 1.0 / maxAbs; + // Compute norm as a = |max(src)|, ||src|| = a*||src * (1/a)|| to help protect against overflow. + double sum = 0; + for(double v : src) { + double vScaled = v*maxInv; + sum += vScaled*vScaled; + } - /** - * Computes the infinity norm of a vector. That is, the largest absolute value. - * @param src The vector to compute the norm of. - * @return The infinity norm of the source vector. - */ - public static double infNorm(FieldVector src) { - return src.maxAbs(); + // Ensure result is properly scaled back up. + return Math.sqrt(sum)*maxAbs; } - // ---------------------------------------------- Low-level Implementations ---------------------------------------------- /** - * Computes the 2-norm of a vector. - * @param src Entries of the vector (or non-zero data if vector is sparse). - * @return The 2-norm of the vector. + * Computes the scaled ℓp norm of a vector. + * This method uses scaling to protect against numerical instability such as overflow or underflow + * when computing the ℓp norm for large or small values of {@code p}. + * + * @param src The input vector (or non-zero values if vector is sparse) whose ℓp norm is to be computed. + * @param p The value of {@code p} for the ℓp norm. + * @return The scaled ℓp norm of the input vector. */ - public static double norm(double... src) { - double norm = 0; - double maxAbs = RealProperties.maxAbs(src); - if(maxAbs == 0) return 0; // Early return for zero norm. + private static > double scaledLpNorm(T[] src, double p) { + // Find the maximum absolute value in the vector. + double maxAbs = CompareRing.maxAbs(src); - // Compute norm as a = |max(src)|, ||src|| = a*||src * (1/a)|| to help protect against overflow. - for(double v : src) { - double vScaled = v/maxAbs; - norm += vScaled*vScaled; - } + if (maxAbs == 0.0) return 0.0; // Quick return for zero vector. - return Math.sqrt(norm)*maxAbs; + double maxInv = 1.0 / maxAbs; + + // Compute the p-norm using scaled values. + double sum = 0; + for (T v : src) + sum += Math.pow(v.abs() * maxInv, p); + + // Ensure result is properly scaled back up. + return maxAbs * Math.pow(sum, 1.0 / p); } /** - * Computes the 2-norm of a vector. - * @param src Entries of the vector (or non-zero data if vector is sparse). - * @return The 2-norm of the vector. + * Computes the scaled ℓ2 norm (Euclidean norm) of a vector. + * This method uses scaling to protect against numerical instability such as overflow or underflow + * when computing the ℓ2 norm for vectors with very large or very small values. + * + * @param src The input vector (or non-zero entries if the vector is sparse) whose ℓ2 norm is to be computed. + * @return The scaled ℓ2 norm of the input vector. */ - public static > double norm(T... src) { - double norm = 0; - double scaledMag; + private static > double scaledL2Norm(T[] src) { + // Find the maximum absolute value in the vector. double maxAbs = CompareRing.maxAbs(src); - if(maxAbs == 0) return 0; // Early return for zero norm. + if (maxAbs == 0.0) return 0.0; // Quick return for zero vector. - double maxAbsRecip = 1.0 / maxAbs; + double maxInv = 1.0 / maxAbs; - // Compute norm as a = |max(src)|, ||src|| = a*||src * (1/a)|| to help protect against over/underflow. - for(Ring value : src) { - scaledMag = value.mag() * maxAbsRecip; - norm += scaledMag*scaledMag; + // Compute norm as a = |max(src)|, ||src|| = a*||src * (1/a)|| to help protect against overflow. + double sum = 0; + for(T v : src) { + double vScaled = v.mag() * maxInv; + sum += vScaled*vScaled; } - return Math.sqrt(norm)*maxAbs; + // Ensure result is properly scaled back up. + return Math.sqrt(sum)*maxAbs; } /** - * Computes the {@code p}-norm of a vector. - * @param src Entries of the vector (or non-zero data if vector is sparse). - * @param p The {@code p} value in the {@code p}-norm: + *

    Computes the ℓ2 (Euclidean) norm of a sub-vector within {@code src}, + * starting at index {@code start} and considering {@code n} elements spaced by {@code stride}. + * + *

    More formally, this method examines and computes the norm of the elements at indices: + * {@code start}, {@code start + stride}, {@code start + 2*stride}, ..., {@code start + (n-1)*stride}. + * + *

    This method may be used to compute the norm of a row or column in a + * {@link org.flag4j.arrays.dense.Matrix matrix} {@code a} as follows: *

      - *
    • If {@code p} is {@link Double#POSITIVE_INFINITY}, then this method computes the maximum/infinite norm.
    • - *
    • If {@code p} is {@link Double#NEGATIVE_INFINITY}, then this method computes the minimum norm.
    • + *
    • Norm of row {@code i}: + *
      {@code norm(a.data, i*a.numCols, a.numCols, 1);}
    • + *
    • Norm of column {@code j}: + *
      {@code norm(a.data, j, a.numRows, a.numRows);}
    • *
    * - *

    Warning, if {@code p} is very large in absolute value, overflow errors may occur. - * @return The {@code p}-norm of the vector. + * @param src The array to containing sub-vector elements to compute norm of. + * @param start The starting index in {@code src} to search. Must be positive but this is not explicitly enforced. + * @param n The number of elements to consider within {@code src1}. Must be positive but this is not explicitly enforced. + * @param stride The gap (in indices) between consecutive elements of the sub-vector within {@code src}. + * Must be positive but this is not explicitly enforced. + * @return The ℓ2 (Euclidean) norm of the specified sub-vector of {@code src}. + * + * @throws IndexOutOfBoundsException If {@code start + (n-1)*stride} exceeds {@code src.length - 1}. */ - public static double norm(double[] src, double p) { - if(Double.isInfinite(p)) { - if(p > 0) return RealProperties.maxAbs(src); // Maximum norm. - else return RealProperties.minAbs(src); // Minimum norm. - } else { - double norm = 0; + public static double norm(double[] src, final int start, final int n, final int stride) { + // Find the maximum absolute value in the vector. + double maxAbs = RealProperties.maxAbs(src, start, n, stride); - for(double v : src) - norm += Math.pow(Math.abs(v), p); + if (maxAbs == 0.0) return 0.0; // Quick return for zero vector. + + double maxInv = 1.0 / maxAbs; - return Math.pow(norm, 1.0/p); + final int end = start + n*stride; + + // Compute norm as a = |max(src)|, ||src|| = a*||src * (1/a)|| to help protect against overflow. + double sum = 0; + for(int i=start; i - * - If {@code p} is {@link Double#POSITIVE_INFINITY}, then this method computes the maximum/infinite norm.
    - * - If {@code p} is {@link Double#NEGATIVE_INFINITY}, then this method computes the minimum norm.
    - * Warning, if {@code p} is large in absolute value, overflow errors may occur. - * @return The {@code p}-norm of the vector. + *

    Computes the ℓ2 (Euclidean) norm of a sub-vector within {@code src}, + * starting at index {@code start} and considering {@code n} elements spaced by {@code stride}. + * + *

    More formally, this method examines and computes the norm of the elements at indices: + * {@code start}, {@code start + stride}, {@code start + 2*stride}, ..., {@code start + (n-1)*stride}. + * + *

    This method may be used to compute the norm of a row or column in a + * {@link org.flag4j.arrays.dense.Matrix matrix} {@code a} as follows: + *

      + *
    • Norm of row {@code i}: + *
      {@code norm(a.data, i*a.numCols, a.numCols, 1);}
    • + *
    • Norm of column {@code j}: + *
      {@code norm(a.data, j, a.numRows, a.numRows);}
    • + *
    + * + * @param src The array to containing sub-vector elements to compute norm of. + * @param start The starting index in {@code src} to search. Must be positive but this is not explicitly enforced. + * @param n The number of elements to consider within {@code src1}. Must be positive but this is not explicitly enforced. + * @param stride The gap (in indices) between consecutive elements of the sub-vector within {@code src}. + * Must be positive but this is not explicitly enforced. + * @return The ℓ2 (Euclidean) norm of the specified sub-vector of {@code src}. + * + * @throws IndexOutOfBoundsException If {@code start + (n-1)*stride} exceeds {@code src.length - 1}. */ - public static > double norm(T[] src, double p) { - if(Double.isInfinite(p)) { - if(p > 0) return CompareRing.maxAbs(src); // Maximum norm. - else return CompareRing.minAbs(src); // Minimum norm. - } else { - double norm = 0; + public static > double norm(T[] src, final int start, final int n, final int stride) { + // Find the maximum absolute value in the vector. + double maxAbs = CompareRing.maxAbs(src, start, n, stride); - for(T value : src) - norm += Math.pow(value.mag(), p); + if (maxAbs == 0.0) return 0.0; // Quick return for zero vector. - return Math.pow(norm, 1.0/p); + double maxInv = 1.0 / maxAbs; + final int end = start + n*stride; + + // Compute norm as a = |max(src)|, ||src|| = a*||src * (1/a)|| to help protect against overflow. + double sum = 0; + for(int i=start; i createRealLU() { - return new RealLU(); - } - - - /** - * Constructs a decomposer to compute the LU decomposition of a complex dense matrix. - * @return A decomposer to compute the LU decomposition of a complex dense matrix. - */ - public static LU createComplexLU() { - return new ComplexLU(); - } - - - /** - * Constructs a decomposer to compute the Cholesky decomposition of a real dense matrix. - * @return A decomposer to compute the Cholesky decomposition of a real dense matrix. - */ - public static Cholesky createRealChol() { - return new RealCholesky(); - } - - - /** - * Constructs a decomposer to compute the Cholesky decomposition of a complex dense matrix. - * @return A decomposer to compute the Cholesky decomposition of a complex dense matrix. - */ - public static Cholesky createComplexChol() { - return new ComplexCholesky(); - } - - - /** - * Constructs a decomposer to compute the QR decomposition of a real dense matrix. - * @return A decomposer to compute the QR decomposition of a real dense matrix. - */ - public static RealQR createRealQR() { - return new RealQR(); - } - - - /** - * Constructs a decomposer to compute the QR decomposition of a complex dense matrix. - * @return A decomposer to compute the QR decomposition of a complex dense matrix. - */ - public static ComplexQR createComplexQR() { - return new ComplexQR(); - } - - - /** - * Constructs a decomposer to compute the Hessenburg decomposition of a real dense matrix. - * @return A decomposer to compute the Hessenburg decomposition of a real dense matrix. - */ - public static RealHess createRealHess() { - return new RealHess(); - } - - - /** - * Constructs a decomposer to compute the Hessenburg decomposition of a complex dense matrix. - * @return A decomposer to compute the Hessenburg decomposition of a complex dense matrix. - */ - public static ComplexQR createComplexHess() { - return new ComplexQR(); - } - - - /** - * Constructs a decomposer to compute the Schur decomposition of a real dense matrix. - * @return A decomposer to compute the Schur decomposition of a real dense matrix. - */ - public static RealSchur createRealSchur() { - return new RealSchur(); - } - - - /** - * Constructs a decomposer to compute the Schur decomposition of a complex dense matrix. - * @return A decomposer to compute the Schur decomposition of a complex dense matrix. - */ - public static ComplexQR createComplexSchur() { - return new ComplexQR(); - } - - - /** - * Constructs a decomposer to compute the singular value decomposition of a real dense matrix. - * @return A decomposer to compute the singular value decomposition of a real dense matrix. - */ - public static RealSVD createRealSVD() { - return new RealSVD(); - } - - - /** - * Constructs a decomposer to compute the singular value decomposition of a complex dense matrix. - * @return A decomposer to compute the singular value decomposition of a complex dense matrix. - */ - public static ComplexSVD createComplexSVD() { - return new ComplexSVD(); - } -} diff --git a/src/main/java/org/flag4j/linalg/decompositions/balance/Balancer.java b/src/main/java/org/flag4j/linalg/decompositions/balance/Balancer.java new file mode 100644 index 000000000..d04449517 --- /dev/null +++ b/src/main/java/org/flag4j/linalg/decompositions/balance/Balancer.java @@ -0,0 +1,686 @@ +/* + * MIT License + * + * Copyright (c) 2025. Jacob Watters + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +package org.flag4j.linalg.decompositions.balance; + +import org.flag4j.arrays.Shape; +import org.flag4j.arrays.backend.MatrixMixin; +import org.flag4j.arrays.dense.Matrix; +import org.flag4j.arrays.sparse.PermutationMatrix; +import org.flag4j.linalg.decompositions.Decomposition; +import org.flag4j.util.ArrayUtils; +import org.flag4j.util.Flag4jConstants; +import org.flag4j.util.ValidateParameters; + +import java.util.Arrays; + + +/** + *

    Base class for all decompositions which implement matrix balancing. Balancing a matrix involves computing a + * diagonal similarity transformation to "balance" the rows and columns of the matrix. This balancing is achieved + * by attempting to scale the entries of the matrix by similarity transformations such that the 1-norms of corresponding + * rows and columns have the similar 1-norms. Rows and columns may also be permuted during balancing if requested. + * + *

    Balancing is often used as a preprocessing step to improve the conditioning of eigenvalue problems. Because the + * balancing transformation is a similarity transformation, the eigenvalues are preserved. Further, when permutations are + * done during balancing it is possible to isolate decoupled eigenvalues. + * + *

    The similarity transformation of a square matrix A into the balanced matrix B can be described as: + *

    + *     B = T-1 A T
    + *       = D-1 P-1 A P D.
    + * Solving for A, balancing may be viewed as the following decomposition: + *
    + *     A = T B T-1
    + *       = P D B D-1 P-1.
    + * Where P is a permutation matrix, and D is a diagonal scaling matrix. + * + *

    When permutations are used during balancing we obtain a specific form. First, + *

    + *             [ T1  X   Y  ]
    + *   P-1 A P = [  0  B1  Z  ]
    + *             [  0  0   T2 ]
    + * Where T1 and T2 are upper triangular matrices whose eigenvalues lie along the diagonal. These are also + * eigenvalues of A. Then, if scaling is applied we obtain: + *
    + *                  [ T1     X*D1       Y    ]
    + *   D-1 P-1 A P D = [  0  D1-1*B*1D1  D1-1*Z  ]
    + *                  [  0      0         T2   ]
    + * Where D1 is a diagonal matrix such that, + *
    + *         [ I1 0  0  ]
    + *     D = [ 0  D1 0  ]
    + *         [ 0  0  I2 ]
    + * Where I1 and I2 are identity matrices with equivalent shapes to T1 and T2. + * + *

    Once balancing has been applied, one need only compute the eigenvalues of B1 and combine them with the diagonal + * entries of T1 and T2 to obtain all eigenvalues of A. + * + *

    The code in this class if heavily based on LAPACK's reference implementations of + * xGEBAL (v 3.12.1). + * + * @param The type of matrix being balanced. + * + * @see #getB() + * @see #getBSubMatrix() + * @see #getD(boolean) + * @see #getD() + * @see #getP() + * @see #getT() + */ +public abstract class Balancer> implements Decomposition { + + /** + * Simple scaling factor used to help ensure safe scaling without over/underflow. + */ + private static final double FACTOR = 0.95; + + // Scaling factor to keep values as powers of two. + private static final double BASE_SCALE = 2.0; + + // Some constants which specify "safe" maximum and minimum values to avoid under/overflow. + private static final double SAFE_MIN_1 = Flag4jConstants.SAFE_MIN_F64 / (Flag4jConstants.EPS_F64*2.0); + private static final double SAFE_MAX_1 = 1.0 / SAFE_MIN_1; + private static final double SAFE_MIN_2 = SAFE_MIN_1*BASE_SCALE; + private static final double SAFE_MAX_2 = 1.0 / SAFE_MIN_2; + + /** + *

    Stores both the scaling and permutation information for the balanced matrix. + * + *

    Let {@code perm[j]} be the index of the row and column swapped with row and column {@code j} and + * {@code scale[j]} be the scaling factor applied to row and column {@code j}. Then, + *

      + *
    • {@code scalePerm[j] = perm[j]} for {@code j = 0, ..., iLow-1}.
    • + *
    • {@code scalePerm[j] = scale[j]} for {@code j = iLow, ..., iHigh-1}.
    • + *
    • {@code scalePerm[j] = perm[j]} for {@code j = iHigh, ..., size-1}.
    • + *
    + * + * The order which row and column swaps are made is {@code size-1} to {@code iHigh}, then from {@code 0} to {@code iLow}. + */ + protected double[] scalePerm; + /** + * Stores the balanced matrix. + */ + protected T balancedMatrix; + /** + * This size of the matrix to be balanced. + */ + protected int size; + /** + * Tracks the ending row/column index of the un-permuted submatrix to be balanced (exclusive). + */ + protected int iHigh; + /** + * Tracks the starting row/column index of the un-permuted submatrix to be balanced (inclusive). + */ + protected int iLow; + /** + * Flag indicating if scaling should be done during balancing. + *
      + *
    • If {@code true}, then scaling will be performed during balancing.
    • + *
    • If {@code false}, the no scaling will be done during balancing.
    • + *
    + */ + protected boolean doScaling; + /** + * Flag indicating if permutations should be done during balancing. + *
      + *
    • If {@code true}: Then row/column permutations will be performed during balancing.
    • + *
    • If {@code false}: Then row/column permutations will be performed during balancing.
    • + *
    + */ + protected boolean doPermutations; + /** + * Flag indicating if the balancing should be done in-place or if a copy should be made. + *
      + *
    • If {@code true}, the balancing will be done in-place and the matrix to be balanced will be overwritten.
    • + *
    • If {@code false}, a copy will be made of the matrix before balancing is applied and the original matrix will remain + * unmodified.
    • + *
    + */ + public final boolean inPlace; + + + /** + * @param doPermutations Flag indicating if row/column permutations should be used when balancing the matrix. + *
      + *
    • If {@code true}, permutations will be used and P will be computed.
    • + *
    • If {@code false}, permutations will not be used and the row and column positions will not be affected.
    • + *
    + * @param doScaling Flag indicating if row/column scaling should be done when balancing the matrix. + *
      + *
    • If {@code true}, scaling will be used and D will be computed.
    • + *
    • If {@code false}, scaling will not be used.
    • + *
    + * @param inPlace Flag indicating if the balancing should be done in-place or if a copy should be made. + *
      + *
    • If {@code true}, the balancing will be done in-place and the matrix to be balanced will be overwritten.
    • + *
    • If {@code false}, a copy will be made of the matrix before balancing is applied and the original matrix will remain + * unmodified.
    • + *
    + */ + protected Balancer(boolean doPermutations, boolean doScaling, boolean inPlace) { + this.inPlace = inPlace; + this.doPermutations = doPermutations; + this.doScaling = doScaling; + } + + + /** + * Swaps two rows, over a specified range, within the {@link #balancedMatrix} matrix. + * @param rowIdx1 Index of the first row to swap. + * @param rowIdx2 Index of the second row to swap. + * @param start Index of the column specifying the start of the range for the row swap (inclusive). + * @param stop Index of the column specifying the end of the range for the row swap (exclusive). + */ + protected abstract void swapRows(int rowIdx1, int rowIdx2, int start, int stop); + + + /** + * Swaps two columns, over a specified range, within the {@link #balancedMatrix} matrix. + * @param colIdx1 Index of the first column to swap. + * @param colIdx2 Index of the second column to swap. + * @param start Index of the row specifying the start of the range for the column swap (inclusive). + * @param stop Index of the row specifying the end of the range for the column swap (exclusive). + */ + protected abstract void swapCols(int colIdx1, int colIdx2, int start, int stop); + + + /** + * Checks if a value within {@link #balancedMatrix} is zero. + * @param idx Index of value within {@link #balancedMatrix}'s 1D data array to check if it is zero. + */ + protected abstract boolean isZero(int idx); + + + /** + * Computes the ℓ2 norm of a vector with {@code n} elements from {@link #balancedMatrix}'s 1D data array + * starting at index {@code start} and spaced by {@code stride}. + * @param start Starting index within {@link #balancedMatrix}'s 1D data array to compute norm of. + * @param n The number of elements in the vector to compute norm of. + * @param stride The spacing between each element within {@link #balancedMatrix}'s 1D data array to norm of. + * @return The norm of the vector containing the specified elements from {@link #balancedMatrix}'s 1D data array. + */ + protected abstract double vectorNorm(int start, int n, int stride); + + + /** + * Computes the maximum absolute value of a vector with {@code n} elements from {@link #balancedMatrix}'s 1D data array + * starting at index {@code start} and spaced by {@code stride}. + * @param start Starting index within {@link #balancedMatrix}'s 1D data array to compute maximum absolute value of. + * @param n The number of elements in the vector to compute maximum absolute value of. + * @param stride The spacing between each element within {@link #balancedMatrix}'s 1D data array to compute maximum absolute + * value of. + * @return The maximum absolute value of the vector containing the specified elements from {@link #balancedMatrix}'s 1D data + * array. + */ + protected abstract double vectorMaxAbs(int start, int n, int stride); + + + /** + * Scales a vector with {@code n} elements from {@link #balancedMatrix}'s 1D data array + * starting at index {@code start} and spaced by {@code stride}. This operation must be done in-place. + * + * @param factor Factor to scale elements by. + * @param start Starting index within {@link #balancedMatrix}'s 1D data array begin scaling. + * @param n The number of elements to scale. + * @param stride The spacing between each element within {@link #balancedMatrix}'s 1D data array to scale. + */ + protected abstract void vectorScale(double factor, int start, int n, int stride); + + + /** + *

    Performs basic setup for balancing. + *

    Specifically, copies the matrix to be balanced if an out-of-place computation was requested, initializes the matrix size, + * {@link #iLow}, {@link #iHigh}, and {@link #scalePerm}. + * @param src The matrix to balance. + */ + private void setUp(T src) { + ValidateParameters.ensureSquare(src.getShape()); + balancedMatrix = inPlace ? src : src.copy(); + size = balancedMatrix.numRows(); + + iLow = 0; + iHigh = size; + + scalePerm = new double[size]; + } + + + /** + * Balances a matrix so that the rows and columns have roughly similar sized norms. + * @param src Matrix to balance. Must be square. If {@link #inPlace == true} then {@code src} will be modified. + * Otherwise, {@code src} will not be modified. + * @return A reference to this balancer object. + * @throws org.flag4j.util.exceptions.TensorShapeException If {@code src} is not a square matrix. + */ + @Override + public Balancer decompose(T src) { + setUp(src); + + if (doPermutations) doIterativePermutations(); + + // Initialize scaling factors for remaining un-permuted rows and columns. + for(int i = iLow; iPerforms the permutation step of matrix balancing. + * + *

    This is, identifies rows and columns which are decoupled from the rest of the matrix and hence isolate an + * eigenvalue. Rows which isolate an eigenvalue are pushed to the bottom of the matrix. Similarly, columns which isolate an + * eigenvalue are pushed to the left of the matrix. To ensure that the row/column swaps are similarity transforms, if any two + * rows are swapped the same columns are swapped. + * + *

    Such row and column permutations transform the original matrix {@code A} into the following form: + *

    +     *             [ T1  X  Y  ]
    +     *   P-1 A P = [  0  B  Z  ]
    +     *             [  0  0  T2 ]
    + *

    Where {@code T1} and {@code T2} are upper-triangular matrices whose eigenvalues are the diagonal elements of the matrix. + * {@code P} is the permutation matrix representing the row and column swaps performed within this method. + * + *

    {@link #iLow} and {@link #iHigh} Specify the starting (inclusive) and ending (exclusive) row/column index of the submatrix + * {@code B}. + */ + protected void doIterativePermutations() { + boolean notConverged = true; + + // Find rows isolating eigenvalues and push to the bottom of the matrix. + while (notConverged) { + notConverged = false; + + for(int i=iHigh-1; i>=0; i--) { + int rowOffset = i*size; + boolean canSwap = true; + + for(int j=0; jPerforms the scaling step of matrix balancing. + * + *

    That is, computes scaling factors such that when a column is scaled by such value and the row is scaled by the reciprocal + * of that value, there ℓ1 norms are "close". Scaling need only be done for rows/column of the matrix which do not + * isolate eigenvalues; rows between {@link #iLow} (inclusive) to {@link #iHigh} (exclusive). + * + *

    D1 is the diagonal matrix describing such scaling and is the diagonal matrix computed by this method. \ + * The diagonal values of D1 are stored in {@link #scalePerm} between indices {@link #iLow} (inclusive) to + * {@link #iHigh} (exclusive). + */ + protected void doIterativeScaling() { + int n = iHigh - iLow; + int bRowOffset = iLow*size; + + boolean notConverged = true; + + while (notConverged) { + notConverged = false; // Set true if any scaling is applied. + + // Process each column/row i in the sub-block. + for (int i = iLow; i < iHigh; i++) { + int rowStart = i*size + iLow; + int colStart = i + bRowOffset; + + // Compute the row/column l2 norms. + double colNorm = vectorNorm(colStart, n, size); + double rowNorm = vectorNorm(rowStart, n, 1); + + // Compute the row/column maximum absolute values. + double colMaxAbs = vectorMaxAbs(i, iHigh - 1, size); + double rowMaxAbs = vectorMaxAbs(rowStart, size - iLow, 1); + + // Avoid division by zero. + if (colNorm == 0.0 || rowNorm == 0.0) + continue; + + // Report if any NaN value is encountered. + if (Double.isNaN(colNorm + colMaxAbs + rowNorm + rowMaxAbs)) + throw new IllegalArgumentException("NaN encountered in balancing step."); + + double g = rowNorm / BASE_SCALE; + double f = 1.0; + double s = colNorm + rowNorm; + + // Scale up colNorm and down rowNorm and avoid under/overflow. + while (colNorm < g + && Math.max(Math.max(f, colNorm), colMaxAbs) < SAFE_MAX_2 + && Math.min(Math.min(rowNorm, g), rowMaxAbs) > SAFE_MIN_2) { + + f *= BASE_SCALE; // multiply f by 2 + colNorm *= BASE_SCALE; + colMaxAbs *= BASE_SCALE; + rowNorm /= BASE_SCALE; + g /= BASE_SCALE; + rowMaxAbs /= BASE_SCALE; + } + + // Now consider if colNorm should be scaled down and rowNorm up while avoiding under/overflow. + g = colNorm / BASE_SCALE; + while (g >= rowNorm + && Math.max(rowNorm, rowMaxAbs) < SAFE_MAX_2 + && Math.min(Math.min(f, colNorm), Math.min(g, colMaxAbs)) > SAFE_MIN_2) { + + f /= BASE_SCALE; // divide f by 2 + colNorm /= BASE_SCALE; + g /= BASE_SCALE; + colMaxAbs /= BASE_SCALE; + rowNorm *= BASE_SCALE; + rowMaxAbs *= BASE_SCALE; + } + + // Now we check if this scaling factor f actually improves (colNorm + rowNorm) + // enough relative to factor*s. If not, skip. + if ( (colNorm + rowNorm) >= FACTOR*s ) + continue; + + // Ensure we don't underflow or overflow scalePerm[i] + if (f < 1.0 && scalePerm[i] < 1.0) { + if (f * scalePerm[i] <= SAFE_MIN_1) + continue; // scalePerm[i] would underflow + } + if (f > 1.0 && scalePerm[i] > 1.0) { + if (scalePerm[i] >= SAFE_MAX_1/ f) + continue; // scalePerm[i] would overflow. + } + + // All checks have passed, apply the scaling + double fInv = 1.0 / f; + scalePerm[i] *= f; + notConverged = true; // Indicate another pass should be performed. + + // Scale row i by fInv and column by f (the diagonal entries of inv(D) and D respectively). + vectorScale(fInv, rowStart, size - iLow, 1); + vectorScale(f, i, iHigh, size); + } + } + } + + + /** + * Ensures that {@link #decompose(MatrixMixin)} has been called on this instance. + * @throws IllegalStateException If {@link #decompose(MatrixMixin)} has not been called on this instance. + */ + private void ensureHasBalanced() { + // If balancedMatrix has not been instantiated, then balance(...) has not been called. + if(balancedMatrix == null) + throw new IllegalStateException("No matrix has been balanced by this balancer. Must call balance(...) first."); + } + + + /** + * Gets the starting index (inclusive) for the sub-matrix B1 of the balanced matrix which did not isolate eigenvalues. + * @return The starting index (inclusive) for the sub-matrix of the balanced matrix which did not isolate eigenvalues. + * @throws IllegalStateException If {@link #decompose(MatrixMixin)} has not yet been called on this instance. + */ + public int getILow() { + ensureHasBalanced(); + return iLow; + } + + + /** + * Gets the starting index (exclusive) for the sub-matrix B1 of the balanced matrix which did not isolate eigenvalues. + * @return The starting index (exclusive) for the sub-matrix of the balanced matrix which did not isolate eigenvalues. + * @throws IllegalStateException If {@link #decompose(MatrixMixin)} has not yet been called on this instance. + */ + public int getIHigh() { + ensureHasBalanced(); + return iHigh; + } + + + /** + * Gets the full balanced matrix, B, for the last matrix balanced by this balancer. + * @return The full balanced matrix for the last matrix balanced by this balancer. + * @throws IllegalStateException If {@link #decompose(MatrixMixin)} has not yet been called on this instance. + */ + public T getB() { + ensureHasBalanced(); + return balancedMatrix; + } + + + /** + * Gets the sub-matrix B1 of the full balanced matrix which did not isolate eigenvalues. + * @return The sub-matrix of the full balanced matrix which did not isolate eigenvalues. + * @throws IllegalStateException If {@link #decompose(MatrixMixin)} has not yet been called on this instance. + */ + public T getBSubMatrix() { + ensureHasBalanced(); + return balancedMatrix.getSlice(iLow, iHigh, iLow, iHigh); + } + + + /** + *

    Gets the raw scaling factors and permutation data stored in a single array. + * + *

    Let {@code perm[j]} be the index of the row and column swapped with row and column {@code j} and + * {@code scale[j]} be the scaling factor applied to row and column {@code j}. Then, + *

      + *
    • {@code scalePerm[j] = perm[j]} for {@code j = 0, ..., iLow-1}.
    • + *
    • {@code scalePerm[j] = scale[j]} for {@code j = iLow, ..., iHigh-1}.
    • + *
    • {@code scalePerm[j] = perm[j]} for {@code j = iHigh, ..., size-1}.
    • + *
    + * + * The order which row and column swaps are made is {@code size-1} to {@code iHigh}, then from {@code 0} to {@code iLow}. + * @return The raw scaling factors and permutation data stored in a single array. + * @throws IllegalStateException If {@link #decompose(MatrixMixin)} has not yet been called on this instance. + */ + public double[] getScalePerm() { + ensureHasBalanced(); + return scalePerm; + } + + + /** + * Gets the diagonal scaling matrix for the last matrix balanced by this balancer. + * @param full Flag indicating if the full diagonal scaling matrix should be constructed or if only the scaling factors should + * be returned. If the last matrix balanced had shape n-by-n then, + *
      + *
    • If {@code true}: The full n-by-n diagonal scaling matrix will be created.
    • + *
    • If {@code false}: A matrix of shape 1-by-n containing only the scaling factors + * (i.e. the diagonal entries of the full scaling matrix). + *
    • + *
    + * @return If {@code full == true} then the full n-by-n scaling matrix is returned. Otherwise if {@code full == false} + * a matrix of shape 1-by-n containing only the diagonal scaling factors is returned. + * @see #getD() + * @throws IllegalStateException If {@link #decompose(MatrixMixin)} has not yet been called on this instance. + */ + public Matrix getD(boolean full) { + ensureHasBalanced(); + + if (full) { + Matrix D = Matrix.I(size); + + for(int i=iLow; iGets the diagonal scaling factors for the last matrix balanced by this balancer. + * + *

    Note, this method will not construct the full diagonal scaling matrix. If the full matrix is desired, use + * {@link #getD(boolean)}. + * + * @return A 1-by-n matrix containing the diagonal elements of the full n-by-n diagonal scaling matrix. + * @see #getD(boolean) + * @throws IllegalStateException If {@link #decompose(MatrixMixin)} has not yet been called on this instance. + */ + public Matrix getD() { + return getD(false); + } + + + /** + *

    Gets the permutation matrix for the last matrix balanced by this balancer. + * @return The permutation matrix for the last matrix balanced by this balancer. + * @throws IllegalStateException If {@link #decompose(MatrixMixin)} has not yet been called on this instance. + */ + public PermutationMatrix getP() { + ensureHasBalanced(); + int[] swapPointers = ArrayUtils.intRange(0, size); + + int temp; + int value; + int count = 1; + for(int i=size-1; i>=iHigh; i--) { + value = (int) scalePerm[i]; + + if(size - count != value) { + temp = swapPointers[size - count]; + swapPointers[size - count] = swapPointers[value]; + swapPointers[value] = temp; + } + + count++; + } + + for(int i=0; i balancer = new RealBalancer(true, true).decompose(a); + Matrix B = balancer.getB(); + PermutationMatrix P = balancer.getP(); + + System.out.println("A:\n" + a + "\n"); + System.out.println("B:\n" + B + "\n"); + System.out.println("P:\n" + P.inv().toDense() + "\n"); + System.out.println("perm:\n" + Arrays.toString(P.getPermutation()) + "\n"); + + System.out.println("lo: " + balancer.getILow()); + System.out.println("hi: " + balancer.getIHigh()); + System.out.println("ps: " + Arrays.toString(balancer.getScalePerm())); + } +} diff --git a/src/main/java/org/flag4j/linalg/decompositions/balance/ComplexBalancer.java b/src/main/java/org/flag4j/linalg/decompositions/balance/ComplexBalancer.java new file mode 100644 index 000000000..69f5880ef --- /dev/null +++ b/src/main/java/org/flag4j/linalg/decompositions/balance/ComplexBalancer.java @@ -0,0 +1,227 @@ +/* + * MIT License + * + * Copyright (c) 2025. Jacob Watters + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +package org.flag4j.linalg.decompositions.balance; + +import org.flag4j.arrays.dense.CMatrix; +import org.flag4j.linalg.VectorNorms; +import org.flag4j.linalg.ops.common.field_ops.FieldOps; +import org.flag4j.linalg.ops.common.real.RealProperties; +import org.flag4j.linalg.ops.dense.DenseOps; + +/** + *

    Instances of this class may be used to balance complex dense matrices. Balancing a matrix involves computing a + * diagonal similarity transformation to "balance" the rows and columns of the matrix. This balancing is achieved + * by attempting to scale the entries of the matrix by similarity transformations such that the 1-norms of corresponding + * rows and columns have the similar 1-norms. Rows and columns may also be permuted during balancing if requested. + * + *

    Balancing is often used as a preprocessing step to improve the conditioning of eigenvalue problems. Because the + * balancing transformation is a similarity transformation, the eigenvalues are preserved. Further, when permutations are + * done during balancing it is possible to isolate decoupled eigenvalues. + * + *

    The similarity transformation of a square matrix A into the balanced matrix B can be described as: + *

    + *     B = T-1 A T
    + *       = D-1 P-1 A P D.
    + * Solving for A, balancing may be viewed as the following decomposition: + *
    + *     A = T B T-1
    + *       = P D B D-1 P-1.
    + * Where P is a permutation matrix, and D is a diagonal scaling matrix. + * + *

    When permutations are used during balancing we obtain a specific form. First, + *

    + *             [ T1  X   Y  ]
    + *   P-1 A P = [  0  B1  Z  ]
    + *             [  0  0   T2 ]
    + * Where T1 and T2 are upper triangular matrices whose eigenvalues lie along the diagonal. These are also + * eigenvalues of A. Then, if scaling is applied we obtain: + *
    + *                  [ T1     X*D1       Y    ]
    + *   D-1 P-1 A P D = [  0  D1-1*B*1D1  D1-1*Z  ]
    + *                   [  0      0         T2   ]
    + * Where D1 is a diagonal matrix such that, + *
    + *         [ I1 0  0  ]
    + *     D = [ 0  D1 0  ]
    + *         [ 0  0  I2 ]
    + * Where I1 and I2 are identity matrices with equivalent shapes to T1 and T2. + * + *

    Once balancing has been applied, one need only compute the eigenvalues of B1 and combine them with the diagonal + * entries of T1 and T2 to obtain all eigenvalues of A. + * + * @param The type of matrix being balanced. + * + * @see #getB() + * @see #getBSubMatrix() + * @see #getD(boolean) + * @see #getD() + * @see #getP() + * @see #getT() + */ +public class ComplexBalancer extends Balancer { + + + /** + *

    Constructs a complex balancer which will perform both the permutations and scaling steps out-of-place. + * + *

    To specify if permutations or scaling should be or should not be performed, use {@link #ComplexBalancer(boolean, boolean)}. + * To specify if the balancing should be done in-place, use {@link #ComplexBalancer(boolean, boolean, boolean)}. + */ + public ComplexBalancer() { + super(true, true, false); + } + + + /** + *

    Constructs a complex balancer optionally performing the permutation and scaling steps out-of-place. + * + *

    To specify if the balancing should be done in-place, use {@link #ComplexBalancer(boolean, boolean, boolean)}. + * + * @param doPermutations Flag indicating if the permutation step should be performed during balancing. + *

      + *
    • If {@code true}: the permutation step will be performed.
    • + *
    • If {@code false}: the permutation step will not be performed.
    • + *
    + * @param doScaling Flag indicating if the scaling step should be performed during balancing. + *
      + *
    • If {@code true}: the scaling step will be performed.
    • + *
    • If {@code false}: the scaling step will not be performed.
    • + *
    + */ + public ComplexBalancer(boolean doPermutations, boolean doScaling) { + super(doPermutations, doScaling, false); + } + + + /** + *

    Constructs a complex balancer optionally performing the permutation and scaling steps in/out-of-place. + * + * @param doPermutations Flag indicating if the permutation step should be performed during balancing. + *

      + *
    • If {@code true}: the permutation step will be performed.
    • + *
    • If {@code false}: the permutation step will not be performed.
    • + *
    + * @param doScaling Flag indicating if the scaling step should be performed during balancing. + *
      + *
    • If {@code true}: the scaling step will be performed.
    • + *
    • If {@code false}: the scaling step will not be performed.
    • + *
    + * @param inPlace Flag indicating if the balancing should be done in or out-of-place. + *
      + *
    • If {@code true}: balancing will be done in-place and the source matrix will be overwritten.
    • + *
    • If {@code false}: balancing will be done out-of-place.
    • + *
    + */ + public ComplexBalancer(boolean doPermutations, boolean doScaling, boolean inPlace) { + super(doScaling, doScaling, inPlace); + } + + + /** + * Swaps two rows, over a specified range, within the {@link #balancedMatrix} matrix. + * + * @param rowIdx1 Index of the first row to swap. + * @param rowIdx2 Index of the second row to swap. + * @param start Index of the column specifying the start of the range for the row swap (inclusive). + * @param stop Index of the column specifying the end of the range for the row swap (exclusive). + */ + @Override + protected void swapRows(int rowIdx1, int rowIdx2, int start, int stop) { + DenseOps.swapRowsUnsafe(balancedMatrix.shape, balancedMatrix.data, rowIdx1, rowIdx2, start, stop); + } + + + /** + * Swaps two columns, over a specified range, within the {@link #balancedMatrix} matrix. + * + * @param colIdx1 Index of the first column to swap. + * @param colIdx2 Index of the second column to swap. + * @param start Index of the row specifying the start of the range for the column swap (inclusive). + * @param stop Index of the row specifying the end of the range for the column swap (exclusive). + */ + @Override + protected void swapCols(int colIdx1, int colIdx2, int start, int stop) { + DenseOps.swapColsUnsafe(balancedMatrix.shape, balancedMatrix.data, colIdx1, colIdx2, start, stop); + } + + + /** + * Checks if a value within {@link #balancedMatrix} is zero. + * + * @param idx Index of value within flat data {@link #balancedMatrix} to check if it is zero. + */ + @Override + protected boolean isZero(int idx) { + return balancedMatrix.data[idx].isZero(); + } + + + /** + * Computes the ℓ2 norm of a vector with {@code n} elements from {@link #balancedMatrix}'s 1D data array + * starting at index {@code start} and spaced by {@code stride}. + * + * @param start Starting index within {@link #balancedMatrix}'s 1D data array to compute norm of. + * @param n The number of elements in the vector to compute norm of. + * @param stride The spacing between each element within {@link #balancedMatrix}'s 1D data array to norm of. + * + * @return The norm of the vector containing the specified elements from {@link #balancedMatrix}'s 1D data array. + */ + @Override + protected double vectorNorm(int start, int n, int stride) { + return VectorNorms.norm(balancedMatrix.data, start, n, stride); + } + + + /** + * Computes the maximum absolute value of a vector with {@code n} elements from {@link #balancedMatrix}'s 1D data array + * starting at index {@code start} and spaced by {@code stride}. + * + * @param start Starting index within {@link #balancedMatrix}'s 1D data array to compute maximum absolute value of. + * @param n The number of elements in the vector to compute maximum absolute value of. + * @param stride The spacing between each element within {@link #balancedMatrix}'s 1D data array to compute maximum absolute + * value of. + * + * @return The maximum absolute value of the vector containing the specified elements from {@link #balancedMatrix}'s 1D data + * array. + */ + @Override + protected double vectorMaxAbs(int start, int n, int stride) { + return RealProperties.maxAbs(start, n, stride); + } + + + /** + * Scales a vector with {@code n} elements from {@link #balancedMatrix}'s 1D data array + * starting at index {@code start} and spaced by {@code stride}. This operation must be done in-place. + * + * @param start Starting index within {@link #balancedMatrix}'s 1D data array begin scaling. + * @param n The number of elements to scale. + * @param stride The spacing between each element within {@link #balancedMatrix}'s 1D data array to scale. + */ + @Override + protected void vectorScale(double factor, int start, int n, int stride) { + FieldOps.scalMult(balancedMatrix.data, factor, start, n, stride, balancedMatrix.data); + } +} diff --git a/src/main/java/org/flag4j/linalg/decompositions/balance/RealBalancer.java b/src/main/java/org/flag4j/linalg/decompositions/balance/RealBalancer.java new file mode 100644 index 000000000..b4097a182 --- /dev/null +++ b/src/main/java/org/flag4j/linalg/decompositions/balance/RealBalancer.java @@ -0,0 +1,227 @@ +/* + * MIT License + * + * Copyright (c) 2025. Jacob Watters + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +package org.flag4j.linalg.decompositions.balance; + +import org.flag4j.arrays.dense.Matrix; +import org.flag4j.linalg.VectorNorms; +import org.flag4j.linalg.ops.common.real.RealOps; +import org.flag4j.linalg.ops.common.real.RealProperties; +import org.flag4j.linalg.ops.dense.real.RealDenseOps; + +/** + *

    Instances of this class may be used to balance real dense matrices. Balancing a matrix involves computing a + * diagonal similarity transformation to "balance" the rows and columns of the matrix. This balancing is achieved + * by attempting to scale the entries of the matrix by similarity transformations such that the 1-norms of corresponding + * rows and columns have the similar 1-norms. Rows and columns may also be permuted during balancing if requested. + * + *

    Balancing is often used as a preprocessing step to improve the conditioning of eigenvalue problems. Because the + * balancing transformation is a similarity transformation, the eigenvalues are preserved. Further, when permutations are + * done during balancing it is possible to isolate decoupled eigenvalues. + * + *

    The similarity transformation of a square matrix A into the balanced matrix B can be described as: + *

    + *     B = T-1 A T
    + *       = D-1 P-1 A P D.
    + * Solving for A, balancing may be viewed as the following decomposition: + *
    + *     A = T B T-1
    + *       = P D B D-1 P-1.
    + * Where P is a permutation matrix, and D is a diagonal scaling matrix. + * + *

    When permutations are used during balancing we obtain a specific form. First, + *

    + *             [ T1  X   Y  ]
    + *   P-1 A P = [  0  B1  Z  ]
    + *             [  0  0   T2 ]
    + * Where T1 and T2 are upper triangular matrices whose eigenvalues lie along the diagonal. These are also + * eigenvalues of A. Then, if scaling is applied we obtain: + *
    + *                  [ T1     X*D1       Y    ]
    + *   D-1 P-1 A P D = [  0  D1-1*B*1D1  D1-1*Z  ]
    + *                   [  0      0         T2   ]
    + * Where D1 is a diagonal matrix such that, + *
    + *         [ I1 0  0  ]
    + *     D = [ 0  D1 0  ]
    + *         [ 0  0  I2 ]
    + * Where I1 and I2 are identity matrices with equivalent shapes to T1 and T2. + * + *

    Once balancing has been applied, one need only compute the eigenvalues of B1 and combine them with the diagonal + * entries of T1 and T2 to obtain all eigenvalues of A. + * + * @param The type of matrix being balanced. + * + * @see #getB() + * @see #getBSubMatrix() + * @see #getD(boolean) + * @see #getD() + * @see #getP() + * @see #getT() + */ +public class RealBalancer extends Balancer { + + + /** + *

    Constructs a real balancer which will perform both the permutations and scaling steps out-of-place. + * + *

    To specify if permutations or scaling should be or should not be performed, use {@link #RealBalancer(boolean, boolean)}. + * To specify if the balancing should be done in-place, use {@link #RealBalancer(boolean, boolean, boolean)}. + */ + public RealBalancer() { + super(true, true, false); + } + + + /** + *

    Constructs a real balancer optionally performing the permutation and scaling steps out-of-place. + * + *

    To specify if the balancing should be done in-place, use {@link #RealBalancer(boolean, boolean, boolean)}. + * + * @param doPermutations Flag indicating if the permutation step should be performed during balancing. + *

      + *
    • If {@code true}: the permutation step will be performed.
    • + *
    • If {@code false}: the permutation step will not be performed.
    • + *
    + * @param doScaling Flag indicating if the scaling step should be performed during balancing. + *
      + *
    • If {@code true}: the scaling step will be performed.
    • + *
    • If {@code false}: the scaling step will not be performed.
    • + *
    + */ + public RealBalancer(boolean doPermutations, boolean doScaling) { + super(doPermutations, doScaling, false); + } + + + /** + *

    Constructs a real balancer optionally performing the permutation and scaling steps in/out-of-place. + * + * @param doPermutations Flag indicating if the permutation step should be performed during balancing. + *

      + *
    • If {@code true}: the permutation step will be performed.
    • + *
    • If {@code false}: the permutation step will not be performed.
    • + *
    + * @param doScaling Flag indicating if the scaling step should be performed during balancing. + *
      + *
    • If {@code true}: the scaling step will be performed.
    • + *
    • If {@code false}: the scaling step will not be performed.
    • + *
    + * @param inPlace Flag indicating if the balancing should be done in or out-of-place. + *
      + *
    • If {@code true}: balancing will be done in-place and the source matrix will be overwritten.
    • + *
    • If {@code false}: balancing will be done out-of-place.
    • + *
    + */ + public RealBalancer(boolean doPermutations, boolean doScaling, boolean inPlace) { + super(doScaling, doScaling, inPlace); + } + + + /** + * Swaps two rows, over a specified range, within the {@link #balancedMatrix} matrix. + * + * @param rowIdx1 Index of the first row to swap. + * @param rowIdx2 Index of the second row to swap. + * @param start Index of the column specifying the start of the range for the row swap (inclusive). + * @param stop Index of the column specifying the end of the range for the row swap (exclusive). + */ + @Override + protected void swapRows(int rowIdx1, int rowIdx2, int start, int stop) { + RealDenseOps.swapRowsUnsafe(balancedMatrix.shape, balancedMatrix.data, rowIdx1, rowIdx2, start, stop); + } + + + /** + * Swaps two columns, over a specified range, within the {@link #balancedMatrix} matrix. + * + * @param colIdx1 Index of the first column to swap. + * @param colIdx2 Index of the second column to swap. + * @param start Index of the row specifying the start of the range for the column swap (inclusive). + * @param stop Index of the row specifying the end of the range for the column swap (exclusive). + */ + @Override + protected void swapCols(int colIdx1, int colIdx2, int start, int stop) { + RealDenseOps.swapColsUnsafe(balancedMatrix.shape, balancedMatrix.data, colIdx1, colIdx2, start, stop); + } + + + /** + * Checks if a value within {@link #balancedMatrix} is zero. + * + * @param idx Index of value within flat data {@link #balancedMatrix} to check if it is zero. + */ + @Override + protected boolean isZero(int idx) { + return balancedMatrix.data[idx] == 0.0; + } + + + /** + * Computes the ℓ2 norm of a vector with {@code n} elements from {@link #balancedMatrix}'s 1D data array + * starting at index {@code start} and spaced by {@code stride}. + * + * @param start Starting index within {@link #balancedMatrix}'s 1D data array to compute norm of. + * @param n The number of elements in the vector to compute norm of. + * @param stride The spacing between each element within {@link #balancedMatrix}'s 1D data array to norm of. + * + * @return The norm of the vector containing the specified elements from {@link #balancedMatrix}'s 1D data array. + */ + @Override + protected double vectorNorm(int start, int n, int stride) { + return VectorNorms.norm(balancedMatrix.data, start, n, stride); + } + + + /** + * Computes the maximum absolute value of a vector with {@code n} elements from {@link #balancedMatrix}'s 1D data array + * starting at index {@code start} and spaced by {@code stride}. + * + * @param start Starting index within {@link #balancedMatrix}'s 1D data array to compute maximum absolute value of. + * @param n The number of elements in the vector to compute maximum absolute value of. + * @param stride The spacing between each element within {@link #balancedMatrix}'s 1D data array to compute maximum absolute + * value of. + * + * @return The maximum absolute value of the vector containing the specified elements from {@link #balancedMatrix}'s 1D data + * array. + */ + @Override + protected double vectorMaxAbs(int start, int n, int stride) { + return RealProperties.maxAbs(start, n, stride); + } + + + /** + * Scales a vector with {@code n} elements from {@link #balancedMatrix}'s 1D data array + * starting at index {@code start} and spaced by {@code stride}. This operation must be done in-place. + * + * @param start Starting index within {@link #balancedMatrix}'s 1D data array begin scaling. + * @param n The number of elements to scale. + * @param stride The spacing between each element within {@link #balancedMatrix}'s 1D data array to scale. + */ + @Override + protected void vectorScale(double factor, int start, int n, int stride) { + RealOps.scalMult(balancedMatrix.data, factor, start, n, stride, balancedMatrix.data); + } +} diff --git a/src/main/java/org/flag4j/linalg/decompositions/schur/RealSchur.java b/src/main/java/org/flag4j/linalg/decompositions/schur/RealSchur.java index cf3225aca..51c004109 100644 --- a/src/main/java/org/flag4j/linalg/decompositions/schur/RealSchur.java +++ b/src/main/java/org/flag4j/linalg/decompositions/schur/RealSchur.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -142,7 +142,7 @@ public RealSchur setExceptionalThreshold(int exceptionalThreshold) { * *

    By default, this is computed as *

    -     *     {@code maxIterations = }{@link #DEFAULT_MAX_ITERS_FACTOR}{@code * src.numRows;}
    + * {@code maxIterations = DEFAULT_MAX_ITERS_FACTOR * src.numRows;}
    * * where {@code src} is the matrix * being decomposed. diff --git a/src/main/java/org/flag4j/linalg/decompositions/schur/Schur.java b/src/main/java/org/flag4j/linalg/decompositions/schur/Schur.java index 8d4ee23bd..d7cadffe9 100644 --- a/src/main/java/org/flag4j/linalg/decompositions/schur/Schur.java +++ b/src/main/java/org/flag4j/linalg/decompositions/schur/Schur.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -69,6 +69,12 @@ public abstract class Schur, U> implements Dec *Decomposer to compute the Hessenburg decomposition as a setup step for the implicit double step QR algorithm. */ protected UnitaryDecomposition hess; +// /** +// *

    Balancer to scale rows and columns of matrix to be decomposed so that all row and columns have roughly similar sized norms. +// *

    This is done to attempt to improve the condition number and improve numerical stability when computing the Schur +// * decomposition. +// */ +// protected RealMatrixBalancerOld balancer; /** *Stores the number of rows in the matrix being decomposed. */ diff --git a/src/main/java/org/flag4j/linalg/decompositions/svd/RealSVD.java b/src/main/java/org/flag4j/linalg/decompositions/svd/RealSVD.java index 39cdd765e..1fbbc33d7 100644 --- a/src/main/java/org/flag4j/linalg/decompositions/svd/RealSVD.java +++ b/src/main/java/org/flag4j/linalg/decompositions/svd/RealSVD.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -27,6 +27,7 @@ import org.flag4j.arrays.Shape; import org.flag4j.arrays.dense.CMatrix; +import org.flag4j.arrays.dense.CVector; import org.flag4j.arrays.dense.Matrix; import org.flag4j.linalg.DirectSum; import org.flag4j.linalg.Eigen; @@ -104,7 +105,6 @@ protected Matrix invDirectSum(Matrix src) { @Override protected Matrix makeEigenPairs(Matrix B, double[] eigVals) { CMatrix[] pairs = Eigen.getEigenPairs(B); - double[] vals = pairs[0].toReal().data; System.arraycopy(vals, 0, eigVals, 0, eigVals.length); @@ -121,6 +121,7 @@ protected Matrix makeEigenPairs(Matrix B, double[] eigVals) { */ @Override protected void makeEigenVals(Matrix B, double[] eigVals) { + CVector valsTest = Eigen.getEigenValues(B); double[] vals = Eigen.getEigenValues(B).toReal().data; System.arraycopy(vals, 0, eigVals, 0, eigVals.length); } diff --git a/src/main/java/org/flag4j/linalg/decompositions/svd/SVD.java b/src/main/java/org/flag4j/linalg/decompositions/svd/SVD.java index 78fe786c5..daaf82217 100644 --- a/src/main/java/org/flag4j/linalg/decompositions/svd/SVD.java +++ b/src/main/java/org/flag4j/linalg/decompositions/svd/SVD.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -28,6 +28,7 @@ import org.flag4j.arrays.Shape; import org.flag4j.arrays.backend.MatrixMixin; import org.flag4j.arrays.dense.Matrix; +import org.flag4j.arrays.dense.Vector; import org.flag4j.linalg.decompositions.Decomposition; import org.flag4j.util.Flag4jConstants; @@ -105,6 +106,16 @@ public Matrix getS() { } + /** + * Gets the singular values of the last matrix decomposed. + * + * @return The singular values of the last matrix decomposed. + */ + public Vector getSingularValues() { + return S.getDiag(); + } + + /** * Gets the unitary matrix V corresponding to M=USVH in the SVD. * @return V corresponding to M=USVH in the SVD. Note that the hermitian transpose has @@ -152,13 +163,24 @@ public SVD decompose(T src) { if(computeUV) initUV(src.getShape(), stopIdx); // Initialize the U and V matrices. S = new Matrix(stopIdx); // initialize the S matrix. - for(int j=0; j 0) { + S.set(sigma, idx, idx); - if(computeUV && singularVecs != null) { - // Extract left and right singular vectors and normalize. - extractNormalizedCols(singularVecs, j); + if(computeUV && singularVecs != null) { + // Extract left and right singular vectors and normalize. + extractNormalizedCols(singularVecs, idx); + } + + idx++; } + + j++; } return this; diff --git a/src/main/java/org/flag4j/linalg/ops/MatrixMultiplyDispatcher.java b/src/main/java/org/flag4j/linalg/ops/MatrixMultiplyDispatcher.java index 2e99698da..54855e166 100644 --- a/src/main/java/org/flag4j/linalg/ops/MatrixMultiplyDispatcher.java +++ b/src/main/java/org/flag4j/linalg/ops/MatrixMultiplyDispatcher.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2022-2024. Jacob Watters + * Copyright (c) 2022-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -28,7 +28,7 @@ import org.flag4j.algebraic_structures.Field; import org.flag4j.arrays.Shape; import org.flag4j.arrays.dense.*; -import org.flag4j.linalg.ops.dense.real.RealDenseMatrixMultiplication; +import org.flag4j.linalg.ops.dense.real.RealDenseMatMult; import org.flag4j.linalg.ops.dense.real_field_ops.RealFieldDenseMatMult; import org.flag4j.linalg.ops.dense.real_field_ops.RealFieldDenseMatMultTranspose; import org.flag4j.linalg.ops.dense.semiring_ops.DenseSemiringMatMult; @@ -82,16 +82,16 @@ public static double[] dispatch(Matrix A, Vector b) { switch(algorithm) { case STANDARD_VECTOR: - dest = RealDenseMatrixMultiplication.standardVector(A.data, A.shape, b.data, bMatShape); + dest = RealDenseMatMult.standardVector(A.data, A.shape, b.data, bMatShape); break; case BLOCKED_VECTOR: - dest = RealDenseMatrixMultiplication.blockedVector(A.data, A.shape, b.data, bMatShape); + dest = RealDenseMatMult.blockedVector(A.data, A.shape, b.data, bMatShape); break; case CONCURRENT_STANDARD_VECTOR: - dest = RealDenseMatrixMultiplication.concurrentStandardVector(A.data, A.shape, b.data, bMatShape); + dest = RealDenseMatMult.concurrentStandardVector(A.data, A.shape, b.data, bMatShape); break; default: - dest = RealDenseMatrixMultiplication.concurrentBlockedVector(A.data, A.shape, b.data, bMatShape); + dest = RealDenseMatMult.concurrentBlockedVector(A.data, A.shape, b.data, bMatShape); break; } diff --git a/src/main/java/org/flag4j/linalg/ops/RealDenseMatrixMultiplyDispatcher.java b/src/main/java/org/flag4j/linalg/ops/RealDenseMatrixMultiplyDispatcher.java index acfe850c3..9e6e1ac6f 100644 --- a/src/main/java/org/flag4j/linalg/ops/RealDenseMatrixMultiplyDispatcher.java +++ b/src/main/java/org/flag4j/linalg/ops/RealDenseMatrixMultiplyDispatcher.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2023-2024. Jacob Watters + * Copyright (c) 2023-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -26,8 +26,8 @@ import org.flag4j.arrays.Shape; import org.flag4j.arrays.dense.Matrix; -import org.flag4j.linalg.ops.dense.real.RealDenseMatrixMultTranspose; -import org.flag4j.linalg.ops.dense.real.RealDenseMatrixMultiplication; +import org.flag4j.linalg.ops.dense.real.RealDenseMatMult; +import org.flag4j.linalg.ops.dense.real.RealDenseMatMultTranspose; import org.flag4j.util.ValidateParameters; import java.util.HashMap; @@ -73,22 +73,22 @@ private RealDenseMatrixMultiplyDispatcher() { algorithmMap = new HashMap<>(); RealDenseTensorBinaryOperation[] algorithms = { - RealDenseMatrixMultiplication::standard, - RealDenseMatrixMultiplication::reordered, - RealDenseMatrixMultiplication::blocked, - RealDenseMatrixMultiplication::blockedReordered, - RealDenseMatrixMultiplication::concurrentStandard, - RealDenseMatrixMultiplication::concurrentReordered, - RealDenseMatrixMultiplication::concurrentBlocked, - RealDenseMatrixMultiplication::concurrentBlockedReordered, - RealDenseMatrixMultiplication::standardVector, - RealDenseMatrixMultiplication::blockedVector, - RealDenseMatrixMultiplication::concurrentStandardVector, - RealDenseMatrixMultiplication::concurrentBlockedVector, - RealDenseMatrixMultTranspose::multTranspose, - RealDenseMatrixMultTranspose::multTransposeBlocked, - RealDenseMatrixMultTranspose::multTransposeConcurrent, - RealDenseMatrixMultTranspose::multTransposeBlockedConcurrent, + RealDenseMatMult::standard, + RealDenseMatMult::reordered, + RealDenseMatMult::blocked, + RealDenseMatMult::blockedReordered, + RealDenseMatMult::concurrentStandard, + RealDenseMatMult::concurrentReordered, + RealDenseMatMult::concurrentBlocked, + RealDenseMatMult::concurrentBlockedReordered, + RealDenseMatMult::standardVector, + RealDenseMatMult::blockedVector, + RealDenseMatMult::concurrentStandardVector, + RealDenseMatMult::concurrentBlockedVector, + RealDenseMatMultTranspose::multTranspose, + RealDenseMatMultTranspose::multTransposeBlocked, + RealDenseMatMultTranspose::multTransposeConcurrent, + RealDenseMatMultTranspose::multTransposeBlockedConcurrent, }; for(int i = 0; i< algorithms.length; i++) { diff --git a/src/main/java/org/flag4j/linalg/ops/TransposeDispatcher.java b/src/main/java/org/flag4j/linalg/ops/TransposeDispatcher.java index e7623d226..24ad2017a 100644 --- a/src/main/java/org/flag4j/linalg/ops/TransposeDispatcher.java +++ b/src/main/java/org/flag4j/linalg/ops/TransposeDispatcher.java @@ -24,17 +24,13 @@ package org.flag4j.linalg.ops; -import org.flag4j.algebraic_structures.Field; import org.flag4j.algebraic_structures.Ring; import org.flag4j.algebraic_structures.Semiring; import org.flag4j.arrays.Shape; -import org.flag4j.arrays.backend.field_arrays.AbstractDenseFieldMatrix; import org.flag4j.arrays.backend.primitive_arrays.AbstractDoubleTensor; import org.flag4j.arrays.backend.semiring_arrays.AbstractDenseSemiringTensor; import org.flag4j.arrays.dense.Matrix; import org.flag4j.linalg.ops.dense.DenseTranspose; -import org.flag4j.linalg.ops.dense.field_ops.DenseFieldHermitianTranspose; -import org.flag4j.linalg.ops.dense.field_ops.DenseFieldTranspose; import org.flag4j.linalg.ops.dense.real.RealDenseTranspose; import org.flag4j.linalg.ops.dense.ring_ops.DenseRingHermitianTranspose; @@ -134,36 +130,6 @@ public static double[] dispatch(double[] src, Shape shape) { } - /** - * Dispatches a matrix transpose problem to the appropriate algorithm based on its shape and size. - * @param src Matrix to transpose. - * @return The transpose of the source matrix. - */ - public static > AbstractDenseFieldMatrix dispatch(AbstractDenseFieldMatrix src) { - - T[] dest; - - TransposeAlgorithms algorithm = chooseAlgorithmComplex(src.shape); // TODO: Need an updated method for this. Or at least a name change. - - switch(algorithm) { - case STANDARD: - dest = DenseFieldTranspose.standardMatrix(src.data, src.numRows, src.numCols); - break; - case BLOCKED: - dest = DenseFieldTranspose.blockedMatrix(src.data, src.numRows, src.numCols); - break; - case CONCURRENT_STANDARD: - dest = DenseFieldTranspose.standardMatrixConcurrent(src.data, src.numRows, src.numCols); - break; - default: - dest = DenseFieldTranspose.blockedMatrixConcurrent(src.data, src.numRows, src.numCols); - break; - } - - return src.makeLikeTensor(new Shape(src.numCols, src.numRows), dest); - } - - /** * Dispatches a matrix transpose problem to the appropriate algorithm based on its shape and size. * @param src Matrix to transpose. @@ -211,13 +177,13 @@ public static Object[] dispatch(Object[] src, Shape shape, Object[] dest) { * @return If {@code dest != null} a reference to the {@code dest} array will be returned. Otherwise, if {@code dest == null} * then a new array will be created and returned. */ - public static > void dispatchHermitian(T[] src, Shape shape, T[] dest) { + public static > void dispatchHermitian(T[] src, Shape shape, T[] dest) { TransposeAlgorithms algorithm = chooseAlgorithmHermitian(shape); if(algorithm == TransposeAlgorithms.BLOCKED) - DenseFieldHermitianTranspose.blockedMatrixHerm(src, shape.get(0), shape.get(1), dest); + DenseRingHermitianTranspose.blockedMatrixHerm(src, shape.get(0), shape.get(1), dest); else - DenseFieldHermitianTranspose.blockedMatrixConcurrentHerm(src, shape.get(0), shape.get(1), dest); + DenseRingHermitianTranspose.blockedMatrixConcurrentHerm(src, shape.get(0), shape.get(1), dest); } @@ -251,24 +217,6 @@ public static > T dispatchTensor( } - /** - * Dispatches a tensor transpose problem to the appropriate algorithm based on its shape and size. - * @param src Tensor to transpose. - * @param axes Permutation of axes in the tensor transpose. - * @return The result of the tensor transpose. - * @throws ArrayIndexOutOfBoundsException If either axis is not within the {@code src} tensor. - */ - public static > T dispatchTensor(T src, int[] axes) { - TransposeAlgorithms algorithm = chooseAlgorithmTensor(src.data.length); - - double[] dest = algorithm == TransposeAlgorithms.STANDARD ? - RealDenseTranspose.standard(src.data, src.shape, axes): - RealDenseTranspose.standardConcurrent(src.data, src.shape, axes); - - return src.makeLikeTensor(src.shape.permuteAxes(axes), dest); - } - - /** * Dispatches a tensor transpose problem to the appropriate algorithm based on its shape and size. * @param src Tensor to transpose. diff --git a/src/main/java/org/flag4j/linalg/ops/common/complex/Complex128Ops.java b/src/main/java/org/flag4j/linalg/ops/common/complex/Complex128Ops.java index 664b4ce12..1eb70ed05 100644 --- a/src/main/java/org/flag4j/linalg/ops/common/complex/Complex128Ops.java +++ b/src/main/java/org/flag4j/linalg/ops/common/complex/Complex128Ops.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2023-2024. Jacob Watters + * Copyright (c) 2023-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -29,19 +29,18 @@ import org.flag4j.util.ErrorMessages; /** - * This class provides low level methods for computing ops on complex tensors. These methods can be applied to + * This class provides low level methods for computing operations on complex tensors. These methods can be applied to * either sparse or dense complex tensors. */ public final class Complex128Ops { private Complex128Ops() { // Hide constructor for utility class. for utility class. - } /** - * Computes the element-wise square root of a tensor as complex values. This allows for the square root of a negative number. + * Computes the element-wise square root of a tensor as complex values. This allows for the square root of negative numbers. * @param src Elements of the tensor. * @return The element-wise square root of the tensor. */ diff --git a/src/main/java/org/flag4j/linalg/ops/common/complex/Complex64Ops.java b/src/main/java/org/flag4j/linalg/ops/common/complex/Complex64Ops.java index 3da385456..c611d76fa 100644 --- a/src/main/java/org/flag4j/linalg/ops/common/complex/Complex64Ops.java +++ b/src/main/java/org/flag4j/linalg/ops/common/complex/Complex64Ops.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -37,7 +37,6 @@ public final class Complex64Ops { private Complex64Ops() { // Hide constructor for utility class. for utility class. - } diff --git a/src/main/java/org/flag4j/linalg/ops/common/field_ops/FieldOps.java b/src/main/java/org/flag4j/linalg/ops/common/field_ops/FieldOps.java index 95fb862b8..edba1f403 100644 --- a/src/main/java/org/flag4j/linalg/ops/common/field_ops/FieldOps.java +++ b/src/main/java/org/flag4j/linalg/ops/common/field_ops/FieldOps.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -210,6 +210,45 @@ public static > void scalMult(double[] entries, V factor, V[] } + /** + *

    Scales entries by the specified {@code factor} within {@code src} starting at index {@code start} + * and scaling a total of {@code n} elements spaced by {@code stride}. + * + *

    More formally, this method scales elements by the specified {@code factor} at indices: + * {@code start}, {@code start + stride}, {@code start + 2*stride}, ..., {@code start + (n-1)*stride}. + * + *

    This method may be used to scale a row or column of a + * {@link org.flag4j.arrays.dense.Matrix matrix} {@code a} as follows: + *

      + *
    • Maximum absolute value within row {@code i}: + *
      {@code scale(a.data, i*a.numCols, a.numCols, 1, dest);}
    • + *
    • Maximum absolute value within column {@code j}: + *
      {@code scale(a.data, j, a.numRows, a.numRows, dest);}
    • + *
    + * + * @param src The array containing values to scale. + * @param factor Factor by which to scale elements. + * @param start The starting index in {@code src} to begin scaling. + * @param n The number of elements to scale within {@code src1}. + * @param stride The gap (in indices) between consecutive elements to scale within {@code src}. + * @param dest The array to store the result in. May be {@code null} or the same array as {@code src} to perform the operation + * in-place. Assumed to be at least as large as {@code src} but this is not explicitly enforced. + * + * @return If {@code dest == null} a new array containing all elements of {@code src} with the appropriate values scaled. + * Otherwise, A reference to the {@code dest} array. + */ + public static > T[] scalMult( + T[] src, double factor, int start, int n, int stride, T[] dest) { + if(dest==null) dest = src.clone(); + int stop = start + n*stride; + + for(int i=start; i> boolean allClose(T[] src1, T[] src2) { - return allClose(src1, src2, 1e-05, 1e-08); - } - - - /** - * Checks if all data of two arrays are 'close'. - * @param src1 First array in comparison. - * @param src2 Second array in comparison. - * @return True if both arrays have the same length and all data are 'close' element-wise, i.e. - * elements {@code a} and {@code b} at the same positions in the two arrays respectively and satisfy - * {@code |a-b| <= (absTol + relTol*|b|)}. Otherwise, returns false. - * @see #allClose(Field[], Field[]) - */ - public static > boolean allClose(T[] src1, T[] src2, double relTol, double absTol) { - if (src1.length != src2.length) return false; - - for(int i=0; i tol) return false; - } - - return true; // If we reach this point, all data must be close. - } -} diff --git a/src/main/java/org/flag4j/linalg/ops/common/real/RealOps.java b/src/main/java/org/flag4j/linalg/ops/common/real/RealOps.java index e0697d833..d5e569235 100644 --- a/src/main/java/org/flag4j/linalg/ops/common/real/RealOps.java +++ b/src/main/java/org/flag4j/linalg/ops/common/real/RealOps.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2022-2024. Jacob Watters + * Copyright (c) 2022-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -85,20 +85,58 @@ public static double[] scalMult(double[] src, double factor, int start, int stop } + /** + *

    Scales entries by the specified {@code factor} within {@code src} starting at index {@code start} + * and scaling a total of {@code n} elements spaced by {@code stride}. + * + *

    More formally, this method scales elements by the specified {@code factor} at indices: + * {@code start}, {@code start + stride}, {@code start + 2*stride}, ..., {@code start + (n-1)*stride}. + * + *

    This method may be used to scale a row or column of a + * {@link org.flag4j.arrays.dense.Matrix matrix} {@code a} as follows: + *

      + *
    • Maximum absolute value within row {@code i}: + *
      {@code scale(a.data, i*a.numCols, a.numCols, 1, dest);}
    • + *
    • Maximum absolute value within column {@code j}: + *
      {@code scale(a.data, j, a.numRows, a.numRows, dest);}
    • + *
    + * + * @param src The array containing values to scale. + * @param factor Factor by which to scale elements. + * @param start The starting index in {@code src} to begin scaling. + * @param n The number of elements to scale within {@code src1}. + * @param stride The gap (in indices) between consecutive elements to scale within {@code src}. + * @param dest The array to store the result in. May be {@code null} or the same array as {@code src} to perform the operation + * in-place. Assumed to be at least as large as {@code src} but this is not explicitly enforced. + * + * @return If {@code dest == null} a new array containing all elements of {@code src} with the appropriate values scaled. + * Otherwise, A reference to the {@code dest} array. + */ + public static double[] scalMult(double[] src, double factor, int start, int n, int stride, double[] dest) { + if(dest==null) dest = src.clone(); + int stop = start + n*stride; + + for(int i=start; iall of the elements of a tensor are finite. + * Checks if all elements of a tensor are finite. * @param src Entries of the tensor. * @return {@code false} is any entry of {@code src} is not {@link Double#isFinite(double) finite}. Otherwise, returns {@code * true}. @@ -262,11 +262,14 @@ public static double maxAbs(double... entries) { */ public static int argmin(double... entries) { double currMin = (entries.length==0) ? 0 : Double.MAX_VALUE; + double curr; int mindex = -1; for(int i=0, size=entries.length; i currMax) { - currMax = entries[i]; + curr = Math.max(entries[i], currMax); + + if (curr != currMax) { + currMax = curr; maxdex = i; } } @@ -303,11 +309,15 @@ public static int argmax(double... entries) { */ public static int argminAbs(double... entries) { double currMin = (entries.length==0) ? 0 : Double.MAX_VALUE; + double curr; int mindex = -1; for(int i=0, size=entries.length; i currMax) { - currMax = entries[i]; + curr = Math.abs(entries[i]); + curr = Math.max(curr, currMax); + + if (curr != currMax) { + currMax = curr; maxdex = i; } } return maxdex; } + + + /** + *

    Returns the maximum absolute value among {@code n} elements in the array {@code src}, + * starting at index {@code start} and advancing by {@code stride} for each subsequent element. + * + *

    More formally, this method examines the elements at indices: + * {@code start}, {@code start + stride}, {@code start + 2*stride}, ..., {@code start + (n-1)*stride}. + * + *

    This method will propagate {@link Double#NaN} values meaning if at least one element considered is {@link Double#NaN} + * the result of this method will be {@link Double#NaN}. + * + *

    This method may be used to find the maximum absolute value within the row or column of a + * {@link org.flag4j.arrays.dense.Matrix matrix} {@code a} as follows: + *

      + *
    • Maximum absolute value within row {@code i}: + *
      {@code maxAbs(a.data, factor, i*a.numCols, a.numCols, 1);}
    • + *
    • Maximum absolute value within column {@code j}: + *
      {@code maxAbs(a.data, factor, j, a.numRows, a.numRows);}
    • + *
    + * + * @param src The array to search for maximum absolute value within. + * @param start The starting index in {@code src} to search. + * @param n The number of elements to consider within {@code src1}. + * @param stride The gap (in indices) between consecutive elements to search within {@code src}. + * @return + *
      + *
    • If any element of {@code src} is {@link Double#NaN} then the result will be {@link Double#NaN}.
    • + *
    • Otherwise, the maximum absolute value found among all elements considered in {@code src}.
    • + *
    + * + * @throws IndexOutOfBoundsException If the specified range extends beyond the array bounds. + */ + public static double maxAbs(double[] src, final int start, final int n, final int stride) { + double currMax = 0; + final int end = start + n*stride; + + for(int i=start; iReturns the minimum absolute value among {@code n} elements in the array {@code src}, + * starting at index {@code start} and advancing by {@code stride} for each subsequent element. + * + *

    More formally, this method examines the elements at indices: + * {@code start}, {@code start + stride}, {@code start + 2*stride}, ..., {@code start + (n-1)*stride}. + * + *

    This method will propagate {@link Double#NaN} values meaning if at least one element considered is {@link Double#NaN} + * the result of this method will be {@link Double#NaN}. + * + *

    This method may be used to find the minimum absolute value within the row or column of a + * {@link org.flag4j.arrays.dense.Matrix matrix} {@code a} as follows: + *

      + *
    • Minimum absolute value within row {@code i}: + *
      {@code maxAbs(a.data, i*a.numCols, a.numCols, 1);}
    • + *
    • Minimum absolute value within column {@code j}: + *
      {@code maxAbs(a.data, j, a.numRows, a.numRows);}
    • + *
    + * + * @param src The array to search for Minimum absolute value within. + * @param start The starting index in {@code src} to search. + * @param n The number of elements to consider within {@code src1}. + * @param stride The gap (in indices) between consecutive elements to search within {@code src}. + * @return + *
      + *
    • If {@code src.length == 0} then {@link Double#POSITIVE_INFINITY} will be returned.
    • + *
    • If any element of {@code src} is {@link Double#NaN} then the result will be {@link Double#NaN}.
    • + *
    • Otherwise, the minimum absolute value found among all elements considered inn{@code src}.
    • + *
    + * + * @throws IndexOutOfBoundsException If the specified range extends beyond the array bounds. + */ + public static double minAbs(double[] src, final int start, final int n, final int stride) { + double currMin = Double.POSITIVE_INFINITY; + final int end = start + n*stride; + + for(int i=start; i> int argminAbs(T... values) { return mindex; } + + + /** + *

    Returns the maximum absolute value among {@code n} elements in the array {@code src}, + * starting at index {@code start} and advancing by {@code stride} for each subsequent element. + * + *

    More formally, this method examines the elements at indices: + * {@code start}, {@code start + stride}, {@code start + 2*stride}, ..., {@code start + (n-1)*stride}. + * + *

    This method may be used to find the maximum absolute value within the row or column of a + * {@link org.flag4j.arrays.dense.RingMatrix matrix} {@code a} as follows: + *

      + *
    • Maximum absolute value within row {@code i}: + *
      {@code maxAbs(a.data, i*a.numCols, a.numCols, 1);}
    • + *
    • Maximum absolute value within column {@code j}: + *
      {@code maxAbs(a.data, j, a.numRows, a.numRows);}
    • + *
    + * + * @param src The array to search for maximum absolute value within. + * @param start The starting index in {@code src} to search. + * @param n The number of elements to consider within {@code src1}. + * @param stride The gap (in indices) between consecutive elements to search within {@code src}. + * @return The maximum absolute value found among all elements considered in {@code src}. + * + * + * @throws IndexOutOfBoundsException If the specified range extends beyond the array bounds. + */ + public static > double maxAbs(T[] src, final int start, final int n, final int stride) { + double currMax = 0; + final int end = start + n*stride; + + for(int i=start; iReturns the minimum absolute value among {@code n} elements in the array {@code src}, + * starting at index {@code start} and advancing by {@code stride} for each subsequent element. + * + *

    More formally, this method examines the elements at indices: + * {@code start}, {@code start + stride}, {@code start + 2*stride}, ..., {@code start + (n-1)*stride}. + * + *

    This method may be used to find the minimum absolute value within the row or column of a + * {@link org.flag4j.arrays.dense.RingMatrix matrix} {@code a} as follows: + *

      + *
    • Minimum absolute value within row {@code i}: + *
      {@code maxAbs(a.data, i*a.numCols, a.numCols, 1);}
    • + *
    • Minimum absolute value within column {@code j}: + *
      {@code maxAbs(a.data, j, a.numRows, a.numRows);}
    • + *
    + * + * @param src The array to search for Minimum absolute value within. + * @param start The starting index in {@code src} to search. + * @param n The number of elements to consider within {@code src1}. + * @param stride The gap (in indices) between consecutive elements to search within {@code src}. + * @return + *
      + *
    • If {@code src.length == 0} then {@link Double#POSITIVE_INFINITY} will be returned.
    • + *
    • Otherwise, the minimum absolute value found among all elements considered inn{@code src}.
    • + *
    + * + * @throws IndexOutOfBoundsException If the specified range extends beyond the array bounds. + */ + public static > double minAbs(T[] src, final int start, final int n, final int stride) { + double currMin = Double.POSITIVE_INFINITY; + final int end = start + n*stride; + + for(int i=start; i> boolean allClose(T[] src1, T[] src2) { diff --git a/src/main/java/org/flag4j/linalg/ops/dense/DenseOps.java b/src/main/java/org/flag4j/linalg/ops/dense/DenseOps.java new file mode 100644 index 000000000..5086657e8 --- /dev/null +++ b/src/main/java/org/flag4j/linalg/ops/dense/DenseOps.java @@ -0,0 +1,141 @@ +/* + * MIT License + * + * Copyright (c) 2025. Jacob Watters + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +package org.flag4j.linalg.ops.dense; + +import org.flag4j.arrays.Shape; +import org.flag4j.util.ValidateParameters; + +public final class DenseOps { + + private DenseOps() { + // Hide default constructor for utility class. + } + + + /** + * Swaps specified rows in the matrix. This is done in place. + * + * @param shape Shape of the matrix. + * @param data Data of the matrix (modified). + * @param rowIndex1 Index of the first row to swap. + * @param rowIndex2 Index of the second row to swap. + * + * @throws org.flag4j.util.exceptions.LinearAlgebraException If {@code shape.getRank() != 2}. + * @throws IndexOutOfBoundsException If either index is outside the matrix bounds. + */ + public static void swapRows(Shape shape, T[] data, int rowIdx1, int rowIdx2) { + ValidateParameters.ensureRank(shape, 2); + int numRows = shape.get(0); + int numCols = shape.get(1); + ValidateParameters.ensureValidArrayIndices(numRows, rowIdx1, rowIdx2); + + swapRowsUnsafe(shape, data, rowIdx1, rowIdx2, 0, numCols); + } + + + /** + *

    Swaps two rows, over a specified range of columns, within a matrix. Specifically, all elements in the matrix within rows + * {@code rowIdx1} + * and {@code rowIdx2} and between columns {@code start} (inclusive) and {@code stop} (exclusive). + * This operation is done in place. + * + *

    No bounds checking is done within this method to ensure that the indices provided are valid. As such, it is + * highly recommended to us {@link #swapRows(Shape, Object[], int, int)} in most cases. + * + * @param shape Shape of the matrix. + * @param data Data of the matrix (modified). + * @param rowIdx1 Index of the first row to swap. + * @param rowIdx2 Index of the second row to swap. + * @param start Index of the column specifying the start of the range for the row swap (inclusive). + * @param stop Index of the column specifying the end of the range for the row swap (exclusive). + */ + public static void swapRowsUnsafe(Shape shape, T[] data, int rowIdx1, int rowIdx2, int start, int stop) { + // Quick return when indices are equal. + if(rowIdx1 == rowIdx2) return; + + final int cols = shape.get(1); + final int rowOffset1 = rowIdx1*cols; + final int rowOffset2 = rowIdx2*cols; + T temp; + + for(int j=start; j void swapCols(Shape shape, T[] data, int colIdx1, int colIdx2) { + ValidateParameters.ensureRank(shape, 2); + int numRows = shape.get(0); + int numCols = shape.get(1); + ValidateParameters.ensureValidArrayIndices(numCols, colIdx1, colIdx2); + + swapColsUnsafe(shape, data, colIdx1, colIdx2, 0, numRows); + } + + + /** + *

    Swaps two columns, over a specified range of rows, within a matrix. Specifically, all elements in the matrix within columns + * {@code colIdx1} and {@code colIdx2} and between rows {@code start} (inclusive) and {@code stop} (exclusive). This operation + * is done in place. + * + *

    No bounds checking is done within this method to ensure that the indices provided are valid. As such, it is + * highly recommended to us {@link #swapCols(Shape, Object[], int, int)} in most cases. + * + * @param shape Shape of the matrix. + * @param data Data of the matrix (modified). + * @param colIdx1 Index of the first column to swap. + * @param colIdx2 Index of the second column to swap. + * @param start Index of the row specifying the start of the range for the row swap (inclusive). + * @param stop Index of the row specifying the end of the range for the row swap (exclusive). + */ + public static void swapColsUnsafe(Shape shape, T[] data, int colIdx1, int colIdx2, int start, int stop) { + if(colIdx1 == colIdx2) return; + + final int cols = shape.get(1); + int rowOffset = start*cols; + T temp; + + for(int i=start; i> void standardHerm(T[] src, - Shape shape, - int[] axes, - T[] dest) { - if(shape.getRank() < 2) { // Can't transpose tensor with less than 2 axes. - throw new IllegalArgumentException("Tensor transpose not defined for rank " + shape.getRank() + - " tensor."); - } - - Shape destShape = shape.permuteAxes(axes); - - for(int i=0, size=src.length; i> void standardHerm(T[] src, - Shape shape, - int axis1, int axis2, - T[] dest) { - if(shape.getRank() < 2) { // Can't transpose tensor with less than 2 axes. - throw new IllegalArgumentException("Tensor transpose not defined for rank " + shape.getRank() + - " tensor."); - } - - Shape destShape = shape.swapAxes(axis1, axis2); - int[] destIndices; - - for(int i=0, size=src.length; i> void standardConcurrentHerm(T[] src, - Shape shape, - int axis1, int axis2, - T[] dest) { - if(shape.getRank() < 2) { // Can't transpose tensor with less than 2 axes. - throw new IllegalArgumentException("Tensor transpose not defined for rank " + shape.getRank() + " tensor."); - } - - Shape destShape = shape.swapAxes(axis1, axis2); - - // Compute transpose concurrently - ThreadManager.concurrentOperation(src.length, (startIdx, endIdx) -> { - for(int i=startIdx; i> T[] standardConcurrentHerm(T[] src, - Shape shape, - int[] axes, - T[] dest) { - if(shape.getRank() < 2) { // Can't transpose tensor with less than 2 axes. - throw new IllegalArgumentException("Tensor transpose not defined for rank " + shape.getRank() + - " tensor."); - } - - Shape destShape = shape.permuteAxes(axes); - - ThreadManager.concurrentOperation(src.length, (startIdx, endIdx) -> { - for(int i=startIdx; i> void standardMatrixHerm(T[] src, int numRows, int numCols, T[] dest) { - int destIndex, srcIndex, end; - - for (int i=0; i> void blockedMatrixHerm(T[] src, int numRows, int numCols, T[] dest) { - int blockSize = Configurations.getBlockSize(); - int blockRowEnd; - int blockColEnd; - int srcIndex, destIndex, end; - - for(int i=0; i> void standardMatrixConcurrentHerm(T[] src, - int numRows, - int numCols, - T[] dest) { - // Compute transpose concurrently. - ThreadManager.concurrentOperation(numCols, (startIdx, endIdx) -> { - for(int i=startIdx; i> void blockedMatrixConcurrentHerm(T[] src, - int numRows, - int numCols, - T[] dest) { - int blockSize = Configurations.getBlockSize(); - - // Compute transpose concurrently. - ThreadManager.concurrentBlockedOperation(numCols, blockSize, (startIdx, endIdx) -> { - for(int i=startIdx; i> boolean isHermitian(Shape shape, T[] src) { - if(shape.get(0)!=shape.get(1)) return false; - - int numCols = shape.get(1); - - for(int i=0, rows=shape.get(0); i> boolean isAntiHermitian(T[] src, Shape shape) { - if(shape.get(0)!=shape.get(1)) return false; - - int numCols = shape.get(1); - - for(int i=0, numRows=shape.get(0); i> boolean isCloseToIdentity(Shape shape, T[] src) { - int numRows = shape.get(0); - int numCols = shape.get(1); - - if(src == null || numRows!=numCols) return false; - if(src.length == 0) return true; - - // Tolerances corresponds to the allClose(...) methods. - double diagTol = 1.001E-5; - double nonDiagTol = 1e-08; - final T ONE = src[0].getOne(); - int rows = numRows; - int cols = numCols; - int pos = 0; - - for(int i=0; i diagTol) - || (i!=j && src[pos].mag() > nonDiagTol)) { - return false; - } - - pos++; - } - } - - return true; // If we make it here, the matrix is "close" to the identity matrix. - } -} diff --git a/src/main/java/org/flag4j/linalg/ops/dense/field_ops/DenseFieldSetOps.java b/src/main/java/org/flag4j/linalg/ops/dense/field_ops/DenseFieldSetOps.java deleted file mode 100644 index 2f7aae130..000000000 --- a/src/main/java/org/flag4j/linalg/ops/dense/field_ops/DenseFieldSetOps.java +++ /dev/null @@ -1,86 +0,0 @@ -/* - * MIT License - * - * Copyright (c) 2024. Jacob Watters - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -package org.flag4j.linalg.ops.dense.field_ops; - - -import org.flag4j.algebraic_structures.Field; -import org.flag4j.arrays.Shape; -import org.flag4j.util.ErrorMessages; -import org.flag4j.util.ValidateParameters; - -/** - * This class contains low-level implementations of setting ops for dense field tensors. - */ -public final class DenseFieldSetOps { - - private DenseFieldSetOps() { - // Hide constructor for utility class.. - throw new IllegalArgumentException(ErrorMessages.getUtilityClassErrMsg(getClass())); - } - - - /** - * Sets the value of this matrix using a 2D array. - * - * @param src New values of the matrix. - * @param dest Destination array for values. - * @throws IllegalArgumentException If the source and destination arrays have different number of total data. - */ - public static > void setValues(T[] src, final T[] dest) { - ValidateParameters.ensureArrayLengthsEq(src.length, dest.length); - System.arraycopy(src, 0, dest, 0, src.length); - } - - - /** - * Sets the value of this matrix using a 2D array. - * - * @param src New values of the matrix. - * @param dest Destination array for values. - * @throws IllegalArgumentException If the source and destination arrays have different number of total data. - */ - public static > void setValues(T[][] src, final T[] dest) { - ValidateParameters.ensureTotalEntriesEq(src, dest); - int count = 0; - final int cols = src[0].length; - - for(T[] vals : src) { - for(int j = 0; j < cols; j++) - dest[count++] = vals[j]; - } - } - - - /** - * Sets an element of a tensor to the specified value. - * @param src Elements of the tensor. This will be modified. - * @param shape Shape of the tensor. - * @param value Value to set specified index to. - * @param indices Indices of tensor value to be set. - */ - public static > void set(T[] src, Shape shape, T value, int... indices) { - src[shape.getFlatIndex(indices)] = value; - } -} diff --git a/src/main/java/org/flag4j/linalg/ops/dense/field_ops/DenseFieldTranspose.java b/src/main/java/org/flag4j/linalg/ops/dense/field_ops/DenseFieldTranspose.java deleted file mode 100644 index c868b2b12..000000000 --- a/src/main/java/org/flag4j/linalg/ops/dense/field_ops/DenseFieldTranspose.java +++ /dev/null @@ -1,522 +0,0 @@ -/* - * MIT License - * - * Copyright (c) 2024. Jacob Watters - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -package org.flag4j.linalg.ops.dense.field_ops; - -import org.flag4j.algebraic_structures.Field; -import org.flag4j.arrays.Shape; -import org.flag4j.arrays.dense.FieldTensor; -import org.flag4j.concurrency.Configurations; -import org.flag4j.concurrency.ThreadManager; -import org.flag4j.util.ArrayUtils; - -/** - * Utility class for computing the transpose of a dense {@link FieldTensor field tensor}. - */ -public final class DenseFieldTranspose { - - private DenseFieldTranspose() { - // Hide default constructor for utility class. - } - - - /** - * Transposes tensor along specified axes using a standard transpose algorithm. In this context, transposing a - * tensor is equivalent to swapping a pair of axes. - * @param src Entries of the tensor. - * @param shape Shape of the tensor to transpose. - * @param axis1 First axis to swap in transpose. - * @param axis2 Second axis to swap in transpose. - * @return The transpose of the tensor along the specified axes. - */ - public static > T[] standard(T[] src, Shape shape, int axis1, int axis2) { - if(shape.getRank() < 2) { // Can't transpose tensor with less than 2 axes. - throw new IllegalArgumentException("Tensor transpose not defined for rank " + shape.getRank() + - " tensor."); - } - - T[] dest = (T[]) new Field[shape.totalEntriesIntValueExact()]; - Shape destShape = shape.swapAxes(axis1, axis2); - int[] destIndices; - - for(int i=0; i> T[] standardConcurrent(T[] src, Shape shape, int axis1, int axis2) { - if(shape.getRank() < 2) { // Can't transpose tensor with less than 2 axes. - throw new IllegalArgumentException("Tensor transpose not defined for rank " - + shape.getRank() + " tensor."); - } - - T[] dest = (T[]) new Field[shape.totalEntriesIntValueExact()]; - Shape destShape = shape.swapAxes(axis1, axis2); - - // Compute transpose concurrently - ThreadManager.concurrentOperation(src.length, (startIdx, endIdx) -> { - for(int i=startIdx; i> T[] standard(T[] src, Shape shape, int[] axes) { - if(shape.getRank() < 2) { // Can't transpose tensor with less than 2 axes. - throw new IllegalArgumentException("Tensor transpose not defined for rank " + shape.getRank() + - " tensor."); - } - - T[] dest = (T[]) new Field[shape.totalEntries().intValue()]; - Shape destShape = shape.permuteAxes(axes); - int[] destIndices; - - for(int i=0; i> T[] standardConcurrent(T[] src, Shape shape, int[] axes) { - if(shape.getRank() < 2) { // Can't transpose tensor with less than 2 axes. - throw new IllegalArgumentException("Tensor transpose not defined for rank " + shape.getRank() + - " tensor."); - } - - T[] dest = (T[]) new Field[shape.totalEntries().intValue()]; - Shape destShape = shape.permuteAxes(axes); - - // Compute transpose concurrently. - ThreadManager.concurrentOperation(src.length, (startIdx, endIdx) -> { - for(int i=startIdx; i> T[] standardMatrix(T[] src, int numRows, int numCols) { - T[] dest = (T[]) new Field[numRows*numCols]; - - int destIndex, srcIndex, end; - - for (int i=0; i> T[] blockedMatrix(T[] src, int numRows, int numCols) { - T[] dest = (T[]) new Field[numRows*numCols]; - int blockSize = Configurations.getBlockSize(); - int blockRowEnd; - int blockColEnd; - int srcIndex, destIndex, end; - - for(int i=0; i> T[] standardMatrixConcurrent(T[] src, int numRows, int numCols) { - T[] dest = (T[]) new Field[src.length]; - - // Compute transpose concurrently. - ThreadManager.concurrentOperation(numCols, (startIdx, endIdx) -> { - for(int i=startIdx; i> T[] blockedMatrixConcurrent(T[] src, int numRows, int numCols) { - T[] dest = (T[]) new Field[src.length]; - int blockSize = Configurations.getBlockSize(); - - // Compute transpose concurrently. - ThreadManager.concurrentBlockedOperation(numCols, blockSize, (startIdx, endIdx) -> { - for(int i=startIdx; i> T[] standardHerm(T[] src, Shape shape, int axis1, int axis2) { - if(shape.getRank() < 2) { // Can't transpose tensor with less than 2 axes. - throw new IllegalArgumentException("Tensor transpose not defined for rank " + shape.getRank() + - " tensor."); - } - - T[] dest = (T[]) new Field[shape.totalEntries().intValue()]; - Shape destShape = shape.swapAxes(axis1, axis2); - int[] destIndices; - - for(int i=0; i> Field[] standardConcurrentHerm(T[] src, Shape shape, int axis1, int axis2) { - if(shape.getRank() < 2) { // Can't transpose tensor with less than 2 axes. - throw new IllegalArgumentException("Tensor transpose not defined for rank " + shape.getRank() + - " tensor."); - } - - T[] dest = (T[]) new Field[shape.totalEntries().intValue()]; - Shape destShape = shape.swapAxes(axis1, axis2); - - // Compute transpose concurrently - ThreadManager.concurrentOperation(src.length, (startIdx, endIdx) -> { - for(int i=startIdx; i> T[] standardConcurrentHerm(T[] src, Shape shape, int[] axes) { - if(shape.getRank() < 2) { // Can't transpose tensor with less than 2 axes. - throw new IllegalArgumentException("Tensor transpose not defined for rank " + shape.getRank() + - " tensor."); - } - - T[] dest = (T[]) new Field[shape.totalEntries().intValue()]; - Shape destShape = shape.permuteAxes(axes); - - ThreadManager.concurrentOperation(src.length, (startIdx, endIdx) -> { - for(int i=startIdx; i> T[] standardMatrixHerm(T[] src, int numRows, int numCols) { - T[] dest = (T[]) new Field[numRows*numCols]; - - int destIndex, srcIndex, end; - - for (int i=0; i> T[] blockedMatrixHerm(T[] src, int numRows, int numCols) { - T[] dest = (T[]) new Field[numRows*numCols]; - int blockSize = Configurations.getBlockSize(); - int blockRowEnd; - int blockColEnd; - int srcIndex, destIndex, end; - - for(int i=0; i> T[] standardMatrixConcurrentHerm(T[] src, int numRows, int numCols) { - T[] dest = (T[]) new Field[src.length]; - - // Compute transpose concurrently. - ThreadManager.concurrentOperation(numCols, (startIdx, endIdx) -> { - for(int i=startIdx; i> T[] blockedMatrixConcurrentHerm(T[] src, int numRows, int numCols) { - T[] dest = (T[]) new Field[src.length]; - int blockSize = Configurations.getBlockSize(); - - // Compute transpose concurrently. - ThreadManager.concurrentBlockedOperation(numCols, blockSize, (startIdx, endIdx) -> { - for(int i=startIdx; i * WARNING: These methods do not perform any sanity checks. */ -public final class RealDenseMatrixMultiplication { +public final class RealDenseMatMult { - private RealDenseMatrixMultiplication() { + private RealDenseMatMult() { // Hide default constructor. } diff --git a/src/main/java/org/flag4j/linalg/ops/dense/real/RealDenseMatrixMultTranspose.java b/src/main/java/org/flag4j/linalg/ops/dense/real/RealDenseMatMultTranspose.java similarity index 98% rename from src/main/java/org/flag4j/linalg/ops/dense/real/RealDenseMatrixMultTranspose.java rename to src/main/java/org/flag4j/linalg/ops/dense/real/RealDenseMatMultTranspose.java index f721c1a62..dadea0d15 100644 --- a/src/main/java/org/flag4j/linalg/ops/dense/real/RealDenseMatrixMultTranspose.java +++ b/src/main/java/org/flag4j/linalg/ops/dense/real/RealDenseMatMultTranspose.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2023-2024. Jacob Watters + * Copyright (c) 2023-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -33,9 +33,9 @@ * two real dense matrices.
    * WARNING: These methods do not perform any sanity checks. */ -public final class RealDenseMatrixMultTranspose { +public final class RealDenseMatMultTranspose { - private RealDenseMatrixMultTranspose() { + private RealDenseMatMultTranspose() { // Hide default constructor. } diff --git a/src/main/java/org/flag4j/linalg/ops/dense/real/RealDenseOps.java b/src/main/java/org/flag4j/linalg/ops/dense/real/RealDenseOps.java index 681f39dbe..81f98b596 100644 --- a/src/main/java/org/flag4j/linalg/ops/dense/real/RealDenseOps.java +++ b/src/main/java/org/flag4j/linalg/ops/dense/real/RealDenseOps.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2022-2024. Jacob Watters + * Copyright (c) 2022-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -36,7 +36,6 @@ public final class RealDenseOps { private RealDenseOps() { // Hide constructor for utility class. - } @@ -235,7 +234,7 @@ public static double[] add(double[] src, double b, double[] dest) { * @param src Entries of the tensor to compute the trace of. * @param axis1 First axis for 2D sub-array. * @param axis2 Second axis for 2D sub-array. - * @param destShape The resulting shape of the tensor trace. Use {@link #getTrShape(Shape, int, int)} to compute this. + * @param destShape The resulting shape of the tensor trace. * @param dest Array to store the result of the generalized tensor trace of. Must satisfy * {@code dest.length == destShape.totalEntriesIntValueExact()}. * @@ -286,4 +285,62 @@ public static void tensorTr(Shape shape, double[] src, dest[i] = sum; } } + + + /** + *

    Swaps two rows, over a specified range of columns, within a matrix. Specifically, all elements in the matrix within rows + * {@code rowIdx1} + * and {@code rowIdx2} and between columns {@code start} (inclusive) and {@code stop} (exclusive). This operation is done in place. + *

    No bounds checking is done within this method to ensure that the indices provided are valid. + * + * @param shape Shape of the matrix. + * @param data Data of the matrix (modified). + * @param rowIdx1 Index of the first row to swap. + * @param rowIdx2 Index of the second row to swap. + * @param start Index of the column specifying the start of the range for the row swap (inclusive). + * @param stop Index of the column specifying the end of the range for the row swap (exclusive). + */ + public static void swapRowsUnsafe(Shape shape, double[] data, int rowIdx1, int rowIdx2, int start, int stop) { + if(rowIdx1 == rowIdx2) return; + + final int cols = shape.get(1); + final int rowOffset1 = rowIdx1*cols; + final int rowOffset2 = rowIdx2*cols; + double temp; + + for(int j=start; jSwaps two columns, over a specified range of rows, within a matrix. Specifically, all elements in the matrix within columns + * {@code colIdx1} and {@code colIdx2} and between rows {@code start} (inclusive) and {@code stop} (exclusive). This operation + * is done in place. + *

    No bounds checking is done within this method to ensure that the indices provided are valid. + * + * @param shape Shape of the matrix. + * @param data Data of the matrix (modified). + * @param colIdx1 Index of the first column to swap. + * @param colIdx2 Index of the second column to swap. + * @param start Index of the row specifying the start of the range for the row swap (inclusive). + * @param stop Index of the row specifying the end of the range for the row swap (exclusive). + */ + public static void swapColsUnsafe(Shape shape, double[] data, int colIdx1, int colIdx2, int start, int stop) { + if(colIdx1 == colIdx2) return; + + final int cols = shape.get(1); + int rowOffset = start*cols; + double temp; + + for(int i=start; i> void sub(Shape shape1, T[] src1, for(int i=0, size=src1.length; i> boolean isHermitian(Shape shape, T[] src) { + if(shape.get(0)!=shape.get(1)) return false; + + int numCols = shape.get(1); + + for(int i=0, rows=shape.get(0); iChecks if a matrix is the identity matrix approximately. + * + *

    Specifically, if the diagonal data are no farther than + * {@code 1.001E-5} in absolute value from {@code 1.0} and the non-diagonal data are no larger than {@code 1e-08} in absolute + * value. + * + *

    These thresholds correspond to the thresholds from the + * {@link org.flag4j.linalg.ops.common.ring_ops.RingProperties#allClose(Ring[], Ring[])} method. + * + * + * @param src Matrix of interest to check if it is the identity matrix. + * @return True if the {@code src} matrix is close the identity matrix or if the matrix has zero data. + */ + public static > boolean isCloseToIdentity(Shape shape, T[] src) { + int numRows = shape.get(0); + int numCols = shape.get(1); + + if(src == null || numRows!=numCols) return false; + if(src.length == 0) return true; + + // Tolerances corresponds to the allClose(...) methods. + double diagTol = 1.001E-5; + double nonDiagTol = 1e-08; + final T ONE = src[0].getOne(); + int rows = numRows; + int cols = numCols; + int pos = 0; + + for(int i=0; i diagTol) + || (i!=j && src[pos].mag() > nonDiagTol)) { + return false; + } + + pos++; + } + } + + return true; // If we make it here, the matrix is "close" to the identity matrix. + } } diff --git a/src/main/java/org/flag4j/linalg/ops/dense/semiring_ops/DenseSemiringOps.java b/src/main/java/org/flag4j/linalg/ops/dense/semiring_ops/DenseSemiringOps.java index 289a47035..186773b68 100644 --- a/src/main/java/org/flag4j/linalg/ops/dense/semiring_ops/DenseSemiringOps.java +++ b/src/main/java/org/flag4j/linalg/ops/dense/semiring_ops/DenseSemiringOps.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -37,7 +37,6 @@ public final class DenseSemiringOps { private DenseSemiringOps() { // Hide constructor for utility class. - } diff --git a/src/main/java/org/flag4j/linalg/ops/dense_sparse/coo/field_ops/DenseCooFieldEquals.java b/src/main/java/org/flag4j/linalg/ops/dense_sparse/coo/field_ops/DenseCooFieldEquals.java index 2cbf4917d..de10129d1 100644 --- a/src/main/java/org/flag4j/linalg/ops/dense_sparse/coo/field_ops/DenseCooFieldEquals.java +++ b/src/main/java/org/flag4j/linalg/ops/dense_sparse/coo/field_ops/DenseCooFieldEquals.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -37,7 +37,6 @@ public final class DenseCooFieldEquals { private DenseCooFieldEquals() { // Hide constructor for utility class. for utility class. - } diff --git a/src/main/java/org/flag4j/linalg/ops/dense_sparse/csr/real/RealCsrDenseMatrixMultiplication.java b/src/main/java/org/flag4j/linalg/ops/dense_sparse/csr/real/RealCsrDenseMatrixMultiplication.java index cb829739c..bad0c8a7b 100644 --- a/src/main/java/org/flag4j/linalg/ops/dense_sparse/csr/real/RealCsrDenseMatrixMultiplication.java +++ b/src/main/java/org/flag4j/linalg/ops/dense_sparse/csr/real/RealCsrDenseMatrixMultiplication.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -39,7 +39,6 @@ public final class RealCsrDenseMatrixMultiplication { private RealCsrDenseMatrixMultiplication() { // Hide default constructor for utility method. - } diff --git a/src/main/java/org/flag4j/linalg/ops/dense_sparse/csr/real_field_ops/RealFieldDenseCsrMatMult.java b/src/main/java/org/flag4j/linalg/ops/dense_sparse/csr/real_field_ops/RealFieldDenseCsrMatMult.java index fb9b99e7d..98599262b 100644 --- a/src/main/java/org/flag4j/linalg/ops/dense_sparse/csr/real_field_ops/RealFieldDenseCsrMatMult.java +++ b/src/main/java/org/flag4j/linalg/ops/dense_sparse/csr/real_field_ops/RealFieldDenseCsrMatMult.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -46,7 +46,6 @@ public final class RealFieldDenseCsrMatMult { private RealFieldDenseCsrMatMult() { // Hide default constructor for utility method. - } diff --git a/src/main/java/org/flag4j/linalg/ops/dense_sparse/csr/field_ops/DenseCsrFieldMatMult.java b/src/main/java/org/flag4j/linalg/ops/dense_sparse/csr/semiring_ops/DenseCsrSemiringMatMult.java similarity index 80% rename from src/main/java/org/flag4j/linalg/ops/dense_sparse/csr/field_ops/DenseCsrFieldMatMult.java rename to src/main/java/org/flag4j/linalg/ops/dense_sparse/csr/semiring_ops/DenseCsrSemiringMatMult.java index 6830112f2..1f710ca2e 100644 --- a/src/main/java/org/flag4j/linalg/ops/dense_sparse/csr/field_ops/DenseCsrFieldMatMult.java +++ b/src/main/java/org/flag4j/linalg/ops/dense_sparse/csr/semiring_ops/DenseCsrSemiringMatMult.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -22,28 +22,27 @@ * SOFTWARE. */ -package org.flag4j.linalg.ops.dense_sparse.csr.field_ops; +package org.flag4j.linalg.ops.dense_sparse.csr.semiring_ops; import org.flag4j.algebraic_structures.Complex128; -import org.flag4j.algebraic_structures.Field; +import org.flag4j.algebraic_structures.Semiring; import org.flag4j.arrays.Shape; -import org.flag4j.arrays.backend.field_arrays.AbstractCsrFieldMatrix; -import org.flag4j.arrays.backend.field_arrays.AbstractDenseFieldMatrix; -import org.flag4j.arrays.backend.field_arrays.AbstractDenseFieldVector; +import org.flag4j.arrays.backend.semiring_arrays.AbstractCsrSemiringMatrix; +import org.flag4j.arrays.backend.semiring_arrays.AbstractDenseSemiringMatrix; +import org.flag4j.arrays.backend.semiring_arrays.AbstractDenseSemiringVector; import org.flag4j.util.ValidateParameters; import java.util.Arrays; /** - * This class contains low-level implementations of complex-complex sparse-sparse matrix multiplication where the sparse matrices - * are in CSR format. + * This class contains low-level implementations of {@link Semiring semiring} sparse-dense matrix multiplication where the sparse + * matrices are in CSR format. */ -public final class DenseCsrFieldMatMult { +public final class DenseCsrSemiringMatMult { - private DenseCsrFieldMatMult() { + private DenseCsrSemiringMatMult() { // Hide default constructor for utility method. - } @@ -57,12 +56,12 @@ private DenseCsrFieldMatMult() { * @throws IllegalArgumentException If {@code src1} does not have the same number of columns as {@code src2} has * rows. */ - public static > AbstractDenseFieldMatrix standard( - AbstractCsrFieldMatrix src1, AbstractDenseFieldMatrix src2) { + public static > AbstractDenseSemiringMatrix standard( + AbstractCsrSemiringMatrix src1, AbstractDenseSemiringMatrix src2) { // Ensure matrices have shapes conducive to matrix multiplication. ValidateParameters.ensureMatMultShapes(src1.shape, src2.shape); - - T[] destEntries = (T[]) new Field[src1.numRows*src2.numCols]; + + T[] destEntries = src2.makeEmptyDataArray(src1.numRows*src2.numCols); Arrays.fill(destEntries, src2.getZeroElement()); int rows1 = src1.numRows; int cols2 = src2.numCols; @@ -98,12 +97,12 @@ private DenseCsrFieldMatMult() { * @throws IllegalArgumentException If {@code src1} does not have the same number of columns as {@code src2} has * rows. */ - public static > AbstractDenseFieldMatrix standard( - AbstractDenseFieldMatrix src1, AbstractCsrFieldMatrix src2) { + public static > AbstractDenseSemiringMatrix standard( + AbstractDenseSemiringMatrix src1, AbstractCsrSemiringMatrix src2) { // Ensure matrices have shapes conducive to matrix multiplication. ValidateParameters.ensureMatMultShapes(src1.shape, src2.shape); - T[] destEntries = (T[]) new Field[src1.numRows * src2.numCols]; + T[] destEntries = src1.makeEmptyDataArray(src1.numRows*src2.numCols); Arrays.fill(destEntries, src1.getZeroElement()); int rows1 = src1.numRows; int cols1 = src1.numCols; @@ -138,12 +137,12 @@ private DenseCsrFieldMatMult() { * @throws IllegalArgumentException If the number of columns in {@code src1} does not equal the length of * {@code src2}. */ - public static > AbstractDenseFieldVector standardVector( - AbstractCsrFieldMatrix src1, AbstractDenseFieldVector src2) { + public static > AbstractDenseSemiringVector standardVector( + AbstractCsrSemiringMatrix src1, AbstractDenseSemiringVector src2) { // Ensure the matrix and vector have shapes conducive to multiplication. ValidateParameters.ensureEquals(src1.numCols, src2.size); - T[] destEntries = (T[]) new Field[src1.numRows]; + T[] destEntries = src2.makeEmptyDataArray(src1.numRows); Arrays.fill(destEntries, Complex128.ZERO); int rows1 = src1.numRows; diff --git a/src/main/java/org/flag4j/linalg/ops/sparse/SparseUtils.java b/src/main/java/org/flag4j/linalg/ops/sparse/SparseUtils.java index cf004e6ef..103c0aa41 100644 --- a/src/main/java/org/flag4j/linalg/ops/sparse/SparseUtils.java +++ b/src/main/java/org/flag4j/linalg/ops/sparse/SparseUtils.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -25,13 +25,14 @@ package org.flag4j.linalg.ops.sparse; import org.flag4j.algebraic_structures.Complex128; -import org.flag4j.algebraic_structures.Field; import org.flag4j.algebraic_structures.Semiring; import org.flag4j.arrays.*; -import org.flag4j.arrays.backend.field_arrays.AbstractCsrFieldMatrix; +import org.flag4j.arrays.backend.semiring_arrays.AbstractCsrSemiringMatrix; +import org.flag4j.arrays.sparse.CooCMatrix; import org.flag4j.arrays.sparse.CooMatrix; import org.flag4j.arrays.sparse.CsrFieldMatrix; import org.flag4j.arrays.sparse.CsrMatrix; +import org.flag4j.linalg.ops.sparse.coo.semiring_ops.CooSemiringEquals; import org.flag4j.util.ErrorMessages; import org.flag4j.util.ValidateParameters; @@ -301,52 +302,15 @@ public static boolean CSREquals(CsrMatrix src1, CsrMatrix src2) { * @param src2 Second CSR matrix in the equality comparison. * @return True if all non-zero values stored in the two matrices are equal and occur at the same indices. */ - public static > boolean CSREquals( - AbstractCsrFieldMatrix src1, - AbstractCsrFieldMatrix src2) { + public static > boolean CSREquals( + AbstractCsrSemiringMatrix src1, + AbstractCsrSemiringMatrix src2) { + if(src1 == src2) return true; if(!src1.shape.equals(src2.shape)) return false; - final Complex128 ZERO = Complex128.ZERO; - // Compare row by row - for (int i=0; i void copyRanges( */ public static SparseMatrixData coalesce( BinaryOperator aggregator, Shape shape, T[] data, int[] rowIndices, int[] colIndices) { - HashMap, T> coalescedValues = new HashMap<>(); + HashMap, T> coalescedValues = new LinkedHashMap<>(); List destRowIndices = new ArrayList<>(data.length); List destColIndices = new ArrayList<>(data.length); - for(int i = 0; i< data.length; i++) { + for(int i = 0; i idx = new Pair<>(rowIndices[i], colIndices[i]); T value = data[i]; @@ -785,4 +749,88 @@ public static > SparseMatrixData dropZerosCsr( return new SparseMatrixData<>(shape, destData, destRowIndices, destColIndices); } + + + /** + * Validates that the specified slice is a valid slice of a matrix with the specified {@code shape}. + * @param shape Shape of the matrix. + * @param rowStart Starting row index of the slice (inclusive). + * @param rowEnd Ending row index of the slice (exclusive). + * @param colStart Starting column index of the slice (inclusive). + * @param colEnd Ending column index of the slice (exclusive). + * @throws IllegalArgumentException If any of the following are {@code true}: + *

      + *
    • {@code rowStart >= rowEnd}
    • + *
    • {@code colStart >= colEnd}
    • + *
    • {@code rowStart < 0 || rowEnd > shape.get(0)}
    • + *
    • {@code colStart < 0 || colEnd > shape.get(1)}
    • + *
    + */ + public static void validateSlice(Shape shape, int rowStart, int rowEnd, int colStart, int colEnd) { + if(rowStart >= rowEnd) { + throw new IllegalArgumentException("rowStart must be greater than rowEnd but got: rowStart=" + + rowStart + " and rowEnd=" + rowEnd + "."); + } + if(colStart >= colEnd) { + throw new IllegalArgumentException("colStart must be greater than colEnd but got: colStart=" + + colStart + " and colEnd=" + colEnd + "."); + } + if(rowStart < 0 || rowEnd > shape.get(0)) { + throw new IllegalArgumentException("Invalid range specified for row indices: [" + rowStart + ", " + rowEnd + ").\n" + + "Out of bounds for matrix with shape: " + shape); + } + if(colStart < 0 || colEnd > shape.get(1)) { + throw new IllegalArgumentException("Invalid range specified for column indices: [" + colStart + ", " + colEnd + ").\n" + + "Out of bounds for matrix with shape: " + shape); + } + } + + + /** + * Validates that the provided arguments specify a valid CSR matrix. + * @param shape Shape of the CSR matrix. + * @param nnz The number of non-zero entries of the CSR matrix. + * @param rowPointers The non-zero row pointers of the CSR matrix. + * @param colIndices The non-zero column indices of the CSR matrix. + * + * @throws IllegalArgumentException If any of the following are {@code true}: + *
      + *
    • {@code shape.getRank() != 2}
    • + *
    • {@code rowPointers.length != shape.get(0) + 1}
    • + *
    • {@code nnz != colIndices.length}
    • + *
    + */ + public static void validateCsrMatrix(Shape shape, int nnz, int[] rowPointers, int[] colIndices) { + if (shape.getRank() != 2) { + throw new IllegalArgumentException("Invalid CSR definition: shape must be of rank 2 but got: " + shape); + } + if (rowPointers.length != shape.get(0) + 1) { + throw new IllegalArgumentException("Invalid CSR definition: the number of row pointers must be " + + "equal to the number of rows plus 1 but got row pointer length " + + rowPointers.length + " for shape " + shape + "."); + } + + if (nnz != colIndices.length) { + throw new IllegalArgumentException("Illegal CSR definition: the number of column indices must be equal" + + "to the number of non-zero entries but got " + colIndices.length + " for nnz=" + nnz + "."); + } + } + + + // TODO: TEMP + public static void main(String[] args) { + Complex128[] bNnz = new Complex128[]{new Complex128(234.5, -0.2), Complex128.ZERO, Complex128.ZERO, Complex128.ZERO, + new Complex128(345.1, 2.5), + new Complex128(9.4, -1), + Complex128.ZERO, new Complex128(235.1, 94.2), new Complex128(3.12, 4), + new Complex128(0, 1), new Complex128(2,9733)}; + int[][] bIndices = new int[][]{ + {0, 0, 0, 0, 0, 1, 5, 5, 12, 67, 67}, + {0, 1, 2, 3, 5, 14, 45, 5002, 142, 15, 60001}}; + Shape bShape = new Shape(900, 450000); + + CooCMatrix b = new CooCMatrix(bShape, bNnz, bIndices[0], bIndices[1]); + + b.coalesce(); + } } diff --git a/src/main/java/org/flag4j/linalg/ops/sparse/coo/CooConversions.java b/src/main/java/org/flag4j/linalg/ops/sparse/coo/CooConversions.java index 589ab24dc..10349e6d6 100644 --- a/src/main/java/org/flag4j/linalg/ops/sparse/coo/CooConversions.java +++ b/src/main/java/org/flag4j/linalg/ops/sparse/coo/CooConversions.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -69,7 +69,7 @@ public static void toCsr(Shape shape, T[] entries, int[] rowIndices, int[] c T[] destEntries, int[] destRowPointers, int[] destColIndices) { final int numRows = shape.get(0); - // Copy the non-zero data anc column indices. Count number of data per row. + // Copy the non-zero data and column indices. Count number of data per row. for(int i=0, size=entries.length; i indices = new ArrayList<>(); - for(int i=0; i list : keys) { List subList = list.subList(start, stop); - for (int i = 0; i < key.size(); i++) { + for (int i = 0; i < key.size(); i++) Collections.swap(subList, swapFrom.get(i), swapTo.get(i)); - } } // Find ranges which have the same value in the sorted key list. @@ -269,9 +266,8 @@ private void sparseSortHelper(int keyIdx, int start, int stop) { */ public void unwrap(Object[] values, int[][] indices) { // Copy over data values. - for(int i=0; i boolean isSymmetric(Shape shape, T[] data, int[] rowIndices, int[] colIndices, T zeroValue) { + if(shape.get(0) != shape.get(1)) return false; // Early return for non-square matrix. + + Map, T> dataMap = new HashMap, T>(); + + for(int i = 0, size=data.length; i < size; i++) { + if(rowIndices[i] == colIndices[i] || Objects.equals(data[i], zeroValue)) + continue; // This value is zero or on the diagonal. No need to consider. + + var p1 = new Pair<>(rowIndices[i], colIndices[i]); + var p2 = new Pair<>(colIndices[i], rowIndices[i]); + + if(!dataMap.containsKey(p2)) { + dataMap.put(p1, data[i]); + } else if(!dataMap.get(p2).equals(data[i])){ + return false; // Not symmetric. + } else { + dataMap.remove(p2); + } + } + + // If there are any remaining values a value with the transposed indices was not found in the matrix. + return dataMap.isEmpty(); + } +} diff --git a/src/main/java/org/flag4j/linalg/ops/sparse/coo/field_ops/CooFieldElementSearch.java b/src/main/java/org/flag4j/linalg/ops/sparse/coo/field_ops/CooFieldElementSearch.java deleted file mode 100644 index ebb3ba580..000000000 --- a/src/main/java/org/flag4j/linalg/ops/sparse/coo/field_ops/CooFieldElementSearch.java +++ /dev/null @@ -1,110 +0,0 @@ -/* - * MIT License - * - * Copyright (c) 2024. Jacob Watters - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -package org.flag4j.linalg.ops.sparse.coo.field_ops; - -import org.flag4j.arrays.backend.field_arrays.AbstractCooFieldMatrix; - -import java.util.Arrays; - -/** - * Utility class for searching for specific elements within a sparse COO matrix. - */ -public final class CooFieldElementSearch { - - private CooFieldElementSearch() { - // Hide default constructor in utility class. - } - - - /** - * Preforms a binary search along the row and column indices of the non-zero values of a sparse matrix for the location - * of an entry with the specified target indices. - * - * @param src Source matrix to search within. - * @param rowKey Target row index. - * @param colKey Target col index. - * @return The location of the non-zero element (within the non-zero values array of {@code src}) with the specified - * row and column indices. If this value does not exist, then (-(insertion point) - 1) - * will be returned. The insertion point is defined as the point at which the - * value, with the row and column key, would be inserted into the array: the index of the first - * element greater than the key, or {@code src.data.length} if all - * elements in the array are less than the specified key. Note - * that this guarantees that the return value will be >= 0 if - * and only if the key is found. - */ - public static int matrixBinarySearch(AbstractCooFieldMatrix src, int rowKey, int colKey) { - int rowIdx = Arrays.binarySearch(src.rowIndices, rowKey); - if(rowIdx<0) return rowIdx; - - // Find range of same valued row indices. - int lowerBound = rowIdx; - for(int i=rowIdx; i>=0; i--) { - if(src.rowIndices[i] == rowKey) lowerBound = i; - else break; - } - - int upperBound = rowIdx + 1; - for(int i=upperBound; i src, int rowKey) { - int rowIdx = Arrays.binarySearch(src.rowIndices, rowKey); - - if(rowIdx < 0) return new int[]{rowIdx, rowIdx}; // Row not found. - - // Find first entry with the specified row key. - int lowerBound = rowIdx; - for(int i=rowIdx; i>=0; i--) { - if(src.rowIndices[i] == rowKey) lowerBound = i; - else break; - } - - int upperBound = rowIdx + 1; - for(int i=upperBound; iThis utility class contains methods for checking the equality, or approximately equal, of sparse COO tensors whose data are - * {@link Field field} elements. - */ -public final class CooFieldEquals { - - private CooFieldEquals() { - // Hide default constructor for utility class. - } - - - /** - * Checks if two real sparse tensors are real. Assumes the indices of each sparse tensor are sorted. Any explicitly stored - * zero's will be ignored. - * @param a First tensor in the equality check. - * @param b Second tensor in the equality check. - * @return True if the tensors are equal. False otherwise. - */ - public static > boolean cooTensorEquals( - AbstractCooFieldTensor a, - AbstractCooFieldTensor b) { - if (a == b) return true; - if (a == null || b == null) return false; - - a = a.coalesce().dropZeros(); - b = b.coalesce().dropZeros(); - return a.shape.equals(b.shape) - && Arrays.equals(a.data, b.data) - && Arrays.deepEquals(a.indices, b.indices); - } - - - /** - * Checks if two real sparse matrices are real. Assumes the indices of each sparse matrix are sorted. Any explicitly stored - * zero's will be ignored. - * @param a First matrix in the equality check. - * @param b Second matrix in the equality check. - * @return True if the matrices are equal. False otherwise. - */ - public static > boolean cooMatrixEquals( - AbstractCooFieldMatrix a, - AbstractCooFieldMatrix b) { - // Early return if possible. - if (a == b) return true; - if (a == null || b == null) return false; - - a = a.coalesce().dropZeros(); - b = b.coalesce().dropZeros(); - return a.shape.equals(b.shape) - && Arrays.equals(a.data, b.data) - && Arrays.equals(a.rowIndices, b.rowIndices) - && Arrays.equals(a.colIndices, b.colIndices); - } - - - /** - * Checks if two real sparse vectors are real. Assumes the indices of each sparse vector are sorted. Any explicitly stored - * zero's will be ignored. - * @param a First vector in the equality check. - * @param b Second vector in the equality check. - * @return True if the vectors are equal. False otherwise. - */ - public static > boolean cooVectorEquals( - AbstractCooFieldVector a, - AbstractCooFieldVector b) { - // Early returns if possible. - if(a == b) return true; - if(a==null || b==null || !a.shape.equals(b.shape)) return false; - - a = a.coalesce().dropZeros(); - b = b.coalesce().dropZeros(); - return a.shape.equals(b.shape) - && Arrays.equals(a.data, b.data) - && Arrays.equals(a.indices, b.indices); - } - - - /** - * Checks if two real sparse vectors are real. Assumes the indices of each sparse vector are sorted. Any explicitly stored - * zero's will be ignored. - * @param a First vector in the equality check. - * @param b Second vector in the equality check. - * @return True if the vectors are equal. False otherwise. - */ - public static > boolean cooVectorEquals( - AbstractCooSemiringVector a, - AbstractCooSemiringVector b) { - // Early returns if possible. - if(a == b) return true; - if(a==null || b==null || !a.shape.equals(b.shape)) return false; - - a = a.coalesce().dropZeros(); - b = b.coalesce().dropZeros(); - return a.shape.equals(b.shape) - && Arrays.equals(a.data, b.data) - && Arrays.equals(a.indices, b.indices); - } - - - /** - * Checks that all non-zero data are "close" according to {@link RealProperties#allClose(double[], double[])} and - * * all indices are the same. - * @param src1 First matrix in comparison. - * @param src2 Second matrix in comparison. - * @param relTol Relative tolerance. - * @param absTol Absolute tolerance. - * @return True if all data are "close". Otherwise, false. - */ - public static > boolean allClose(AbstractCooFieldMatrix src1, - AbstractCooFieldMatrix src2, - double relTol, double absTol) { - // TODO: We need to first check if values are "close" to zero and remove them. Then do the indices and entry check. - return src1.shape.equals(src2.shape) - && Arrays.equals(src1.rowIndices, src2.rowIndices) - && Arrays.equals(src1.colIndices, src2.colIndices) - && RingProperties.allClose(src1.data, src2.data, relTol, absTol); - } - - - /** - * Checks that all non-zero data are "close" according to - * - * {@link RingProperties#allClose(Ring[], Ring[], double, double)} and all indices - * are the same. - * @param src1 First tensor in comparison. - * @param src2 Second tensor in comparison. - * @param relTol Relative tolerance. - * @param absTol Absolute tolerance. - * @return True if all data are "close". Otherwise, false. - */ - public static > boolean allClose(AbstractCooFieldTensor src1, - AbstractCooFieldTensor src2, - double relTol, double absTol) { - // TODO: We need to first check if values are "close" to zero and remove them. Then do the indices and entry check. - return src1.shape.equals(src2.shape) - && Arrays.deepEquals(src1.indices, src2.indices) - && RingProperties.allClose(src1.data, src2.data, relTol, absTol); - } - - - /** - * Checks that all non-zero data are "close" according to - * {@link RingProperties#allClose(Ring[], Ring[])} )} and all indices are the same. - * @param src1 First vector in comparison. - * @param src2 Second vector in comparison. - * @param relTol Relative tolerance. - * @param absTol Absolute tolerance. - * @return True if all data are "close". Otherwise, false. - */ - public static > boolean allClose(AbstractCooFieldVector src1, - AbstractCooFieldVector src2, - double relTol, double absTol) { - // TODO: We need to first check if values are "close" to zero and remove them. Then do the indices and entry check. - return src1.shape.equals(src2.shape) - && Arrays.equals(src1.indices, src2.indices) - && RingProperties.allClose(src1.data, src2.data, relTol, absTol); - } -} diff --git a/src/main/java/org/flag4j/linalg/ops/sparse/coo/field_ops/CooFieldMatrixGetSet.java b/src/main/java/org/flag4j/linalg/ops/sparse/coo/field_ops/CooFieldMatrixGetSet.java deleted file mode 100644 index c71a1d41c..000000000 --- a/src/main/java/org/flag4j/linalg/ops/sparse/coo/field_ops/CooFieldMatrixGetSet.java +++ /dev/null @@ -1,680 +0,0 @@ -/* - * MIT License - * - * Copyright (c) 2024. Jacob Watters - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -package org.flag4j.linalg.ops.sparse.coo.field_ops; - - -import org.flag4j.algebraic_structures.Field; -import org.flag4j.arrays.Shape; -import org.flag4j.arrays.backend.MatrixMixin; -import org.flag4j.arrays.backend.field_arrays.AbstractCooFieldMatrix; -import org.flag4j.arrays.backend.field_arrays.AbstractCooFieldVector; -import org.flag4j.linalg.ops.sparse.SparseElementSearch; -import org.flag4j.util.ArrayUtils; -import org.flag4j.util.ValidateParameters; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.IntStream; - -/** - * This class contains methods for getting/setting elements and slices from/to a sparse - * {@link Field} matrix. - */ -public final class CooFieldMatrixGetSet { - - private CooFieldMatrixGetSet() { - // Hide default constructor in utility class. - - } - - - /** - * Gets the specified element from a sparse matrix. - * @param src Source matrix to get value from. - * @param row Row index of the value to get from the sparse matrix. - * @param col Column index of the value to get from the sparse matrix. - * @return The value in the sparse matrix at the specified indices. - */ - public static > V matrixGet(AbstractCooFieldMatrix src, int row, int col) { - V zero = src.data.length > 0 ? src.data[0].getZero() : null; - int idx = SparseElementSearch.matrixBinarySearch(src.rowIndices, src.colIndices, row, col); - - return idx<0 ? zero : src.data[idx]; - } - - - /** - * Sets the specified element from a sparse matrix. - * @param src Sparse matrix to set value in. - * @param row Row index of the value to set in the sparse matrix. - * @param col Column index of the value to set in the sparse matrix. - * @param value Value to set. - * @return The - */ - public static > AbstractCooFieldMatrix - matrixSet(AbstractCooFieldMatrix src, int row, int col, V value) { - // Find position of row index within the row indices if it exits. - int idx = SparseElementSearch.matrixBinarySearch(src.rowIndices, src.colIndices, row, col); - V[] destEntries; - int[] destRowIndices; - int[] destColIndices; - - if(idx < 0) { - idx = -idx - 1; - - // No non-zero element with these indices exists. Insert new value. - destEntries = src.makeEmptyDataArray(src.data.length + 1); - System.arraycopy(src.data, 0, destEntries, 0, idx); - destEntries[idx] = value; - System.arraycopy(src.data, idx, destEntries, -idx, src.data.length - idx); - - destRowIndices = new int[src.data.length + 1]; - System.arraycopy(src.rowIndices, 0, destRowIndices, 0, idx); - destRowIndices[idx] = row; - System.arraycopy(src.rowIndices, idx, destRowIndices, -idx, src.rowIndices.length - idx); - - destColIndices = new int[src.data.length + 1]; - System.arraycopy(src.colIndices, 0, destColIndices, 0, idx); - destColIndices[idx] = col; - System.arraycopy(src.colIndices, idx, destColIndices, idx+1, src.colIndices.length - idx); - } else { - // Value with these indices exists. Simply update value. - destEntries = Arrays.copyOf(src.data, src.data.length); - destEntries[idx] = value; - destRowIndices = src.rowIndices.clone(); - destColIndices = src.colIndices.clone(); - } - - return src.makeLikeTensor(src.shape, destEntries, destRowIndices, destColIndices); - } - - - /** - * Sets a specified row of a complex sparse matrix to the values of a dense array. - * @param src Source matrix to set the row of. - * @param rowIdx Index of the row to set. - * @param row Dense array containing the data of the row to set. - * @return A copy of the {@code src} matrix with the specified row set to the dense {@code row} array. - */ - public static > AbstractCooFieldMatrix setRow( - AbstractCooFieldMatrix src, int rowIdx, V[] row) { - ValidateParameters.ensureIndicesInBounds(src.numRows, rowIdx); - ValidateParameters.ensureEquals(src.numCols, row.length); - - int[] startEnd = SparseElementSearch.matrixFindRowStartEnd(src.rowIndices, rowIdx); - int start = startEnd[0]; - int end = startEnd[1]; - - V[] destEntries; - int[] destRowIndices ; - int[] destColIndices; - - if(start<0) { - // No data with row index found. - destEntries = src.makeEmptyDataArray(src.data.length + row.length); - destRowIndices = new int[destEntries.length]; - destColIndices = new int[destEntries.length]; - - System.arraycopy(src.data, 0, destEntries, 0, -start-1); - System.arraycopy(row, 0, destEntries, -start-1, row.length); - System.arraycopy( - src.data, -start-1, - destEntries, -start-1+row.length, destEntries.length-(row.length - start - 1) - ); - - System.arraycopy(src.rowIndices, 0, destRowIndices, 0, -start-1); - Arrays.fill(destRowIndices, -start-1, -start-1+row.length, rowIdx); - System.arraycopy( - src.rowIndices, -start-1, - destRowIndices, -start-1+row.length, destRowIndices.length-(row.length - start - 1) - ); - - System.arraycopy(src.colIndices, 0, destColIndices, 0, -start-1); - System.arraycopy(ArrayUtils.intRange(0, src.numCols), 0, destColIndices, -start-1, row.length); - System.arraycopy( - src.colIndices, -start-1, - destColIndices, -start-1+row.length, destColIndices.length-(row.length - start - 1) - ); - - } else { - // Entries with row index found. - destEntries = src.makeEmptyDataArray(src.data.length + row.length - (end-start)); - destRowIndices = new int[destEntries.length]; - destColIndices = new int[destEntries.length]; - - System.arraycopy(src.data, 0, destEntries, 0, start); - System.arraycopy(row, 0, destEntries, start, row.length); - System.arraycopy( - src.data, end, - destEntries, start + row.length, destEntries.length-(start + row.length) - ); - - System.arraycopy(src.rowIndices, 0, destRowIndices, 0, start); - Arrays.fill(destRowIndices, start, start+row.length, rowIdx); - System.arraycopy( - src.rowIndices, end, - destRowIndices, start + row.length, destEntries.length-(start + row.length) - ); - - System.arraycopy(src.colIndices, 0, destColIndices, 0, start); - System.arraycopy(ArrayUtils.intRange(0, src.numCols), 0, destColIndices, start, row.length); - System.arraycopy( - src.colIndices, end, - destColIndices, start + row.length, destEntries.length-(start + row.length) - ); - } - - return src.makeLikeTensor(src.shape, destEntries, destRowIndices, destColIndices); - } - - - /** - * Sets a specified row of a complex sparse matrix to the values of a sparse vector. - * @param src Source matrix to set the row of. - * @param rowIdx Index of the row to set. - * @param row Dense array containing the data of the row to set. - * @return A copy of the {@code src} matrix with the specified row set to the dense {@code row} array. - */ - public static > AbstractCooFieldMatrix setRow( - AbstractCooFieldMatrix src, int rowIdx, AbstractCooFieldVector row) { - ValidateParameters.ensureIndicesInBounds(src.numRows, rowIdx); - ValidateParameters.ensureEquals(src.numCols, row.size); - - int[] startEnd = SparseElementSearch.matrixFindRowStartEnd(src.rowIndices, rowIdx); - int start = startEnd[0]; - int end = startEnd[1]; - - V[] destEntries; - int[] destRowIndices ; - int[] destColIndices; - - if(start<0) { - // No data with row index found. - destEntries = src.makeEmptyDataArray(src.data.length + row.data.length); - destRowIndices = new int[destEntries.length]; - destColIndices = new int[destEntries.length]; - - System.arraycopy(src.data, 0, destEntries, 0, -start-1); - System.arraycopy(row.data, 0, destEntries, -start-1, row.data.length); - System.arraycopy( - src.data, -start-1, - destEntries, -start-1+row.data.length, destEntries.length-(row.data.length - start - 1) - ); - - System.arraycopy(src.rowIndices, 0, destRowIndices, 0, -start-1); - Arrays.fill(destRowIndices, -start-1, -start-1+row.data.length, rowIdx); - System.arraycopy( - src.rowIndices, -start-1, - destRowIndices, -start-1+row.data.length, destRowIndices.length-(row.data.length - start - 1) - ); - - System.arraycopy(src.colIndices, 0, destColIndices, 0, -start-1); - System.arraycopy(row.indices, 0, destColIndices, -start-1, row.data.length); - System.arraycopy( - src.colIndices, -start-1, - destColIndices, -start-1+row.data.length, destColIndices.length-(row.data.length - start - 1) - ); - - } else { - // Entries with row index found. - destEntries = src.makeEmptyDataArray(src.data.length + row.data.length - (end-start)); - destRowIndices = new int[destEntries.length]; - destColIndices = new int[destEntries.length]; - - System.arraycopy(src.data, 0, destEntries, 0, start); - System.arraycopy(row.data, 0, destEntries, start, row.data.length); - int length = destEntries.length - (start + row.data.length); - - System.arraycopy( - src.data, end, - destEntries, start + row.data.length, length - ); - - System.arraycopy(src.rowIndices, 0, destRowIndices, 0, start); - Arrays.fill(destRowIndices, start, start+row.data.length, rowIdx); - System.arraycopy( - src.rowIndices, end, - destRowIndices, start + row.data.length, length - ); - - System.arraycopy(src.colIndices, 0, destColIndices, 0, start); - System.arraycopy(row.indices, 0, destColIndices, start, row.data.length); - System.arraycopy( - src.colIndices, end, - destColIndices, start + row.data.length, length - ); - } - - return src.makeLikeTensor(src.shape, destEntries, destRowIndices, destColIndices); - } - - - /** - * Sets a column of a sparse matrix to the data of a dense array. - * @param src Source matrix to set column of. - * @param colIdx The index of the column to set within the {@code src} matrix. - * @param col The dense array containing the new column data for the {@code src} array. - * @return A copy of the {@code src} matrix with the specified column set to the dense array. - * @throws IllegalArgumentException If the {@code colIdx} is not within the range of the matrix. - * @throws IllegalArgumentException If the {@code col} array does not have the same length as the number of - * rows in {@code src} matrix. - */ - public static > AbstractCooFieldMatrix - setCol(AbstractCooFieldMatrix src, int colIdx, V[] col) { - ValidateParameters.ensureIndicesInBounds(src.numCols, colIdx); - ValidateParameters.ensureEquals(src.numRows, col.length); - - Integer[] colIndices = new Integer[col.length]; - Arrays.fill(colIndices, colIdx); - - // Initialize destination arrays with the new column and the appropriate indices. - List destEntries = Arrays.asList(col); - List destRowIndices = IntStream.of( - ArrayUtils.intRange(0, col.length) - ).boxed().collect(Collectors.toList()); - List destColIndices = new ArrayList<>(Arrays.asList(colIndices)); - - // Add all data in old matrix that are NOT in the specified column. - for(int i = 0; i dest = src.makeLikeTensor(src.shape, destEntries, destRowIndices, destColIndices); - dest.sortIndices(); // Ensure the indices are sorted properly. - - return dest; - } - - - /** - * Sets a column of a sparse matrix to the data of a sparse vector. - * @param src Source matrix to set column of. - * @param colIdx The index of the column to set within the {@code src} matrix. - * @param col The dense array containing the new column data for the {@code src} array. - * @return A copy of the {@code src} matrix with the specified column set to the dense array. - * @throws IllegalArgumentException If the {@code colIdx} is not within the range of the matrix. - * @throws IllegalArgumentException If the {@code col} array does not have the same length as the number of - * rows in {@code src} matrix. - */ - public static > AbstractCooFieldMatrix setCol( - AbstractCooFieldMatrix src, - int colIdx, - AbstractCooFieldVector col) { - ValidateParameters.ensureIndicesInBounds(src.numCols, colIdx); - ValidateParameters.ensureEquals(src.numRows, col.size); - - int[] colIndices = new int[col.data.length]; - Arrays.fill(colIndices, colIdx); - - // Initialize destination arrays with the new column and the appropriate indices. - List destEntries = new ArrayList(Arrays.asList(col.data)); - List destRowIndices = ArrayUtils.toArrayList(col.indices); - List destColIndices = ArrayUtils.toArrayList(colIndices); - - // Add all data in old matrix that are NOT in the specified column. - for(int i = 0; i dest = src.makeLikeTensor( - src.shape, - destEntries, - destRowIndices, - destColIndices - ); - dest.sortIndices(); - - return dest; - } - - - /** - * Copies a sparse matrix and sets a slice of the sparse matrix to the data of another sparse matrix. - * @param src Source sparse matrix to copy and set values of. - * @param values Values of the slice to be set. - * @param row Starting row index of slice. - * @param col Starting column index of slice. - * @return A copy of the {@code src} matrix with the specified slice set to the {@code values} matrix. - * @throws IllegalArgumentException If the {@code values} matrix does not fit in the {@code src} - * matrix given the row and - * column index. - */ - public static > AbstractCooFieldMatrix setSlice( - AbstractCooFieldMatrix src, - AbstractCooFieldMatrix values, - int row, int col) { - // Ensure the values matrix fits inside the src matrix. - setSliceParamCheck(src, values, row, col); - - // Initialize lists to new values for the specified slice. - List entries = new ArrayList<>(Arrays.asList(values.data)); - List rowIndices = ArrayUtils.toArrayList(ArrayUtils.shift(row, values.rowIndices)); - List colIndices = ArrayUtils.toArrayList(ArrayUtils.shift(col, values.colIndices)); - - int[] rowRange = ArrayUtils.intRange(row, values.numRows + row); - int[] colRange = ArrayUtils.intRange(col, values.numCols + col); - - copyValuesNotInSlice(src, entries, rowIndices, colIndices, rowRange, colRange); - - // Create matrix and ensure data are properly sorted. - AbstractCooFieldMatrix mat = src.makeLikeTensor(src.shape, entries, rowIndices, colIndices); - mat.sortIndices(); - - return mat; - } - - - /** - * Copies a sparse matrix and sets a slice of the sparse matrix to the data of a dense array. - * @param src Source sparse matrix to copy and set values of. - * @param values Dense values of the slice to be set. - * @param row Starting row index of slice. - * @param col Starting column index of slice. - * @return A copy of the {@code src} matrix with the specified slice set to the {@code values} array. - * @throws IllegalArgumentException If the {@code values} array does not fit in the {@code src} matrix - * given the row and column index. - */ - public static > AbstractCooFieldMatrix setSlice( - AbstractCooFieldMatrix src, V[][] values, int row, int col) { - // Ensure the values matrix fits inside the src matrix. - setSliceParamCheck(src, values.length, values[0].length, row, col); - - // Flatten values. - V[] flatValues = ArrayUtils.flatten(values); - int[] sliceRows = ArrayUtils.intRange(row, values.length + row, values[0].length); - int[] sliceCols = ArrayUtils.repeat(values.length, ArrayUtils.intRange(col, values[0].length + col)); - - return setSlice(src, flatValues, values.length, values[0].length, sliceRows, sliceCols, row, col); - } - - - /** - * Sets a slice of a sparse matrix to values given in a 1d dense array. - * @param src Source sparse matrix to copy non-slice from. - * @param values Dense value for slice. - * @param numRows Number of rows in the matrix represented by {@code values}. - * @param numCols Number of columns in the matrix represented by {@code values}. - * @param sliceRows Row indices for slice. - * @param sliceCols Column indices for slice. - * @param row Starting row index of slice. - * @param col Starting column index of slice. - * @return A copy of the {@code src} matrix with the specified slice set to the specified values. - */ - private static > AbstractCooFieldMatrix setSlice( - AbstractCooFieldMatrix src, V[] values, - int numRows, int numCols, int[] sliceRows, int[] sliceCols, int row, int col) { - // Copy vales and row/column indices (with appropriate shifting) to destination lists. - List entries = new ArrayList<>(Arrays.asList(values)); - List rowIndices = ArrayUtils.toArrayList(sliceRows); - List colIndices = ArrayUtils.toArrayList(sliceCols); - - int[] rowRange = ArrayUtils.intRange(row, numRows + row); - int[] colRange = ArrayUtils.intRange(col, numCols + col); - - copyValuesNotInSlice(src, entries, rowIndices, colIndices, rowRange, colRange); - - // Create matrix and ensure data are properly sorted. - AbstractCooFieldMatrix mat = src.makeLikeTensor(src.shape, entries, rowIndices, colIndices); - mat.sortIndices(); - - return mat; - } - - - /** - * Gets a specified row from this sparse matrix. - * @param src Source sparse matrix to extract row from. - * @param rowIdx Index of the row to extract from the {@code src} matrix. - * @return Returns the specified row from this sparse matrix. - */ - public static > AbstractCooFieldVector getRow(AbstractCooFieldMatrix src, int rowIdx) { - ValidateParameters.ensureIndicesInBounds(src.numRows, rowIdx); - - List entries = new ArrayList<>(); - List indices = new ArrayList<>(); - - for(int i = 0; i> AbstractCooFieldVector getRow( - AbstractCooFieldMatrix src, - int rowIdx, int start, int end) { - ValidateParameters.ensureIndicesInBounds(src.numRows, rowIdx); - ValidateParameters.ensureIndicesInBounds(src.numCols, start, end-1); - ValidateParameters.ensureLessEq(end-1, start); - - List entries = new ArrayList<>(); - List indices = new ArrayList<>(); - - for(int i = 0; i= start && src.colIndices[i] < end) { - entries.add(src.data[i]); - indices.add(src.colIndices[i]); - } - } - - return src.makeLikeVector(new Shape(end-start), - entries.toArray(src.makeEmptyDataArray(entries.size())), - ArrayUtils.fromIntegerList(indices)); - } - - - /** - * Gets a specified column from this sparse matrix. - * @param src Source sparse matrix to extract column from. - * @param colIdx Index of the column to extract from the {@code src} matrix. - * @return Returns the specified column from this sparse matrix. - */ - public static > AbstractCooFieldVector getCol(AbstractCooFieldMatrix src, int colIdx) { - ValidateParameters.ensureIndicesInBounds(src.numCols, colIdx); - - List entries = new ArrayList<>(); - List indices = new ArrayList<>(); - - for(int i = 0; i> AbstractCooFieldVector getCol( - AbstractCooFieldMatrix src, - int colIdx, int start, int end) { - ValidateParameters.ensureIndicesInBounds(src.numCols, colIdx); - ValidateParameters.ensureIndicesInBounds(src.numRows, start, end); - ValidateParameters.ensureLessEq(end, start); - - List entries = new ArrayList<>(); - List indices = new ArrayList<>(); - - for(int i = 0; i= start && src.rowIndices[i] < end) { - entries.add(src.data[i]); - indices.add(src.rowIndices[i]); - } - } - - return src.makeLikeVector(new Shape(end-start), - (T[]) entries.toArray(new Field[entries.size()]), - ArrayUtils.fromIntegerList(indices)); - } - - - /** - * Gets a specified rectangular slice of a sparse matrix. - * @param src Sparse matrix to extract slice from. - * @param rowStart Starting row index of the slice (inclusive). - * @param rowEnd Ending row index of the slice (exclusive). - * @param colStart Staring column index of a slice (inclusive). - * @param colEnd Ending column index of the slice (exclusive). - * @return The specified slice of the sparse matrix. - */ - public static > AbstractCooFieldMatrix getSlice( - AbstractCooFieldMatrix src, - int rowStart, int rowEnd, int colStart, int colEnd) { - ValidateParameters.ensureIndicesInBounds(src.numRows, rowStart, rowEnd-1); - ValidateParameters.ensureIndicesInBounds(src.numCols, colStart, colEnd-1); - - Shape shape = new Shape(rowEnd-rowStart, colEnd-colStart); - List entries = new ArrayList<>(); - List rowIndices = new ArrayList<>(); - List colIndices = new ArrayList<>(); - - int start = SparseElementSearch.matrixBinarySearch(src.rowIndices, src.colIndices, rowStart, colStart); - - if(start < 0) { - // If no item with the specified indices is found, then begin search at the insertion point. - start = -start - 1; - } - - for(int i = start; i= rowStart && row < rowEnd && col >= colStart && col < colEnd; - } - - - /** - * Copies values in sparse matrix which do not fall in the specified row and column ranges. - * @param src Source sparse matrix to copy from. - * @param entries Destination list to add copied values to. - * @param rowIndices Destination list to add copied row indices to. - * @param colIndices Destination list to add copied column indices to. - * @param rowRange List of row indices to NOT copy from. - * @param colRange List of column indices to NOT copy from. - */ - private static > void copyValuesNotInSlice( - AbstractCooFieldMatrix src, List entries, - List rowIndices, - List colIndices, int[] rowRange, int[] colRange) { - // Copy values not in slice. - for(int i = 0; i, U extends MatrixMixin> void setSliceParamCheck( - T src, U values, int row, int col) { - - ValidateParameters.ensureIndicesInBounds(src.numRows(), row); - ValidateParameters.ensureIndicesInBounds(src.numCols(), col); - ValidateParameters.ensureLessEq(src.numRows(), values.numRows() + row); - ValidateParameters.ensureLessEq(src.numCols(), values.numCols() + col); - } - - - /** - * Checks that parameters are valued for setting a slice of a matrix. - * @param src Matrix to set slice of. - * @param valueRows Number of rows in slice to set. - * @param valueCols Number of columns in slice to be set. - * @param row Starting row for slice. - * @param col Ending row for slice. - */ - private static > void setSliceParamCheck( - T src, int valueRows, int valueCols, int row, int col) { - ValidateParameters.ensureIndicesInBounds(src.numRows(), row); - ValidateParameters.ensureIndicesInBounds(src.numCols(), col); - ValidateParameters.ensureLessEq(src.numRows(), valueRows + row); - ValidateParameters.ensureLessEq(src.numCols(), valueCols + col); - } -} diff --git a/src/main/java/org/flag4j/linalg/ops/sparse/coo/field_ops/CooFieldMatrixManipulations.java b/src/main/java/org/flag4j/linalg/ops/sparse/coo/field_ops/CooFieldMatrixManipulations.java deleted file mode 100644 index 29687ab10..000000000 --- a/src/main/java/org/flag4j/linalg/ops/sparse/coo/field_ops/CooFieldMatrixManipulations.java +++ /dev/null @@ -1,224 +0,0 @@ -/* - * MIT License - * - * Copyright (c) 2024. Jacob Watters - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -package org.flag4j.linalg.ops.sparse.coo.field_ops; - -import org.flag4j.algebraic_structures.Field; -import org.flag4j.arrays.Shape; -import org.flag4j.arrays.backend.field_arrays.AbstractCooFieldMatrix; -import org.flag4j.linalg.ops.sparse.SparseElementSearch; -import org.flag4j.util.ArrayUtils; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -/** - * This class contains implementations for sparse {@link Field} matrix manipulations. - */ -public final class CooFieldMatrixManipulations { - - private CooFieldMatrixManipulations() { - // Hide default constructor for utility class. - } - - - /** - * Removes a specified row from a sparse matrix. - * @param src Source matrix to remove row of. - * @param rowIdx Row to remove from the {@code src} matrix. - * @return A sparse matrix which has one less row than the {@code src} matrix with the specified row removed. - */ - public static > AbstractCooFieldMatrix removeRow( - AbstractCooFieldMatrix src, int rowIdx) { - Shape shape = new Shape(src.numRows-1, src.numCols); - - // Find the start and end index within the data array which have the given row index. - int[] startEnd = SparseElementSearch.matrixFindRowStartEnd(src.rowIndices, rowIdx); - int size = src.data.length - (startEnd[1]-startEnd[0]); - - // Initialize arrays. - V[] entries = (V[]) new Field[size]; - int[] rowIndices = new int[size]; - int[] colIndices = new int[size]; - - copyRanges(src, entries, rowIndices, colIndices, startEnd); - - return src.makeLikeTensor(shape, (V[]) entries, rowIndices, colIndices); - } - - - /** - * Removes multiple rows from a sparse matrix. - * @param src The source sparse matrix to remove rows from. - * @param rowIdxs Indices of rows to remove from the {@code src} matrix. - * @return A copy of the {@code src} matrix with the specified rows removed. - */ - public static > AbstractCooFieldMatrix - removeRows(AbstractCooFieldMatrix src, int... rowIdxs) { - Shape shape = new Shape(src.numRows-rowIdxs.length, src.numCols); - List entries = new ArrayList<>(src.data.length); - List rowIndices = new ArrayList<>(src.data.length); - List colIndices = new ArrayList<>(src.data.length); - - for(int i = 0; i> AbstractCooFieldMatrix - removeCol(AbstractCooFieldMatrix src, int colIdx) { - Shape shape = new Shape(src.numRows, src.numCols-1); - List entries = new ArrayList<>(src.data.length); - List rowIndices = new ArrayList<>(src.data.length); - List colIndices = new ArrayList<>(src.data.length); - - for(int i = 0; i> AbstractCooFieldMatrix - removeCols(AbstractCooFieldMatrix src, int... colIdxs) { - Shape shape = new Shape(src.numRows, src.numCols-1); - List entries = new ArrayList<>(src.data.length); - List rowIndices = new ArrayList<>(src.data.length); - List colIndices = new ArrayList<>(src.data.length); - - for(int i = 0; i> void - copyRanges(AbstractCooFieldMatrix src, V[] entries, int[] rowIndices, int[] colIndices, int[] startEnd) { - - if(startEnd[0] > 0) { - System.arraycopy(src.data, 0, entries, 0, startEnd[0]); - System.arraycopy(src.data, startEnd[1], entries, startEnd[0], entries.length - startEnd[0]); - - System.arraycopy(src.rowIndices, 0, rowIndices, 0, startEnd[0]); - System.arraycopy(src.rowIndices, startEnd[1], rowIndices, startEnd[0], entries.length - startEnd[0]); - - System.arraycopy(src.colIndices, 0, colIndices, 0, startEnd[0]); - System.arraycopy(src.colIndices, startEnd[1], colIndices, startEnd[0], entries.length - startEnd[0]); - } else { - System.arraycopy(src.data, 0, entries, 0, entries.length); - System.arraycopy(src.rowIndices, 0, rowIndices, 0, rowIndices.length); - System.arraycopy(src.colIndices, 0, colIndices, 0, colIndices.length); - } - } - - - /** - * Swaps two rows, in place, in a sparse matrix. - * @param src The source sparse matrix to swap rows within. - * @param rowIdx1 Index of the first row in the swap. - * @param rowIdx2 Index of the second row in the swap. - * @return A reference to the {@code src} sparse matrix. - */ - public static > AbstractCooFieldMatrix - swapRows(AbstractCooFieldMatrix src, int rowIdx1, int rowIdx2) { - for(int i = 0; i> AbstractCooFieldMatrix - swapCols(AbstractCooFieldMatrix src, int colIdx1, int colIdx2) { - for(int i = 0; i> AbstractCooFieldMatrix - add(AbstractCooFieldMatrix src1, AbstractCooFieldMatrix src2) { - ValidateParameters.ensureEqualShape(src1.shape, src2.shape); - - int initCapacity = Math.max(src1.data.length, src2.data.length); - - List sum = new ArrayList<>(initCapacity); - List rowIndices = new ArrayList<>(initCapacity); - List colIndices = new ArrayList<>(initCapacity); - - int src1Counter = 0; - int src2Counter = 0; - - // Flags which indicate if a value should be added from the corresponding matrix - boolean add1; - boolean add2; - - while(src1Counter < src1.data.length || src2Counter < src2.data.length) { - - if(src1Counter >= src1.data.length || src2Counter >= src2.data.length) { - add1 = src2Counter >= src2.data.length; - add2 = !add1; - } else if(src1.rowIndices[src1Counter] == src2.rowIndices[src2Counter] - && src1.colIndices[src1Counter] == src2.colIndices[src2Counter]) { - // Found matching indices. - add1 = true; - add2 = true; - } else if(src1.rowIndices[src1Counter] == src2.rowIndices[src2Counter]) { - // Matching row indices. - add1 = src1.colIndices[src1Counter] < src2.colIndices[src2Counter]; - add2 = !add1; - } else { - add1 = src1.rowIndices[src1Counter] < src2.rowIndices[src2Counter]; - add2 = !add1; - } - - if(add1 && add2) { - sum.add(src1.data[src1Counter].add(src2.data[src2Counter])); - rowIndices.add(src1.rowIndices[src1Counter]); - colIndices.add(src1.colIndices[src1Counter]); - src1Counter++; - src2Counter++; - } else if(add1) { - sum.add(src1.data[src1Counter]); - rowIndices.add(src1.rowIndices[src1Counter]); - colIndices.add(src1.colIndices[src1Counter]); - src1Counter++; - } else { - sum.add(src2.data[src2Counter]); - rowIndices.add(src2.rowIndices[src2Counter]); - colIndices.add(src2.colIndices[src2Counter]); - src2Counter++; - } - } - - return src1.makeLikeTensor(src1.shape, sum, rowIndices, colIndices); - } - - - /** - * Adds a double all data (including zero values) of a real sparse matrix. - * @param src Sparse matrix to add double value to. - * @param a Double value to add to the sparse matrix. - * @return The result of the matrix addition. - * @throws ArithmeticException If the {@code src} sparse matrix is too large to be converted to a dense matrix. - * That is, there are more than {@link Integer#MAX_VALUE} data in the matrix (including zero data). - */ - public static > FieldMatrix - add(AbstractCooFieldMatrix src, double a) { - V[] sum = (V[]) new Field[src.totalEntries().intValueExact()]; - Arrays.fill(sum, a); - - int row; - int col; - - for(int i = 0; i(src.shape, sum); - } - - - /** - * Computes the subtraction between two real sparse matrices. This method assumes that the indices of the two matrices are sorted - * lexicographically. - * @param src1 First matrix in the subtraction. - * @param src2 Second matrix in the subtraction. - * @return The difference of the two matrices {@code src1} and {@code src2}. - * @throws IllegalArgumentException If the two matrices do not have the same shape. - */ - public static > AbstractCooFieldMatrix - sub(AbstractCooFieldMatrix src1, AbstractCooFieldMatrix src2) { - ValidateParameters.ensureEqualShape(src1.shape, src2.shape); - - int initCapacity = Math.max(src1.data.length, src2.data.length); - - List sum = new ArrayList<>(initCapacity); - List rowIndices = new ArrayList<>(initCapacity); - List colIndices = new ArrayList<>(initCapacity); - - int src1Counter = 0; - int src2Counter = 0; - - // Flags which indicate if a value should be added from the corresponding matrix - boolean add1; - boolean add2; - - while(src1Counter < src1.data.length || src2Counter < src2.data.length) { - - if(src1Counter >= src1.data.length || src2Counter >= src2.data.length) { - add1 = src2Counter >= src2.data.length; - add2 = !add1; - } else if(src1.rowIndices[src1Counter] == src2.rowIndices[src2Counter] - && src1.colIndices[src1Counter] == src2.colIndices[src2Counter]) { - // Found matching indices. - add1 = true; - add2 = true; - } else if(src1.rowIndices[src1Counter] == src2.rowIndices[src2Counter]) { - // Matching row indices. - add1 = src1.colIndices[src1Counter] < src2.colIndices[src2Counter]; - add2 = !add1; - } else { - add1 = src1.rowIndices[src1Counter] < src2.rowIndices[src2Counter]; - add2 = !add1; - } - - if(add1 && add2) { - sum.add(src1.data[src1Counter].sub(src2.data[src2Counter])); - rowIndices.add(src1.rowIndices[src1Counter]); - colIndices.add(src1.colIndices[src1Counter]); - src1Counter++; - src2Counter++; - } else if(add1) { - sum.add(src1.data[src1Counter]); - rowIndices.add(src1.rowIndices[src1Counter]); - colIndices.add(src1.colIndices[src1Counter]); - src1Counter++; - } else { - sum.add(src2.data[src2Counter].addInv()); - rowIndices.add(src2.rowIndices[src2Counter]); - colIndices.add(src2.colIndices[src2Counter]); - src2Counter++; - } - } - - return src1.makeLikeTensor(src1.shape, sum, rowIndices, colIndices); - } - - - /** - * Subtracts a double from all data (including zero values) of a real sparse matrix. - * @param src Sparse matrix to subtract double value from. - * @param a Double value to subtract from the sparse matrix. - * @return The result of the matrix subtraction. - * @throws ArithmeticException If the {@code src} sparse matrix is too large to be converted to a dense matrix. - * That is, there are more than {@link Integer#MAX_VALUE} data in the matrix (including zero data). - */ - public static > FieldMatrix - sub(AbstractCooFieldMatrix src, double a) { - V[] sum = (V[]) new Field[src.totalEntries().intValueExact()]; - Arrays.fill(sum, -a); - - int row; - int col; - - for(int i = 0; i(src.shape, sum); - } - - - - /** - * Multiplies two sparse matrices element-wise. This method assumes that the indices of the two matrices are sorted - * lexicographically. - * @param src1 First matrix in the element-wise multiplication. - * @param src2 Second matrix in the element-wise multiplication. - * @return The element-wise product of the two matrices {@code src1} and {@code src2}. - * @throws IllegalArgumentException If the two matrices do not have the same shape. - */ - public static > AbstractCooFieldMatrix - elemMult(AbstractCooFieldMatrix src1, AbstractCooFieldMatrix src2) { - ValidateParameters.ensureEqualShape(src1.shape, src2.shape); - - int initCapacity = Math.max(src1.data.length, src2.data.length); - - List product = new ArrayList<>(initCapacity); - List rowIndices = new ArrayList<>(initCapacity); - List colIndices = new ArrayList<>(initCapacity); - - int src1Counter = 0; - int src2Counter = 0; - - while(src1Counter < src1.data.length && src2Counter < src2.data.length) { - if(src1.rowIndices[src1Counter] == src2.rowIndices[src2Counter] - && src1.colIndices[src1Counter] == src2.colIndices[src2Counter]) { - product.add(src1.data[src1Counter].mult(src2.data[src2Counter])); - rowIndices.add(src1.rowIndices[src1Counter]); - colIndices.add(src1.colIndices[src1Counter]); - src1Counter++; - src2Counter++; - } else if(src1.rowIndices[src1Counter] == src2.rowIndices[src2Counter]) { - // Matching row indices. - - if(src1.colIndices[src1Counter] < src2.colIndices[src2Counter]) { - src1Counter++; - } else { - src2Counter++; - } - } else { - if(src1.rowIndices[src1Counter] < src2.rowIndices[src2Counter]) { - src1Counter++; - } else { - src2Counter++; - } - } - } - - return src1.makeLikeTensor(src1.shape, product, rowIndices, colIndices); - } - - /** - * Adds a sparse vector to each column of a sparse matrix as if the vector is a column vector. - * @param src The source sparse matrix. - * @param col Sparse vector to add to each column of the sparse matrix. - * @return A dense copy of the {@code src} matrix with the {@code col} vector added to each row of the matrix. - */ - public static > FieldMatrix addToEachCol(CooFieldMatrix src, CooFieldVector col) { - ValidateParameters.ensureEquals(src.numRows, col.size); - T[] destEntries = (T[]) new Field[src.totalEntries().intValueExact()]; - - // Add values from sparse matrix. - for(int i = 0; i(src.shape, destEntries); - } - - - /** - * Adds a sparse vector to each row of a sparse matrix as if the vector is a row vector. - * @param src The source sparse matrix. - * @param row Sparse vector to add to each row of the sparse matrix. - * @return A dense copy of the {@code src} matrix with the {@code row} vector added to each row of the matrix. - */ - public static > FieldMatrix addToEachRow(CooFieldMatrix src, CooFieldVector row) { - ValidateParameters.ensureEquals(src.numCols, row.size); - T[] destEntries = (T[]) new Field[src.totalEntries().intValueExact()]; - - // Add values from sparse matrix. - for(int i = 0; i(src.shape, destEntries); - } -} diff --git a/src/main/java/org/flag4j/linalg/ops/sparse/coo/field_ops/CooFieldMatrixProperties.java b/src/main/java/org/flag4j/linalg/ops/sparse/coo/field_ops/CooFieldMatrixProperties.java deleted file mode 100644 index 968d612cc..000000000 --- a/src/main/java/org/flag4j/linalg/ops/sparse/coo/field_ops/CooFieldMatrixProperties.java +++ /dev/null @@ -1,264 +0,0 @@ -/* - * MIT License - * - * Copyright (c) 2024. Jacob Watters - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -package org.flag4j.linalg.ops.sparse.coo.field_ops; - - -import org.flag4j.algebraic_structures.Field; -import org.flag4j.arrays.Pair; -import org.flag4j.arrays.Shape; -import org.flag4j.arrays.backend.field_arrays.AbstractCooFieldMatrix; - -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; -import java.util.stream.IntStream; - -/** - * This class contains low level implementations for methods to evaluate certain properties of a sparse coo field matrix. - * For example, if the matrix is symmetric. - */ -public final class CooFieldMatrixProperties { - - private CooFieldMatrixProperties() { - // Hide public constructor for utility class. - - } - - - /** - * Checks if a complex sparse matrix is the identity matrix. - * @param src Matrix to check if it is the identity matrix. - * @return {@code true} if the {@code src} matrix is the identity matrix; {@code false} otherwise. - */ - public static > boolean isIdentity(AbstractCooFieldMatrix src) { - // Ensure the matrix is square and there are at least the same number of non-zero data as data on the diagonal. - if(!src.isSquare() || src.data.length> boolean isCloseToIdentity(AbstractCooFieldMatrix src) { - // Ensure the matrix is square and there are the same number of non-zero data as data on the diagonal. - boolean result = src.isSquare() && src.data.length==src.numRows; - - // Tolerances corresponds to the allClose(...) methods. - double diagTol = 1.E-5; - double nonDiagTol = 1e-08; - - final T ONE = src.data.length > 0 ? src.data[0].getOne() : null; - - for(int i = 0; i diagTol ) { - return false; // Diagonal value is not close to one. - } else if((src.rowIndices[i] != i && src.colIndices[i] != i) && src.data[i].mag() > nonDiagTol) { - return false; // Non-diagonal value is not close to zero. - } - } - - return true; - } - - - /** - * Checks if a sparse COO matrix is hermitian. That is, the matrix is equal to its conjugate transpose. - * @param shape Shape of the COO matrix. - * @param entries Non-zero data of the COO matrix. - * @param rowIndices Non-zero row indices of the COO matrix. - * @param colIndices Non-zero column indices of the COO matrix. - * @return {@code true} if the {@code src} matrix is hermitian. {@code false} otherwise. - */ - public static > boolean isHermitian(Shape shape, T[] entries, int[] rowIndices, int[] colIndices) { - // Check if the matrix is square. - if (shape.get(0) != shape.get(1)) - return false; - - // Build a map from (row, col) to value for quick access. - Map, T> matrixMap = new HashMap<>(); - int nnz = entries.length; // Number of non-zero data. - - for (int i = 0; i < nnz; i++) { - int row = rowIndices[i]; - int col = colIndices[i]; - T value = entries[i]; - - matrixMap.put(new Pair<>(row, col), value); - } - - // Iterate over the data to check for Hermitian property. - for (Map.Entry, T> entry : matrixMap.entrySet()) { - int row = entry.getKey().first(); - int col = entry.getKey().second(); - T value = entry.getValue(); - - // Skip data where row > col to avoid redundant checks. - if (row > col) continue; - - if (row == col) { - // Diagonal data must be real: value == value.conj() - if (!value.equals(value.conj())) return false; - - } else { - // Get the symmetric value at (col, row). - T symValue = matrixMap.get(new Pair<>(col, row)); - - if (symValue == null) // Missing symmetric entry implies zero. - symValue = value.getZero(); - - // Check if value equals the conjugate of the symmetric value. - if (!value.equals(symValue.conj())) return false; - } - } - - // If all checks pass, the matrix is Hermitian. - return true; - } - - - /** - * Checks if a sparse matrix is symmetric. - * @param src Matrix to check if it is the hermitian matrix. - * @return True if the {@code src} matrix is hermitian. False otherwise. - */ - public static > boolean isSymmetric(AbstractCooFieldMatrix src) { - boolean result = src.isSquare(); - - List entries = Arrays.asList(src.data); - List rowIndices = IntStream.of(src.rowIndices).boxed().collect(Collectors.toList()); - List colIndices = IntStream.of(src.colIndices).boxed().collect(Collectors.toList()); - - T value; - int row; - int col; - - while(result && entries.size() > 0) { - // Extract value of interest. - value = entries.remove(0); - row = rowIndices.remove(0); - col = colIndices.remove(0); - - // Find indices of first and last value whose row index matched the value of interests column index. - int rowStart = rowIndices.indexOf(col); - int rowEnd = rowIndices.lastIndexOf(col); - - if(rowStart == -1) { - // Then no non-zero value was found. - result = value.equals(0); - } else { - // At least one entry has a row-index matching the specified column index. - List colIdxRange = colIndices.subList(rowStart, rowEnd + 1); - - // Search for element whose column index matches the specified row index - int idx = colIdxRange.indexOf(row); - - if(idx == -1) { - // Then no non-zero value was found. - result = value.equals(0); - } else { - // Check that value with opposite row/column indices is equal. - result = value.equals(entries.get(idx + rowStart)); - - // Remove the value and the indices. - entries.remove(idx + rowStart); - rowIndices.remove(idx + rowStart); - colIndices.remove(idx + rowStart); - } - } - } - - return result; - } - - - /** - * Checks if a sparse matrix is anti-hermitian. - * @param src Matrix to check if it is the anti-hermitian matrix. - * @return True if the {@code src} matrix is anti-hermitian. False otherwise. - */ - public static > boolean isAntiHermitian(AbstractCooFieldMatrix src) { - boolean result = src.isSquare(); - - List entries = Arrays.asList(src.data); - List rowIndices = IntStream.of(src.rowIndices).boxed().collect(Collectors.toList()); - List colIndices = IntStream.of(src.colIndices).boxed().collect(Collectors.toList()); - - V value; - int row; - int col; - - while(result && entries.size() > 0) { - // Extract value of interest. - value = entries.remove(0); - row = rowIndices.remove(0); - col = colIndices.remove(0); - - // Find indices of first and last value whose row index matched the value of interests column index. - int rowStart = rowIndices.indexOf(col); - int rowEnd = rowIndices.lastIndexOf(col); - - if(rowStart == -1) { - // Then no non-zero value was found. - result = value.equals(0); - } else { - // At least one entry has a row-index matching the specified column index. - List colIdxRange = colIndices.subList(rowStart, rowEnd + 1); - - // Search for element whose column index matches the specified row index - int idx = colIdxRange.indexOf(row); - - if(idx == -1) { - // Then no non-zero value was found. - result = value.equals(0); - } else { - // Check that value with opposite row/column indices is equal. - result = value.equals(entries.get(idx + rowStart).addInv().conj()); - - // Remove the value and the indices. - entries.remove(idx + rowStart); - rowIndices.remove(idx + rowStart); - colIndices.remove(idx + rowStart); - } - } - } - - return result; - } -} diff --git a/src/main/java/org/flag4j/linalg/ops/sparse/coo/field_ops/CooFieldVectorOps.java b/src/main/java/org/flag4j/linalg/ops/sparse/coo/field_ops/CooFieldVectorOps.java deleted file mode 100644 index 8f2211efe..000000000 --- a/src/main/java/org/flag4j/linalg/ops/sparse/coo/field_ops/CooFieldVectorOps.java +++ /dev/null @@ -1,373 +0,0 @@ -/* - * MIT License - * - * Copyright (c) 2024. Jacob Watters - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -package org.flag4j.linalg.ops.sparse.coo.field_ops; - -import org.flag4j.algebraic_structures.Field; -import org.flag4j.arrays.Shape; -import org.flag4j.arrays.backend.field_arrays.AbstractCooFieldMatrix; -import org.flag4j.arrays.backend.field_arrays.AbstractCooFieldVector; -import org.flag4j.arrays.backend.field_arrays.AbstractDenseFieldMatrix; -import org.flag4j.arrays.backend.field_arrays.AbstractDenseFieldVector; -import org.flag4j.util.ArrayUtils; -import org.flag4j.util.ValidateParameters; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - - -/** - * This utility class contains methods for computing ops between two sparse coo - * {@link Field} vectors. - */ -public final class CooFieldVectorOps { - - private CooFieldVectorOps() { - - } - - - /** - * Adds a real number to each entry of a sparse vector, including the zero data. - * @param src Sparse vector to add value to. - * @param a Value to add to the {@code src} sparse vector. - * @return The result of adding the specified value to the sparse vector. - */ - public static > AbstractDenseFieldVector add(AbstractCooFieldVector src, T a) { - T[] dest = (T[]) new Field[src.size]; - Arrays.fill(dest, a); - - for(int i = 0; i> AbstractDenseFieldVector sub(AbstractCooFieldVector src, T a) { - T[] dest = (T[]) new Field[src.size]; - Arrays.fill(dest, a.addInv()); - - for(int i = 0; i> T add( - AbstractCooFieldVector src1, - AbstractCooFieldVector src2) { - ValidateParameters.ensureEqualShape(src1.shape, src2.shape); - List values = new ArrayList<>(src1.data.length); - List indices = new ArrayList<>(src1.data.length); - - int src1Counter = 0; - int src2Counter = 0; - - while(src1Counter < src1.data.length && src2Counter < src2.data.length) { - if(src1.indices[src1Counter] == src2.indices[src2Counter]) { - values.add(src1.data[src1Counter].add(src2.data[src2Counter])); - indices.add(src1.indices[src1Counter]); - src1Counter++; - src2Counter++; - - } else if(src1.indices[src1Counter] < src2.indices[src2Counter]) { - values.add(src1.data[src1Counter]); - indices.add(src1.indices[src1Counter]); - src1Counter++; - } else { - values.add(src2.data[src2Counter]); - indices.add(src2.indices[src2Counter]); - src2Counter++; - } - } - - // Finish inserting the rest of the values. - if(src1Counter < src1.data.length) { - for(int i = src1Counter; i> AbstractCooFieldVector sub( - AbstractCooFieldVector src1, - AbstractCooFieldVector src2) { - ValidateParameters.ensureEqualShape(src1.shape, src2.shape); - List values = new ArrayList<>(src1.data.length); - List indices = new ArrayList<>(src1.data.length); - - int src1Counter = 0; - int src2Counter = 0; - - while(src1Counter < src1.data.length && src2Counter < src2.data.length) { - if(src1.indices[src1Counter] == src2.indices[src2Counter]) { - values.add(src1.data[src1Counter].sub(src2.data[src2Counter])); - indices.add(src1.indices[src1Counter]); - src1Counter++; - src2Counter++; - - } else if(src1.indices[src1Counter] < src2.indices[src2Counter]) { - values.add(src1.data[src1Counter]); - indices.add(src1.indices[src1Counter]); - src1Counter++; - } else { - values.add(src2.data[src2Counter].addInv()); - indices.add(src2.indices[src2Counter]); - src2Counter++; - } - } - - // Finish inserting the rest of the values. - if(src1Counter < src1.data.length) { - for(int i = src1Counter; i> AbstractCooFieldVector elemMult( - AbstractCooFieldVector src1, AbstractCooFieldVector src2) { - ValidateParameters.ensureEqualShape(src1.shape, src2.shape); - List values = new ArrayList<>(src1.data.length); - List indices = new ArrayList<>(src1.data.length); - - int src1Counter = 0; - int src2Counter = 0; - - while(src1Counter < src1.data.length && src2Counter < src2.data.length) { - if(src1.indices[src1Counter]==src2.indices[src2Counter]) { - // Then indices match, add product of elements. - values.add(src1.data[src1Counter].mult(src2.data[src2Counter])); - indices.add(src1.indices[src1Counter]); - src1Counter++; - src2Counter++; - } else if(src1.indices[src1Counter] < src2.indices[src2Counter]) { - src1Counter++; - } else { - src2Counter++; - } - } - - return src2.makeLikeTensor(src1.shape, values, indices); - } - - - /** - * Computes the inner product of two complex sparse vectors. Both sparse vectors are assumed - * to have their indices sorted lexicographically. - * @param src1 First sparse vector in the inner product. Indices assumed to be sorted lexicographically. - * @param src2 Second sparse vector in the inner product. Indices assumed to be sorted lexicographically. - * @return The result of the vector inner product. - * @throws IllegalArgumentException If the two vectors do not have the same size (full size including zeros). - */ - public static > T inner( - AbstractCooFieldVector src1, - AbstractCooFieldVector src2) { - ValidateParameters.ensureEqualShape(src1.shape, src2.shape); - - T product = null; - if(src1.nnz > 0) product = src1.data[0].getZero(); - else if(src2.nnz > 0) product = src2.data[0].getZero(); - - int src1Counter = 0; - int src2Counter = 0; - - while(src1Counter < src1.data.length && src2Counter < src2.data.length) { - if(src1.indices[src1Counter]==src2.indices[src2Counter]) { - // Then indices match, add product of elements. - product = product.add(src1.data[src1Counter].mult(src2.data[src2Counter].conj())); - } else if(src1.indices[src1Counter] < src2.indices[src2Counter]) { - src1Counter++; - } else { - src2Counter++; - } - } - - return product; - } - - - /** - * Computes the vector outer product between two complex sparse vectors. - * @param src1 Entries of the first sparse vector in the outer product. - * @param src2 Second sparse vector in the outer product. - * @return The matrix resulting from the vector outer product. - */ - public static > AbstractDenseFieldMatrix outerProduct( - AbstractCooFieldVector src1, AbstractCooFieldVector src2) { - T[] dest = (T[]) new Field[src2.size*src1.size]; - Arrays.fill(dest, src1.getZeroElement()); - - int destRow; - int index1; - int index2; - - for(int i=0; i> AbstractCooFieldMatrix stack(AbstractCooFieldVector src1, - AbstractCooFieldVector src2) { - ValidateParameters.ensureEqualShape(src1.shape, src2.shape); - - Field[] entries = new Field[src1.data.length + src2.data.length]; - int[][] indices = new int[2][src1.indices.length + src2.indices.length]; // Row and column indices. - - // Copy values from vector src1. - System.arraycopy(src1.data, 0, entries, 0, src1.data.length); - // Copy values from vector src2. - System.arraycopy(src2.data, 0, entries, src1.data.length, src2.data.length); - - // Set row indices to 1 for src2 values (this vectors row indices are 0 which was implicitly set already). - Arrays.fill(indices[0], src1.indices.length, entries.length, 1); - - // Copy indices from src1 vector to the column indices. - System.arraycopy(src1.indices, 0, indices[1], 0, src1.data.length); - // Copy indices from src2 vector to the column indices. - System.arraycopy(src2.indices, 0, indices[1], src1.data.length, src2.data.length); - - return src1.makeLikeMatrix(new Shape(2, src1.size), (T[]) entries, indices[0], indices[1]); - } - - - /** - * Repeats a vector {@code n} times along a certain axis to create a matrix. - * - * @param src The vector to repeat. - * @param n Number of times to repeat vector. - * @param axis Axis along which to repeat vector. If {@code axis=0} then each row of the resulting matrix will be equivalent to - * this vector. If {@code axis=1} then each column of the resulting matrix will be equivalent to this vector. - * - * @return A matrix whose rows/columns are this vector repeated. - */ - public static > AbstractCooFieldMatrix repeat(AbstractCooFieldVector src, - int n, - int axis) { - ValidateParameters.ensureInRange(axis, 0, 1, "axis"); - ValidateParameters.ensureGreaterEq(0, n, "n"); - - Shape tiledShape; - Field[] tiledEntries = new Field[n*src.data.length]; - int[] tiledRows = new int[tiledEntries.length]; - int[] tiledCols = new int[tiledEntries.length]; - int nnz = src.nnz; - - if(axis==0) { - tiledShape = new Shape(n, src.size); - - for(int i=0; i entries = new ArrayList<>(); diff --git a/src/main/java/org/flag4j/linalg/ops/sparse/coo/real/RealSparseMatrixManipulations.java b/src/main/java/org/flag4j/linalg/ops/sparse/coo/real/RealSparseMatrixManipulations.java index 06efa50fb..a98c4e36a 100644 --- a/src/main/java/org/flag4j/linalg/ops/sparse/coo/real/RealSparseMatrixManipulations.java +++ b/src/main/java/org/flag4j/linalg/ops/sparse/coo/real/RealSparseMatrixManipulations.java @@ -28,6 +28,7 @@ import org.flag4j.arrays.sparse.CooMatrix; import org.flag4j.linalg.ops.sparse.SparseElementSearch; import org.flag4j.util.ArrayUtils; +import org.flag4j.util.ValidateParameters; import java.util.ArrayList; import java.util.Arrays; @@ -63,6 +64,17 @@ public static CooMatrix removeRow(CooMatrix src, int rowIdx) { copyRanges(src, entries, rowIndices, colIndices, startEnd); + // Shift all row indices occurring after removed row. + if (startEnd[0] > 0) { + for(int i=startEnd[0], length=rowIndices.length; i rowIdx) + rowIndices[i]--; + } + } + return new CooMatrix(shape, entries, rowIndices, colIndices); } @@ -77,23 +89,40 @@ public static CooMatrix removeRow(CooMatrix src, int rowIdx) { * @return A copy of the {@code src} matrix with the specified rows removed. */ public static CooMatrix removeRows(CooMatrix src, int... rowIdxs) { - Shape shape = new Shape(src.numRows-rowIdxs.length, src.numCols); - List entries = new ArrayList<>(src.data.length); - List rowIndices = new ArrayList<>(src.data.length); - List colIndices = new ArrayList<>(src.data.length); + // Ensure the indices are sorted. + Arrays.sort(rowIdxs); - for(int i = 0; i entries = new ArrayList<>(src.nnz); + List newRowIndices = new ArrayList<>(src.nnz); + List newColIndices = new ArrayList<>(src.nnz); - if(idx < 0) { - // Then copy the entry over and apply proper shift to row index. - entries.add(src.data[i]); - rowIndices.add(src.rowIndices[i] + (idx+1)); - colIndices.add(src.colIndices[i]); + int j = 0; // Points into the rowIdxs array + int removeCount = 0; // Tracks number of removed rows. + + for (int i = 0; i < src.nnz; i++) { + int oldRow = src.rowIndices[i]; + + // Advance j while rowIdxs[j] < oldRow, updating removeCount + while (j < rowIdxs.length && rowIdxs[j] < oldRow) { + removeCount++; + j++; } + + // If oldRow is one of the removed rows, skip this entry. + if (j < rowIdxs.length && rowIdxs[j] == oldRow) + continue; + + // Otherwise, shift oldRow by however many removed rows lie below it. + int newRow = oldRow - removeCount; + + // Keep the entry + entries.add(src.data[i]); + newRowIndices.add(newRow); + newColIndices.add(src.colIndices[i]); } - return new CooMatrix(shape, entries, rowIndices, colIndices); + return new CooMatrix(shape, entries, newRowIndices, newColIndices); } @@ -131,23 +160,35 @@ public static CooMatrix removeCol(CooMatrix src, int colIdx) { * @return A copy of the {@code src} sparse matrix with the specified columns removed. */ public static CooMatrix removeCols(CooMatrix src, int... colIdxs) { - Shape shape = new Shape(src.numRows, src.numCols-1); - List entries = new ArrayList<>(src.data.length); - List rowIndices = new ArrayList<>(src.data.length); - List colIndices = new ArrayList<>(src.data.length); + ValidateParameters.ensureValidArrayIndices(src.numRows, colIdxs); - for(int i = 0; i destEntries = new ArrayList<>(src.data.length); + List destRowIdx = new ArrayList<>(src.data.length); + List destColIdx = new ArrayList<>(src.data.length); + + for (int i = 0; i < src.data.length; i++) { + int oldCol = src.colIndices[i]; + + // Check if oldCol is being removed. + int idx = Arrays.binarySearch(colIdxs, oldCol); + + // If idx >= 0, oldCol is in colIdxs then skip this entry + if (idx >= 0) continue; + + // Otherwise, shift column index. + int insertionPoint = -idx - 1; + int newCol = oldCol - insertionPoint; + + destEntries.add(src.data[i]); + destRowIdx.add(src.rowIndices[i]); + destColIdx.add(newCol); } - return new CooMatrix(shape, entries, rowIndices, colIndices); + return new CooMatrix(shape, destEntries, destRowIdx, destColIdx); } @@ -213,9 +254,9 @@ public static CooMatrix swapRows(CooMatrix src, int rowIdx1, int rowIdx2) { */ public static CooMatrix swapCols(CooMatrix src, int colIdx1, int colIdx2) { for(int i = 0; i diagTol ) { + for(int i=0, size=src.data.length; i diagTol ) { return false; // Diagonal value is not close to one. - } else if((src.rowIndices[i] != i && src.colIndices[i] != i) && Math.abs(src.data[i]) > nonDiagTol) { + } else if(row != col && Math.abs(src.data[i]) > nonDiagTol) { return false; // Non-diagonal value is not close to zero. } } - return true; } /** - * Checks if a real sparse matrix is symmetric. - * @param src Matrix to check if it is the symmetric matrix. - * @return True if the {@code src} matrix is symmetric. False otherwise. + * Checks if a sparse COO matrix is symmetric. + * @param shape The shape of the COO matrix. + * @param data Non-zero entries of the COO matrix. + * @param rowIndices Non-zero row indices of the COO matrix. + * @param colIndices Non-zero column indices of the COO matrix. + * @return {@code true} if the specified COO matrix is symmetric + * (i.e. equal to its transpose); {@code false} otherwise. */ - public static boolean isSymmetric(CooMatrix src) { - boolean result = src.isSquare(); - - List entries = DoubleStream.of(src.data).boxed().collect(Collectors.toList()); - List rowIndices = IntStream.of(src.rowIndices).boxed().collect(Collectors.toList()); - List colIndices = IntStream.of(src.colIndices).boxed().collect(Collectors.toList()); - - double value; - int row; - int col; - - while(result && !entries.isEmpty()) { - // Extract value of interest. - value = entries.remove(0); - row = rowIndices.remove(0); - col = colIndices.remove(0); - - // Find indices of first and last value whose row index matched the value of interests column index. - int rowStart = rowIndices.indexOf(col); - int rowEnd = rowIndices.lastIndexOf(col); - - if(rowStart == -1) { - // Then no non-zero value was found. - result = value == 0; + public static boolean isSymmetric(Shape shape, double[] data, int[] rowIndices, int[] colIndices) { + if(shape.get(0) != shape.get(1)) return false; // Early return for non-square matrix. + + Map, Double> dataMap = new HashMap, Double>(); + + for(int i = 0, size=data.length; i < size; i++) { + if(rowIndices[i] == colIndices[i] || data[i] == 0d) + continue; // This value is zero or on the diagonal. No need to consider. + + var p1 = new Pair<>(rowIndices[i], colIndices[i]); + var p2 = new Pair<>(colIndices[i], rowIndices[i]); + + if(!dataMap.containsKey(p2)) { + dataMap.put(p1, data[i]); + } else if(dataMap.get(p2) != data[i]){ + return false; // Not symmetric. } else { - // At least one entry has a row-index matching the specified column index. - List colIdxRange = colIndices.subList(rowStart, rowEnd + 1); - - // Search for element whose column index matches the specified row index - int idx = colIdxRange.indexOf(row); - - if(idx == -1) { - // Then no non-zero value was found. - result = value == 0; - } else { - // Check that value with opposite row/column indices is equal. - result = value == entries.get(idx + rowStart); - - // Remove the value and the indices. - entries.remove(idx + rowStart); - rowIndices.remove(idx + rowStart); - colIndices.remove(idx + rowStart); - } + dataMap.remove(p2); } } - return result; + // If there are any remaining values a value with the transposed indices was not found in the matrix. + return dataMap.isEmpty(); } /** - * Checks if a real sparse matrix is anti-symmetric. - * @param src Matrix to check if it is the anti-symmetric matrix. - * @return True if the {@code src} matrix is anti-symmetric. False otherwise. + * Checks if a sparse COO matrix is symmetric. + * @param shape The shape of the COO matrix. + * @param data Non-zero entries of the COO matrix. + * @param rowIndices Non-zero row indices of the COO matrix. + * @param colIndices Non-zero column indices of the COO matrix. + * @return {@code true} if the specified COO matrix is symmetric + * (i.e. equal to its transpose); {@code false} otherwise. */ - public static boolean isAntiSymmetric(CooMatrix src) { - boolean result = src.isSquare(); - - List entries = DoubleStream.of(src.data).boxed().collect(Collectors.toList()); - List rowIndices = IntStream.of(src.rowIndices).boxed().collect(Collectors.toList()); - List colIndices = IntStream.of(src.colIndices).boxed().collect(Collectors.toList()); - - double value; - int row; - int col; - - while(result && !entries.isEmpty()) { - // Extract value of interest. - value = entries.remove(0); - row = rowIndices.remove(0); - col = colIndices.remove(0); - - // Find indices of first and last value whose row index matched the value of interests column index. - int rowStart = rowIndices.indexOf(col); - int rowEnd = rowIndices.lastIndexOf(col); - - if(rowStart == -1) { - // Then no non-zero value was found. - result = value == 0; + public static boolean isAntiSymmetric(Shape shape, double[] data, int[] rowIndices, int[] colIndices) { + if(shape.get(0) != shape.get(1)) return false; // Early return for non-square matrix. + + Map, Double> dataMap = new HashMap, Double>(); + + for(int i = 0, size=data.length; i < size; i++) { + if(rowIndices[i] == colIndices[i] || data[i] == 0d) + continue; // This value is zero or on the diagonal. No need to consider. + + var p1 = new Pair<>(rowIndices[i], colIndices[i]); + var p2 = new Pair<>(colIndices[i], rowIndices[i]); + + if(!dataMap.containsKey(p2)) { + dataMap.put(p1, data[i]); + } else if(dataMap.get(p2) != -data[i]){ + return false; // Not symmetric. } else { - // At least one entry has a row-index matching the specified column index. - List colIdxRange = colIndices.subList(rowStart, rowEnd + 1); - - // Search for element whose column index matches the specified row index - int idx = colIdxRange.indexOf(row); - - if(idx == -1) { - // Then no non-zero value was found. - result = value == 0; - } else { - // Check that value with opposite row/column indices is equal. - result = value == -entries.get(idx + rowStart); - - // Remove the value and the indices. - entries.remove(idx + rowStart); - rowIndices.remove(idx + rowStart); - colIndices.remove(idx + rowStart); - } + dataMap.remove(p2); } } - return result; + // If there are any remaining values a value with the transposed indices was not found in the matrix. + return dataMap.isEmpty(); } } diff --git a/src/main/java/org/flag4j/linalg/ops/sparse/coo/ring_ops/CooRingEquals.java b/src/main/java/org/flag4j/linalg/ops/sparse/coo/ring_ops/CooRingEquals.java deleted file mode 100644 index 2ca9fbabf..000000000 --- a/src/main/java/org/flag4j/linalg/ops/sparse/coo/ring_ops/CooRingEquals.java +++ /dev/null @@ -1,108 +0,0 @@ -/* - * MIT License - * - * Copyright (c) 2024. Jacob Watters - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -package org.flag4j.linalg.ops.sparse.coo.ring_ops; - - -import org.flag4j.algebraic_structures.Ring; -import org.flag4j.arrays.backend.ring_arrays.AbstractCooRingMatrix; -import org.flag4j.arrays.backend.ring_arrays.AbstractCooRingTensor; -import org.flag4j.arrays.backend.ring_arrays.AbstractCooRingVector; - -import java.util.Arrays; - -/** - * A utility class for checking equality between COO {@link Ring} tensors. - */ -public class CooRingEquals { - - private CooRingEquals() { - // Hide default constructor. - } - - /** - * Checks if two real sparse tensors are real. Assumes the indices of each sparse tensor are sorted. Any explicitly stored - * zero's will be ignored. - * @param a First tensor in the equality check. - * @param b Second tensor in the equality check. - * @return True if the tensors are equal. False otherwise. - */ - public static > boolean cooTensorEquals( - AbstractCooRingTensor a, - AbstractCooRingTensor b) { - if (a == b) return true; - if (a == null || b == null) return false; - - a = a.coalesce().dropZeros(); - b = b.coalesce().dropZeros(); - return a.shape.equals(b.shape) - && Arrays.equals(a.data, b.data) - && Arrays.deepEquals(a.indices, b.indices); - } - - - /** - * Checks if two real sparse matrices are real. Assumes the indices of each sparse matrix are sorted. Any explicitly stored - * zero's will be ignored. - * @param a First matrix in the equality check. - * @param b Second matrix in the equality check. - * @return True if the matrices are equal. False otherwise. - */ - public static > boolean cooMatrixEquals( - AbstractCooRingMatrix a, - AbstractCooRingMatrix b) { - // Early return if possible. - if (a == b) return true; - if (a == null || b == null) return false; - - a = a.coalesce().dropZeros(); - b = b.coalesce().dropZeros(); - return a.shape.equals(b.shape) - && Arrays.equals(a.data, b.data) - && Arrays.equals(a.rowIndices, b.rowIndices) - && Arrays.equals(a.colIndices, b.colIndices); - } - - - /** - * Checks if two real sparse vectors are real. Assumes the indices of each sparse vector are sorted. Any explicitly stored - * zero's will be ignored. - * @param a First vector in the equality check. - * @param b Second vector in the equality check. - * @return True if the vectors are equal. False otherwise. - */ - public static > boolean cooVectorEquals( - AbstractCooRingVector a, - AbstractCooRingVector b) { - // Early returns if possible. - if(a == b) return true; - if(a==null || b==null || !a.shape.equals(b.shape)) return false; - - a = a.coalesce().dropZeros(); - b = b.coalesce().dropZeros(); - return a.shape.equals(b.shape) - && Arrays.equals(a.data, b.data) - && Arrays.equals(a.indices, b.indices); - } -} diff --git a/src/main/java/org/flag4j/linalg/ops/sparse/coo/field_ops/CooFieldHermTranspose.java b/src/main/java/org/flag4j/linalg/ops/sparse/coo/ring_ops/CooRingHermTranspose.java similarity index 90% rename from src/main/java/org/flag4j/linalg/ops/sparse/coo/field_ops/CooFieldHermTranspose.java rename to src/main/java/org/flag4j/linalg/ops/sparse/coo/ring_ops/CooRingHermTranspose.java index 9c44c4ad3..48e68dab6 100644 --- a/src/main/java/org/flag4j/linalg/ops/sparse/coo/field_ops/CooFieldHermTranspose.java +++ b/src/main/java/org/flag4j/linalg/ops/sparse/coo/ring_ops/CooRingHermTranspose.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -22,22 +22,22 @@ * SOFTWARE. */ -package org.flag4j.linalg.ops.sparse.coo.field_ops; +package org.flag4j.linalg.ops.sparse.coo.ring_ops; - -import org.flag4j.algebraic_structures.Field; +import org.flag4j.algebraic_structures.Ring; import org.flag4j.arrays.Shape; import org.flag4j.linalg.ops.sparse.coo.CooDataSorter; import org.flag4j.util.ArrayUtils; import org.flag4j.util.ValidateParameters; + /** - * Utility class for computing the Hermitian transpose of a COO matrix or tensor. + * Utility class for computing the Hermitian transpose of a COO ring matrix or tensor. */ -public final class CooFieldHermTranspose { +public final class CooRingHermTranspose { - private CooFieldHermTranspose() { - + private CooRingHermTranspose() { + // Hide default constructor for utility class. } /** @@ -55,9 +55,9 @@ private CooFieldHermTranspose() { * {@code [srcEntries.length][shape.getRank()]} * * @throws IndexOutOfBoundsException If either {@code axis1} or {@code axis2} are out of bounds for the rank of this tensor. - * @see #tensorTranspose(Shape, Object[], int[][], int[], Object[], int[][]) + * @see #tensorHermTranspose(Shape, Ring[], int[][], int[], Ring[], int[][]) */ - public static > void tensorHermTranspose( + public static > void tensorHermTranspose( Shape shape, T[] srcEntries, int[][] srcIndices, int axis1, int axis2, T[] destEntries, int[][] destIndices) { @@ -98,9 +98,9 @@ public static > void tensorHermTranspose( * @throws IndexOutOfBoundsException If any element of {@code axes} is out of bounds for the rank of this tensor. * @throws IllegalArgumentException If {@code axes} is not a permutation of {@code {1, 2, 3, ... N-1}}. */ - public static > void tensorHermTranspose( + public static > void tensorHermTranspose( Shape shape, T[] srcEntries, int[][] srcIndices, int[] axes, - T[] destEntries, int[][] destIndices) { + T[] destEntries, int[][] destIndices) { int rank = shape.getRank(); ValidateParameters.ensurePermutation(axes, rank); diff --git a/src/main/java/org/flag4j/linalg/ops/sparse/coo/ring_ops/CooRingMatrixOps.java b/src/main/java/org/flag4j/linalg/ops/sparse/coo/ring_ops/CooRingMatrixOps.java index d8c5cacf9..bdb87c683 100644 --- a/src/main/java/org/flag4j/linalg/ops/sparse/coo/ring_ops/CooRingMatrixOps.java +++ b/src/main/java/org/flag4j/linalg/ops/sparse/coo/ring_ops/CooRingMatrixOps.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -25,12 +25,16 @@ package org.flag4j.linalg.ops.sparse.coo.ring_ops; import org.flag4j.algebraic_structures.Ring; +import org.flag4j.arrays.Pair; import org.flag4j.arrays.Shape; import org.flag4j.arrays.SparseMatrixData; +import org.flag4j.arrays.backend.ring_arrays.AbstractCooRingMatrix; import org.flag4j.util.ValidateParameters; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; /** * Utility class for computing ops on sparse COO {@link Ring} matrices. @@ -115,4 +119,70 @@ public static > SparseMatrixData sub( return new SparseMatrixData(shape1, diff, rowIndices, colIndices); } + + + /** + * Checks if a real sparse matrix is close to the identity matrix. + * @param src Matrix to check if it is the identity matrix. + * @return {@code true} if the {@code src} matrix is the identity matrix; {@code false} otherwise. + */ + public static > boolean isCloseToIdentity(AbstractCooRingMatrix src) { + // Ensure the matrix is square and there are the same number of non-zero data as data on the diagonal. + if(!src.isSquare() || src.data.length < src.numRows) return false; + + // Tolerances corresponds to the allClose(...) methods. + double diagTol = 1.E-5; + double nonDiagTol = 1e-08; + + final T ONE = src.data.length > 0 ? src.data[0].getOne() : null; + + for(int i=0, size=src.data.length; i diagTol ) { + return false; // Diagonal value is not close to one. + } else if(row != col && src.data[i].mag() > nonDiagTol) { + return false; // Non-diagonal value is not close to zero. + } + } + + return true; + } + + + /** + * Checks if a sparse COO {@link Ring} matrix is Hermitian. + * @param shape The shape of the COO matrix. + * @param data Non-zero entries of the COO matrix. + * @param rowIndices Non-zero row indices of the COO matrix. + * @param colIndices Non-zero column indices of the COO matrix. + * @return {@code true} if the specified COO matrix is Hermitian + * (i.e. equal to its conjugate transpose); {@code false} otherwise. + * @param The ring to which the data values of the COO matrix belong. + */ + public static > boolean isHermitian(Shape shape, T[] data, int[] rowIndices, int[] colIndices) { + if(shape.get(0) != shape.get(1)) return false; // Early return for non-square matrix. + + Map, T> dataMap = new HashMap, T>(); + + for(int i = 0, size=data.length; i < size; i++) { + if(rowIndices[i] == colIndices[i] || data[i].isZero()) + continue; // This value is zero or on the diagonal. No need to consider. + + var p1 = new Pair<>(rowIndices[i], colIndices[i]); + var p2 = new Pair<>(colIndices[i], rowIndices[i]); + + if(!dataMap.containsKey(p2)) { + dataMap.put(p1, data[i]); + } else if(!dataMap.get(p2).equals(data[i].conj())){ + return false; // Not Hermitian. + } else { + dataMap.remove(p2); + } + } + + // If there are any remaining values a value with the transposed indices was not found in the matrix. + return dataMap.isEmpty(); + } } diff --git a/src/main/java/org/flag4j/linalg/ops/sparse/coo/field_ops/CooFieldNorms.java b/src/main/java/org/flag4j/linalg/ops/sparse/coo/ring_ops/CooRingNorms.java similarity index 64% rename from src/main/java/org/flag4j/linalg/ops/sparse/coo/field_ops/CooFieldNorms.java rename to src/main/java/org/flag4j/linalg/ops/sparse/coo/ring_ops/CooRingNorms.java index 802dc3fdb..7ba124ac1 100644 --- a/src/main/java/org/flag4j/linalg/ops/sparse/coo/field_ops/CooFieldNorms.java +++ b/src/main/java/org/flag4j/linalg/ops/sparse/coo/ring_ops/CooRingNorms.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -22,10 +22,10 @@ * SOFTWARE. */ -package org.flag4j.linalg.ops.sparse.coo.field_ops; +package org.flag4j.linalg.ops.sparse.coo.ring_ops; -import org.flag4j.algebraic_structures.Field; -import org.flag4j.arrays.backend.field_arrays.AbstractCooFieldMatrix; +import org.flag4j.algebraic_structures.Ring; +import org.flag4j.arrays.backend.ring_arrays.AbstractCooRingMatrix; import org.flag4j.util.ArrayUtils; import org.flag4j.util.ValidateParameters; @@ -33,21 +33,23 @@ /** - * This utility class contains low level implementations of norms for sparse field coo tensors, matrices and vectors. + * This utility class contains low level implementations of norms for sparse COO ring tensors, matrices and vectors. */ -public final class CooFieldNorms { +public final class CooRingNorms { - private CooFieldNorms() { + private CooRingNorms() { // Hide default constructor for utility class. } /** - * Computes the L2 norm of a matrix. + * Computes the L2, 2 norm of a matrix. Also called the Frobenius norm. * @param src Source matrix to compute norm of. - * @return The L2 of the {@code src} matrix. + * @return The L2, 2 of the {@code src} matrix. */ - public static > double matrixNormL2(AbstractCooFieldMatrix src) { + public static > double matrixNormL22(AbstractCooRingMatrix src) { + ValidateParameters.ensureSquare(src.shape); + double norm = 0; double[] colSums = new double[ArrayUtils.numUnique(src.colIndices)]; @@ -68,40 +70,13 @@ public static > double matrixNormL2(AbstractCooFieldMatrixp norm of a matrix. - * @param src Source matrix to compute norm of. - * @param p Parameter for Lp norm - * @return The Lp of the {@code src} matrix. - */ - public static > double matrixNormLp(AbstractCooFieldMatrix src, double p) { - ValidateParameters.ensureGreaterEq(1, p); - - double norm = 0; - double[] colSums = new double[ArrayUtils.numUnique(src.colIndices)]; - - // Create a mapping from the unique column indices to a unique position in the colSums array. - HashMap columnMap = ArrayUtils.createUniqueMapping(src.colIndices); - - // Compute the column sums. - for(int i = 0; ip, q norm of a matrix. * @param src Source matrix to compute norm of. * @param p First parameter for Lp, q norm * @return The Lp, q of the {@code src} matrix. */ - public static > double matrixNormLpq(AbstractCooFieldMatrix src, double p, double q) { + public static > double matrixNormLpq(AbstractCooRingMatrix src, double p, double q) { ValidateParameters.ensureGreaterEq(1, p, q); double norm = 0; diff --git a/src/main/java/org/flag4j/linalg/ops/sparse/coo/ring_ops/CooRingVectorOps.java b/src/main/java/org/flag4j/linalg/ops/sparse/coo/ring_ops/CooRingVectorOps.java index dc348394b..7a1de5b33 100644 --- a/src/main/java/org/flag4j/linalg/ops/sparse/coo/ring_ops/CooRingVectorOps.java +++ b/src/main/java/org/flag4j/linalg/ops/sparse/coo/ring_ops/CooRingVectorOps.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -27,6 +27,7 @@ import org.flag4j.algebraic_structures.Ring; import org.flag4j.arrays.Shape; import org.flag4j.arrays.SparseVectorData; +import org.flag4j.arrays.backend.ring_arrays.AbstractCooRingVector; import org.flag4j.util.ValidateParameters; import java.util.ArrayList; @@ -94,4 +95,37 @@ public static > SparseVectorData sub( return new SparseVectorData(shape1, values, indices); } + + + /** + * Computes the inner product of two complex sparse vectors. Both sparse vectors are assumed + * to have their indices sorted lexicographically. + * @param src1 First sparse vector in the inner product. Indices assumed to be sorted lexicographically. + * @param src2 Second sparse vector in the inner product. Indices assumed to be sorted lexicographically. + * @return The result of the vector inner product. + * @throws IllegalArgumentException If the two vectors do not have the same size (full size including zeros). + */ + public static > T inner( + AbstractCooRingVector src1, + AbstractCooRingVector src2) { + ValidateParameters.ensureEqualShape(src1.shape, src2.shape); + + T product = src1.getZeroElement(); + + int src1Counter = 0; + int src2Counter = 0; + + while(src1Counter < src1.data.length && src2Counter < src2.data.length) { + if(src1.indices[src1Counter]==src2.indices[src2Counter]) { + // Then indices match, add product of elements. + product = product.add(src1.data[src1Counter].mult(src2.data[src2Counter].conj())); + } else if(src1.indices[src1Counter] < src2.indices[src2Counter]) { + src1Counter++; + } else { + src2Counter++; + } + } + + return product; + } } diff --git a/src/main/java/org/flag4j/linalg/ops/sparse/coo/semiring_ops/CooSemiringEquals.java b/src/main/java/org/flag4j/linalg/ops/sparse/coo/semiring_ops/CooSemiringEquals.java index 0202406de..5ea0fb831 100644 --- a/src/main/java/org/flag4j/linalg/ops/sparse/coo/semiring_ops/CooSemiringEquals.java +++ b/src/main/java/org/flag4j/linalg/ops/sparse/coo/semiring_ops/CooSemiringEquals.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -77,6 +77,7 @@ public static > boolean cooMatrixEquals( a = a.coalesce().dropZeros(); b = b.coalesce().dropZeros(); + return a.shape.equals(b.shape) && Arrays.equals(a.data, b.data) && Arrays.equals(a.rowIndices, b.rowIndices) diff --git a/src/main/java/org/flag4j/linalg/ops/sparse/coo/semiring_ops/CooSemiringMatrixOps.java b/src/main/java/org/flag4j/linalg/ops/sparse/coo/semiring_ops/CooSemiringMatrixOps.java index 26d2c3033..2f5bb3b16 100644 --- a/src/main/java/org/flag4j/linalg/ops/sparse/coo/semiring_ops/CooSemiringMatrixOps.java +++ b/src/main/java/org/flag4j/linalg/ops/sparse/coo/semiring_ops/CooSemiringMatrixOps.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -27,6 +27,7 @@ import org.flag4j.algebraic_structures.Semiring; import org.flag4j.arrays.Shape; import org.flag4j.arrays.SparseMatrixData; +import org.flag4j.arrays.backend.semiring_arrays.AbstractCooSemiringMatrix; import org.flag4j.util.ValidateParameters; import java.util.ArrayList; @@ -170,4 +171,26 @@ public static > SparseMatrixData elemMult( return new SparseMatrixData(shape1, product, rowIndices, colIndices); } + + + /** + * Checks if a complex sparse matrix is the identity matrix. + * @param src Matrix to check if it is the identity matrix. + * @return {@code true} if the {@code src} matrix is the identity matrix; {@code false} otherwise. + */ + public static > boolean isIdentity(AbstractCooSemiringMatrix src) { + // Ensure the matrix is square and there are at least the same number of non-zero data as data on the diagonal. + if(!src.isSquare() || src.data.length> boolean isIdentity( return true; // If we make it to this point the matrix must be an identity matrix. } - - - /** - * Checks if a sparse matrix is symmetric. - * @param shape Shape of the matrix. - * @param entries Non-zero data of the matrix. - * @param rowIndices Non-zero row indices of the matrix. - * @param colIndices Non-zero column indices of the matrix. - * @return True if the {@code src} matrix is hermitian. False otherwise. - */ - public static > boolean isSymmetric( - Shape shape, T[] entries, int[] rowIndices, int[] colIndices) { - if (shape.get(0) != shape.get(1)) return false; // Quick return for non-square matrix. - - List entriesList = Arrays.asList(entries); - List rowIndicesList = IntStream.of(rowIndices).boxed().collect(Collectors.toList()); - List colIndicesList = IntStream.of(colIndices).boxed().collect(Collectors.toList()); - - boolean result = true; - - while(result && entriesList.size() > 0) { - // Extract value of interest. - T value = entriesList.remove(0); - int row = rowIndicesList.remove(0); - int col = colIndicesList.remove(0); - - // Find indices of first and last value whose row index matched the value of interests column index. - int rowStart = rowIndicesList.indexOf(col); - int rowEnd = rowIndicesList.lastIndexOf(col); - - if(rowStart == -1) { - // Then no non-zero value was found. - result = value.equals(0); - } else { - // At least one entry has a row-index matching the specified column index. - List colIdxRange = colIndicesList.subList(rowStart, rowEnd + 1); - - // Search for element whose column index matches the specified row index - int idx = colIdxRange.indexOf(row); - - if(idx == -1) { - // Then no non-zero value was found. - result = value.equals(0); - } else { - // Check that value with opposite row/column indices is equal. - result = value.equals(entriesList.get(idx + rowStart)); - - // Remove the value and the indices. - entriesList.remove(idx + rowStart); - rowIndicesList.remove(idx + rowStart); - colIndicesList.remove(idx + rowStart); - } - } - } - - return result; - } } diff --git a/src/main/java/org/flag4j/linalg/ops/sparse/coo/semiring_ops/CooSemiringTensorOps.java b/src/main/java/org/flag4j/linalg/ops/sparse/coo/semiring_ops/CooSemiringTensorOps.java index 8316b4062..11d69ca86 100644 --- a/src/main/java/org/flag4j/linalg/ops/sparse/coo/semiring_ops/CooSemiringTensorOps.java +++ b/src/main/java/org/flag4j/linalg/ops/sparse/coo/semiring_ops/CooSemiringTensorOps.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -40,7 +40,6 @@ public final class CooSemiringTensorOps { private CooSemiringTensorOps() { // Hide constructor for utility class. for utility class. - } diff --git a/src/main/java/org/flag4j/linalg/ops/sparse/csr/CsrConversions.java b/src/main/java/org/flag4j/linalg/ops/sparse/csr/CsrConversions.java index 34a9582ca..63c34f44f 100644 --- a/src/main/java/org/flag4j/linalg/ops/sparse/csr/CsrConversions.java +++ b/src/main/java/org/flag4j/linalg/ops/sparse/csr/CsrConversions.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/main/java/org/flag4j/linalg/ops/sparse/csr/CsrOps.java b/src/main/java/org/flag4j/linalg/ops/sparse/csr/CsrOps.java index c32c183c1..24f27e791 100644 --- a/src/main/java/org/flag4j/linalg/ops/sparse/csr/CsrOps.java +++ b/src/main/java/org/flag4j/linalg/ops/sparse/csr/CsrOps.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -27,6 +27,7 @@ import org.flag4j.algebraic_structures.Field; import org.flag4j.arrays.Shape; import org.flag4j.arrays.SparseMatrixData; +import org.flag4j.linalg.ops.sparse.SparseUtils; import org.flag4j.util.ArrayUtils; import org.flag4j.util.ValidateParameters; @@ -354,8 +355,8 @@ private static void moveAndShiftRight(T[] entries, int[] rowPointers, int[] // Shift data in row to right. for(int j=currPos; j>newPos; j--) { - entries[j] = entries[j-1]; - colIndices[j] = colIndices[j-1]; + entries[j] = entries[j - 1]; + colIndices[j] = colIndices[j - 1]; } entries[newPos] = value; // Move non-zero value to new location. @@ -393,6 +394,7 @@ private static void moveAndShiftLeft(T[] entries, int[] rowPointers, int[] c /** * Gets a specified slice of a CSR matrix. * + * @param shape Shape of the CSR matrix. * @param entries Non-zero data of the CSR matrix. * @param rowPointers Non-zero row pointers of the CSR matrix. * @param colIndices Non-zero column indices of the CSR matrix. @@ -405,9 +407,10 @@ private static void moveAndShiftLeft(T[] entries, int[] rowPointers, int[] c * @throws IllegalArgumentException If {@code rowEnd} is not greater than {@code rowStart} * or if {@code colEnd} is not greater than {@code colStart}. */ - public static SparseMatrixData getSlice(T[] entries, int[] rowPointers, int[] colIndices, + public static SparseMatrixData getSlice(Shape shape, T[] entries, int[] rowPointers, int[] colIndices, int rowStart, int rowEnd, int colStart, int colEnd) { + SparseUtils.validateSlice(shape, rowStart, rowEnd, colStart, colEnd); List slice = new ArrayList<>(); List sliceRowIndices = new ArrayList<>(); List sliceColIndices = new ArrayList<>(); @@ -425,13 +428,33 @@ public static SparseMatrixData getSlice(T[] entries, int[] rowPointers, i // Add value if it is within the slice. if(col >= colStart && col < colEnd) { slice.add(entries[j]); - sliceRowIndices.add(i); - sliceColIndices.add(col); + sliceRowIndices.add(i-rowStart); + sliceColIndices.add(col-colStart); } } } + // Matrix has been constructed as COO matrix. Now must be converted to CSR matrix. + int size = rowEnd-rowStart + 1; + List sliceRowPointers = new ArrayList<>(size); + for(int i=0; i(new Shape(rowEnd-rowStart, colEnd-colStart), - slice, sliceRowIndices, sliceColIndices); + slice, sliceRowPointers, sliceColIndices); } } diff --git a/src/main/java/org/flag4j/linalg/ops/sparse/csr/CsrProperties.java b/src/main/java/org/flag4j/linalg/ops/sparse/csr/CsrProperties.java index 212e04512..bc221b61b 100644 --- a/src/main/java/org/flag4j/linalg/ops/sparse/csr/CsrProperties.java +++ b/src/main/java/org/flag4j/linalg/ops/sparse/csr/CsrProperties.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -40,14 +40,19 @@ private CsrProperties() { /** - * Checks if a CSR matrix is symmetric. + * Checks if a sparse CSR matrix is symmetric. * @param shape Shape of the CSR matrix. * @param values Non-zero values of a CSR matrix. * @param rowPointers Non-zero row pointers of the CSR matrix. * @param colIndices Non-zero column indices of the CSR matrix. + * @param zeroValue Any value in {@code values} equal to {@code zeroValue} + * will be considered zero and will not be considered when determining the symmetry. Equality is determined according to + * {@link Objects#equals(Object, Object)} where if one of the parameters is {@code null} then the result will always be {@code + * false}. This means passing {@code zeroValue = null} will result in all items in {@code values} being considered. This is + * useful if there is no definable zero value for the values of the CSR matrix. * @return {@code true} if the CSR matrix is symmetric; {@code false} otherwise. */ - public static boolean isSymmetric(Shape shape, T[] values, int[] rowPointers, int[] colIndices) { + public static boolean isSymmetric(Shape shape, T[] values, int[] rowPointers, int[] colIndices, T zeroValue) { int numRows = shape.get(0); int numCols = shape.get(1); @@ -60,7 +65,7 @@ public static boolean isSymmetric(Shape shape, T[] values, int[] rowPointers for (int idx = rowStart; idx < rowEnd; idx++) { int j = colIndices[idx]; - if (j >= i) { + if (j >= i && !Objects.equals(values[idx], zeroValue)) { T val1 = values[idx]; // Search for the value with swapped row and column indices. diff --git a/src/main/java/org/flag4j/linalg/ops/sparse/csr/field_ops/CsrFieldMatrixProperties.java b/src/main/java/org/flag4j/linalg/ops/sparse/csr/field_ops/CsrFieldMatrixProperties.java deleted file mode 100644 index 7d7757131..000000000 --- a/src/main/java/org/flag4j/linalg/ops/sparse/csr/field_ops/CsrFieldMatrixProperties.java +++ /dev/null @@ -1,154 +0,0 @@ -/* - * MIT License - * - * Copyright (c) 2024. Jacob Watters - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -package org.flag4j.linalg.ops.sparse.csr.field_ops; - -import org.flag4j.algebraic_structures.Field; -import org.flag4j.arrays.backend.field_arrays.AbstractCsrFieldMatrix; - -/** - * This utility class contains methods usefully for determining properties of a sparse CSR - * {@link Field} matrix. - */ -public final class CsrFieldMatrixProperties { - - private CsrFieldMatrixProperties() { - // Hide default constructor for utility class. - } - - - /** - * Checks if the {@code src} matrix is the identity matrix. - * @param src The matrix to check if it is the identity matrix. - * @return True if the {@code src} matrix is the identity matrix. False otherwise. - */ - public static > boolean isIdentity(AbstractCsrFieldMatrix src) { - if(src.isSquare() && src.colIndices.length >= src.numCols) { - int diagCount = 0; - - for(int i=0; i> boolean isCloseToIdentity(AbstractCsrFieldMatrix src) { - if(src.isSquare() && src.colIndices.length >= src.numCols) { - // Tolerances corresponds to the allClose(...) methods. - double diagTol = 1.001E-5; - double nonDiagTol = 1e-08; - int diagCount = 0; - - final T ONE = src.nnz > 0 ? src.data[0].getOne() : null; - - for(int i=0; i diagTol) { - if(src.colIndices[j] != i) return false; // Diagonal value not close to one. - diagCount++; - } else if(src.data[i].abs() > nonDiagTol) { - return false; // Non-diagonal value is not close to one. - } - } - } - - return diagCount == src.numCols; - } else { - return false; - } - } - - - /** - * Checks if the {@code src} matrix is symmetric. - * @param src Source matrix to check symmetry of. - * @return {@code true} if {@code src} is symmetric; {@code false} otherwise. - */ - public static > boolean isSymmetric(AbstractCsrFieldMatrix src) { - // Check for early returns. - if(!src.isSquare()) return false; - if(src.data.length == 0) return true; - - return src.T().equals(src); - } - - - /** - * Checks if the {@code src} matrix is anti-symmetric. - * @param src Source matrix to check symmetry of. - * @return {@code true} if {@code src} is symmetric; {@code false} otherwise. - */ - public static > boolean isAntiSymmetric(AbstractCsrFieldMatrix src) { - // Check for early returns. - if(!src.isSquare()) return false; - if(src.data.length == 0) return true; - - return src.T().mult(-1).equals(src); - } - - - /** - * Checks if the {@code src} matrix is Hermitian. - * @param src Source matrix to check. - * @return {@code true} if {@code src} is Hermitian; {@code false} otherwise. - */ - public static > boolean isHermitian(AbstractCsrFieldMatrix src) { - // Check for early returns. - if(!src.isSquare()) return false; - if(src.data.length == 0) return true; - - return src.H().equals(src); - } - - - /** - * Checks if the {@code src} matrix is anti-Hermitian. - * @param src Source matrix to check. - * @return {@code true} if {@code src} is Hermitian; {@code false} otherwise. - */ - public static > boolean isAntiHermitian(AbstractCsrFieldMatrix src) { - // Check for early returns. - if(!src.isSquare()) return false; - if(src.data.length == 0) return true; - - return src.H().mult(-1).equals(src); - } -} diff --git a/src/main/java/org/flag4j/linalg/ops/sparse/csr/real/RealCsrOps.java b/src/main/java/org/flag4j/linalg/ops/sparse/csr/real/RealCsrOps.java index f5c1f1393..f569e987c 100644 --- a/src/main/java/org/flag4j/linalg/ops/sparse/csr/real/RealCsrOps.java +++ b/src/main/java/org/flag4j/linalg/ops/sparse/csr/real/RealCsrOps.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -25,14 +25,10 @@ package org.flag4j.linalg.ops.sparse.csr.real; import org.flag4j.arrays.Shape; -import org.flag4j.arrays.dense.CMatrix; -import org.flag4j.arrays.dense.CVector; import org.flag4j.arrays.dense.Matrix; -import org.flag4j.arrays.dense.Vector; -import org.flag4j.arrays.sparse.CooCVector; import org.flag4j.arrays.sparse.CooMatrix; -import org.flag4j.arrays.sparse.CooVector; import org.flag4j.arrays.sparse.CsrMatrix; +import org.flag4j.linalg.ops.sparse.SparseUtils; import org.flag4j.util.ArrayUtils; import org.flag4j.util.ValidateParameters; @@ -43,7 +39,7 @@ /** - * This class contains low-level implementations for ops on CSR matrices. + * This utility class contains low-level implementations for operations on real CSR matrices. */ public final class RealCsrOps { @@ -57,8 +53,11 @@ private RealCsrOps() { * @param src1 First CSR matrix in the element-wise product. * @param src2 Second CSR matrix in the element-wise product. * @return The element-wise product of {@code src1} and {@code src2}. + * @throws org.flag4j.util.exceptions.TensorShapeException If {@code !src1.shape.equals(src2.shape)} */ public static CsrMatrix elemMult(CsrMatrix src1, CsrMatrix src2) { + ValidateParameters.ensureEqualShape(src1.shape, src2.shape); + int numRows = src1.numRows; int[] rowPointers = new int[numRows + 1]; List colIndices = new ArrayList<>(); @@ -188,9 +187,6 @@ public static CsrMatrix applyBinOpp(CsrMatrix src1, CsrMatrix src2, } - - - /** * Transposes a sparse CSR matrix. * @param src The matrix to transpose. @@ -242,6 +238,7 @@ public static CsrMatrix transpose(CsrMatrix src) { * @throws IllegalArgumentException If {@code rowEnd} is not greater than {@code rowStart} or if {@code colEnd} is not greater than {@code colStart}. */ public static CsrMatrix getSlice(CsrMatrix src, int rowStart, int rowEnd, int colStart, int colEnd) { + SparseUtils.validateSlice(src.shape, rowStart, rowEnd, colStart, colEnd); List slice = new ArrayList<>(); List sliceRowIndices = new ArrayList<>(); List sliceColIndices = new ArrayList<>(); @@ -259,8 +256,8 @@ public static CsrMatrix getSlice(CsrMatrix src, int rowStart, int rowEnd, int co // Add value if it is within the slice. if(col >= colStart && col < colEnd) { slice.add(src.data[j]); - sliceRowIndices.add(i); - sliceColIndices.add(col); + sliceRowIndices.add(i-rowStart); + sliceColIndices.add(col-colStart); } } } @@ -269,216 +266,6 @@ public static CsrMatrix getSlice(CsrMatrix src, int rowStart, int rowEnd, int co new Shape(rowEnd-rowStart, colEnd-colStart), slice, sliceRowIndices, sliceColIndices).toCsr(); } - - - /** - * Adds a vector to each column of a matrix. The vector need not be a column vector. If it is a row vector it will be - * treated as if it were a column vector. - * - * @param src1 CSR matrix to add vector to each column of. - * @param src2 Vector to add to each column of this matrix. - * @return The result of adding the vector src2 to each column of this matrix. - */ - public static Matrix addToEachCol(CsrMatrix src1, Vector src2) { - ValidateParameters.ensureEquals(src1.numRows, src2.size); - Matrix sum = src2.repeat(src1.numCols, 1); - - for(int i=0; i> boolean allClose( - AbstractCsrFieldMatrix src1, - AbstractCsrFieldMatrix src2, + public static > boolean allClose( + AbstractCsrRingMatrix src1, + AbstractCsrRingMatrix src2, double relTol, double absTol) { boolean close = src1.shape.equals(src2.shape); @@ -78,8 +81,8 @@ public static > boolean allClose( && Arrays.equals(ArrayUtils.fromIntegerList(src1ColIndices), ArrayUtils.fromIntegerList(src2ColIndices)) - && RingProperties.allClose(src1Entries.toArray(new Field[0]), - src2Entries.toArray(new Field[0]), relTol, absTol); + && RingProperties.allClose(src1Entries.toArray(new Ring[0]), + src2Entries.toArray(new Ring[0]), relTol, absTol); } return close; @@ -94,8 +97,8 @@ public static > boolean allClose( * @param rowPointers Row pointers for data. * @param aTol Absolute tolerance for value to be considered close to zero. */ - private static > void removeCloseToZero( - AbstractCsrFieldMatrix src, + private static > void removeCloseToZero( + AbstractCsrRingMatrix src, List entries, int[] rowPointers, List colIndices, double aTol) { for(int i=0, size=src.numRows; i> void removeCloseToZero( // Accumulate row pointers. int size = rowPointers.length-1; - for(int i=0; i> boolean isCloseToIdentity(AbstractCsrRingMatrix src) { + if(src.isSquare() && src.colIndices.length >= src.numCols) { + // Tolerances corresponds to the allClose(...) methods. + double diagTol = 1.001E-5; + double nonDiagTol = 1e-08; + int diagCount = 0; + + final T ONE = src.nnz > 0 ? src.data[0].getOne() : null; + + for(int i=0; i diagTol) { + if(src.colIndices[j] != i) return false; // Diagonal value not close to one. + diagCount++; + } else if(src.data[i].abs() > nonDiagTol) { + return false; // Non-diagonal value is not close to one. + } + } + } + + return diagCount == src.numCols; + } else { + return false; + } + } + + + /** + * Checks if a sparse CSR matrix is Hermitian. + * @param shape Shape of the CSR matrix. + * @param values Non-zero values of a CSR matrix. + * @param rowPointers Non-zero row pointers of the CSR matrix. + * @param colIndices Non-zero column indices of the CSR matrix. + * @return {@code true} if the CSR matrix is Hermitian (i.e. equal to its conjugate transpose); {@code false} otherwise. + */ + public static > boolean isHermitian(Shape shape, T[] values, int[] rowPointers, int[] colIndices) { + int numRows = shape.get(0); + int numCols = shape.get(1); + + if(numRows != numCols) return false; // Early return for non-square matrix. + + for (int i = 0; i < numRows; i++) { + int rowStart = rowPointers[i]; + int rowEnd = rowPointers[i + 1]; + + for (int idx = rowStart; idx < rowEnd; idx++) { + int j = colIndices[idx]; + + if (j >= i && !values[idx].isZero()) { + T val1 = values[idx]; + + // Search for the value with swapped row and column indices. + int pos = Arrays.binarySearch(colIndices, rowPointers[j], rowPointers[j + 1], i); + + if (pos >= 0) { + T val2 = values[pos]; + + // Ensure values are Equal + if (!Objects.equals(val1, val2.conj())) return false; + + } else { + // Corresponding value not found. + return false; + } + } + } } + + return true; } } diff --git a/src/main/java/org/flag4j/linalg/ops/sparse/csr/semiring_ops/SemiringCsrProperties.java b/src/main/java/org/flag4j/linalg/ops/sparse/csr/semiring_ops/SemiringCsrProperties.java index 132052cd1..812e8e0c0 100644 --- a/src/main/java/org/flag4j/linalg/ops/sparse/csr/semiring_ops/SemiringCsrProperties.java +++ b/src/main/java/org/flag4j/linalg/ops/sparse/csr/semiring_ops/SemiringCsrProperties.java @@ -1,7 +1,7 @@ /* * MIT License * - * Copyright (c) 2024. Jacob Watters + * Copyright (c) 2024-2025. Jacob Watters * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -26,6 +26,7 @@ import org.flag4j.algebraic_structures.Semiring; import org.flag4j.arrays.Shape; +import org.flag4j.arrays.backend.semiring_arrays.AbstractCsrSemiringMatrix; /** * Utility class containing methods useful for determining certain properties of a @@ -116,4 +117,31 @@ public static > boolean isIdentity(Shape shape, T[] entrie return diagCount == numCols; } + + + /** + * Checks if the {@code src} matrix is the identity matrix. + * @param src The matrix to check if it is the identity matrix. + * @return True if the {@code src} matrix is the identity matrix. False otherwise. + */ + public static > boolean isIdentity(AbstractCsrSemiringMatrix src) { + if(src.isSquare() && src.colIndices.length >= src.numCols) { + int diagCount = 0; + + for(int i=0; i 0) { - boolean equal = true; - double base = values[0]; - - for(double v : values) { - if(v != base) { - equal = false; - break; - } - } + boolean equal = allEqual(values); - if(!equal) - throw new IllegalArgumentException("Expecting values to be equal but got: " + Arrays.toString(values)); + if(!equal) { + throw new IllegalArgumentException("Expecting values to be equal but got: " + Arrays.toString(values)); } } @@ -240,19 +229,10 @@ public static void ensureEquals(double... values) { * @throws IllegalArgumentException If any of the specified values are not equal. */ public static void ensureEquals(int... values) { - if(values.length > 0) { - boolean equal = true; - double base = values[0]; - - for(double v : values) { - if(v != base) { - equal = false; - break; - } - } + boolean equal = allEqual(values); - if(!equal) - throw new IllegalArgumentException("Expecting values to be equal but got: " + Arrays.toString(values)); + if(!equal) { + throw new IllegalArgumentException("Expecting values to be equal but got: " + Arrays.toString(values)); } } @@ -436,7 +416,7 @@ public static void ensureSquareMatrix(Shape shape) { /** * Checks if a shape represents a square tensor. * @param shape Shape to check. - * @throws IllegalArgumentException If all axis of the shape are not the same length. + * @throws TensorShapeException If all axis of the shape are not the same length. */ public static void ensureSquare(Shape shape) { ValidateParameters.ensureEquals(shape.getDims()); @@ -465,9 +445,8 @@ public static void ensureSquareMatrix(int numRows, int numCols) { * @throws LinearAlgebraException If the specified shape does not have the expected rank. */ public static void ensureRank(Shape shape, int expRank) { - if(shape.getRank() != expRank) { + if(shape.getRank() != expRank) throw new LinearAlgebraException(ErrorMessages.shapeRankErr(shape.getRank(), expRank)); - } } @@ -477,9 +456,8 @@ public static void ensureRank(Shape shape, int expRank) { * @throws IllegalArgumentException If the axis is not a valid 2D axis. */ public static void ensureAxis2D(int axis) { - if(!(axis == 0 || axis==1)) { + if(!(axis == 0 || axis==1)) throw new IllegalArgumentException(ErrorMessages.getAxisErr(axis, 0, 1)); - } } @@ -490,7 +468,7 @@ public static void ensureAxis2D(int axis) { */ public static void ensurePermutation(int... axes) { if (axes == null) - throw new IllegalArgumentException("Array is not a permutation of v"); + throw new IllegalArgumentException("Array is not a permutation."); ensurePermutation(axes, axes.length); } @@ -605,9 +583,8 @@ public static void ensureValidArrayIndices(int length, int... indices) { * @param indices Indices to validate. */ public static void validateTensorIndices(Shape shape, int[]... indices) { - for(int[] index : indices) { + for(int[] index : indices) validateTensorIndex(shape, index); - } } @@ -655,4 +632,44 @@ public static void ensureLengthEqualsRank(Shape shape, int size) { if(shape.getRank() != size) throw new LinearAlgebraException("Array length of " + size + " does not match rank of " + shape.getRank()); } + + + /** + * Checks that all values in an array are equal. + * @param values Values of interest. + * @return {@code ture} if all entries in {@code values} are equal; {@code false} otherwise. + */ + private static boolean allEqual(int... values) { + for(int i=0, size=values.length-1; i boolean allEqual(T[] values) { + for(int i=0, size=values.length-1; i new Complex64("sdf")); + + // ---------- sub-case 11 ---------- + assertThrows(RuntimeException.class, () -> new Complex64("1.023*i")); + + // ---------- sub-case 12 ---------- + assertThrows(RuntimeException.class, () -> new Complex64("1.13 - 2ei")); + } +} diff --git a/src/test/java/org/flag4j/algebraic_structures/fields/Complex64ConversionTest.java b/src/test/java/org/flag4j/algebraic_structures/fields/Complex64ConversionTest.java new file mode 100644 index 000000000..32406a369 --- /dev/null +++ b/src/test/java/org/flag4j/algebraic_structures/fields/Complex64ConversionTest.java @@ -0,0 +1,83 @@ +package org.flag4j.algebraic_structures.fields; + +import org.flag4j.algebraic_structures.Complex64; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +class Complex64ConversionTest { + Complex64 n; + float[] expPolar, actPolar; + Complex64 expRect, actRect; + float[] polar; + + /* + Note: These test take into consideration precision errors from double floating point errors. + */ + + @Test + void toPolarTestCase() { + // --------------- Sub-case 1 --------------- + n = new Complex64(0); + expPolar = new float[]{0, 0}; + actPolar = n.toPolar(); + Assertions.assertArrayEquals(expPolar, actPolar); + + // --------------- Sub-case 2 --------------- + n = new Complex64(1, 3); + expPolar = new float[]{(float) Math.sqrt(10), (float) Math.atan(3)}; + actPolar = n.toPolar(); + Assertions.assertArrayEquals(expPolar, actPolar); + + // --------------- Sub-case 3 --------------- + n = new Complex64(2.42f, -1.35f); + expPolar = new float[]{2.771082820848197f, -0.5088510437828061f}; + actPolar = n.toPolar(); + Assertions.assertArrayEquals(expPolar, actPolar); + + // --------------- Sub-case 4 --------------- + n = new Complex64(1, 1); + expPolar = new float[]{(float) Math.sqrt(2), (float) (Math.PI/4.0)}; + actPolar = n.toPolar(); + Assertions.assertArrayEquals(expPolar, actPolar); + + // --------------- Sub-case 5 --------------- + n = new Complex64((float) (-Math.sqrt(3.0)/2.0), (float) (-1.0/2.0)); + expPolar = new float[]{0.9999999999999999f, (float) (-5.0*Math.PI/6.0)}; + actPolar = n.toPolar(); + Assertions.assertArrayEquals(expPolar, actPolar); + } + + + @Test + void fromPolarTestCase() { + // --------------- Sub-case 1 --------------- + expRect = new Complex64(0); + polar = new float[]{0, 0}; + actRect = Complex64.fromPolar(polar[0], polar[1]); + Assertions.assertEquals(expRect, actRect); + + // --------------- Sub-case 2 --------------- + expRect = new Complex64(1.0000001f, 3); + polar = new float[]{(float) Math.sqrt(10), (float) Math.atan(3)}; + actRect = Complex64.fromPolar(polar[0], polar[1]); + Assertions.assertEquals(expRect, actRect); + + // --------------- Sub-case 3 --------------- + expRect = new Complex64(2.42f, -1.3499999999999999f); + polar = new float[]{2.771082820848197f, -0.5088510437828061f}; + actRect = Complex64.fromPolar(polar[0], polar[1]); + Assertions.assertEquals(expRect, actRect); + + // --------------- Sub-case 4 --------------- + expRect = new Complex64(0.99999994f, 0.9999999403953552f); + polar = new float[]{(float) Math.sqrt(2), (float) (Math.PI/4.0)}; + actRect = Complex64.fromPolar(polar[0], polar[1]); + Assertions.assertEquals(expRect, actRect); + + // --------------- Sub-case 5 --------------- + expRect = new Complex64(-0.8660254037844387f, -0.5000000596046448f); + polar = new float[]{1, (float) (-5.0*Math.PI/6.0)}; + actRect = Complex64.fromPolar(polar[0], polar[1]); + Assertions.assertEquals(expRect, actRect); + } +} diff --git a/src/test/java/org/flag4j/algebraic_structures/fields/Complex64ExponentialTest.java b/src/test/java/org/flag4j/algebraic_structures/fields/Complex64ExponentialTest.java new file mode 100644 index 000000000..072b7a38f --- /dev/null +++ b/src/test/java/org/flag4j/algebraic_structures/fields/Complex64ExponentialTest.java @@ -0,0 +1,319 @@ +package org.flag4j.algebraic_structures.fields; + +import org.flag4j.algebraic_structures.Complex64; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +class Complex64ExponentialTest { + float a, b; + Complex64 aComplex, bComplex; + Complex64 expResult, actResult; + + @Test + void powerOneDoubleTestCase() { + // ------------ Sub-case 1 --------------- + a = 4; + bComplex = new Complex64(6, 9); + expResult = new Complex64(4079.525f, -367.0057067871094f); + actResult = Complex64.pow(a, bComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 2 --------------- + a = 7.243867f; + bComplex = new Complex64(-4.3f, 13.45f); + expResult = new Complex64( 1.4114029E-5f, 2.000083914026618E-4f); + actResult = Complex64.pow(a, bComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 3 --------------- + a = Float.NaN; + bComplex = new Complex64(-4.3f, 13.45f); + actResult = Complex64.pow(a, bComplex); + Assertions.assertTrue(Float.isNaN(actResult.re)); + Assertions.assertTrue(Float.isNaN(actResult.im)); + + // ------------ Sub-case 4 --------------- + a = 4.545f; + bComplex = new Complex64(2.34f); + expResult = new Complex64((float) Math.pow(a, 2.34f)); + actResult = Complex64.pow(a, bComplex); + Assertions.assertEquals(expResult, actResult); + } + + + @Test + void powTestCase() { + // ------------ Sub-case 1 --------------- + aComplex = new Complex64(5); + bComplex = new Complex64(3); + expResult = new Complex64((float) Math.pow(5, 3)); + actResult = Complex64.pow(aComplex, bComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 2 --------------- + aComplex = new Complex64(3.4113f); + bComplex = new Complex64(-6.133f, 1.3f); + expResult = new Complex64(-1.3164215E-5f, 5.388559657149017E-4f); + actResult = Complex64.pow(aComplex, bComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 3 --------------- + aComplex = new Complex64(5, 1.34f); + bComplex = new Complex64(3, 4); + expResult = new Complex64(22.98758f, 42.894046783447266f); + actResult = Complex64.pow(aComplex, bComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 4 --------------- + aComplex = new Complex64(-8.4f, 2.234f); + bComplex = new Complex64(1.65901f, -4.192436f); + expResult = new Complex64(-2644103.5f, 5805822.5f); + actResult = Complex64.pow(aComplex, bComplex); + Assertions.assertEquals(expResult, actResult); + } + + + @Test + void expTestCase() { + // ------------ Sub-case 1 --------------- + aComplex = new Complex64(5); + expResult = new Complex64((float) Math.exp(5)); + actResult = Complex64.exp(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 2 --------------- + aComplex = new Complex64(5, 1.34f); + expResult = new Complex64(33.949924f, 144.47792053222656f); + actResult = Complex64.exp(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 3 --------------- + aComplex = new Complex64(-23.23f, -13.32f); + expResult = new Complex64(5.9455524E-11f, -5.579295359048331E-11f); + actResult = Complex64.exp(aComplex); + Assertions.assertEquals(expResult, actResult); + } + + + @Test + void lnTestCase() { + // ------------ Sub-case 1 --------------- + aComplex = new Complex64(1); + expResult = new Complex64(0); + actResult = Complex64.ln(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 2 --------------- + aComplex = new Complex64(-1); + expResult = new Complex64((float) 0, (float) Math.PI); + actResult = Complex64.ln(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 3 --------------- + aComplex = new Complex64(0); + expResult = new Complex64(Float.NEGATIVE_INFINITY); + actResult = Complex64.ln(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 4 --------------- + aComplex = new Complex64(146.1417912f); + expResult = new Complex64(4.984577323028071f); + actResult = Complex64.ln(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 5 --------------- + aComplex = new Complex64(142.18623f, -92.394356f); + expResult = new Complex64(5.133259841229789f, -0.5762432330428644f); + actResult = Complex64.ln(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 6 --------------- + aComplex = new Complex64(-8.5464f, -9.72352f); + expResult = new Complex64(2.5607536790655163f, -2.2918540198902058f); + actResult = Complex64.ln(aComplex); + Assertions.assertEquals(expResult, actResult); + } + + + @Test + void lnDoubleTestCase() { + // ------------ Sub-case 1 --------------- + a = 1; + expResult = new Complex64(0); + actResult = Complex64.ln(a); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 2 --------------- + a = (float) Math.E; + expResult = new Complex64(0.99999994f); + actResult = Complex64.ln(a); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 3 --------------- + a = 0; + expResult = new Complex64(Float.NEGATIVE_INFINITY); + actResult = Complex64.ln(a); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 4 --------------- + a = 146.1417912f; + expResult = new Complex64(4.984577323028071f); + actResult = Complex64.ln(a); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 5 --------------- + a = -1; + expResult = new Complex64((float) 0, (float) Math.PI); + actResult = Complex64.ln(a); + Assertions.assertEquals(expResult, actResult); + } + + + @Test + void logTestCase() { + // ------------ Sub-case 1 --------------- + aComplex = new Complex64(1); + expResult = new Complex64(0); + actResult = Complex64.log(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 2 --------------- + aComplex = new Complex64(-1); + expResult = new Complex64(0, 1.364376425743103f); + actResult = Complex64.log(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 3 --------------- + aComplex = new Complex64(0); + expResult = new Complex64(Float.NEGATIVE_INFINITY); + actResult = Complex64.log(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 4 --------------- + aComplex = new Complex64(146.1417912f); + expResult = new Complex64(2.164774426011174f); + actResult = Complex64.log(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 5 --------------- + aComplex = new Complex64(142.18623f, -92.394356f); + expResult = new Complex64(2.2293463f, -0.25025925040245056f); + actResult = Complex64.log(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 6 --------------- + aComplex = new Complex64(-8.5464f, -9.72352f); + expResult = new Complex64(1.1121211f, -0.9953395128250122f); + actResult = Complex64.log(aComplex); + Assertions.assertEquals(expResult, actResult); + } + + + @Test + void logDoubleTestCase() { + // ------------ Sub-case 1 --------------- + a = 1; + expResult = new Complex64(0); + actResult = Complex64.log(a); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 2 --------------- + a = -1; + expResult = new Complex64(0, 1.364376425743103f); + actResult = Complex64.log(a); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 3 --------------- + a = 0; + expResult = new Complex64(Float.NEGATIVE_INFINITY); + actResult = Complex64.log(a); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 4 --------------- + a = 146.1417912f; + expResult = new Complex64(2.164774426011174f); + actResult = Complex64.log(a); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 5 --------------- + a = -984.593465f; + expResult = new Complex64( 2.9932568f, 1.364376425743103f); + actResult = Complex64.log(a); + Assertions.assertEquals(expResult, actResult); + } + + + @Test + void logDoubleBaseDoubleTestCase() { + // ------------ Sub-case 1 ------------ + a = 10; + b = 12.23423f; + expResult = new Complex64(1.0875766408496945f); + actResult = Complex64.log(a, b); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 2 ------------ + a = 985.343242f; + b = 34.532f; + expResult = new Complex64((float) (Math.log(b)/Math.log(a))); + actResult = Complex64.log(a, b); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 3 ------------ + a = -985.343242f; + b = 34.532f; + expResult = new Complex64(0.4254609128939375f, -0.1939107511743641f); + actResult = Complex64.log(a, b); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 3 ------------ + a = 98.4715f; + b = -0.3096712f; + expResult = new Complex64(-0.25540384666493776f, 0.6844776272773743f); + actResult = Complex64.log(a, b); + Assertions.assertEquals(expResult, actResult); + } + + + @Test + void logBaseDoubleTestCase() { + // ------------ Sub-case 1 ------------ + a = 2; + bComplex = new Complex64(14.32f, 785.234981f); + expResult = new Complex64(9.61722f,2.239873170852661f); + actResult = Complex64.log(a, bComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 2 ------------ + a = -42; + bComplex = new Complex64(0.23423f, -18.343f); + expResult = new Complex64(0.25081712f,-0.627661943435669f); + actResult = Complex64.log(a, bComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 3 ------------ + a = -23.123f; + bComplex = new Complex64(-123.34f, 895); + expResult = new Complex64(1.3551072366805246f,-0.8117131305115636f); + actResult = Complex64.log(a, bComplex); + Assertions.assertEquals(expResult, actResult); + } + + + @Test + void logBaseTestCase() { + // ------------ Sub-case 1 ------------ + aComplex = new Complex64(93.23487f, -6.32465f); + bComplex = new Complex64(-345.2f, 14.556f); + expResult = new Complex64(1.2776989f, 0.7021597623825073f); + actResult = Complex64.log(aComplex, bComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 2 ------------ + aComplex = new Complex64(12.1843f); + bComplex = new Complex64(0); + expResult = new Complex64(Float.NEGATIVE_INFINITY); + actResult = Complex64.log(aComplex, bComplex); + Assertions.assertEquals(expResult, actResult); + } +} diff --git a/src/test/java/org/flag4j/algebraic_structures/fields/Complex64GettersSettersTest.java b/src/test/java/org/flag4j/algebraic_structures/fields/Complex64GettersSettersTest.java new file mode 100644 index 000000000..7b9a08320 --- /dev/null +++ b/src/test/java/org/flag4j/algebraic_structures/fields/Complex64GettersSettersTest.java @@ -0,0 +1,23 @@ +package org.flag4j.algebraic_structures.fields; + +import org.flag4j.algebraic_structures.Complex64; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +class Complex64GettersSettersTest { + Complex64 num; + long expReLong, expImLong; + int expReInt, expImInt; + float expReFloat, expImFloat; + + + @Test + void gettersTestCase() { + num = new Complex64(692.13f, -9673.134f); + expReFloat = 692.13f; + expImFloat = -9673.134f; + + Assertions.assertEquals(expReFloat, num.re()); + Assertions.assertEquals(expImFloat, num.im()); + } +} diff --git a/src/test/java/org/flag4j/algebraic_structures/fields/Complex64MinMaxSumTest.java b/src/test/java/org/flag4j/algebraic_structures/fields/Complex64MinMaxSumTest.java new file mode 100644 index 000000000..700811c28 --- /dev/null +++ b/src/test/java/org/flag4j/algebraic_structures/fields/Complex64MinMaxSumTest.java @@ -0,0 +1,202 @@ +package org.flag4j.algebraic_structures.fields; + +import org.flag4j.algebraic_structures.Complex64; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +class Complex64MinMaxSumTest { + + Complex64 n1, n2, n3, n4, n5; + Complex64 sum, min, max; + Complex64 expSum, expMin, expMax; + int expArg, arg; + + @Test + void sumTestCase() { + // ------------ Sub-case 1 ------------ + n1 = new Complex64(93.13f, -6456.331f); + n2 = new Complex64(1.3f, 7.5f); + n3 = new Complex64(-4.2e-8f); + n4 = new Complex64(0, -2); + + sum = Complex64.sum(n1, n2, n3, n4); + + expSum = new Complex64(93.13f + 1.3f + -4.2e-8f + 0, + -6456.331f + 7.5f + 0 + -2); + + Assertions.assertEquals(sum, expSum); + + // ------------ Sub-case 2 ------------ + sum = Complex64.sum(); + expSum = Complex64.ZERO; + Assertions.assertEquals(sum, expSum); + } + + + @Test + void minTestCase() { + // ------------ Sub-case 1 ------------ + n1 = new Complex64(93.13f, -6456.331f); + n2 = new Complex64(1.3f, 7.5f); + n3 = new Complex64(-4.2e-8f); + n4 = new Complex64(0, -2); + + min = Complex64.min(n1, n2, n3, n4); + + expMin = new Complex64(-4.2e-8f); + + Assertions.assertEquals(expMin, min); + + // ------------ Sub-case 2 ------------ + min = Complex64.min(); + Assertions.assertNull(min); + } + + + @Test + void minReal() { + // ------------ Sub-case 1 ------------ + n1 = new Complex64(93.13f, -6456.331f); + n2 = new Complex64(1.3f, 7.5f); + n3 = new Complex64(-4.2e-8f); + n4 = new Complex64(0, -2); + n5 = new Complex64(-9347, 100); + + min = Complex64.minRe(n1, n2, n3, n4, n5); + + expMin = new Complex64(-9347); + + Assertions.assertEquals(expMin, min); + + // ------------ Sub-case 2 ------------ + min = Complex64.minRe(); + Assertions.assertTrue(Double.isNaN(min.re)); + } + + + @Test + void argminTestCase() { + // ------------ Sub-case 1 ------------ + n1 = new Complex64(93.13f, -6456.331f); + n2 = new Complex64(1.3f, 7.5f); + n3 = new Complex64(-4.2e-8f); + n4 = new Complex64(0, -2); + + arg = Complex64.argmin(n1, n2, n3, n4); + + expArg = 2; + + Assertions.assertEquals(expMin, min); + + // ------------ Sub-case 2 ------------ + arg = Complex64.argmin(); + expArg = -1; + Assertions.assertEquals(expMin, min); + } + + + @Test + void argminRealTestCase() { + // ------------ Sub-case 1 ------------ + n1 = new Complex64(93.13f, -6456.331f); + n2 = new Complex64(1.3f, 7.5f); + n3 = new Complex64(-4.2e-8f); + n4 = new Complex64(-122, -2); + n5 = new Complex64(0, -2); + + arg = Complex64.argminReal(n1, n2, n3, n4); + + expArg = 3; + + Assertions.assertEquals(expMin, min); + + // ------------ Sub-case 2 ------------ + arg = Complex64.argminReal(); + expArg = -1; + Assertions.assertEquals(expMin, min); + } + + + @Test + void maxTestCase() { + // ------------ Sub-case 1 ------------ + n1 = new Complex64(1.3f, 7.5f); + n2 = new Complex64(93.13f, -6456.331f); + n3 = new Complex64(-4.2e-8f); + n4 = new Complex64(0, -2); + + max = Complex64.max(n1, n2, n3, n4); + + expMax = new Complex64(93.13f, -6456.331f); + + Assertions.assertEquals(expMax, max); + + // ------------ Sub-case 2 ------------ + max = Complex64.max(); + Assertions.assertNull(max); + } + + + @Test + void maxReal() { + // ------------ Sub-case 1 ------------ + n1 = new Complex64(93.13f, -6456.331f); + n2 = new Complex64(1.3f, 7.5f); + n3 = new Complex64(-4.2e-8f); + n4 = new Complex64(104.43f, -2); + n5 = new Complex64(0, 100); + + max = Complex64.maxRe(n1, n2, n3, n4, n5); + + expMax = new Complex64(104.43f); + + Assertions.assertEquals(expMax, max); + + // ------------ Sub-case 2 ------------ + max = Complex64.maxRe(); + Assertions.assertTrue(Double.isNaN(max.re)); + } + + + @Test + void argmaxTestCase() { + // ------------ Sub-case 1 ------------ + n1 = new Complex64(1.3f, 7.5f); + n2 = new Complex64(93.13f, -6456.331f); + n3 = new Complex64(-4.2e-8f); + n4 = new Complex64(0, -2); + + arg = Complex64.argmax(n1, n2, n3, n4); + + expArg = 1; + + Assertions.assertEquals(expMin, min); + + // ------------ Sub-case 2 ------------ + arg = Complex64.argmax(); + expArg = -1; + Assertions.assertEquals(expMin, min); + } + + + @Test + void argmaxRealTestCase() { + // ------------ Sub-case 1 ------------ + n1 = new Complex64(93.13f, -6456.331f); + n2 = new Complex64(1e10f, 7.5f); + n3 = new Complex64(-4.2e-8f); + n4 = new Complex64(-122, -2); + n5 = new Complex64(0, -2); + + arg = Complex64.argmaxReal(n1, n2, n3, n4); + + expArg = 1; + + Assertions.assertEquals(expMin, min); + + // ------------ Sub-case 2 ------------ + arg = Complex64.argmaxReal(); + expArg = -1; + Assertions.assertEquals(expMin, min); + } +} diff --git a/src/test/java/org/flag4j/algebraic_structures/fields/Complex64PropertiesTest.java b/src/test/java/org/flag4j/algebraic_structures/fields/Complex64PropertiesTest.java new file mode 100644 index 000000000..8dcbddb29 --- /dev/null +++ b/src/test/java/org/flag4j/algebraic_structures/fields/Complex64PropertiesTest.java @@ -0,0 +1,546 @@ +package org.flag4j.algebraic_structures.fields; + +import org.flag4j.algebraic_structures.Complex64; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + + +class Complex64PropertiesTest { + Complex64 a; + boolean expResult; + boolean result; + + @Test + void isIntTestCase() { + // ------------- Sub-case 1 -------------- + a = new Complex64(5); + expResult = true; + result = a.isInt(); + + assertEquals(expResult, result); + + // ------------- Sub-case 2 -------------- + a = new Complex64(-4); + expResult = true; + result = a.isInt(); + + assertEquals(expResult, result); + + // ------------- Sub-case 3 -------------- + a = new Complex64(2,-1); + expResult = false; + result = a.isInt(); + + assertEquals(expResult, result); + + // ------------- Sub-case 4 -------------- + a = new Complex64(Float.POSITIVE_INFINITY); + expResult = false; + result = a.isInt(); + + assertEquals(expResult, result); + + // ------------- Sub-case 5 -------------- + a = new Complex64(23.5f); + expResult = false; + result = a.isInt(); + + assertEquals(expResult, result); + } + + @Test + void isDoubleTestCase() { + // ------------- Sub-case 1 -------------- + a = new Complex64(5); + expResult = true; + result = a.isFloat(); + + assertEquals(expResult, result); + + // ------------- Sub-case 2 -------------- + a = new Complex64(-4); + expResult = true; + result = a.isFloat(); + + assertEquals(expResult, result); + + // ------------- Sub-case 3 -------------- + a = new Complex64(2,-1); + expResult = false; + result = a.isFloat(); + + assertEquals(expResult, result); + + // ------------- Sub-case 4 -------------- + a = new Complex64(Float.POSITIVE_INFINITY); + expResult = true; + result = a.isFloat(); + + assertEquals(expResult, result); + + // ------------- Sub-case 5 -------------- + a = new Complex64(223.54268f); + expResult = true; + result = a.isFloat(); + + assertEquals(expResult, result); + } + + @Test + void isNaNTestCase() { + // ------------- Sub-case 1 -------------- + a = new Complex64(5); + expResult = false; + result = a.isNaN(); + + assertEquals(expResult, result); + + // ------------- Sub-case 2 -------------- + a = new Complex64(-4); + expResult = false; + result = a.isNaN(); + + assertEquals(expResult, result); + + // ------------- Sub-case 3 -------------- + a = new Complex64(2,-1); + expResult = false; + result = a.isNaN(); + + assertEquals(expResult, result); + + // ------------- Sub-case 4 -------------- + a = new Complex64(Float.POSITIVE_INFINITY); + expResult = false; + result = a.isNaN(); + + assertEquals(expResult, result); + + // ------------- Sub-case 5 -------------- + a = new Complex64(223.54268f); + expResult = false; + result = a.isNaN(); + + assertEquals(expResult, result); + + // ------------- Sub-case 6 -------------- + a = new Complex64(223.54268f, Float.NaN); + expResult = true; + result = a.isNaN(); + + assertEquals(expResult, result); + + // ------------- Sub-case 7 -------------- + a = new Complex64(223.54268f,12434.33f); + expResult = false; + result = a.isNaN(); + + assertEquals(expResult, result); + + // ------------- Sub-case 8 -------------- + a = new Complex64(Float.NaN,12434.33f); + expResult = true; + result = a.isNaN(); + + assertEquals(expResult, result); + + + // ------------- Sub-case 9 -------------- + a = new Complex64(Float.NaN,Float.NaN); + expResult = true; + result = a.isNaN(); + + assertEquals(expResult, result); + } + + @Test + void isFiniteTestCase() { + // ------------- Sub-case 1 -------------- + a = new Complex64(5); + expResult = true; + result = a.isFinite(); + + assertEquals(expResult, result); + + // ------------- Sub-case 2 -------------- + a = new Complex64(-4); + expResult = true; + result = a.isFinite(); + + assertEquals(expResult, result); + + // ------------- Sub-case 3 -------------- + a = new Complex64(2,-1); + expResult = true; + result = a.isFinite(); + + assertEquals(expResult, result); + + // ------------- Sub-case 4 -------------- + a = new Complex64(Float.POSITIVE_INFINITY); + expResult = false; + result = a.isFinite(); + + assertEquals(expResult, result); + + // ------------- Sub-case 5 -------------- + a = new Complex64(223.54268f); + expResult = true; + result = a.isFinite(); + + assertEquals(expResult, result); + + // ------------- Sub-case 6 -------------- + a = new Complex64(223.54268f, Float.NEGATIVE_INFINITY); + expResult = false; + result = a.isFinite(); + + assertEquals(expResult, result); + + // ------------- Sub-case 7 -------------- + a = new Complex64(223.54268f,12434.33f); + expResult = true; + result = a.isFinite(); + + assertEquals(expResult, result); + + // ------------- Sub-case 8 -------------- + a = new Complex64(Float.NaN,12434.33f); + expResult = false; + result = a.isFinite(); + + assertEquals(expResult, result); + + + // ------------- Sub-case 9 -------------- + a = new Complex64(Float.NaN,Float.NaN); + expResult = false; + result = a.isFinite(); + + assertEquals(expResult, result); + + + // ------------- Sub-case 10 -------------- + a = new Complex64(Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY); + expResult = false; + result = a.isFinite(); + + assertEquals(expResult, result); + } + + + @Test + void isInfiniteTestCase() { + // ------------- Sub-case 1 -------------- + a = new Complex64(5); + expResult = false; + result = a.isInfinite(); + + assertEquals(expResult, result); + + // ------------- Sub-case 2 -------------- + a = new Complex64(-4); + expResult = false; + result = a.isInfinite(); + + assertEquals(expResult, result); + + // ------------- Sub-case 3 -------------- + a = new Complex64(2,-1); + expResult = false; + result = a.isInfinite(); + + assertEquals(expResult, result); + + // ------------- Sub-case 4 -------------- + a = new Complex64(Float.POSITIVE_INFINITY); + expResult = true; + result = a.isInfinite(); + + assertEquals(expResult, result); + + // ------------- Sub-case 5 -------------- + a = new Complex64(223.54268f); + expResult = false; + result = a.isInfinite(); + + assertEquals(expResult, result); + + // ------------- Sub-case 6 -------------- + a = new Complex64(223.54268f, Float.NEGATIVE_INFINITY); + expResult = true; + result = a.isInfinite(); + + assertEquals(expResult, result); + + // ------------- Sub-case 7 -------------- + a = new Complex64(223.54268f,12434.33f); + expResult = false; + result = a.isInfinite(); + + assertEquals(expResult, result); + + // ------------- Sub-case 8 -------------- + a = new Complex64(Float.NaN,12434.33f); + expResult = false; + result = a.isInfinite(); + + assertEquals(expResult, result); + + + // ------------- Sub-case 9 -------------- + a = new Complex64(Float.NaN,Float.NaN); + expResult = false; + result = a.isInfinite(); + + assertEquals(expResult, result); + + + // ------------- Sub-case 10 -------------- + a = new Complex64(Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY); + expResult = true; + result = a.isInfinite(); + + assertEquals(expResult, result); + } + + + @Test + void isRealTestCase() { + // ------------- Sub-case 1 -------------- + a = new Complex64(5); + expResult = true; + result = a.isReal(); + + assertEquals(expResult, result); + + // ------------- Sub-case 2 -------------- + a = new Complex64(-4); + expResult = true; + result = a.isReal(); + + assertEquals(expResult, result); + + // ------------- Sub-case 3 -------------- + a = new Complex64(2,-1); + expResult = false; + result = a.isReal(); + + assertEquals(expResult, result); + + // ------------- Sub-case 4 -------------- + a = new Complex64(Float.POSITIVE_INFINITY); + expResult = true; + result = a.isReal(); + + assertEquals(expResult, result); + + // ------------- Sub-case 5 -------------- + a = new Complex64(223.54268f); + expResult = true; + result = a.isReal(); + + assertEquals(expResult, result); + + // ------------- Sub-case 6 -------------- + a = new Complex64(223.54268f, Float.NEGATIVE_INFINITY); + expResult = false; + result = a.isReal(); + + assertEquals(expResult, result); + + // ------------- Sub-case 7 -------------- + a = new Complex64(223.54268f,12434.33f); + expResult = false; + result = a.isReal(); + + assertEquals(expResult, result); + + // ------------- Sub-case 8 -------------- + a = new Complex64(Float.NaN,12434.33f); + expResult = false; + result = a.isReal(); + + assertEquals(expResult, result); + + // ------------- Sub-case 9 -------------- + a = new Complex64(Float.NaN,Float.NaN); + expResult = false; + result = a.isReal(); + + assertEquals(expResult, result); + + // ------------- Sub-case 10 -------------- + a = new Complex64(Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY); + expResult = false; + result = a.isReal(); + + assertEquals(expResult, result); + } + + + @Test + void isImaginaryTestCase() { + // ------------- Sub-case 1 -------------- + a = new Complex64(5); + expResult = false; + result = a.isImaginary(); + + assertEquals(expResult, result); + + // ------------- Sub-case 2 -------------- + a = new Complex64(-4); + expResult = false; + result = a.isImaginary(); + + assertEquals(expResult, result); + + // ------------- Sub-case 3 -------------- + a = new Complex64(2,-1); + expResult = false; + result = a.isImaginary(); + + assertEquals(expResult, result); + + // ------------- Sub-case 4 -------------- + a = new Complex64(0, -342); + expResult = true; + result = a.isImaginary(); + + assertEquals(expResult, result); + + // ------------- Sub-case 5 -------------- + a = new Complex64(223.54268f); + expResult = false; + result = a.isImaginary(); + + assertEquals(expResult, result); + + // ------------- Sub-case 6 -------------- + a = new Complex64(223.54268f, Float.NEGATIVE_INFINITY); + expResult = false; + result = a.isImaginary(); + + assertEquals(expResult, result); + + // ------------- Sub-case 7 -------------- + a = new Complex64(0,12434.33f); + expResult = true; + result = a.isImaginary(); + + assertEquals(expResult, result); + + // ------------- Sub-case 8 -------------- + a = new Complex64(Float.NaN,12434.33f); + expResult = false; + result = a.isImaginary(); + + assertEquals(expResult, result); + + + // ------------- Sub-case 9 -------------- + a = new Complex64(Float.NaN,Float.NaN); + expResult = false; + result = a.isImaginary(); + + assertEquals(expResult, result); + + + // ------------- Sub-case 10 -------------- + a = new Complex64(Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY); + expResult = false; + result = a.isImaginary(); + + assertEquals(expResult, result); + + // ------------- Sub-case 11 -------------- + a = new Complex64(0, Float.NEGATIVE_INFINITY); + expResult = true; + result = a.isImaginary(); + + assertEquals(expResult, result); + } + + + @Test + void isComplexTestCase() { + // ------------- Sub-case 1 -------------- + a = new Complex64(5); + expResult = false; + result = a.isComplex(); + + assertEquals(expResult, result); + + // ------------- Sub-case 2 -------------- + a = new Complex64(-4); + expResult = false; + result = a.isComplex(); + + assertEquals(expResult, result); + + // ------------- Sub-case 3 -------------- + a = new Complex64(2,-1); + expResult = true; + result = a.isComplex(); + + assertEquals(expResult, result); + + // ------------- Sub-case 4 -------------- + a = new Complex64(0, -342); + expResult = true; + result = a.isComplex(); + + assertEquals(expResult, result); + + // ------------- Sub-case 5 -------------- + a = new Complex64(223.54268f); + expResult = false; + result = a.isComplex(); + + assertEquals(expResult, result); + + // ------------- Sub-case 6 -------------- + a = new Complex64(223.54268f, Float.NEGATIVE_INFINITY); + expResult = true; + result = a.isComplex(); + + assertEquals(expResult, result); + + // ------------- Sub-case 7 -------------- + a = new Complex64(0,12434.33f); + expResult = true; + result = a.isComplex(); + + assertEquals(expResult, result); + + // ------------- Sub-case 8 -------------- + a = new Complex64(Float.NaN,12434.33f); + expResult = true; + result = a.isComplex(); + + assertEquals(expResult, result); + + // ------------- Sub-case 9 -------------- + a = new Complex64(Float.NaN,Float.NaN); + expResult = true; + result = a.isComplex(); + + assertEquals(expResult, result); + + // ------------- Sub-case 10 -------------- + a = new Complex64(Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY); + expResult = true; + result = a.isComplex(); + + assertEquals(expResult, result); + + // ------------- Sub-case 11 -------------- + a = new Complex64(0, Float.NEGATIVE_INFINITY); + expResult = true; + result = a.isComplex(); + + assertEquals(expResult, result); + } + +} diff --git a/src/test/java/org/flag4j/algebraic_structures/fields/Complex64RoundTest.java b/src/test/java/org/flag4j/algebraic_structures/fields/Complex64RoundTest.java new file mode 100644 index 000000000..1a1bf1eaf --- /dev/null +++ b/src/test/java/org/flag4j/algebraic_structures/fields/Complex64RoundTest.java @@ -0,0 +1,147 @@ +package org.flag4j.algebraic_structures.fields; + +import org.flag4j.algebraic_structures.Complex64; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +class Complex64RoundTest { + + Complex64 n, expRound, actRound; + boolean expNearZero, actNearZero; + + @Test + void roundTestCase() { + // -------------- Sub-case 1 -------------- + n = new Complex64(0); + expRound = new Complex64(0); + actRound = Complex64.round(n); + Assertions.assertEquals(expRound, actRound); + + // -------------- Sub-case 2 -------------- + n = new Complex64(13, 4); + expRound = new Complex64(13, 4); + actRound = Complex64.round(n); + Assertions.assertEquals(expRound, actRound); + + // -------------- Sub-case 3 -------------- + n = new Complex64(-0.133f, 13.413f); + expRound = new Complex64(0, 13); + actRound = Complex64.round(n); + Assertions.assertEquals(expRound, actRound); + + // -------------- Sub-case 4 -------------- + n = new Complex64(-0.893f, 16.5f); + expRound = new Complex64(-1, 17); + actRound = Complex64.round(n); + Assertions.assertEquals(expRound, actRound); + + + // -------------- Sub-case 5 -------------- + n = new Complex64(9.3E10f, 0.1993312f); + expRound = new Complex64(9.3E10f, 0); + actRound = Complex64.round(n); + Assertions.assertEquals(expRound, actRound); + } + + + @Test + void roundDecimalsTestCase() { + // -------------- Sub-case 1 -------------- + n = new Complex64(0); + expRound = new Complex64(0); + actRound = Complex64.round(n, 1); + Assertions.assertEquals(expRound, actRound); + + // -------------- Sub-case 2 -------------- + n = new Complex64(13, 4); + expRound = new Complex64(13, 4); + actRound = Complex64.round(n, 2); + Assertions.assertEquals(expRound, actRound); + + // -------------- Sub-case 3 -------------- + n = new Complex64(-0.133f, 13.41562f); + expRound = new Complex64(-0.13f, 13.42f); + actRound = Complex64.round(n, 2); + Assertions.assertEquals(expRound, actRound); + + // -------------- Sub-case 4 -------------- + n = new Complex64(-0.89242993f, 16.99999999f); + expRound = new Complex64(-0.89243f, 17); + actRound = Complex64.round(n, 6); + Assertions.assertEquals(expRound, actRound); + + // -------------- Sub-case 5 -------------- + n = new Complex64(9.3E10f, 0.1993312f); + expRound = new Complex64(9.3E10f, 0.1993f); + actRound = Complex64.round(n, 4); + Assertions.assertEquals(expRound, actRound); + + // -------------- Sub-case 9 -------------- + n = new Complex64(8234.5f, 123.34f); + Assertions.assertThrows(IllegalArgumentException.class, () -> Complex64.round(n, -1)); + + // -------------- Sub-case 10 -------------- + n = new Complex64(8234.5f, 123.34f); + Assertions.assertThrows(IllegalArgumentException.class, () -> Complex64.round(n, -100)); + } + + + @Test + void nearZeroTestCase() { + // -------------- Sub-case 1 -------------- + n = new Complex64(24.25f, 0.3422f); + expNearZero = false; + actNearZero = Complex64.nearZero(n, 0.001f); + Assertions.assertEquals(expRound, actRound); + + // -------------- Sub-case 2 -------------- + n = new Complex64(13, 4); + expNearZero = true; + actNearZero = Complex64.nearZero(n, 15); + Assertions.assertEquals(expRound, actRound); + + // -------------- Sub-case 3 -------------- + n = new Complex64(-0.133f, 13.41562f); + expNearZero = false; + actNearZero = Complex64.nearZero(n, 1); + Assertions.assertEquals(expRound, actRound); + + // -------------- Sub-case 4 -------------- + n = new Complex64(-0.0001231f, 0.0000001313f); + expNearZero = true; + actNearZero = Complex64.nearZero(n, 0.0005f); + Assertions.assertEquals(expRound, actRound); + + // -------------- Sub-case 5 -------------- + n = new Complex64(9.3E10f, 0.1993312f); + expNearZero = false; + actNearZero = Complex64.nearZero(n, 13100); + Assertions.assertEquals(expRound, actRound); + + // -------------- Sub-case 6 -------------- + n = new Complex64(Float.POSITIVE_INFINITY); + expNearZero = false; + actNearZero = Complex64.nearZero(n, 13100); + Assertions.assertEquals(expRound, actRound); + + // -------------- Sub-case 7 -------------- + n = new Complex64(Float.NEGATIVE_INFINITY); + expNearZero = false; + actNearZero = Complex64.nearZero(n, 1000000); + Assertions.assertEquals(expRound, actRound); + + // -------------- Sub-case 8 -------------- + n = new Complex64(Float.NaN); + expNearZero = false; + actNearZero = Complex64.nearZero(n, 13); + Assertions.assertEquals(expRound, actRound); + + // -------------- Sub-case 9 -------------- + n = new Complex64(8234.5f, 123.34f); + Assertions.assertThrows(IllegalArgumentException.class, () -> Complex64.nearZero(n, -1)); + + // -------------- Sub-case 10 -------------- + n = new Complex64(8234.5f, 123.34f); + Assertions.assertThrows(IllegalArgumentException.class, () -> Complex64.nearZero(n, -100)); + } +} diff --git a/src/test/java/org/flag4j/algebraic_structures/fields/Complex64SqrtTest.java b/src/test/java/org/flag4j/algebraic_structures/fields/Complex64SqrtTest.java new file mode 100644 index 000000000..6d422949b --- /dev/null +++ b/src/test/java/org/flag4j/algebraic_structures/fields/Complex64SqrtTest.java @@ -0,0 +1,152 @@ +package org.flag4j.algebraic_structures.fields; + +import org.flag4j.algebraic_structures.Complex64; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +class Complex64SqrtTest { + float a; + Complex64 aComplex; + Complex64 expResult, actResult; + + @Test + void sqrtDoubleTestCase() { + // ------------- Sub-case 1 ------------- + a = 1; + expResult = new Complex64(1); + actResult = Complex64.sqrt(a); + Assertions.assertEquals(expResult, actResult); + + // ------------- Sub-case 2 ------------- + a = 4; + expResult = new Complex64(2); + actResult = Complex64.sqrt(a); + Assertions.assertEquals(expResult, actResult); + + // ------------- Sub-case 3 ------------- + a = 2; + expResult = Complex64.ROOT_TWO; + actResult = Complex64.sqrt(a); + Assertions.assertEquals(expResult, actResult); + + + // ------------- Sub-case 4 ------------- + a = 3; + expResult = Complex64.ROOT_THREE; + actResult = Complex64.sqrt(a); + Assertions.assertEquals(expResult, actResult); + + // ------------- Sub-case 5 ------------- + a = 763.3422f; + expResult = new Complex64((float) Math.sqrt(a)); + actResult = Complex64.sqrt(a); + Assertions.assertEquals(expResult, actResult); + + // ------------- Sub-case 6 ------------- + a = 0; + expResult = new Complex64(0); + actResult = Complex64.sqrt(a); + Assertions.assertEquals(expResult, actResult); + + // ------------- Sub-case 6 ------------- + a = -1; + expResult = new Complex64(0, 1); + actResult = Complex64.sqrt(a); + Assertions.assertEquals(expResult, actResult); + + // ------------- Sub-case 6 ------------- + a = -56.3947f; + expResult = new Complex64(0, (float) Math.sqrt(56.3947f)); + actResult = Complex64.sqrt(a); + Assertions.assertEquals(expResult, actResult); + + // ------------- Sub-case 7 ------------- + a = -0.0f; + expResult = new Complex64(0, -0.0f); + actResult = Complex64.sqrt(a); + Assertions.assertEquals(expResult, actResult); + } + + + @Test + void sqrtTestCase() { + // ------------- Sub-case 1 ------------- + aComplex = new Complex64(1); + expResult = new Complex64(1); + actResult = aComplex.sqrt(); + Assertions.assertEquals(expResult, actResult); + + // ------------- Sub-case 2 ------------- + aComplex = new Complex64(4); + expResult = new Complex64(2); + actResult = aComplex.sqrt(); + Assertions.assertEquals(expResult, actResult); + + // ------------- Sub-case 3 ------------- + aComplex = new Complex64(2); + expResult = Complex64.ROOT_TWO; + actResult = aComplex.sqrt(); + Assertions.assertEquals(expResult, actResult); + + + // ------------- Sub-case 4 ------------- + aComplex = new Complex64(3); + expResult = Complex64.ROOT_THREE; + actResult = aComplex.sqrt(); + Assertions.assertEquals(expResult, actResult); + + // ------------- Sub-case 5 ------------- + aComplex = new Complex64(763.3422f); + expResult = new Complex64((float) Math.sqrt(763.3422f)); + actResult = aComplex.sqrt(); + Assertions.assertEquals(expResult, actResult); + + // ------------- Sub-case 6 ------------- + aComplex = new Complex64(0); + expResult = new Complex64(0); + actResult = aComplex.sqrt(); + Assertions.assertEquals(expResult, actResult); + + // ------------- Sub-case 6 ------------- + aComplex = new Complex64(-1); + expResult = new Complex64(0, 1); + actResult = aComplex.sqrt(); + Assertions.assertEquals(expResult, actResult); + + // ------------- Sub-case 6 ------------- + aComplex = new Complex64(-56.3947f); + expResult = new Complex64(0, (float) Math.sqrt(56.3947f)); + actResult = aComplex.sqrt(); + Assertions.assertEquals(expResult, actResult); + + // ------------- Sub-case 7 ------------- + aComplex = new Complex64(-0.0f); + expResult = new Complex64(0, -0.0f); + actResult = aComplex.sqrt(); + Assertions.assertEquals(expResult, actResult); + + // ------------- Sub-case 7 ------------- + aComplex = new Complex64(14.3f, 7683.453f); + expResult = new Complex64(62.0393677722319f, 61.92401112313579f); + actResult = aComplex.sqrt(); + Assertions.assertEquals(expResult, actResult); + + // ------------- Sub-case 8 ------------- + aComplex = new Complex64(-84.3453f, 32.337847f); + expResult = new Complex64(1.7301267f, 9.345514338779571f); + actResult = aComplex.sqrt(); + Assertions.assertEquals(expResult, actResult); + + // ------------- Sub-case 9 ------------- + aComplex = new Complex64(0.34534f, -9753246.45756f); + expResult = new Complex64(2208.307814017331f, -2208.307735826237f); + actResult = aComplex.sqrt(); + Assertions.assertEquals(expResult, actResult); + + // ------------- Sub-case 10 ------------- + aComplex = new Complex64(-74.2346f, -634.2146f); + expResult = new Complex64(16.797466883551945f, -18.87828100500743f); + actResult = aComplex.sqrt(); + Assertions.assertEquals(expResult, actResult); + } +} diff --git a/src/test/java/org/flag4j/algebraic_structures/fields/Complex64ToStringTest.java b/src/test/java/org/flag4j/algebraic_structures/fields/Complex64ToStringTest.java new file mode 100644 index 000000000..9b14f7983 --- /dev/null +++ b/src/test/java/org/flag4j/algebraic_structures/fields/Complex64ToStringTest.java @@ -0,0 +1,105 @@ +package org.flag4j.algebraic_structures.fields; + +import org.flag4j.algebraic_structures.Complex64; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +class Complex64ToStringTest { + Complex64 a; + String expStr; + + @Test + void realToStringTestCase() { + // ---------- Sub-case 1 ------------ + a = new Complex64(1); + expStr = "1"; + Assertions.assertEquals(expStr, a.toString()); + Assertions.assertEquals(expStr.length(), Complex64.length(a)); + + // ---------- Sub-case 2 ------------ + a = new Complex64(93.234f); + expStr = "93.234"; + Assertions.assertEquals(expStr, a.toString()); + Assertions.assertEquals(expStr.length(), Complex64.length(a)); + + // ---------- Sub-case 3 ------------ + a = new Complex64(-1.23e-5f); + expStr = "-1.23E-5"; + Assertions.assertEquals(expStr, a.toString()); + Assertions.assertEquals(expStr.length(), Complex64.length(a)); + + // ---------- Sub-case 4 ------------ + a = Complex64.ZERO; + expStr = "0"; + Assertions.assertEquals(expStr, a.toString()); + Assertions.assertEquals(expStr.length(), Complex64.length(a)); + } + + + @Test + void imaginaryToStringTestCase() { + // ---------- Sub-case 1 ------------ + a = new Complex64(0, 1); + expStr = "i"; + Assertions.assertEquals(expStr, a.toString()); + Assertions.assertEquals(expStr.length(), Complex64.length(a)); + + // ---------- Sub-case 2 ------------ + a = new Complex64(0, 93.234f); + expStr = "93.23400115966797i"; + Assertions.assertEquals(expStr, a.toString()); + Assertions.assertEquals(expStr.length(), Complex64.length(a)); + + // ---------- Sub-case 3 ------------ + a = new Complex64(0, -1.23e-5f); + expStr = "-1.2299999980314169E-5i"; + Assertions.assertEquals(expStr, a.toString()); + Assertions.assertEquals(expStr.length(), Complex64.length(a)); + + // ---------- Sub-case 4 ------------ + a = new Complex64(0, -1); + expStr = "-i"; + Assertions.assertEquals(expStr, a.toString()); + Assertions.assertEquals(expStr.length(), Complex64.length(a)); + + // ---------- Sub-case 5 ------------ + a = new Complex64(0, 24); + expStr = "24i"; + Assertions.assertEquals(expStr, a.toString()); + Assertions.assertEquals(expStr.length(), Complex64.length(a)); + + // ---------- Sub-case 6 ------------ + a = new Complex64(0, -56); + expStr = "-56i"; + Assertions.assertEquals(expStr, a.toString()); + Assertions.assertEquals(expStr.length(), Complex64.length(a)); + } + + + @Test + void complexToStringTestCase() { + // ---------- Sub-case 1 ------------ + a = new Complex64(234.3f, 1); + expStr = "234.3 + i"; + Assertions.assertEquals(expStr, a.toString()); + Assertions.assertEquals(expStr.length(), Complex64.length(a)); + + // ---------- Sub-case 2 ------------ + a = new Complex64(1.341f, 93.234f); + expStr = "1.341 + 93.23400115966797i"; + Assertions.assertEquals(expStr, a.toString()); + Assertions.assertEquals(expStr.length(), Complex64.length(a)); + + // ---------- Sub-case 3 ------------ + a = new Complex64(-9.324f, -1.23e-5f); + expStr = "-9.324 - 1.2299999980314169E-5i"; + Assertions.assertEquals(expStr, a.toString()); + Assertions.assertEquals(expStr.length(), Complex64.length(a)); + + // ---------- Sub-case 4 ------------ + a = new Complex64(994.242f, -1); + expStr = "994.242 - i"; + Assertions.assertEquals(expStr, a.toString()); + Assertions.assertEquals(expStr.length(), Complex64.length(a)); + } +} diff --git a/src/test/java/org/flag4j/algebraic_structures/fields/Complex64TrigTest.java b/src/test/java/org/flag4j/algebraic_structures/fields/Complex64TrigTest.java new file mode 100644 index 000000000..b2742385c --- /dev/null +++ b/src/test/java/org/flag4j/algebraic_structures/fields/Complex64TrigTest.java @@ -0,0 +1,173 @@ +package org.flag4j.algebraic_structures.fields; + +import org.flag4j.algebraic_structures.Complex64; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +// Note, all tests account for floating point precision errors. +class Complex64TrigTest { + double a; + Complex64 aComplex; + Complex64 expResult, actResult; + + @Test + void sinTestCase() { + // ------------ Sub-case 1 -------------- + aComplex = new Complex64((float) Math.PI); + expResult = new Complex64(-8.742278E-8f); + actResult = Complex64.sin(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 2 -------------- + aComplex = new Complex64((float) Math.PI/2); + expResult = new Complex64(1); + actResult = Complex64.sin(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 3 -------------- + aComplex = new Complex64((float) (3*Math.PI/2)); + expResult = new Complex64(-1); + actResult = Complex64.sin(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 4 -------------- + aComplex = new Complex64(0); + expResult = new Complex64(0); + actResult = Complex64.sin(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 5 -------------- + aComplex = new Complex64((float) (-600*Math.PI)); + expResult = new Complex64(2.5747626E-5f); + actResult = Complex64.sin(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 6 -------------- + aComplex = new Complex64((float) Math.PI/4); + expResult = new Complex64(0.7071067811865475f); + actResult = Complex64.sin(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 7 -------------- + aComplex = new Complex64(63425.234432673f); + expResult = new Complex64(0.37064958f); + actResult = Complex64.sin(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 7 -------------- + aComplex = new Complex64(355.34f, (float) Math.PI); + expResult = new Complex64( -3.8660564f, -10.887526512145996f); + actResult = Complex64.sin(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 7 -------------- + aComplex = new Complex64(-2.3f, 8.099867543f); + expResult = new Complex64( -1228.1888f, -1097.3673095703125f); + actResult = Complex64.sin(aComplex); + Assertions.assertEquals(expResult, actResult); + } + + + @Test + void cosTestCase() { + // ------------ Sub-case 1 -------------- + aComplex = new Complex64((float) Math.PI); + expResult = new Complex64(-1); + actResult = Complex64.cos(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 2 -------------- + aComplex = new Complex64((float) Math.PI/2); + expResult = new Complex64(-4.371139E-8f); + actResult = Complex64.cos(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 3 -------------- + aComplex = new Complex64((float) (3*Math.PI/2)); + expResult = new Complex64(1.1924881E-8f); + actResult = Complex64.cos(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 4 -------------- + aComplex = new Complex64(0); + expResult = new Complex64(1); + actResult = Complex64.cos(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 5 -------------- + aComplex = new Complex64((float) (-600*Math.PI)); + expResult = new Complex64(1); + actResult = Complex64.cos(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 6 -------------- + aComplex = new Complex64((float) Math.PI/4); + expResult = new Complex64((float) Math.sqrt(2)/2); + actResult = Complex64.cos(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 7 -------------- + aComplex = new Complex64(63425.234432673f); + expResult = new Complex64(-0.9287728f); + actResult = Complex64.cos(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 7 -------------- + aComplex = new Complex64(355.34f, (float) Math.PI); + expResult = new Complex64( -10.928267f, -3.851644277572632f); + actResult = Complex64.cos(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 7 -------------- + aComplex = new Complex64(-2.3f, 8.099867543f); + expResult = new Complex64( -1097.3676f, -1228.1885986328125f); + actResult = Complex64.cos(aComplex); + Assertions.assertEquals(expResult, actResult); + } + + + @Test + void tanTestCase() { + // ------------ Sub-case 1 -------------- + aComplex = new Complex64((float) Math.PI); + expResult = new Complex64(8.742278E-8f); + actResult = Complex64.tan(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 2 -------------- + aComplex = new Complex64(0); + expResult = new Complex64(0); + actResult = Complex64.tan(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 3 -------------- + aComplex = new Complex64((float) (-600*Math.PI)); + expResult = new Complex64(2.5747626E-5f); + actResult = Complex64.tan(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 4 -------------- + aComplex = new Complex64((float) Math.PI/4); + expResult = new Complex64(0.9999999999999999f); + actResult = Complex64.tan(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 5 -------------- + aComplex = new Complex64(63425.234432673f); + expResult = new Complex64(-0.39907455f); + actResult = Complex64.tan(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 6 -------------- + aComplex = new Complex64(355.34f, (float) Math.PI); + expResult = new Complex64( 0.002341809f, 0.9970974326133728f); + actResult = Complex64.tan(aComplex); + Assertions.assertEquals(expResult, actResult); + + // ------------ Sub-case 7 -------------- + aComplex = new Complex64(-2.3f, 8.099867543f); + expResult = new Complex64( 1.587595E-7f, 1.0000000206720314f); + actResult = Complex64.tan(aComplex); + Assertions.assertEquals(expResult, actResult); + } +} diff --git a/src/test/java/org/flag4j/algebraic_structures/fields/Complex64UnaryOpsTest.java b/src/test/java/org/flag4j/algebraic_structures/fields/Complex64UnaryOpsTest.java new file mode 100644 index 000000000..9e09b25e0 --- /dev/null +++ b/src/test/java/org/flag4j/algebraic_structures/fields/Complex64UnaryOpsTest.java @@ -0,0 +1,328 @@ +package org.flag4j.algebraic_structures.fields; + +import org.flag4j.algebraic_structures.Complex64; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +class Complex64UnaryOpsTest { + Complex64 a; + Complex64 expValue, value; + float expValueFloat, valueFloat; + double expValueDouble, valueDouble; + + + @Test + void magTestCase() { + // ----------- Sub-case 1 -------------- + a = new Complex64(0); + expValueDouble = 0; + + valueDouble = a.mag(); + + Assertions.assertEquals(expValueDouble, valueDouble); + + // ----------- Sub-case 2 -------------- + a = new Complex64(2.4f); + expValueDouble = 2.4f; + + valueDouble= a.mag(); + + Assertions.assertEquals(expValueDouble, valueDouble); + + + // ----------- Sub-case 3 -------------- + a = new Complex64(-10.394f); + expValueDouble = 10.394f; + + valueDouble = a.mag(); + + Assertions.assertEquals(expValueDouble, valueDouble); + + // ----------- Sub-case 4 -------------- + a = new Complex64(2, 8); + expValueDouble = Math.sqrt(4+64); + + valueDouble = a.mag(); + + Assertions.assertEquals(expValueDouble, valueDouble); + + + // ----------- Sub-case 5 -------------- + a = new Complex64(-8.42f, 1.94f); + expValueDouble = Math.sqrt(Math.pow(-8.42f, 2) + Math.pow(1.94f, 2)); + + valueDouble = a.mag(); + + Assertions.assertEquals(expValueDouble, valueDouble); + + + // ----------- Sub-case 5 -------------- + a = new Complex64(Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY); + expValueDouble = Float.POSITIVE_INFINITY; + + valueDouble = a.mag(); + + Assertions.assertEquals(expValueDouble, valueDouble); + + + // ----------- Sub-case 5 -------------- + a = new Complex64(Float.NaN, 2.3f); + valueDouble = a.mag(); + Assertions.assertTrue(Double.isNaN(valueDouble)); + } + + + @Test + void magDoubleTestCase() { + // ----------- Sub-case 1 -------------- + a = new Complex64(0); + expValueDouble = 0; + valueDouble = a.mag(); + Assertions.assertEquals(expValueDouble, valueDouble); + + // ----------- Sub-case 2 -------------- + a = new Complex64(2.4f); + expValueDouble = 2.4f; + valueDouble = a.mag(); + Assertions.assertEquals(expValueDouble, valueDouble); + + + // ----------- Sub-case 3 -------------- + a = new Complex64(-10.394f); + expValueDouble = 10.394f; + valueDouble = a.mag(); + Assertions.assertEquals(expValueDouble, valueDouble); + + // ----------- Sub-case 4 -------------- + a = new Complex64(2, 8); + expValueDouble = Math.sqrt(4+64); + valueDouble = a.mag(); + Assertions.assertEquals(expValueDouble, valueDouble); + + + // ----------- Sub-case 5 -------------- + a = new Complex64(-8.42f, 1.94f); + expValueDouble = Math.sqrt(Math.pow(-8.42f, 2) + Math.pow(1.94f, 2)); + valueDouble = a.mag(); + Assertions.assertEquals(expValueDouble, valueDouble); + + + // ----------- Sub-case 5 -------------- + a = new Complex64(Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY); + expValueDouble = Float.POSITIVE_INFINITY; + valueDouble = a.mag(); + Assertions.assertEquals(expValueDouble, valueDouble); + + + // ----------- Sub-case 5 -------------- + a = new Complex64(Float.NaN, 2.3f); + valueDouble = a.mag(); + Assertions.assertTrue(Double.isNaN(valueDouble)); + } + + + @Test + void absTestCase() { + // ----------- Sub-case 1 -------------- + a = new Complex64(0); + expValueDouble = 0; + valueDouble = a.abs(); + Assertions.assertEquals(expValueDouble, valueDouble); + + // ----------- Sub-case 2 -------------- + a = new Complex64(2.4f); + expValueDouble = 2.4f; + valueDouble = a.abs(); + Assertions.assertEquals(expValueDouble, valueDouble); + + + // ----------- Sub-case 3 -------------- + a = new Complex64(-10.394f); + expValueDouble = 10.394f; + valueDouble = a.abs(); + Assertions.assertEquals(expValueDouble, valueDouble); + + // ----------- Sub-case 4 -------------- + a = new Complex64(2, 8); + expValueDouble = Math.sqrt(4+64); + valueDouble = a.abs(); + Assertions.assertEquals(expValueDouble, valueDouble); + + + // ----------- Sub-case 5 -------------- + a = new Complex64(-8.42f, 1.94f); + expValueDouble = Math.sqrt(Math.pow(-8.42f, 2) + Math.pow(1.94f, 2)); + valueDouble = a.abs(); + Assertions.assertEquals(expValueDouble, valueDouble); + + + // ----------- Sub-case 5 -------------- + a = new Complex64(Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY); + expValueDouble = Float.POSITIVE_INFINITY; + valueDouble = a.abs(); + Assertions.assertEquals(expValueDouble, valueDouble); + + + // ----------- Sub-case 5 -------------- + a = new Complex64(Float.NaN, 2.3f); + valueDouble = a.abs(); + Assertions.assertTrue(Double.isNaN(valueDouble)); + } + + + @Test + void addInvTestCase() { + // ---------- Sub-case 1 ------------ + a = new Complex64(4); + expValue = new Complex64(-4); + value = a.addInv(); + Assertions.assertEquals(expValue, value); + + // ---------- Sub-case 2 ------------ + a = new Complex64(-2.445f); + expValue = new Complex64(2.445f); + value = a.addInv(); + Assertions.assertEquals(expValue, value); + + + // ---------- Sub-case 3 ------------ + a = new Complex64(13.4f, -123); + expValue = new Complex64(-13.4f, 123); + value = a.addInv(); + Assertions.assertEquals(expValue, value); + + + // ---------- Sub-case 4 ------------ + a = new Complex64(Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY); + expValue = new Complex64(Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY); + value = a.addInv(); + Assertions.assertEquals(expValue, value); + + // ---------- Sub-case 5 ------------ + a = Complex64.NaN; + value = a.addInv(); + Assertions.assertTrue(Float.isNaN(value.re)); + Assertions.assertTrue(Float.isNaN(value.im)); + } + + + @Test + void multInvTestCase() { + // ---------- Sub-case 1 ------------ + a = new Complex64(4); + expValue = new Complex64(1.0f/4.0f); + value = a.multInv(); + Assertions.assertEquals(expValue, value); + + // ---------- Sub-case 2 ------------ + a = new Complex64(-2.445f); + expValue = new Complex64(-0.40899795501022496f); + value = a.multInv(); + Assertions.assertEquals(expValue, value); + + + // ---------- Sub-case 3 ------------ + a = new Complex64(13.4f, -123); + expValue = new Complex64(8.753272678814991E-4f, 0.00803472101688385f); + value = a.multInv(); + Assertions.assertEquals(expValue, value); + + // ---------- Sub-case 4 ------------ + a = new Complex64(Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY); + value = a.multInv(); + Assertions.assertTrue(Float.isNaN(value.re)); + Assertions.assertTrue(Float.isNaN(value.im)); + + // ---------- Sub-case 5 ------------ + a = Complex64.NaN; + value = a.multInv(); + Assertions.assertTrue(Float.isNaN(value.re)); + Assertions.assertTrue(Float.isNaN(value.im)); + } + + + @Test + void conjTestCase() { + // --------- Sub-case 1 ----------- + a = new Complex64(0, 0); + expValue = new Complex64(0, 0); + value = a.conj(); + Assertions.assertEquals(expValue, value); + + // --------- Sub-case 2 ----------- + a = new Complex64(14.234f, 0); + expValue = new Complex64(14.234f, 0); + value = a.conj(); + Assertions.assertEquals(expValue, value); + + // --------- Sub-case 3 ----------- + a = new Complex64(1.451f, -9.3f); + expValue = new Complex64(1.451f, 9.3f); + value = a.conj(); + Assertions.assertEquals(expValue, value); + + // --------- Sub-case 4 ----------- + a = new Complex64(24, Float.POSITIVE_INFINITY); + expValue = new Complex64(24, Float.NEGATIVE_INFINITY); + value = a.conj(); + Assertions.assertEquals(expValue, value); + + // --------- Sub-case 5 ----------- + a = new Complex64(123.3f, Float.NEGATIVE_INFINITY); + expValue = new Complex64(123.3f, Float.POSITIVE_INFINITY); + value = a.conj(); + Assertions.assertEquals(expValue, value); + + // --------- Sub-case 6 ----------- + a = Complex64.NaN; + value = a.conj(); + Assertions.assertTrue(Float.isNaN(value.re)); + Assertions.assertTrue(Float.isNaN(value.im)); + } + + + @Test + void sgnTestCase() { + // --------- Sub-case 1 ----------- + a = new Complex64(0); + expValue = new Complex64(0); + value = Complex64.sgn(a); + Assertions.assertEquals(expValue, value); + + // --------- Sub-case 2 ----------- + a = new Complex64(1.23f); + expValue = new Complex64(1); + value = Complex64.sgn(a); + Assertions.assertEquals(expValue, value); + + // --------- Sub-case 3 ----------- + a = new Complex64(-32974.234f); + expValue = new Complex64(-1); + value = Complex64.sgn(a); + Assertions.assertEquals(expValue, value); + + // --------- Sub-case 4 ----------- + a = new Complex64(1.4f, 13.4f); + expValue = a.div((float) a.mag()); + value = Complex64.sgn(a); + Assertions.assertEquals(expValue, value); + + // --------- Sub-case 5 ----------- + a = new Complex64(-13.13f, 4141.2f); + expValue = a.div((float) a.mag()); + value = Complex64.sgn(a); + Assertions.assertEquals(expValue, value); + + // --------- Sub-case 6 ----------- + a = new Complex64(Float.POSITIVE_INFINITY, 4141.2f); + value = Complex64.sgn(a); + Assertions.assertEquals(0, value.im); + Assertions.assertTrue(Float.isNaN(value.re)); + + // --------- Sub-case 7 ----------- + a = Complex64.NaN; + value = Complex64.sgn(a); + Assertions.assertTrue(Float.isNaN(value.im)); + Assertions.assertTrue(Float.isNaN(value.re)); + } +} diff --git a/src/test/java/org/flag4j/complex_matrix/CMatrixAddTests.java b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixAddTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_matrix/CMatrixAddTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixAddTests.java index 747f5df90..f9dffb0fd 100644 --- a/src/test/java/org/flag4j/complex_matrix/CMatrixAddTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixAddTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_matrix; +package org.flag4j.arrays.dense.complex_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/complex_matrix/CMatrixConstructorTests.java b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixConstructorTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_matrix/CMatrixConstructorTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixConstructorTests.java index a0623dfbc..6045b3a4b 100644 --- a/src/test/java/org/flag4j/complex_matrix/CMatrixConstructorTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixConstructorTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_matrix; +package org.flag4j.arrays.dense.complex_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/complex_matrix/CMatrixConversionTests.java b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixConversionTests.java similarity index 98% rename from src/test/java/org/flag4j/complex_matrix/CMatrixConversionTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixConversionTests.java index 62d9b72d5..514718086 100644 --- a/src/test/java/org/flag4j/complex_matrix/CMatrixConversionTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixConversionTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_matrix; +package org.flag4j.arrays.dense.complex_matrix; import org.flag4j.algebraic_structures.Complex128; diff --git a/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixCsrMatMultTests.java b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixCsrMatMultTests.java new file mode 100644 index 000000000..0919d808b --- /dev/null +++ b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixCsrMatMultTests.java @@ -0,0 +1,269 @@ +package org.flag4j.arrays.dense.complex_matrix; + +import org.flag4j.algebraic_structures.Complex128; +import org.flag4j.arrays.Shape; +import org.flag4j.arrays.dense.CMatrix; +import org.flag4j.arrays.dense.CVector; +import org.flag4j.arrays.sparse.CsrCMatrix; +import org.flag4j.arrays.sparse.CsrMatrix; +import org.flag4j.util.exceptions.LinearAlgebraException; +import org.junit.jupiter.api.Test; + +import static org.flag4j.algebraic_structures.Complex128.ZERO; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class CMatrixCsrMatMultTests { + CMatrix A; + Complex128[][] aEntries; + + CsrMatrix Breal; + CsrCMatrix B; + Shape bShape; + double[] bRealEntries; + Complex128[] bEntries; + int[] bRowPointers; + int[] bColIndices; + + CMatrix exp; + Complex128[][] expEntries; + + + @Test + void standardRealTests() { + // ------------------------ Sub-case 1 ------------------------ + aEntries = new Complex128[][]{ + {new Complex128("0.1013+0.5667i"), new Complex128("0.56204+0.08795i"), new Complex128("0.26919+0.74589i"), new Complex128("0.40605+0.37181i"), new Complex128("0.41665+0.56373i")}, + {new Complex128("0.88631+0.70806i"), new Complex128("0.94873+0.70971i"), new Complex128("0.73508+0.92932i"), new Complex128("0.32551+0.08181i"), new Complex128("0.80165+0.87963i")}, + {new Complex128("0.01923+0.20639i"), new Complex128("0.01025+0.53356i"), new Complex128("0.77862+0.04428i"), new Complex128("0.24381+0.01189i"), new Complex128("0.5903+0.51795i")}, + {new Complex128("0.65994+0.40064i"), new Complex128("0.21257+0.16288i"), new Complex128("0.85927+0.11806i"), new Complex128("0.8716+0.70231i"), new Complex128("0.83819+0.21429i")}, + {new Complex128("0.29866+0.71364i"), new Complex128("0.46553+0.89626i"), new Complex128("0.25626+0.07154i"), new Complex128("0.27779+0.59077i"), new Complex128("0.3676+0.45885i")}}; + A = new CMatrix(aEntries); + bShape = new Shape(5, 5); + bRealEntries = new double[]{0.35336, 0.80623, 0.7923, 0.96503, 0.33155, 0.26233, 0.93305, 0.97519}; + bRowPointers = new int[]{0, 2, 3, 7, 8, 8}; + bColIndices = new int[]{1, 4, 4, 0, 1, 2, 3, 4}; + Breal = new CsrMatrix(bShape, bRealEntries, bRowPointers, bColIndices); + expEntries = new Complex128[][]{ + {new Complex128("0.2597764257+0.7198062267i"), new Complex128("0.1250453125+0.4475489415i"), new Complex128("0.0706166127+0.1956693237i"), new Complex128("0.2511677295+0.6959526645i"), new Complex128("0.9229512905+0.8891587199i")}, + {new Complex128("0.7093742524+0.8968216796i"), new Complex128("0.5569022755999999+0.5583161276i"), new Complex128("0.1928335364+0.2437885156i"), new Complex128("0.685866394+0.8671020260000001i"), new Complex128("1.7836825872+1.2129427407000002i")}, + {new Complex128("0.7513916586+0.0427315284i"), new Complex128("0.26494657380000003+0.0876110044i"), new Complex128("0.2042553846+0.0116159724i"), new Complex128("0.726491391+0.041315454i"), new Complex128("0.2613859518+0.6007324068000001i")}, + {new Complex128("0.8292213281+0.1139314418i"), new Complex128("0.5180873669+0.18071294340000002i"), new Complex128("0.2254122991+0.0309706798i"), new Complex128("0.8017418735+0.11015588300000001i"), new Complex128("1.5504582412+1.1369435001000001i")}, + {new Complex128("0.2472985878+0.06903824620000001i"), new Complex128("0.19049750059999998+0.27589091740000005i"), new Complex128("0.0672246858+0.018767088200000004i"), new Complex128("0.239103393+0.066750397i"), new Complex128("0.8805261008999998+1.8615777715i")}}; + exp = new CMatrix(expEntries); + assertEquals(exp, A.mult(Breal)); + + // ------------------------ Sub-case 2 ------------------------ + aEntries = new Complex128[][]{ + {new Complex128("0.66805+0.1356i"), new Complex128("0.02304+0.45698i"), new Complex128("0.21488+0.42425i")}, + {new Complex128("0.41083+0.05417i"), new Complex128("0.73696+0.55199i"), new Complex128("0.14116+0.82652i")}, + {new Complex128("0.37502+0.66468i"), new Complex128("0.0909+0.08818i"), new Complex128("0.67084+0.0016i")}, + {new Complex128("0.57517+0.93909i"), new Complex128("0.42799+0.32726i"), new Complex128("0.92712+0.44066i")}, + {new Complex128("0.9064+0.65531i"), new Complex128("0.12811+0.44938i"), new Complex128("0.3343+0.45097i")}, + {new Complex128("0.49565+0.04573i"), new Complex128("0.84561+0.56477i"), new Complex128("0.5758+0.99876i")}, + {new Complex128("0.82457+0.04468i"), new Complex128("0.73855+0.10166i"), new Complex128("0.23595+0.92517i")}, + {new Complex128("0.93964+0.6983i"), new Complex128("0.80439+0.29492i"), new Complex128("0.15538+0.7558i")}, + {new Complex128("0.17902+0.59831i"), new Complex128("0.78638+0.43736i"), new Complex128("0.33244+0.23329i")}, + {new Complex128("0.6007+0.16697i"), new Complex128("0.54194+0.05845i"), new Complex128("0.64488+0.21455i")}, + {new Complex128("0.26206+0.22261i"), new Complex128("0.4165+0.64287i"), new Complex128("0.27519+0.17844i")}}; + A = new CMatrix(aEntries); + + bShape = new Shape(3, 10); + bRealEntries = new double[]{0.58579, 0.44167, 0.79592}; + bRowPointers = new int[]{0, 1, 3, 3}; + bColIndices = new int[]{1, 4, 5}; + Breal = new CsrMatrix(bShape, bRealEntries, bRowPointers, bColIndices); + + expEntries = new Complex128[][]{ + {new Complex128("0.0"), new Complex128("0.39133700950000005+0.07943312400000001i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.010176076800000001+0.20183435660000001i"), new Complex128("0.0183379968+0.3637195216i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0")}, + {new Complex128("0.0"), new Complex128("0.2406601057+0.031732244300000004i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.32549312319999996+0.24379742329999998i"), new Complex128("0.5865612031999999+0.43933988079999997i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0")}, + {new Complex128("0.0"), new Complex128("0.21968296580000002+0.38936289720000006i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.040147802999999996+0.038946460599999996i"), new Complex128("0.072349128+0.0701842256i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0")}, + {new Complex128("0.0"), new Complex128("0.3369288343+0.5501095311i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.1890303433+0.1445409242i"), new Complex128("0.34064580079999995+0.2604727792i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0")}, + {new Complex128("0.0"), new Complex128("0.5309600560000001+0.3838740449i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0565823437+0.1984776646i"), new Complex128("0.10196531119999999+0.3576705296i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0")}, + {new Complex128("0.0"), new Complex128("0.2903468135+0.026788176700000003i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.3734805687+0.2494419659i"), new Complex128("0.6730379111999999+0.44951173839999997i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0")}, + {new Complex128("0.0"), new Complex128("0.4830248603+0.0261730972i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.3261953785+0.0449001722i"), new Complex128("0.587826716+0.08091322719999999i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0")}, + {new Complex128("0.0"), new Complex128("0.5504317156+0.40905715700000006i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.35527493130000004+0.13025731640000002i"), new Complex128("0.6402300888+0.2347327264i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0")}, + {new Complex128("0.0"), new Complex128("0.10486812580000002+0.35048401490000003i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.3473204546+0.19316879120000002i"), new Complex128("0.6258955695999999+0.3481035712i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0")}, + {new Complex128("0.0"), new Complex128("0.35188405300000003+0.09780935630000001i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.2393586398+0.025815611500000002i"), new Complex128("0.43134088479999994+0.046521524i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0")}, + {new Complex128("0.0"), new Complex128("0.15351212740000003+0.1304027119i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.183955555+0.2839363929i"), new Complex128("0.33150068+0.5116730904i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0")}}; + exp = new CMatrix(expEntries); + + assertEquals(exp, A.mult(Breal)); + + // ------------------------ Sub-case 3 ------------------------ + A = new CMatrix(24, 516); + Breal = new CsrMatrix(15, 12); + assertThrows(LinearAlgebraException.class, ()->A.mult(Breal)); + } + + + @Test + void standardTests() { + // ------------------------ Sub-case 1 ------------------------ + aEntries = new Complex128[][]{ + {new Complex128(0.2259, 0.42103), new Complex128(0.28846, 0.02369), new Complex128(0.01058, 0.56772), new Complex128(0.95379, 0.05841), new Complex128(0.74607, 0.44243)}, + {new Complex128(0.54801, 0.91401), new Complex128(0.14567, 0.17494), new Complex128(0.57924, 0.7371), new Complex128(0.12672, 0.1154), new Complex128(0.6019, 0.78311)}, + {new Complex128(0.12358, 0.62565), new Complex128(0.18628, 0.85896), new Complex128(0.50953, 0.95163), new Complex128(0.57341, 0.24745), new Complex128(0.95169, 0.32643)}}; + A = new CMatrix(aEntries); + bShape = new Shape(5, 5); + bEntries = new Complex128[]{new Complex128(0.46134, 0.36271), new Complex128(0.84641, 0.7844), new Complex128(0.37528, 0.27378), + new Complex128(0.71721, 0.72083), new Complex128(0.83229, 0.48221), new Complex128(0.58432, 0.8158), new Complex128(0.5328, 0.79834), new Complex128(0.41335, 0.35008)}; + bRowPointers = new int[]{0, 3, 4, 4, 7, 8}; + bColIndices = new int[]{0, 1, 2, 4, 0, 3, 4, 1}; + B = new CsrCMatrix(bShape, bEntries, bRowPointers, bColIndices); + expEntries = new Complex128[][]{ + {new Complex128(0.7171689077, 0.7847153040000001), new Complex128(0.01445022709999999, 0.9776225884), new Complex128(-0.030493841400000013, 0.2198510404), new Complex128(0.5096676948, 0.8122320131999999), new Complex128(0.6513582065000001, 1.0174908833)}, + {new Complex128(-0.02888087890000001, 0.7775899977), new Complex128(-0.2784640836999999, 1.7378979185999999), new Complex128(-0.044580465000000014, 0.4930438506), new Complex128(-0.020098289600000013, 0.17080870399999998), new Complex128(-0.04623803950000001, 0.3931227883)}, + {new Complex128(0.18800343009999992, 0.8159152694), new Complex128(-0.10705606509999999, 1.0945900442), new Complex128(-0.12491335460000003, 0.2686276644), new Complex128(0.13318522119999995, 0.6123778619999999), new Complex128(-0.377598643, 1.3399484134000001)}}; + exp = new CMatrix(expEntries); + assertEquals(exp, A.mult(B)); + + // ------------------------ Sub-case 2 ------------------------ + aEntries = new Complex128[][]{ + {new Complex128(0.296, 0.48224), new Complex128(0.21107, 0.36182), new Complex128(0.54805, 0.07346)}, + {new Complex128(0.16943, 0.1518), new Complex128(0.45495, 0.5998), new Complex128(0.31595, 0.12229)}, + {new Complex128(0.75293, 0.12555), new Complex128(0.71752, 0.99826), new Complex128(0.20345, 0.82965)}, + {new Complex128(0.90122, 0.14502), new Complex128(0.57932, 0.31246), new Complex128(0.47947, 0.64057)}, + {new Complex128(0.93961, 0.56065), new Complex128(0.39266, 0.27952), new Complex128(0.79273, 0.12665)}, + {new Complex128(0.5556, 0.21159), new Complex128(0.06416, 0.39728), new Complex128(0.12435, 0.77808)}, + {new Complex128(0.06261, 0.11049), new Complex128(0.87951, 0.69049), new Complex128(0.85275, 0.32878)}, + {new Complex128(0.45795, 0.11123), new Complex128(0.82144, 0.3513), new Complex128(0.20372, 0.88586)}, + {new Complex128(0.22205, 0.53269), new Complex128(0.06012, 0.99425), new Complex128(0.03268, 0.60358)}, + {new Complex128(0.7371, 0.98015), new Complex128(0.55131, 0.96068), new Complex128(0.86053, 0.08451)}, + {new Complex128(0.39417, 0.60287), new Complex128(0.75213, 0.70276), new Complex128(0.47889, 0.34281)}}; + A = new CMatrix(aEntries); + + bShape = new Shape(3, 10); + bEntries = new Complex128[]{new Complex128(0.40213, 0.62104), new Complex128(0.92285, 0.9137), new Complex128(0.45774, 0.49795)}; + bRowPointers = new int[]{0, 1, 1, 3}; + bColIndices = new int[]{4, 1, 9}; + B = new CsrCMatrix(bShape, bEntries, bRowPointers, bColIndices); + + expEntries = new Complex128[][]{ + {ZERO, new Complex128(0.4386475405, 0.568545846), ZERO, ZERO, new Complex128(-0.1804598496, 0.3777510112), ZERO, ZERO, ZERO, ZERO, new Complex128(0.214285, 0.30652707790000006)}, + {ZERO, new Complex128(0.1798380845, 0.40153884149999997), ZERO, ZERO, new Complex128(-0.0261409861, 0.1662661412), ZERO, ZERO, ZERO, ZERO, new Complex128(0.0837286475, 0.21330432710000002)}, + {ZERO, new Complex128(-0.5702973725, 0.9515347675), ZERO, ZERO, new Complex128(0.2248041689, 0.5180870687), ZERO, ZERO, ZERO, ZERO, new Complex128(-0.3199970145, 0.4810719185)}, + {ZERO, new Complex128(-0.14280991949999994, 1.0292417635), ZERO, ZERO, new Complex128(0.2723443778, 0.6180105614), ZERO, ZERO, ZERO, ZERO, new Complex128(-0.09949923369999997, 0.5319665982999999)}, + {ZERO, new Complex128(0.6158507755, 0.8411963535), ZERO, ZERO, new Complex128(0.029659293299999945, 0.8089895789), ZERO, ZERO, ZERO, ZERO, new Complex128(0.2997988627, 0.4527126745000001)}, + {ZERO, new Complex128(-0.5961752985, 0.831669723), ZERO, ZERO, new Complex128(0.09201757439999997, 0.43013651070000003), ZERO, ZERO, ZERO, ZERO, new Complex128(-0.330524967, 0.4180784217)}, + {ZERO, new Complex128(0.4865540514999999, 1.082572298), ZERO, ZERO, new Complex128(-0.04344135030000001, 0.0833146581), ZERO, ZERO, ZERO, ZERO, new Complex128(0.226621784, 0.5751226197)}, + {ZERO, new Complex128(-0.62140728, 1.003654865), ZERO, ZERO, new Complex128(0.1150771543, 0.32913418790000004), ZERO, ZERO, ZERO, ZERO, new Complex128(-0.3478631942, 0.5069359304)}, + {ZERO, new Complex128(-0.521332308, 0.586873519), ZERO, ZERO, new Complex128(-0.2415288311, 0.3521125617), ZERO, ZERO, ZERO, ZERO, new Complex128(-0.2855937178, 0.29255571519999996)}, + {ZERO, new Complex128(0.7169233235, 0.8642563145), ZERO, ZERO, new Complex128(-0.3123023330000001, 0.8519163035), ZERO, ZERO, ZERO, ZERO, new Complex128(0.3518172477, 0.4671845209)}, + {ZERO, new Complex128(0.1287181395, 0.7539240014999999), ZERO, ZERO, new Complex128(-0.21589880270000003, 0.48722744990000005), ZERO, ZERO, ZERO, ZERO, new Complex128(0.048504869099999987, 0.3953811249)}}; + exp = new CMatrix(expEntries); + + assertEquals(exp, A.mult(B)); + + // ------------------------ Sub-case 3 ------------------------ + A = new CMatrix(24, 516); + B = new CsrCMatrix(15, 12); + assertThrows(LinearAlgebraException.class, ()->A.mult(B)); + } + + + @Test + void standardVectorTests() { + Shape aShape; + Complex128[] aData; + Complex128[] bData; + Complex128[] expData; + int[] aRowPointers; + int[] aColIndices; + CsrCMatrix A; + CVector B; + CVector exp; + + + // ------------------------ Sub-case 1 ------------------------ + aShape = new Shape(11, 3); + aData = new Complex128[]{new Complex128(0.66236, 0.99511), new Complex128(0.64644, 0.38726), new Complex128(0.74006, 0.0656)}; + aRowPointers = new int[]{0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 3, 3}; + aColIndices = new int[]{0, 2, 0}; + A = new CsrCMatrix(aShape, aData, aRowPointers, aColIndices); + + Shape bShape = new Shape(3); + bData = new Complex128[]{new Complex128(0.3658, 0.98422), new Complex128(0.41882, 0.92433), new Complex128(0.0877, 0.37298)}; + B = new CVector(bShape, bData); + + Shape expShape = new Shape(11); + expData = new Complex128[]{ZERO, ZERO, ZERO, + new Complex128(-0.7371158762000001, 1.0159191972), ZERO, + ZERO, ZERO, new Complex128(-0.08774744679999999, 0.2750718932), + ZERO, new Complex128(0.20614911600000002, 0.7523783332), ZERO}; + exp = new CVector(expShape, expData); + + assertEquals(exp, A.mult(B)); + + // ------------------------ Sub-case 2 ------------------------ + aShape = new Shape(25, 12); + aData = new Complex128[]{new Complex128(0.93957, 0.74939), new Complex128(0.85329, 0.4632), new Complex128(0.74275, 0.36821)}; + aRowPointers = new int[]{0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}; + aColIndices = new int[]{3, 11, 3}; + A = new CsrCMatrix(aShape, aData, aRowPointers, aColIndices); + + bShape = new Shape(12); + bData = new Complex128[]{new Complex128(0.71476, 0.21926), new Complex128(0.70663, 0.32465), new Complex128(0.86338, 0.91874), new Complex128(0.858, 0.18053), new Complex128(0.91248, 0.40531), new Complex128(0.19292, 0.76875), new Complex128(0.09796, 0.25519), new Complex128(0.33066, 0.93037), new Complex128(0.19831, 0.98659), new Complex128(0.63703, 0.12761), new Complex128(0.43991, 0.98135), new Complex128(0.54505, 0.15659)}; + B = new CVector(bShape, bData); + + expShape = new Shape(25); + expData = new Complex128[]{ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new Complex128(0.6708636833, 0.8125971921), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, new Complex128(0.3925532265, 0.38608384110000005), new Complex128(0.5708065487, 0.4500128375), ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, ZERO}; + exp = new CVector(expShape, expData); + + assertEquals(exp, A.mult(B)); + + // ------------------------ Sub-case 3 ------------------------ + aShape = new Shape(8, 12); + aData = new Complex128[]{new Complex128(0.49723, 0.92929), new Complex128(0.97348, 0.99936), new Complex128(0.446, 0.17502), new Complex128(0.70681, 0.54578), new Complex128(0.87631, 0.11464), new Complex128(0.8615, 0.34637), new Complex128(0.31813, 0.74686), new Complex128(0.58007, 0.62119)}; + aRowPointers = new int[]{0, 1, 1, 3, 3, 5, 6, 7, 8}; + aColIndices = new int[]{2, 3, 8, 4, 9, 3, 8, 4}; + A = new CsrCMatrix(aShape, aData, aRowPointers, aColIndices); + + bShape = new Shape(12); + bData = new Complex128[]{new Complex128(0.18108, 0.66938), new Complex128(0.50518, 0.39486), new Complex128(0.49211, 0.5359), new Complex128(0.82519, 0.59277), new Complex128(0.35919, 0.5141), new Complex128(0.06634, 0.15878), new Complex128(0.99391, 0.72529), new Complex128(0.25949, 0.64906), new Complex128(0.8205, 0.25117), new Complex128(0.94748, 0.6847), new Complex128(0.81247, 0.26422), new Complex128(0.30642, 0.20797)}; + B = new CVector(bShape, bData); + + expShape = new Shape(8); + expData = new Complex128[]{new Complex128(-0.2533146557, 0.7237784589), ZERO, new Complex128(0.5328985606, 1.6573373480000002), ZERO, new Complex128(0.7250857767000001, 1.2680383034), new Complex128(0.5055834401, 0.7964924153), new Complex128(0.07343683880000001, 0.6927033421000001), new Complex128(-0.11099843570000001, 0.5213392231)}; + exp = new CVector(expShape, expData); + + assertEquals(exp, A.mult(B)); + + // ------------------------ Sub-case 4 ------------------------ + aShape = new Shape(12, 12); + aData = new Complex128[]{}; + aRowPointers = new int[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + aColIndices = new int[]{}; + A = new CsrCMatrix(aShape, aData, aRowPointers, aColIndices); + + bShape = new Shape(12); + bData = new Complex128[]{new Complex128(0.5048, 0.33308), new Complex128(0.89792, 0.43697), new Complex128(0.11148, 0.44758), new Complex128(0.48047, 0.79639), new Complex128(0.90456, 0.12293), new Complex128(0.2917, 0.64718), new Complex128(0.8388, 0.82837), new Complex128(0.79096, 0.5637), new Complex128(0.35008, 0.1461), new Complex128(0.24666, 0.35544), new Complex128(0.58987, 0.27922), new Complex128(0.68159, 0.04372)}; + B = new CVector(bShape, bData); + + expShape = new Shape(12); + expData = new Complex128[]{ZERO, ZERO, ZERO, ZERO, ZERO, + ZERO, ZERO, ZERO, ZERO, ZERO, ZERO, + ZERO}; + exp = new CVector(expShape, expData); + assertEquals(exp, A.mult(B)); + + // ------------------------ Sub-case 5 ------------------------ + aShape = new Shape(10, 10); + A = new CsrCMatrix(aShape); + bShape = new Shape(6); + B = new CVector(bShape); + + CsrCMatrix finalA = A; + CVector finalB = B; + assertThrows(IllegalArgumentException.class, ()-> finalA.mult(finalB)); + + aShape = new Shape(10, 10); + A = new CsrCMatrix(aShape); + bShape = new Shape(32); + B = new CVector(bShape); + + CsrCMatrix finalA1 = A; + CVector finalB1 = B; + assertThrows(IllegalArgumentException.class, ()-> finalA1.mult(finalB1)); + } +} diff --git a/src/test/java/org/flag4j/complex_matrix/CMatrixDirectSumTests.java b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixDirectSumTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_matrix/CMatrixDirectSumTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixDirectSumTests.java index 9f25235d8..67b02e1f4 100644 --- a/src/test/java/org/flag4j/complex_matrix/CMatrixDirectSumTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixDirectSumTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_matrix; +package org.flag4j.arrays.dense.complex_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/complex_matrix/CMatrixElemDivTests.java b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixElemDivTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_matrix/CMatrixElemDivTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixElemDivTests.java index e04a2c49c..a85d5582c 100644 --- a/src/test/java/org/flag4j/complex_matrix/CMatrixElemDivTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixElemDivTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_matrix; +package org.flag4j.arrays.dense.complex_matrix; import org.flag4j.CustomAssertions; import org.flag4j.algebraic_structures.Complex128; diff --git a/src/test/java/org/flag4j/complex_matrix/CMatrixElemMultTests.java b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixElemMultTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_matrix/CMatrixElemMultTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixElemMultTests.java index 4f5c75ac0..dbe4f57fe 100644 --- a/src/test/java/org/flag4j/complex_matrix/CMatrixElemMultTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixElemMultTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_matrix; +package org.flag4j.arrays.dense.complex_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/complex_matrix/CMatrixEqualsTests.java b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixEqualsTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_matrix/CMatrixEqualsTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixEqualsTests.java index feca678e4..8f59b8e04 100644 --- a/src/test/java/org/flag4j/complex_matrix/CMatrixEqualsTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixEqualsTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_matrix; +package org.flag4j.arrays.dense.complex_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/complex_matrix/CMatrixInversionTests.java b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixInversionTests.java similarity index 98% rename from src/test/java/org/flag4j/complex_matrix/CMatrixInversionTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixInversionTests.java index cbf2fe889..ecbee7c11 100644 --- a/src/test/java/org/flag4j/complex_matrix/CMatrixInversionTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixInversionTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_matrix; +package org.flag4j.arrays.dense.complex_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CMatrix; diff --git a/src/test/java/org/flag4j/complex_matrix/CMatrixMatVecMultTests.java b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixMatVecMultTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_matrix/CMatrixMatVecMultTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixMatVecMultTests.java index c850eb73d..cc277aab1 100644 --- a/src/test/java/org/flag4j/complex_matrix/CMatrixMatVecMultTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixMatVecMultTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_matrix; +package org.flag4j.arrays.dense.complex_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/complex_matrix/CMatrixMultTests.java b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixMultTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_matrix/CMatrixMultTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixMultTests.java index e48866995..3e8af1745 100644 --- a/src/test/java/org/flag4j/complex_matrix/CMatrixMultTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixMultTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_matrix; +package org.flag4j.arrays.dense.complex_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixNormTests.java b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixNormTests.java new file mode 100644 index 000000000..e21e3753a --- /dev/null +++ b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixNormTests.java @@ -0,0 +1,30 @@ +package org.flag4j.arrays.dense.complex_matrix; + +import org.flag4j.algebraic_structures.Complex128; +import org.flag4j.arrays.Shape; +import org.flag4j.arrays.dense.CMatrix; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class CMatrixNormTests { + + @Test + void testCMatrixNorm() { + Shape aShape; + Complex128[] aData; + CMatrix a; + double exp, p, q; + + // --------------------- Sub-case 1 --------------------- + aShape = new Shape(5, 5); + aData = new Complex128[]{new Complex128(0.967, 0.682), new Complex128(0.575, 0.711), new Complex128(0.969, 0.409), new Complex128(0.76, 0.337), new Complex128(0.29, 0.104), new Complex128(0.647, 0.085), new Complex128(0.008, 0.563), new Complex128(0.662, 0.6), new Complex128(0.337, 0.836), new Complex128(0.783, 0.89), new Complex128(0.9, 0.439), new Complex128(0.633, 0.53), new Complex128(0.002, 0.758), new Complex128(0.033, 0.509), new Complex128(0.438, 0.833), new Complex128(0.72, 0.67), new Complex128(0.043, 0.722), new Complex128(0.642, 0.79), new Complex128(0.602, 0.604), new Complex128(0.398, 0.693), new Complex128(0.599, 0.455), new Complex128(0.981, 0.875), new Complex128(0.15, 0.586), new Complex128(0.151, 0.755), new Complex128(0.663, 0.114)}; + a = new CMatrix(aShape, aData); + + p = 1; + q = 1; + exp = 21.011943301767335; + + assertEquals(exp, a.norm(p, q)); + } +} diff --git a/src/test/java/org/flag4j/complex_matrix/CMatrixPropertiesTests.java b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixPropertiesTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_matrix/CMatrixPropertiesTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixPropertiesTests.java index 9b636033d..fe1a1dcc7 100644 --- a/src/test/java/org/flag4j/complex_matrix/CMatrixPropertiesTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixPropertiesTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_matrix; +package org.flag4j.arrays.dense.complex_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CMatrix; diff --git a/src/test/java/org/flag4j/complex_matrix/CMatrixRemoveRowColTests.java b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixRemoveRowColTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_matrix/CMatrixRemoveRowColTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixRemoveRowColTests.java index 0b6884e18..521fc4361 100644 --- a/src/test/java/org/flag4j/complex_matrix/CMatrixRemoveRowColTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixRemoveRowColTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_matrix; +package org.flag4j.arrays.dense.complex_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CMatrix; diff --git a/src/test/java/org/flag4j/complex_matrix/CMatrixReshapeTests.java b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixReshapeTests.java similarity index 98% rename from src/test/java/org/flag4j/complex_matrix/CMatrixReshapeTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixReshapeTests.java index 529cf1872..1f1745300 100644 --- a/src/test/java/org/flag4j/complex_matrix/CMatrixReshapeTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixReshapeTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_matrix; +package org.flag4j.arrays.dense.complex_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; @@ -111,13 +111,13 @@ void flattenTestCase() { assertArrayEquals(A.data, B.data); // --------------- Sub-case 2 --------------- - expShape = new Shape(1, entries.length); + expShape = new Shape(entries.length, 1); B = A.flatten(1); assertEquals(expShape, B.shape); assertArrayEquals(A.data, B.data); // --------------- Sub-case 2 --------------- - expShape = new Shape(entries.length, 1); + expShape = new Shape(1, entries.length); B = A.flatten(0); assertEquals(expShape, B.shape); assertArrayEquals(A.data, B.data); diff --git a/src/test/java/org/flag4j/complex_matrix/CMatrixRoundTests.java b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixRoundTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_matrix/CMatrixRoundTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixRoundTests.java index c0742ab26..13a20c9cc 100644 --- a/src/test/java/org/flag4j/complex_matrix/CMatrixRoundTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixRoundTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_matrix; +package org.flag4j.arrays.dense.complex_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CMatrix; diff --git a/src/test/java/org/flag4j/complex_matrix/CMatrixScaleOpsTests.java b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixScaleOpsTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_matrix/CMatrixScaleOpsTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixScaleOpsTests.java index 8f7105831..81f5b60c1 100644 --- a/src/test/java/org/flag4j/complex_matrix/CMatrixScaleOpsTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixScaleOpsTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_matrix; +package org.flag4j.arrays.dense.complex_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CMatrix; diff --git a/src/test/java/org/flag4j/complex_matrix/CMatrixSetOperationTests.java b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixSetOperationTests.java similarity index 98% rename from src/test/java/org/flag4j/complex_matrix/CMatrixSetOperationTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixSetOperationTests.java index d1b74924c..346a02896 100644 --- a/src/test/java/org/flag4j/complex_matrix/CMatrixSetOperationTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixSetOperationTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_matrix; +package org.flag4j.arrays.dense.complex_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; @@ -128,7 +128,10 @@ void setColumnCVectorTestCase() { {new Complex128(-445, 0.32), new Complex128(4)}, {new Complex128(94.1), new Complex128(-1334.5)}}; exp = new CMatrix(entriesExp); - entriesA = new Complex128[][]{{new Complex128(0), new Complex128(0)}, {new Complex128(1), new Complex128(4)}, {new Complex128(1331.14), new Complex128(-1334.5)}}; + entriesA = new Complex128[][]{{ + new Complex128(0), new Complex128(0)}, + {new Complex128(1), new Complex128(4)}, + {new Complex128(1331.14), new Complex128(-1334.5)}}; A = new CMatrix(entriesA); A.setCol(valuesVec, col); diff --git a/src/test/java/org/flag4j/complex_matrix/CMatrixStackTests.java b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixStackTests.java similarity index 98% rename from src/test/java/org/flag4j/complex_matrix/CMatrixStackTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixStackTests.java index b660bef3b..9005bf231 100644 --- a/src/test/java/org/flag4j/complex_matrix/CMatrixStackTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixStackTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_matrix; +package org.flag4j.arrays.dense.complex_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/complex_matrix/CMatrixSubTests.java b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixSubTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_matrix/CMatrixSubTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixSubTests.java index 7ea83c776..42a7ca261 100644 --- a/src/test/java/org/flag4j/complex_matrix/CMatrixSubTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixSubTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_matrix; +package org.flag4j.arrays.dense.complex_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/complex_matrix/CMatrixToStringTests.java b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixToStringTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_matrix/CMatrixToStringTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixToStringTests.java index 0f70d622f..a11f58fff 100644 --- a/src/test/java/org/flag4j/complex_matrix/CMatrixToStringTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixToStringTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_matrix; +package org.flag4j.arrays.dense.complex_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CMatrix; diff --git a/src/test/java/org/flag4j/complex_matrix/CMatrixUnaryOpsTests.java b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixUnaryOpsTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_matrix/CMatrixUnaryOpsTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixUnaryOpsTests.java index c4ba0fe09..ad065f7be 100644 --- a/src/test/java/org/flag4j/complex_matrix/CMatrixUnaryOpsTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixUnaryOpsTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_matrix; +package org.flag4j.arrays.dense.complex_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CMatrix; diff --git a/src/test/java/org/flag4j/complex_matrix/CMatrixZeroOnesTests.java b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixZeroOnesTests.java similarity index 95% rename from src/test/java/org/flag4j/complex_matrix/CMatrixZeroOnesTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixZeroOnesTests.java index 8e8dac4bf..56dd6fb21 100644 --- a/src/test/java/org/flag4j/complex_matrix/CMatrixZeroOnesTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_matrix/CMatrixZeroOnesTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_matrix; +package org.flag4j.arrays.dense.complex_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CMatrix; diff --git a/src/test/java/org/flag4j/complex_tensor/CTensorAddTests.java b/src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorAddTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_tensor/CTensorAddTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorAddTests.java index 0e5da5f6a..45b6fd69e 100644 --- a/src/test/java/org/flag4j/complex_tensor/CTensorAddTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorAddTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_tensor; +package org.flag4j.arrays.dense.complex_tensor; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/complex_tensor/CTensorConstructorTests.java b/src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorConstructorTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_tensor/CTensorConstructorTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorConstructorTests.java index ac82bbba6..a27ab45f4 100644 --- a/src/test/java/org/flag4j/complex_tensor/CTensorConstructorTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorConstructorTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_tensor; +package org.flag4j.arrays.dense.complex_tensor; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/complex_tensor/CTensorConversionTests.java b/src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorConversionTests.java similarity index 98% rename from src/test/java/org/flag4j/complex_tensor/CTensorConversionTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorConversionTests.java index 988134a3e..54e2ab298 100644 --- a/src/test/java/org/flag4j/complex_tensor/CTensorConversionTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorConversionTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_tensor; +package org.flag4j.arrays.dense.complex_tensor; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/complex_tensor/CTensorDotTests.java b/src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorDotTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_tensor/CTensorDotTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorDotTests.java index 145833436..6356c9e1c 100644 --- a/src/test/java/org/flag4j/complex_tensor/CTensorDotTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorDotTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_tensor; +package org.flag4j.arrays.dense.complex_tensor; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/complex_tensor/CTensorElemDivTests.java b/src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorElemDivTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_tensor/CTensorElemDivTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorElemDivTests.java index 7e18db6b6..42aa6cb46 100644 --- a/src/test/java/org/flag4j/complex_tensor/CTensorElemDivTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorElemDivTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_tensor; +package org.flag4j.arrays.dense.complex_tensor; import org.flag4j.algebraic_structures.Complex128; diff --git a/src/test/java/org/flag4j/complex_tensor/CTensorElemMultTests.java b/src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorElemMultTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_tensor/CTensorElemMultTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorElemMultTests.java index 25eb7046b..03fd1f785 100644 --- a/src/test/java/org/flag4j/complex_tensor/CTensorElemMultTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorElemMultTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_tensor; +package org.flag4j.arrays.dense.complex_tensor; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/complex_tensor/CTensorReshapeTests.java b/src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorReshapeTests.java similarity index 98% rename from src/test/java/org/flag4j/complex_tensor/CTensorReshapeTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorReshapeTests.java index a5d2afaff..932bf5356 100644 --- a/src/test/java/org/flag4j/complex_tensor/CTensorReshapeTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorReshapeTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_tensor; +package org.flag4j.arrays.dense.complex_tensor; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/complex_tensor/CTensorSubTests.java b/src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorSubTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_tensor/CTensorSubTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorSubTests.java index e5d21fb36..892ebcdbc 100644 --- a/src/test/java/org/flag4j/complex_tensor/CTensorSubTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorSubTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_tensor; +package org.flag4j.arrays.dense.complex_tensor; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/complex_tensor/CTensorToStringTests.java b/src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorToStringTests.java similarity index 97% rename from src/test/java/org/flag4j/complex_tensor/CTensorToStringTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorToStringTests.java index f085768b8..c3878ee2c 100644 --- a/src/test/java/org/flag4j/complex_tensor/CTensorToStringTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorToStringTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_tensor; +package org.flag4j.arrays.dense.complex_tensor; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/complex_tensor/CTensorTransposeTest.java b/src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorTransposeTest.java similarity index 97% rename from src/test/java/org/flag4j/complex_tensor/CTensorTransposeTest.java rename to src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorTransposeTest.java index 60ee34c9c..27d6eec20 100644 --- a/src/test/java/org/flag4j/complex_tensor/CTensorTransposeTest.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorTransposeTest.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_tensor; +package org.flag4j.arrays.dense.complex_tensor; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; @@ -92,7 +92,7 @@ void transposeTestCase() { // -------------------- Sub-case 7 -------------------- aAxes = new int[]{0, 1, 3, 2, 4}; - assertThrows(IllegalArgumentException.class, ()->A.T(aAxes)); + assertThrows(LinearAlgebraException.class, ()->A.T(aAxes)); // -------------------- Sub-case 8 -------------------- assertThrows(LinearAlgebraException.class, ()->A.T(-1, 0)); diff --git a/src/test/java/org/flag4j/complex_tensor/CTensorTransposeTests.java b/src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorTransposeTests.java similarity index 98% rename from src/test/java/org/flag4j/complex_tensor/CTensorTransposeTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorTransposeTests.java index 5df3b5307..8b0ad4e85 100644 --- a/src/test/java/org/flag4j/complex_tensor/CTensorTransposeTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_tensor/CTensorTransposeTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_tensor; +package org.flag4j.arrays.dense.complex_tensor; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; @@ -96,7 +96,7 @@ void transposeTestCase() { // -------------------- Sub-case 7 -------------------- aAxes = new int[]{0, 1, 3, 2, 4}; - assertThrows(IllegalArgumentException.class, ()->A.T(aAxes)); + assertThrows(LinearAlgebraException.class, ()->A.T(aAxes)); // -------------------- Sub-case 8 -------------------- assertThrows(LinearAlgebraException.class, ()->A.T(-1, 0)); diff --git a/src/test/java/org/flag4j/complex_vector/CVectorAddTests.java b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorAddTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_vector/CVectorAddTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorAddTests.java index 78e2ac5f8..5f0588c86 100644 --- a/src/test/java/org/flag4j/complex_vector/CVectorAddTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorAddTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_vector; +package org.flag4j.arrays.dense.complex_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CVector; diff --git a/src/test/java/org/flag4j/complex_vector/CVectorAggregateTests.java b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorAggregateTests.java similarity index 97% rename from src/test/java/org/flag4j/complex_vector/CVectorAggregateTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorAggregateTests.java index 8ea8755c4..d82aa6526 100644 --- a/src/test/java/org/flag4j/complex_vector/CVectorAggregateTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorAggregateTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_vector; +package org.flag4j.arrays.dense.complex_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CVector; diff --git a/src/test/java/org/flag4j/complex_vector/CVectorConstructorTests.java b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorConstructorTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_vector/CVectorConstructorTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorConstructorTests.java index 971b101ec..9d99bda03 100644 --- a/src/test/java/org/flag4j/complex_vector/CVectorConstructorTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorConstructorTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_vector; +package org.flag4j.arrays.dense.complex_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/complex_vector/CVectorCrossTests.java b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorCrossTests.java similarity index 97% rename from src/test/java/org/flag4j/complex_vector/CVectorCrossTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorCrossTests.java index f32163539..ee8ce906a 100644 --- a/src/test/java/org/flag4j/complex_vector/CVectorCrossTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorCrossTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_vector; +package org.flag4j.arrays.dense.complex_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CVector; diff --git a/src/test/java/org/flag4j/complex_vector/CVectorElemDivTests.java b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorElemDivTests.java similarity index 98% rename from src/test/java/org/flag4j/complex_vector/CVectorElemDivTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorElemDivTests.java index 9ba0beea1..a68ffe3c7 100644 --- a/src/test/java/org/flag4j/complex_vector/CVectorElemDivTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorElemDivTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_vector; +package org.flag4j.arrays.dense.complex_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CVector; diff --git a/src/test/java/org/flag4j/complex_vector/CVectorElemMultTests.java b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorElemMultTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_vector/CVectorElemMultTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorElemMultTests.java index bab9b4ef3..a7b737bc0 100644 --- a/src/test/java/org/flag4j/complex_vector/CVectorElemMultTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorElemMultTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_vector; +package org.flag4j.arrays.dense.complex_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CVector; diff --git a/src/test/java/org/flag4j/complex_vector/CVectorElemOppTests.java b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorElemOppTests.java similarity index 97% rename from src/test/java/org/flag4j/complex_vector/CVectorElemOppTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorElemOppTests.java index 37a1c31ea..63c52a25c 100644 --- a/src/test/java/org/flag4j/complex_vector/CVectorElemOppTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorElemOppTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_vector; +package org.flag4j.arrays.dense.complex_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CVector; diff --git a/src/test/java/org/flag4j/complex_vector/CVectorEqualsTests.java b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorEqualsTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_vector/CVectorEqualsTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorEqualsTests.java index 85d7ffda6..b8b8c743b 100644 --- a/src/test/java/org/flag4j/complex_vector/CVectorEqualsTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorEqualsTests.java @@ -22,7 +22,7 @@ * SOFTWARE. */ -package org.flag4j.complex_vector; +package org.flag4j.arrays.dense.complex_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CVector; diff --git a/src/test/java/org/flag4j/complex_vector/CVectorInnerProductTests.java b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorInnerProductTests.java similarity index 98% rename from src/test/java/org/flag4j/complex_vector/CVectorInnerProductTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorInnerProductTests.java index acd260741..5152af116 100644 --- a/src/test/java/org/flag4j/complex_vector/CVectorInnerProductTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorInnerProductTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_vector; +package org.flag4j.arrays.dense.complex_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CVector; diff --git a/src/test/java/org/flag4j/complex_vector/CVectorNormTests.java b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorNormTests.java similarity index 69% rename from src/test/java/org/flag4j/complex_vector/CVectorNormTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorNormTests.java index 2051f8a11..2afbef4b6 100644 --- a/src/test/java/org/flag4j/complex_vector/CVectorNormTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorNormTests.java @@ -1,8 +1,7 @@ -package org.flag4j.complex_vector; +package org.flag4j.arrays.dense.complex_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CVector; -import org.flag4j.linalg.VectorNorms; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -26,7 +25,7 @@ static void setup() { void normTestCase() { // ------------------ Sub-case 1 ------------------ expNorm = 6126.638392078558; - assertEquals(expNorm, VectorNorms.norm(a)); + assertEquals(expNorm, a.norm()); } @@ -34,27 +33,27 @@ void normTestCase() { void pNormTestCase() { // ------------------ Sub-case 1 ------------------ expNorm = 6208.346603991548; - assertEquals(expNorm, VectorNorms.norm(a, 1)); + assertEquals(expNorm, a.norm(1)); // ------------------ Sub-case 2 ------------------ - expNorm =6126.347178284363; - assertEquals(expNorm, VectorNorms.norm(a, 4.15)); + expNorm = 6126.347178284364; + assertEquals(expNorm, a.norm(4.15), 1.0e-12); // ------------------ Sub-case 3 ------------------ - expNorm = 6126.347172780369; - assertEquals(expNorm, VectorNorms.norm(a, 45)); + expNorm = 6126.347172780367; + assertEquals(expNorm, a.norm(45), 1.0e-12); // ------------------ Sub-case 4 ------------------ expNorm = 6126.347172780367; - assertEquals(expNorm, VectorNorms.norm(a, Double.POSITIVE_INFINITY)); + assertEquals(expNorm, a.norm(Double.POSITIVE_INFINITY)); // ------------------ Sub-case 5 ------------------ expNorm = 0.009241438243709998; - assertEquals(expNorm, VectorNorms.norm(a, -1)); + assertEquals(expNorm, a.norm(-1), 1.0e-12); // ------------------ Sub-case 6 ------------------ expNorm = 0.009257; - assertEquals(expNorm, VectorNorms.norm(a, Double.NEGATIVE_INFINITY)); + assertEquals(expNorm, a.norm(Double.NEGATIVE_INFINITY)); } @@ -62,6 +61,6 @@ void pNormTestCase() { void infNormTestCase() { // ------------------ Sub-case 1 ------------------ expNorm = 6126.347172780367; - assertEquals(expNorm, VectorNorms.infNorm(a)); + assertEquals(expNorm, a.norm(Double.POSITIVE_INFINITY)); } } diff --git a/src/test/java/org/flag4j/complex_vector/CVectorOuterProductTests.java b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorOuterProductTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_vector/CVectorOuterProductTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorOuterProductTests.java index 82887b265..ae65661af 100644 --- a/src/test/java/org/flag4j/complex_vector/CVectorOuterProductTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorOuterProductTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_vector; +package org.flag4j.arrays.dense.complex_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CMatrix; diff --git a/src/test/java/org/flag4j/complex_vector/CVectorRepeatTests.java b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorRepeatTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_vector/CVectorRepeatTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorRepeatTests.java index dcba90900..6919cf57e 100644 --- a/src/test/java/org/flag4j/complex_vector/CVectorRepeatTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorRepeatTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_vector; +package org.flag4j.arrays.dense.complex_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CMatrix; diff --git a/src/test/java/org/flag4j/complex_vector/CVectorScalMultDivTests.java b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorScalMultDivTests.java similarity index 98% rename from src/test/java/org/flag4j/complex_vector/CVectorScalMultDivTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorScalMultDivTests.java index 56e71084a..63059370d 100644 --- a/src/test/java/org/flag4j/complex_vector/CVectorScalMultDivTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorScalMultDivTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_vector; +package org.flag4j.arrays.dense.complex_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CVector; diff --git a/src/test/java/org/flag4j/complex_vector/CVectorSetTests.java b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorSetTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_vector/CVectorSetTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorSetTests.java index 8c2063e05..d11a9ed6e 100644 --- a/src/test/java/org/flag4j/complex_vector/CVectorSetTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorSetTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_vector; +package org.flag4j.arrays.dense.complex_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CVector; diff --git a/src/test/java/org/flag4j/complex_vector/CVectorStackJoinExtendTest.java b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorStackJoinExtendTest.java similarity index 98% rename from src/test/java/org/flag4j/complex_vector/CVectorStackJoinExtendTest.java rename to src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorStackJoinExtendTest.java index 4d8ac28ae..1a7918dc3 100644 --- a/src/test/java/org/flag4j/complex_vector/CVectorStackJoinExtendTest.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorStackJoinExtendTest.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_vector; +package org.flag4j.arrays.dense.complex_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CMatrix; diff --git a/src/test/java/org/flag4j/complex_vector/CVectorSubTests.java b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorSubTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_vector/CVectorSubTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorSubTests.java index d1af5f450..cd8c59441 100644 --- a/src/test/java/org/flag4j/complex_vector/CVectorSubTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorSubTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_vector; +package org.flag4j.arrays.dense.complex_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CVector; diff --git a/src/test/java/org/flag4j/complex_vector/CVectorZeroOneTests.java b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorZeroOneTests.java similarity index 98% rename from src/test/java/org/flag4j/complex_vector/CVectorZeroOneTests.java rename to src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorZeroOneTests.java index ae417c079..ed53422b5 100644 --- a/src/test/java/org/flag4j/complex_vector/CVectorZeroOneTests.java +++ b/src/test/java/org/flag4j/arrays/dense/complex_vector/CVectorZeroOneTests.java @@ -22,7 +22,7 @@ * SOFTWARE. */ -package org.flag4j.complex_vector; +package org.flag4j.arrays.dense.complex_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CVector; diff --git a/src/test/java/org/flag4j/matrix/MatrixAddSubEqTests.java b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixAddSubEqTests.java similarity index 98% rename from src/test/java/org/flag4j/matrix/MatrixAddSubEqTests.java rename to src/test/java/org/flag4j/arrays/dense/matrix/MatrixAddSubEqTests.java index 383fedc52..95b9b4fed 100644 --- a/src/test/java/org/flag4j/matrix/MatrixAddSubEqTests.java +++ b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixAddSubEqTests.java @@ -1,4 +1,4 @@ -package org.flag4j.matrix; +package org.flag4j.arrays.dense.matrix; import org.flag4j.arrays.Shape; import org.flag4j.arrays.dense.Matrix; diff --git a/src/test/java/org/flag4j/matrix/MatrixAddTests.java b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixAddTests.java similarity index 97% rename from src/test/java/org/flag4j/matrix/MatrixAddTests.java rename to src/test/java/org/flag4j/arrays/dense/matrix/MatrixAddTests.java index 47d467ad9..7f514ae6d 100644 --- a/src/test/java/org/flag4j/matrix/MatrixAddTests.java +++ b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixAddTests.java @@ -1,4 +1,4 @@ -package org.flag4j.matrix; +package org.flag4j.arrays.dense.matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; @@ -9,9 +9,7 @@ import org.flag4j.util.exceptions.LinearAlgebraException; import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.*; class MatrixAddTests { diff --git a/src/test/java/org/flag4j/matrix/MatrixAggregationTests.java b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixAggregationTests.java similarity index 98% rename from src/test/java/org/flag4j/matrix/MatrixAggregationTests.java rename to src/test/java/org/flag4j/arrays/dense/matrix/MatrixAggregationTests.java index 317c00697..32afc2b5c 100644 --- a/src/test/java/org/flag4j/matrix/MatrixAggregationTests.java +++ b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixAggregationTests.java @@ -1,4 +1,4 @@ -package org.flag4j.matrix; +package org.flag4j.arrays.dense.matrix; import org.flag4j.arrays.dense.Matrix; import org.junit.jupiter.api.Assertions; diff --git a/src/test/java/org/flag4j/matrix/MatrixConstructorTests.java b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixConstructorTests.java similarity index 99% rename from src/test/java/org/flag4j/matrix/MatrixConstructorTests.java rename to src/test/java/org/flag4j/arrays/dense/matrix/MatrixConstructorTests.java index 38568f542..659b27763 100644 --- a/src/test/java/org/flag4j/matrix/MatrixConstructorTests.java +++ b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixConstructorTests.java @@ -1,4 +1,4 @@ -package org.flag4j.matrix; +package org.flag4j.arrays.dense.matrix; import org.flag4j.arrays.Shape; import org.flag4j.arrays.dense.Matrix; diff --git a/src/test/java/org/flag4j/matrix/MatrixConversionTests.java b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixConversionTests.java similarity index 97% rename from src/test/java/org/flag4j/matrix/MatrixConversionTests.java rename to src/test/java/org/flag4j/arrays/dense/matrix/MatrixConversionTests.java index 63b477cc5..08f9dc539 100644 --- a/src/test/java/org/flag4j/matrix/MatrixConversionTests.java +++ b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixConversionTests.java @@ -1,4 +1,4 @@ -package org.flag4j.matrix; +package org.flag4j.arrays.dense.matrix; import org.flag4j.arrays.dense.CMatrix; diff --git a/src/test/java/org/flag4j/matrix/MatrixCsrCMatMultTests.java b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixCsrCMatMultTests.java similarity index 99% rename from src/test/java/org/flag4j/matrix/MatrixCsrCMatMultTests.java rename to src/test/java/org/flag4j/arrays/dense/matrix/MatrixCsrCMatMultTests.java index bc3ac91c2..bf7d318fe 100644 --- a/src/test/java/org/flag4j/matrix/MatrixCsrCMatMultTests.java +++ b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixCsrCMatMultTests.java @@ -1,4 +1,4 @@ -package org.flag4j.matrix; +package org.flag4j.arrays.dense.matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/matrix/MatrixDetTests.java b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixDetTests.java similarity index 98% rename from src/test/java/org/flag4j/matrix/MatrixDetTests.java rename to src/test/java/org/flag4j/arrays/dense/matrix/MatrixDetTests.java index 0e7ae9e60..271a54f07 100644 --- a/src/test/java/org/flag4j/matrix/MatrixDetTests.java +++ b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixDetTests.java @@ -1,4 +1,4 @@ -package org.flag4j.matrix; +package org.flag4j.arrays.dense.matrix; import org.flag4j.arrays.dense.Matrix; import org.flag4j.util.exceptions.LinearAlgebraException; diff --git a/src/test/java/org/flag4j/matrix/MatrixDirectSumTests.java b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixDirectSumTests.java similarity index 99% rename from src/test/java/org/flag4j/matrix/MatrixDirectSumTests.java rename to src/test/java/org/flag4j/arrays/dense/matrix/MatrixDirectSumTests.java index 11243062c..2189bcf01 100644 --- a/src/test/java/org/flag4j/matrix/MatrixDirectSumTests.java +++ b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixDirectSumTests.java @@ -1,4 +1,4 @@ -package org.flag4j.matrix; +package org.flag4j.arrays.dense.matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/matrix/MatrixElemDivTests.java b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixElemDivTests.java similarity index 98% rename from src/test/java/org/flag4j/matrix/MatrixElemDivTests.java rename to src/test/java/org/flag4j/arrays/dense/matrix/MatrixElemDivTests.java index 6bb538aa4..d633aeb4a 100644 --- a/src/test/java/org/flag4j/matrix/MatrixElemDivTests.java +++ b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixElemDivTests.java @@ -1,4 +1,4 @@ -package org.flag4j.matrix; +package org.flag4j.arrays.dense.matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CMatrix; diff --git a/src/test/java/org/flag4j/matrix/MatrixElemMultTests.java b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixElemMultTests.java similarity index 96% rename from src/test/java/org/flag4j/matrix/MatrixElemMultTests.java rename to src/test/java/org/flag4j/arrays/dense/matrix/MatrixElemMultTests.java index 98bb2049a..a121b223b 100644 --- a/src/test/java/org/flag4j/matrix/MatrixElemMultTests.java +++ b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixElemMultTests.java @@ -1,4 +1,4 @@ -package org.flag4j.matrix; +package org.flag4j.arrays.dense.matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; @@ -9,9 +9,7 @@ import org.flag4j.util.exceptions.LinearAlgebraException; import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.*; class MatrixElemMultTests { diff --git a/src/test/java/org/flag4j/matrix/MatrixElemOppTests.java b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixElemOppTests.java similarity index 97% rename from src/test/java/org/flag4j/matrix/MatrixElemOppTests.java rename to src/test/java/org/flag4j/arrays/dense/matrix/MatrixElemOppTests.java index 54685db70..6034419cc 100644 --- a/src/test/java/org/flag4j/matrix/MatrixElemOppTests.java +++ b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixElemOppTests.java @@ -1,4 +1,4 @@ -package org.flag4j.matrix; +package org.flag4j.arrays.dense.matrix; import org.flag4j.CustomAssertions; import org.flag4j.algebraic_structures.Complex128; diff --git a/src/test/java/org/flag4j/matrix/MatrixElementScalarTests.java b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixElementScalarTests.java similarity index 91% rename from src/test/java/org/flag4j/matrix/MatrixElementScalarTests.java rename to src/test/java/org/flag4j/arrays/dense/matrix/MatrixElementScalarTests.java index 675d54af3..712cd5a2a 100644 --- a/src/test/java/org/flag4j/matrix/MatrixElementScalarTests.java +++ b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixElementScalarTests.java @@ -1,4 +1,4 @@ -package org.flag4j.matrix; +package org.flag4j.arrays.dense.matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CMatrix; @@ -24,8 +24,8 @@ void scalDivTestCase() { aEntries = new double[][]{{1.334, -2.3112, 334.3}, {4.13, -35.33, 6}}; A = new Matrix(aEntries); scalar = 1.44; - expEntries = new double[][]{{1.334/1.44, -2.3112/1.44, 334.3/1.44}, - {4.13/1.44, -35.33/1.44, 6/1.44}}; + expEntries = new double[][]{{1.334*(1.0/1.44), -2.3112*(1.0/1.44), 334.3*(1.0/1.44)}, + {4.13*(1.0/1.44), -35.33*(1.0/1.44), 6*(1.0/1.44)}}; expResult = new Matrix(expEntries); assertArrayEquals(expResult.data, A.div(scalar).data); diff --git a/src/test/java/org/flag4j/matrix/MatrixFibTests.java b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixFibTests.java similarity index 98% rename from src/test/java/org/flag4j/matrix/MatrixFibTests.java rename to src/test/java/org/flag4j/arrays/dense/matrix/MatrixFibTests.java index e5f927fec..31eeeb0dc 100644 --- a/src/test/java/org/flag4j/matrix/MatrixFibTests.java +++ b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixFibTests.java @@ -1,4 +1,4 @@ -package org.flag4j.matrix; +package org.flag4j.arrays.dense.matrix; import org.flag4j.arrays.Shape; import org.flag4j.arrays.dense.Matrix; diff --git a/src/test/java/org/flag4j/matrix/MatrixGetTests.java b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixGetTests.java similarity index 98% rename from src/test/java/org/flag4j/matrix/MatrixGetTests.java rename to src/test/java/org/flag4j/arrays/dense/matrix/MatrixGetTests.java index 20745f44f..8315f7600 100644 --- a/src/test/java/org/flag4j/matrix/MatrixGetTests.java +++ b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixGetTests.java @@ -1,4 +1,4 @@ -package org.flag4j.matrix; +package org.flag4j.arrays.dense.matrix; import org.flag4j.arrays.dense.Matrix; diff --git a/src/test/java/org/flag4j/matrix/MatrixInversionTests.java b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixInversionTests.java similarity index 97% rename from src/test/java/org/flag4j/matrix/MatrixInversionTests.java rename to src/test/java/org/flag4j/arrays/dense/matrix/MatrixInversionTests.java index f2d388942..e821107b0 100644 --- a/src/test/java/org/flag4j/matrix/MatrixInversionTests.java +++ b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixInversionTests.java @@ -1,4 +1,4 @@ -package org.flag4j.matrix; +package org.flag4j.arrays.dense.matrix; import org.flag4j.arrays.dense.Matrix; import org.flag4j.linalg.Invert; diff --git a/src/test/java/org/flag4j/matrix/MatrixMultTests.java b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixMultTests.java similarity index 99% rename from src/test/java/org/flag4j/matrix/MatrixMultTests.java rename to src/test/java/org/flag4j/arrays/dense/matrix/MatrixMultTests.java index 9a82e2ef7..8b3159203 100644 --- a/src/test/java/org/flag4j/matrix/MatrixMultTests.java +++ b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixMultTests.java @@ -1,4 +1,4 @@ -package org.flag4j.matrix; +package org.flag4j.arrays.dense.matrix; import org.flag4j.algebraic_structures.Complex128; diff --git a/src/test/java/org/flag4j/matrix/MatrixPropertiesTests.java b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixPropertiesTests.java similarity index 99% rename from src/test/java/org/flag4j/matrix/MatrixPropertiesTests.java rename to src/test/java/org/flag4j/arrays/dense/matrix/MatrixPropertiesTests.java index 6ba236583..689b075f6 100644 --- a/src/test/java/org/flag4j/matrix/MatrixPropertiesTests.java +++ b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixPropertiesTests.java @@ -1,4 +1,4 @@ -package org.flag4j.matrix; +package org.flag4j.arrays.dense.matrix; import org.flag4j.arrays.dense.Matrix; import org.flag4j.linalg.Invert; diff --git a/src/test/java/org/flag4j/matrix/MatrixRemoveRowColTests.java b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixRemoveRowColTests.java similarity index 99% rename from src/test/java/org/flag4j/matrix/MatrixRemoveRowColTests.java rename to src/test/java/org/flag4j/arrays/dense/matrix/MatrixRemoveRowColTests.java index e362d68b6..d2e3b4dd3 100644 --- a/src/test/java/org/flag4j/matrix/MatrixRemoveRowColTests.java +++ b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixRemoveRowColTests.java @@ -1,4 +1,4 @@ -package org.flag4j.matrix; +package org.flag4j.arrays.dense.matrix; import org.flag4j.arrays.dense.Matrix; import org.junit.jupiter.api.Test; diff --git a/src/test/java/org/flag4j/matrix/MatrixReshapeTests.java b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixReshapeTests.java similarity index 99% rename from src/test/java/org/flag4j/matrix/MatrixReshapeTests.java rename to src/test/java/org/flag4j/arrays/dense/matrix/MatrixReshapeTests.java index cea6a4aed..e04884ba3 100644 --- a/src/test/java/org/flag4j/matrix/MatrixReshapeTests.java +++ b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixReshapeTests.java @@ -1,4 +1,4 @@ -package org.flag4j.matrix; +package org.flag4j.arrays.dense.matrix; import org.flag4j.arrays.Shape; import org.flag4j.arrays.dense.Matrix; diff --git a/src/test/java/org/flag4j/matrix/MatrixScalMultTests.java b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixScalMultTests.java similarity index 98% rename from src/test/java/org/flag4j/matrix/MatrixScalMultTests.java rename to src/test/java/org/flag4j/arrays/dense/matrix/MatrixScalMultTests.java index 1feeb29e7..b0209ce43 100644 --- a/src/test/java/org/flag4j/matrix/MatrixScalMultTests.java +++ b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixScalMultTests.java @@ -1,4 +1,4 @@ -package org.flag4j.matrix; +package org.flag4j.arrays.dense.matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CMatrix; diff --git a/src/test/java/org/flag4j/matrix/MatrixSetOperationTests.java b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixSetOperationTests.java similarity index 99% rename from src/test/java/org/flag4j/matrix/MatrixSetOperationTests.java rename to src/test/java/org/flag4j/arrays/dense/matrix/MatrixSetOperationTests.java index d4a024634..f87b258e9 100644 --- a/src/test/java/org/flag4j/matrix/MatrixSetOperationTests.java +++ b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixSetOperationTests.java @@ -1,4 +1,4 @@ -package org.flag4j.matrix; +package org.flag4j.arrays.dense.matrix; import org.flag4j.arrays.dense.Matrix; import org.flag4j.util.exceptions.LinearAlgebraException; diff --git a/src/test/java/org/flag4j/matrix/MatrixStackTests.java b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixStackTests.java similarity index 98% rename from src/test/java/org/flag4j/matrix/MatrixStackTests.java rename to src/test/java/org/flag4j/arrays/dense/matrix/MatrixStackTests.java index cceef237e..4c8224c49 100644 --- a/src/test/java/org/flag4j/matrix/MatrixStackTests.java +++ b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixStackTests.java @@ -1,4 +1,4 @@ -package org.flag4j.matrix; +package org.flag4j.arrays.dense.matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/matrix/MatrixSubTests.java b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixSubTests.java similarity index 99% rename from src/test/java/org/flag4j/matrix/MatrixSubTests.java rename to src/test/java/org/flag4j/arrays/dense/matrix/MatrixSubTests.java index 4b9ac34e2..9e7679495 100644 --- a/src/test/java/org/flag4j/matrix/MatrixSubTests.java +++ b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixSubTests.java @@ -1,4 +1,4 @@ -package org.flag4j.matrix; +package org.flag4j.arrays.dense.matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/matrix/MatrixToStringTest.java b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixToStringTest.java similarity index 99% rename from src/test/java/org/flag4j/matrix/MatrixToStringTest.java rename to src/test/java/org/flag4j/arrays/dense/matrix/MatrixToStringTest.java index 93941d5e2..452a4e02f 100644 --- a/src/test/java/org/flag4j/matrix/MatrixToStringTest.java +++ b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixToStringTest.java @@ -1,4 +1,4 @@ -package org.flag4j.matrix; +package org.flag4j.arrays.dense.matrix; import org.flag4j.arrays.dense.Matrix; import org.flag4j.io.PrintOptions; diff --git a/src/test/java/org/flag4j/matrix/MatrixTransposeTests.java b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixTransposeTests.java similarity index 95% rename from src/test/java/org/flag4j/matrix/MatrixTransposeTests.java rename to src/test/java/org/flag4j/arrays/dense/matrix/MatrixTransposeTests.java index 70b37e3ab..75dc7faf1 100644 --- a/src/test/java/org/flag4j/matrix/MatrixTransposeTests.java +++ b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixTransposeTests.java @@ -1,4 +1,4 @@ -package org.flag4j.matrix; +package org.flag4j.arrays.dense.matrix; import org.flag4j.arrays.dense.Matrix; import org.junit.jupiter.api.Test; diff --git a/src/test/java/org/flag4j/matrix/MatrixTriangularTests.java b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixTriangularTests.java similarity index 98% rename from src/test/java/org/flag4j/matrix/MatrixTriangularTests.java rename to src/test/java/org/flag4j/arrays/dense/matrix/MatrixTriangularTests.java index 7bb7141d6..a5469fd82 100644 --- a/src/test/java/org/flag4j/matrix/MatrixTriangularTests.java +++ b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixTriangularTests.java @@ -1,4 +1,4 @@ -package org.flag4j.matrix; +package org.flag4j.arrays.dense.matrix; import org.flag4j.arrays.dense.Matrix; import org.junit.jupiter.api.Test; diff --git a/src/test/java/org/flag4j/matrix/MatrixVectorCheckTests.java b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixVectorCheckTests.java similarity index 98% rename from src/test/java/org/flag4j/matrix/MatrixVectorCheckTests.java rename to src/test/java/org/flag4j/arrays/dense/matrix/MatrixVectorCheckTests.java index 8392c14ab..4c4e5e052 100644 --- a/src/test/java/org/flag4j/matrix/MatrixVectorCheckTests.java +++ b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixVectorCheckTests.java @@ -1,4 +1,4 @@ -package org.flag4j.matrix; +package org.flag4j.arrays.dense.matrix; import org.flag4j.arrays.dense.Matrix; import org.junit.jupiter.api.Test; diff --git a/src/test/java/org/flag4j/matrix/MatrixVectorTests.java b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixVectorTests.java similarity index 99% rename from src/test/java/org/flag4j/matrix/MatrixVectorTests.java rename to src/test/java/org/flag4j/arrays/dense/matrix/MatrixVectorTests.java index dde1b7c8e..9b42c4a1a 100644 --- a/src/test/java/org/flag4j/matrix/MatrixVectorTests.java +++ b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixVectorTests.java @@ -1,4 +1,4 @@ -package org.flag4j.matrix; +package org.flag4j.arrays.dense.matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CVector; diff --git a/src/test/java/org/flag4j/matrix/MatrixZerosOnesTests.java b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixZerosOnesTests.java similarity index 96% rename from src/test/java/org/flag4j/matrix/MatrixZerosOnesTests.java rename to src/test/java/org/flag4j/arrays/dense/matrix/MatrixZerosOnesTests.java index 110bdd960..899f0bb9a 100644 --- a/src/test/java/org/flag4j/matrix/MatrixZerosOnesTests.java +++ b/src/test/java/org/flag4j/arrays/dense/matrix/MatrixZerosOnesTests.java @@ -1,4 +1,4 @@ -package org.flag4j.matrix; +package org.flag4j.arrays.dense.matrix; import org.flag4j.arrays.dense.Matrix; import org.flag4j.util.ArrayUtils; diff --git a/src/test/java/org/flag4j/arrays/dense/matrix/NormTests.java b/src/test/java/org/flag4j/arrays/dense/matrix/NormTests.java new file mode 100644 index 000000000..a020059c0 --- /dev/null +++ b/src/test/java/org/flag4j/arrays/dense/matrix/NormTests.java @@ -0,0 +1,274 @@ +package org.flag4j.arrays.dense.matrix; + +import org.flag4j.arrays.Shape; +import org.flag4j.arrays.dense.Matrix; +import org.flag4j.linalg.MatrixNorms; +import org.flag4j.util.exceptions.LinearAlgebraException; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class NormTests { + Shape aShape; + double[] aData; + Matrix a; + double exp; + double p; + double q; + + @Test + void schattenNormTests() { + // ---------------- Sub-case 1 ---------------- + aShape = new Shape(5, 5); + aData = new double[]{0.5508, 0.70815, 0.2909, 0.51083, 0.89295, 0.89629, 0.12559, 0.20724, 0.05147, 0.44081, 0.02988, 0.45683, 0.64914, 0.27849, 0.67625, 0.59086, 0.02398, 0.55885, 0.25925, 0.4151, 0.28353, 0.69314, 0.44045, 0.15687, 0.54465}; + a = new Matrix(aShape, aData); + p = 1; + + exp = 4.003647020302842; + + assertEquals(exp, MatrixNorms.schattenNorm(a, p)); + + // ---------------- Sub-case 2 ---------------- + aShape = new Shape(3, 2); + aData = new double[]{0.5508, 0.70815, 0.2909, 0.51083, 0.89295, 0.89629}; + a = new Matrix(aShape, aData); + + p = 1; + exp = 1.803745292776371; + + assertEquals(exp, MatrixNorms.schattenNorm(a, p)); + + // ---------------- Sub-case 3 ---------------- + aShape = new Shape(2, 3); + aData = new double[]{0.5508, 0.70815, 0.2909, 0.51083, 0.89295, 0.89629}; + a = new Matrix(aShape, aData); + + p = 1; + exp = 1.9501167666102215; + + assertEquals(exp, MatrixNorms.schattenNorm(a, p)); + + // ---------------- Sub-case 4 ---------------- + aShape = new Shape(9, 9); + aData = new double[]{0.5508, 0.70815, 0.2909, 0.51083, 0.89295, 0.89629, 0.12559, 0.20724, 0.05147, 0.44081, + 0.02988, 0.45683, 0.64914, 0.27849, 0.67625, 0.59086, 0.02398, 0.55885, 0.25925, 0.4151, 0.28353, + 0.69314, 0.44045, 0.15687, 0.54465, 0.78031, 0.30636, 0.22196, 0.38797, 0.93638, 0.976, 0.67238, + 0.90283, 0.84575, 0.37799, 0.09222, 0.65341, 0.55784, 0.36156, 0.22505, 0.40652, 0.46894, 0.26924, + 0.29179, 0.45769, 0.86053, 0.58625, 0.28349, 0.27798, 0.45462, 0.20541, 0.20138, 0.51404, 0.08723, + 0.48359, 0.36218, 0.70769, 0.74675, 0.69109, 0.68918, 0.3736, 0.66813, 0.33985, 0.57279, 0.32581, + 0.44515, 0.06153, 0.24268, 0.9716, 0.23058, 0.69148, 0.65048, 0.72394, 0.47509, 0.59666, 0.06697, + 0.07256, 0.19898, 0.15186, 0.1001, 0.12929}; + a = new Matrix(aShape, aData); + + p = 2; + exp = 4.622437431215267; + assertEquals(exp, MatrixNorms.schattenNorm(a, p)); + + // ---------------- Sub-case 5 ---------------- + aShape = new Shape(9, 9); + aData = new double[]{0.5508, 0.70815, 0.2909, 0.51083, 0.89295, 0.89629, 0.12559, 0.20724, 0.05147, 0.44081, 0.02988, 0.45683, 0.64914, 0.27849, 0.67625, 0.59086, 0.02398, 0.55885, 0.25925, 0.4151, 0.28353, 0.69314, 0.44045, 0.15687, 0.54465, 0.78031, 0.30636, 0.22196, 0.38797, 0.93638, 0.976, 0.67238, 0.90283, 0.84575, 0.37799, 0.09222, 0.65341, 0.55784, 0.36156, 0.22505, 0.40652, 0.46894, 0.26924, 0.29179, 0.45769, 0.86053, 0.58625, 0.28349, 0.27798, 0.45462, 0.20541, 0.20138, 0.51404, 0.08723, 0.48359, 0.36218, 0.70769, 0.74675, 0.69109, 0.68918, 0.3736, 0.66813, 0.33985, 0.57279, 0.32581, 0.44515, 0.06153, 0.24268, 0.9716, 0.23058, 0.69148, 0.65048, 0.72394, 0.47509, 0.59666, 0.06697, 0.07256, 0.19898, 0.15186, 0.1001, 0.12929}; + a = new Matrix(aShape, aData); + + p = Double.POSITIVE_INFINITY; + exp = 4.169634097557557; + assertEquals(exp, MatrixNorms.schattenNorm(a, p)); + + // ---------------- Sub-case 6 ---------------- + aShape = new Shape(9, 9); + aData = new double[]{0.5508, 0.70815, 0.2909, 0.51083, 0.89295, 0.89629, 0.12559, 0.20724, 0.05147, 0.44081, 0.02988, 0.45683, 0.64914, 0.27849, 0.67625, 0.59086, 0.02398, 0.55885, 0.25925, 0.4151, 0.28353, 0.69314, 0.44045, 0.15687, 0.54465, 0.78031, 0.30636, 0.22196, 0.38797, 0.93638, 0.976, 0.67238, 0.90283, 0.84575, 0.37799, 0.09222, 0.65341, 0.55784, 0.36156, 0.22505, 0.40652, 0.46894, 0.26924, 0.29179, 0.45769, 0.86053, 0.58625, 0.28349, 0.27798, 0.45462, 0.20541, 0.20138, 0.51404, 0.08723, 0.48359, 0.36218, 0.70769, 0.74675, 0.69109, 0.68918, 0.3736, 0.66813, 0.33985, 0.57279, 0.32581, 0.44515, 0.06153, 0.24268, 0.9716, 0.23058, 0.69148, 0.65048, 0.72394, 0.47509, 0.59666, 0.06697, 0.07256, 0.19898, 0.15186, 0.1001, 0.12929}; + a = new Matrix(aShape, aData); + + p = 3.25621; + exp = 4.215513043537819; + assertEquals(exp, MatrixNorms.schattenNorm(a, p)); + + // ---------------- Sub-case 7 ---------------- + aShape = new Shape(4, 4); + aData = new double[]{0.5508, 0.70815, 0.2909, 0.51083, 0.89295, + 0.89629, 0.12559, 0.20724, 0.05147, 0.44081, + 0.02988, 0.45683, 0.64914, 0.27849, 0.67625, 0.59086}; + a = new Matrix(aShape, aData); + + p = -1; + exp = 0.026667736708101627; + assertEquals(exp, MatrixNorms.schattenNorm(a, p)); + + // ---------------- Sub-case 8 ---------------- + aShape = new Shape(4, 4); + aData = new double[]{0.5508, 0.70815, 0.2909, 0.51083, 0.89295, 0.89629, 0.12559, 0.20724, 0.05147, 0.44081, 0.02988, 0.45683, 0.64914, 0.27849, 0.67625, 0.59086}; + a = new Matrix(aShape, aData); + + p = Double.NEGATIVE_INFINITY; + exp = 0.029990005262000484; + + assertEquals(exp, MatrixNorms.schattenNorm(a, p)); + } + + + @Test + void entryWiseNormTests() { + // ---------------- Sub-case 1 ---------------- + aShape = new Shape(4, 4); + aData = new double[]{0.03766, 0.53811, 0.6172, 0.83172, 0.46543, 0.06995, 0.05425, 0.57504, 0.33034, 0.42442, 0.29937, 0.92909, 0.40243, 0.48919, 0.89897, 0.2998}; + a = new Matrix(aShape, aData); + + p = 1; + exp = 7.26297; + assertEquals(exp, MatrixNorms.entryWiseNorm(a, p)); + + // ---------------- Sub-case 2 ---------------- + aShape = new Shape(7, 2); + aData = new double[]{0.52663, 0.10295, 0.80074, 0.70444, 0.50312, 0.94987, 0.65969, 0.4173, 0.09431, 0.36203, 0.77842, 0.06282, 0.2784, 0.40589}; + a = new Matrix(aShape, aData); + + p = -1; + exp = 0.017283789037010208; + assertEquals(exp, MatrixNorms.entryWiseNorm(a, p)); + + // ---------------- Sub-case 3 ---------------- + aShape = new Shape(3, 5); + aData = new double[]{0.58006, 1.53305, -0.16158, -0.87468, 0.35272, -1.96337, 1.03996, -0.03287, -1.32042, -0.95082, -0.18674, -0.74043, 1.61446, 0.3053, 1.06991}; + a = new Matrix(aShape, aData); + + p = 2; + exp = 3.9515551558468727; + assertEquals(exp, MatrixNorms.entryWiseNorm(a, p)); + + // ---------------- Sub-case 4 ---------------- + aShape = new Shape(3, 3); + aData = new double[]{1.30776, -0.03747, -1.68686, -0.46246, 0.99453, 1.47121, 1.18819, -0.04004, -0.09288}; + a = new Matrix(aShape, aData); + + p = -2; + exp = 0.026173754083315442; + assertEquals(exp, MatrixNorms.entryWiseNorm(a, p)); + + // ---------------- Sub-case 5 ---------------- + aShape = new Shape(3, 3); + aData = new double[]{-0.57719, 0.76596, -0.78453, -0.33543, 1.861, -0.38476, -0.84035, -1.7351, 0.90212}; + a = new Matrix(aShape, aData); + + p = 4.51; + exp = 2.1303977251997614; + assertEquals(exp, MatrixNorms.entryWiseNorm(a, p)); + + // ---------------- Sub-case 6 ---------------- + aShape = new Shape(3, 3); + aData = new double[]{-0.79318, 1.54067, -0.5029, 0.67312, -0.24927, -0.97936, 0.61532, 0.61237, 0.00226}; + a = new Matrix(aShape, aData); + + p = Double.POSITIVE_INFINITY; + exp = 1.54067; + assertEquals(exp, MatrixNorms.entryWiseNorm(a, p)); + + // ---------------- Sub-case 7 ---------------- + aShape = new Shape(3, 3); + aData = new double[]{0.87487, -0.41412, 1.05809, 1.97711, -1.33573, -1.05693, 0.92927, -1.43141, -0.71267}; + a = new Matrix(aShape, aData); + + p = Double.NEGATIVE_INFINITY; + exp = 0.41412; + assertEquals(exp, MatrixNorms.entryWiseNorm(a, p)); + } + + @Test + void inducedNormTests() { + // ---------------- Sub-case 1 ---------------- + aShape = new Shape(6, 8); + aData = new double[]{0.87001, -0.53075, 1.80658, -0.36132, -0.86666, 2.14868, -1.48927, 1.7104, -0.24429, + 0.00475, 0.00986, -1.52087, 0.01851, -0.91805, -0.68661, + 0.63289, 0.27013, -1.8278, 2.31575, 0.00153, 0.5054, -0.24858, + -0.32284, 1.44068, 0.09654, 1.80215, -0.85929, -0.83069, 0.12971, + -0.40369, -0.86397, -0.31808, -0.94131, 0.07882, -0.36221, -0.56579, + 0.25437, 0.86334, -0.44146, 0.53003, 0.06752, -0.16675, 0.65504, + -1.92004, -0.86963, -0.9875, -0.12287, -0.08252}; + a = new Matrix(aShape, aData); + + p = 1; + exp = 6.008729999999999; + assertEquals(exp, MatrixNorms.inducedNorm(a, p)); + + // ---------------- Sub-case 2 ---------------- + aShape = new Shape(6, 8); + aData = new double[]{-1.11826, -1.09012, -0.42251, -0.24597, 1.71915, 1.59003, 0.15207, -0.14672, 0.11485, 1.41536, 0.72123, 1.22943, -2.11986, 0.02405, 0.0345, 0.86933, 0.39555, -0.49977, 0.27804, -1.24306, -0.74671, -0.25535, -0.06707, 0.2038, -0.78898, -0.1462, -0.39705, -0.60432, 0.84509, -2.39218, -0.58742, 1.2572, 1.21936, 0.23098, 0.99622, -1.49354, 0.44016, 1.76301, -0.93744, -0.24744, 1.37773, 0.92328, 0.90134, -0.10557, 0.40794, -1.60648, 0.58006, 0.33817}; + a = new Matrix(aShape, aData); + + p = -1; + exp = 2.35856; + assertEquals(exp, MatrixNorms.inducedNorm(a, p)); + + // ---------------- Sub-case 3 ---------------- + aShape = new Shape(14, 14); + aData = new double[]{0.72041, 1.89717, 1.94906, 0.27556, -1.1979, -0.53392, -0.21621, -0.04478, 0.35898, -0.0922, 1.3103, + -0.21821, 0.22011, 0.3564, 0.72875, -1.29384, -0.13304, -1.66793, 0.50722, -1.38537, -0.42744, 0.79773, 0.4184, + -0.00807, 0.20668, -1.1379, -1.37254, -0.03686, 0.99935, 0.97841, 0.09745, -1.60085, 1.10699, -0.15013, 0.65144, + -0.3773, -1.08707, 0.83051, -0.36223, 1.48693, 0.35924, 1.24536, 0.13307, -0.58486, 0.32037, 0.18886, -0.06011, + 0.35101, -0.07892, -1.50245, -0.3265, 1.28271, 0.04595, 2.22292, -1.04199, 0.87992, -0.42606, 0.11407, 0.27266, + 0.05567, -1.2789, 1.05216, -0.57621, 0.13403, -0.73407, -0.19883, -0.48796, 0.10257, 0.5898, -0.12817, -0.28215, + -0.9266, -1.54239, 1.21185, -0.03149, -0.65298, -0.66052, 1.58396, -0.62336, -0.47375, -0.41352, -1.85793, + -1.49312, 0.94452, -0.05956, 2.71684, 0.27951, -2.07396, -0.99021, -1.36677, -0.50161, -1.16948, -0.26776, + -0.12959, 1.0008, 0.35709, 0.40961, -0.20094, 0.57074, 0.12553, 0.72861, 0.63256, 0.64524, 0.29983, -0.12953, + 0.52116, -1.69015, -0.10428, -1.16605, -0.55069, 0.02144, 1.13731, -0.46032, 1.13434, 1.01273, -0.18879, 0.41548, + 0.04057, 1.49239, -0.92921, -1.21722, -0.68579, -0.13984, 0.82782, 0.37276, 1.15156, 0.93135, 0.22033, 1.24912, + 0.93163, 0.35953, 1.3668, 0.31723, 1.11881, -0.65634, -0.47619, 0.65127, 0.81784, 2.03589, 0.76112, 0.47193, + -0.6429, 1.15648, 0.53338, 1.08118, -1.86393, -0.61871, -0.30993, -0.37079, -0.13606, -0.13351, 0.10444, 0.18017, + -0.01592, -0.6389, -0.97788, 0.78164, -0.34281, -0.61406, -1.70849, 3.08745, 0.28245, 2.32343, -1.26434, -0.55688, + -0.07774, -0.04553, -2.44962, 0.11056, 1.34567, -0.72284, -1.29676, -1.33067, 0.05912, 0.20892, 1.46179, -2.28954, + -1.45136, -0.44066, 0.84616, -0.52071, -0.82045, -0.54602, 1.38205, -0.22712, -0.10525, 0.06766, 0.0821, 1.88073, + -0.12256, 1.32759, 2.51665, 1.39866, 0.55066, 1.49893, 0.83932}; + a = new Matrix(aShape, aData); + + p = 2; + exp = 6.764801864669105; + assertEquals(exp, MatrixNorms.inducedNorm(a, p)); + + // ---------------- Sub-case 4 ---------------- + aShape = new Shape(5, 5); + aData = new double[]{0.58539, -2.14589, -2.27855, -0.80679, -1.05446, -0.35289, -0.43887, -1.46231, 0.43138, -0.19616, + 0.05008, 1.85505, -0.7718, -0.11618, -0.60505, 0.3251, -1.93755, 0.41506, -1.2915, -1.2645, 0.33895, 2.39226, + 1.35074, 0.46931, -0.31796}; + a = new Matrix(aShape, aData); + + p = -2; + exp = 0.4970402772750497; + assertEquals(exp, MatrixNorms.inducedNorm(a, p)); + + // ---------------- Sub-case 5 ---------------- + aShape = new Shape(5, 5); + aData = new double[]{-0.9525, 1.27424, 0.98569, -0.27018, 0.82013, -0.17059, 0.57118, 0.50879, 1.08135, -1.05082, -0.56911, + 0.17297, 0.21703, 0.57278, 0.15335, -2.26664, -1.46254, 0.70572, 0.29167, 0.8563, -1.05627, 0.01729, + 0.12783, 0.13725, -0.5995}; + a = new Matrix(aShape, aData); + + p = Double.POSITIVE_INFINITY; + exp = 5.582870000000001; + assertEquals(exp, MatrixNorms.inducedNorm(a, p)); + + // ---------------- Sub-case 6 ---------------- + aShape = new Shape(5, 5); + aData = new double[]{-0.32281, 2.19241, -0.56914, 0.69786, 0.35078, 0.28429, 0.4643, -0.13171, -0.23963, 0.89503, -0.61056, + -1.26108, -0.72803, -1.38127, 0.8772, 0.88868, 0.57504, -0.45355, 0.23406, 3.34191, -0.07215, 0.1964, + 0.63613, 0.20461, 1.03078}; + a = new Matrix(aShape, aData); + + p = Double.NEGATIVE_INFINITY; + exp = 2.0149600000000003; + assertEquals(exp, MatrixNorms.inducedNorm(a, p)); + + // ---------------- Sub-case 7 ---------------- + a = new Matrix(5, 5); + assertEquals(0, MatrixNorms.inducedNorm(a, 1), 1e-16); + assertEquals(0, MatrixNorms.inducedNorm(a, -1), 1e-16); + // TODO: SVD does not converge here (i.e. Schur decomp does not converge). Need to do some work. Balancing first may + // fix or partly fix this. +// assertEquals(0, MatrixNorms.inducedNorm(a, 2), 1e-16); +// assertEquals(0, MatrixNorms.inducedNorm(a, -2), 1e-16); + assertEquals(0, MatrixNorms.inducedNorm(a, Double.POSITIVE_INFINITY), 1e-16); + assertEquals(0, MatrixNorms.inducedNorm(a, Double.NEGATIVE_INFINITY), 1e-16); + + // ---------------- Sub-case 8 ---------------- + assertThrows(LinearAlgebraException.class, () -> MatrixNorms.inducedNorm(a, 2.2)); + assertThrows(LinearAlgebraException.class, () -> MatrixNorms.inducedNorm(a, 15.332)); + } +} diff --git a/src/test/java/org/flag4j/tensor/TensorAddTests.java b/src/test/java/org/flag4j/arrays/dense/tensor/TensorAddTests.java similarity index 99% rename from src/test/java/org/flag4j/tensor/TensorAddTests.java rename to src/test/java/org/flag4j/arrays/dense/tensor/TensorAddTests.java index 2a3a80fbe..46461d77c 100644 --- a/src/test/java/org/flag4j/tensor/TensorAddTests.java +++ b/src/test/java/org/flag4j/arrays/dense/tensor/TensorAddTests.java @@ -1,4 +1,4 @@ -package org.flag4j.tensor; +package org.flag4j.arrays.dense.tensor; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/tensor/TensorAggregateTests.java b/src/test/java/org/flag4j/arrays/dense/tensor/TensorAggregateTests.java similarity index 97% rename from src/test/java/org/flag4j/tensor/TensorAggregateTests.java rename to src/test/java/org/flag4j/arrays/dense/tensor/TensorAggregateTests.java index 4c1e6507d..1f953bd43 100644 --- a/src/test/java/org/flag4j/tensor/TensorAggregateTests.java +++ b/src/test/java/org/flag4j/arrays/dense/tensor/TensorAggregateTests.java @@ -1,4 +1,4 @@ -package org.flag4j.tensor; +package org.flag4j.arrays.dense.tensor; import org.flag4j.arrays.Shape; import org.flag4j.arrays.dense.Tensor; diff --git a/src/test/java/org/flag4j/tensor/TensorConstructorTests.java b/src/test/java/org/flag4j/arrays/dense/tensor/TensorConstructorTests.java similarity index 99% rename from src/test/java/org/flag4j/tensor/TensorConstructorTests.java rename to src/test/java/org/flag4j/arrays/dense/tensor/TensorConstructorTests.java index 31a73653b..0f15ffbf0 100644 --- a/src/test/java/org/flag4j/tensor/TensorConstructorTests.java +++ b/src/test/java/org/flag4j/arrays/dense/tensor/TensorConstructorTests.java @@ -1,4 +1,4 @@ -package org.flag4j.tensor; +package org.flag4j.arrays.dense.tensor; import org.flag4j.arrays.Shape; import org.flag4j.arrays.dense.Matrix; diff --git a/src/test/java/org/flag4j/tensor/TensorConversionTests.java b/src/test/java/org/flag4j/arrays/dense/tensor/TensorConversionTests.java similarity index 98% rename from src/test/java/org/flag4j/tensor/TensorConversionTests.java rename to src/test/java/org/flag4j/arrays/dense/tensor/TensorConversionTests.java index 1d9b54408..492e0d863 100644 --- a/src/test/java/org/flag4j/tensor/TensorConversionTests.java +++ b/src/test/java/org/flag4j/arrays/dense/tensor/TensorConversionTests.java @@ -1,4 +1,4 @@ -package org.flag4j.tensor; +package org.flag4j.arrays.dense.tensor; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/tensor/TensorDotTests.java b/src/test/java/org/flag4j/arrays/dense/tensor/TensorDotTests.java similarity index 99% rename from src/test/java/org/flag4j/tensor/TensorDotTests.java rename to src/test/java/org/flag4j/arrays/dense/tensor/TensorDotTests.java index 2c3312f8d..3c3b48faf 100644 --- a/src/test/java/org/flag4j/tensor/TensorDotTests.java +++ b/src/test/java/org/flag4j/arrays/dense/tensor/TensorDotTests.java @@ -1,4 +1,4 @@ -package org.flag4j.tensor; +package org.flag4j.arrays.dense.tensor; import org.flag4j.arrays.Shape; import org.flag4j.arrays.dense.Tensor; diff --git a/src/test/java/org/flag4j/tensor/TensorElemDivTests.java b/src/test/java/org/flag4j/arrays/dense/tensor/TensorElemDivTests.java similarity index 99% rename from src/test/java/org/flag4j/tensor/TensorElemDivTests.java rename to src/test/java/org/flag4j/arrays/dense/tensor/TensorElemDivTests.java index 058ea2f8e..3f9e936ec 100644 --- a/src/test/java/org/flag4j/tensor/TensorElemDivTests.java +++ b/src/test/java/org/flag4j/arrays/dense/tensor/TensorElemDivTests.java @@ -1,4 +1,4 @@ -package org.flag4j.tensor; +package org.flag4j.arrays.dense.tensor; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/tensor/TensorElemMultTests.java b/src/test/java/org/flag4j/arrays/dense/tensor/TensorElemMultTests.java similarity index 99% rename from src/test/java/org/flag4j/tensor/TensorElemMultTests.java rename to src/test/java/org/flag4j/arrays/dense/tensor/TensorElemMultTests.java index beaebf7c4..3cee03531 100644 --- a/src/test/java/org/flag4j/tensor/TensorElemMultTests.java +++ b/src/test/java/org/flag4j/arrays/dense/tensor/TensorElemMultTests.java @@ -1,4 +1,4 @@ -package org.flag4j.tensor; +package org.flag4j.arrays.dense.tensor; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/tensor/TensorEqualsTests.java b/src/test/java/org/flag4j/arrays/dense/tensor/TensorEqualsTests.java similarity index 99% rename from src/test/java/org/flag4j/tensor/TensorEqualsTests.java rename to src/test/java/org/flag4j/arrays/dense/tensor/TensorEqualsTests.java index ef65ec1f2..b713bd391 100644 --- a/src/test/java/org/flag4j/tensor/TensorEqualsTests.java +++ b/src/test/java/org/flag4j/arrays/dense/tensor/TensorEqualsTests.java @@ -1,4 +1,4 @@ -package org.flag4j.tensor; +package org.flag4j.arrays.dense.tensor; import org.flag4j.algebraic_structures.Complex128; diff --git a/src/test/java/org/flag4j/tensor/TensorGetSetTests.java b/src/test/java/org/flag4j/arrays/dense/tensor/TensorGetSetTests.java similarity index 98% rename from src/test/java/org/flag4j/tensor/TensorGetSetTests.java rename to src/test/java/org/flag4j/arrays/dense/tensor/TensorGetSetTests.java index 6357c0a58..dc7fd7bab 100644 --- a/src/test/java/org/flag4j/tensor/TensorGetSetTests.java +++ b/src/test/java/org/flag4j/arrays/dense/tensor/TensorGetSetTests.java @@ -1,4 +1,4 @@ -package org.flag4j.tensor; +package org.flag4j.arrays.dense.tensor; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/tensor/TensorPropertiesTests.java b/src/test/java/org/flag4j/arrays/dense/tensor/TensorPropertiesTests.java similarity index 99% rename from src/test/java/org/flag4j/tensor/TensorPropertiesTests.java rename to src/test/java/org/flag4j/arrays/dense/tensor/TensorPropertiesTests.java index ced0ee1a4..766c62327 100644 --- a/src/test/java/org/flag4j/tensor/TensorPropertiesTests.java +++ b/src/test/java/org/flag4j/arrays/dense/tensor/TensorPropertiesTests.java @@ -1,4 +1,4 @@ -package org.flag4j.tensor; +package org.flag4j.arrays.dense.tensor; import org.flag4j.arrays.Shape; import org.flag4j.arrays.dense.Tensor; diff --git a/src/test/java/org/flag4j/tensor/TensorReshapeTests.java b/src/test/java/org/flag4j/arrays/dense/tensor/TensorReshapeTests.java similarity index 98% rename from src/test/java/org/flag4j/tensor/TensorReshapeTests.java rename to src/test/java/org/flag4j/arrays/dense/tensor/TensorReshapeTests.java index 2963c390c..a9e9bf90c 100644 --- a/src/test/java/org/flag4j/tensor/TensorReshapeTests.java +++ b/src/test/java/org/flag4j/arrays/dense/tensor/TensorReshapeTests.java @@ -1,4 +1,4 @@ -package org.flag4j.tensor; +package org.flag4j.arrays.dense.tensor; import org.flag4j.arrays.Shape; import org.flag4j.arrays.dense.Tensor; diff --git a/src/test/java/org/flag4j/tensor/TensorScalMultDivTests.java b/src/test/java/org/flag4j/arrays/dense/tensor/TensorScalMultDivTests.java similarity index 91% rename from src/test/java/org/flag4j/tensor/TensorScalMultDivTests.java rename to src/test/java/org/flag4j/arrays/dense/tensor/TensorScalMultDivTests.java index 5cee483d6..da0bb5d00 100644 --- a/src/test/java/org/flag4j/tensor/TensorScalMultDivTests.java +++ b/src/test/java/org/flag4j/arrays/dense/tensor/TensorScalMultDivTests.java @@ -1,4 +1,4 @@ -package org.flag4j.tensor; +package org.flag4j.arrays.dense.tensor; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; @@ -70,8 +70,8 @@ void realScalDivTestCase() { // ------------------------ Sub-case 1 ------------------------ expEntries = new double[]{ - 1.23/-1.4115, 2.556/-1.4115, -121.5/-1.4115, 15.61/-1.4115, 14.15/-1.4115, -99.23425/-1.4115, - 0.001345/-1.4115, 2.677/-1.4115, 8.14/-1.4115, -0.000194/-1.4115, 1/-1.4115, 234/-1.4115 + 1.23*(1.0/-1.4115), 2.556*(1.0/-1.4115), -121.5*(1.0/-1.4115), 15.61*(1.0/-1.4115), 14.15*(1.0/-1.4115), -99.23425*(1.0/-1.4115), + 0.001345*(1.0/-1.4115), 2.677*(1.0/-1.4115), 8.14*(1.0/-1.4115), -0.000194*(1.0/-1.4115), 1*(1.0/-1.4115), 234*(1.0/-1.4115) }; expShape = new Shape(2, 3, 2); exp = new Tensor(expShape, expEntries); diff --git a/src/test/java/org/flag4j/tensor/TensorSubTests.java b/src/test/java/org/flag4j/arrays/dense/tensor/TensorSubTests.java similarity index 99% rename from src/test/java/org/flag4j/tensor/TensorSubTests.java rename to src/test/java/org/flag4j/arrays/dense/tensor/TensorSubTests.java index 314ce61f5..1fa29b3aa 100644 --- a/src/test/java/org/flag4j/tensor/TensorSubTests.java +++ b/src/test/java/org/flag4j/arrays/dense/tensor/TensorSubTests.java @@ -1,4 +1,4 @@ -package org.flag4j.tensor; +package org.flag4j.arrays.dense.tensor; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/tensor/TensorToStringTests.java b/src/test/java/org/flag4j/arrays/dense/tensor/TensorToStringTests.java similarity index 97% rename from src/test/java/org/flag4j/tensor/TensorToStringTests.java rename to src/test/java/org/flag4j/arrays/dense/tensor/TensorToStringTests.java index 7bf07669f..b0e6cff0c 100644 --- a/src/test/java/org/flag4j/tensor/TensorToStringTests.java +++ b/src/test/java/org/flag4j/arrays/dense/tensor/TensorToStringTests.java @@ -1,4 +1,4 @@ -package org.flag4j.tensor; +package org.flag4j.arrays.dense.tensor; import org.flag4j.arrays.Shape; import org.flag4j.arrays.dense.Tensor; diff --git a/src/test/java/org/flag4j/tensor/TensorTransposeTests.java b/src/test/java/org/flag4j/arrays/dense/tensor/TensorTransposeTests.java similarity index 98% rename from src/test/java/org/flag4j/tensor/TensorTransposeTests.java rename to src/test/java/org/flag4j/arrays/dense/tensor/TensorTransposeTests.java index dc76f77e3..f1e3ee006 100644 --- a/src/test/java/org/flag4j/tensor/TensorTransposeTests.java +++ b/src/test/java/org/flag4j/arrays/dense/tensor/TensorTransposeTests.java @@ -1,4 +1,4 @@ -package org.flag4j.tensor; +package org.flag4j.arrays.dense.tensor; import org.flag4j.arrays.Shape; import org.flag4j.arrays.dense.Tensor; diff --git a/src/test/java/org/flag4j/tensor/TensorUnitaryOperationTests.java b/src/test/java/org/flag4j/arrays/dense/tensor/TensorUnitaryOperationTests.java similarity index 98% rename from src/test/java/org/flag4j/tensor/TensorUnitaryOperationTests.java rename to src/test/java/org/flag4j/arrays/dense/tensor/TensorUnitaryOperationTests.java index aff0ad3a1..9c3aaaac9 100644 --- a/src/test/java/org/flag4j/tensor/TensorUnitaryOperationTests.java +++ b/src/test/java/org/flag4j/arrays/dense/tensor/TensorUnitaryOperationTests.java @@ -1,4 +1,4 @@ -package org.flag4j.tensor; +package org.flag4j.arrays.dense.tensor; import org.flag4j.arrays.Shape; import org.flag4j.arrays.dense.Tensor; diff --git a/src/test/java/org/flag4j/vector/VectorAddSubEqTests.java b/src/test/java/org/flag4j/arrays/dense/vector/VectorAddSubEqTests.java similarity index 99% rename from src/test/java/org/flag4j/vector/VectorAddSubEqTests.java rename to src/test/java/org/flag4j/arrays/dense/vector/VectorAddSubEqTests.java index 0c111435c..e83a04e70 100644 --- a/src/test/java/org/flag4j/vector/VectorAddSubEqTests.java +++ b/src/test/java/org/flag4j/arrays/dense/vector/VectorAddSubEqTests.java @@ -1,4 +1,4 @@ -package org.flag4j.vector; +package org.flag4j.arrays.dense.vector; import org.flag4j.arrays.dense.Vector; import org.flag4j.arrays.sparse.CooVector; diff --git a/src/test/java/org/flag4j/vector/VectorAddSubTests.java b/src/test/java/org/flag4j/arrays/dense/vector/VectorAddSubTests.java similarity index 99% rename from src/test/java/org/flag4j/vector/VectorAddSubTests.java rename to src/test/java/org/flag4j/arrays/dense/vector/VectorAddSubTests.java index 28333e6da..38c60c44c 100644 --- a/src/test/java/org/flag4j/vector/VectorAddSubTests.java +++ b/src/test/java/org/flag4j/arrays/dense/vector/VectorAddSubTests.java @@ -1,4 +1,4 @@ -package org.flag4j.vector; +package org.flag4j.arrays.dense.vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CVector; diff --git a/src/test/java/org/flag4j/vector/VectorAggregateTests.java b/src/test/java/org/flag4j/arrays/dense/vector/VectorAggregateTests.java similarity index 98% rename from src/test/java/org/flag4j/vector/VectorAggregateTests.java rename to src/test/java/org/flag4j/arrays/dense/vector/VectorAggregateTests.java index 0fc4e60bb..a1eedd3f7 100644 --- a/src/test/java/org/flag4j/vector/VectorAggregateTests.java +++ b/src/test/java/org/flag4j/arrays/dense/vector/VectorAggregateTests.java @@ -1,4 +1,4 @@ -package org.flag4j.vector; +package org.flag4j.arrays.dense.vector; import org.flag4j.arrays.dense.Vector; import org.junit.jupiter.api.Test; diff --git a/src/test/java/org/flag4j/vector/VectorConstructorTests.java b/src/test/java/org/flag4j/arrays/dense/vector/VectorConstructorTests.java similarity index 99% rename from src/test/java/org/flag4j/vector/VectorConstructorTests.java rename to src/test/java/org/flag4j/arrays/dense/vector/VectorConstructorTests.java index 40a44c696..333605241 100644 --- a/src/test/java/org/flag4j/vector/VectorConstructorTests.java +++ b/src/test/java/org/flag4j/arrays/dense/vector/VectorConstructorTests.java @@ -1,4 +1,4 @@ -package org.flag4j.vector; +package org.flag4j.arrays.dense.vector; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/vector/VectorConversionTests.java b/src/test/java/org/flag4j/arrays/dense/vector/VectorConversionTests.java similarity index 98% rename from src/test/java/org/flag4j/vector/VectorConversionTests.java rename to src/test/java/org/flag4j/arrays/dense/vector/VectorConversionTests.java index 516478b08..7919bbe1d 100644 --- a/src/test/java/org/flag4j/vector/VectorConversionTests.java +++ b/src/test/java/org/flag4j/arrays/dense/vector/VectorConversionTests.java @@ -1,4 +1,4 @@ -package org.flag4j.vector; +package org.flag4j.arrays.dense.vector; import org.flag4j.arrays.Shape; import org.flag4j.arrays.dense.Matrix; diff --git a/src/test/java/org/flag4j/vector/VectorCopyTransposeTests.java b/src/test/java/org/flag4j/arrays/dense/vector/VectorCopyTransposeTests.java similarity index 97% rename from src/test/java/org/flag4j/vector/VectorCopyTransposeTests.java rename to src/test/java/org/flag4j/arrays/dense/vector/VectorCopyTransposeTests.java index e19639166..e779bfad1 100644 --- a/src/test/java/org/flag4j/vector/VectorCopyTransposeTests.java +++ b/src/test/java/org/flag4j/arrays/dense/vector/VectorCopyTransposeTests.java @@ -1,4 +1,4 @@ -package org.flag4j.vector; +package org.flag4j.arrays.dense.vector; import org.flag4j.arrays.dense.Vector; import org.junit.jupiter.api.Test; diff --git a/src/test/java/org/flag4j/vector/VectorCrossProductTests.java b/src/test/java/org/flag4j/arrays/dense/vector/VectorCrossProductTests.java similarity index 97% rename from src/test/java/org/flag4j/vector/VectorCrossProductTests.java rename to src/test/java/org/flag4j/arrays/dense/vector/VectorCrossProductTests.java index 70d975f16..96b46e1b9 100644 --- a/src/test/java/org/flag4j/vector/VectorCrossProductTests.java +++ b/src/test/java/org/flag4j/arrays/dense/vector/VectorCrossProductTests.java @@ -1,4 +1,4 @@ -package org.flag4j.vector; +package org.flag4j.arrays.dense.vector; import org.flag4j.arrays.dense.Vector; import org.junit.jupiter.api.Test; diff --git a/src/test/java/org/flag4j/vector/VectorElemMultDivTests.java b/src/test/java/org/flag4j/arrays/dense/vector/VectorElemMultDivTests.java similarity index 99% rename from src/test/java/org/flag4j/vector/VectorElemMultDivTests.java rename to src/test/java/org/flag4j/arrays/dense/vector/VectorElemMultDivTests.java index a1a9840b5..399b9a490 100644 --- a/src/test/java/org/flag4j/vector/VectorElemMultDivTests.java +++ b/src/test/java/org/flag4j/arrays/dense/vector/VectorElemMultDivTests.java @@ -1,4 +1,4 @@ -package org.flag4j.vector; +package org.flag4j.arrays.dense.vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CVector; diff --git a/src/test/java/org/flag4j/vector/VectorElementwiseTests.java b/src/test/java/org/flag4j/arrays/dense/vector/VectorElementwiseTests.java similarity index 97% rename from src/test/java/org/flag4j/vector/VectorElementwiseTests.java rename to src/test/java/org/flag4j/arrays/dense/vector/VectorElementwiseTests.java index 1a1f0ce12..e144138a0 100644 --- a/src/test/java/org/flag4j/vector/VectorElementwiseTests.java +++ b/src/test/java/org/flag4j/arrays/dense/vector/VectorElementwiseTests.java @@ -1,4 +1,4 @@ -package org.flag4j.vector; +package org.flag4j.arrays.dense.vector; import org.flag4j.arrays.dense.Vector; import org.junit.jupiter.api.Test; diff --git a/src/test/java/org/flag4j/vector/VectorEqualsTests.java b/src/test/java/org/flag4j/arrays/dense/vector/VectorEqualsTests.java similarity index 99% rename from src/test/java/org/flag4j/vector/VectorEqualsTests.java rename to src/test/java/org/flag4j/arrays/dense/vector/VectorEqualsTests.java index fb571fd73..8ad855662 100644 --- a/src/test/java/org/flag4j/vector/VectorEqualsTests.java +++ b/src/test/java/org/flag4j/arrays/dense/vector/VectorEqualsTests.java @@ -1,4 +1,4 @@ -package org.flag4j.vector; +package org.flag4j.arrays.dense.vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CVector; diff --git a/src/test/java/org/flag4j/vector/VectorInnerProductTest.java b/src/test/java/org/flag4j/arrays/dense/vector/VectorInnerProductTest.java similarity index 96% rename from src/test/java/org/flag4j/vector/VectorInnerProductTest.java rename to src/test/java/org/flag4j/arrays/dense/vector/VectorInnerProductTest.java index de32e4c64..b481f71ad 100644 --- a/src/test/java/org/flag4j/vector/VectorInnerProductTest.java +++ b/src/test/java/org/flag4j/arrays/dense/vector/VectorInnerProductTest.java @@ -1,4 +1,4 @@ -package org.flag4j.vector; +package org.flag4j.arrays.dense.vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CVector; @@ -90,7 +90,7 @@ void normalizeTestCase() { Vector exp; // ----------------------- Sub-case 1 ----------------------- - expEntries = new double[]{0.0046451435284722955, 0.026012803759444855, -0.043455317708858326, 0.9987058586215436}; + expEntries = new double[]{0.0046451435284722955, 0.026012803759444852, -0.043455317708858326, 0.9987058586215435}; exp = new Vector(expEntries); assertEquals(exp, a.normalize()); diff --git a/src/test/java/org/flag4j/vector/VectorNormTest.java b/src/test/java/org/flag4j/arrays/dense/vector/VectorNormTest.java similarity index 97% rename from src/test/java/org/flag4j/vector/VectorNormTest.java rename to src/test/java/org/flag4j/arrays/dense/vector/VectorNormTest.java index 1d21380ea..4759d6373 100644 --- a/src/test/java/org/flag4j/vector/VectorNormTest.java +++ b/src/test/java/org/flag4j/arrays/dense/vector/VectorNormTest.java @@ -1,4 +1,4 @@ -package org.flag4j.vector; +package org.flag4j.arrays.dense.vector; import org.flag4j.arrays.dense.Vector; import org.flag4j.linalg.VectorNorms; diff --git a/src/test/java/org/flag4j/vector/VectorOuterProductTests.java b/src/test/java/org/flag4j/arrays/dense/vector/VectorOuterProductTests.java similarity index 99% rename from src/test/java/org/flag4j/vector/VectorOuterProductTests.java rename to src/test/java/org/flag4j/arrays/dense/vector/VectorOuterProductTests.java index 0faeb8ad5..42ac18da0 100644 --- a/src/test/java/org/flag4j/vector/VectorOuterProductTests.java +++ b/src/test/java/org/flag4j/arrays/dense/vector/VectorOuterProductTests.java @@ -1,4 +1,4 @@ -package org.flag4j.vector; +package org.flag4j.arrays.dense.vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CMatrix; diff --git a/src/test/java/org/flag4j/vector/VectorPerpParallelTests.java b/src/test/java/org/flag4j/arrays/dense/vector/VectorPerpParallelTests.java similarity index 98% rename from src/test/java/org/flag4j/vector/VectorPerpParallelTests.java rename to src/test/java/org/flag4j/arrays/dense/vector/VectorPerpParallelTests.java index eda3f694c..a0feac8a7 100644 --- a/src/test/java/org/flag4j/vector/VectorPerpParallelTests.java +++ b/src/test/java/org/flag4j/arrays/dense/vector/VectorPerpParallelTests.java @@ -1,4 +1,4 @@ -package org.flag4j.vector; +package org.flag4j.arrays.dense.vector; import org.flag4j.arrays.dense.Vector; import org.junit.jupiter.api.Test; diff --git a/src/test/java/org/flag4j/vector/VectorRepeatTests.java b/src/test/java/org/flag4j/arrays/dense/vector/VectorRepeatTests.java similarity index 98% rename from src/test/java/org/flag4j/vector/VectorRepeatTests.java rename to src/test/java/org/flag4j/arrays/dense/vector/VectorRepeatTests.java index ad6129db8..ffaf6da70 100644 --- a/src/test/java/org/flag4j/vector/VectorRepeatTests.java +++ b/src/test/java/org/flag4j/arrays/dense/vector/VectorRepeatTests.java @@ -1,4 +1,4 @@ -package org.flag4j.vector; +package org.flag4j.arrays.dense.vector; import org.flag4j.arrays.dense.Matrix; import org.flag4j.arrays.dense.Vector; diff --git a/src/test/java/org/flag4j/vector/VectorScaleMultDivTests.java b/src/test/java/org/flag4j/arrays/dense/vector/VectorScaleMultDivTests.java similarity index 98% rename from src/test/java/org/flag4j/vector/VectorScaleMultDivTests.java rename to src/test/java/org/flag4j/arrays/dense/vector/VectorScaleMultDivTests.java index 3af06b4ec..80faaf315 100644 --- a/src/test/java/org/flag4j/vector/VectorScaleMultDivTests.java +++ b/src/test/java/org/flag4j/arrays/dense/vector/VectorScaleMultDivTests.java @@ -1,4 +1,4 @@ -package org.flag4j.vector; +package org.flag4j.arrays.dense.vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CVector; diff --git a/src/test/java/org/flag4j/vector/VectorSetTests.java b/src/test/java/org/flag4j/arrays/dense/vector/VectorSetTests.java similarity index 93% rename from src/test/java/org/flag4j/vector/VectorSetTests.java rename to src/test/java/org/flag4j/arrays/dense/vector/VectorSetTests.java index 8b21b64b6..4a3408ed9 100644 --- a/src/test/java/org/flag4j/vector/VectorSetTests.java +++ b/src/test/java/org/flag4j/arrays/dense/vector/VectorSetTests.java @@ -1,4 +1,4 @@ -package org.flag4j.vector; +package org.flag4j.arrays.dense.vector; import org.flag4j.arrays.dense.Vector; import org.junit.jupiter.api.Test; diff --git a/src/test/java/org/flag4j/vector/VectorShapeTests.java b/src/test/java/org/flag4j/arrays/dense/vector/VectorShapeTests.java similarity index 99% rename from src/test/java/org/flag4j/vector/VectorShapeTests.java rename to src/test/java/org/flag4j/arrays/dense/vector/VectorShapeTests.java index 2817d1349..356df85df 100644 --- a/src/test/java/org/flag4j/vector/VectorShapeTests.java +++ b/src/test/java/org/flag4j/arrays/dense/vector/VectorShapeTests.java @@ -1,4 +1,4 @@ -package org.flag4j.vector; +package org.flag4j.arrays.dense.vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CVector; diff --git a/src/test/java/org/flag4j/vector/VectorStackJoinTests.java b/src/test/java/org/flag4j/arrays/dense/vector/VectorStackJoinTests.java similarity index 98% rename from src/test/java/org/flag4j/vector/VectorStackJoinTests.java rename to src/test/java/org/flag4j/arrays/dense/vector/VectorStackJoinTests.java index 725d779ce..7107af926 100644 --- a/src/test/java/org/flag4j/vector/VectorStackJoinTests.java +++ b/src/test/java/org/flag4j/arrays/dense/vector/VectorStackJoinTests.java @@ -1,4 +1,4 @@ -package org.flag4j.vector; +package org.flag4j.arrays.dense.vector; import org.flag4j.arrays.dense.Matrix; import org.flag4j.arrays.dense.Vector; diff --git a/src/test/java/org/flag4j/vector/VectorToStringTest.java b/src/test/java/org/flag4j/arrays/dense/vector/VectorToStringTest.java similarity index 96% rename from src/test/java/org/flag4j/vector/VectorToStringTest.java rename to src/test/java/org/flag4j/arrays/dense/vector/VectorToStringTest.java index d20fa3021..abea945e6 100644 --- a/src/test/java/org/flag4j/vector/VectorToStringTest.java +++ b/src/test/java/org/flag4j/arrays/dense/vector/VectorToStringTest.java @@ -1,4 +1,4 @@ -package org.flag4j.vector; +package org.flag4j.arrays.dense.vector; import org.flag4j.arrays.dense.Vector; diff --git a/src/test/java/org/flag4j/vector/VectorZeroOnesTests.java b/src/test/java/org/flag4j/arrays/dense/vector/VectorZeroOnesTests.java similarity index 97% rename from src/test/java/org/flag4j/vector/VectorZeroOnesTests.java rename to src/test/java/org/flag4j/arrays/dense/vector/VectorZeroOnesTests.java index 888475d5f..8df20e3c3 100644 --- a/src/test/java/org/flag4j/vector/VectorZeroOnesTests.java +++ b/src/test/java/org/flag4j/arrays/dense/vector/VectorZeroOnesTests.java @@ -1,4 +1,4 @@ -package org.flag4j.vector; +package org.flag4j.arrays.dense.vector; import org.flag4j.arrays.dense.Vector; import org.junit.jupiter.api.Test; diff --git a/src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixAddSubTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixAddSubTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixAddSubTests.java rename to src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixAddSubTests.java index d9c03f2c9..b937891cb 100644 --- a/src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixAddSubTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixAddSubTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_sparse_matrix; +package org.flag4j.arrays.sparse.complex_coo_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixAugmentTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixAugmentTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixAugmentTests.java rename to src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixAugmentTests.java index e6e069f0a..81be10dd4 100644 --- a/src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixAugmentTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixAugmentTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_sparse_matrix; +package org.flag4j.arrays.sparse.complex_coo_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixAugmentVectorTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixAugmentVectorTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixAugmentVectorTests.java rename to src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixAugmentVectorTests.java index 52bd46b81..7a3d96d95 100644 --- a/src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixAugmentVectorTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixAugmentVectorTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_sparse_matrix; +package org.flag4j.arrays.sparse.complex_coo_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixDirectSumTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixDirectSumTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixDirectSumTests.java rename to src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixDirectSumTests.java index 0e696b7fa..64860b093 100644 --- a/src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixDirectSumTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixDirectSumTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_sparse_matrix; +package org.flag4j.arrays.sparse.complex_coo_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixElemDivTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixElemDivTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixElemDivTests.java rename to src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixElemDivTests.java index df3f0a1d4..243afdcef 100644 --- a/src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixElemDivTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixElemDivTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_sparse_matrix; +package org.flag4j.arrays.sparse.complex_coo_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixElemMultTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixElemMultTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixElemMultTests.java rename to src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixElemMultTests.java index e92c9f022..127057514 100644 --- a/src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixElemMultTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixElemMultTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_sparse_matrix; +package org.flag4j.arrays.sparse.complex_coo_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixGetRowColTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixGetRowColTests.java similarity index 97% rename from src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixGetRowColTests.java rename to src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixGetRowColTests.java index 08b81f040..7572f103c 100644 --- a/src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixGetRowColTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixGetRowColTests.java @@ -1,10 +1,9 @@ -package org.flag4j.complex_sparse_matrix; +package org.flag4j.arrays.sparse.complex_coo_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; import org.flag4j.arrays.sparse.CooCMatrix; import org.flag4j.arrays.sparse.CooCVector; -import org.flag4j.linalg.ops.sparse.coo.field_ops.CooFieldMatrixGetSet; import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -114,7 +113,7 @@ void getRowSliceTest() { expIndices = new int[]{}; exp = new CooCVector(expShape.get(0), expEntries, expIndices); - assertEquals(exp, CooFieldMatrixGetSet.getRow(a, 2, 1, 3)); + assertEquals(exp, a.getRow(2, 1, 3)); // --------------------- Sub-case 2 --------------------- aShape = new Shape(23, 11); @@ -128,7 +127,7 @@ void getRowSliceTest() { expIndices = new int[]{}; exp = new CooCVector(expShape.get(0), expEntries, expIndices); - assertEquals(exp, CooFieldMatrixGetSet.getRow(a,18, 0, 7)); + assertEquals(exp, a.getRow(18, 0, 7)); // --------------------- Sub-case 3 --------------------- aShape = new Shape(1000, 5); @@ -142,7 +141,7 @@ void getRowSliceTest() { expIndices = new int[]{}; exp = new CooCVector(expShape.get(0), expEntries, expIndices); - assertEquals(exp, CooFieldMatrixGetSet.getRow(a,0, 1, 4)); + assertEquals(exp, a.getRow(0, 1, 4)); // --------------------- Sub-case 4 --------------------- aShape = new Shape(3, 5); @@ -152,7 +151,7 @@ void getRowSliceTest() { a = new CooCMatrix(aShape, aEntries, aRowIndices, aColIndices); CooCMatrix final0a = a; - assertThrows(Exception.class, ()->CooFieldMatrixGetSet.getRow(final0a,-1, 1, 3)); + assertThrows(Exception.class, ()->final0a.getRow(-1, 1, 3)); // --------------------- Sub-case 5 --------------------- aShape = new Shape(3, 5); @@ -162,7 +161,7 @@ void getRowSliceTest() { a = new CooCMatrix(aShape, aEntries, aRowIndices, aColIndices); CooCMatrix final1a = a; - assertThrows(Exception.class, ()->CooFieldMatrixGetSet.getRow(final1a,3, 1, 3)); + assertThrows(Exception.class, ()->final1a.getRow(3, 1, 3)); // --------------------- Sub-case 6 --------------------- aShape = new Shape(3, 5); @@ -172,7 +171,7 @@ void getRowSliceTest() { a = new CooCMatrix(aShape, aEntries, aRowIndices, aColIndices); CooCMatrix final2a = a; - assertThrows(Exception.class, ()->CooFieldMatrixGetSet.getRow(final2a,2, -1, 3)); + assertThrows(Exception.class, ()->final2a.getRow(2, -1, 3)); // --------------------- Sub-case 7 --------------------- aShape = new Shape(3, 5); @@ -182,7 +181,7 @@ void getRowSliceTest() { a = new CooCMatrix(aShape, aEntries, aRowIndices, aColIndices); CooCMatrix final3a = a; - assertThrows(Exception.class, ()->CooFieldMatrixGetSet.getRow(final3a,2, 1, 6)); + assertThrows(Exception.class, ()->final3a.getRow(2, 1, 6)); } diff --git a/src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixGetSliceTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixGetSliceTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixGetSliceTests.java rename to src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixGetSliceTests.java index 472ff7088..721a282e7 100644 --- a/src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixGetSliceTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixGetSliceTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_sparse_matrix; +package org.flag4j.arrays.sparse.complex_coo_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixIsCloseToITests.java b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixIsCloseToITests.java new file mode 100644 index 000000000..703407802 --- /dev/null +++ b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixIsCloseToITests.java @@ -0,0 +1,91 @@ +package org.flag4j.arrays.sparse.complex_coo_matrix; + + +import org.flag4j.algebraic_structures.Complex128; +import org.flag4j.arrays.Shape; +import org.flag4j.arrays.sparse.CooCMatrix; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class CooCMatrixIsCloseToITests { + + @Test + void testCooCMatrixIsCloseToI() { + Shape aShape; + int[] aRowIndices, aColIndices; + Complex128[] aData; + CooCMatrix a; + + // ---------------------- Sub-case 1 ---------------------- + aShape = new Shape(50, 12); + aRowIndices = new int[]{0, 5, 14, 23, 49}; + aColIndices = new int[]{1, 3, 3, 1, 2}; + aData = new Complex128[]{new Complex128(1, 234), new Complex128(3, 456), new Complex128(5, 789), + new Complex128(7, 89), new Complex128(9, 99)}; + a = new CooCMatrix(aShape, aData, aRowIndices, aColIndices); + + assertFalse(a.isCloseToI()); + + // ---------------------- Sub-case 2 ---------------------- + aShape = new Shape(20, 20); + aRowIndices = new int[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}; + aColIndices = new int[]{0, 1, 2, 3, 4, 5, 12, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}; + aData = new Complex128[]{Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, + Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, + Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE}; + a = new CooCMatrix(aShape, aData, aRowIndices, aColIndices); + + assertFalse(a.isCloseToI()); + + // ---------------------- Sub-case 3 ---------------------- + aShape = new Shape(20, 20); + aRowIndices = new int[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}; + aColIndices = new int[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}; + aData = new Complex128[]{Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, + Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, + Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, + new Complex128(5, -1)}; + a = new CooCMatrix(aShape, aData, aRowIndices, aColIndices); + + assertFalse(a.isCloseToI()); + + // ---------------------- Sub-case 4 ---------------------- + aShape = new Shape(20, 20); + aRowIndices = new int[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}; + aColIndices = new int[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}; + aData = new Complex128[]{Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, + Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, + Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE}; + a = new CooCMatrix(aShape, aData, aRowIndices, aColIndices); + + assertTrue(a.isCloseToI()); + + // ---------------------- Sub-case 5 ---------------------- + aShape = new Shape(20, 20); + aRowIndices = new int[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}; + aColIndices = new int[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}; + aData = new Complex128[]{Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, + Complex128.ONE, Complex128.ONE, new Complex128(1.000000000000001, -15.1e-56), Complex128.ONE, Complex128.ONE, + Complex128.ONE, Complex128.ONE, + Complex128.ONE, Complex128.ONE, Complex128.ONE, new Complex128(0.99999999, 1.4e-14), Complex128.ONE, Complex128.ONE, + Complex128.ONE}; + a = new CooCMatrix(aShape, aData, aRowIndices, aColIndices); + + assertTrue(a.isCloseToI()); + + // ---------------------- Sub-case 6 ---------------------- + aShape = new Shape(20, 20); + aRowIndices = new int[]{0, 1, 2, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}; + aColIndices = new int[]{0, 1, 2, 0, 3, 4, 5, 6, 7, 8, 9, 10, 13, 11, 12, 13, 14, 15, 16, 17, 18, 19}; + aData = new Complex128[]{Complex128.ONE, Complex128.ONE, Complex128.ONE, new Complex128(1.24e-16, 1.2e-12), + Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, + new Complex128(1.000000000000001, -15.1e-56), Complex128.ONE, Complex128.ONE, + new Complex128(1.0e-18, 25.6e-21), Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, Complex128.ONE, + new Complex128(0.99999999, 1.4e-14), Complex128.ONE, Complex128.ONE, Complex128.ONE}; + a = new CooCMatrix(aShape, aData, aRowIndices, aColIndices); + + assertTrue(a.isCloseToI()); + } +} diff --git a/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixIsSymmTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixIsSymmTests.java new file mode 100644 index 000000000..edc388513 --- /dev/null +++ b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixIsSymmTests.java @@ -0,0 +1,75 @@ +package org.flag4j.arrays.sparse.complex_coo_matrix; + +import org.flag4j.algebraic_structures.Complex128; +import org.flag4j.arrays.Shape; +import org.flag4j.arrays.sparse.CooCMatrix; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class CooCMatrixIsSymmTests { + + @Test + void isSymmetricTests() { + Shape aShape; + int[] aRowIndices, aColIndices; + Complex128[] aData; + CooCMatrix a; + + // -------------------- Sub-case 1 -------------------- + aShape = new Shape(51, 51); + aData = new Complex128[]{new Complex128(0.711, 0.3), new Complex128(0.875, 0.657), + new Complex128(0.057, 0.164), new Complex128(0.207, 0.887), + new Complex128(0.885, 0.926), new Complex128(0.939, 0.405), + new Complex128(0.869, 0.506), new Complex128(0.562, 0.55), + new Complex128(0.94, 0.756), new Complex128(0.193, 0.037), + new Complex128(0.727, 0.541), new Complex128(0.938, 0.119)}; + aRowIndices = new int[]{0, 1, 3, 4, 17, 22, 23, 29, 38, 41, 50, 50}; + aColIndices = new int[]{30, 25, 49, 11, 22, 37, 50, 15, 25, 4, 21, 32}; + a = new CooCMatrix(aShape, aData, aRowIndices, aColIndices); + + assertFalse(a.isSymmetric()); + assertFalse(a.isHermitian()); + + // -------------------- Sub-case 2 -------------------- + aShape = new Shape(51, 51); + aData = new Complex128[]{new Complex128(0.315, 0.311), new Complex128(0.155, 0.236), new Complex128(0.345, 0.92), new Complex128(0.155, 0.236), new Complex128(0.347, 0.256), new Complex128(0.315, 0.311), new Complex128(0.345, 0.92), new Complex128(0.347, 0.256), new Complex128(0.119, 0.913)}; + aRowIndices = new int[]{0, 6, 10, 14, 28, 29, 40, 42, 49}; + aColIndices = new int[]{29, 14, 40, 6, 42, 0, 10, 28, 49}; + a = new CooCMatrix(aShape, aData, aRowIndices, aColIndices); + + assertTrue(a.isSymmetric()); + assertFalse(a.isHermitian()); + + // -------------------- Sub-case 3 -------------------- + aShape = new Shape(51, 51); + aData = new Complex128[]{new Complex128(0.38, -0.82), new Complex128(0.38, 0.82), new Complex128(0.456, -0.305), new Complex128(0.843, -0.768), new Complex128(0.839, -0.306), new Complex128(0.533, -0.878), new Complex128(0.718, -0.497), new Complex128(0.533, 0.878), new Complex128(0.456, 0.305), new Complex128(0.718, 0.497), new Complex128(0.839, 0.306), new Complex128(0.843, 0.768)}; + aRowIndices = new int[]{1, 8, 9, 17, 26, 29, 32, 34, 36, 38, 40, 44}; + aColIndices = new int[]{8, 1, 36, 44, 40, 34, 38, 29, 9, 32, 26, 17}; + a = new CooCMatrix(aShape, aData, aRowIndices, aColIndices); + + assertFalse(a.isSymmetric()); + assertTrue(a.isHermitian()); + + // -------------------- Sub-case 4 -------------------- + aShape = new Shape(12, 12); + aData = new Complex128[]{}; + aRowIndices = new int[]{}; + aColIndices = new int[]{}; + a = new CooCMatrix(aShape, aData, aRowIndices, aColIndices); + + assertTrue(a.isSymmetric()); + assertTrue(a.isHermitian()); + + // -------------------- Sub-case 5 -------------------- + aShape = new Shape(12, 17); + aData = new Complex128[]{new Complex128(0.305, 0.406), new Complex128(0.599, 0.739), new Complex128(0.02, 0.945), new Complex128(0.549, 0.64), new Complex128(0.842, 0.842), new Complex128(0.48, 0.124), new Complex128(0.166, 0.889)}; + aRowIndices = new int[]{0, 0, 0, 0, 3, 3, 8}; + aColIndices = new int[]{2, 6, 11, 12, 5, 8, 0}; + a = new CooCMatrix(aShape, aData, aRowIndices, aColIndices); + + assertFalse(a.isSymmetric()); + assertFalse(a.isHermitian()); + } +} diff --git a/src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixMultTransposeTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixMultTransposeTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixMultTransposeTests.java rename to src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixMultTransposeTests.java index 9df5ebfe8..a00774329 100644 --- a/src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixMultTransposeTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixMultTransposeTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_sparse_matrix; +package org.flag4j.arrays.sparse.complex_coo_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; @@ -47,6 +47,7 @@ void complexSparseMultTransposeTest() { }; exp = new CMatrix(expEntries); + assertEquals(exp, a.multTranspose(b)); // --------------------- Sub-case 2 --------------------- diff --git a/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixNormTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixNormTests.java new file mode 100644 index 000000000..f04db0eb1 --- /dev/null +++ b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixNormTests.java @@ -0,0 +1,44 @@ +package org.flag4j.arrays.sparse.complex_coo_matrix; + +import org.flag4j.algebraic_structures.Complex128; +import org.flag4j.arrays.Shape; +import org.flag4j.arrays.sparse.CooCMatrix; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class CooCMatrixNormTests { + + + @Test + void cooCMatrixNormTests() { + Shape aShape; + int[] aRowIndices; + int[] aColIndices; + Complex128[] aData; + CooCMatrix a; + double exp; + + // --------------------- Sub-case 1 --------------------- + aShape = new Shape(12, 12); + aData = new Complex128[]{new Complex128(0.808, 0.929), new Complex128(0.231, 0.157), new Complex128(0.509, 0.895), new Complex128(0.25, 0.055), new Complex128(0.602, 0.428), new Complex128(0.39, 0.012), new Complex128(0.609, 0.324), new Complex128(0.248, 0.281), new Complex128(0.25, 0.694), new Complex128(0.176, 0.411), new Complex128(0.938, 0.241), new Complex128(0.831, 0.321), new Complex128(0.541, 0.092), new Complex128(0.406, 0.908)}; + aRowIndices = new int[]{1, 1, 2, 2, 4, 5, 5, 6, 6, 7, 7, 7, 9, 10}; + aColIndices = new int[]{1, 5, 1, 9, 0, 6, 7, 1, 2, 2, 9, 11, 7, 8}; + a = new CooCMatrix(aShape, aData, aRowIndices, aColIndices); + + exp = 2.7927951947824603; + + assertEquals(exp, a.norm(), 1e-12); + + // --------------------- Sub-case 2 --------------------- + aShape = new Shape(300, 300); + aData = new Complex128[]{new Complex128(0.754, 0.419), new Complex128(0.936, 0.196), new Complex128(0.661, 0.887), new Complex128(0.476, 0.081), new Complex128(0.129, 0.368), new Complex128(0.029, 0.25), new Complex128(0.482, 0.922), new Complex128(0.581, 0.33), new Complex128(0.627, 0.832)}; + aRowIndices = new int[]{5, 6, 46, 81, 98, 111, 186, 271, 294}; + aColIndices = new int[]{69, 212, 265, 42, 246, 186, 95, 39, 109}; + a = new CooCMatrix(aShape, aData, aRowIndices, aColIndices); + + exp = 2.438246090943242; + + assertEquals(exp, a.norm(), 1e-12); + } +} diff --git a/src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixSetColTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixSetColTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixSetColTests.java rename to src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixSetColTests.java index 4251c6b24..4e11c4ac2 100644 --- a/src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixSetColTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixSetColTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_sparse_matrix; +package org.flag4j.arrays.sparse.complex_coo_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixSetTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixSetTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixSetTests.java rename to src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixSetTests.java index 803a57c12..1e2e3e796 100644 --- a/src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixSetTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixSetTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_sparse_matrix; +package org.flag4j.arrays.sparse.complex_coo_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixStackTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixStackTests.java similarity index 98% rename from src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixStackTests.java rename to src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixStackTests.java index c3a0bfadf..89231ebe5 100644 --- a/src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixStackTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixStackTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_sparse_matrix; +package org.flag4j.arrays.sparse.complex_coo_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixToStringTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixToStringTests.java similarity index 99% rename from src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixToStringTests.java rename to src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixToStringTests.java index 4837fa653..ae9c38dbc 100644 --- a/src/test/java/org/flag4j/complex_sparse_matrix/CooCMatrixToStringTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/complex_coo_matrix/CooCMatrixToStringTests.java @@ -1,4 +1,4 @@ -package org.flag4j.complex_sparse_matrix; +package org.flag4j.arrays.sparse.complex_coo_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/ComplexCsrDenseMatMultTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/ComplexCsrDenseMatMultTests.java new file mode 100644 index 000000000..cebdbb3b2 --- /dev/null +++ b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/ComplexCsrDenseMatMultTests.java @@ -0,0 +1,161 @@ +package org.flag4j.arrays.sparse.complex_csr_matrix; + +import org.flag4j.algebraic_structures.Complex128; +import org.flag4j.arrays.dense.CMatrix; +import org.flag4j.arrays.dense.Matrix; +import org.flag4j.arrays.sparse.CsrCMatrix; +import org.flag4j.util.exceptions.LinearAlgebraException; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class ComplexCsrDenseMatMultTests { + static CsrCMatrix A; + static CMatrix aDense; + static Complex128[][] aEntries; + static Matrix Breal; + static CMatrix B; + static double[][] bRealEntries; + static Complex128[][] bComplexEntries; + static CMatrix exp; + + private static void buildReal(boolean... args) { + aDense = new CMatrix(aEntries); + A = aDense.toCsr(); + Breal = new Matrix(bRealEntries); + if(args.length != 1 || args[0]) exp = aDense.mult(Breal); + } + + private static void buildComplex(boolean... args) { + aDense = new CMatrix(aEntries); + A = aDense.toCsr(); + B = new CMatrix(bComplexEntries); + if(args.length != 1 || args[0]) exp = aDense.mult(B); + } + + + @Test + void multRealDenseTests() { + // ---------------------- Sub-case 1 ---------------------- + aEntries = new Complex128[][]{ + {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(80.1, 2.5)}, + {new Complex128(0), new Complex128(1.41, -92.2), new Complex128(0), new Complex128(0, 15.5), new Complex128(0), new Complex128(0)}, + {new Complex128(0), new Complex128(-9.25, 23.5), new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0)}, + {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(-999.1155, 2.25), new Complex128(-1, 1)}}; + bRealEntries = new double[][]{ + {0.72773, 0.90836}, + {0.02926, 0.3265}, + {0.23691, 0.77541}, + {0.6462, 0.36597}, + {0.18312, 0.77178}, + {0.40715, 0.35642}}; + buildReal(); + + assertEquals(exp, A.mult(Breal)); + + // ---------------------- Sub-case 2 ---------------------- + aEntries = new Complex128[][]{ + {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0)}, + {new Complex128(-77.3, -15122.1), new Complex128(0), new Complex128(0), new Complex128(0)}, + {new Complex128(0), new Complex128(0), new Complex128(0, 803.2), new Complex128(0)}, + {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0)}, + {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0)}, + {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0)}, + {new Complex128(0), new Complex128(-9.345, 58.1), new Complex128(0), new Complex128(0)}, + {new Complex128(0), new Complex128(0), new Complex128(1.45, -23), new Complex128(0)}, + {new Complex128(345), new Complex128(2.4, 5.61), new Complex128(0), new Complex128(0)}, + {new Complex128(0), new Complex128(0), new Complex128(4.45, -67.2), new Complex128(0)}, + {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(1)}}; + bRealEntries = new double[][]{ + {0.72773, 0.90836}, + {0.02926, 0.3265}, + {0.23691, 0.77541}, + {0.6462, 0.36597}}; + buildReal(); + + assertEquals(exp, A.mult(Breal)); + + // ---------------------- Sub-case 3 ---------------------- + aEntries = new Complex128[][]{ + {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0)}, + {new Complex128(-77.3, -15122.1), new Complex128(0), new Complex128(0), new Complex128(0)}, + {new Complex128(0), new Complex128(0), new Complex128(0, 803.2), new Complex128(0)}, + {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0)}, + {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0)}, + {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0)}, + {new Complex128(0), new Complex128(-9.345, 58.1), new Complex128(0), new Complex128(0)}, + {new Complex128(0), new Complex128(0), new Complex128(1.45, -23), new Complex128(0)}, + {new Complex128(345), new Complex128(2.4, 5.61), new Complex128(0), new Complex128(0)}, + {new Complex128(0), new Complex128(0), new Complex128(4.45, -67.2), new Complex128(0)}, + {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(1)}}; + bRealEntries = new double[][]{ + {0.72773, 0.90836}, + {0.02926, 0.3265}, + {0.23691, 0.77541}}; + buildReal(false); + + assertThrows(LinearAlgebraException.class, ()->A.mult(Breal)); + } + + + @Test + void multComplexDenseTests() { + // ---------------------- Sub-case 1 ---------------------- + aEntries = new Complex128[][]{ + {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(80.1, 2.5)}, + {new Complex128(0), new Complex128(1.41, -92.2), new Complex128(0), new Complex128(0, 15.5), new Complex128(0), new Complex128(0)}, + {new Complex128(0), new Complex128(-9.25, 23.5), new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0)}, + {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(-999.1155, 2.25), new Complex128(-1, 1)}}; + bComplexEntries = new Complex128[][]{ + {new Complex128(0.60886, 0.33378), new Complex128(0.00204, 0.66152)}, + {new Complex128(0.11395, 0.22798), new Complex128(0.85626, 0.48514)}, + {new Complex128(0.63642, 0.52434), new Complex128(0.95994, 0.9354)}, + {new Complex128(0.19401, 0.93407), new Complex128(0.64822, 0.24427)}, + {new Complex128(0.49749, 0.11432), new Complex128(0.06738, 0.73179)}, + {new Complex128(0.08942, 0.10066), new Complex128(0.02026, 0.06551)}}; + buildComplex(); + assertEquals(exp, A.mult(B)); + + // ---------------------- Sub-case 2 ---------------------- + aEntries = new Complex128[][]{ + {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0)}, + {new Complex128(-77.3, -15122.1), new Complex128(0), new Complex128(0), new Complex128(0)}, + {new Complex128(0), new Complex128(0), new Complex128(0, 803.2), new Complex128(0)}, + {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0)}, + {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0)}, + {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0)}, + {new Complex128(0), new Complex128(-9.345, 58.1), new Complex128(0), new Complex128(0)}, + {new Complex128(0), new Complex128(0), new Complex128(1.45, -23), new Complex128(0)}, + {new Complex128(345), new Complex128(2.4, 5.61), new Complex128(0), new Complex128(0)}, + {new Complex128(0), new Complex128(0), new Complex128(4.45, -67.2), new Complex128(0)}, + {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(1)}}; + bComplexEntries = new Complex128[][]{ + {new Complex128(0.66751, 0.11856), new Complex128(0.98271, 0.49906)}, + {new Complex128(0.14152, 0.98128), new Complex128(0.30904, 0.21053)}, + {new Complex128(0.28185, 0.28402), new Complex128(0.76892, 0.97375)}, + {new Complex128(0.44435, 0.06128), new Complex128(0.57068, 0.89705)}}; + buildComplex(); + assertEquals(exp, A.mult(B)); + + // ---------------------- Sub-case 3 ---------------------- + aEntries = new Complex128[][]{ + {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0)}, + {new Complex128(-77.3, -15122.1), new Complex128(0), new Complex128(0), new Complex128(0)}, + {new Complex128(0), new Complex128(0), new Complex128(0, 803.2), new Complex128(0)}, + {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0)}, + {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0)}, + {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0)}, + {new Complex128(0), new Complex128(-9.345, 58.1), new Complex128(0), new Complex128(0)}, + {new Complex128(0), new Complex128(0), new Complex128(1.45, -23), new Complex128(0)}, + {new Complex128(345), new Complex128(2.4, 5.61), new Complex128(0), new Complex128(0)}, + {new Complex128(0), new Complex128(0), new Complex128(4.45, -67.2), new Complex128(0)}, + {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(1)}}; + bComplexEntries = new Complex128[][]{ + {new Complex128(0.57033, 0.74092), new Complex128(0.62504, 0.25253)}, + {new Complex128(0.69264, 0.37406), new Complex128(0.29895, 0.17085)}, + {new Complex128(0.95162, 0.22682), new Complex128(0.30524, 0.91462)}}; + buildComplex(false); + assertThrows(LinearAlgebraException.class, ()->A.mult(B)); + } +} diff --git a/src/test/java/org/flag4j/sparse_csr_complex_matrix/ComplexCsrMatMultTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/ComplexCsrMatMultTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_csr_complex_matrix/ComplexCsrMatMultTests.java rename to src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/ComplexCsrMatMultTests.java index bf4b6b8d9..7afb540d4 100644 --- a/src/test/java/org/flag4j/sparse_csr_complex_matrix/ComplexCsrMatMultTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/ComplexCsrMatMultTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_csr_complex_matrix; +package org.flag4j.arrays.sparse.complex_csr_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/sparse_csr_complex_matrix/ComplexRealCsrCsrMatMultTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/ComplexRealCsrCsrMatMultTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_csr_complex_matrix/ComplexRealCsrCsrMatMultTests.java rename to src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/ComplexRealCsrCsrMatMultTests.java index f292501a8..c65adb347 100644 --- a/src/test/java/org/flag4j/sparse_csr_complex_matrix/ComplexRealCsrCsrMatMultTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/ComplexRealCsrCsrMatMultTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_csr_complex_matrix; +package org.flag4j.arrays.sparse.complex_csr_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CMatrix; diff --git a/src/test/java/org/flag4j/sparse_csr_complex_matrix/CsrCMatrixAddSubTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixAddSubTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_csr_complex_matrix/CsrCMatrixAddSubTests.java rename to src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixAddSubTests.java index 2121d5057..552433d9e 100644 --- a/src/test/java/org/flag4j/sparse_csr_complex_matrix/CsrCMatrixAddSubTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixAddSubTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_csr_complex_matrix; +package org.flag4j.arrays.sparse.complex_csr_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/sparse_csr_complex_matrix/CsrCMatrixEqualsTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixEqualsTests.java similarity index 97% rename from src/test/java/org/flag4j/sparse_csr_complex_matrix/CsrCMatrixEqualsTests.java rename to src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixEqualsTests.java index 7bebd1dc5..b93f61cf3 100644 --- a/src/test/java/org/flag4j/sparse_csr_complex_matrix/CsrCMatrixEqualsTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixEqualsTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_csr_complex_matrix; +package org.flag4j.arrays.sparse.complex_csr_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; @@ -88,6 +88,9 @@ Complex128.ZERO, new Complex128(235.1, 94.2), new Complex128(3.12, 4), A = new CooCMatrix(aShape, aNnz, aIndices[0], aIndices[1]).toCsr(); B = new CooCMatrix(bShape, bNnz, bIndices[0], bIndices[1]).toCsr(); + CooCMatrix ta = new CooCMatrix(aShape, aNnz, aIndices[0], aIndices[1]); + CooCMatrix tb = new CooCMatrix(bShape, bNnz, bIndices[0], bIndices[1]); + assertEquals(A, B); // ---------------------- Sub-case 4 ---------------------- diff --git a/src/test/java/org/flag4j/sparse_csr_complex_matrix/CsrCMatrixGetRowColTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixGetRowColTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_csr_complex_matrix/CsrCMatrixGetRowColTests.java rename to src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixGetRowColTests.java index 4a9d035f2..06491709b 100644 --- a/src/test/java/org/flag4j/sparse_csr_complex_matrix/CsrCMatrixGetRowColTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixGetRowColTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_csr_complex_matrix; +package org.flag4j.arrays.sparse.complex_csr_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/sparse_csr_complex_matrix/CsrCMatrixGetSetTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixGetSetTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_csr_complex_matrix/CsrCMatrixGetSetTests.java rename to src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixGetSetTests.java index 199a9bf9a..42848b422 100644 --- a/src/test/java/org/flag4j/sparse_csr_complex_matrix/CsrCMatrixGetSetTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixGetSetTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_csr_complex_matrix; +package org.flag4j.arrays.sparse.complex_csr_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CMatrix; diff --git a/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixGetSliceTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixGetSliceTests.java new file mode 100644 index 000000000..0b01035cc --- /dev/null +++ b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixGetSliceTests.java @@ -0,0 +1,116 @@ +package org.flag4j.arrays.sparse.complex_csr_matrix; + +import org.flag4j.algebraic_structures.Complex128; +import org.flag4j.arrays.Shape; +import org.flag4j.arrays.sparse.CsrCMatrix; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class CsrCMatrixGetSliceTests { + + @Test + void getSliceTests() { + int rowStart, rowEnd, colStart, colEnd; + int[] aRowPointers, aColIndices, expRowPointers, expColIndices; + Complex128[] aData, expData; + Shape aShape, expShape; + CsrCMatrix a, exp; + + // -------------------- sub-case 1 -------------------- + rowStart = 0; + rowEnd = 15; + colStart = 0; + colEnd = 156; + + aShape = new Shape(162, 525); + aData = new Complex128[]{new Complex128(0.78035, 0.12308), new Complex128(0.69964, 0.39359), new Complex128(0.71946, 0.23139), new Complex128(0.03003, 0.24849), new Complex128(0.75854, 0.87197), new Complex128(0.6154, 0.79933), new Complex128(0.22866, 0.81778), new Complex128(0.62108, 0.02225), new Complex128(0.12051, 0.77826)}; + aRowPointers = new int[]{0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9}; + aColIndices = new int[]{377, 323, 104, 450, 260, 373, 314, 507, 383}; + a = new CsrCMatrix(aShape, aData, aRowPointers, aColIndices); + + expShape = new Shape(15, 156); + expData = new Complex128[]{}; + expRowPointers = new int[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + expColIndices = new int[]{}; + exp = new CsrCMatrix(expShape, expData, expRowPointers, expColIndices); + + assertEquals(exp, a.getSlice(rowStart, rowEnd, colStart, colEnd)); + + // -------------------- sub-case 2 -------------------- + rowStart = 15; + rowEnd = 25; + colStart = 6; + colEnd = 24; + + aShape = new Shape(25, 35); + aData = new Complex128[]{new Complex128(0.58041, 0.13741), new Complex128(0.31126, 0.11722), new Complex128(0.24317, 0.64169), new Complex128(0.37413, 0.97784), new Complex128(0.00337, 0.35128), new Complex128(0.76686, 0.96587), new Complex128(0.1325, 0.53405), new Complex128(0.17819, 0.94705), new Complex128(0.20021, 0.52085)}; + aRowPointers = new int[]{0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 5, 7, 8, 8, 8, 9, 9, 9, 9}; + aColIndices = new int[]{24, 34, 10, 21, 27, 11, 13, 29, 15}; + a = new CsrCMatrix(aShape, aData, aRowPointers, aColIndices); + + expShape = new Shape(10, 18); + expData = new Complex128[]{new Complex128(0.24317, 0.64169), new Complex128(0.37413, 0.97784), new Complex128(0.76686, 0.96587), new Complex128(0.1325, 0.53405), new Complex128(0.20021, 0.52085)}; + expRowPointers = new int[]{0, 0, 2, 4, 4, 4, 4, 5, 5, 5, 5}; + expColIndices = new int[]{4, 15, 5, 7, 9}; + exp = new CsrCMatrix(expShape, expData, expRowPointers, expColIndices); + + assertEquals(exp, a.getSlice(rowStart, rowEnd, colStart, colEnd)); + + // -------------------- sub-case 3 -------------------- + rowStart = 8; + rowEnd = 9; + colStart = 18; + colEnd = 21; + + aShape = new Shape(33, 21); + aData = new Complex128[]{new Complex128(0.38171, 0.9855), new Complex128(0.97057, 0.96037), new Complex128(0.87573, 0.34079), new Complex128(0.37559, 0.59789), new Complex128(0.87787, 0.76645), new Complex128(0.57899, 0.49322), new Complex128(0.31817, 0.39966)}; + aRowPointers = new int[]{0, 0, 0, 0, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 7, 7, 7, 7}; + aColIndices = new int[]{9, 13, 16, 8, 11, 2, 1}; + a = new CsrCMatrix(aShape, aData, aRowPointers, aColIndices); + + expShape = new Shape(1, 3); + expData = new Complex128[]{}; + expRowPointers = new int[]{0, 0}; + expColIndices = new int[]{}; + exp = new CsrCMatrix(expShape, expData, expRowPointers, expColIndices); + + assertEquals(exp, a.getSlice(rowStart, rowEnd, colStart, colEnd)); + + // -------------------- sub-case 4 -------------------- + rowStart = 5; + rowEnd = 22; + colStart = 0; + colEnd = 55; + + aShape = new Shape(55, 55); + aData = new Complex128[]{}; + aRowPointers = new int[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + aColIndices = new int[]{}; + a = new CsrCMatrix(aShape, aData, aRowPointers, aColIndices); + + expShape = new Shape(17, 55); + expData = new Complex128[]{}; + expRowPointers = new int[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + expColIndices = new int[]{}; + exp = new CsrCMatrix(expShape, expData, expRowPointers, expColIndices); + + assertEquals(exp, a.getSlice(rowStart, rowEnd, colStart, colEnd)); + + // -------------------- sub-case 4 -------------------- + aShape = new Shape(55, 55); + aData = new Complex128[]{}; + aRowPointers = new int[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + aColIndices = new int[]{}; + a = new CsrCMatrix(aShape, aData, aRowPointers, aColIndices); + + CsrCMatrix finalA = a; + assertThrows(IllegalArgumentException.class, ()-> finalA.getSlice(-1, 2, 4, 5)); + assertThrows(IllegalArgumentException.class, ()-> finalA.getSlice(1, 2, -4, 5)); + assertThrows(IllegalArgumentException.class, ()-> finalA.getSlice(4, 2, 4, 5)); + assertThrows(IllegalArgumentException.class, ()-> finalA.getSlice(1, 2, 4, 1)); + assertThrows(IllegalArgumentException.class, ()-> finalA.getSlice(1, 56, 0, 5)); + assertThrows(IllegalArgumentException.class, ()-> finalA.getSlice(1, 2, 0, 514)); + } +} diff --git a/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixIsSymmTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixIsSymmTests.java new file mode 100644 index 000000000..f866f9675 --- /dev/null +++ b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixIsSymmTests.java @@ -0,0 +1,70 @@ +package org.flag4j.arrays.sparse.complex_csr_matrix; + +import org.flag4j.algebraic_structures.Complex128; +import org.flag4j.arrays.Shape; +import org.flag4j.arrays.sparse.CsrCMatrix; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class CsrCMatrixIsSymmTests { + + @Test + void isSymmetricTests() { + Shape aShape; + int[] aRowPointers, aColIndices; + Complex128[] aData; + CsrCMatrix a; + + // -------------------- Sub-case 1 -------------------- + aShape = new Shape(51, 51); + aData = new Complex128[]{new Complex128(0.711, 0.3), new Complex128(0.875, 0.657), new Complex128(0.057, 0.164), new Complex128(0.207, 0.887), new Complex128(0.885, 0.926), new Complex128(0.939, 0.405), new Complex128(0.869, 0.506), new Complex128(0.562, 0.55), new Complex128(0.94, 0.756), new Complex128(0.193, 0.037), new Complex128(0.727, 0.541), new Complex128(0.938, 0.119)}; + aRowPointers = new int[]{0, 1, 2, 2, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 12}; + aColIndices = new int[]{30, 25, 49, 11, 22, 37, 50, 15, 25, 4, 21, 32}; + a = new CsrCMatrix(aShape, aData, aRowPointers, aColIndices); + + assertFalse(a.isSymmetric()); + assertFalse(a.isHermitian()); + + // -------------------- Sub-case 2 -------------------- + aShape = new Shape(51, 51); + aData = new Complex128[]{new Complex128(0.315, 0.311), new Complex128(0.155, 0.236), new Complex128(0.345, 0.92), new Complex128(0.155, 0.236), new Complex128(0.347, 0.256), new Complex128(0.315, 0.311), new Complex128(0.345, 0.92), new Complex128(0.347, 0.256), new Complex128(0.119, 0.913)}; + aRowPointers = new int[]{0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 8, 8, 8, 8, 8, 8, 8, 9, 9}; + aColIndices = new int[]{29, 14, 40, 6, 42, 0, 10, 28, 49}; + a = new CsrCMatrix(aShape, aData, aRowPointers, aColIndices); + + assertTrue(a.isSymmetric()); + assertFalse(a.isHermitian()); + + // -------------------- Sub-case 3 -------------------- + aShape = new Shape(51, 51); + aData = new Complex128[]{new Complex128(0.38, -0.82), new Complex128(0.38, 0.82), new Complex128(0.456, -0.305), new Complex128(0.843, -0.768), new Complex128(0.839, -0.306), new Complex128(0.533, -0.878), new Complex128(0.718, -0.497), new Complex128(0.533, 0.878), new Complex128(0.456, 0.305), new Complex128(0.718, 0.497), new Complex128(0.839, 0.306), new Complex128(0.843, 0.768)}; + aRowPointers = new int[]{0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12}; + aColIndices = new int[]{8, 1, 36, 44, 40, 34, 38, 29, 9, 32, 26, 17}; + a = new CsrCMatrix(aShape, aData, aRowPointers, aColIndices); + + assertFalse(a.isSymmetric()); + assertTrue(a.isHermitian()); + + // -------------------- Sub-case 4 -------------------- + aShape = new Shape(12, 12); + aData = new Complex128[]{}; + aRowPointers = new int[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + aColIndices = new int[]{}; + a = new CsrCMatrix(aShape, aData, aRowPointers, aColIndices); + + assertTrue(a.isSymmetric()); + assertTrue(a.isHermitian()); + + // -------------------- Sub-case 5 -------------------- + aShape = new Shape(12, 17); + aData = new Complex128[]{new Complex128(0.305, 0.406), new Complex128(0.599, 0.739), new Complex128(0.02, 0.945), new Complex128(0.549, 0.64), new Complex128(0.842, 0.842), new Complex128(0.48, 0.124), new Complex128(0.166, 0.889)}; + aRowPointers = new int[]{0, 4, 4, 4, 6, 6, 6, 6, 6, 7, 7, 7, 7}; + aColIndices = new int[]{2, 6, 11, 12, 5, 8, 0}; + a = new CsrCMatrix(aShape, aData, aRowPointers, aColIndices); + + assertFalse(a.isSymmetric()); + assertFalse(a.isHermitian()); + } +} diff --git a/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixNormTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixNormTests.java new file mode 100644 index 000000000..e6fcb730e --- /dev/null +++ b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixNormTests.java @@ -0,0 +1,88 @@ +package org.flag4j.arrays.sparse.complex_csr_matrix; + +import org.flag4j.algebraic_structures.Complex128; +import org.flag4j.arrays.Shape; +import org.flag4j.arrays.sparse.CsrCMatrix; +import org.flag4j.linalg.MatrixNorms; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class CsrCMatrixNormTests { + + @Test + void csrLpqNorms() { + Shape aShape; + Complex128[] aData; + int[] aRowPointers, aColIndices; + CsrCMatrix a; + double exp, p, q; + + // ----------------------- Sub-case 1 ----------------------- + aShape = new Shape(32, 32); + aData = new Complex128[]{new Complex128(0.02477, 0.76398), new Complex128(0.40479, 0.85967), new Complex128(0.4548, 0.38816), new Complex128(0.91506, 0.39483), new Complex128(0.57948, 0.66513), new Complex128(0.66353, 0.70929), new Complex128(0.43548, 0.52384), new Complex128(0.34137, 0.33486), new Complex128(0.24652, 0.1701), new Complex128(0.95208, 0.50698)}; + aRowPointers = new int[]{0, 0, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 5, 5, 6, 6, 6, 7, 9, 9, 9, 10, 10}; + aColIndices = new int[]{2, 20, 11, 15, 22, 15, 21, 17, 24, 28}; + a = new CsrCMatrix(aShape, aData, aRowPointers, aColIndices); + + p = 1; + q = 1; + exp = 7.700099282600347; + + assertEquals(exp, MatrixNorms.norm(a, p, q)); + + // ----------------------- Sub-case 2 ----------------------- + aShape = new Shape(32, 32); + aData = new Complex128[]{new Complex128(0.49025, 0.29208), new Complex128(0.17466, 0.43675), new Complex128(0.51197, 0.38465), new Complex128(0.4191, 0.46123), new Complex128(0.43736, 0.70804), new Complex128(0.49982, 0.30753), new Complex128(0.11769, 0.52149), new Complex128(0.05334, 0.41568), new Complex128(0.07662, 0.35198), new Complex128(0.63588, 0.99672)}; + aRowPointers = new int[]{0, 1, 1, 1, 1, 2, 3, 3, 3, 3, 3, 4, 4, 5, 7, 7, 7, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 10}; + aColIndices = new int[]{13, 31, 7, 24, 16, 9, 12, 2, 12, 13}; + a = new CsrCMatrix(aShape, aData, aRowPointers, aColIndices); + + p = 1; + q = 2; + exp = 2.4710068494660162; + + assertEquals(exp, MatrixNorms.norm(a, p, q)); + + // ----------------------- Sub-case 3 ----------------------- + aShape = new Shape(32, 32); + aData = new Complex128[]{new Complex128(0.72585, 0.65), new Complex128(0.07972, 0.38903), new Complex128(0.20433, 0.35512), new Complex128(0.52162, 0.26212), new Complex128(0.97118, 0.80551), new Complex128(0.90915, 0.03623), new Complex128(0.81902, 0.84555), new Complex128(0.84593, 0.62298), new Complex128(0.876, 0.03699), new Complex128(0.97226, 0.77725)}; + aRowPointers = new int[]{0, 1, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 5, 5, 5, 5, 7, 7, 7, 7, 8, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10}; + aColIndices = new int[]{25, 31, 5, 31, 4, 7, 8, 26, 12, 9}; + a = new CsrCMatrix(aShape, aData, aRowPointers, aColIndices); + + p = 2; + q = 1; + exp = 8.611013354258432; + assertEquals(exp, MatrixNorms.norm(a, p, q)); + + // ----------------------- Sub-case 4 ----------------------- + aShape = new Shape(32, 32); + aData = new Complex128[]{new Complex128(0.27792, 0.15589), new Complex128(0.66318, 0.34308), new Complex128(0.52266, 0.94122), new Complex128(0.03301, 0.23813), new Complex128(0.29464, 0.79951), new Complex128(0.89043, 0.34081), new Complex128(0.60507, 0.00875), new Complex128(0.86107, 0.62376), new Complex128(0.48382, 0.17035), new Complex128(0.74956, 0.71469)}; + aRowPointers = new int[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 3, 3, 4, 4, 6, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 9, 9, 9, 10}; + aColIndices = new int[]{24, 30, 5, 1, 0, 9, 4, 27, 30, 5}; + a = new CsrCMatrix(aShape, aData, aRowPointers, aColIndices); + + p = 4.12; + q = 9.3; + exp = 1.2907615018426493; + + assertEquals(exp, MatrixNorms.norm(a, p, q)); + + // ----------------------- Sub-case 5 ----------------------- + aShape = new Shape(32, 32); + aData = new Complex128[]{new Complex128(0.86738, 0.01318), new Complex128(0.83588, 0.48586), new Complex128(0.02535, 0.51349), new Complex128(0.00979, 0.97632), new Complex128(0.02841, 0.76271), new Complex128(0.02338, 0.49899), new Complex128(0.84757, 0.03883), new Complex128(0.7871, 0.17426), new Complex128(0.24826, 0.88086), new Complex128(0.34889, 0.24449)}; + aRowPointers = new int[]{0, 0, 0, 1, 3, 3, 3, 3, 5, 6, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10}; + aColIndices = new int[]{0, 15, 31, 1, 25, 20, 10, 4, 4, 31}; + a = new CsrCMatrix(aShape, aData, aRowPointers, aColIndices); + + p = 0; + q = 0; + + CsrCMatrix finalA = a; + double finalP = p; + double finalQ = q; + assertThrows(IllegalArgumentException.class, () -> MatrixNorms.norm(finalA, finalP, finalQ)); + } +} diff --git a/src/test/java/org/flag4j/sparse_csr_complex_matrix/CsrCMatrixRowColSwapTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixRowColSwapTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_csr_complex_matrix/CsrCMatrixRowColSwapTests.java rename to src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixRowColSwapTests.java index a426870ae..4fb8ab8a6 100644 --- a/src/test/java/org/flag4j/sparse_csr_complex_matrix/CsrCMatrixRowColSwapTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixRowColSwapTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_csr_complex_matrix; +package org.flag4j.arrays.sparse.complex_csr_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CMatrix; diff --git a/src/test/java/org/flag4j/sparse_csr_complex_matrix/CsrCMatrixToDenseTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixToDenseTests.java similarity index 97% rename from src/test/java/org/flag4j/sparse_csr_complex_matrix/CsrCMatrixToDenseTests.java rename to src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixToDenseTests.java index 19c11de49..124006e88 100644 --- a/src/test/java/org/flag4j/sparse_csr_complex_matrix/CsrCMatrixToDenseTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixToDenseTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_csr_complex_matrix; +package org.flag4j.arrays.sparse.complex_csr_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CMatrix; diff --git a/src/test/java/org/flag4j/sparse_csr_complex_matrix/CsrCMatrixToStringTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixToStringTests.java similarity index 98% rename from src/test/java/org/flag4j/sparse_csr_complex_matrix/CsrCMatrixToStringTests.java rename to src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixToStringTests.java index 516620130..71a08028b 100644 --- a/src/test/java/org/flag4j/sparse_csr_complex_matrix/CsrCMatrixToStringTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixToStringTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_csr_complex_matrix; +package org.flag4j.arrays.sparse.complex_csr_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/sparse_csr_complex_matrix/CsrCMatrixToVectorTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixToVectorTests.java similarity index 98% rename from src/test/java/org/flag4j/sparse_csr_complex_matrix/CsrCMatrixToVectorTests.java rename to src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixToVectorTests.java index 2fdae7c7b..14ebb96d6 100644 --- a/src/test/java/org/flag4j/sparse_csr_complex_matrix/CsrCMatrixToVectorTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixToVectorTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_csr_complex_matrix; +package org.flag4j.arrays.sparse.complex_csr_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CMatrix; diff --git a/src/test/java/org/flag4j/sparse_csr_complex_matrix/CsrCMatrixTransposeTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixTransposeTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_csr_complex_matrix/CsrCMatrixTransposeTests.java rename to src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixTransposeTests.java index 9d4add396..a6d4c5122 100644 --- a/src/test/java/org/flag4j/sparse_csr_complex_matrix/CsrCMatrixTransposeTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixTransposeTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_csr_complex_matrix; +package org.flag4j.arrays.sparse.complex_csr_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CMatrix; diff --git a/src/test/java/org/flag4j/sparse_csr_complex_matrix/CsrCMatrixTriDiagTests.java b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixTriDiagTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_csr_complex_matrix/CsrCMatrixTriDiagTests.java rename to src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixTriDiagTests.java index 67ed12588..27b72d965 100644 --- a/src/test/java/org/flag4j/sparse_csr_complex_matrix/CsrCMatrixTriDiagTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/complex_csr_matrix/CsrCMatrixTriDiagTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_csr_complex_matrix; +package org.flag4j.arrays.sparse.complex_csr_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/sparse_matrix/CooMatrixAddSubTests.java b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixAddSubTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_matrix/CooMatrixAddSubTests.java rename to src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixAddSubTests.java index 2dbc79e29..badbedaec 100644 --- a/src/test/java/org/flag4j/sparse_matrix/CooMatrixAddSubTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixAddSubTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_matrix; +package org.flag4j.arrays.sparse.coo_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/sparse_matrix/CooMatrixAugmentTests.java b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixAugmentTests.java similarity index 98% rename from src/test/java/org/flag4j/sparse_matrix/CooMatrixAugmentTests.java rename to src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixAugmentTests.java index d7951beb2..8853b70e4 100644 --- a/src/test/java/org/flag4j/sparse_matrix/CooMatrixAugmentTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixAugmentTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_matrix; +package org.flag4j.arrays.sparse.coo_matrix; import org.flag4j.arrays.Shape; import org.flag4j.arrays.sparse.CooMatrix; diff --git a/src/test/java/org/flag4j/sparse_matrix/CooMatrixAugmentVectorTests.java b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixAugmentVectorTests.java similarity index 98% rename from src/test/java/org/flag4j/sparse_matrix/CooMatrixAugmentVectorTests.java rename to src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixAugmentVectorTests.java index 64c14fd85..c368c2720 100644 --- a/src/test/java/org/flag4j/sparse_matrix/CooMatrixAugmentVectorTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixAugmentVectorTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_matrix; +package org.flag4j.arrays.sparse.coo_matrix; import org.flag4j.arrays.Shape; import org.flag4j.arrays.sparse.CooMatrix; diff --git a/src/test/java/org/flag4j/sparse_matrix/CooMatrixConstructorTests.java b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixConstructorTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_matrix/CooMatrixConstructorTests.java rename to src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixConstructorTests.java index 2c12cdfe8..f8941acb5 100644 --- a/src/test/java/org/flag4j/sparse_matrix/CooMatrixConstructorTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixConstructorTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_matrix; +package org.flag4j.arrays.sparse.coo_matrix; import org.flag4j.arrays.Shape; import org.flag4j.arrays.sparse.CooMatrix; diff --git a/src/test/java/org/flag4j/sparse_matrix/CooMatrixDirectSumTests.java b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixDirectSumTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_matrix/CooMatrixDirectSumTests.java rename to src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixDirectSumTests.java index e888cd54e..1a88d767a 100644 --- a/src/test/java/org/flag4j/sparse_matrix/CooMatrixDirectSumTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixDirectSumTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_matrix; +package org.flag4j.arrays.sparse.coo_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/sparse_matrix/CooMatrixElemDivTests.java b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixElemDivTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_matrix/CooMatrixElemDivTests.java rename to src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixElemDivTests.java index 3758a2804..32e78ae55 100644 --- a/src/test/java/org/flag4j/sparse_matrix/CooMatrixElemDivTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixElemDivTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_matrix; +package org.flag4j.arrays.sparse.coo_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/sparse_matrix/CooMatrixElemMultTests.java b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixElemMultTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_matrix/CooMatrixElemMultTests.java rename to src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixElemMultTests.java index 16a9a6bce..b14e8e810 100644 --- a/src/test/java/org/flag4j/sparse_matrix/CooMatrixElemMultTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixElemMultTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_matrix; +package org.flag4j.arrays.sparse.coo_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/sparse_matrix/CooMatrixEqualsTest.java b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixEqualsTest.java similarity index 97% rename from src/test/java/org/flag4j/sparse_matrix/CooMatrixEqualsTest.java rename to src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixEqualsTest.java index a06768136..a17d76bcb 100644 --- a/src/test/java/org/flag4j/sparse_matrix/CooMatrixEqualsTest.java +++ b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixEqualsTest.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_matrix; +package org.flag4j.arrays.sparse.coo_matrix; import org.flag4j.arrays.Shape; import org.flag4j.arrays.sparse.CooMatrix; diff --git a/src/test/java/org/flag4j/sparse_matrix/CooMatrixGetRowColTests.java b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixGetRowColTests.java similarity index 95% rename from src/test/java/org/flag4j/sparse_matrix/CooMatrixGetRowColTests.java rename to src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixGetRowColTests.java index 46681e8df..6f9d8e230 100644 --- a/src/test/java/org/flag4j/sparse_matrix/CooMatrixGetRowColTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixGetRowColTests.java @@ -1,9 +1,9 @@ -package org.flag4j.sparse_matrix; +package org.flag4j.arrays.sparse.coo_matrix; import org.flag4j.arrays.Shape; import org.flag4j.arrays.sparse.CooMatrix; import org.flag4j.arrays.sparse.CooVector; -import org.flag4j.linalg.ops.sparse.coo.real.RealSparseMatrixGetSet; +import org.flag4j.linalg.ops.sparse.coo.real.RealCooMatrixGetSet; import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -114,7 +114,7 @@ void getRowSliceTest() { expIndices = new int[]{}; exp = new CooVector(expShape.get(0), expEntries, expIndices); - assertEquals(exp, RealSparseMatrixGetSet.getRow(a, 2, 1, 3)); + assertEquals(exp, RealCooMatrixGetSet.getRow(a, 2, 1, 3)); // --------------------- Sub-case 2 --------------------- aShape = new Shape(23, 11); @@ -128,7 +128,7 @@ void getRowSliceTest() { expIndices = new int[]{}; exp = new CooVector(expShape.get(0), expEntries, expIndices); - assertEquals(exp, RealSparseMatrixGetSet.getRow(a, 18, 0, 7)); + assertEquals(exp, RealCooMatrixGetSet.getRow(a, 18, 0, 7)); // --------------------- Sub-case 3 --------------------- aShape = new Shape(1000, 5); @@ -142,7 +142,7 @@ void getRowSliceTest() { expIndices = new int[]{}; exp = new CooVector(expShape.get(0), expEntries, expIndices); - assertEquals(exp, RealSparseMatrixGetSet.getRow(a, 0, 1, 4)); + assertEquals(exp, RealCooMatrixGetSet.getRow(a, 0, 1, 4)); // --------------------- Sub-case 4 --------------------- aShape = new Shape(3, 5); @@ -152,7 +152,7 @@ void getRowSliceTest() { a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); CooMatrix final0a = a; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.getRow(final0a, -1, 1, 3)); + assertThrows(Exception.class, ()-> RealCooMatrixGetSet.getRow(final0a, -1, 1, 3)); // --------------------- Sub-case 5 --------------------- aShape = new Shape(3, 5); @@ -162,7 +162,7 @@ void getRowSliceTest() { a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); CooMatrix final1a = a; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.getRow(final1a, 3, 1, 3)); + assertThrows(Exception.class, ()-> RealCooMatrixGetSet.getRow(final1a, 3, 1, 3)); // --------------------- Sub-case 6 --------------------- aShape = new Shape(3, 5); @@ -172,7 +172,7 @@ void getRowSliceTest() { a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); CooMatrix final2a = a; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.getRow(final2a, 2, -1, 3)); + assertThrows(Exception.class, ()-> RealCooMatrixGetSet.getRow(final2a, 2, -1, 3)); // --------------------- Sub-case 7 --------------------- aShape = new Shape(3, 5); @@ -182,7 +182,7 @@ void getRowSliceTest() { a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); CooMatrix final3a = a; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.getRow(final3a, 2, 1, 6)); + assertThrows(Exception.class, ()-> RealCooMatrixGetSet.getRow(final3a, 2, 1, 6)); } diff --git a/src/test/java/org/flag4j/sparse_matrix/CooMatrixGetSliceTests.java b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixGetSliceTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_matrix/CooMatrixGetSliceTests.java rename to src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixGetSliceTests.java index 1785117c3..51e007991 100644 --- a/src/test/java/org/flag4j/sparse_matrix/CooMatrixGetSliceTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixGetSliceTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_matrix; +package org.flag4j.arrays.sparse.coo_matrix; import org.flag4j.arrays.Shape; import org.flag4j.arrays.sparse.CooMatrix; diff --git a/src/test/java/org/flag4j/sparse_matrix/CooMatrixGetTests.java b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixGetTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_matrix/CooMatrixGetTests.java rename to src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixGetTests.java index dec2893bf..a336e42e1 100644 --- a/src/test/java/org/flag4j/sparse_matrix/CooMatrixGetTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixGetTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_matrix; +package org.flag4j.arrays.sparse.coo_matrix; import org.flag4j.arrays.Shape; import org.flag4j.arrays.sparse.CooMatrix; diff --git a/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixIsSymmTests.java b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixIsSymmTests.java new file mode 100644 index 000000000..d9b4bd063 --- /dev/null +++ b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixIsSymmTests.java @@ -0,0 +1,83 @@ +package org.flag4j.arrays.sparse.coo_matrix; + +import org.flag4j.arrays.Shape; +import org.flag4j.arrays.sparse.CooMatrix; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class CooMatrixIsSymmTests { + + @Test + void isSymmetricTests() { + Shape aShape; + int[] aRowIndices, aColIndices; + double[] aData; + CooMatrix a; + + // -------------------- Sub-case 1 -------------------- + aShape = new Shape(51, 51); + aData = new double[]{0.711, 0.875, 0.057, 0.207, + 0.885, 0.939, 0.869, 0.562, + 0.94, 0.193, 0.727, 0.938}; + aRowIndices = new int[]{0, 1, 3, 4, 17, 22, 23, 29, 38, 41, 50, 50}; + aColIndices = new int[]{30, 25, 49, 11, 22, 37, 50, 15, 25, 4, 21, 32}; + a = new CooMatrix(aShape, aData, aRowIndices, aColIndices); + + assertFalse(a.isSymmetric()); + assertFalse(a.isHermitian()); + + // -------------------- Sub-case 2 -------------------- + aShape = new Shape(51, 51); + aData = new double[]{0.315, 0.155, + 0.345, 0.155, + 0.347, 0.315, + 0.345, 0.347, + 0.119}; + aRowIndices = new int[]{0, 6, 10, 14, 28, 29, 40, 42, 49}; + aColIndices = new int[]{29, 14, 40, 6, 42, 0, 10, 28, 49}; + a = new CooMatrix(aShape, aData, aRowIndices, aColIndices); + + assertTrue(a.isSymmetric()); + assertTrue(a.isHermitian()); + + // -------------------- Sub-case 3 -------------------- + aShape = new Shape(51, 51); + aData = new double[]{0.38, 0.38, + 0.456, 0.843, + 0.839, 0.533, + 0.718, 0.533, + 0.456, 0.718, + 0.839, 0.843}; + aRowIndices = new int[]{1, 8, 9, 17, 26, 29, 32, 34, 36, 38, 40, 44}; + aColIndices = new int[]{8, 1, 36, 44, 40, 34, 38, 29, 9, 32, 26, 17}; + a = new CooMatrix(aShape, aData, aRowIndices, aColIndices); + + assertTrue(a.isSymmetric()); + assertTrue(a.isHermitian()); + + // -------------------- Sub-case 4 -------------------- + aShape = new Shape(12, 12); + aData = new double[]{}; + aRowIndices = new int[]{}; + aColIndices = new int[]{}; + a = new CooMatrix(aShape, aData, aRowIndices, aColIndices); + + assertTrue(a.isSymmetric()); + assertTrue(a.isHermitian()); + + // -------------------- Sub-case 5 -------------------- + aShape = new Shape(12, 17); + aData = new double[]{0.305, 0.599, + 0.02, 0.549, + 0.842, 0.48, + 0.166}; + aRowIndices = new int[]{0, 0, 0, 0, 3, 3, 8}; + aColIndices = new int[]{2, 6, 11, 12, 5, 8, 0}; + a = new CooMatrix(aShape, aData, aRowIndices, aColIndices); + + assertFalse(a.isSymmetric()); + assertFalse(a.isHermitian()); + } +} diff --git a/src/test/java/org/flag4j/sparse_matrix/CooMatrixMatMultTests.java b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixMatMultTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_matrix/CooMatrixMatMultTests.java rename to src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixMatMultTests.java index 89161c565..9a67dd067 100644 --- a/src/test/java/org/flag4j/sparse_matrix/CooMatrixMatMultTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixMatMultTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_matrix; +package org.flag4j.arrays.sparse.coo_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/sparse_matrix/CooMatrixMultTransposeTests.java b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixMultTransposeTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_matrix/CooMatrixMultTransposeTests.java rename to src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixMultTransposeTests.java index 5557f713f..e976ddaa0 100644 --- a/src/test/java/org/flag4j/sparse_matrix/CooMatrixMultTransposeTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixMultTransposeTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_matrix; +package org.flag4j.arrays.sparse.coo_matrix; import org.flag4j.arrays.Shape; import org.flag4j.arrays.dense.Matrix; diff --git a/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixSetRowColTests.java b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixSetRowColTests.java new file mode 100644 index 000000000..9dbae1fd3 --- /dev/null +++ b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixSetRowColTests.java @@ -0,0 +1,574 @@ +package org.flag4j.arrays.sparse.coo_matrix; + +import org.flag4j.arrays.Shape; +import org.flag4j.arrays.sparse.CooMatrix; +import org.flag4j.arrays.sparse.CooVector; +import org.flag4j.linalg.ops.sparse.coo.real.RealCooMatrixGetSet; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class CooMatrixSetRowColTests { + + @Test + void setColTest() { + Shape aShape; + int[] aRowIndices; + int[] aColIndices; + double[] aEntries; + CooMatrix a; + + double[] bEntries; + + Shape expShape; + int[] expRowIndices; + int[] expColIndices; + double[] expEntries; + CooMatrix exp; + + // --------------------- Sub-case 1 --------------------- + aShape = new Shape(5, 3); + aEntries = new double[]{0.42216, 0.86886, 0.51801}; + aRowIndices = new int[]{0, 0, 1}; + aColIndices = new int[]{1, 2, 2}; + a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); + + bEntries = new double[]{0.30728, 0.13698, 0.23211, 0.05517, 0.12575}; + + expShape = new Shape(5, 3); + expEntries = new double[]{0.30728, 0.42216, 0.86886, 0.13698, 0.51801, 0.23211, 0.05517, 0.12575}; + expRowIndices = new int[]{0, 0, 0, 1, 1, 2, 3, 4}; + expColIndices = new int[]{0, 1, 2, 0, 2, 0, 0, 0}; + exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); + + assertEquals(exp, RealCooMatrixGetSet.setCol(a, 0, bEntries)); + + // --------------------- Sub-case 2 --------------------- + aShape = new Shape(11, 23); + aEntries = new double[]{0.86291, 0.59273, 0.14697, 0.79343, 0.0691}; + aRowIndices = new int[]{4, 5, 6, 8, 10}; + aColIndices = new int[]{14, 3, 9, 15, 4}; + a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); + + bEntries = new double[]{0.09599, 0.03342, 0.08342, 0.86195, 0.18126, 0.71121, 0.03191, 0.3479, 0.5699, 0.35584, 0.51796}; + + expShape = new Shape(11, 23); + expEntries = new double[]{0.09599, 0.03342, 0.08342, 0.86195, 0.86291, 0.18126, 0.59273, 0.71121, 0.14697, 0.03191, 0.3479, 0.79343, 0.5699, 0.35584, 0.0691, 0.51796}; + expRowIndices = new int[]{0, 1, 2, 3, 4, 4, 5, 5, 6, 6, 7, 8, 8, 9, 10, 10}; + expColIndices = new int[]{16, 16, 16, 16, 14, 16, 3, 16, 9, 16, 16, 15, 16, 16, 4, 16}; + exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); + + assertEquals(exp, RealCooMatrixGetSet.setCol(a, 16, bEntries)); + + // --------------------- Sub-case 3 --------------------- + aShape = new Shape(5, 1000); + aEntries = new double[]{0.91557, 0.99112, 0.97331, 0.46736, 0.39273, 0.9236, 0.55027, 0.96506, 0.46553}; + aRowIndices = new int[]{0, 1, 2, 3, 3, 4, 4, 4, 4}; + aColIndices = new int[]{118, 335, 419, 424, 880, 134, 358, 492, 949}; + a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); + + bEntries = new double[]{0.86214, 0.01468, 0.80744, 0.38058, 0.27367}; + + expShape = new Shape(5, 1000); + expEntries = new double[]{0.91557, 0.86214, 0.99112, 0.01468, 0.97331, 0.80744, 0.46736, 0.39273, 0.38058, 0.9236, 0.55027, 0.96506, 0.46553, 0.27367}; + expRowIndices = new int[]{0, 0, 1, 1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4}; + expColIndices = new int[]{118, 999, 335, 999, 419, 999, 424, 880, 999, 134, 358, 492, 949, 999}; + exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); + + assertEquals(exp, RealCooMatrixGetSet.setCol(a, 999, bEntries)); + + // --------------------- Sub-case 4 --------------------- + aShape = new Shape(3, 5); + aEntries = new double[]{0.20695, 0.08553, 0.58839, 0.42649}; + aRowIndices = new int[]{0, 0, 1, 2}; + aColIndices = new int[]{1, 2, 4, 1}; + a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); + + bEntries = new double[]{0.70299, 0.12535, 0.51468}; + + CooMatrix final0a = a; + double[] final0b = bEntries; + assertThrows(Exception.class, ()-> RealCooMatrixGetSet.setCol(final0a, 6, final0b)); + + // --------------------- Sub-case 5 --------------------- + aShape = new Shape(3, 5); + aEntries = new double[]{0.8715, 0.48536, 0.74835, 0.61107}; + aRowIndices = new int[]{1, 1, 2, 2}; + aColIndices = new int[]{2, 4, 2, 3}; + a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); + + bEntries = new double[]{0.18264, 0.50269, 0.62068, 0.68308, 0.25792}; + + CooMatrix final1a = a; + double[] final1b = bEntries; + assertThrows(Exception.class, ()-> RealCooMatrixGetSet.setCol(final1a, 3, final1b)); + + // --------------------- Sub-case 6 --------------------- + aShape = new Shape(3, 5); + aEntries = new double[]{0.97644, 0.04564, 0.1204, 0.19723}; + aRowIndices = new int[]{0, 2, 2, 2}; + aColIndices = new int[]{4, 1, 2, 4}; + a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); + + bEntries = new double[]{0.69692, 0.15703}; + + CooMatrix final2a = a; + double[] final2b = bEntries; + assertThrows(Exception.class, ()-> RealCooMatrixGetSet.setCol(final2a, 3, final2b)); + + // --------------------- Sub-case 7 --------------------- + aShape = new Shape(3, 5); + aEntries = new double[]{0.9503, 0.0484, 0.44488, 0.29844}; + aRowIndices = new int[]{0, 2, 2, 2}; + aColIndices = new int[]{0, 1, 3, 4}; + a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); + + bEntries = new double[]{0.36708, 0.70117, 0.73955}; + + CooMatrix final3a = a; + double[] final3b = bEntries; + assertThrows(Exception.class, ()-> RealCooMatrixGetSet.setCol(final3a, 19, final3b)); + } + + + @Test + void setColSparseVectorTest() { + Shape aShape; + int[] aRowIndices; + int[] aColIndices; + double[] aEntries; + CooMatrix a; + + Shape bShape; + int[] bIndices; + double[] bEntries; + CooVector b; + + Shape expShape; + int[] expRowIndices; + int[] expColIndices; + double[] expEntries; + CooMatrix exp; + + // --------------------- Sub-case 1 --------------------- + aShape = new Shape(5, 3); + aEntries = new double[]{0.69683, 0.7974, 0.01005}; + aRowIndices = new int[]{0, 3, 4}; + aColIndices = new int[]{1, 0, 2}; + a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); + + bShape = new Shape(5); + bEntries = new double[]{0.42925, 0.95116}; + bIndices = new int[]{2, 3}; + b = new CooVector(bShape.get(0), bEntries, bIndices); + + expShape = new Shape(5, 3); + expEntries = new double[]{0.69683, 0.42925, 0.95116, 0.01005}; + expRowIndices = new int[]{0, 2, 3, 4}; + expColIndices = new int[]{1, 0, 0, 2}; + exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); + + assertEquals(exp, a.setCol(b, 0)); + + // --------------------- Sub-case 2 --------------------- + aShape = new Shape(11, 23); + aEntries = new double[]{0.09879, 0.44944, 0.39054, 0.51234, 0.10826}; + aRowIndices = new int[]{1, 3, 7, 8, 10}; + aColIndices = new int[]{14, 15, 16, 3, 4}; + a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); + + bShape = new Shape(11); + bEntries = new double[]{0.42701, 0.22431, 0.48719, 0.79679}; + bIndices = new int[]{5, 6, 7, 10}; + b = new CooVector(bShape.get(0), bEntries, bIndices); + + expShape = new Shape(11, 23); + expEntries = new double[]{0.09879, 0.44944, 0.42701, 0.22431, 0.48719, 0.51234, 0.10826, 0.79679}; + expRowIndices = new int[]{1, 3, 5, 6, 7, 8, 10, 10}; + expColIndices = new int[]{14, 15, 16, 16, 16, 3, 4, 16}; + exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); + + assertEquals(exp, a.setCol(b, 16)); + + // --------------------- Sub-case 3 --------------------- + aShape = new Shape(5, 1000); + aEntries = new double[]{0.548, 0.12782, 0.71044, 0.03123, 0.73197, 0.23329, 0.76449, 0.62306, 0.77283}; + aRowIndices = new int[]{0, 1, 1, 1, 2, 2, 2, 3, 4}; + aColIndices = new int[]{663, 597, 620, 926, 73, 153, 627, 66, 743}; + a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); + + bShape = new Shape(5); + bEntries = new double[]{0.92473, 0.36888}; + bIndices = new int[]{1, 4}; + b = new CooVector(bShape.get(0), bEntries, bIndices); + + expShape = new Shape(5, 1000); + expEntries = new double[]{0.548, 0.12782, 0.71044, 0.03123, 0.92473, 0.73197, 0.23329, 0.76449, 0.62306, 0.77283, 0.36888}; + expRowIndices = new int[]{0, 1, 1, 1, 1, 2, 2, 2, 3, 4, 4}; + expColIndices = new int[]{663, 597, 620, 926, 999, 73, 153, 627, 66, 743, 999}; + exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); + + assertEquals(exp, a.setCol(b, 999)); + + // --------------------- Sub-case 4 --------------------- + aShape = new Shape(3, 5); + aEntries = new double[]{0.38374, 0.24165, 0.20689, 0.73343}; + aRowIndices = new int[]{0, 1, 2, 2}; + aColIndices = new int[]{4, 0, 0, 1}; + a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); + + bShape = new Shape(3); + bEntries = new double[]{0.93917}; + bIndices = new int[]{2}; + b = new CooVector(bShape.get(0), bEntries, bIndices); + + CooMatrix final0a = a; + CooVector final0b = b; + assertThrows(Exception.class, ()->final0a.setCol(final0b, 6)); + + // --------------------- Sub-case 5 --------------------- + aShape = new Shape(3, 5); + aEntries = new double[]{0.52077, 0.42897, 0.35701, 0.94909}; + aRowIndices = new int[]{0, 1, 2, 2}; + aColIndices = new int[]{2, 2, 0, 3}; + a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); + + bShape = new Shape(5); + bEntries = new double[]{0.41526, 0.41046}; + bIndices = new int[]{0, 2}; + b = new CooVector(bShape.get(0), bEntries, bIndices); + + CooMatrix final1a = a; + CooVector final1b = b; + assertThrows(Exception.class, ()->final1a.setCol(final1b, 3)); + + // --------------------- Sub-case 6 --------------------- + aShape = new Shape(3, 5); + aEntries = new double[]{0.89024, 0.42578, 0.66571, 0.53301}; + aRowIndices = new int[]{0, 1, 1, 2}; + aColIndices = new int[]{2, 0, 1, 4}; + a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); + + bShape = new Shape(2); + bEntries = new double[]{0.55374}; + bIndices = new int[]{1}; + b = new CooVector(bShape.get(0), bEntries, bIndices); + + CooMatrix final2a = a; + CooVector final2b = b; + assertThrows(Exception.class, ()->final2a.setCol(final2b, 3)); + + // --------------------- Sub-case 7 --------------------- + aShape = new Shape(3, 5); + aEntries = new double[]{0.74812, 0.07704, 0.80715, 0.45783}; + aRowIndices = new int[]{0, 1, 2, 2}; + aColIndices = new int[]{0, 1, 3, 4}; + a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); + + bShape = new Shape(3); + bEntries = new double[]{0.838}; + bIndices = new int[]{0}; + b = new CooVector(bShape.get(0), bEntries, bIndices); + + CooMatrix final3a = a; + CooVector final3b = b; + assertThrows(Exception.class, ()->final3a.setCol(final3b, 19)); + } + + + @Test + void setRowCooVectorTests() { + Shape aShape; + int[] aRowIndices; + int[] aColIndices; + double[] aEntries; + CooMatrix a; + + Shape bShape; + int[] bIndices; + double[] bEntries; + CooVector b; + + Shape expShape; + int[] expRowIndices; + int[] expColIndices; + double[] expEntries; + CooMatrix exp; + + // --------------------- Sub-case 1 --------------------- + aShape = new Shape(3, 5); + aEntries = new double[]{0.69683, 0.7974, 0.01005}; + aRowIndices = new int[]{1, 0, 2}; + aColIndices = new int[]{0, 3, 4}; + a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices).sortIndices(); + + bShape = new Shape(5); + bEntries = new double[]{0.42925, 0.95116}; + bIndices = new int[]{2, 3}; + b = new CooVector(bShape.get(0), bEntries, bIndices); + + expShape = new Shape(3, 5); + expEntries = new double[]{0.69683, 0.42925, 0.95116, 0.01005}; + expRowIndices = new int[]{1, 0, 0, 2}; + expColIndices = new int[]{0, 2, 3, 4}; + exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices).sortIndices(); + + assertEquals(exp, a.setRow(b, 0)); + + // --------------------- Sub-case 2 --------------------- + aShape = new Shape(23, 11); + aEntries = new double[]{0.09879, 0.44944, 0.39054, 0.51234, 0.10826}; + aColIndices = new int[]{1, 3, 7, 8, 10}; + aRowIndices = new int[]{14, 15, 16, 3, 4}; + a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices).sortIndices(); + + bShape = new Shape(11); + bEntries = new double[]{0.42701, 0.22431, 0.48719, 0.79679}; + bIndices = new int[]{5, 6, 7, 10}; + b = new CooVector(bShape.get(0), bEntries, bIndices); + + expShape = new Shape(23, 11); + expEntries = new double[]{0.09879, 0.44944, 0.42701, 0.22431, 0.48719, 0.51234, 0.10826, 0.79679}; + expColIndices = new int[]{1, 3, 5, 6, 7, 8, 10, 10}; + expRowIndices = new int[]{14, 15, 16, 16, 16, 3, 4, 16}; + exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices).sortIndices(); + + assertEquals(exp, a.setRow(b, 16)); + + // --------------------- Sub-case 3 --------------------- + aShape = new Shape(1000, 5); + aEntries = new double[]{0.548, 0.12782, 0.71044, 0.03123, 0.73197, 0.23329, 0.76449, 0.62306, 0.77283}; + aColIndices = new int[]{0, 1, 1, 1, 2, 2, 2, 3, 4}; + aRowIndices = new int[]{663, 597, 620, 926, 73, 153, 627, 66, 743}; + a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices).sortIndices(); + + bShape = new Shape(5); + bEntries = new double[]{0.92473, 0.36888}; + bIndices = new int[]{1, 4}; + b = new CooVector(bShape.get(0), bEntries, bIndices); + + expShape = new Shape(1000, 5); + expEntries = new double[]{0.548, 0.12782, 0.71044, 0.03123, 0.92473, 0.73197, 0.23329, 0.76449, 0.62306, 0.77283, 0.36888}; + expColIndices = new int[]{0, 1, 1, 1, 1, 2, 2, 2, 3, 4, 4}; + expRowIndices = new int[]{663, 597, 620, 926, 999, 73, 153, 627, 66, 743, 999}; + exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices).sortIndices(); + + assertEquals(exp, a.setRow(b, 999)); + + // --------------------- Sub-case 4 --------------------- + aShape = new Shape(5, 3); + aEntries = new double[]{0.38374, 0.24165, 0.20689, 0.73343}; + aColIndices = new int[]{0, 1, 2, 2}; + aRowIndices = new int[]{4, 0, 0, 1}; + a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices).sortIndices(); + + bShape = new Shape(3); + bEntries = new double[]{0.93917}; + bIndices = new int[]{2}; + b = new CooVector(bShape.get(0), bEntries, bIndices); + + CooMatrix final0a = a; + CooVector final0b = b; + assertThrows(Exception.class, ()->final0a.setRow(final0b, 6)); + + // --------------------- Sub-case 5 --------------------- + aShape = new Shape(5, 3); + aEntries = new double[]{0.52077, 0.42897, 0.35701, 0.94909}; + aColIndices = new int[]{0, 1, 2, 2}; + aRowIndices = new int[]{2, 2, 0, 3}; + a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices).sortIndices(); + + bShape = new Shape(5); + bEntries = new double[]{0.41526, 0.41046}; + bIndices = new int[]{0, 2}; + b = new CooVector(bShape.get(0), bEntries, bIndices); + + CooMatrix final1a = a; + CooVector final1b = b; + assertThrows(Exception.class, ()->final1a.setRow(final1b, 3)); + + // --------------------- Sub-case 6 --------------------- + aShape = new Shape(5, 3); + aEntries = new double[]{0.89024, 0.42578, 0.66571, 0.53301}; + aColIndices = new int[]{0, 1, 1, 2}; + aRowIndices = new int[]{2, 0, 1, 4}; + a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices).sortIndices(); + + bShape = new Shape(2); + bEntries = new double[]{0.55374}; + bIndices = new int[]{1}; + b = new CooVector(bShape.get(0), bEntries, bIndices); + + CooMatrix final2a = a; + CooVector final2b = b; + assertThrows(Exception.class, ()->final2a.setRow(final2b, 3)); + + // --------------------- Sub-case 7 --------------------- + aShape = new Shape(5, 3); + aEntries = new double[]{0.74812, 0.07704, 0.80715, 0.45783}; + aColIndices = new int[]{0, 1, 2, 2}; + aRowIndices = new int[]{0, 1, 3, 4}; + a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices).sortIndices(); + + bShape = new Shape(3); + bEntries = new double[]{0.838}; + bIndices = new int[]{0}; + b = new CooVector(bShape.get(0), bEntries, bIndices); + + CooMatrix final3a = a; + CooVector final3b = b; + assertThrows(Exception.class, ()->final3a.setRow(final3b, 19)); + } + + @Test + void setRowDenseTests() { + Shape aShape; + int[] aRowIndices; + int[] aColIndices; + double[] aEntries; + CooMatrix a; + + double[] bEntries; + + Shape expShape; + int[] expRowIndices; + int[] expColIndices; + double[] expEntries; + CooMatrix exp; + + // --------------------- Sub-case 1 --------------------- + aShape = new Shape(3, 5); + aEntries = new double[]{0.42216, 0.86886, 0.51801}; + aColIndices = new int[]{0, 0, 1}; + aRowIndices = new int[]{1, 2, 2}; + a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices).sortIndices(); + + bEntries = new double[]{0.30728, 0.13698, 0.23211, 0.05517, 0.12575}; + + expShape = new Shape(3, 5); + expEntries = new double[]{0.30728, 0.42216, 0.86886, 0.13698, 0.51801, 0.23211, 0.05517, 0.12575}; + expColIndices = new int[]{0, 0, 0, 1, 1, 2, 3, 4}; + expRowIndices = new int[]{0, 1, 2, 0, 2, 0, 0, 0}; + exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices).sortIndices(); + + assertEquals(exp, RealCooMatrixGetSet.setRow(a, 0, bEntries)); + + // --------------------- Sub-case 2 --------------------- + aShape = new Shape(23, 11); + aEntries = new double[]{0.86291, 0.59273, 0.14697, 0.79343, 0.0691}; + aColIndices = new int[]{4, 5, 6, 8, 10}; + aRowIndices = new int[]{14, 3, 9, 15, 4}; + a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices).sortIndices(); + + bEntries = new double[]{0.09599, 0.03342, 0.08342, 0.86195, 0.18126, 0.71121, 0.03191, 0.3479, 0.5699, 0.35584, 0.51796}; + + expShape = new Shape(23, 11); + expEntries = new double[]{0.09599, 0.03342, 0.08342, 0.86195, 0.86291, 0.18126, 0.59273, 0.71121, 0.14697, 0.03191, 0.3479, 0.79343, 0.5699, 0.35584, 0.0691, 0.51796}; + expColIndices = new int[]{0, 1, 2, 3, 4, 4, 5, 5, 6, 6, 7, 8, 8, 9, 10, 10}; + expRowIndices = new int[]{16, 16, 16, 16, 14, 16, 3, 16, 9, 16, 16, 15, 16, 16, 4, 16}; + exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices).sortIndices(); + + assertEquals(exp, RealCooMatrixGetSet.setRow(a, 16, bEntries)); + + // --------------------- Sub-case 3 --------------------- + aShape = new Shape(1000, 5); + aEntries = new double[]{0.91557, 0.99112, 0.97331, 0.46736, 0.39273, 0.9236, 0.55027, 0.96506, 0.46553}; + aColIndices = new int[]{0, 1, 2, 3, 3, 4, 4, 4, 4}; + aRowIndices = new int[]{118, 335, 419, 424, 880, 134, 358, 492, 949}; + a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices).sortIndices(); + + bEntries = new double[]{0.86214, 0.01468, 0.80744, 0.38058, 0.27367}; + + expShape = new Shape(1000, 5); + expEntries = new double[]{0.91557, 0.86214, 0.99112, 0.01468, 0.97331, 0.80744, 0.46736, 0.39273, 0.38058, 0.9236, 0.55027, 0.96506, 0.46553, 0.27367}; + expColIndices = new int[]{0, 0, 1, 1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4}; + expRowIndices = new int[]{118, 999, 335, 999, 419, 999, 424, 880, 999, 134, 358, 492, 949, 999}; + exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices).sortIndices(); + + assertEquals(exp, RealCooMatrixGetSet.setRow(a, 999, bEntries)); + + // --------------------- Sub-case 4 --------------------- + aShape = new Shape(5, 3); + aEntries = new double[]{0.20695, 0.08553, 0.58839, 0.42649}; + aColIndices = new int[]{0, 0, 1, 2}; + aRowIndices = new int[]{1, 2, 4, 1}; + a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices).sortIndices(); + + bEntries = new double[]{0.70299, 0.12535, 0.51468}; + + CooMatrix final0a = a; + double[] final0b = bEntries; + assertThrows(Exception.class, ()-> RealCooMatrixGetSet.setRow(final0a, 6, final0b)); + + // --------------------- Sub-case 5 --------------------- + aShape = new Shape(5, 3); + aEntries = new double[]{0.8715, 0.48536, 0.74835, 0.61107}; + aColIndices = new int[]{1, 1, 2, 2}; + aRowIndices = new int[]{2, 4, 2, 3}; + a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices).sortIndices(); + + bEntries = new double[]{0.18264, 0.50269, 0.62068, 0.68308, 0.25792}; + + CooMatrix final1a = a; + double[] final1b = bEntries; + assertThrows(Exception.class, ()-> RealCooMatrixGetSet.setRow(final1a, 3, final1b)); + + // --------------------- Sub-case 6 --------------------- + aShape = new Shape(5, 3); + aEntries = new double[]{0.97644, 0.04564, 0.1204, 0.19723}; + aColIndices = new int[]{0, 2, 2, 2}; + aRowIndices = new int[]{4, 1, 2, 4}; + a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices).sortIndices(); + + bEntries = new double[]{0.69692, 0.15703}; + + CooMatrix final2a = a; + double[] final2b = bEntries; + assertThrows(Exception.class, ()-> RealCooMatrixGetSet.setRow(final2a, 3, final2b)); + + // --------------------- Sub-case 7 --------------------- + aShape = new Shape(5, 3); + aEntries = new double[]{0.9503, 0.0484, 0.44488, 0.29844}; + aColIndices = new int[]{0, 2, 2, 2}; + aRowIndices = new int[]{0, 1, 3, 4}; + a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices).sortIndices(); + + bEntries = new double[]{0.36708, 0.70117, 0.73955}; + + CooMatrix final3a = a; + double[] final3b = bEntries; + assertThrows(Exception.class, ()-> RealCooMatrixGetSet.setRow(final3a, 19, final3b)); + + // --------------------- Sub-case 8 --------------------- + aShape = new Shape(21, 32); + aEntries = new double[]{0.9503, 0.0484, 0.44488, 0.29844, -1.515, 20234.123}; + aRowIndices = new int[]{3, 15, 15, 15, 17, 20}; + aColIndices = new int[]{1, 5, 8, 13, 3, 0 }; + a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices).sortIndices(); + + bEntries = new double[]{-9.865780643419962, 4.755111047514857, -6.443769475611005, -6.389082329867075, -7.77229497365825, + 2.9256765346746505, -0.4403770113908312, -9.930275071364584, -1.301116729717764, 4.112954038109342, + 5.840693384369942, -0.9163855794658371, 7.877142592316218, -4.962843547135085, -9.342168182539481, + -7.875571161798163, -2.5778080197216173, -5.910510433509996, -5.982349806812259, 7.9221497570046004, + -6.637400014097297, -7.2100059336746, 6.580249567644238, -1.2577616266417806, 2.9633941118990776, + 1.68922841488396, 1.3879844182529482, 1.304876540209957, -1.7553804221712515, 1.075723178061402, + -7.01621785650135, 0.8428004605597845}; + + expShape = new Shape(21, 32); + expEntries = new double[]{0.9503, -9.865780643419962, 4.755111047514857, -6.443769475611005, -6.389082329867075, + -7.77229497365825, 2.9256765346746505, -0.4403770113908312, -9.930275071364584, -1.301116729717764, 4.112954038109342, + 5.840693384369942, -0.9163855794658371, 7.877142592316218, -4.962843547135085, -9.342168182539481, + -7.875571161798163, -2.5778080197216173, -5.910510433509996, -5.982349806812259, 7.9221497570046004, + -6.637400014097297, -7.2100059336746, 6.580249567644238, -1.2577616266417806, 2.9633941118990776, + 1.68922841488396, 1.3879844182529482, 1.304876540209957, -1.7553804221712515, 1.075723178061402, + -7.01621785650135, 0.8428004605597845, -1.515, 20234.123}; + expRowIndices = new int[]{3, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, + 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 17, 20}; + expColIndices = new int[]{1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, + 26, 27, 28, 29, 30, 31, 3, 0}; + exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); + + assertEquals(exp, a.setRow(bEntries, 15)); + } +} diff --git a/src/test/java/org/flag4j/sparse_matrix/CooMatrixSetSliceTests.java b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixSetSliceTests.java similarity index 61% rename from src/test/java/org/flag4j/sparse_matrix/CooMatrixSetSliceTests.java rename to src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixSetSliceTests.java index 89bcfb113..92495adda 100644 --- a/src/test/java/org/flag4j/sparse_matrix/CooMatrixSetSliceTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixSetSliceTests.java @@ -1,9 +1,9 @@ -package org.flag4j.sparse_matrix; +package org.flag4j.arrays.sparse.coo_matrix; import org.flag4j.arrays.Shape; import org.flag4j.arrays.dense.Matrix; import org.flag4j.arrays.sparse.CooMatrix; -import org.flag4j.linalg.ops.sparse.coo.real.RealSparseMatrixGetSet; +import org.flag4j.linalg.ops.sparse.coo.real.RealCooMatrixGetSet; import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -50,7 +50,7 @@ void setSliceTest() { expColIndices = new int[]{0, 2, 2, 0}; exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); - assertEquals(exp, RealSparseMatrixGetSet.setSlice(a, b, 2, 0)); + assertEquals(exp, RealCooMatrixGetSet.setSlice(a, b, 2, 0)); // --------------------- Sub-case 2 --------------------- aShape = new Shape(23, 11); @@ -71,7 +71,7 @@ void setSliceTest() { expColIndices = new int[]{8, 2, 10, 4, 8, 1, 2, 6, 7, 4, 7, 8, 9, 2, 4, 7, 8, 1, 5}; exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); - assertEquals(exp, RealSparseMatrixGetSet.setSlice(a, b, 18, 1)); + assertEquals(exp, RealCooMatrixGetSet.setSlice(a, b, 18, 1)); // --------------------- Sub-case 3 --------------------- aShape = new Shape(1000, 5); @@ -92,7 +92,7 @@ void setSliceTest() { expColIndices = new int[]{1, 0, 2, 1, 4, 2, 2, 2, 4, 1, 3}; exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); - assertEquals(exp, RealSparseMatrixGetSet.setSlice(a, b, 0, 0)); + assertEquals(exp, RealCooMatrixGetSet.setSlice(a, b, 0, 0)); // --------------------- Sub-case 4 --------------------- aShape = new Shape(3, 5); @@ -109,7 +109,7 @@ void setSliceTest() { CooMatrix final0a = a; CooMatrix final0b = b; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.setSlice(final0a, final0b, -1, 2)); + assertThrows(Exception.class, ()-> RealCooMatrixGetSet.setSlice(final0a, final0b, -1, 2)); // --------------------- Sub-case 5 --------------------- aShape = new Shape(3, 5); @@ -126,7 +126,7 @@ void setSliceTest() { CooMatrix final1a = a; CooMatrix final1b = b; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.setSlice(final1a, final1b, 0, 16)); + assertThrows(Exception.class, ()-> RealCooMatrixGetSet.setSlice(final1a, final1b, 0, 16)); // --------------------- Sub-case 6 --------------------- aShape = new Shape(3, 5); @@ -143,7 +143,7 @@ void setSliceTest() { CooMatrix final2a = a; CooMatrix final2b = b; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.setSlice(final2a, final2b, 2, 0)); + assertThrows(Exception.class, ()-> RealCooMatrixGetSet.setSlice(final2a, final2b, 2, 0)); // --------------------- Sub-case 7 --------------------- aShape = new Shape(3, 5); @@ -160,7 +160,7 @@ void setSliceTest() { CooMatrix final3a = a; CooMatrix final3b = b; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.setSlice(final3a, final3b, 0, 4)); + assertThrows(Exception.class, ()-> RealCooMatrixGetSet.setSlice(final3a, final3b, 0, 4)); } @@ -199,7 +199,7 @@ void setSliceDenseTest() { expColIndices = new int[]{0, 0, 1, 2, 0, 1, 2}; exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); - assertEquals(exp, RealSparseMatrixGetSet.setSlice(a, b, 2, 0)); + assertEquals(exp, RealCooMatrixGetSet.setSlice(a, b, 2, 0)); // --------------------- Sub-case 2 --------------------- aShape = new Shape(23, 11); @@ -222,7 +222,7 @@ void setSliceDenseTest() { expColIndices = new int[]{2, 3, 3, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9}; exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); - assertEquals(exp, RealSparseMatrixGetSet.setSlice(a, b, 18, 1)); + assertEquals(exp, RealCooMatrixGetSet.setSlice(a, b, 18, 1)); // --------------------- Sub-case 3 --------------------- aShape = new Shape(1000, 5); @@ -243,7 +243,7 @@ void setSliceDenseTest() { expColIndices = new int[]{0, 1, 0, 1, 0, 1, 4, 1, 3, 3, 0, 4, 4, 3, 3}; exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); - assertEquals(exp, RealSparseMatrixGetSet.setSlice(a, b, 0, 0)); + assertEquals(exp, RealCooMatrixGetSet.setSlice(a, b, 0, 0)); // --------------------- Sub-case 4 --------------------- aShape = new Shape(3, 5); @@ -258,7 +258,7 @@ void setSliceDenseTest() { CooMatrix final0a = a; Matrix final0b = b; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.setSlice(final0a, final0b, -1, 2)); + assertThrows(Exception.class, ()-> RealCooMatrixGetSet.setSlice(final0a, final0b, -1, 2)); // --------------------- Sub-case 5 --------------------- aShape = new Shape(3, 5); @@ -275,7 +275,7 @@ void setSliceDenseTest() { CooMatrix final1a = a; Matrix final1b = b; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.setSlice(final1a, final1b, 0, 16)); + assertThrows(Exception.class, ()-> RealCooMatrixGetSet.setSlice(final1a, final1b, 0, 16)); // --------------------- Sub-case 6 --------------------- aShape = new Shape(3, 5); @@ -291,7 +291,7 @@ void setSliceDenseTest() { CooMatrix final2a = a; Matrix final2b = b; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.setSlice(final2a, final2b, 2, 0)); + assertThrows(Exception.class, ()-> RealCooMatrixGetSet.setSlice(final2a, final2b, 2, 0)); // --------------------- Sub-case 7 --------------------- aShape = new Shape(3, 5); @@ -306,7 +306,7 @@ void setSliceDenseTest() { CooMatrix final3a = a; Matrix final3b = b; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.setSlice(final3a, final3b, 0, 4)); + assertThrows(Exception.class, ()-> RealCooMatrixGetSet.setSlice(final3a, final3b, 0, 4)); } @@ -343,7 +343,7 @@ void setSliceDenseArrayTest() { expColIndices = new int[]{0, 0, 1, 2, 0, 1, 2}; exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); - assertEquals(exp, RealSparseMatrixGetSet.setSlice(a, b, 2, 0)); + assertEquals(exp, RealCooMatrixGetSet.setSlice(a, b, 2, 0)); // --------------------- Sub-case 2 --------------------- aShape = new Shape(23, 11); @@ -365,7 +365,7 @@ void setSliceDenseArrayTest() { expColIndices = new int[]{2, 3, 3, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9}; exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); - assertEquals(exp, RealSparseMatrixGetSet.setSlice(a, b, 18, 1)); + assertEquals(exp, RealCooMatrixGetSet.setSlice(a, b, 18, 1)); // --------------------- Sub-case 3 --------------------- aShape = new Shape(1000, 5); @@ -385,7 +385,7 @@ void setSliceDenseArrayTest() { expColIndices = new int[]{0, 1, 0, 1, 0, 1, 4, 1, 3, 3, 0, 4, 4, 3, 3}; exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); - assertEquals(exp, RealSparseMatrixGetSet.setSlice(a, b, 0, 0)); + assertEquals(exp, RealCooMatrixGetSet.setSlice(a, b, 0, 0)); // --------------------- Sub-case 4 --------------------- aShape = new Shape(3, 5); @@ -399,7 +399,7 @@ void setSliceDenseArrayTest() { CooMatrix final0a = a; double[][] final0b = b; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.setSlice(final0a, final0b, -1, 2)); + assertThrows(Exception.class, ()-> RealCooMatrixGetSet.setSlice(final0a, final0b, -1, 2)); // --------------------- Sub-case 5 --------------------- aShape = new Shape(3, 5); @@ -415,7 +415,7 @@ void setSliceDenseArrayTest() { CooMatrix final1a = a; double[][] final1b = b; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.setSlice(final1a, final1b, 0, 16)); + assertThrows(Exception.class, ()-> RealCooMatrixGetSet.setSlice(final1a, final1b, 0, 16)); // --------------------- Sub-case 6 --------------------- aShape = new Shape(3, 5); @@ -430,7 +430,7 @@ void setSliceDenseArrayTest() { CooMatrix final2a = a; double[][] final2b = b; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.setSlice(final2a, final2b, 2, 0)); + assertThrows(Exception.class, ()-> RealCooMatrixGetSet.setSlice(final2a, final2b, 2, 0)); // --------------------- Sub-case 7 --------------------- aShape = new Shape(3, 5); @@ -444,145 +444,7 @@ void setSliceDenseArrayTest() { CooMatrix final3a = a; double[][] final3b = b; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.setSlice(final3a, final3b, 0, 4)); - } - - - @Test - void setSliceDenseBoxedArrayTest() { - Shape aShape; - int[] aRowIndices; - int[] aColIndices; - double[] aEntries; - CooMatrix a; - - Double[][] b; - - Shape expShape; - int[] expRowIndices; - int[] expColIndices; - double[] expEntries; - CooMatrix exp; - - // --------------------- Sub-case 1 --------------------- - aShape = new Shape(5, 3); - aEntries = new double[]{0.27021, 0.29417, 0.06904}; - aRowIndices = new int[]{1, 2, 3}; - aColIndices = new int[]{0, 1, 1}; - a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); - - b = new Double[][]{ - {0.88098, 0.32602, 0.83928}, - {0.08236, 0.83795, 0.84279}}; - - expShape = new Shape(5, 3); - expEntries = new double[]{0.27021, 0.88098, 0.32602, 0.83928, 0.08236, 0.83795, 0.84279}; - expRowIndices = new int[]{1, 2, 2, 2, 3, 3, 3}; - expColIndices = new int[]{0, 0, 1, 2, 0, 1, 2}; - exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); - - assertEquals(exp, RealSparseMatrixGetSet.setSlice(a, b, 2, 0)); - - // --------------------- Sub-case 2 --------------------- - aShape = new Shape(23, 11); - aEntries = new double[]{0.64914, 0.66932, 0.40628, 0.37954, 0.08519}; - aRowIndices = new int[]{0, 5, 6, 16, 22}; - aColIndices = new int[]{2, 3, 3, 9, 4}; - a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); - - b = new Double[][]{ - {0.58805, 0.25559, 0.41371, 0.98028, 0.63021, 0.31739, 0.53601, 0.31622, 0.79944}, - {0.47782, 0.07173, 0.48916, 0.19796, 0.33267, 0.11585, 0.55903, 0.85354, 0.76878}, - {0.37356, 0.92958, 0.878, 0.50643, 0.05278, 0.85421, 0.29942, 0.52806, 0.28666}, - {0.63041, 0.87807, 0.18841, 0.78023, 0.9306, 0.81551, 0.04105, 0.0534, 0.23816}, - {0.76035, 0.43175, 0.25131, 0.8096, 0.84916, 0.16624, 0.28679, 0.13698, 0.12409}}; - - expShape = new Shape(23, 11); - expEntries = new double[]{0.64914, 0.66932, 0.40628, 0.37954, 0.58805, 0.25559, 0.41371, 0.98028, 0.63021, 0.31739, 0.53601, 0.31622, 0.79944, 0.47782, 0.07173, 0.48916, 0.19796, 0.33267, 0.11585, 0.55903, 0.85354, 0.76878, 0.37356, 0.92958, 0.878, 0.50643, 0.05278, 0.85421, 0.29942, 0.52806, 0.28666, 0.63041, 0.87807, 0.18841, 0.78023, 0.9306, 0.81551, 0.04105, 0.0534, 0.23816, 0.76035, 0.43175, 0.25131, 0.8096, 0.84916, 0.16624, 0.28679, 0.13698, 0.12409}; - expRowIndices = new int[]{0, 5, 6, 16, 18, 18, 18, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19, 19, 19, 19, 19, 19, 20, 20, 20, 20, 20, 20, 20, 20, 20, 21, 21, 21, 21, 21, 21, 21, 21, 21, 22, 22, 22, 22, 22, 22, 22, 22, 22}; - expColIndices = new int[]{2, 3, 3, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9}; - exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); - - assertEquals(exp, RealSparseMatrixGetSet.setSlice(a, b, 18, 1)); - - // --------------------- Sub-case 3 --------------------- - aShape = new Shape(1000, 5); - aEntries = new double[]{0.09033, 0.75299, 0.78946, 0.91141, 0.66149, 0.23721, 0.78215, 0.14658, 0.48493}; - aRowIndices = new int[]{31, 88, 174, 224, 258, 291, 562, 595, 854}; - aColIndices = new int[]{4, 1, 3, 3, 0, 4, 4, 3, 3}; - a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); - - b = new Double[][]{ - {0.92281, 0.27413}, - {0.0728, 0.35943}, - {0.65062, 0.823}}; - - expShape = new Shape(1000, 5); - expEntries = new double[]{0.92281, 0.27413, 0.0728, 0.35943, 0.65062, 0.823, 0.09033, 0.75299, 0.78946, 0.91141, 0.66149, 0.23721, 0.78215, 0.14658, 0.48493}; - expRowIndices = new int[]{0, 0, 1, 1, 2, 2, 31, 88, 174, 224, 258, 291, 562, 595, 854}; - expColIndices = new int[]{0, 1, 0, 1, 0, 1, 4, 1, 3, 3, 0, 4, 4, 3, 3}; - exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); - - assertEquals(exp, RealSparseMatrixGetSet.setSlice(a, b, 0, 0)); - - // --------------------- Sub-case 4 --------------------- - aShape = new Shape(3, 5); - aEntries = new double[]{0.62672, 0.72454, 0.03301, 0.05962}; - aRowIndices = new int[]{0, 1, 2, 2}; - aColIndices = new int[]{4, 0, 0, 1}; - a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); - - b = new Double[][]{ - {0.88379, 0.23006, 0.79116}}; - - CooMatrix final0a = a; - Double[][] final0b = b; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.setSlice(final0a, final0b, -1, 2)); - - // --------------------- Sub-case 5 --------------------- - aShape = new Shape(3, 5); - aEntries = new double[]{0.36772, 0.68086, 0.78025, 0.7059}; - aRowIndices = new int[]{0, 0, 1, 1}; - aColIndices = new int[]{0, 1, 2, 4}; - a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); - - b = new Double[][]{ - {0.37548}, - {0.90032}, - {0.54146}}; - - CooMatrix final1a = a; - Double[][] final1b = b; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.setSlice(final1a, final1b, 0, 16)); - - // --------------------- Sub-case 6 --------------------- - aShape = new Shape(3, 5); - aEntries = new double[]{0.08019, 0.07101, 0.62705, 0.60587}; - aRowIndices = new int[]{0, 1, 2, 2}; - aColIndices = new int[]{3, 0, 0, 2}; - a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); - - b = new Double[][]{ - {0.77061, 0.2345}, - {0.34159, 0.25256}}; - - CooMatrix final2a = a; - Double[][] final2b = b; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.setSlice(final2a, final2b, 2, 0)); - - // --------------------- Sub-case 7 --------------------- - aShape = new Shape(3, 5); - aEntries = new double[]{0.63961, 0.47008, 0.91095, 0.36202}; - aRowIndices = new int[]{0, 2, 2, 2}; - aColIndices = new int[]{3, 2, 3, 4}; - a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); - - b = new Double[][]{ - {0.17845, 0.95046, 0.59089}}; - - CooMatrix final3a = a; - Double[][] final3b = b; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.setSlice(final3a, final3b, 0, 4)); + assertThrows(Exception.class, ()-> RealCooMatrixGetSet.setSlice(final3a, final3b, 0, 4)); } @@ -619,7 +481,7 @@ void setSliceDenseIntArrayTest() { expColIndices = new int[]{0, 0, 1, 2, 0, 1, 2}; exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); - assertEquals(exp, RealSparseMatrixGetSet.setSlice(a, b, 2, 0)); + assertEquals(exp, RealCooMatrixGetSet.setSlice(a, b, 2, 0)); // --------------------- Sub-case 2 --------------------- aShape = new Shape(23, 11); @@ -641,7 +503,7 @@ void setSliceDenseIntArrayTest() { expColIndices = new int[]{2, 3, 3, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9}; exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); - assertEquals(exp, RealSparseMatrixGetSet.setSlice(a, b, 18, 1)); + assertEquals(exp, RealCooMatrixGetSet.setSlice(a, b, 18, 1)); // --------------------- Sub-case 3 --------------------- aShape = new Shape(1000, 5); @@ -661,7 +523,7 @@ void setSliceDenseIntArrayTest() { expColIndices = new int[]{0, 1, 0, 1, 0, 1, 4, 1, 3, 3, 0, 4, 4, 3, 3}; exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); - assertEquals(exp, RealSparseMatrixGetSet.setSlice(a, b, 0, 0)); + assertEquals(exp, RealCooMatrixGetSet.setSlice(a, b, 0, 0)); // --------------------- Sub-case 4 --------------------- aShape = new Shape(3, 5); @@ -675,7 +537,7 @@ void setSliceDenseIntArrayTest() { CooMatrix final0a = a; int[][] final0b = b; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.setSlice(final0a, final0b, -1, 2)); + assertThrows(Exception.class, ()-> RealCooMatrixGetSet.setSlice(final0a, final0b, -1, 2)); // --------------------- Sub-case 5 --------------------- aShape = new Shape(3, 5); @@ -691,7 +553,7 @@ void setSliceDenseIntArrayTest() { CooMatrix final1a = a; int[][] final1b = b; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.setSlice(final1a, final1b, 0, 16)); + assertThrows(Exception.class, ()-> RealCooMatrixGetSet.setSlice(final1a, final1b, 0, 16)); // --------------------- Sub-case 6 --------------------- aShape = new Shape(3, 5); @@ -706,7 +568,7 @@ void setSliceDenseIntArrayTest() { CooMatrix final2a = a; int[][] final2b = b; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.setSlice(final2a, final2b, 2, 0)); + assertThrows(Exception.class, ()-> RealCooMatrixGetSet.setSlice(final2a, final2b, 2, 0)); // --------------------- Sub-case 7 --------------------- aShape = new Shape(3, 5); @@ -720,144 +582,6 @@ void setSliceDenseIntArrayTest() { CooMatrix final3a = a; int[][] final3b = b; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.setSlice(final3a, final3b, 0, 4)); - } - - - @Test - void setSliceDenseBoxedIntArrayTest() { - Shape aShape; - int[] aRowIndices; - int[] aColIndices; - double[] aEntries; - CooMatrix a; - - Integer[][] b; - - Shape expShape; - int[] expRowIndices; - int[] expColIndices; - double[] expEntries; - CooMatrix exp; - - // --------------------- Sub-case 1 --------------------- - aShape = new Shape(5, 3); - aEntries = new double[]{0.27021, 0.29417, 0.06904}; - aRowIndices = new int[]{1, 2, 3}; - aColIndices = new int[]{0, 1, 1}; - a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); - - b = new Integer[][]{ - {88098, 32602, 83928}, - {8236, 83795, 84279}}; - - expShape = new Shape(5, 3); - expEntries = new double[]{0.27021, 88098, 32602, 83928, 8236, 83795, 84279}; - expRowIndices = new int[]{1, 2, 2, 2, 3, 3, 3}; - expColIndices = new int[]{0, 0, 1, 2, 0, 1, 2}; - exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); - - assertEquals(exp, RealSparseMatrixGetSet.setSlice(a, b, 2, 0)); - - // --------------------- Sub-case 2 --------------------- - aShape = new Shape(23, 11); - aEntries = new double[]{0.64914, 0.66932, 0.40628, 0.37954, 0.08519}; - aRowIndices = new int[]{0, 5, 6, 16, 22}; - aColIndices = new int[]{2, 3, 3, 9, 4}; - a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); - - b = new Integer[][]{ - {58805, 25559, 41371, 98028, 63021, 31739, 53601, 31622, 79944}, - {47782, 7173, 48916, 19796, 33267, 11585, 55903, 85354, 76878}, - {37356, 92958, 878, 50643, 5278, 85421, 29942, 52806, 28666}, - {63041, 87807, 18841, 78023, 9306, 81551, 4105, 534, 23816}, - {76035, 43175, 25131, 8096, 84916, 16624, 28679, 13698, 12409}}; - - expShape = new Shape(23, 11); - expEntries = new double[]{0.64914, 0.66932, 0.40628, 0.37954, 58805, 25559, 41371, 98028, 63021, 31739, 53601, 31622, 79944, 47782, 7173, 48916, 19796, 33267, 11585, 55903, 85354, 76878, 37356, 92958, 878, 50643, 5278, 85421, 29942, 52806, 28666, 63041, 87807, 18841, 78023, 9306, 81551, 4105, 534, 23816, 76035, 43175, 25131, 8096, 84916, 16624, 28679, 13698, 12409}; - expRowIndices = new int[]{0, 5, 6, 16, 18, 18, 18, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19, 19, 19, 19, 19, 19, 20, 20, 20, 20, 20, 20, 20, 20, 20, 21, 21, 21, 21, 21, 21, 21, 21, 21, 22, 22, 22, 22, 22, 22, 22, 22, 22}; - expColIndices = new int[]{2, 3, 3, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9}; - exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); - - assertEquals(exp, RealSparseMatrixGetSet.setSlice(a, b, 18, 1)); - - // --------------------- Sub-case 3 --------------------- - aShape = new Shape(1000, 5); - aEntries = new double[]{0.09033, 0.75299, 0.78946, 0.91141, 0.66149, 0.23721, 0.78215, 0.14658, 0.48493}; - aRowIndices = new int[]{31, 88, 174, 224, 258, 291, 562, 595, 854}; - aColIndices = new int[]{4, 1, 3, 3, 0, 4, 4, 3, 3}; - a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); - - b = new Integer[][]{ - {92281, 27413}, - {728, 35943}, - {65062, 823}}; - - expShape = new Shape(1000, 5); - expEntries = new double[]{92281, 27413, 728, 35943, 65062, 823, 0.09033, 0.75299, 0.78946, 0.91141, 0.66149, 0.23721, 0.78215, 0.14658, 0.48493}; - expRowIndices = new int[]{0, 0, 1, 1, 2, 2, 31, 88, 174, 224, 258, 291, 562, 595, 854}; - expColIndices = new int[]{0, 1, 0, 1, 0, 1, 4, 1, 3, 3, 0, 4, 4, 3, 3}; - exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); - - assertEquals(exp, RealSparseMatrixGetSet.setSlice(a, b, 0, 0)); - - // --------------------- Sub-case 4 --------------------- - aShape = new Shape(3, 5); - aEntries = new double[]{0.62672, 0.72454, 0.03301, 0.05962}; - aRowIndices = new int[]{0, 1, 2, 2}; - aColIndices = new int[]{4, 0, 0, 1}; - a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); - - b = new Integer[][]{ - {88379, 23006, 79116}}; - - CooMatrix final0a = a; - Integer[][] final0b = b; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.setSlice(final0a, final0b, -1, 2)); - - // --------------------- Sub-case 5 --------------------- - aShape = new Shape(3, 5); - aEntries = new double[]{0.36772, 0.68086, 0.78025, 0.7059}; - aRowIndices = new int[]{0, 0, 1, 1}; - aColIndices = new int[]{0, 1, 2, 4}; - a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); - - b = new Integer[][]{ - {37548}, - {90032}, - {54146}}; - - CooMatrix final1a = a; - Integer[][] final1b = b; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.setSlice(final1a, final1b, 0, 16)); - - // --------------------- Sub-case 6 --------------------- - aShape = new Shape(3, 5); - aEntries = new double[]{0.08019, 0.07101, 0.62705, 0.60587}; - aRowIndices = new int[]{0, 1, 2, 2}; - aColIndices = new int[]{3, 0, 0, 2}; - a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); - - b = new Integer[][]{ - {77061, 2345}, - {34159, 25256}}; - - CooMatrix final2a = a; - Integer[][] final2b = b; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.setSlice(final2a, final2b, 2, 0)); - - // --------------------- Sub-case 7 --------------------- - aShape = new Shape(3, 5); - aEntries = new double[]{0.63961, 0.47008, 0.91095, 0.36202}; - aRowIndices = new int[]{0, 2, 2, 2}; - aColIndices = new int[]{3, 2, 3, 4}; - a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); - - b = new Integer[][]{ - {17845, 95046, 59089}}; - - CooMatrix final3a = a; - Integer[][] final3b = b; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.setSlice(final3a, final3b, 0, 4)); + assertThrows(Exception.class, ()-> RealCooMatrixGetSet.setSlice(final3a, final3b, 0, 4)); } } diff --git a/src/test/java/org/flag4j/sparse_matrix/CooMatrixSetTests.java b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixSetTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_matrix/CooMatrixSetTests.java rename to src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixSetTests.java index e7a20e9e3..414ee5544 100644 --- a/src/test/java/org/flag4j/sparse_matrix/CooMatrixSetTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixSetTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_matrix; +package org.flag4j.arrays.sparse.coo_matrix; import org.flag4j.arrays.Shape; import org.flag4j.arrays.sparse.CooMatrix; diff --git a/src/test/java/org/flag4j/sparse_matrix/CooMatrixStackTests.java b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixStackTests.java similarity index 98% rename from src/test/java/org/flag4j/sparse_matrix/CooMatrixStackTests.java rename to src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixStackTests.java index 67171d057..b678b0424 100644 --- a/src/test/java/org/flag4j/sparse_matrix/CooMatrixStackTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixStackTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_matrix; +package org.flag4j.arrays.sparse.coo_matrix; import org.flag4j.arrays.Shape; import org.flag4j.arrays.sparse.CooMatrix; diff --git a/src/test/java/org/flag4j/sparse_matrix/CooMatrixToStringTests.java b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixToStringTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_matrix/CooMatrixToStringTests.java rename to src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixToStringTests.java index 9d07cdbbc..fd09e6d14 100644 --- a/src/test/java/org/flag4j/sparse_matrix/CooMatrixToStringTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/CooMatrixToStringTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_matrix; +package org.flag4j.arrays.sparse.coo_matrix; import org.flag4j.arrays.Shape; import org.flag4j.arrays.sparse.CooMatrix; diff --git a/src/test/java/org/flag4j/arrays/sparse/coo_matrix/IsIdentityTests.java b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/IsIdentityTests.java new file mode 100644 index 000000000..d51bdb81b --- /dev/null +++ b/src/test/java/org/flag4j/arrays/sparse/coo_matrix/IsIdentityTests.java @@ -0,0 +1,88 @@ +package org.flag4j.arrays.sparse.coo_matrix; + +import org.flag4j.arrays.Shape; +import org.flag4j.arrays.sparse.CooMatrix; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class IsIdentityTests { + + @Test + void testCooMatrixIsCloseToI() { + Shape aShape; + int[] aRowIndices, aColIndices; + double[] aData; + CooMatrix a; + + // ---------------------- Sub-case 1 ---------------------- + aShape = new Shape(50, 12); + aRowIndices = new int[]{0, 5, 14, 23, 49}; + aColIndices = new int[]{1, 3, 3, 1, 2}; + aData = new double[]{1, 3, 5, 7, 9}; + a = new CooMatrix(aShape, aData, aRowIndices, aColIndices); + + assertFalse(a.isCloseToI()); + + // ---------------------- Sub-case 2 ---------------------- + aShape = new Shape(20, 20); + aRowIndices = new int[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}; + aColIndices = new int[]{0, 1, 2, 3, 4, 5, 12, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}; + aData = new double[]{1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1}; + a = new CooMatrix(aShape, aData, aRowIndices, aColIndices); + + assertFalse(a.isCloseToI()); + + // ---------------------- Sub-case 3 ---------------------- + aShape = new Shape(20, 20); + aRowIndices = new int[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}; + aColIndices = new int[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}; + aData = new double[]{1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, + 5}; + a = new CooMatrix(aShape, aData, aRowIndices, aColIndices); + + assertFalse(a.isCloseToI()); + + // ---------------------- Sub-case 4 ---------------------- + aShape = new Shape(20, 20); + aRowIndices = new int[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}; + aColIndices = new int[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}; + aData = new double[]{1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1}; + a = new CooMatrix(aShape, aData, aRowIndices, aColIndices); + + assertTrue(a.isCloseToI()); + + // ---------------------- Sub-case 5 ---------------------- + aShape = new Shape(20, 20); + aRowIndices = new int[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}; + aColIndices = new int[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}; + aData = new double[]{1, 1, 1, 1, 1, 1, + 1, 1, 1.000000000000001, 1, 1, + 1, 1, + 1, 1, 1, 0.99999999, 1, 1, + 1}; + a = new CooMatrix(aShape, aData, aRowIndices, aColIndices); + + assertTrue(a.isCloseToI()); + + // ---------------------- Sub-case 6 ---------------------- + aShape = new Shape(20, 20); + aRowIndices = new int[]{0, 1, 2, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}; + aColIndices = new int[]{0, 1, 2, 0, 3, 4, 5, 6, 7, 8, 9, 10, 13, 11, 12, 13, 14, 15, 16, 17, 18, 19}; + aData = new double[]{1, 1, 1, 1.24e-16, + 1, 1, 1, 1, 1, + 1.00000000000000001, 1, 1, + 1.0e-18, 1, 1, 1, 1, 1, + 0.9999999999999, 1, 1, 1}; + a = new CooMatrix(aShape, aData, aRowIndices, aColIndices); + + assertTrue(a.isCloseToI()); + } +} diff --git a/src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrElemMultTests.java b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrElemMultTests.java new file mode 100644 index 000000000..82b11cacc --- /dev/null +++ b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrElemMultTests.java @@ -0,0 +1,103 @@ +package org.flag4j.arrays.sparse.csr_matrix; + +import org.flag4j.arrays.Shape; +import org.flag4j.arrays.sparse.CsrMatrix; +import org.flag4j.util.exceptions.TensorShapeException; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class CsrElemMultTests { + + @Test + void elemMultTests() { + Shape aShape, bShape, expShape; + double[] aData, bData, expData; + int[] aRowPointers, bRowPointers, expRowPointers; + int[] aColIndices, bColIndices, expColIndices; + CsrMatrix a, b, exp; + + // ------------------- Sub-case 1 ------------------- + aShape = new Shape(12, 12); + aData = new double[]{0.02234, 0.43189, 0.87963, 0.79995, 0.72224, 0.74126, 0.44675, 0.44917, 0.69626, 0.81694, 0.32876, 0.02551, 0.75703, 0.58045, 0.64317, 0.23147, 0.52038, 0.37257, 0.88039, 0.48479, 0.02311, 0.72269, 0.80202, 0.0231, 0.12105, 0.47067, 0.2849, 0.3631, 0.70669}; + aRowPointers = new int[]{0, 4, 4, 6, 12, 14, 15, 17, 18, 24, 25, 28, 29}; + aColIndices = new int[]{0, 1, 2, 9, 1, 7, 1, 2, 3, 7, 8, 11, 1, 7, 2, 9, 11, 7, 0, 2, 4, 5, 6, 9, 1, 6, 9, 10, 4}; + a = new CsrMatrix(aShape, aData, aRowPointers, aColIndices); + + bShape = new Shape(12, 12); + bData = new double[]{0.60519, 0.32299, 0.80625, 0.28575, 0.34192, 0.89753, 0.92196, 0.23285, 0.32287, 0.43847, 0.70881, 0.79414, 0.71232, 0.8306, 0.09576, 0.10497, 0.99244, 0.27191, 0.65053, 0.16913, 0.27482, 0.14437, 0.36674, 0.07254, 0.97292, 0.47324, 0.09344, 0.26302, 0.2397}; + bRowPointers = new int[]{0, 3, 3, 4, 6, 9, 10, 14, 17, 20, 25, 29, 29}; + bColIndices = new int[]{0, 6, 10, 2, 3, 8, 1, 4, 11, 2, 3, 4, 7, 9, 0, 4, 6, 0, 6, 10, 0, 2, 6, 8, 9, 4, 7, 10, 11}; + b = new CsrMatrix(bShape, bData, bRowPointers, bColIndices); + + expShape = new Shape(12, 12); + expData = new double[]{0.0135199446, 0.2380652192, 0.29507196280000003, 0.6979513788, 0.2820107499, 0.192258982, 0.23938684489999998, 0.5217380706, 0.09550256199999999}; + expRowPointers = new int[]{0, 1, 1, 1, 3, 4, 5, 6, 6, 8, 8, 9, 9}; + expColIndices = new int[]{0, 3, 8, 1, 2, 9, 0, 6, 10}; + exp = new CsrMatrix(expShape, expData, expRowPointers, expColIndices); + + assertEquals(exp, a.elemMult(b)); + + // ------------------- Sub-case 2 ------------------- + aShape = new Shape(14, 16); + aData = new double[]{0.70348, 0.43878, 0.5199, 0.78916, 0.634, 0.24166, 0.14356, 0.99185, 0.90768}; + aRowPointers = new int[]{0, 0, 1, 2, 3, 4, 4, 5, 5, 6, 6, 7, 8, 8, 9}; + aColIndices = new int[]{2, 5, 15, 4, 8, 12, 13, 6, 3}; + a = new CsrMatrix(aShape, aData, aRowPointers, aColIndices); + + bShape = new Shape(14, 16); + bData = new double[]{0.88271, 0.31581, 0.41842, 0.67326}; + bRowPointers = new int[]{0, 0, 0, 0, 0, 0, 1, 2, 2, 2, 2, 3, 3, 3, 4}; + bColIndices = new int[]{14, 11, 8, 1}; + b = new CsrMatrix(bShape, bData, bRowPointers, bColIndices); + + expShape = new Shape(14, 16); + expData = new double[]{}; + expRowPointers = new int[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + expColIndices = new int[]{}; + exp = new CsrMatrix(expShape, expData, expRowPointers, expColIndices); + + assertEquals(exp, a.elemMult(b)); + + // ------------------- Sub-case 3 ------------------- + aShape = new Shape(16, 5); + aData = new double[]{}; + aRowPointers = new int[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + aColIndices = new int[]{}; + a = new CsrMatrix(aShape, aData, aRowPointers, aColIndices); + + bShape = new Shape(16, 5); + bData = new double[]{}; + bRowPointers = new int[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + bColIndices = new int[]{}; + b = new CsrMatrix(bShape, bData, bRowPointers, bColIndices); + + expShape = new Shape(16, 5); + expData = new double[]{}; + expRowPointers = new int[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + expColIndices = new int[]{}; + exp = new CsrMatrix(expShape, expData, expRowPointers, expColIndices); + + assertEquals(exp, a.elemMult(b)); + + // ------------------- Sub-case 4 ------------------- + aShape = new Shape(16, 5); + a = new CsrMatrix(aShape); + bShape = new Shape(16, 4); + b = new CsrMatrix(bShape); + + CsrMatrix finalA = a; + CsrMatrix finalB = b; + assertThrows(TensorShapeException.class, ()-> finalA.elemMult(finalB)); + + aShape = new Shape(15156, 95314); + a = new CsrMatrix(aShape); + bShape = new Shape(132, 235); + b = new CsrMatrix(bShape); + + CsrMatrix finalA1 = a; + CsrMatrix finalB1 = b; + assertThrows(TensorShapeException.class, ()-> finalA1.elemMult(finalB1)); + } +} diff --git a/src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixAddSubTests.java b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixAddSubTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixAddSubTests.java rename to src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixAddSubTests.java index 5c14825f0..61a29ac2a 100644 --- a/src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixAddSubTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixAddSubTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_csr_matrix; +package org.flag4j.arrays.sparse.csr_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixGetRowColTests.java b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixGetRowColTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixGetRowColTests.java rename to src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixGetRowColTests.java index 8da17bce6..b1e21e54b 100644 --- a/src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixGetRowColTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixGetRowColTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_csr_matrix; +package org.flag4j.arrays.sparse.csr_matrix; import org.flag4j.arrays.Shape; import org.flag4j.arrays.sparse.CooVector; diff --git a/src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixGetSetTests.java b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixGetSetTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixGetSetTests.java rename to src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixGetSetTests.java index 04e110bd0..c59091638 100644 --- a/src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixGetSetTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixGetSetTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_csr_matrix; +package org.flag4j.arrays.sparse.csr_matrix; import org.flag4j.arrays.dense.Matrix; import org.flag4j.arrays.sparse.CsrMatrix; diff --git a/src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixGetSliceTest.java b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixGetSliceTest.java new file mode 100644 index 000000000..25f2b4898 --- /dev/null +++ b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixGetSliceTest.java @@ -0,0 +1,114 @@ +package org.flag4j.arrays.sparse.csr_matrix; + +import org.flag4j.arrays.Shape; +import org.flag4j.arrays.sparse.CsrMatrix; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class CsrMatrixGetSliceTest { + + @Test + void getSliceTests() { + int rowStart, rowEnd, colStart, colEnd; + int[] aRowPointers, aColIndices, expRowPointers, expColIndices; + double[] aData, expData; + Shape aShape, expShape; + CsrMatrix a, exp; + + // -------------------- sub-case 1 -------------------- + rowStart = 0; + rowEnd = 15; + colStart = 0; + colEnd = 156; + + aShape = new Shape(162, 525); + aData = new double[]{0.00689, 0.47811, 0.22132, 0.95089, 0.17458, 0.96716, 0.38071, 0.32721, 0.70462}; + aRowPointers = new int[]{0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9}; + aColIndices = new int[]{211, 479, 403, 342, 499, 197, 187, 256, 302}; + a = new CsrMatrix(aShape, aData, aRowPointers, aColIndices); + + expShape = new Shape(15, 156); + expData = new double[]{}; + expRowPointers = new int[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + expColIndices = new int[]{}; + exp = new CsrMatrix(expShape, expData, expRowPointers, expColIndices); + + assertEquals(exp, a.getSlice(rowStart, rowEnd, colStart, colEnd)); + + // -------------------- sub-case 2 -------------------- + rowStart = 15; + rowEnd = 25; + colStart = 6; + colEnd = 24; + + aShape = new Shape(25, 35); + aData = new double[]{0.21399, 0.92765, 0.47011, 0.39667, 0.72502, 0.97094, 0.47102, 0.14218, 0.44779}; + aRowPointers = new int[]{0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2, 3, 3, 4, 5, 5, 5, 7, 8, 8, 8, 8, 9, 9, 9, 9}; + aColIndices = new int[]{14, 6, 18, 19, 0, 21, 27, 14, 11}; + a = new CsrMatrix(aShape, aData, aRowPointers, aColIndices); + + expShape = new Shape(10, 18); + expData = new double[]{0.97094, 0.14218, 0.44779}; + expRowPointers = new int[]{0, 0, 1, 2, 2, 2, 2, 3, 3, 3, 3}; + expColIndices = new int[]{15, 8, 5}; + exp = new CsrMatrix(expShape, expData, expRowPointers, expColIndices); + + assertEquals(exp, a.getSlice(rowStart, rowEnd, colStart, colEnd)); + + // -------------------- sub-case 3 -------------------- + rowStart = 8; + rowEnd = 9; + colStart = 18; + colEnd = 21; + + aShape = new Shape(33, 21); + aData = new double[]{0.85158, 0.21804, 0.43476, 0.49103, 0.96505, 0.23558, 0.52658}; + aRowPointers = new int[]{0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 6, 6, 7, 7, 7, 7, 7}; + aColIndices = new int[]{18, 0, 0, 8, 7, 14, 17}; + a = new CsrMatrix(aShape, aData, aRowPointers, aColIndices); + + expShape = new Shape(1, 3); + expData = new double[]{}; + expRowPointers = new int[]{0, 0}; + expColIndices = new int[]{}; + exp = new CsrMatrix(expShape, expData, expRowPointers, expColIndices); + assertEquals(exp, a.getSlice(rowStart, rowEnd, colStart, colEnd)); + + + // -------------------- sub-case 4 -------------------- + rowStart = 5; + rowEnd = 22; + colStart = 0; + colEnd = 55; + + aShape = new Shape(55, 55); + aData = new double[]{}; + aRowPointers = new int[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + aColIndices = new int[]{}; + a = new CsrMatrix(aShape, aData, aRowPointers, aColIndices); + + expShape = new Shape(17, 55); + expData = new double[]{}; + expRowPointers = new int[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + expColIndices = new int[]{}; + exp = new CsrMatrix(expShape, expData, expRowPointers, expColIndices); + assertEquals(exp, a.getSlice(rowStart, rowEnd, colStart, colEnd)); + + // -------------------- sub-case 4 -------------------- + aShape = new Shape(55, 55); + aData = new double[]{}; + aRowPointers = new int[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + aColIndices = new int[]{}; + a = new CsrMatrix(aShape, aData, aRowPointers, aColIndices); + + CsrMatrix finalA = a; + assertThrows(IllegalArgumentException.class, ()-> finalA.getSlice(-1, 2, 4, 5)); + assertThrows(IllegalArgumentException.class, ()-> finalA.getSlice(1, 2, -4, 5)); + assertThrows(IllegalArgumentException.class, ()-> finalA.getSlice(4, 2, 4, 5)); + assertThrows(IllegalArgumentException.class, ()-> finalA.getSlice(1, 2, 4, 1)); + assertThrows(IllegalArgumentException.class, ()-> finalA.getSlice(1, 56, 0, 5)); + assertThrows(IllegalArgumentException.class, ()-> finalA.getSlice(1, 2, 0, 514)); + } +} diff --git a/src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixNormTests.java b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixNormTests.java new file mode 100644 index 000000000..f1a1544ac --- /dev/null +++ b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixNormTests.java @@ -0,0 +1,87 @@ +package org.flag4j.arrays.sparse.csr_matrix; + +import org.flag4j.arrays.Shape; +import org.flag4j.arrays.sparse.CsrMatrix; +import org.flag4j.linalg.MatrixNorms; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class CsrMatrixNormTests { + + @Test + void csrLpqNorms() { + Shape aShape; + double[] aData; + int[] aRowPointers, aColIndices; + CsrMatrix a; + double exp, p, q; + + // ----------------------- Sub-case 1 ----------------------- + aShape = new Shape(32, 32); + aData = new double[]{0.59873, 0.14037, 0.51302, 0.81953, 0.85602, 0.26156, 0.09093, 0.47848, 0.49457, 0.22042}; + aRowPointers = new int[]{0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 4, 4, 4, 6, 7, 7, 8, 8, 9, 10, 10, 10, 10}; + aColIndices = new int[]{6, 8, 9, 27, 1, 2, 4, 31, 5, 18}; + a = new CsrMatrix(aShape, aData, aRowPointers, aColIndices); + + p = 1; + q = 1; + exp = 4.473629999999999; + + assertEquals(exp, MatrixNorms.norm(a, p, q)); + + // ----------------------- Sub-case 2 ----------------------- + aShape = new Shape(32, 32); + aData = new double[]{0.25735, 0.03955, 0.30834, 0.78676, 0.06766, 0.86635, 0.29043, 0.6284, 0.87736, 0.9143}; + aRowPointers = new int[]{0, 0, 0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 4, 4, 4, 4, 6, 6, 7, 7, 8, 8, 8, 10, 10, 10, 10, 10, 10}; + aColIndices = new int[]{9, 3, 9, 17, 8, 25, 9, 29, 24, 29}; + a = new CsrMatrix(aShape, aData, aRowPointers, aColIndices); + + p = 1; + q = 2; + exp = 2.2931029222867427; + + assertEquals(exp, MatrixNorms.norm(a, p, q)); + + // ----------------------- Sub-case 3 ----------------------- + aShape = new Shape(32, 32); + aData = new double[]{0.01011, 0.30401, 0.44966, 0.91805, 0.74716, 0.79202, 0.48672, 0.9156, 0.26773, 0.19309}; + aRowPointers = new int[]{0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 4, 4, 4, 6, 7, 7, 7, 7, 7, 8, 8, 8, 8, 9, 9, 10}; + aColIndices = new int[]{2, 1, 6, 22, 4, 29, 31, 29, 19, 4}; + a = new CsrMatrix(aShape, aData, aRowPointers, aColIndices); + + p = 2; + q = 1; + exp = 4.418614617567907; + assertEquals(exp, MatrixNorms.norm(a, p, q)); + + // ----------------------- Sub-case 4 ----------------------- + aShape = new Shape(32, 32); + aData = new double[]{0.26041, 0.0959, 0.91373, 0.93231, 0.1104, 0.02363, 0.94682, 0.72971, 0.98633, 0.96019}; + aRowPointers = new int[]{0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 5, 5, 5, 5, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 10, 10}; + aColIndices = new int[]{22, 8, 5, 8, 16, 6, 29, 27, 5, 6}; + a = new CsrMatrix(aShape, aData, aRowPointers, aColIndices); + + p = 4.12; + q = 9.3; + exp = 1.1861666012978695; + + assertEquals(exp, MatrixNorms.norm(a, p, q)); + + // ----------------------- Sub-case 5 ----------------------- + aShape = new Shape(32, 32); + aData = new double[]{0.57274, 0.87413, 0.52516, 0.08473, 0.55946, 0.6395, 0.10405, 0.78955, 0.68463, 0.48435}; + aRowPointers = new int[]{0, 0, 1, 1, 1, 2, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 8, 8, 8, 9, 9, 9, 9, 9, 9, 10, 10, 10}; + aColIndices = new int[]{3, 4, 14, 10, 31, 5, 0, 26, 24, 8}; + a = new CsrMatrix(aShape, aData, aRowPointers, aColIndices); + + p = 0; + q = 0; + + CsrMatrix finalA = a; + double finalP = p; + double finalQ = q; + assertThrows(IllegalArgumentException.class, () -> MatrixNorms.norm(finalA, finalP, finalQ)); + } +} diff --git a/src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixRowColSwapTests.java b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixRowColSwapTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixRowColSwapTests.java rename to src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixRowColSwapTests.java index cbdbe5a30..c1b522e48 100644 --- a/src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixRowColSwapTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixRowColSwapTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_csr_matrix; +package org.flag4j.arrays.sparse.csr_matrix; import org.flag4j.arrays.dense.Matrix; import org.flag4j.arrays.sparse.CsrMatrix; diff --git a/src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixToDenseTests.java b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixToDenseTests.java similarity index 96% rename from src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixToDenseTests.java rename to src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixToDenseTests.java index 1ce4fff1f..2e51c4ea8 100644 --- a/src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixToDenseTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixToDenseTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_csr_matrix; +package org.flag4j.arrays.sparse.csr_matrix; import org.flag4j.arrays.dense.Matrix; import org.flag4j.arrays.sparse.CsrMatrix; diff --git a/src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixToStringTests.java b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixToStringTests.java similarity index 97% rename from src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixToStringTests.java rename to src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixToStringTests.java index e5802d268..b29f5be19 100644 --- a/src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixToStringTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixToStringTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_csr_matrix; +package org.flag4j.arrays.sparse.csr_matrix; import org.flag4j.arrays.Shape; import org.flag4j.arrays.sparse.CooMatrix; diff --git a/src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixToVectorTests.java b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixToVectorTests.java similarity index 97% rename from src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixToVectorTests.java rename to src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixToVectorTests.java index d2a9c6571..87a6019af 100644 --- a/src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixToVectorTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixToVectorTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_csr_matrix; +package org.flag4j.arrays.sparse.csr_matrix; import org.flag4j.arrays.dense.Matrix; import org.flag4j.arrays.dense.Vector; diff --git a/src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixTransposeTests.java b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixTransposeTests.java similarity index 98% rename from src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixTransposeTests.java rename to src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixTransposeTests.java index 45a2f5878..6a366bedb 100644 --- a/src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixTransposeTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixTransposeTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_csr_matrix; +package org.flag4j.arrays.sparse.csr_matrix; import org.flag4j.arrays.dense.Matrix; import org.flag4j.arrays.sparse.CsrMatrix; diff --git a/src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixTriDiagTests.java b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixTriDiagTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixTriDiagTests.java rename to src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixTriDiagTests.java index f8873a6ea..12d2ba4e2 100644 --- a/src/test/java/org/flag4j/sparse_csr_matrix/CsrMatrixTriDiagTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/CsrMatrixTriDiagTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_csr_matrix; +package org.flag4j.arrays.sparse.csr_matrix; import org.flag4j.arrays.Shape; import org.flag4j.arrays.dense.Vector; diff --git a/src/test/java/org/flag4j/sparse_csr_matrix/RealComplexCsrCsrMatMultTests.java b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/RealComplexCsrCsrMatMultTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_csr_matrix/RealComplexCsrCsrMatMultTests.java rename to src/test/java/org/flag4j/arrays/sparse/csr_matrix/RealComplexCsrCsrMatMultTests.java index 0b49777e4..f7e6f61d4 100644 --- a/src/test/java/org/flag4j/sparse_csr_matrix/RealComplexCsrCsrMatMultTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/RealComplexCsrCsrMatMultTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_csr_matrix; +package org.flag4j.arrays.sparse.csr_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CMatrix; @@ -8,8 +8,8 @@ import org.flag4j.util.exceptions.LinearAlgebraException; import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; class RealComplexCsrCsrMatMultTests { diff --git a/src/test/java/org/flag4j/sparse_csr_matrix/RealComplexCsrDenseMatMultTests.java b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/RealComplexCsrDenseMatMultTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_csr_matrix/RealComplexCsrDenseMatMultTests.java rename to src/test/java/org/flag4j/arrays/sparse/csr_matrix/RealComplexCsrDenseMatMultTests.java index 5a4bc7b70..bdeadaada 100644 --- a/src/test/java/org/flag4j/sparse_csr_matrix/RealComplexCsrDenseMatMultTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/RealComplexCsrDenseMatMultTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_csr_matrix; +package org.flag4j.arrays.sparse.csr_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/sparse_csr_matrix/RealCsrCsrMatMultTests.java b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/RealCsrCsrMatMultTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_csr_matrix/RealCsrCsrMatMultTests.java rename to src/test/java/org/flag4j/arrays/sparse/csr_matrix/RealCsrCsrMatMultTests.java index 629bac5cc..9e742f05f 100644 --- a/src/test/java/org/flag4j/sparse_csr_matrix/RealCsrCsrMatMultTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/RealCsrCsrMatMultTests.java @@ -1,12 +1,12 @@ -package org.flag4j.sparse_csr_matrix; +package org.flag4j.arrays.sparse.csr_matrix; import org.flag4j.arrays.dense.Matrix; import org.flag4j.arrays.sparse.CsrMatrix; import org.flag4j.util.exceptions.LinearAlgebraException; import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; class RealCsrCsrMatMultTests { diff --git a/src/test/java/org/flag4j/sparse_csr_matrix/RealCsrDenseMatMultTests.java b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/RealCsrDenseMatMultTests.java similarity index 98% rename from src/test/java/org/flag4j/sparse_csr_matrix/RealCsrDenseMatMultTests.java rename to src/test/java/org/flag4j/arrays/sparse/csr_matrix/RealCsrDenseMatMultTests.java index b9d2429b0..2d97e8ac5 100644 --- a/src/test/java/org/flag4j/sparse_csr_matrix/RealCsrDenseMatMultTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/RealCsrDenseMatMultTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_csr_matrix; +package org.flag4j.arrays.sparse.csr_matrix; import org.flag4j.arrays.dense.Matrix; import org.flag4j.arrays.sparse.CsrMatrix; diff --git a/src/test/java/org/flag4j/sparse_csr_matrix/RealCsrEqualsTests.java b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/RealCsrEqualsTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_csr_matrix/RealCsrEqualsTests.java rename to src/test/java/org/flag4j/arrays/sparse/csr_matrix/RealCsrEqualsTests.java index a94090d53..090e0a19a 100644 --- a/src/test/java/org/flag4j/sparse_csr_matrix/RealCsrEqualsTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/csr_matrix/RealCsrEqualsTests.java @@ -22,7 +22,7 @@ * SOFTWARE. */ -package org.flag4j.sparse_csr_matrix; +package org.flag4j.arrays.sparse.csr_matrix; import org.flag4j.arrays.Shape; import org.flag4j.arrays.dense.Matrix; diff --git a/src/test/java/org/flag4j/sparse_complex_tensor/ComplexSparseTensorElemBinOpsTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_tensor/ComplexSparseTensorElemBinOpsTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_complex_tensor/ComplexSparseTensorElemBinOpsTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_complex_tensor/ComplexSparseTensorElemBinOpsTests.java index fde6bc659..28e18df02 100644 --- a/src/test/java/org/flag4j/sparse_complex_tensor/ComplexSparseTensorElemBinOpsTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_tensor/ComplexSparseTensorElemBinOpsTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_complex_tensor; +package org.flag4j.arrays.sparse.sparse_complex_tensor; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/sparse_complex_tensor/CooCTensorConstructorTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_tensor/CooCTensorConstructorTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_complex_tensor/CooCTensorConstructorTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_complex_tensor/CooCTensorConstructorTests.java index 47a23ec54..5b381bcda 100644 --- a/src/test/java/org/flag4j/sparse_complex_tensor/CooCTensorConstructorTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_tensor/CooCTensorConstructorTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_complex_tensor; +package org.flag4j.arrays.sparse.sparse_complex_tensor; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/sparse_complex_tensor/CooCTensorGetSetTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_tensor/CooCTensorGetSetTests.java similarity index 98% rename from src/test/java/org/flag4j/sparse_complex_tensor/CooCTensorGetSetTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_complex_tensor/CooCTensorGetSetTests.java index 14b3712a3..a683acee6 100644 --- a/src/test/java/org/flag4j/sparse_complex_tensor/CooCTensorGetSetTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_tensor/CooCTensorGetSetTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_complex_tensor; +package org.flag4j.arrays.sparse.sparse_complex_tensor; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/sparse_complex_tensor/CooCTensorHermTransposeTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_tensor/CooCTensorHermTransposeTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_complex_tensor/CooCTensorHermTransposeTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_complex_tensor/CooCTensorHermTransposeTests.java index 6398ecd5f..8e0af289e 100644 --- a/src/test/java/org/flag4j/sparse_complex_tensor/CooCTensorHermTransposeTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_tensor/CooCTensorHermTransposeTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_complex_tensor; +package org.flag4j.arrays.sparse.sparse_complex_tensor; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/sparse_complex_tensor/CooCTensorReshapeTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_tensor/CooCTensorReshapeTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_complex_tensor/CooCTensorReshapeTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_complex_tensor/CooCTensorReshapeTests.java index 5f55964e6..dd7bdaa2e 100644 --- a/src/test/java/org/flag4j/sparse_complex_tensor/CooCTensorReshapeTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_tensor/CooCTensorReshapeTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_complex_tensor; +package org.flag4j.arrays.sparse.sparse_complex_tensor; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/sparse_complex_tensor/CooCTensorToStringTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_tensor/CooCTensorToStringTests.java similarity index 98% rename from src/test/java/org/flag4j/sparse_complex_tensor/CooCTensorToStringTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_complex_tensor/CooCTensorToStringTests.java index 77987ca52..f30cac5cc 100644 --- a/src/test/java/org/flag4j/sparse_complex_tensor/CooCTensorToStringTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_tensor/CooCTensorToStringTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_complex_tensor; +package org.flag4j.arrays.sparse.sparse_complex_tensor; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/sparse_complex_vector/CooCVectorAddTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_vector/CooCVectorAddTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_complex_vector/CooCVectorAddTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_complex_vector/CooCVectorAddTests.java index 4e53753c5..76568faeb 100644 --- a/src/test/java/org/flag4j/sparse_complex_vector/CooCVectorAddTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_vector/CooCVectorAddTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_complex_vector; +package org.flag4j.arrays.sparse.sparse_complex_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CVector; diff --git a/src/test/java/org/flag4j/sparse_complex_vector/CooCVectorConstructorTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_vector/CooCVectorConstructorTests.java similarity index 98% rename from src/test/java/org/flag4j/sparse_complex_vector/CooCVectorConstructorTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_complex_vector/CooCVectorConstructorTests.java index bba8bba53..091ee1088 100644 --- a/src/test/java/org/flag4j/sparse_complex_vector/CooCVectorConstructorTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_vector/CooCVectorConstructorTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_complex_vector; +package org.flag4j.arrays.sparse.sparse_complex_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/sparse_complex_vector/CooCVectorConversionTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_vector/CooCVectorConversionTests.java similarity index 98% rename from src/test/java/org/flag4j/sparse_complex_vector/CooCVectorConversionTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_complex_vector/CooCVectorConversionTests.java index 4b484c2d3..0001b4816 100644 --- a/src/test/java/org/flag4j/sparse_complex_vector/CooCVectorConversionTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_vector/CooCVectorConversionTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_complex_vector; +package org.flag4j.arrays.sparse.sparse_complex_vector; import org.flag4j.algebraic_structures.Complex128; diff --git a/src/test/java/org/flag4j/sparse_complex_vector/CooCVectorElemDivTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_vector/CooCVectorElemDivTests.java similarity index 98% rename from src/test/java/org/flag4j/sparse_complex_vector/CooCVectorElemDivTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_complex_vector/CooCVectorElemDivTests.java index d235c3b41..997ead587 100644 --- a/src/test/java/org/flag4j/sparse_complex_vector/CooCVectorElemDivTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_vector/CooCVectorElemDivTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_complex_vector; +package org.flag4j.arrays.sparse.sparse_complex_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CVector; diff --git a/src/test/java/org/flag4j/sparse_complex_vector/CooCVectorElemMultTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_vector/CooCVectorElemMultTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_complex_vector/CooCVectorElemMultTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_complex_vector/CooCVectorElemMultTests.java index 82f729e31..08308bf78 100644 --- a/src/test/java/org/flag4j/sparse_complex_vector/CooCVectorElemMultTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_vector/CooCVectorElemMultTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_complex_vector; +package org.flag4j.arrays.sparse.sparse_complex_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CVector; diff --git a/src/test/java/org/flag4j/sparse_complex_vector/CooCVectorJoinTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_vector/CooCVectorJoinTests.java similarity index 98% rename from src/test/java/org/flag4j/sparse_complex_vector/CooCVectorJoinTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_complex_vector/CooCVectorJoinTests.java index 0be7a5062..1c806c661 100644 --- a/src/test/java/org/flag4j/sparse_complex_vector/CooCVectorJoinTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_vector/CooCVectorJoinTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_complex_vector; +package org.flag4j.arrays.sparse.sparse_complex_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/sparse_complex_vector/CooCVectorRepeatTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_vector/CooCVectorRepeatTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_complex_vector/CooCVectorRepeatTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_complex_vector/CooCVectorRepeatTests.java index f6c0271bd..a1497aee5 100644 --- a/src/test/java/org/flag4j/sparse_complex_vector/CooCVectorRepeatTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_vector/CooCVectorRepeatTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_complex_vector; +package org.flag4j.arrays.sparse.sparse_complex_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CMatrix; diff --git a/src/test/java/org/flag4j/sparse_complex_vector/CooCVectorReshapeTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_vector/CooCVectorReshapeTests.java similarity index 97% rename from src/test/java/org/flag4j/sparse_complex_vector/CooCVectorReshapeTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_complex_vector/CooCVectorReshapeTests.java index 624dfc90e..95b33ed4b 100644 --- a/src/test/java/org/flag4j/sparse_complex_vector/CooCVectorReshapeTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_vector/CooCVectorReshapeTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_complex_vector; +package org.flag4j.arrays.sparse.sparse_complex_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/sparse_complex_vector/CooCVectorSubTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_vector/CooCVectorSubTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_complex_vector/CooCVectorSubTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_complex_vector/CooCVectorSubTests.java index 4f25b0031..8aec4f76a 100644 --- a/src/test/java/org/flag4j/sparse_complex_vector/CooCVectorSubTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_vector/CooCVectorSubTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_complex_vector; +package org.flag4j.arrays.sparse.sparse_complex_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CVector; diff --git a/src/test/java/org/flag4j/sparse_complex_vector/CooCVectorUnaryOpTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_vector/CooCVectorUnaryOpTests.java similarity index 98% rename from src/test/java/org/flag4j/sparse_complex_vector/CooCVectorUnaryOpTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_complex_vector/CooCVectorUnaryOpTests.java index 8d4111cd1..cd5f428d5 100644 --- a/src/test/java/org/flag4j/sparse_complex_vector/CooCVectorUnaryOpTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_complex_vector/CooCVectorUnaryOpTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_complex_vector; +package org.flag4j.arrays.sparse.sparse_complex_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.sparse.CooCVector; diff --git a/src/test/java/org/flag4j/sparse_tensor/CooTensorConstructorTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_tensor/CooTensorConstructorTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_tensor/CooTensorConstructorTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_tensor/CooTensorConstructorTests.java index 6eb113d29..0071a406f 100644 --- a/src/test/java/org/flag4j/sparse_tensor/CooTensorConstructorTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_tensor/CooTensorConstructorTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_tensor; +package org.flag4j.arrays.sparse.sparse_tensor; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/sparse_tensor/CooTensorGetSetTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_tensor/CooTensorGetSetTests.java similarity index 98% rename from src/test/java/org/flag4j/sparse_tensor/CooTensorGetSetTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_tensor/CooTensorGetSetTests.java index 12117e47b..f5e586e50 100644 --- a/src/test/java/org/flag4j/sparse_tensor/CooTensorGetSetTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_tensor/CooTensorGetSetTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_tensor; +package org.flag4j.arrays.sparse.sparse_tensor; import org.flag4j.arrays.Shape; import org.flag4j.arrays.sparse.CooTensor; diff --git a/src/test/java/org/flag4j/sparse_tensor/CooTensorReshapeTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_tensor/CooTensorReshapeTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_tensor/CooTensorReshapeTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_tensor/CooTensorReshapeTests.java index 357f81e42..563da8049 100644 --- a/src/test/java/org/flag4j/sparse_tensor/CooTensorReshapeTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_tensor/CooTensorReshapeTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_tensor; +package org.flag4j.arrays.sparse.sparse_tensor; import org.flag4j.arrays.Shape; import org.flag4j.arrays.sparse.CooTensor; diff --git a/src/test/java/org/flag4j/sparse_tensor/CooTensorSpElemBinOpsTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_tensor/CooTensorSpElemBinOpsTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_tensor/CooTensorSpElemBinOpsTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_tensor/CooTensorSpElemBinOpsTests.java index a8acf8cf0..0fc155436 100644 --- a/src/test/java/org/flag4j/sparse_tensor/CooTensorSpElemBinOpsTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_tensor/CooTensorSpElemBinOpsTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_tensor; +package org.flag4j.arrays.sparse.sparse_tensor; import org.flag4j.arrays.Shape; import org.flag4j.arrays.sparse.CooTensor; diff --git a/src/test/java/org/flag4j/sparse_tensor/CooTensorToStringTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_tensor/CooTensorToStringTests.java similarity index 98% rename from src/test/java/org/flag4j/sparse_tensor/CooTensorToStringTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_tensor/CooTensorToStringTests.java index 5fc2de75c..afb831ecd 100644 --- a/src/test/java/org/flag4j/sparse_tensor/CooTensorToStringTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_tensor/CooTensorToStringTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_tensor; +package org.flag4j.arrays.sparse.sparse_tensor; import org.flag4j.arrays.Shape; import org.flag4j.arrays.sparse.CooTensor; diff --git a/src/test/java/org/flag4j/sparse_tensor/CooTensorTransposeTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_tensor/CooTensorTransposeTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_tensor/CooTensorTransposeTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_tensor/CooTensorTransposeTests.java index 49c12f49b..234d88c0b 100644 --- a/src/test/java/org/flag4j/sparse_tensor/CooTensorTransposeTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_tensor/CooTensorTransposeTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_tensor; +package org.flag4j.arrays.sparse.sparse_tensor; import org.flag4j.arrays.Shape; import org.flag4j.arrays.sparse.CooTensor; diff --git a/src/test/java/org/flag4j/sparse_tensor/RealComplexCooTensorBinOpsTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_tensor/RealComplexCooTensorBinOpsTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_tensor/RealComplexCooTensorBinOpsTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_tensor/RealComplexCooTensorBinOpsTests.java index 5a5d5ee06..41c503ca3 100644 --- a/src/test/java/org/flag4j/sparse_tensor/RealComplexCooTensorBinOpsTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_tensor/RealComplexCooTensorBinOpsTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_tensor; +package org.flag4j.arrays.sparse.sparse_tensor; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/sparse_vector/CooVectorAddTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorAddTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_vector/CooVectorAddTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorAddTests.java index 467aa76de..21b006469 100644 --- a/src/test/java/org/flag4j/sparse_vector/CooVectorAddTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorAddTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_vector; +package org.flag4j.arrays.sparse.sparse_vector; import org.flag4j.algebraic_structures.Complex128; diff --git a/src/test/java/org/flag4j/sparse_vector/CooVectorAggregateTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorAggregateTests.java similarity index 97% rename from src/test/java/org/flag4j/sparse_vector/CooVectorAggregateTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorAggregateTests.java index 96dca1f26..27386b1a6 100644 --- a/src/test/java/org/flag4j/sparse_vector/CooVectorAggregateTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorAggregateTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_vector; +package org.flag4j.arrays.sparse.sparse_vector; import org.flag4j.arrays.sparse.CooVector; import org.junit.jupiter.api.BeforeAll; diff --git a/src/test/java/org/flag4j/sparse_vector/CooVectorConstructorTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorConstructorTests.java similarity index 98% rename from src/test/java/org/flag4j/sparse_vector/CooVectorConstructorTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorConstructorTests.java index cd56a9f63..eee4b7673 100644 --- a/src/test/java/org/flag4j/sparse_vector/CooVectorConstructorTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorConstructorTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_vector; +package org.flag4j.arrays.sparse.sparse_vector; import org.flag4j.arrays.Shape; import org.flag4j.arrays.sparse.CooVector; diff --git a/src/test/java/org/flag4j/sparse_vector/CooVectorConversionTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorConversionTests.java similarity index 98% rename from src/test/java/org/flag4j/sparse_vector/CooVectorConversionTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorConversionTests.java index ab54ce745..ce4404fbf 100644 --- a/src/test/java/org/flag4j/sparse_vector/CooVectorConversionTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorConversionTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_vector; +package org.flag4j.arrays.sparse.sparse_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; diff --git a/src/test/java/org/flag4j/sparse_vector/CooVectorElemDivTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorElemDivTests.java similarity index 96% rename from src/test/java/org/flag4j/sparse_vector/CooVectorElemDivTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorElemDivTests.java index 952b11256..4e4405cb3 100644 --- a/src/test/java/org/flag4j/sparse_vector/CooVectorElemDivTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorElemDivTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_vector; +package org.flag4j.arrays.sparse.sparse_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CVector; @@ -99,7 +99,7 @@ void doubleScalarDivTestCase() { // -------------------- Sub-case 1 -------------------- b = 24.56; - expValues = new double[]{1.34/b, 51.6/b, -0.00245/b, 99.2456/b, -1005.6/b}; + expValues = new double[]{1.34*(1.0/b), 51.6*(1.0/b), -0.00245*(1.0/b), 99.2456*(1.0/b), -1005.6*(1.0/b)}; expIndices = new int[]{2, 5, 81, 102, 104}; exp = new CooVector(151, expValues, expIndices); assertEquals(exp, a.div(b)); diff --git a/src/test/java/org/flag4j/sparse_vector/CooVectorElemMultTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorElemMultTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_vector/CooVectorElemMultTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorElemMultTests.java index 6345a524c..d92f6b9fe 100644 --- a/src/test/java/org/flag4j/sparse_vector/CooVectorElemMultTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorElemMultTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_vector; +package org.flag4j.arrays.sparse.sparse_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CVector; diff --git a/src/test/java/org/flag4j/sparse_vector/CooVectorInnerProdTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorInnerProdTests.java similarity index 98% rename from src/test/java/org/flag4j/sparse_vector/CooVectorInnerProdTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorInnerProdTests.java index 0f945c3a0..87fdbf1a6 100644 --- a/src/test/java/org/flag4j/sparse_vector/CooVectorInnerProdTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorInnerProdTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_vector; +package org.flag4j.arrays.sparse.sparse_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CVector; @@ -156,7 +156,7 @@ void denseComplexInnerProdTestCase() { @Test void normalizeTestCase() { // ----------------------- Sub-case 1 ----------------------- - double[] expEntries = {0.0046451435284722955, 0.026012803759444855, -0.043455317708858326, 0.9987058586215436}; + double[] expEntries = {0.0046451435284722955, 0.026012803759444852, -0.043455317708858326, 0.9987058586215435}; int[] expIndices = {1, 2, 8, 13}; CooVector exp = new CooVector(sparseSize, expEntries, expIndices); diff --git a/src/test/java/org/flag4j/sparse_vector/CooVectorJoinTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorJoinTests.java similarity index 98% rename from src/test/java/org/flag4j/sparse_vector/CooVectorJoinTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorJoinTests.java index 2f9ba44ff..4ffd118eb 100644 --- a/src/test/java/org/flag4j/sparse_vector/CooVectorJoinTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorJoinTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_vector; +package org.flag4j.arrays.sparse.sparse_vector; import org.flag4j.arrays.Shape; import org.flag4j.arrays.sparse.CooMatrix; diff --git a/src/test/java/org/flag4j/sparse_vector/CooVectorOuterProductTest.java b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorOuterProductTest.java similarity index 99% rename from src/test/java/org/flag4j/sparse_vector/CooVectorOuterProductTest.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorOuterProductTest.java index f24ee55f1..13419eb42 100644 --- a/src/test/java/org/flag4j/sparse_vector/CooVectorOuterProductTest.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorOuterProductTest.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_vector; +package org.flag4j.arrays.sparse.sparse_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CMatrix; diff --git a/src/test/java/org/flag4j/sparse_vector/CooVectorRepeatTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorRepeatTests.java similarity index 98% rename from src/test/java/org/flag4j/sparse_vector/CooVectorRepeatTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorRepeatTests.java index 5c776ffe1..41681a4eb 100644 --- a/src/test/java/org/flag4j/sparse_vector/CooVectorRepeatTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorRepeatTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_vector; +package org.flag4j.arrays.sparse.sparse_vector; import org.flag4j.arrays.dense.Matrix; import org.flag4j.arrays.dense.Vector; diff --git a/src/test/java/org/flag4j/sparse_vector/CooVectorReshapeTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorReshapeTests.java similarity index 97% rename from src/test/java/org/flag4j/sparse_vector/CooVectorReshapeTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorReshapeTests.java index e46dabb57..04c230ed9 100644 --- a/src/test/java/org/flag4j/sparse_vector/CooVectorReshapeTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorReshapeTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_vector; +package org.flag4j.arrays.sparse.sparse_vector; import org.flag4j.arrays.Shape; import org.flag4j.arrays.sparse.CooVector; diff --git a/src/test/java/org/flag4j/sparse_vector/CooVectorSetTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorSetTests.java similarity index 95% rename from src/test/java/org/flag4j/sparse_vector/CooVectorSetTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorSetTests.java index 50c69664a..e56f21d1a 100644 --- a/src/test/java/org/flag4j/sparse_vector/CooVectorSetTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorSetTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_vector; +package org.flag4j.arrays.sparse.sparse_vector; import org.flag4j.arrays.sparse.CooVector; import org.junit.jupiter.api.BeforeAll; diff --git a/src/test/java/org/flag4j/sparse_vector/CooVectorSubTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorSubTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_vector/CooVectorSubTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorSubTests.java index 8402be2cb..57b5088de 100644 --- a/src/test/java/org/flag4j/sparse_vector/CooVectorSubTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorSubTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_vector; +package org.flag4j.arrays.sparse.sparse_vector; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.dense.CVector; diff --git a/src/test/java/org/flag4j/sparse_vector/CooVectorToStringTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorToStringTests.java similarity index 96% rename from src/test/java/org/flag4j/sparse_vector/CooVectorToStringTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorToStringTests.java index 4ff53715c..0c26ceb95 100644 --- a/src/test/java/org/flag4j/sparse_vector/CooVectorToStringTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorToStringTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_vector; +package org.flag4j.arrays.sparse.sparse_vector; import org.flag4j.arrays.sparse.CooVector; import org.flag4j.io.PrintOptions; diff --git a/src/test/java/org/flag4j/sparse_vector/CooVectorUnaryOpTests.java b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorUnaryOpTests.java similarity index 97% rename from src/test/java/org/flag4j/sparse_vector/CooVectorUnaryOpTests.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorUnaryOpTests.java index eca9139ba..c2d8fdd6d 100644 --- a/src/test/java/org/flag4j/sparse_vector/CooVectorUnaryOpTests.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/CooVectorUnaryOpTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_vector; +package org.flag4j.arrays.sparse.sparse_vector; import org.flag4j.CustomAssertions; import org.flag4j.arrays.sparse.CooVector; diff --git a/src/test/java/org/flag4j/sparse_vector/SparseVectorCooSortTest.java b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/SparseVectorCooSortTest.java similarity index 95% rename from src/test/java/org/flag4j/sparse_vector/SparseVectorCooSortTest.java rename to src/test/java/org/flag4j/arrays/sparse/sparse_vector/SparseVectorCooSortTest.java index e47d68222..b11898496 100644 --- a/src/test/java/org/flag4j/sparse_vector/SparseVectorCooSortTest.java +++ b/src/test/java/org/flag4j/arrays/sparse/sparse_vector/SparseVectorCooSortTest.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_vector; +package org.flag4j.arrays.sparse.sparse_vector; import org.flag4j.arrays.sparse.CooVector; import org.junit.jupiter.api.BeforeAll; diff --git a/src/test/java/org/flag4j/sparse_complex_matrix/CooCMatrixConstructorTests.java b/src/test/java/org/flag4j/complex_coo_matrix/CooCMatrixConstructorTests.java similarity index 99% rename from src/test/java/org/flag4j/sparse_complex_matrix/CooCMatrixConstructorTests.java rename to src/test/java/org/flag4j/complex_coo_matrix/CooCMatrixConstructorTests.java index 8fafe2973..17d321b99 100644 --- a/src/test/java/org/flag4j/sparse_complex_matrix/CooCMatrixConstructorTests.java +++ b/src/test/java/org/flag4j/complex_coo_matrix/CooCMatrixConstructorTests.java @@ -1,4 +1,4 @@ -package org.flag4j.sparse_complex_matrix; +package org.flag4j.complex_coo_matrix; import org.flag4j.algebraic_structures.Complex128; import org.flag4j.arrays.Shape; @@ -9,7 +9,7 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertThrows; -public class CooCMatrixConstructorTests { +class CooCMatrixConstructorTests { Complex128[] expNonZero; double[] expNonZeroD; int[] expNonZeroI, expRowIndices, expColIndices; diff --git a/src/test/java/org/flag4j/complex_coo_matrix/CooCMatrixRemoveRowColTests.java b/src/test/java/org/flag4j/complex_coo_matrix/CooCMatrixRemoveRowColTests.java new file mode 100644 index 000000000..3c19bbee4 --- /dev/null +++ b/src/test/java/org/flag4j/complex_coo_matrix/CooCMatrixRemoveRowColTests.java @@ -0,0 +1,268 @@ +package org.flag4j.complex_coo_matrix; + +import org.flag4j.algebraic_structures.Complex128; +import org.flag4j.arrays.Shape; +import org.flag4j.arrays.sparse.CooCMatrix; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class CooCMatrixRemoveRowColTests { + + Shape expShape, actShape; + Complex128[] expData, actData; + int[] expRowIndices, actRowIndices; + int[] expColIndices, actColIndices; + CooCMatrix exp, act; + + @Test + void removeRowTests() { + // ------------------ sub-case 1 ------------------ + actShape = new Shape(12, 45); + actData = new Complex128[]{new Complex128(0.77866, 0.69048), new Complex128(0.76284, 0.7625), new Complex128(0.50343, 0.5486), new Complex128(0.94497, 0.41532), new Complex128(0.53088, 0.74234)}; + actRowIndices = new int[]{1, 3, 5, 7, 11}; + actColIndices = new int[]{13, 38, 3, 33, 3}; + + act = new CooCMatrix(actShape, actData, actRowIndices, actColIndices); + expShape = new Shape(11, 45); + expData = new Complex128[]{new Complex128(0.77866, 0.69048), new Complex128(0.76284, 0.7625), new Complex128(0.50343, 0.5486), new Complex128(0.94497, 0.41532), new Complex128(0.53088, 0.74234)}; + expRowIndices = new int[]{0, 2, 4, 6, 10}; + expColIndices = new int[]{13, 38, 3, 33, 3}; + exp = new CooCMatrix(expShape, expData, expRowIndices, expColIndices); + + assertEquals(exp, act.removeRow(0)); + + // ------------------ sub-case 2 ------------------ + actShape = new Shape(12, 45); + actData = new Complex128[]{new Complex128(0.71598, 0.48605), new Complex128(0.50074, 0.4636), new Complex128(0.4353, 0.48429), new Complex128(0.26224, 0.1529), new Complex128(0.65118, 0.59656)}; + actRowIndices = new int[]{1, 4, 9, 10, 11}; + actColIndices = new int[]{25, 2, 16, 41, 10}; + + act = new CooCMatrix(actShape, actData, actRowIndices, actColIndices); + expShape = new Shape(11, 45); + expData = new Complex128[]{new Complex128(0.71598, 0.48605), new Complex128(0.4353, 0.48429), new Complex128(0.26224, 0.1529), new Complex128(0.65118, 0.59656)}; + expRowIndices = new int[]{1, 8, 9, 10}; + expColIndices = new int[]{25, 16, 41, 10}; + exp = new CooCMatrix(expShape, expData, expRowIndices, expColIndices); + + assertEquals(exp, act.removeRow(4)); + + // ------------------ sub-case 3 ------------------ + actShape = new Shape(12, 45); + actData = new Complex128[]{new Complex128(0.64416, 0.53894), new Complex128(0.80314, 0.25352), new Complex128(0.1104, 0.96916), new Complex128(0.34511, 0.53383), new Complex128(0.44366, 0.67431)}; + actRowIndices = new int[]{1, 3, 4, 8, 11}; + actColIndices = new int[]{3, 24, 14, 27, 34}; + + act = new CooCMatrix(actShape, actData, actRowIndices, actColIndices); + expShape = new Shape(11, 45); + expData = new Complex128[]{new Complex128(0.64416, 0.53894), new Complex128(0.80314, 0.25352), new Complex128(0.1104, 0.96916), new Complex128(0.34511, 0.53383)}; + expRowIndices = new int[]{1, 3, 4, 8}; + expColIndices = new int[]{3, 24, 14, 27}; + exp = new CooCMatrix(expShape, expData, expRowIndices, expColIndices); + + assertEquals(exp, act.removeRow(11)); + + // ------------------ sub-case 4 ------------------ + actShape = new Shape(12, 45); + actData = new Complex128[]{}; + actRowIndices = new int[]{}; + actColIndices = new int[]{}; + + act = new CooCMatrix(actShape, actData, actRowIndices, actColIndices); + expShape = new Shape(11, 45); + expData = new Complex128[]{}; + expRowIndices = new int[]{}; + expColIndices = new int[]{}; + exp = new CooCMatrix(expShape, expData, expRowIndices, expColIndices); + + assertEquals(exp, act.removeRow(2)); + + // ------------------ sub-case 5 ------------------ + act = new CooCMatrix(new Shape(3, 9)); + assertThrows(IndexOutOfBoundsException.class, () -> act.removeRow(4)); + assertThrows(IndexOutOfBoundsException.class, () -> act.removeRow(-1)); + } + + @Test + void removeRowsTests() { + // ------------------ sub-case 1 ------------------ + actShape = new Shape(12, 45); + actData = new Complex128[]{new Complex128(0.69978, 0.68298), new Complex128(0.73937, 0.09706), new Complex128(0.89267, 0.05119), new Complex128(0.57729, 0.10399), new Complex128(0.55045, 0.73243)}; + actRowIndices = new int[]{1, 1, 8, 9, 11}; + actColIndices = new int[]{15, 23, 13, 25, 18}; + + act = new CooCMatrix(actShape, actData, actRowIndices, actColIndices); + expShape = new Shape(9, 45); + expData = new Complex128[]{new Complex128(0.69978, 0.68298), new Complex128(0.73937, 0.09706), new Complex128(0.89267, 0.05119), new Complex128(0.57729, 0.10399), new Complex128(0.55045, 0.73243)}; + expRowIndices = new int[]{0, 0, 5, 6, 8}; + expColIndices = new int[]{15, 23, 13, 25, 18}; + exp = new CooCMatrix(expShape, expData, expRowIndices, expColIndices); + + assertEquals(exp, act.removeRows(0, 3, 4)); + + // ------------------ sub-case 2 ------------------ + actShape = new Shape(15, 12); + actData = new Complex128[]{new Complex128(0.47902, 0.11469), new Complex128(0.4156, 0.84995), new Complex128(0.01331, 0.61483), new Complex128(0.84897, 0.45753), new Complex128(0.89824, 0.53618), new Complex128(0.34686, 0.98037), new Complex128(0.50154, 0.17636), new Complex128(0.47826, 0.48458), new Complex128(0.63824, 0.79207), new Complex128(0.12708, 0.28845), new Complex128(0.20776, 0.21964)}; + actRowIndices = new int[]{0, 0, 1, 1, 1, 5, 6, 7, 8, 13, 13}; + actColIndices = new int[]{2, 10, 4, 5, 9, 6, 8, 10, 2, 1, 9}; + + act = new CooCMatrix(actShape, actData, actRowIndices, actColIndices); + expShape = new Shape(12, 12); + expData = new Complex128[]{new Complex128(0.01331, 0.61483), new Complex128(0.84897, 0.45753), new Complex128(0.89824, 0.53618), new Complex128(0.34686, 0.98037), new Complex128(0.50154, 0.17636), new Complex128(0.47826, 0.48458), new Complex128(0.63824, 0.79207), new Complex128(0.12708, 0.28845), new Complex128(0.20776, 0.21964)}; + expRowIndices = new int[]{0, 0, 0, 2, 3, 4, 5, 10, 10}; + expColIndices = new int[]{4, 5, 9, 6, 8, 10, 2, 1, 9}; + exp = new CooCMatrix(expShape, expData, expRowIndices, expColIndices); + + assertEquals(exp, act.removeRows(4, 0, 3)); + + // ------------------ sub-case 3 ------------------ + actShape = new Shape(15, 12); + actData = new Complex128[]{}; + actRowIndices = new int[]{}; + actColIndices = new int[]{}; + + act = new CooCMatrix(actShape, actData, actRowIndices, actColIndices); + expShape = new Shape(12, 12); + expData = new Complex128[]{}; + expRowIndices = new int[]{}; + expColIndices = new int[]{}; + exp = new CooCMatrix(expShape, expData, expRowIndices, expColIndices); + + assertEquals(exp, act.removeRows(2, 5, 9)); + + // ------------------ sub-case 4 ------------------ + act = new CooCMatrix(new Shape(3, 9)); + assertThrows(IndexOutOfBoundsException.class, () -> act.removeRows(2, 5, 9)); + assertThrows(IndexOutOfBoundsException.class, () -> act.removeRows(0, -1)); + } + + + @Test + void removeColTests() { + // ------------------ sub-case 1 ------------------ + actShape = new Shape(15, 24); + actData = new Complex128[]{new Complex128(0.50744, 0.22408), new Complex128(0.03128, 0.09434), new Complex128(0.67216, 0.71928), new Complex128(0.26396, 0.84781), new Complex128(0.74211, 0.78747), new Complex128(0.53803, 0.80745), new Complex128(0.45852, 0.08599)}; + actRowIndices = new int[]{0, 3, 3, 5, 13, 14, 14}; + actColIndices = new int[]{7, 7, 22, 9, 10, 0, 1}; + + act = new CooCMatrix(actShape, actData, actRowIndices, actColIndices); + expShape = new Shape(15, 23); + expData = new Complex128[]{new Complex128(0.50744, 0.22408), new Complex128(0.03128, 0.09434), new Complex128(0.67216, 0.71928), new Complex128(0.26396, 0.84781), new Complex128(0.74211, 0.78747), new Complex128(0.45852, 0.08599)}; + expRowIndices = new int[]{0, 3, 3, 5, 13, 14}; + expColIndices = new int[]{6, 6, 21, 8, 9, 0}; + exp = new CooCMatrix(expShape, expData, expRowIndices, expColIndices); + + assertEquals(exp, act.removeCol(0)); + + // ------------------ sub-case 2 ------------------ + actShape = new Shape(24, 15); + actData = new Complex128[]{new Complex128(0.66056, 0.75806), new Complex128(0.82459, 0.65965), new Complex128(0.63419, 0.16823), new Complex128(0.74155, 0.35718), new Complex128(0.48856, 0.71093), new Complex128(0.3763, 0.7764), new Complex128(0.66567, 0.19264)}; + actRowIndices = new int[]{0, 7, 7, 13, 15, 18, 19}; + actColIndices = new int[]{9, 2, 13, 14, 0, 0, 8}; + + act = new CooCMatrix(actShape, actData, actRowIndices, actColIndices); + expShape = new Shape(24, 14); + expData = new Complex128[]{new Complex128(0.66056, 0.75806), new Complex128(0.82459, 0.65965), new Complex128(0.63419, 0.16823), new Complex128(0.74155, 0.35718), new Complex128(0.66567, 0.19264)}; + expRowIndices = new int[]{0, 7, 7, 13, 19}; + expColIndices = new int[]{8, 1, 12, 13, 7}; + exp = new CooCMatrix(expShape, expData, expRowIndices, expColIndices); + + assertEquals(exp, act.removeCol(0)); + + // ------------------ sub-case 3 ------------------ + actShape = new Shape(24, 15); + actData = new Complex128[]{new Complex128(0.85192, 0.11152), new Complex128(0.46917, 0.82861), new Complex128(0.62357, 0.56506), new Complex128(0.72122, 0.04086), new Complex128(0.17467, 0.97519), new Complex128(0.93648, 0.79936), new Complex128(0.19383, 0.26116), new Complex128(0.35744, 0.12276), new Complex128(0.714, 0.69116), new Complex128(0.26821, 0.17701), new Complex128(0.306, 0.75897), new Complex128(0.50037, 0.63035)}; + actRowIndices = new int[]{1, 1, 4, 5, 12, 15, 16, 17, 17, 18, 21, 22}; + actColIndices = new int[]{1, 8, 11, 8, 2, 2, 4, 11, 14, 1, 1, 4}; + + act = new CooCMatrix(actShape, actData, actRowIndices, actColIndices); + expShape = new Shape(24, 14); + expData = new Complex128[]{new Complex128(0.85192, 0.11152), new Complex128(0.46917, 0.82861), new Complex128(0.62357, 0.56506), new Complex128(0.72122, 0.04086), new Complex128(0.17467, 0.97519), new Complex128(0.93648, 0.79936), new Complex128(0.35744, 0.12276), new Complex128(0.714, 0.69116), new Complex128(0.26821, 0.17701), new Complex128(0.306, 0.75897)}; + expRowIndices = new int[]{1, 1, 4, 5, 12, 15, 17, 17, 18, 21}; + expColIndices = new int[]{1, 7, 10, 7, 2, 2, 10, 13, 1, 1}; + exp = new CooCMatrix(expShape, expData, expRowIndices, expColIndices); + + assertEquals(exp, act.removeCol(4)); + + // ------------------ sub-case 4 ------------------ + actShape = new Shape(24, 15); + actData = new Complex128[]{new Complex128(0.77821, 0.00128), new Complex128(0.2461, 0.13312), new Complex128(0.97996, 0.751), new Complex128(0.58622, 0.18542), new Complex128(0.32885, 0.45088), new Complex128(0.24512, 0.82565), new Complex128(0.08009, 0.86475), new Complex128(0.0346, 0.32301), new Complex128(0.68427, 0.36554), new Complex128(0.07324, 0.88487), new Complex128(0.54839, 0.4006), new Complex128(0.1038, 0.73258)}; + actRowIndices = new int[]{0, 2, 4, 5, 7, 8, 8, 20, 20, 21, 22, 22}; + actColIndices = new int[]{12, 0, 5, 12, 14, 9, 12, 8, 12, 11, 3, 9}; + act = new CooCMatrix(actShape, actData, actRowIndices, actColIndices); + + expShape = new Shape(24, 14); + expData = new Complex128[]{new Complex128(0.77821, 0.00128), new Complex128(0.2461, 0.13312), new Complex128(0.97996, 0.751), new Complex128(0.58622, 0.18542), new Complex128(0.24512, 0.82565), new Complex128(0.08009, 0.86475), new Complex128(0.0346, 0.32301), new Complex128(0.68427, 0.36554), new Complex128(0.07324, 0.88487), new Complex128(0.54839, 0.4006), new Complex128(0.1038, 0.73258)}; + expRowIndices = new int[]{0, 2, 4, 5, 8, 8, 20, 20, 21, 22, 22}; + expColIndices = new int[]{12, 0, 5, 12, 9, 12, 8, 12, 11, 3, 9}; + exp = new CooCMatrix(expShape, expData, expRowIndices, expColIndices); + + assertEquals(exp, act.removeCol(14)); + + // ------------------ sub-case 5 ------------------ + actShape = new Shape(15, 24); + actData = new Complex128[]{new Complex128(0.50744, 0.22408), + new Complex128(0.03128, 0.09434), + new Complex128(0.67216, 0.71928), + new Complex128(0.26396, 0.84781), + new Complex128(0.74211, 0.78747), + new Complex128(0.53803, 0.80745), + new Complex128(0.45852, 0.08599)}; + actRowIndices = new int[]{0, 1, 2, 3, 5, 11, 13}; + actColIndices = new int[]{5, 5, 5, 5, 5, 5, 5}; + act = new CooCMatrix(actShape, actData, actRowIndices, actColIndices); + + expShape = new Shape(15, 23); + expData = new Complex128[]{}; + expRowIndices = new int[]{}; + expColIndices = new int[]{}; + exp = new CooCMatrix(expShape, expData, expRowIndices, expColIndices); + + assertEquals(exp, act.removeCol(5)); + + // ------------------ sub-case 6 ------------------ + act = new CooCMatrix(new Shape(3, 9)); + assertThrows(IndexOutOfBoundsException.class, () -> act.removeCol(14)); + assertThrows(IndexOutOfBoundsException.class, () -> act.removeRow(-1)); + } + + + @Test + void removeColsTests() { + // ------------------ sub-case 1 ------------------ + actShape = new Shape(24, 15); + actData = new Complex128[]{new Complex128(0.34666, 0.12705), new Complex128(0.0907, 0.15388), new Complex128(0.95514, 0.67512), new Complex128(0.84671, 0.65075), new Complex128(0.57485, 0.02965), new Complex128(0.00081, 0.82227), new Complex128(0.94886, 0.3042), new Complex128(0.03307, 0.24848), new Complex128(0.82155, 0.00032), new Complex128(0.63371, 0.2528), new Complex128(0.01639, 0.46448), new Complex128(0.3547, 0.83087)}; + actRowIndices = new int[]{0, 0, 4, 6, 9, 9, 10, 14, 15, 16, 19, 19}; + actColIndices = new int[]{13, 14, 2, 12, 3, 4, 6, 2, 9, 6, 4, 5}; + act = new CooCMatrix(actShape, actData, actRowIndices, actColIndices); + + expShape = new Shape(24, 11); + expData = new Complex128[]{new Complex128(0.34666, 0.12705), new Complex128(0.95514, 0.67512), new Complex128(0.84671, 0.65075), new Complex128(0.57485, 0.02965), new Complex128(0.00081, 0.82227), new Complex128(0.94886, 0.3042), new Complex128(0.03307, 0.24848), new Complex128(0.82155, 0.00032), new Complex128(0.63371, 0.2528), new Complex128(0.01639, 0.46448)}; + expRowIndices = new int[]{0, 4, 6, 9, 9, 10, 14, 15, 16, 19}; + expColIndices = new int[]{10, 0, 9, 1, 2, 3, 0, 6, 3, 2}; + exp = new CooCMatrix(expShape, expData, expRowIndices, expColIndices); + + assertEquals(exp, act.removeCols(14, 5, 0, 1)); + + // ------------------ sub-case 2 ------------------ + actShape = new Shape(24, 15); + actData = new Complex128[]{}; + actRowIndices = new int[]{}; + actColIndices = new int[]{}; + act = new CooCMatrix(actShape, actData, actRowIndices, actColIndices); + + expShape = new Shape(24, 11); + expData = new Complex128[]{}; + expRowIndices = new int[]{}; + expColIndices = new int[]{}; + exp = new CooCMatrix(expShape, expData, expRowIndices, expColIndices); + + assertEquals(exp, act.removeCols(1, 5, 0, 2)); + + // ------------------ sub-case 4 ------------------ + act = new CooCMatrix(new Shape(9, 3)); + assertThrows(IndexOutOfBoundsException.class, () -> act.removeCols(2, 5, 9)); + assertThrows(IndexOutOfBoundsException.class, () -> act.removeCols(0, -1)); + } +} diff --git a/src/test/java/org/flag4j/complex_coo_matrix/CooCMatrixReshapeTests.java b/src/test/java/org/flag4j/complex_coo_matrix/CooCMatrixReshapeTests.java new file mode 100644 index 000000000..7a9ee26c6 --- /dev/null +++ b/src/test/java/org/flag4j/complex_coo_matrix/CooCMatrixReshapeTests.java @@ -0,0 +1,157 @@ +package org.flag4j.complex_coo_matrix; + +import org.flag4j.algebraic_structures.Complex128; +import org.flag4j.arrays.Shape; +import org.flag4j.arrays.sparse.CooCMatrix; +import org.flag4j.util.exceptions.TensorShapeException; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class CooCMatrixReshapeTests { + + Shape expShape, actShape; + Complex128[] expData, actData; + int[] expRowIndices, actRowIndices; + int[] expColIndices, actColIndices; + CooCMatrix exp, act; + + @Test + void flattenTests() { + // ---------------- Sub-case 1 ---------------- + actShape = new Shape(54, 12); + actData = new Complex128[]{new Complex128(1, 2), new Complex128(3, 4), new Complex128(5, 6), + new Complex128(7, 8), new Complex128(9, 10)}; + actRowIndices = new int[]{0, 14, 14, 14, 45}; + actColIndices = new int[]{9, 4, 5, 11, 6}; + act = new CooCMatrix(actShape, actData, actRowIndices, actColIndices); + + expShape = new Shape(1, 54*12); + expData = new Complex128[]{new Complex128(1, 2), new Complex128(3, 4), new Complex128(5, 6), + new Complex128(7, 8), new Complex128(9, 10)}; + expRowIndices = new int[]{0, 0, 0, 0, 0}; + expColIndices = new int[]{9, 14*12 + 4, 14*12 + 5, 14*12 + 11, 45*12 + 6}; + exp = new CooCMatrix(expShape, expData, expRowIndices, expColIndices); + + assertEquals(exp, act.flatten()); + + // ---------------- Sub-case 2 ---------------- + actShape = new Shape(54, 12); + actData = new Complex128[]{}; + actRowIndices = new int[]{}; + actColIndices = new int[]{}; + act = new CooCMatrix(actShape, actData, actRowIndices, actColIndices); + + expShape = new Shape(1, 54*12); + expData = new Complex128[]{}; + expRowIndices = new int[]{}; + expColIndices = new int[]{}; + exp = new CooCMatrix(expShape, expData, expRowIndices, expColIndices); + + assertEquals(exp, act.flatten()); + + // ---------------- Sub-case 3 ---------------- + actShape = new Shape(25, 15); + actData = new Complex128[]{new Complex128(0.2777, 0.94248), new Complex128(0.38635, 0.24736), new Complex128(0.3829, 0.22189), new Complex128(0.49247, 0.31679), new Complex128(0.5719, 0.25363), new Complex128(0.24135, 0.72457), new Complex128(0.83898, 0.67385), new Complex128(0.43352, 0.61757)}; + actRowIndices = new int[]{6, 7, 8, 8, 15, 17, 19, 23}; + actColIndices = new int[]{2, 9, 2, 11, 6, 4, 4, 12}; + act = new CooCMatrix(actShape, actData, actRowIndices, actColIndices); + + expShape = new Shape(1, 375); + expData = new Complex128[]{new Complex128(0.2777, 0.94248), new Complex128(0.38635, 0.24736), new Complex128(0.3829, 0.22189), new Complex128(0.49247, 0.31679), new Complex128(0.5719, 0.25363), new Complex128(0.24135, 0.72457), new Complex128(0.83898, 0.67385), new Complex128(0.43352, 0.61757)}; + expRowIndices = new int[]{0, 0, 0, 0, 0, 0, 0, 0}; + expColIndices = new int[]{92, 114, 122, 131, 231, 259, 289, 357}; + exp = new CooCMatrix(expShape, expData, expRowIndices, expColIndices); + + assertEquals(exp, act.flatten(1)); + + // ---------------- Sub-case 4 ---------------- + actShape = new Shape(12, 54); + actData = new Complex128[]{new Complex128(0.51154, 0.50554), new Complex128(0.79968, 0.71548), new Complex128(0.4903, 0.39299), new Complex128(0.93451, 0.13945), new Complex128(0.89019, 0.7986), new Complex128(0.59208, 0.65129)}; + actRowIndices = new int[]{2, 4, 5, 7, 7, 9}; + actColIndices = new int[]{36, 43, 12, 2, 37, 6}; + act = new CooCMatrix(actShape, actData, actRowIndices, actColIndices); + + expShape = new Shape(648, 1); + expData = new Complex128[]{new Complex128(0.51154, 0.50554), new Complex128(0.79968, 0.71548), new Complex128(0.4903, 0.39299), new Complex128(0.93451, 0.13945), new Complex128(0.89019, 0.7986), new Complex128(0.59208, 0.65129)}; + expRowIndices = new int[]{144, 259, 282, 380, 415, 492}; + expColIndices = new int[]{0, 0, 0, 0, 0, 0}; + exp = new CooCMatrix(expShape, expData, expRowIndices, expColIndices); + + assertEquals(exp, act.flatten(0)); + + // ---------------- Sub-case 5 ---------------- + actShape = new Shape(5, 5); + actData = new Complex128[]{}; + actRowIndices = new int[]{}; + actColIndices = new int[]{}; + act = new CooCMatrix(actShape, actData, actRowIndices, actColIndices); + + expShape = new Shape(25, 1); + expData = new Complex128[]{}; + expRowIndices = new int[]{}; + expColIndices = new int[]{}; + exp = new CooCMatrix(expShape, expData, expRowIndices, expColIndices); + + assertEquals(exp, act.flatten(0)); + } + + + @Test + void reshapeTests() { + // ---------------- Sub-case 1 ---------------- + actShape = new Shape(12, 4); + actData = new Complex128[]{}; + actRowIndices = new int[]{}; + actColIndices = new int[]{}; + act = new CooCMatrix(actShape, actData, actRowIndices, actColIndices); + + expShape = new Shape(3, 16); + expData = new Complex128[]{}; + expRowIndices = new int[]{}; + expColIndices = new int[]{}; + exp = new CooCMatrix(expShape, expData, expRowIndices, expColIndices); + + assertEquals(exp, act.reshape(expShape)); + + // ---------------- Sub-case 2 ---------------- + actShape = new Shape(12, 4); + actData = new Complex128[]{new Complex128(0.34722, 0.89473), new Complex128(0.3139, 0.63988), new Complex128(0.18243, 0.63827), new Complex128(0.34544, 0.55157), new Complex128(0.85195, 0.06837)}; + actRowIndices = new int[]{0, 2, 4, 5, 10}; + actColIndices = new int[]{1, 1, 3, 3, 1}; + act = new CooCMatrix(actShape, actData, actRowIndices, actColIndices); + + expShape = new Shape(3, 16); + expData = new Complex128[]{new Complex128(0.34722, 0.89473), new Complex128(0.3139, 0.63988), new Complex128(0.18243, 0.63827), new Complex128(0.34544, 0.55157), new Complex128(0.85195, 0.06837)}; + expRowIndices = new int[]{0, 0, 1, 1, 2}; + expColIndices = new int[]{1, 9, 3, 7, 9}; + exp = new CooCMatrix(expShape, expData, expRowIndices, expColIndices); + + assertEquals(exp, act.reshape(expShape)); + + // ---------------- Sub-case 3 ---------------- + actShape = new Shape(8, 8); + actData = new Complex128[]{new Complex128(0.07652, 0.43013), new Complex128(0.58944, 0.31322), new Complex128(0.66872, 0.3021)}; + actRowIndices = new int[]{4, 5, 6}; + actColIndices = new int[]{4, 0, 3}; + act = new CooCMatrix(actShape, actData, actRowIndices, actColIndices); + + expShape = new Shape(32, 2); + expData = new Complex128[]{new Complex128(0.07652, 0.43013), new Complex128(0.58944, 0.31322), new Complex128(0.66872, 0.3021)}; + expRowIndices = new int[]{18, 20, 25}; + expColIndices = new int[]{0, 0, 1}; + exp = new CooCMatrix(expShape, expData, expRowIndices, expColIndices); + + assertEquals(exp, act.reshape(expShape)); + + // ---------------- Sub-case 4 ---------------- + actShape = new Shape(8, 8); + actData = new Complex128[]{new Complex128(0.07652, 0.43013), new Complex128(0.58944, 0.31322), new Complex128(0.66872, 0.3021)}; + actRowIndices = new int[]{4, 5, 6}; + actColIndices = new int[]{4, 0, 3}; + act = new CooCMatrix(actShape, actData, actRowIndices, actColIndices); + + assertThrows(TensorShapeException.class, ()->act.reshape(new Shape(4, 5))); + } +} diff --git a/src/test/java/org/flag4j/complex_coo_matrix/CooCMatrixSwapRowColTests.java b/src/test/java/org/flag4j/complex_coo_matrix/CooCMatrixSwapRowColTests.java new file mode 100644 index 000000000..f07a1953f --- /dev/null +++ b/src/test/java/org/flag4j/complex_coo_matrix/CooCMatrixSwapRowColTests.java @@ -0,0 +1,155 @@ +package org.flag4j.complex_coo_matrix; + +import org.flag4j.algebraic_structures.Complex128; +import org.flag4j.arrays.Shape; +import org.flag4j.arrays.sparse.CooCMatrix; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class CooCMatrixSwapRowColTests { + Shape expShape, actShape; + Complex128[] expData, actData; + int[] expRowIndices, actRowIndices; + int[] expColIndices, actColIndices; + CooCMatrix exp, act; + + @Test + void swapRowTests() { + // ------------------ sub-case 1 ------------------ + actShape = new Shape(42, 15); + actData = new Complex128[]{new Complex128(0.83954, 0.27882), new Complex128(0.49378, 0.45173), new Complex128(0.57154, 0.29647), new Complex128(0.53119, 0.32832), new Complex128(0.2354, 0.52295), new Complex128(0.01532, 0.15957)}; + actRowIndices = new int[]{2, 2, 13, 15, 18, 21}; + actColIndices = new int[]{2, 14, 5, 6, 13, 3}; + act = new CooCMatrix(actShape, actData, actRowIndices, actColIndices); + + expShape = new Shape(42, 15); + expData = new Complex128[]{new Complex128(0.83954, 0.27882), new Complex128(0.49378, 0.45173), new Complex128(0.57154, 0.29647), new Complex128(0.53119, 0.32832), new Complex128(0.2354, 0.52295), new Complex128(0.01532, 0.15957)}; + expRowIndices = new int[]{2, 2, 13, 15, 18, 21}; + expColIndices = new int[]{2, 14, 5, 6, 13, 3}; + exp = new CooCMatrix(expShape, expData, expRowIndices, expColIndices); + + assertEquals(exp, act.copy().swapRows(0, 5)); + assertEquals(exp, act.swapRows(5, 0)); + + // ------------------ sub-case 2 ------------------ + actShape = new Shape(13, 15); + actData = new Complex128[]{new Complex128(0.69961, 0.84185), new Complex128(0.2981, 0.15324), new Complex128(0.38024, 0.52313), new Complex128(0.02953, 0.6424), new Complex128(0.24341, 0.90827), new Complex128(0.34581, 0.51126), new Complex128(0.72608, 0.87854), new Complex128(0.55568, 0.57181), new Complex128(0.37162, 0.5097), new Complex128(0.58337, 0.23366), new Complex128(0.76915, 0.50693), new Complex128(0.4387, 0.26465), new Complex128(0.63875, 0.92151), new Complex128(0.97516, 0.75893), new Complex128(0.5171, 0.82702), new Complex128(0.82809, 0.53152)}; + actRowIndices = new int[]{0, 0, 1, 2, 2, 3, 4, 4, 5, 5, 6, 7, 8, 12, 12, 12}; + actColIndices = new int[]{6, 9, 4, 6, 10, 8, 2, 13, 1, 10, 10, 8, 5, 3, 10, 11}; + act = new CooCMatrix(actShape, actData, actRowIndices, actColIndices); + + expShape = new Shape(13, 15); + expData = new Complex128[]{new Complex128(0.37162, 0.5097), new Complex128(0.58337, 0.23366), new Complex128(0.38024, 0.52313), new Complex128(0.02953, 0.6424), new Complex128(0.24341, 0.90827), new Complex128(0.34581, 0.51126), new Complex128(0.72608, 0.87854), new Complex128(0.55568, 0.57181), new Complex128(0.69961, 0.84185), new Complex128(0.2981, 0.15324), new Complex128(0.76915, 0.50693), new Complex128(0.4387, 0.26465), new Complex128(0.63875, 0.92151), new Complex128(0.97516, 0.75893), new Complex128(0.5171, 0.82702), new Complex128(0.82809, 0.53152)}; + expRowIndices = new int[]{0, 0, 1, 2, 2, 3, 4, 4, 5, 5, 6, 7, 8, 12, 12, 12}; + expColIndices = new int[]{1, 10, 4, 6, 10, 8, 2, 13, 6, 9, 10, 8, 5, 3, 10, 11}; + exp = new CooCMatrix(expShape, expData, expRowIndices, expColIndices); + + assertEquals(exp, act.copy().swapRows(0, 5)); + assertEquals(exp, act.swapRows(5, 0)); + + // ------------------ sub-case 3 ------------------ + actShape = new Shape(13, 15); + actData = new Complex128[]{new Complex128(0.24099, 0.0519), new Complex128(0.41046, 0.01209), new Complex128(0.44403, 0.26415), new Complex128(0.64789, 0.41217), new Complex128(0.34822, 0.81154), new Complex128(0.73094, 0.6667), new Complex128(0.58335, 0.76636), new Complex128(0.3706, 0.00392), new Complex128(0.805, 0.43727), new Complex128(0.64965, 0.42741), new Complex128(0.63493, 0.04803), new Complex128(0.17587, 0.57598), new Complex128(0.2884, 0.13797), new Complex128(0.74812, 0.06963), new Complex128(0.29151, 0.61113), new Complex128(0.79594, 0.98446), new Complex128(0.32668, 0.48901), new Complex128(0.74668, 0.44182), new Complex128(0.63456, 0.4588), new Complex128(0.2882, 0.3568)}; + actRowIndices = new int[]{1, 2, 3, 3, 4, 5, 6, 6, 6, 7, 8, 9, 10, 10, 11, 11, 12, 12, 12, 12}; + actColIndices = new int[]{5, 1, 9, 14, 13, 14, 7, 12, 14, 8, 10, 8, 1, 2, 3, 12, 0, 1, 7, 8}; + act = new CooCMatrix(actShape, actData, actRowIndices, actColIndices); + + expShape = new Shape(13, 15); + expData = new Complex128[]{new Complex128(0.24099, 0.0519), new Complex128(0.41046, 0.01209), new Complex128(0.44403, 0.26415), new Complex128(0.64789, 0.41217), new Complex128(0.34822, 0.81154), new Complex128(0.32668, 0.48901), new Complex128(0.74668, 0.44182), new Complex128(0.63456, 0.4588), new Complex128(0.2882, 0.3568), new Complex128(0.58335, 0.76636), new Complex128(0.3706, 0.00392), new Complex128(0.805, 0.43727), new Complex128(0.64965, 0.42741), new Complex128(0.63493, 0.04803), new Complex128(0.17587, 0.57598), new Complex128(0.2884, 0.13797), new Complex128(0.74812, 0.06963), new Complex128(0.29151, 0.61113), new Complex128(0.79594, 0.98446), new Complex128(0.73094, 0.6667)}; + expRowIndices = new int[]{1, 2, 3, 3, 4, 5, 5, 5, 5, 6, 6, 6, 7, 8, 9, 10, 10, 11, 11, 12}; + expColIndices = new int[]{5, 1, 9, 14, 13, 0, 1, 7, 8, 7, 12, 14, 8, 10, 8, 1, 2, 3, 12, 14}; + exp = new CooCMatrix(expShape, expData, expRowIndices, expColIndices); + + assertEquals(exp, act.copy().swapRows(5, 12)); + assertEquals(exp, act.swapRows(12, 5)); + + // ------------------ sub-case 3 ------------------ + actShape = new Shape(13, 15); + actData = new Complex128[]{new Complex128(0.24099, 0.0519), new Complex128(0.41046, 0.01209), new Complex128(0.44403, 0.26415), new Complex128(0.64789, 0.41217), new Complex128(0.34822, 0.81154), new Complex128(0.73094, 0.6667), new Complex128(0.58335, 0.76636), new Complex128(0.3706, 0.00392), new Complex128(0.805, 0.43727), new Complex128(0.64965, 0.42741), new Complex128(0.63493, 0.04803), new Complex128(0.17587, 0.57598), new Complex128(0.2884, 0.13797), new Complex128(0.74812, 0.06963), new Complex128(0.29151, 0.61113), new Complex128(0.79594, 0.98446), new Complex128(0.32668, 0.48901), new Complex128(0.74668, 0.44182), new Complex128(0.63456, 0.4588), new Complex128(0.2882, 0.3568)}; + actRowIndices = new int[]{1, 2, 3, 3, 4, 5, 6, 6, 6, 7, 8, 9, 10, 10, 11, 11, 12, 12, 12, 12}; + actColIndices = new int[]{5, 1, 9, 14, 13, 14, 7, 12, 14, 8, 10, 8, 1, 2, 3, 12, 0, 1, 7, 8}; + act = new CooCMatrix(actShape, actData, actRowIndices, actColIndices); + + expShape = new Shape(13, 15); + expData = new Complex128[]{new Complex128(0.24099, 0.0519), new Complex128(0.41046, 0.01209), new Complex128(0.44403, 0.26415), new Complex128(0.64789, 0.41217), new Complex128(0.34822, 0.81154), new Complex128(0.73094, 0.6667), new Complex128(0.58335, 0.76636), new Complex128(0.3706, 0.00392), new Complex128(0.805, 0.43727), new Complex128(0.64965, 0.42741), new Complex128(0.63493, 0.04803), new Complex128(0.17587, 0.57598), new Complex128(0.2884, 0.13797), new Complex128(0.74812, 0.06963), new Complex128(0.29151, 0.61113), new Complex128(0.79594, 0.98446), new Complex128(0.32668, 0.48901), new Complex128(0.74668, 0.44182), new Complex128(0.63456, 0.4588), new Complex128(0.2882, 0.3568)}; + expRowIndices = new int[]{1, 2, 3, 3, 4, 5, 6, 6, 6, 7, 8, 9, 10, 10, 11, 11, 12, 12, 12, 12}; + expColIndices = new int[]{5, 1, 9, 14, 13, 14, 7, 12, 14, 8, 10, 8, 1, 2, 3, 12, 0, 1, 7, 8}; + exp = new CooCMatrix(expShape, expData, expRowIndices, expColIndices); + + assertEquals(exp, act.copy().swapRows(5, 5)); + assertEquals(exp, act.swapRows(12, 12)); + + // ------------------ sub-case 4 ------------------ + actShape = new Shape(13, 15); + act = new CooCMatrix(actShape); + + assertThrows(IndexOutOfBoundsException.class, () -> act.swapRows(-1, 0)); + assertThrows(IndexOutOfBoundsException.class, () -> act.swapRows(0, -1)); + assertThrows(IndexOutOfBoundsException.class, () -> act.swapRows(1, 14)); + assertThrows(IndexOutOfBoundsException.class, () -> act.swapRows(13, 1)); + } + + + @Test + void swapColTests() { + // ------------------ sub-case 1 ------------------ + actShape = new Shape(15, 13); + actData = new Complex128[]{new Complex128(0.91602, 0.16221), new Complex128(0.96461, 0.28433), new Complex128(0.58055, 0.82438), new Complex128(0.52431, 0.01252), new Complex128(0.29084, 0.55349), new Complex128(0.57376, 0.65649), new Complex128(0.30012, 0.98787), new Complex128(0.04433, 0.47742), new Complex128(0.74451, 0.58319), new Complex128(0.03021, 0.14801), new Complex128(0.02442, 0.49737), new Complex128(0.79838, 0.67637), new Complex128(0.89686, 0.42408), new Complex128(0.50349, 0.86714), new Complex128(0.06367, 0.22256), new Complex128(0.33357, 0.6249), new Complex128(0.69218, 0.04192), new Complex128(0.76516, 0.11117), new Complex128(0.24567, 0.35127), new Complex128(0.56763, 0.58676)}; + actRowIndices = new int[]{0, 0, 1, 3, 4, 5, 5, 6, 6, 6, 8, 9, 10, 10, 11, 12, 13, 14, 14, 14}; + actColIndices = new int[]{3, 12, 4, 4, 8, 0, 12, 1, 4, 5, 7, 9, 0, 11, 0, 7, 4, 2, 6, 11}; + act = new CooCMatrix(actShape, actData, actRowIndices, actColIndices); + + expShape = new Shape(15, 13); + expData = new Complex128[]{new Complex128(0.91602, 0.16221), new Complex128(0.96461, 0.28433), new Complex128(0.58055, 0.82438), new Complex128(0.52431, 0.01252), new Complex128(0.29084, 0.55349), new Complex128(0.57376, 0.65649), new Complex128(0.30012, 0.98787), new Complex128(0.03021, 0.14801), new Complex128(0.04433, 0.47742), new Complex128(0.74451, 0.58319), new Complex128(0.02442, 0.49737), new Complex128(0.79838, 0.67637), new Complex128(0.89686, 0.42408), new Complex128(0.50349, 0.86714), new Complex128(0.06367, 0.22256), new Complex128(0.33357, 0.6249), new Complex128(0.69218, 0.04192), new Complex128(0.76516, 0.11117), new Complex128(0.24567, 0.35127), new Complex128(0.56763, 0.58676)}; + expRowIndices = new int[]{0, 0, 1, 3, 4, 5, 5, 6, 6, 6, 8, 9, 10, 10, 11, 12, 13, 14, 14, 14}; + expColIndices = new int[]{3, 12, 4, 4, 8, 5, 12, 0, 1, 4, 7, 9, 5, 11, 5, 7, 4, 2, 6, 11}; + exp = new CooCMatrix(expShape, expData, expRowIndices, expColIndices); + + assertEquals(exp, act.copy().swapCols(0, 5)); + assertEquals(exp, act.swapCols(5, 0)); + + // ------------------ sub-case 2 ------------------ + actShape = new Shape(15, 13); + actData = new Complex128[]{new Complex128(0.11608, 0.61874), new Complex128(0.46118, 0.28252), new Complex128(0.58771, 0.70903), new Complex128(0.77484, 0.16391), new Complex128(0.29803, 0.72315), new Complex128(0.41697, 0.09332), new Complex128(0.16798, 0.65098), new Complex128(0.04629, 0.22586), new Complex128(0.4762, 0.55212), new Complex128(0.88635, 0.57277), new Complex128(0.98888, 0.28049), new Complex128(0.03914, 0.90305), new Complex128(0.44566, 0.92861), new Complex128(0.69347, 0.70765), new Complex128(0.16541, 0.37864), new Complex128(0.02777, 0.03835), new Complex128(0.08825, 0.34296), new Complex128(0.3324, 0.12397), new Complex128(0.79772, 0.86073), new Complex128(0.11217, 0.66587)}; + actRowIndices = new int[]{0, 0, 0, 2, 2, 3, 4, 4, 4, 4, 5, 5, 7, 7, 9, 10, 11, 12, 12, 13}; + actColIndices = new int[]{6, 7, 11, 4, 7, 10, 5, 8, 11, 12, 1, 8, 7, 10, 6, 5, 4, 5, 8, 7}; + act = new CooCMatrix(actShape, actData, actRowIndices, actColIndices); + + expShape = new Shape(15, 13); + expData = new Complex128[]{new Complex128(0.11608, 0.61874), new Complex128(0.46118, 0.28252), new Complex128(0.58771, 0.70903), new Complex128(0.77484, 0.16391), new Complex128(0.29803, 0.72315), new Complex128(0.41697, 0.09332), new Complex128(0.88635, 0.57277), new Complex128(0.04629, 0.22586), new Complex128(0.4762, 0.55212), new Complex128(0.16798, 0.65098), new Complex128(0.98888, 0.28049), new Complex128(0.03914, 0.90305), new Complex128(0.44566, 0.92861), new Complex128(0.69347, 0.70765), new Complex128(0.16541, 0.37864), new Complex128(0.02777, 0.03835), new Complex128(0.08825, 0.34296), new Complex128(0.79772, 0.86073), new Complex128(0.3324, 0.12397), new Complex128(0.11217, 0.66587)}; + expRowIndices = new int[]{0, 0, 0, 2, 2, 3, 4, 4, 4, 4, 5, 5, 7, 7, 9, 10, 11, 12, 12, 13}; + expColIndices = new int[]{6, 7, 11, 4, 7, 10, 5, 8, 11, 12, 1, 8, 7, 10, 6, 12, 4, 8, 12, 7}; + exp = new CooCMatrix(expShape, expData, expRowIndices, expColIndices); + + assertEquals(exp, act.copy().swapCols(12, 5)); + assertEquals(exp, act.swapCols(5, 12)); + + + // ------------------ sub-case 2 ------------------ + actShape = new Shape(15, 13); + actData = new Complex128[]{new Complex128(0.11608, 0.61874), new Complex128(0.46118, 0.28252), new Complex128(0.58771, 0.70903), new Complex128(0.77484, 0.16391), new Complex128(0.29803, 0.72315), new Complex128(0.41697, 0.09332), new Complex128(0.16798, 0.65098), new Complex128(0.04629, 0.22586), new Complex128(0.4762, 0.55212), new Complex128(0.88635, 0.57277), new Complex128(0.98888, 0.28049), new Complex128(0.03914, 0.90305), new Complex128(0.44566, 0.92861), new Complex128(0.69347, 0.70765), new Complex128(0.16541, 0.37864), new Complex128(0.02777, 0.03835), new Complex128(0.08825, 0.34296), new Complex128(0.3324, 0.12397), new Complex128(0.79772, 0.86073), new Complex128(0.11217, 0.66587)}; + actRowIndices = new int[]{0, 0, 0, 2, 2, 3, 4, 4, 4, 4, 5, 5, 7, 7, 9, 10, 11, 12, 12, 13}; + actColIndices = new int[]{6, 7, 11, 4, 7, 10, 5, 8, 11, 12, 1, 8, 7, 10, 6, 5, 4, 5, 8, 7}; + act = new CooCMatrix(actShape, actData, actRowIndices, actColIndices); + + expShape = new Shape(15, 13); + expData = new Complex128[]{new Complex128(0.11608, 0.61874), new Complex128(0.46118, 0.28252), new Complex128(0.58771, 0.70903), new Complex128(0.77484, 0.16391), new Complex128(0.29803, 0.72315), new Complex128(0.41697, 0.09332), new Complex128(0.16798, 0.65098), new Complex128(0.04629, 0.22586), new Complex128(0.4762, 0.55212), new Complex128(0.88635, 0.57277), new Complex128(0.98888, 0.28049), new Complex128(0.03914, 0.90305), new Complex128(0.44566, 0.92861), new Complex128(0.69347, 0.70765), new Complex128(0.16541, 0.37864), new Complex128(0.02777, 0.03835), new Complex128(0.08825, 0.34296), new Complex128(0.3324, 0.12397), new Complex128(0.79772, 0.86073), new Complex128(0.11217, 0.66587)}; + expRowIndices = new int[]{0, 0, 0, 2, 2, 3, 4, 4, 4, 4, 5, 5, 7, 7, 9, 10, 11, 12, 12, 13}; + expColIndices = new int[]{6, 7, 11, 4, 7, 10, 5, 8, 11, 12, 1, 8, 7, 10, 6, 5, 4, 5, 8, 7}; + exp = new CooCMatrix(expShape, expData, expRowIndices, expColIndices); + + assertEquals(exp, act.copy().swapCols(5, 5)); + assertEquals(exp, act.swapCols(12, 12)); + + // ------------------ sub-case 3 ------------------ + actShape = new Shape(15, 13); + act = new CooCMatrix(actShape); + + assertThrows(IndexOutOfBoundsException.class, () -> act.swapCols(-1, 0)); + assertThrows(IndexOutOfBoundsException.class, () -> act.swapCols(0, -1)); + assertThrows(IndexOutOfBoundsException.class, () -> act.swapCols(1, 14)); + assertThrows(IndexOutOfBoundsException.class, () -> act.swapCols(13, 1)); + } +} diff --git a/src/test/java/org/flag4j/complex_matrix/CMatrixCsrMatMultTests.java b/src/test/java/org/flag4j/complex_matrix/CMatrixCsrMatMultTests.java deleted file mode 100644 index 31942d960..000000000 --- a/src/test/java/org/flag4j/complex_matrix/CMatrixCsrMatMultTests.java +++ /dev/null @@ -1,93 +0,0 @@ -package org.flag4j.complex_matrix; - -import org.flag4j.algebraic_structures.Complex128; -import org.flag4j.arrays.Shape; -import org.flag4j.arrays.dense.CMatrix; -import org.flag4j.arrays.sparse.CsrMatrix; -import org.flag4j.util.exceptions.LinearAlgebraException; -import org.junit.jupiter.api.Test; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; - -class CMatrixCsrMatMultTests { - CMatrix A; - Complex128[][] aEntries; - - CsrMatrix B; - Shape bShape; - double[] bEntries; - int[] bRowPointers; - int[] bColIndices; - - CMatrix exp; - Complex128[][] expEntries; - - - @Test - void standardTests() { - // ------------------------ Sub-case 1 ------------------------ - aEntries = new Complex128[][]{ - {new Complex128("0.1013+0.5667i"), new Complex128("0.56204+0.08795i"), new Complex128("0.26919+0.74589i"), new Complex128("0.40605+0.37181i"), new Complex128("0.41665+0.56373i")}, - {new Complex128("0.88631+0.70806i"), new Complex128("0.94873+0.70971i"), new Complex128("0.73508+0.92932i"), new Complex128("0.32551+0.08181i"), new Complex128("0.80165+0.87963i")}, - {new Complex128("0.01923+0.20639i"), new Complex128("0.01025+0.53356i"), new Complex128("0.77862+0.04428i"), new Complex128("0.24381+0.01189i"), new Complex128("0.5903+0.51795i")}, - {new Complex128("0.65994+0.40064i"), new Complex128("0.21257+0.16288i"), new Complex128("0.85927+0.11806i"), new Complex128("0.8716+0.70231i"), new Complex128("0.83819+0.21429i")}, - {new Complex128("0.29866+0.71364i"), new Complex128("0.46553+0.89626i"), new Complex128("0.25626+0.07154i"), new Complex128("0.27779+0.59077i"), new Complex128("0.3676+0.45885i")}}; - A = new CMatrix(aEntries); - bShape = new Shape(5, 5); - bEntries = new double[]{0.35336, 0.80623, 0.7923, 0.96503, 0.33155, 0.26233, 0.93305, 0.97519}; - bRowPointers = new int[]{0, 2, 3, 7, 8, 8}; - bColIndices = new int[]{1, 4, 4, 0, 1, 2, 3, 4}; - B = new CsrMatrix(bShape, bEntries, bRowPointers, bColIndices); - expEntries = new Complex128[][]{ - {new Complex128("0.2597764257+0.7198062267i"), new Complex128("0.1250453125+0.4475489415i"), new Complex128("0.0706166127+0.1956693237i"), new Complex128("0.2511677295+0.6959526645i"), new Complex128("0.9229512905+0.8891587199i")}, - {new Complex128("0.7093742524+0.8968216796i"), new Complex128("0.5569022755999999+0.5583161276i"), new Complex128("0.1928335364+0.2437885156i"), new Complex128("0.685866394+0.8671020260000001i"), new Complex128("1.7836825872+1.2129427407000002i")}, - {new Complex128("0.7513916586+0.0427315284i"), new Complex128("0.26494657380000003+0.0876110044i"), new Complex128("0.2042553846+0.0116159724i"), new Complex128("0.726491391+0.041315454i"), new Complex128("0.2613859518+0.6007324068000001i")}, - {new Complex128("0.8292213281+0.1139314418i"), new Complex128("0.5180873669+0.18071294340000002i"), new Complex128("0.2254122991+0.0309706798i"), new Complex128("0.8017418735+0.11015588300000001i"), new Complex128("1.5504582412+1.1369435001000001i")}, - {new Complex128("0.2472985878+0.06903824620000001i"), new Complex128("0.19049750059999998+0.27589091740000005i"), new Complex128("0.0672246858+0.018767088200000004i"), new Complex128("0.239103393+0.066750397i"), new Complex128("0.8805261008999998+1.8615777715i")}}; - exp = new CMatrix(expEntries); - assertEquals(exp, A.mult(B)); - - // ------------------------ Sub-case 2 ------------------------ - aEntries = new Complex128[][]{ - {new Complex128("0.66805+0.1356i"), new Complex128("0.02304+0.45698i"), new Complex128("0.21488+0.42425i")}, - {new Complex128("0.41083+0.05417i"), new Complex128("0.73696+0.55199i"), new Complex128("0.14116+0.82652i")}, - {new Complex128("0.37502+0.66468i"), new Complex128("0.0909+0.08818i"), new Complex128("0.67084+0.0016i")}, - {new Complex128("0.57517+0.93909i"), new Complex128("0.42799+0.32726i"), new Complex128("0.92712+0.44066i")}, - {new Complex128("0.9064+0.65531i"), new Complex128("0.12811+0.44938i"), new Complex128("0.3343+0.45097i")}, - {new Complex128("0.49565+0.04573i"), new Complex128("0.84561+0.56477i"), new Complex128("0.5758+0.99876i")}, - {new Complex128("0.82457+0.04468i"), new Complex128("0.73855+0.10166i"), new Complex128("0.23595+0.92517i")}, - {new Complex128("0.93964+0.6983i"), new Complex128("0.80439+0.29492i"), new Complex128("0.15538+0.7558i")}, - {new Complex128("0.17902+0.59831i"), new Complex128("0.78638+0.43736i"), new Complex128("0.33244+0.23329i")}, - {new Complex128("0.6007+0.16697i"), new Complex128("0.54194+0.05845i"), new Complex128("0.64488+0.21455i")}, - {new Complex128("0.26206+0.22261i"), new Complex128("0.4165+0.64287i"), new Complex128("0.27519+0.17844i")}}; - A = new CMatrix(aEntries); - - bShape = new Shape(3, 10); - bEntries = new double[]{0.58579, 0.44167, 0.79592}; - bRowPointers = new int[]{0, 1, 3, 3}; - bColIndices = new int[]{1, 4, 5}; - B = new CsrMatrix(bShape, bEntries, bRowPointers, bColIndices); - - expEntries = new Complex128[][]{ - {new Complex128("0.0"), new Complex128("0.39133700950000005+0.07943312400000001i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.010176076800000001+0.20183435660000001i"), new Complex128("0.0183379968+0.3637195216i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0")}, - {new Complex128("0.0"), new Complex128("0.2406601057+0.031732244300000004i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.32549312319999996+0.24379742329999998i"), new Complex128("0.5865612031999999+0.43933988079999997i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0")}, - {new Complex128("0.0"), new Complex128("0.21968296580000002+0.38936289720000006i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.040147802999999996+0.038946460599999996i"), new Complex128("0.072349128+0.0701842256i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0")}, - {new Complex128("0.0"), new Complex128("0.3369288343+0.5501095311i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.1890303433+0.1445409242i"), new Complex128("0.34064580079999995+0.2604727792i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0")}, - {new Complex128("0.0"), new Complex128("0.5309600560000001+0.3838740449i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0565823437+0.1984776646i"), new Complex128("0.10196531119999999+0.3576705296i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0")}, - {new Complex128("0.0"), new Complex128("0.2903468135+0.026788176700000003i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.3734805687+0.2494419659i"), new Complex128("0.6730379111999999+0.44951173839999997i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0")}, - {new Complex128("0.0"), new Complex128("0.4830248603+0.0261730972i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.3261953785+0.0449001722i"), new Complex128("0.587826716+0.08091322719999999i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0")}, - {new Complex128("0.0"), new Complex128("0.5504317156+0.40905715700000006i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.35527493130000004+0.13025731640000002i"), new Complex128("0.6402300888+0.2347327264i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0")}, - {new Complex128("0.0"), new Complex128("0.10486812580000002+0.35048401490000003i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.3473204546+0.19316879120000002i"), new Complex128("0.6258955695999999+0.3481035712i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0")}, - {new Complex128("0.0"), new Complex128("0.35188405300000003+0.09780935630000001i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.2393586398+0.025815611500000002i"), new Complex128("0.43134088479999994+0.046521524i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0")}, - {new Complex128("0.0"), new Complex128("0.15351212740000003+0.1304027119i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.183955555+0.2839363929i"), new Complex128("0.33150068+0.5116730904i"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0"), new Complex128("0.0")}}; - exp = new CMatrix(expEntries); - - assertEquals(exp, A.mult(B)); - - // ------------------------ Sub-case 3 ------------------------ - A = new CMatrix(24, 516); - B = new CsrMatrix(15, 12); - assertThrows(LinearAlgebraException.class, ()->A.mult(B)); - } -} diff --git a/src/test/java/org/flag4j/linalg/decompositions/RealBalanceTest.java b/src/test/java/org/flag4j/linalg/decompositions/RealBalanceTest.java new file mode 100644 index 000000000..9413e2c2a --- /dev/null +++ b/src/test/java/org/flag4j/linalg/decompositions/RealBalanceTest.java @@ -0,0 +1,221 @@ +package org.flag4j.linalg.decompositions; + +import org.flag4j.arrays.Shape; +import org.flag4j.arrays.dense.Matrix; +import org.flag4j.linalg.decompositions.balance.RealBalancer; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class RealBalanceTest { + static Shape aShape; + static double[] aData; + static Matrix a; + + static Shape permuteShape; + static double[] permuteData; + static Matrix permute; + static Matrix permuteAct; + + static Shape scaleShape; + static double[] scaleData; + static Matrix scale; + static Matrix scaleAct; + + static Shape permuteScaleShape; + static double[] permuteScaleData; + static Matrix permuteScale; + static Matrix permuteScaleAct; + + static RealBalancer scaler; + static RealBalancer permutor; + static RealBalancer permutorScaler; + + static void applyBalancers() { + permuteAct = permutor.decompose(a).getB(); + scaleAct = scaler.decompose(a).getB(); + permuteScaleAct = permutorScaler.decompose(a).getB(); + } + + + @BeforeAll + static void setUp() { + permutor = new RealBalancer(true, false); + scaler = new RealBalancer(false, true); + permutorScaler = new RealBalancer(true, true); + } + + + @Test + void testRealBalance() { + // ----------------- Sub-case 1 ----------------- + aShape = new Shape(5, 5); + aData = new double[]{0.0, 0.0, 0.0, 100.2331, -140.0, 0.0, 0.0, 1.2, 2.54, 142.0, 0.0, 3.4, 0.0, 4.12, + -10022.2212, 0.0, 0.0, 0.0, 10.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; + a = new Matrix(aShape, aData); + + permuteShape = new Shape(5, 5); + permuteData = new double[]{0.0, 3.4, 0.0, 4.12, -10022.2212, 1.2, 0.0, 0.0, 2.54, 142.0, 0.0, 0.0, 0.0, 100.2331, + -140.0, 0.0, 0.0, 0.0, 10.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; + permute = new Matrix(permuteShape, permuteData); + + scaleShape = new Shape(5, 5); + scaleData = new double[]{0.0, 0.0, 0.0, 12.5291375, -140.0, 0.0, 0.0, 2.4, 0.00031005859375, 0.138671875, 0.0, + 1.7, 0.0, 0.00025146484375, -4.8936626953125, 0.0, 0.0, 0.0, 10.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; + scale = new Matrix(scaleShape, scaleData); + + permuteScaleShape = new Shape(5, 5); + permuteScaleData = new double[]{0.0, 1.7, 0.0, 2.06, -5011.1106, 2.4, 0.0, 0.0, 2.54, 142.0, 0.0, 0.0, 0.0, 100.2331, + -140.0, 0.0, 0.0, 0.0, 10.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}; + permuteScale = new Matrix(permuteScaleShape, permuteScaleData); + + applyBalancers(); + + assertEquals(permute, permuteAct); + assertEquals(scale, scaleAct); + assertEquals(permuteScale, permuteScaleAct); + + // ----------------- Sub-case 2 ----------------- + aShape = new Shape(8, 8); + aData = new double[]{6126.624252232483, 5230.904223529219, 4678.687283176287, 6395.835360234681, 5878.74954814557, + 7661.393200389235, 1826.4770608063936, 1726.6184250862443, 4122.877054785548, 1567.8469057995724, + 1744.7348734829984, 4521.044175276936, 2580.3262607639517, 5347.3621114901225, 2831.919283087424, + 6653.323956023176, 228.24242258927816, 237.69014045114713, 413.4411312454667, -7.448499599173523, + 678.4904375130443, 531.07217089675, 1012.9861688923027, 1016.378925274297, 1996.796215054365, 3739.871106883846, + 3730.836820889853, 3176.9808955419717, 986.9301184998963, 1315.6826089152178, 274.6334002238601, 508.9702755605151, + 97.04662999347228, 4135.047978558794, 267.844896567297, 112.08773771250591, 4308.866135869701, 2458.791420115115, + 3168.1532792779053, 4417.578092703147, 5062.9928116183555, 888.9483506139015, 1633.3083054064934, 1468.33492214697, + 4185.479384919863, 4687.930261898332, 3544.9689149538503, 1394.9297307361285, 622.9511113246125, 302.6582576726524, + 4511.255493925161, 1317.27667526414, 358.102095933098, 4519.853248988414, 4755.999451680448, 4168.373649093314, + 2604.2069814615215, 2041.019729854635, 866.1523849592389, 4646.076957590683, 704.2815854464203, 1084.4385900164125, + 2294.4890418962063, 6013.927731024033}; + a = new Matrix(aShape, aData); + + permuteShape = new Shape(8, 8); + permuteData = new double[]{6126.624252232483, 5230.904223529219, 4678.687283176287, 6395.835360234681, 5878.74954814557, + 7661.393200389235, 1826.4770608063936, 1726.6184250862443, 4122.877054785548, 1567.8469057995724, 1744.7348734829984, + 4521.044175276936, 2580.3262607639517, 5347.3621114901225, 2831.919283087424, 6653.323956023176, 228.24242258927816, + 237.69014045114713, 413.4411312454667, -7.448499599173523, 678.4904375130443, 531.07217089675, 1012.9861688923027, + 1016.378925274297, 1996.796215054365, 3739.871106883846, 3730.836820889853, 3176.9808955419717, 986.9301184998963, + 1315.6826089152178, 274.6334002238601, 508.9702755605151, 97.04662999347228, 4135.047978558794, 267.844896567297, + 112.08773771250591, 4308.866135869701, 2458.791420115115, 3168.1532792779053, 4417.578092703147, 5062.9928116183555, + 888.9483506139015, 1633.3083054064934, 1468.33492214697, 4185.479384919863, 4687.930261898332, 3544.9689149538503, + 1394.9297307361285, 622.9511113246125, 302.6582576726524, 4511.255493925161, 1317.27667526414, 358.102095933098, + 4519.853248988414, 4755.999451680448, 4168.373649093314, 2604.2069814615215, 2041.019729854635, 866.1523849592389, + 4646.076957590683, 704.2815854464203, 1084.4385900164125, 2294.4890418962063, 6013.927731024033}; + permute = new Matrix(permuteShape, permuteData); + + scaleShape = new Shape(8, 8); + scaleData = new double[]{6126.624252232483, 5230.904223529219, 2339.3436415881433, 6395.835360234681, 5878.74954814557, + 7661.393200389235, 1826.4770608063936, 1726.6184250862443, 4122.877054785548, 1567.8469057995724, 872.3674367414992, + 4521.044175276936, 2580.3262607639517, 5347.3621114901225, 2831.919283087424, 6653.323956023176, 456.4848451785563, + 475.38028090229426, 413.4411312454667, -14.896999198347046, 1356.9808750260886, 1062.1443417935, 2025.9723377846053, + 2032.757850548594, 1996.796215054365, 3739.871106883846, 1865.4184104449264, 3176.9808955419717, 986.9301184998963, + 1315.6826089152178, 274.6334002238601, 508.9702755605151, 97.04662999347228, 4135.047978558794, 133.9224482836485, + 112.08773771250591, 4308.866135869701, 2458.791420115115, 3168.1532792779053, 4417.578092703147, 5062.9928116183555, + 888.9483506139015, 816.6541527032467, 1468.33492214697, 4185.479384919863, 4687.930261898332, 3544.9689149538503, + 1394.9297307361285, 622.9511113246125, 302.6582576726524, 2255.6277469625807, 1317.27667526414, 358.102095933098, + 4519.853248988414, 4755.999451680448, 4168.373649093314, 2604.2069814615215, 2041.019729854635, 433.07619247961946, + 4646.076957590683, 704.2815854464203, 1084.4385900164125, 2294.4890418962063, 6013.927731024033}; + scale = new Matrix(scaleShape, scaleData); + + permuteScaleShape = new Shape(8, 8); + permuteScaleData = new double[]{6126.624252232483, 5230.904223529219, 2339.3436415881433, 6395.835360234681, 5878.74954814557, + 7661.393200389235, 1826.4770608063936, 1726.6184250862443, 4122.877054785548, 1567.8469057995724, 872.3674367414992, + 4521.044175276936, 2580.3262607639517, 5347.3621114901225, 2831.919283087424, 6653.323956023176, 456.4848451785563, + 475.38028090229426, 413.4411312454667, -14.896999198347046, 1356.9808750260886, 1062.1443417935, 2025.9723377846053, + 2032.757850548594, 1996.796215054365, 3739.871106883846, 1865.4184104449264, 3176.9808955419717, 986.9301184998963, + 1315.6826089152178, 274.6334002238601, 508.9702755605151, 97.04662999347228, 4135.047978558794, 133.9224482836485, + 112.08773771250591, 4308.866135869701, 2458.791420115115, 3168.1532792779053, 4417.578092703147, 5062.9928116183555, + 888.9483506139015, 816.6541527032467, 1468.33492214697, 4185.479384919863, 4687.930261898332, 3544.9689149538503, + 1394.9297307361285, 622.9511113246125, 302.6582576726524, 2255.6277469625807, 1317.27667526414, 358.102095933098, + 4519.853248988414, 4755.999451680448, 4168.373649093314, 2604.2069814615215, 2041.019729854635, 433.07619247961946, + 4646.076957590683, 704.2815854464203, 1084.4385900164125, 2294.4890418962063, 6013.927731024033}; + permuteScale = new Matrix(permuteScaleShape, permuteScaleData); + + applyBalancers(); + + assertEquals(permute, permuteAct); + assertEquals(scale, scaleAct); + assertEquals(permuteScale, permuteScaleAct); + + // ----------------- Sub-case 3 ----------------- + aShape = new Shape(11, 11); + aData = new double[]{1e-08, 0.02, 0.0, 5e-05, 30000.0, 0.0, -1000.0, 0.0, 70.0, 0.0, 9000000.0, 5000.0, 100.0, -0.002, 0.0, + 0.0, 0.0, 9e-09, 0.0, 0.0, -300000.0, 200.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1e-09, + 100000000.0, -10.0, 0.0, 0.0, 40.0, 0.0, 0.0, 0.0, -0.0003, 40000.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, -400.0, 0.0, 10.0, + 200.0, 0.0, 200000.0, 0.0, 1e-05, 0.0, 0.0, 0.0, 0.0, 5e-06, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2000000000.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 100.0, 0.0, 0.0, 0.0, 0.005, 0.0002, 1e-07, 0.0, 0.0, 10.0, 0.0, -90000.0, 0.0, 3e-06, 0.0, + 0.0, 0.0, 300000000.0, 0.0, 0.0, 0.0, 0.0005, 0.0, 0.0, 10000.0, 0.0, 0.0, 0.0, 0.0, 1e-08, 200000.0, 7000.0, 0.0, + 0.0, 0.2, 0.0, 0.0, 1000.0, 0.0, 0.0, -400.0, 0.01}; + a = new Matrix(aShape, aData); + + permuteShape = new Shape(11, 11); + permuteData = new double[]{0.0, 0.0, 5e-06, 0.0, 1e-05, 200.0, 0.0, 0.0, 0.0, 200000.0, 0.0, 0.0, 100.0, -300000.0, 0.0, 0.0, + 5000.0, 200.0, 0.0, 0.0, -0.002, 9e-09, 0.0, 0.0005, 1e-08, 0.0, 10000.0, 0.0, 200000.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 100000000.0, -10.0, 0.0, 0.0, 40.0, 0.0, 1e-09, 0.0, 0.0, 40000.0, 0.0, 1.0, 0.0, -0.0003, 10.0, 0.0, -400.0, + 0.0, 0.0, 0.0, 0.02, 0.0, 5e-05, 30000.0, 1e-08, 9000000.0, 0.0, 70.0, 0.0, -1000.0, 0.0, 0.0, -400.0, 0.2, 0.0, + 7000.0, 0.01, 0.0, 0.0, 0.0, 1000.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0002, 1e-07, 100.0, 0.005, 0.0, 0.0, 0.0, + 0.0, 3e-06, 10.0, 0.0, 0.0, 300000000.0, -90000.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2000000000.0}; + permute = new Matrix(permuteShape, permuteData); + + scaleShape = new Shape(11, 11); + scaleData = new double[]{1e-08, 0.005, 0.0, 1.5625e-06, 7500.0, 0.0, -31.25, 0.0, 2.1875, 0.0, 281250.0, 20000.0, 100.0, + -0.00025, 0.0, 0.0, 0.0, 1.125e-09, 0.0, 0.0, -75000.0, 25.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 1e-09, 100000000.0, -80.0, 0.0, 0.0, 80.0, 0.0, 0.0, 0.0, -0.0012, 40000.0, 0.0, 0.125, 0.0, 0.0, 0.0, 0.0, + -50.0, 0.0, 1.25, 6400.0, 0.0, 200000.0, 0.0, 8e-05, 0.0, 0.0, 0.0, 0.0, 1e-05, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 2000000000.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 50.0, 0.0, 0.0, 0.0, 0.0025, 0.0002, 5e-08, 0.0, 0.0, 320.0, 0.0, + -90000.0, 0.0, 2.4e-05, 0.0, 0.0, 0.0, 300000000.0, 0.0, 0.0, 0.0, 0.002, 0.0, 0.0, 40000.0, 0.0, 0.0, 0.0, 0.0, + 1e-08, 100000.0, 224000.0, 0.0, 0.0, 0.2, 0.0, 0.0, 1000.0, 0.0, 0.0, -800.0, 0.01}; + scale = new Matrix(scaleShape, scaleData); + + permuteScaleShape = new Shape(11, 11); + permuteScaleData = new double[]{0.0, 0.0, 1e-05, 0.0, 8e-05, 6400.0, 0.0, 0.0, 0.0, 200000.0, 0.0, 0.0, 100.0, -75000.0, 0.0, + 0.0, 20000.0, 25.0, 0.0, 0.0, -0.00025, 1.125e-09, 0.0, 0.002, 1e-08, 0.0, 40000.0, 0.0, 100000.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 100000000.0, -80.0, 0.0, 0.0, 0.00244140625, 0.0, 1e-09, 0.0, 0.0, 40000.0, 0.0, 0.125, 0.0, -0.0012, + 1.25, 0.0, -50.0, 0.0, 0.0, 0.0, 0.005, 0.0, 1.5625e-06, 7500.0, 1e-08, 281250.0, 0.0, 2.1875, 0.0, -31.25, 0.0, 0.0, + -800.0, 0.2, 0.0, 224000.0, 0.01, 0.0, 0.0, 0.0, 1000.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0002, 0.0016384, + 1638400.0, 81.92, 0.0, 0.0, 0.0, 0.0, 2.4e-05, 320.0, 0.0, 0.0, 300000000.0, -90000.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2000000000.0}; + permuteScale = new Matrix(permuteScaleShape, permuteScaleData); + + applyBalancers(); + + assertEquals(permute, permuteAct); + assertEquals(scale, scaleAct); + assertEquals(permuteScale, permuteScaleAct); + + // ----------------- Sub-case 4 ----------------- + aShape = new Shape(8, 8); + aData = new double[]{3.0, 1.0, 0.0, 1.0, 0.0, 5.0, 0.0, 0.0, 0.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 3.0, 0.0, 4.0, 1.0, 0.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0}; + a = new Matrix(aShape, aData); + + permuteShape = new Shape(8, 8); + permuteData = new double[]{3.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, + 1.0, 1.0, 0.0, 0.0, 5.0, 0.0, 0.0, 3.0, 1.0, 0.0, 4.0, 0.0, 5.0, 0.0, 0.0, 0.0, 0.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.0}; + permute = new Matrix(permuteShape, permuteData); + + scaleShape = new Shape(8, 8); + scaleData = new double[]{3.0, 1.0, 0.0, 2.0, 0.0, 5.0, 0.0, 0.0, 0.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 1.5, 0.0, 2.0, 1.0, 0.0, 2.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.5, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 3.0, 0.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0}; + scale = new Matrix(scaleShape, scaleData); + + permuteScaleShape = new Shape(8, 8); + permuteScaleData = new double[]{3.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 3.0, 1.0, 1.0, 0.0, 0.0, 5.0, 0.0, 0.0, 3.0, 1.0, 0.0, 4.0, 0.0, 5.0, 0.0, 0.0, 0.0, 0.0, 6.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.0}; + permuteScale = new Matrix(permuteScaleShape, permuteScaleData); + + applyBalancers(); + + assertEquals(permute, permuteAct); + assertEquals(scale, scaleAct); + assertEquals(permuteScale, permuteScaleAct); + } +} diff --git a/src/test/java/org/flag4j/linalg/ops/common/real/MatrixNormTests.java b/src/test/java/org/flag4j/linalg/ops/common/real/MatrixNormTests.java index 64a493a55..b77f23383 100644 --- a/src/test/java/org/flag4j/linalg/ops/common/real/MatrixNormTests.java +++ b/src/test/java/org/flag4j/linalg/ops/common/real/MatrixNormTests.java @@ -2,6 +2,7 @@ import org.flag4j.arrays.dense.Matrix; import org.flag4j.linalg.MatrixNorms; +import org.flag4j.util.exceptions.LinearAlgebraException; import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -38,11 +39,13 @@ void infNormTestCase() { @Test void lpNormTestCase() { // ---------------- Sub-case 1 ---------------- - aEntries = new double[][]{{1.1234, 99.234, 0.000123, -9.1}, {-932.45, 551.35, -0.92342, 124.5}}; + aEntries = new double[][]{ + {1.1234, 99.234, 0.000123, -9.1}, + {-932.45, 551.35, -0.92342, 124.5}}; A = new Matrix(aEntries); expNorm = 1094.9348777384303; - assertEquals(expNorm, MatrixNorms.norm(A)); + assertEquals(expNorm, MatrixNorms.norm(A), 1.0e-12); // ---------------- Sub-case 2 ---------------- aEntries = new double[][]{{1.1234, 99.234, 0.000123}, {-932.45, 551.35, -0.92342}, {123.445, 0.00013, 0}}; @@ -52,18 +55,24 @@ void lpNormTestCase() { assertEquals(expNorm, MatrixNorms.norm(A)); // ---------------- Sub-case 3 ---------------- - aEntries = new double[][]{{1.1234, 99.234, 0.000123}, {-932.45, 551.35, -0.92342}, {123.445, 0.00013, 0}}; + aEntries = new double[][]{ + {1.1234, 99.234, 0.000123}, + {-932.45, 551.35, -0.92342}, + {123.445, 0.00013, 0}}; A = new Matrix(aEntries); - expNorm = 1094.7776004801563; + expNorm = 1089.5874942580217; - assertEquals(expNorm, MatrixNorms.norm(A,2)); + assertEquals(expNorm, MatrixNorms.inducedNorm(A,2), 1.0e-12); // ---------------- Sub-case 4 ---------------- - aEntries = new double[][]{{1.1234, 99.234, 0.000123}, {-932.45, 551.35, -0.92342}, {123.445, 0.00013, 0}}; + aEntries = new double[][]{ + {1.1234, 99.234, 0.000123}, + {-932.45, 551.35, -0.92342}, + {123.445, 0.00013, 0}}; A = new Matrix(aEntries); - expNorm = 1708.5260729999998; + expNorm = 1057.0184; - assertEquals(expNorm, MatrixNorms.norm(A,1)); + assertEquals(expNorm, MatrixNorms.inducedNorm(A,1)); // ---------------- Sub-case 5 ---------------- aEntries = new double[][]{{1.1234, 99.234, 0.000123}, {-932.45, 551.35, -0.92342}, {123.445, 0.00013, 0}}; @@ -90,18 +99,18 @@ void lpNormTestCase() { aEntries = new double[][]{{1.1234, 99.234, 0.000123}, {-932.45, 551.35, -0.92342}, {123.445, 0.00013, 0}}; A = new Matrix(aEntries); - assertThrows(IllegalArgumentException.class, ()-> MatrixNorms.norm(A,0)); + assertThrows(LinearAlgebraException.class, ()-> MatrixNorms.inducedNorm(A,0)); // ---------------- Sub-case 10 ---------------- aEntries = new double[][]{{1.1234, 99.234, 0.000123}, {-932.45, 551.35, -0.92342}, {123.445, 0.00013, 0}}; A = new Matrix(aEntries); - assertThrows(IllegalArgumentException.class, ()-> MatrixNorms.norm(A, 0, 1)); + assertThrows(LinearAlgebraException.class, ()-> MatrixNorms.norm(A, 0, 1)); // ---------------- Sub-case 11 ---------------- aEntries = new double[][]{{1.1234, 99.234, 0.000123}, {-932.45, 551.35, -0.92342}, {123.445, 0.00013, 0}}; A = new Matrix(aEntries); - assertThrows(IllegalArgumentException.class, ()-> MatrixNorms.norm(A,1, -12)); + assertThrows(LinearAlgebraException.class, ()-> MatrixNorms.norm(A,1, 0)); } } diff --git a/src/test/java/org/flag4j/linalg/ops/dense/complex/ComplexDenseTransposeTests.java b/src/test/java/org/flag4j/linalg/ops/dense/complex/ComplexDenseTransposeTests.java index 2a1e4e5ee..00b13bb62 100644 --- a/src/test/java/org/flag4j/linalg/ops/dense/complex/ComplexDenseTransposeTests.java +++ b/src/test/java/org/flag4j/linalg/ops/dense/complex/ComplexDenseTransposeTests.java @@ -1,14 +1,16 @@ package org.flag4j.linalg.ops.dense.complex; import org.flag4j.algebraic_structures.Complex128; +import org.flag4j.linalg.ops.dense.ring_ops.DenseRingHermitianTranspose; import org.junit.jupiter.api.Test; -import static org.flag4j.linalg.ops.dense.field_ops.DenseFieldTranspose.*; +import static org.flag4j.linalg.ops.dense.DenseTranspose.*; import static org.junit.jupiter.api.Assertions.assertArrayEquals; class ComplexDenseTransposeTests { Complex128[] A; Complex128[] expTranspose, expTransposeH; + Complex128[] act; int numRows, numCols; @@ -35,15 +37,26 @@ void transposeTestCase() { new Complex128(3).conj(), new Complex128(7).conj(), new Complex128(11).conj(), new Complex128(4).conj(), new Complex128(8).conj(), new Complex128(12).conj()}; - assertArrayEquals(expTranspose, standardMatrix(A, numRows, numCols)); - assertArrayEquals(expTranspose, blockedMatrix(A, numRows, numCols)); - assertArrayEquals(expTranspose, standardMatrixConcurrent(A, numRows, numCols)); - assertArrayEquals(expTranspose, blockedMatrixConcurrent(A, numRows, numCols)); + assertArrayEquals(expTranspose, standardMatrix(A, numRows, numCols, null)); + assertArrayEquals(expTranspose, blockedMatrix(A, numRows, numCols, null)); + assertArrayEquals(expTranspose, standardMatrixConcurrent(A, numRows, numCols, null)); + assertArrayEquals(expTranspose, blockedMatrixConcurrent(A, numRows, numCols, null)); - assertArrayEquals(expTransposeH, standardMatrixHerm(A, numRows, numCols)); - assertArrayEquals(expTranspose, blockedMatrixHerm(A, numRows, numCols)); - assertArrayEquals(expTranspose, standardMatrixConcurrentHerm(A, numRows, numCols)); - assertArrayEquals(expTranspose, blockedMatrixConcurrentHerm(A, numRows, numCols)); + act = new Complex128[A.length]; + DenseRingHermitianTranspose.standardMatrixHerm(A, numRows, numCols, act); + assertArrayEquals(expTransposeH, act); + + act = new Complex128[A.length]; + DenseRingHermitianTranspose.blockedMatrixHerm(A, numRows, numCols, act); + assertArrayEquals(expTranspose, act); + + act = new Complex128[A.length]; + DenseRingHermitianTranspose.standardMatrixConcurrentHerm(A, numRows, numCols, act); + assertArrayEquals(expTranspose, act); + + act = new Complex128[A.length]; + DenseRingHermitianTranspose.blockedMatrixConcurrentHerm(A, numRows, numCols, act); + assertArrayEquals(expTranspose, act); // ------------- Sub-case 2 --------------- numRows = 6; @@ -63,15 +76,26 @@ void transposeTestCase() { new Complex128(2).conj(), new Complex128(4).conj(), new Complex128(6).conj(), new Complex128(8).conj(), new Complex128(10).conj(), new Complex128(12).conj()}; - assertArrayEquals(expTranspose, standardMatrix(A, numRows, numCols)); - assertArrayEquals(expTranspose, blockedMatrix(A, numRows, numCols)); - assertArrayEquals(expTranspose, standardMatrixConcurrent(A, numRows, numCols)); - assertArrayEquals(expTranspose, blockedMatrixConcurrent(A, numRows, numCols)); + assertArrayEquals(expTranspose, standardMatrix(A, numRows, numCols, null)); + assertArrayEquals(expTranspose, blockedMatrix(A, numRows, numCols, null)); + assertArrayEquals(expTranspose, standardMatrixConcurrent(A, numRows, numCols, null)); + assertArrayEquals(expTranspose, blockedMatrixConcurrent(A, numRows, numCols, null)); + + act = new Complex128[A.length]; + DenseRingHermitianTranspose.standardMatrixHerm(A, numRows, numCols, act); + assertArrayEquals(expTransposeH, act); - assertArrayEquals(expTransposeH, standardMatrixHerm(A, numRows, numCols)); - assertArrayEquals(expTranspose, blockedMatrixHerm(A, numRows, numCols)); - assertArrayEquals(expTranspose, standardMatrixConcurrentHerm(A, numRows, numCols)); - assertArrayEquals(expTranspose, blockedMatrixConcurrentHerm(A, numRows, numCols)); + act = new Complex128[A.length]; + DenseRingHermitianTranspose.blockedMatrixHerm(A, numRows, numCols, act); + assertArrayEquals(expTranspose, act); + + act = new Complex128[A.length]; + DenseRingHermitianTranspose.standardMatrixConcurrentHerm(A, numRows, numCols, act); + assertArrayEquals(expTranspose, act); + + act = new Complex128[A.length]; + DenseRingHermitianTranspose.blockedMatrixConcurrentHerm(A, numRows, numCols, act); + assertArrayEquals(expTranspose, act); // ------------- Sub-case 3 --------------- numRows = 12; @@ -91,15 +115,26 @@ void transposeTestCase() { new Complex128(7).conj(), new Complex128(8).conj(), new Complex128(9).conj(), new Complex128(10).conj(), new Complex128(11).conj(), new Complex128(12).conj()}; - assertArrayEquals(expTranspose, standardMatrix(A, numRows, numCols)); - assertArrayEquals(expTranspose, blockedMatrix(A, numRows, numCols)); - assertArrayEquals(expTranspose, standardMatrixConcurrent(A, numRows, numCols)); - assertArrayEquals(expTranspose, blockedMatrixConcurrent(A, numRows, numCols)); + assertArrayEquals(expTranspose, standardMatrix(A, numRows, numCols, null)); + assertArrayEquals(expTranspose, blockedMatrix(A, numRows, numCols, null)); + assertArrayEquals(expTranspose, standardMatrixConcurrent(A, numRows, numCols, null)); + assertArrayEquals(expTranspose, blockedMatrixConcurrent(A, numRows, numCols, null)); + + act = new Complex128[A.length]; + DenseRingHermitianTranspose.standardMatrixHerm(A, numRows, numCols, act); + assertArrayEquals(expTransposeH, act); + + act = new Complex128[A.length]; + DenseRingHermitianTranspose.blockedMatrixHerm(A, numRows, numCols, act); + assertArrayEquals(expTranspose, act); - assertArrayEquals(expTransposeH, standardMatrixHerm(A, numRows, numCols)); - assertArrayEquals(expTranspose, blockedMatrixHerm(A, numRows, numCols)); - assertArrayEquals(expTranspose, standardMatrixConcurrentHerm(A, numRows, numCols)); - assertArrayEquals(expTranspose, blockedMatrixConcurrentHerm(A, numRows, numCols)); + act = new Complex128[A.length]; + DenseRingHermitianTranspose.standardMatrixConcurrentHerm(A, numRows, numCols, act); + assertArrayEquals(expTranspose, act); + + act = new Complex128[A.length]; + DenseRingHermitianTranspose.blockedMatrixConcurrentHerm(A, numRows, numCols, act); + assertArrayEquals(expTranspose, act); // ------------- Sub-case 3 --------------- numRows = 3; @@ -117,15 +152,26 @@ void transposeTestCase() { new Complex128(2).conj(), new Complex128(5).conj(), new Complex128(8).conj(), new Complex128(3).conj(), new Complex128(6).conj(), new Complex128(9).conj()}; - assertArrayEquals(expTranspose, standardMatrix(A, numRows, numCols)); - assertArrayEquals(expTranspose, blockedMatrix(A, numRows, numCols)); - assertArrayEquals(expTranspose, standardMatrixConcurrent(A, numRows, numCols)); - assertArrayEquals(expTranspose, blockedMatrixConcurrent(A, numRows, numCols)); + assertArrayEquals(expTranspose, standardMatrix(A, numRows, numCols, null)); + assertArrayEquals(expTranspose, blockedMatrix(A, numRows, numCols, null)); + assertArrayEquals(expTranspose, standardMatrixConcurrent(A, numRows, numCols, null)); + assertArrayEquals(expTranspose, blockedMatrixConcurrent(A, numRows, numCols, null)); + + act = new Complex128[A.length]; + DenseRingHermitianTranspose.standardMatrixHerm(A, numRows, numCols, act); + assertArrayEquals(expTransposeH, act); + + act = new Complex128[A.length]; + DenseRingHermitianTranspose.blockedMatrixHerm(A, numRows, numCols, act); + assertArrayEquals(expTranspose, act); + + act = new Complex128[A.length]; + DenseRingHermitianTranspose.standardMatrixConcurrentHerm(A, numRows, numCols, act); + assertArrayEquals(expTranspose, act); - assertArrayEquals(expTransposeH, standardMatrixHerm(A, numRows, numCols)); - assertArrayEquals(expTranspose, blockedMatrixHerm(A, numRows, numCols)); - assertArrayEquals(expTranspose, standardMatrixConcurrentHerm(A, numRows, numCols)); - assertArrayEquals(expTranspose, blockedMatrixConcurrentHerm(A, numRows, numCols)); + act = new Complex128[A.length]; + DenseRingHermitianTranspose.blockedMatrixConcurrentHerm(A, numRows, numCols, act); + assertArrayEquals(expTranspose, act); // ------------- Sub-case 3 --------------- numRows = 1; @@ -134,14 +180,25 @@ void transposeTestCase() { expTranspose = new Complex128[]{new Complex128(1.13)}; expTransposeH = new Complex128[]{new Complex128(1.13).conj()}; - assertArrayEquals(expTranspose, standardMatrix(A, numRows, numCols)); - assertArrayEquals(expTranspose, blockedMatrix(A, numRows, numCols)); - assertArrayEquals(expTranspose, standardMatrixConcurrent(A, numRows, numCols)); - assertArrayEquals(expTranspose, blockedMatrixConcurrent(A, numRows, numCols)); + assertArrayEquals(expTranspose, standardMatrix(A, numRows, numCols, null)); + assertArrayEquals(expTranspose, blockedMatrix(A, numRows, numCols, null)); + assertArrayEquals(expTranspose, standardMatrixConcurrent(A, numRows, numCols, null)); + assertArrayEquals(expTranspose, blockedMatrixConcurrent(A, numRows, numCols, null)); + + act = new Complex128[A.length]; + DenseRingHermitianTranspose.standardMatrixHerm(A, numRows, numCols, act); + assertArrayEquals(expTransposeH, act); + + act = new Complex128[A.length]; + DenseRingHermitianTranspose.blockedMatrixHerm(A, numRows, numCols, act); + assertArrayEquals(expTranspose, act); + + act = new Complex128[A.length]; + DenseRingHermitianTranspose.standardMatrixConcurrentHerm(A, numRows, numCols, act); + assertArrayEquals(expTranspose, act); - assertArrayEquals(expTransposeH, standardMatrixHerm(A, numRows, numCols)); - assertArrayEquals(expTranspose, blockedMatrixHerm(A, numRows, numCols)); - assertArrayEquals(expTranspose, standardMatrixConcurrentHerm(A, numRows, numCols)); - assertArrayEquals(expTranspose, blockedMatrixConcurrentHerm(A, numRows, numCols)); + act = new Complex128[A.length]; + DenseRingHermitianTranspose.blockedMatrixConcurrentHerm(A, numRows, numCols, act); + assertArrayEquals(expTranspose, act); } } diff --git a/src/test/java/org/flag4j/linalg/ops/dense/real/MatrixMultiplyTests.java b/src/test/java/org/flag4j/linalg/ops/dense/real/MatrixMultiplyTests.java index ad6cd1e1a..5f9620497 100644 --- a/src/test/java/org/flag4j/linalg/ops/dense/real/MatrixMultiplyTests.java +++ b/src/test/java/org/flag4j/linalg/ops/dense/real/MatrixMultiplyTests.java @@ -23,35 +23,35 @@ void squareTestCase() { B = new Matrix(entriesB); // ------------ Sub-case 1 ------------ - act = RealDenseMatrixMultiplication.standard(A.data, A.shape, B.data, B.shape); + act = RealDenseMatMult.standard(A.data, A.shape, B.data, B.shape); assertArrayEquals(exp, act); // ------------ Sub-case 2 ------------ - act = RealDenseMatrixMultiplication.reordered(A.data, A.shape, B.data, B.shape); + act = RealDenseMatMult.reordered(A.data, A.shape, B.data, B.shape); assertArrayEquals(exp, act); // ------------ Sub-case 3 ------------ - act = RealDenseMatrixMultiplication.blocked(A.data, A.shape, B.data, B.shape); + act = RealDenseMatMult.blocked(A.data, A.shape, B.data, B.shape); assertArrayEquals(exp, act); // ------------ Sub-case 4 ------------ - act = RealDenseMatrixMultiplication.blockedReordered(A.data, A.shape, B.data, B.shape); + act = RealDenseMatMult.blockedReordered(A.data, A.shape, B.data, B.shape); assertArrayEquals(exp, act); // ------------ Sub-case 5 ------------ - act = RealDenseMatrixMultiplication.concurrentStandard(A.data, A.shape, B.data, B.shape); + act = RealDenseMatMult.concurrentStandard(A.data, A.shape, B.data, B.shape); assertArrayEquals(exp, act); // ------------ Sub-case 6 ------------ - act = RealDenseMatrixMultiplication.concurrentReordered(A.data, A.shape, B.data, B.shape); + act = RealDenseMatMult.concurrentReordered(A.data, A.shape, B.data, B.shape); assertArrayEquals(exp, act); // ------------ Sub-case 7 ------------ - act = RealDenseMatrixMultiplication.concurrentBlocked(A.data, A.shape, B.data, B.shape); + act = RealDenseMatMult.concurrentBlocked(A.data, A.shape, B.data, B.shape); assertArrayEquals(exp, act); // ------------ Sub-case 8 ------------ - act = RealDenseMatrixMultiplication.concurrentBlockedReordered(A.data, A.shape, B.data, B.shape); + act = RealDenseMatMult.concurrentBlockedReordered(A.data, A.shape, B.data, B.shape); assertArrayEquals(exp, act); } @@ -74,35 +74,35 @@ void rectangleTestCase() { B = new Matrix(entriesB); // ------------ Sub-case 1 ------------ - act = RealDenseMatrixMultiplication.standard(A.data, A.shape, B.data, B.shape); + act = RealDenseMatMult.standard(A.data, A.shape, B.data, B.shape); assertArrayEquals(exp, act); // ------------ Sub-case 2 ------------ - act = RealDenseMatrixMultiplication.reordered(A.data, A.shape, B.data, B.shape); + act = RealDenseMatMult.reordered(A.data, A.shape, B.data, B.shape); assertArrayEquals(exp, act); // ------------ Sub-case 3 ------------ - act = RealDenseMatrixMultiplication.blocked(A.data, A.shape, B.data, B.shape); + act = RealDenseMatMult.blocked(A.data, A.shape, B.data, B.shape); assertArrayEquals(exp, act); // ------------ Sub-case 4 ------------ - act = RealDenseMatrixMultiplication.blockedReordered(A.data, A.shape, B.data, B.shape); + act = RealDenseMatMult.blockedReordered(A.data, A.shape, B.data, B.shape); assertArrayEquals(exp, act); // ------------ Sub-case 5 ------------ - act = RealDenseMatrixMultiplication.concurrentStandard(A.data, A.shape, B.data, B.shape); + act = RealDenseMatMult.concurrentStandard(A.data, A.shape, B.data, B.shape); assertArrayEquals(exp, act); // ------------ Sub-case 6 ------------ - act = RealDenseMatrixMultiplication.concurrentReordered(A.data, A.shape, B.data, B.shape); + act = RealDenseMatMult.concurrentReordered(A.data, A.shape, B.data, B.shape); assertArrayEquals(exp, act); // ------------ Sub-case 7 ------------ - act = RealDenseMatrixMultiplication.concurrentBlocked(A.data, A.shape, B.data, B.shape); + act = RealDenseMatMult.concurrentBlocked(A.data, A.shape, B.data, B.shape); assertArrayEquals(exp, act); // ------------ Sub-case 8 ------------ - act = RealDenseMatrixMultiplication.concurrentBlockedReordered(A.data, A.shape, B.data, B.shape); + act = RealDenseMatMult.concurrentBlockedReordered(A.data, A.shape, B.data, B.shape); assertArrayEquals(exp, act); } @@ -118,19 +118,19 @@ void columnVectorTestCase() { B = new Matrix(entriesB); // ------------ Sub-case 1 ------------ - act = RealDenseMatrixMultiplication.standardVector(A.data, A.shape, B.data, B.shape); + act = RealDenseMatMult.standardVector(A.data, A.shape, B.data, B.shape); assertArrayEquals(exp, act); // ------------ Sub-case 2 ------------ - act = RealDenseMatrixMultiplication.blockedVector(A.data, A.shape, B.data, B.shape); + act = RealDenseMatMult.blockedVector(A.data, A.shape, B.data, B.shape); assertArrayEquals(exp, act); // ------------ Sub-case 3 ------------ - act = RealDenseMatrixMultiplication.concurrentStandardVector(A.data, A.shape, B.data, B.shape); + act = RealDenseMatMult.concurrentStandardVector(A.data, A.shape, B.data, B.shape); assertArrayEquals(exp, act); // ------------ Sub-case 4 ------------ - act = RealDenseMatrixMultiplication.concurrentBlockedVector(A.data, A.shape, B.data, B.shape); + act = RealDenseMatMult.concurrentBlockedVector(A.data, A.shape, B.data, B.shape); assertArrayEquals(exp, act); } } diff --git a/src/test/java/org/flag4j/linalg/ops/dense/real/RealDenseMatMultTransposeTests.java b/src/test/java/org/flag4j/linalg/ops/dense/real/RealDenseMatMultTransposeTests.java index ed4174831..478351bdb 100644 --- a/src/test/java/org/flag4j/linalg/ops/dense/real/RealDenseMatMultTransposeTests.java +++ b/src/test/java/org/flag4j/linalg/ops/dense/real/RealDenseMatMultTransposeTests.java @@ -19,19 +19,19 @@ void squareTestCase() { exp = A.mult(B.T()).data; // ------------ Sub-case 1 ------------ - act = RealDenseMatrixMultTranspose.multTranspose(A.data, A.shape, B.data, B.shape); + act = RealDenseMatMultTranspose.multTranspose(A.data, A.shape, B.data, B.shape); assertArrayEquals(exp, act); // ------------ Sub-case 2 ------------ - act = RealDenseMatrixMultTranspose.multTransposeBlocked(A.data, A.shape, B.data, B.shape); + act = RealDenseMatMultTranspose.multTransposeBlocked(A.data, A.shape, B.data, B.shape); assertArrayEquals(exp, act); // ------------ Sub-case 3 ------------ - act = RealDenseMatrixMultTranspose.multTransposeConcurrent(A.data, A.shape, B.data, B.shape); + act = RealDenseMatMultTranspose.multTransposeConcurrent(A.data, A.shape, B.data, B.shape); assertArrayEquals(exp, act); // ------------ Sub-case 4 ------------ - act = RealDenseMatrixMultTranspose.multTransposeBlockedConcurrent(A.data, A.shape, B.data, B.shape); + act = RealDenseMatMultTranspose.multTransposeBlockedConcurrent(A.data, A.shape, B.data, B.shape); assertArrayEquals(exp, act); } @@ -52,19 +52,19 @@ void rectangleTestCase() { exp = A.mult(B.T()).data; // ------------ Sub-case 1 ------------ - act = RealDenseMatrixMultTranspose.multTranspose(A.data, A.shape, B.data, B.shape); + act = RealDenseMatMultTranspose.multTranspose(A.data, A.shape, B.data, B.shape); assertArrayEquals(exp, act); // ------------ Sub-case 2 ------------ - act = RealDenseMatrixMultTranspose.multTransposeBlocked(A.data, A.shape, B.data, B.shape); + act = RealDenseMatMultTranspose.multTransposeBlocked(A.data, A.shape, B.data, B.shape); assertArrayEquals(exp, act); // ------------ Sub-case 3 ------------ - act = RealDenseMatrixMultTranspose.multTransposeConcurrent(A.data, A.shape, B.data, B.shape); + act = RealDenseMatMultTranspose.multTransposeConcurrent(A.data, A.shape, B.data, B.shape); assertArrayEquals(exp, act); // ------------ Sub-case 4 ------------ - act = RealDenseMatrixMultTranspose.multTransposeBlockedConcurrent(A.data, A.shape, B.data, B.shape); + act = RealDenseMatMultTranspose.multTransposeBlockedConcurrent(A.data, A.shape, B.data, B.shape); assertArrayEquals(exp, act); } } diff --git a/src/test/java/org/flag4j/sparse_csr_complex_matrix/ComplexCsrDenseMatMultTests.java b/src/test/java/org/flag4j/sparse_csr_complex_matrix/ComplexCsrDenseMatMultTests.java deleted file mode 100644 index ec48df5b6..000000000 --- a/src/test/java/org/flag4j/sparse_csr_complex_matrix/ComplexCsrDenseMatMultTests.java +++ /dev/null @@ -1,91 +0,0 @@ -package org.flag4j.sparse_csr_complex_matrix; - -import org.flag4j.algebraic_structures.Complex128; -import org.flag4j.arrays.dense.CMatrix; -import org.flag4j.arrays.dense.Matrix; -import org.flag4j.arrays.sparse.CsrCMatrix; -import org.flag4j.util.exceptions.LinearAlgebraException; -import org.junit.jupiter.api.Test; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; - -class ComplexCsrDenseMatMultTests { - static CsrCMatrix A; - static CMatrix aDense; - static Complex128[][] aEntries; - static Matrix B; - static double[][] bEntries; - static CMatrix exp; - - private static void build(boolean... args) { - aDense = new CMatrix(aEntries); - A = aDense.toCsr(); - B = new Matrix(bEntries); - if(args.length != 1 || args[0]) exp = aDense.mult(B); - } - - - @Test - void multTests() { - // ---------------------- Sub-case 1 ---------------------- - aEntries = new Complex128[][]{ - {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(80.1, 2.5)}, - {new Complex128(0), new Complex128(1.41, -92.2), new Complex128(0), new Complex128(0, 15.5), new Complex128(0), new Complex128(0)}, - {new Complex128(0), new Complex128(-9.25, 23.5), new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0)}, - {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(-999.1155, 2.25), new Complex128(-1, 1)}}; - bEntries = new double[][]{ - {0.72773, 0.90836}, - {0.02926, 0.3265}, - {0.23691, 0.77541}, - {0.6462, 0.36597}, - {0.18312, 0.77178}, - {0.40715, 0.35642}}; - build(); - - assertEquals(exp, A.mult(B)); - - // ---------------------- Sub-case 2 ---------------------- - aEntries = new Complex128[][]{ - {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0)}, - {new Complex128(-77.3, -15122.1), new Complex128(0), new Complex128(0), new Complex128(0)}, - {new Complex128(0), new Complex128(0), new Complex128(0, 803.2), new Complex128(0)}, - {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0)}, - {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0)}, - {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0)}, - {new Complex128(0), new Complex128(-9.345, 58.1), new Complex128(0), new Complex128(0)}, - {new Complex128(0), new Complex128(0), new Complex128(1.45, -23), new Complex128(0)}, - {new Complex128(345), new Complex128(2.4, 5.61), new Complex128(0), new Complex128(0)}, - {new Complex128(0), new Complex128(0), new Complex128(4.45, -67.2), new Complex128(0)}, - {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(1)}}; - bEntries = new double[][]{ - {0.72773, 0.90836}, - {0.02926, 0.3265}, - {0.23691, 0.77541}, - {0.6462, 0.36597}}; - build(); - - assertEquals(exp, A.mult(B)); - - // ---------------------- Sub-case 3 ---------------------- - aEntries = new Complex128[][]{ - {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0)}, - {new Complex128(-77.3, -15122.1), new Complex128(0), new Complex128(0), new Complex128(0)}, - {new Complex128(0), new Complex128(0), new Complex128(0, 803.2), new Complex128(0)}, - {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0)}, - {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0)}, - {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(0)}, - {new Complex128(0), new Complex128(-9.345, 58.1), new Complex128(0), new Complex128(0)}, - {new Complex128(0), new Complex128(0), new Complex128(1.45, -23), new Complex128(0)}, - {new Complex128(345), new Complex128(2.4, 5.61), new Complex128(0), new Complex128(0)}, - {new Complex128(0), new Complex128(0), new Complex128(4.45, -67.2), new Complex128(0)}, - {new Complex128(0), new Complex128(0), new Complex128(0), new Complex128(1)}}; - bEntries = new double[][]{ - {0.72773, 0.90836}, - {0.02926, 0.3265}, - {0.23691, 0.77541}}; - build(false); - - assertThrows(LinearAlgebraException.class, ()->A.mult(B)); - } -} diff --git a/src/test/java/org/flag4j/sparse_matrix/CooMatrixSetColTests.java b/src/test/java/org/flag4j/sparse_matrix/CooMatrixSetColTests.java deleted file mode 100644 index 0197adae0..000000000 --- a/src/test/java/org/flag4j/sparse_matrix/CooMatrixSetColTests.java +++ /dev/null @@ -1,278 +0,0 @@ -package org.flag4j.sparse_matrix; - -import org.flag4j.arrays.Shape; -import org.flag4j.arrays.sparse.CooMatrix; -import org.flag4j.arrays.sparse.CooVector; -import org.flag4j.linalg.ops.sparse.coo.real.RealSparseMatrixGetSet; -import org.junit.jupiter.api.Test; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; - -class CooMatrixSetColTests { - - @Test - void setColTest() { - Shape aShape; - int[] aRowIndices; - int[] aColIndices; - double[] aEntries; - CooMatrix a; - - double[] bEntries; - - Shape expShape; - int[] expRowIndices; - int[] expColIndices; - double[] expEntries; - CooMatrix exp; - - // --------------------- Sub-case 1 --------------------- - aShape = new Shape(5, 3); - aEntries = new double[]{0.42216, 0.86886, 0.51801}; - aRowIndices = new int[]{0, 0, 1}; - aColIndices = new int[]{1, 2, 2}; - a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); - - bEntries = new double[]{0.30728, 0.13698, 0.23211, 0.05517, 0.12575}; - - expShape = new Shape(5, 3); - expEntries = new double[]{0.30728, 0.42216, 0.86886, 0.13698, 0.51801, 0.23211, 0.05517, 0.12575}; - expRowIndices = new int[]{0, 0, 0, 1, 1, 2, 3, 4}; - expColIndices = new int[]{0, 1, 2, 0, 2, 0, 0, 0}; - exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); - - assertEquals(exp, RealSparseMatrixGetSet.setCol(a, 0, bEntries)); - - // --------------------- Sub-case 2 --------------------- - aShape = new Shape(11, 23); - aEntries = new double[]{0.86291, 0.59273, 0.14697, 0.79343, 0.0691}; - aRowIndices = new int[]{4, 5, 6, 8, 10}; - aColIndices = new int[]{14, 3, 9, 15, 4}; - a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); - - bEntries = new double[]{0.09599, 0.03342, 0.08342, 0.86195, 0.18126, 0.71121, 0.03191, 0.3479, 0.5699, 0.35584, 0.51796}; - - expShape = new Shape(11, 23); - expEntries = new double[]{0.09599, 0.03342, 0.08342, 0.86195, 0.86291, 0.18126, 0.59273, 0.71121, 0.14697, 0.03191, 0.3479, 0.79343, 0.5699, 0.35584, 0.0691, 0.51796}; - expRowIndices = new int[]{0, 1, 2, 3, 4, 4, 5, 5, 6, 6, 7, 8, 8, 9, 10, 10}; - expColIndices = new int[]{16, 16, 16, 16, 14, 16, 3, 16, 9, 16, 16, 15, 16, 16, 4, 16}; - exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); - - assertEquals(exp, RealSparseMatrixGetSet.setCol(a, 16, bEntries)); - - // --------------------- Sub-case 3 --------------------- - aShape = new Shape(5, 1000); - aEntries = new double[]{0.91557, 0.99112, 0.97331, 0.46736, 0.39273, 0.9236, 0.55027, 0.96506, 0.46553}; - aRowIndices = new int[]{0, 1, 2, 3, 3, 4, 4, 4, 4}; - aColIndices = new int[]{118, 335, 419, 424, 880, 134, 358, 492, 949}; - a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); - - bEntries = new double[]{0.86214, 0.01468, 0.80744, 0.38058, 0.27367}; - - expShape = new Shape(5, 1000); - expEntries = new double[]{0.91557, 0.86214, 0.99112, 0.01468, 0.97331, 0.80744, 0.46736, 0.39273, 0.38058, 0.9236, 0.55027, 0.96506, 0.46553, 0.27367}; - expRowIndices = new int[]{0, 0, 1, 1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4}; - expColIndices = new int[]{118, 999, 335, 999, 419, 999, 424, 880, 999, 134, 358, 492, 949, 999}; - exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); - - assertEquals(exp, RealSparseMatrixGetSet.setCol(a, 999, bEntries)); - - // --------------------- Sub-case 4 --------------------- - aShape = new Shape(3, 5); - aEntries = new double[]{0.20695, 0.08553, 0.58839, 0.42649}; - aRowIndices = new int[]{0, 0, 1, 2}; - aColIndices = new int[]{1, 2, 4, 1}; - a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); - - bEntries = new double[]{0.70299, 0.12535, 0.51468}; - - CooMatrix final0a = a; - double[] final0b = bEntries; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.setCol(final0a, 6, final0b)); - - // --------------------- Sub-case 5 --------------------- - aShape = new Shape(3, 5); - aEntries = new double[]{0.8715, 0.48536, 0.74835, 0.61107}; - aRowIndices = new int[]{1, 1, 2, 2}; - aColIndices = new int[]{2, 4, 2, 3}; - a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); - - bEntries = new double[]{0.18264, 0.50269, 0.62068, 0.68308, 0.25792}; - - CooMatrix final1a = a; - double[] final1b = bEntries; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.setCol(final1a, 3, final1b)); - - // --------------------- Sub-case 6 --------------------- - aShape = new Shape(3, 5); - aEntries = new double[]{0.97644, 0.04564, 0.1204, 0.19723}; - aRowIndices = new int[]{0, 2, 2, 2}; - aColIndices = new int[]{4, 1, 2, 4}; - a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); - - bEntries = new double[]{0.69692, 0.15703}; - - CooMatrix final2a = a; - double[] final2b = bEntries; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.setCol(final2a, 3, final2b)); - - // --------------------- Sub-case 7 --------------------- - aShape = new Shape(3, 5); - aEntries = new double[]{0.9503, 0.0484, 0.44488, 0.29844}; - aRowIndices = new int[]{0, 2, 2, 2}; - aColIndices = new int[]{0, 1, 3, 4}; - a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); - - bEntries = new double[]{0.36708, 0.70117, 0.73955}; - - CooMatrix final3a = a; - double[] final3b = bEntries; - assertThrows(Exception.class, ()->RealSparseMatrixGetSet.setCol(final3a, 19, final3b)); - } - - - @Test - void setColSparseVectorTest() { - Shape aShape; - int[] aRowIndices; - int[] aColIndices; - double[] aEntries; - CooMatrix a; - - Shape bShape; - int[] bIndices; - double[] bEntries; - CooVector b; - - Shape expShape; - int[] expRowIndices; - int[] expColIndices; - double[] expEntries; - CooMatrix exp; - - // --------------------- Sub-case 1 --------------------- - aShape = new Shape(5, 3); - aEntries = new double[]{0.69683, 0.7974, 0.01005}; - aRowIndices = new int[]{0, 3, 4}; - aColIndices = new int[]{1, 0, 2}; - a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); - - bShape = new Shape(5); - bEntries = new double[]{0.42925, 0.95116}; - bIndices = new int[]{2, 3}; - b = new CooVector(bShape.get(0), bEntries, bIndices); - - expShape = new Shape(5, 3); - expEntries = new double[]{0.69683, 0.42925, 0.95116, 0.01005}; - expRowIndices = new int[]{0, 2, 3, 4}; - expColIndices = new int[]{1, 0, 0, 2}; - exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); - - assertEquals(exp, a.setCol(b, 0)); - - // --------------------- Sub-case 2 --------------------- - aShape = new Shape(11, 23); - aEntries = new double[]{0.09879, 0.44944, 0.39054, 0.51234, 0.10826}; - aRowIndices = new int[]{1, 3, 7, 8, 10}; - aColIndices = new int[]{14, 15, 16, 3, 4}; - a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); - - bShape = new Shape(11); - bEntries = new double[]{0.42701, 0.22431, 0.48719, 0.79679}; - bIndices = new int[]{5, 6, 7, 10}; - b = new CooVector(bShape.get(0), bEntries, bIndices); - - expShape = new Shape(11, 23); - expEntries = new double[]{0.09879, 0.44944, 0.42701, 0.22431, 0.48719, 0.51234, 0.10826, 0.79679}; - expRowIndices = new int[]{1, 3, 5, 6, 7, 8, 10, 10}; - expColIndices = new int[]{14, 15, 16, 16, 16, 3, 4, 16}; - exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); - - assertEquals(exp, a.setCol(b, 16)); - - // --------------------- Sub-case 3 --------------------- - aShape = new Shape(5, 1000); - aEntries = new double[]{0.548, 0.12782, 0.71044, 0.03123, 0.73197, 0.23329, 0.76449, 0.62306, 0.77283}; - aRowIndices = new int[]{0, 1, 1, 1, 2, 2, 2, 3, 4}; - aColIndices = new int[]{663, 597, 620, 926, 73, 153, 627, 66, 743}; - a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); - - bShape = new Shape(5); - bEntries = new double[]{0.92473, 0.36888}; - bIndices = new int[]{1, 4}; - b = new CooVector(bShape.get(0), bEntries, bIndices); - - expShape = new Shape(5, 1000); - expEntries = new double[]{0.548, 0.12782, 0.71044, 0.03123, 0.92473, 0.73197, 0.23329, 0.76449, 0.62306, 0.77283, 0.36888}; - expRowIndices = new int[]{0, 1, 1, 1, 1, 2, 2, 2, 3, 4, 4}; - expColIndices = new int[]{663, 597, 620, 926, 999, 73, 153, 627, 66, 743, 999}; - exp = new CooMatrix(expShape, expEntries, expRowIndices, expColIndices); - - assertEquals(exp, a.setCol(b, 999)); - - // --------------------- Sub-case 4 --------------------- - aShape = new Shape(3, 5); - aEntries = new double[]{0.38374, 0.24165, 0.20689, 0.73343}; - aRowIndices = new int[]{0, 1, 2, 2}; - aColIndices = new int[]{4, 0, 0, 1}; - a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); - - bShape = new Shape(3); - bEntries = new double[]{0.93917}; - bIndices = new int[]{2}; - b = new CooVector(bShape.get(0), bEntries, bIndices); - - CooMatrix final0a = a; - CooVector final0b = b; - assertThrows(Exception.class, ()->final0a.setCol(final0b, 6)); - - // --------------------- Sub-case 5 --------------------- - aShape = new Shape(3, 5); - aEntries = new double[]{0.52077, 0.42897, 0.35701, 0.94909}; - aRowIndices = new int[]{0, 1, 2, 2}; - aColIndices = new int[]{2, 2, 0, 3}; - a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); - - bShape = new Shape(5); - bEntries = new double[]{0.41526, 0.41046}; - bIndices = new int[]{0, 2}; - b = new CooVector(bShape.get(0), bEntries, bIndices); - - CooMatrix final1a = a; - CooVector final1b = b; - assertThrows(Exception.class, ()->final1a.setCol(final1b, 3)); - - // --------------------- Sub-case 6 --------------------- - aShape = new Shape(3, 5); - aEntries = new double[]{0.89024, 0.42578, 0.66571, 0.53301}; - aRowIndices = new int[]{0, 1, 1, 2}; - aColIndices = new int[]{2, 0, 1, 4}; - a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); - - bShape = new Shape(2); - bEntries = new double[]{0.55374}; - bIndices = new int[]{1}; - b = new CooVector(bShape.get(0), bEntries, bIndices); - - CooMatrix final2a = a; - CooVector final2b = b; - assertThrows(Exception.class, ()->final2a.setCol(final2b, 3)); - - // --------------------- Sub-case 7 --------------------- - aShape = new Shape(3, 5); - aEntries = new double[]{0.74812, 0.07704, 0.80715, 0.45783}; - aRowIndices = new int[]{0, 1, 2, 2}; - aColIndices = new int[]{0, 1, 3, 4}; - a = new CooMatrix(aShape, aEntries, aRowIndices, aColIndices); - - bShape = new Shape(3); - bEntries = new double[]{0.838}; - bIndices = new int[]{0}; - b = new CooVector(bShape.get(0), bEntries, bIndices); - - CooMatrix final3a = a; - CooVector final3b = b; - assertThrows(Exception.class, ()->final3a.setCol(final3b, 19)); - } -} diff --git a/src/test/java/org/flag4j/tensor/TensorNormTests.java b/src/test/java/org/flag4j/tensor/TensorNormTests.java deleted file mode 100644 index f1b5d422f..000000000 --- a/src/test/java/org/flag4j/tensor/TensorNormTests.java +++ /dev/null @@ -1,52 +0,0 @@ -package org.flag4j.tensor; - -import org.flag4j.arrays.Shape; -import org.flag4j.arrays.dense.Tensor; -import org.flag4j.linalg.TensorNorms; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -public class TensorNormTests { - - static Shape aShape; - static double[] aEntries; - static Tensor A; - static double exp; - - @BeforeAll - static void setup() { - aShape = new Shape(2, 3, 1, 2); - aEntries = new double[]{1, -1.4133, 113.4, 0.4, 11.3, 445, 133.445, 9.8, 13384, -993.44, 11, 12}; - A = new Tensor(aShape, aEntries); - } - - -// @Test -// void infNormTestCase() { -// // ------------------------- Sub-case 1 ------------------------- -// exp = 13384; -// assertEquals(exp, TensorNorms.infNorm(A)); -// } - - - @Test - void normTestCase() { - // ------------------------- Sub-case 1 ------------------------- - exp = 13429.354528384523; - assertEquals(exp, TensorNorms.norm(A)); - } - - - @Test - void pnormTestCase() { - // ------------------------- Sub-case 1 ------------------------- - exp = 13429.354528384523; - assertEquals(exp, TensorNorms.norm(A, 2)); - - // ------------------------- Sub-case 2 ------------------------- - exp = 13384.105704217562; - assertEquals(exp, TensorNorms.norm(A, 4)); - } -} diff --git a/src/test/java/org/flag4j/util/ErrorMessagesTests.java b/src/test/java/org/flag4j/util/ErrorMessagesTests.java index 72d3d3a5a..6cded4a91 100644 --- a/src/test/java/org/flag4j/util/ErrorMessagesTests.java +++ b/src/test/java/org/flag4j/util/ErrorMessagesTests.java @@ -17,7 +17,7 @@ void EqualShapeErrMsgTestCase() { // --------- sub-case 1 --------- s1 = new Shape(2); s2 = new Shape(5); - expMsg = String.format("Expecting matrices to have the same shape but got shapes %s and %s.", + expMsg = String.format("Expecting matrices to have the same shape but got %s and %s.", "(2)", "(5)"); assertEquals(expMsg, ErrorMessages.equalShapeErrMsg(s1, s2)); @@ -26,7 +26,7 @@ void EqualShapeErrMsgTestCase() { // --------- sub-case 2 --------- s1 = new Shape(1, 2, 3, 4); s2 = new Shape(4, 3, 2, 1); - expMsg = String.format("Expecting matrices to have the same shape but got shapes %s and %s.", + expMsg = String.format("Expecting matrices to have the same shape but got %s and %s.", "(1, 2, 3, 4)", "(4, 3, 2, 1)"); assertEquals(expMsg, ErrorMessages.equalShapeErrMsg(s1, s2)); @@ -38,33 +38,12 @@ void matMultShapeErrMsgTestCase() { // --------- sub-case 1 --------- s1 = new Shape(10, 5); s2 = new Shape(14, 4); - expMsg = String.format("Expecting the number of columns in the first matrix to match the number " + - "rows/length in the second matrix/vector but got shapes (10, 5) and (14, 4)."); + expMsg = String.format("Cannot multiply matrices/vector with shapes (10, 5) and (14, 4)."); assertEquals(expMsg, ErrorMessages.matMultShapeErrMsg(s1, s2)); } - @Test - void vecRowOrientErrMsgTestCase() { - s1 = new Shape(14, 1); - expMsg = String.format("Expecting vector to be a row vector but got a vector with shape %s.", - s1); - - assertEquals(expMsg, ErrorMessages.vecRowOrientErrMsg(s1)); - } - - - @Test - void vecColOrientErrMsgTestCase() { - s1 = new Shape(1, 31); - expMsg = String.format("Expecting vector to be a column vector but got a row vector with shape %s.", - s1); - - assertEquals(expMsg, ErrorMessages.vecColOrientErrMsg(s1)); - } - - @Test void negativeDimErrMsgTestCase() { int[] dims = {-1, 1};