-
Notifications
You must be signed in to change notification settings - Fork 8
/
loader_checkpoint.py
33 lines (26 loc) · 1.28 KB
/
loader_checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torchvision
from generator import GeneratorResnet
import pandas as pd
# Load a particular generator
def load_gan(args, domain):
if domain[-5:] == 'incv3':
netG = GeneratorResnet(inception=True)
else:
netG = GeneratorResnet()
if args.RN and args.DA:
save_checkpoint_suffix = 'BIA+RN+DA'
elif args.RN:
save_checkpoint_suffix = 'BIA+RN'
elif args.DA:
save_checkpoint_suffix = 'BIA+DA'
else:
save_checkpoint_suffix = 'BIA'
print('Substitute Model: {} \t RN: {} \t DA: {} \tSaving instance: {}'.format(args.model_type,
args.RN,
args.DA,
args.epochs))
netG.load_state_dict(torch.load('saved_models/{}/netG_{}_{}.pth'.format(args.model_type,
save_checkpoint_suffix,
args.epochs)))
return netG