Skip to content

Commit

Permalink
Merge pull request #4 from pogren/math-serialization-paramtests
Browse files Browse the repository at this point in the history
refactored tests to be ParameterizedTests
  • Loading branch information
Craigacp authored Aug 7, 2023
2 parents c339ec7 + fcd55e1 commit ece0bd8
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 47 deletions.
47 changes: 24 additions & 23 deletions Math/src/test/java/org/tribuo/math/KernelTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,28 @@

package org.tribuo.math;

import org.junit.jupiter.api.Test;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.OutputInfo;
import org.tribuo.math.kernel.Kernel;
import org.tribuo.math.kernel.Linear;
import org.tribuo.math.kernel.Polynomial;
import org.tribuo.math.kernel.RBF;
import org.tribuo.math.kernel.Sigmoid;
import org.tribuo.math.protos.KernelProto;
import org.tribuo.protos.core.OutputDomainProto;
import org.tribuo.protos.core.OutputFactoryProto;
import org.tribuo.protos.core.OutputProto;
import org.tribuo.test.Helpers;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.tribuo.test.Helpers.testProtoSerialization;

import java.io.IOException;
import java.io.InputStream;
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.stream.Stream;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.tribuo.test.Helpers.testProtoSerialization;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.tribuo.math.kernel.Kernel;
import org.tribuo.math.kernel.Linear;
import org.tribuo.math.kernel.Polynomial;
import org.tribuo.math.kernel.RBF;
import org.tribuo.math.kernel.Sigmoid;
import org.tribuo.math.protos.KernelProto;
import org.tribuo.test.Helpers;

public class KernelTest {

Expand Down Expand Up @@ -67,7 +65,9 @@ public void testSigmoid() {
testProtoSerialization(lin);
}

private void testProto(String name, Kernel actualKernel) throws URISyntaxException, IOException {
@ParameterizedTest
@MethodSource("load431Protobufs")
public void testProto(String name, Kernel actualKernel) throws URISyntaxException, IOException {
Path kernelPath = Paths.get(KernelTest.class.getResource(name).toURI());
try (InputStream fis = Files.newInputStream(kernelPath)) {
KernelProto proto = KernelProto.parseFrom(fis);
Expand All @@ -76,12 +76,13 @@ private void testProto(String name, Kernel actualKernel) throws URISyntaxExcepti
}
}

@Test
public void load431Protobufs() throws URISyntaxException, IOException {
testProto("linear-kernel-431.tribuo", new Linear());
testProto("poly-kernel-431.tribuo", new Polynomial(1,2,3));
testProto("rbf-kernel-431.tribuo", new RBF(1.0));
testProto("sigmoid-kernel-431.tribuo", new Sigmoid(1,2));

private static Stream<Arguments> load431Protobufs() throws URISyntaxException, IOException {
return Stream.of(
Arguments.of("linear-kernel-431.tribuo", new Linear()),
Arguments.of("poly-kernel-431.tribuo", new Polynomial(1,2,3)),
Arguments.of("rbf-kernel-431.tribuo", new RBF(1.0)),
Arguments.of("sigmoid-kernel-431.tribuo", new Sigmoid(1,2)));
}

public void generateProtobufs() throws IOException {
Expand Down
29 changes: 17 additions & 12 deletions Math/src/test/java/org/tribuo/math/distance/DistanceTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,41 @@

package org.tribuo.math.distance;

import org.junit.jupiter.api.Test;
import org.tribuo.math.protos.DistanceProto;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.test.Helpers;
import static org.junit.jupiter.api.Assertions.assertEquals;

import java.io.IOException;
import java.io.InputStream;
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.stream.Stream;

import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.tribuo.math.protos.DistanceProto;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.test.Helpers;

public class DistanceTest {

private void testProto(String name, Distance actualDistance) throws URISyntaxException, IOException {
@ParameterizedTest
@MethodSource("load431Protobufs")
public void testProto(String name, Distance actualDistance) throws URISyntaxException, IOException {
Path distancePath = Paths.get(DistanceTest.class.getResource(name).toURI());
try (InputStream fis = Files.newInputStream(distancePath)) {
DistanceProto proto = DistanceProto.parseFrom(fis);
Distance distance = ProtoUtil.deserialize(proto);
assertEquals(actualDistance, distance);
}
}

@Test
public void load431Protobufs() throws URISyntaxException, IOException {
testProto("cosine-distance-431.tribuo", new CosineDistance());
testProto("l1-distance-431.tribuo", new L1Distance());
testProto("l2-distance-431.tribuo", new L2Distance());
private static Stream<Arguments> load431Protobufs() throws URISyntaxException, IOException {
return Stream.of(
Arguments.of("cosine-distance-431.tribuo", new CosineDistance()),
Arguments.of("l1-distance-431.tribuo", new L1Distance()),
Arguments.of("l2-distance-431.tribuo", new L2Distance()));
}

public void generateProtobufs() throws IOException {
Expand Down
16 changes: 11 additions & 5 deletions Math/src/test/java/org/tribuo/math/util/MergerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
import org.tribuo.math.la.DenseSparseMatrix;
import org.tribuo.math.la.SparseVector;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.tribuo.math.protos.MergerProto;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.test.Helpers;
Expand All @@ -29,6 +32,7 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.stream.Stream;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.tribuo.test.Helpers.testProtoSerialization;
Expand Down Expand Up @@ -206,7 +210,9 @@ public static void testMerger(Merger merger) {
assertEquals(output,merged, "Merge A - B - A - B unsuccessful");
}

private void testProto(String name, Merger actualMerger) throws URISyntaxException, IOException {
@ParameterizedTest
@MethodSource("load431Protobufs")
public void testProto(String name, Merger actualMerger) throws URISyntaxException, IOException {
Path mergerPath = Paths.get(MergerTest.class.getResource(name).toURI());
try (InputStream fis = Files.newInputStream(mergerPath)) {
MergerProto proto = MergerProto.parseFrom(fis);
Expand All @@ -215,10 +221,10 @@ private void testProto(String name, Merger actualMerger) throws URISyntaxExcepti
}
}

@Test
public void load431Protobufs() throws URISyntaxException, IOException {
testProto("heap-merger-431.tribuo", new HeapMerger());
testProto("matrix-merger-431.tribuo", new MatrixHeapMerger());
private static Stream<Arguments> load431Protobufs() throws URISyntaxException, IOException {
return Stream.of(
Arguments.of("heap-merger-431.tribuo", new HeapMerger()),
Arguments.of("matrix-merger-431.tribuo", new MatrixHeapMerger()));
}

public void generateProtobufs() throws IOException {
Expand Down
20 changes: 13 additions & 7 deletions Math/src/test/java/org/tribuo/math/util/NormalizerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
package org.tribuo.math.util;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.tribuo.math.protos.NormalizerProto;
import org.tribuo.protos.ProtoUtil;
import org.tribuo.test.Helpers;
Expand All @@ -27,6 +30,7 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.stream.Stream;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.tribuo.test.Helpers.testProtoSerialization;
Expand Down Expand Up @@ -57,7 +61,9 @@ public void sigmoidNormalizerTest() {
testProtoSerialization(n);
}

private void testProto(String name, VectorNormalizer actualNormalizer) throws URISyntaxException, IOException {
@ParameterizedTest
@MethodSource("load431Protobufs")
public void testProto(String name, VectorNormalizer actualNormalizer) throws URISyntaxException, IOException {
Path normalizerPath = Paths.get(NormalizerTest.class.getResource(name).toURI());
try (InputStream fis = Files.newInputStream(normalizerPath)) {
NormalizerProto proto = NormalizerProto.parseFrom(fis);
Expand All @@ -66,12 +72,12 @@ private void testProto(String name, VectorNormalizer actualNormalizer) throws UR
}
}

@Test
public void load431Protobufs() throws URISyntaxException, IOException {
testProto("normalizer-431.tribuo", new Normalizer());
testProto("noop-normalizer-431.tribuo", new NoopNormalizer());
testProto("exp-normalizer-431.tribuo", new ExpNormalizer());
testProto("sigmoid-normalizer-431.tribuo", new SigmoidNormalizer());
private static Stream<Arguments> load431Protobufs() throws URISyntaxException, IOException {
return Stream.of(
Arguments.of("normalizer-431.tribuo", new Normalizer()),
Arguments.of("noop-normalizer-431.tribuo", new NoopNormalizer()),
Arguments.of("exp-normalizer-431.tribuo", new ExpNormalizer()),
Arguments.of("sigmoid-normalizer-431.tribuo", new SigmoidNormalizer()));
}

public void generateProtobufs() throws IOException {
Expand Down

0 comments on commit ece0bd8

Please sign in to comment.