-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathutils.py
147 lines (118 loc) · 3.94 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import logging
import time
from datetime import timedelta
import pickle
import pandas as pd
import numpy as np
import torch
class LogFormatter:
def __init__(self):
self.start_time = time.time()
def format(self, record):
elapsed_seconds = round(record.created - self.start_time)
prefix = "%s - %s - %s" % (
record.levelname,
time.strftime("%x %X"),
timedelta(seconds=elapsed_seconds),
)
message = record.getMessage()
message = message.replace("\n", "\n" + " " * (len(prefix) + 3))
return "%s - %s" % (prefix, message) if message else ""
def create_logger(filepath):
"""
Create a logger.
Use a different log file for each process.
"""
# create log formatter
log_formatter = LogFormatter()
# create file handler and set level to debug
if filepath is not None:
file_handler = logging.FileHandler(filepath, "a")
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(log_formatter)
# create console handler and set level to info
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(log_formatter)
# create logger and set level to debug
logger = logging.getLogger()
logger.handlers = []
logger.setLevel(logging.DEBUG)
logger.propagate = False
if filepath is not None:
logger.addHandler(file_handler)
logger.addHandler(console_handler)
# reset logger elapsed time
def reset_time():
log_formatter.start_time = time.time()
logger.reset_time = reset_time
return logger
class PD_Stats(object):
"""
Log stuff with pandas library
"""
def __init__(self, path, columns):
self.path = path
# reload path stats
if os.path.isfile(self.path):
self.stats = pd.read_pickle(self.path)
# check that columns are the same
assert list(self.stats.columns) == list(columns)
else:
self.stats = pd.DataFrame(columns=columns)
def update(self, row, save=True):
self.stats.loc[len(self.stats.index)] = row
# save the statistics
if save:
self.stats.to_pickle(self.path)
def fix_random_seeds(seed=777):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
def initialize_exp(params, *args, dump_params=True):
"""
Initialize the experience:
- dump parameters
- create checkpoint repo
- create a logger
- create a panda object to keep track of the training statistics
"""
# create dump_path if not exists
os.makedirs(params.dump_path, exist_ok=True)
# dump parameters
if dump_params:
pickle.dump(params, open(os.path.join(params.dump_path, "params.pkl"), "wb"))
# create repo to store checkpoints
params.dump_checkpoints = os.path.join(params.dump_path, "checkpoints")
if not os.path.isdir(params.dump_checkpoints):
os.mkdir(params.dump_checkpoints)
# create a panda object to log loss and acc
training_stats = PD_Stats(
os.path.join(params.dump_path, "stats.pkl"), args
)
# create a logger
logger = create_logger(
os.path.join(params.dump_path, "train.log")
)
logger.info("============ Initialized logger ============")
logger.info(
"\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(params)).items()))
)
logger.info("The experiment will be stored in %s\n" % params.dump_path)
logger.info("")
return logger, training_stats
class AverageMeter(object):
"""computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count