Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Ikala #128

Merged
merged 3 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 193 additions & 0 deletions basic_pitch/data/datasets/ikala.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
#!/usr/bin/env python
# encoding: utf-8
#
# Copyright 2024 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import logging
import os
import random
import sys
import time
from typing import Any, Dict, List, Tuple, Optional

import apache_beam as beam
import mirdata

from basic_pitch.data import commandline, pipeline


class IkalaInvalidTracks(beam.DoFn):
def process(self, element: Tuple[str, str], *args: Tuple[Any, Any], **kwargs: Dict[str, Any]) -> Any:
track_id, split = element
yield beam.pvalue.TaggedOutput(split, track_id)


class IkalaToTfExample(beam.DoFn):
DOWNLOAD_ATTRIBUTES = ["audio_path", "notes_pyin_path", "f0_path"]

def __init__(self, source: str, download: bool) -> None:
self.source = source
self.download = download

def setup(self) -> None:
import apache_beam as beam
import os
import mirdata

self.ikala_remote = mirdata.initialize("ikala", data_home=os.path.join(self.source, "iKala"))
self.filesystem = beam.io.filesystems.FileSystems() # TODO: replace with fsspec
if self.download:
self.ikala_remote.download()

def process(self, element: List[str], *args: Tuple[Any, Any], **kwargs: Dict[str, Any]) -> List[Any]:
import tempfile

import numpy as np
import sox

from basic_pitch.constants import (
AUDIO_N_CHANNELS,
AUDIO_SAMPLE_RATE,
FREQ_BINS_CONTOURS,
FREQ_BINS_NOTES,
ANNOTATION_HOP,
N_FREQ_BINS_CONTOURS,
N_FREQ_BINS_NOTES,
)
from basic_pitch.data import tf_example_serialization

logging.info(f"Processing {element}")
batch = []

for track_id in element:
track_remote = self.ikala_remote.track(track_id)
with tempfile.TemporaryDirectory() as local_tmp_dir:
ikala_local = mirdata.initialize("ikala", local_tmp_dir)
track_local = ikala_local.track(track_id)

for attr in self.DOWNLOAD_ATTRIBUTES:
source = getattr(track_remote, attr)
dest = getattr(track_local, attr)
os.makedirs(os.path.dirname(dest), exist_ok=True)
with self.filesystem.open(source) as s, open(dest, "wb") as d:
d.write(s.read())

local_wav_path = "{}_tmp.wav".format(track_local.audio_path)

tfm = sox.Transformer()
tfm.rate(AUDIO_SAMPLE_RATE)
tfm.remix({1: [2]})
tfm.channels(AUDIO_N_CHANNELS)
tfm.build(track_local.audio_path, local_wav_path)

duration = sox.file_info.duration(local_wav_path)
time_scale = np.arange(0, duration + ANNOTATION_HOP, ANNOTATION_HOP)
n_time_frames = len(time_scale)

if track_local.notes_pyin is not None:
note_indices, note_values = track_local.notes_pyin.to_sparse_index(
time_scale, "s", FREQ_BINS_NOTES, "hz"
)
onset_indices, onset_values = track_local.notes_pyin.to_sparse_index(
time_scale, "s", FREQ_BINS_NOTES, "hz", onsets_only=True
)
note_shape = (n_time_frames, N_FREQ_BINS_NOTES)
# if there are no notes, return empty note indices
else:
note_indices = []
onset_indices = []
note_values = []
onset_values = []
note_shape = (0, 0)

contour_indices, contour_values = track_local.f0.to_sparse_index(
time_scale, "s", FREQ_BINS_CONTOURS, "hz"
)

batch.append(
tf_example_serialization.to_transcription_tfexample(
track_id,
"ikala",
local_wav_path,
note_indices,
note_values,
onset_indices,
onset_values,
contour_indices,
contour_values,
note_shape,
(n_time_frames, N_FREQ_BINS_CONTOURS),
)
)
return [batch]


def create_input_data(train_percent: float, seed: Optional[int] = None) -> List[Tuple[str, str]]:
assert train_percent < 1.0, "Don't over allocate the data!"

# Test percent is 1 - train - validation
validation_bound = train_percent

if seed:
random.seed(seed)

def determine_split() -> str:
partition = random.uniform(0, 1)
if partition < validation_bound:
return "train"
return "validation"

ikala = mirdata.initialize("ikala")

return [(track_id, determine_split()) for track_id in ikala.track_ids]


def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None:
time_created = int(time.time())
destination = commandline.resolve_destination(known_args, time_created)

pipeline_options = {
"runner": known_args.runner,
"job_name": f"ikala-tfrecords-{time_created}",
"machine_type": "e2-standard-4",
"num_workers": 25,
"disk_size_gb": 128,
"experiments": ["use_runner_v2", "no_use_multiple_sdk_containers"],
"save_main_session": True,
"sdk_container_image": known_args.sdk_container_image,
"job_endpoint": known_args.job_endpoint,
"environment_type": "DOCKER",
"environment_config": known_args.sdk_container_image,
}
input_data = create_input_data(known_args.train_percent, known_args.split_seed)
pipeline.run(
pipeline_options,
pipeline_args,
input_data,
IkalaToTfExample(known_args.source, download=True),
IkalaInvalidTracks(known_args.source),
destination,
known_args.batch_size,
)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
commandline.add_default(parser, os.path.basename(os.path.splitext(__file__)[0]))
commandline.add_split(parser)
known_args, pipeline_args = parser.parse_known_args(sys.argv)

main(known_args, pipeline_args)
5 changes: 2 additions & 3 deletions basic_pitch/data/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@

from basic_pitch.data import commandline
from basic_pitch.data.datasets.guitarset import main as guitarset_main
from basic_pitch.data.datasets.ikala import main as ikala_main

logger = logging.getLogger()
logger.setLevel(logging.INFO)

DATASET_DICT = {
"guitarset": guitarset_main,
}
DATASET_DICT = {"guitarset": guitarset_main, "ikala": ikala_main}


def main() -> None:
Expand Down
68 changes: 68 additions & 0 deletions tests/data/test_ikala.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#!/usr/bin/env python
# encoding: utf-8
#
# Copyright 2024 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import apache_beam as beam
import itertools
import os

from apache_beam.testing.test_pipeline import TestPipeline

from basic_pitch.data.datasets.ikala import (
IkalaInvalidTracks,
create_input_data,
)


# TODO: Create test_ikala_to_tf_example


def test_ikala_invalid_tracks(tmpdir: str) -> None:
split_labels = ["train", "validation"]
input_data = [(str(i), split) for i, split in enumerate(split_labels)]
with TestPipeline() as p:
splits = (
p
| "Create PCollection" >> beam.Create(input_data)
| "Tag it" >> beam.ParDo(IkalaInvalidTracks()).with_outputs(*split_labels)
)

for split in split_labels:
(
getattr(splits, split)
| f"Write {split} to text"
>> beam.io.WriteToText(os.path.join(tmpdir, f"output_{split}.txt"), shard_name_template="")
)

for i, split in enumerate(split_labels):
with open(os.path.join(tmpdir, f"output_{split}.txt"), "r") as fp:
assert fp.read().strip() == str(i)


def test_create_input_data() -> None:
data = create_input_data(train_percent=0.5)
data.sort(key=lambda el: el[1]) # sort by split
tolerance = 0.05
for key, group in itertools.groupby(data, lambda el: el[1]):
assert (0.5 - tolerance) * len(data) <= len(list(group)) <= (0.5 + tolerance) * len(data)


def test_create_input_data_overallocate() -> None:
try:
create_input_data(train_percent=1.1)
except AssertionError:
assert True
else:
assert False
Loading