diff --git a/audiolm_pytorch/trainer.py b/audiolm_pytorch/trainer.py index 703f782..d296dd5 100644 --- a/audiolm_pytorch/trainer.py +++ b/audiolm_pytorch/trainer.py @@ -276,25 +276,24 @@ def __init__( seq_len_multiple_of = soundstream.seq_len_multiple_of ) - # split for validation - - if valid_frac > 0: - train_size = int((1 - valid_frac) * len(self.ds)) - valid_size = len(self.ds) - train_size - dataset, val_dataset = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed)) - self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples') - else: - val_dataset = dataset - self.print(f'training with shared training and valid dataset of {len(self.ds)} samples') - assert len(dataset) >= batch_size, 'dataset must have sufficient samples for training' - assert len(val_dataset) >= batch_size, f'validation dataset must have sufficient number of samples (currently {len(self.valid_ds)}) for training' if exists(dataset): assert not exists(train_dataloader) assert not exists(val_dataloader) - val_dataset = default(val_dataset, dataset) + # maybe split for validation + + if valid_frac > 0: + train_size = int((1 - valid_frac) * len(dataset)) + valid_size = len(dataset) - train_size + dataset, val_dataset = random_split(dataset, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed)) + self.print(f'training with dataset of {len(dataset)} samples and validating with randomly splitted {len(val_dataset)} samples') + else: + val_dataset = dataset + self.print(f'training with shared training and valid dataset of {len(dataset)} samples') + + assert len(val_dataset) >= batch_size, f'validation dataset must have sufficient number of samples (currently {len(val_dataset)}) for training' train_dataloader = get_dataloader(dataset, batch_size = batch_size, num_workers = dl_num_workers, shuffle = True, drop_last = dataloader_drop_last) val_dataloader = get_dataloader(val_dataset, batch_size = batch_size, num_workers = dl_num_workers, shuffle = True, drop_last = dataloader_drop_last) diff --git a/audiolm_pytorch/version.py b/audiolm_pytorch/version.py index 871921a..655be52 100644 --- a/audiolm_pytorch/version.py +++ b/audiolm_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.8.6' +__version__ = '1.8.7'