diff --git a/config_set.m b/config_set.m index 084cb63..72d4391 100644 --- a/config_set.m +++ b/config_set.m @@ -1,7 +1,56 @@ -function [ output_args ] = config_set( input_args ) -%CONFIG_SET Summary of this function goes here -% Detailed explanation goes here +function [ config ] = config_set( varargin ) +%CONFIG_SET Constructs the default configuration stucture to be used by +%nytro_train + % Set default configuration fields + config = struct(); + % data + config.data.shuffle = 1; + + % crossValidation + config.crossValidation.storeTrainingError = 0; + config.crossValidation.validationPart = 0.2; + config.crossValidation.recompute = 0; + config.crossValidation.errorFunction = @rmse; + config.crossValidation.codingFunction = []; + config.crossValidation.stoppingRule = @windowLinearFitting; + config.crossValidation.windowSize = 10; + config.crossValidation.threshold = 0; + + % filter + config.filter.fixedIterations = []; + config.filter.maxIterations = 500; + config.filter.gamma = []; + + % kernel + config.kernel.kernelFunction = @gaussianKernel; + config.kernel.kernelParameters = 1; + config.kernel.m = 100; + + % Parse function inputs + if ~isempty(varargin) + + % Assign parsed parameters to object properties + fields = varargin(1:2:end); + for idx = 1:numel(fields) + + currField = fields{idx}; + % Parse current field + k = strfind(currField , '.'); + k = [0 ; k ; (numel(currField)+1)]; + tokens = cell(1,(numel(k) - 1)); + for i = 1 : (numel(k) - 1); + tokens{i} = currField( (k(i)+1) : (k(i+1)-1) ); + end + + cmdStr = 'config'; + for i = 1 : (numel(tokens) - 1) + cmdStr = strcat(cmdStr , '.(''' , tokens{i} , ''')'); + end + cmdStr = strcat(cmdStr , '.(''' , tokens{end} , ''') = varargin{2*(idx-1) + 2};'); + eval(cmdStr); + end + end end diff --git a/examples/simpleUsage.m b/examples/simpleUsage.m index c2f1882..3301a0b 100644 --- a/examples/simpleUsage.m +++ b/examples/simpleUsage.m @@ -8,4 +8,4 @@ [ training_output ] = nytro_train( Xtr , Ytr ); % Perform predictions on the test set and evaluate results -[ prediction_output ] = nytro_test( Xte , Yte , training_output); \ No newline at end of file +[ prediction_output ] = nytro_test( Xtr , Xte , Yte , training_output); \ No newline at end of file diff --git a/nytro_test.m b/nytro_test.m index 3ed3474..ed3f320 100644 --- a/nytro_test.m +++ b/nytro_test.m @@ -11,7 +11,7 @@ % Perform prediction tic - output.YtePred = KnmTe * training_output.config.best.alpha; + output.YtePred = KnmTe * training_output.best.alpha; if ~isempty(training_output.config.crossValidation.codingFunction) output.YtePred = codingFunction(output.YtePred); end diff --git a/nytro_train.m b/nytro_train.m index 97bf4a9..1bdbf10 100644 --- a/nytro_train.m +++ b/nytro_train.m @@ -105,17 +105,17 @@ output.best.t = Inf; % Error buffers - if config.crossValidation.storeValidationError == 1 - output.errorPath.validation = zeros(1,config.filter.maxIterations) * NaN; - else - output.errorPath.training = []; - end + output.errorPath.validation = zeros(1,config.filter.maxIterations) * NaN; if config.crossValidation.storeTrainingError == 1 output.errorPath.training = zeros(1,config.filter.maxIterations) * NaN; else output.errorPath.training = []; end + % Init time structures + output.time.crossValidationTrain = 0; + output.time.crossValidationEval = 0; + % Subdivide training set in training1 and validation ntr1 = floor( ntr * ( 1 - config.crossValidation.validationPart )); @@ -146,7 +146,7 @@ % Compute kernel tic - Knm = kernelFunction(Xtr, Xtr1(nysIdx,:), config.kernel.kernelParameters); + Knm = config.kernel.kernelFunction(X, Xtr1(nysIdx,:), config.kernel.kernelParameters); Kmm = Knm(trainIdx(nysIdx),:); R = chol( ( Kmm + Kmm') / 2 + 1e-10 * eye(config.kernel.m)); % Compute upper Cholesky factor of Kmm output.time.kernelComputation = toc; @@ -206,7 +206,7 @@ stop = config.crossValidation.stoppingRule(... output.errorPath.validation(1:iter) , ... config.crossValidation.windowSize , ... - config.crossValidation.thres); + config.crossValidation.threshold); if stop == 1 break @@ -216,7 +216,7 @@ output.time.crossValidationTotal = output.time.crossValidationEval + output.time.crossValidationTrain; end - if config.crossValidation.retraining == 1 + if config.crossValidation.recompute == 1 %%% Retrain on whole dataset @@ -253,7 +253,7 @@ % Compute kernels tic - Knm = kernelFunction(X, X(nysIdx,:), config.kernel.kernelParameters); + Knm = config.kernel.kernelFunction(X, X(nysIdx,:), config.kernel.kernelParameters); Kmm = Knm(nysIdx,:); R = chol( ( Kmm + Kmm') / 2 + 1e-10 * eye(config.kernel.m)); % Compute upper Cholesky factor of Kmm output.time.kernelComputation = toc;