forked from ChrisWu1997/2D-Motion-Retargeting
-
Notifications
You must be signed in to change notification settings - Fork 0
/
common.py
97 lines (79 loc) · 3.24 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
from functional import utils
import torch
import numpy as np
class Config:
name = None
device = None
# data paths
data_dir = './mixamo_data'
meanpose_path = None
stdpose_path = None
# training paths
save_dir = './train_log'
exp_dir = None
log_dir = None
model_dir = None
# data info
img_size = (512, 512)
unit = 128
nr_joints = 15
len_joints = 2 * nr_joints - 2
view_angles = [(0, 0, -np.pi / 2),
(0, 0, -np.pi / 3),
(0, 0, -np.pi / 6),
(0, 0, 0),
(0, 0, np.pi / 6),
(0, 0, np.pi / 3),
(0, 0, np.pi / 2)]
# network channels
mot_en_channels = None
body_en_channels = None
view_en_channels = None
de_channels = None
# training settings
use_triplet = True
triplet_margin = 1
triplet_weight = 1
use_footvel_loss = False
foot_idx = [20, 21, 26, 27]
footvel_loss_weight = 0.1
nr_epochs = 300
batch_size = 64
num_workers = 4
lr = 1e-3
save_frequency = 50
val_frequency = 100
visualize_frequency = 500
def initialize(self, args):
self.name = args.name if hasattr(args, 'name') else 'full'
self.use_triplet = not args.disable_triplet if hasattr(args, 'disable_triplet') else None
self.use_footvel_loss = args.use_footvel_loss if hasattr(args, 'use_footvel_loss') else None
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_ids)
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.exp_dir = os.path.join(self.save_dir, 'exp_' + self.name)
self.log_dir = os.path.join(self.exp_dir, 'log/')
self.model_dir = os.path.join(self.exp_dir, 'model/')
utils.ensure_dirs([self.log_dir, self.model_dir])
if self.name == 'skeleton':
self.mot_en_channels = [self.len_joints + 2, 64, 96, 128]
self.body_en_channels = [self.len_joints, 32, 48, 64]
self.de_channels = [self.mot_en_channels[-1] + self.body_en_channels[-1], 128, 64, self.len_joints + 2]
self.view_angles = None
self.meanpose_path = './mixamo_data/meanpose.npy'
self.stdpose_path = './mixamo_data/stdpose.npy'
elif self.name == 'view':
self.mot_en_channels = [self.len_joints + 2, 64, 96, 128]
self.view_en_channels = [self.len_joints, 64, 96, 128, 32]
self.de_channels = [self.mot_en_channels[-1] + self.view_en_channels[-1], 128, 64, self.len_joints + 2]
self.meanpose_path = './mixamo_data/meanpose_with_view.npy'
self.stdpose_path = './mixamo_data/stdpose_with_view.npy'
else:
self.mot_en_channels = [self.len_joints + 2, 64, 96, 128]
self.body_en_channels = [self.len_joints, 32, 48, 64, 16]
self.view_en_channels = [self.len_joints, 32, 48, 64, 8]
self.de_channels = [self.mot_en_channels[-1] + self.body_en_channels[-1] + self.view_en_channels[-1],
128, 64, self.len_joints + 2]
self.meanpose_path = './mixamo_data/meanpose_with_view.npy'
self.stdpose_path = './mixamo_data/stdpose_with_view.npy'
config = Config()