-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathconfig.py
71 lines (57 loc) · 3.01 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import argparse
import tensorflow as tf
import numpy as np
parser = argparse.ArgumentParser()
# For margin loss
parser.add_argument('--m_plus', default=0.9, help='m+ parameter')
parser.add_argument('--m_minus', default=0.1, help='m- parameter')
parser.add_argument('--lambda_val', default=0.5, help='Down-weighting parameter for the absent class')
# For reconstruction loss
parser.add_argument('--alpha', default=0.0005, help='Regularization coefficient to scale down the reconstruction loss')
# For training
parser.add_argument('--mode', default='train', help='train, test, visualize, or adv_attack')
parser.add_argument('--batch_size', default=128, help='Batch size')
parser.add_argument('--epoch', default=50, help='Total number of training epochs')
parser.add_argument('--iter_routing', default=3, help='Number of routing iterations')
parser.add_argument('--stddev', default=0.01, help='std for W initializer')
# Data set info.
parser.add_argument('--dataset', default='mnist', help='dataset name, mnist or fashion-mnist')
parser.add_argument('--n_cls', default=10, help='Total number of classes')
parser.add_argument('--img_w', default=28)
parser.add_argument('--img_h', default=28)
parser.add_argument('--n_ch', default=1, help='Number of input image channels')
# Environment and result saving setting
parser.add_argument('--restore_training', default=False, help='Restores the last trained model to resume training')
parser.add_argument('--checkpoint_path', default='./saved_model/', help='path for saving the model checkpoints')
parser.add_argument('--log_dir', default='./log_dir/', help='logs directory (to save graph and summaries)')
parser.add_argument('--results', default='./results/', help='path for saving the results')
parser.add_argument('--tr_disp_sum', default=100, help='The frequency of saving train results (step)')
# Visualize mode parameters
parser.add_argument('--n_samples', default=5, help='Number of sample images to be saved in visualize mode')
# Adversarial attack mode parameters (using the trained model)
parser.add_argument('--max_iter', default=3, help='Number of iterations for basic iteration adversarial attack')
parser.add_argument('--max_eps', default=np.array(range(0, 100, 5)) / 100.,
help='Maximum epsilon values to be used in FGSM adversarial attack mode')
args = parser.parse_args()
# Parameters of Conv1_layer
conv1_params = {"filters": 256,
"kernel_size": 9,
"strides": 1,
"padding": "valid",
"activation": tf.nn.relu}
# Parameters of PrimaryCaps_layer
caps1_n_maps = 32
caps1_n_caps = caps1_n_maps * 6 * 6 # 1152 primary capsules
caps1_n_dims = 8
conv2_params = {"filters": caps1_n_maps * caps1_n_dims, # 256 convolutional filters
"kernel_size": 9,
"strides": 2,
"padding": "valid",
"activation": tf.nn.relu}
# Parameters of DigitCaps_layer
caps2_n_caps = 10
caps2_n_dims = 16
# Parameters of the Decoder
n_hidden1 = 512
n_hidden2 = 1024
n_output = args.img_w * args.img_h