-
Notifications
You must be signed in to change notification settings - Fork 17
/
mess_lib_gpu.py
211 lines (194 loc) · 7.81 KB
/
mess_lib_gpu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
""" Basic modules of MESS
"""
import os, time
import torch
import torch.nn.functional as F
from scipy.signal import correlate
import numpy as np
from dataset_gpu import cpu2cuda
import config
# MESS params
cfg = config.Config()
min_sta = cfg.min_sta
freq_band = cfg.freq_band
samp_rate = cfg.samp_rate
temp_win_p = [int(samp_rate * win) for win in cfg.temp_win_p]
temp_win_s = [int(samp_rate * win) for win in cfg.temp_win_s]
pick_win_p = [int(samp_rate * win) for win in cfg.pick_win_p]
pick_win_s = [int(samp_rate * win) for win in cfg.pick_win_s]
amp_win = [int(samp_rate * win) for win in cfg.amp_win]
expand_len = int(samp_rate * cfg.expand_len)
det_gap = int(samp_rate * cfg.det_gap)
chn_p = cfg.chn_p
chn_s = cfg.chn_s
trig_thres = cfg.trig_thres
def mess_det(temp_pick_dict, data_dict):
""" MESS detection (main)
Input
temp, norm_temp, dt_list = temp_pick_dict[net_sta]
data, norm_data = data_dict[net_sta]
Output
dets = [det_ot, det_cc]
*note: det_ot is in relative sec
"""
t=time.time()
# get batch input
data_list, temp_list, dt_ot_list = [], [], []
for net_sta, [temp, norm_temp, dt_list] in temp_pick_dict.items():
if net_sta not in data_dict: continue
data, norm_data = data_dict[net_sta][1:3]
data_list.append([data, norm_data])
temp_list.append([temp[0], norm_temp[0]])
dt_ot_list.append(dt_list[0])
dt_ot_list = np.array(dt_ot_list)
num_sta = len(data_list)
if num_sta<min_sta: return []
# 1. match
cc_mat = match_filter(data_list, temp_list)
cc_cond = np.amax(cc_mat, axis=1) > trig_thres
cc_mat = cc_mat[cc_cond]
dt_ot_list = dt_ot_list[cc_cond]
num_sta = len(cc_mat)
if num_sta<min_sta: return []
# 2. expand
cc_expand = [expand_cc(cc_i) for cc_i in cc_mat]
# 3. shift
cc_holder = np.zeros([num_sta, int(86400*samp_rate)])
cc_shift = shift_ot(cc_expand, dt_ot_list, cc_holder)
# 4. stack & detect
cc_stack = np.mean(cc_shift, axis=0)
dets = det_cc_stack(cc_stack)
print('{} detections, {} stations, {:.1f}s'.format(len(dets), num_sta, time.time()-t))
return dets
def cc_pick(det_ot, temp_pick_dict, data_dict):
""" Cross-correlation phase picking (main)
Input
det_ot: detected orgin time (relative sec to date)
temp_pick_dict
data_dict
Output
picks: [net_sta, tp, ts, dt_p, dt_s, s_amp, cc_p, cc_s]
"""
det_ot = int(det_ot * samp_rate) # sec to idx
picks = []
for net_sta, [temp, norm_temp, dt_list] in temp_pick_dict.items():
# get np data & temp
if net_sta not in data_dict: continue
data_np = data_dict[net_sta][0].numpy()
temp = [di.numpy() for di in temp]
norm_temp = [di.numpy() for di in norm_temp]
# slice p&s data
p_range = [temp_win_p[0] + pick_win_p[0], temp_win_p[1] + pick_win_p[1]]
s_range = [temp_win_s[0] + pick_win_s[0], temp_win_s[1] + pick_win_s[1]]
tp0, ts0 = int(det_ot + dt_list[1]), int(det_ot + dt_list[2])
if tp0 < p_range[0] or ts0 < s_range[0] \
or ts0 + s_range[1] > data_np.shape[-1]: continue
data_p = data_np[:, tp0 - p_range[0] : tp0 + p_range[1]]
data_s = data_np[:, ts0 - s_range[0] : ts0 + s_range[1]]
# 1. pick by cc
cc_p = [calc_cc(data_p[i], temp[1][i], norm_temp=norm_temp[1][i]) for i in chn_p]
cc_s = [calc_cc(data_s[i], temp[2][i], norm_temp=norm_temp[2][i]) for i in chn_s]
cc_p = np.mean(cc_p, axis=0)
cc_s = np.mean(cc_s, axis=0)
# [tp, ts], [dt_p, dt_s] (relative sec), & [cc_p, cc_s]
tp_idx = tp0 + np.argmax(cc_p) - pick_win_p[0]
ts_idx = ts0 + np.argmax(cc_s) - pick_win_s[0]
tp, ts = tp_idx/samp_rate, ts_idx/samp_rate
dt_p, dt_s = (tp_idx-tp0)/samp_rate, (ts_idx-ts0)/samp_rate
cc_p_max, cc_s_max = np.amax(cc_p), np.amax(cc_s)
# 2. get amplitude
data_amp = data_np[:, tp_idx-amp_win[0] : ts_idx+amp_win[1]]
if data_amp.shape[1]<amp_win[0]: continue
s_amp = get_s_amp(data_amp)
picks.append([net_sta, tp, ts, dt_p, dt_s, s_amp, cc_p_max, cc_s_max])
return picks
""" Base functions
"""
def calc_cc_gpu(data_mat, temp_mat, norm_data_mat, norm_temp_vec):
num_sta, len_data = data_mat.shape
_, len_temp = temp_mat.shape
data_mat = data_mat.view([1, num_sta, len_data])
temp_mat = temp_mat.view([num_sta, 1, len_temp])
cc_mat = F.conv1d(data_mat, temp_mat, groups=num_sta)[0,:,1:]
cc_mat /= norm_data_mat
cc_mat /= norm_temp_vec.view([num_sta,1])
return cc_mat
def calc_cc(data, temp, norm_data=None, norm_temp=None):
ntemp, ndata = len(temp), len(data)
if ntemp>ndata: return [0]
if not norm_temp:
norm_temp = np.sqrt(np.sum(temp**2))
if not norm_data:
data_cum = np.cumsum(data**2)
norm_data = np.sqrt(data_cum[ntemp:] - data_cum[:-ntemp])
cc = correlate(data, temp, mode='valid')[1:]
cc /= norm_data * norm_temp
cc[np.isinf(cc)] = 0.
cc[np.isnan(cc)] = 0.
return cc
# 1. matched filter (calc cc traces)
def match_filter(data_list, temp_list):
for i in range(3):
data_mat = torch.stack([data[i] for [data,_] in data_list])
norm_data = torch.stack([norm[i] for [_,norm] in data_list])
temp_mat = cpu2cuda(torch.stack([temp[i] for [temp,_] in temp_list]))
norm_temp = cpu2cuda(torch.tensor([norm[i] for [_,norm] in temp_list]))
if i==0: cc_mat = calc_cc_gpu(data_mat, temp_mat, norm_data, norm_temp)
else: cc_mat += calc_cc_gpu(data_mat, temp_mat, norm_data, norm_temp)
cc_mat = (cc_mat/3.).cpu().numpy()
cc_mat[np.isinf(cc_mat)] = 0.
cc_mat[np.isnan(cc_mat)] = 0.
return cc_mat
# 2. expand peak value in CC trace
def expand_cc(cc):
trig_idxs = np.where(cc>trig_thres)[0]
slide_idx = 0
for trig_idx in trig_idxs:
if trig_idx < slide_idx: continue
cc_trig = cc[trig_idx : trig_idx+2*expand_len]
cc_max = np.amax(cc_trig)
idx_max = trig_idx + np.argmax(cc_trig)
idx_0 = max(0, idx_max - expand_len//2)
idx_1 = idx_max + expand_len//2
cc[idx_0:idx_1] = cc_max
# next trig
slide_idx = trig_idx + expand_len + det_gap
return cc
# 3. shift time shift to ot
def shift_ot(cc_list, dt_ot_list, cc_holder):
for i,dt_ot in enumerate(dt_ot_list):
cc_i = cc_list[i][max(0,-dt_ot) : cc_holder.shape[1] - dt_ot]
cc_holder[i][max(0,dt_ot) : max(0,dt_ot) + len(cc_i)] = cc_i
return cc_holder
# 4. detect on stacked cc trace
def det_cc_stack(cc_stack):
det_idxs = np.where(cc_stack>trig_thres)[0]
slide_idx = 0
dets = []
for det_idx in det_idxs:
if det_idx < slide_idx: continue
# det ot (rel sec)
cc_det = cc_stack[det_idx : det_idx+2*expand_len]
cc_max = np.amax(cc_det)
det_ot = (det_idx + np.median(np.where(cc_det == cc_max)[0])) / samp_rate
dets.append([det_ot, cc_max])
# next det
slide_idx = det_idx + expand_len + det_gap
return dets
# get S amplitide
def get_s_amp(velo):
# remove mean
velo -= np.reshape(np.mean(velo, axis=1), [velo.shape[0],1])
# velocity to displacement
disp = np.cumsum(velo, axis=1)
disp /= samp_rate
return np.amax(np.sum(disp**2, axis=0))**0.5
# write detection to catalog
def write_ctlg(det_ot, det_cc, temp_name, temp_loc, out_ctlg):
out_ctlg.write('{0},{1},{2[1]},{2[2]},{2[3]},{3:.3f}\n'.format(temp_name, det_ot, temp_loc, det_cc))
# write phase picks to phase file
def write_pha(det_ot, det_cc, temp_name, temp_loc, picks, out_pha):
out_pha.write('{0},{1},{2[1]},{2[2]},{2[3]},{3:.3f}\n'.format(temp_name, det_ot, temp_loc, det_cc))
for pick in picks:
# net_sta, tp, ts, dt_p, dt_s, s_amp, cc_p, cc_s
out_pha.write('{0[0]},{0[1]},{0[2]},{0[3]:.2f},{0[4]:.2f},{0[5]},{0[6]:.3f},{0[7]:.3f}\n'.format(pick))