Skip to content

Commit 0d808ec

Browse files
committed
refactor model evaluation, add option to use cross validation
- change type of dataset - refactoring of model evaluation - add parameter for objective function - add option to use cross validation for model evaluation - adjust test example - update README
1 parent 6cf5de7 commit 0d808ec

File tree

7 files changed

+174
-121
lines changed

7 files changed

+174
-121
lines changed

README.md

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,21 @@ autotuner.bayesianOptimization();
3131
```javascript
3232
autotuner.gridSearchOptimizytion();
3333
```
34+
#### Optinal parameters
35+
The ojective function of the optimization can be specified (either 'error' or 'accuracy').
36+
```javascript
37+
autotuner.gridSearchOptimizytion('accuracy');
38+
```
39+
Also one can enable cross validation when evaluating a model.
40+
```javascript
41+
autotuner.gridSearchOptimizytion('accuracy', true);
42+
```
3443
When doing bayesian optimization the maximum number of domain points to be evaluated can be specified as an optional parameter.
3544
```javascript
36-
autotuner.bayesianOptimization(0.8);
45+
autotuner.bayesianOptimization('accuracy', true, 0.8);
3746
```
38-
In the example above the optimizytion search stops after 80% of the domain ponits have been evaluated. By default this value is set to 0.75.
39-
47+
In the example above the optimizytion search stops after 80% of the domain ponits have been evaluated. By default this value is set to 0.75.
48+
### Example
4049
An example usage can be found here:
4150
```bash
4251
tets/runExampleAutotuner.ts

src/autotuner.ts

Lines changed: 57 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import * as tensorflow from '@tensorflow/tfjs';
2-
import * as math from 'mathjs';
3-
import { DataSet, ModelDict, SequentialModelParameters, datasetType, BaysianOptimisationStep, LossFunction, DomainPointValue } from '../types/types';
2+
import { ModelDict, SequentialModelParameters, DataPoint, BaysianOptimisationStep, LossFunction, DomainPointValue } from '../types/types';
43
import * as bayesianOptimizer from './bayesianOptimizer';
54
import * as gridSearchOptimizer from './gridSearchOptimizer';
65
import * as paramspace from './paramspace';
76
import * as priors from './priors';
7+
import * as modelEvaluator from './modelEvaluater';
88

99
class AutotunerBaseClass {
10-
dataset: DataSet;
1110
metrics: string[] = [];
1211
observedValues: DomainPointValue[] = [];
1312
/**
@@ -24,14 +23,15 @@ class AutotunerBaseClass {
2423
paramspace: any;
2524
optimizer: any;
2625
priors: any;
26+
modelEvaluator: any;
2727

2828
/**
2929
* Returns the value of a domain point.
3030
*
3131
* @param {number} domainIndex Index of the domain point to be evaluated.
3232
* @return {Promise<number>} Value of the domain point
3333
*/
34-
evaluateModel: (domainIndex: number) => Promise<number>;
34+
evaluateModel: (domainIndex: number, objective: string, useCrossValidation: boolean) => Promise<number>;
3535

3636
/**
3737
* Decide whether to continue tuning the hyperparameters.
@@ -57,44 +57,63 @@ class AutotunerBaseClass {
5757
}
5858
}
5959

60-
constructor(metrics: string[], trainingSet: datasetType, testSet: datasetType, evaluationSet: datasetType, numberOfCategories: number) {
60+
checkObjective (objective: string): boolean {
61+
const allowedObjectives: string[] = ['error'].concat(this.metrics);
62+
if (!allowedObjectives.includes(objective)) {
63+
console.log("Invalid objective function selected!");
64+
console.log("Objective function must be one of the following: " + allowedObjectives.join());
65+
return true;
66+
}
67+
return false;
68+
}
69+
70+
constructor(metrics: string[], dataSet: DataPoint[], numberOfCategories: number, validationSetRatio: number = 0.25, testSetRatio: number = 0.25) {
6171
this.paramspace = new paramspace.Paramspace();
72+
this.modelEvaluator = new modelEvaluator.ModelEvaluater(dataSet, numberOfCategories, validationSetRatio, testSetRatio);
6273
this.metrics = metrics;
63-
64-
const dataset: DataSet = {trainingSet: trainingSet, testSet: testSet, evaluationSet: evaluationSet, numberOfCategories: numberOfCategories};
65-
this.dataset = dataset;
6674
}
6775

6876
/**
6977
* Search the best Parameters using bayesian optimization.
7078
*
79+
* @param {string} [objective='error'] Define the objective of the optimization. Set to 'error' by default.
80+
* @param {boolean} [useCrossValidation=false] Indicate wheter or not to use cross validation to evaluate the model. Set to 'false' by default.
7181
* @param {number} [maxIteration=0.75] Fraction of domain points that should be evaluated at most. (e.g. for 'maxIteration=0.75' the optimization stops if 75% of the domain has been evaluated)
7282
* @param {boolean} [stopingCriteria] Predicate on the observed values when to stop the optimization
7383
*/
74-
async bayesianOptimization(maxIteration: number = 0.75, stopingCriteria?: ((observedValues: DomainPointValue[]) => boolean)) {
84+
async bayesianOptimization(objective: string = 'error', useCrossValidation: boolean = false, maxIteration: number = 0.75, stopingCriteria?: ((observedValues: DomainPointValue[]) => boolean)) {
85+
if (this.checkObjective(objective)) {
86+
return;
87+
}
7588
this.initializePriors();
7689
this.optimizer = new bayesianOptimizer.Optimizer(this.paramspace.domainIndices, this.paramspace.modelsDomains, this.priors.mean, this.priors.kernel);
7790
this.maxIterations = maxIteration;
7891
if (stopingCriteria) {
7992
this.metricsStopingCriteria = stopingCriteria;
8093
}
8194

82-
this.tuneHyperparameters();
95+
this.tuneHyperparameters(objective, useCrossValidation);
8396
}
8497

8598
/**
8699
* Search the best Parameters using grid search.
100+
*
101+
* @param {string} [objective='error'] Define the objective of the optimization. Set to 'error' by default.
102+
* @param {boolean} [useCrossValidation=false] Indicate wheter or not to use cross validation to evaluate the model. Set to 'false' by default.
87103
*/
88-
async gridSearchOptimizytion() {
104+
async gridSearchOptimizytion(objective: string = 'error', useCrossValidation: boolean = false) {
105+
if (this.checkObjective(objective)) {
106+
return;
107+
}
89108
this.initializePriors();
90109
this.optimizer = new gridSearchOptimizer.Optimizer(this.paramspace.domainIndices, this.paramspace.modelsDomains);
91110
this.maxIterations = 1;
92111

93-
this.tuneHyperparameters();
112+
this.tuneHyperparameters(objective, useCrossValidation);
94113
}
95114

96115

97-
async tuneHyperparameters() {
116+
async tuneHyperparameters(objective: string, useCrossValidation: boolean) {
98117
console.log("============================");
99118
console.log("tuning the hyperparameters");
100119

@@ -104,13 +123,12 @@ class AutotunerBaseClass {
104123
var nextOptimizationPoint: BaysianOptimisationStep = this.optimizer.getNextPoint();
105124

106125
// Train a model given the params and obtain a quality metric value.
107-
var value = await this.evaluateModel(nextOptimizationPoint.nextPoint);
126+
var value = await this.evaluateModel(nextOptimizationPoint.nextPoint, objective, useCrossValidation);
108127

109128
// Report the obtained quality metric value.
110129
this.optimizer.addSample(nextOptimizationPoint.nextPoint, value);
111130

112131
optimizing = this.stopingCriteria();
113-
114132
}
115133
// keep observations for the next optimization run
116134
this.priors.commit(this.paramspace.observedValues);
@@ -123,10 +141,19 @@ class AutotunerBaseClass {
123141
class TensorflowlModelAutotuner extends AutotunerBaseClass {
124142
modelDict: ModelDict = {};
125143

126-
constructor(metrics: string[], trainingSet: datasetType, testSet: datasetType, evaluationSet: datasetType, numberOfCategories: number) {
127-
super(metrics, trainingSet, testSet, evaluationSet, numberOfCategories);
144+
/**
145+
* Initialize the autotuner.
146+
*
147+
* @param {string[]} metrics
148+
* @param {DataPoint[]} dataSet
149+
* @param {number} numberOfCategories
150+
* @param {number=0.25} validationSetRatio
151+
* @param {number=0.25} testSetRatio
152+
*/
153+
constructor(metrics: string[], dataSet: DataPoint[], numberOfCategories: number, validationSetRatio: number = 0.25, testSetRatio: number = 0.25) {
154+
super(metrics, dataSet, numberOfCategories, validationSetRatio, testSetRatio);
128155

129-
this.evaluateModel = async (point: number) => {
156+
this.evaluateModel = async (point: number, objective: string, useCrossValidation: boolean) => {
130157
const modelIdentifier = this.paramspace.domain[point]['model'];
131158
const model = this.modelDict[modelIdentifier];
132159
const params = this.paramspace.domain[point]['params'];
@@ -143,22 +170,21 @@ class TensorflowlModelAutotuner extends AutotunerBaseClass {
143170
optimizer: optimizerFunction
144171
});
145172

146-
let concatenatedTensorTrainData = tensorflow.tidy(() => tensorflow.concat(this.dataset.trainingSet.data));
147-
let concatenatedTrainLableData = tensorflow.tidy(() => tensorflow.oneHot(this.dataset.trainingSet.lables, this.dataset.numberOfCategories));
148-
await model.fit(concatenatedTensorTrainData, concatenatedTrainLableData, args);
149-
150-
let concatenatedTensorTestData = tensorflow.tidy(() => tensorflow.concat(this.dataset.trainingSet.data));
151-
let concatenatedTestLables = tensorflow.tidy(() => tensorflow.oneHot(this.dataset.trainingSet.lables, this.dataset.numberOfCategories));
152-
const evaluationResult = model.evaluate(concatenatedTensorTestData, concatenatedTestLables) as tensorflow.Tensor[];
153-
154-
const error = evaluationResult[0].dataSync()[0];
155-
const score = evaluationResult[1].dataSync()[0];
156-
// keep track of the scores
157-
this.observedValues.push({error: error, metricScores: [score]});
158-
return score;
173+
let dataPointValue: DomainPointValue = useCrossValidation
174+
? await this.modelEvaluator.EvaluateSequentialTensorflowModelCV(model, args)
175+
: await this.modelEvaluator.EvaluateSequentialTensorflowModel(model, args);
176+
this.observedValues.push(dataPointValue);
177+
return objective === 'error' ? dataPointValue.error : dataPointValue.metricScores[0];
159178
}
160179
}
161180

181+
/**
182+
* Add a new model and its range of parameters to the autotuner.
183+
*
184+
* @param {string} modelIdentifier Identifier of the model
185+
* @param {tensorflow.Sequential} model Actual Tensorflow model
186+
* @param {SequentialModelParameters} modelParameters Parameters of the Model: define lossfunction, optimizer, algorithm batch size and number of traning epochs.
187+
*/
162188
addModel(modelIdentifier: string, model: tensorflow.Sequential, modelParameters: SequentialModelParameters) {
163189
this.modelDict[modelIdentifier] = model;
164190

src/crossValidation.ts

Lines changed: 0 additions & 41 deletions
This file was deleted.

src/modelEvaluater.ts

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import * as tensorflow from '@tensorflow/tfjs';
2+
import * as math from 'mathjs';
3+
import { DataPoint, DomainPointValue } from '../types/types';
4+
5+
export class ModelEvaluater{
6+
trainData: DataPoint[] = [];
7+
validationData: DataPoint[] = [];
8+
testData: DataPoint[] = [];
9+
numberOfCategories: number;
10+
11+
constructor(dataSet: DataPoint[], numberOfCategories: number, validationSetRatio: number, testSetRatio: number) {
12+
this.numberOfCategories = numberOfCategories;
13+
14+
// shuffle the dataset
15+
tensorflow.util.shuffle(dataSet);
16+
17+
// create validation dataset
18+
const numSamplesValidation = Math.max(
19+
1,
20+
Math.round(dataSet.length * validationSetRatio)
21+
);
22+
this.validationData = dataSet.splice(0, numSamplesValidation);
23+
24+
// create test dataset
25+
const numSamplesTest = Math.max(
26+
1,
27+
Math.round(dataSet.length * testSetRatio)
28+
);
29+
this.testData = dataSet.splice(0, numSamplesTest);
30+
this.trainData = dataSet;
31+
}
32+
33+
ConcatenateTensorData(data: DataPoint[]) {
34+
const trainData: tensorflow.Tensor<tensorflow.Rank>[] = [];
35+
const trainLables: number[] = [];
36+
for (let i = 0; i < data.length; i++) {
37+
trainData.push(data[i].data);
38+
trainLables.push(data[i].lables);
39+
}
40+
41+
let concatenatedTensorData = tensorflow.tidy(() => tensorflow.concat(trainData));
42+
let concatenatedLables = tensorflow.tidy(() => tensorflow.oneHot(trainLables, this.numberOfCategories));
43+
return { concatenatedTensorData, concatenatedLables };
44+
}
45+
46+
EvaluateSequentialTensorflowModel = async (model: tensorflow.Sequential, args: any): Promise<DomainPointValue> => {
47+
var trainData = this.ConcatenateTensorData(this.trainData);
48+
await model.fit(trainData.concatenatedTensorData, trainData.concatenatedLables, args);
49+
50+
var validationData = this.ConcatenateTensorData(this.validationData);
51+
const evaluationResult = model.evaluate(validationData.concatenatedTensorData, validationData.concatenatedLables) as tensorflow.Tensor[];
52+
53+
const error = evaluationResult[0].dataSync()[0];
54+
const score = evaluationResult[1].dataSync()[0];
55+
return {error: error, metricScores: [score]} as DomainPointValue;
56+
}
57+
58+
EvaluateSequentialTensorflowModelCV = async (model: tensorflow.Sequential, args: any): Promise<DomainPointValue> => {
59+
const dataSet = this.trainData.concat(this.validationData);
60+
const dataSize = dataSet.length;
61+
const k = math.min(10, math.floor(math.nthRoot(dataSize) as number));
62+
63+
const dataFolds: DataPoint[][] = Array.from(Array(math.ceil(dataSet.length/k)), (_,i) => dataSet.slice(i*k,i*k+k));
64+
65+
var error = 0;
66+
var score = 0;
67+
for (let i = 0; i < k; i++) {
68+
var validationData = dataFolds[i];
69+
var trainData: DataPoint[] = [];
70+
71+
for (var j = 0; j < k; j++) {
72+
if (j !== i) {
73+
trainData = trainData.concat(dataFolds[j]);
74+
}
75+
}
76+
77+
var concatenatedTrainData = this.ConcatenateTensorData(trainData);
78+
await model.fit(concatenatedTrainData.concatenatedTensorData, concatenatedTrainData.concatenatedLables, args);
79+
80+
var concatenatedValidationData = this.ConcatenateTensorData(validationData);
81+
const evaluationResult = model.evaluate(concatenatedValidationData.concatenatedTensorData, concatenatedValidationData.concatenatedLables) as tensorflow.Tensor[];
82+
83+
const foldError = evaluationResult[0].dataSync()[0];
84+
const foldScore = evaluationResult[1].dataSync()[0];
85+
error += foldError;
86+
score += foldScore;
87+
}
88+
return {error: error/dataSize, metricScores: [score/k]} as DomainPointValue;
89+
90+
}
91+
92+
}

0 commit comments

Comments
 (0)