@@ -894,31 +894,45 @@ def inverse_kinematics_multilink(
894
894
if n_links == 0 :
895
895
gs .raise_exception ("Target link not provided." )
896
896
897
- if len (poss ) == n_links :
898
- if self ._solver .n_envs > 0 :
899
- if poss [0 ].shape [0 ] != self ._solver .n_envs :
900
- gs .raise_exception ("First dimension of elements in `poss` must be equal to scene.n_envs." )
901
- elif len (poss ) == 0 :
902
- if self ._solver .n_envs == 0 :
903
- poss = [gu .zero_pos ()] * n_links
904
- else :
905
- poss = [self ._solver ._batch_array (gu .zero_pos (), True )] * n_links
897
+ if len (poss ) == 0 :
898
+ poss = [None ] * n_links
906
899
pos_mask = [False , False , False ]
907
- else :
900
+ elif len ( poss ) != n_links :
908
901
gs .raise_exception ("Accepting only `poss` with length equal to `links` or empty list." )
909
902
910
- if len (quats ) == n_links :
911
- if self ._solver .n_envs > 0 :
912
- if quats [0 ].shape [0 ] != self ._solver .n_envs :
913
- gs .raise_exception ("First dimension of elements in `quats` must be equal to scene.n_envs." )
914
- elif len (quats ) == 0 :
915
- if self ._solver .n_envs == 0 :
916
- quats = [gu .identity_quat ()] * n_links
917
- else :
918
- quats = [self ._solver ._batch_array (gu .identity_quat (), True )] * n_links
903
+ if len (quats ) == 0 :
904
+ quats = [None ] * n_links
919
905
rot_mask = [False , False , False ]
920
- else :
921
- gs .raise_exception ("Accepting only `quats` with length equal to `links` or empty list." )
906
+ elif len (quats ) != n_links :
907
+ gs .raise_exception ("Accepting only `quatss` with length equal to `links` or empty list." )
908
+
909
+ link_pos_mask = []
910
+ link_rot_mask = []
911
+ for i in range (n_links ):
912
+ if poss [i ] is None and quats [i ] is None :
913
+ gs .raise_exception ("At least one of `poss` or `quats` must be provided." )
914
+ if poss [i ] is not None :
915
+ link_pos_mask .append (True )
916
+ if self ._solver .n_envs > 0 :
917
+ if poss [i ].shape [0 ] != self ._solver .n_envs :
918
+ gs .raise_exception ("First dimension of elements in `poss` must be equal to scene.n_envs." )
919
+ else :
920
+ link_pos_mask .append (False )
921
+ if self ._solver .n_envs == 0 :
922
+ poss [i ] = gu .zero_pos ()
923
+ else :
924
+ poss [i ] = self ._solver ._batch_array (gu .zero_pos (), True )
925
+ if quats [i ] is not None :
926
+ link_rot_mask .append (True )
927
+ if self ._solver .n_envs > 0 :
928
+ if quats [i ].shape [0 ] != self ._solver .n_envs :
929
+ gs .raise_exception ("First dimension of elements in `quats` must be equal to scene.n_envs." )
930
+ else :
931
+ link_rot_mask .append (False )
932
+ if self ._solver .n_envs == 0 :
933
+ quats [i ] = gu .identity_quat ()
934
+ else :
935
+ quats [i ] = self ._solver ._batch_array (gu .identity_quat (), True )
922
936
923
937
if init_qpos is not None :
924
938
init_qpos = torch .as_tensor (init_qpos , dtype = gs .tc_float )
@@ -947,6 +961,8 @@ def inverse_kinematics_multilink(
947
961
gs .raise_exception ("You can only align 0, 1 axis or all 3 axes." )
948
962
else :
949
963
pass # nothing needs to change for 0 or 3 axes
964
+ link_pos_mask = torch .as_tensor (link_pos_mask , dtype = gs .tc_int , device = gs .device )
965
+ link_rot_mask = torch .as_tensor (link_rot_mask , dtype = gs .tc_int , device = gs .device )
950
966
951
967
links_idx = torch .as_tensor ([link .idx for link in links ], dtype = gs .tc_int , device = gs .device )
952
968
poss = torch .stack (
@@ -992,6 +1008,8 @@ def inverse_kinematics_multilink(
992
1008
rot_tol ,
993
1009
pos_mask ,
994
1010
rot_mask ,
1011
+ link_pos_mask ,
1012
+ link_rot_mask ,
995
1013
max_step_size ,
996
1014
respect_joint_limit ,
997
1015
)
@@ -1032,6 +1050,8 @@ def _kernel_inverse_kinematics(
1032
1050
rot_tol : ti .f32 ,
1033
1051
pos_mask_ : ti .types .ndarray (),
1034
1052
rot_mask_ : ti .types .ndarray (),
1053
+ link_pos_mask : ti .types .ndarray (),
1054
+ link_rot_mask : ti .types .ndarray (),
1035
1055
max_step_size : ti .f32 ,
1036
1056
respect_joint_limit : ti .i32 ,
1037
1057
):
@@ -1067,7 +1087,7 @@ def _kernel_inverse_kinematics(
1067
1087
tgt_pos_i = ti .Vector ([poss [i_ee , i_b , 0 ], poss [i_ee , i_b , 1 ], poss [i_ee , i_b , 2 ]])
1068
1088
err_pos_i = tgt_pos_i - self ._solver .links_state [i_l_ee , i_b ].pos
1069
1089
for k in range (3 ):
1070
- err_pos_i [k ] *= pos_mask [k ]
1090
+ err_pos_i [k ] *= pos_mask [k ] * link_pos_mask [ i_ee ]
1071
1091
if err_pos_i .norm () > pos_tol :
1072
1092
solved = False
1073
1093
@@ -1080,7 +1100,7 @@ def _kernel_inverse_kinematics(
1080
1100
)
1081
1101
)
1082
1102
for k in range (3 ):
1083
- err_rot_i [k ] *= rot_mask [k ]
1103
+ err_rot_i [k ] *= rot_mask [k ] * link_rot_mask [ i_ee ]
1084
1104
if err_rot_i .norm () > rot_tol :
1085
1105
solved = False
1086
1106
@@ -1150,7 +1170,7 @@ def _kernel_inverse_kinematics(
1150
1170
tgt_pos_i = ti .Vector ([poss [i_ee , i_b , 0 ], poss [i_ee , i_b , 1 ], poss [i_ee , i_b , 2 ]])
1151
1171
err_pos_i = tgt_pos_i - self ._solver .links_state [i_l_ee , i_b ].pos
1152
1172
for k in range (3 ):
1153
- err_pos_i [k ] *= pos_mask [k ]
1173
+ err_pos_i [k ] *= pos_mask [k ] * link_pos_mask [ i_ee ]
1154
1174
if err_pos_i .norm () > pos_tol :
1155
1175
solved = False
1156
1176
@@ -1163,7 +1183,7 @@ def _kernel_inverse_kinematics(
1163
1183
)
1164
1184
)
1165
1185
for k in range (3 ):
1166
- err_rot_i [k ] *= rot_mask [k ]
1186
+ err_rot_i [k ] *= rot_mask [k ] * link_rot_mask [ i_ee ]
1167
1187
if err_rot_i .norm () > rot_tol :
1168
1188
solved = False
1169
1189
0 commit comments