-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
146 lines (112 loc) · 4.36 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
import matplotlib.pyplot as plt
import numpy as np
import wfdb
from scipy.io import loadmat
import setting_path as PATH
def recall_score(target, pred, average=None):
if type(target) not in [list, tuple]:
target = target.tolist()
pred = pred.tolist()
target_type = list(set(target))
count_dict = {t_t: [0, 0] for t_t in target_type}
for tar, pre in zip(target, pred):
if tar == pre:
count_dict[tar][0] += 1
else:
count_dict[tar][1] += 1
count_dict = [c_d[0] / sum(c_d) for c_d in count_dict.values()]
if average is not None:
count_dict = sum(count_dict) / len(count_dict)
return np.array(count_dict)
def pv_score(target, pred, average=None):
if type(target) not in [list, tuple]:
target = target.tolist()
pred = pred.tolist()
pred_type = list(set(pred))
count_dict = {t_t: [0, 0] for t_t in pred_type}
for tar, pre in zip(target, pred):
if tar == pre:
count_dict[pre][0] += 1
else:
count_dict[pre][1] += 1
count_dict = [c_d[0] / sum(c_d) for c_d in count_dict.values()]
if average is not None:
count_dict = sum(count_dict) / len(count_dict)
return np.array(count_dict)
def get_graph(base_record):
row_data = load_data(base_record)
row_data = row_data[0][0:360]
row_time = [i / 360 for i in range(360)]
lc = get_ls_signal(base_record)
lc_data = lc[0][0:98, 0]
lc_time = lc[1][0:98, 0]
print(lc_data.shape)
print(lc_time.shape)
fig = plt.figure(facecolor="white")
ax = fig.add_subplot(111, xlabel="time(s)", ylabel="ECG(mv)")
plt.plot(row_time, row_data, label="normal ADC")
plt.plot(lc_time, lc_data, marker="o", markersize=3, label="level-cross ADC")
ax.legend()
plt.show()
def load_data(base_record, channel=0): # [0, 1]
record_name = os.path.join(PATH.mit_path, str(base_record))
signals, _ = wfdb.rdsamp(record_name)
annotation = wfdb.rdann(record_name, "atr")
symbols = annotation.symbol
positions = annotation.sample
return signals[:, channel], symbols, positions
def get_ls_signal(base_record, channel=0):
file_name = "Rec" + base_record + "_ED_ch" + str(channel + 1)
path = os.path.join(PATH.mit_lc_path, file_name)
data = loadmat(path)["edECG"]
sig = data[0][0]
dtype = data.dtype.fields
ann_time = []
ann_type = []
for s, d in zip(sig, dtype.keys()):
if d == ["ann"]:
ann_time.append(s)
elif d == ["anntype"]:
ann_type.append(s)
elif d == "RR":
rr = s
elif d == "counter":
counter = s
ann = [(ti, ty) for ti, ty in zip(ann_time, ann_type)]
return sig[0], sig[1], rr, counter, ann
def minmax(signal):
return max(signal), min(signal)
def main():
val_dict = []
record_list = ["101"] # PATH.record_list
for r in record_list:
sig, time, rr, counter, ann = get_ls_signal(r)
for s in sig[:, 0]:
if s not in val_dict:
val_dict.append(s)
val_dict.sort()
val_dict = dict([(v, i) for i, v in enumerate(val_dict)])
print(val_dict)
for r in record_list:
sig, time, rr, counter, ann = get_ls_signal(r)
inted_sig = []
for s in sig[:, 0]:
inted_sig.append(val_dict[s])
inted_sig = np.array(inted_sig)
get_graph("101")
if __name__ == "__main__":
# main()
signals, symbols, positions = load_data("101")
x = np.array(range(360)) / 360
plt.rcParams["font.family"] = "sans-serif" # 使用するフォント
plt.rcParams["xtick.direction"] = "in" # x軸の目盛線が内向き('in')か外向き('out')か双方向か('inout')
plt.rcParams["ytick.direction"] = "in" # y軸の目盛線が内向き('in')か外向き('out')か双方向か('inout')
plt.rcParams["xtick.major.width"] = 1.0 # x軸主目盛り線の線幅
plt.rcParams["ytick.major.width"] = 1.0 # y軸主目盛り線の線幅
plt.rcParams["font.size"] = 8 # フォントの大きさ
plt.rcParams["axes.linewidth"] = 1.0 # 軸の線幅edge linewidth。囲みの太さ
plt.xlabel("time (s)")
plt.ylabel("ECG (mV)")
plt.plot(x, signals[:360])
plt.show()