-
Notifications
You must be signed in to change notification settings - Fork 0
/
还原图像.py
79 lines (65 loc) · 2.48 KB
/
还原图像.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch.utils.data as data
import numpy as np
import lmdb
import os
import io
from PIL import Image
def num_samples(dataset, train):
if dataset == 'celeba':
return 27000 if train else 3000
elif dataset == 'celeba64':
return 162770 if train else 19867
elif dataset == 'imagenet-oord':
return 1281147 if train else 50000
elif dataset == 'ffhq':
return 63000 if train else 7000
else:
raise NotImplementedError('dataset %s is unknown' % dataset)
class LMDBDataset(data.Dataset):
def __init__(self, root, name='', train=True, transform=None, is_encoded=False):
self.train = train
self.name = name
self.transform = transform
if self.train:
lmdb_path = os.path.join(root, 'train.lmdb')
else:
lmdb_path = os.path.join(root, 'validation.lmdb')
self.data_lmdb = lmdb.open(lmdb_path, readonly=True, max_readers=1,
lock=False, readahead=False, meminit=False)
self.is_encoded = is_encoded
def __getitem__(self, index):
target = [0]
with self.data_lmdb.begin(write=False, buffers=True) as txn:
data = txn.get(str(index).encode())
if self.is_encoded:
img = Image.open(io.BytesIO(data))
img = img.convert('RGB')
else:
img = np.asarray(data, dtype=np.uint8)
# assume data is RGB
size = int(np.sqrt(len(img) / 3))
img = np.reshape(img, (size, size, 3))
img = Image.fromarray(img, mode='RGB')
if self.transform is not None:
img = self.transform(img)
return img, target
def __len__(self):
return num_samples(self.name, self.train)
def _data_transforms_celeba64(size):
train_transform = transforms.Compose([
CropCelebA64(),
transforms.Resize(size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
valid_transform = transforms.Compose([
CropCelebA64(),
transforms.Resize(size),
transforms.ToTensor(),
])
return train_transform, valid_transform
resize = 64
num_classes = 40
train_transform, valid_transform = _data_transforms_celeba64(resize)
train_data = LMDBDataset(root='./pic', name='celeba64', train=True, transform=train_transform, is_encoded=True)
valid_data = LMDBDataset(root='./pic', name='celeba64', train=False, transform=valid_transform, is_encoded=True)