-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
44 lines (29 loc) · 1009 Bytes
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from options import Options
from model.learner import Decoder
import torch
from torchvision.utils import save_image
import os
def denorm( x):
out = (x + 1) / 2
return out.clamp_(0, 1)
if __name__ == '__main__':
opt = Options()
opt.build()
opt.isTrain = False
save_root = 'results/'
if not os.path.isdir( save_root):
os.mkdir( save_root)
device = torch.device( "cpu")
path_to_chkpt = f'checkpoints/mt_vae_results/FineTuning_chkpt.tar'
netDec = Decoder( latent_dim= opt.latent_dim)
# print('Loading model ...')
# state_dict = torch.load( path_to_chkpt, map_location=str( 'cpu' ) )
# netDec.load_state_dict(state_dict['Dec_state_dict'])
# print('...Done loading model')
netDec.to( device )
netDec.eval()
batch_size = 5
q_z = torch.empty( batch_size, opt.latent_dim).normal_( mean=0,std=1)
p_x = netDec( q_z)
xpath = os.path.join(save_root, 'tmp_img.jpg')
save_image( denorm( p_x.data.cpu()), xpath)