-
Notifications
You must be signed in to change notification settings - Fork 8
/
utils.py
62 lines (50 loc) · 2.03 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import time
import logging
import datetime
def epoch_time(start_time, end_time):
elapsed_time = end_time - start_time
elapsed_mins = int(elapsed_time / 60)
elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
return elapsed_mins, elapsed_secs
def gen_checkpoint_id(args):
timez = datetime.datetime.now().strftime("%Y%m%d%H%M")
checkpoint_id = "_".join([args.encoder_name, timez])
return checkpoint_id
def get_logger(args):
log_path = f"{args.checkpoint}/info/"
if not os.path.isdir(log_path):
os.mkdir(log_path)
train_instance_log_files = os.listdir(log_path)
train_instance_count = len(train_instance_log_files)
logging.basicConfig(
filename=f'{args.checkpoint}/info/train_instance_{train_instance_count}_info.log',
filemode='w',
format="%(asctime)s | %(filename)15s | %(levelname)7s | %(funcName)10s | %(message)s",
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
# 화면 출력
streamHandler = logging.StreamHandler()
streamHandler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s;[%(levelname)7s];%(message)s",
"%Y-%m-%d %H:%M:%S")
streamHandler.setFormatter(formatter)
logger.addHandler(streamHandler)
logger.info("-"*40)
for arg in vars(args):
logger.info(f"{arg}: {getattr(args, arg)}")
logger.info("-"*40)\
return logger
def checkpoint_count(checkpoint):
_, folders, files = next(iter(os.walk(checkpoint)))
files = list(filter(lambda x: "saved_checkpoint_" in x, files))
# regex used to extract only integer elements from the list of files in the corresponding folder
# this is to extract the most recent checkpoint in case of continuation of training
checkpoints = map(lambda x: int(re.search(r"[0-9]{1,}", x).group()[0]), files)
try:
last_checkpoint = sorted(checkpoints)[-1]
except IndexError:
last_checkpoint = 0
return last_checkpoint