-
Notifications
You must be signed in to change notification settings - Fork 0
/
calculate_threshold.py
109 lines (71 loc) · 3.57 KB
/
calculate_threshold.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import sys
import os
import re
import pickle
import argparse
import warnings
import numpy as np
import pandas as pd
import holoviews as hv
from natsort import natsorted
from data import nested_dict, process_data
from clustering import get_ssc_thresh
warnings.simplefilter('ignore')
hv.extension('matplotlib','bokeh')
np.random.seed(seed=250)
class Logger(object):
def __init__(self, save_dir, source):
self.terminal = sys.stdout
self.log = open(f'{save_dir}/{source}_thresh_log.log', 'a')
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
pass
def run(CSV, mfname, pos_idx, save_dir, floor):
source_df = pd.read_csv(CSV)
source = os.path.basename(os.path.normpath(CSV))
source = re.sub(r'(.*)\.[cC][sS][vV]$', r'\1', source)
source_threshs_dir = f'{save_dir}/img/threshs/{source}'
os.makedirs(source_threshs_dir, exist_ok=True)
sys.stderr = Logger(save_dir, source)
sys.stdout = Logger(save_dir, source)
me_markers = pd.read_csv(mfname)
me_markers = {col: me_markers[col].dropna().to_list() for col in me_markers.columns}
pos = list(me_markers.keys())[int(pos_idx)]
thresh_dict = nested_dict()
for neg in me_markers[pos]:
print(pos, neg)
figs = []
for scene in natsorted(set(source_df.scene)):
print(source, scene)
tmp_df = source_df[source_df.scene==scene][[pos,neg]]
tmp_df = process_data(tmp_df,pos,neg,floor)
thresh, clusters = get_ssc_thresh(tmp_df)
thresh_dict[source][scene][pos][neg] = [thresh]
xlim = (source_df[pos].quantile(0.001),
source_df[pos].quantile(0.999))
ylim = (source_df[neg].quantile(0.001),
source_df[neg].quantile(0.999))
scatters = hv.Overlay([hv.Scatter(i).opts(xlabel=pos,
ylabel=neg,
xlim=xlim,
ylim=ylim,
alpha=0.1) for i in clusters])
thresh_line = hv.VLine(thresh).opts(color='black')
fig = hv.Overlay([scatters, thresh_line],label=f'scene {str(scene)}')
figs.append(fig)
layout = hv.Layout(figs).opts(title=source,sublabel_format='').cols(4)
hv.save(layout,f'{source_threshs_dir}/{source}_{pos}_{neg}.png')
thresh_dir = f'{save_dir}/thresh_dicts/{source}'
os.makedirs(thresh_dir,exist_ok=True)
pickle.dump(thresh_dict, open(f'{thresh_dir}/{source}_{pos}_thresh_dict.pkl','wb'))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='RESTPRE calculate threshold')
parser.add_argument('--CSV', type=str, default=None, metavar='S', help='which CSV file to process')
parser.add_argument('--mfname', type=str, default=None, metavar='S', help='which marker file to process')
parser.add_argument('--pos_idx', type=str, default=None, metavar='S', help='which marker to normalize')
parser.add_argument('--save_dir', type=str, default='~/', metavar='S', help='where to save')
parser.add_argument('--floor', type=float, default=0.0, metavar='S', help='set floor of intensity')
args = parser.parse_args()
run(args.CSV, args.mfname, args.pos_idx, args.save_dir, args.floor)