Skip to content

Commit 2cb2233

Browse files
committed
commiting changes before lockout
1 parent db59439 commit 2cb2233

File tree

8 files changed

+17
-17
lines changed

8 files changed

+17
-17
lines changed

basic_pitch/dataset/commandline.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,8 @@
1616
# limitations under the License.
1717

1818
import argparse
19-
import inspect
2019
import os
2120
import os.path as op
22-
import pdb
2321

2422

2523
def add_default(parser: argparse.ArgumentParser, dataset_name: str):
@@ -33,7 +31,8 @@ def add_default(parser: argparse.ArgumentParser, dataset_name: str):
3331
help="If passed, the dataset will be put into a timestamp directory instead of 'splits'")
3432
parser.add_argument("--batch-size", default=5, type=int, help="Number of examples per tfrecord")
3533
parser.add_argument("--worker-harness-container-image", default="",
36-
help="Container image to run dataset generation job with. Required due to non-python dependencies")
34+
help="Container image to run dataset generation job with. \
35+
Required due to non-python dependencies.")
3736

3837

3938
def resolve_destination(namespace: argparse.Namespace, dataset: str, time_created: int) -> str:

basic_pitch/dataset/download.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,18 @@
77
from basic_pitch.dataset.medleydb_pitch import main as medleydb_pitch_main
88
from basic_pitch.dataset.slakh import main as slakh_main
99

10-
dataset_dict = {
10+
DATASET_DICT = {
1111
'guitarset': guitarset_main,
1212
'ikala': ikala_main,
1313
'maestro': maestro_main,
1414
'medleydb_pitch': medleydb_pitch_main,
1515
'slakh': slakh_main
1616
}
1717

18+
1819
def main():
1920
dataset_parser = argparse.ArgumentParser()
20-
dataset_parser.add_argument("dataset", choices=list(dataset_dict.keys()), help="The dataset to download / process.")
21+
dataset_parser.add_argument("dataset", choices=list(DATASET_DICT.keys()), help="The dataset to download / process.")
2122
dataset = dataset_parser.parse_args().dataset
2223

2324
print(f'got the arg: {dataset}')
@@ -26,7 +27,7 @@ def main():
2627
commandline.add_split(cl_parser)
2728
known_args, pipeline_args = cl_parser.parse_known_args() # sys.argv)
2829

29-
dataset_dict[dataset](known_args, pipeline_args)
30+
DATASET_DICT[dataset](known_args, pipeline_args)
3031

3132

3233
if __name__ == '__main__':

basic_pitch/dataset/guitarset.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import os
2121
import os.path as op
2222
import random
23-
import sys
2423
import time
2524
from typing import List, Tuple, Optional
2625

basic_pitch/dataset/slakh.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import logging
2020
import os
2121
import os.path as op
22-
import sys
2322
import time
2423
from typing import List, Tuple
2524

basic_pitch/dataset/tf_example_deserialization.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,13 @@
3939

4040

4141
def prepare_datasets(
42-
datasets_base_path,
43-
training_shuffle_buffer_size,
44-
batch_size,
45-
validation_steps,
42+
datasets_base_path: str,
43+
training_shuffle_buffer_size: int,
44+
batch_size: int,
45+
validation_steps: int,
4646
datasets_to_use: List[str],
4747
dataset_sampling_frequency: np.ndarray,
48-
):
48+
) -> tf.data.Dataset:
4949
"""
5050
Return a training and a testing dataset.
5151
@@ -177,7 +177,6 @@ def sample_datasets(
177177

178178
ds_list = []
179179

180-
181180
file_generator, random_seed = transcription_file_generator(
182181
split,
183182
datasets,
@@ -213,7 +212,7 @@ def sample_datasets(
213212
choice_dataset = tf.data.Dataset.range(
214213
n_datasets
215214
).repeat() # this repeat is critical! if not, only n_dataset points will be sampled!!
216-
return tf.data.experimental.choose_from_datasets(ds_list, choice_dataset)
215+
return tf.data.Datasets.choose_from_datasets(ds_list, choice_dataset)
217216

218217

219218
def transcription_file_generator(

basic_pitch/dataset/tf_example_serialization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def to_transcription_tfexample(
8787
contours_values: List[float],
8888
notes_onsets_shape: Tuple[int, int],
8989
contours_shape: Tuple[int, int],
90-
):
90+
) -> tf.train.Example:
9191
"""
9292
- `file_id` string
9393
- `source` string (e.g., "maestro")

basic_pitch/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,8 @@ def console_entry_point():
220220

221221
args = parser.parse_args()
222222
datasets_to_use = [
223-
dataset.lower() for dataset in DATASET_SAMPLING_FREQUENCY.keys() if getattr(args, dataset.lower().replace("-", "_"))
223+
dataset.lower() for dataset in DATASET_SAMPLING_FREQUENCY.keys()
224+
if getattr(args, dataset.lower().replace("-", "_"))
224225
]
225226
dataset_sampling_frequency = [
226227
frequency

setup.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ console_scripts =
5151
download-data = basic_pitch.dataset.download:main
5252

5353
[options.extras_require]
54+
training =
55+
apache_beam
5456
test =
5557
coverage>=5.0.2
5658
pytest>=6.1.1

0 commit comments

Comments
 (0)