Skip to content

Commit 1891628

Browse files
authored
Merge pull request #5 from cleong110/feature/save_cropped_poses
Add option to save off .pose files for each segment
2 parents 46017c9 + f4110c1 commit 1891628

File tree

1 file changed

+48
-17
lines changed
  • sign_language_segmentation

1 file changed

+48
-17
lines changed

sign_language_segmentation/bin.py

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/usr/bin/env python
2+
from pathlib import Path
23
import argparse
34
import os
45

@@ -11,7 +12,7 @@
1112
from sign_language_segmentation.src.utils.probs_to_segments import probs_to_segments
1213

1314

14-
def add_optical_flow(pose: Pose):
15+
def add_optical_flow(pose: Pose)->None:
1516
from pose_format.numpy.representation.distance import DistanceRepresentation
1617
from pose_format.utils.optical_flow import OpticalFlowCalculator
1718

@@ -25,7 +26,7 @@ def add_optical_flow(pose: Pose):
2526
pose.body.data = np.concatenate([pose.body.data, flow], axis=-1).astype(np.float32)
2627

2728

28-
def process_pose(pose: Pose, optical_flow=False, hand_normalization=False):
29+
def process_pose(pose: Pose, optical_flow=False, hand_normalization=False) -> Pose:
2930
pose = pose.get_components(["POSE_LANDMARKS", "LEFT_HAND_LANDMARKS", "RIGHT_HAND_LANDMARKS"])
3031

3132
normalization_info = pose_normalization_info(pose.header)
@@ -57,40 +58,59 @@ def predict(model, pose: Pose):
5758
return model(pose_data)
5859

5960

61+
def save_pose_segments(tiers:dict, tier_id:str, input_file_path:Path)->None:
62+
# reload it without any of the processing, so we get all the original points and such.
63+
with input_file_path.open("rb") as f:
64+
pose = Pose.read(f.read())
65+
66+
for i, segment in enumerate(tiers[tier_id]):
67+
out_path = input_file_path.parent / f"{input_file_path.stem}_{tier_id}_{i}.pose"
68+
start_frame = int(segment["start"])
69+
end_frame = int(segment["end"])
70+
cropped_pose = Pose(header=pose.header, body=pose.body[start_frame:end_frame])
71+
72+
print(f"Saving cropped pose with start {start_frame} and end {end_frame} to {out_path}")
73+
with out_path.open("wb") as f:
74+
cropped_pose.write(f)
75+
76+
6077
def get_args():
6178
parser = argparse.ArgumentParser()
62-
parser.add_argument('--pose', required=True, type=str, help='path to input pose file')
63-
parser.add_argument('--elan', required=True, type=str, help='path to output elan file')
64-
parser.add_argument('--video', default=None, required=False, type=str, help='path to video file')
65-
parser.add_argument('--subtitles', default=None, required=False, type=str, help='path to subtitle file')
66-
parser.add_argument('--model', default='model_E1s-1.pth', required=False, type=str, help='path to model file')
67-
parser.add_argument('--no-pose-link', action='store_true', help='whether to link the pose file')
79+
parser.add_argument("--pose", required=True, type=Path, help="path to input pose file")
80+
parser.add_argument("--elan", required=True, type=str, help="path to output elan file")
81+
parser.add_argument(
82+
"--save-segments", type=str, choices=["SENTENCE", "SIGN"], help="whether to save cropped .pose files"
83+
)
84+
parser.add_argument("--video", default=None, required=False, type=str, help="path to video file")
85+
parser.add_argument("--subtitles", default=None, required=False, type=str, help="path to subtitle file")
86+
parser.add_argument("--model", default="model_E1s-1.pth", required=False, type=str, help="path to model file")
87+
parser.add_argument("--no-pose-link", action="store_true", help="whether to link the pose file")
6888

6989
return parser.parse_args()
7090

7191

7292
def main():
7393
args = get_args()
7494

75-
print('Loading pose ...')
95+
print("Loading pose ...")
7696
with open(args.pose, "rb") as f:
7797
pose = Pose.read(f.read())
78-
if 'E4' in args.model:
98+
if "E4" in args.model:
7999
pose = process_pose(pose, optical_flow=True, hand_normalization=True)
80100
else:
81101
pose = process_pose(pose)
82102

83-
print('Loading model ...')
103+
print("Loading model ...")
84104
install_dir = str(os.path.dirname(os.path.abspath(__file__)))
85105
model = load_model(os.path.join(install_dir, "dist", args.model))
86106

87-
print('Estimating segments ...')
107+
print("Estimating segments ...")
88108
probs = predict(model, pose)
89109

90110
sign_segments = probs_to_segments(probs["sign"], 60, 50)
91111
sentence_segments = probs_to_segments(probs["sentence"], 90, 90)
92112

93-
print('Building ELAN file ...')
113+
print("Building ELAN file ...")
94114
tiers = {
95115
"SIGN": sign_segments,
96116
"SENTENCE": sentence_segments,
@@ -111,20 +131,31 @@ def main():
111131
for tier_id, segments in tiers.items():
112132
eaf.add_tier(tier_id)
113133
for segment in segments:
114-
eaf.add_annotation(tier_id, int(segment["start"] / fps * 1000), int(segment["end"] / fps * 1000))
134+
# convert frame numbers to millisecond timestamps, for Elan
135+
start_time_ms = int(segment["start"] / fps * 1000)
136+
end_time_ms = int(segment["end"] / fps * 1000)
137+
eaf.add_annotation(tier_id, start_time_ms, end_time_ms)
138+
139+
if args.save_segments:
140+
print(f"Saving {args.save_segments} cropped .pose files")
141+
save_pose_segments(tiers, tier_id=args.save_segments, input_file_path=args.pose)
115142

116143
if args.subtitles and os.path.exists(args.subtitles):
117144
import srt
145+
118146
eaf.add_tier("SUBTITLE")
119-
with open(args.subtitles, "r") as infile:
147+
# open with explicit encoding,
148+
# as directed in https://github.com/cdown/srt/blob/master/srt_tools/utils.py#L155-L160
149+
# see also https://github.com/cdown/srt/issues/67, https://github.com/cdown/srt/issues/36
150+
with open(args.subtitles, "r", encoding="utf-8-sig") as infile:
120151
for subtitle in srt.parse(infile):
121152
start = subtitle.start.total_seconds()
122153
end = subtitle.end.total_seconds()
123154
eaf.add_annotation("SUBTITLE", int(start * 1000), int(end * 1000), subtitle.content)
124155

125-
print('Saving to disk ...')
156+
print("Saving .eaf to disk ...")
126157
eaf.to_file(args.elan)
127158

128159

129-
if __name__ == '__main__':
160+
if __name__ == "__main__":
130161
main()

0 commit comments

Comments
 (0)