Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add deserialization tests for 4.3 protobufs in Classification #345

Merged
merged 4 commits into from
Aug 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,9 +20,21 @@
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.OutputInfo;
import org.tribuo.classification.ensemble.FullyWeightedVotingCombiner;
import org.tribuo.classification.ensemble.VotingCombiner;
import org.tribuo.ensemble.EnsembleCombiner;
import org.tribuo.protos.core.EnsembleCombinerProto;
import org.tribuo.protos.core.OutputDomainProto;
import org.tribuo.protos.core.OutputFactoryProto;
import org.tribuo.protos.core.OutputProto;
import org.tribuo.test.Helpers;

import java.io.IOException;
import java.io.InputStream;
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;

import static org.junit.jupiter.api.Assertions.assertEquals;

Expand Down Expand Up @@ -70,4 +82,85 @@ public void infoSerializationTest() {
assertEquals(immutableInfo,deserImInfo);
}

@Test
public void load431Protobufs() throws URISyntaxException, IOException {
Label test = new Label("TEST",1.0);
Label other = new Label("OTHER",1.0);
// Label
Path eventPath = Paths.get(SerializationTest.class.getResource("label-clf-431.tribuo").toURI());
try (InputStream fis = Files.newInputStream(eventPath)) {
OutputProto proto = OutputProto.parseFrom(fis);
Label lbl = (Label) Output.deserialize(proto);
assertEquals(test, lbl);
}

// LabelFactory
Path factoryPath = Paths.get(SerializationTest.class.getResource("factory-clf-431.tribuo").toURI());
try (InputStream fis = Files.newInputStream(factoryPath)) {
OutputFactoryProto proto = OutputFactoryProto.parseFrom(fis);
LabelFactory factory = (LabelFactory) OutputFactory.deserialize(proto);
assertEquals(new LabelFactory(), factory);
}

MutableLabelInfo info = new MutableLabelInfo();
for (int i = 0; i < 5; i++) {
info.observe(test);
info.observe(other);
}
for (int i = 0; i < 2; i++) {
info.observe(LabelFactory.UNKNOWN_LABEL);
}
ImmutableLabelInfo imInfo = (ImmutableLabelInfo) info.generateImmutableOutputInfo();

// MutableLabelInfo
Path mutablePath = Paths.get(SerializationTest.class.getResource("mutableinfo-clf-431.tribuo").toURI());
try (InputStream fis = Files.newInputStream(mutablePath)) {
OutputDomainProto proto = OutputDomainProto.parseFrom(fis);
LabelInfo deserInfo = (LabelInfo) OutputInfo.deserialize(proto);
assertEquals(info, deserInfo);
}
// ImmutableLabelInfo
Path immutablePath = Paths.get(SerializationTest.class.getResource("immutableinfo-clf-431.tribuo").toURI());
try (InputStream fis = Files.newInputStream(immutablePath)) {
OutputDomainProto proto = OutputDomainProto.parseFrom(fis);
LabelInfo deserInfo = (LabelInfo) OutputInfo.deserialize(proto);
assertEquals(imInfo, deserInfo);
}
// VotingCombiner
VotingCombiner comb = new VotingCombiner();
Path combinerPath = Paths.get(SerializationTest.class.getResource("vote-combiner-clf-431.tribuo").toURI());
try (InputStream fis = Files.newInputStream(combinerPath)) {
EnsembleCombinerProto proto = EnsembleCombinerProto.parseFrom(fis);
VotingCombiner deserComb = (VotingCombiner) EnsembleCombiner.deserialize(proto);
assertEquals(comb, deserComb);
}
// MultiLabelVotingCombiner
FullyWeightedVotingCombiner fvComb = new FullyWeightedVotingCombiner();
Path fvCombinerPath = Paths.get(SerializationTest.class.getResource("fullvote-combiner-clf-431.tribuo").toURI());
try (InputStream fis = Files.newInputStream(fvCombinerPath)) {
EnsembleCombinerProto proto = EnsembleCombinerProto.parseFrom(fis);
FullyWeightedVotingCombiner deserComb = (FullyWeightedVotingCombiner) EnsembleCombiner.deserialize(proto);
assertEquals(fvComb, deserComb);
}
}

public void generateProtobufs() throws IOException {
Label test = new Label("TEST",1.0);
Label other = new Label("OTHER",1.0);
Helpers.writeProtobuf(new LabelFactory(), Paths.get("src","test","resources","org","tribuo","classification","factory-clf-431.tribuo"));
Helpers.writeProtobuf(test, Paths.get("src","test","resources","org","tribuo","classification","label-clf-431.tribuo"));
MutableLabelInfo info = new MutableLabelInfo();
for (int i = 0; i < 5; i++) {
info.observe(test);
info.observe(other);
}
for (int i = 0; i < 2; i++) {
info.observe(LabelFactory.UNKNOWN_LABEL);
}
Helpers.writeProtobuf(info, Paths.get("src","test","resources","org","tribuo","classification","mutableinfo-clf-431.tribuo"));
ImmutableLabelInfo imInfo = (ImmutableLabelInfo) info.generateImmutableOutputInfo();
Helpers.writeProtobuf(imInfo, Paths.get("src","test","resources","org","tribuo","classification","immutableinfo-clf-431.tribuo"));
Helpers.writeProtobuf(new VotingCombiner(), Paths.get("src","test","resources","org","tribuo","classification","vote-combiner-clf-431.tribuo"));
Helpers.writeProtobuf(new FullyWeightedVotingCombiner(), Paths.get("src","test","resources","org","tribuo","classification","fullvote-combiner-clf-431.tribuo"));
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2023, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,18 +19,29 @@
import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tribuo.Dataset;
import org.tribuo.Model;
import org.tribuo.Prediction;
import org.tribuo.Trainer;
import org.tribuo.classification.Label;
import org.tribuo.classification.evaluation.LabelEvaluation;
import org.tribuo.classification.evaluation.LabelEvaluator;
import org.tribuo.classification.example.LabelledDataGenerator;
import org.tribuo.evaluation.Evaluator;
import org.junit.jupiter.api.Test;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.test.Helpers;

import java.io.IOException;
import java.io.InputStream;
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class TestDummyClassifier {

private static final Trainer<Label> constant = DummyClassifierTrainer.createConstantTrainer("Foo");
Expand Down Expand Up @@ -70,4 +81,32 @@ public void testSparseBinaryData() {
testDummyClassifier(p,false);
}

@Test
public void loadProtobufModel() throws IOException, URISyntaxException {
Path path = Paths.get(TestDummyClassifier.class.getResource("dummyclf-431.tribuo").toURI());
try (InputStream fis = Files.newInputStream(path)) {
ModelProto proto = ModelProto.parseFrom(fis);
DummyClassifierModel deserModel = (DummyClassifierModel) Model.deserialize(proto);

assertEquals("4.3.1", deserModel.getProvenance().getTribuoVersion());

Pair<Dataset<Label>,Dataset<Label>> p = LabelledDataGenerator.denseTrainTest();
DummyClassifierTrainer trainer = DummyClassifierTrainer.createMostFrequentTrainer();
Model<Label> model = trainer.train(p.getA());

List<Prediction<Label>> deserPredictions = deserModel.predict(p.getB());
List<Prediction<Label>> predictions = model.predict(p.getB());
assertEquals(p.getB().size(), deserPredictions.size());
assertTrue(Helpers.predictionListDistributionEquals(predictions, deserPredictions));
}
}

public void generateProtobuf() throws IOException {
Pair<Dataset<Label>,Dataset<Label>> p = LabelledDataGenerator.denseTrainTest();

DummyClassifierTrainer trainer = DummyClassifierTrainer.createMostFrequentTrainer();
Model<Label> model = trainer.train(p.getA());

Helpers.writeProtobuf(model, Paths.get("src","test","resources","org","tribuo","classification","baseline","dummyclf-431.tribuo"));
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
&org.tribuo.classification.LabelFactory
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
>org.tribuo.classification.ensemble.FullyWeightedVotingCombiner
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
*org.tribuo.classification.MutableLabelInfoV
?type.googleapis.com/tribuo.classification.MutableLabelInfoProto
OTHER
TEST
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
1org.tribuo.classification.ensemble.VotingCombiner
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2023, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -21,6 +21,7 @@
import org.tribuo.Example;
import org.tribuo.Model;
import org.tribuo.Prediction;
import org.tribuo.SparseModel;
import org.tribuo.Trainer;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
Expand All @@ -35,11 +36,14 @@
import org.tribuo.dataset.DatasetView;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.test.Helpers;

import java.io.IOException;
import java.io.InputStream;
import java.net.URISyntaxException;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.HashMap;
Expand Down Expand Up @@ -271,4 +275,38 @@ public void testPureLeaf() throws IOException {
assertTrue(predictions.get(3).distributionEquals(pred));
assertTrue(predictions.get(4).distributionEquals(pred));
}

@Test
public void loadProtobufModel() throws IOException, URISyntaxException {
Path path = Paths.get(TestCART.class.getResource("cart-clf-431.tribuo").toURI());
try (InputStream fis = Files.newInputStream(path)) {
ModelProto proto = ModelProto.parseFrom(fis);
@SuppressWarnings("unchecked")
SparseModel<Label> deserModel = (SparseModel<Label>) Model.deserialize(proto);

assertEquals("4.3.1", deserModel.getProvenance().getTribuoVersion());

Pair<Dataset<Label>, Dataset<Label>> p = LabelledDataGenerator.denseTrainTest(1.0);

List<Prediction<Label>> deserOutput = deserModel.predict(p.getB());

CARTClassificationTrainer trainer = new CARTClassificationTrainer();
TreeModel<Label> model = trainer.train(p.getA());
List<Prediction<Label>> output = model.predict(p.getB());

assertEquals(deserOutput.size(), p.getB().size());
assertTrue(Helpers.predictionListDistributionEquals(deserOutput, output));
}
}

/**
* Test protobuf generation method.
* @throws IOException If the write failed.
*/
public void generateModel() throws IOException {
Pair<Dataset<Label>, Dataset<Label>> p = LabelledDataGenerator.denseTrainTest(1.0);
CARTClassificationTrainer trainer = new CARTClassificationTrainer();
TreeModel<Label> model = trainer.train(p.getA());
Helpers.writeProtobuf(model, Paths.get("src","test","resources","org","tribuo","classification","dtree","cart-clf-431.tribuo"));
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2023, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -34,6 +34,7 @@
import org.tribuo.classification.evaluation.LabelEvaluator;
import org.tribuo.classification.example.LabelledDataGenerator;
import org.tribuo.classification.liblinear.LinearClassificationType.LinearType;
import org.tribuo.common.liblinear.LibLinearModel;
import org.tribuo.data.text.TextDataSource;
import org.tribuo.data.text.TextFeatureExtractor;
import org.tribuo.data.text.impl.BasicPipeline;
Expand All @@ -45,12 +46,12 @@
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.tribuo.interop.onnx.OnnxTestUtils;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.test.Helpers;
import org.tribuo.util.tokens.impl.BreakIteratorTokenizer;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.net.URISyntaxException;
import java.net.URL;
Expand Down Expand Up @@ -294,6 +295,33 @@ public void testOnnxSerialization() throws IOException, OrtException {
onnxFile.toFile().delete();
}

@Test
public void loadProtobufModel() throws IOException, URISyntaxException {
Path path = Paths.get(TestLibLinearModel.class.getResource("liblinear-clf-431.tribuo").toURI());
try (InputStream fis = Files.newInputStream(path)) {
ModelProto proto = ModelProto.parseFrom(fis);
LibLinearClassificationModel model = (LibLinearClassificationModel) Model.deserialize(proto);

assertEquals("4.3.1", model.getProvenance().getTribuoVersion());

Pair<Dataset<Label>,Dataset<Label>> p = LabelledDataGenerator.denseTrainTest();
List<Prediction<Label>> output = model.predict(p.getB());
assertEquals(output.size(), p.getB().size());
}
}

/**
* Test protobuf generation method.
* @throws IOException If the write failed.
*/
public void generateModel() throws IOException {
Pair<Dataset<Label>,Dataset<Label>> p = LabelledDataGenerator.denseTrainTest();
LinearClassificationType type = new LinearClassificationType(LinearType.L2R_L2LOSS_SVC);
LibLinearClassificationTrainer trainer = new LibLinearClassificationTrainer(type,1.0,1000,0.01);
LibLinearModel<Label> model = trainer.train(p.getA());
Helpers.writeProtobuf(model, Paths.get("src","test","resources","org","tribuo","classification","liblinear","liblinear-clf-431.tribuo"));
}

private static int[] getIndices(FeatureNode[] nodes) {
int[] indices = new int[nodes.length];

Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2023, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -50,10 +50,12 @@
import org.tribuo.dataset.DatasetView;
import org.tribuo.impl.ListExample;
import org.tribuo.interop.onnx.OnnxTestUtils;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.test.Helpers;
import org.tribuo.util.tokens.impl.BreakIteratorTokenizer;

import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.net.URISyntaxException;
import java.net.URL;
Expand All @@ -78,6 +80,7 @@
public class TestLibSVM {
private static final Logger logger = Logger.getLogger(TestLibSVM.class.getName());

private static final SVMParameters<Label> linearParams = new SVMParameters<>(new SVMClassificationType(SVMMode.C_SVC), KernelType.LINEAR);
private static final LibSVMClassificationTrainer C_RBF = new LibSVMClassificationTrainer(new SVMParameters<>(new SVMClassificationType(SVMMode.C_SVC), KernelType.RBF));
private static final LibSVMClassificationTrainer NU_RBF = new LibSVMClassificationTrainer(new SVMParameters<>(new SVMClassificationType(SVMMode.NU_SVC), KernelType.RBF));
private static final LibSVMClassificationTrainer C_LINEAR = new LibSVMClassificationTrainer(new SVMParameters<>(new SVMClassificationType(SVMMode.C_SVC), KernelType.LINEAR));
Expand Down Expand Up @@ -358,6 +361,32 @@ public void testEmptyExample() {
});
}

@Test
public void loadProtobufModel() throws IOException, URISyntaxException {
Path path = Paths.get(TestLibSVM.class.getResource("libsvm-clf-431.tribuo").toURI());
try (InputStream fis = Files.newInputStream(path)) {
ModelProto proto = ModelProto.parseFrom(fis);
LibSVMClassificationModel model = (LibSVMClassificationModel) Model.deserialize(proto);

assertEquals("4.3.1", model.getProvenance().getTribuoVersion());

Pair<Dataset<Label>,Dataset<Label>> p = LabelledDataGenerator.denseTrainTest();
List<Prediction<Label>> output = model.predict(p.getB());
assertEquals(output.size(), p.getB().size());
}
}

/**
* Test protobuf generation method.
* @throws IOException If the write failed.
*/
public void generateModel() throws IOException {
Pair<Dataset<Label>,Dataset<Label>> p = LabelledDataGenerator.denseTrainTest();
LibSVMClassificationTrainer trainer = new LibSVMClassificationTrainer(linearParams);
LibSVMModel<Label> model = trainer.train(p.getA());
Helpers.writeProtobuf(model, Paths.get("src","test","resources","org","tribuo","classification","libsvm","libsvm-clf-431.tribuo"));
}

private static int[] getIndices(svm_node[] nodes) {
int[] indices = new int[nodes.length];

Expand Down
Binary file not shown.
Loading