-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
35 lines (28 loc) · 1.01 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
from torchvision import transforms, datasets, models
from torchvision.utils import save_image
from model import Encoder, Decoder
input_nc = 1
output_nc = 1
enc = Encoder(input_nc, output_nc)
dec = Decoder(input_nc, output_nc)
enc.load_state_dict(torch.load("pretrained/enc.pth"))
dec.load_state_dict(torch.load("pretrained/dec.pth"))
enc.eval()
dec.eval()
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5],
std=[0.5])
])
dataset = datasets.MNIST(root='./data',transform=transform, download=True, train=False)
dataset_loader = torch.utils.data.DataLoader(dataset,
batch_size=96, shuffle=True,)
for i,(image, _) in enumerate(dataset_loader):
encoded = enc(image)
decoded = dec(encoded)
encoded = encoded.view(96,1,4,8)
save_image(image, "output/A_orig.png")
save_image(encoded, "output/A_encoded.png")
save_image(decoded, "output/A_decoded100.png")
break