From 846b7714300d93552c16ca6879a3dd8145f3acec Mon Sep 17 00:00:00 2001 From: zhuwq Date: Tue, 26 Sep 2023 23:16:35 -0700 Subject: [PATCH] add autograd test --- adloc/seismic_ops.py | 470 +++++++++++++++++++++++++++++++++++++++++++ adloc/travel_time.py | 170 ++++++++++++++++ adloc_v3.py | 427 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 1067 insertions(+) create mode 100644 adloc/seismic_ops.py create mode 100644 adloc/travel_time.py create mode 100644 adloc_v3.py diff --git a/adloc/seismic_ops.py b/adloc/seismic_ops.py new file mode 100644 index 0000000..8c7a56a --- /dev/null +++ b/adloc/seismic_ops.py @@ -0,0 +1,470 @@ +import itertools +import numpy as np +import scipy.optimize +import shelve +from pathlib import Path +from numba import njit +from numba.typed import List +import time + +###################################### Eikonal Solver ###################################### +# |\nabla u| = f +# ((u - a1)^+)^2 + ((u - a2)^+)^2 + ((u - a3)^+)^2 = f^2 h^2 + + +@njit +def calculate_unique_solution(a, b, f, h): + d = abs(a - b) + if d >= f * h: + return min([a, b]) + f * h + else: + return (a + b + np.sqrt(2 * f * f * h * h - (a - b) ** 2)) / 2 + + +@njit +def sweeping_over_I_J_K(u, I, J, f, h): + m = len(I) + n = len(J) + + # for i, j in itertools.product(I, J): + for i in I: + for j in J: + if i == 0: + uxmin = u[i + 1, j] + elif i == m - 1: + uxmin = u[i - 1, j] + else: + uxmin = min([u[i - 1, j], u[i + 1, j]]) + + if j == 0: + uymin = u[i, j + 1] + elif j == n - 1: + uymin = u[i, j - 1] + else: + uymin = min([u[i, j - 1], u[i, j + 1]]) + + u_new = calculate_unique_solution(uxmin, uymin, f[i, j], h) + + u[i, j] = min([u_new, u[i, j]]) + + return u + + +@njit +def sweeping(u, v, h): + f = 1.0 / v ## slowness + + m, n = u.shape + # I = list(range(m)) + # I = List() + # [I.append(i) for i in range(m)] + I = np.arange(m) + iI = I[::-1] + # J = list(range(n)) + # J = List() + # [J.append(j) for j in range(n)] + J = np.arange(n) + iJ = J[::-1] + + u = sweeping_over_I_J_K(u, I, J, f, h) + u = sweeping_over_I_J_K(u, iI, J, f, h) + u = sweeping_over_I_J_K(u, iI, iJ, f, h) + u = sweeping_over_I_J_K(u, I, iJ, f, h) + + return u + + +def eikonal_solve(u, f, h): + print("Eikonal Solver: ") + t0 = time.time() + for i in range(50): + u_old = np.copy(u) + u = sweeping(u, f, h) + + err = np.max(np.abs(u - u_old)) + print(f"Iter {i}, error = {err:.3f}") + if err < 1e-6: + break + print(f"Time: {time.time() - t0:.3f}") + return u + + +###################################### Traveltime based on Eikonal Timetable ###################################### +@njit +def _get_index(ir, iz, nr, nz, order="C"): + if order == "C": + return ir * nz + iz + elif order == "F": + return iz * nr + ir + else: + raise ValueError("order must be either C or F") + + +def test_get_index(): + vr, vz = np.meshgrid(np.arange(10), np.arange(20), indexing="ij") + vr = vr.flatten() + vz = vz.flatten() + nr = 10 + nz = 20 + for ir in range(nr): + for iz in range(nz): + assert vr[_get_index(ir, iz, nr, nz)] == ir + assert vz[_get_index(ir, iz, nr, nz)] == iz + + +@njit +def _interp(time_table, r, z, rgrid0, zgrid0, nr, nz, h): + ir0 = np.floor((r - rgrid0) / h).clip(0, nr - 2).astype(np.int64) + iz0 = np.floor((z - zgrid0) / h).clip(0, nz - 2).astype(np.int64) + ir1 = ir0 + 1 + iz1 = iz0 + 1 + + ## https://en.wikipedia.org/wiki/Bilinear_interpolation + x1 = ir0 * h + rgrid0 + x2 = ir1 * h + rgrid0 + y1 = iz0 * h + zgrid0 + y2 = iz1 * h + zgrid0 + + Q11 = time_table[_get_index(ir0, iz0, nr, nz)] + Q12 = time_table[_get_index(ir0, iz1, nr, nz)] + Q21 = time_table[_get_index(ir1, iz0, nr, nz)] + Q22 = time_table[_get_index(ir1, iz1, nr, nz)] + + t = ( + 1 + / (x2 - x1) + / (y2 - y1) + * ( + Q11 * (x2 - r) * (y2 - z) + + Q21 * (r - x1) * (y2 - z) + + Q12 * (x2 - r) * (z - y1) + + Q22 * (r - x1) * (z - y1) + ) + ) + + return t + + +def traveltime(event_loc, station_loc, phase_type, eikonal): + r = np.linalg.norm(event_loc[:, :2] - station_loc[:, :2], axis=-1, keepdims=False) + z = event_loc[:, 2] - station_loc[:, 2] + + rgrid0 = eikonal["rgrid"][0] + zgrid0 = eikonal["zgrid"][0] + nr = eikonal["nr"] + nz = eikonal["nz"] + h = eikonal["h"] + + p_index = phase_type == "p" + s_index = phase_type == "s" + tt = np.zeros(len(phase_type), dtype=np.float32) + tt[phase_type == "p"] = _interp(eikonal["up"], r[p_index], z[p_index], rgrid0, zgrid0, nr, nz, h) + tt[phase_type == "s"] = _interp(eikonal["us"], r[s_index], z[s_index], rgrid0, zgrid0, nr, nz, h) + tt = tt[:, np.newaxis] + + return tt + + +def grad_traveltime(event_loc, station_loc, phase_type, eikonal): + r = np.linalg.norm(event_loc[:, :2] - station_loc[:, :2], axis=-1, keepdims=False) + z = event_loc[:, 2] - station_loc[:, 2] + + rgrid0 = eikonal["rgrid"][0] + zgrid0 = eikonal["zgrid"][0] + nr = eikonal["nr"] + nz = eikonal["nz"] + h = eikonal["h"] + + p_index = phase_type == "p" + s_index = phase_type == "s" + dt_dr = np.zeros(len(phase_type)) + dt_dz = np.zeros(len(phase_type)) + dt_dr[p_index] = _interp(eikonal["grad_up"][0], r[p_index], z[p_index], rgrid0, zgrid0, nr, nz, h) + dt_dr[s_index] = _interp(eikonal["grad_us"][0], r[s_index], z[s_index], rgrid0, zgrid0, nr, nz, h) + dt_dz[p_index] = _interp(eikonal["grad_up"][1], r[p_index], z[p_index], rgrid0, zgrid0, nr, nz, h) + dt_dz[s_index] = _interp(eikonal["grad_us"][1], r[s_index], z[s_index], rgrid0, zgrid0, nr, nz, h) + + dr_dxy = (event_loc[:, :-2] - station_loc[:, :-1]) / (r[:, np.newaxis] + 1e-6) + dt_dxy = dt_dr[:, np.newaxis] * dr_dxy + + grad = np.column_stack((dt_dxy, dt_dz[:, np.newaxis])) + + return grad + + +############################################# Seismic Ops for GaMMA ##################################################################### + + +def calc_time(event_loc, station_loc, phase_type, vel={"p": 6.0, "s": 6.0 / 1.75}, eikonal=None, **kwargs): + ev_loc = event_loc[:, :-1] + ev_t = event_loc[:, -1:] + + if eikonal is None: + v = np.array([vel[x] for x in phase_type])[:, np.newaxis] + tt = np.linalg.norm(ev_loc - station_loc, axis=-1, keepdims=True) / v + ev_t + else: + tt = traveltime(event_loc, station_loc, phase_type, eikonal) + ev_t + return tt + + +def calc_mag(data, event_loc, station_loc, weight, min=-2, max=8): + dist = np.linalg.norm(event_loc[:, :-1] - station_loc, axis=-1, keepdims=True) + # mag_ = ( data - 2.48 + 2.76 * np.log10(dist) ) + ## Picozzi et al. (2018) A rapid response magnitude scale... + c0, c1, c2, c3 = 1.08, 0.93, -0.015, -1.68 + mag_ = (data - c0 - c3 * np.log10(np.maximum(dist, 0.1))) / c1 + 3.5 + ## Atkinson, G. M. (2015). Ground-Motion Prediction Equation... + # c0, c1, c2, c3, c4 = (-4.151, 1.762, -0.09509, -1.669, -0.0006) + # mag_ = (data - c0 - c3*np.log10(dist))/c1 + # mag = np.sum(mag_ * weight) / (np.sum(weight)+1e-6) + mu = np.sum(mag_ * weight) / (np.sum(weight) + 1e-6) + std = np.sqrt(np.sum((mag_ - mu) ** 2 * weight) / (np.sum(weight) + 1e-12)) + mask = np.abs(mag_ - mu) <= 2 * std + mag = np.sum(mag_[mask] * weight[mask]) / (np.sum(weight[mask]) + 1e-6) + mag = np.clip(mag, min, max) + return mag + + +def calc_amp(mag, event_loc, station_loc): + dist = np.linalg.norm(event_loc[:, :-1] - station_loc, axis=-1, keepdims=True) + # logA = mag + 2.48 - 2.76 * np.log10(dist) + ## Picozzi et al. (2018) A rapid response magnitude scale... + c0, c1, c2, c3 = 1.08, 0.93, -0.015, -1.68 + logA = c0 + c1 * (mag - 3.5) + c3 * np.log10(np.maximum(dist, 0.1)) + ## Atkinson, G. M. (2015). Ground-Motion Prediction Equation... + # c0, c1, c2, c3, c4 = (-4.151, 1.762, -0.09509, -1.669, -0.0006) + # logA = c0 + c1*mag + c3*np.log10(dist) + return logA + + +################################################ Earthquake Location ################################################ + + +def huber_loss_grad( + event_loc, phase_time, phase_type, station_loc, weight, vel={"p": 6.0, "s": 6.0 / 1.75}, sigma=1, eikonal=None +): + event_loc = event_loc[np.newaxis, :] + predict_time = calc_time(event_loc, station_loc, phase_type, vel, eikonal) + t_diff = predict_time - phase_time + + l1 = np.squeeze((np.abs(t_diff) > sigma)) + l2 = np.squeeze((np.abs(t_diff) <= sigma)) + + # loss + loss = np.sum((sigma * np.abs(t_diff[l1]) - 0.5 * sigma**2) * weight[l1]) + np.sum( + 0.5 * t_diff[l2] ** 2 * weight[l2] + ) + J = np.zeros([phase_time.shape[0], event_loc.shape[1]]) + + # gradient + if eikonal is None: + v = np.array([vel[p] for p in phase_type])[:, np.newaxis] + dist = np.linalg.norm(event_loc[:, :-1] - station_loc, axis=-1, keepdims=True) + J[:, :-1] = (event_loc[:, :-1] - station_loc) / (dist + 1e-6) / v + else: + grad = grad_traveltime(event_loc, station_loc, phase_type, eikonal) + J[:, :-1] = grad + J[:, -1] = 1 + + J_ = np.sum(sigma * np.sign(t_diff[l1]) * J[l1] * weight[l1], axis=0, keepdims=True) + np.sum( + t_diff[l2] * J[l2] * weight[l2], axis=0, keepdims=True + ) + + return loss, J_ + + +def calc_loc( + phase_time, + phase_type, + station_loc, + weight, + event_loc0, + eikonal=None, + vel={"p": 6.0, "s": 6.0 / 1.75}, + bounds=None, + max_iter=100, + convergence=1e-6, +): + + opt = scipy.optimize.minimize( + huber_loss_grad, + np.squeeze(event_loc0), + method="L-BFGS-B", + jac=True, + args=(phase_time, phase_type, station_loc, weight, vel, 1, eikonal), + bounds=bounds, + options={"maxiter": max_iter, "gtol": convergence, "iprint": -1}, + ) + + return opt.x[np.newaxis, :], opt.fun + + + +def initialize_eikonal(config): + path = Path("./eikonal") + path.mkdir(exist_ok=True) + rlim = [0, np.sqrt((config["xlim"][1] - config["xlim"][0]) ** 2 + (config["ylim"][1] - config["ylim"][0]) ** 2)] + zlim = config["zlim"] + h = config["h"] + + filename = f"timetable_{rlim[0]:.0f}_{rlim[1]:.0f}_{zlim[0]:.0f}_{zlim[1]:.0f}_{h:.3f}" + if (path / (filename + ".dir")).is_file(): + print("Loading precomputed timetable...") + with shelve.open(str(path / filename)) as db: + up = db["up"] + us = db["us"] + grad_up = db["grad_up"] + grad_us = db["grad_us"] + rgrid = db["rgrid"] + zgrid = db["zgrid"] + nr = db["nr"] + nz = db["nz"] + h = db["h"] + else: + edge_grids = 0 + + rgrid = np.arange(rlim[0] - edge_grids * h, rlim[1], h) + zgrid = np.arange(zlim[0] - edge_grids * h, zlim[1], h) + nr, nz = len(rgrid), len(zgrid) + + vel = config["vel"] + zz, vp, vs = vel["z"], vel["p"], vel["s"] + vp1d = np.interp(zgrid, zz, vp) + vs1d = np.interp(zgrid, zz, vs) + vp = np.ones((nr, nz)) * vp1d + vs = np.ones((nr, nz)) * vs1d + + up = 1000.0 * np.ones((nr, nz)) + up[edge_grids, edge_grids] = 0.0 + up = eikonal_solve(up, vp, h) + + grad_up = np.gradient(up, h) + + us = 1000.0 * np.ones((nr, nz)) + us[edge_grids, edge_grids] = 0.0 + us = eikonal_solve(us, vs, h) + + grad_us = np.gradient(us, h) + + with shelve.open(str(path / filename)) as db: + db["up"] = up + db["us"] = us + db["grad_up"] = grad_up + db["grad_us"] = grad_us + db["rgrid"] = rgrid + db["zgrid"] = zgrid + db["nr"] = nr + db["nz"] = nz + db["h"] = h + + up = up.flatten() + us = us.flatten() + grad_up = np.array([grad_up[0].flatten(), grad_up[1].flatten()]) + grad_us = np.array([grad_us[0].flatten(), grad_us[1].flatten()]) + config.update( + { + "up": up, + "us": us, + "grad_up": grad_up, + "grad_us": grad_us, + "rgrid": rgrid, + "zgrid": zgrid, + "nr": nr, + "nz": nz, + "h": h, + } + ) + + return config + + +def initialize_centers(X, phase_type, centers_init, station_locs, random_state): + n_samples, n_features = X.shape + n_components, _ = centers_init.shape + centers = centers_init.copy() + + means = np.zeros([n_components, n_samples, n_features]) + for i in range(n_components): + if n_features == 1: # (time,) + means[i, :, :] = calc_time(centers_init[i : i + 1, :], station_locs, phase_type) + elif n_features == 2: # (time, amp) + means[i, :, 0:1] = calc_time(centers_init[i : i + 1, :-1], station_locs, phase_type) + means[i, :, 1:2] = X[:, 1:2] + # means[i, :, 1:2] = calc_amp(self.centers_init[i, -1:], self.centers_init[i:i+1, :-1], self.station_locs) + else: + raise ValueError(f"n_features = {n_features} > 2!") + + ## performance is not good + # resp = np.zeros((n_samples, self.n_components)) + # dist = np.sum(np.abs(means - X), axis=-1).T # (n_components, n_samples, n_features) -> (n_samples, n_components) + # resp[np.arange(n_samples), np.argmax(dist, axis=1)] = 1.0 + + ## performance is ok + # sigma = np.array([1.0,1.0]) + # prob = np.sum(1.0/sigma * np.exp( - (means - X) ** 2 / (2 * sigma**2)), axis=-1).T # (n_components, n_samples, n_features) -> (n_samples, n_components) + # prob_sum = np.sum(prob, axis=1, keepdims=True) + # prob_sum[prob_sum == 0] = 1.0 + # resp = prob / prob_sum + + dist = np.linalg.norm(means - X, axis=-1).T # (n_components, n_samples, n_features) -> (n_samples, n_components) + resp = np.exp(-dist) + resp_sum = resp.sum(axis=1, keepdims=True) + resp_sum[resp_sum == 0] = 1.0 + resp = resp / resp_sum + + # dist = np.linalg.norm(means - X, axis=-1) # (n_components, n_samples, n_features) -> (n_components, n_samples) + # resp = np.exp(-dist/np.median(dist, axis=0, keepdims=True)).T + # resp /= np.sum(resp, axis=1, keepdims=True) # (n_components, n_samples) + + if n_features == 2: + for i in range(n_components): + centers[i, -1:] = calc_mag(X[:, 1:2], centers_init[i : i + 1, :-1], station_locs, resp[:, i : i + 1]) + + return resp, centers, means + + +######################################################################################################################### +## L2 norm +def diff_and_grad(vars, data, station_locs, phase_type, vel={"p": 6.0, "s": 6.0 / 1.75}): + """ + data: (n_sample, t) + """ + v = np.array([vel[p] for p in phase_type])[:, np.newaxis] + # loc, t = vars[:,:-1], vars[:,-1:] + dist = np.sqrt(np.sum((station_locs - vars[:, :-1]) ** 2, axis=1, keepdims=True)) + y = dist / v - (data - vars[:, -1:]) + J = np.zeros([data.shape[0], vars.shape[1]]) + J[:, :-1] = (vars[:, :-1] - station_locs) / (dist + 1e-6) / v + J[:, -1] = 1 + return y, J + + +def newton_method( + vars, data, station_locs, phase_type, weight, max_iter=20, convergence=1, vel={"p": 6.0, "s": 6.0 / 1.75} +): + for i in range(max_iter): + prev = vars.copy() + y, J = diff_and_grad(vars, data, station_locs, phase_type, vel=vel) + JTJ = np.dot(J.T, weight * J) + I = np.zeros_like(JTJ) + np.fill_diagonal(I, 1e-3) + vars -= np.dot(np.linalg.inv(JTJ + I), np.dot(J.T, y * weight)).T + if (np.sum(np.abs(vars - prev))) < convergence: + return vars + return vars + + +## l1 norm +# def loss_and_grad(vars, data, station_locs, phase_type, weight, vel={"p":6.0, "s":6.0/1.75}): + +# v = np.array([vel[p] for p in phase_type])[:, np.newaxis] +# vars = vars[np.newaxis, :] +# dist = np.sqrt(np.sum((station_locs - vars[:,:-1])**2, axis=1, keepdims=True)) +# J = np.zeros([data.shape[0], vars.shape[1]]) +# J[:, :-1] = (vars[:,:-1] - station_locs)/(dist + 1e-6)/v +# J[:, -1] = 1 + +# loss = np.sum(np.abs(dist/v - (data[:,-1:] - vars[:,-1:])) * weight) +# J = np.sum(np.sign(dist/v - (data[:,-1:] - vars[:,-1:])) * weight * J, axis=0, keepdims=True) + +# return loss, J diff --git a/adloc/travel_time.py b/adloc/travel_time.py new file mode 100644 index 0000000..3bc55f4 --- /dev/null +++ b/adloc/travel_time.py @@ -0,0 +1,170 @@ +# %% +import numpy as np +import torch +from numba import njit +from torch.autograd import Function +from torch import nn +import time + + +@njit +def _get_index(ir, iz, nr, nz, order="C"): + if order == "C": + return ir * nz + iz + elif order == "F": + return iz * nr + ir + else: + raise ValueError("order must be either C or F") + + +def test_get_index(): + vr, vz = np.meshgrid(np.arange(10), np.arange(20), indexing="ij") + vr = vr.flatten() + vz = vz.flatten() + nr = 10 + nz = 20 + for ir in range(nr): + for iz in range(nz): + assert vr[_get_index(ir, iz, nr, nz)] == ir + assert vz[_get_index(ir, iz, nr, nz)] == iz + + +@njit +def _interp(time_table, r, z, rgrid0, zgrid0, nr, nz, h): + ir0 = np.floor((r - rgrid0) / h).clip(0, nr - 2).astype(np.int64) + iz0 = np.floor((z - zgrid0) / h).clip(0, nz - 2).astype(np.int64) + ir1 = ir0 + 1 + iz1 = iz0 + 1 + + ## https://en.wikipedia.org/wiki/Bilinear_interpolation + x1 = ir0 * h + rgrid0 + x2 = ir1 * h + rgrid0 + y1 = iz0 * h + zgrid0 + y2 = iz1 * h + zgrid0 + + Q11 = time_table[_get_index(ir0, iz0, nr, nz)] + Q12 = time_table[_get_index(ir0, iz1, nr, nz)] + Q21 = time_table[_get_index(ir1, iz0, nr, nz)] + Q22 = time_table[_get_index(ir1, iz1, nr, nz)] + + t = ( + 1 + / (x2 - x1) + / (y2 - y1) + * ( + Q11 * (x2 - r) * (y2 - z) + + Q21 * (r - x1) * (y2 - z) + + Q12 * (x2 - r) * (z - y1) + + Q22 * (r - x1) * (z - y1) + ) + ) + + return t + + +class TravelTime(Function): + @staticmethod + def forward(r, z, timetable, rgrid0, zgrid0, nr, nz, h): + tt = _interp(timetable.numpy(), r.numpy(), z.numpy(), rgrid0, zgrid0, nr, nz, h) + tt = torch.from_numpy(tt) + return tt + + @staticmethod + def setup_context(ctx, inputs, output): + r, z, timetable, rgrid0, zgrid0, nr, nz, h = inputs + ctx.timetable = timetable + ctx.rgrid0 = rgrid0 + ctx.zgrid0 = zgrid0 + ctx.nr = nr + ctx.nz = nz + ctx.h = h + + @staticmethod + def backward(ctx, grad_output): + timetable = ctx.timetable + rgrid0 = ctx.rgrid0 + zgrid0 = ctx.zgrid0 + nr = ctx.nr + nz = ctx.nz + h = ctx.h + + grad_r = grad_z = grad_timetable = grad_rgrid0 = grad_zgrid0 = grad_nr = grad_nz = grad_h = None + + timetable = timetable.numpy().reshape(nr, nz) + grad_time_r, grad_time_z = np.gradient(timetable, h, edge_order=2) + grad_time_r = grad_time_r.flatten() + grad_time_z = grad_time_z.flatten() + grad_r = _interp(grad_time_r, r.numpy(), z.numpy(), rgrid0, zgrid0, nr, nz, h) + grad_z = _interp(grad_time_z, r.numpy(), z.numpy(), rgrid0, zgrid0, nr, nz, h) + grad_r = torch.from_numpy(grad_r) + grad_z = torch.from_numpy(grad_z) + + return grad_r, grad_z, grad_timetable, grad_rgrid0, grad_zgrid0, grad_nr, grad_nz, grad_h + + +class Test(nn.Module): + def __init__(self, timetable, rgrid0, zgrid0, nr, nz, h): + super().__init__() + self.timetable = timetable + self.rgrid0 = rgrid0 + self.zgrid0 = zgrid0 + self.nr = nr + self.nz = nz + self.h = h + + def forward(self, r, z): + tt = TravelTime.apply(r, z, self.timetable, self.rgrid0, self.zgrid0, self.nr, self.nz, self.h) + return tt + + +if __name__ == "__main__": + import matplotlib.pyplot as plt + + starttime = time.time() + rgrid0 = 0 + zgrid0 = 0 + nr0 = 20 + nz0 = 20 + h = 1 + r = rgrid0 + h * np.arange(0, nr0) + z = zgrid0 + h * np.arange(0, nz0) + r, z = np.meshgrid(r, z, indexing="ij") + timetalbe = np.sqrt(r**2 + z**2) + timetable = torch.from_numpy(timetalbe.flatten()) + grad_r, grad_z = np.gradient(timetalbe, h, edge_order=2) + + nr = 10000 + nz = 10000 + r = torch.linspace(0, 20, nr) + z = torch.linspace(0, 20, nz) + r, z = torch.meshgrid(r, z, indexing="ij") + r = r.flatten() + z = z.flatten() + + test = Test(timetable, rgrid0, zgrid0, nr0, nz0, h) + r.requires_grad = True + z.requires_grad = True + tt = test(r, z) + tt.backward(torch.ones_like(tt)) + + endtime = time.time() + print(f"Time elapsed: {endtime - starttime} seconds.") + tt = tt.detach().numpy() + + fig, ax = plt.subplots(3, 2) + im = ax[0, 0].imshow(tt.reshape(nr, nz)) + fig.colorbar(im, ax=ax[0, 0]) + im = ax[0, 1].imshow(timetable.reshape(nr0, nz0)) + fig.colorbar(im, ax=ax[0, 1]) + im = ax[1, 0].imshow(r.grad.reshape(nr, nz)) + fig.colorbar(im, ax=ax[1, 0]) + im = ax[1, 1].imshow(grad_r.reshape(nr0, nz0)) + fig.colorbar(im, ax=ax[1, 1]) + im = ax[2, 0].imshow(z.grad.reshape(nr, nz)) + fig.colorbar(im, ax=ax[2, 0]) + im = ax[2, 1].imshow(grad_z.reshape(nr0, nz0)) + fig.colorbar(im, ax=ax[2, 1]) + plt.show() + + +# %% diff --git a/adloc_v3.py b/adloc_v3.py new file mode 100644 index 0000000..bfdff50 --- /dev/null +++ b/adloc_v3.py @@ -0,0 +1,427 @@ +# %% +from datetime import datetime +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as F +from matplotlib import pyplot as plt +from pyproj import Proj +from torch import nn +import torch.optim as optim +from tqdm.auto import tqdm +import utils +from torch.utils.data import Dataset, DataLoader + + +def get_args_parser(add_help=True): + import argparse + + parser = argparse.ArgumentParser(description="PyTorch Detection Training", add_help=add_help) + + parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") + parser.add_argument( + "-b", "--batch-size", default=2, type=int, help="images per gpu, the total batch size is $NGPU x batch_size" + ) + parser.add_argument("--epochs", default=26, type=int, metavar="N", help="number of total epochs to run") + parser.add_argument( + "-j", "--workers", default=0, type=int, metavar="N", help="number of data loading workers (default: 4)" + ) + parser.add_argument("--opt", default="sgd", type=str, help="optimizer") + parser.add_argument( + "--lr", + default=0.02, + type=float, + help="initial learning rate, 0.02 is the default value for training on 8 gpus and 2 images_per_gpu", + ) + parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") + parser.add_argument( + "--wd", + "--weight-decay", + default=1e-4, + type=float, + metavar="W", + help="weight decay (default: 1e-4)", + dest="weight_decay", + ) + parser.add_argument( + "--norm-weight-decay", + default=None, + type=float, + help="weight decay for Normalization layers (default: None, same value as --wd)", + ) + parser.add_argument( + "--lr-scheduler", default="multisteplr", type=str, help="name of lr scheduler (default: multisteplr)" + ) + parser.add_argument( + "--lr-step-size", default=8, type=int, help="decrease lr every step-size epochs (multisteplr scheduler only)" + ) + parser.add_argument( + "--lr-steps", + default=[16, 22], + nargs="+", + type=int, + help="decrease lr every step-size epochs (multisteplr scheduler only)", + ) + parser.add_argument( + "--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma (multisteplr scheduler only)" + ) + parser.add_argument("--print-freq", default=20, type=int, help="print frequency") + parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs") + parser.add_argument("--resume", default="", type=str, help="path of checkpoint") + + parser.add_argument( + "--sync-bn", + dest="sync_bn", + help="Use sync batch norm", + action="store_true", + ) + parser.add_argument( + "--test-only", + dest="test_only", + help="Only test the model", + action="store_true", + ) + + parser.add_argument( + "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only." + ) + + # distributed training parameters + parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") + parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") + parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") + parser.add_argument("--weights-backbone", default=None, type=str, help="the backbone weights enum name to load") + + # Mixed precision training parameters + parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") + + return parser + + +class PhaseDataset: + def __init__(self, picks, events, stations): + self.picks = picks + self.events = events + self.stations = stations + self.__cache() + + def __len__(self): + # return len(self.events) + return 1 + + def __cache(self): + event_index = [] + station_index = [] + phase_score = [] + phase_time = [] + phase_type = [] + + for i in range(len(self.events)): + phase_time.append( + self.picks[self.picks["event_index"] == self.events.loc[i, "event_index"]]["phase_time"].values + ) + phase_score.append( + self.picks[self.picks["event_index"] == self.events.loc[i, "event_index"]]["phase_score"].values + ) + phase_type.extend( + self.picks[self.picks["event_index"] == self.events.loc[i, "event_index"]]["phase_type"].values.tolist() + ) + event_index.extend([i] * len(self.picks[self.picks["event_index"] == self.events.loc[i, "event_index"]])) + station_index.append( + self.stations.loc[ + self.picks[self.picks["event_index"] == self.events.loc[i, "event_index"]]["station_id"], "index" + ].values + ) + + phase_time = np.concatenate(phase_time) + phase_score = np.concatenate(phase_score) + # phase_type = np.array([{"P": 0, "S": 1}[x.upper()] for x in phase_type]) + event_index = np.array(event_index) + station_index = np.concatenate(station_index) + + self.station_index = torch.tensor(station_index, dtype=torch.long) + self.event_index = torch.tensor(event_index, dtype=torch.long) + self.phase_weight = torch.tensor(phase_score, dtype=torch.float32) + self.phase_time = torch.tensor(phase_time[:, np.newaxis], dtype=torch.float32) + self.phase_type = torch.tensor([{"P": 0, "S": 1}[x.upper()] for x in phase_type], dtype=torch.long) + + def __getitem__(self, i): + # phase_time = self.picks[self.picks["event_index"] == self.events.loc[i, "event_index"]]["phase_time"].values + # phase_score = self.picks[self.picks["event_index"] == self.events.loc[i, "event_index"]]["phase_score"].values + # phase_type = self.picks[self.picks["event_index"] == self.events.loc[i, "event_index"]][ + # "phase_type" + # ].values.tolist() + # event_index = np.array([i] * len(self.picks[self.picks["event_index"] == self.events.loc[i, "event_index"]])) + # station_index = self.stations.loc[ + # self.picks[self.picks["event_index"] == self.events.loc[i, "event_index"]]["station_id"], "index" + # ].values + + return { + "event_index": self.event_index, + "station_index": self.station_index, + "phase_time": self.phase_time, + "phase_weight": self.phase_weight, + "phase_type": self.phase_type, + } + + +# %% +class TravelTime(nn.Module): + def __init__( + self, + num_event, + num_station, + station_loc, + station_dt=None, + event_loc=None, + event_time=None, + reg=0.1, + velocity={"P": 6.0, "S": 6.0 / 1.73}, + dtype=torch.float32, + ): + super().__init__() + self.num_event = num_event + self.event_loc = nn.Embedding(num_event, 3) + self.event_time = nn.Embedding(num_event, 1) + self.station_loc = nn.Embedding(num_station, 3) + self.station_dt = nn.Embedding(num_station, 2) # vp, vs + self.station_loc.weight = torch.nn.Parameter(torch.tensor(station_loc, dtype=dtype), requires_grad=False) + if station_dt is not None: + self.station_dt.weight = torch.nn.Parameter(torch.tensor(station_dt, dtype=dtype)) # , requires_grad=False) + else: + self.station_dt.weight = torch.nn.Parameter( + torch.zeros(num_station, 2, dtype=dtype) + ) # , requires_grad=False) + # self.register_buffer("station_loc", torch.tensor(station_loc, dtype=dtype)) + self.velocity = [velocity["P"], velocity["S"]] + self.reg = reg + if event_loc is not None: + self.event_loc.weight = torch.nn.Parameter(torch.tensor(event_loc, dtype=dtype).contiguous()) + if event_time is not None: + self.event_time.weight = torch.nn.Parameter(torch.tensor(event_time, dtype=dtype).contiguous()) + + def calc_time(self, event_loc, station_loc, phase_type): + dist = torch.linalg.norm(event_loc - station_loc, axis=-1, keepdim=True) + # velocity = torch.tensor([self.velocity[p] for p in phase_type]).unsqueeze(-1) + # tt = dist / velocity + # if isinstance(self.velocity, dict): + # self.velocity = torch.tensor([vel[p.upper()] for p in phase_type]).unsqueeze(-1) + # tt = dist / self.velocity + tt = dist / self.velocity[phase_type] + return tt + + def forward( + self, + station_index, + event_index=None, + phase_type=None, + phase_time=None, + phase_weight=None, + double_difference=False, + ): + loss = 0.0 + pred_time = torch.zeros(len(phase_type), 1, dtype=torch.float32) + for type in [0, 1]: + station_index_ = station_index[phase_type == type] + event_index_ = event_index[phase_type == type] + phase_weight_ = phase_weight[phase_type == type] + + station_loc_ = self.station_loc(station_index_) + station_dt_ = self.station_dt(station_index_)[:, type].unsqueeze(-1) + + event_loc_ = self.event_loc(event_index_) + event_time_ = self.event_time(event_index_) + + tt_ = self.calc_time(event_loc_, station_loc_, type) + t_ = event_time_ + tt_ + station_dt_ + pred_time[phase_type == type] = t_ + + if double_difference: + t_ = t_[0] - t_[1] + + if phase_time is not None: + phase_time_ = phase_time[phase_type == type] + # loss = torch.mean(phase_weight * (t - phase_time) ** 2) + loss += torch.mean( + F.huber_loss(tt_ + station_dt_, phase_time_ - event_time_, reduction="none") * phase_weight_ + ) + loss += self.reg * torch.mean( + torch.abs(station_dt_) + ) ## prevent the trade-off between station_dt and event_time + + return {"phase_time": pred_time, "loss": loss} + + +def main(args): + # %% + data_path = Path("test_data") + figure_path = Path("figures") + figure_path.mkdir(exist_ok=True) + + config = { + "center": (-117.504, 35.705), + "xlim_degree": [-118.004, -117.004], + "ylim_degree": [35.205, 36.205], + "degree2km": 111.19492474777779, + "starttime": datetime(2019, 7, 4, 17, 0), + "endtime": datetime(2019, 7, 5, 0, 0), + } + + # %% + stations = pd.read_csv(data_path / "stations.csv", delimiter="\t") + picks = pd.read_csv(data_path / "picks_gamma.csv", delimiter="\t", parse_dates=["phase_time"]) + events = pd.read_csv(data_path / "catalog_gamma.csv", delimiter="\t", parse_dates=["time"]) + + events = events[events["event_index"] < 100] + picks = picks[picks["event_index"] < 100] + + # %% + proj = Proj(f"+proj=sterea +lon_0={config['center'][0]} +lat_0={config['center'][1]} +units=km") + stations[["x_km", "y_km"]] = stations.apply( + lambda x: pd.Series(proj(longitude=x.longitude, latitude=x.latitude)), axis=1 + ) + stations["z_km"] = stations["elevation(m)"].apply(lambda x: -x / 1e3) + starttime = events["time"].min() + events["time"] = (events["time"] - starttime).dt.total_seconds() + picks["phase_time"] = (picks["phase_time"] - starttime).dt.total_seconds() + events[["x_km", "y_km"]] = events.apply( + lambda x: pd.Series(proj(longitude=x.longitude, latitude=x.latitude)), axis=1 + ) + events["z_km"] = events["depth(m)"].apply(lambda x: x / 1e3) + + # %% + num_event = len(events) + num_station = len(stations) + vp = 6.0 + vs = vp / 1.73 + + stations.reset_index(inplace=True, drop=True) + stations["index"] = stations.index.values + stations.set_index("station", inplace=True) + station_loc = stations[["x_km", "y_km", "z_km"]].values + station_dt = None + + events.reset_index(inplace=True, drop=True) + events["index"] = events.index.values + event_loc = events[["x_km", "y_km", "z_km"]].values + event_time = events["time"].values[:, np.newaxis] + + # %% + plt.figure() + plt.scatter(stations["x_km"], stations["y_km"], s=10, marker="^") + plt.scatter(events["x_km"], events["y_km"], s=1) + plt.axis("scaled") + plt.savefig(figure_path / "station_event_v2.png", dpi=300, bbox_inches="tight") + + utils.init_distributed_mode(args) + print(args) + + phase_dataset = PhaseDataset(picks, events, stations) + + if args.distributed: + sampler = torch.utils.data.distributed.DistributedSampler(phase_dataset, shuffle=False) + else: + sampler = torch.utils.data.SequentialSampler(phase_dataset) + + data_loader = DataLoader(phase_dataset, batch_size=None, sampler=sampler, num_workers=args.workers, collate_fn=None) + + ##################################### + # %% + event_index = [] + station_index = [] + phase_score = [] + phase_time = [] + phase_type = [] + + for i in range(len(events)): + phase_time.append(picks[picks["event_index"] == events.loc[i, "event_index"]]["phase_time"].values) + phase_score.append(picks[picks["event_index"] == events.loc[i, "event_index"]]["phase_score"].values) + phase_type.extend(picks[picks["event_index"] == events.loc[i, "event_index"]]["phase_type"].values.tolist()) + event_index.extend([i] * len(picks[picks["event_index"] == events.loc[i, "event_index"]])) + station_index.append( + stations.loc[picks[picks["event_index"] == events.loc[i, "event_index"]]["station_id"], "index"].values + ) + + phase_time = np.concatenate(phase_time) + phase_score = np.concatenate(phase_score) + phase_type = np.array([{"P": 0, "S": 1}[x.upper()] for x in phase_type]) + event_index = np.array(event_index) + station_index = np.concatenate(station_index) + + # %% + station_index = torch.tensor(station_index, dtype=torch.long) + event_index = torch.tensor(event_index, dtype=torch.long) + phase_weight = torch.tensor(phase_score, dtype=torch.float32) + phase_time = torch.tensor(phase_time[:, np.newaxis], dtype=torch.float32) + phase_type = torch.tensor(phase_type, dtype=torch.long) + + ##################################### + + travel_time = TravelTime(num_event, num_station, station_loc, event_time=event_time, velocity={"P": vp, "S": vs}) + tt = travel_time(station_index, event_index, phase_type, phase_weight=phase_weight)["phase_time"] + print("Loss using init location", F.mse_loss(tt, phase_time)) + init_event_loc = travel_time.event_loc.weight.clone().detach().numpy() + init_event_time = travel_time.event_time.weight.clone().detach().numpy() + + # optimizer = optim.LBFGS(params=travel_time.parameters(), max_iter=1000, line_search_fn="strong_wolfe") + optimizer = optim.Adam(params=travel_time.parameters(), lr=0.1) + + epoch = 1000 + for i in range(epoch): + optimizer.zero_grad() + + # for meta in tqdm(data_loader, desc=f"Epoch {i}"): + for meta in data_loader: + station_index = meta["station_index"] + event_index = meta["event_index"] + phase_time = meta["phase_time"] + phase_type = meta["phase_type"] + phase_weight = meta["phase_weight"] + + # def closure(): + # loss = travel_time(station_index, event_index, phase_type, phase_time, phase_weight)["loss"] + # loss.backward() + # return loss + # optimizer.step(closure) + + loss = travel_time(station_index, event_index, phase_type, phase_time, phase_weight)["loss"] + loss.backward() + + if i % 100 == 0: + print(f"Loss: {loss.item()}") + + # optimizer.step(closure) + optimizer.step() + + # %% + tt = travel_time(station_index, event_index, phase_type, phase_weight=phase_weight)["phase_time"] + print("Loss using invert location", F.mse_loss(tt, phase_time)) + station_dt = travel_time.station_dt.weight.clone().detach().numpy() + print(f"station_dt: max = {np.max(station_dt)}, min = {np.min(station_dt)}, mean = {np.mean(station_dt)}") + invert_event_loc = travel_time.event_loc.weight.clone().detach().numpy() + invert_event_time = travel_time.event_time.weight.clone().detach().numpy() + invert_station_dt = travel_time.station_dt.weight.clone().detach().numpy() + + # %% + plt.figure() + # plt.scatter(station_loc[:,0], station_loc[:,1], c=tp[idx_event,:]) + plt.plot(event_loc[:, 0], event_loc[:, 1], "x", markersize=1, color="blue", label="True locations") + plt.scatter(station_loc[:, 0], station_loc[:, 1], c=station_dt[:, 0], marker="o", linewidths=0, alpha=0.6) + plt.scatter(station_loc[:, 0], station_loc[:, 1] + 2, c=station_dt[:, 1], marker="o", linewidths=0, alpha=0.6) + plt.axis("scaled") + plt.colorbar() + xlim = plt.xlim() + ylim = plt.ylim() + plt.plot(init_event_loc[:, 0], init_event_loc[:, 1], "x", markersize=1, color="green", label="Initial locations") + plt.plot(invert_event_loc[:, 0], invert_event_loc[:, 1], "x", markersize=1, color="red", label="Inverted locations") + # plt.xlim(xlim) + # plt.ylim(ylim) + plt.legend() + plt.savefig(figure_path / "invert_location_v2.png", dpi=300, bbox_inches="tight") + + +if __name__ == "__main__": + args = get_args_parser().parse_args() + main(args)