7
7
from da_grf_test_set_0 import cols_to_unmask , dset_to_skip , drop_frame_num_range
8
8
from data .addb_dataset import MotionDataset
9
9
from matplotlib import rc , lines
10
- from fig_utils import FONT_DICT_SMALL , FONT_SIZE_SMALL , format_axis , LINE_WIDTH
10
+ from fig_utils import FONT_DICT_SMALL , FONT_SIZE_SMALL , format_axis , LINE_WIDTH , FONT_DICT_X_SMALL
11
11
from scipy .stats import friedmanchisquare , wilcoxon
12
12
13
13
@@ -248,7 +248,7 @@ def get_all_the_metrics(model_key):
248
248
# plt.plot(true_concat[within_gait_cycle, param_col_loc])
249
249
# plt.plot(pred_concat[within_gait_cycle, param_col_loc])
250
250
# plt.title(dset_short + ' ' + str(metric_dset[param]))
251
- # plt.show()
251
+ # plt.show()
252
252
253
253
# for param_col, metric_list in metric_all_dsets.items():
254
254
# if param_col == 'dset_short':
@@ -259,13 +259,18 @@ def get_all_the_metrics(model_key):
259
259
260
260
def draw_fig_2 (fast_run = False ):
261
261
def format_ticks (ax ):
262
- ax .set_ylabel ('Mean Absolute Error (\% Body Weight)' , fontdict = FONT_DICT_SMALL )
262
+ ax .text (- 0.7 , 5 , 'Mean Absolute Error (\% BW)' , rotation = 90 , fontdict = FONT_DICT_SMALL , verticalalignment = 'center' )
263
+ ax .text (- 0.8 , 12.5 , 'Better' , rotation = 90 , fontdict = FONT_DICT_SMALL , color = 'green' , verticalalignment = 'center' )
264
+ ax .annotate ('' , xy = (- 0.08 , 0.78 ), xycoords = 'axes fraction' , xytext = (- 0.08 , 0.98 ),
265
+ arrowprops = dict (arrowstyle = "->" , color = 'green' ))
266
+
263
267
ax .set_yticks (range (0 , 15 , 2 ))
264
268
ax .set_yticklabels (range (0 , 15 , 2 ), fontdict = FONT_DICT_SMALL )
265
- # ax.set_xticks([0.25, 1.25, 2.25, 3.25])
269
+ ax .set_xlim ([- 0.3 , 3.8 ])
270
+
266
271
ax .set_xticks ([])
267
272
for i in range (4 ):
268
- ax .text (i + 0.25 , - 1.1 , list (params_name_formal_name_pairs .values ())[i ], fontdict = FONT_DICT_SMALL , ha = 'center' )
273
+ ax .text (i + 0.25 , - 1.6 , list (params_name_formal_name_pairs .values ())[i ], fontdict = FONT_DICT_SMALL , ha = 'center' )
269
274
270
275
colors = [np .array (x ) / 255 for x in [[70 , 130 , 180 ], [207 , 154 , 130 ], [177 , 124 , 90 ]]] # [207, 154, 130], [100, 155, 227]
271
276
folder = 'fast' if fast_run else 'full'
@@ -274,14 +279,16 @@ def format_ticks(ax):
274
279
metric_sugainet = get_all_the_metrics (model_key = f'/{ folder } /sugainet_none_diffusion_filling' )
275
280
276
281
params_name_formal_name_pairs = {
277
- 'calcn_l_force_vy_max' : r'$f_v$ - Peak' , 'calcn_l_force_vy' : r'$f_v$ - Profile' ,
278
- 'calcn_l_force_vx' : r'$f_{ap}$ - Profile' , 'calcn_l_force_vz' : r'$f_{ml}$ - Profile' }
282
+ 'calcn_l_force_vy' : 'Vertical\n Force (Profile)' ,
283
+ 'calcn_l_force_vx' : 'Anterior-Posterior\n Force (Profile)' ,
284
+ 'calcn_l_force_vz' : 'Medial-Lateral\n Force (Profile)' ,
285
+ 'calcn_l_force_vy_max' : 'Vertical\n Force (Peak)' }
279
286
params_of_interest = list (params_name_formal_name_pairs .keys ())
280
287
281
288
rc ('text' , usetex = True )
282
289
plt .rc ('font' , family = 'Helvetica' )
283
290
284
- fig = plt .figure (figsize = (5 , 3 .5 ))
291
+ fig = plt .figure (figsize = (7.7 , 4 .5 ))
285
292
print ('Parameter\t \t All\t \t 1-2\t \t 1-3\t \t 2-3' )
286
293
for i_axis , param in enumerate (params_of_interest ):
287
294
bar_locs = [i_axis , i_axis + 0.25 , i_axis + 0.5 ]
@@ -301,25 +308,35 @@ def format_ticks(ax):
301
308
print ()
302
309
303
310
# From "Comparison of different machine learning models to enhance sacral acceleration-based estimations of running stride temporal variables and peak vertical ground reaction force"
304
- line0 , = plt .plot ([- 0.2 , 0.7 ], [13 , 13 ], ':' , linewidth = 2 , color = [0.0 , 0.0 , 0.0 ], alpha = 0.5 )
311
+ line0 , = plt .plot ([2.8 , 3.7 ], [13 , 13 ], '--' , linewidth = 2 , color = [0.0 , 0.0 , 0.0 ], alpha = 0.5 )
312
+ plt .text (4.1 , 14 , 'Running MDC - Healthy [25]' , fontdict = FONT_DICT_SMALL , color = [0.0 , 0.0 , 0.0 ], va = 'center' )
313
+ plt .annotate ('' , xytext = (4.05 , 14 ), xycoords = 'data' , xy = (3.75 , 13 ), arrowprops = dict (arrowstyle = "->" ))
314
+
305
315
# From "Intra-rater repeatability of gait parameters in healthy adults during self-paced treadmill-based virtual reality walking"
306
- line1 , = plt .plot ([- 0.2 , 0.7 ], [10.18 , 10.18 ], '--' , linewidth = 2 , color = [0.0 , 0.0 , 0.0 ], alpha = 0.5 )
316
+ line1 , = plt .plot ([2.8 , 3.7 ], [10.18 , 10.18 ], '--' , linewidth = 2 , color = [0.0 , 0.0 , 0.0 ], alpha = 0.5 )
317
+ plt .text (4.1 , 11.18 , 'Walking MDC - Healthy [26]' , fontdict = FONT_DICT_SMALL , color = [0.0 , 0.0 , 0.0 ], va = 'center' )
318
+ plt .annotate ('' , xytext = (4.05 , 11.18 ), xycoords = 'data' , xy = (3.75 , 10.18 ), arrowprops = dict (arrowstyle = "->" ))
319
+
307
320
# and "Minimal detectable change for gait variables collected during treadmill walking in individuals post-stroke"
308
- line2 , = plt .plot ([- 0.2 , 0.7 ], [4.65 , 4.65 ], '-' , linewidth = 2 , color = [0.0 , 0.0 , 0.0 ], alpha = 0.5 )
321
+ line2 , = plt .plot ([2.8 , 3.7 ], [4.65 , 4.65 ], '--' , linewidth = 2 , color = [0.0 , 0.0 , 0.0 ], alpha = 0.5 )
322
+ plt .text (4.1 , 5.65 , 'Walking MDC - Stroke [27]' , fontdict = FONT_DICT_SMALL , color = [0.0 , 0.0 , 0.0 ], va = 'center' )
323
+ plt .annotate ('' , xytext = (4.05 , 5.65 ), xycoords = 'data' , xy = (3.75 , 4.65 ), arrowprops = dict (arrowstyle = "->" ))
309
324
310
325
format_axis (plt .gca ())
311
326
format_ticks (plt .gca ())
312
- plt .tight_layout (rect = [0. , - 0.01 , 1 , 1.01 ])
313
- plt .legend (list (bars ) + [line0 , line1 , line2 ], [
314
- 'GaitDynamics' , 'GroundLink [XX]' , 'SugaiNet [XX]' , 'Running MDC - Healthy [XX]' , 'Walking MDC - Healthy [XX]' , 'Walking MDC - Stroke [XX]' ],
315
- frameon = False , fontsize = FONT_SIZE_SMALL , bbox_to_anchor = (0.4 , 1.05 )) # fontsize=font_size,
327
+ plt .tight_layout (rect = [- 0.03 , 0. , 1.03 , 0.88 ])
328
+ plt .legend (list (bars ), ['GaitDynamics' , 'Convolutional Neural Network [19]' , 'Recurrent Neural Network [20]' ],
329
+ frameon = False , fontsize = FONT_SIZE_SMALL , bbox_to_anchor = (0.7 , 1.2 ), ncols = 1 )
316
330
plt .savefig (f'exports/da_grf.png' , dpi = 300 , bbox_inches = 'tight' )
317
331
plt .show ()
318
332
319
333
320
334
def draw_fig_3 (fast_run = False ):
321
335
def format_ticks (ax_plt ):
322
- ax_plt .set_ylabel ('Mean Absolute Error of Peak $f_v$ (\% Body Weight)' , fontdict = FONT_DICT_SMALL )
336
+ ax_plt .text (- 0.7 , 20 , 'Mean Absolute Error of Vertical Force Estimation (\% BW)' , rotation = 90 , fontdict = FONT_DICT_SMALL , verticalalignment = 'center' )
337
+ ax_plt .text (- 1. , 37 , 'Better' , rotation = 90 , fontdict = FONT_DICT_SMALL , color = 'green' , verticalalignment = 'center' )
338
+ ax_plt .annotate ('' , xy = (- 0.11 , 0.8 ), xycoords = 'axes fraction' , xytext = (- 0.11 , 1. ),
339
+ arrowprops = dict (arrowstyle = "->" , color = 'green' ))
323
340
ax_plt .set_yticks ([0 , 10 , 20 , 30 , 40 ])
324
341
ax_plt .set_yticklabels ([0 , 10 , 20 , 30 , 40 ], fontdict = FONT_DICT_SMALL )
325
342
ax_plt .set_ylim ([0 , 40 ])
@@ -336,17 +353,17 @@ def format_ticks(ax_plt):
336
353
masked_segments = test_name .split ('_' )
337
354
for i_segment , segment in enumerate (segment_list ):
338
355
if segment in masked_segments or segment [:- 1 ] in masked_segments :
339
- ax_text .text (i_test + 0.14 , 7.9 - i_segment * 1.1 , segment , fontdict = FONT_DICT_SMALL , color = [0.8 , 0.8 , 0.8 ], ha = 'center' )
356
+ ax_text .text (i_test * 0.96 + 0.35 , 7.9 - i_segment * 1.1 , segment , fontdict = FONT_DICT_SMALL , color = [0.8 , 0.8 , 0.8 ], ha = 'center' )
340
357
else :
341
- ax_text .text (i_test + 0.14 , 7.9 - i_segment * 1.1 , segment , fontdict = FONT_DICT_SMALL , ha = 'center' )
358
+ ax_text .text (i_test * 0.96 + 0.35 , 7.9 - i_segment * 1.1 , segment , fontdict = FONT_DICT_SMALL , ha = 'center' )
342
359
343
360
colors = [np .array (x ) / 255 for x in [[70 , 130 , 180 ], [207 , 154 , 130 ]]] # [207, 154, 130], [100, 155, 227]
344
361
folder = 'fast' if fast_run else 'full'
345
362
param_of_interest = 'calcn_l_force_vy_max'
346
- fig = plt .figure (figsize = (5.5 , 4.2 ))
363
+ fig = plt .figure (figsize = (7.7 , 4.8 ))
347
364
rc ('text' , usetex = True )
348
365
plt .rc ('font' , family = 'Helvetica' )
349
- ax_plt = fig .add_axes ([0.1 , 0.25 , 0.87 , 0.62 ])
366
+ ax_plt = fig .add_axes ([0.14 , 0.25 , 0.83 , 0.66 ])
350
367
351
368
full_input = get_all_the_metrics (model_key = f'/{ folder } /tf_none_diffusion_filling' )[param_of_interest ]
352
369
line_1 , = plt .plot ([- 0.3 , 7.6 ], [np .mean (full_input ), np .mean (full_input )], color = np .array ([70 , 130 , 180 ])/ 255 , linewidth = LINE_WIDTH , linestyle = '--' )
@@ -366,8 +383,8 @@ def format_ticks(ax_plt):
366
383
format_axis (plt .gca ())
367
384
format_ticks (ax_plt )
368
385
ax_plt .legend (list (bars ) + [line_1 ], [
369
- 'Partial-Body Kinematics with Inpainting Filling' , 'Partial-Body Kinematics with Median Filling' , 'Full-Body Kinematics' ],
370
- frameon = False , fontsize = FONT_SIZE_SMALL , bbox_to_anchor = (0.05 , 0.88 ), loc = 'lower left' )
386
+ 'Partial-Body Kinematics with Inpainting Filling (GaitDynamics) ' , 'Partial-Body Kinematics with Median Filling' , 'Full-Body Kinematics (GaitDynamics) ' ],
387
+ frameon = False , fontsize = FONT_SIZE_SMALL , bbox_to_anchor = (0. , 0.88 ), loc = 'lower left' )
371
388
plt .savefig (f'exports/da_segment_filling.png' , dpi = 300 , bbox_inches = 'tight' )
372
389
plt .show ()
373
390
0 commit comments