-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
58 lines (46 loc) · 1.89 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from config import params
from models.dcgan import Discriminator, Generator
from trainer import Trainer
from checkpoints import restore_model
from utils import init_random_seed, check_dirs, get_data_loader
def main():
# init random seed
init_random_seed(params.manual_seed)
#check the needed dirs of config
check_dirs()
cudnn.benchmark = True
torch.cuda.set_device(params.gpu_id[0]) #set current device
print('=== Build model ===')
#gpu mode
generator = Generator()
discriminator = Discriminator()
generator = nn.DataParallel(generator, device_ids=params.gpu_id).cuda()
discriminator = nn.DataParallel(discriminator, device_ids=params.gpu_id).cuda()
# restore trained model
if params.generator_restored:
generator = restore_model(generator, params.generator_restored)
if params.discriminator_restored:
discriminator = restore_model(discriminator, params.discriminator_restored)
# container of training
trainer = Trainer(generator,discriminator)
if params.mode == 'train':
# data loader
print('=== Load data ===')
train_dataloader = get_data_loader(params.dataset)
print('=== Begin training ===')
trainer.train(train_dataloader)
print('=== Generate {} images, saving in {} ==='.format(params.num_images, params.save_root))
trainer.generate_images(params.num_images, params.save_root)
elif params.mode == 'test':
if params.generator_restored:
print('=== Generate {} images, saving in {} ==='.format(params.num_images, params.save_root))
trainer.generate_images(params.num_images, params.save_root)
else:
assert False, '[*]load Generator model first!'
else:
assert False, "[*]mode must be 'train' or 'test'!"
if __name__ == '__main__':
main()