Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 45 additions & 25 deletions genesis/engine/entities/rigid_entity/rigid_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,31 +894,45 @@ def inverse_kinematics_multilink(
if n_links == 0:
gs.raise_exception("Target link not provided.")

if len(poss) == n_links:
if self._solver.n_envs > 0:
if poss[0].shape[0] != self._solver.n_envs:
gs.raise_exception("First dimension of elements in `poss` must be equal to scene.n_envs.")
elif len(poss) == 0:
if self._solver.n_envs == 0:
poss = [gu.zero_pos()] * n_links
else:
poss = [self._solver._batch_array(gu.zero_pos(), True)] * n_links
if len(poss) == 0:
poss = [None] * n_links
pos_mask = [False, False, False]
else:
elif len(poss) != n_links:
gs.raise_exception("Accepting only `poss` with length equal to `links` or empty list.")

if len(quats) == n_links:
if self._solver.n_envs > 0:
if quats[0].shape[0] != self._solver.n_envs:
gs.raise_exception("First dimension of elements in `quats` must be equal to scene.n_envs.")
elif len(quats) == 0:
if self._solver.n_envs == 0:
quats = [gu.identity_quat()] * n_links
else:
quats = [self._solver._batch_array(gu.identity_quat(), True)] * n_links
if len(quats) == 0:
quats = [None] * n_links
rot_mask = [False, False, False]
else:
gs.raise_exception("Accepting only `quats` with length equal to `links` or empty list.")
elif len(quats) != n_links:
gs.raise_exception("Accepting only `quatss` with length equal to `links` or empty list.")

link_pos_mask = []
link_rot_mask = []
for i in range(n_links):
if poss[i] is None and quats[i] is None:
gs.raise_exception("At least one of `poss` or `quats` must be provided.")
if poss[i] is not None:
link_pos_mask.append(True)
if self._solver.n_envs > 0:
if poss[i].shape[0] != self._solver.n_envs:
gs.raise_exception("First dimension of elements in `poss` must be equal to scene.n_envs.")
else:
link_pos_mask.append(False)
if self._solver.n_envs == 0:
poss[i] = gu.zero_pos()
else:
poss[i] = self._solver._batch_array(gu.zero_pos(), True)
if quats[i] is not None:
link_rot_mask.append(True)
if self._solver.n_envs > 0:
if quats[i].shape[0] != self._solver.n_envs:
gs.raise_exception("First dimension of elements in `quats` must be equal to scene.n_envs.")
else:
link_rot_mask.append(False)
if self._solver.n_envs == 0:
quats[i] = gu.identity_quat()
else:
quats[i] = self._solver._batch_array(gu.identity_quat(), True)

if init_qpos is not None:
init_qpos = torch.as_tensor(init_qpos, dtype=gs.tc_float)
Expand Down Expand Up @@ -947,6 +961,8 @@ def inverse_kinematics_multilink(
gs.raise_exception("You can only align 0, 1 axis or all 3 axes.")
else:
pass # nothing needs to change for 0 or 3 axes
link_pos_mask = torch.as_tensor(link_pos_mask, dtype=gs.tc_int, device=gs.device)
link_rot_mask = torch.as_tensor(link_rot_mask, dtype=gs.tc_int, device=gs.device)

links_idx = torch.as_tensor([link.idx for link in links], dtype=gs.tc_int, device=gs.device)
poss = torch.stack(
Expand Down Expand Up @@ -992,6 +1008,8 @@ def inverse_kinematics_multilink(
rot_tol,
pos_mask,
rot_mask,
link_pos_mask,
link_rot_mask,
max_step_size,
respect_joint_limit,
)
Expand Down Expand Up @@ -1032,6 +1050,8 @@ def _kernel_inverse_kinematics(
rot_tol: ti.f32,
pos_mask_: ti.types.ndarray(),
rot_mask_: ti.types.ndarray(),
link_pos_mask: ti.types.ndarray(),
link_rot_mask: ti.types.ndarray(),
max_step_size: ti.f32,
respect_joint_limit: ti.i32,
):
Expand Down Expand Up @@ -1067,7 +1087,7 @@ def _kernel_inverse_kinematics(
tgt_pos_i = ti.Vector([poss[i_ee, i_b, 0], poss[i_ee, i_b, 1], poss[i_ee, i_b, 2]])
err_pos_i = tgt_pos_i - self._solver.links_state[i_l_ee, i_b].pos
for k in range(3):
err_pos_i[k] *= pos_mask[k]
err_pos_i[k] *= pos_mask[k] * link_pos_mask[i_ee]
if err_pos_i.norm() > pos_tol:
solved = False

Expand All @@ -1080,7 +1100,7 @@ def _kernel_inverse_kinematics(
)
)
for k in range(3):
err_rot_i[k] *= rot_mask[k]
err_rot_i[k] *= rot_mask[k] * link_rot_mask[i_ee]
if err_rot_i.norm() > rot_tol:
solved = False

Expand Down Expand Up @@ -1150,7 +1170,7 @@ def _kernel_inverse_kinematics(
tgt_pos_i = ti.Vector([poss[i_ee, i_b, 0], poss[i_ee, i_b, 1], poss[i_ee, i_b, 2]])
err_pos_i = tgt_pos_i - self._solver.links_state[i_l_ee, i_b].pos
for k in range(3):
err_pos_i[k] *= pos_mask[k]
err_pos_i[k] *= pos_mask[k] * link_pos_mask[i_ee]
if err_pos_i.norm() > pos_tol:
solved = False

Expand All @@ -1163,7 +1183,7 @@ def _kernel_inverse_kinematics(
)
)
for k in range(3):
err_rot_i[k] *= rot_mask[k]
err_rot_i[k] *= rot_mask[k] * link_rot_mask[i_ee]
if err_rot_i.norm() > rot_tol:
solved = False

Expand Down