forked from vt-vl-lab/video-data-aug
-
Notifications
You must be signed in to change notification settings - Fork 0
/
video_dataset.py
133 lines (110 loc) · 4.86 KB
/
video_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os.path as osp
import torch
from mmcv.utils import print_log
from ..core import mean_class_accuracy, top_k_accuracy
from .base import BaseDataset
from .registry import DATASETS
@DATASETS.register_module()
class VideoDataset(BaseDataset):
"""Video dataset for action recognition.
The dataset loads raw videos and apply specified transforms to return a
dict containing the frame tensors and other information.
The ann_file is a text file with multiple lines, and each line indicates
a sample video with the filepath and label, which are split with a
whitespace. Example of a annotation file:
.. code-block:: txt
some/path/000.mp4 1
some/path/001.mp4 1
some/path/002.mp4 2
some/path/003.mp4 2
some/path/004.mp4 3
some/path/005.mp4 3
Args:
ann_file (str): Path to the annotation file.
pipeline (list[dict | callable]): A sequence of data transforms.
start_index (int): Specify a start index for frames in consideration of
different filename format. However, when taking videos as input,
it should be set to 0, since frames loaded from videos count
from 0. Default: 0.
**kwargs: Keyword arguments for ``BaseDataset``.
"""
def __init__(self, ann_file, pipeline, start_index=0, **kwargs):
super().__init__(ann_file, pipeline, start_index=start_index, **kwargs)
def load_annotations(self):
"""Load annotation file to get video information."""
if self.ann_file.endswith('.json'):
return self.load_json_annotations()
video_infos = []
with open(self.ann_file, 'r') as fin:
for line in fin:
line_split = line.strip().split()
if self.multi_class:
assert self.num_classes is not None
filename, label = line_split[0], line_split[1:]
label = list(map(int, label))
onehot = torch.zeros(self.num_classes)
onehot[label] = 1.0
else:
filename, label = line_split
label = int(label)
if self.data_prefix is not None:
filename = osp.join(self.data_prefix, filename)
video_infos.append(
dict(
filename=filename,
label=onehot if self.multi_class else label))
return video_infos
def evaluate(self,
results,
metrics='top_k_accuracy',
topk=(1, 5),
logger=None):
"""Evaluation in rawframe dataset.
Args:
results (list): Output results.
metrics (str | sequence[str]): Metrics to be performed.
Defaults: 'top_k_accuracy'.
logger (obj): Training logger. Defaults: None.
topk (tuple[int]): K value for top_k_accuracy metric.
Defaults: (1, 5).
logger (logging.Logger | None): Logger for recording.
Default: None.
Return:
dict: Evaluation results dict.
"""
if not isinstance(results, list):
raise TypeError(f'results must be a list, but got {type(results)}')
assert len(results) == len(self), (
f'The length of results is not equal to the dataset len: '
f'{len(results)} != {len(self)}')
if not isinstance(topk, (int, tuple)):
raise TypeError(
f'topk must be int or tuple of int, but got {type(topk)}')
metrics = metrics if isinstance(metrics, (list, tuple)) else [metrics]
allowed_metrics = ['top_k_accuracy', 'mean_class_accuracy']
for metric in metrics:
if metric not in allowed_metrics:
raise KeyError(f'metric {metric} is not supported')
eval_results = {}
gt_labels = [ann['label'] for ann in self.video_infos]
for metric in metrics:
msg = f'Evaluating {metric}...'
if logger is None:
msg = '\n' + msg
print_log(msg, logger=logger)
if metric == 'top_k_accuracy':
top_k_acc = top_k_accuracy(results, gt_labels, topk)
log_msg = []
for k, acc in zip(topk, top_k_acc):
eval_results[f'top{k}_acc'] = acc
log_msg.append(f'\ntop{k}_acc\t{acc:.4f}')
log_msg = ''.join(log_msg)
print_log(log_msg, logger=logger)
continue
if metric == 'mean_class_accuracy':
mean_acc = mean_class_accuracy(results, gt_labels)
eval_results['mean_class_accuracy'] = mean_acc
log_msg = f'\nmean_acc\t{mean_acc:.4f}'
print_log(log_msg, logger=logger)
continue
return eval_results