-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathconfig.py
192 lines (158 loc) · 5.82 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import configparser
import copy
import logging
from datetime import datetime
from typing import Optional
import typing
from typing import List
import torch
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class Config(object):
"""Configuration object for the experiments"""
name: str = None
date: str = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
workspace: str = "../workspace"
dataset: str = "mesh"
input_directory: str = ""
max_trajectory_length: int = 120
min_trajectory_length: int = 6
k_closest_nodes: int = 5
extract_coord_features: bool = True
device: torch.device = torch.device("cpu")
optimizer: str = "Adam"
loss: str = "nll_loss"
lr: float = 0.01
momentum: float = 0.5
batch_size: int = 5
overfit1: bool = False
shuffle_samples: bool = True
seed: int = 0
number_epoch: int = 10
train_test_ratio: str = "0.8/0.2"
patience: int = 1000
number_observations: int = 5
self_loop_weight: float = 0.01
self_loop_deadend_only: bool = True
diffusion_k_hops: int = 60
diffusion_hidden_dimension: int = 1
parametrized_diffusion: bool = False
target_prediction: str = "next"
latent_transformer_see_target: bool = False
rw_max_steps: int = -1
rw_edge_weight_see_number_step: bool = False
rw_expected_steps: bool = True
with_interpolation: bool = False
initial_edge_transformer: bool = False
use_shortest_path_distance: bool = False
double_way_diffusion: bool = False
rw_non_backtracking: bool = True
diffusion_self_loops: bool = False
print_per_epoch: int = 10
checkpoint_directory: str = "chkpt"
enable_checkpointing: bool = True
chechpoint_every_num_epoch: int = 5
restore_from_checkpoint: bool = False
compute_baseline: bool = True
_DEPRECATED: List[str] = [
"tensorboard_logdir",
"log_tensorboard",
"rw_restart_coef",
"rw_tolerance",
"rw_walk_or_die",
"print_every",
]
tensorboard_logdir: str = "logdir"
log_tensorboard: bool = True
rw_restart_coef: float = 0.0
rw_tolerance: float = 0.0
rw_walk_or_die: bool = True
print_every: int = 10
def load_from_file(self, filename: str):
"""Load configuration fiels from a file
Args:
filename (str): file name
"""
config = configparser.ConfigParser()
config.read(filename)
fields = {
key: value for section in config.sections() for key, value in config[section].items()
}
self.load_from_dict(fields)
def load_from_dict(self, fields: dict):
"""Load all fields from the dictionary
Args:
fields (dict): configuration (key, value)'s
Raises:
NotImplementedError: Unknown configuration keys
"""
for key, value in fields.items():
if hasattr(self, key):
setattr(self, key, value)
else:
raise NotImplementedError(f'Unknown configuration field "{key}"')
def __getattribute__(self, k):
deprecated = super(Config, self).__getattribute__("_DEPRECATED")
if k in deprecated:
logger.warning(f"Accessed to deprecated config field '{k}'")
return super(Config, self).__getattribute__(k)
def __setattr__(self, k, v):
"""Update a configuration key, take care of casting
Args:
k (str): configuration
v: value
Raises:
NotImplementedError: Unkown type of field for casting
AttributeError: Unknown attribute
"""
type_annotations = typing.get_type_hints(self)
if k in type_annotations:
typ = type_annotations[k]
if type(v) is not typ:
if typ is int:
v = int(v)
elif typ is float:
v = float(v)
elif typ is bool:
v = v.lower() in ["yes", "true", "1"]
elif typ is torch.device:
v = torch.device(v)
else:
raise NotImplementedError(f"Unknown config type '{typ}' for field '{k}'")
self.__dict__[k] = v
else:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{k}'")
def to_dict(self):
return {k: v for k, v in self.__dict__.items() if k[0] != "_" and k not in self._DEPRECATED}
def save_to_file(self, filename):
"""Save this configuration object to a file
Args:
filename (str): path of the file
"""
with open(filename, "w") as f:
f.write("[config]\n")
for k, v in self.to_dict().items():
f.write(f"{k} = {v}\n")
def __str__(self):
type_annotations = typing.get_type_hints(self)
lines = ["[config]"] + [f"{k} = {getattr(self, k)}" for k in type_annotations.keys()]
return "\n".join(lines)
def config_generator(config: Config, parameters: list, selected_params=None):
"""Generate alternative Config objects for grid search
Args:
config (Config): the initial configuration
parameters (list): [('param_name_1', [values...]), ... ('param_name_n', [values...])]
selected_params ([type], optional): Defaults to None.
"""
selected_params = selected_params or []
if parameters:
curr_param, values = parameters[-1]
for v in values:
new_config = copy.copy(config)
setattr(new_config, curr_param, v)
new_selected_params = (selected_params or []) + [(curr_param, v)]
for conf in config_generator(new_config, parameters[:-1], new_selected_params):
yield conf
else:
config.name += "-".join("{}:{}".format(k, v) for k, v in selected_params)
yield config