Skip to content

Commit

Permalink
improve decode
Browse files Browse the repository at this point in the history
  • Loading branch information
sithu31296 committed Sep 18, 2021
1 parent dcb8602 commit b00da09
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 22 deletions.
6 changes: 3 additions & 3 deletions infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,14 @@ def postprocess(self, pred, img1, img0):

for det in pred:
if len(det):
boxes = scale_boxes(det[:, :4], img0.shape[:2], img1.shape[-2:])
boxes = scale_boxes(det[:, :4], img0.shape[:2], img1.shape[-2:]).cpu()
boxes = self.box_to_center_scale(boxes)
outputs = self.predict_poses(boxes, img0)

if 'simdr' in self.model_name:
coords = get_simdr_final_preds(*outputs, boxes[:, :2], boxes[:, 2:], self.patch_size)
coords = get_simdr_final_preds(*outputs, boxes, self.patch_size)
else:
coords = get_final_preds(outputs, boxes[:, :2].cpu().numpy(), boxes[:, 2:].cpu().numpy())
coords = get_final_preds(outputs, boxes)

draw_keypoints(img0, coords, self.coco_skeletons)

Expand Down
29 changes: 10 additions & 19 deletions pose/utils/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@
from torch import Tensor


def get_simdr_final_preds(pred_x: Tensor, pred_y: Tensor, center, scale, image_size):
def get_simdr_final_preds(pred_x: Tensor, pred_y: Tensor, boxes: Tensor, image_size: tuple):
center, scale = boxes[:, :2].numpy(), boxes[:, 2:].numpy()

pred_x, pred_y = pred_x.softmax(dim=2), pred_y.softmax(dim=2)
pred_x, pred_y = pred_x.max(dim=2)[-1], pred_y.max(dim=2)[-1]
coords = torch.stack([pred_x / 2, pred_y / 2], dim=-1)
coords = torch.stack([pred_x / 2, pred_y / 2], dim=-1).cpu().numpy()

for i in range(coords.shape[0]):
coords[i] = transform_preds_tensor(coords[i], center[i], scale[i], image_size)
return coords.to(int).cpu().numpy()
coords[i] = transform_preds(coords[i], center[i], scale[i], image_size)
return coords.astype(int)


def get_final_preds(heatmaps, center, scale):
def get_final_preds(heatmaps: Tensor, boxes: Tensor):
center, scale = boxes[:, :2].numpy(), boxes[:, 2:].numpy()
heatmaps = heatmaps.cpu().numpy()
B, C, H, W = heatmaps.shape
coords = get_max_preds(heatmaps)
Expand All @@ -31,12 +34,10 @@ def get_final_preds(heatmaps, center, scale):
hm[py+1][px] - hm[py-1][px]
])
coords[n][p] += np.sign(diff) * .25
preds = coords.copy()

for i in range(B):
preds[i] = transform_preds(coords[i], center[i], scale[i], [W, H])

return preds.astype(int)
coords[i] = transform_preds(coords[i], center[i], scale[i], [W, H])
return coords.astype(int)


def get_max_preds(heatmaps: np.ndarray):
Expand All @@ -51,16 +52,6 @@ def get_max_preds(heatmaps: np.ndarray):
return preds


def transform_preds_tensor(coords, center, scale, output_size):
scale = scale * 200
scale_x = scale[0] / output_size[0]
scale_y = scale[1] / output_size[1]
target_coords = torch.ones_like(coords, device=coords.device)
target_coords[:, 0] = coords[:, 0] * scale_x + center[0] - scale[0] * 0.5
target_coords[:, 1] = coords[:, 1] * scale_y + center[1] - scale[1] * 0.5
return target_coords


def transform_preds(coords, center, scale, output_size):
scale = scale * 200
scale_x = scale[0] / output_size[0]
Expand Down

0 comments on commit b00da09

Please sign in to comment.