@@ -18,7 +18,7 @@ def parallel_coordinate_plot(
18
18
models_to_highlight_by_line = True ,
19
19
models_to_highlight_colors = None ,
20
20
models_to_highlight_labels = None ,
21
- models_to_highlight_markers = ['s' , 'o' , '^' , '*' ],
21
+ models_to_highlight_markers = ["s" , "o" , "^" , "*" ],
22
22
models_to_highlight_markers_size = 10 ,
23
23
fig = None ,
24
24
ax = None ,
@@ -56,7 +56,7 @@ def parallel_coordinate_plot(
56
56
- `data`: 2-d numpy array for metrics
57
57
- `metric_names`: list, names of metrics for individual vertical axes (axis=1)
58
58
- `model_names`: list, name of models for markers/lines (axis=0)
59
- - `models_to_highlight`: list, default=None, List of models to highlight as lines or marker
59
+ - `models_to_highlight`: list, default=None, List of models to highlight as lines or marker
60
60
- `models_to_highlight_by_line`: bool, default=True, highlight as lines. If False, as marker
61
61
- `models_to_highlight_colors`: list, default=None, List of colors for models to highlight as lines
62
62
- `models_to_highlight_labels`: list, default=None, List of string labels for models to highlight as lines
@@ -242,14 +242,19 @@ def parallel_coordinate_plot(
242
242
label = models_to_highlight_labels [mh_index ]
243
243
else :
244
244
label = model
245
-
245
+
246
246
if models_to_highlight_by_line :
247
247
ax .plot (range (N ), zs [j , :], "-" , c = color , label = label , lw = 3 )
248
248
else :
249
- ax .plot (range (N ), zs [j , :], models_to_highlight_markers [mh_index ],
250
- c = color , label = label ,
251
- markersize = models_to_highlight_markers_size )
252
-
249
+ ax .plot (
250
+ range (N ),
251
+ zs [j , :],
252
+ models_to_highlight_markers [mh_index ],
253
+ c = color ,
254
+ label = label ,
255
+ markersize = models_to_highlight_markers_size ,
256
+ )
257
+
253
258
mh_index += 1
254
259
else :
255
260
if identify_all_models :
@@ -300,20 +305,28 @@ def parallel_coordinate_plot(
300
305
interpolate = False ,
301
306
alpha = 0.5 ,
302
307
)
303
-
308
+
304
309
if arrow_between_lines :
305
310
# Add vertical arrows
306
311
for xi , yi1 , yi2 in zip (x , y1 , y2 ):
307
- if ( yi2 > yi1 ) :
312
+ if yi2 > yi1 :
308
313
arrow_color = arrow_between_lines_colors [0 ]
309
- elif ( yi2 < yi1 ) :
314
+ elif yi2 < yi1 :
310
315
arrow_color = arrow_between_lines_colors [1 ]
311
316
else :
312
317
arrow_color = None
313
318
arrow_length = yi2 - yi1
314
- ax .arrow (xi , yi1 , 0 , arrow_length , color = arrow_color ,
315
- length_includes_head = True ,
316
- alpha = arrow_alpha , width = 0.05 , head_width = 0.15 )
319
+ ax .arrow (
320
+ xi ,
321
+ yi1 ,
322
+ 0 ,
323
+ arrow_length ,
324
+ color = arrow_color ,
325
+ length_includes_head = True ,
326
+ alpha = arrow_alpha ,
327
+ width = 0.05 ,
328
+ head_width = 0.15 ,
329
+ )
317
330
318
331
ax .set_xlim (- 0.5 , N - 0.5 )
319
332
ax .set_xticks (range (N ))
0 commit comments