-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathauto_ml_training.py
105 lines (79 loc) · 3.97 KB
/
auto_ml_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from seisbench.models import EQTransformer
model = EQTransformer.from_pretrained('original')
in_channels = 3
in_samples = 6000
df = pd.read_csv('xa.s12.00.mhz.1970-01-19HR00_evid00002.csv', parse_dates=['time_abs(%Y-%m-%dT%H:%M:%S.%f)'])
class SeismicDataset(Dataset):
def __init__(self, df, in_channels, in_samples):
self.df = df
self.in_channels = in_channels
self.in_samples = in_samples
def __len__(self):
return len(self.df) - self.in_samples + 1
def __getitem__(self, idx):
waveform_data = self.df['velocity(m/s)'].iloc[idx:idx + self.in_samples].values
# Normalize the data
waveform_data = (waveform_data - np.mean(waveform_data)) / np.std(waveform_data)
# Replicate the single channel to create 3 channels
waveform = np.tile(waveform_data, (self.in_channels, 1))
waveform = torch.tensor(waveform, dtype=torch.float32)
time_abs = self.df['time_abs(%Y-%m-%dT%H:%M:%S.%f)'].iloc[idx + self.in_samples - 1]
time_rel = self.df['time_rel(sec)'].iloc[idx + self.in_samples - 1]
return waveform, time_abs.timestamp(), time_rel
def custom_collate(batch):
waveforms, time_abs, time_rel = zip(*batch)
return torch.stack(waveforms), torch.tensor(time_abs), torch.tensor(time_rel)
batch_size = 32
dataset = SeismicDataset(df, in_channels, in_samples)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate)
def detect_seismic_activity(predicted_activity, threshold=0.3):
activity_mask = predicted_activity > threshold
if not activity_mask.any():
return None, None, 0
start_idx = activity_mask.argmax()
end_idx = len(activity_mask) - activity_mask.flip(0).argmax() - 1
return start_idx, end_idx, end_idx - start_idx + 1
model.eval()
total_events = 0
total_samples = 0
max_detection_prob = 0
with torch.no_grad():
for batch_idx, batch in enumerate(dataloader):
waveforms, time_abs, time_rel = batch
outputs = model(waveforms)
# Interpret the model outputs
p_prob, s_prob, detection_prob = outputs
max_detection_prob = max(max_detection_prob, detection_prob.max().item())
for i in range(detection_prob.shape[0]):
total_samples += 1
start_idx, end_idx, duration = detect_seismic_activity(detection_prob[i])
if start_idx is not None and end_idx is not None:
total_events += 1
start_time = pd.Timestamp(time_abs[i].item(), unit='s')
end_time = start_time + pd.Timedelta(
seconds=(end_idx - start_idx) / 100) # supp 00 Hz sampling rate
print(f'Batch {batch_idx}, Sample {i}:')
print(f' Predicted start: {start_time}')
print(f' Predicted end: {end_time}')
print(f' Duration: {duration / 100:.4f} seconds')
print(f' Max detection probability: {detection_prob[i].max().item():.4f}')
p_arrival = start_time + pd.Timedelta(seconds=p_prob[i].argmax() / 100)
s_arrival = start_time + pd.Timedelta(seconds=s_prob[i].argmax() / 100)
print(f' P-wave arrival: {p_arrival}')
print(f' S-wave arrival: {s_arrival}')
print()
if batch_idx >= 9:
break
print(f"Processing complete. Processed {total_samples} samples.")
print(f"Detected {total_events} potential seismic events.")
print(f"Maximum detection probability: {max_detection_prob:.4f}")
detection_probs = torch.cat([model(batch[0])[-1] for batch in dataloader])
print(f"Detection probability statistics:")
print(f" Mean: {detection_probs.mean().item():.4f}")
print(f" Median: {detection_probs.median().item():.4f}")
print(f" 95th percentile: {np.percentile(detection_probs.numpy(), 95):.4f}")
print(f" 99th percentile: {np.percentile(detection_probs.numpy(), 99):.4f}")