Skip to content

Commit

Permalink
[add] add profiling test and fix svh retargeting bug
Browse files Browse the repository at this point in the history
  • Loading branch information
yzqin committed Jul 15, 2024
1 parent 6ef0df7 commit 630f457
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,5 @@ retargeting:
finger_tip_link_names: [ "thtip", "fftip", "mftip", "rftip", "lftip" ]
scaling_factor: 1.2

# Source refers to the retargeting input, which usually corresponds to the human hand
# The joint indices of human hand joint which corresponds to each link in the target_link_names
target_link_human_indices: [ [ 0, 0, 0, 0, 0 ], [ 4, 8, 12, 16, 20, ] ]

# A smaller alpha means stronger filtering, i.e. more smooth but also larger latency
low_pass_alpha: 0.2
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,5 @@ retargeting:
finger_tip_link_names: [ "thtip", "fftip", "mftip", "rftip", "lftip" ]
scaling_factor: 1.2

# Source refers to the retargeting input, which usually corresponds to the human hand
# The joint indices of human hand joint which corresponds to each link in the target_link_names
target_link_human_indices: [ [ 0, 0, 0, 0, 0 ], [ 4, 8, 12, 16, 20, ] ]

# A smaller alpha means stronger filtering, i.e. more smooth but also larger latency
low_pass_alpha: 0.2
62 changes: 62 additions & 0 deletions example/profiling/generate_human_data_from_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pickle
from pathlib import Path

import cv2
import tqdm
import tyro

from dex_retargeting.constants import HandType
from example.vector_retargeting.single_hand_detector import SingleHandDetector


def detect_video(video_path: str, output_path: str, is_right: bool):
cap = cv2.VideoCapture(video_path)

if not cap.isOpened():
print("Error: Could not open video file.")
else:
data = []
detector = SingleHandDetector(hand_type="Right", selfie=False)
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
with tqdm.tqdm(total=length) as pbar:
while cap.isOpened():
ret, frame = cap.read()

if not ret:
break

rgb = frame[..., ::-1]
_, joint_pos, _, _ = detector.detect(rgb)
data.append(joint_pos)
pbar.update(1)

output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("wb") as f:
pickle.dump(data, f)

cap.release()
cv2.destroyAllWindows()


def main(
video_path: str,
output_path: str,
hand_type: HandType = HandType.right,
):
"""
Detects the human hand pose from a video and translates the human pose trajectory into a robot pose trajectory.
Args:
video_path: The file path for the input video in .mp4 format.
output_path: The file path for the output data in .pickle format.
hand_type: Specifies which hand is being tracked, either left or right.
Please note that retargeting is specific to the same type of hand: a left robot hand can only be retargeted
to another left robot hand, and the same applies for the right hand.
"""

detect_video(video_path, output_path, is_right=HandType.right == hand_type)


if __name__ == "__main__":
tyro.cli(main)
Binary file added example/profiling/human_joint_right.pkl
Binary file not shown.
69 changes: 69 additions & 0 deletions example/profiling/profile_online_retargeting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import time
from pathlib import Path
from typing import List

import numpy as np

from dex_retargeting.constants import (
get_default_config_path,
RetargetingType,
HandType,
ROBOT_NAMES,
ROBOT_NAME_MAP,
)
from dex_retargeting.retargeting_config import RetargetingConfig
from dex_retargeting.seq_retarget import SeqRetargeting


def profile_retargeting(retargeting: SeqRetargeting, data: List[np.ndarray]):
retargeting_type = retargeting.optimizer.retargeting_type
indices = retargeting.optimizer.target_link_human_indices

total_time = 0
for i, joint_pos in enumerate(data):
if retargeting_type == "POSITION":
indices = indices
ref_value = joint_pos[indices, :]
else:
origin_indices = indices[0, :]
task_indices = indices[1, :]
ref_value = joint_pos[task_indices, :] - joint_pos[origin_indices, :]
tic = time.perf_counter()
qpos = retargeting.retarget(ref_value)
tac = time.perf_counter()
total_time += tac - tic

return total_time


def main():

robot_dir = Path(__file__).absolute().parent.parent.parent / "assets" / "robots" / "hands"
RetargetingConfig.set_default_urdf_dir(str(robot_dir))

# Load data
joint_data = np.load("human_joint_right.pkl", allow_pickle=True)
data_len = len(joint_data)

# Vector retargeting
print(f"Being retargeting profiling with a trajectory of {data_len} hand poses.")
for robot_name in ROBOT_NAMES:
config_path = get_default_config_path(robot_name, RetargetingType.vector, HandType.right)
retargeting = RetargetingConfig.load_from_file(config_path).build()
total_time = profile_retargeting(retargeting, joint_data)
print(
f"Vector retargeting of {ROBOT_NAME_MAP[robot_name]} take {total_time}s in total, fps: {data_len/total_time}hz "
)

# DexPilot retargeting
for robot_name in ROBOT_NAMES:
config_path = get_default_config_path(robot_name, RetargetingType.dexpilot, HandType.right)
retargeting = RetargetingConfig.load_from_file(config_path).build()
total_time = profile_retargeting(retargeting, joint_data)
print(
f"DexPilot retargeting of {ROBOT_NAME_MAP[robot_name]} take {total_time}s in total, fps: {data_len/total_time}hz "
)


if __name__ == "__main__":
main()
20 changes: 10 additions & 10 deletions example/vector_retargeting/detect_from_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,18 @@ def retarget_video(retargeting: SeqRetargeting, video_path: str, output_path: st
ref_value = joint_pos[task_indices, :] - joint_pos[origin_indices, :]
qpos = retargeting.retarget(ref_value)
data.append(qpos)
pbar.update(1)

meta_data = dict(
config_path=config_path,
dof=len(retargeting.optimizer.robot.dof_joint_names),
joint_names=retargeting.optimizer.robot.dof_joint_names,
)
meta_data = dict(
config_path=config_path,
dof=len(retargeting.optimizer.robot.dof_joint_names),
joint_names=retargeting.optimizer.robot.dof_joint_names,
)

output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("wb") as f:
pickle.dump(dict(data=data, meta_data=meta_data), f)
pbar.update(1)
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("wb") as f:
pickle.dump(dict(data=data, meta_data=meta_data), f)

retargeting.verbose()
cap.release()
Expand Down

0 comments on commit 630f457

Please sign in to comment.