Skip to content

Commit

Permalink
fix for splitting out valid dataset when dataset is directly passed i…
Browse files Browse the repository at this point in the history
…nto soundstream trainer
  • Loading branch information
lucidrains committed Dec 10, 2023
1 parent 05c23d9 commit 31b0b7e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 14 deletions.
25 changes: 12 additions & 13 deletions audiolm_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,25 +276,24 @@ def __init__(
seq_len_multiple_of = soundstream.seq_len_multiple_of
)

# split for validation

if valid_frac > 0:
train_size = int((1 - valid_frac) * len(self.ds))
valid_size = len(self.ds) - train_size
dataset, val_dataset = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
else:
val_dataset = dataset
self.print(f'training with shared training and valid dataset of {len(self.ds)} samples')

assert len(dataset) >= batch_size, 'dataset must have sufficient samples for training'
assert len(val_dataset) >= batch_size, f'validation dataset must have sufficient number of samples (currently {len(self.valid_ds)}) for training'

if exists(dataset):
assert not exists(train_dataloader)
assert not exists(val_dataloader)

val_dataset = default(val_dataset, dataset)
# maybe split for validation

if valid_frac > 0:
train_size = int((1 - valid_frac) * len(dataset))
valid_size = len(dataset) - train_size
dataset, val_dataset = random_split(dataset, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
self.print(f'training with dataset of {len(dataset)} samples and validating with randomly splitted {len(val_dataset)} samples')
else:
val_dataset = dataset
self.print(f'training with shared training and valid dataset of {len(dataset)} samples')

assert len(val_dataset) >= batch_size, f'validation dataset must have sufficient number of samples (currently {len(val_dataset)}) for training'

train_dataloader = get_dataloader(dataset, batch_size = batch_size, num_workers = dl_num_workers, shuffle = True, drop_last = dataloader_drop_last)
val_dataloader = get_dataloader(val_dataset, batch_size = batch_size, num_workers = dl_num_workers, shuffle = True, drop_last = dataloader_drop_last)
Expand Down
2 changes: 1 addition & 1 deletion audiolm_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.8.6'
__version__ = '1.8.7'

0 comments on commit 31b0b7e

Please sign in to comment.