-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconfig.py
67 lines (57 loc) · 1.32 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from time import strftime, gmtime
'''Hyperparameters for the input.'''
dyn = 'spring'
size = 10
# 3 for kuramoto and 4 for spring and charged.
dim = 4
edge_type = 2
'''Hyperparameters for the model.'''
enc = 'GNNENC'
dec = 'GNNDEC'
n_hid = 2 ** 8
input_emb_hid = 2 ** 7
# Dimension of hidden layers of attention mechanisms.
emb_hid = 2 ** 8
att_hid = 2 ** 8
skip = False
# Soft constraint for symmetry.
reg = 0.
no_reg = False
# Hard constraint for symmetry.
sym = False
'''Hyperparameters for training.'''
# Scale the loss to avoid gradient explosion.
scale = 1e-5
epochs = 500
lr = 2.5e-4
lr_decay = 200
gamma = 0.5
batch_size = 2 ** 7
M = 10
gpu = True
'''Hyperparameters for data generation.'''
base = 10 ** 4
train = 5 * base
test = base
val = base
timesteps = 50
train_steps = 49
test_steps = 99
temp = 0.5
# NOTE: 10 for kuramoto, 100 for spring and charged.
interval = 10 ** 1
samples = 10 ** 2
'''Others'''
log = strftime('logs/{}_{}_%m-%d_%H:%M:%S/{}_{}.txt'.format(
dyn, size, enc, dec), gmtime())
# Run n rounds of the code.
rounds = 1
def init_args(args):
global enc, dec, size, dyn, log, sym, epochs
enc = args.enc
dec = args.dec
dyn = args.dyn
size = args.size
sym = args.sym
epochs = args.epochs
log = strftime('logs/{}_{}_%m-%d_%H:%M:%S/{}_{}.txt'.format(dyn, size, enc, dec), gmtime())