-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_utils.py
98 lines (75 loc) · 3.6 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from os import listdir
from os.path import join
from PIL import Image
from torch.utils.data.dataset import Dataset
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize
def is_image_file(filename):
return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])
def calculate_valid_crop_size(crop_size, upscale_factor):
return crop_size - (crop_size % upscale_factor)
def train_hr_transform(crop_size):
return Compose([
RandomCrop(crop_size),
ToTensor(),
])
def train_lr_transform(crop_size, upscale_factor):
return Compose([
ToPILImage(),
Resize(crop_size // upscale_factor, interpolation=Image.BICUBIC),
ToTensor()
])
def display_transform():
return Compose([
ToPILImage(),
Resize(400),
CenterCrop(400),
ToTensor()
])
class TrainDatasetFromFolder(Dataset):
def __init__(self, dataset_dir, crop_size, upscale_factor):
super(TrainDatasetFromFolder, self).__init__()
self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]
crop_size = calculate_valid_crop_size(crop_size, upscale_factor)
self.hr_transform = train_hr_transform(crop_size)
self.lr_transform = train_lr_transform(crop_size, upscale_factor)
def __getitem__(self, index):
hr_image = self.hr_transform(Image.open(self.image_filenames[index]))
lr_image = self.lr_transform(hr_image)
return lr_image, hr_image
def __len__(self):
return len(self.image_filenames)
class ValDatasetFromFolder(Dataset):
def __init__(self, dataset_dir, upscale_factor):
super(ValDatasetFromFolder, self).__init__()
self.upscale_factor = upscale_factor
self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]
def __getitem__(self, index):
hr_image = Image.open(self.image_filenames[index])
w, h = hr_image.size
crop_size = calculate_valid_crop_size(min(w, h), self.upscale_factor)
lr_scale = Resize(crop_size // self.upscale_factor, interpolation=Image.BICUBIC)
hr_scale = Resize(crop_size, interpolation=Image.BICUBIC)
hr_image = CenterCrop(crop_size)(hr_image)
lr_image = lr_scale(hr_image)
hr_restore_img = hr_scale(lr_image)
return ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)
def __len__(self):
return len(self.image_filenames)
class TestDatasetFromFolder(Dataset):
def __init__(self, dataset_dir, upscale_factor):
super(TestDatasetFromFolder, self).__init__()
self.lr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/data/'
self.hr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/target/'
self.upscale_factor = upscale_factor
self.lr_filenames = [join(self.lr_path, x) for x in listdir(self.lr_path) if is_image_file(x)]
self.hr_filenames = [join(self.hr_path, x) for x in listdir(self.hr_path) if is_image_file(x)]
def __getitem__(self, index):
image_name = self.lr_filenames[index].split('/')[-1]
lr_image = Image.open(self.lr_filenames[index])
w, h = lr_image.size
hr_image = Image.open(self.hr_filenames[index])
hr_scale = Resize((self.upscale_factor * h, self.upscale_factor * w), interpolation=Image.BICUBIC)
hr_restore_img = hr_scale(lr_image)
return image_name, ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)
def __len__(self):
return len(self.lr_filenames)