Skip to content

Commit

Permalink
allow removing filtering on the file duration setting min_duration/ma…
Browse files Browse the repository at this point in the history
…x_duration to None
  • Loading branch information
jeremyfix committed Nov 6, 2023
1 parent b54d84e commit 6426df4
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions LabsSolutions/02-pytorch-asr/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def __init__(
self.valid_indices = [
i
for i, (w, r, _) in tqdm.tqdm(enumerate(ds))
if min_duration <= w.squeeze().shape[0] / r <= max_duration
if (not min_duration or min_duration <= (w.squeeze().shape[0] / r))
and (not max_duration or (w.squeeze().shape[0] / r) <= max_duration)
]
pickle.dump(self.valid_indices, open(cachepath, "wb"))
self.ds = ds
Expand Down Expand Up @@ -445,17 +446,21 @@ def get_dataloaders(
"""

def dataset_loader(fold, version, overwrite_index):
return DatasetFilter(
ds=load_dataset(
fold,
commonvoice_root=commonvoice_root,
commonvoice_version=commonvoice_version,
),
min_duration=min_duration,
max_duration=max_duration,
cachepath=Path(fold + "-" + version + ".idx"),
overwrite_index=overwrite_index,
ds = load_dataset(
fold,
commonvoice_root=commonvoice_root,
commonvoice_version=commonvoice_version,
)
if not min_duration and not max_duration:
return ds
else:
return DatasetFilter(
ds=ds,
min_duration=min_duration,
max_duration=max_duration,
cachepath=Path(fold + "-" + version + ".idx"),
overwrite_index=overwrite_index,
)

valid_dataset = dataset_loader("dev", commonvoice_version, overwrite_index)
train_dataset = dataset_loader("train", commonvoice_version, overwrite_index)
Expand Down Expand Up @@ -730,7 +735,7 @@ def ex_augmented_spectro():
cuda=False,
n_threads=4,
min_duration=1,
max_duration=3,
max_duration=None,
batch_size=batch_size,
train_augment=True,
normalize=False,
Expand Down Expand Up @@ -835,6 +840,6 @@ def test_spectro():
# test_spectro()
# ex_waveform_spectro()
# ex_spectro()
# ex_augmented_spectro()
ex_augmented_spectro()
pass
# SOL@

0 comments on commit 6426df4

Please sign in to comment.