-
Notifications
You must be signed in to change notification settings - Fork 40
/
main.py
109 lines (95 loc) · 4.63 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import numpy as np
import os
import copy
from demos import demo_full
from lib import models, mesh_sampling
from lib.load_data import BodyData, load_graph_mtx
from config_parser import parse_config
from psbody.mesh import Mesh
args, args_dict = parse_config()
np.random.seed(args_dict['seed'])
project_dir = os.path.dirname(os.path.realpath(__file__))
reference_mesh_file = os.path.join(project_dir, 'data/template_mesh.obj')
reference_mesh = Mesh(filename=reference_mesh_file)
datadir_root = os.path.join(project_dir, 'data', 'datasets')
data_dir = os.path.join(datadir_root, args.dataset)
# load data for train and test
if args.mode in ['train', 'test']:
print("Loading data from {} ..".format(data_dir))
bodydata = BodyData(nVal=100,
train_mesh_fn=data_dir + '/train/train_disp.npy',
train_cond1_fn=data_dir + '/train/train_{}.npy'.format(args.pose_type),
train_cond2_fn=data_dir + '/train/train_{}.npy'.format('clo_label'),
test_mesh_fn=data_dir + '/test/test_disp.npy',
test_cond1_fn=data_dir + '/test/test_{}.npy'.format(args.pose_type),
test_cond2_fn=data_dir + '/test/test_{}.npy'.format('clo_label'),
reference_mesh_file=reference_mesh_file)
if args.num_conv_layers==4:
ds_factors = [1, args.ds_factor, 1, 1]
elif args.num_conv_layers==6:
ds_factors = [1, args.ds_factor, 1, args.ds_factor, 1, 1]
elif args.num_conv_layers == 8:
ds_factors = [1, args.ds_factor, 1, args.ds_factor, 1, args.ds_factor, 1, 1]
print("Pre-computing mesh pooling matrices ..")
M,A,D,U, _ = mesh_sampling.generate_transform_matrices(reference_mesh, ds_factors)
p = list(map(lambda x: x.shape[0], A))
A = list(map(lambda x: x.astype('float32'), A))
D = list(map(lambda x: x.astype('float32'), D))
U = list(map(lambda x: x.astype('float32'), U))
L = [mesh_sampling.laplacian(a, normalized=True) for a in A]
# load pre-computed graph laplacian and pooling matrices for discriminator
L_ds2, D_ds2, U_ds2 = load_graph_mtx(project_dir)
# pass params and build model
params = copy.deepcopy(args_dict)
params['restart'] = bool(args.restart)
params['use_res_block'], params['use_res_block_dec'] = bool(args.use_res_block), bool(args.use_res_block_dec)
params['nn_input_channel'] = 3
nf = args.nf
if args.num_conv_layers==4:
params['F'] = [nf, 2*nf, 2*nf, nf]
elif args.num_conv_layers==6:
params['F'] = [nf, nf, 2*nf, 2*nf, 4*nf, 4*nf]
elif args.num_conv_layers == 8:
params['F'] = [nf, nf, 2*nf, 2*nf, 4*nf, 4*nf, 8*nf, 8*nf]
else:
raise NotImplementedError
params['K'] = [2] * args.num_conv_layers
params['Kd'] = args.Kd # Chebyshev Polynomial orders.
params['p'] = p
params['decay_steps'] = args.decay_every * len(bodydata.vertices_train) / params['batch_size'] if args.mode=='train' else 1
params['cond_dim'] = 14*9 # 14 clothing-related joints * 9 elements per rot matrix
params['cond2_dim'] = 4
params['n_layer_cond'] = args.n_layer_cond
params['cond_encoder'] = bool(args.cond_encoder)
params['reduce_dim'] = args.reduce_dim
params['affine'] = bool(args.affine)
params['optimizer'] = args.optimizer
params['lr_warmup'] = bool(args.lr_warmup)
params['optim_condnet'] = bool(args.optim_condnet)
non_model_params = ['demo_n_sample', 'mode', 'dataset', 'num_conv_layers', 'ds_factor',
'nf', 'config', 'pose_type', 'decay_every', 'gender',
'save_obj', 'vis_demo', 'smpl_model_folder']
for key in non_model_params:
params.pop(key)
print("Building model graph...")
model = models.CAPE(L=L, D=D, U=U, L_d=L_ds2, D_d=D_ds2, **params)
# start train or test/demo
if args.mode == 'train':
model.build_graph(model.input_num_verts, model.nn_input_channel, phase='train')
loss, t_step = model.fit(bodydata)
# full test pipeline after training
model.build_graph(model.input_num_verts, model.nn_input_channel, phase='demo')
demos = demo_full(model, args.name, args.gender, args.dataset, data_dir, datadir_root,
n_sample=args.demo_n_sample, save_obj=bool(args.save_obj), random_seed=args.seed,
vis=bool(args.vis_demo), smpl_model_folder=args.smpl_model_folder)
demos.test_model(bodydata)
demos.run()
else:
model.build_graph(model.input_num_verts, model.nn_input_channel, phase='demo')
demos = demo_full(model, args.name, args.gender, args.dataset, data_dir, datadir_root,
n_sample=args.demo_n_sample, save_obj=bool(args.save_obj), random_seed=args.seed,
vis=bool(args.vis_demo), smpl_model_folder=args.smpl_model_folder)
if args.mode == 'test':
demos.test_model(bodydata)
else:
demos.run()