Skip to content

Commit

Permalink
add humanart: hrenet-w32/48 vitpose-h/l rtmpose-t
Browse files Browse the repository at this point in the history
  • Loading branch information
juxuan27 committed Jun 26, 2023
1 parent 5b143d9 commit ec26fca
Show file tree
Hide file tree
Showing 11 changed files with 1,104 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
_base_ = ['../../../_base_/default_runtime.py']

# runtime
max_epochs = 420
stage2_num_epochs = 30
base_lr = 4e-3

train_cfg = dict(max_epochs=max_epochs, val_interval=10)
randomness = dict(seed=21)

# optimizer
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.),
paramwise_cfg=dict(
norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))

# learning rate
param_scheduler = [
dict(
type='LinearLR',
start_factor=1.0e-5,
by_epoch=False,
begin=0,
end=1000),
dict(
# use cosine lr from 210 to 420 epoch
type='CosineAnnealingLR',
eta_min=base_lr * 0.05,
begin=max_epochs // 2,
end=max_epochs,
T_max=max_epochs // 2,
by_epoch=True,
convert_to_iter_based=True),
]

# automatically scaling LR based on the actual training batch size
auto_scale_lr = dict(base_batch_size=1024)

# codec settings
codec = dict(
type='SimCCLabel',
input_size=(192, 256),
sigma=(4.9, 5.66),
simcc_split_ratio=2.0,
normalize=False,
use_dark=False)

# model settings
model = dict(
type='TopdownPoseEstimator',
data_preprocessor=dict(
type='PoseDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True),
backbone=dict(
_scope_='mmdet',
type='CSPNeXt',
arch='P5',
expand_ratio=0.5,
deepen_factor=0.167,
widen_factor=0.375,
out_indices=(4, ),
channel_attention=True,
norm_cfg=dict(type='SyncBN'),
act_cfg=dict(type='SiLU'),
init_cfg=dict(
type='Pretrained',
prefix='backbone.',
checkpoint='https://download.openmmlab.com/mmpose/v1/projects/'
'rtmpose/cspnext-tiny_udp-aic-coco_210e-256x192-cbed682d_20230130.pth' # noqa
)),
head=dict(
type='RTMCCHead',
in_channels=384,
out_channels=17,
input_size=codec['input_size'],
in_featuremap_size=(6, 8),
simcc_split_ratio=codec['simcc_split_ratio'],
final_layer_kernel_size=7,
gau_cfg=dict(
hidden_dims=256,
s=128,
expansion_factor=2,
dropout_rate=0.,
drop_path=0.,
act_fn='SiLU',
use_rel_bias=False,
pos_enc=False),
loss=dict(
type='KLDiscretLoss',
use_target_weight=True,
beta=10.,
label_softmax=True),
decoder=codec),
test_cfg=dict(flip_test=True))

# base dataset settings
dataset_type = 'HumanArtDataset'
data_mode = 'topdown'
data_root = 'data/'

backend_args = dict(backend='local')
# backend_args = dict(
# backend='petrel',
# path_mapping=dict({
# f'{data_root}': 's3://openmmlab/datasets/detection/coco/',
# f'{data_root}': 's3://openmmlab/datasets/detection/coco/'
# }))

# pipelines
train_pipeline = [
dict(type='LoadImage', backend_args=backend_args),
dict(type='GetBBoxCenterScale'),
dict(type='RandomFlip', direction='horizontal'),
dict(type='RandomHalfBody'),
dict(
type='RandomBBoxTransform', scale_factor=[0.6, 1.4], rotate_factor=80),
dict(type='TopdownAffine', input_size=codec['input_size']),
dict(type='mmdet.YOLOXHSVRandomAug'),
dict(
type='Albumentation',
transforms=[
dict(type='Blur', p=0.1),
dict(type='MedianBlur', p=0.1),
dict(
type='CoarseDropout',
max_holes=1,
max_height=0.4,
max_width=0.4,
min_holes=1,
min_height=0.2,
min_width=0.2,
p=1.),
]),
dict(type='GenerateTarget', encoder=codec),
dict(type='PackPoseInputs')
]
val_pipeline = [
dict(type='LoadImage', backend_args=backend_args),
dict(type='GetBBoxCenterScale'),
dict(type='TopdownAffine', input_size=codec['input_size']),
dict(type='PackPoseInputs')
]

train_pipeline_stage2 = [
dict(type='LoadImage', backend_args=backend_args),
dict(type='GetBBoxCenterScale'),
dict(type='RandomFlip', direction='horizontal'),
dict(type='RandomHalfBody'),
dict(
type='RandomBBoxTransform',
shift_factor=0.,
scale_factor=[0.75, 1.25],
rotate_factor=60),
dict(type='TopdownAffine', input_size=codec['input_size']),
dict(type='mmdet.YOLOXHSVRandomAug'),
dict(
type='Albumentation',
transforms=[
dict(type='Blur', p=0.1),
dict(type='MedianBlur', p=0.1),
dict(
type='CoarseDropout',
max_holes=1,
max_height=0.4,
max_width=0.4,
min_holes=1,
min_height=0.2,
min_width=0.2,
p=0.5),
]),
dict(type='GenerateTarget', encoder=codec),
dict(type='PackPoseInputs')
]

# data loaders
train_dataloader = dict(
batch_size=256,
num_workers=10,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='HumanArt/annotations/training_humanart_coco.json',
data_prefix=dict(img=''),
pipeline=train_pipeline,
))
val_dataloader = dict(
batch_size=64,
num_workers=10,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='HumanArt/annotations/validation_humanart.json',
# bbox_file=f'{data_root}HumanArt/person_detection_results/'
# 'HumanArt_validation_detections_AP_H_56_person.json',
data_prefix=dict(img=''),
test_mode=True,
pipeline=val_pipeline,
))
test_dataloader = val_dataloader

# hooks
default_hooks = dict(
checkpoint=dict(save_best='coco/AP', rule='greater', max_keep_ckpts=1))

custom_hooks = [
# Turn off EMA while training the tiny model
# dict(
# type='EMAHook',
# ema_type='ExpMomentumEMA',
# momentum=0.0002,
# update_buffers=True,
# priority=49),
dict(
type='mmdet.PipelineSwitchHook',
switch_epoch=max_epochs - stage2_num_epochs,
switch_pipeline=train_pipeline_stage2)
]

# evaluators
val_evaluator = dict(
type='CocoMetric',
ann_file=data_root + 'HumanArt/annotations/validation_humanart.json')
test_evaluator = val_evaluator
7 changes: 7 additions & 0 deletions configs/body_2d_keypoint/rtmpose/humanart/rtmpose_humanart.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ Results on Human-Art validation dataset with detector having human AP of 56.2 on

| Arch | Input Size | AP | AP<sup>50</sup> | AP<sup>75</sup> | AR | AR<sup>50</sup> | ckpt | log |
| :-------------------------------------------- | :--------: | :---: | :-------------: | :-------------: | :---: | :-------------: | :-------------------------------------------: | :-------------------------------------------: |
| [rtmpose-t-coco](/configs/body_2d_keypoint/rtmpose/coco/rtmpose-t_8xb256-420e_coco-256x192.py) | 256x192 | 0.161 | 0.283 | 0.154 | 0.221 | 0.373 | [ckpt](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-tiny_simcc-coco_pt-aic-coco_420e-256x192-e613ba3f_20230127.pth) | [log](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-tiny_simcc-coco_pt-aic-coco_420e-256x192-e613ba3f_20230127.json) |
| [rtmpose-t-humanart-coco](/configs/body_2d_keypoint/rtmpose/humanart/rtmpose-t_8xb256-420e_humanart-256x192.py) | 256x192 | 0.249 | 0.395 | 0.256 | 0.323 | 0.485 | [ckpt](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-t_8xb256-420e_humanart-256x192-60b68c98_20230612.pth) | [log](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-t_8xb256-420e_humanart-256x192-60b68c98_20230612.json) |
| [rtmpose-s-coco](/configs/body_2d_keypoint/rtmpose/coco/rtmpose-s_8xb256-420e_coco-256x192.py) | 256x192 | 0.199 | 0.328 | 0.198 | 0.261 | 0.418 | [ckpt](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-s_simcc-coco_pt-aic-coco_420e-256x192-8edcf0d7_20230127.pth) | [log](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-s_simcc-coco_pt-aic-coco_420e-256x192-8edcf0d7_20230127.json) |
| [rtmpose-s-humanart-coco](/configs/body_2d_keypoint/rtmpose/humanart/rtmpose-s_8xb256-420e_humanart-256x192.py) | 256x192 | 0.311 | 0.462 | 0.323 | 0.381 | 0.540 | [ckpt](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-s_8xb256-420e_humanart-256x192-5a3ac943_20230611.pth) | [log](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-s_8xb256-420e_humanart-256x192-5a3ac943_20230611.json) |
| [rtmpose-m-coco](/configs/body_2d_keypoint/rtmpose/coco/rtmpose-m_8xb256-420e_coco-256x192.py) | 256x192 | 0.239 | 0.372 | 0.243 | 0.302 | 0.455 | [ckpt](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-m_simcc-coco_pt-aic-coco_420e-256x192-d8dd5ca4_20230127.pth) | [log](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-m_simcc-coco_pt-aic-coco_420e-256x192-d8dd5ca4_20230127.json) |
Expand All @@ -83,6 +85,8 @@ Results on Human-Art validation dataset with ground-truth bounding-box

| Arch | Input Size | AP | AP<sup>50</sup> | AP<sup>75</sup> | AR | AR<sup>50</sup> | ckpt | log |
| :-------------------------------------------- | :--------: | :---: | :-------------: | :-------------: | :---: | :-------------: | :-------------------------------------------: | :-------------------------------------------: |
| [rtmpose-t-coco](/configs/body_2d_keypoint/rtmpose/coco/rtmpose-t_8xb256-420e_coco-256x192.py) | 256x192 | 0.444 | 0.725 | 0.453 | 0.488 | 0.750 | [ckpt](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-tiny_simcc-coco_pt-aic-coco_420e-256x192-e613ba3f_20230127.pth) | [log](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-tiny_simcc-coco_pt-aic-coco_420e-256x192-e613ba3f_20230127.json) |
| [rtmpose-t-humanart-coco](/configs/body_2d_keypoint/rtmpose/humanart/rtmpose-t_8xb256-420e_humanart-256x192.py) | 256x192 | 0.655 | 0.872 | 0.720 | 0.693 | 0.890 | [ckpt](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-t_8xb256-420e_humanart-256x192-60b68c98_20230612.pth) | [log](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-t_8xb256-420e_humanart-256x192-60b68c98_20230612.json) |
| [rtmpose-s-coco](/configs/body_2d_keypoint/rtmpose/coco/rtmpose-s_8xb256-420e_coco-256x192.py) | 256x192 | 0.480 | 0.739 | 0.498 | 0.521 | 0.763 | [ckpt](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-s_simcc-coco_pt-aic-coco_420e-256x192-8edcf0d7_20230127.pth) | [log](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-s_simcc-coco_pt-aic-coco_420e-256x192-8edcf0d7_20230127.json) |
| [rtmpose-s-humanart-coco](/configs/body_2d_keypoint/rtmpose/humanart/rtmpose-s_8xb256-420e_humanart-256x192.py) | 256x192 | 0.698 | 0.893 | 0.768 | 0.732 | 0.903 | [ckpt](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-s_8xb256-420e_humanart-256x192-5a3ac943_20230611.pth) | [log](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-s_8xb256-420e_humanart-256x192-5a3ac943_20230611.json) |
| [rtmpose-m-coco](/configs/body_2d_keypoint/rtmpose/coco/rtmpose-m_8xb256-420e_coco-256x192.py) | 256x192 | 0.532 | 0.765 | 0.563 | 0.571 | 0.789 | [ckpt](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-m_simcc-coco_pt-aic-coco_420e-256x192-d8dd5ca4_20230127.pth) | [log](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-m_simcc-coco_pt-aic-coco_420e-256x192-d8dd5ca4_20230127.json) |
Expand All @@ -94,6 +98,8 @@ Results on COCO val2017 with detector having human AP of 56.4 on COCO val2017 da

| Arch | Input Size | AP | AP<sup>50</sup> | AP<sup>75</sup> | AR | AR<sup>50</sup> | ckpt | log |
| :-------------------------------------------- | :--------: | :---: | :-------------: | :-------------: | :---: | :-------------: | :-------------------------------------------: | :-------------------------------------------: |
| [rtmpose-t-coco](/configs/body_2d_keypoint/rtmpose/coco/rtmpose-t_8xb256-420e_coco-256x192.py) | 256x192 | 0.682 | 0.883 | 0.759 | 0.736 | 0.920 | [ckpt](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-tiny_simcc-coco_pt-aic-coco_420e-256x192-e613ba3f_20230127.pth) | [log](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-tiny_simcc-coco_pt-aic-coco_420e-256x192-e613ba3f_20230127.json) |
| [rtmpose-t-humanart-coco](/configs/body_2d_keypoint/rtmpose/humanart/rtmpose-t_8xb256-420e_humanart-256x192.py) | 256x192 | 0.665 | 0.875 | 0.739 | 0.721 | 0.916 | [ckpt](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-t_8xb256-420e_humanart-256x192-60b68c98_20230612.pth) | [log](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-t_8xb256-420e_humanart-256x192-60b68c98_20230612.json) |
| [rtmpose-s-coco](/configs/body_2d_keypoint/rtmpose/coco/rtmpose-s_8xb256-420e_coco-256x192.py) | 256x192 | 0.716 | 0.892 | 0.789 | 0.768 | 0.929 | [ckpt](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-s_simcc-coco_pt-aic-coco_420e-256x192-8edcf0d7_20230127.pth) | [log](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-s_simcc-coco_pt-aic-coco_420e-256x192-8edcf0d7_20230127.json) |
| [rtmpose-s-humanart-coco](/configs/body_2d_keypoint/rtmpose/humanart/rtmpose-s_8xb256-420e_humanart-256x192.py) | 256x192 | 0.706 | 0.888 | 0.780 | 0.759 | 0.928 | [ckpt](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-s_8xb256-420e_humanart-256x192-5a3ac943_20230611.pth) | [log](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-s_8xb256-420e_humanart-256x192-5a3ac943_20230611.json) |
| [rtmpose-m-coco](/configs/body_2d_keypoint/rtmpose/coco/rtmpose-m_8xb256-420e_coco-256x192.py) | 256x192 | 0.746 | 0.899 | 0.817 | 0.795 | 0.935 | [ckpt](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-m_simcc-coco_pt-aic-coco_420e-256x192-d8dd5ca4_20230127.pth) | [log](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-m_simcc-coco_pt-aic-coco_420e-256x192-d8dd5ca4_20230127.json) |
Expand All @@ -105,6 +111,7 @@ Results on COCO val2017 with ground-truth bounding box

| Arch | Input Size | AP | AP<sup>50</sup> | AP<sup>75</sup> | AR | AR<sup>50</sup> | ckpt | log |
| :-------------------------------------------- | :--------: | :---: | :-------------: | :-------------: | :---: | :-------------: | :-------------------------------------------: | :-------------------------------------------: |
| [rtmpose-t-humanart-coco](/configs/body_2d_keypoint/rtmpose/humanart/rtmpose-t_8xb256-420e_humanart-256x192.py) | 256x192 | 0.679 | 0.895 | 0.755 | 0.710 | 0.907 | [ckpt](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-t_8xb256-420e_humanart-256x192-60b68c98_20230612.pth) | [log](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-t_8xb256-420e_humanart-256x192-60b68c98_20230612.json) |
| [rtmpose-s-humanart-coco](/configs/body_2d_keypoint/rtmpose/humanart/rtmpose-s_8xb256-420e_humanart-256x192.py) | 256x192 | 0.725 | 0.916 | 0.798 | 0.753 | 0.925 | [ckpt](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-s_8xb256-420e_humanart-256x192-5a3ac943_20230611.pth) | [log](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-s_8xb256-420e_humanart-256x192-5a3ac943_20230611.json) |
| [rtmpose-m-humanart-coco](/configs/body_2d_keypoint/rtmpose/humanart/rtmpose-m_8xb256-420e_humanart-256x192.py) | 256x192 | 0.744 | 0.916 | 0.818 | 0.770 | 0.930 | [ckpt](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-m_8xb256-420e_humanart-256x192-8430627b_20230611.pth) | [log](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-m_8xb256-420e_humanart-256x192-8430627b_20230611.json) |
| [rtmpose-l-humanart-coco](/configs/body_2d_keypoint/rtmpose/humanart/rtmpose-l_8xb256-420e_humanart-256x192.py) | 256x192 | 0.770 | 0.927 | 0.840 | 0.794 | 0.939 | [ckpt](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-l_8xb256-420e_humanart-256x192-389f2cb0_20230611.pth) | [log](https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-l_8xb256-420e_humanart-256x192-389f2cb0_20230611.json) |
32 changes: 32 additions & 0 deletions configs/body_2d_keypoint/rtmpose/humanart/rtmpose_humanart.yml
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,35 @@ Models:
AR@0.5: 0.903
Task: Body 2D Keypoint
Weights: https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-s_8xb256-420e_humanart-256x192-5a3ac943_20230611.pth
- Config: configs/body_2d_keypoint/rtmpose/humanart/rtmpose-t_8xb256-420e_humanart-256x192.py
In Collection: RTMPose
Metadata:
Architecture: *id001
Training Data: *id002
Name: rtmpose-t_8xb256-420e_humanart-256x192
Results:
- Dataset: COCO
Metrics:
AP: 0.665
AP@0.5: 0.875
AP@0.75: 0.739
AR: 0.721
AR@0.5: 0.916
Task: Body 2D Keypoint
- Dataset: Human-Art
Metrics:
AP: 0.249
AP@0.5: 0.395
AP@0.75: 0.256
AR: 0.323
AR@0.5: 0.485
Task: Body 2D Keypoint
- Dataset: Human-Art(GT)
Metrics:
AP: 0.655
AP@0.5: 0.872
AP@0.75: 0.720
AR: 0.693
AR@0.5: 0.890
Task: Body 2D Keypoint
Weights: https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/rtmpose-t_8xb256-420e_humanart-256x192-60b68c98_20230612.pth
Loading

0 comments on commit ec26fca

Please sign in to comment.