Skip to content

Commit 1403b5f

Browse files
redo medley_db, ikala, and guitarset create_input_data functions to have more stable dataset division.
1 parent 7d78385 commit 1403b5f

File tree

4 files changed

+21
-28
lines changed

4 files changed

+21
-28
lines changed

basic_pitch/data/datasets/guitarset.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -136,17 +136,19 @@ def create_input_data(
136136
if seed:
137137
random.seed(seed)
138138

139-
def determine_split() -> str:
140-
partition = random.uniform(0, 1)
141-
if partition < validation_bound:
139+
def determine_split(index: int) -> str:
140+
if index < len(track_ids) * validation_bound:
142141
return "train"
143-
if partition < test_bound:
142+
elif index < len(track_ids) * test_bound:
144143
return "validation"
145-
return "test"
144+
else:
145+
return "test"
146146

147147
guitarset = mirdata.initialize("guitarset")
148+
track_ids = guitarset.track_ids
149+
random.shuffle(track_ids)
148150

149-
return [(track_id, determine_split()) for track_id in guitarset.track_ids]
151+
return [(track_id, determine_split(i)) for i, track_id in enumerate(track_ids)]
150152

151153

152154
def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None:

basic_pitch/data/datasets/ikala.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -138,21 +138,17 @@ def process(self, element: List[str], *args: Tuple[Any, Any], **kwargs: Dict[str
138138
def create_input_data(train_percent: float, seed: Optional[int] = None) -> List[Tuple[str, str]]:
139139
assert train_percent < 1.0, "Don't over allocate the data!"
140140

141-
# Test percent is 1 - train - validation
142-
validation_bound = train_percent
143-
144141
if seed:
145142
random.seed(seed)
146143

147-
def determine_split() -> str:
148-
partition = random.uniform(0, 1)
149-
if partition < validation_bound:
150-
return "train"
151-
return "validation"
152-
153144
ikala = mirdata.initialize("ikala")
145+
track_ids = ikala.track_ids
146+
random.shuffle(track_ids)
147+
148+
def determine_split(index: int) -> str:
149+
return "train" if index < len(track_ids) * train_percent else "validation"
154150

155-
return [(track_id, determine_split()) for track_id in ikala.track_ids]
151+
return [(track_id, determine_split(i)) for i, track_id in enumerate(track_ids)]
156152

157153

158154
def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None:

basic_pitch/data/datasets/medleydb_pitch.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -136,22 +136,17 @@ def process(self, element: List[str], *args: Tuple[Any, Any], **kwargs: Dict[str
136136
def create_input_data(train_percent: float, seed: Optional[int] = None) -> List[Tuple[str, str]]:
137137
assert train_percent < 1.0, "Don't over allocate the data!"
138138

139-
# Test percent is 1 - train - validation
140-
validation_bound = train_percent
141-
142139
if seed:
143140
random.seed(seed)
144141

145-
def determine_split() -> str:
146-
partition = random.uniform(0, 1)
147-
if partition < validation_bound:
148-
return "train"
149-
return "validation"
150-
151142
medleydb_pitch = mirdata.initialize("medleydb_pitch")
152-
medleydb_pitch.download()
143+
track_ids = medleydb_pitch.track_ids
144+
random.shuffle(track_ids)
145+
146+
def determine_split(index: int) -> str:
147+
return "train" if index < len(track_ids) * train_percent else "validation"
153148

154-
return [(track_id, determine_split()) for track_id in medleydb_pitch.track_ids]
149+
return [(track_id, determine_split(i)) for i, track_id in enumerate(track_ids)]
155150

156151

157152
def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None:

tests/data/test_medleydb_pitch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_medleydb_pitch_invalid_tracks(tmpdir: str) -> None:
5454
def test_medleydb_create_input_data() -> None:
5555
data = create_input_data(train_percent=0.5)
5656
data.sort(key=lambda el: el[1]) # sort by split
57-
tolerance = 0.05
57+
tolerance = 0.01
5858
for _, group in itertools.groupby(data, lambda el: el[1]):
5959
assert (0.5 - tolerance) * len(data) <= len(list(group)) <= (0.5 + tolerance) * len(data)
6060

0 commit comments

Comments
 (0)