-
Notifications
You must be signed in to change notification settings - Fork 0
/
TrainANN.m
76 lines (76 loc) · 2.12 KB
/
TrainANN.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
function results=TrainANN(x,t)
if ~isempty(x)
trainFcn = 'trainlm'; % Levenberg-Marquardt
% Create a Fitting Network
hiddenLayerSize = 10;
net = fitnet(hiddenLayerSize,trainFcn);
% Choose Input and Output Pre/Post-Processing Functions
% For a list of all processing functions type: help nnprocess
net.input.processFcns = {'removeconstantrows','mapminmax'};
net.output.processFcns = {'removeconstantrows','mapminmax'};
% Setup Division of Data for Training, Validation, Testing
% For a list of all data division functions type: help nndivide
net.divideFcn = 'dividerand'; % Divide data randomly
net.divideMode = 'sample'; % Divide up every sample
net.divideParam.trainRatio = 70/100;
net.divideParam.valRatio = 15/100;
net.divideParam.testRatio = 15/100;
% Choose a Performance Function
% For a list of all performance functions type: help nnperformance
net.performFcn = 'mse'; % Mean squared error
% Choose Plot Functions
% For a list of all plot functions type: help nnplot
net.plotFcns = {};
% net.plotFcns = {'plotperform','plottrainstate','ploterrhist', 'plotregression', 'plotfit'};
net.trainParam.showWindow=false;
% Train the Network
[net,tr] = train(net,x,t);
% Test the Network
y = net(x);
e = gsubtract(t,y);
E = perform(net,t,y);
else
y=inf(size(t));
e=inf(size(t));
E=inf;
tr.trainInd=[];
tr.valInd=[];
tr.testInd=[];
end
% All Data
Data.x=x;
Data.t=t;
Data.y=y;
Data.e=e;
Data.E=E;
% Train Data
TrainData.x=x(:,tr.trainInd);
TrainData.t=t(:,tr.trainInd);
TrainData.y=y(:,tr.trainInd);
TrainData.e=e(:,tr.trainInd);
if ~isempty(x)
TrainData.E=perform(net,TrainData.t,TrainData.y);
else
TrainData.E=inf;
end
% Validation and Test Data
TestData.x=x(:,[tr.testInd tr.valInd]);
TestData.t=t(:,[tr.testInd tr.valInd]);
TestData.y=y(:,[tr.testInd tr.valInd]);
TestData.e=e(:,[tr.testInd tr.valInd]);
if ~isempty(x)
TestData.E=perform(net,TestData.t,TestData.y);
else
TestData.E=inf;
end
% Export Results
if ~isempty(x)
results.net=net;
else
results.net=[];
end
results.Data=Data;
results.TrainData=TrainData;
% results.ValidationData=ValidationData;
results.TestData=TestData;
end