Skip to content

Commit f9adb8d

Browse files
committed
fix(dgs_corpus): load custom splits docs
1 parent 11c5027 commit f9adb8d

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

sign_language_datasets/datasets/dgs_corpus/dgs_corpus.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import tensorflow_datasets as tfds
1111

1212
from os import path
13-
from typing import Dict, Any, Set, Optional
13+
from typing import Dict, Any, Set, Optional, List
1414
from pose_format.utils.openpose import load_openpose, OpenPoseFrames
1515
from pose_format.pose import Pose
1616

@@ -102,24 +102,28 @@ def get_openpose(openpose_path: str, fps: int, people: Optional[Set] = None,
102102
return poses
103103

104104

105-
def load_split(split_name: str) -> Dict[str, str]:
105+
def load_split(split_name: str) -> Dict[str, List[str]]:
106106
"""
107+
Loads a split from the file system. What is loaded must be a JSON object with the following structure:
107108
108-
:param split_name:
109-
:return:
109+
{"train": ..., "dev": ..., "test": ...}
110+
111+
:param split_name: An identifier for a predefined split or a filepath to a custom split file.
112+
:return: The split loaded as a dictionary.
110113
"""
111114
if split_name not in _KNOWN_SPLITS.keys():
115+
# assume that the supplied string is a path on the file system
112116
if not path.exists(split_name):
113-
raise ValueError("Split '%s' is not a known data split identifier and does not exist as a file either." % split_name)
117+
raise ValueError("Split '%s' is not a known data split identifier and does not exist as a file either.\n"
118+
"Known split identifiers are: %s" % (split_name, str(_KNOWN_SPLITS)))
114119

115-
# assume that the supplied string is a path on the file system
116120
split_path = split_name
117121
else:
118122
# the supplied string is an identifier for a predefined split
119123
split_path = _KNOWN_SPLITS[split_name]
120124

121125
with open(split_path) as infile:
122-
split = json.load(infile) # type: Dict[str, str]
126+
split = json.load(infile) # type: Dict[str, List[str]]
123127

124128
return split
125129

0 commit comments

Comments
 (0)