diff --git a/examples/collision/tower.py b/examples/collision/tower.py index 6f4e4fe39..9bb3f76f7 100644 --- a/examples/collision/tower.py +++ b/examples/collision/tower.py @@ -1,4 +1,5 @@ import argparse +import os import genesis as gs @@ -9,12 +10,13 @@ def main(): parser.add_argument("-v", "--vis", action="store_true", default=False) args = parser.parse_args() object_type = args.object + horizon = 50 if "PYTEST_VERSION" in os.environ else 1000 gs.init(backend=gs.cpu, precision="32") scene = gs.Scene( sim_options=gs.options.SimOptions( - dt=0.005, + dt=0.004, ), rigid_options=gs.options.RigidOptions( max_collision_pairs=200, @@ -27,7 +29,7 @@ def main(): show_viewer=args.vis, ) - plane = scene.add_entity(gs.morphs.Plane()) + scene.add_entity(gs.morphs.Plane()) # create pyramid of boxes box_width, box_length, box_height = 0.25, 2.0, 0.1 @@ -51,12 +53,12 @@ def main(): # Drop a huge mesh if object_type == "duck": - duck_scale = 1.0 - duck = scene.add_entity( + duck_scale = 0.8 + scene.add_entity( morph=gs.morphs.Mesh( file="meshes/duck.obj", scale=duck_scale, - pos=(0, 0, num_stacks * box_height + 10 * duck_scale), + pos=(0, -0.1, num_stacks * box_height + 10 * duck_scale), ), ) elif object_type == "sphere": @@ -78,7 +80,7 @@ def main(): ) scene.build() - for i in range(600): + for i in range(horizon): scene.step() diff --git a/genesis/engine/entities/rigid_entity/rigid_entity.py b/genesis/engine/entities/rigid_entity/rigid_entity.py index 2d35fe665..37ae16589 100644 --- a/genesis/engine/entities/rigid_entity/rigid_entity.py +++ b/genesis/engine/entities/rigid_entity/rigid_entity.py @@ -580,6 +580,10 @@ def _build(self): self._n_free_verts = len(self._free_verts_idx_local) self._n_fixed_verts = len(self._fixed_verts_idx_local) + self._dofs_idx = torch.arange( + self._dof_start, self._dof_start + self._n_dofs, dtype=gs.tc_int, device=gs.device + ) + self._geoms = self.geoms self._vgeoms = self.vgeoms @@ -1493,6 +1497,7 @@ def _kernel_forward_kinematics( # ------------------------------------------------------------------------------------ # --------------------------------- motion planing ----------------------------------- # ------------------------------------------------------------------------------------ + @gs.assert_built def plan_path( self, @@ -1623,6 +1628,50 @@ def plan_path( # ---------------------------------- control & io ------------------------------------ # ------------------------------------------------------------------------------------ + def _get_idx(self, idx_local, idx_local_max, idx_global_start=0, *, unsafe=False): + # Handling default argument and special cases + if idx_local is None: + if unsafe: + idx_global = slice(idx_global_start, idx_local_max + idx_global_start) + else: + idx_global = range(idx_global_start, idx_local_max + idx_global_start) + elif isinstance(idx_local, (range, slice)): + idx_global = range( + (idx_local.start or 0) + idx_global_start, + (idx_local.stop if idx_local.stop is not None else idx_local_max) + idx_global_start, + idx_local.step or 1, + ) + elif isinstance(idx_local, (int, np.integer)): + idx_global = idx_local + idx_global_start + elif isinstance(idx_local, (list, tuple)): + try: + idx_global = [i + idx_global_start for i in idx_local] + except TypeError: + gs.raise_exception("Expecting a sequence of integers for `idx_local`.") + else: + # Increment may be slow when dealing with heterogenuous data, so it must be avoided if possible + if idx_global_start > 0: + idx_global = idx_local + idx_global_start + else: + idx_global = idx_local + + # Early return if unsafe + if unsafe: + return idx_global + + # Perform a bunch of sanity checks + _idx_global = torch.as_tensor(idx_global, dtype=gs.tc_int, device=gs.device).contiguous() + if _idx_global is not idx_global: + gs.logger.debug(ALLOCATE_TENSOR_WARNING) + idx_global = torch.atleast_1d(_idx_global) + + if idx_global.ndim != 1: + gs.raise_exception("Expecting a 1D tensor for `idx_local`.") + if (idx_global < 0).any() or (idx_global >= idx_global_start + idx_local_max).any(): + gs.raise_exception("`idx_local` exceeds valid range.") + + return idx_global + def get_joint(self, name=None, uid=None): """ Get a RigidJoint object by name or uid. @@ -1949,7 +1998,7 @@ def get_links_invweight(self, links_idx_local=None, envs_idx=None, *, unsafe=Fal return self._solver.get_links_invweight(links_idx, envs_idx, unsafe=unsafe) @gs.assert_built - def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, unsafe=False): + def set_pos(self, pos, envs_idx=None, *, relative=False, unsafe=False): """ Set position of the entity's base link. @@ -1971,19 +2020,13 @@ def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, uns if _pos is not pos: gs.logger.debug(ALLOCATE_TENSOR_WARNING) pos = _pos + self._solver.set_dofs_velocity(None, self._dofs_idx, envs_idx, skip_forward=True, unsafe=unsafe) self._solver.set_base_links_pos( - pos.unsqueeze(-2), - self._base_links_idx_, - envs_idx, - relative=relative, - unsafe=unsafe, - skip_forward=zero_velocity, + pos.unsqueeze(-2), self._base_links_idx_, envs_idx, relative=relative, unsafe=unsafe ) - if zero_velocity: - self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe) @gs.assert_built - def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, unsafe=False): + def set_quat(self, quat, envs_idx=None, *, relative=False, unsafe=False): """ Set quaternion of the entity's base link. @@ -2005,16 +2048,10 @@ def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, u if _quat is not quat: gs.logger.debug(ALLOCATE_TENSOR_WARNING) quat = _quat + self._solver.set_dofs_velocity(None, self._dofs_idx, envs_idx, skip_forward=True, unsafe=unsafe) self._solver.set_base_links_quat( - quat.unsqueeze(-2), - self._base_links_idx_, - envs_idx, - relative=relative, - unsafe=unsafe, - skip_forward=zero_velocity, + quat.unsqueeze(-2), self._base_links_idx_, envs_idx, relative=relative, unsafe=unsafe ) - if zero_velocity: - self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe) @gs.assert_built def get_verts(self): @@ -2061,52 +2098,8 @@ def get_verts(self): tensor = tensor[0] return tensor - def _get_idx(self, idx_local, idx_local_max, idx_global_start=0, *, unsafe=False): - # Handling default argument and special cases - if idx_local is None: - if unsafe: - idx_global = slice(idx_global_start, idx_local_max + idx_global_start) - else: - idx_global = range(idx_global_start, idx_local_max + idx_global_start) - elif isinstance(idx_local, (range, slice)): - idx_global = range( - (idx_local.start or 0) + idx_global_start, - (idx_local.stop if idx_local.stop is not None else idx_local_max) + idx_global_start, - idx_local.step or 1, - ) - elif isinstance(idx_local, (int, np.integer)): - idx_global = idx_local + idx_global_start - elif isinstance(idx_local, (list, tuple)): - try: - idx_global = [i + idx_global_start for i in idx_local] - except TypeError: - gs.raise_exception("Expecting a sequence of integers for `idx_local`.") - else: - # Increment may be slow when dealing with heterogenuous data, so it must be avoided if possible - if idx_global_start > 0: - idx_global = idx_local + idx_global_start - else: - idx_global = idx_local - - # Early return if unsafe - if unsafe: - return idx_global - - # Perform a bunch of sanity checks - _idx_global = torch.as_tensor(idx_global, dtype=gs.tc_int, device=gs.device).contiguous() - if _idx_global is not idx_global: - gs.logger.debug(ALLOCATE_TENSOR_WARNING) - idx_global = torch.atleast_1d(_idx_global) - - if idx_global.ndim != 1: - gs.raise_exception("Expecting a 1D tensor for `idx_local`.") - if (idx_global < 0).any() or (idx_global >= idx_global_start + idx_local_max).any(): - gs.raise_exception("`idx_local` exceeds valid range.") - - return idx_global - @gs.assert_built - def set_qpos(self, qpos, qs_idx_local=None, envs_idx=None, *, zero_velocity=True, unsafe=False): + def set_qpos(self, qpos, qs_idx_local=None, envs_idx=None, *, zero_velocity=True, skip_forward=False, unsafe=False): """ Set the entity's qpos. @@ -2122,9 +2115,9 @@ def set_qpos(self, qpos, qs_idx_local=None, envs_idx=None, *, zero_velocity=True Whether to zero the velocity of all the entity's dofs. Defaults to True. This is a safety measure after a sudden change in entity pose. """ qs_idx = self._get_idx(qs_idx_local, self.n_qs, self._q_start, unsafe=True) - self._solver.set_qpos(qpos, qs_idx, envs_idx, unsafe=unsafe, skip_forward=zero_velocity) if zero_velocity: - self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe) + self._solver.set_dofs_velocity(None, self._dofs_idx, envs_idx, skip_forward=True, unsafe=unsafe) + self._solver.set_qpos(qpos, qs_idx, envs_idx, skip_forward=skip_forward, unsafe=unsafe) @gs.assert_built def set_dofs_kp(self, kp, dofs_idx_local=None, envs_idx=None, *, unsafe=False): @@ -2203,37 +2196,37 @@ def set_dofs_damping(self, damping, dofs_idx_local=None, envs_idx=None, *, unsaf self._solver.set_dofs_damping(damping, dofs_idx, envs_idx, unsafe=unsafe) @gs.assert_built - def set_dofs_velocity(self, velocity=None, dofs_idx_local=None, envs_idx=None, *, unsafe=False): + def set_dofs_frictionloss(self, frictionloss, dofs_idx_local=None, envs_idx=None, *, unsafe=False): """ - Set the entity's dofs' velocity. - + Set the entity's dofs' friction loss. Parameters ---------- - velocity : array_like | None - The velocity to set. Zero if not specified. + frictionloss : array_like + The friction loss values to set. dofs_idx_local : None | array_like, optional The indices of the dofs to set. If None, all dofs will be set. Note that here this uses the local `q_idx`, not the scene-level one. Defaults to None. envs_idx : None | array_like, optional The indices of the environments. If None, all environments will be considered. Defaults to None. """ dofs_idx = self._get_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True) - self._solver.set_dofs_velocity(velocity, dofs_idx, envs_idx, skip_forward=False, unsafe=unsafe) + self._solver.set_dofs_frictionloss(frictionloss, dofs_idx, envs_idx, unsafe=unsafe) @gs.assert_built - def set_dofs_frictionloss(self, frictionloss, dofs_idx_local=None, envs_idx=None, *, unsafe=False): + def set_dofs_velocity(self, velocity=None, dofs_idx_local=None, envs_idx=None, *, skip_forward=False, unsafe=False): """ - Set the entity's dofs' friction loss. + Set the entity's dofs' velocity. + Parameters ---------- - frictionloss : array_like - The friction loss values to set. + velocity : array_like | None + The velocity to set. Zero if not specified. dofs_idx_local : None | array_like, optional The indices of the dofs to set. If None, all dofs will be set. Note that here this uses the local `q_idx`, not the scene-level one. Defaults to None. envs_idx : None | array_like, optional The indices of the environments. If None, all environments will be considered. Defaults to None. """ dofs_idx = self._get_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True) - self._solver.set_dofs_frictionloss(frictionloss, dofs_idx, envs_idx, unsafe=unsafe) + self._solver.set_dofs_velocity(velocity, dofs_idx, envs_idx, skip_forward=skip_forward, unsafe=unsafe) @gs.assert_built def set_dofs_position(self, position, dofs_idx_local=None, envs_idx=None, *, zero_velocity=True, unsafe=False): @@ -2252,9 +2245,9 @@ def set_dofs_position(self, position, dofs_idx_local=None, envs_idx=None, *, zer Whether to zero the velocity of all the entity's dofs. Defaults to True. This is a safety measure after a sudden change in entity pose. """ dofs_idx = self._get_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True) - self._solver.set_dofs_position(position, dofs_idx, envs_idx, unsafe=unsafe, skip_forward=zero_velocity) if zero_velocity: - self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe) + self._solver.set_dofs_velocity(None, self._dofs_idx, envs_idx, skip_forward=True, unsafe=unsafe) + self._solver.set_dofs_position(position, dofs_idx, envs_idx, unsafe=unsafe) @gs.assert_built def control_dofs_force(self, force, dofs_idx_local=None, envs_idx=None, *, unsafe=False): @@ -2570,8 +2563,7 @@ def zero_all_dofs_velocity(self, envs_idx=None, *, unsafe=False): envs_idx : None | array_like, optional The indices of the environments. If None, all environments will be considered. Defaults to None. """ - dofs_idx_local = torch.arange(self.n_dofs, dtype=gs.tc_int, device=gs.device) - self.set_dofs_velocity(None, dofs_idx_local, envs_idx, unsafe=unsafe) + self.set_dofs_velocity(None, self._dofs_idx, envs_idx, unsafe=unsafe) @gs.assert_built def detect_collision(self, env_idx=0): diff --git a/genesis/engine/scene.py b/genesis/engine/scene.py index 596e1b79c..225a1c635 100644 --- a/genesis/engine/scene.py +++ b/genesis/engine/scene.py @@ -1406,8 +1406,9 @@ def _sanitize_envs_idx(self, envs_idx, *, unsafe=False): if _envs_idx.ndim != 1: gs.raise_exception("Expecting a 1D tensor for `envs_idx`.") - if (_envs_idx < 0).any() or (_envs_idx >= self.n_envs).any(): - gs.raise_exception("`envs_idx` exceeds valid range.") + # FIXME: This check is too expensive + # if (_envs_idx < 0).any() or (_envs_idx >= self.n_envs).any(): + # gs.raise_exception("`envs_idx` exceeds valid range.") return _envs_idx diff --git a/genesis/engine/solvers/rigid/collider_decomp.py b/genesis/engine/solvers/rigid/collider_decomp.py index c48dac9a5..8e133a28f 100644 --- a/genesis/engine/solvers/rigid/collider_decomp.py +++ b/genesis/engine/solvers/rigid/collider_decomp.py @@ -296,15 +296,30 @@ def _init_terrain_state(self): self._collider_info.terrain_scale.from_numpy(scale) self._collider_info.terrain_xyz_maxmin.from_numpy(xyz_maxmin) - def reset(self, envs_idx: npt.NDArray[np.int32] | None = None) -> None: + def reset(self, envs_idx: npt.NDArray[np.int32] | None = None, cache_only: bool = False) -> None: + self._contacts_info_cache.clear() + if gs.use_zerocopy: + mask = () if envs_idx is None else envs_idx + if not cache_only: + first_time = ti_to_torch(self._collider_state.first_time, copy=False) + if isinstance(envs_idx, torch.Tensor): + first_time.scatter_(0, envs_idx, True) + else: + first_time[mask] = True + i_va_ws = ti_to_torch(self._collider_state.contact_cache.i_va_ws, copy=False) + normal = ti_to_torch(self._collider_state.contact_cache.normal, copy=False) + if isinstance(envs_idx, torch.Tensor): + n_geoms = i_va_ws.shape[0] + i_va_ws.scatter_(2, envs_idx[None, None].expand((n_geoms, n_geoms, -1)), -1) + normal.scatter_(2, envs_idx[None, None, :, None].expand((n_geoms, n_geoms, -1, 3)), 0.0) + else: + i_va_ws[mask] = -1 + normal[mask] = 0.0 + return + if envs_idx is None: envs_idx = self._solver._scene._envs_idx - collider_kernel_reset( - envs_idx, - self._solver._static_rigid_sim_config, - self._collider_state, - ) - self._contacts_info_cache.clear() + collider_kernel_reset(envs_idx, self._solver._static_rigid_sim_config, self._collider_state, cache_only) def clear(self, envs_idx=None): if envs_idx is None: @@ -548,13 +563,17 @@ def collider_kernel_reset( envs_idx: ti.types.ndarray(), static_rigid_sim_config: ti.template(), collider_state: array_class.ColliderState, + cache_only: ti.template(), ): n_geoms = collider_state.active_buffer.shape[0] ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_b_ in range(envs_idx.shape[0]): i_b = envs_idx[i_b_] - collider_state.first_time[i_b] = 1 + + if ti.static(not cache_only): + collider_state.first_time[i_b] = True + for i_ga, i_gb in ti.ndrange(n_geoms, n_geoms): collider_state.contact_cache.i_va_ws[i_ga, i_gb, i_b] = -1 collider_state.contact_cache.i_va_ws[i_gb, i_ga, i_b] = -1 diff --git a/genesis/engine/solvers/rigid/constraint_solver_decomp.py b/genesis/engine/solvers/rigid/constraint_solver_decomp.py index e89afab9b..cbafc65db 100644 --- a/genesis/engine/solvers/rigid/constraint_solver_decomp.py +++ b/genesis/engine/solvers/rigid/constraint_solver_decomp.py @@ -12,6 +12,7 @@ import genesis.engine.solvers.rigid.rigid_solver_decomp as rigid_solver import genesis.engine.solvers.rigid.constraint_noslip as constraint_noslip from genesis.engine.solvers.rigid.contact_island import ContactIsland +from genesis.utils.misc import ti_to_torch if TYPE_CHECKING: from genesis.engine.solvers.rigid.rigid_solver_decomp import RigidSolver @@ -107,6 +108,7 @@ def clear(self, envs_idx: npt.NDArray[np.int32] | None = None, cache_only: bool self._eq_const_info_cache.clear() if cache_only: return + if envs_idx is None: envs_idx = self._solver._scene._envs_idx constraint_solver_kernel_clear( @@ -115,8 +117,26 @@ def clear(self, envs_idx: npt.NDArray[np.int32] | None = None, cache_only: bool static_rigid_sim_config=self._solver._static_rigid_sim_config, ) - def reset(self, envs_idx=None): + def reset(self, envs_idx=None, clear_contraints_info=True): self._eq_const_info_cache.clear() + + if gs.use_zerocopy and not clear_contraints_info: + n_constraints = ti_to_torch(self.constraint_state.n_constraints, copy=False) + n_constraints_equality = ti_to_torch(self.constraint_state.n_constraints_equality, copy=False) + n_constraints_frictionloss = ti_to_torch(self.constraint_state.n_constraints_frictionloss, copy=False) + qacc_ws = ti_to_torch(self.constraint_state.qacc_ws, copy=False) + if isinstance(envs_idx, torch.Tensor): + n_constraints.scatter_(0, envs_idx, 0) + n_constraints_equality.scatter_(0, envs_idx, 0) + n_constraints_frictionloss.scatter_(0, envs_idx, 0) + qacc_ws.scatter_(1, envs_idx[None].expand((qacc_ws.shape[0], -1)), 0.0) + else: + n_constraints[envs_idx] = 0 + n_constraints_equality[envs_idx] = 0 + n_constraints_frictionloss[envs_idx] = 0 + qacc_ws[:, envs_idx] = 0.0 + return + if envs_idx is None: envs_idx = self._solver._scene._envs_idx constraint_solver_kernel_reset( diff --git a/genesis/engine/solvers/rigid/rigid_solver_decomp.py b/genesis/engine/solvers/rigid/rigid_solver_decomp.py index f2274ab09..783ae0547 100644 --- a/genesis/engine/solvers/rigid/rigid_solver_decomp.py +++ b/genesis/engine/solvers/rigid/rigid_solver_decomp.py @@ -914,6 +914,7 @@ def check_errno(self): ) def _kernel_detect_collision(self): + self.collider.reset(cache_only=True) self.collider.clear() self.collider.detection() @@ -929,10 +930,12 @@ def detect_collision(self, env_idx=0): return collision_pairs def _func_constraint_force(self): - self.constraint_solver.clear(cache_only=not self._use_contact_island) - - if not self._disable_constraint and not self._use_contact_island: - self.constraint_solver.add_equality_constraints() + if not self._disable_constraint: + if self._use_contact_island: + self.constraint_solver.clear() + else: + self.constraint_solver.clear(cache_only=True) + self.constraint_solver.add_equality_constraints() if self._enable_collision: self.collider.detection() @@ -1294,11 +1297,10 @@ def set_state(self, f, state, envs_idx=None): ) self._errno[None] = 0 - self.collider.reset(envs_idx) + self.collider.reset(envs_idx, cache_only=False) self.collider.clear(envs_idx) if self.constraint_solver is not None: self.constraint_solver.reset(envs_idx) - self.constraint_solver.clear(envs_idx) self._links_state_cache.clear() self._cur_step = -1 @@ -1501,12 +1503,10 @@ def _sanitize_2D_io_variables( def _get_qs_idx(self, qs_idx_local=None): return self._get_qs_idx_local(qs_idx_local) + self._q_start - def set_links_pos(self, pos, links_idx=None, envs_idx=None, *, skip_forward=False, unsafe=False): + def set_links_pos(self, pos, links_idx=None, envs_idx=None, *, unsafe=False): raise DeprecationError("This method has been removed. Please use 'set_base_links_pos' instead.") - def set_base_links_pos( - self, pos, links_idx=None, envs_idx=None, *, relative=False, skip_forward=False, unsafe=False - ): + def set_base_links_pos(self, pos, links_idx=None, envs_idx=None, *, relative=False, unsafe=False): if links_idx is None: links_idx = self._base_links_idx pos, links_idx, envs_idx = self._sanitize_2D_io_variables( @@ -1542,28 +1542,25 @@ def set_base_links_pos( ) self._links_state_cache.clear() - if not skip_forward: - kernel_forward_kinematics_links_geoms( - envs_idx, - links_state=self.links_state, - links_info=self.links_info, - joints_state=self.joints_state, - joints_info=self.joints_info, - dofs_state=self.dofs_state, - dofs_info=self.dofs_info, - geoms_state=self.geoms_state, - geoms_info=self.geoms_info, - entities_info=self.entities_info, - rigid_global_info=self._rigid_global_info, - static_rigid_sim_config=self._static_rigid_sim_config, - ) + kernel_forward_kinematics_links_geoms( + envs_idx, + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + geoms_state=self.geoms_state, + geoms_info=self.geoms_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) - def set_links_quat(self, quat, links_idx=None, envs_idx=None, *, skip_forward=False, unsafe=False): + def set_links_quat(self, quat, links_idx=None, envs_idx=None, *, unsafe=False): raise DeprecationError("This method has been removed. Please use 'set_base_links_quat' instead.") - def set_base_links_quat( - self, quat, links_idx=None, envs_idx=None, *, relative=False, skip_forward=False, unsafe=False - ): + def set_base_links_quat(self, quat, links_idx=None, envs_idx=None, *, relative=False, unsafe=False): if links_idx is None: links_idx = self._base_links_idx quat, links_idx, envs_idx = self._sanitize_2D_io_variables( @@ -1594,21 +1591,20 @@ def set_base_links_quat( ) self._links_state_cache.clear() - if not skip_forward: - kernel_forward_kinematics_links_geoms( - envs_idx, - links_state=self.links_state, - links_info=self.links_info, - joints_state=self.joints_state, - joints_info=self.joints_info, - dofs_state=self.dofs_state, - dofs_info=self.dofs_info, - geoms_state=self.geoms_state, - geoms_info=self.geoms_info, - entities_info=self.entities_info, - rigid_global_info=self._rigid_global_info, - static_rigid_sim_config=self._static_rigid_sim_config, - ) + kernel_forward_kinematics_links_geoms( + envs_idx, + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + geoms_state=self.geoms_state, + geoms_info=self.geoms_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) def set_links_mass_shift(self, mass, links_idx=None, envs_idx=None, *, unsafe=False): mass, links_idx, envs_idx = self._sanitize_1D_io_variables( @@ -1658,20 +1654,31 @@ def set_geoms_friction_ratio(self, friction_ratio, geoms_idx=None, envs_idx=None ) def set_qpos(self, qpos, qs_idx=None, envs_idx=None, *, skip_forward=False, unsafe=False): - qpos, qs_idx, envs_idx = self._sanitize_1D_io_variables( - qpos, qs_idx, self.n_qs, envs_idx, idx_name="qs_idx", skip_allocation=True, unsafe=unsafe - ) - if self.n_envs == 0: - qpos = qpos.unsqueeze(0) - kernel_set_qpos(qpos, qs_idx, envs_idx, self._rigid_global_info, self._static_rigid_sim_config) + if gs.use_zerocopy: + mask = (0, *indices_to_mask(qs_idx)) if self.n_envs == 0 else indices_to_mask(envs_idx, qs_idx) + data = ti_to_torch(self._rigid_global_info.qpos, transpose=True, copy=False) + data[mask] = torch.as_tensor(qpos, dtype=gs.tc_float, device=gs.device) + if mask and isinstance(mask[0], torch.Tensor): + envs_idx = mask[0] + else: + qpos, qs_idx, envs_idx = self._sanitize_1D_io_variables( + qpos, qs_idx, self.n_qs, envs_idx, idx_name="qs_idx", skip_allocation=True, unsafe=unsafe + ) + if self.n_envs == 0: + qpos = qpos.unsqueeze(0) + kernel_set_qpos(qpos, qs_idx, envs_idx, self._rigid_global_info, self._static_rigid_sim_config) - self._errno[None] = 0 - self.collider.reset(envs_idx) - self.collider.clear(envs_idx) - if self.constraint_solver is not None: - self.constraint_solver.reset(envs_idx) - self.constraint_solver.clear(envs_idx) self._links_state_cache.clear() + self.collider.reset(envs_idx, cache_only=True) + if not isinstance(envs_idx, torch.Tensor): + envs_idx = self._scene._sanitize_envs_idx(envs_idx, unsafe=unsafe) + if not skip_forward: + self.collider.clear(envs_idx) + if self.constraint_solver is not None: + if self._use_contact_island: + self.constraint_solver.reset(envs_idx) + else: + self.constraint_solver.reset(envs_idx, clear_contraints_info=not skip_forward) if not skip_forward: kernel_forward_kinematics_links_geoms( envs_idx, @@ -1848,35 +1855,42 @@ def set_dofs_limit(self, lower, upper, dofs_idx=None, envs_idx=None, *, unsafe=F self._set_dofs_info([lower, upper], dofs_idx, "limit", envs_idx, unsafe=unsafe) def set_dofs_velocity(self, velocity, dofs_idx=None, envs_idx=None, *, skip_forward=False, unsafe=False): - velocity, dofs_idx, envs_idx = self._sanitize_1D_io_variables( - velocity, dofs_idx, self.n_dofs, envs_idx, skip_allocation=True, unsafe=unsafe - ) - - if velocity is None: - kernel_set_dofs_zero_velocity(dofs_idx, envs_idx, self.dofs_state, self._static_rigid_sim_config) + if gs.use_zerocopy: + vel = ti_to_torch(self.dofs_state.vel, transpose=True, copy=False) + if velocity is None and isinstance(dofs_idx, slice) and isinstance(envs_idx, torch.Tensor): + (vel := vel[:, dofs_idx]).scatter_(0, envs_idx[:, None].expand((-1, vel.shape[1])), 0.0) + else: + mask = (0, *indices_to_mask(dofs_idx)) if self.n_envs == 0 else indices_to_mask(envs_idx, dofs_idx) + vel[mask] = 0.0 if velocity is None else torch.as_tensor(velocity, dtype=gs.tc_float, device=gs.device) + if mask and isinstance(mask[0], torch.Tensor): + envs_idx = mask[0] + elif not isinstance(envs_idx, torch.Tensor): + envs_idx = self._scene._sanitize_envs_idx(envs_idx, unsafe=unsafe) else: - if self.n_envs == 0: - velocity = velocity.unsqueeze(0) - kernel_set_dofs_velocity(velocity, dofs_idx, envs_idx, self.dofs_state, self._static_rigid_sim_config) + velocity, dofs_idx, envs_idx = self._sanitize_1D_io_variables( + velocity, dofs_idx, self.n_dofs, envs_idx, skip_allocation=True, unsafe=unsafe + ) + if velocity is None: + kernel_set_dofs_zero_velocity(dofs_idx, envs_idx, self.dofs_state, self._static_rigid_sim_config) + else: + if self.n_envs == 0: + velocity = velocity.unsqueeze(0) + kernel_set_dofs_velocity(velocity, dofs_idx, envs_idx, self.dofs_state, self._static_rigid_sim_config) self._links_state_cache.clear() if not skip_forward: - kernel_forward_kinematics_links_geoms( + kernel_forward_velocity( envs_idx, links_state=self.links_state, links_info=self.links_info, - joints_state=self.joints_state, joints_info=self.joints_info, dofs_state=self.dofs_state, - dofs_info=self.dofs_info, - geoms_state=self.geoms_state, - geoms_info=self.geoms_info, entities_info=self.entities_info, rigid_global_info=self._rigid_global_info, static_rigid_sim_config=self._static_rigid_sim_config, ) - def set_dofs_position(self, position, dofs_idx=None, envs_idx=None, *, skip_forward=False, unsafe=False): + def set_dofs_position(self, position, dofs_idx=None, envs_idx=None, *, unsafe=False): position, dofs_idx, envs_idx = self._sanitize_1D_io_variables( position, dofs_idx, self.n_dofs, envs_idx, skip_allocation=True, unsafe=unsafe ) @@ -1894,28 +1908,25 @@ def set_dofs_position(self, position, dofs_idx=None, envs_idx=None, *, skip_forw self._static_rigid_sim_config, ) - self._errno[None] = 0 - self.collider.reset(envs_idx) + self._links_state_cache.clear() + self.collider.reset(envs_idx, cache_only=True) self.collider.clear(envs_idx) if self.constraint_solver is not None: self.constraint_solver.reset(envs_idx) - self.constraint_solver.clear(envs_idx) - self._links_state_cache.clear() - if not skip_forward: - kernel_forward_kinematics_links_geoms( - envs_idx, - links_state=self.links_state, - links_info=self.links_info, - joints_state=self.joints_state, - joints_info=self.joints_info, - dofs_state=self.dofs_state, - dofs_info=self.dofs_info, - geoms_state=self.geoms_state, - geoms_info=self.geoms_info, - entities_info=self.entities_info, - rigid_global_info=self._rigid_global_info, - static_rigid_sim_config=self._static_rigid_sim_config, - ) + kernel_forward_kinematics_links_geoms( + envs_idx, + links_state=self.links_state, + links_info=self.links_info, + joints_state=self.joints_state, + joints_info=self.joints_info, + dofs_state=self.dofs_state, + dofs_info=self.dofs_info, + geoms_state=self.geoms_state, + geoms_info=self.geoms_info, + entities_info=self.entities_info, + rigid_global_info=self._rigid_global_info, + static_rigid_sim_config=self._static_rigid_sim_config, + ) def control_dofs_force(self, force, dofs_idx=None, envs_idx=None, *, unsafe=False): if gs.use_zerocopy: @@ -2324,7 +2335,7 @@ def get_equality_constraints(self, as_tensor: bool = True, to_torch: bool = True def clear_external_force(self): if gs.use_zerocopy: for tensor in (self.links_state.cfrc_applied_ang, self.links_state.cfrc_applied_vel): - out = ti_to_python(tensor, copy=False) + out = ti_to_torch(tensor, copy=False) out.zero_() else: kernel_clear_external_force(self.links_state, self._rigid_global_info, self._static_rigid_sim_config) @@ -3696,17 +3707,6 @@ def func_update_cartesian_space( rigid_global_info=rigid_global_info, static_rigid_sim_config=static_rigid_sim_config, ) - func_forward_velocity( - i_b, - entities_info=entities_info, - links_info=links_info, - links_state=links_state, - joints_info=joints_info, - dofs_state=dofs_state, - rigid_global_info=rigid_global_info, - static_rigid_sim_config=static_rigid_sim_config, - ) - func_update_geoms( i_b=i_b, entities_info=entities_info, @@ -3754,6 +3754,16 @@ def kernel_step_1( static_rigid_sim_config=static_rigid_sim_config, force_update_fixed_geoms=False, ) + func_forward_velocity( + i_b=i_b, + entities_info=entities_info, + links_info=links_info, + links_state=links_state, + joints_info=joints_info, + dofs_state=dofs_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) func_forward_dynamics( links_state=links_state, @@ -3920,6 +3930,16 @@ def kernel_step_2( static_rigid_sim_config=static_rigid_sim_config, force_update_fixed_geoms=False, ) + func_forward_velocity( + i_b=i_b, + entities_info=entities_info, + links_info=links_info, + links_state=links_state, + joints_info=joints_info, + dofs_state=dofs_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) @ti.kernel(fastcache=gs.use_fastcache) @@ -3955,6 +3975,41 @@ def kernel_forward_kinematics_links_geoms( static_rigid_sim_config=static_rigid_sim_config, force_update_fixed_geoms=True, ) + func_forward_velocity( + i_b=i_b, + entities_info=entities_info, + links_info=links_info, + links_state=links_state, + joints_info=joints_info, + dofs_state=dofs_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) + + +@ti.kernel(fastcache=gs.use_fastcache) +def kernel_forward_velocity( + envs_idx: ti.types.ndarray(), + links_state: array_class.LinksState, + links_info: array_class.LinksInfo, + joints_info: array_class.JointsInfo, + dofs_state: array_class.DofsState, + entities_info: array_class.EntitiesInfo, + rigid_global_info: array_class.RigidGlobalInfo, + static_rigid_sim_config: ti.template(), +): + for i_b_ in range(envs_idx.shape[0]): + i_b = envs_idx[i_b_] + func_forward_velocity( + i_b=i_b, + entities_info=entities_info, + links_info=links_info, + links_state=links_state, + joints_info=joints_info, + dofs_state=dofs_state, + rigid_global_info=rigid_global_info, + static_rigid_sim_config=static_rigid_sim_config, + ) @ti.func @@ -5920,7 +5975,6 @@ def kernel_set_links_pos( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), ): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_l_, i_b_ in ti.ndrange(links_idx.shape[0], envs_idx.shape[0]): i_b = envs_idx[i_b_] @@ -5955,7 +6009,6 @@ def kernel_set_links_quat( rigid_global_info: array_class.RigidGlobalInfo, static_rigid_sim_config: ti.template(), ): - ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) for i_l_, i_b_ in ti.ndrange(links_idx.shape[0], envs_idx.shape[0]): i_b = envs_idx[i_b_] diff --git a/genesis/utils/array_class.py b/genesis/utils/array_class.py index d8dcb07f7..23c54bddb 100644 --- a/genesis/utils/array_class.py +++ b/genesis/utils/array_class.py @@ -547,7 +547,7 @@ def get_collider_state( prism=V_VEC(3, dtype=gs.ti_float, shape=(6, _B)), n_contacts=V(dtype=gs.ti_int, shape=(_B,)), n_contacts_hibernated=V(dtype=gs.ti_int, shape=(_B,)), - first_time=V(dtype=gs.ti_int, shape=(_B,)), + first_time=V(dtype=gs.ti_bool, shape=(_B,)), contact_cache=get_contact_cache(solver), broad_collision_pairs=V_VEC(2, dtype=gs.ti_int, shape=(max(max_collision_pairs_broad, 1), _B)), contact_data=get_contact_data(solver, max_contact_pairs, requires_grad), diff --git a/tests/test_rigid_benchmarks.py b/tests/test_rigid_benchmarks.py index 17b5b4c0b..5abb14ee2 100644 --- a/tests/test_rigid_benchmarks.py +++ b/tests/test_rigid_benchmarks.py @@ -420,14 +420,27 @@ def _batched_franka(solver, n_envs, gjk, is_collision_free, accessors): if n_envs > 0: ctrl = torch.tile(ctrl, (n_envs, 1)) if is_collision_free: + franka.set_qpos(ctrl) franka.control_dofs_position(ctrl) + vel0 = torch.zeros((franka.n_qs,), dtype=gs.tc_float, device=gs.device) + if n_envs > 0: + n_reset_envs = int(0.02 * n_envs) + reset_envs_idx = torch.randperm(n_envs)[:n_reset_envs] + vel0 = torch.tile(vel0, (n_reset_envs, 1)) + qpos0 = ctrl[reset_envs_idx] + else: + reset_envs_idx = None + qpos0 = ctrl + num_steps = 0 is_recording = False time_start = time.time() while True: scene.step() if accessors: + franka.set_qpos(qpos0, envs_idx=reset_envs_idx, zero_velocity=False, skip_forward=True) + franka.set_dofs_velocity(vel0, envs_idx=reset_envs_idx, skip_forward=True) franka.get_ang() franka.get_vel() franka.get_dofs_position()