Skip to content

Commit f2f9a39

Browse files
redid maestro tests to mock all data, removed midi and wav files from repo. Also removed metadata csv file as it wasn't being used in the code.
1 parent 62634f9 commit f2f9a39

9 files changed

+62
-1297
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ test = [
6464
"coverage>=5.0.2",
6565
"pytest>=6.1.1",
6666
"pytest-mock",
67+
"wave",
68+
"mido"
6769
]
6870
tf = [
6971
"tensorflow>=2.4.1,<2.15.1; platform_system != 'Darwin'",

tests/data/test_maestro.py

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
import numpy as np
1919
import os
2020
import pathlib
21-
from typing import List
2221
import wave
2322

23+
from mido import MidiFile, MidiTrack, Message
24+
from typing import List
25+
2426
import apache_beam as beam
2527
from apache_beam.testing.test_pipeline import TestPipeline
2628

@@ -37,8 +39,7 @@
3739
TRAIN_TRACK_ID = "2004/MIDI-Unprocessed_SMF_05_R1_2004_01_ORIG_MID--AUDIO_05_R1_2004_03_Track03_wav"
3840
VALID_TRACK_ID = "2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_06_Track06_wav"
3941
TEST_TRACK_ID = "2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_08_Track08_wav"
40-
41-
MOCK_15M_TRACK_ID = "2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav"
42+
GT_15M_TRACK_ID = "2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav"
4243

4344

4445
def create_mock_wav(output_fpath: str, duration_min: int) -> None:
@@ -61,31 +62,75 @@ def create_mock_wav(output_fpath: str, duration_min: int) -> None:
6162
logging.info(f"Mock {duration_min}-minute WAV file '{output_fpath}' created successfully.")
6263

6364

65+
def create_mock_midi(output_fpath: str) -> None:
66+
# Create a new MIDI file with one track
67+
mid = MidiFile()
68+
track = MidiTrack()
69+
mid.tracks.append(track)
70+
71+
# Define a sequence of notes (time, type, note, velocity)
72+
notes = [
73+
(0, "note_on", 60, 64), # C4
74+
(500, "note_off", 60, 64),
75+
(0, "note_on", 62, 64), # D4
76+
(500, "note_off", 62, 64),
77+
]
78+
79+
# Add the notes to the track
80+
for time, type, note, velocity in notes:
81+
track.append(Message(type, note=note, velocity=velocity, time=time))
82+
83+
# Save the MIDI file
84+
mid.save(output_fpath)
85+
86+
logging.info(f"Mock MIDI file '{output_fpath}' created successfully.")
87+
88+
6489
def test_maestro_to_tf_example(tmpdir: str) -> None:
90+
mock_maestro_home = pathlib.Path(tmpdir) / "maestro"
91+
mock_maestro_ext = mock_maestro_home / "2004"
92+
mock_maestro_ext.mkdir(parents=True, exist_ok=True)
93+
94+
create_mock_wav(str(mock_maestro_ext / (TRAIN_TRACK_ID.split("/")[1] + ".wav")), 3)
95+
create_mock_midi(str(mock_maestro_ext / (TRAIN_TRACK_ID.split("/")[1] + ".midi")))
96+
97+
output_dir = pathlib.Path(tmpdir) / "outputs"
98+
output_dir.mkdir(parents=True, exist_ok=True)
99+
65100
input_data: List[str] = [TRAIN_TRACK_ID]
66101
with TestPipeline() as p:
67102
(
68103
p
69104
| "Create PCollection of track IDs" >> beam.Create([input_data])
70-
| "Create tf.Example" >> beam.ParDo(MaestroToTfExample(str(MAESTRO_TEST_DATA_PATH), download=False))
71-
| "Write to tfrecord" >> beam.ParDo(WriteBatchToTfRecord(tmpdir))
105+
| "Create tf.Example" >> beam.ParDo(MaestroToTfExample(str(mock_maestro_home), download=False))
106+
| "Write to tfrecord" >> beam.ParDo(WriteBatchToTfRecord(str(output_dir)))
72107
)
73108

74-
assert len(os.listdir(tmpdir)) == 1
75-
assert os.path.splitext(os.listdir(tmpdir)[0])[-1] == ".tfrecord"
76-
with open(os.path.join(tmpdir, os.listdir(tmpdir)[0]), "rb") as fp:
109+
assert len(os.listdir(str(output_dir))) == 1
110+
print("PASSED THIS POINT")
111+
assert os.path.splitext(os.listdir(str(output_dir))[0])[-1] == ".tfrecord"
112+
print("PASSED THIS OTHER POINT")
113+
with open(os.path.join(str(output_dir), os.listdir(str(output_dir))[0]), "rb") as fp:
77114
data = fp.read()
78115
assert len(data) != 0
79116

80117

81118
def test_maestro_invalid_tracks(tmpdir: str) -> None:
119+
mock_maestro_home = pathlib.Path(tmpdir) / "maestro"
120+
mock_maestro_ext = mock_maestro_home / "2004"
121+
mock_maestro_ext.mkdir(parents=True, exist_ok=True)
122+
123+
create_mock_wav(str(mock_maestro_ext / (TRAIN_TRACK_ID.split("/")[1] + ".wav")), 3)
124+
create_mock_wav(str(mock_maestro_ext / (VALID_TRACK_ID.split("/")[1] + ".wav")), 3)
125+
create_mock_wav(str(mock_maestro_ext / (TEST_TRACK_ID.split("/")[1] + ".wav")), 3)
126+
82127
input_data = [(TRAIN_TRACK_ID, "train"), (VALID_TRACK_ID, "validation"), (TEST_TRACK_ID, "test")]
83128
split_labels = set([e[1] for e in input_data])
84129
with TestPipeline() as p:
85130
splits = (
86131
p
87132
| "Create PCollection" >> beam.Create(input_data)
88-
| "Tag it" >> beam.ParDo(MaestroInvalidTracks(str(MAESTRO_TEST_DATA_PATH))).with_outputs(*split_labels)
133+
| "Tag it" >> beam.ParDo(MaestroInvalidTracks(str(mock_maestro_home))).with_outputs(*split_labels)
89134
)
90135

91136
for split in split_labels:
@@ -106,16 +151,19 @@ def test_maestro_invalid_tracks_over_15_min(tmpdir: str) -> None:
106151
not to store a large file in git, hence the variable name.
107152
"""
108153

109-
mock_fpath = MAESTRO_TEST_DATA_PATH / "2004" / (MOCK_15M_TRACK_ID.split("/")[1] + ".wav")
154+
mock_maestro_home = pathlib.Path(tmpdir) / "maestro"
155+
mock_maestro_ext = mock_maestro_home / "2004"
156+
mock_maestro_ext.mkdir(parents=True, exist_ok=True)
157+
mock_fpath = mock_maestro_ext / (GT_15M_TRACK_ID.split("/")[1] + ".wav")
110158
create_mock_wav(str(mock_fpath), 16)
111159

112-
input_data = [(MOCK_15M_TRACK_ID, "train")]
160+
input_data = [(GT_15M_TRACK_ID, "train")]
113161
split_labels = set([e[1] for e in input_data])
114162
with TestPipeline() as p:
115163
splits = (
116164
p
117165
| "Create PCollection" >> beam.Create(input_data)
118-
| "Tag it" >> beam.ParDo(MaestroInvalidTracks(str(MAESTRO_TEST_DATA_PATH))).with_outputs(*split_labels)
166+
| "Tag it" >> beam.ParDo(MaestroInvalidTracks(str(mock_maestro_home))).with_outputs(*split_labels)
119167
)
120168

121169
for split in split_labels:
@@ -129,8 +177,6 @@ def test_maestro_invalid_tracks_over_15_min(tmpdir: str) -> None:
129177
with open(os.path.join(tmpdir, f"output_{split}.txt"), "r") as fp:
130178
assert fp.read().strip() == ""
131179

132-
os.remove(mock_fpath)
133-
134180

135181
def test_maestro_create_input_data() -> None:
136182
data = create_input_data(str(MAESTRO_TEST_DATA_PATH))

0 commit comments

Comments
 (0)