-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_dataset.py
69 lines (57 loc) · 1.72 KB
/
train_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import random
from tqdm import tqdm
from torch.utils.data import Dataset
import torch
import random
from PIL import Image
import numpy as np
from torchvision import transforms
from itertools import islice, chain
import os
import cv2
def center_crop(image):
center = image.shape[0] / 2, image.shape[1] / 2
if center[1] < 256 or center[0] < 256:
return cv2.resize(image, (256, 256))
x = center[1] - 128
y = center[0] - 128
return image[int(y):int(y+256), int(x):int(x+256)]
class MyCustomDataset(Dataset):
def __init__(self,
path_gt,
device='cpu'
):
self._items = []
self._index = 0
self.device = device
dir_img = sorted(os.listdir(path_gt))
img_pathes = dir_img
for img_path in img_pathes:
self._items.append((
os.path.join(path_gt, img_path)
))
random.shuffle(self._items)
def __len__(self):
return len(self._items)
def next_data(self):
gt_path = self._items[self._index]
self._index += 1
if self._index == len(self._items):
self._index = 0
random.shuffle(self._items)
image = Image.open(gt_path).convert('RGB')
image = np.array(image).astype(np.float32)
image = center_crop(image)
image = image / 255.
image = transforms.ToTensor()(image)
y = image.to(self.device)
return y
def __getitem__(self, index):
gt_path = self._items[index]
image = Image.open(gt_path).convert('RGB')
image = np.array(image).astype(np.float32)
image = center_crop(image)
image = image / 255.
image = transforms.ToTensor()(image)
y = image.to(self.device)
return y