Skip to content

Commit 5d83899

Browse files
authored
[FEATURE] Add link-wise mask for poss and quats in multilink IK (#499)
* add link-wise mask for poss and quats * add mask in recompute final error
1 parent 0a5e8a8 commit 5d83899

File tree

1 file changed

+45
-25
lines changed

1 file changed

+45
-25
lines changed

genesis/engine/entities/rigid_entity/rigid_entity.py

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -894,31 +894,45 @@ def inverse_kinematics_multilink(
894894
if n_links == 0:
895895
gs.raise_exception("Target link not provided.")
896896

897-
if len(poss) == n_links:
898-
if self._solver.n_envs > 0:
899-
if poss[0].shape[0] != self._solver.n_envs:
900-
gs.raise_exception("First dimension of elements in `poss` must be equal to scene.n_envs.")
901-
elif len(poss) == 0:
902-
if self._solver.n_envs == 0:
903-
poss = [gu.zero_pos()] * n_links
904-
else:
905-
poss = [self._solver._batch_array(gu.zero_pos(), True)] * n_links
897+
if len(poss) == 0:
898+
poss = [None] * n_links
906899
pos_mask = [False, False, False]
907-
else:
900+
elif len(poss) != n_links:
908901
gs.raise_exception("Accepting only `poss` with length equal to `links` or empty list.")
909902

910-
if len(quats) == n_links:
911-
if self._solver.n_envs > 0:
912-
if quats[0].shape[0] != self._solver.n_envs:
913-
gs.raise_exception("First dimension of elements in `quats` must be equal to scene.n_envs.")
914-
elif len(quats) == 0:
915-
if self._solver.n_envs == 0:
916-
quats = [gu.identity_quat()] * n_links
917-
else:
918-
quats = [self._solver._batch_array(gu.identity_quat(), True)] * n_links
903+
if len(quats) == 0:
904+
quats = [None] * n_links
919905
rot_mask = [False, False, False]
920-
else:
921-
gs.raise_exception("Accepting only `quats` with length equal to `links` or empty list.")
906+
elif len(quats) != n_links:
907+
gs.raise_exception("Accepting only `quatss` with length equal to `links` or empty list.")
908+
909+
link_pos_mask = []
910+
link_rot_mask = []
911+
for i in range(n_links):
912+
if poss[i] is None and quats[i] is None:
913+
gs.raise_exception("At least one of `poss` or `quats` must be provided.")
914+
if poss[i] is not None:
915+
link_pos_mask.append(True)
916+
if self._solver.n_envs > 0:
917+
if poss[i].shape[0] != self._solver.n_envs:
918+
gs.raise_exception("First dimension of elements in `poss` must be equal to scene.n_envs.")
919+
else:
920+
link_pos_mask.append(False)
921+
if self._solver.n_envs == 0:
922+
poss[i] = gu.zero_pos()
923+
else:
924+
poss[i] = self._solver._batch_array(gu.zero_pos(), True)
925+
if quats[i] is not None:
926+
link_rot_mask.append(True)
927+
if self._solver.n_envs > 0:
928+
if quats[i].shape[0] != self._solver.n_envs:
929+
gs.raise_exception("First dimension of elements in `quats` must be equal to scene.n_envs.")
930+
else:
931+
link_rot_mask.append(False)
932+
if self._solver.n_envs == 0:
933+
quats[i] = gu.identity_quat()
934+
else:
935+
quats[i] = self._solver._batch_array(gu.identity_quat(), True)
922936

923937
if init_qpos is not None:
924938
init_qpos = torch.as_tensor(init_qpos, dtype=gs.tc_float)
@@ -947,6 +961,8 @@ def inverse_kinematics_multilink(
947961
gs.raise_exception("You can only align 0, 1 axis or all 3 axes.")
948962
else:
949963
pass # nothing needs to change for 0 or 3 axes
964+
link_pos_mask = torch.as_tensor(link_pos_mask, dtype=gs.tc_int, device=gs.device)
965+
link_rot_mask = torch.as_tensor(link_rot_mask, dtype=gs.tc_int, device=gs.device)
950966

951967
links_idx = torch.as_tensor([link.idx for link in links], dtype=gs.tc_int, device=gs.device)
952968
poss = torch.stack(
@@ -992,6 +1008,8 @@ def inverse_kinematics_multilink(
9921008
rot_tol,
9931009
pos_mask,
9941010
rot_mask,
1011+
link_pos_mask,
1012+
link_rot_mask,
9951013
max_step_size,
9961014
respect_joint_limit,
9971015
)
@@ -1032,6 +1050,8 @@ def _kernel_inverse_kinematics(
10321050
rot_tol: ti.f32,
10331051
pos_mask_: ti.types.ndarray(),
10341052
rot_mask_: ti.types.ndarray(),
1053+
link_pos_mask: ti.types.ndarray(),
1054+
link_rot_mask: ti.types.ndarray(),
10351055
max_step_size: ti.f32,
10361056
respect_joint_limit: ti.i32,
10371057
):
@@ -1067,7 +1087,7 @@ def _kernel_inverse_kinematics(
10671087
tgt_pos_i = ti.Vector([poss[i_ee, i_b, 0], poss[i_ee, i_b, 1], poss[i_ee, i_b, 2]])
10681088
err_pos_i = tgt_pos_i - self._solver.links_state[i_l_ee, i_b].pos
10691089
for k in range(3):
1070-
err_pos_i[k] *= pos_mask[k]
1090+
err_pos_i[k] *= pos_mask[k] * link_pos_mask[i_ee]
10711091
if err_pos_i.norm() > pos_tol:
10721092
solved = False
10731093

@@ -1080,7 +1100,7 @@ def _kernel_inverse_kinematics(
10801100
)
10811101
)
10821102
for k in range(3):
1083-
err_rot_i[k] *= rot_mask[k]
1103+
err_rot_i[k] *= rot_mask[k] * link_rot_mask[i_ee]
10841104
if err_rot_i.norm() > rot_tol:
10851105
solved = False
10861106

@@ -1150,7 +1170,7 @@ def _kernel_inverse_kinematics(
11501170
tgt_pos_i = ti.Vector([poss[i_ee, i_b, 0], poss[i_ee, i_b, 1], poss[i_ee, i_b, 2]])
11511171
err_pos_i = tgt_pos_i - self._solver.links_state[i_l_ee, i_b].pos
11521172
for k in range(3):
1153-
err_pos_i[k] *= pos_mask[k]
1173+
err_pos_i[k] *= pos_mask[k] * link_pos_mask[i_ee]
11541174
if err_pos_i.norm() > pos_tol:
11551175
solved = False
11561176

@@ -1163,7 +1183,7 @@ def _kernel_inverse_kinematics(
11631183
)
11641184
)
11651185
for k in range(3):
1166-
err_rot_i[k] *= rot_mask[k]
1186+
err_rot_i[k] *= rot_mask[k] * link_rot_mask[i_ee]
11671187
if err_rot_i.norm() > rot_tol:
11681188
solved = False
11691189

0 commit comments

Comments
 (0)