-
Notifications
You must be signed in to change notification settings - Fork 36
/
Copy pathclassRF_train.m
452 lines (407 loc) · 17.5 KB
/
classRF_train.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
%**************************************************************
%* mex interface to Andy Liaw et al.'s C code (used in R package randomForest)
%* Added by Abhishek Jaiantilal ( abhishek.jaiantilal@colorado.edu )
%* License: GPLv2
%* Version: 0.02
%
% Calls Classification Random Forest
% A wrapper matlab file that calls the mex file
% This does training given the data and labels
% Documentation copied from R-packages pdf
% http://cran.r-project.org/web/packages/randomForest/randomForest.pdf
% Tutorial on getting this working in tutorial_ClassRF.m
%**************************************************************
% function model = classRF_train(X,Y,ntree,mtry, extra_options)
%
%___Options
% requires 2 arguments and the rest 3 are optional
% X: data matrix
% Y: target values
% ntree (optional): number of trees (default is 500). also if set to 0
% will default to 500
% mtry (default is floor(sqrt(size(X,2))) D=number of features in X). also if set to 0
% will default to 500
%
%
% Note: TRUE = 1 and FALSE = 0 below
% extra_options represent a structure containing various misc. options to
% control the RF
% extra_options.replace = 0 or 1 (default is 1) sampling with or without
% replacement
% extra_options.classwt = priors of classes. Here the function first gets
% the labels in ascending order and assumes the
% priors are given in the same order. So if the class
% labels are [-1 1 2] and classwt is [0.1 2 3] then
% there is a 1-1 correspondence. (ascending order of
% class labels). Once this is set the freq of labels in
% train data also affects.
% extra_options.cutoff (Classification only) = A vector of length equal to number of classes. The ?winning?
% class for an observation is the one with the maximum ratio of proportion
% of votes to cutoff. Default is 1/k where k is the number of classes (i.e., majority
% vote wins).
% extra_options.strata = (not yet stable in code) variable that is used for stratified
% sampling. I don't yet know how this works. Disabled
% by default
% extra_options.sampsize = Size(s) of sample to draw. For classification,
% if sampsize is a vector of the length the number of strata, then sampling is stratified by strata,
% and the elements of sampsize indicate the numbers to be
% drawn from the strata.
% extra_options.nodesize = Minimum size of terminal nodes. Setting this number larger causes smaller trees
% to be grown (and thus take less time). Note that the default values are different
% for classification (1) and regression (5).
% extra_options.importance = Should importance of predictors be assessed?
% extra_options.localImp = Should casewise importance measure be computed? (Setting this to TRUE will
% override importance.)
% extra_options.proximity = Should proximity measure among the rows be calculated?
% extra_options.oob_prox = Should proximity be calculated only on 'out-of-bag' data?
% extra_options.do_trace = If set to TRUE, give a more verbose output as randomForest is run. If set to
% some integer, then running output is printed for every
% do_trace trees.
% extra_options.keep_inbag Should an n by ntree matrix be returned that keeps track of which samples are
% 'in-bag' in which trees (but not how many times, if sampling with replacement)
% extra_options.categorical_feature a 1xD true/false vector to say which features are categorical (true), which are numeric (false)
% maximum of 32 categories per feature is permitted
%
% Options eliminated
% corr_bias which happens only for regression ommitted
% norm_votes - always set to return total votes for each class.
%
%___Returns model which has
% importance = a matrix with nclass + 2 (for classification) or two (for regression) columns.
% For classification, the first nclass columns are the class-specific measures
% computed as mean decrease in accuracy. The nclass + 1st column is the
% mean decrease in accuracy over all classes. The last column is the mean decrease
% in Gini index. For Regression, the first column is the mean decrease in
% accuracy and the second the mean decrease in MSE. If importance=FALSE,
% the last measure is still returned as a vector.
% importanceSD = The ?standard errors? of the permutation-based importance measure. For classification,
% a p by nclass + 1 matrix corresponding to the first nclass + 1
% columns of the importance matrix. For regression, a length p vector.
% localImp = a p by n matrix containing the casewise importance measures, the [i,j] element
% of which is the importance of i-th variable on the j-th case. NULL if
% localImp=FALSE.
% ntree = number of trees grown.
% mtry = number of predictors sampled for spliting at each node.
% votes (classification only) a matrix with one row for each input data point and one
% column for each class, giving the fraction or number of ?votes? from the random
% forest.
% oob_times number of times cases are 'out-of-bag' (and thus used in computing OOB error
% estimate)
% proximity if proximity=TRUE when randomForest is called, a matrix of proximity
% measures among the input (based on the frequency that pairs of data points are
% in the same terminal nodes).
% errtr = first column is OOB Err rate, second is for class 1 and so on
function model=classRF_train(X,Y,ntree,mtry, extra_options,Xtst,Ytst)
DEFAULTS_ON =0;
% DEBUG_ON=1;
if exist('Xtst','var') && exist('Ytst','var')
if(size(Xtst,1)~=length(Ytst))
error('Size of Xtst and Ytst dont match');
end
fprintf('Test data available\n');
tst_available=1;
tst_size = length(Ytst);
else
Xtst=X(1,:);
Ytst=Y(1);
tst_available=0;
tst_size=0;
end
TRUE=1;
FALSE=0;
orig_labels = sort(unique([Y; Ytst]));
Y_new = Y;
Y_new_tst = Ytst;
new_labels = 1:length(orig_labels);
for i=1:length(orig_labels)
Y_new(find(Y==orig_labels(i)))=Inf;
Y_new(isinf(Y_new))=new_labels(i);
Y_new_tst(find(Ytst==orig_labels(i)))=Inf;
Y_new_tst(isinf(Y_new_tst))=new_labels(i);
end
Y = Y_new;
Ytst = Y_new_tst;
if exist('extra_options','var')
if isfield(extra_options,'DEBUG_ON'); DEBUG_ON = extra_options.DEBUG_ON; end
if isfield(extra_options,'replace'); replace = extra_options.replace; end
if isfield(extra_options,'classwt'); classwt = extra_options.classwt; end
if isfield(extra_options,'cutoff'); cutoff = extra_options.cutoff; end
if isfield(extra_options,'strata'); strata = extra_options.strata; end
if isfield(extra_options,'sampsize'); sampsize = extra_options.sampsize; end
if isfield(extra_options,'nodesize'); nodesize = extra_options.nodesize; end
if isfield(extra_options,'importance'); importance = extra_options.importance; end
if isfield(extra_options,'localImp'); localImp = extra_options.localImp; end
if isfield(extra_options,'nPerm'); nPerm = extra_options.nPerm; end
if isfield(extra_options,'proximity'); proximity = extra_options.proximity; end
if isfield(extra_options,'oob_prox'); oob_prox = extra_options.oob_prox; end
%if isfield(extra_options,'norm_votes'); norm_votes = extra_options.norm_votes; end
if isfield(extra_options,'do_trace'); do_trace = extra_options.do_trace; end
%if isfield(extra_options,'corr_bias'); corr_bias = extra_options.corr_bias; end
if isfield(extra_options,'keep_inbag'); keep_inbag = extra_options.keep_inbag; end
if isfield(extra_options,'print_verbose_tree_progression'); print_verbose_tree_progression = extra_options.print_verbose_tree_progression; end
end
keep_forest=1; %always save the trees :)
%set defaults if not already set
if ~exist('DEBUG_ON','var') DEBUG_ON=FALSE; end
if ~exist('replace','var'); replace = TRUE; end
%if ~exist('classwt','var'); classwt = []; end %will handle these three later
%if ~exist('cutoff','var'); cutoff = 1; end
%if ~exist('strata','var'); strata = 1; end
if ~exist('sampsize','var');
if (replace)
sampsize = size(X,1);
else
sampsize = ceil(0.632*size(X,1));
end;
end
if ~exist('nodesize','var'); nodesize = 1; end %classification=1, regression=5
if ~exist('importance','var'); importance = FALSE; end
if ~exist('localImp','var'); localImp = FALSE; end
if ~exist('nPerm','var'); nPerm = 1; end
%if ~exist('proximity','var'); proximity = 1; end %will handle these two later
%if ~exist('oob_prox','var'); oob_prox = 1; end
%if ~exist('norm_votes','var'); norm_votes = TRUE; end
if ~exist('do_trace','var'); do_trace = FALSE; end
%if ~exist('corr_bias','var'); corr_bias = FALSE; end
if ~exist('keep_inbag','var'); keep_inbag = FALSE; end
if ~exist('print_verbose_tree_progression','var'); print_verbose_tree_progression = FALSE; end
if ~exist('ntree','var') | ntree<=0
ntree=500;
DEFAULTS_ON=1;
end
if ~exist('mtry','var') | mtry<=0 | mtry>size(X,2)
mtry =floor(sqrt(size(X,2)));
end
addclass =isempty(Y);
if (~addclass && length(unique(Y))<2)
error('need atleast two classes for classification');
end
[N D] = size(X);
n_size = N;
p_size = D;
if N==0; error(' data (X) has 0 rows');end
if (mtry <1 || mtry > D)
DEFAULTS_ON=1;
end
mtry = max(1,min(D,round(mtry)));
if DEFAULTS_ON
fprintf('\tSetting to defaults %d trees and mtry=%d\n',ntree,mtry);
end
if ~isempty(Y)
if length(Y)~=N,
error('Y size is not the same as X size');
end
addclass = FALSE;
else
if ~addclass,
addclass=TRUE;
end
Y_new = [ones(N,1); ones(N,1)*2];
Y = Y_new;
X = [X; X];
% no need to do the below as
% [N D] = size(X);
% n_size = N;
% p_size = D;
% n_size
% p_size
%error('have to fill stuff here')
end
if ~isempty(find(isnan(X))); error('NaNs in X'); end
if ~isempty(find(isnan(Y))); error('NaNs in Y'); end
%now handle categories. Problem is that categories in R are more
%enhanced. In this i ask the user to specify the column/features to
%consider as categories, 1 if all the values are real values else
%specify the number of categories here
orig_uniques_in_feature = cell(1,D);
mapped_uniques_in_feature = cell(1,D);
if exist ('extra_options','var') && isfield(extra_options,'categorical_feature')
ncat = ones(1,D);
for i=1:D
if extra_options.categorical_feature(i)
orig_uniques_in_feature{i} = sort(unique(X(:,i)));
tmp_uniques_in_feature = orig_uniques_in_feature{i};
mapped_uniques_in_feature{i} = 1:length(tmp_uniques_in_feature);
tmp_mapped_uniques_in_feature = mapped_uniques_in_feature{i};
X_loc = X(:,i); %cannot change the original array which may cause chained change of categories to something totally wrong
for j=1:length(tmp_uniques_in_feature)
indices_to_change = find( X(:,i) == tmp_uniques_in_feature(j) );
X_loc(indices_to_change) = tmp_mapped_uniques_in_feature(j);
end
X(:,i) = X_loc;
ncat(i) = length(tmp_uniques_in_feature);
else
ncat(i) = 1;
end
end
else
ncat = ones(1,D);
end
maxcat = max(ncat);
if maxcat>32
error('Can not handle categorical predictors with more than 32 categories');
end
%classRF - line 88 in randomForest.default.R
nclass = length(unique(Y));
if ~exist('cutoff','var')
cutoff = ones(1,nclass)* (1/nclass);
else
if sum(cutoff)>1 || sum(cutoff)<0 || length(find(cutoff<=0))>0 || length(cutoff)~=nclass
error('Incorrect cutoff specified');
end
end
if ~exist('classwt','var')
classwt = ones(1,nclass);
ipi=0;
else
if length(classwt)~=nclass
error('Length of classwt not equal to the number of classes')
end
if ~isempty(find(classwt<=0))
error('classwt must be positive');
end
ipi=1;
end
if ~exist('proximity','var')
proximity = addclass;
oob_prox = proximity;
end
if ~exist('oob_prox','var')
oob_prox = proximity;
end
%i handle the below in the mex file
% if proximity
% prox = zeros(N,N);
% proxts = 1;
% else
% prox = 1;
% proxts = 1;
% end
%i handle the below in the mex file
if localImp
importance = TRUE;
% impmat = zeors(D,N);
else
% impmat = 1;
end
if importance
if (nPerm<1)
nPerm = int32(1);
else
nPerm = int32(nPerm);
end
%classRF
% impout = zeros(D,nclass+2);
% impSD = zeros(D,nclass+1);
else
% impout = zeros(D,1);
% impSD = 1;
end
%i handle the below in the mex file
%somewhere near line 157 in randomForest.default.R
if addclass
nsample = 2*n_size;
else
nsample = n_size;
end
Stratify = (length(sampsize)>1);
if (~Stratify && sampsize>N)
error('Sampsize too large')
end
if Stratify
if ~exist('strata','var')
strata = Y;
end
nsum = sum(sampsize);
if ( ~isempty(find(sampsize<=0)) || nsum==0)
error('Bad sampsize specification');
end
else
nsum = sampsize;
end
%i handle the below in the mex file
%nrnodes = 2*floor(nsum/nodesize)+1;
%xtest = 1;
%ytest = 1;
%ntest = 1;
%labelts = FALSE;
%nt = ntree;
%[ldau,rdau,nodestatus,nrnodes,upper,avnode,mbest,ndtree]=
%keyboard
if Stratify
strata = int32(strata);
else
strata = int32(1);
end
Options = int32([addclass, importance, localImp, proximity, oob_prox, do_trace, keep_forest, replace, Stratify, keep_inbag]);
if DEBUG_ON
%print the parameters that i am sending in
fprintf('size(x) %d\n',size(X));
fprintf('size(y) %d\n',size(Y));
fprintf('nclass %d\n',nclass);
fprintf('size(ncat) %d\n',size(ncat));
fprintf('maxcat %d\n',maxcat);
fprintf('size(sampsize) %d\n',size(sampsize));
fprintf('sampsize[0] %d\n',sampsize(1));
fprintf('Stratify %d\n',Stratify);
fprintf('Proximity %d\n',proximity);
fprintf('oob_prox %d\n',oob_prox);
fprintf('strata %d\n',strata);
fprintf('ntree %d\n',ntree);
fprintf('mtry %d\n',mtry);
fprintf('ipi %d\n',ipi);
fprintf('classwt %f\n',classwt);
fprintf('cutoff %f\n',cutoff);
fprintf('nodesize %f\n',nodesize);
fprintf('print verbose %f\n',print_verbose_tree_progression);
end
[nrnodes,ntree,xbestsplit,classwt,cutoff,treemap,nodestatus,nodeclass,bestvar,ndbigtree,mtry ...
outcl, counttr, prox, impmat, impout, impSD, errtr, inbag, outclts, proxts, errts] ...
= mexClassRF_train(X',int32(Y_new),length(unique(Y)),ntree,mtry,int32(ncat), ...
int32(maxcat), int32(sampsize), strata, Options, int32(ipi), ...
classwt, cutoff, int32(nodesize),int32(nsum), int32(n_size), int32(p_size), int32(nsample),...
int32(tst_available), Xtst',int32(Ytst),int32(tst_size), int32(print_verbose_tree_progression));
if maxcat ~= 1 % maxcat = 1: no categorical features exist so we dont have to save anything
model.orig_uniques_in_feature = orig_uniques_in_feature;
model.mapped_uniques_in_feature = mapped_uniques_in_feature;
model.ncat = ncat;
model.categorical_feature = extra_options.categorical_feature;
end
model.nrnodes=nrnodes;
model.ntree=ntree;
model.xbestsplit=xbestsplit;
model.classwt=classwt;
model.cutoff=cutoff;
model.treemap=treemap;
model.nodestatus=nodestatus;
model.nodeclass=nodeclass;
model.bestvar = bestvar;
model.ndbigtree = ndbigtree;
model.mtry = mtry;
model.orig_labels=orig_labels;
model.new_labels=new_labels;
model.nclass = length(unique(Y));
model.outcl = outcl; %predicted label on training
model.outclts = outclts; %predicted label on test
model.counttr = counttr;
if proximity
model.proximity = prox;
if tst_available
model.proximity_tst = proxts;
else
model.proximity_tst = [];
end
else
model.proximity = [];
end
model.localImp = impmat;
model.importance = impout;
model.importanceSD = impSD;
model.errtr = errtr';
model.errts = errts';
model.inbag = inbag;
model.votes = counttr';
model.oob_times = sum(counttr)';
clear mexClassRF_train
%keyboard
1;