Skip to content

Commit

Permalink
update mmrotate/datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
Challyfilio committed Mar 18, 2024
1 parent 6296480 commit 9a6b9e4
Show file tree
Hide file tree
Showing 3 changed files with 768 additions and 1 deletion.
4 changes: 3 additions & 1 deletion mmrotate/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@
from .hrsc import HRSCDataset # noqa: F401, F403
from .pipelines import * # noqa: F401, F403
from .sar import SARDataset # noqa: F401, F403
from .fair import FairDataset
from .df2023 import DF2023Dataset

__all__ = ['SARDataset', 'DOTADataset', 'build_dataset', 'HRSCDataset']
__all__ = ['SARDataset', 'DOTADataset', 'build_dataset', 'HRSCDataset', 'FairDataset', 'DF2023Dataset']
374 changes: 374 additions & 0 deletions mmrotate/datasets/df2023.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,374 @@
# Copyright (c) OpenMMLab. All rights reserved.
import glob
import os
import os.path as osp
import re
import tempfile
import time
import zipfile
from collections import defaultdict
from functools import partial

import mmcv
import numpy as np
import torch
from mmcv.ops import nms_rotated
from mmdet.datasets.custom import CustomDataset

from mmrotate.core import eval_rbbox_map, obb2poly_np, poly2obb_np
from .builder import ROTATED_DATASETS


@ROTATED_DATASETS.register_module()
class DF2023Dataset(CustomDataset):
"""DOTA dataset for detection.
Args:
ann_file (str): Annotation file path.
pipeline (list[dict]): Processing pipeline.
version (str, optional): Angle representations. Defaults to 'oc'.
difficulty (bool, optional): The difficulty threshold of GT.
"""
# CLASSES = ('plane', 'baseball-diamond', 'bridge', 'ground-track-field',
# 'small-vehicle', 'large-vehicle', 'ship', 'tennis-court',
# 'basketball-court', 'storage-tank', 'soccer-ball-field',
# 'roundabout', 'harbor', 'swimming-pool', 'helicopter')

# CLASSES = (
# 'S1', 'P1', 'T4', 'P3', 'B2', 'T3', 'P2', 'D1', 'C1', 'W1', 'H1', 'F1', 'S2', 'T2', 'S4', 'C3', 'A4', 'C2', 'B1',
# 'M4', 'T1', 'S3', 'M1', 'E1', 'T5', 'M3', 'L2', 'T6', 'M2', 'A2', 'E2', 'R2', 'A3', 'R3', 'A1', 'E3', 'R1', 'A5',
# 'L1')

CLASSES = (
'F5', 'P7', 'F2', 'W1', 'S4', 'T1', 'C14', 'B3', 'A7', 'A8', 'C2', 'P3', 'F8', 'C8', 'W2', 'S7', 'C13', 'T7', 'L3',
'Y1', 'M2', 'S5', 'V1', 'T2', 'S6', 'C10', 'S1', 'R2', 'D2', 'V2', 'C9', 'P2', 'H1', 'U2', 'H3', 'N1', 'T5', 'A9',
'D1', 'C6', 'C5', 'T8', 'P5', 'K2', 'P4', 'H2', 'A3', 'B1', 'E2', 'K3', 'C12', 'C15', 'L4', 'S2', 'R1', 'W3', 'T9',
'C11', 'M5', 'E4', 'R3', 'F7', 'U1', 'C3', 'K1', 'M1', 'A6', 'F3', 'E3', 'C1', 'B2', 'T6', 'P1', 'K5', 'K4', 'A4',
'L2', 'C16', 'S3', 'C4', 'A5', 'I1', 'A1', 'E1', 'P6', 'F6', 'C7', 'M4', 'F1', 'T10', 'T3', 'L1', 'Z1', 'A2', 'T4',
'M3', 'R4', 'T11')

PALETTE = [(165, 42, 42), (189, 183, 107), (0, 255, 0), (255, 0, 0),
(138, 43, 226), (255, 128, 0), (255, 0, 255), (0, 255, 255),
(255, 193, 193), (0, 51, 153), (255, 250, 205), (0, 139, 139),
(255, 255, 0), (147, 116, 116), (0, 0, 255)]

def __init__(self,
ann_file,
pipeline,
version='oc',
difficulty=100,
**kwargs):
self.version = version
self.difficulty = difficulty

super(DF2023Dataset, self).__init__(ann_file, pipeline, **kwargs)

def __len__(self):
"""Total number of samples of data."""
return len(self.data_infos)

def load_annotations(self, ann_folder):
"""
Args:
ann_folder: folder that contains DOTA v1 annotations txt files
"""
cls_map = {c: i
for i, c in enumerate(self.CLASSES)
} # in mmdet v2.0 label is 0-based
ann_files = glob.glob(ann_folder + '/*.txt')
data_infos = []
if not ann_files: # test phase
ann_files = glob.glob(ann_folder + '/*.png')
for ann_file in ann_files:
data_info = {}
img_id = osp.split(ann_file)[1][:-4]
img_name = img_id + '.png'
data_info['filename'] = img_name
data_info['ann'] = {}
data_info['ann']['bboxes'] = []
data_info['ann']['labels'] = []
data_infos.append(data_info)
else:
for ann_file in ann_files:
data_info = {}
img_id = osp.split(ann_file)[1][:-4]
img_name = img_id + '.png'
data_info['filename'] = img_name
data_info['ann'] = {}
gt_bboxes = []
gt_labels = []
gt_polygons = []
gt_bboxes_ignore = []
gt_labels_ignore = []
gt_polygons_ignore = []

if os.path.getsize(ann_file) == 0 and self.filter_empty_gt:
continue

with open(ann_file) as f:
s = f.readlines()
for si in s:
bbox_info = si.split()
poly = np.array(bbox_info[:8], dtype=np.float32)
try:
x, y, w, h, a = poly2obb_np(poly, self.version)
except: # noqa: E722
continue
cls_name = bbox_info[8]
difficulty = int(bbox_info[9])
label = cls_map[cls_name]
if difficulty > self.difficulty:
pass
else:
gt_bboxes.append([x, y, w, h, a])
gt_labels.append(label)
gt_polygons.append(poly)

if gt_bboxes:
data_info['ann']['bboxes'] = np.array(
gt_bboxes, dtype=np.float32)
data_info['ann']['labels'] = np.array(
gt_labels, dtype=np.int64)
data_info['ann']['polygons'] = np.array(
gt_polygons, dtype=np.float32)
else:
data_info['ann']['bboxes'] = np.zeros((0, 5),
dtype=np.float32)
data_info['ann']['labels'] = np.array([], dtype=np.int64)
data_info['ann']['polygons'] = np.zeros((0, 8),
dtype=np.float32)

if gt_polygons_ignore:
data_info['ann']['bboxes_ignore'] = np.array(
gt_bboxes_ignore, dtype=np.float32)
data_info['ann']['labels_ignore'] = np.array(
gt_labels_ignore, dtype=np.int64)
data_info['ann']['polygons_ignore'] = np.array(
gt_polygons_ignore, dtype=np.float32)
else:
data_info['ann']['bboxes_ignore'] = np.zeros(
(0, 5), dtype=np.float32)
data_info['ann']['labels_ignore'] = np.array(
[], dtype=np.int64)
data_info['ann']['polygons_ignore'] = np.zeros(
(0, 8), dtype=np.float32)

data_infos.append(data_info)

self.img_ids = [*map(lambda x: x['filename'][:-4], data_infos)]
return data_infos

def _filter_imgs(self):
"""Filter images without ground truths."""
valid_inds = []
for i, data_info in enumerate(self.data_infos):
if (not self.filter_empty_gt
or data_info['ann']['labels'].size > 0):
valid_inds.append(i)
return valid_inds

def _set_group_flag(self):
"""Set flag according to image aspect ratio.
All set to 0.
"""
self.flag = np.zeros(len(self), dtype=np.uint8)

def evaluate(self,
results,
metric='mAP',
logger=None,
proposal_nums=(100, 300, 1000),
iou_thr=0.5,
scale_ranges=None,
nproc=4):
"""Evaluate the dataset.
Args:
results (list): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated.
logger (logging.Logger | None | str): Logger used for printing
related information during evaluation. Default: None.
proposal_nums (Sequence[int]): Proposal number used for evaluating
recalls, such as recall@100, recall@1000.
Default: (100, 300, 1000).
iou_thr (float | list[float]): IoU threshold. It must be a float
when evaluating mAP, and can be a list when evaluating recall.
Default: 0.5.
scale_ranges (list[tuple] | None): Scale ranges for evaluating mAP.
Default: None.
nproc (int): Processes used for computing TP and FP.
Default: 4.
"""
nproc = min(nproc, os.cpu_count())
if not isinstance(metric, str):
assert len(metric) == 1
metric = metric[0]
allowed_metrics = ['mAP']
if metric not in allowed_metrics:
raise KeyError(f'metric {metric} is not supported')
annotations = [self.get_ann_info(i) for i in range(len(self))]
eval_results = {}
if metric == 'mAP':
assert isinstance(iou_thr, float)
mean_ap, _ = eval_rbbox_map(
results,
annotations,
scale_ranges=scale_ranges,
iou_thr=iou_thr,
dataset=self.CLASSES,
logger=logger,
nproc=nproc)
eval_results['mAP'] = mean_ap
else:
raise NotImplementedError

return eval_results

def merge_det(self, results, nproc=4):
"""Merging patch bboxes into full image.
Args:
results (list): Testing results of the dataset.
nproc (int): number of process. Default: 4.
"""
collector = defaultdict(list)
for idx in range(len(self)):
result = results[idx]
img_id = self.img_ids[idx]
splitname = img_id.split('__')
oriname = splitname[0]
pattern1 = re.compile(r'__\d+___\d+')
x_y = re.findall(pattern1, img_id)
x_y_2 = re.findall(r'\d+', x_y[0])
x, y = int(x_y_2[0]), int(x_y_2[1])
new_result = []
for i, dets in enumerate(result):
bboxes, scores = dets[:, :-1], dets[:, [-1]]
ori_bboxes = bboxes.copy()
ori_bboxes[..., :2] = ori_bboxes[..., :2] + np.array(
[x, y], dtype=np.float32)
labels = np.zeros((bboxes.shape[0], 1)) + i
new_result.append(
np.concatenate([labels, ori_bboxes, scores], axis=1))

new_result = np.concatenate(new_result, axis=0)
collector[oriname].append(new_result)

merge_func = partial(_merge_func, CLASSES=self.CLASSES, iou_thr=0.1)
if nproc <= 1:
print('Single processing')
merged_results = mmcv.track_iter_progress(
(map(merge_func, collector.items()), len(collector)))
else:
print('Multiple processing')
merged_results = mmcv.track_parallel_progress(
merge_func, list(collector.items()), nproc)

return zip(*merged_results)

def _results2submission(self, id_list, dets_list, out_folder=None):
"""Generate the submission of full images.
Args:
id_list (list): Id of images.
dets_list (list): Detection results of per class.
out_folder (str, optional): Folder of submission.
"""
if osp.exists(out_folder):
raise ValueError(f'The out_folder should be a non-exist path, '
f'but {out_folder} is existing')
os.makedirs(out_folder)

files = [
osp.join(out_folder, 'Task1_' + cls + '.txt')
for cls in self.CLASSES
]
file_objs = [open(f, 'w') for f in files]
for img_id, dets_per_cls in zip(id_list, dets_list):
for f, dets in zip(file_objs, dets_per_cls):
if dets.size == 0:
continue
bboxes = obb2poly_np(dets, self.version)
for bbox in bboxes:
txt_element = [img_id, str(bbox[-1])
] + [f'{p:.2f}' for p in bbox[:-1]]
f.writelines(' '.join(txt_element) + '\n')

for f in file_objs:
f.close()

target_name = osp.split(out_folder)[-1]
with zipfile.ZipFile(
osp.join(out_folder, target_name + '.zip'), 'w',
zipfile.ZIP_DEFLATED) as t:
for f in files:
t.write(f, osp.split(f)[-1])

return files

def format_results(self, results, submission_dir=None, nproc=4, **kwargs):
"""Format the results to submission text (standard format for DOTA
evaluation).
Args:
results (list): Testing results of the dataset.
submission_dir (str, optional): The folder that contains submission
files. If not specified, a temp folder will be created.
Default: None.
nproc (int, optional): number of process.
Returns:
tuple:
- result_files (dict): a dict containing the json filepaths
- tmp_dir (str): the temporal directory created for saving \
json files when submission_dir is not specified.
"""
nproc = min(nproc, os.cpu_count())
assert isinstance(results, list), 'results must be a list'
assert len(results) == len(self), (
f'The length of results is not equal to '
f'the dataset len: {len(results)} != {len(self)}')
if submission_dir is None:
submission_dir = tempfile.TemporaryDirectory()
else:
tmp_dir = None

print('\nMerging patch bboxes into full image!!!')
start_time = time.time()
id_list, dets_list = self.merge_det(results, nproc)
stop_time = time.time()
print(f'Used time: {(stop_time - start_time):.1f} s')

result_files = self._results2submission(id_list, dets_list,
submission_dir)

return result_files, tmp_dir


def _merge_func(info, CLASSES, iou_thr):
"""Merging patch bboxes into full image.
Args:
CLASSES (list): Label category.
iou_thr (float): Threshold of IoU.
"""
img_id, label_dets = info
label_dets = np.concatenate(label_dets, axis=0)

labels, dets = label_dets[:, 0], label_dets[:, 1:]

big_img_results = []
for i in range(len(CLASSES)):
if len(dets[labels == i]) == 0:
big_img_results.append(dets[labels == i])
else:
try:
cls_dets = torch.from_numpy(dets[labels == i]).cuda()
except: # noqa: E722
cls_dets = torch.from_numpy(dets[labels == i])
nms_dets, keep_inds = nms_rotated(cls_dets[:, :5], cls_dets[:, -1],
iou_thr)
big_img_results.append(nms_dets.cpu().numpy())
return img_id, big_img_results
Loading

0 comments on commit 9a6b9e4

Please sign in to comment.