-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6296480
commit 9a6b9e4
Showing
3 changed files
with
768 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,374 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import glob | ||
import os | ||
import os.path as osp | ||
import re | ||
import tempfile | ||
import time | ||
import zipfile | ||
from collections import defaultdict | ||
from functools import partial | ||
|
||
import mmcv | ||
import numpy as np | ||
import torch | ||
from mmcv.ops import nms_rotated | ||
from mmdet.datasets.custom import CustomDataset | ||
|
||
from mmrotate.core import eval_rbbox_map, obb2poly_np, poly2obb_np | ||
from .builder import ROTATED_DATASETS | ||
|
||
|
||
@ROTATED_DATASETS.register_module() | ||
class DF2023Dataset(CustomDataset): | ||
"""DOTA dataset for detection. | ||
Args: | ||
ann_file (str): Annotation file path. | ||
pipeline (list[dict]): Processing pipeline. | ||
version (str, optional): Angle representations. Defaults to 'oc'. | ||
difficulty (bool, optional): The difficulty threshold of GT. | ||
""" | ||
# CLASSES = ('plane', 'baseball-diamond', 'bridge', 'ground-track-field', | ||
# 'small-vehicle', 'large-vehicle', 'ship', 'tennis-court', | ||
# 'basketball-court', 'storage-tank', 'soccer-ball-field', | ||
# 'roundabout', 'harbor', 'swimming-pool', 'helicopter') | ||
|
||
# CLASSES = ( | ||
# 'S1', 'P1', 'T4', 'P3', 'B2', 'T3', 'P2', 'D1', 'C1', 'W1', 'H1', 'F1', 'S2', 'T2', 'S4', 'C3', 'A4', 'C2', 'B1', | ||
# 'M4', 'T1', 'S3', 'M1', 'E1', 'T5', 'M3', 'L2', 'T6', 'M2', 'A2', 'E2', 'R2', 'A3', 'R3', 'A1', 'E3', 'R1', 'A5', | ||
# 'L1') | ||
|
||
CLASSES = ( | ||
'F5', 'P7', 'F2', 'W1', 'S4', 'T1', 'C14', 'B3', 'A7', 'A8', 'C2', 'P3', 'F8', 'C8', 'W2', 'S7', 'C13', 'T7', 'L3', | ||
'Y1', 'M2', 'S5', 'V1', 'T2', 'S6', 'C10', 'S1', 'R2', 'D2', 'V2', 'C9', 'P2', 'H1', 'U2', 'H3', 'N1', 'T5', 'A9', | ||
'D1', 'C6', 'C5', 'T8', 'P5', 'K2', 'P4', 'H2', 'A3', 'B1', 'E2', 'K3', 'C12', 'C15', 'L4', 'S2', 'R1', 'W3', 'T9', | ||
'C11', 'M5', 'E4', 'R3', 'F7', 'U1', 'C3', 'K1', 'M1', 'A6', 'F3', 'E3', 'C1', 'B2', 'T6', 'P1', 'K5', 'K4', 'A4', | ||
'L2', 'C16', 'S3', 'C4', 'A5', 'I1', 'A1', 'E1', 'P6', 'F6', 'C7', 'M4', 'F1', 'T10', 'T3', 'L1', 'Z1', 'A2', 'T4', | ||
'M3', 'R4', 'T11') | ||
|
||
PALETTE = [(165, 42, 42), (189, 183, 107), (0, 255, 0), (255, 0, 0), | ||
(138, 43, 226), (255, 128, 0), (255, 0, 255), (0, 255, 255), | ||
(255, 193, 193), (0, 51, 153), (255, 250, 205), (0, 139, 139), | ||
(255, 255, 0), (147, 116, 116), (0, 0, 255)] | ||
|
||
def __init__(self, | ||
ann_file, | ||
pipeline, | ||
version='oc', | ||
difficulty=100, | ||
**kwargs): | ||
self.version = version | ||
self.difficulty = difficulty | ||
|
||
super(DF2023Dataset, self).__init__(ann_file, pipeline, **kwargs) | ||
|
||
def __len__(self): | ||
"""Total number of samples of data.""" | ||
return len(self.data_infos) | ||
|
||
def load_annotations(self, ann_folder): | ||
""" | ||
Args: | ||
ann_folder: folder that contains DOTA v1 annotations txt files | ||
""" | ||
cls_map = {c: i | ||
for i, c in enumerate(self.CLASSES) | ||
} # in mmdet v2.0 label is 0-based | ||
ann_files = glob.glob(ann_folder + '/*.txt') | ||
data_infos = [] | ||
if not ann_files: # test phase | ||
ann_files = glob.glob(ann_folder + '/*.png') | ||
for ann_file in ann_files: | ||
data_info = {} | ||
img_id = osp.split(ann_file)[1][:-4] | ||
img_name = img_id + '.png' | ||
data_info['filename'] = img_name | ||
data_info['ann'] = {} | ||
data_info['ann']['bboxes'] = [] | ||
data_info['ann']['labels'] = [] | ||
data_infos.append(data_info) | ||
else: | ||
for ann_file in ann_files: | ||
data_info = {} | ||
img_id = osp.split(ann_file)[1][:-4] | ||
img_name = img_id + '.png' | ||
data_info['filename'] = img_name | ||
data_info['ann'] = {} | ||
gt_bboxes = [] | ||
gt_labels = [] | ||
gt_polygons = [] | ||
gt_bboxes_ignore = [] | ||
gt_labels_ignore = [] | ||
gt_polygons_ignore = [] | ||
|
||
if os.path.getsize(ann_file) == 0 and self.filter_empty_gt: | ||
continue | ||
|
||
with open(ann_file) as f: | ||
s = f.readlines() | ||
for si in s: | ||
bbox_info = si.split() | ||
poly = np.array(bbox_info[:8], dtype=np.float32) | ||
try: | ||
x, y, w, h, a = poly2obb_np(poly, self.version) | ||
except: # noqa: E722 | ||
continue | ||
cls_name = bbox_info[8] | ||
difficulty = int(bbox_info[9]) | ||
label = cls_map[cls_name] | ||
if difficulty > self.difficulty: | ||
pass | ||
else: | ||
gt_bboxes.append([x, y, w, h, a]) | ||
gt_labels.append(label) | ||
gt_polygons.append(poly) | ||
|
||
if gt_bboxes: | ||
data_info['ann']['bboxes'] = np.array( | ||
gt_bboxes, dtype=np.float32) | ||
data_info['ann']['labels'] = np.array( | ||
gt_labels, dtype=np.int64) | ||
data_info['ann']['polygons'] = np.array( | ||
gt_polygons, dtype=np.float32) | ||
else: | ||
data_info['ann']['bboxes'] = np.zeros((0, 5), | ||
dtype=np.float32) | ||
data_info['ann']['labels'] = np.array([], dtype=np.int64) | ||
data_info['ann']['polygons'] = np.zeros((0, 8), | ||
dtype=np.float32) | ||
|
||
if gt_polygons_ignore: | ||
data_info['ann']['bboxes_ignore'] = np.array( | ||
gt_bboxes_ignore, dtype=np.float32) | ||
data_info['ann']['labels_ignore'] = np.array( | ||
gt_labels_ignore, dtype=np.int64) | ||
data_info['ann']['polygons_ignore'] = np.array( | ||
gt_polygons_ignore, dtype=np.float32) | ||
else: | ||
data_info['ann']['bboxes_ignore'] = np.zeros( | ||
(0, 5), dtype=np.float32) | ||
data_info['ann']['labels_ignore'] = np.array( | ||
[], dtype=np.int64) | ||
data_info['ann']['polygons_ignore'] = np.zeros( | ||
(0, 8), dtype=np.float32) | ||
|
||
data_infos.append(data_info) | ||
|
||
self.img_ids = [*map(lambda x: x['filename'][:-4], data_infos)] | ||
return data_infos | ||
|
||
def _filter_imgs(self): | ||
"""Filter images without ground truths.""" | ||
valid_inds = [] | ||
for i, data_info in enumerate(self.data_infos): | ||
if (not self.filter_empty_gt | ||
or data_info['ann']['labels'].size > 0): | ||
valid_inds.append(i) | ||
return valid_inds | ||
|
||
def _set_group_flag(self): | ||
"""Set flag according to image aspect ratio. | ||
All set to 0. | ||
""" | ||
self.flag = np.zeros(len(self), dtype=np.uint8) | ||
|
||
def evaluate(self, | ||
results, | ||
metric='mAP', | ||
logger=None, | ||
proposal_nums=(100, 300, 1000), | ||
iou_thr=0.5, | ||
scale_ranges=None, | ||
nproc=4): | ||
"""Evaluate the dataset. | ||
Args: | ||
results (list): Testing results of the dataset. | ||
metric (str | list[str]): Metrics to be evaluated. | ||
logger (logging.Logger | None | str): Logger used for printing | ||
related information during evaluation. Default: None. | ||
proposal_nums (Sequence[int]): Proposal number used for evaluating | ||
recalls, such as recall@100, recall@1000. | ||
Default: (100, 300, 1000). | ||
iou_thr (float | list[float]): IoU threshold. It must be a float | ||
when evaluating mAP, and can be a list when evaluating recall. | ||
Default: 0.5. | ||
scale_ranges (list[tuple] | None): Scale ranges for evaluating mAP. | ||
Default: None. | ||
nproc (int): Processes used for computing TP and FP. | ||
Default: 4. | ||
""" | ||
nproc = min(nproc, os.cpu_count()) | ||
if not isinstance(metric, str): | ||
assert len(metric) == 1 | ||
metric = metric[0] | ||
allowed_metrics = ['mAP'] | ||
if metric not in allowed_metrics: | ||
raise KeyError(f'metric {metric} is not supported') | ||
annotations = [self.get_ann_info(i) for i in range(len(self))] | ||
eval_results = {} | ||
if metric == 'mAP': | ||
assert isinstance(iou_thr, float) | ||
mean_ap, _ = eval_rbbox_map( | ||
results, | ||
annotations, | ||
scale_ranges=scale_ranges, | ||
iou_thr=iou_thr, | ||
dataset=self.CLASSES, | ||
logger=logger, | ||
nproc=nproc) | ||
eval_results['mAP'] = mean_ap | ||
else: | ||
raise NotImplementedError | ||
|
||
return eval_results | ||
|
||
def merge_det(self, results, nproc=4): | ||
"""Merging patch bboxes into full image. | ||
Args: | ||
results (list): Testing results of the dataset. | ||
nproc (int): number of process. Default: 4. | ||
""" | ||
collector = defaultdict(list) | ||
for idx in range(len(self)): | ||
result = results[idx] | ||
img_id = self.img_ids[idx] | ||
splitname = img_id.split('__') | ||
oriname = splitname[0] | ||
pattern1 = re.compile(r'__\d+___\d+') | ||
x_y = re.findall(pattern1, img_id) | ||
x_y_2 = re.findall(r'\d+', x_y[0]) | ||
x, y = int(x_y_2[0]), int(x_y_2[1]) | ||
new_result = [] | ||
for i, dets in enumerate(result): | ||
bboxes, scores = dets[:, :-1], dets[:, [-1]] | ||
ori_bboxes = bboxes.copy() | ||
ori_bboxes[..., :2] = ori_bboxes[..., :2] + np.array( | ||
[x, y], dtype=np.float32) | ||
labels = np.zeros((bboxes.shape[0], 1)) + i | ||
new_result.append( | ||
np.concatenate([labels, ori_bboxes, scores], axis=1)) | ||
|
||
new_result = np.concatenate(new_result, axis=0) | ||
collector[oriname].append(new_result) | ||
|
||
merge_func = partial(_merge_func, CLASSES=self.CLASSES, iou_thr=0.1) | ||
if nproc <= 1: | ||
print('Single processing') | ||
merged_results = mmcv.track_iter_progress( | ||
(map(merge_func, collector.items()), len(collector))) | ||
else: | ||
print('Multiple processing') | ||
merged_results = mmcv.track_parallel_progress( | ||
merge_func, list(collector.items()), nproc) | ||
|
||
return zip(*merged_results) | ||
|
||
def _results2submission(self, id_list, dets_list, out_folder=None): | ||
"""Generate the submission of full images. | ||
Args: | ||
id_list (list): Id of images. | ||
dets_list (list): Detection results of per class. | ||
out_folder (str, optional): Folder of submission. | ||
""" | ||
if osp.exists(out_folder): | ||
raise ValueError(f'The out_folder should be a non-exist path, ' | ||
f'but {out_folder} is existing') | ||
os.makedirs(out_folder) | ||
|
||
files = [ | ||
osp.join(out_folder, 'Task1_' + cls + '.txt') | ||
for cls in self.CLASSES | ||
] | ||
file_objs = [open(f, 'w') for f in files] | ||
for img_id, dets_per_cls in zip(id_list, dets_list): | ||
for f, dets in zip(file_objs, dets_per_cls): | ||
if dets.size == 0: | ||
continue | ||
bboxes = obb2poly_np(dets, self.version) | ||
for bbox in bboxes: | ||
txt_element = [img_id, str(bbox[-1]) | ||
] + [f'{p:.2f}' for p in bbox[:-1]] | ||
f.writelines(' '.join(txt_element) + '\n') | ||
|
||
for f in file_objs: | ||
f.close() | ||
|
||
target_name = osp.split(out_folder)[-1] | ||
with zipfile.ZipFile( | ||
osp.join(out_folder, target_name + '.zip'), 'w', | ||
zipfile.ZIP_DEFLATED) as t: | ||
for f in files: | ||
t.write(f, osp.split(f)[-1]) | ||
|
||
return files | ||
|
||
def format_results(self, results, submission_dir=None, nproc=4, **kwargs): | ||
"""Format the results to submission text (standard format for DOTA | ||
evaluation). | ||
Args: | ||
results (list): Testing results of the dataset. | ||
submission_dir (str, optional): The folder that contains submission | ||
files. If not specified, a temp folder will be created. | ||
Default: None. | ||
nproc (int, optional): number of process. | ||
Returns: | ||
tuple: | ||
- result_files (dict): a dict containing the json filepaths | ||
- tmp_dir (str): the temporal directory created for saving \ | ||
json files when submission_dir is not specified. | ||
""" | ||
nproc = min(nproc, os.cpu_count()) | ||
assert isinstance(results, list), 'results must be a list' | ||
assert len(results) == len(self), ( | ||
f'The length of results is not equal to ' | ||
f'the dataset len: {len(results)} != {len(self)}') | ||
if submission_dir is None: | ||
submission_dir = tempfile.TemporaryDirectory() | ||
else: | ||
tmp_dir = None | ||
|
||
print('\nMerging patch bboxes into full image!!!') | ||
start_time = time.time() | ||
id_list, dets_list = self.merge_det(results, nproc) | ||
stop_time = time.time() | ||
print(f'Used time: {(stop_time - start_time):.1f} s') | ||
|
||
result_files = self._results2submission(id_list, dets_list, | ||
submission_dir) | ||
|
||
return result_files, tmp_dir | ||
|
||
|
||
def _merge_func(info, CLASSES, iou_thr): | ||
"""Merging patch bboxes into full image. | ||
Args: | ||
CLASSES (list): Label category. | ||
iou_thr (float): Threshold of IoU. | ||
""" | ||
img_id, label_dets = info | ||
label_dets = np.concatenate(label_dets, axis=0) | ||
|
||
labels, dets = label_dets[:, 0], label_dets[:, 1:] | ||
|
||
big_img_results = [] | ||
for i in range(len(CLASSES)): | ||
if len(dets[labels == i]) == 0: | ||
big_img_results.append(dets[labels == i]) | ||
else: | ||
try: | ||
cls_dets = torch.from_numpy(dets[labels == i]).cuda() | ||
except: # noqa: E722 | ||
cls_dets = torch.from_numpy(dets[labels == i]) | ||
nms_dets, keep_inds = nms_rotated(cls_dets[:, :5], cls_dets[:, -1], | ||
iou_thr) | ||
big_img_results.append(nms_dets.cpu().numpy()) | ||
return img_id, big_img_results |
Oops, something went wrong.