-
Notifications
You must be signed in to change notification settings - Fork 1
/
util.py
222 lines (181 loc) · 9.16 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import torch
import shutil
import logging
from typing import Type, List
from argparse import Namespace
from cosface_loss import MarginCosineProduct
def move_to_device(optimizer: Type[torch.optim.Optimizer], device: str):
for state in optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v):
state[k] = v.to(device)
def save_checkpoint(state: dict, is_best: bool, output_folder: str,
ckpt_filename: str = "last_checkpoint.pth"):
# TODO it would be better to move weights to cpu before saving
checkpoint_path = f"{output_folder}/{ckpt_filename}"
torch.save(state, checkpoint_path)
if is_best:
torch.save(state["model_state_dict"], f"{output_folder}/best_model.pth")
def resume_train(args: Namespace, output_folder: str, model: torch.nn.Module,
model_optimizer: Type[torch.optim.Optimizer], classifiers: List[MarginCosineProduct],
classifiers_optimizers: List[Type[torch.optim.Optimizer]]):
"""Load model, optimizer, and other training parameters"""
logging.info(f"Loading checkpoint: {args.resume_train}")
checkpoint = torch.load(args.resume_train)
start_epoch_num = checkpoint["epoch_num"]
model_state_dict = checkpoint["model_state_dict"]
model.load_state_dict(model_state_dict)
model = model.to(args.device)
model_optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
assert args.groups_num == len(classifiers) == len(classifiers_optimizers) == \
len(checkpoint["classifiers_state_dict"]) == len(checkpoint["optimizers_state_dict"]), \
(f"{args.groups_num}, {len(classifiers)}, {len(classifiers_optimizers)}, "
f"{len(checkpoint['classifiers_state_dict'])}, {len(checkpoint['optimizers_state_dict'])}")
for c, sd in zip(classifiers, checkpoint["classifiers_state_dict"]):
# Move classifiers to GPU before loading their optimizers
c = c.to(args.device)
c.load_state_dict(sd)
for c, sd in zip(classifiers_optimizers, checkpoint["optimizers_state_dict"]):
c.load_state_dict(sd)
for c in classifiers:
# Move classifiers back to CPU to save some GPU memory
c = c.cpu()
best_val_recall1 = checkpoint["best_val_recall1"]
# Copy best model to current output_folder
shutil.copy(args.resume_train.replace("last_checkpoint.pth", "best_model.pth"), output_folder)
return model, model_optimizer, classifiers, classifiers_optimizers, best_val_recall1, start_epoch_num
import os
import re
import utm
import cv2
import math
import time
import shutil
import requests
from tqdm import tqdm
RETRY_SECONDS = 2
def get_distance(coords_A, coords_B):
return math.sqrt((float(coords_B[0])-float(coords_A[0]))**2 + (float(coords_B[1])-float(coords_A[1]))**2)
def download_heavy_file(url, output_path):
os.makedirs("tmp", exist_ok=True)
tmp_filename = os.path.join("tmp", f"tmp_{int(time.time()*1000)}")
if os.path.exists(output_path):
print(f"File {output_path} already exists, I won't download it again")
return
for attempt_num in range(10): # In case of errors, try 10 times
try:
req = requests.get(url, stream=True)
total_size = int(req.headers.get('content-length', 0)) # Total size in bytes
block_size = 1024 # 1 KB
tqdm_bar = tqdm(total=total_size, desc=os.path.basename(output_path),
unit='iB', unit_scale=True, ncols=100)
with open(tmp_filename, 'wb') as f:
for data in req.iter_content(block_size):
tqdm_bar.update(len(data))
f.write(data)
tqdm_bar.close()
if total_size != 0 and tqdm_bar.n != total_size:
print(tqdm_bar.n)
print(total_size)
raise RuntimeError("ERROR, something went wrong during download")
break
except (Exception, RuntimeError) as e:
if os.path.exists(tmp_filename): os.remove(tmp_filename)
print(e)
print(f"I'll try again to download {output_path} in {RETRY_SECONDS**attempt_num} seconds")
time.sleep(RETRY_SECONDS**attempt_num)
else:
raise RuntimeError(f"I tried 10 times and I couldn't download {output_path} from {url}")
os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
shutil.move(tmp_filename, output_path)
def is_valid_timestamp(timestamp):
"""Return True if it's a valid timestamp, in format YYYYMMDD_hhmmss,
with all fields from left to right optional.
>>> is_valid_timestamp('')
True
>>> is_valid_timestamp('201901')
True
>>> is_valid_timestamp('20190101_123000')
True
"""
return bool(re.match("^(\d{4}(\d{2}(\d{2}(_(\d{2})(\d{2})?(\d{2})?)?)?)?)?$", timestamp))
def format_coord(num, left=2, right=5):
"""Return the formatted number as a string with (left) int digits
(including sign '-' for negatives) and (right) float digits.
>>> format_coord(1.1, 3, 3)
'001.100'
>>> format_coord(-0.123, 3, 3)
'-00.123'
"""
sign = "-" if float(num) < 0 else ""
num = str(abs(float(num))) + "."
integer, decimal = num.split(".")[:2]
left -= len(sign)
return f"{sign}{int(integer):0{left}d}.{decimal[:right]:<0{right}}"
import doctest
doctest.testmod() # Automatically execute unit-test of format_coord()
def format_location_info(latitude, longitude):
easting, northing, zone_number, zone_letter = utm.from_latlon(float(latitude), float(longitude))
easting = format_coord(easting, 7, 2)
northing = format_coord(northing, 7, 2)
latitude = format_coord(latitude, 3, 5)
longitude = format_coord(longitude, 4, 5)
return easting, northing, zone_number, zone_letter, latitude, longitude
def get_dst_image_name(latitude, longitude, pano_id=None, tile_num=None, heading=None,
pitch=None, roll=None, height=None, timestamp=None, note=None, extension=".jpg"):
easting, northing, zone_number, zone_letter, latitude, longitude = format_location_info(latitude, longitude)
tile_num = f"{int(float(tile_num)):02d}" if tile_num is not None else ""
heading = f"{int(float(heading)):03d}" if heading is not None else ""
pitch = f"{int(float(pitch)):03d}" if pitch is not None else ""
timestamp = f"{timestamp}" if timestamp is not None else ""
note = f"{note}" if note is not None else ""
assert is_valid_timestamp(timestamp), f"{timestamp} is not in YYYYMMDD_hhmmss format"
if roll is None: roll = ""
else: raise NotImplementedError()
if height is None: height = ""
else: raise NotImplementedError()
return f"@{easting}@{northing}@{zone_number:02d}@{zone_letter}@{latitude}@{longitude}" + \
f"@{pano_id}@{tile_num}@{heading}@{pitch}@{roll}@{height}@{timestamp}@{note}@{extension}"
class VideoReader:
def __init__(self, video_name, size=None):
if not os.path.exists(video_name):
raise FileNotFoundError(f"{video_name} does not exist")
self.video_name = video_name
self.size = size
self.vc = cv2.VideoCapture(f"{video_name}")
self.frames_per_second = self.vc.get(cv2.CAP_PROP_FPS)
self.frame_duration_millis = 1000 / self.frames_per_second
self.frames_num = int(self.vc.get(cv2.CAP_PROP_FRAME_COUNT))
self.video_length_in_millis = int(self.frames_num * 1000 / self.frames_per_second)
def get_time_at_frame(self, frame_num):
return int(self.frame_duration_millis * frame_num)
def get_frame_num_at_time(self, time):
# time can be str ('21:59') or int in milliseconds
millis = time if type(time) == int else self.str_to_millis(time)
return min(int(millis / self.frame_duration_millis), self.frames_num)
def get_frame_at_frame_num(self, frame_num):
self.vc.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
frame = self.vc.read()[1]
if frame is None: return None # In case of corrupt videos
if self.size is not None:
frame = cv2.resize(frame, self.size[::-1], cv2.INTER_CUBIC)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return frame
@staticmethod
def str_to_millis(time_str):
return (int(time_str.split(":")[0]) * 60 + int(time_str.split(":")[1])) * 1000
@staticmethod
def millis_to_str(millis):
if millis < 60*60*1000:
return f"{math.floor((millis//1000//60)%60):02d}:{millis//1000%60:02d}"
else:
return f"{math.floor((millis//1000//60//60)%60):02d}:{math.floor((millis//1000//60)%60):02d}:{millis//1000%60:02d}"
def __repr__(self):
H, W = int(self.vc.get(cv2.CAP_PROP_FRAME_HEIGHT)), int(self.vc.get(cv2.CAP_PROP_FRAME_WIDTH))
return (f"Video '{self.video_name}' has {self.frames_num} frames, " +
f"with resolution {H}x{W}, " +
f"and lasts {self.video_length_in_millis // 1000} seconds "
f"({self.millis_to_str(self.video_length_in_millis)}), therefore "
f"there's a frame every {int(self.frame_duration_millis)} millis")
def __del__(self):
self.vc.release()