Skip to content

Commit

Permalink
add transforms and full config
Browse files Browse the repository at this point in the history
  • Loading branch information
LareinaM committed Jul 10, 2023
1 parent fcf1ec6 commit 0d3b47a
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,16 @@
type='Pose3dLocalVisualizer', vis_backends=vis_backends, name='visualizer')

# runtime
train_cfg = None
train_cfg = dict(max_epochs=120, val_interval=10)

# optimizer
optim_wrapper = dict(
optimizer=dict(type='AdamW', lr=0.0002, weight_decay=0.01))

# learning policy
param_scheduler = [
dict(type='ExponentialLR', gamma=0.99, end=120, by_epoch=True)
]

auto_scale_lr = dict(base_batch_size=512)

Expand Down Expand Up @@ -57,6 +62,18 @@
data_root = 'data/h36m/'

# pipelines
train_pipeline = [
dict(
type='RandomFlipAroundRoot',
keypoints_flip_cfg={},
target_flip_cfg={},
flip_image=True),
dict(type='GenerateTarget', encoder=codec),
dict(
type='PackPoseInputs',
meta_keys=('id', 'category_id', 'target_img_path', 'flip_indices',
'factor', 'camera_param'))
]
val_pipeline = [
dict(type='GenerateTarget', encoder=codec),
dict(
Expand All @@ -66,9 +83,27 @@
]

# data loaders
train_dataloader = dict(
batch_size=32,
prefetch_factor=4,
pin_memory=True,
num_workers=2,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
ann_file='annotation_body3d/fps50/h36m_train.npz',
seq_len=1,
multiple_target=243,
multiple_target_step=81,
camera_param_file='annotation_body3d/cameras.pkl',
data_root=data_root,
data_prefix=dict(img='images/'),
pipeline=train_pipeline,
))

val_dataloader = dict(
batch_size=32,
shuffle=False,
prefetch_factor=4,
pin_memory=True,
num_workers=2,
Expand All @@ -78,8 +113,8 @@
type=dataset_type,
ann_file='annotation_body3d/fps50/h36m_test.npz',
seq_len=1,
multiple_target=243,
seq_step=1,
multiple_target=243,
camera_param_file='annotation_body3d/cameras.pkl',
data_root=data_root,
data_prefix=dict(img='images/'),
Expand Down
12 changes: 10 additions & 2 deletions mmpose/datasets/datasets/body3d/h36m_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class Human36mDataset(BaseMocapDataset):
Default: 1.
multiple_target (int): If larger than 0, merge every
``multiple_target`` sequence together. Default: 0.
multiple_target_step (int): The interval for merging sequence. Only
valid when ``multiple_target`` is larger than 0. Default: 0.
pad_video_seq (bool): Whether to pad the video so that poses will be
predicted for every frame in the video. Default: ``False``.
causal (bool): If set to ``True``, the rightmost input frame will be
Expand Down Expand Up @@ -110,6 +112,7 @@ def __init__(self,
seq_len: int = 1,
seq_step: int = 1,
multiple_target: int = 0,
multiple_target_step: int = 0,
pad_video_seq: bool = False,
causal: bool = True,
subset_frac: float = 1.0,
Expand Down Expand Up @@ -151,6 +154,10 @@ def __init__(self,
assert exists(factor_file), 'Annotation file does not exist.'
self.factor_file = factor_file

if multiple_target > 0 and multiple_target_step == 0:
multiple_target_step = multiple_target
self.multiple_target_step = multiple_target_step

super().__init__(
ann_file=ann_file,
seq_len=seq_len,
Expand Down Expand Up @@ -191,8 +198,9 @@ def get_sequence_indices(self) -> List[List[int]]:
n_frame = len(_indices)
seqs_from_video = [
_indices[i:(i + self.multiple_target):_step]
for i in range(0, n_frame, self.multiple_target)
][:n_frame // self.multiple_target]
for i in range(0, n_frame, self.multiple_target_step)
][:(n_frame + self.multiple_target_step -
self.multiple_target) // self.multiple_target_step]
sequence_indices.extend(seqs_from_video)

else:
Expand Down
18 changes: 15 additions & 3 deletions mmpose/datasets/transforms/pose3d_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class RandomFlipAroundRoot(BaseTransform):
flip_prob (float): Probability of flip. Default: 0.5.
flip_camera (bool): Whether to flip horizontal distortion coefficients.
Default: ``False``.
flip_image (bool): Whether to flip keypoints horizontally according
to image size. Default: ``False``.
Required keys:
keypoints
Expand All @@ -39,14 +41,16 @@ def __init__(self,
keypoints_flip_cfg,
target_flip_cfg,
flip_prob=0.5,
flip_camera=False):
flip_camera=False,
flip_image=False):
self.keypoints_flip_cfg = keypoints_flip_cfg
self.target_flip_cfg = target_flip_cfg
self.flip_prob = flip_prob
self.flip_camera = flip_camera
self.flip_image = flip_image

def transform(self, results: Dict) -> dict:
"""The transform function of :class:`ZeroCenterPose`.
"""The transform function of :class:`RandomFlipAroundRoot`.
See ``transform()`` method of :class:`BaseTransform` for details.
Expand Down Expand Up @@ -76,6 +80,15 @@ def transform(self, results: Dict) -> dict:
flip_indices = results['flip_indices']

# flip joint coordinates
_camera_param = deepcopy(results['camera_param'])
if self.flip_image:
assert 'camera_param' in results, \
'Camera parameters are missing.'
assert 'w' in _camera_param
w = _camera_param['w'] / 2
self.keypoints_flip_cfg['center_x'] = w
self.target_flip_cfg['center_x'] = w

keypoints, keypoints_visible = flip_keypoints_custom_center(
keypoints, keypoints_visible, flip_indices,
**self.keypoints_flip_cfg)
Expand All @@ -92,7 +105,6 @@ def transform(self, results: Dict) -> dict:
if self.flip_camera:
assert 'camera_param' in results, \
'Camera parameters are missing.'
_camera_param = deepcopy(results['camera_param'])

assert 'c' in _camera_param
_camera_param['c'][0] *= -1
Expand Down
20 changes: 20 additions & 0 deletions tests/test_datasets/test_transforms/test_pose3d_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,23 @@ def test_transform(self):
-self.data_info['camera_param']['p'][0],
camera2['p'][0],
atol=4.))

# test flipping w.r.t. image
transform = RandomFlipAroundRoot({}, {}, flip_prob=1, flip_image=True)
results = deepcopy(self.data_info)
results = transform(results)
kpts2 = results['keypoints']
tar2 = results['lifting_target']

camera_param = results['camera_param']
for left, right in enumerate(flip_indices):
self.assertTrue(
np.allclose(
camera_param['w'] - kpts1[0][left][:1],
kpts2[0][right][:1],
atol=4.))
self.assertTrue(
np.allclose(kpts1[0][left][1:], kpts2[0][right][1:], atol=4.))
self.assertTrue(
np.allclose(
tar1[..., left, 1:], tar2[..., right, 1:], atol=4.))

0 comments on commit 0d3b47a

Please sign in to comment.