Skip to content

Commit 62634f9

Browse files
updated tolerance for ikala, test names for ikala and guitarset, added data and tests + uploaded download.py for maestro, added test data for maestro, updated Manifest for wav and midi files in test
1 parent 91d220b commit 62634f9

15 files changed

+1659
-11
lines changed

MANIFEST.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
include *.txt tox.ini *.rst *.md LICENSE
22
include catalog-info.yaml
33
include Dockerfile .dockerignore
4-
recursive-include tests *.py *.wav *.npz *.jams *.zip
4+
recursive-include tests *.py *.wav *.npz *.jams *.zip *.midi *.csv *.json
55
recursive-include basic_pitch *.py *.md
66
recursive-include basic_pitch/saved_models *.index *.pb variables.data* *.mlmodel *.json *.onnx *.tflite *.bin

basic_pitch/data/datasets/maestro.py

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
#!/usr/bin/env python
2+
# encoding: utf-8
3+
#
4+
# Copyright 2024 Spotify AB
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
18+
import argparse
19+
import logging
20+
import os
21+
import sys
22+
import tempfile
23+
import time
24+
from typing import Any, Dict, List, TextIO, Tuple
25+
26+
import apache_beam as beam
27+
import mirdata
28+
29+
from basic_pitch.data import commandline, pipeline
30+
31+
32+
def read_in_chunks(file_object: TextIO, chunk_size: int = 1024) -> Any:
33+
"""Lazy function (generator) to read a file piece by piece.
34+
Default chunk size: 1k."""
35+
while True:
36+
data = file_object.read(chunk_size)
37+
if not data:
38+
break
39+
yield data
40+
41+
42+
class MaestroInvalidTracks(beam.DoFn):
43+
DOWNLOAD_ATTRIBUTES = ["audio_path"]
44+
45+
def __init__(self, source: str) -> None:
46+
self.source = source
47+
48+
def setup(self) -> None:
49+
# Oddly enough we dont want to include the gcs bucket uri.
50+
# Just the path within the bucket
51+
self.maestro_remote = mirdata.initialize("maestro", data_home=self.source)
52+
self.filesystem = beam.io.filesystems.FileSystems()
53+
54+
def process(self, element: Tuple[str, str], *args: Tuple[Any, Any], **kwargs: Dict[str, Any]) -> Any:
55+
import tempfile
56+
import sox
57+
58+
track_id, split = element
59+
logging.info(f"Processing (track_id, split): ({track_id}, {split})")
60+
61+
track_remote = self.maestro_remote.track(track_id)
62+
with tempfile.TemporaryDirectory() as local_tmp_dir:
63+
maestro_local = mirdata.initialize("maestro", local_tmp_dir)
64+
track_local = maestro_local.track(track_id)
65+
66+
for attribute in self.DOWNLOAD_ATTRIBUTES:
67+
source = getattr(track_remote, attribute)
68+
destination = getattr(track_local, attribute)
69+
os.makedirs(os.path.dirname(destination), exist_ok=True)
70+
with self.filesystem.open(source) as s, open(destination, "wb") as d:
71+
for piece in read_in_chunks(s):
72+
d.write(piece)
73+
74+
# 15 minutes * 60 seconds/minute
75+
if sox.file_info.duration(track_local.audio_path) >= 15 * 60:
76+
return None
77+
78+
yield beam.pvalue.TaggedOutput(split, track_id)
79+
80+
81+
class MaestroToTfExample(beam.DoFn):
82+
DOWNLOAD_ATTRIBUTES = ["audio_path", "midi_path"]
83+
84+
def __init__(self, source: str, download: bool):
85+
self.source = source
86+
self.download = download
87+
88+
def setup(self) -> None:
89+
import apache_beam as beam
90+
import mirdata
91+
92+
# Oddly enough we dont want to include the gcs bucket uri.
93+
# Just the path within the bucket
94+
self.maestro_remote = mirdata.initialize("maestro", data_home=self.source)
95+
self.filesystem = beam.io.filesystems.FileSystems()
96+
if self.download:
97+
self.maestro_remote.download()
98+
99+
def process(self, element: List[str], *args: Tuple[Any, Any], **kwargs: Dict[str, Any]) -> List[Any]:
100+
import tempfile
101+
102+
import numpy as np
103+
import sox
104+
105+
from basic_pitch.constants import (
106+
AUDIO_N_CHANNELS,
107+
AUDIO_SAMPLE_RATE,
108+
FREQ_BINS_CONTOURS,
109+
FREQ_BINS_NOTES,
110+
ANNOTATION_HOP,
111+
N_FREQ_BINS_NOTES,
112+
N_FREQ_BINS_CONTOURS,
113+
)
114+
from basic_pitch.data import tf_example_serialization
115+
116+
logging.info(f"Processing {element}")
117+
batch = []
118+
119+
for track_id in element:
120+
track_remote = self.maestro_remote.track(track_id)
121+
with tempfile.TemporaryDirectory() as local_tmp_dir:
122+
maestro_local = mirdata.initialize("maestro", local_tmp_dir)
123+
track_local = maestro_local.track(track_id)
124+
125+
for attribute in self.DOWNLOAD_ATTRIBUTES:
126+
source = getattr(track_remote, attribute)
127+
destination = getattr(track_local, attribute)
128+
os.makedirs(os.path.dirname(destination), exist_ok=True)
129+
with self.filesystem.open(source) as s, open(destination, "wb") as d:
130+
# d.write(s.read())
131+
for piece in read_in_chunks(s):
132+
d.write(piece)
133+
134+
local_wav_path = f"{track_local.audio_path}_tmp.wav"
135+
136+
tfm = sox.Transformer()
137+
tfm.rate(AUDIO_SAMPLE_RATE)
138+
tfm.channels(AUDIO_N_CHANNELS)
139+
tfm.build(track_local.audio_path, local_wav_path)
140+
141+
duration = sox.file_info.duration(local_wav_path)
142+
time_scale = np.arange(0, duration + ANNOTATION_HOP, ANNOTATION_HOP)
143+
n_time_frames = len(time_scale)
144+
145+
note_indices, note_values = track_local.notes.to_sparse_index(time_scale, "s", FREQ_BINS_NOTES, "hz")
146+
onset_indices, onset_values = track_local.notes.to_sparse_index(
147+
time_scale, "s", FREQ_BINS_NOTES, "hz", onsets_only=True
148+
)
149+
contour_indices, contour_values = track_local.notes.to_sparse_index(
150+
time_scale, "s", FREQ_BINS_CONTOURS, "hz"
151+
)
152+
153+
batch.append(
154+
tf_example_serialization.to_transcription_tfexample(
155+
track_local.track_id,
156+
"maestro",
157+
local_wav_path,
158+
note_indices,
159+
note_values,
160+
onset_indices,
161+
onset_values,
162+
contour_indices,
163+
contour_values,
164+
(n_time_frames, N_FREQ_BINS_NOTES),
165+
(n_time_frames, N_FREQ_BINS_CONTOURS),
166+
)
167+
)
168+
return [batch]
169+
170+
171+
def create_input_data(source: str) -> List[Tuple[str, str]]:
172+
import apache_beam as beam
173+
174+
filesystem = beam.io.filesystems.FileSystems()
175+
176+
with tempfile.TemporaryDirectory() as tmpdir:
177+
maestro = mirdata.initialize("maestro", data_home=tmpdir)
178+
metadata_path = maestro._index["metadata"]["maestro-v2.0.0"][0]
179+
with filesystem.open(
180+
os.path.join(source, metadata_path),
181+
) as s, open(os.path.join(tmpdir, metadata_path), "wb") as d:
182+
d.write(s.read())
183+
184+
return [(track_id, track.split) for track_id, track in maestro.load_tracks().items()]
185+
186+
187+
def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None:
188+
time_created = int(time.time())
189+
destination = commandline.resolve_destination(known_args, time_created)
190+
191+
# TODO: Remove or abstract for foss
192+
pipeline_options = {
193+
"runner": known_args.runner,
194+
"job_name": f"maestro-tfrecords-{time_created}",
195+
"machine_type": "e2-highmem-4",
196+
"num_workers": 25,
197+
"disk_size_gb": 128,
198+
"experiments": ["use_runner_v2", "no_use_multiple_sdk_containers"],
199+
"save_main_session": True,
200+
"sdk_container_image": known_args.sdk_container_image,
201+
"job_endpoint": known_args.job_endpoint,
202+
"environment_type": "DOCKER",
203+
"environment_config": known_args.sdk_container_image,
204+
}
205+
input_data = create_input_data(known_args.source)
206+
pipeline.run(
207+
pipeline_options,
208+
pipeline_args,
209+
input_data,
210+
MaestroToTfExample(known_args.source, download=True),
211+
MaestroInvalidTracks(known_args.source),
212+
destination,
213+
known_args.batch_size,
214+
)
215+
216+
217+
if __name__ == "__main__":
218+
parser = argparse.ArgumentParser()
219+
commandline.add_default(parser, os.path.basename(os.path.splitext(__file__)[0]))
220+
commandline.add_split(parser)
221+
known_args, pipeline_args = parser.parse_known_args(sys.argv)
222+
223+
main(known_args, pipeline_args)

basic_pitch/data/download.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@
2020
from basic_pitch.data import commandline
2121
from basic_pitch.data.datasets.guitarset import main as guitarset_main
2222
from basic_pitch.data.datasets.ikala import main as ikala_main
23+
from basic_pitch.data.datasets.maestro import main as maestro_main
2324

2425
logger = logging.getLogger()
2526
logger.setLevel(logging.INFO)
2627

27-
DATASET_DICT = {"guitarset": guitarset_main, "ikala": ikala_main}
28+
DATASET_DICT = {"guitarset": guitarset_main, "ikala": ikala_main, "maestro": maestro_main}
2829

2930

3031
def main() -> None:

tests/data/test_guitarset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
TRACK_ID = "00_BN1-129-Eb_comp"
3434

3535

36-
def test_guitar_set_to_tf_example(tmpdir: str) -> None:
36+
def test_guitarset_to_tf_example(tmpdir: str) -> None:
3737
input_data: List[str] = [TRACK_ID]
3838
with TestPipeline() as p:
3939
(
@@ -51,7 +51,7 @@ def test_guitar_set_to_tf_example(tmpdir: str) -> None:
5151
assert len(data) != 0
5252

5353

54-
def test_guitar_set_invalid_tracks(tmpdir: str) -> None:
54+
def test_guitarset_invalid_tracks(tmpdir: str) -> None:
5555
split_labels = ["train", "test", "validation"]
5656
input_data = [(str(i), split) for i, split in enumerate(split_labels)]
5757
with TestPipeline() as p:
@@ -73,15 +73,15 @@ def test_guitar_set_invalid_tracks(tmpdir: str) -> None:
7373
assert fp.read().strip() == str(i)
7474

7575

76-
def test_create_input_data() -> None:
76+
def test_guitarset_create_input_data() -> None:
7777
data = create_input_data(train_percent=0.33, validation_percent=0.33)
7878
data.sort(key=lambda el: el[1]) # sort by split
7979
tolerance = 0.1
8080
for key, group in itertools.groupby(data, lambda el: el[1]):
8181
assert (0.33 - tolerance) * len(data) <= len(list(group)) <= (0.33 + tolerance) * len(data)
8282

8383

84-
def test_create_input_data_overallocate() -> None:
84+
def test_guitarset_create_input_data_overallocate() -> None:
8585
try:
8686
create_input_data(train_percent=0.6, validation_percent=0.6)
8787
except AssertionError:

tests/data/test_ikala.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,15 @@ def test_ikala_invalid_tracks(tmpdir: str) -> None:
5151
assert fp.read().strip() == str(i)
5252

5353

54-
def test_create_input_data() -> None:
54+
def test_ikala_create_input_data() -> None:
5555
data = create_input_data(train_percent=0.5)
5656
data.sort(key=lambda el: el[1]) # sort by split
57-
tolerance = 0.05
58-
for key, group in itertools.groupby(data, lambda el: el[1]):
57+
tolerance = 0.1
58+
for _, group in itertools.groupby(data, lambda el: el[1]):
5959
assert (0.5 - tolerance) * len(data) <= len(list(group)) <= (0.5 + tolerance) * len(data)
6060

6161

62-
def test_create_input_data_overallocate() -> None:
62+
def test_ikala_create_input_data_overallocate() -> None:
6363
try:
6464
create_input_data(train_percent=1.1)
6565
except AssertionError:

0 commit comments

Comments
 (0)