Skip to content

Commit

Permalink
1st working version of NYTRO crossval. Needs further testing.
Browse files Browse the repository at this point in the history
  • Loading branch information
Raffaello Camoriano committed Oct 20, 2015
1 parent b66fae6 commit 90b195f
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 14 deletions.
55 changes: 52 additions & 3 deletions config_set.m
Original file line number Diff line number Diff line change
@@ -1,7 +1,56 @@
function [ output_args ] = config_set( input_args )
%CONFIG_SET Summary of this function goes here
% Detailed explanation goes here
function [ config ] = config_set( varargin )
%CONFIG_SET Constructs the default configuration stucture to be used by
%nytro_train

% Set default configuration fields
config = struct();

% data
config.data.shuffle = 1;

% crossValidation
config.crossValidation.storeTrainingError = 0;
config.crossValidation.validationPart = 0.2;
config.crossValidation.recompute = 0;
config.crossValidation.errorFunction = @rmse;
config.crossValidation.codingFunction = [];
config.crossValidation.stoppingRule = @windowLinearFitting;
config.crossValidation.windowSize = 10;
config.crossValidation.threshold = 0;

% filter
config.filter.fixedIterations = [];
config.filter.maxIterations = 500;
config.filter.gamma = [];

% kernel
config.kernel.kernelFunction = @gaussianKernel;
config.kernel.kernelParameters = 1;
config.kernel.m = 100;

% Parse function inputs
if ~isempty(varargin)

% Assign parsed parameters to object properties
fields = varargin(1:2:end);
for idx = 1:numel(fields)

currField = fields{idx};
% Parse current field
k = strfind(currField , '.');
k = [0 ; k ; (numel(currField)+1)];
tokens = cell(1,(numel(k) - 1));
for i = 1 : (numel(k) - 1);
tokens{i} = currField( (k(i)+1) : (k(i+1)-1) );
end

cmdStr = 'config';
for i = 1 : (numel(tokens) - 1)
cmdStr = strcat(cmdStr , '.(''' , tokens{i} , ''')');
end
cmdStr = strcat(cmdStr , '.(''' , tokens{end} , ''') = varargin{2*(idx-1) + 2};');
eval(cmdStr);
end
end
end

2 changes: 1 addition & 1 deletion examples/simpleUsage.m
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
[ training_output ] = nytro_train( Xtr , Ytr );

% Perform predictions on the test set and evaluate results
[ prediction_output ] = nytro_test( Xte , Yte , training_output);
[ prediction_output ] = nytro_test( Xtr , Xte , Yte , training_output);
2 changes: 1 addition & 1 deletion nytro_test.m
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

% Perform prediction
tic
output.YtePred = KnmTe * training_output.config.best.alpha;
output.YtePred = KnmTe * training_output.best.alpha;
if ~isempty(training_output.config.crossValidation.codingFunction)
output.YtePred = codingFunction(output.YtePred);
end
Expand Down
18 changes: 9 additions & 9 deletions nytro_train.m
Original file line number Diff line number Diff line change
Expand Up @@ -105,17 +105,17 @@
output.best.t = Inf;

% Error buffers
if config.crossValidation.storeValidationError == 1
output.errorPath.validation = zeros(1,config.filter.maxIterations) * NaN;
else
output.errorPath.training = [];
end
output.errorPath.validation = zeros(1,config.filter.maxIterations) * NaN;
if config.crossValidation.storeTrainingError == 1
output.errorPath.training = zeros(1,config.filter.maxIterations) * NaN;
else
output.errorPath.training = [];
end

% Init time structures
output.time.crossValidationTrain = 0;
output.time.crossValidationEval = 0;

% Subdivide training set in training1 and validation

ntr1 = floor( ntr * ( 1 - config.crossValidation.validationPart ));
Expand Down Expand Up @@ -146,7 +146,7 @@

% Compute kernel
tic
Knm = kernelFunction(Xtr, Xtr1(nysIdx,:), config.kernel.kernelParameters);
Knm = config.kernel.kernelFunction(X, Xtr1(nysIdx,:), config.kernel.kernelParameters);
Kmm = Knm(trainIdx(nysIdx),:);
R = chol( ( Kmm + Kmm') / 2 + 1e-10 * eye(config.kernel.m)); % Compute upper Cholesky factor of Kmm
output.time.kernelComputation = toc;
Expand Down Expand Up @@ -206,7 +206,7 @@
stop = config.crossValidation.stoppingRule(...
output.errorPath.validation(1:iter) , ...
config.crossValidation.windowSize , ...
config.crossValidation.thres);
config.crossValidation.threshold);

if stop == 1
break
Expand All @@ -216,7 +216,7 @@
output.time.crossValidationTotal = output.time.crossValidationEval + output.time.crossValidationTrain;
end

if config.crossValidation.retraining == 1
if config.crossValidation.recompute == 1

%%% Retrain on whole dataset

Expand Down Expand Up @@ -253,7 +253,7 @@

% Compute kernels
tic
Knm = kernelFunction(X, X(nysIdx,:), config.kernel.kernelParameters);
Knm = config.kernel.kernelFunction(X, X(nysIdx,:), config.kernel.kernelParameters);
Kmm = Knm(nysIdx,:);
R = chol( ( Kmm + Kmm') / 2 + 1e-10 * eye(config.kernel.m)); % Compute upper Cholesky factor of Kmm
output.time.kernelComputation = toc;
Expand Down

0 comments on commit 90b195f

Please sign in to comment.