Skip to content

Commit 40aa4fc

Browse files
authored
Merge pull request #15 from bricksdont/dgs_split
Dgs split
2 parents e261edb + f9adb8d commit 40aa4fc

File tree

6 files changed

+566
-4
lines changed

6 files changed

+566
-4
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ sign_language_datasets/datasets/ncslgr
99
.coverage
1010
build/
1111
dist/
12-
sign_language_datasets.egg-info/
12+
sign_language_datasets.egg-info/
13+
.DS_Store

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
setup(
1212
name="sign-language-datasets",
1313
packages=packages,
14-
version="0.0.12",
14+
version="0.0.13",
1515
description="TFDS Datasets for sign language",
1616
author="Amit Moryossef",
1717
author_email="amitmoryossef@gmail.com",

sign_language_datasets/datasets/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def __init__(
1414
include_pose: Optional[str] = None,
1515
fps: Optional[float] = None,
1616
resolution: Optional[Tuple[int, int]] = None,
17+
split: Optional[str] = None,
1718
extra: dict = {},
1819
**kwargs,
1920
):
@@ -24,6 +25,7 @@ def __init__(
2425
include_pose: str, what pose data to include.
2526
fps: float, what pose data to include.
2627
resolution: (int, int), what resolution of videos to load.
28+
split: specify a known split identifier (optional)
2729
**kwargs: keyword arguments forwarded to super.
2830
"""
2931
super(SignDatasetConfig, self).__init__(**kwargs)
@@ -33,6 +35,7 @@ def __init__(
3335

3436
self.fps = fps
3537
self.resolution = resolution
38+
self.split = split
3639
self.extra = extra
3740

3841
def ffmpeg_args(self):

sign_language_datasets/datasets/dgs_corpus/dgs_corpus.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import tensorflow_datasets as tfds
1111

1212
from os import path
13-
from typing import Dict, Any, Set, Optional
13+
from typing import Dict, Any, Set, Optional, List
1414
from pose_format.utils.openpose import load_openpose, OpenPoseFrames
1515
from pose_format.pose import Pose
1616

@@ -44,6 +44,10 @@
4444
"openpose": path.join(path.dirname(path.realpath(__file__)), "openpose.poseheader"),
4545
}
4646

47+
_KNOWN_SPLITS = {
48+
"3.0.0-uzh-document": path.join(path.dirname(path.realpath(__file__)), "splits", "split.3.0.0-uzh-document.json"),
49+
}
50+
4751

4852
def convert_dgs_dict_to_openpose_frames(input_dict: Dict[str, Any]) -> OpenPoseFrames:
4953
"""
@@ -98,6 +102,32 @@ def get_openpose(openpose_path: str, fps: int, people: Optional[Set] = None,
98102
return poses
99103

100104

105+
def load_split(split_name: str) -> Dict[str, List[str]]:
106+
"""
107+
Loads a split from the file system. What is loaded must be a JSON object with the following structure:
108+
109+
{"train": ..., "dev": ..., "test": ...}
110+
111+
:param split_name: An identifier for a predefined split or a filepath to a custom split file.
112+
:return: The split loaded as a dictionary.
113+
"""
114+
if split_name not in _KNOWN_SPLITS.keys():
115+
# assume that the supplied string is a path on the file system
116+
if not path.exists(split_name):
117+
raise ValueError("Split '%s' is not a known data split identifier and does not exist as a file either.\n"
118+
"Known split identifiers are: %s" % (split_name, str(_KNOWN_SPLITS)))
119+
120+
split_path = split_name
121+
else:
122+
# the supplied string is an identifier for a predefined split
123+
split_path = _KNOWN_SPLITS[split_name]
124+
125+
with open(split_path) as infile:
126+
split = json.load(infile) # type: Dict[str, List[str]]
127+
128+
return split
129+
130+
101131
class DgsCorpus(tfds.core.GeneratorBasedBuilder):
102132
"""DatasetBuilder for dgs_corpus dataset."""
103133

@@ -193,7 +223,19 @@ def _split_generators(self, dl_manager: tfds.download.DownloadManager):
193223
_id: {k: local_paths[v] if v is not None else None for k, v in datum.items()} for _id, datum in index_data.items()
194224
}
195225

196-
return [tfds.core.SplitGenerator(name=tfds.Split.TRAIN, gen_kwargs={"data": processed_data})]
226+
if self._builder_config.split is not None:
227+
split = load_split(self._builder_config.split)
228+
229+
train_data = {key: value for key, value in processed_data.items() if key in split["train"]}
230+
dev_data = {key: value for key, value in processed_data.items() if key in split["dev"]}
231+
test_data = {key: value for key, value in processed_data.items() if key in split["test"]}
232+
233+
return [tfds.core.SplitGenerator(name=tfds.Split.TRAIN, gen_kwargs={"data": train_data}),
234+
tfds.core.SplitGenerator(name=tfds.Split.VALIDATION, gen_kwargs={"data": dev_data}),
235+
tfds.core.SplitGenerator(name=tfds.Split.TEST, gen_kwargs={"data": test_data})]
236+
237+
else:
238+
return [tfds.core.SplitGenerator(name=tfds.Split.TRAIN, gen_kwargs={"data": processed_data})]
197239

198240
def _generate_examples(self, data):
199241
""" Yields examples. """
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# -*- coding: utf-8 -*-
2+
"""dgs_document_split.ipynb
3+
4+
Automatically generated by Colaboratory.
5+
6+
Original file is located at
7+
https://colab.research.google.com/drive/19pHmLuIEAKFn4BqVr7cwNRaVxQHWKNI7
8+
"""
9+
10+
# ! pip install sign-language-datasets==0.0.12
11+
12+
import json
13+
14+
import numpy as np
15+
16+
import tensorflow_datasets as tfds
17+
import sign_language_datasets.datasets
18+
from sign_language_datasets.datasets.config import SignDatasetConfig
19+
from sign_language_datasets.datasets.dgs_corpus.dgs_utils import get_elan_sentences
20+
21+
from typing import Optional, Tuple
22+
23+
np.random.seed(1)
24+
25+
# Videos 1177918 and 1432043 have 25 fps, start and end frame won't match
26+
27+
INCORRECT_FRAMERATE = ["1432043", "1177918"]
28+
29+
30+
def get_split_indexes(total_size: int, dev_size: int, test_size) -> Tuple[np.array, np.array, np.array]:
31+
train_indexes = np.arange(total_size, dtype=np.int32)
32+
33+
np.random.shuffle(train_indexes)
34+
35+
# high inclusive
36+
37+
dev_indexes = np.random.choice(train_indexes, size=(dev_size,), replace=False)
38+
39+
remaining_train_indexes = np.asarray([i for i in train_indexes if i not in dev_indexes])
40+
41+
test_indexes = np.random.choice(remaining_train_indexes, size=(test_size,), replace=False)
42+
43+
remaining_train_indexes = np.asarray([i for i in remaining_train_indexes if i not in test_indexes])
44+
45+
return remaining_train_indexes, dev_indexes, test_indexes
46+
47+
48+
config = SignDatasetConfig(name="only-annotations", version="1.0.0", include_video=False, include_pose=None)
49+
dgs_corpus = tfds.load('dgs_corpus', builder_kwargs=dict(config=config))
50+
51+
52+
def get_split(dev_size: int, test_size: int):
53+
ids = np.array([datum["id"].numpy().decode("utf-8") for datum in dgs_corpus["train"] if
54+
datum["id"] not in INCORRECT_FRAMERATE])
55+
56+
train_indexes, dev_indexes, test_indexes = get_split_indexes(len(ids), dev_size=dev_size, test_size=test_size)
57+
58+
print("Number of entire files in each split:")
59+
print(str({"train": len(train_indexes), "dev": len(dev_indexes), "test": len(test_indexes)}))
60+
61+
return {"dgs_corpus_version": "3.0.0",
62+
"train": list(ids[train_indexes]),
63+
"dev": list(ids[dev_indexes]),
64+
"test": list(ids[test_indexes])}
65+
66+
67+
split = get_split(dev_size=10, test_size=10)
68+
69+
with open('split.json', 'w') as outfile:
70+
json.dump(split, outfile, indent=4)
71+
72+
# ! cat split.json
73+
74+
"""## compute sentence statistics for this split"""
75+
76+
77+
def get_split_name_from_id(_id: str) -> str:
78+
for key in split.keys():
79+
if _id in split[key]:
80+
return key
81+
82+
return "none"
83+
84+
85+
sentences_found = {"train": 0, "dev": 0, "test": 0, "none": 0}
86+
87+
for datum in dgs_corpus["train"]:
88+
89+
_id = datum["id"].numpy().decode('utf-8')
90+
91+
split_name = get_split_name_from_id(_id)
92+
93+
elan_path = datum["paths"]["eaf"].numpy().decode('utf-8')
94+
sentences = get_elan_sentences(elan_path)
95+
96+
for sentence in sentences:
97+
gloss_sequence = " ".join([s["gloss"] for s in sentence["glosses"]])
98+
german_sentence = sentence["german"]
99+
100+
if gloss_sequence != "" and german_sentence != "":
101+
sentences_found[split_name] += 1
102+
103+
print(sentences_found)

0 commit comments

Comments
 (0)