Skip to content

Commit 91d220b

Browse files
authored
Merge pull request #128 from spotify/bgenchel/add-ikala.py
Added Ikala
2 parents 8af1ad0 + 2abf927 commit 91d220b

File tree

3 files changed

+263
-3
lines changed

3 files changed

+263
-3
lines changed

basic_pitch/data/datasets/ikala.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
#!/usr/bin/env python
2+
# encoding: utf-8
3+
#
4+
# Copyright 2024 Spotify AB
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
import argparse
19+
import logging
20+
import os
21+
import random
22+
import sys
23+
import time
24+
from typing import Any, Dict, List, Tuple, Optional
25+
26+
import apache_beam as beam
27+
import mirdata
28+
29+
from basic_pitch.data import commandline, pipeline
30+
31+
32+
class IkalaInvalidTracks(beam.DoFn):
33+
def process(self, element: Tuple[str, str], *args: Tuple[Any, Any], **kwargs: Dict[str, Any]) -> Any:
34+
track_id, split = element
35+
yield beam.pvalue.TaggedOutput(split, track_id)
36+
37+
38+
class IkalaToTfExample(beam.DoFn):
39+
DOWNLOAD_ATTRIBUTES = ["audio_path", "notes_pyin_path", "f0_path"]
40+
41+
def __init__(self, source: str, download: bool) -> None:
42+
self.source = source
43+
self.download = download
44+
45+
def setup(self) -> None:
46+
import apache_beam as beam
47+
import os
48+
import mirdata
49+
50+
self.ikala_remote = mirdata.initialize("ikala", data_home=os.path.join(self.source, "iKala"))
51+
self.filesystem = beam.io.filesystems.FileSystems() # TODO: replace with fsspec
52+
if self.download:
53+
self.ikala_remote.download()
54+
55+
def process(self, element: List[str], *args: Tuple[Any, Any], **kwargs: Dict[str, Any]) -> List[Any]:
56+
import tempfile
57+
58+
import numpy as np
59+
import sox
60+
61+
from basic_pitch.constants import (
62+
AUDIO_N_CHANNELS,
63+
AUDIO_SAMPLE_RATE,
64+
FREQ_BINS_CONTOURS,
65+
FREQ_BINS_NOTES,
66+
ANNOTATION_HOP,
67+
N_FREQ_BINS_CONTOURS,
68+
N_FREQ_BINS_NOTES,
69+
)
70+
from basic_pitch.data import tf_example_serialization
71+
72+
logging.info(f"Processing {element}")
73+
batch = []
74+
75+
for track_id in element:
76+
track_remote = self.ikala_remote.track(track_id)
77+
with tempfile.TemporaryDirectory() as local_tmp_dir:
78+
ikala_local = mirdata.initialize("ikala", local_tmp_dir)
79+
track_local = ikala_local.track(track_id)
80+
81+
for attr in self.DOWNLOAD_ATTRIBUTES:
82+
source = getattr(track_remote, attr)
83+
dest = getattr(track_local, attr)
84+
os.makedirs(os.path.dirname(dest), exist_ok=True)
85+
with self.filesystem.open(source) as s, open(dest, "wb") as d:
86+
d.write(s.read())
87+
88+
local_wav_path = "{}_tmp.wav".format(track_local.audio_path)
89+
90+
tfm = sox.Transformer()
91+
tfm.rate(AUDIO_SAMPLE_RATE)
92+
tfm.remix({1: [2]})
93+
tfm.channels(AUDIO_N_CHANNELS)
94+
tfm.build(track_local.audio_path, local_wav_path)
95+
96+
duration = sox.file_info.duration(local_wav_path)
97+
time_scale = np.arange(0, duration + ANNOTATION_HOP, ANNOTATION_HOP)
98+
n_time_frames = len(time_scale)
99+
100+
if track_local.notes_pyin is not None:
101+
note_indices, note_values = track_local.notes_pyin.to_sparse_index(
102+
time_scale, "s", FREQ_BINS_NOTES, "hz"
103+
)
104+
onset_indices, onset_values = track_local.notes_pyin.to_sparse_index(
105+
time_scale, "s", FREQ_BINS_NOTES, "hz", onsets_only=True
106+
)
107+
note_shape = (n_time_frames, N_FREQ_BINS_NOTES)
108+
# if there are no notes, return empty note indices
109+
else:
110+
note_indices = []
111+
onset_indices = []
112+
note_values = []
113+
onset_values = []
114+
note_shape = (0, 0)
115+
116+
contour_indices, contour_values = track_local.f0.to_sparse_index(
117+
time_scale, "s", FREQ_BINS_CONTOURS, "hz"
118+
)
119+
120+
batch.append(
121+
tf_example_serialization.to_transcription_tfexample(
122+
track_id,
123+
"ikala",
124+
local_wav_path,
125+
note_indices,
126+
note_values,
127+
onset_indices,
128+
onset_values,
129+
contour_indices,
130+
contour_values,
131+
note_shape,
132+
(n_time_frames, N_FREQ_BINS_CONTOURS),
133+
)
134+
)
135+
return [batch]
136+
137+
138+
def create_input_data(train_percent: float, seed: Optional[int] = None) -> List[Tuple[str, str]]:
139+
assert train_percent < 1.0, "Don't over allocate the data!"
140+
141+
# Test percent is 1 - train - validation
142+
validation_bound = train_percent
143+
144+
if seed:
145+
random.seed(seed)
146+
147+
def determine_split() -> str:
148+
partition = random.uniform(0, 1)
149+
if partition < validation_bound:
150+
return "train"
151+
return "validation"
152+
153+
ikala = mirdata.initialize("ikala")
154+
155+
return [(track_id, determine_split()) for track_id in ikala.track_ids]
156+
157+
158+
def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None:
159+
time_created = int(time.time())
160+
destination = commandline.resolve_destination(known_args, time_created)
161+
162+
pipeline_options = {
163+
"runner": known_args.runner,
164+
"job_name": f"ikala-tfrecords-{time_created}",
165+
"machine_type": "e2-standard-4",
166+
"num_workers": 25,
167+
"disk_size_gb": 128,
168+
"experiments": ["use_runner_v2", "no_use_multiple_sdk_containers"],
169+
"save_main_session": True,
170+
"sdk_container_image": known_args.sdk_container_image,
171+
"job_endpoint": known_args.job_endpoint,
172+
"environment_type": "DOCKER",
173+
"environment_config": known_args.sdk_container_image,
174+
}
175+
input_data = create_input_data(known_args.train_percent, known_args.split_seed)
176+
pipeline.run(
177+
pipeline_options,
178+
pipeline_args,
179+
input_data,
180+
IkalaToTfExample(known_args.source, download=True),
181+
IkalaInvalidTracks(known_args.source),
182+
destination,
183+
known_args.batch_size,
184+
)
185+
186+
187+
if __name__ == "__main__":
188+
parser = argparse.ArgumentParser()
189+
commandline.add_default(parser, os.path.basename(os.path.splitext(__file__)[0]))
190+
commandline.add_split(parser)
191+
known_args, pipeline_args = parser.parse_known_args(sys.argv)
192+
193+
main(known_args, pipeline_args)

basic_pitch/data/download.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,12 @@
1919

2020
from basic_pitch.data import commandline
2121
from basic_pitch.data.datasets.guitarset import main as guitarset_main
22+
from basic_pitch.data.datasets.ikala import main as ikala_main
2223

2324
logger = logging.getLogger()
2425
logger.setLevel(logging.INFO)
2526

26-
DATASET_DICT = {
27-
"guitarset": guitarset_main,
28-
}
27+
DATASET_DICT = {"guitarset": guitarset_main, "ikala": ikala_main}
2928

3029

3130
def main() -> None:

tests/data/test_ikala.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#!/usr/bin/env python
2+
# encoding: utf-8
3+
#
4+
# Copyright 2024 Spotify AB
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
import apache_beam as beam
18+
import itertools
19+
import os
20+
21+
from apache_beam.testing.test_pipeline import TestPipeline
22+
23+
from basic_pitch.data.datasets.ikala import (
24+
IkalaInvalidTracks,
25+
create_input_data,
26+
)
27+
28+
29+
# TODO: Create test_ikala_to_tf_example
30+
31+
32+
def test_ikala_invalid_tracks(tmpdir: str) -> None:
33+
split_labels = ["train", "validation"]
34+
input_data = [(str(i), split) for i, split in enumerate(split_labels)]
35+
with TestPipeline() as p:
36+
splits = (
37+
p
38+
| "Create PCollection" >> beam.Create(input_data)
39+
| "Tag it" >> beam.ParDo(IkalaInvalidTracks()).with_outputs(*split_labels)
40+
)
41+
42+
for split in split_labels:
43+
(
44+
getattr(splits, split)
45+
| f"Write {split} to text"
46+
>> beam.io.WriteToText(os.path.join(tmpdir, f"output_{split}.txt"), shard_name_template="")
47+
)
48+
49+
for i, split in enumerate(split_labels):
50+
with open(os.path.join(tmpdir, f"output_{split}.txt"), "r") as fp:
51+
assert fp.read().strip() == str(i)
52+
53+
54+
def test_create_input_data() -> None:
55+
data = create_input_data(train_percent=0.5)
56+
data.sort(key=lambda el: el[1]) # sort by split
57+
tolerance = 0.05
58+
for key, group in itertools.groupby(data, lambda el: el[1]):
59+
assert (0.5 - tolerance) * len(data) <= len(list(group)) <= (0.5 + tolerance) * len(data)
60+
61+
62+
def test_create_input_data_overallocate() -> None:
63+
try:
64+
create_input_data(train_percent=1.1)
65+
except AssertionError:
66+
assert True
67+
else:
68+
assert False

0 commit comments

Comments
 (0)