Skip to content

Commit c411020

Browse files
committed
first submission
1 parent b3582af commit c411020

25 files changed

+405
-91
lines changed

README.md

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,29 @@
1-
# MotionDiffusionModel
1+
2+
# GaitDynamics: A Foundation Model for Analyzing Gait Dynamics
3+
By Tian Tan, Tom Van Wouwe, Keenon F. Werling, Scott L. Delp, Jennifer L. Hicks, C. Karen Liu, and Akshay S. Chaudhari
4+
5+
## Exclusive Summary
6+
GaitDynamics is a generative foundation model for general-purpose gait dynamics prediction.
7+
We illustrate in three diverse tasks with different inputs, outputs, and clinical impacts: i) estimating
8+
external forces from kinematics, ii) predicting the influence of gait modifications on knee loading without human
9+
experiments, and iii) predicting comprehensive kinematics and kinetic changes that occur with increasing running
10+
speeds.
11+
12+
## Corresponding Publication
13+
This repository includes the code and models for an [abstract](./figures/readme_fig/Tan_ASB2024.pdf).
14+
Full-length preprint is coming soon.
15+
16+
## Environment
17+
Our code is developed under the following environment. Versions different from ours may still work.
18+
19+
Python 3.9.16; Pytorch 1.13.1; Cuda 11.6; Cudnn 8.3.2; numpy 1.23.5;
20+
21+
## Dataset
22+
[AddBiomechanics Dataset](https://addbiomechanics.org/download_data.html)
23+
24+
## Example code
25+
[A Google Colab notebook](https://colab.research.google.com/drive/1n6kH3gnwLdQ2DH5krigbkiO06NjDtyxI?usp=sharing)
26+
is provided for the downstream tasks 1 – force estimation using flexible combinations of kinematic inputs.
27+
By executing the code, and example .mot file with joint angles of an OpenSim skeletal model will be imported from GitHub.
28+
Users can upload their own .mot files to the Colab notebook to obtain force predictions.
29+
To use a reduced kinematic input combinations, simply delete the corresponding columns in the .mot file.

data/addb_dataset.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(
4040
restrict_contact_bodies=True,
4141
use_camargo_lumbar_reconstructed=False,
4242
check_cop_to_calcn_distance=True,
43+
wrong_cop_ratio=0.002
4344
):
4445
self.data_path = data_path
4546
self.trial_start_num = trial_start_num
@@ -54,6 +55,7 @@ def __init__(
5455
self.restrict_contact_bodies = restrict_contact_bodies
5556
self.use_camargo_lumbar_reconstructed = use_camargo_lumbar_reconstructed
5657
self.check_cop_to_calcn_distance = check_cop_to_calcn_distance
58+
self.wrong_cop_ratio = wrong_cop_ratio
5759
self.skels = {}
5860
self.num_of_excluded_trials = {'contact_body_num': 0, 'trial_length': 0, 'lumbar_rotation': 0, 'wrong_cop': 0,
5961
'large_moving_direction_change': 0, 'jittery_sample': 0}
@@ -513,7 +515,7 @@ def load_addb(self, opt, max_trial_num):
513515
# f' flipped {len(grf_flag_counts)} times, thus setting all to True.', end='')
514516
probably_missing = [False] * len(probably_missing)
515517

516-
states = norm_cops(skel, states, opt, weight_kg, height_m, self.check_cop_to_calcn_distance)
518+
states = norm_cops(skel, states, opt, weight_kg, height_m, self.check_cop_to_calcn_distance, self.wrong_cop_ratio)
517519
if states is False:
518520
print(f'{sub_and_trial_name} has CoP far away from foot, skipping')
519521
self.num_of_excluded_trials['wrong_cop'] += 1

data/check_b3d.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import nimblephysics as nimble
2+
3+
4+
b3d_file = '/mnt/d/Downloads/AB06_split0.b3d'
5+
subject = nimble.biomechanics.SubjectOnDisk(b3d_file)
6+
for trial_id in range(subject.getNumTrials()):
7+
trial_name = subject.getTrialName(trial_id)
8+
print(trial_name)
9+
10+
missing_grf_labels = subject.getMissingGRF(trial_id)
11+
12+
probably_missing = [reason != nimble.biomechanics.MissingGRFReason.notMissingGRF for reason
13+
in subject.getMissingGRF(trial_id)]
14+
15+
16+
17+
18+
19+
20+
21+
22+
23+

example_usage/compress_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def compress_model(checkpoint_path, model_class, new_model_name):
133133
compress_model(checkpoint_path, DanceDecoder, 'GaitDynamicsDiffusion')
134134

135135
# # full-body tf
136-
# checkpoint_path = os.getcwd() + '/../trained_models/train-2560_tf.pt'
136+
# checkpoint_path = os.getcwd() + '/../trained_models/train-7680_tf.pt'
137137
# compress_model(checkpoint_path, TransformerEncoderArchitecture, 'GaitDynamicsRefinement')
138138

139139
# # hip-knee tf

example_usage/gait_dynamics.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,6 +1208,7 @@ def convertDfToGRFMot(df, out_folder, dt, max_time=None):
12081208
out_file.write('\t' + str(0))
12091209
out_file.write('\n')
12101210
out_file.close()
1211+
print('GRF file exported to ' + out_folder)
12111212

12121213

12131214
def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:

figures/da_grf_test_set_0.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@ def load_baseline_model(opt, model_to_test):
7272
elif model_to_test == 3:
7373
model_architecture_class = SugaiNetArchitecture
7474
if opt.use_server:
75-
opt.checkpoint_bl = opt.data_path_parent + f"/../code/runs/train/{'SugaiNetArchitecture_ema999'}/weights/train-{'2000_sugainet'}.pt"
75+
opt.checkpoint_bl = opt.data_path_parent + f"/../code/runs/train/{'SugaiNetArchitecture_ema999'}/weights/train-{'2020_sugainet'}.pt"
7676
else:
77-
opt.checkpoint_bl = os.path.dirname(os.path.realpath(__file__)) + f"/../trained_models/train-{'2000_sugainet'}.pt"
77+
opt.checkpoint_bl = os.path.dirname(os.path.realpath(__file__)) + f"/../trained_models/train-{'2020_sugainet'}.pt"
7878
model_key = 'sugainet'
7979

8080
set_with_arm_opt(opt, False)
@@ -403,6 +403,10 @@ def load_test_dataset_dict():
403403
for dset in DATASETS_NO_ARM:
404404
if dset in dset_to_skip:
405405
continue
406+
if 'vanderZee2022' in dset:
407+
wrong_cop_ratio = 0.01
408+
else:
409+
wrong_cop_ratio = 0.002
406410
print(dset)
407411
test_dataset = MotionDataset(
408412
data_path=opt.data_path_test,
@@ -414,6 +418,7 @@ def load_test_dataset_dict():
414418
include_trials_shorter_than_window_len=True,
415419
restrict_contact_bodies=False,
416420
max_trial_num=max_trial_num,
421+
wrong_cop_ratio=wrong_cop_ratio
417422
)
418423
test_dataset_dict[dset] = test_dataset
419424
return test_dataset_dict

figures/da_grf_test_set_1.py

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from da_grf_test_set_0 import cols_to_unmask, dset_to_skip, drop_frame_num_range
88
from data.addb_dataset import MotionDataset
99
from matplotlib import rc, lines
10-
from fig_utils import FONT_DICT_SMALL, FONT_SIZE_SMALL, format_axis, LINE_WIDTH
10+
from fig_utils import FONT_DICT_SMALL, FONT_SIZE_SMALL, format_axis, LINE_WIDTH, FONT_DICT_X_SMALL
1111
from scipy.stats import friedmanchisquare, wilcoxon
1212

1313

@@ -248,7 +248,7 @@ def get_all_the_metrics(model_key):
248248
# plt.plot(true_concat[within_gait_cycle, param_col_loc])
249249
# plt.plot(pred_concat[within_gait_cycle, param_col_loc])
250250
# plt.title(dset_short + ' ' + str(metric_dset[param]))
251-
# plt.show()
251+
# plt.show()
252252

253253
# for param_col, metric_list in metric_all_dsets.items():
254254
# if param_col == 'dset_short':
@@ -259,13 +259,18 @@ def get_all_the_metrics(model_key):
259259

260260
def draw_fig_2(fast_run=False):
261261
def format_ticks(ax):
262-
ax.set_ylabel('Mean Absolute Error (\% Body Weight)', fontdict=FONT_DICT_SMALL)
262+
ax.text(-0.7, 5, 'Mean Absolute Error (\% BW)', rotation=90, fontdict=FONT_DICT_SMALL, verticalalignment='center')
263+
ax.text(-0.8, 12.5, 'Better', rotation=90, fontdict=FONT_DICT_SMALL, color='green', verticalalignment='center')
264+
ax.annotate('', xy=(-0.08, 0.78), xycoords='axes fraction', xytext=(-0.08, 0.98),
265+
arrowprops=dict(arrowstyle="->", color='green'))
266+
263267
ax.set_yticks(range(0, 15, 2))
264268
ax.set_yticklabels(range(0, 15, 2), fontdict=FONT_DICT_SMALL)
265-
# ax.set_xticks([0.25, 1.25, 2.25, 3.25])
269+
ax.set_xlim([-0.3, 3.8])
270+
266271
ax.set_xticks([])
267272
for i in range(4):
268-
ax.text(i+0.25, -1.1, list(params_name_formal_name_pairs.values())[i], fontdict=FONT_DICT_SMALL, ha='center')
273+
ax.text(i+0.25, -1.6, list(params_name_formal_name_pairs.values())[i], fontdict=FONT_DICT_SMALL, ha='center')
269274

270275
colors = [np.array(x) / 255 for x in [[70, 130, 180], [207, 154, 130], [177, 124, 90]]] # [207, 154, 130], [100, 155, 227]
271276
folder = 'fast' if fast_run else 'full'
@@ -274,14 +279,16 @@ def format_ticks(ax):
274279
metric_sugainet = get_all_the_metrics(model_key=f'/{folder}/sugainet_none_diffusion_filling')
275280

276281
params_name_formal_name_pairs = {
277-
'calcn_l_force_vy_max': r'$f_v$ - Peak', 'calcn_l_force_vy': r'$f_v$ - Profile',
278-
'calcn_l_force_vx': r'$f_{ap}$ - Profile', 'calcn_l_force_vz': r'$f_{ml}$ - Profile'}
282+
'calcn_l_force_vy': 'Vertical\nForce (Profile)',
283+
'calcn_l_force_vx': 'Anterior-Posterior\nForce (Profile)',
284+
'calcn_l_force_vz': 'Medial-Lateral\nForce (Profile)',
285+
'calcn_l_force_vy_max': 'Vertical\nForce (Peak)'}
279286
params_of_interest = list(params_name_formal_name_pairs.keys())
280287

281288
rc('text', usetex=True)
282289
plt.rc('font', family='Helvetica')
283290

284-
fig = plt.figure(figsize=(5, 3.5))
291+
fig = plt.figure(figsize=(7.7, 4.5))
285292
print('Parameter\t\tAll\t\t1-2\t\t1-3\t\t2-3')
286293
for i_axis, param in enumerate(params_of_interest):
287294
bar_locs = [i_axis, i_axis + 0.25, i_axis + 0.5]
@@ -301,25 +308,35 @@ def format_ticks(ax):
301308
print()
302309

303310
# From "Comparison of different machine learning models to enhance sacral acceleration-based estimations of running stride temporal variables and peak vertical ground reaction force"
304-
line0, = plt.plot([-0.2, 0.7], [13, 13], ':', linewidth=2, color=[0.0, 0.0, 0.0], alpha=0.5)
311+
line0, = plt.plot([2.8, 3.7], [13, 13], '--', linewidth=2, color=[0.0, 0.0, 0.0], alpha=0.5)
312+
plt.text(4.1, 14, 'Running MDC - Healthy [25]', fontdict=FONT_DICT_SMALL, color=[0.0, 0.0, 0.0], va='center')
313+
plt.annotate('', xytext=(4.05, 14), xycoords='data', xy=(3.75, 13), arrowprops=dict(arrowstyle="->"))
314+
305315
# From "Intra-rater repeatability of gait parameters in healthy adults during self-paced treadmill-based virtual reality walking"
306-
line1, = plt.plot([-0.2, 0.7], [10.18, 10.18], '--', linewidth=2, color=[0.0, 0.0, 0.0], alpha=0.5)
316+
line1, = plt.plot([2.8, 3.7], [10.18, 10.18], '--', linewidth=2, color=[0.0, 0.0, 0.0], alpha=0.5)
317+
plt.text(4.1, 11.18, 'Walking MDC - Healthy [26]', fontdict=FONT_DICT_SMALL, color=[0.0, 0.0, 0.0], va='center')
318+
plt.annotate('', xytext=(4.05, 11.18), xycoords='data', xy=(3.75, 10.18), arrowprops=dict(arrowstyle="->"))
319+
307320
# and "Minimal detectable change for gait variables collected during treadmill walking in individuals post-stroke"
308-
line2, = plt.plot([-0.2, 0.7], [4.65, 4.65], '-', linewidth=2, color=[0.0, 0.0, 0.0], alpha=0.5)
321+
line2, = plt.plot([2.8, 3.7], [4.65, 4.65], '--', linewidth=2, color=[0.0, 0.0, 0.0], alpha=0.5)
322+
plt.text(4.1, 5.65, 'Walking MDC - Stroke [27]', fontdict=FONT_DICT_SMALL, color=[0.0, 0.0, 0.0], va='center')
323+
plt.annotate('', xytext=(4.05, 5.65), xycoords='data', xy=(3.75, 4.65), arrowprops=dict(arrowstyle="->"))
309324

310325
format_axis(plt.gca())
311326
format_ticks(plt.gca())
312-
plt.tight_layout(rect=[0., -0.01, 1, 1.01])
313-
plt.legend(list(bars) + [line0, line1, line2], [
314-
'GaitDynamics', 'GroundLink [XX]', 'SugaiNet [XX]', 'Running MDC - Healthy [XX]', 'Walking MDC - Healthy [XX]', 'Walking MDC - Stroke [XX]'],
315-
frameon=False, fontsize=FONT_SIZE_SMALL, bbox_to_anchor=(0.4, 1.05)) # fontsize=font_size,
327+
plt.tight_layout(rect=[-0.03, 0., 1.03, 0.88])
328+
plt.legend(list(bars), ['GaitDynamics', 'Convolutional Neural Network [19]', 'Recurrent Neural Network [20]'],
329+
frameon=False, fontsize=FONT_SIZE_SMALL, bbox_to_anchor=(0.7, 1.2), ncols=1)
316330
plt.savefig(f'exports/da_grf.png', dpi=300, bbox_inches='tight')
317331
plt.show()
318332

319333

320334
def draw_fig_3(fast_run=False):
321335
def format_ticks(ax_plt):
322-
ax_plt.set_ylabel('Mean Absolute Error of Peak $f_v$ (\% Body Weight)', fontdict=FONT_DICT_SMALL)
336+
ax_plt.text(-0.7, 20, 'Mean Absolute Error of Vertical Force Estimation (\% BW)', rotation=90, fontdict=FONT_DICT_SMALL, verticalalignment='center')
337+
ax_plt.text(-1., 37, 'Better', rotation=90, fontdict=FONT_DICT_SMALL, color='green', verticalalignment='center')
338+
ax_plt.annotate('', xy=(-0.11, 0.8), xycoords='axes fraction', xytext=(-0.11, 1.),
339+
arrowprops=dict(arrowstyle="->", color='green'))
323340
ax_plt.set_yticks([0, 10, 20, 30, 40])
324341
ax_plt.set_yticklabels([0, 10, 20, 30, 40], fontdict=FONT_DICT_SMALL)
325342
ax_plt.set_ylim([0, 40])
@@ -336,17 +353,17 @@ def format_ticks(ax_plt):
336353
masked_segments = test_name.split('_')
337354
for i_segment, segment in enumerate(segment_list):
338355
if segment in masked_segments or segment[:-1] in masked_segments:
339-
ax_text.text(i_test+0.14, 7.9 - i_segment*1.1, segment, fontdict=FONT_DICT_SMALL, color=[0.8, 0.8, 0.8], ha='center')
356+
ax_text.text(i_test*0.96+0.35, 7.9 - i_segment*1.1, segment, fontdict=FONT_DICT_SMALL, color=[0.8, 0.8, 0.8], ha='center')
340357
else:
341-
ax_text.text(i_test+0.14, 7.9 - i_segment*1.1, segment, fontdict=FONT_DICT_SMALL, ha='center')
358+
ax_text.text(i_test*0.96+0.35, 7.9 - i_segment*1.1, segment, fontdict=FONT_DICT_SMALL, ha='center')
342359

343360
colors = [np.array(x) / 255 for x in [[70, 130, 180], [207, 154, 130]]] # [207, 154, 130], [100, 155, 227]
344361
folder = 'fast' if fast_run else 'full'
345362
param_of_interest = 'calcn_l_force_vy_max'
346-
fig = plt.figure(figsize=(5.5, 4.2))
363+
fig = plt.figure(figsize=(7.7, 4.8))
347364
rc('text', usetex=True)
348365
plt.rc('font', family='Helvetica')
349-
ax_plt = fig.add_axes([0.1, 0.25, 0.87, 0.62])
366+
ax_plt = fig.add_axes([0.14, 0.25, 0.83, 0.66])
350367

351368
full_input = get_all_the_metrics(model_key=f'/{folder}/tf_none_diffusion_filling')[param_of_interest]
352369
line_1, = plt.plot([-0.3, 7.6], [np.mean(full_input), np.mean(full_input)], color=np.array([70, 130, 180])/255, linewidth=LINE_WIDTH, linestyle='--')
@@ -366,8 +383,8 @@ def format_ticks(ax_plt):
366383
format_axis(plt.gca())
367384
format_ticks(ax_plt)
368385
ax_plt.legend(list(bars) + [line_1], [
369-
'Partial-Body Kinematics with Inpainting Filling', 'Partial-Body Kinematics with Median Filling', 'Full-Body Kinematics'],
370-
frameon=False, fontsize=FONT_SIZE_SMALL, bbox_to_anchor=(0.05, 0.88), loc='lower left')
386+
'Partial-Body Kinematics with Inpainting Filling (GaitDynamics)', 'Partial-Body Kinematics with Median Filling', 'Full-Body Kinematics (GaitDynamics)'],
387+
frameon=False, fontsize=FONT_SIZE_SMALL, bbox_to_anchor=(0., 0.88), loc='lower left')
371388
plt.savefig(f'exports/da_segment_filling.png', dpi=300, bbox_inches='tight')
372389
plt.show()
373390

figures/da_guided_ts_uhlrich_0.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,12 @@ def loop_all(opt):
138138
sub_ts_true_original_rate[test_name].append(true_original_rate)
139139
sub_ts_pred_original_rate[test_name].append(pred_original_rate)
140140

141-
height_m_all[test_name] = trial_of_this_win.height_m
142-
weight_kg_all[test_name] = trial_of_this_win.weight_kg
141+
if test_name not in height_m_all.keys():
142+
height_m_all[test_name] = []
143+
weight_kg_all[test_name] = []
144+
if ('walking' in trial_of_this_win.sub_and_trial_name) and ('ts' not in trial_of_this_win.sub_and_trial_name.lower()):
145+
height_m_all[test_name].append(trial_of_this_win.height_m)
146+
weight_kg_all[test_name].append(trial_of_this_win.weight_kg)
143147

144148
# if x_times_lumbar_bending > 1:
145149
# name_states_dict = {'true': true_val, 'pred': state_pred.detach().numpy()}
@@ -156,8 +160,8 @@ def loop_all(opt):
156160

157161
if __name__ == "__main__":
158162
opt = parse_opt()
159-
opt.n_guided_steps = 3
160-
opt.guidance_lr = 0.01
163+
opt.n_guided_steps = 5
164+
opt.guidance_lr = 0.02
161165
opt.guide_x_start_the_beginning_step = 1000
162166
opt.guide_x_start_the_end_step = 0
163167
opt.checkpoint = os.path.dirname(os.path.realpath(__file__)) + f"/../trained_models/train-{'2560_diffusion'}.pt"

0 commit comments

Comments
 (0)