@@ -580,6 +580,10 @@ def _build(self):
580580 self ._n_free_verts = len (self ._free_verts_idx_local )
581581 self ._n_fixed_verts = len (self ._fixed_verts_idx_local )
582582
583+ self ._dofs_idx = torch .arange (
584+ self ._dof_start , self ._dof_start + self ._n_dofs , dtype = gs .tc_int , device = gs .device
585+ )
586+
583587 self ._geoms = self .geoms
584588 self ._vgeoms = self .vgeoms
585589
@@ -1493,6 +1497,7 @@ def _kernel_forward_kinematics(
14931497 # ------------------------------------------------------------------------------------
14941498 # --------------------------------- motion planing -----------------------------------
14951499 # ------------------------------------------------------------------------------------
1500+
14961501 @gs .assert_built
14971502 def plan_path (
14981503 self ,
@@ -1623,6 +1628,50 @@ def plan_path(
16231628 # ---------------------------------- control & io ------------------------------------
16241629 # ------------------------------------------------------------------------------------
16251630
1631+ def _get_idx (self , idx_local , idx_local_max , idx_global_start = 0 , * , unsafe = False ):
1632+ # Handling default argument and special cases
1633+ if idx_local is None :
1634+ if unsafe :
1635+ idx_global = slice (idx_global_start , idx_local_max + idx_global_start )
1636+ else :
1637+ idx_global = range (idx_global_start , idx_local_max + idx_global_start )
1638+ elif isinstance (idx_local , (range , slice )):
1639+ idx_global = range (
1640+ (idx_local .start or 0 ) + idx_global_start ,
1641+ (idx_local .stop if idx_local .stop is not None else idx_local_max ) + idx_global_start ,
1642+ idx_local .step or 1 ,
1643+ )
1644+ elif isinstance (idx_local , (int , np .integer )):
1645+ idx_global = idx_local + idx_global_start
1646+ elif isinstance (idx_local , (list , tuple )):
1647+ try :
1648+ idx_global = [i + idx_global_start for i in idx_local ]
1649+ except TypeError :
1650+ gs .raise_exception ("Expecting a sequence of integers for `idx_local`." )
1651+ else :
1652+ # Increment may be slow when dealing with heterogenuous data, so it must be avoided if possible
1653+ if idx_global_start > 0 :
1654+ idx_global = idx_local + idx_global_start
1655+ else :
1656+ idx_global = idx_local
1657+
1658+ # Early return if unsafe
1659+ if unsafe :
1660+ return idx_global
1661+
1662+ # Perform a bunch of sanity checks
1663+ _idx_global = torch .as_tensor (idx_global , dtype = gs .tc_int , device = gs .device ).contiguous ()
1664+ if _idx_global is not idx_global :
1665+ gs .logger .debug (ALLOCATE_TENSOR_WARNING )
1666+ idx_global = torch .atleast_1d (_idx_global )
1667+
1668+ if idx_global .ndim != 1 :
1669+ gs .raise_exception ("Expecting a 1D tensor for `idx_local`." )
1670+ if (idx_global < 0 ).any () or (idx_global >= idx_global_start + idx_local_max ).any ():
1671+ gs .raise_exception ("`idx_local` exceeds valid range." )
1672+
1673+ return idx_global
1674+
16261675 def get_joint (self , name = None , uid = None ):
16271676 """
16281677 Get a RigidJoint object by name or uid.
@@ -1949,7 +1998,7 @@ def get_links_invweight(self, links_idx_local=None, envs_idx=None, *, unsafe=Fal
19491998 return self ._solver .get_links_invweight (links_idx , envs_idx , unsafe = unsafe )
19501999
19512000 @gs .assert_built
1952- def set_pos (self , pos , envs_idx = None , * , relative = False , zero_velocity = True , unsafe = False ):
2001+ def set_pos (self , pos , envs_idx = None , * , relative = False , unsafe = False ):
19532002 """
19542003 Set position of the entity's base link.
19552004
@@ -1971,19 +2020,13 @@ def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, uns
19712020 if _pos is not pos :
19722021 gs .logger .debug (ALLOCATE_TENSOR_WARNING )
19732022 pos = _pos
2023+ self ._solver .set_dofs_velocity (None , self ._dofs_idx , envs_idx , skip_forward = True , unsafe = unsafe )
19742024 self ._solver .set_base_links_pos (
1975- pos .unsqueeze (- 2 ),
1976- self ._base_links_idx_ ,
1977- envs_idx ,
1978- relative = relative ,
1979- unsafe = unsafe ,
1980- skip_forward = zero_velocity ,
2025+ pos .unsqueeze (- 2 ), self ._base_links_idx_ , envs_idx , relative = relative , unsafe = unsafe
19812026 )
1982- if zero_velocity :
1983- self .zero_all_dofs_velocity (envs_idx , unsafe = unsafe )
19842027
19852028 @gs .assert_built
1986- def set_quat (self , quat , envs_idx = None , * , relative = False , zero_velocity = True , unsafe = False ):
2029+ def set_quat (self , quat , envs_idx = None , * , relative = False , unsafe = False ):
19872030 """
19882031 Set quaternion of the entity's base link.
19892032
@@ -2005,16 +2048,10 @@ def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, u
20052048 if _quat is not quat :
20062049 gs .logger .debug (ALLOCATE_TENSOR_WARNING )
20072050 quat = _quat
2051+ self ._solver .set_dofs_velocity (None , self ._dofs_idx , envs_idx , skip_forward = True , unsafe = unsafe )
20082052 self ._solver .set_base_links_quat (
2009- quat .unsqueeze (- 2 ),
2010- self ._base_links_idx_ ,
2011- envs_idx ,
2012- relative = relative ,
2013- unsafe = unsafe ,
2014- skip_forward = zero_velocity ,
2053+ quat .unsqueeze (- 2 ), self ._base_links_idx_ , envs_idx , relative = relative , unsafe = unsafe
20152054 )
2016- if zero_velocity :
2017- self .zero_all_dofs_velocity (envs_idx , unsafe = unsafe )
20182055
20192056 @gs .assert_built
20202057 def get_verts (self ):
@@ -2061,52 +2098,8 @@ def get_verts(self):
20612098 tensor = tensor [0 ]
20622099 return tensor
20632100
2064- def _get_idx (self , idx_local , idx_local_max , idx_global_start = 0 , * , unsafe = False ):
2065- # Handling default argument and special cases
2066- if idx_local is None :
2067- if unsafe :
2068- idx_global = slice (idx_global_start , idx_local_max + idx_global_start )
2069- else :
2070- idx_global = range (idx_global_start , idx_local_max + idx_global_start )
2071- elif isinstance (idx_local , (range , slice )):
2072- idx_global = range (
2073- (idx_local .start or 0 ) + idx_global_start ,
2074- (idx_local .stop if idx_local .stop is not None else idx_local_max ) + idx_global_start ,
2075- idx_local .step or 1 ,
2076- )
2077- elif isinstance (idx_local , (int , np .integer )):
2078- idx_global = idx_local + idx_global_start
2079- elif isinstance (idx_local , (list , tuple )):
2080- try :
2081- idx_global = [i + idx_global_start for i in idx_local ]
2082- except TypeError :
2083- gs .raise_exception ("Expecting a sequence of integers for `idx_local`." )
2084- else :
2085- # Increment may be slow when dealing with heterogenuous data, so it must be avoided if possible
2086- if idx_global_start > 0 :
2087- idx_global = idx_local + idx_global_start
2088- else :
2089- idx_global = idx_local
2090-
2091- # Early return if unsafe
2092- if unsafe :
2093- return idx_global
2094-
2095- # Perform a bunch of sanity checks
2096- _idx_global = torch .as_tensor (idx_global , dtype = gs .tc_int , device = gs .device ).contiguous ()
2097- if _idx_global is not idx_global :
2098- gs .logger .debug (ALLOCATE_TENSOR_WARNING )
2099- idx_global = torch .atleast_1d (_idx_global )
2100-
2101- if idx_global .ndim != 1 :
2102- gs .raise_exception ("Expecting a 1D tensor for `idx_local`." )
2103- if (idx_global < 0 ).any () or (idx_global >= idx_global_start + idx_local_max ).any ():
2104- gs .raise_exception ("`idx_local` exceeds valid range." )
2105-
2106- return idx_global
2107-
21082101 @gs .assert_built
2109- def set_qpos (self , qpos , qs_idx_local = None , envs_idx = None , * , zero_velocity = True , unsafe = False ):
2102+ def set_qpos (self , qpos , qs_idx_local = None , envs_idx = None , * , zero_velocity = True , skip_forward = False , unsafe = False ):
21102103 """
21112104 Set the entity's qpos.
21122105
@@ -2122,9 +2115,9 @@ def set_qpos(self, qpos, qs_idx_local=None, envs_idx=None, *, zero_velocity=True
21222115 Whether to zero the velocity of all the entity's dofs. Defaults to True. This is a safety measure after a sudden change in entity pose.
21232116 """
21242117 qs_idx = self ._get_idx (qs_idx_local , self .n_qs , self ._q_start , unsafe = True )
2125- self ._solver .set_qpos (qpos , qs_idx , envs_idx , unsafe = unsafe , skip_forward = zero_velocity )
21262118 if zero_velocity :
2127- self .zero_all_dofs_velocity (envs_idx , unsafe = unsafe )
2119+ self ._solver .set_dofs_velocity (None , self ._dofs_idx , envs_idx , skip_forward = True , unsafe = unsafe )
2120+ self ._solver .set_qpos (qpos , qs_idx , envs_idx , skip_forward = skip_forward , unsafe = unsafe )
21282121
21292122 @gs .assert_built
21302123 def set_dofs_kp (self , kp , dofs_idx_local = None , envs_idx = None , * , unsafe = False ):
@@ -2203,37 +2196,37 @@ def set_dofs_damping(self, damping, dofs_idx_local=None, envs_idx=None, *, unsaf
22032196 self ._solver .set_dofs_damping (damping , dofs_idx , envs_idx , unsafe = unsafe )
22042197
22052198 @gs .assert_built
2206- def set_dofs_velocity (self , velocity = None , dofs_idx_local = None , envs_idx = None , * , unsafe = False ):
2199+ def set_dofs_frictionloss (self , frictionloss , dofs_idx_local = None , envs_idx = None , * , unsafe = False ):
22072200 """
2208- Set the entity's dofs' velocity.
2209-
2201+ Set the entity's dofs' friction loss.
22102202 Parameters
22112203 ----------
2212- velocity : array_like | None
2213- The velocity to set. Zero if not specified .
2204+ frictionloss : array_like
2205+ The friction loss values to set .
22142206 dofs_idx_local : None | array_like, optional
22152207 The indices of the dofs to set. If None, all dofs will be set. Note that here this uses the local `q_idx`, not the scene-level one. Defaults to None.
22162208 envs_idx : None | array_like, optional
22172209 The indices of the environments. If None, all environments will be considered. Defaults to None.
22182210 """
22192211 dofs_idx = self ._get_idx (dofs_idx_local , self .n_dofs , self ._dof_start , unsafe = True )
2220- self ._solver .set_dofs_velocity ( velocity , dofs_idx , envs_idx , skip_forward = False , unsafe = unsafe )
2212+ self ._solver .set_dofs_frictionloss ( frictionloss , dofs_idx , envs_idx , unsafe = unsafe )
22212213
22222214 @gs .assert_built
2223- def set_dofs_frictionloss (self , frictionloss , dofs_idx_local = None , envs_idx = None , * , unsafe = False ):
2215+ def set_dofs_velocity (self , velocity = None , dofs_idx_local = None , envs_idx = None , * , skip_forward = False , unsafe = False ):
22242216 """
2225- Set the entity's dofs' friction loss.
2217+ Set the entity's dofs' velocity.
2218+
22262219 Parameters
22272220 ----------
2228- frictionloss : array_like
2229- The friction loss values to set.
2221+ velocity : array_like | None
2222+ The velocity to set. Zero if not specified .
22302223 dofs_idx_local : None | array_like, optional
22312224 The indices of the dofs to set. If None, all dofs will be set. Note that here this uses the local `q_idx`, not the scene-level one. Defaults to None.
22322225 envs_idx : None | array_like, optional
22332226 The indices of the environments. If None, all environments will be considered. Defaults to None.
22342227 """
22352228 dofs_idx = self ._get_idx (dofs_idx_local , self .n_dofs , self ._dof_start , unsafe = True )
2236- self ._solver .set_dofs_frictionloss ( frictionloss , dofs_idx , envs_idx , unsafe = unsafe )
2229+ self ._solver .set_dofs_velocity ( velocity , dofs_idx , envs_idx , skip_forward = skip_forward , unsafe = unsafe )
22372230
22382231 @gs .assert_built
22392232 def set_dofs_position (self , position , dofs_idx_local = None , envs_idx = None , * , zero_velocity = True , unsafe = False ):
@@ -2252,9 +2245,9 @@ def set_dofs_position(self, position, dofs_idx_local=None, envs_idx=None, *, zer
22522245 Whether to zero the velocity of all the entity's dofs. Defaults to True. This is a safety measure after a sudden change in entity pose.
22532246 """
22542247 dofs_idx = self ._get_idx (dofs_idx_local , self .n_dofs , self ._dof_start , unsafe = True )
2255- self ._solver .set_dofs_position (position , dofs_idx , envs_idx , unsafe = unsafe , skip_forward = zero_velocity )
22562248 if zero_velocity :
2257- self .zero_all_dofs_velocity (envs_idx , unsafe = unsafe )
2249+ self ._solver .set_dofs_velocity (None , self ._dofs_idx , envs_idx , skip_forward = True , unsafe = unsafe )
2250+ self ._solver .set_dofs_position (position , dofs_idx , envs_idx , unsafe = unsafe )
22582251
22592252 @gs .assert_built
22602253 def control_dofs_force (self , force , dofs_idx_local = None , envs_idx = None , * , unsafe = False ):
@@ -2570,8 +2563,7 @@ def zero_all_dofs_velocity(self, envs_idx=None, *, unsafe=False):
25702563 envs_idx : None | array_like, optional
25712564 The indices of the environments. If None, all environments will be considered. Defaults to None.
25722565 """
2573- dofs_idx_local = torch .arange (self .n_dofs , dtype = gs .tc_int , device = gs .device )
2574- self .set_dofs_velocity (None , dofs_idx_local , envs_idx , unsafe = unsafe )
2566+ self .set_dofs_velocity (None , self ._dofs_idx , envs_idx , unsafe = unsafe )
25752567
25762568 @gs .assert_built
25772569 def detect_collision (self , env_idx = 0 ):
0 commit comments