-
Notifications
You must be signed in to change notification settings - Fork 15
/
test_data.py
33 lines (28 loc) · 1.1 KB
/
test_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import os
from PIL import Image
import torchvision.transforms as transforms
class test_dataset:
def __init__(self, image_root, gt_root):
self.img_list = [os.path.splitext(f)[0] for f in os.listdir(gt_root) if f.endswith('.png')]
self.image_root = image_root
self.gt_root = gt_root
self.transform = transforms.Compose([
transforms.ToTensor(),
])
self.gt_transform = transforms.ToTensor()
self.size = len(self.img_list)
self.index = 0
def load_data(self):
#image = self.rgb_loader(self.images[self.index])
image = self.binary_loader(os.path.join(self.image_root,self.img_list[self.index]+ '.png'))
gt = self.binary_loader(os.path.join(self.gt_root,self.img_list[self.index] + '.png'))
self.index += 1
return image, gt
def rgb_loader(self, path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def binary_loader(self, path):
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('L')