Skip to content
14 changes: 8 additions & 6 deletions examples/collision/tower.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import os

import genesis as gs

Expand All @@ -9,12 +10,13 @@ def main():
parser.add_argument("-v", "--vis", action="store_true", default=False)
args = parser.parse_args()
object_type = args.object
horizon = 50 if "PYTEST_VERSION" in os.environ else 1000

gs.init(backend=gs.cpu, precision="32")

scene = gs.Scene(
sim_options=gs.options.SimOptions(
dt=0.005,
dt=0.004,
),
rigid_options=gs.options.RigidOptions(
max_collision_pairs=200,
Expand All @@ -27,7 +29,7 @@ def main():
show_viewer=args.vis,
)

plane = scene.add_entity(gs.morphs.Plane())
scene.add_entity(gs.morphs.Plane())

# create pyramid of boxes
box_width, box_length, box_height = 0.25, 2.0, 0.1
Expand All @@ -51,12 +53,12 @@ def main():

# Drop a huge mesh
if object_type == "duck":
duck_scale = 1.0
duck = scene.add_entity(
duck_scale = 0.8
scene.add_entity(
morph=gs.morphs.Mesh(
file="meshes/duck.obj",
scale=duck_scale,
pos=(0, 0, num_stacks * box_height + 10 * duck_scale),
pos=(0, -0.1, num_stacks * box_height + 10 * duck_scale),
),
)
elif object_type == "sphere":
Expand All @@ -78,7 +80,7 @@ def main():
)

scene.build()
for i in range(600):
for i in range(horizon):
scene.step()


Expand Down
152 changes: 72 additions & 80 deletions genesis/engine/entities/rigid_entity/rigid_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,10 @@ def _build(self):
self._n_free_verts = len(self._free_verts_idx_local)
self._n_fixed_verts = len(self._fixed_verts_idx_local)

self._dofs_idx = torch.arange(
self._dof_start, self._dof_start + self._n_dofs, dtype=gs.tc_int, device=gs.device
)

self._geoms = self.geoms
self._vgeoms = self.vgeoms

Expand Down Expand Up @@ -1493,6 +1497,7 @@ def _kernel_forward_kinematics(
# ------------------------------------------------------------------------------------
# --------------------------------- motion planing -----------------------------------
# ------------------------------------------------------------------------------------

@gs.assert_built
def plan_path(
self,
Expand Down Expand Up @@ -1623,6 +1628,50 @@ def plan_path(
# ---------------------------------- control & io ------------------------------------
# ------------------------------------------------------------------------------------

def _get_idx(self, idx_local, idx_local_max, idx_global_start=0, *, unsafe=False):
# Handling default argument and special cases
if idx_local is None:
if unsafe:
idx_global = slice(idx_global_start, idx_local_max + idx_global_start)
else:
idx_global = range(idx_global_start, idx_local_max + idx_global_start)
elif isinstance(idx_local, (range, slice)):
idx_global = range(
(idx_local.start or 0) + idx_global_start,
(idx_local.stop if idx_local.stop is not None else idx_local_max) + idx_global_start,
idx_local.step or 1,
)
elif isinstance(idx_local, (int, np.integer)):
idx_global = idx_local + idx_global_start
elif isinstance(idx_local, (list, tuple)):
try:
idx_global = [i + idx_global_start for i in idx_local]
except TypeError:
gs.raise_exception("Expecting a sequence of integers for `idx_local`.")
else:
# Increment may be slow when dealing with heterogenuous data, so it must be avoided if possible
if idx_global_start > 0:
idx_global = idx_local + idx_global_start
else:
idx_global = idx_local

# Early return if unsafe
if unsafe:
return idx_global

# Perform a bunch of sanity checks
_idx_global = torch.as_tensor(idx_global, dtype=gs.tc_int, device=gs.device).contiguous()
if _idx_global is not idx_global:
gs.logger.debug(ALLOCATE_TENSOR_WARNING)
idx_global = torch.atleast_1d(_idx_global)

if idx_global.ndim != 1:
gs.raise_exception("Expecting a 1D tensor for `idx_local`.")
if (idx_global < 0).any() or (idx_global >= idx_global_start + idx_local_max).any():
gs.raise_exception("`idx_local` exceeds valid range.")

return idx_global

def get_joint(self, name=None, uid=None):
"""
Get a RigidJoint object by name or uid.
Expand Down Expand Up @@ -1949,7 +1998,7 @@ def get_links_invweight(self, links_idx_local=None, envs_idx=None, *, unsafe=Fal
return self._solver.get_links_invweight(links_idx, envs_idx, unsafe=unsafe)

@gs.assert_built
def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, unsafe=False):
def set_pos(self, pos, envs_idx=None, *, relative=False, unsafe=False):
"""
Set position of the entity's base link.

Expand All @@ -1971,19 +2020,13 @@ def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, uns
if _pos is not pos:
gs.logger.debug(ALLOCATE_TENSOR_WARNING)
pos = _pos
self._solver.set_dofs_velocity(None, self._dofs_idx, envs_idx, skip_forward=True, unsafe=unsafe)
self._solver.set_base_links_pos(
pos.unsqueeze(-2),
self._base_links_idx_,
envs_idx,
relative=relative,
unsafe=unsafe,
skip_forward=zero_velocity,
pos.unsqueeze(-2), self._base_links_idx_, envs_idx, relative=relative, unsafe=unsafe
)
if zero_velocity:
self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe)

@gs.assert_built
def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, unsafe=False):
def set_quat(self, quat, envs_idx=None, *, relative=False, unsafe=False):
"""
Set quaternion of the entity's base link.

Expand All @@ -2005,16 +2048,10 @@ def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, u
if _quat is not quat:
gs.logger.debug(ALLOCATE_TENSOR_WARNING)
quat = _quat
self._solver.set_dofs_velocity(None, self._dofs_idx, envs_idx, skip_forward=True, unsafe=unsafe)
self._solver.set_base_links_quat(
quat.unsqueeze(-2),
self._base_links_idx_,
envs_idx,
relative=relative,
unsafe=unsafe,
skip_forward=zero_velocity,
quat.unsqueeze(-2), self._base_links_idx_, envs_idx, relative=relative, unsafe=unsafe
)
if zero_velocity:
self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe)

@gs.assert_built
def get_verts(self):
Expand Down Expand Up @@ -2061,52 +2098,8 @@ def get_verts(self):
tensor = tensor[0]
return tensor

def _get_idx(self, idx_local, idx_local_max, idx_global_start=0, *, unsafe=False):
# Handling default argument and special cases
if idx_local is None:
if unsafe:
idx_global = slice(idx_global_start, idx_local_max + idx_global_start)
else:
idx_global = range(idx_global_start, idx_local_max + idx_global_start)
elif isinstance(idx_local, (range, slice)):
idx_global = range(
(idx_local.start or 0) + idx_global_start,
(idx_local.stop if idx_local.stop is not None else idx_local_max) + idx_global_start,
idx_local.step or 1,
)
elif isinstance(idx_local, (int, np.integer)):
idx_global = idx_local + idx_global_start
elif isinstance(idx_local, (list, tuple)):
try:
idx_global = [i + idx_global_start for i in idx_local]
except TypeError:
gs.raise_exception("Expecting a sequence of integers for `idx_local`.")
else:
# Increment may be slow when dealing with heterogenuous data, so it must be avoided if possible
if idx_global_start > 0:
idx_global = idx_local + idx_global_start
else:
idx_global = idx_local

# Early return if unsafe
if unsafe:
return idx_global

# Perform a bunch of sanity checks
_idx_global = torch.as_tensor(idx_global, dtype=gs.tc_int, device=gs.device).contiguous()
if _idx_global is not idx_global:
gs.logger.debug(ALLOCATE_TENSOR_WARNING)
idx_global = torch.atleast_1d(_idx_global)

if idx_global.ndim != 1:
gs.raise_exception("Expecting a 1D tensor for `idx_local`.")
if (idx_global < 0).any() or (idx_global >= idx_global_start + idx_local_max).any():
gs.raise_exception("`idx_local` exceeds valid range.")

return idx_global

@gs.assert_built
def set_qpos(self, qpos, qs_idx_local=None, envs_idx=None, *, zero_velocity=True, unsafe=False):
def set_qpos(self, qpos, qs_idx_local=None, envs_idx=None, *, zero_velocity=True, skip_forward=False, unsafe=False):
"""
Set the entity's qpos.

Expand All @@ -2122,9 +2115,9 @@ def set_qpos(self, qpos, qs_idx_local=None, envs_idx=None, *, zero_velocity=True
Whether to zero the velocity of all the entity's dofs. Defaults to True. This is a safety measure after a sudden change in entity pose.
"""
qs_idx = self._get_idx(qs_idx_local, self.n_qs, self._q_start, unsafe=True)
self._solver.set_qpos(qpos, qs_idx, envs_idx, unsafe=unsafe, skip_forward=zero_velocity)
if zero_velocity:
self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe)
self._solver.set_dofs_velocity(None, self._dofs_idx, envs_idx, skip_forward=True, unsafe=unsafe)
self._solver.set_qpos(qpos, qs_idx, envs_idx, skip_forward=skip_forward, unsafe=unsafe)

@gs.assert_built
def set_dofs_kp(self, kp, dofs_idx_local=None, envs_idx=None, *, unsafe=False):
Expand Down Expand Up @@ -2203,37 +2196,37 @@ def set_dofs_damping(self, damping, dofs_idx_local=None, envs_idx=None, *, unsaf
self._solver.set_dofs_damping(damping, dofs_idx, envs_idx, unsafe=unsafe)

@gs.assert_built
def set_dofs_velocity(self, velocity=None, dofs_idx_local=None, envs_idx=None, *, unsafe=False):
def set_dofs_frictionloss(self, frictionloss, dofs_idx_local=None, envs_idx=None, *, unsafe=False):
"""
Set the entity's dofs' velocity.

Set the entity's dofs' friction loss.
Parameters
----------
velocity : array_like | None
The velocity to set. Zero if not specified.
frictionloss : array_like
The friction loss values to set.
dofs_idx_local : None | array_like, optional
The indices of the dofs to set. If None, all dofs will be set. Note that here this uses the local `q_idx`, not the scene-level one. Defaults to None.
envs_idx : None | array_like, optional
The indices of the environments. If None, all environments will be considered. Defaults to None.
"""
dofs_idx = self._get_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True)
self._solver.set_dofs_velocity(velocity, dofs_idx, envs_idx, skip_forward=False, unsafe=unsafe)
self._solver.set_dofs_frictionloss(frictionloss, dofs_idx, envs_idx, unsafe=unsafe)

@gs.assert_built
def set_dofs_frictionloss(self, frictionloss, dofs_idx_local=None, envs_idx=None, *, unsafe=False):
def set_dofs_velocity(self, velocity=None, dofs_idx_local=None, envs_idx=None, *, skip_forward=False, unsafe=False):
"""
Set the entity's dofs' friction loss.
Set the entity's dofs' velocity.

Parameters
----------
frictionloss : array_like
The friction loss values to set.
velocity : array_like | None
The velocity to set. Zero if not specified.
dofs_idx_local : None | array_like, optional
The indices of the dofs to set. If None, all dofs will be set. Note that here this uses the local `q_idx`, not the scene-level one. Defaults to None.
envs_idx : None | array_like, optional
The indices of the environments. If None, all environments will be considered. Defaults to None.
"""
dofs_idx = self._get_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True)
self._solver.set_dofs_frictionloss(frictionloss, dofs_idx, envs_idx, unsafe=unsafe)
self._solver.set_dofs_velocity(velocity, dofs_idx, envs_idx, skip_forward=skip_forward, unsafe=unsafe)

@gs.assert_built
def set_dofs_position(self, position, dofs_idx_local=None, envs_idx=None, *, zero_velocity=True, unsafe=False):
Expand All @@ -2252,9 +2245,9 @@ def set_dofs_position(self, position, dofs_idx_local=None, envs_idx=None, *, zer
Whether to zero the velocity of all the entity's dofs. Defaults to True. This is a safety measure after a sudden change in entity pose.
"""
dofs_idx = self._get_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True)
self._solver.set_dofs_position(position, dofs_idx, envs_idx, unsafe=unsafe, skip_forward=zero_velocity)
if zero_velocity:
self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe)
self._solver.set_dofs_velocity(None, self._dofs_idx, envs_idx, skip_forward=True, unsafe=unsafe)
self._solver.set_dofs_position(position, dofs_idx, envs_idx, unsafe=unsafe)

@gs.assert_built
def control_dofs_force(self, force, dofs_idx_local=None, envs_idx=None, *, unsafe=False):
Expand Down Expand Up @@ -2570,8 +2563,7 @@ def zero_all_dofs_velocity(self, envs_idx=None, *, unsafe=False):
envs_idx : None | array_like, optional
The indices of the environments. If None, all environments will be considered. Defaults to None.
"""
dofs_idx_local = torch.arange(self.n_dofs, dtype=gs.tc_int, device=gs.device)
self.set_dofs_velocity(None, dofs_idx_local, envs_idx, unsafe=unsafe)
self.set_dofs_velocity(None, self._dofs_idx, envs_idx, unsafe=unsafe)

@gs.assert_built
def detect_collision(self, env_idx=0):
Expand Down
5 changes: 3 additions & 2 deletions genesis/engine/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,8 +1406,9 @@ def _sanitize_envs_idx(self, envs_idx, *, unsafe=False):
if _envs_idx.ndim != 1:
gs.raise_exception("Expecting a 1D tensor for `envs_idx`.")

if (_envs_idx < 0).any() or (_envs_idx >= self.n_envs).any():
gs.raise_exception("`envs_idx` exceeds valid range.")
# FIXME: This check is too expensive
# if (_envs_idx < 0).any() or (_envs_idx >= self.n_envs).any():
# gs.raise_exception("`envs_idx` exceeds valid range.")

return _envs_idx

Expand Down
35 changes: 27 additions & 8 deletions genesis/engine/solvers/rigid/collider_decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,15 +296,30 @@ def _init_terrain_state(self):
self._collider_info.terrain_scale.from_numpy(scale)
self._collider_info.terrain_xyz_maxmin.from_numpy(xyz_maxmin)

def reset(self, envs_idx: npt.NDArray[np.int32] | None = None) -> None:
def reset(self, envs_idx: npt.NDArray[np.int32] | None = None, cache_only: bool = False) -> None:
self._contacts_info_cache.clear()
if gs.use_zerocopy:
mask = () if envs_idx is None else envs_idx
if not cache_only:
first_time = ti_to_torch(self._collider_state.first_time, copy=False)
if isinstance(envs_idx, torch.Tensor):
first_time.scatter_(0, envs_idx, True)
else:
first_time[mask] = True
i_va_ws = ti_to_torch(self._collider_state.contact_cache.i_va_ws, copy=False)
normal = ti_to_torch(self._collider_state.contact_cache.normal, copy=False)
if isinstance(envs_idx, torch.Tensor):
n_geoms = i_va_ws.shape[0]
i_va_ws.scatter_(2, envs_idx[None, None].expand((n_geoms, n_geoms, -1)), -1)
normal.scatter_(2, envs_idx[None, None, :, None].expand((n_geoms, n_geoms, -1, 3)), 0.0)
else:
i_va_ws[mask] = -1
normal[mask] = 0.0
return

if envs_idx is None:
envs_idx = self._solver._scene._envs_idx
collider_kernel_reset(
envs_idx,
self._solver._static_rigid_sim_config,
self._collider_state,
)
self._contacts_info_cache.clear()
collider_kernel_reset(envs_idx, self._solver._static_rigid_sim_config, self._collider_state, cache_only)

def clear(self, envs_idx=None):
if envs_idx is None:
Expand Down Expand Up @@ -548,13 +563,17 @@ def collider_kernel_reset(
envs_idx: ti.types.ndarray(),
static_rigid_sim_config: ti.template(),
collider_state: array_class.ColliderState,
cache_only: ti.template(),
):
n_geoms = collider_state.active_buffer.shape[0]

ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_b_ in range(envs_idx.shape[0]):
i_b = envs_idx[i_b_]
collider_state.first_time[i_b] = 1

if ti.static(not cache_only):
collider_state.first_time[i_b] = True

for i_ga, i_gb in ti.ndrange(n_geoms, n_geoms):
collider_state.contact_cache.i_va_ws[i_ga, i_gb, i_b] = -1
collider_state.contact_cache.i_va_ws[i_gb, i_ga, i_b] = -1
Expand Down
Loading