forked from open-mmlab/mmaction2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlfb_infer_head.py
145 lines (118 loc) · 5.03 KB
/
lfb_infer_head.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os.path as osp
import mmcv
import torch
import torch.distributed as dist
import torch.nn as nn
from mmcv.runner import get_dist_info
try:
from mmdet.models.builder import SHARED_HEADS as MMDET_SHARED_HEADS
mmdet_imported = True
except (ImportError, ModuleNotFoundError):
mmdet_imported = False
class LFBInferHead(nn.Module):
"""Long-Term Feature Bank Infer Head.
This head is used to derive and save the LFB without affecting the input.
Args:
lfb_prefix_path (str): The prefix path to store the lfb.
dataset_mode (str, optional): Which dataset to be inferred. Choices are
'train', 'val' or 'test'. Default: 'train'.
use_half_precision (bool, optional): Whether to store the
half-precision roi features. Default: True.
temporal_pool_type (str): The temporal pool type. Choices are 'avg' or
'max'. Default: 'avg'.
spatial_pool_type (str): The spatial pool type. Choices are 'avg' or
'max'. Default: 'max'.
"""
def __init__(self,
lfb_prefix_path,
dataset_mode='train',
use_half_precision=True,
temporal_pool_type='avg',
spatial_pool_type='max'):
super().__init__()
rank, world_size = get_dist_info()
if rank == 0:
if not osp.exists(lfb_prefix_path):
print(f'lfb prefix path {lfb_prefix_path} does not exist. '
f'Creating the folder...')
mmcv.mkdir_or_exist(lfb_prefix_path)
print('\nInferring LFB...')
assert temporal_pool_type in ['max', 'avg']
assert spatial_pool_type in ['max', 'avg']
self.lfb_prefix_path = lfb_prefix_path
self.dataset_mode = dataset_mode
self.use_half_precision = use_half_precision
# Pool by default
if temporal_pool_type == 'avg':
self.temporal_pool = nn.AdaptiveAvgPool3d((1, None, None))
else:
self.temporal_pool = nn.AdaptiveMaxPool3d((1, None, None))
if spatial_pool_type == 'avg':
self.spatial_pool = nn.AdaptiveAvgPool3d((None, 1, 1))
else:
self.spatial_pool = nn.AdaptiveMaxPool3d((None, 1, 1))
self.all_features = []
self.all_metadata = []
def init_weights(self, pretrained=None):
# LFBInferHead has no parameters to be initialized.
pass
def forward(self, x, rois, img_metas):
# [N, C, 1, 1, 1]
features = self.temporal_pool(x)
features = self.spatial_pool(features)
if self.use_half_precision:
features = features.half()
inds = rois[:, 0].type(torch.int64)
for ind in inds:
self.all_metadata.append(img_metas[ind]['img_key'])
self.all_features += list(features)
# Return the input directly and doesn't affect the input.
return x
def __del__(self):
assert len(self.all_features) == len(self.all_metadata), (
'features and metadata are not equal in length!')
rank, world_size = get_dist_info()
if world_size > 1:
dist.barrier()
_lfb = {}
for feature, metadata in zip(self.all_features, self.all_metadata):
video_id, timestamp = metadata.split(',')
timestamp = int(timestamp)
if video_id not in _lfb:
_lfb[video_id] = {}
if timestamp not in _lfb[video_id]:
_lfb[video_id][timestamp] = []
_lfb[video_id][timestamp].append(torch.squeeze(feature))
_lfb_file_path = osp.normpath(
osp.join(self.lfb_prefix_path,
f'_lfb_{self.dataset_mode}_{rank}.pkl'))
torch.save(_lfb, _lfb_file_path)
print(f'{len(self.all_features)} features from {len(_lfb)} videos '
f'on GPU {rank} have been stored in {_lfb_file_path}.')
# Synchronizes all processes to make sure all gpus have stored their
# roi features
if world_size > 1:
dist.barrier()
if rank > 0:
return
print('Gathering all the roi features...')
lfb = {}
for rank_id in range(world_size):
_lfb_file_path = osp.normpath(
osp.join(self.lfb_prefix_path,
f'_lfb_{self.dataset_mode}_{rank_id}.pkl'))
# Since each frame will only be distributed to one GPU,
# the roi features on the same timestamp of the same video are all
# on the same GPU
_lfb = torch.load(_lfb_file_path)
for video_id in _lfb:
if video_id not in lfb:
lfb[video_id] = _lfb[video_id]
else:
lfb[video_id].update(_lfb[video_id])
lfb_file_path = osp.normpath(
osp.join(self.lfb_prefix_path, f'lfb_{self.dataset_mode}.pkl'))
torch.save(lfb, lfb_file_path)
print(f'LFB has been constructed in {lfb_file_path}!')
if mmdet_imported:
MMDET_SHARED_HEADS.register_module()(LFBInferHead)