diff --git a/lib/utils.py b/lib/utils.py index 82dd538..e620c2c 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -186,7 +186,7 @@ def load_dataset(dataset_dir, batch_size, test_batch_size=None, **kwargs): for category in ['train', 'val', 'test']: data['x_' + category][..., 0] = scaler.transform(data['x_' + category][..., 0]) data['y_' + category][..., 0] = scaler.transform(data['y_' + category][..., 0]) - data['train_loader'] = DataLoader(data['x_train'], data['y_train'], batch_size, shuffle=True) + data['train_loader'] = DataLoader(data['x_train'], data['y_train'], batch_size, shuffle=False) data['val_loader'] = DataLoader(data['x_val'], data['y_val'], test_batch_size, shuffle=False) data['test_loader'] = DataLoader(data['x_test'], data['y_test'], test_batch_size, shuffle=False) data['scaler'] = scaler diff --git a/model/pytorch/dcrnn_supervisor.py b/model/pytorch/dcrnn_supervisor.py index c2a34d4..53f6a33 100644 --- a/model/pytorch/dcrnn_supervisor.py +++ b/model/pytorch/dcrnn_supervisor.py @@ -174,12 +174,23 @@ def _train(self, base_lr, self.dcrnn_model = self.dcrnn_model.train() + + # shuffle the batches train_iterator = self._data['train_loader'].get_iterator() + all_train = np.array([(x,y) for _, (x, y) in enumerate(train_iterator)]) + permutation = np.random.permutation(all_train.shape[0]) + all_train = all_train[permutation] + + losses = [] - start_time = time.time() - for _, (x, y) in enumerate(train_iterator): + for batch in all_train: + x = batch[0] + y = batch[1] + + + optimizer.zero_grad() x, y = self._prepare_data(x, y)