Skip to content

Commit

Permalink
Adding linear algebra serialization tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
Craigacp committed Jul 20, 2023
1 parent c254340 commit 5bb7fbb
Show file tree
Hide file tree
Showing 11 changed files with 128 additions and 3 deletions.
23 changes: 22 additions & 1 deletion Math/src/test/java/org/tribuo/math/la/DenseMatrixTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2023, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -23,11 +23,18 @@
import static org.junit.jupiter.api.Assertions.fail;
import static org.tribuo.math.la.DenseVectorTest.makeMalformedProto;

import java.io.IOException;
import java.io.InputStream;
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Optional;
import java.util.Random;

import org.junit.jupiter.api.Test;
import org.tribuo.math.protos.TensorProto;
import org.tribuo.test.Helpers;

/**
* Matrices used -
Expand Down Expand Up @@ -1444,6 +1451,20 @@ public void matrixVectorTest() {
assertEquals(matrixMatrixOutput,matrixVectorOutput);
}

@Test
public void serialization431Test() throws URISyntaxException, IOException {
Path matrixPath = Paths.get(DenseMatrixTest.class.getResource("dense-matrix-431.tribuo").toURI());
try (InputStream fis = Files.newInputStream(matrixPath)) {
TensorProto proto = TensorProto.parseFrom(fis);
Tensor matrix = Tensor.deserialize(proto);
assertEquals(generateA(), matrix);
}
}

public void generateProtobuf() throws IOException {
Helpers.writeProtobuf(generateA(), Paths.get("src","test","resources","org","tribuo","math","la","dense-matrix-431.tribuo"));
}

@Test
public void serializationTest() {
DenseMatrix a = generateA();
Expand Down
24 changes: 24 additions & 0 deletions Math/src/test/java/org/tribuo/math/la/DenseSparseMatrixTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@

import org.junit.jupiter.api.Test;
import org.tribuo.math.protos.TensorProto;
import org.tribuo.test.Helpers;

import java.io.IOException;
import java.io.InputStream;
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;

import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
Expand Down Expand Up @@ -53,6 +61,22 @@ public void testGetColumn() {
assertEquals(0, column.get(2));
}

@Test
public void serialization431Test() throws URISyntaxException, IOException {
Path matrixPath = Paths.get(DenseSparseMatrixTest.class.getResource("densesparse-matrix-431.tribuo").toURI());
try (InputStream fis = Files.newInputStream(matrixPath)) {
TensorProto proto = TensorProto.parseFrom(fis);
Tensor matrix = Tensor.deserialize(proto);
DenseSparseMatrix a = DenseSparseMatrix.createDiagonal(new DenseVector(new double[]{1,2,3,4,5,6}));
assertEquals(a, matrix);
}
}

public void generateProtobuf() throws IOException {
DenseSparseMatrix a = DenseSparseMatrix.createDiagonal(new DenseVector(new double[]{1,2,3,4,5,6}));
Helpers.writeProtobuf(a, Paths.get("src","test","resources","org","tribuo","math","la","densesparse-matrix-431.tribuo"));
}

@Test
public void serializationTest() {
DenseSparseMatrix a = DenseSparseMatrix.createDiagonal(new DenseVector(new double[]{1,2,3,4,5,6}));
Expand Down
22 changes: 21 additions & 1 deletion Math/src/test/java/org/tribuo/math/la/DenseVectorTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2023, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -30,9 +30,15 @@
import org.tribuo.test.MockOutputFactory;
import org.tribuo.util.MeanVarianceAccumulator;

import java.io.IOException;
import java.io.InputStream;
import java.net.URISyntaxException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.function.BiFunction;
import java.util.function.DoubleUnaryOperator;
Expand Down Expand Up @@ -302,6 +308,20 @@ public void subtract() {
assertEquals(cSubB, c.subtract(b), "C - B");
}

@Test
public void serialization431Test() throws URISyntaxException, IOException {
Path vectorPath = Paths.get(DenseVectorTest.class.getResource("dense-vector-431.tribuo").toURI());
try (InputStream fis = Files.newInputStream(vectorPath)) {
TensorProto proto = TensorProto.parseFrom(fis);
Tensor vector = Tensor.deserialize(proto);
assertEquals(generateVectorA(), vector);
}
}

public void generateProtobuf() throws IOException {
Helpers.writeProtobuf(generateVectorA(), Paths.get("src","test","resources","org","tribuo","math","la","dense-vector-431.tribuo"));
}

@Test
public void serializationTest() {
DenseVector a = generateVectorA();
Expand Down
23 changes: 22 additions & 1 deletion Math/src/test/java/org/tribuo/math/la/SparseVectorTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2023, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -29,16 +29,23 @@
import org.tribuo.impl.ListExample;
import org.tribuo.math.protos.SparseTensorProto;
import org.tribuo.math.protos.TensorProto;
import org.tribuo.test.Helpers;
import org.tribuo.test.MockDataSourceProvenance;
import org.tribuo.test.MockOutput;
import org.tribuo.test.MockOutputFactory;
import org.tribuo.util.Util;
import org.junit.jupiter.api.Test;

import java.io.IOException;
import java.io.InputStream;
import java.net.URISyntaxException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.IntBuffer;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
Expand Down Expand Up @@ -739,6 +746,20 @@ private static int[] slowIntersection(int[] firstIndices, int[] secondIndices) {
return Util.toPrimitiveInt(intersectIndices);
}

@Test
public void serialization431Test() throws URISyntaxException, IOException {
Path vectorPath = Paths.get(SparseVectorTest.class.getResource("sparse-vector-431.tribuo").toURI());
try (InputStream fis = Files.newInputStream(vectorPath)) {
TensorProto proto = TensorProto.parseFrom(fis);
Tensor vector = Tensor.deserialize(proto);
assertEquals(generateVectorA(), vector);
}
}

public void generateProtobuf() throws IOException {
Helpers.writeProtobuf(generateVectorA(), Paths.get("src","test","resources","org","tribuo","math","la","sparse-vector-431.tribuo"));
}

@Test
public void serializationTest() {
SparseVector a = generateVectorA();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,51 @@
import org.tribuo.math.la.DenseMatrixTest;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.DenseVectorTest;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.protos.TensorProto;
import org.tribuo.test.Helpers;

import java.io.IOException;
import java.io.InputStream;
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.tribuo.test.Helpers.testProtoSerialization;

public class ShrinkingTensorTest {

@Test
public void serialization431Test() throws URISyntaxException, IOException {
Path matrixPath = Paths.get(ShrinkingTensorTest.class.getResource("shrinking-matrix-431.tribuo").toURI());
try (InputStream fis = Files.newInputStream(matrixPath)) {
TensorProto proto = TensorProto.parseFrom(fis);
Tensor matrix = Tensor.deserialize(proto);
DenseMatrix a = DenseMatrixTest.generateA();
ShrinkingMatrix sh = new ShrinkingMatrix(a,0.1,true);
assertEquals(sh, matrix);
}
Path vectorPath = Paths.get(ShrinkingTensorTest.class.getResource("shrinking-vector-431.tribuo").toURI());
try (InputStream fis = Files.newInputStream(vectorPath)) {
TensorProto proto = TensorProto.parseFrom(fis);
Tensor vector = Tensor.deserialize(proto);
DenseVector a = DenseVectorTest.generateVectorA();
ShrinkingVector sh = new ShrinkingVector(a,0.1,true);
assertEquals(sh, vector);
}
}

public void generateProtobuf() throws IOException {
DenseMatrix aMatrix = DenseMatrixTest.generateA();
ShrinkingMatrix shMatrix = new ShrinkingMatrix(aMatrix,0.1,true);
Helpers.writeProtobuf(shMatrix, Paths.get("src","test","resources","org","tribuo","math","optimizers","util","shrinking-matrix-431.tribuo"));
DenseVector aVec = DenseVectorTest.generateVectorA();
ShrinkingVector shVec = new ShrinkingVector(aVec,0.1,true);
Helpers.writeProtobuf(shVec, Paths.get("src","test","resources","org","tribuo","math","optimizers","util","shrinking-vector-431.tribuo"));
}

@Test
public void matrixSerializationTest() {
DenseMatrix a = DenseMatrixTest.generateA();
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit 5bb7fbb

Please sign in to comment.