-
Notifications
You must be signed in to change notification settings - Fork 0
/
snowballSample.m
125 lines (115 loc) · 3.19 KB
/
snowballSample.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
function [train, test] = snowballSample(graph)
% Get all edges
[I,J] = find(graph);
n_e = length(I);
front_tr = zeros(n_e,2);
front_te = zeros(n_e,2);
next_front_tr = 1;
next_front_te = 1;
%% Randomly sample train/test seeds
seed = randsample(n_e, 2);
% train
i = I(seed(1)); j = J(seed(1));
graph(i,j) = 0;
train = [i j];
front_tr(next_front_tr,:) = [i j];
n_e = n_e - 1;
next_front_tr = next_front_tr + 1;
% test
i = I(seed(2)); j = J(seed(2));
graph(i,j) = 0;
test = [i j];
front_te(next_front_te,:) = [i j];
n_e = n_e - 1;
next_front_te = next_front_te + 1;
%% Grab all edges until none left
while n_e > 0
%% Train
% Try to find edges adjacent to frontier
added = 0;
while added == 0 && nnz(front_tr(:,1)) > 0
% Randomly choose edge on frontier
frontIdx = find(front_tr(:,1));
if length(frontIdx) > 1
idx = randsample(frontIdx, 1);
else
idx = frontIdx;
end
i = front_tr(idx,1); j = front_tr(idx,2);
% Update frontier
front_tr(idx,:) = [0 0];
neighbors_i = find(graph(i,:))';
neighbors_j = find(graph(j,:))';
% If any adjacent edges
if length(neighbors_i) + length(neighbors_j) > 0
edges = [repmat(i,length(neighbors_i),1) neighbors_i; ...
repmat(j,length(neighbors_j),1) neighbors_j];
added = size(edges,1);
% Add to set/frontier and update graph/remaining edge count
graph(edges(:,1),edges(:,2)) = 0;
train = [train ; edges];
front_tr(next_front_tr:next_front_tr+added-1,:) = edges;
n_e = n_e - added;
next_front_tr = next_front_tr + added;
end
end
% If couldn't find any edges adjacent to frontier, reseed
if added == 0
[I,J] = find(graph);
seed = randsample(n_e, 1);
i = I(seed); j = J(seed);
graph(i,j) = 0;
train = [train ; i j];
front_tr(next_front_tr,:) = [i j];
n_e = n_e - 1;
next_front_tr = next_front_tr + 1;
end
% Check for remaining edges
if n_e == 0
break
end
%% Test
% Try to find edges adjacent to frontier
added = 0;
while added == 0 && nnz(front_te(:,1)) > 0
% Randomly choose edge on frontier
frontIdx = find(front_te(:,1));
if length(frontIdx) > 1
idx = randsample(frontIdx, 1);
else
idx = frontIdx;
end
i = front_te(idx,1); j = front_te(idx,2);
% Update frontier
front_te(idx,:) = [0 0];
neighbors_i = find(graph(i,:))';
neighbors_j = find(graph(j,:))';
% If any adjacent edges
if length(neighbors_i) + length(neighbors_j) > 0
edges = [repmat(i,length(neighbors_i),1) neighbors_i; ...
repmat(j,length(neighbors_j),1) neighbors_j];
added = size(edges,1);
% Add to set/frontier and update graph/remaining edge count
graph(edges(:,1),edges(:,2)) = 0;
test = [test; edges];
front_te(next_front_te:next_front_te+added-1,:) = edges;
n_e = n_e - added;
next_front_te = next_front_te + added;
end
end
% If couldn't find any edges adjacent to frontier, reseed
if added == 0
[I,J] = find(graph);
seed = randsample(n_e, 1);
i = I(seed); j = J(seed);
graph(i,j) = 0;
test = [test ; i j];
front_te(next_front_te,:) = [i j];
n_e = n_e - 1;
next_front_te = next_front_te + 1;
end
end
%% Convert edge lists to graphs
n = size(graph,1);
train = sparse(train(:,1), train(:,2), ones(size(train,1),1), n, n);
test = sparse(test(:,1), test(:,2), ones(size(test,1),1), n, n);