-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathDemo_pocs_cnn.m
210 lines (187 loc) · 7.63 KB
/
Demo_pocs_cnn.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
%%% Seismic Interpolation using POCS-CNN method.
%%%
%%%
clear; close all
addpath('utilities');
addpath('seismicData');
addpath('seismicData/masks');
%%% choose the original complete data, the default variable in .mat is 'D'
dataChoice = 1;
switch dataChoice
case 1
Data = 'hyperbolic-events';
otherwise
error('Unexpected choice.');
end
load([Data, '.mat'])
Dataname = Data;
%%% ------------------- Parameters setting -------------------------------
noiseL = 0; % noise level, valid range [0, 255],
% non-zero for simultaneous denoising and
% interpolation.
sampleType = 'iregc'; % down-sampling method, valid choice :
% 'regc' : regularly down-sampling
% 'iregc': irregularly down-sampling
% 'randc': ramdonly down-sampling
Ratio= .5; % Sampling ratio. valid in [0, 1]. The smaller
% it is, the less traces will be preserved in
% down-sampling process.
useMaskFile = 1; % whether to load prepared sampling matrix.
% 1 for loading prepared sampling matrix
% generated by program makeMask.m
% 0 for generating sampling matrix.
interpMethod = 'shepard'; % pre-interpolating method, valid choice in
% ['', 'shepard', 'nearest', 'linear',
% 'cubic', 'natural']. The '' stands for
% not doing pre-interpolating.
%
totalIter = 30; % No. of POCS interations.
%%% Two key parameters for CNN-POCS method
lambda1 = 30; % The upper bound of sigma.
lambda2 = 10; % The lower bound of sigma. Usually setting
% lamdba2 = 2 could yield good results in
% noise free cases. It needs to be
% fine-tuned for noisy data.
%
%%% some other parameters setting for result visualizing and saving out.
dx = 10;
dt = 0.004;
freqThresh = 50;
showResult = 1;
saveResult = 0;
saveSnr = 0;
useGPU = 0;
%%% -----------------------------------------------------------------------
%%% First of all, We have to cast the orignal data into value
%%% range of [0, 1].
label = single(D);
[m, n] = size(label);
% normalize to [0, 1]
xmin = min(label(:));
label = label - xmin;
xmax = max(label(:));
label = label/xmax;
nlabel = label + single(noiseL/255*randn(size(label)));
%%% Before interpolating, we do down-sampling the original data to get
%%% the down-sampled data.
if useMaskFile
load(['mask',num2str(m),'x',num2str(n),sampleType,num2str(fix(Ratio*100)),'.mat']);
else
mask = projMask(D, Ratio, sampleType);
end
input = nlabel.*mask;
input(mask==0) = mean(nlabel(:));
SNRinput = CalSNR(D, input*xmax+xmin);
PSNRinput = Psnr(D, input*xmax+xmin);
disp(['Input SNR: ', num2str(SNRinput), ' PSNR: ', num2str(PSNRinput)]);
%%% Pre-interpolating the data if it is demanded.
if ~strcmp(interpMethod, 'shepard') && ~strcmp(interpMethod, '')
initD = initInterp(input, mask, interpMethod);
else
if strcmp(interpMethod, 'shepard')
window = 10; % default 10, from [5, 30]
initD = shepard_initialize(input, mask, window);
else
initD = input;
end
end
SNRinitD = CalSNR(D, initD*xmax+xmin);
PSNRinitD = Psnr(D, initD*xmax+xmin);
disp(['Pre-iterpolated SNR: ', num2str(SNRinitD), ' PSNR: ', num2str(PSNRinitD)]);
%%% Now it's time to interpolating the down-sampled data 'input' using the
%%% CNN-POCS method.
inIter = 1;
SigmaS = (lambda2/lambda1).^((0:totalIter-1)/(totalIter-1))*lambda1;
ns = min(25,max(ceil(SigmaS/2),1));
ns = [ns(1)-1,ns];
snrs = zeros(1, totalIter);
folderModel = 'models';
load(fullfile(folderModel,'model.mat'));
output = single(initD);
input = single(input);
mask = single(mask);
if useGPU
input = gpuArray(input);
mask = gpuArray(mask);
output = gpuaArray(output);
end
tic; cput = cputime;
for itern = 1 : totalIter
output = (1 - mask).*output + mask.*input;
if ns(itern+1) ~= ns(itern)
[net] = loadmodel(SigmaS(itern), CNNdenoiser);
net = vl_simplenn_tidy(net);
if useGPU
net = vl_simplenn_move(net, 'gpu');
end
end
for k = 1 : inIter
res = vl_simplenn(net,output,[],[],'conserveMemory',true,'mode','test');
output = output - res(end).x;
end
snrs(itern) = CalSNR(D, output*xmax+xmin);
end
if useGPU
input = gather(input);
mask = gather(mask);
output = gather(output);
end
toc;
disp(['CPU time: ', num2str(cputime-cput)]);
pocscnnRecon = output*xmax + xmin;
SNRCur = CalSNR(D, pocscnnRecon);
PSNRCur = Psnr(D, pocscnnRecon);
disp(['CNN-POCS SNR: ', num2str(SNRCur), ' PSNR: ', num2str(PSNRCur)]);
%%% Finally we visualize the results.
if showResult
x = (0:m-1)*dx; t = (0:n-1)*dt;
fig1 = figure(1); set(gcf, 'color', 'white'), set(gcf, 'Position', [100, 100, 900, 700]), colormap(gray);
sub1 = subplot(221);
imagesc(x,t,D), %caxis([0,1]); cb1 = colorbar('Xtick', 0:0.1:1);%setColorbar(sub1, cb1, -0.02, 0.02, 0.01, 0.3); axis off;
xlabel('Distance (m)'); ylabel('Time (s)');
title('Original Data')
sub2 = subplot(222);
imagesc(x,t,input),%caxis([0,1]); cb2 = colorbar('Xtick', 0:0.1:1); %setColorbar(sub2, cb2, -0.02, 0.02, 0.01, 0.3); axis off;
xlabel('Distance (m)'); ylabel('Time (s)');
title([num2str(fix(Ratio*100)), '% subsampled data'])
sub3 = subplot(223);
imagesc(x,t,pocscnnRecon),%caxis([0,1]); cb3 = colorbar('Xtick', 0:0.1:1); %setColorbar(sub3, cb3, -0.02, 0.02, 0.01, 0.3); axis off;
xlabel('Distance (m)'); ylabel('Time (s)');
title(['Reconstructed data,', ' SNR ', num2str(SNRCur, '%2.2f'), 'dB'])
sub4 = subplot(224);
imagesc(x,t,D-pocscnnRecon); %cb4 = colorbar('Xtick', 0:0.1:1); %setColorbar(sub4, cb4, -0.02, 0.02, 0.01, 0.3); axis off;
xlabel('Distance (m)'); ylabel('Time (s)');
title('Reconstrunction error')
drawnow;
figure(2); plot(snrs);
fig3 = figure(3); set(gcf, 'color', 'white'), set(gcf, 'Position', [100, 100, 900, 700]), colormap(jet);
sub21 = subplot(131);
[wn, k, f] = waveNumFreq(D, dx, dt);
index = find(f>=0 & f<=freqThresh);
imagesc(k/max(abs(k))/2, f(f>=0 & f<=freqThresh), log10(1+abs(wn(index, :))))
xlabel('Normalized Wavenumber'); ylabel('Frequency (Hz)');
set(gca, 'xlim', [-0.5, 0.5]);
set(gca, 'xtick', [-0.5:0.5:0.5]);
sub22 = subplot(132);
[wn, k, f] = waveNumFreq(input, dx, dt);
index = find(f>=0 & f<=freqThresh);
imagesc(k/max(abs(k))/2, f(f>=0 & f<=freqThresh), log10(1+abs(wn(index, :))))
xlabel('Normalized Wavenumber'); ylabel('Frequency');
set(gca, 'xlim', [-0.5, 0.5]);
set(gca, 'xtick', [-0.5:0.5:0.5]);
sub23 = subplot(133);
[wn, k, f] = waveNumFreq(pocscnnRecon, dx, dt);
index = find(f>=0 & f<=freqThresh);
imagesc(k/max(abs(k))/2, f(f>=0 & f<=freqThresh), log10(1+abs(wn(index, :))))
xlabel('Normalized Wavenumber'); ylabel('Frequency (Hz)');
set(gca, 'xlim', [-0.5, 0.5]);
set(gca, 'xtick', [-0.5:0.5:0.5]);
end
if saveResult
save(['seismicResult/pocscnn/results/', Dataname, '-pocscnn-', sampleType,...
num2str(fix(Ratio*100)), '.mat'], 'pocscnnRecon');
end
if saveSnr
save(['seismicResult/pocscnn/snrs/', Dataname, '-', sampleType, ...
num2str(fix(Ratio*100)), '-lambda-', num2str(lambda1), '-', num2str(lambda2), '.mat'], 'snrs');
end