Skip to content

Commit ba231c6

Browse files
committed
handles other pose models better
1 parent 5d11e96 commit ba231c6

File tree

3 files changed

+13
-13
lines changed

3 files changed

+13
-13
lines changed

Sports2D/Sports2D.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,9 @@ def base_params(config_dict):
138138
if video_dir == '': video_dir = os.getcwd()
139139
video_files = config_dict.get('project').get('video_files')
140140
if isinstance(video_files, str):
141-
video_files = [Path(video_files)]
141+
video_files = [Path(os.getcwd()) / video_files]
142142
else:
143-
video_files = [Path(v) for v in video_files]
143+
video_files = [Path(os.getcwd()) / v for v in video_files]
144144
result_dir = Path(config_dict.get('project').get('result_dir')).resolve()
145145
if result_dir == '': result_dir = os.getcwd()
146146

Sports2D/compute_angles.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ def compute_angles_fun(config_dict, video_file):
402402

403403
# Find csv position files in video_dir, search pose_model and video_file.stem
404404
logging.info(f'Retrieving csv position files in {result_dir}...')
405-
csv_paths = list(result_dir.glob(f'*{video_file.stem}_{pose_model}_*points*.csv'))
405+
csv_paths = list(result_dir.glob(f'*{video_file.stem}_*points*.csv'))
406406
logging.info(f'{len(csv_paths)} persons found.')
407407

408408
# Compute angles
@@ -476,9 +476,9 @@ def compute_angles_fun(config_dict, video_file):
476476
# Add angles to vid and img
477477
if show_angles_img or show_angles_vid:
478478
video_base = Path(video_dir / video_file)
479-
img_pose = result_dir / (video_base.stem + '_' + pose_model + '_img')
480-
video_pose = result_dir / (video_base.stem + '_' + pose_model + '.mp4')
481-
video_pose2 = result_dir / (video_base.stem + '_' + pose_model + '2.mp4')
479+
img_pose = result_dir / (video_base.stem + '_img')
480+
video_pose = result_dir / (video_base.stem + '.mp4')
481+
video_pose2 = result_dir / (video_base.stem + '2.mp4')
482482

483483
if show_angles_vid:
484484
logging.info(f'Saving video in {str(video_pose)}.')
@@ -514,7 +514,7 @@ def compute_angles_fun(config_dict, video_file):
514514
frame = overlay_angles(frame, df_angles_list_frame)
515515
if show_angles_img:
516516
if frame_nb==0: img_pose.mkdir(parents=True, exist_ok=True)
517-
cv2.imwrite(str(img_pose / (video_base.stem + '_' + pose_model + '.' + str(frame_nb).zfill(5)+'.png')), frame)
517+
cv2.imwrite(str(img_pose / (video_base.stem + '_' + '.' + str(frame_nb).zfill(5)+'.png')), frame)
518518
if show_angles_vid:
519519
writer.write(frame)
520520
frame_nb+=1

Sports2D/detect_pose.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,6 @@ def json_to_csv(json_path, frame_rate, pose_model, interp_gap_smaller_than, filt
286286
model = eval(pose_model)
287287
keypoints_ids = [node.id for _, _, node in RenderTree(model) if node.id!=None]
288288
keypoints_names = [node.name for _, _, node in RenderTree(model) if node.id!=None]
289-
keypoints_names_rearranged = [y for x,y in sorted(zip(keypoints_ids,keypoints_names))]
290289
keypoints_nb = len(keypoints_ids)
291290

292291
# Retrieve coordinates
@@ -300,7 +299,8 @@ def json_to_csv(json_path, frame_rate, pose_model, interp_gap_smaller_than, filt
300299
keypt = []
301300
# Retrieve coords for this frame
302301
for ppl in range(len(json_file['people'])): # for each detected person
303-
keypt += [np.asarray(json_file['people'][ppl]['pose_keypoints_2d']).reshape(-1,3)]
302+
keypt_all = np.asarray(json_file['people'][ppl]['pose_keypoints_2d']).reshape(-1,3)[keypoints_ids]
303+
keypt += [keypt_all]
304304
keypt = np.array(keypt)
305305
# Make sure keypt is as large as the number of persons that need to be detected
306306
if len(keypt) < nb_persons_to_detect:
@@ -321,7 +321,7 @@ def json_to_csv(json_path, frame_rate, pose_model, interp_gap_smaller_than, filt
321321
# Prepare csv header
322322
scorer = ['DavidPagnon']*(keypoints_nb*3+1)
323323
individuals = [f'person{i}']*(keypoints_nb*3+1)
324-
bodyparts = [[p]*3 for p in keypoints_names_rearranged]
324+
bodyparts = [[p]*3 for p in keypoints_names]
325325
bodyparts = ['Time']+[item for sublist in bodyparts for item in sublist]
326326
coords = ['seconds']+['x', 'y', 'likelihood']*keypoints_nb
327327
tuples = list(zip(scorer, individuals, bodyparts, coords))
@@ -604,7 +604,7 @@ def detect_pose_fun(config_dict, video_file):
604604

605605
if pose_algo == 'OPENPOSE':
606606
pose_model = config_dict.get('pose').get('OPENPOSE').get('openpose_model')
607-
json_path = result_dir / '_'.join((video_file_stem,pose_model,'json'))
607+
json_path = result_dir / '_'.join((video_file_stem,'json'))
608608

609609
# Pose detection skipped if load existing json files
610610
if load_pose and len(list(json_path.glob('*')))>0:
@@ -633,8 +633,8 @@ def detect_pose_fun(config_dict, video_file):
633633
elif platform == "linux" or platform=="linux2":
634634
run_openpose_linux(video_path, json_path, pose_model)
635635
os.chdir(root_dir)
636-
637-
636+
637+
638638
elif pose_algo == 'BLAZEPOSE':
639639
pose_model = pose_algo
640640
json_path = result_dir / '_'.join((video_file_stem,pose_algo,'json'))

0 commit comments

Comments
 (0)