Skip to content

Commit 075d70a

Browse files
authored
[FEATURE] Add support of zero-copy to 'set_qpos' and 'set_dofs_velocity'. (#2025)
* Improve numerical stability of tower. * Do not clear error code in 'set_qpos' and 'set_dofs_position'. * Add support of zero-copy to 'set_dofs_velocity'. * Skip velocity update if not necessary. * Add support of zero-copy to 'set_qpos'. * Do not expose 'skip_forward' option for setters that are discouraged in hot path. * Add 'set_qpos', 'set_dofs_velocity' to accessor benchmark.
1 parent ea4c7dd commit 075d70a

File tree

8 files changed

+302
-202
lines changed

8 files changed

+302
-202
lines changed

examples/collision/tower.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import os
23

34
import genesis as gs
45

@@ -9,12 +10,13 @@ def main():
910
parser.add_argument("-v", "--vis", action="store_true", default=False)
1011
args = parser.parse_args()
1112
object_type = args.object
13+
horizon = 50 if "PYTEST_VERSION" in os.environ else 1000
1214

1315
gs.init(backend=gs.cpu, precision="32")
1416

1517
scene = gs.Scene(
1618
sim_options=gs.options.SimOptions(
17-
dt=0.005,
19+
dt=0.004,
1820
),
1921
rigid_options=gs.options.RigidOptions(
2022
max_collision_pairs=200,
@@ -27,7 +29,7 @@ def main():
2729
show_viewer=args.vis,
2830
)
2931

30-
plane = scene.add_entity(gs.morphs.Plane())
32+
scene.add_entity(gs.morphs.Plane())
3133

3234
# create pyramid of boxes
3335
box_width, box_length, box_height = 0.25, 2.0, 0.1
@@ -51,12 +53,12 @@ def main():
5153

5254
# Drop a huge mesh
5355
if object_type == "duck":
54-
duck_scale = 1.0
55-
duck = scene.add_entity(
56+
duck_scale = 0.8
57+
scene.add_entity(
5658
morph=gs.morphs.Mesh(
5759
file="meshes/duck.obj",
5860
scale=duck_scale,
59-
pos=(0, 0, num_stacks * box_height + 10 * duck_scale),
61+
pos=(0, -0.1, num_stacks * box_height + 10 * duck_scale),
6062
),
6163
)
6264
elif object_type == "sphere":
@@ -78,7 +80,7 @@ def main():
7880
)
7981

8082
scene.build()
81-
for i in range(600):
83+
for i in range(horizon):
8284
scene.step()
8385

8486

genesis/engine/entities/rigid_entity/rigid_entity.py

Lines changed: 72 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,10 @@ def _build(self):
580580
self._n_free_verts = len(self._free_verts_idx_local)
581581
self._n_fixed_verts = len(self._fixed_verts_idx_local)
582582

583+
self._dofs_idx = torch.arange(
584+
self._dof_start, self._dof_start + self._n_dofs, dtype=gs.tc_int, device=gs.device
585+
)
586+
583587
self._geoms = self.geoms
584588
self._vgeoms = self.vgeoms
585589

@@ -1493,6 +1497,7 @@ def _kernel_forward_kinematics(
14931497
# ------------------------------------------------------------------------------------
14941498
# --------------------------------- motion planing -----------------------------------
14951499
# ------------------------------------------------------------------------------------
1500+
14961501
@gs.assert_built
14971502
def plan_path(
14981503
self,
@@ -1623,6 +1628,50 @@ def plan_path(
16231628
# ---------------------------------- control & io ------------------------------------
16241629
# ------------------------------------------------------------------------------------
16251630

1631+
def _get_idx(self, idx_local, idx_local_max, idx_global_start=0, *, unsafe=False):
1632+
# Handling default argument and special cases
1633+
if idx_local is None:
1634+
if unsafe:
1635+
idx_global = slice(idx_global_start, idx_local_max + idx_global_start)
1636+
else:
1637+
idx_global = range(idx_global_start, idx_local_max + idx_global_start)
1638+
elif isinstance(idx_local, (range, slice)):
1639+
idx_global = range(
1640+
(idx_local.start or 0) + idx_global_start,
1641+
(idx_local.stop if idx_local.stop is not None else idx_local_max) + idx_global_start,
1642+
idx_local.step or 1,
1643+
)
1644+
elif isinstance(idx_local, (int, np.integer)):
1645+
idx_global = idx_local + idx_global_start
1646+
elif isinstance(idx_local, (list, tuple)):
1647+
try:
1648+
idx_global = [i + idx_global_start for i in idx_local]
1649+
except TypeError:
1650+
gs.raise_exception("Expecting a sequence of integers for `idx_local`.")
1651+
else:
1652+
# Increment may be slow when dealing with heterogenuous data, so it must be avoided if possible
1653+
if idx_global_start > 0:
1654+
idx_global = idx_local + idx_global_start
1655+
else:
1656+
idx_global = idx_local
1657+
1658+
# Early return if unsafe
1659+
if unsafe:
1660+
return idx_global
1661+
1662+
# Perform a bunch of sanity checks
1663+
_idx_global = torch.as_tensor(idx_global, dtype=gs.tc_int, device=gs.device).contiguous()
1664+
if _idx_global is not idx_global:
1665+
gs.logger.debug(ALLOCATE_TENSOR_WARNING)
1666+
idx_global = torch.atleast_1d(_idx_global)
1667+
1668+
if idx_global.ndim != 1:
1669+
gs.raise_exception("Expecting a 1D tensor for `idx_local`.")
1670+
if (idx_global < 0).any() or (idx_global >= idx_global_start + idx_local_max).any():
1671+
gs.raise_exception("`idx_local` exceeds valid range.")
1672+
1673+
return idx_global
1674+
16261675
def get_joint(self, name=None, uid=None):
16271676
"""
16281677
Get a RigidJoint object by name or uid.
@@ -1949,7 +1998,7 @@ def get_links_invweight(self, links_idx_local=None, envs_idx=None, *, unsafe=Fal
19491998
return self._solver.get_links_invweight(links_idx, envs_idx, unsafe=unsafe)
19501999

19512000
@gs.assert_built
1952-
def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, unsafe=False):
2001+
def set_pos(self, pos, envs_idx=None, *, relative=False, unsafe=False):
19532002
"""
19542003
Set position of the entity's base link.
19552004
@@ -1971,19 +2020,13 @@ def set_pos(self, pos, envs_idx=None, *, relative=False, zero_velocity=True, uns
19712020
if _pos is not pos:
19722021
gs.logger.debug(ALLOCATE_TENSOR_WARNING)
19732022
pos = _pos
2023+
self._solver.set_dofs_velocity(None, self._dofs_idx, envs_idx, skip_forward=True, unsafe=unsafe)
19742024
self._solver.set_base_links_pos(
1975-
pos.unsqueeze(-2),
1976-
self._base_links_idx_,
1977-
envs_idx,
1978-
relative=relative,
1979-
unsafe=unsafe,
1980-
skip_forward=zero_velocity,
2025+
pos.unsqueeze(-2), self._base_links_idx_, envs_idx, relative=relative, unsafe=unsafe
19812026
)
1982-
if zero_velocity:
1983-
self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe)
19842027

19852028
@gs.assert_built
1986-
def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, unsafe=False):
2029+
def set_quat(self, quat, envs_idx=None, *, relative=False, unsafe=False):
19872030
"""
19882031
Set quaternion of the entity's base link.
19892032
@@ -2005,16 +2048,10 @@ def set_quat(self, quat, envs_idx=None, *, relative=False, zero_velocity=True, u
20052048
if _quat is not quat:
20062049
gs.logger.debug(ALLOCATE_TENSOR_WARNING)
20072050
quat = _quat
2051+
self._solver.set_dofs_velocity(None, self._dofs_idx, envs_idx, skip_forward=True, unsafe=unsafe)
20082052
self._solver.set_base_links_quat(
2009-
quat.unsqueeze(-2),
2010-
self._base_links_idx_,
2011-
envs_idx,
2012-
relative=relative,
2013-
unsafe=unsafe,
2014-
skip_forward=zero_velocity,
2053+
quat.unsqueeze(-2), self._base_links_idx_, envs_idx, relative=relative, unsafe=unsafe
20152054
)
2016-
if zero_velocity:
2017-
self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe)
20182055

20192056
@gs.assert_built
20202057
def get_verts(self):
@@ -2061,52 +2098,8 @@ def get_verts(self):
20612098
tensor = tensor[0]
20622099
return tensor
20632100

2064-
def _get_idx(self, idx_local, idx_local_max, idx_global_start=0, *, unsafe=False):
2065-
# Handling default argument and special cases
2066-
if idx_local is None:
2067-
if unsafe:
2068-
idx_global = slice(idx_global_start, idx_local_max + idx_global_start)
2069-
else:
2070-
idx_global = range(idx_global_start, idx_local_max + idx_global_start)
2071-
elif isinstance(idx_local, (range, slice)):
2072-
idx_global = range(
2073-
(idx_local.start or 0) + idx_global_start,
2074-
(idx_local.stop if idx_local.stop is not None else idx_local_max) + idx_global_start,
2075-
idx_local.step or 1,
2076-
)
2077-
elif isinstance(idx_local, (int, np.integer)):
2078-
idx_global = idx_local + idx_global_start
2079-
elif isinstance(idx_local, (list, tuple)):
2080-
try:
2081-
idx_global = [i + idx_global_start for i in idx_local]
2082-
except TypeError:
2083-
gs.raise_exception("Expecting a sequence of integers for `idx_local`.")
2084-
else:
2085-
# Increment may be slow when dealing with heterogenuous data, so it must be avoided if possible
2086-
if idx_global_start > 0:
2087-
idx_global = idx_local + idx_global_start
2088-
else:
2089-
idx_global = idx_local
2090-
2091-
# Early return if unsafe
2092-
if unsafe:
2093-
return idx_global
2094-
2095-
# Perform a bunch of sanity checks
2096-
_idx_global = torch.as_tensor(idx_global, dtype=gs.tc_int, device=gs.device).contiguous()
2097-
if _idx_global is not idx_global:
2098-
gs.logger.debug(ALLOCATE_TENSOR_WARNING)
2099-
idx_global = torch.atleast_1d(_idx_global)
2100-
2101-
if idx_global.ndim != 1:
2102-
gs.raise_exception("Expecting a 1D tensor for `idx_local`.")
2103-
if (idx_global < 0).any() or (idx_global >= idx_global_start + idx_local_max).any():
2104-
gs.raise_exception("`idx_local` exceeds valid range.")
2105-
2106-
return idx_global
2107-
21082101
@gs.assert_built
2109-
def set_qpos(self, qpos, qs_idx_local=None, envs_idx=None, *, zero_velocity=True, unsafe=False):
2102+
def set_qpos(self, qpos, qs_idx_local=None, envs_idx=None, *, zero_velocity=True, skip_forward=False, unsafe=False):
21102103
"""
21112104
Set the entity's qpos.
21122105
@@ -2122,9 +2115,9 @@ def set_qpos(self, qpos, qs_idx_local=None, envs_idx=None, *, zero_velocity=True
21222115
Whether to zero the velocity of all the entity's dofs. Defaults to True. This is a safety measure after a sudden change in entity pose.
21232116
"""
21242117
qs_idx = self._get_idx(qs_idx_local, self.n_qs, self._q_start, unsafe=True)
2125-
self._solver.set_qpos(qpos, qs_idx, envs_idx, unsafe=unsafe, skip_forward=zero_velocity)
21262118
if zero_velocity:
2127-
self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe)
2119+
self._solver.set_dofs_velocity(None, self._dofs_idx, envs_idx, skip_forward=True, unsafe=unsafe)
2120+
self._solver.set_qpos(qpos, qs_idx, envs_idx, skip_forward=skip_forward, unsafe=unsafe)
21282121

21292122
@gs.assert_built
21302123
def set_dofs_kp(self, kp, dofs_idx_local=None, envs_idx=None, *, unsafe=False):
@@ -2203,37 +2196,37 @@ def set_dofs_damping(self, damping, dofs_idx_local=None, envs_idx=None, *, unsaf
22032196
self._solver.set_dofs_damping(damping, dofs_idx, envs_idx, unsafe=unsafe)
22042197

22052198
@gs.assert_built
2206-
def set_dofs_velocity(self, velocity=None, dofs_idx_local=None, envs_idx=None, *, unsafe=False):
2199+
def set_dofs_frictionloss(self, frictionloss, dofs_idx_local=None, envs_idx=None, *, unsafe=False):
22072200
"""
2208-
Set the entity's dofs' velocity.
2209-
2201+
Set the entity's dofs' friction loss.
22102202
Parameters
22112203
----------
2212-
velocity : array_like | None
2213-
The velocity to set. Zero if not specified.
2204+
frictionloss : array_like
2205+
The friction loss values to set.
22142206
dofs_idx_local : None | array_like, optional
22152207
The indices of the dofs to set. If None, all dofs will be set. Note that here this uses the local `q_idx`, not the scene-level one. Defaults to None.
22162208
envs_idx : None | array_like, optional
22172209
The indices of the environments. If None, all environments will be considered. Defaults to None.
22182210
"""
22192211
dofs_idx = self._get_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True)
2220-
self._solver.set_dofs_velocity(velocity, dofs_idx, envs_idx, skip_forward=False, unsafe=unsafe)
2212+
self._solver.set_dofs_frictionloss(frictionloss, dofs_idx, envs_idx, unsafe=unsafe)
22212213

22222214
@gs.assert_built
2223-
def set_dofs_frictionloss(self, frictionloss, dofs_idx_local=None, envs_idx=None, *, unsafe=False):
2215+
def set_dofs_velocity(self, velocity=None, dofs_idx_local=None, envs_idx=None, *, skip_forward=False, unsafe=False):
22242216
"""
2225-
Set the entity's dofs' friction loss.
2217+
Set the entity's dofs' velocity.
2218+
22262219
Parameters
22272220
----------
2228-
frictionloss : array_like
2229-
The friction loss values to set.
2221+
velocity : array_like | None
2222+
The velocity to set. Zero if not specified.
22302223
dofs_idx_local : None | array_like, optional
22312224
The indices of the dofs to set. If None, all dofs will be set. Note that here this uses the local `q_idx`, not the scene-level one. Defaults to None.
22322225
envs_idx : None | array_like, optional
22332226
The indices of the environments. If None, all environments will be considered. Defaults to None.
22342227
"""
22352228
dofs_idx = self._get_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True)
2236-
self._solver.set_dofs_frictionloss(frictionloss, dofs_idx, envs_idx, unsafe=unsafe)
2229+
self._solver.set_dofs_velocity(velocity, dofs_idx, envs_idx, skip_forward=skip_forward, unsafe=unsafe)
22372230

22382231
@gs.assert_built
22392232
def set_dofs_position(self, position, dofs_idx_local=None, envs_idx=None, *, zero_velocity=True, unsafe=False):
@@ -2252,9 +2245,9 @@ def set_dofs_position(self, position, dofs_idx_local=None, envs_idx=None, *, zer
22522245
Whether to zero the velocity of all the entity's dofs. Defaults to True. This is a safety measure after a sudden change in entity pose.
22532246
"""
22542247
dofs_idx = self._get_idx(dofs_idx_local, self.n_dofs, self._dof_start, unsafe=True)
2255-
self._solver.set_dofs_position(position, dofs_idx, envs_idx, unsafe=unsafe, skip_forward=zero_velocity)
22562248
if zero_velocity:
2257-
self.zero_all_dofs_velocity(envs_idx, unsafe=unsafe)
2249+
self._solver.set_dofs_velocity(None, self._dofs_idx, envs_idx, skip_forward=True, unsafe=unsafe)
2250+
self._solver.set_dofs_position(position, dofs_idx, envs_idx, unsafe=unsafe)
22582251

22592252
@gs.assert_built
22602253
def control_dofs_force(self, force, dofs_idx_local=None, envs_idx=None, *, unsafe=False):
@@ -2570,8 +2563,7 @@ def zero_all_dofs_velocity(self, envs_idx=None, *, unsafe=False):
25702563
envs_idx : None | array_like, optional
25712564
The indices of the environments. If None, all environments will be considered. Defaults to None.
25722565
"""
2573-
dofs_idx_local = torch.arange(self.n_dofs, dtype=gs.tc_int, device=gs.device)
2574-
self.set_dofs_velocity(None, dofs_idx_local, envs_idx, unsafe=unsafe)
2566+
self.set_dofs_velocity(None, self._dofs_idx, envs_idx, unsafe=unsafe)
25752567

25762568
@gs.assert_built
25772569
def detect_collision(self, env_idx=0):

genesis/engine/scene.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1406,8 +1406,9 @@ def _sanitize_envs_idx(self, envs_idx, *, unsafe=False):
14061406
if _envs_idx.ndim != 1:
14071407
gs.raise_exception("Expecting a 1D tensor for `envs_idx`.")
14081408

1409-
if (_envs_idx < 0).any() or (_envs_idx >= self.n_envs).any():
1410-
gs.raise_exception("`envs_idx` exceeds valid range.")
1409+
# FIXME: This check is too expensive
1410+
# if (_envs_idx < 0).any() or (_envs_idx >= self.n_envs).any():
1411+
# gs.raise_exception("`envs_idx` exceeds valid range.")
14111412

14121413
return _envs_idx
14131414

genesis/engine/solvers/rigid/collider_decomp.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -296,15 +296,30 @@ def _init_terrain_state(self):
296296
self._collider_info.terrain_scale.from_numpy(scale)
297297
self._collider_info.terrain_xyz_maxmin.from_numpy(xyz_maxmin)
298298

299-
def reset(self, envs_idx: npt.NDArray[np.int32] | None = None) -> None:
299+
def reset(self, envs_idx: npt.NDArray[np.int32] | None = None, cache_only: bool = False) -> None:
300+
self._contacts_info_cache.clear()
301+
if gs.use_zerocopy:
302+
mask = () if envs_idx is None else envs_idx
303+
if not cache_only:
304+
first_time = ti_to_torch(self._collider_state.first_time, copy=False)
305+
if isinstance(envs_idx, torch.Tensor):
306+
first_time.scatter_(0, envs_idx, True)
307+
else:
308+
first_time[mask] = True
309+
i_va_ws = ti_to_torch(self._collider_state.contact_cache.i_va_ws, copy=False)
310+
normal = ti_to_torch(self._collider_state.contact_cache.normal, copy=False)
311+
if isinstance(envs_idx, torch.Tensor):
312+
n_geoms = i_va_ws.shape[0]
313+
i_va_ws.scatter_(2, envs_idx[None, None].expand((n_geoms, n_geoms, -1)), -1)
314+
normal.scatter_(2, envs_idx[None, None, :, None].expand((n_geoms, n_geoms, -1, 3)), 0.0)
315+
else:
316+
i_va_ws[mask] = -1
317+
normal[mask] = 0.0
318+
return
319+
300320
if envs_idx is None:
301321
envs_idx = self._solver._scene._envs_idx
302-
collider_kernel_reset(
303-
envs_idx,
304-
self._solver._static_rigid_sim_config,
305-
self._collider_state,
306-
)
307-
self._contacts_info_cache.clear()
322+
collider_kernel_reset(envs_idx, self._solver._static_rigid_sim_config, self._collider_state, cache_only)
308323

309324
def clear(self, envs_idx=None):
310325
if envs_idx is None:
@@ -548,13 +563,17 @@ def collider_kernel_reset(
548563
envs_idx: ti.types.ndarray(),
549564
static_rigid_sim_config: ti.template(),
550565
collider_state: array_class.ColliderState,
566+
cache_only: ti.template(),
551567
):
552568
n_geoms = collider_state.active_buffer.shape[0]
553569

554570
ti.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
555571
for i_b_ in range(envs_idx.shape[0]):
556572
i_b = envs_idx[i_b_]
557-
collider_state.first_time[i_b] = 1
573+
574+
if ti.static(not cache_only):
575+
collider_state.first_time[i_b] = True
576+
558577
for i_ga, i_gb in ti.ndrange(n_geoms, n_geoms):
559578
collider_state.contact_cache.i_va_ws[i_ga, i_gb, i_b] = -1
560579
collider_state.contact_cache.i_va_ws[i_gb, i_ga, i_b] = -1

0 commit comments

Comments
 (0)