diff --git a/trainer/dataset.py b/trainer/dataset.py index a6485eea3a9..958d3c9fe42 100644 --- a/trainer/dataset.py +++ b/trainer/dataset.py @@ -98,12 +98,12 @@ def get_batch(self): for i, data_loader_iter in enumerate(self.dataloader_iter_list): try: - image, text = data_loader_iter.next() + image, text = next(data_loader_iter) balanced_batch_images.append(image) balanced_batch_texts += text except StopIteration: self.dataloader_iter_list[i] = iter(self.data_loader_list[i]) - image, text = self.dataloader_iter_list[i].next() + image, text = next(self.dataloader_iter_list[i]) balanced_batch_images.append(image) balanced_batch_texts += text except ValueError: diff --git a/trainer/train.py b/trainer/train.py index e0066f3d078..d2a51289f06 100644 --- a/trainer/train.py +++ b/trainer/train.py @@ -14,7 +14,7 @@ from utils import CTCLabelConverter, AttnLabelConverter, Averager from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset from model import Model -from test import validation +from validate import validation device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def count_parameters(model): diff --git a/trainer/test.py b/trainer/validate.py similarity index 100% rename from trainer/test.py rename to trainer/validate.py