-
Notifications
You must be signed in to change notification settings - Fork 1
/
pret_cost.m
113 lines (99 loc) · 3.57 KB
/
pret_cost.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
function cost = pret_cost(data,samplerate,trialwindow,model,options)
% pret_cost
% cost = pret_cost(data,samplerate,trialwindow,model)
% cost = pret_cost(data,samplerate,trialwindow,model,options)
% options = pret_cost()
%
% Calculates the sum of the square errors between some input pupil size
% time series "data" and a time series created from the specifications and
% parameters in "model". If a MxN matrix with M time series is input into
% "data", the cost will be computed between each time series and the single
% time series produced from "model" and summed.
%
% Inputs:
%
% data = a single pupil size time series as a row vector OR a MxN
% matrix with M time series.
%
% samplerate = sampling rate of data in Hz.
%
% trialwindow = a 2 element vector containing the starting and ending
% times (in ms) of the trial epoch.
%
% model = model structure created by pret_model and filled in by user.
% Parameter values in model.ampvals, model.boxampvals, model.latvals,
% model.tmaxval, and model.yintval must be provided.
% *Note - an optim structure from pret_estimate, pret_bootstrap, or
% pret_optim can be input in the place of model*
%
% options = options structure for pret_cost. Default options can be
% returned by calling this function with no arguments.
%
% Outputs:
%
% cost = sum of the square errors between "data" and time series
% created using "model"
%
% Options
%
% pret_calc_options = options structure for pret_calc, which pret_cost uses
% to produce the time series from "model".
%
%
% Copyright (C) 2019 Jacob Parker and Rachel Denison
%
% This program is free software: you can redistribute it and/or modify
% it under the terms of the GNU General Public License as published by
% the Free Software Foundation, either version 3 of the License, or
% (at your option) any later version.
%
% This program is distributed in the hope that it will be useful,
% but WITHOUT ANY WARRANTY; without even the implied warranty of
% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
% GNU General Public License for more details.
%
% You should have received a copy of the GNU General Public License
% along with this program. If not, see <https://www.gnu.org/licenses/>.
%
if nargin < 5
opts = pret_default_options();
options = opts.pret_cost;
clear opts
if nargin < 1
cost = options;
return
end
end
%OPTIONS
pret_calc_options = options.pret_calc;
sfact = samplerate/1000;
time = trialwindow(1):1/sfact:trialwindow(2);
%check inputs
pret_model_check(model)
%samplerate, trialwindow vs data
if length(time) ~= length(data)
error('The number of time points does not equal the number of data points')
end
%sample rate vs model sample rate
if samplerate ~= model.samplerate
error('The input sample rate and the sample rate in model do not match')
end
%model time window vs time points
if ~(any(model.window(1) == time)) || ~(any(model.window(2) == time ))
error('Model time window does not fall on time points according to sample rate and trial window')
end
%how many time series to fit simultaneously?
nts = size(data,1);
%crop data to match model.window
datalb = find(model.window(1) == time);
dataub = find(model.window(2) == time);
data = data(:,datalb:dataub);
Ycalc = pret_calc(model,pret_calc_options);
if nts>1
% concatenate time series
temp = data';
data = temp(:)';
% concatenate model prediction
Ycalc = repmat(Ycalc,1,nts);
end
cost = nansum((data-Ycalc).^2);