@@ -107,6 +107,139 @@ def cluster_scatter(
107
107
ell .set_transform (transform + ax .transData )
108
108
ax .add_patch (ell )
109
109
110
+ def array_scatter_5_features (
111
+ labels ,
112
+ geom ,
113
+ x ,
114
+ z ,
115
+ maxptp ,
116
+ trough_val ,
117
+ trough_time ,
118
+ tip_val ,
119
+ zlim = (- 50 , 3900 ),
120
+ xlim = None ,
121
+ ptplim = None ,
122
+ maxptp_c = None ,
123
+ title = None ,
124
+ axes = None ,
125
+ do_ellipse = True ,
126
+ figsize = (15 , 15 ),
127
+ s_dot = 1 ,
128
+ s_size_geom = 3 ,
129
+ ):
130
+ fig = None
131
+ if axes is None :
132
+ fig , axes = plt .subplots (1 , 6 , sharey = True , figsize = figsize )
133
+
134
+ if title is not None :
135
+ fig .suptitle (title )
136
+
137
+ excluded_ids = {- 1 }
138
+ if not do_ellipse :
139
+ excluded_ids = np .unique (labels )
140
+
141
+ cluster_scatter (
142
+ x ,
143
+ z ,
144
+ labels ,
145
+ ax = axes [0 ],
146
+ s = s_dot ,
147
+ alpha = 0.1 ,
148
+ excluded_ids = excluded_ids ,
149
+ do_ellipse = do_ellipse ,
150
+ )
151
+ axes [0 ].scatter (* geom .T , c = "orange" , marker = "s" , s = s_size_geom )
152
+ axes [0 ].scatter (geom [0 , 0 ], geom [0 , 1 ], c = "orange" , marker = "s" , s = 10 , label = 'Channel Locations' )
153
+ axes [0 ].set_ylabel ("Registered Depth (um)" , fontsize = 14 )
154
+ axes [0 ].set_xlabel ("x (um)" , fontsize = 14 )
155
+ axes [0 ].legend (fontsize = 14 , loc = 'upper left' )
156
+ axes [0 ].tick_params (axis = 'x' , labelsize = 14 )
157
+ axes [0 ].tick_params (axis = 'y' , labelsize = 14 )
158
+
159
+ if maxptp_c is None :
160
+ maxptp_c = np .clip (maxptp , 3 , 15 )
161
+ axes [1 ].scatter (
162
+ x ,
163
+ z ,
164
+ c = maxptp_c ,
165
+ s = s_dot ,
166
+ alpha = 0.1 ,
167
+ marker = "." ,
168
+ cmap = plt .cm .jet ,
169
+ )
170
+ axes [1 ].scatter (* geom .T , c = "orange" , marker = "s" , s = s_size_geom )
171
+ axes [1 ].set_title ("colored by ptps" )
172
+
173
+ cluster_scatter (
174
+ maxptp ,
175
+ z ,
176
+ labels ,
177
+ ax = axes [2 ],
178
+ s = s_dot ,
179
+ alpha = 0.1 ,
180
+ excluded_ids = excluded_ids ,
181
+ do_ellipse = do_ellipse ,
182
+ )
183
+ axes [2 ].set_xlabel ("Amplitude (s.u.)" , fontsize = 16 )
184
+ axes [2 ].tick_params (axis = 'x' , labelsize = 16 )
185
+
186
+ cluster_scatter (
187
+ trough_val ,
188
+ z ,
189
+ labels ,
190
+ ax = axes [3 ],
191
+ s = s_dot ,
192
+ alpha = 0.1 ,
193
+ excluded_ids = excluded_ids ,
194
+ do_ellipse = do_ellipse ,
195
+ )
196
+ axes [3 ].set_xlabel ("Trough Val (s.u.)" , fontsize = 16 )
197
+ axes [3 ].tick_params (axis = 'x' , labelsize = 16 )
198
+
199
+ cluster_scatter (
200
+ trough_time ,
201
+ z ,
202
+ labels ,
203
+ ax = axes [4 ],
204
+ s = s_dot ,
205
+ alpha = 0.1 ,
206
+ excluded_ids = excluded_ids ,
207
+ do_ellipse = do_ellipse ,
208
+ )
209
+ axes [4 ].set_xlabel ("Trough Time (1/30ms)" , fontsize = 16 )
210
+ axes [4 ].tick_params (axis = 'x' , labelsize = 16 )
211
+
212
+ cluster_scatter (
213
+ tip_val ,
214
+ z ,
215
+ labels ,
216
+ ax = axes [5 ],
217
+ s = s_dot ,
218
+ alpha = 0.1 ,
219
+ excluded_ids = excluded_ids ,
220
+ do_ellipse = do_ellipse ,
221
+ )
222
+ axes [5 ].set_xlabel ("Tip Val (s.u.)" , fontsize = 16 )
223
+ axes [5 ].tick_params (axis = 'x' , labelsize = 16 )
224
+
225
+ axes [0 ].set_ylim (zlim )
226
+ axes [1 ].set_ylim (zlim )
227
+ axes [2 ].set_ylim (zlim )
228
+ axes [3 ].set_ylim (zlim )
229
+ axes [4 ].set_ylim (zlim )
230
+ axes [5 ].set_ylim (zlim )
231
+
232
+ if xlim is not None :
233
+ axes [0 ].set_xlim (xlim )
234
+ axes [1 ].set_xlim (xlim )
235
+ if ptplim is not None :
236
+ axes [2 ].set_xlim (ptplim )
237
+
238
+ # if fig is not None:
239
+ # plt.tight_layout()
240
+
241
+ return fig , axes
242
+
110
243
111
244
# %%
112
245
def array_scatter (
0 commit comments